Update CMake min to 3.12 STRING JOIN was introduced in 3.12 TEST=builds with cmake 3.12 fails with 3.10.2 (stock in ubuntu) Latest Ubuntu PPAs are here: https://apt.kitware.com/ -- e65d4dfd0278185962071d30359fc0ce1232be9e by Anush Elangovan <anush@nod-labs.com>: Fix suprious reference to DLOG TEST:builds Closes #89 COPYBARA_INTEGRATE_REVIEW=https://github.com/google/iree/pull/89 from powderluv:fix-cmake-ver e65d4dfd0278185962071d30359fc0ce1232be9e PiperOrigin-RevId: 276298047
diff --git a/BUILD.bazel b/BUILD.bazel deleted file mode 100644 index 2b86f73..0000000 --- a/BUILD.bazel +++ /dev/null
@@ -1,4 +0,0 @@ -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -)
diff --git a/BUILD.bazel.oss b/BUILD.bazel.oss new file mode 100644 index 0000000..f1f9d53 --- /dev/null +++ b/BUILD.bazel.oss
@@ -0,0 +1,3 @@ +package( + licenses = ["notice"], # Apache 2.0 +)
diff --git a/BUILD.oss b/BUILD.oss new file mode 100644 index 0000000..f5c6905 --- /dev/null +++ b/BUILD.oss
@@ -0,0 +1,30 @@ +# Main IREE build file. +# Note that project-wide, bazel repo aliases are used: +# "//third_party/absl/python" +# "//third_party/absl" +# "//third_party/benchmark" +# "//third_party/llvm/llvm/projects/google_mlir" +# "//third_party/llvm/llvm" +# "//third_party/flatbuffers" +# "//third_party/tensorflow" +# +# Various scripts and helpers operate on these prefixes textually, so +# avoid doing any systematic construction that would break the matching. + +package( + licenses = ["notice"], # Apache 2.0 +) + +# Enables the debug service and other profiling features. +# $ bazel build --define=IREE_DEBUG=1 :some_target +config_setting( + name = "debug", + define_values = {"IREE_DEBUG": "1"}, +) + +# Marker library which can be extended to provide flags for things that +# need to know the platform target. +cc_library( + name = "target_config", + defines = ["IREE_UNSPECIFIED_TARGET=1"], +)
diff --git a/CMakeLists.txt b/CMakeLists.txt index fbf93d0..ac138f0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt
@@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -cmake_minimum_required(VERSION 3.5) +cmake_minimum_required(VERSION 3.12) cmake_policy(SET CMP0077 NEW) set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
diff --git a/base/BUILD b/base/BUILD new file mode 100644 index 0000000..80bce9f --- /dev/null +++ b/base/BUILD
@@ -0,0 +1,362 @@ +# Common types and utilities used in the IREE codebase. + +load("//:build_defs.google.bzl", "platform_trampoline_deps") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "api", + srcs = ["api.cc"], + hdrs = ["api.h"], + visibility = ["//visibility:public"], + deps = [ + ":api_hdrs", + ":api_util", + ":file_mapping", + ":tracing", + ], +) + +cc_library( + name = "api_hdrs", + hdrs = ["api.h"], +) + +cc_library( + name = "api_util", + hdrs = ["api_util.h"], + deps = [ + ":api_hdrs", + ":logging", + ":shape", + ":status", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/time", + ], +) + +cc_library( + name = "arena", + srcs = ["arena.cc"], + hdrs = ["arena.h"], + deps = [ + ":logging", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "arena_test", + srcs = ["arena_test.cc"], + deps = [ + ":arena", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "bitfield", + hdrs = ["bitfield.h"], + deps = [ + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "bitfield_test", + srcs = ["bitfield_test.cc"], + deps = [ + ":bitfield", + "@com_google_absl//absl/base:core_headers", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "file_io", + hdrs = ["file_io.h"], + deps = [ + ":status", + ":target_platform", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ] + platform_trampoline_deps("file_io"), +) + +cc_library( + name = "file_io_hdrs", + hdrs = ["file_io.h"], + deps = [":status"], +) + +cc_library( + name = "file_mapping", + hdrs = ["file_mapping.h"], + deps = [ + ":ref_ptr", + ":status", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ] + platform_trampoline_deps("file_mapping"), +) + +cc_library( + name = "file_mapping_hdrs", + hdrs = ["file_mapping.h"], + deps = [ + ":ref_ptr", + ":status", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "file_path", + srcs = ["file_path.cc"], + hdrs = ["file_path.h"], + deps = [ + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "flatbuffer_util", + srcs = ["flatbuffer_util.cc"], + hdrs = ["flatbuffer_util.h"], + deps = [ + ":file_mapping", + ":memory", + ":source_location", + ":status", + ":tracing", + "@com_github_google_flatbuffers//:flatbuffers", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "init", + hdrs = ["init.h"], + deps = platform_trampoline_deps("init"), +) + +cc_library( + name = "intrusive_list", + hdrs = [ + "intrusive_list.h", + "intrusive_list_ref_ptr.inc", + "intrusive_list_unique_ptr.inc", + ], + deps = [ + ":logging", + ":ref_ptr", + ], +) + +cc_test( + name = "intrusive_list_test", + srcs = [ + "intrusive_list_ref_ptr_test.cc", + "intrusive_list_test.cc", + "intrusive_list_unique_ptr_test.cc", + ], + deps = [ + ":intrusive_list", + "@com_google_absl//absl/memory", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "logging", + hdrs = ["logging.h"], + deps = platform_trampoline_deps("logging"), +) + +cc_library( + name = "math", + hdrs = ["math.h"], + deps = [ + "@com_google_absl//absl/base:core_headers", + ], +) + +cc_library( + name = "memory", + hdrs = ["memory.h"], + deps = [ + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "ref_ptr", + hdrs = ["ref_ptr.h"], + deps = [ + ":logging", + "@com_google_absl//absl/base:core_headers", + ], +) + +cc_test( + name = "ref_ptr_test", + size = "small", + srcs = ["ref_ptr_test.cc"], + deps = [ + ":ref_ptr", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "shape", + srcs = ["shape.cc"], + hdrs = ["shape.h"], + deps = [ + ":logging", + ":source_location", + ":status", + "@com_google_absl//absl/meta:type_traits", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "shape_test", + srcs = ["shape_test.cc"], + deps = [ + ":shape", + ":status", + ":status_matchers", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "source_location", + hdrs = ["source_location.h"], + deps = platform_trampoline_deps("source_location"), +) + +cc_library( + name = "status", + hdrs = ["status.h"], + deps = [ + ":source_location", + ] + platform_trampoline_deps("status"), +) + +cc_library( + name = "status_matchers", + testonly = 1, + hdrs = ["status_matchers.h"], + deps = platform_trampoline_deps("status_matchers"), +) + +cc_library( + name = "target_platform", + hdrs = ["target_platform.h"], +) + +cc_library( + name = "time", + hdrs = ["time.h"], + deps = [ + "@com_google_absl//absl/time", + ], +) + +cc_library( + name = "tracing", + hdrs = ["tracing.h"], + deps = [ + "//:target_config", + "@com_google_tracing_framework_cpp//:tracing_framework_bindings_cpp", + ] + select({ + "@com_google_tracing_framework_cpp//:wtf_enable": [":tracing_enabled"], + "//conditions:default": [":tracing_disabled"], + }), +) + +cc_library( + name = "tracing_disabled", + srcs = [ + "tracing.h", + "tracing_disabled.cc", + ], + visibility = ["//visibility:private"], + deps = [ + ":init", + ":logging", + "@com_google_absl//absl/flags:flag", + "@com_google_tracing_framework_cpp//:tracing_framework_bindings_cpp", + ], + alwayslink = 1, +) + +cc_library( + name = "tracing_enabled", + srcs = [ + "tracing.cc", + "tracing.h", + ], + visibility = ["//visibility:private"], + deps = [ + ":file_io", + ":file_path", + ":init", + ":logging", + ":status", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_tracing_framework_cpp//:tracing_framework_bindings_cpp", + ], + alwayslink = 1, +) + +# Dependent code has been removed and wait_handle is currently incompatible +# with Windows, so excluding entirely. +# See google/iree/65 +# cc_library( +# name = "wait_handle", +# srcs = ["wait_handle.cc"], +# hdrs = ["wait_handle.h"], +# deps = [ +# ":logging", +# ":ref_ptr", +# ":source_location", +# ":status", +# ":time", +# "@com_google_absl//absl/base:core_headers", +# "@com_google_absl//absl/container:fixed_array", +# "@com_google_absl//absl/strings", +# "@com_google_absl//absl/time", +# "@com_google_absl//absl/types:span", +# ], +# ) + +# cc_test( +# name = "wait_handle_test", +# srcs = ["wait_handle_test.cc"], +# deps = [ +# ":status", +# ":status_matchers", +# ":wait_handle", +# "@com_google_absl//absl/time", +# "@com_google_googletest//:gtest_main", +# ], +# )
diff --git a/iree/base/CMakeLists.txt b/base/CMakeLists.txt similarity index 100% rename from iree/base/CMakeLists.txt rename to base/CMakeLists.txt
diff --git a/base/api.cc b/base/api.cc new file mode 100644 index 0000000..c187a34 --- /dev/null +++ b/base/api.cc
@@ -0,0 +1,124 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/api.h" + +#include <cstdlib> +#include <string> + +#include "base/api_util.h" +#include "base/file_mapping.h" +#include "base/tracing.h" + +namespace iree { + +//===----------------------------------------------------------------------===// +// iree Core API +//===----------------------------------------------------------------------===// + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_api_version_check(iree_api_version_t expected_version, + iree_api_version_t* out_actual_version) { + iree_api_version_t actual_version = IREE_API_VERSION_0; + *out_actual_version = actual_version; + return expected_version == actual_version ? IREE_STATUS_OK + : IREE_STATUS_OUT_OF_RANGE; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_allocator_alloc(void* self, iree_host_size_t byte_length, void** out_ptr) { + IREE_TRACE_SCOPE0("iree_allocator_alloc"); + + if (!out_ptr) { + return IREE_STATUS_INVALID_ARGUMENT; + } + *out_ptr = nullptr; + + if (byte_length <= 0) { + return IREE_STATUS_INVALID_ARGUMENT; + } + + *out_ptr = std::malloc(byte_length); + if (!*out_ptr) { + return IREE_STATUS_RESOURCE_EXHAUSTED; + } + + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_allocator_free(void* self, + void* ptr) { + IREE_TRACE_SCOPE0("iree_allocator_free"); + if (ptr) { + std::free(ptr); + } + return IREE_STATUS_OK; +} + +//===----------------------------------------------------------------------===// +// iree::FileMapping +//===----------------------------------------------------------------------===// + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_file_mapping_open_read(iree_string_view_t path, iree_allocator_t allocator, + iree_file_mapping_t** out_file_mapping) { + IREE_TRACE_SCOPE0("iree_file_mapping_open_read"); + + if (!out_file_mapping) { + return IREE_STATUS_INVALID_ARGUMENT; + } + *out_file_mapping = nullptr; + + IREE_API_ASSIGN_OR_RETURN( + auto file_mapping, + FileMapping::OpenRead(std::string(path.data, path.size))); + + *out_file_mapping = + reinterpret_cast<iree_file_mapping_t*>(file_mapping.release()); + + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_file_mapping_retain(iree_file_mapping_t* file_mapping) { + IREE_TRACE_SCOPE0("iree_file_mapping_retain"); + auto* handle = reinterpret_cast<FileMapping*>(file_mapping); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + handle->AddReference(); + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_file_mapping_release(iree_file_mapping_t* file_mapping) { + IREE_TRACE_SCOPE0("iree_file_mapping_release"); + auto* handle = reinterpret_cast<FileMapping*>(file_mapping); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + handle->ReleaseReference(); + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_byte_span_t IREE_API_CALL +iree_file_mapping_data(iree_file_mapping_t* file_mapping) { + IREE_TRACE_SCOPE0("iree_file_mapping_data"); + auto* handle = reinterpret_cast<FileMapping*>(file_mapping); + CHECK(handle) << "NULL file_mapping handle"; + auto data = handle->data(); + return {const_cast<uint8_t*>(data.data()), data.size()}; +} + +} // namespace iree
diff --git a/iree/base/api.h b/base/api.h similarity index 100% rename from iree/base/api.h rename to base/api.h
diff --git a/base/api_util.h b/base/api_util.h new file mode 100644 index 0000000..1f750ad --- /dev/null +++ b/base/api_util.h
@@ -0,0 +1,125 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BASE_API_UTIL_H_ +#define IREE_BASE_API_UTIL_H_ + +#include "absl/base/macros.h" +#include "absl/time/time.h" +#include "base/api.h" +#include "base/logging.h" +#include "base/shape.h" +#include "base/status.h" + +namespace iree { + +inline iree_status_t ToApiStatus(Status status) { + return static_cast<iree_status_t>(status.code()); +} + +inline Status FromApiStatus(iree_status_t status_code, SourceLocation loc) { + return StatusBuilder(static_cast<StatusCode>(status_code), loc); +} + +// Internal helper for concatenating macro values. +#define IREE_API_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) x##y +#define IREE_API_STATUS_MACROS_IMPL_CONCAT_(x, y) \ + IREE_API_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) + +// clang-format off +#define IREE_API_STATUS_MACROS_IMPL_ELSE_BLOCKER_ switch (0) case 0: default: // NOLINT +// clang-format on + +namespace status_macro_internal { +class StatusAdaptorForApiMacros { + public: + StatusAdaptorForApiMacros(const Status& status) : status_(status) {} + StatusAdaptorForApiMacros(Status&& status) : status_(std::move(status)) {} + StatusAdaptorForApiMacros(const StatusAdaptorForApiMacros&) = delete; + StatusAdaptorForApiMacros& operator=(const StatusAdaptorForApiMacros&) = + delete; + explicit operator bool() const { return ABSL_PREDICT_TRUE(status_.ok()); } + Status&& Consume() { return std::move(status_); } + + private: + Status status_; +}; +} // namespace status_macro_internal + +#define IREE_API_RETURN_IF_ERROR(expr) \ + IREE_API_STATUS_MACROS_IMPL_ELSE_BLOCKER_ \ + if (::iree::status_macro_internal::StatusAdaptorForApiMacros \ + status_adaptor = {expr}) { \ + } else /* NOLINT */ \ + return ::iree::ToApiStatus(status_adaptor.Consume()) + +#define IREE_API_RETURN_IF_API_ERROR(expr) \ + IREE_API_STATUS_MACROS_IMPL_ELSE_BLOCKER_ \ + if (iree_status_t status = (expr)) { \ + return status; \ + } + +#define IREE_API_ASSIGN_OR_RETURN(...) \ + IREE_API_STATUS_MACROS_IMPL_GET_VARIADIC_( \ + (__VA_ARGS__, IREE_API_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_, \ + IREE_API_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_)) \ + (__VA_ARGS__) + +#define IREE_API_STATUS_MACROS_IMPL_GET_VARIADIC_HELPER_(_1, _2, _3, NAME, \ + ...) \ + NAME +#define IREE_API_STATUS_MACROS_IMPL_GET_VARIADIC_(args) \ + IREE_API_STATUS_MACROS_IMPL_GET_VARIADIC_HELPER_ args + +#define IREE_API_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_(lhs, rexpr) \ + IREE_API_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, std::move(_)) +#define IREE_API_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, \ + error_expression) \ + IREE_API_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_( \ + IREE_API_STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, \ + rexpr, error_expression) +#define IREE_API_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_(statusor, lhs, rexpr, \ + error_expression) \ + auto statusor = (rexpr); \ + if (ABSL_PREDICT_FALSE(!statusor.ok())) { \ + return ::iree::ToApiStatus(std::move(statusor).status()); \ + } \ + lhs = std::move(statusor).ValueOrDie() + +// Converts an iree_time_t to its equivalent absl::Time. +inline absl::Time ToAbslTime(iree_time_t time) { + if (time == IREE_TIME_INFINITE_PAST) { + return absl::InfinitePast(); + } else if (time == IREE_TIME_INFINITE_FUTURE) { + return absl::InfiniteFuture(); + } else { + return absl::FromUnixNanos(time); + } +} + +// Converts a Shape to an iree_shape_t. +inline iree_status_t ToApiShape(const Shape& shape, iree_shape_t* out_shape) { + out_shape->rank = shape.size(); + if (shape.size() > ABSL_ARRAYSIZE(out_shape->dims)) { + return IREE_STATUS_OUT_OF_RANGE; + } + for (int i = 0; i < out_shape->rank; ++i) { + out_shape->dims[i] = shape[i]; + } + return IREE_STATUS_OK; +} + +} // namespace iree + +#endif // IREE_BASE_API_UTIL_H_
diff --git a/base/arena.cc b/base/arena.cc new file mode 100644 index 0000000..e51aab4 --- /dev/null +++ b/base/arena.cc
@@ -0,0 +1,125 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/arena.h" + +#include <memory> + +#include "absl/base/attributes.h" +#include "base/logging.h" + +namespace iree { + +namespace { + +// Rounds up to the next alignment value, if it is not already aligned. +template <typename T> +ABSL_ATTRIBUTE_ALWAYS_INLINE constexpr T RoundToAlignment( + T value, T alignment) noexcept { + return ((value + alignment - 1) / alignment) * alignment; +} + +} // namespace + +Arena::Arena(size_t block_size) : block_size_(block_size) {} + +Arena::~Arena() { Clear(); } + +void Arena::Clear() { + // Deallocate all memory. + auto block_header = block_list_head_; + while (block_header) { + auto next_block = block_header->next_block; + std::free(block_header); + block_header = next_block; + } + block_list_head_ = nullptr; + block_header = unused_block_list_head_; + while (block_header) { + auto next_block = block_header->next_block; + std::free(block_header); + block_header = next_block; + } + unused_block_list_head_ = nullptr; + + bytes_allocated_ = 0; + block_bytes_allocated_ = 0; +} + +void Arena::Reset() { + // Move all blocks to the unused list and reset allocation count only. + auto block_header = block_list_head_; + while (block_header) { + auto next_block = block_header->next_block; + block_header->bytes_allocated = 0; + block_header->next_block = unused_block_list_head_; + unused_block_list_head_ = block_header; + block_header = next_block; + } + block_list_head_ = nullptr; + + bytes_allocated_ = 0; +} + +uint8_t* Arena::AllocateBytes(size_t length) { + if (!length) { + // Guarantee zero-length allocations return nullptr. + return nullptr; + } + + // Pad length allocated so we are machine word aligned. + // This ensures the next allocation starts at the right boundary. + size_t aligned_length = RoundToAlignment(length, sizeof(uintptr_t)); + + if (aligned_length > block_size_) { + // This allocation is larger than an entire block. That's bad. + // We could allocate this with malloc (and then keep track of those to free + // things), but for now let's just die. + CHECK(false); + return nullptr; + } + + if (!block_list_head_ || + block_list_head_->bytes_allocated + aligned_length > block_size_) { + // Check to see if we have an existing unused block we can use. + if (unused_block_list_head_) { + // Move block from unused list to main list. + auto block_header = unused_block_list_head_; + unused_block_list_head_ = block_header->next_block; + block_header->next_block = block_list_head_; + block_header->bytes_allocated = 0; + block_list_head_ = block_header; + } else { + // Allocate a new block. + auto block_ptr = reinterpret_cast<uint8_t*>( + std::malloc(sizeof(BlockHeader) + block_size_)); + auto block_header = reinterpret_cast<BlockHeader*>(block_ptr); + block_header->next_block = block_list_head_; + block_header->bytes_allocated = 0; + block_list_head_ = block_header; + block_bytes_allocated_ += sizeof(BlockHeader) + block_size_; + } + } + + BlockHeader* target_block = block_list_head_; + auto data_ptr = reinterpret_cast<uint8_t*>(target_block) + + sizeof(BlockHeader) + target_block->bytes_allocated; + target_block->bytes_allocated += aligned_length; + + bytes_allocated_ += length; + + return data_ptr; +} + +} // namespace iree
diff --git a/iree/base/arena.h b/base/arena.h similarity index 100% rename from iree/base/arena.h rename to base/arena.h
diff --git a/base/arena_test.cc b/base/arena_test.cc new file mode 100644 index 0000000..95b557f --- /dev/null +++ b/base/arena_test.cc
@@ -0,0 +1,148 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/arena.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace iree { +namespace { + +// Tests basic block allocations. +TEST(ArenaTest, BasicAllocation) { + Arena arena(64); + EXPECT_EQ(64, arena.block_size()); + EXPECT_EQ(0, arena.bytes_allocated()); + EXPECT_EQ(0, arena.block_bytes_allocated()); + + // Zero byte allocations should return nullptr and not allocate bytes. + auto zero_ptr = reinterpret_cast<uintptr_t>(arena.AllocateBytes(0)); + EXPECT_EQ(0, zero_ptr); + EXPECT_EQ(0, arena.bytes_allocated()); + EXPECT_EQ(0, arena.block_bytes_allocated()); + + arena.Clear(); + + // Allocations must be machine word aligned. + auto one_ptr = reinterpret_cast<uintptr_t>(arena.AllocateBytes(1)); + EXPECT_NE(0, one_ptr); + EXPECT_EQ(0, one_ptr % sizeof(uintptr_t)); + one_ptr = reinterpret_cast<uintptr_t>(arena.AllocateBytes(1)); + EXPECT_NE(0, one_ptr); + EXPECT_EQ(0, one_ptr % sizeof(uintptr_t)); + EXPECT_EQ(2, arena.bytes_allocated()); + EXPECT_LT(2, arena.block_bytes_allocated()); + + arena.Clear(); + EXPECT_EQ(0, arena.bytes_allocated()); + EXPECT_EQ(0, arena.block_bytes_allocated()); +} + +// Tests typed allocations. +TEST(ArenaTest, TypedAllocations) { + Arena arena(64); + + EXPECT_NE(nullptr, arena.Allocate<int>()); + EXPECT_EQ(4, arena.bytes_allocated()); + EXPECT_EQ(64 + Arena::kBlockOverhead, arena.block_bytes_allocated()); + arena.Clear(); + EXPECT_EQ(0, arena.bytes_allocated()); + EXPECT_EQ(0, arena.block_bytes_allocated()); + + struct MyType { + MyType() {} + explicit MyType(int initial_value) : value(initial_value) {} + + int value = 5; + }; + auto my_type_ptr = arena.Allocate<MyType>(); + EXPECT_NE(nullptr, my_type_ptr); + EXPECT_EQ(sizeof(MyType), arena.bytes_allocated()); + EXPECT_EQ(5, my_type_ptr->value); // Default ctor must be called. + arena.Clear(); + EXPECT_EQ(0, arena.bytes_allocated()); + EXPECT_EQ(0, arena.block_bytes_allocated()); + + my_type_ptr = arena.Allocate<MyType>(10); + EXPECT_NE(nullptr, my_type_ptr); + EXPECT_EQ(sizeof(MyType), arena.bytes_allocated()); + EXPECT_EQ(10, my_type_ptr->value); // Ctor should have been called. + arena.Clear(); + EXPECT_EQ(0, arena.bytes_allocated()); + EXPECT_EQ(0, arena.block_bytes_allocated()); +} + +// Tests multiple blocks. +TEST(ArenaTest, MultipleBlocks) { + Arena arena(16); + EXPECT_EQ(0, arena.bytes_allocated()); + EXPECT_EQ(0, arena.block_bytes_allocated()); + + // Allocate one entire block. + EXPECT_NE(nullptr, arena.AllocateBytes(16)); + EXPECT_EQ(16, arena.bytes_allocated()); + EXPECT_EQ(16 + Arena::kBlockOverhead, arena.block_bytes_allocated()); + + // Allocate into the next block. + EXPECT_NE(nullptr, arena.AllocateBytes(16)); + EXPECT_EQ(32, arena.bytes_allocated()); + EXPECT_EQ(32 + 2 * Arena::kBlockOverhead, arena.block_bytes_allocated()); + + // Clear. + arena.Clear(); + EXPECT_EQ(0, arena.bytes_allocated()); + EXPECT_EQ(0, arena.block_bytes_allocated()); + + // Allocate again. + EXPECT_NE(nullptr, arena.AllocateBytes(16)); + EXPECT_EQ(16, arena.bytes_allocated()); + EXPECT_EQ(16 + Arena::kBlockOverhead, arena.block_bytes_allocated()); + EXPECT_NE(nullptr, arena.AllocateBytes(16)); + EXPECT_EQ(32, arena.bytes_allocated()); + EXPECT_EQ(32 + 2 * Arena::kBlockOverhead, arena.block_bytes_allocated()); +} + +// Tests fast reset. +TEST(ArenaTest, FastReset) { + Arena arena(16); + EXPECT_EQ(0, arena.bytes_allocated()); + EXPECT_EQ(0, arena.block_bytes_allocated()); + + // Allocate one entire block. + EXPECT_NE(nullptr, arena.AllocateBytes(16)); + EXPECT_EQ(16, arena.bytes_allocated()); + EXPECT_EQ(16 + Arena::kBlockOverhead, arena.block_bytes_allocated()); + + // Allocate into the next block. + EXPECT_NE(nullptr, arena.AllocateBytes(16)); + EXPECT_EQ(32, arena.bytes_allocated()); + EXPECT_EQ(32 + 2 * Arena::kBlockOverhead, arena.block_bytes_allocated()); + + // Reset (without deallocating). + arena.Reset(); + EXPECT_EQ(0, arena.bytes_allocated()); + EXPECT_EQ(32 + 2 * Arena::kBlockOverhead, arena.block_bytes_allocated()); + + // Allocate again. + EXPECT_NE(nullptr, arena.AllocateBytes(16)); + EXPECT_EQ(16, arena.bytes_allocated()); + EXPECT_EQ(32 + 2 * Arena::kBlockOverhead, arena.block_bytes_allocated()); + EXPECT_NE(nullptr, arena.AllocateBytes(16)); + EXPECT_EQ(32, arena.bytes_allocated()); + EXPECT_EQ(32 + 2 * Arena::kBlockOverhead, arena.block_bytes_allocated()); +} + +} // namespace +} // namespace iree
diff --git a/iree/base/bitfield.h b/base/bitfield.h similarity index 100% rename from iree/base/bitfield.h rename to base/bitfield.h
diff --git a/base/bitfield_test.cc b/base/bitfield_test.cc new file mode 100644 index 0000000..80ead1a --- /dev/null +++ b/base/bitfield_test.cc
@@ -0,0 +1,82 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/bitfield.h" + +#include <cstdint> +#include <vector> + +#include "gtest/gtest.h" + +namespace iree { + +// NOTE: define here so that we don't get internal linkage warnings. +enum class MyValue : uint32_t { + kNone = 0, + kA = 1 << 0, + kB = 1 << 1, + kAll = kA | kB, +}; +IREE_BITFIELD(MyValue); + +namespace { + +// Tests general usage. +TEST(BitfieldTest, FormatBitfieldValue) { + std::vector<std::pair<MyValue, const char *>> mappings = { + {MyValue::kA, "kA"}, + {MyValue::kB, "kB"}, + }; + EXPECT_EQ("", + FormatBitfieldValue(MyValue::kNone, absl::MakeConstSpan(mappings))); + EXPECT_EQ("kA", + FormatBitfieldValue(MyValue::kA, absl::MakeConstSpan(mappings))); + EXPECT_EQ("kA|kB", FormatBitfieldValue(MyValue::kA | MyValue::kB, + absl::MakeConstSpan(mappings))); +} + +// Tests that empty mapping tables are fine. +TEST(BitfieldTest, FormatBitfieldValueEmpty) { + EXPECT_EQ("", FormatBitfieldValue(MyValue::kNone, {})); +} + +// Tests that values not found in the mappings are still displayed. +TEST(BitfieldTest, FormatBitfieldValueUnhandledValues) { + EXPECT_EQ("kA|2h", FormatBitfieldValue(MyValue::kA | MyValue::kB, + { + {MyValue::kA, "kA"}, + })); +} + +// Tests priority order in the mapping table. +TEST(BitfieldTest, FormatBitfieldValuePriority) { + // No priority, will do separate. + EXPECT_EQ("kA|kB", FormatBitfieldValue(MyValue::kA | MyValue::kB, + { + {MyValue::kA, "kA"}, + {MyValue::kB, "kB"}, + {MyValue::kAll, "kAll"}, + })); + + // Priority on the combined flag, use that instead. + EXPECT_EQ("kAll", FormatBitfieldValue(MyValue::kA | MyValue::kB, + { + {MyValue::kAll, "kAll"}, + {MyValue::kA, "kA"}, + {MyValue::kB, "kB"}, + })); +} + +} // namespace +} // namespace iree
diff --git a/base/file_io.h b/base/file_io.h new file mode 100644 index 0000000..4bf1b72 --- /dev/null +++ b/base/file_io.h
@@ -0,0 +1,48 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BASE_FILE_IO_H_ +#define IREE_BASE_FILE_IO_H_ + +#include <string> + +#include "base/status.h" + +namespace iree { +namespace file_io { + +// Checks if a file exists at the provided path. +// +// Returns an OK status if the file definitely exists. +// Errors can include PermissionDeniedError, NotFoundError, etc. +Status FileExists(const std::string& path); + +// Synchronously reads a file's contents into a string. +StatusOr<std::string> GetFileContents(const std::string& path); + +// Deletes the file at the provided path. +Status DeleteFile(const std::string& path); + +// Moves a file from 'source_path' to 'destination_path'. +// +// This may simply rename the file, but may fall back to a full copy and delete +// of the original if renaming is not possible (for example when moving between +// physical storage locations). +Status MoveFile(const std::string& source_path, + const std::string& destination_path); + +} // namespace file_io +} // namespace iree + +#endif // IREE_BASE_FILE_IO_H_
diff --git a/base/file_mapping.h b/base/file_mapping.h new file mode 100644 index 0000000..23a2769 --- /dev/null +++ b/base/file_mapping.h
@@ -0,0 +1,51 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BASE_FILE_MAPPING_H_ +#define IREE_BASE_FILE_MAPPING_H_ + +#include <cstdint> +#include <string> + +#include "absl/types/span.h" +#include "base/ref_ptr.h" +#include "base/status.h" + +namespace iree { + +// A memory-mapped file handle. +class FileMapping : public RefObject<FileMapping> { + public: + // Opens a file and maps it into the calling process memory. + // The file will be opened for shared read access. + static StatusOr<ref_ptr<FileMapping>> OpenRead(std::string path); + + virtual ~FileMapping() = default; + + // Read-only contents of the file. + inline absl::Span<const uint8_t> data() const noexcept { return data_; } + + protected: + explicit FileMapping(absl::Span<const uint8_t> data) : data_(data) {} + + absl::Span<const uint8_t> data_; + + private: + FileMapping(const FileMapping&) = delete; + FileMapping& operator=(const FileMapping&) = delete; +}; + +} // namespace iree + +#endif // IREE_BASE_FILE_MAPPING_H_
diff --git a/base/file_path.cc b/base/file_path.cc new file mode 100644 index 0000000..60384f7 --- /dev/null +++ b/base/file_path.cc
@@ -0,0 +1,83 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/file_path.h" + +#include "absl/strings/str_cat.h" + +namespace iree { +namespace file_path { + +namespace { + +std::pair<absl::string_view, absl::string_view> SplitPath( + absl::string_view path) { + size_t pos = path.find_last_of('/'); + // Handle the case with no '/' in 'path'. + if (pos == absl::string_view::npos) { + return std::make_pair(path.substr(0, 0), path); + } + // Handle the case with a single leading '/' in 'path'. + if (pos == 0) { + return std::make_pair(path.substr(0, 1), absl::ClippedSubstr(path, 1)); + } + return std::make_pair(path.substr(0, pos), + absl::ClippedSubstr(path, pos + 1)); +} + +// Return the parts of the basename of path, split on the final ".". +// If there is no "." in the basename or "." is the final character in the +// basename, the second value will be empty. +std::pair<absl::string_view, absl::string_view> SplitBasename( + absl::string_view path) { + path = Basename(path); + size_t pos = path.find_last_of('.'); + if (pos == absl::string_view::npos) + return std::make_pair(path, absl::ClippedSubstr(path, path.size(), 0)); + return std::make_pair(path.substr(0, pos), + absl::ClippedSubstr(path, pos + 1)); +} + +} // namespace + +std::string JoinPaths(absl::string_view path1, absl::string_view path2) { + if (path1.empty()) return std::string(path2); + if (path2.empty()) return std::string(path1); + if (path1.back() == '/') { + if (path2.front() == '/') + return absl::StrCat(path1, absl::ClippedSubstr(path2, 1)); + } else { + if (path2.front() != '/') return absl::StrCat(path1, "/", path2); + } + return absl::StrCat(path1, path2); +} + +absl::string_view DirectoryName(absl::string_view path) { + return SplitPath(path).first; +} + +absl::string_view Basename(absl::string_view path) { + return SplitPath(path).second; +} + +absl::string_view Stem(absl::string_view path) { + return SplitBasename(path).first; +} + +absl::string_view Extension(absl::string_view path) { + return SplitBasename(path).second; +} + +} // namespace file_path +} // namespace iree
diff --git a/iree/base/file_path.h b/base/file_path.h similarity index 100% rename from iree/base/file_path.h rename to base/file_path.h
diff --git a/base/flatbuffer_util.cc b/base/flatbuffer_util.cc new file mode 100644 index 0000000..72a9e7c --- /dev/null +++ b/base/flatbuffer_util.cc
@@ -0,0 +1,145 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/flatbuffer_util.h" + +#include <cerrno> +#include <cstring> + +#include "absl/memory/memory.h" +#include "base/file_mapping.h" +#include "base/memory.h" +#include "base/source_location.h" +#include "base/status.h" +#include "base/tracing.h" + +namespace iree { + +FlatBufferFileBase::~FlatBufferFileBase() { + if (deleter_) { + deleter_(); + deleter_ = []() {}; + } +} + +Status FlatBufferFileBase::Create(const void* root_ptr, + std::function<void()> deleter) { + IREE_TRACE_SCOPE0("FlatBufferFileBase::Create"); + + root_ptr_ = root_ptr; + deleter_ = std::move(deleter); + + return OkStatus(); +} + +Status FlatBufferFileBase::CreateWithBackingBuffer( + const void* root_ptr, ::flatbuffers::DetachedBuffer backing_buffer) { + IREE_TRACE_SCOPE0("FlatBufferFileBase::Create"); + + root_ptr_ = root_ptr; + + // Pass along the buffer provided so we keep it alive until the + // FlatBufferFileBase is destructed. + auto backing_buffer_baton = IreeMoveToLambda(backing_buffer); + deleter_ = [backing_buffer_baton]() { (void)backing_buffer_baton.value; }; + + return OkStatus(); +} + +Status FlatBufferFileBase::Wrap(const void* root_ptr) { + IREE_TRACE_SCOPE0("FlatBufferFileBase::Wrap"); + return Create(root_ptr, []() {}); +} + +Status FlatBufferFileBase::FromBuffer(Identifier identifier, + absl::Span<const uint8_t> buffer_data, + std::function<void()> deleter, + size_t root_type_size, + VerifierFn verifier_fn) { + IREE_TRACE_SCOPE("FlatBufferFileBase::FromBuffer:size", int) + (static_cast<int>(buffer_data.size())); + + // Sanity check buffer for the minimum size as FlatBuffers doesn't. + if (buffer_data.size() < 16) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Provided serialized flatbuffer buffer is too small to be legit " + "at size=" + << buffer_data.size(); + } + + // Ensure the buffer has the BIPE magic bytes. + if (identifier.has_value() && !::flatbuffers::BufferHasIdentifier( + buffer_data.data(), identifier.value())) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Provided serialized buffer does not contain the expected type; " + "magic bytes mismatch (expected " + << identifier.value() << ")"; + } + + // Verify the FlatBuffer contains valid offsets and won't try to read out of + // bounds of the buffer. We inline a bit of VerifyBufferFromStart so this code + // can stay generic. + { + IREE_TRACE_SCOPE0("FlatBufferFileBase::FromBufferVerification"); + ::flatbuffers::Verifier verifier{buffer_data.data(), buffer_data.size()}; + if (!verifier_fn(identifier.value_or(nullptr), &verifier)) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "FlatBuffer failed to verify as expected type; possibly " + "corrupt input"; + } + } + + // Resolve the root pointer in the buffer. + // This is GetMutableRoot such that we don't need to know T. + root_ptr_ = buffer_data.data() + + ::flatbuffers::EndianScalar( + *reinterpret_cast<const ::flatbuffers::uoffset_t*>( + buffer_data.data())); + if (!root_ptr_) { + return FailedPreconditionErrorBuilder(IREE_LOC) + << "Unable to resolve root table"; + } + deleter_ = std::move(deleter); + + return OkStatus(); +} + +Status FlatBufferFileBase::WrapBuffer(Identifier identifier, + absl::Span<const uint8_t> buffer_data, + size_t root_type_size, + VerifierFn verifier_fn) { + IREE_TRACE_SCOPE0("FlatBufferFileBase::WrapBuffer"); + return FromBuffer( + identifier, buffer_data, []() {}, root_type_size, verifier_fn); +} + +Status FlatBufferFileBase::LoadFile(Identifier identifier, std::string path, + size_t root_type_size, + VerifierFn verifier_fn) { + IREE_TRACE_SCOPE0("FlatBufferFileBase::LoadFile"); + + ASSIGN_OR_RETURN(auto file_mapping, FileMapping::OpenRead(path)); + auto buffer_data = file_mapping->data(); + + auto handle_baton = IreeMoveToLambda(file_mapping); + return FromBuffer( + identifier, buffer_data, + [handle_baton]() { + // Keeping the mmap handle alive. + (void)handle_baton.value; + }, + root_type_size, verifier_fn); +} + +} // namespace iree
diff --git a/base/flatbuffer_util.h b/base/flatbuffer_util.h new file mode 100644 index 0000000..79b8fe4 --- /dev/null +++ b/base/flatbuffer_util.h
@@ -0,0 +1,321 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BASE_FLATBUFFER_UTIL_H_ +#define IREE_BASE_FLATBUFFER_UTIL_H_ + +#include <cstddef> +#include <cstdint> +#include <functional> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "base/memory.h" +#include "base/status.h" +#include "flatbuffers/flatbuffers.h" + +namespace iree { + +// Wraps a FlatBuffer String in an absl::string_view. +// Returns empty-string ("") for nullptr values. +inline absl::string_view WrapString(const ::flatbuffers::String* value) { + return value ? absl::string_view{value->data(), value->size()} : ""; +} + +// Base type for FlatBufferFile<T>. See below. +class FlatBufferFileBase { + public: + using Identifier = absl::optional<const char*>; + + virtual ~FlatBufferFileBase(); + + protected: + template <typename T> + friend class FlatBufferFile; + + using VerifierFn = bool (*)(const char* identifier, + ::flatbuffers::Verifier* verifier); + + FlatBufferFileBase() = default; + + const void* root_ptr() const { return root_ptr_; } + + // Redirections of template static methods on FlatBufferFile so we can put the + // implementations in a shared compilation unit. + // See FlatBufferFile<T> for doc comments. + Status Create(const void* root_ptr, std::function<void()> deleter); + Status CreateWithBackingBuffer(const void* root_ptr, + ::flatbuffers::DetachedBuffer backing_buffer); + Status Wrap(const void* root); + Status FromBuffer(Identifier identifier, + absl::Span<const uint8_t> buffer_data, + std::function<void()> deleter, size_t root_type_size, + VerifierFn verifier_fn); + // Initializes from an STL byte based container (string and vector of + // char/byte should be compatible). + template <typename Container> + Status FromContainer(Identifier identifier, Container container, + size_t root_type_size, VerifierFn verifier_fn); + Status WrapBuffer(Identifier identifier, + absl::Span<const uint8_t> buffer_data, + size_t root_type_size, VerifierFn verifier_fn); + Status LoadFile(Identifier identifier, std::string path, + size_t root_type_size, VerifierFn verifier_fn); + + private: + const void* root_ptr_ = nullptr; + std::function<void()> deleter_; +}; + +// Immutable root FlatBuffer type wrapper with support for loading and backing +// buffer management. +// +// Immutable and thread-safe. +template <typename T> +class FlatBufferFile final : public FlatBufferFileBase { + public: + // Creates a FlatBufferFile from an in-memory root pointer. + // The provided |deleter| will be called when the FlatBufferFile is destructed + // and can be used to deallocate/clean up resources. + // + // This assumes that the root pointer has already been verified as valid. + // If verification is required instead use FromBuffer on the original buffer. + static StatusOr<std::unique_ptr<FlatBufferFile<T>>> Create( + const T* root, std::function<void()> deleter); + + // Creates a FlatBufferFile from an in-memory root pointer and the detached + // backing buffer storing it. + // + // Example: + // FlatBufferBuilder fbb; + // MyTypeBuilder mtb(fbb); + // fbb.Finish(mtb.Finish()); + // auto my_type = FlatBufferFile<MyType>::CreateWithBackingBuffer( + // fbb.Release()); + // my_type->foo(); + static StatusOr<std::unique_ptr<FlatBufferFile<T>>> CreateWithBackingBuffer( + ::flatbuffers::DetachedBuffer backing_buffer); + + // Wraps a caller-owned in-memory root pointer. + // The provided |root| must remain valid for the lifetime of the returned + // FlatBufferFile. + // + // This assumes that the root pointer has already been verified as valid. + // If verification is required instead use FromBuffer on the original buffer. + static StatusOr<std::unique_ptr<FlatBufferFile<T>>> Wrap(const T* root); + + // Creates a FlatBufferFile wrapping an external data buffer with a deleter + // function that will be called when the FlatBufferFile is destructed. + static StatusOr<std::unique_ptr<FlatBufferFile<T>>> FromBuffer( + Identifier identifier, absl::Span<const uint8_t> buffer_data, + std::function<void()> deleter); + + // Creates a FlatBufferFile from a serialized data buffer. + // The FlatBufferFile takes ownership of the vector. + static StatusOr<std::unique_ptr<FlatBufferFile<T>>> FromBuffer( + Identifier identifier, std::vector<uint8_t> buffer_data); + + // Loads a FlatBufferFile from an external buffer owned by the caller. + // The buffer must remain valid until the Pipeline is destroyed. + static StatusOr<std::unique_ptr<FlatBufferFile<T>>> WrapBuffer( + Identifier identifier, absl::Span<const uint8_t> buffer_data); + + // Loads the FlatBufferFile from a serialized byte-based STL container. + template <typename Container> + static StatusOr<std::unique_ptr<FlatBufferFile<T>>> FromContainer( + Identifier identifier, Container buffer_data); + + // Loads a FlatBufferFile from a serialized string. + // The FlatBufferFile takes ownership of the string. + static StatusOr<std::unique_ptr<FlatBufferFile<T>>> FromString( + Identifier identifier, std::string buffer_data) { + return FromContainer(identifier, std::move(buffer_data)); + } + + // Loads a FlatBufferFile from a serialized byte vector. + // The FlatBufferFile takes ownership of the vector. + static StatusOr<std::unique_ptr<FlatBufferFile<T>>> FromVector( + Identifier identifier, std::vector<uint8_t> buffer_data) { + return FromContainer(identifier, std::move(buffer_data)); + } + + // Loads a FlatBufferFile from a serialized file on the file system. + // This will attempt to mmap the file and is the preferred way of loading as + // only those pages that contain requested tables will be read. + static StatusOr<std::unique_ptr<FlatBufferFile<T>>> LoadFile( + Identifier identifier, std::string path); + + // Returns a vector of file references that share the same underlying data + // buffer. The buffer will be kept alive until the last file is released. + static StatusOr<std::vector<std::unique_ptr<FlatBufferFile<T>>>> + CreateShareGroup(std::unique_ptr<FlatBufferFile<T>> file, int count); + + ~FlatBufferFile() override = default; + + // Typed root pointer of the file. + const T* root() const { return reinterpret_cast<const T*>(root_ptr()); } + + private: + FlatBufferFile() = default; + + // Conforms to VerifierFn. + static bool VerifierFnT(const char* identifier, + ::flatbuffers::Verifier* verifier) { + return verifier->VerifyBuffer<T>(identifier); + } +}; + +template <typename Container> +Status FlatBufferFileBase::FromContainer(Identifier identifier, + Container container, + size_t root_type_size, + VerifierFn verifier_fn) { + static_assert(sizeof(*container.data()) == 1, + "Expected container of byte sized elements"); + auto buffer_data = absl::MakeConstSpan( + // Double static_cast through void is safer than reinterpret_cast. + static_cast<const uint8_t*>(static_cast<const void*>(container.data())), + container.size()); + // Use a baton to keep the container alive until the FlatBufferFileBase is + // destroyed. + auto buffer_data_baton = IreeMoveToLambda(container); + return FromBuffer( + identifier, buffer_data, + [buffer_data_baton]() { + // Keeping the container alive. + (void)buffer_data_baton.value; + }, + root_type_size, verifier_fn); +} + +// static +template <typename T> +StatusOr<std::unique_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::Create( + const T* root, std::function<void()> deleter) { + std::unique_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>}; + auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get()); + RETURN_IF_ERROR(base_file->Create(root, std::move(deleter))); + return std::move(flat_buffer_file); +} + +// static +template <typename T> +StatusOr<std::unique_ptr<FlatBufferFile<T>>> +FlatBufferFile<T>::CreateWithBackingBuffer( + ::flatbuffers::DetachedBuffer backing_buffer) { + std::unique_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>}; + auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get()); + auto* root_ptr = ::flatbuffers::GetRoot<T>(backing_buffer.data()); + RETURN_IF_ERROR( + base_file->CreateWithBackingBuffer(root_ptr, std::move(backing_buffer))); + return std::move(flat_buffer_file); +} + +// static +template <typename T> +StatusOr<std::unique_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::Wrap( + const T* root) { + std::unique_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>}; + auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get()); + RETURN_IF_ERROR(base_file->Wrap(root)); + return std::move(flat_buffer_file); +} + +// static +template <typename T> +StatusOr<std::unique_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::FromBuffer( + Identifier identifier, absl::Span<const uint8_t> buffer_data, + std::function<void()> deleter) { + std::unique_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>}; + auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get()); + RETURN_IF_ERROR(base_file->FromBuffer( + identifier, buffer_data, std::move(deleter), sizeof(T), VerifierFnT)); + return std::move(flat_buffer_file); +} + +// static +template <typename T> +StatusOr<std::unique_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::FromBuffer( + Identifier identifier, std::vector<uint8_t> buffer_data) { + auto* buffer_data_ptr = new decltype(buffer_data); + (*buffer_data_ptr) = std::move(buffer_data); + return FromBuffer(identifier, absl::MakeConstSpan(*buffer_data_ptr), + [buffer_data_ptr]() { delete buffer_data_ptr; }); +} + +// static +template <typename T> +StatusOr<std::unique_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::WrapBuffer( + Identifier identifier, absl::Span<const uint8_t> buffer_data) { + std::unique_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>}; + auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get()); + RETURN_IF_ERROR( + base_file->WrapBuffer(identifier, buffer_data, sizeof(T), VerifierFnT)); + return std::move(flat_buffer_file); +} + +// static +template <typename T> +template <typename Container> +StatusOr<std::unique_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::FromContainer( + Identifier identifier, Container buffer_data) { + std::unique_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>}; + auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get()); + RETURN_IF_ERROR(base_file->FromContainer(identifier, std::move(buffer_data), + sizeof(T), VerifierFnT)); + return std::move(flat_buffer_file); +} + +// static +template <typename T> +StatusOr<std::unique_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::LoadFile( + Identifier identifier, std::string path) { + std::unique_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>}; + auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get()); + RETURN_IF_ERROR( + base_file->LoadFile(identifier, std::move(path), sizeof(T), VerifierFnT)); + return std::move(flat_buffer_file); +} + +// static +template <typename T> +StatusOr<std::vector<std::unique_ptr<FlatBufferFile<T>>>> +FlatBufferFile<T>::CreateShareGroup(std::unique_ptr<FlatBufferFile<T>> file, + int count) { + // Create a shared_ptr wrapper for the base file that will be. + std::shared_ptr<FlatBufferFile<T>> shared_file{file.release()}; + + // Create N files. We wrap and keep the shared_ptr alive in the deleter + // capture. By wrapping we avoid reverifying the entire buffer. + std::vector<std::unique_ptr<FlatBufferFile<T>>> list; + for (int i = 0; i < count; ++i) { + ASSIGN_OR_RETURN(auto new_file, FlatBufferFile<T>::Create( + shared_file->root(), [shared_file]() { + // Each new file keeps a reference to + // the shared file to keep it alive. + (void)shared_file; + })); + list.push_back(std::move(new_file)); + } + return std::move(list); +} + +} // namespace iree + +#endif // IREE_BASE_FLATBUFFER_UTIL_H_
diff --git a/base/init.h b/base/init.h new file mode 100644 index 0000000..dab07f4 --- /dev/null +++ b/base/init.h
@@ -0,0 +1,52 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BASE_INIT_H_ +#define IREE_BASE_INIT_H_ + +// Initializer macros are defined in separate files: +// IREE_DECLARE_MODULE_INITIALIZER(name) +// IREE_REGISTER_MODULE_INITIALIZER(name, body) +// IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(name1, name2) +// IREE_REQUIRE_MODULE_INITIALIZED(name) +// IREE_RUN_MODULE_INITIALIZERS() +// IREE_REQUIRE_MODULE_LINKED(name) +// +// These macros allow for arranging pieces of initialization code to be +// executed at a well-defined time and in a well-defined order. +// +// Initialization happens automatically during InitializeEnvironment(), which +// should be called early in main(), before other code runs. + +#ifdef IREE_CONFIG_GOOGLE_INTERNAL +#include "base/google/init_google.h" +#else +#include "base/internal/init_internal.h" +#endif // IREE_CONFIG_GOOGLE_INTERNAL + +namespace iree { + +// Initializes the system environment in a binary. +// +// This first parses command line flags, then resolves module initializers +// by calling IREE_RUN_MODULE_INITIALIZERS(). +// +// 'argc' and 'argv' are the command line flags to parse. +// +// This should typically be called early in main(), before other code runs. +void InitializeEnvironment(int* argc, char*** argv); + +} // namespace iree + +#endif // IREE_BASE_INIT_H_
diff --git a/base/internal/BUILD b/base/internal/BUILD new file mode 100644 index 0000000..3e91f20 --- /dev/null +++ b/base/internal/BUILD
@@ -0,0 +1,118 @@ +# Implementations for iree/base/ + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "file_handle_win32", + srcs = ["file_handle_win32.cc"], + hdrs = ["file_handle_win32.h"], + deps = [ + "///base:status", + "///base:target_platform", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "file_io_internal", + srcs = [ + "file_io_posix.cc", + "file_io_win32.cc", + ], + deps = [ + ":file_handle_win32", + "///base:file_io_hdrs", + "///base:status", + "///base:target_platform", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "file_mapping_internal", + srcs = [ + "file_mapping_posix.cc", + "file_mapping_win32.cc", + ], + deps = [ + ":file_handle_win32", + "///base:file_mapping_hdrs", + "///base:target_platform", + "///base:tracing", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "init_internal", + srcs = ["init_internal.cc"], + hdrs = ["init_internal.h"], + deps = [ + "///base:target_platform", + "@com_google_absl//absl/flags:parse", + ], +) + +cc_library( + name = "logging_internal", + srcs = ["logging.cc"], + hdrs = ["logging.h"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/flags:flag", + ], +) + +cc_library( + name = "source_location_internal", + hdrs = ["source_location.h"], +) + +cc_library( + name = "status_internal", + srcs = [ + "status.cc", + "status_builder.cc", + "status_errno.cc", + "status_errors.cc", + "status_win32_errors.cc", + "statusor.cc", + ], + hdrs = [ + "status.h", + "status_builder.h", + "status_errno.h", + "status_errors.h", + "status_macros.h", + "status_win32_errors.h", + "statusor.h", + ], + deps = [ + ":logging_internal", + "///base:source_location", + "///base:target_platform", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/debugging:stacktrace", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "status_matchers_internal", + testonly = 1, + hdrs = ["status_matchers.h"], + deps = [ + "///base:status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_googletest//:gtest", + ], +)
diff --git a/base/internal/file_handle_win32.cc b/base/internal/file_handle_win32.cc new file mode 100644 index 0000000..37c3352 --- /dev/null +++ b/base/internal/file_handle_win32.cc
@@ -0,0 +1,55 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/internal/file_handle_win32.h" + +#include "absl/memory/memory.h" +#include "base/target_platform.h" + +#if defined(IREE_PLATFORM_WINDOWS) + +#include <windows.h> + +namespace iree { + +// static +StatusOr<std::unique_ptr<FileHandle>> FileHandle::OpenRead(std::string path, + DWORD file_flags) { + HANDLE handle = ::CreateFileA( + /*lpFileName=*/path.c_str(), /*dwDesiredAccess=*/GENERIC_READ, + /*dwShareMode=*/FILE_SHARE_READ, /*lpSecurityAttributes=*/nullptr, + /*dwCreationDisposition=*/OPEN_EXISTING, + /*dwFlagsAndAttributes=*/FILE_ATTRIBUTE_NORMAL | file_flags, + /*hTemplateFile=*/nullptr); + if (handle == INVALID_HANDLE_VALUE) { + return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC) + << "Unable to open file " << path; + } + + BY_HANDLE_FILE_INFORMATION file_info; + if (::GetFileInformationByHandle(handle, &file_info) == FALSE) { + return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC) + << "Unable to query file info for " << path; + } + + uint64_t file_size = (static_cast<uint64_t>(file_info.nFileSizeHigh) << 32) | + file_info.nFileSizeLow; + return absl::make_unique<FileHandle>(handle, file_size); +} + +FileHandle::~FileHandle() { ::CloseHandle(handle_); } + +} // namespace iree + +#endif // IREE_PLATFORM_WINDOWS
diff --git a/base/internal/file_handle_win32.h b/base/internal/file_handle_win32.h new file mode 100644 index 0000000..099e16b --- /dev/null +++ b/base/internal/file_handle_win32.h
@@ -0,0 +1,57 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BASE_INTERNAL_FILE_HANDLE_WIN32_H_ +#define IREE_BASE_INTERNAL_FILE_HANDLE_WIN32_H_ + +#include <memory> +#include <string> + +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" +#include "base/status.h" +#include "base/target_platform.h" + +#if defined(IREE_PLATFORM_WINDOWS) + +#include <windows.h> + +namespace iree { + +class FileHandle { + public: + static StatusOr<std::unique_ptr<FileHandle>> OpenRead(std::string path, + DWORD file_flags); + + FileHandle(HANDLE handle, size_t size) : handle_(handle), size_(size) {} + ~FileHandle(); + + absl::string_view path() const { return path_; } + HANDLE handle() const { return handle_; } + size_t size() const { return size_; } + + private: + FileHandle(const FileHandle&) = delete; + FileHandle& operator=(const FileHandle&) = delete; + + std::string path_; + HANDLE handle_; + size_t size_; +}; + +} // namespace iree + +#endif // IREE_PLATFORM_WINDOWS + +#endif // IREE_BASE_INTERNAL_FILE_HANDLE_WIN32_H_
diff --git a/base/internal/file_io_posix.cc b/base/internal/file_io_posix.cc new file mode 100644 index 0000000..95de126 --- /dev/null +++ b/base/internal/file_io_posix.cc
@@ -0,0 +1,89 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <cstdio> + +#include "base/file_io.h" +#include "base/status.h" +#include "base/target_platform.h" + +#if defined(IREE_PLATFORM_ANDROID) || defined(IREE_PLATFORM_APPLE) || \ + defined(IREE_PLATFORM_LINUX) + +#include <sys/stat.h> +#include <sys/types.h> +#include <unistd.h> + +namespace iree { +namespace file_io { + +Status FileExists(const std::string& path) { + struct stat stat_buf; + return stat(path.c_str(), &stat_buf) == 0 ? OkStatus() + : NotFoundErrorBuilder(IREE_LOC); +} + +StatusOr<std::string> GetFileContents(const std::string& path) { + std::unique_ptr<FILE, void (*)(FILE*)> file = {std::fopen(path.c_str(), "r"), + +[](FILE* file) { + if (file) fclose(file); + }}; + if (file == nullptr) { + return ErrnoToCanonicalStatusBuilder(errno, "Failed to open file", + IREE_LOC); + } + if (std::fseek(file.get(), 0, SEEK_END) == -1) { + return ErrnoToCanonicalStatusBuilder(errno, "Failed to seek file", + IREE_LOC); + } + size_t file_size = std::ftell(file.get()); + if (file_size == -1L) { + return ErrnoToCanonicalStatusBuilder(errno, "Failed to read file length", + IREE_LOC); + } + if (std::fseek(file.get(), 0, SEEK_SET) == -1) { + return ErrnoToCanonicalStatusBuilder(errno, "Failed to seek file", + IREE_LOC); + } + std::string contents; + contents.resize(file_size); + if (std::fread(const_cast<char*>(contents.data()), file_size, 1, + file.get()) != file_size) { + return UnavailableErrorBuilder(IREE_LOC) + << "Unable to read entire file contents"; + } + return contents; +} + +Status DeleteFile(const std::string& path) { + if (::remove(path.c_str()) == -1) { + return ErrnoToCanonicalStatusBuilder(errno, "Failed to delete file", + IREE_LOC); + } + return OkStatus(); +} + +Status MoveFile(const std::string& source_path, + const std::string& destination_path) { + if (::rename(source_path.c_str(), destination_path.c_str()) == -1) { + return ErrnoToCanonicalStatusBuilder(errno, "Failed to rename file", + IREE_LOC); + } + return OkStatus(); +} + +} // namespace file_io +} // namespace iree + +#endif // IREE_PLATFORM_*
diff --git a/base/internal/file_io_win32.cc b/base/internal/file_io_win32.cc new file mode 100644 index 0000000..fc16c20 --- /dev/null +++ b/base/internal/file_io_win32.cc
@@ -0,0 +1,76 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "base/file_io.h" +#include "base/internal/file_handle_win32.h" +#include "base/target_platform.h" + +#if defined(IREE_PLATFORM_WINDOWS) + +#include <windows.h> + +namespace iree { +namespace file_io { + +Status FileExists(const std::string& path) { + DWORD attrs = ::GetFileAttributesA(path.c_str()); + if (attrs == INVALID_FILE_ATTRIBUTES) { + return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC) + << "Unable to find/access file: " << path; + } + return OkStatus(); +} + +StatusOr<std::string> GetFileContents(const std::string& path) { + ASSIGN_OR_RETURN(auto file, FileHandle::OpenRead(std::move(path), + FILE_FLAG_SEQUENTIAL_SCAN)); + std::string result; + result.resize(file->size()); + DWORD bytes_read = 0; + if (::ReadFile(file->handle(), const_cast<char*>(result.data()), + result.size(), &bytes_read, nullptr) == FALSE) { + return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC) + << "Unable to read file span of " << result.size() << " bytes"; + } else if (bytes_read != file->size()) { + return ResourceExhaustedErrorBuilder(IREE_LOC) + << "Unable to read all " << file->size() + << " bytes from the file (got " << bytes_read << ")"; + } + return result; +} + +Status DeleteFile(const std::string& path) { + if (::DeleteFileA(path.c_str()) == FALSE) { + return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC) + << "Unable to delete/access file: " << path; + } + return OkStatus(); +} + +Status MoveFile(const std::string& source_path, + const std::string& destination_path) { + if (::MoveFileA(source_path.c_str(), destination_path.c_str()) == FALSE) { + return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC) + << "Unable to move file " << source_path << " to " + << destination_path; + } + return OkStatus(); +} + +} // namespace file_io +} // namespace iree + +#endif // IREE_PLATFORM_WINDOWS
diff --git a/base/internal/file_mapping_posix.cc b/base/internal/file_mapping_posix.cc new file mode 100644 index 0000000..fd4e1e6 --- /dev/null +++ b/base/internal/file_mapping_posix.cc
@@ -0,0 +1,106 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/file_mapping.h" +#include "base/target_platform.h" +#include "base/tracing.h" + +#if defined(IREE_PLATFORM_ANDROID) || defined(IREE_PLATFORM_APPLE) || \ + defined(IREE_PLATFORM_LINUX) + +#include <fcntl.h> +#include <sys/mman.h> +#include <sys/stat.h> +#include <sys/types.h> +#include <unistd.h> + +#include <cerrno> + +namespace iree { + +namespace { + +class FileDescriptor { + public: + static StatusOr<std::unique_ptr<FileDescriptor>> OpenRead(std::string path) { + struct stat buf; + if (::lstat(path.c_str(), &buf) == -1) { + return NotFoundErrorBuilder(IREE_LOC) + << "Unable to stat file " << path << ": " << ::strerror(errno); + } + uint64_t file_size = static_cast<size_t>(buf.st_size); + + int fd = ::open(path.c_str(), O_RDONLY); + if (fd == -1) { + return UnavailableErrorBuilder(IREE_LOC) + << "Unable to open file " << path << ": " << ::strerror(errno); + } + + return absl::make_unique<FileDescriptor>(std::move(path), fd, file_size); + } + + FileDescriptor(std::string path, int fd, size_t size) + : path_(std::move(path)), fd_(fd), size_(size) {} + ~FileDescriptor() { ::close(fd_); } + + absl::string_view path() const { return path_; } + int fd() const { return fd_; } + size_t size() const { return size_; } + + private: + FileDescriptor(const FileDescriptor&) = delete; + FileDescriptor& operator=(const FileDescriptor&) = delete; + + std::string path_; + int fd_; + size_t size_; +}; + +class MMapMapping : public FileMapping { + public: + MMapMapping(void* data, size_t data_size) + : FileMapping( + absl::MakeSpan(reinterpret_cast<uint8_t*>(data), data_size)) {} + + ~MMapMapping() override { + if (::munmap(const_cast<uint8_t*>(data_.data()), data_.size()) != 0) { + LOG(WARNING) << "Unable to unmap file: " << strerror(errno); + } + } +}; + +} // namespace + +// static +StatusOr<ref_ptr<FileMapping>> FileMapping::OpenRead(std::string path) { + IREE_TRACE_SCOPE0("FileMapping::Open"); + + // Open the file for reading. Note that we only need to keep it open long + // enough to map it and we can close the descriptor after that. + ASSIGN_OR_RETURN(auto file, FileDescriptor::OpenRead(std::move(path))); + + // Map the file from the file descriptor. + void* data = + ::mmap(nullptr, file->size(), PROT_READ, MAP_SHARED, file->fd(), 0); + if (data == MAP_FAILED) { + return UnavailableErrorBuilder(IREE_LOC) + << "Mapping failed on file (ensure uncompressed): " << file->path(); + } + + return make_ref<MMapMapping>(data, file->size()); +} + +} // namespace iree + +#endif // IREE_PLATFORM_*
diff --git a/base/internal/file_mapping_win32.cc b/base/internal/file_mapping_win32.cc new file mode 100644 index 0000000..9cdead8 --- /dev/null +++ b/base/internal/file_mapping_win32.cc
@@ -0,0 +1,98 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "base/file_mapping.h" +#include "base/internal/file_handle_win32.h" +#include "base/target_platform.h" +#include "base/tracing.h" + +#if defined(IREE_PLATFORM_WINDOWS) + +#include <windows.h> + +namespace iree { + +namespace { + +class Win32FileMapping : public FileMapping { + public: + Win32FileMapping(HANDLE mapping_handle, void* data, size_t data_size) + : FileMapping( + absl::MakeSpan(reinterpret_cast<uint8_t*>(data), data_size)), + mapping_handle_(mapping_handle) {} + + ~Win32FileMapping() override { + if (!data_.empty()) { + if (::UnmapViewOfFile(data_.data()) == FALSE) { + LOG(WARNING) << "Unable to unmap file: " << GetLastError(); + } + data_ = {}; + } + if (mapping_handle_) { + ::CloseHandle(mapping_handle_); + mapping_handle_ = nullptr; + } + } + + private: + HANDLE mapping_handle_; +}; + +} // namespace + +// static +StatusOr<ref_ptr<FileMapping>> FileMapping::OpenRead(std::string path) { + IREE_TRACE_SCOPE0("FileMapping::Open"); + + // Open the file for reading. Note that we only need to keep it open long + // enough to map it and we can close the descriptor after that. + ASSIGN_OR_RETURN(auto file, FileHandle::OpenRead(std::move(path), + FILE_FLAG_RANDOM_ACCESS)); + + HANDLE mapping_handle = ::CreateFileMappingA( + /*hFile=*/file->handle(), /*lpFileMappingAttributes=*/nullptr, + /*flProtect=*/PAGE_READONLY, /*dwMaximumSizeHigh=*/0, + /*dwMaximumSizeLow=*/0, /*lpName=*/nullptr); + if (!mapping_handle) { + return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC) + << "Failed to create mapping on file (ensure uncompressed): " + << file->path(); + } + + void* data = + ::MapViewOfFileEx(/*hFileMappingObject=*/mapping_handle, + /*dwDesiredAccess=*/FILE_MAP_READ, + /*dwFileOffsetHigh=*/0, /*dwFileOffsetLow=*/0, + /*dwNumberOfBytesToMap=*/0, /*lpBaseAddress=*/nullptr); + if (!data) { + DWORD map_view_error = GetLastError(); + ::CloseHandle(mapping_handle); + return Win32ErrorToCanonicalStatusBuilder(map_view_error, IREE_LOC) + << "Failed to map view of file: " << file->path(); + } + + auto result = make_ref<Win32FileMapping>(mapping_handle, data, file->size()); + + // NOTE: file mappings hold references to the file, so we don't need to keep + // the file around any longer than this function. + file.reset(); + + return result; +} + +} // namespace iree + +#endif // IREE_PLATFORM_WINDOWS
diff --git a/base/internal/init_internal.cc b/base/internal/init_internal.cc new file mode 100644 index 0000000..0f0ddeb --- /dev/null +++ b/base/internal/init_internal.cc
@@ -0,0 +1,110 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/internal/init_internal.h" + +#include <string.h> + +#include <set> + +#include "absl/flags/parse.h" + +namespace iree { + +static Initializer::NameMap* static_name_map = nullptr; + +struct Initializer::InitializerData { + Initializer* initializer_obj; + std::set<std::string> dependency_names; + + InitializerData() : initializer_obj(nullptr) {} + explicit InitializerData(Initializer* i) : initializer_obj(i) {} +}; + +Initializer::DependencyRegisterer::DependencyRegisterer( + const char* name, Initializer* initializer, const Dependency& dependency) { + NameMap* name_map = InitializerNameMap(); + + // Insert 'dependency' into the 'dependency_names' set for 'initializer'. + InitializerData* initializer_data = &(*name_map)[name]; + initializer_data->dependency_names.insert(dependency.name); + + // Ensure that 'dependency' exists in the map. + InitializerData* dependency_data = &(*name_map)[dependency.name]; + dependency_data->initializer_obj = dependency.initializer; +} + +Initializer::Initializer(const char* name, InitializerFunc function) + : name_(name), function_(function), done_(false) { + // Register this Initializer instance (wrapped by an InitializerData) within + // the static name map. + NameMap* name_map = InitializerNameMap(); + InitializerData* initializer_data = &(*name_map)[name]; + initializer_data->initializer_obj = this; +} + +void Initializer::RunInitializers() { + // Run each registered Initializer, in lexicographic order of their names. + // Initializer dependencies will be run first as needed. + NameMap* name_map = InitializerNameMap(); + for (auto& p : *name_map) { + RunInitializer(&p.second); + } +} + +void Initializer::Require() { + NameMap* name_map = InitializerNameMap(); + InitializerData* initializer_data = &(name_map->find(name_)->second); + RunInitializer(initializer_data); +} + +Initializer::NameMap* Initializer::InitializerNameMap() { + if (static_name_map == nullptr) { + static_name_map = new Initializer::NameMap; + } + return static_name_map; +} + +void Initializer::RunInitializer(InitializerData* initializer_data) { + if (initializer_data->initializer_obj->done_) { + return; + } + + // Run Initializer dependencies first. + NameMap* name_map = InitializerNameMap(); + for (const auto& dependency_name : initializer_data->dependency_names) { + auto dep_init = name_map->find(dependency_name); + RunInitializer(&dep_init->second); + } + + // Finally run the Initializer itself. + initializer_data->initializer_obj->function_(); + initializer_data->initializer_obj->done_ = true; +} + +void InitializeEnvironment(int* argc, char*** argv) { + auto positional_args = absl::ParseCommandLine(*argc, *argv); + if (positional_args.size() < *argc) { + // Edit the passed argument refs to only include positional args. + *argc = positional_args.size(); + for (int i = 0; i < *argc; ++i) { + (*argv)[i] = positional_args[i]; + } + (*argv)[*argc + 1] = nullptr; + } + + IREE_RUN_MODULE_INITIALIZERS(); +} + +} // namespace iree
diff --git a/base/internal/init_internal.h b/base/internal/init_internal.h new file mode 100644 index 0000000..bbf2591 --- /dev/null +++ b/base/internal/init_internal.h
@@ -0,0 +1,110 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BASE_INTERNAL_INIT_INTERNAL_H_ +#define IREE_BASE_INTERNAL_INIT_INTERNAL_H_ + +#include <map> +#include <string> + +#include "base/target_platform.h" + +namespace iree { + +// A static instance of this class is declared for each piece of initialization +// code using the initializer macros. +class Initializer { + public: + typedef void (*InitializerFunc)(); + + Initializer(const char* name, InitializerFunc function); + + // Runs all registered initializers that have not yet run. + // The initializers are invoked in lexicographically increasing order by name, + // except as necessary to satisfy dependencies. + // + // This is normally called by InitializeEnvironment(), so application code + // typically should not call it directly. + static void RunInitializers(); + + // Runs this initializer if it has not yet run, including any dependencies. + void Require(); + + struct Dependency { + Dependency(const char* n, Initializer* i) : name(n), initializer(i) {} + const char* const name; + Initializer* const initializer; + }; + + // A static instance of this class is declared for each piece of + // initializer ordering definition. + struct DependencyRegisterer { + DependencyRegisterer(const char* name, Initializer* initializer, + const Dependency& dependency); + }; + + struct InitializerData; + typedef std::map<std::string, InitializerData> NameMap; + + private: + static NameMap* InitializerNameMap(); + static void RunInitializer(InitializerData* initializer_data); + + const std::string name_; + InitializerFunc function_; + bool done_; +}; + +// In iree/base/init.h: +void InitializeEnvironment(int* argc, char*** argv); + +} // namespace iree + +#define IREE_DECLARE_MODULE_INITIALIZER(name) \ + extern ::iree::Initializer iree_initializer_##name + +#define IREE_REGISTER_MODULE_INITIALIZER(name, body) \ + static void iree_init_##name() { body; } \ + ::iree::Initializer iree_initializer_##name(#name, iree_init_##name) + +#define IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(name1, name2) \ + namespace { \ + static ::iree::Initializer::DependencyRegisterer \ + iree_initializer_dependency_##name1##_##name2( \ + #name2, &iree_initializer_##name2, \ + ::iree::Initializer::Dependency(#name1, &iree_initializer_##name1)); \ + } + +#define IREE_REQUIRE_MODULE_INITIALIZED(name) \ + do { \ + IREE_DECLARE_MODULE_INITIALIZER(name); \ + iree_initializer_##name.Require(); \ + } while (0) + +#define IREE_RUN_MODULE_INITIALIZERS() \ + do { \ + ::iree::Initializer::RunInitializers(); \ + } while (0) + +#if !defined(IREE_COMPILER_MSVC) +#define IREE_ATTRIBUTE_USED __attribute__((used)) +#else +#define IREE_ATTRIBUTE_USED +#endif // IREE_COMPILER_MSVC + +#define IREE_REQUIRE_MODULE_LINKED(name) \ + IREE_ATTRIBUTE_USED static ::iree::Initializer* iree_module_ref_##name = \ + &iree_initializer_##name + +#endif // IREE_BASE_INTERNAL_INIT_INTERNAL_H_
diff --git a/base/internal/logging.cc b/base/internal/logging.cc new file mode 100644 index 0000000..5e4c52e --- /dev/null +++ b/base/internal/logging.cc
@@ -0,0 +1,106 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/internal/logging.h" + +#include <string> + +#include "absl/flags/flag.h" + +ABSL_FLAG(int, iree_minloglevel, 0, + "Minimum logging level. 0 = INFO and above."); +ABSL_FLAG(int, iree_v, 0, + "Verbosity level maximum. 1 = VLOG(0-1), 2 = VLOG(0-2)."); +ABSL_FLAG(bool, iree_logtostderr, false, "Logs to stderr instead of stdout"); + +namespace iree { +namespace internal { + +namespace { + +// Parse log level (int64_t) from environment variable (char*). +// Returns true if the value was present and parsed successfully. +bool LogLevelStrToInt(const char* iree_env_var_val, int64_t* out_level) { + *out_level = 0; + if (iree_env_var_val == nullptr) { + return false; + } + + std::string min_log_level(iree_env_var_val); + std::istringstream ss(min_log_level); + int64_t level; + if (!(ss >> level)) { + // Invalid vlog level setting, set level to default (0). + return false; + } + + *out_level = level; + return true; +} + +int64_t MinLogLevelFromEnv() { + const char* iree_env_var_val = getenv("IREE_MIN_LOG_LEVEL"); + int64_t level = 0; + if (LogLevelStrToInt(iree_env_var_val, &level)) { + return level; + } + return absl::GetFlag(FLAGS_iree_minloglevel); +} + +int64_t MinVLogLevelFromEnv() { + const char* iree_env_var_val = getenv("IREE_MIN_VLOG_LEVEL"); + int64_t level = 0; + if (LogLevelStrToInt(iree_env_var_val, &level)) { + return level; + } + return absl::GetFlag(FLAGS_iree_v); +} + +} // namespace + +LogMessage::LogMessage(const char* file_name, int line, int severity) + : file_name_(file_name), line_(line), severity_(severity) {} + +LogMessage::~LogMessage() { + // Read the min log level once during the first call to logging. + static int64_t min_log_level = MinLogLevelFromEnv(); + if (ABSL_PREDICT_TRUE(severity_ >= min_log_level)) { + EmitLogMessage(); + } +} + +int64_t LogMessage::MinVLogLevel() { + static int64_t min_vlog_level = MinVLogLevelFromEnv(); + return min_vlog_level; +} + +void LogMessage::EmitLogMessage() { + // TODO(scotttodd): Include current system time + fprintf(absl::GetFlag(FLAGS_iree_logtostderr) ? stderr : stdout, + "%c %s:%d] %s\n", "IWEF"[severity_], file_name_, line_, + str().c_str()); +} + +LogMessageFatal::LogMessageFatal(const char* file, int line) + : LogMessage(file, line, FATAL) {} + +LogMessageFatal::~LogMessageFatal() { + EmitLogMessage(); + + // abort() ensures we don't return (as promised via ATTRIBUTE_NORETURN). + abort(); +} + +} // namespace internal +} // namespace iree
diff --git a/iree/base/internal/logging.h b/base/internal/logging.h similarity index 100% rename from iree/base/internal/logging.h rename to base/internal/logging.h
diff --git a/iree/base/internal/source_location.h b/base/internal/source_location.h similarity index 100% rename from iree/base/internal/source_location.h rename to base/internal/source_location.h
diff --git a/base/internal/status.cc b/base/internal/status.cc new file mode 100644 index 0000000..447ab52 --- /dev/null +++ b/base/internal/status.cc
@@ -0,0 +1,178 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/internal/status.h" + +#include <atomic> +#include <memory> + +#include "absl/base/attributes.h" +#include "absl/debugging/stacktrace.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" + +ABSL_FLAG(bool, iree_status_save_stack_trace, false, + "Save and display the full stack trace of the point of error") + .OnUpdate([]() { + iree::StatusSavesStackTrace( + absl::GetFlag(FLAGS_iree_status_save_stack_trace)); + }); + +namespace iree { + +namespace status_internal { + +ABSL_CONST_INIT std::atomic<bool> iree_save_stack_trace{false}; + +} // namespace status_internal + +bool DoesStatusSaveStackTrace() { + return status_internal::iree_save_stack_trace.load(std::memory_order_relaxed); +} +void StatusSavesStackTrace(bool on_off) { + status_internal::iree_save_stack_trace.store(on_off, + std::memory_order_relaxed); +} + +std::string StatusCodeToString(StatusCode code) { + switch (code) { + case StatusCode::kOk: + return "OK"; + case StatusCode::kCancelled: + return "CANCELLED"; + case StatusCode::kUnknown: + return "UNKNOWN"; + case StatusCode::kInvalidArgument: + return "INVALID_ARGUMENT"; + case StatusCode::kDeadlineExceeded: + return "DEADLINE_EXCEEDED"; + case StatusCode::kNotFound: + return "NOT_FOUND"; + case StatusCode::kAlreadyExists: + return "ALREADY_EXISTS"; + case StatusCode::kPermissionDenied: + return "PERMISSION_DENIED"; + case StatusCode::kUnauthenticated: + return "UNAUTHENTICATED"; + case StatusCode::kResourceExhausted: + return "RESOURCE_EXHAUSTED"; + case StatusCode::kFailedPrecondition: + return "FAILED_PRECONDITION"; + case StatusCode::kAborted: + return "ABORTED"; + case StatusCode::kOutOfRange: + return "OUT_OF_RANGE"; + case StatusCode::kUnimplemented: + return "UNIMPLEMENTED"; + case StatusCode::kInternal: + return "INTERNAL"; + case StatusCode::kUnavailable: + return "UNAVAILABLE"; + case StatusCode::kDataLoss: + return "DATA_LOSS"; + default: + return ""; + } +} + +Status::Status() {} + +Status::Status(StatusCode code, absl::string_view message) { + state_ = absl::make_unique<State>(); + state_->code = code; + state_->message = std::string(message); +} + +Status::Status(const Status& x) { + if (x.ok()) return; + + state_ = absl::make_unique<State>(); + state_->code = x.state_->code; + state_->message = x.state_->message; +} + +Status& Status::operator=(const Status& x) { + if (x.ok()) { + state_ = nullptr; + } else { + state_ = absl::make_unique<State>(); + state_->code = x.state_->code; + state_->message = x.state_->message; + } + return *this; +} + +Status::~Status() {} + +bool Status::ok() const { return state_ == nullptr; } + +StatusCode Status::code() const { + return ok() ? StatusCode::kOk : state_->code; +} + +absl::string_view Status::message() const { + return ok() ? absl::string_view() : absl::string_view(state_->message); +} + +std::string Status::ToString() const { + if (ok()) { + return "OK"; + } + + std::string text; + absl::StrAppend(&text, StatusCodeToString(state_->code), ": ", + state_->message); + // TODO(scotttodd): Payloads (stack traces) + return text; +} + +void Status::IgnoreError() const { + // no-op +} + +bool Status::EqualsSlow(const Status& a, const Status& b) { + if (a.code() != b.code()) return false; + if (a.message() != b.message()) return false; + // TODO(scotttodd): Payloads + return true; +} + +bool operator==(const Status& lhs, const Status& rhs) { + return lhs.state_ == rhs.state_ || Status::EqualsSlow(lhs, rhs); +} + +bool operator!=(const Status& lhs, const Status& rhs) { return !(lhs == rhs); } + +std::ostream& operator<<(std::ostream& os, const Status& x) { + os << x.ToString(); + return os; +} + +Status OkStatus() { return Status(); } + +Status Annotate(const Status& s, absl::string_view msg) { + if (s.ok() || msg.empty()) return s; + + absl::string_view new_msg = msg; + std::string annotated; + if (!s.message().empty()) { + absl::StrAppend(&annotated, s.message(), "; ", msg); + new_msg = annotated; + } + Status result(s.code(), new_msg); + // TODO(scotttodd): Copy payload(s) into the new Status + return result; +} + +} // namespace iree
diff --git a/base/internal/status.h b/base/internal/status.h new file mode 100644 index 0000000..cc1d206 --- /dev/null +++ b/base/internal/status.h
@@ -0,0 +1,130 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BASE_INTERNAL_STATUS_H_ +#define IREE_BASE_INTERNAL_STATUS_H_ + +#include <atomic> +#include <string> + +#include "absl/base/attributes.h" +#include "absl/flags/flag.h" +#include "absl/strings/string_view.h" +#include "base/internal/logging.h" + +ABSL_DECLARE_FLAG(bool, iree_status_save_stack_trace); + +namespace iree { + +// True if Status objects will capture stack traces on init for non-ok Statuses. +bool DoesStatusSaveStackTrace(); + +// Enables/disables status stack trace saving. This is global for the process. +// While useful for debugging, stack traces can impact performance severely. +void StatusSavesStackTrace(bool on_off); + +enum class StatusCode : int { + kOk = 0, + kCancelled = 1, + kUnknown = 2, + kInvalidArgument = 3, + kDeadlineExceeded = 4, + kNotFound = 5, + kAlreadyExists = 6, + kPermissionDenied = 7, + kResourceExhausted = 8, + kFailedPrecondition = 9, + kAborted = 10, + kOutOfRange = 11, + kUnimplemented = 12, + kInternal = 13, + kUnavailable = 14, + kDataLoss = 15, + kUnauthenticated = 16, + kDoNotUseReservedForFutureExpansionUseDefaultInSwitchInstead_ = 20 +}; + +std::string StatusCodeToString(StatusCode code); + +class ABSL_MUST_USE_RESULT Status; + +// A Status value can be either OK or not-OK +// * OK indicates that the operation succeeded. +// * A not-OK value indicates that the operation failed and contains details +// about the error. +class Status final { + public: + // Creates an OK status with no message. + Status(); + + // Creates a status with the specified code and error message. + Status(StatusCode code, absl::string_view message); + + Status(const Status&); + Status& operator=(const Status& x); + + ~Status(); + + // Returns true if the Status is OK. + ABSL_MUST_USE_RESULT bool ok() const; + + // Returns the error code. + StatusCode code() const; + + // Returns the error message. Note: prefer ToString() for debug logging. + // This message rarely describes the error code. It is not unusual for the + // error message to be the empty string. + absl::string_view message() const; + + // Return a combination of the error code name and message. + std::string ToString() const; + + // Compatibility with upstream API. Equiv to ToString(). + std::string error_message() const { return ToString(); } + + friend bool operator==(const Status&, const Status&); + friend bool operator!=(const Status&, const Status&); + + // Ignores any errors, potentially suppressing complaints from any tools. + void IgnoreError() const; + + private: + static bool EqualsSlow(const Status& a, const Status& b); + + struct State { + StatusCode code; + std::string message; + }; + // OK status has a nullptr state_. Otherwise, 'state_' points to + // a 'State' structure containing the error code and message(s). + std::unique_ptr<State> state_; +}; + +// Returns an OK status, equivalent to a default constructed instance. +Status OkStatus(); + +// Prints a human-readable representation of `x` to `os`. +std::ostream& operator<<(std::ostream& os, const Status& x); + +// Returns a Status that is identical to `s` except that the message() +// has been augmented by adding `msg` to the end of the original message. +Status Annotate(const Status& s, absl::string_view msg); + +#define CHECK_OK(val) CHECK_EQ(::iree::OkStatus(), (val)) +#define QCHECK_OK(val) QCHECK_EQ(::iree::OkStatus(), (val)) +#define DCHECK_OK(val) DCHECK_EQ(::iree::OkStatus(), (val)) + +} // namespace iree + +#endif // IREE_BASE_INTERNAL_STATUS_H_
diff --git a/base/internal/status_builder.cc b/base/internal/status_builder.cc new file mode 100644 index 0000000..d5bf924 --- /dev/null +++ b/base/internal/status_builder.cc
@@ -0,0 +1,140 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/internal/status_builder.h" + +#include <cstdio> + +#include "base/internal/status_errors.h" + +namespace iree { + +StatusBuilder::StatusBuilder(const Status& original_status, + SourceLocation location) + : status_(original_status), loc_(location) {} + +StatusBuilder::StatusBuilder(Status&& original_status, SourceLocation location) + : status_(original_status), loc_(location) {} + +StatusBuilder::StatusBuilder(const StatusBuilder& sb) + : status_(sb.status_), loc_(sb.loc_), message_(sb.message_) {} + +StatusBuilder::StatusBuilder(StatusCode code, SourceLocation location) + : status_(code, ""), loc_(location) {} + +StatusBuilder& StatusBuilder::operator=(const StatusBuilder& sb) { + status_ = sb.status_; + loc_ = sb.loc_; + message_ = sb.message_; + return *this; +} + +StatusBuilder::operator Status() const& { + return StatusBuilder(*this).CreateStatus(); +} +StatusBuilder::operator Status() && { return std::move(*this).CreateStatus(); } + +bool StatusBuilder::ok() const { return status_.ok(); } + +StatusCode StatusBuilder::code() const { return status_.code(); } + +SourceLocation StatusBuilder::source_location() const { return loc_; } + +Status StatusBuilder::CreateStatus() && { + Status result = JoinMessageToStatus(status_, message_); + + // Reset the status after consuming it. + status_ = UnknownError(""); + message_ = ""; + return result; +} + +Status StatusBuilder::JoinMessageToStatus(Status s, absl::string_view msg) { + if (msg.empty()) return s; + return Annotate(s, msg); +} + +std::ostream& operator<<(std::ostream& os, const StatusBuilder& builder) { + return os << static_cast<Status>(builder); +} + +std::ostream& operator<<(std::ostream& os, StatusBuilder&& builder) { + return os << static_cast<Status>(std::move(builder)); +} + +StatusBuilder AbortedErrorBuilder(SourceLocation location) { + return StatusBuilder(StatusCode::kAborted, location); +} + +StatusBuilder AlreadyExistsErrorBuilder(SourceLocation location) { + return StatusBuilder(StatusCode::kAlreadyExists, location); +} + +StatusBuilder CancelledErrorBuilder(SourceLocation location) { + return StatusBuilder(StatusCode::kCancelled, location); +} + +StatusBuilder DataLossErrorBuilder(SourceLocation location) { + return StatusBuilder(StatusCode::kDataLoss, location); +} + +StatusBuilder DeadlineExceededErrorBuilder(SourceLocation location) { + return StatusBuilder(StatusCode::kDeadlineExceeded, location); +} + +StatusBuilder FailedPreconditionErrorBuilder(SourceLocation location) { + return StatusBuilder(StatusCode::kFailedPrecondition, location); +} + +StatusBuilder InternalErrorBuilder(SourceLocation location) { + return StatusBuilder(StatusCode::kInternal, location); +} + +StatusBuilder InvalidArgumentErrorBuilder(SourceLocation location) { + return StatusBuilder(StatusCode::kInvalidArgument, location); +} + +StatusBuilder NotFoundErrorBuilder(SourceLocation location) { + return StatusBuilder(StatusCode::kNotFound, location); +} + +StatusBuilder OutOfRangeErrorBuilder(SourceLocation location) { + return StatusBuilder(StatusCode::kOutOfRange, location); +} + +StatusBuilder PermissionDeniedErrorBuilder(SourceLocation location) { + return StatusBuilder(StatusCode::kPermissionDenied, location); +} + +StatusBuilder UnauthenticatedErrorBuilder(SourceLocation location) { + return StatusBuilder(StatusCode::kUnauthenticated, location); +} + +StatusBuilder ResourceExhaustedErrorBuilder(SourceLocation location) { + return StatusBuilder(StatusCode::kResourceExhausted, location); +} + +StatusBuilder UnavailableErrorBuilder(SourceLocation location) { + return StatusBuilder(StatusCode::kUnavailable, location); +} + +StatusBuilder UnimplementedErrorBuilder(SourceLocation location) { + return StatusBuilder(StatusCode::kUnimplemented, location); +} + +StatusBuilder UnknownErrorBuilder(SourceLocation location) { + return StatusBuilder(StatusCode::kUnknown, location); +} + +} // namespace iree
diff --git a/base/internal/status_builder.h b/base/internal/status_builder.h new file mode 100644 index 0000000..60f3221 --- /dev/null +++ b/base/internal/status_builder.h
@@ -0,0 +1,137 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BASE_INTERNAL_STATUS_BUILDER_H_ +#define IREE_BASE_INTERNAL_STATUS_BUILDER_H_ + +#include "base/internal/status.h" +#include "base/source_location.h" + +namespace iree { + +// Creates a status based on an original_status, but enriched with additional +// information. The builder implicitly converts to Status and StatusOr<T> +// allowing for it to be returned directly. +class ABSL_MUST_USE_RESULT StatusBuilder { + public: + // Creates a `StatusBuilder` based on an original status. + explicit StatusBuilder(const Status& original_status, + SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); + explicit StatusBuilder(Status&& original_status, + SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); + + // Creates a `StatusBuilder` from a status code. + // A typical user will not specify `location`, allowing it to default to the + // current location. + explicit StatusBuilder(StatusCode code, + SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); + + StatusBuilder(const StatusBuilder& sb); + StatusBuilder& operator=(const StatusBuilder& sb); + StatusBuilder(StatusBuilder&&) = default; + StatusBuilder& operator=(StatusBuilder&&) = default; + + // Appends to the extra message that will be added to the original status. + template <typename T> + StatusBuilder& operator<<(const T& value) &; + template <typename T> + StatusBuilder&& operator<<(const T& value) &&; + + // No-op functions that may be added later. + StatusBuilder& LogError() & { return *this; } + StatusBuilder&& LogError() && { return std::move(LogError()); } + StatusBuilder& LogWarning() & { return *this; } + StatusBuilder&& LogWarning() && { return std::move(LogWarning()); } + StatusBuilder& LogInfo() & { return *this; } + StatusBuilder&& LogInfo() && { return std::move(LogInfo()); } + + // Returns true if the Status created by this builder will be ok(). + bool ok() const; + + // Returns the error code for the Status created by this builder. + StatusCode code() const; + + // Returns the source location used to create this builder. + SourceLocation source_location() const; + + // Implicit conversion to Status. + operator Status() const&; + operator Status() &&; + + private: + Status CreateStatus() &&; + + static Status JoinMessageToStatus(Status s, absl::string_view msg); + + // The status that the result will be based on. + Status status_; + + // The location to record if this status is logged. + SourceLocation loc_; + + // The message that will be added to the original status. + std::string message_; +}; + +template <typename T> +StatusBuilder& StatusBuilder::operator<<(const T& value) & { + return *this; +} +template <typename T> +StatusBuilder&& StatusBuilder::operator<<(const T& value) && { + return std::move(operator<<(value)); +} + +// Implicitly converts `builder` to `Status` and write it to `os`. +std::ostream& operator<<(std::ostream& os, const StatusBuilder& builder); +std::ostream& operator<<(std::ostream& os, StatusBuilder&& builder); + +// Each of the functions below creates StatusBuilder with a canonical error. +// The error code of the StatusBuilder matches the name of the function. +StatusBuilder AbortedErrorBuilder( + SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); +StatusBuilder AlreadyExistsErrorBuilder( + SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); +StatusBuilder CancelledErrorBuilder( + SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); +StatusBuilder DataLossErrorBuilder( + SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); +StatusBuilder DeadlineExceededErrorBuilder( + SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); +StatusBuilder FailedPreconditionErrorBuilder( + SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); +StatusBuilder InternalErrorBuilder( + SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); +StatusBuilder InvalidArgumentErrorBuilder( + SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); +StatusBuilder NotFoundErrorBuilder( + SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); +StatusBuilder OutOfRangeErrorBuilder( + SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); +StatusBuilder PermissionDeniedErrorBuilder( + SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); +StatusBuilder UnauthenticatedErrorBuilder( + SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); +StatusBuilder ResourceExhaustedErrorBuilder( + SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); +StatusBuilder UnavailableErrorBuilder( + SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); +StatusBuilder UnimplementedErrorBuilder( + SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); +StatusBuilder UnknownErrorBuilder( + SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); + +} // namespace iree + +#endif // IREE_BASE_INTERNAL_STATUS_BUILDER_H_
diff --git a/base/internal/status_errno.cc b/base/internal/status_errno.cc new file mode 100644 index 0000000..3788bc3 --- /dev/null +++ b/base/internal/status_errno.cc
@@ -0,0 +1,175 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/internal/status_errno.h" + +#include <cerrno> + +#include "absl/strings/str_cat.h" + +namespace iree { + +StatusCode ErrnoToCanonicalCode(int error_number) { + switch (error_number) { + case 0: + return StatusCode::kOk; + case EINVAL: // Invalid argument + case ENAMETOOLONG: // Filename too long + case E2BIG: // Argument list too long + case EDESTADDRREQ: // Destination address required + case EDOM: // Mathematics argument out of domain of function + case EFAULT: // Bad address + case EILSEQ: // Illegal byte sequence + case ENOPROTOOPT: // Protocol not available + case ENOSTR: // Not a STREAM + case ENOTSOCK: // Not a socket + case ENOTTY: // Inappropriate I/O control operation + case EPROTOTYPE: // Protocol wrong type for socket + case ESPIPE: // Invalid seek + return StatusCode::kInvalidArgument; + case ETIMEDOUT: // Connection timed out + case ETIME: // Timer expired + return StatusCode::kDeadlineExceeded; + case ENODEV: // No such device + case ENOENT: // No such file or directory +#ifdef ENOMEDIUM + case ENOMEDIUM: // No medium found +#endif + case ENXIO: // No such device or address + case ESRCH: // No such process + return StatusCode::kNotFound; + case EEXIST: // File exists + case EADDRNOTAVAIL: // Address not available + case EALREADY: // Connection already in progress +#ifdef ENOTUNIQ + case ENOTUNIQ: // Name not unique on network +#endif + return StatusCode::kAlreadyExists; + case EPERM: // Operation not permitted + case EACCES: // Permission denied +#ifdef ENOKEY + case ENOKEY: // Required key not available +#endif + case EROFS: // Read only file system + return StatusCode::kPermissionDenied; + case ENOTEMPTY: // Directory not empty + case EISDIR: // Is a directory + case ENOTDIR: // Not a directory + case EADDRINUSE: // Address already in use + case EBADF: // Invalid file descriptor +#ifdef EBADFD + case EBADFD: // File descriptor in bad state +#endif + case EBUSY: // Device or resource busy + case ECHILD: // No child processes + case EISCONN: // Socket is connected +#ifdef EISNAM + case EISNAM: // Is a named type file +#endif +#ifdef ENOTBLK + case ENOTBLK: // Block device required +#endif + case ENOTCONN: // The socket is not connected + case EPIPE: // Broken pipe +#ifdef ESHUTDOWN + case ESHUTDOWN: // Cannot send after transport endpoint shutdown +#endif + case ETXTBSY: // Text file busy +#ifdef EUNATCH + case EUNATCH: // Protocol driver not attached +#endif + return StatusCode::kFailedPrecondition; + case ENOSPC: // No space left on device +#ifdef EDQUOT + case EDQUOT: // Disk quota exceeded +#endif + case EMFILE: // Too many open files + case EMLINK: // Too many links + case ENFILE: // Too many open files in system + case ENOBUFS: // No buffer space available + case ENODATA: // No message is available on the STREAM read queue + case ENOMEM: // Not enough space + case ENOSR: // No STREAM resources +#ifdef EUSERS + case EUSERS: // Too many users +#endif + return StatusCode::kResourceExhausted; +#ifdef ECHRNG + case ECHRNG: // Channel number out of range +#endif + case EFBIG: // File too large + case EOVERFLOW: // Value too large to be stored in data type + case ERANGE: // Result too large + return StatusCode::kOutOfRange; +#ifdef ENOPKG + case ENOPKG: // Package not installed +#endif + case ENOSYS: // Function not implemented + case ENOTSUP: // Operation not supported + case EAFNOSUPPORT: // Address family not supported +#ifdef EPFNOSUPPORT + case EPFNOSUPPORT: // Protocol family not supported +#endif + case EPROTONOSUPPORT: // Protocol not supported +#ifdef ESOCKTNOSUPPORT + case ESOCKTNOSUPPORT: // Socket type not supported +#endif + case EXDEV: // Improper link + return StatusCode::kUnimplemented; + case EAGAIN: // Resource temporarily unavailable +#ifdef ECOMM + case ECOMM: // Communication error on send +#endif + case ECONNREFUSED: // Connection refused + case ECONNABORTED: // Connection aborted + case ECONNRESET: // Connection reset + case EINTR: // Interrupted function call +#ifdef EHOSTDOWN + case EHOSTDOWN: // Host is down +#endif + case EHOSTUNREACH: // Host is unreachable + case ENETDOWN: // Network is down + case ENETRESET: // Connection aborted by network + case ENETUNREACH: // Network unreachable + case ENOLCK: // No locks available + case ENOLINK: // Link has been severed +#ifdef ENONET + case ENONET: // Machine is not on the network +#endif + return StatusCode::kUnavailable; + case EDEADLK: // Resource deadlock avoided +#ifdef ESTALE + case ESTALE: // Stale file handle +#endif + return StatusCode::kAborted; + case ECANCELED: // Operation cancelled + return StatusCode::kCancelled; + default: + return StatusCode::kUnknown; + } +} + +Status ErrnoToCanonicalStatus(int error_number, absl::string_view message) { + // TODO(scotttodd): convert error number to a string + return Status(ErrnoToCanonicalCode(error_number), + absl::StrCat(message, ": ", error_number)); +} + +StatusBuilder ErrnoToCanonicalStatusBuilder(int error_number, + absl::string_view message, + SourceLocation location) { + return StatusBuilder(ErrnoToCanonicalStatus(error_number, message), location); +} + +} // namespace iree
diff --git a/base/internal/status_errno.h b/base/internal/status_errno.h new file mode 100644 index 0000000..651e80d --- /dev/null +++ b/base/internal/status_errno.h
@@ -0,0 +1,41 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BASE_INTERNAL_STATUS_ERRNO_H_ +#define IREE_BASE_INTERNAL_STATUS_ERRNO_H_ + +#include "absl/strings/string_view.h" +#include "base/internal/status.h" +#include "base/internal/statusor.h" +#include "base/source_location.h" + +namespace iree { + +// Returns the code for |error_number|, which should be an |errno| value. +// See https://en.cppreference.com/w/cpp/error/errno_macros and similar refs. +StatusCode ErrnoToCanonicalCode(int error_number); + +// Returns a Status, using a code of `ErrnoToCode(error_number)`, and a +// |message| with the result of `StrError(error_number)` appended. +Status ErrnoToCanonicalStatus(int error_number, absl::string_view message); + +// Returns a StatusBuilder using a status of +// `ErrnoToCanonicalStatus(error_number, message)` and |location|. +StatusBuilder ErrnoToCanonicalStatusBuilder(int error_number, + absl::string_view message, + SourceLocation location); + +} // namespace iree + +#endif // IREE_BASE_INTERNAL_STATUS_ERRNO_H_
diff --git a/base/internal/status_errors.cc b/base/internal/status_errors.cc new file mode 100644 index 0000000..ae695e5 --- /dev/null +++ b/base/internal/status_errors.cc
@@ -0,0 +1,147 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/internal/status_errors.h" + +namespace iree { + +Status AbortedError(absl::string_view message) { + return Status(StatusCode::kAborted, message); +} + +Status AlreadyExistsError(absl::string_view message) { + return Status(StatusCode::kAlreadyExists, message); +} + +Status CancelledError(absl::string_view message) { + return Status(StatusCode::kCancelled, message); +} + +Status DataLossError(absl::string_view message) { + return Status(StatusCode::kDataLoss, message); +} + +Status DeadlineExceededError(absl::string_view message) { + return Status(StatusCode::kDeadlineExceeded, message); +} + +Status FailedPreconditionError(absl::string_view message) { + return Status(StatusCode::kFailedPrecondition, message); +} + +Status InternalError(absl::string_view message) { + return Status(StatusCode::kInternal, message); +} + +Status InvalidArgumentError(absl::string_view message) { + return Status(StatusCode::kInvalidArgument, message); +} + +Status NotFoundError(absl::string_view message) { + return Status(StatusCode::kNotFound, message); +} + +Status OutOfRangeError(absl::string_view message) { + return Status(StatusCode::kOutOfRange, message); +} + +Status PermissionDeniedError(absl::string_view message) { + return Status(StatusCode::kPermissionDenied, message); +} + +Status ResourceExhaustedError(absl::string_view message) { + return Status(StatusCode::kResourceExhausted, message); +} + +Status UnauthenticatedError(absl::string_view message) { + return Status(StatusCode::kUnauthenticated, message); +} + +Status UnavailableError(absl::string_view message) { + return Status(StatusCode::kUnavailable, message); +} + +Status UnimplementedError(absl::string_view message) { + return Status(StatusCode::kUnimplemented, message); +} + +Status UnknownError(absl::string_view message) { + return Status(StatusCode::kUnknown, message); +} + +bool IsAborted(const Status& status) { + return status.code() == StatusCode::kAborted; +} + +bool IsAlreadyExists(const Status& status) { + return status.code() == StatusCode::kAlreadyExists; +} + +bool IsCancelled(const Status& status) { + return status.code() == StatusCode::kCancelled; +} + +bool IsDataLoss(const Status& status) { + return status.code() == StatusCode::kDataLoss; +} + +bool IsDeadlineExceeded(const Status& status) { + return status.code() == StatusCode::kDeadlineExceeded; +} + +bool IsFailedPrecondition(const Status& status) { + return status.code() == StatusCode::kFailedPrecondition; +} + +bool IsInternal(const Status& status) { + return status.code() == StatusCode::kInternal; +} + +bool IsInvalidArgument(const Status& status) { + return status.code() == StatusCode::kInvalidArgument; +} + +bool IsNotFound(const Status& status) { + return status.code() == StatusCode::kNotFound; +} + +bool IsOutOfRange(const Status& status) { + return status.code() == StatusCode::kOutOfRange; +} + +bool IsPermissionDenied(const Status& status) { + return status.code() == StatusCode::kPermissionDenied; +} + +bool IsResourceExhausted(const Status& status) { + return status.code() == StatusCode::kResourceExhausted; +} + +bool IsUnauthenticated(const Status& status) { + return status.code() == StatusCode::kUnauthenticated; +} + +bool IsUnavailable(const Status& status) { + return status.code() == StatusCode::kUnavailable; +} + +bool IsUnimplemented(const Status& status) { + return status.code() == StatusCode::kUnimplemented; +} + +bool IsUnknown(const Status& status) { + return status.code() == StatusCode::kUnknown; +} + +} // namespace iree
diff --git a/base/internal/status_errors.h b/base/internal/status_errors.h new file mode 100644 index 0000000..fee3f0f --- /dev/null +++ b/base/internal/status_errors.h
@@ -0,0 +1,60 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BASE_INTERNAL_STATUS_ERRORS_H_ +#define IREE_BASE_INTERNAL_STATUS_ERRORS_H_ + +#include "absl/base/attributes.h" +#include "absl/strings/string_view.h" +#include "base/internal/status.h" + +namespace iree { + +Status AbortedError(absl::string_view message); +Status AlreadyExistsError(absl::string_view message); +Status CancelledError(absl::string_view message); +Status DataLossError(absl::string_view message); +Status DeadlineExceededError(absl::string_view message); +Status FailedPreconditionError(absl::string_view message); +Status InternalError(absl::string_view message); +Status InvalidArgumentError(absl::string_view message); +Status NotFoundError(absl::string_view message); +Status OutOfRangeError(absl::string_view message); +Status PermissionDeniedError(absl::string_view message); +Status ResourceExhaustedError(absl::string_view message); +Status UnauthenticatedError(absl::string_view message); +Status UnavailableError(absl::string_view message); +Status UnimplementedError(absl::string_view message); +Status UnknownError(absl::string_view message); + +ABSL_MUST_USE_RESULT bool IsAborted(const Status& status); +ABSL_MUST_USE_RESULT bool IsAlreadyExists(const Status& status); +ABSL_MUST_USE_RESULT bool IsCancelled(const Status& status); +ABSL_MUST_USE_RESULT bool IsDataLoss(const Status& status); +ABSL_MUST_USE_RESULT bool IsDeadlineExceeded(const Status& status); +ABSL_MUST_USE_RESULT bool IsFailedPrecondition(const Status& status); +ABSL_MUST_USE_RESULT bool IsInternal(const Status& status); +ABSL_MUST_USE_RESULT bool IsInvalidArgument(const Status& status); +ABSL_MUST_USE_RESULT bool IsNotFound(const Status& status); +ABSL_MUST_USE_RESULT bool IsOutOfRange(const Status& status); +ABSL_MUST_USE_RESULT bool IsPermissionDenied(const Status& status); +ABSL_MUST_USE_RESULT bool IsResourceExhausted(const Status& status); +ABSL_MUST_USE_RESULT bool IsUnauthenticated(const Status& status); +ABSL_MUST_USE_RESULT bool IsUnavailable(const Status& status); +ABSL_MUST_USE_RESULT bool IsUnimplemented(const Status& status); +ABSL_MUST_USE_RESULT bool IsUnknown(const Status& status); + +} // namespace iree + +#endif // IREE_BASE_INTERNAL_STATUS_ERRORS_H_
diff --git a/base/internal/status_macros.h b/base/internal/status_macros.h new file mode 100644 index 0000000..bcae07f --- /dev/null +++ b/base/internal/status_macros.h
@@ -0,0 +1,108 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BASE_INTERNAL_STATUS_MACROS_H_ +#define IREE_BASE_INTERNAL_STATUS_MACROS_H_ + +#include "base/internal/status.h" +#include "base/internal/status_builder.h" +#include "base/internal/statusor.h" +#include "base/source_location.h" + +// Evaluates an expression that produces a `iree::Status`. If the status is not +// ok, returns it from the current function. +#define RETURN_IF_ERROR(expr) \ + STATUS_MACROS_IMPL_ELSE_BLOCKER_ \ + if (iree::status_macro_internal::StatusAdaptorForMacros \ + status_macro_internal_adaptor = {(expr), IREE_LOC}) { \ + } else /* NOLINT */ \ + return status_macro_internal_adaptor.Consume() + +// Executes an expression `rexpr` that returns a `iree::StatusOr<T>`. On OK, +// moves its value into the variable defined by `lhs`, otherwise returns +// from the current function. +#define ASSIGN_OR_RETURN(...) \ + STATUS_MACROS_IMPL_GET_VARIADIC_((__VA_ARGS__, \ + STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_, \ + STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_)) \ + (__VA_ARGS__) + +// ================================================================= +// == Implementation details, do not rely on anything below here. == +// ================================================================= + +// MSVC incorrectly expands variadic macros, splice together a macro call to +// work around the bug. +#define STATUS_MACROS_IMPL_GET_VARIADIC_HELPER_(_1, _2, _3, NAME, ...) NAME +#define STATUS_MACROS_IMPL_GET_VARIADIC_(args) \ + STATUS_MACROS_IMPL_GET_VARIADIC_HELPER_ args + +#define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_(lhs, rexpr) \ + STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, std::move(_)) +#define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, error_expression) \ + STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_( \ + STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, rexpr, \ + error_expression) +#define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_(statusor, lhs, rexpr, \ + error_expression) \ + auto statusor = (rexpr); \ + if (ABSL_PREDICT_FALSE(!statusor.ok())) { \ + iree::StatusBuilder _(std::move(statusor).status(), IREE_LOC); \ + (void)_; /* error_expression is allowed to not use this variable */ \ + return (error_expression); \ + } \ + lhs = std::move(statusor).ValueOrDie() + +// Internal helper for concatenating macro values. +#define STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) x##y +#define STATUS_MACROS_IMPL_CONCAT_(x, y) STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) + +// clang-format off +#define STATUS_MACROS_IMPL_ELSE_BLOCKER_ switch (0) case 0: default: // NOLINT +// clang-format on + +namespace iree { +namespace status_macro_internal { + +// Provides a conversion to bool so that it can be used inside an if statement +// that declares a variable. +class StatusAdaptorForMacros { + public: + StatusAdaptorForMacros(const Status& status, SourceLocation loc) + : builder_(status, loc) {} + + StatusAdaptorForMacros(Status&& status, SourceLocation loc) + : builder_(std::move(status), loc) {} + + StatusAdaptorForMacros(const StatusBuilder& builder, SourceLocation loc) + : builder_(builder) {} + + StatusAdaptorForMacros(StatusBuilder&& builder, SourceLocation loc) + : builder_(std::move(builder)) {} + + StatusAdaptorForMacros(const StatusAdaptorForMacros&) = delete; + StatusAdaptorForMacros& operator=(const StatusAdaptorForMacros&) = delete; + + explicit operator bool() const { return ABSL_PREDICT_TRUE(builder_.ok()); } + + StatusBuilder&& Consume() { return std::move(builder_); } + + private: + StatusBuilder builder_; +}; + +} // namespace status_macro_internal +} // namespace iree + +#endif // IREE_BASE_INTERNAL_STATUS_MACROS_H_
diff --git a/base/internal/status_matchers.h b/base/internal/status_matchers.h new file mode 100644 index 0000000..f951e17 --- /dev/null +++ b/base/internal/status_matchers.h
@@ -0,0 +1,299 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BASE_INTERNAL_STATUS_MATCHERS_H_ +#define IREE_BASE_INTERNAL_STATUS_MATCHERS_H_ + +#include <memory> + +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "base/status.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +#undef EXPECT_OK +#undef ASSERT_OK +#undef ASSERT_OK_AND_ASSIGN + +namespace iree { + +namespace internal { + +// Implements a gMock matcher that checks that an iree::StaturOr<T> has an OK +// status and that the contained T value matches another matcher. +template <typename T> +class IsOkAndHoldsMatcher + : public ::testing::MatcherInterface<const StatusOr<T> &> { + public: + template <typename MatcherT> + IsOkAndHoldsMatcher(MatcherT &&value_matcher) + : value_matcher_(::testing::SafeMatcherCast<const T &>(value_matcher)) {} + + // From testing::MatcherInterface. + void DescribeTo(std::ostream *os) const override { + *os << "is OK and contains a value that "; + value_matcher_.DescribeTo(os); + } + + // From testing::MatcherInterface. + void DescribeNegationTo(std::ostream *os) const override { + *os << "is not OK or contains a value that "; + value_matcher_.DescribeNegationTo(os); + } + + // From testing::MatcherInterface. + bool MatchAndExplain( + const StatusOr<T> &status_or, + ::testing::MatchResultListener *listener) const override { + if (!status_or.ok()) { + *listener << "which is not OK"; + return false; + } + + ::testing::StringMatchResultListener value_listener; + bool is_a_match = + value_matcher_.MatchAndExplain(status_or.ValueOrDie(), &value_listener); + std::string value_explanation = value_listener.str(); + if (!value_explanation.empty()) { + *listener << absl::StrCat("which contains a value ", value_explanation); + } + + return is_a_match; + } + + private: + const ::testing::Matcher<const T &> value_matcher_; +}; + +// A polymorphic IsOkAndHolds() matcher. +// +// IsOkAndHolds() returns a matcher that can be used to process an IsOkAndHolds +// expectation. However, the value type T is not provided when IsOkAndHolds() is +// invoked. The value type is only inferable when the gUnit framework invokes +// the matcher with a value. Consequently, the IsOkAndHolds() function must +// return an object that is implicitly convertible to a matcher for StatusOr<T>. +// gUnit refers to such an object as a polymorphic matcher, since it can be used +// to match with more than one type of value. +template <typename ValueMatcherT> +class IsOkAndHoldsGenerator { + public: + explicit IsOkAndHoldsGenerator(ValueMatcherT value_matcher) + : value_matcher_(std::move(value_matcher)) {} + + template <typename T> + operator ::testing::Matcher<const StatusOr<T> &>() const { + return ::testing::MakeMatcher(new IsOkAndHoldsMatcher<T>(value_matcher_)); + } + + private: + const ValueMatcherT value_matcher_; +}; + +// Implements a gMock matcher for checking error-code expectations on +// iree::Status objects. +template <typename Enum> +class StatusMatcher : public ::testing::MatcherInterface<const Status &> { + public: + StatusMatcher(Enum code, absl::optional<absl::string_view> message) + : code_(code), message_(message) {} + + // From testing::MatcherInterface. + // + // Describes the expected error code. + void DescribeTo(std::ostream *os) const override { + *os << "error code " << StatusCodeToString(code_); + if (message_.has_value()) { + *os << "::'" << message_.value() << "'"; + } + } + + // From testing::MatcherInterface. + // + // Tests whether |status| has an error code that meets this matcher's + // expectation. If an error message string is specified in this matcher, it + // also tests that |status| has an error message that matches that + // expectation. + bool MatchAndExplain( + const Status &status, + ::testing::MatchResultListener *listener) const override { + if (status.code() != code_) { + *listener << "whose error code is " << StatusCodeToString(status.code()); + return false; + } + if (message_.has_value() && status.message() != message_.value()) { + *listener << "whose error message is '" << status.message() << "'"; + return false; + } + return true; + } + + private: + // Expected error code. + const Enum code_; + + // Expected error message (empty if none expected and verified). + const absl::optional<std::string> message_; +}; + +// Implements a gMock matcher that checks whether a status container (e.g. +// iree::Status or iree::StatusOr<T>) has an OK status. +template <class T> +class IsOkMatcherImpl : public ::testing::MatcherInterface<T> { + public: + IsOkMatcherImpl() = default; + + // From testing::MatcherInterface. + // + // Describes the OK expectation. + void DescribeTo(std::ostream *os) const override { *os << "is OK"; } + + // From testing::MatcherInterface. + // + // Describes the negative OK expectation. + void DescribeNegationTo(std::ostream *os) const override { + *os << "is not OK"; + } + + // From testing::MatcherInterface. + // + // Tests whether |status_container|'s OK value meets this matcher's + // expectation. + bool MatchAndExplain( + const T &status_container, + ::testing::MatchResultListener *listener) const override { + if (!status_container.ok()) { + *listener << "which is not OK"; + return false; + } + return true; + } +}; + +// IsOkMatcherGenerator is an intermediate object returned by iree::IsOk(). +// It implements implicit type-cast operators to supported matcher types: +// Matcher<const Status &> and Matcher<const StatusOr<T> &>. These typecast +// operators create gMock matchers that test OK expectations on a status +// container. +class IsOkMatcherGenerator { + public: + // Type-cast operator for Matcher<const iree::Status &>. + operator ::testing::Matcher<const Status &>() const { + return ::testing::MakeMatcher( + new internal::IsOkMatcherImpl<const Status &>()); + } + + // Type-cast operator for Matcher<const iree::StatusOr<T> &>. + template <class T> + operator ::testing::Matcher<const StatusOr<T> &>() const { + return ::testing::MakeMatcher( + new internal::IsOkMatcherImpl<const StatusOr<T> &>()); + } +}; + +} // namespace internal + +// Returns a gMock matcher that expects an iree::StatusOr<T> object to have an +// OK status and for the contained T object to match |value_matcher|. +// +// Example: +// +// StatusOr<string> raven_speech_result = raven.Speak(); +// EXPECT_THAT(raven_speech_result, IsOkAndHolds(HasSubstr("nevermore"))); +// +// If foo is an object of type T and foo_result is an object of type +// StatusOr<T>, you can write: +// +// EXPECT_THAT(foo_result, IsOkAndHolds(foo)); +// +// instead of: +// +// EXPECT_THAT(foo_result, IsOkAndHolds(Eq(foo))); +template <typename ValueMatcherT> +internal::IsOkAndHoldsGenerator<ValueMatcherT> IsOkAndHolds( + ValueMatcherT value_matcher) { + return internal::IsOkAndHoldsGenerator<ValueMatcherT>(value_matcher); +} + +// Returns a gMock matcher that expects an iree::Status object to have the +// given |code|. +template <typename Enum> +::testing::Matcher<const Status &> StatusIs(Enum code) { + return ::testing::MakeMatcher( + new internal::StatusMatcher<Enum>(code, absl::nullopt)); +} + +// Returns a gMock matcher that expects an iree::Status object to have the +// given |code| and |message|. +template <typename Enum> +::testing::Matcher<const Status &> StatusIs(Enum code, + absl::string_view message) { + return ::testing::MakeMatcher( + new internal::StatusMatcher<Enum>(code, message)); +} + +// Returns an internal::IsOkMatcherGenerator, which may be typecast to a +// Matcher<iree::Status> or Matcher<iree::StatusOr<T>>. These gMock +// matchers test that a given status container has an OK status. +inline internal::IsOkMatcherGenerator IsOk() { + return internal::IsOkMatcherGenerator(); +} + +// Macros for testing the results of functions that return iree::Status or +// iree::StatusOr<T> (for any type T). +#define EXPECT_OK(rexpr) EXPECT_THAT(rexpr, ::iree::IsOk()) +#define ASSERT_OK(rexpr) ASSERT_THAT(rexpr, ::iree::IsOk()) + +// Executes an expression that returns an iree::StatusOr<T>, and assigns the +// contained variable to lhs if the error code is OK. +// If the Status is non-OK, generates a test failure and returns from the +// current function, which must have a void return type. +// +// Example: Assigning to an existing value +// ASSERT_OK_AND_ASSIGN(ValueType value, MaybeGetValue(arg)); +// +// The value assignment example might expand into: +// StatusOr<ValueType> status_or_value = MaybeGetValue(arg); +// ASSERT_OK(status_or_value.status()); +// ValueType value = status_or_value.ValueOrDie(); +#define ASSERT_OK_AND_ASSIGN(lhs, rexpr) \ + IREE_ASSERT_OK_AND_ASSIGN_IMPL( \ + IREE_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, \ + rexpr); + +#define IREE_ASSERT_OK_AND_ASSIGN_IMPL(statusor, lhs, rexpr) \ + auto statusor = (rexpr); \ + ASSERT_OK(statusor.status()) << statusor.status(); \ + lhs = std::move(statusor.ValueOrDie()) +#define IREE_STATUS_MACROS_CONCAT_NAME(x, y) \ + IREE_STATUS_MACROS_CONCAT_IMPL(x, y) +#define IREE_STATUS_MACROS_CONCAT_IMPL(x, y) x##y + +// Implements the PrintTo() method for iree::StatusOr<T>. This method is +// used by gUnit to print iree::StatusOr<T> objects for debugging. The +// implementation relies on gUnit for printing values of T when a +// iree::StatusOr<T> object is OK and contains a value. +template <typename T> +void PrintTo(const StatusOr<T> &statusor, std::ostream *os) { + if (!statusor.ok()) { + *os << statusor.status(); + } else { + *os << absl::StrCat("OK: ", + ::testing::PrintToString(statusor.ValueOrDie())); + } +} + +} // namespace iree + +#endif // IREE_BASE_INTERNAL_STATUS_MATCHERS_H_
diff --git a/base/internal/status_win32_errors.cc b/base/internal/status_win32_errors.cc new file mode 100644 index 0000000..f09009c --- /dev/null +++ b/base/internal/status_win32_errors.cc
@@ -0,0 +1,63 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/internal/status_win32_errors.h" + +#include "absl/strings/str_cat.h" + +#if defined(IREE_PLATFORM_WINDOWS) + +#include <windows.h> + +namespace iree { + +StatusCode Win32ErrorToCanonicalCode(uint32_t error) { + switch (error) { + case ERROR_SUCCESS: + return StatusCode::kOk; + case ERROR_FILE_NOT_FOUND: + case ERROR_PATH_NOT_FOUND: + return StatusCode::kNotFound; + case ERROR_TOO_MANY_OPEN_FILES: + case ERROR_OUTOFMEMORY: + case ERROR_HANDLE_DISK_FULL: + case ERROR_HANDLE_EOF: + return StatusCode::kResourceExhausted; + case ERROR_ACCESS_DENIED: + return StatusCode::kPermissionDenied; + case ERROR_INVALID_HANDLE: + return StatusCode::kInvalidArgument; + case ERROR_NOT_READY: + case ERROR_READ_FAULT: + return StatusCode::kUnavailable; + case ERROR_WRITE_FAULT: + return StatusCode::kDataLoss; + case ERROR_NOT_SUPPORTED: + return StatusCode::kUnimplemented; + default: + return StatusCode::kUnknown; + } +} + +StatusBuilder Win32ErrorToCanonicalStatusBuilder(uint32_t error, + SourceLocation location) { + // TODO(benvanik): use FormatMessage; or defer until required? + return StatusBuilder( + Status(Win32ErrorToCanonicalCode(error), absl::StrCat("<TBD>", error)), + location); +} + +} // namespace iree + +#endif // IREE_PLATFORM_WINDOWS
diff --git a/base/internal/status_win32_errors.h b/base/internal/status_win32_errors.h new file mode 100644 index 0000000..33ce74f --- /dev/null +++ b/base/internal/status_win32_errors.h
@@ -0,0 +1,38 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BASE_INTERNAL_STATUS_WIN32_ERRORS_H_ +#define IREE_BASE_INTERNAL_STATUS_WIN32_ERRORS_H_ + +#include "absl/strings/string_view.h" +#include "base/internal/statusor.h" +#include "base/source_location.h" +#include "base/target_platform.h" + +#if defined(IREE_PLATFORM_WINDOWS) + +namespace iree { + +// Returns the code for |error| which should be a Win32 error dword. +StatusCode Win32ErrorToCanonicalCode(uint32_t error); + +// Returns a StatusBuilder with a status describing the |error| and |location|. +StatusBuilder Win32ErrorToCanonicalStatusBuilder(uint32_t error, + SourceLocation location); + +} // namespace iree + +#endif // IREE_PLATFORM_WINDOWS + +#endif // IREE_BASE_INTERNAL_STATUS_WIN32_ERRORS_H_
diff --git a/base/internal/statusor.cc b/base/internal/statusor.cc new file mode 100644 index 0000000..0550241 --- /dev/null +++ b/base/internal/statusor.cc
@@ -0,0 +1,39 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/internal/statusor.h" + +#include "base/internal/status_errors.h" + +namespace iree { + +namespace internal_statusor { + +void Helper::HandleInvalidStatusCtorArg(Status* status) { + const char* kMessage = + "An OK status is not a valid constructor argument to StatusOr<T>"; + LOG(ERROR) << kMessage; + *status = InternalError(kMessage); + abort(); +} + +void Helper::Crash(const Status& status) { + LOG(FATAL) << "Attempting to fetch value instead of handling error " + << status; + abort(); +} + +} // namespace internal_statusor + +} // namespace iree
diff --git a/base/internal/statusor.h b/base/internal/statusor.h new file mode 100644 index 0000000..895a8dd --- /dev/null +++ b/base/internal/statusor.h
@@ -0,0 +1,699 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BASE_INTERNAL_STATUSOR_H_ +#define IREE_BASE_INTERNAL_STATUSOR_H_ + +#include "absl/base/attributes.h" +#include "base/internal/status.h" +#include "base/internal/status_builder.h" + +namespace iree { + +template <typename T> +class ABSL_MUST_USE_RESULT StatusOr; + +namespace internal_statusor { + +template <typename T, typename U> +using IsStatusOrConversionAmbiguous = + absl::disjunction<std::is_constructible<T, StatusOr<U>&>, + std::is_constructible<T, const StatusOr<U>&>, + std::is_constructible<T, StatusOr<U>&&>, + std::is_constructible<T, const StatusOr<U>&&>, + std::is_convertible<StatusOr<U>&, T>, + std::is_convertible<const StatusOr<U>&, T>, + std::is_convertible<StatusOr<U>&&, T>, + std::is_convertible<const StatusOr<U>&&, T>>; + +template <typename T, typename U> +using IsStatusOrConversionAssigmentAmbiguous = + absl::disjunction<IsStatusOrConversionAmbiguous<T, U>, + std::is_assignable<T&, StatusOr<U>&>, + std::is_assignable<T&, const StatusOr<U>&>, + std::is_assignable<T&, StatusOr<U>&&>, + std::is_assignable<T&, const StatusOr<U>&&>>; + +template <typename T, typename U> +struct IsAmbiguousStatusOrForInitialization + : // Strip const-value refs from type and check again, else false_type. + public absl::conditional_t< + std::is_same<absl::remove_cv_t<absl::remove_reference_t<U>>, + U>::value, + std::false_type, + IsAmbiguousStatusOrForInitialization< + T, absl::remove_cv_t<absl::remove_reference_t<U>>>> {}; + +template <typename T, typename U> +struct IsAmbiguousStatusOrForInitialization<T, StatusOr<U>> + : public IsStatusOrConversionAmbiguous<T, U> {}; + +template <typename T, typename U> +using IsStatusOrDirectInitializationAmbiguous = absl::disjunction< + std::is_same<StatusOr<T>, absl::remove_cv_t<absl::remove_reference_t<U>>>, + std::is_same<Status, absl::remove_cv_t<absl::remove_reference_t<U>>>, + std::is_same<StatusBuilder, absl::remove_cv_t<absl::remove_reference_t<U>>>, + std::is_same<absl::in_place_t, + absl::remove_cv_t<absl::remove_reference_t<U>>>, + IsAmbiguousStatusOrForInitialization<T, U>>; + +template <typename T, typename U> +using IsStatusOrDirectInitializationValid = absl::disjunction< + // The is_same allows nested status ors to ignore this check iff same type. + std::is_same<T, absl::remove_cv_t<absl::remove_reference_t<U>>>, + absl::negation<IsStatusOrDirectInitializationAmbiguous<T, U>>>; + +class Helper { + public: + ABSL_ATTRIBUTE_NORETURN static void HandleInvalidStatusCtorArg(Status*); + ABSL_ATTRIBUTE_NORETURN static void Crash(const Status& status); +}; + +// Construct an instance of T in `p` through placement new, passing Args... to +// the constructor. +// This abstraction is here mostly for the gcc performance fix. +template <typename T, typename... Args> +void PlacementNew(void* p, Args&&... args) { +#if defined(__GNUC__) && !defined(__clang__) + // Teach gcc that 'p' cannot be null, fixing code size issues. + if (p == nullptr) __builtin_unreachable(); +#endif + new (p) T(std::forward<Args>(args)...); +} + +// Helper base class to hold the data and all operations. +// We move all this to a base class to allow mixing with the appropriate +// TraitsBase specialization. +template <typename T> +class StatusOrData { + template <typename U> + friend class StatusOrData; + + public: + StatusOrData() = delete; + + StatusOrData(const StatusOrData& other) { + if (other.ok()) { + MakeValue(other.data_); + MakeStatus(); + } else { + MakeStatus(other.status_); + } + } + + StatusOrData(StatusOrData&& other) noexcept { + if (other.ok()) { + MakeValue(std::move(other.data_)); + MakeStatus(); + } else { + MakeStatus(std::move(other.status_)); + } + } + + template <typename U> + explicit StatusOrData(const StatusOrData<U>& other) { + if (other.ok()) { + MakeValue(other.data_); + MakeStatus(); + } else { + MakeStatus(other.status_); + } + } + + template <typename U> + explicit StatusOrData(StatusOrData<U>&& other) { + if (other.ok()) { + MakeValue(std::move(other.data_)); + MakeStatus(); + } else { + MakeStatus(std::move(other.status_)); + } + } + + template <typename... Args> + explicit StatusOrData(absl::in_place_t, Args&&... args) + : data_(std::forward<Args>(args)...) { + MakeStatus(); + } + + explicit StatusOrData(const T& value) : data_(value) { MakeStatus(); } + explicit StatusOrData(T&& value) : data_(std::move(value)) { MakeStatus(); } + + explicit StatusOrData(const Status& status) : status_(status) { + EnsureNotOk(); + } + explicit StatusOrData(Status&& status) : status_(status) { EnsureNotOk(); } + + explicit StatusOrData(const StatusBuilder& builder) : status_(builder) { + EnsureNotOk(); + } + explicit StatusOrData(StatusBuilder&& builder) : status_(std::move(builder)) { + EnsureNotOk(); + } + + StatusOrData& operator=(const StatusOrData& other) { + if (this == &other) return *this; + if (other.ok()) + Assign(other.data_); + else + Assign(other.status_); + return *this; + } + + StatusOrData& operator=(StatusOrData&& other) { + if (this == &other) return *this; + if (other.ok()) + Assign(std::move(other.data_)); + else + Assign(std::move(other.status_)); + return *this; + } + + ~StatusOrData() { + if (ok()) { + status_.~Status(); + data_.~T(); + } else { + status_.~Status(); + } + } + + void Assign(const T& value) { + if (ok()) { + data_.~T(); + MakeValue(value); + } else { + MakeValue(value); + status_ = OkStatus(); + } + } + + void Assign(T&& value) { + if (ok()) { + data_.~T(); + MakeValue(std::move(value)); + } else { + MakeValue(std::move(value)); + status_ = OkStatus(); + } + } + + void Assign(const Status& status) { + Clear(); + status_ = status; + EnsureNotOk(); + } + + void Assign(Status&& status) { + Clear(); + status_ = std::move(status); + EnsureNotOk(); + } + + bool ok() const { return status_.ok(); } + + protected: + // status_ will always be active after the constructor. + // Union to be able to initialize exactly how we need without waste. + // Eg. in the copy constructor we use the default constructor of Status in + // the ok() path to avoid an extra Ref call. + union { + Status status_; + }; + + // data_ is active iff status_.ok()==true + struct Dummy {}; + union { + // When T is const, we need some non-const object we can cast to void* for + // the placement new. dummy_ is that object. + Dummy dummy_; + T data_; + }; + + void Clear() { + if (ok()) data_.~T(); + } + + void EnsureOk() const { + if (!ok()) Helper::Crash(status_); + } + + void EnsureNotOk() { + if (ok()) Helper::HandleInvalidStatusCtorArg(&status_); + } + + // Construct the value (data_) through placement new with the passed arg. + template <typename Arg> + void MakeValue(Arg&& arg) { + internal_statusor::PlacementNew<T>(&dummy_, std::forward<Arg>(arg)); + } + + // Construct the status (status_) through placement new with the passed arg. + template <typename... Args> + void MakeStatus(Args&&... args) { + internal_statusor::PlacementNew<Status>(&status_, + std::forward<Args>(args)...); + } +}; + +// Helper base class to allow implicitly deleted constructors and assignment +// operations in StatusOr. +// TraitsBase will explicitly delete what it can't support and StatusOr will +// inherit that behavior implicitly. +template <bool Copy, bool Move> +struct TraitsBase { + TraitsBase() = default; + TraitsBase(const TraitsBase&) = default; + TraitsBase(TraitsBase&&) = default; + TraitsBase& operator=(const TraitsBase&) = default; + TraitsBase& operator=(TraitsBase&&) = default; +}; + +template <> +struct TraitsBase<false, true> { + TraitsBase() = default; + TraitsBase(const TraitsBase&) = delete; + TraitsBase(TraitsBase&&) = default; + TraitsBase& operator=(const TraitsBase&) = delete; + TraitsBase& operator=(TraitsBase&&) = default; +}; + +template <> +struct TraitsBase<false, false> { + TraitsBase() = default; + TraitsBase(const TraitsBase&) = delete; + TraitsBase(TraitsBase&&) = delete; + TraitsBase& operator=(const TraitsBase&) = delete; + TraitsBase& operator=(TraitsBase&&) = delete; +}; + +} // namespace internal_statusor + +// StatusOr<T> is the union of a Status object and a T object. +// +// A StatusOr object either holds a usable value, or an error Status explaining +// why such a value is not present. +template <typename T> +class StatusOr : private internal_statusor::StatusOrData<T>, + private internal_statusor::TraitsBase< + std::is_copy_constructible<T>::value, + std::is_move_constructible<T>::value> { + template <typename U> + friend class StatusOr; + + typedef internal_statusor::StatusOrData<T> Base; + + public: + typedef T element_type; + + // Constructs a new StatusOr with StatusCode::kUnknown status. + explicit StatusOr(); + + // StatusOr<T> is copy constructible/assignable if T is copy constructible. + StatusOr(const StatusOr&) = default; + StatusOr& operator=(const StatusOr&) = default; + + // StatusOr<T> is move constructible/assignable if T is move constructible. + StatusOr(StatusOr&&) = default; + StatusOr& operator=(StatusOr&&) = default; + + // Converting constructors from StatusOr<U>, when T is constructible from U. + // To avoid ambiguity, they are disabled if T is also constructible from + // StatusOr<U>. Explicit iff the corresponding construction of T from U is + // explicit. + template < + typename U, + absl::enable_if_t< + absl::conjunction< + absl::negation<std::is_same<T, U>>, + std::is_constructible<T, const U&>, + std::is_convertible<const U&, T>, + absl::negation<internal_statusor::IsStatusOrConversionAmbiguous< + T, U>>>::value, + int> = 0> + StatusOr(const StatusOr<U>& other) // NOLINT + : Base(static_cast<const typename StatusOr<U>::Base&>(other)) {} + template < + typename U, + absl::enable_if_t< + absl::conjunction< + absl::negation<std::is_same<T, U>>, + std::is_constructible<T, const U&>, + absl::negation<std::is_convertible<const U&, T>>, + absl::negation<internal_statusor::IsStatusOrConversionAmbiguous< + T, U>>>::value, + int> = 0> + explicit StatusOr(const StatusOr<U>& other) + : Base(static_cast<const typename StatusOr<U>::Base&>(other)) {} + + template < + typename U, + absl::enable_if_t< + absl::conjunction< + absl::negation<std::is_same<T, U>>, std::is_constructible<T, U&&>, + std::is_convertible<U&&, T>, + absl::negation<internal_statusor::IsStatusOrConversionAmbiguous< + T, U>>>::value, + int> = 0> + StatusOr(StatusOr<U>&& other) // NOLINT + : Base(static_cast<typename StatusOr<U>::Base&&>(other)) {} + template < + typename U, + absl::enable_if_t< + absl::conjunction< + absl::negation<std::is_same<T, U>>, std::is_constructible<T, U&&>, + absl::negation<std::is_convertible<U&&, T>>, + absl::negation<internal_statusor::IsStatusOrConversionAmbiguous< + T, U>>>::value, + int> = 0> + explicit StatusOr(StatusOr<U>&& other) + : Base(static_cast<typename StatusOr<U>::Base&&>(other)) {} + + // Conversion copy/move assignment operator, T must be constructible and + // assignable from U. Only enable if T cannot be directly assigned from + // StatusOr<U>. + template < + typename U, + absl::enable_if_t< + absl::conjunction< + absl::negation<std::is_same<T, U>>, + std::is_constructible<T, const U&>, + std::is_assignable<T, const U&>, + absl::negation< + internal_statusor::IsStatusOrConversionAssigmentAmbiguous< + T, U>>>::value, + int> = 0> + StatusOr& operator=(const StatusOr<U>& other) { + this->Assign(other); + return *this; + } + template < + typename U, + absl::enable_if_t< + absl::conjunction< + absl::negation<std::is_same<T, U>>, std::is_constructible<T, U&&>, + std::is_assignable<T, U&&>, + absl::negation< + internal_statusor::IsStatusOrConversionAssigmentAmbiguous< + T, U>>>::value, + int> = 0> + StatusOr& operator=(StatusOr<U>&& other) { + this->Assign(std::move(other)); + return *this; + } + + // Constructs a new StatusOr with the given value. After calling this + // constructor, this->ok() will be true and the contained value may be + // retrieved with ValueOrDie(), operator*(), or operator->(). + StatusOr(const T& value); + + // Constructs a new StatusOr with the given non-ok status. After calling this + // constructor, this->ok() will be false and calls to ValueOrDie() will + // CHECK-fail. + StatusOr(const Status& status); + StatusOr& operator=(const Status& status); + StatusOr(const StatusBuilder& builder); + StatusOr& operator=(const StatusBuilder& builder); + + // Similar to the `const T&` overload. + // + // REQUIRES: T is move constructible. + StatusOr(T&& value); + + // RValue versions of the operations declared above. + StatusOr(Status&& status); + StatusOr& operator=(Status&& status); + StatusOr(StatusBuilder&& builder); + StatusOr& operator=(StatusBuilder&& builder); + + // Constructs the inner value T in-place using the provided args, using the + // T(args...) constructor. + template <typename... Args> + explicit StatusOr(absl::in_place_t, Args&&... args); + template <typename U, typename... Args> + explicit StatusOr(absl::in_place_t, std::initializer_list<U> ilist, + Args&&... args); + + // Constructs the inner value T in-place using the provided args, using the + // T(U) (direct-initialization) constructor. Only valid if T can be + // constructed from a U. Can accept move or copy constructors. Explicit it + // U is not convertible to T. To avoid ambiguity, this is disabled if U is + // a StatusOr<J>, where J is convertible to T. + template < + typename U = T, + absl::enable_if_t< + absl::conjunction< + internal_statusor::IsStatusOrDirectInitializationValid<T, U&&>, + std::is_constructible<T, U&&>, + std::is_convertible<U&&, T>>::value, + int> = 0> + StatusOr(U&& u) // NOLINT + : StatusOr(absl::in_place, std::forward<U>(u)) {} + + template < + typename U = T, + absl::enable_if_t< + absl::conjunction< + internal_statusor::IsStatusOrDirectInitializationValid<T, U&&>, + std::is_constructible<T, U&&>, + absl::negation<std::is_convertible<U&&, T>>>::value, + int> = 0> + explicit StatusOr(U&& u) // NOLINT + : StatusOr(absl::in_place, std::forward<U>(u)) {} + + // Returns this->ok() + explicit operator bool() const { return ok(); } + + // Returns this->status().ok() + ABSL_MUST_USE_RESULT bool ok() const { return this->status_.ok(); } + + // Returns a reference to our status. If this contains a T, then + // returns OkStatus(). + const Status& status() const&; + Status status() &&; + + // Returns a reference to our current value, or CHECK-fails if !this->ok(). If + // you have already checked the status using this->ok() or operator bool(), + // then you probably want to use operator*() or operator->() to access the + // current value instead of ValueOrDie(). + const T& ValueOrDie() const&; + T& ValueOrDie() &; + const T&& ValueOrDie() const&&; + T&& ValueOrDie() &&; + + // Returns a reference to the current value. + // + // REQUIRES: this->ok() == true, otherwise the behavior is undefined. + const T& operator*() const&; + T& operator*() &; + const T&& operator*() const&&; + T&& operator*() &&; + + // Returns a pointer to the current value. + // + // REQUIRES: this->ok() == true, otherwise the behavior is undefined. + const T* operator->() const; + T* operator->(); + + // Returns a copy of the current value if this->ok() == true. Otherwise + // returns a default value. + template <typename U> + T value_or(U&& default_value) const&; + template <typename U> + T value_or(U&& default_value) &&; + + // Ignores any errors. This method does nothing except potentially suppress + // complaints from any tools that are checking that errors are not dropped on + // the floor. + void IgnoreError() const; + + private: + using internal_statusor::StatusOrData<T>::Assign; + template <typename U> + void Assign(const StatusOr<U>& other); + template <typename U> + void Assign(StatusOr<U>&& other); +}; + +//////////////////////////////////////////////////////////////////////////////// +// Implementation details for StatusOr<T> + +template <typename T> +StatusOr<T>::StatusOr() : Base(Status(StatusCode::kUnknown, "")) {} + +template <typename T> +StatusOr<T>::StatusOr(const T& value) : Base(value) {} + +template <typename T> +StatusOr<T>::StatusOr(const Status& status) : Base(status) {} + +template <typename T> +StatusOr<T>::StatusOr(const StatusBuilder& builder) : Base(builder) {} + +template <typename T> +StatusOr<T>& StatusOr<T>::operator=(const Status& status) { + this->Assign(status); + return *this; +} + +template <typename T> +StatusOr<T>& StatusOr<T>::operator=(const StatusBuilder& builder) { + return *this = static_cast<Status>(builder); +} + +template <typename T> +StatusOr<T>::StatusOr(T&& value) : Base(std::move(value)) {} + +template <typename T> +StatusOr<T>::StatusOr(Status&& status) : Base(std::move(status)) {} + +template <typename T> +StatusOr<T>::StatusOr(StatusBuilder&& builder) : Base(std::move(builder)) {} + +template <typename T> +StatusOr<T>& StatusOr<T>::operator=(Status&& status) { + this->Assign(std::move(status)); + return *this; +} + +template <typename T> +StatusOr<T>& StatusOr<T>::operator=(StatusBuilder&& builder) { + return *this = static_cast<Status>(std::move(builder)); +} + +template <typename T> +template <typename U> +inline void StatusOr<T>::Assign(const StatusOr<U>& other) { + if (other.ok()) { + this->Assign(other.ValueOrDie()); + } else { + this->Assign(other.status()); + } +} + +template <typename T> +template <typename U> +inline void StatusOr<T>::Assign(StatusOr<U>&& other) { + if (other.ok()) { + this->Assign(std::move(other).ValueOrDie()); + } else { + this->Assign(std::move(other).status()); + } +} +template <typename T> +template <typename... Args> +StatusOr<T>::StatusOr(absl::in_place_t, Args&&... args) + : Base(absl::in_place, std::forward<Args>(args)...) {} + +template <typename T> +template <typename U, typename... Args> +StatusOr<T>::StatusOr(absl::in_place_t, std::initializer_list<U> ilist, + Args&&... args) + : Base(absl::in_place, ilist, std::forward<Args>(args)...) {} + +template <typename T> +const Status& StatusOr<T>::status() const& { + return this->status_; +} +template <typename T> +Status StatusOr<T>::status() && { + return ok() ? OkStatus() : std::move(this->status_); +} + +template <typename T> +const T& StatusOr<T>::ValueOrDie() const& { + this->EnsureOk(); + return this->data_; +} + +template <typename T> +T& StatusOr<T>::ValueOrDie() & { + this->EnsureOk(); + return this->data_; +} + +template <typename T> +const T&& StatusOr<T>::ValueOrDie() const&& { + this->EnsureOk(); + return std::move(this->data_); +} + +template <typename T> +T&& StatusOr<T>::ValueOrDie() && { + this->EnsureOk(); + return std::move(this->data_); +} + +template <typename T> +const T& StatusOr<T>::operator*() const& { + this->EnsureOk(); + return this->data_; +} + +template <typename T> +T& StatusOr<T>::operator*() & { + this->EnsureOk(); + return this->data_; +} + +template <typename T> +const T&& StatusOr<T>::operator*() const&& { + this->EnsureOk(); + return std::move(this->data_); +} + +template <typename T> +T&& StatusOr<T>::operator*() && { + this->EnsureOk(); + return std::move(this->data_); +} + +template <typename T> +const T* StatusOr<T>::operator->() const { + this->EnsureOk(); + return &this->data_; +} + +template <typename T> +T* StatusOr<T>::operator->() { + this->EnsureOk(); + return &this->data_; +} + +template <typename T> +template <typename U> +T StatusOr<T>::value_or(U&& default_value) const& { + if (ok()) { + return this->data_; + } + return std::forward<U>(default_value); +} + +template <typename T> +template <typename U> +T StatusOr<T>::value_or(U&& default_value) && { + if (ok()) { + return std::move(this->data_); + } + return std::forward<U>(default_value); +} + +template <typename T> +void StatusOr<T>::IgnoreError() const { + // no-op +} + +} // namespace iree + +#endif // IREE_BASE_INTERNAL_STATUSOR_H_
diff --git a/base/intrusive_list.h b/base/intrusive_list.h new file mode 100644 index 0000000..506c37b --- /dev/null +++ b/base/intrusive_list.h
@@ -0,0 +1,758 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Doubly linked list using element interior storage. +// This has the performance of std::list (that means O(1) on insert and remove) +// but performs no allocations and has better caching behavior. +// +// Elements are maintained in lists by way of IntrusiveListLinks, with each link +// allowing the element to exist in one list simultaneously. In the most simple +// case subclassing IntrusiveLinkBase will let the type be added to a list with +// little boilerplate. If an element must be in more than one list +// simultaneously IntrusiveListLinks can be added as members. +// +// Usage (simple): +// class MySimpleElement : public IntrusiveLinkBase {}; +// IntrusiveList<MySimpleElement> list; +// list.push_back(new MySimpleElement()); +// for (auto element : list) { ... } +// +// Usage (multiple lists): +// class MultiElement { +// public: +// IntrusiveListLink list_link_a; +// IntrusiveListLink list_link_b; +// }; +// IntrusiveList<MultiElement, offsetof(MultiElement, list_link_a)> list_a; +// IntrusiveList<MultiElement, offsetof(MultiElement, list_link_b)> list_b; +// +// By default elements in the list are not retained and must be kept alive +// externally. For automatic memory management there are specializations for +// std::unique_ptr. +// +// Usage (unique_ptr): +// IntrusiveList<std::unique_ptr<MyElement>> list; +// list.push_back(absl::make_unique<MyElement>()); +// std::unique_ptr<MyElement> elm = list.take(list.front()); +// +// This type is thread-unsafe. + +#ifndef IREE_BASE_INTRUSIVE_LIST_H_ +#define IREE_BASE_INTRUSIVE_LIST_H_ + +#include <cstddef> +#include <cstdint> +#include <functional> +#include <iterator> +#include <limits> + +#include "base/logging.h" + +namespace iree { + +// Define to enable extensive checks after each mutation of the intrusive list. +// #define IREE_PARANOID_INTRUSIVE_LIST + +// Storage for the doubly-linked list. +// This is embedded within all elements in an intrusive list. +struct IntrusiveListLink { + IntrusiveListLink* prev = nullptr; + IntrusiveListLink* next = nullptr; + + IntrusiveListLink() = default; + + // Prevent copies. + IntrusiveListLink(const IntrusiveListLink&) = delete; + IntrusiveListLink& operator=(const IntrusiveListLink&) = delete; +}; + +template <class T> +struct IntrusiveLinkBase : public T { + public: + IntrusiveListLink link; +}; + +template <> +struct IntrusiveLinkBase<void> { + public: + IntrusiveListLink link; +}; + +// Base type for intrusive lists. +// This is either used directly when the list is on naked pointers or +// specialized to std::unique_ptr. +template <typename T, typename IteratorT, typename ReverseIteratorT, + size_t kOffset> +class IntrusiveListBase { + public: + using self_type = IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>; + + IntrusiveListBase() = default; + virtual ~IntrusiveListBase() { clear(); } + + // Prevent copies. + IntrusiveListBase(const IntrusiveListBase&) = delete; + IntrusiveListBase& operator=(const IntrusiveListBase&) = delete; + + // Returns true if the list is empty. + // Performance: O(1) + constexpr bool empty() const { return head_ == nullptr; } + + // Returns the total number of items in the list. + // Performance: O(1) + constexpr size_t size() const { return count_; } + + // Returns true if the given item is contained within the list. + // Performance: O(n) + bool contains(T* value) const; + + // Appends the contents of the given list to this one. + // The |other_list| is cleared. + // Performance: O(1) + void merge_from(self_type* other_list); + + // Removes all items from the list. + // Performance: O(n) + void clear(); + + IteratorT begin() const { return IteratorT(head_); } + IteratorT end() const { return IteratorT(nullptr); } + ReverseIteratorT rbegin() const { return ReverseIteratorT(tail_); } + ReverseIteratorT rend() const { return ReverseIteratorT(nullptr); } + + // Returns the next item in the list relative to the given item. + // |value| must exist in the list. + // Performance: O(1) + T* next(T* value) const; + + // Returns the previous item in the list relative to the given item. + // |value| must exist in the list. + // Performance: O(1) + T* previous(T* value) const; + + // Returns the item at the front of the list, if any. + // Performance: O(1) + T* front() const; + + // Inserts an item at the front of the list. + // Performance: O(1) + void push_front(T* value); + + // Removes the item at the front of the list. + // Performance: O(1) + void pop_front(); + + // Returns the item at the back of the list, if any. + // Performance: O(1) + T* back() const; + + // Inserts an item at the back of the list. + // Performance: O(1) + void push_back(T* value); + + // Removes the item at the back of the list. + // Performance: O(1) + void pop_back(); + + // Inserts an item into the list before the given iterator. + // Performance: O(1) + void insert(const IteratorT& it, T* value) { return insert(*it, value); } + void insert(T* position, T* value); + + // Erases the given item from the list. + // Returns the item following the erased item, if any. + // Performance: O(1) + T* erase(T* value); + + // Erases the item from the list at the given iterator. + // Performance: O(1) + IteratorT erase(const IteratorT& it); + ReverseIteratorT erase(const ReverseIteratorT& it); + + // Replaces the item with a new item at the same position. + // |new_value| must not be contained in any list. + // Performance: O(1) + void replace(T* old_value, T* new_value); + + // Sorts the list with the given comparison function. + // The sort function is the same as used by std::sort. + // + // Uses merge sort O(N log N) using the algorithm described here: + // http://www.chiark.greenend.org.uk/~sgtatham/algorithms/listsort.html + void sort(bool (*compare_fn)(T* a, T* b)); + + protected: + // Called when an item is added to the list. + virtual void OnAdd(T* value) {} + // Called when an item is removed from the list. + virtual void OnRemove(T* value) {} + // Called when an item is removed and deallocated. + virtual void OnDeallocate(T* value) {} + + // Performs expensive correctness checks on the list structure. It's too slow + // to use in normal builds (even dbg), so it should only be used when there's + // a suspected issue with an intrusive list. Define + // IREE_PARANOID_INTRUSIVE_LIST to enable. + void CheckCorrectness() const; + + IntrusiveListLink* head_ = nullptr; + IntrusiveListLink* tail_ = nullptr; + size_t count_ = 0; +}; + +// Basic iterator for an IntrusiveList. +template <typename T, size_t kOffset, bool kForward> +class IntrusiveListIterator + : public std::iterator<std::input_iterator_tag, int> { + public: + using self_type = IntrusiveListIterator<T, kOffset, kForward>; + + explicit IntrusiveListIterator(IntrusiveListLink* current) + : current_(current) {} + IntrusiveListIterator& operator++(); + self_type operator++(int); + self_type& operator--(); + self_type operator--(int); + bool operator==(const self_type& rhs) const; + bool operator!=(const self_type& rhs) const; + T* operator*() const; + + protected: + IntrusiveListLink* current_; +}; + +// Specialized IntrusiveListBase used for unreferenced naked pointers. +// This very thinly wraps the base type and does no special memory management. +template <typename T, size_t kOffset> +class IntrusiveListUnrefBase + : public IntrusiveListBase<T, IntrusiveListIterator<T, kOffset, true>, + IntrusiveListIterator<T, kOffset, false>, + kOffset> { + public: + using IteratorT = IntrusiveListIterator<T, kOffset, true>; + using ReverseIteratorT = IntrusiveListIterator<T, kOffset, false>; + using base_list = IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>; + + using base_list::clear; + + // Removes all items from the list and calls the given deleter function for + // each of them. The built-in OnDeallocate will not be used. + // Performance: O(n) + void clear(const std::function<void(T*)>& deleter); + + private: + using base_list::count_; + using base_list::head_; + using base_list::tail_; +}; + +constexpr size_t kUseDefaultLinkOffset = std::numeric_limits<size_t>::max(); + +// IntrusiveList for raw pointers with a specified offset. +// Use this if there are multiple links within a type. +// +// Usage: +// struct MyType { +// IntrusiveListLink link_a; +// IntrusiveListLink link_b; +// }; +// IntrusiveList<MyType, offsetof(MyType, link_a)> list_a; +// IntrusiveList<MyType, offsetof(MyType, link_b)> list_b; +template <typename T, size_t kOffset = kUseDefaultLinkOffset> +class IntrusiveList : public IntrusiveListUnrefBase<T, kOffset> {}; + +// IntrusiveList for raw pointers. +// Items added to the list will not be owned by the list and must be freed by +// the caller. +// +// Usage: +// struct MyType : public IntrusiveListBase<void> {}; +// IntrusiveList<MyType> list; +// auto* p = new MyType(); +// list.push_back(p); // p is not retained and won't be freed! +// delete p; +template <typename T> +class IntrusiveList<T, kUseDefaultLinkOffset> + : public IntrusiveListUnrefBase<T, offsetof(T, link)> {}; + +// -- implementation -- + +namespace impl { + +// Maps an IntrusiveListLink to its containing type T. +template <typename T, size_t kOffset> +static inline T* LinkToT(IntrusiveListLink* link) { + if (link) { + return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(link) - kOffset); + } else { + return nullptr; + } +} + +// Maps a containing type T to its IntrusiveListLink. +template <typename T, size_t kOffset> +static inline IntrusiveListLink* TToLink(T* value) { + if (value) { + return reinterpret_cast<IntrusiveListLink*>( + reinterpret_cast<uintptr_t>(value) + kOffset); + } else { + return nullptr; + } +} + +} // namespace impl + +template <typename T, size_t kOffset, bool kForward> +IntrusiveListIterator<T, kOffset, kForward>& +IntrusiveListIterator<T, kOffset, kForward>::operator++() { + if (current_) { + current_ = kForward ? current_->next : current_->prev; + } + return *this; +} + +template <typename T, size_t kOffset, bool kForward> +IntrusiveListIterator<T, kOffset, kForward> +IntrusiveListIterator<T, kOffset, kForward>::operator++(int) { + self_type tmp(current_); + operator++(); + return tmp; +} + +template <typename T, size_t kOffset, bool kForward> +IntrusiveListIterator<T, kOffset, kForward>& +IntrusiveListIterator<T, kOffset, kForward>::operator--() { + if (current_) { + current_ = kForward ? current_->prev : current_->next; + } + return *this; +} + +template <typename T, size_t kOffset, bool kForward> +IntrusiveListIterator<T, kOffset, kForward> +IntrusiveListIterator<T, kOffset, kForward>::operator--(int) { + self_type tmp(current_); + operator--(); + return tmp; +} + +template <typename T, size_t kOffset, bool kForward> +bool IntrusiveListIterator<T, kOffset, kForward>::operator==( + const self_type& rhs) const { + return rhs.current_ == current_; +} + +template <typename T, size_t kOffset, bool kForward> +bool IntrusiveListIterator<T, kOffset, kForward>::operator!=( + const self_type& rhs) const { + return !operator==(rhs); +} + +template <typename T, size_t kOffset, bool kForward> +T* IntrusiveListIterator<T, kOffset, kForward>::operator*() const { + return impl::LinkToT<T, kOffset>(current_); +} + +template <typename T, size_t kOffset> +void IntrusiveListUnrefBase<T, kOffset>::clear( + const std::function<void(T*)>& deleter) { + auto* link = head_; + while (link) { + auto* next = link->next; + link->prev = link->next = nullptr; + deleter(impl::LinkToT<T, kOffset>(link)); + link = next; + } + head_ = tail_ = nullptr; + count_ = 0; +} + +template <typename T, typename IteratorT, typename ReverseIteratorT, + size_t kOffset> +void IntrusiveListBase<T, IteratorT, ReverseIteratorT, + kOffset>::CheckCorrectness() const { +#if defined(IREE_PARANOID_INTRUSIVE_LIST) + auto* link = head_; + IntrusiveListLink* previous = nullptr; + size_t actual_count = 0; + while (link) { + ++actual_count; + if (!link->prev) { + DCHECK_EQ(link, head_); + } + if (!link->next) { + DCHECK_EQ(link, tail_); + } + DCHECK_EQ(link->prev, previous); + previous = link; + link = link->next; + } + DCHECK_EQ(actual_count, count_); +#endif // IREE_PARANOID_INTRUSIVE_LIST +} + +template <typename T, typename IteratorT, typename ReverseIteratorT, + size_t kOffset> +bool IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::contains( + T* value) const { + if (!value) return false; + // TODO(benvanik): faster way of checking? requires list ptr in link? + auto* needle = impl::TToLink<T, kOffset>(value); + auto* link = head_; + while (link) { + if (link == needle) { + return true; + } + link = link->next; + } + return false; +} + +template <typename T, typename IteratorT, typename ReverseIteratorT, + size_t kOffset> +void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::merge_from( + self_type* other_list) { + if (tail_) { + tail_->next = other_list->head_; + } + if (other_list->head_) { + other_list->head_->prev = tail_; + } + if (!head_) { + head_ = other_list->head_; + } + tail_ = other_list->tail_; + + other_list->head_ = nullptr; + other_list->tail_ = nullptr; + + count_ += other_list->count_; + other_list->count_ = 0; +} + +template <typename T, typename IteratorT, typename ReverseIteratorT, + size_t kOffset> +void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::clear() { + auto* link = head_; + while (link) { + auto* next = link->next; + link->prev = link->next = nullptr; + OnDeallocate(impl::LinkToT<T, kOffset>(link)); + link = next; + } + head_ = tail_ = nullptr; + count_ = 0; +} + +template <typename T, typename IteratorT, typename ReverseIteratorT, + size_t kOffset> +inline T* IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::next( + T* value) const { + if (!value) { + return nullptr; + } + auto* link = impl::TToLink<T, kOffset>(value); + return impl::LinkToT<T, kOffset>(link->next); +} + +template <typename T, typename IteratorT, typename ReverseIteratorT, + size_t kOffset> +inline T* IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::previous( + T* value) const { + if (!value) { + return nullptr; + } + auto* link = impl::TToLink<T, kOffset>(value); + return impl::LinkToT<T, kOffset>(link->prev); +} + +template <typename T, typename IteratorT, typename ReverseIteratorT, + size_t kOffset> +inline T* IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::front() + const { + return impl::LinkToT<T, kOffset>(head_); +} + +template <typename T, typename IteratorT, typename ReverseIteratorT, + size_t kOffset> +void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::push_front( + T* value) { + DCHECK(value); + auto* link = impl::TToLink<T, kOffset>(value); + DCHECK(!link->next); + DCHECK(!link->prev); + link->next = head_; + link->prev = nullptr; + head_ = link; + if (link->next) { + link->next->prev = link; + } + if (!tail_) { + tail_ = link; + } + ++count_; + OnAdd(value); + CheckCorrectness(); +} + +template <typename T, typename IteratorT, typename ReverseIteratorT, + size_t kOffset> +void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::pop_front() { + DCHECK(head_); + auto* link = head_; + if (link) { + head_ = head_->next; + link->next = link->prev = nullptr; + if (head_) { + head_->prev = nullptr; + } + if (link == tail_) { + tail_ = nullptr; + } + --count_; + OnDeallocate(impl::LinkToT<T, kOffset>(link)); + } + CheckCorrectness(); +} + +template <typename T, typename IteratorT, typename ReverseIteratorT, + size_t kOffset> +inline T* IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::back() + const { + return impl::LinkToT<T, kOffset>(tail_); +} + +template <typename T, typename IteratorT, typename ReverseIteratorT, + size_t kOffset> +void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::push_back( + T* value) { + DCHECK(value); + auto* link = impl::TToLink<T, kOffset>(value); + DCHECK(!link->next); + DCHECK(!link->prev); + link->prev = tail_; + link->next = nullptr; + tail_ = link; + if (link->prev) { + link->prev->next = link; + } + if (!head_) { + head_ = link; + } + ++count_; + OnAdd(value); + CheckCorrectness(); +} + +template <typename T, typename IteratorT, typename ReverseIteratorT, + size_t kOffset> +void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::pop_back() { + DCHECK(tail_); + auto* link = tail_; + if (link) { + tail_ = tail_->prev; + link->next = link->prev = nullptr; + if (tail_) { + tail_->next = nullptr; + } + if (link == head_) { + head_ = nullptr; + } + --count_; + OnDeallocate(impl::LinkToT<T, kOffset>(link)); + } + CheckCorrectness(); +} + +template <typename T, typename IteratorT, typename ReverseIteratorT, + size_t kOffset> +void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::insert( + T* position, T* value) { + DCHECK(value); + auto* link = impl::TToLink<T, kOffset>(value); + auto* position_link = impl::TToLink<T, kOffset>(position); + DCHECK(!link->next); + DCHECK(!link->prev); + + if (position_link == head_) { + push_front(value); + } else if (position_link == nullptr) { + push_back(value); + } else { + link->next = position_link; + link->prev = position_link->prev; + position_link->prev->next = link; + position_link->prev = link; + ++count_; + OnAdd(value); + } + CheckCorrectness(); +} + +template <typename T, typename IteratorT, typename ReverseIteratorT, + size_t kOffset> +T* IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::erase(T* value) { + if (!value) { + return nullptr; + } + auto* link = impl::TToLink<T, kOffset>(value); + if (link->prev) { + DCHECK_NE(link, head_); + link->prev->next = link->next; + } else { + DCHECK_EQ(link, head_); + head_ = link->next; + } + if (link->next) { + DCHECK_NE(link, tail_); + link->next->prev = link->prev; + } else { + DCHECK_EQ(link, tail_); + tail_ = link->prev; + } + auto* next = link->next; + link->next = link->prev = nullptr; + --count_; + OnDeallocate(value); + CheckCorrectness(); + return impl::LinkToT<T, kOffset>(next); +} + +template <typename T, typename IteratorT, typename ReverseIteratorT, + size_t kOffset> +IteratorT IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::erase( + const IteratorT& it) { + return IteratorT(impl::TToLink<T, kOffset>(erase(*it))); +} + +template <typename T, typename IteratorT, typename ReverseIteratorT, + size_t kOffset> +ReverseIteratorT IntrusiveListBase<T, IteratorT, ReverseIteratorT, + kOffset>::erase(const ReverseIteratorT& it) { + return ReverseIteratorT(impl::TToLink<T, kOffset>(erase(*it))); +} + +template <typename T, typename IteratorT, typename ReverseIteratorT, + size_t kOffset> +void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::replace( + T* old_value, T* new_value) { + DCHECK(old_value); + DCHECK(new_value); + DCHECK_NE(old_value, new_value); + auto* old_link = impl::TToLink<T, kOffset>(old_value); + auto* new_link = impl::TToLink<T, kOffset>(new_value); + new_link->next = old_link->next; + new_link->prev = old_link->prev; + if (new_link->prev) { + new_link->prev->next = new_link; + } else { + head_ = new_link; + } + if (new_link->next) { + new_link->next->prev = new_link; + } else { + tail_ = new_link; + } + old_link->next = old_link->prev = nullptr; + OnAdd(new_value); + OnDeallocate(old_value); + CheckCorrectness(); +} + +template <typename T, typename IteratorT, typename ReverseIteratorT, + size_t kOffset> +void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::sort( + bool (*compare_fn)(T* a, T* b)) { + if (empty()) { + // Empty list no-op. + return; + } + // Repeatedly run until the list is sorted. + int in_size = 1; + while (true) { + IntrusiveListLink* p = head_; + IntrusiveListLink* q = nullptr; + IntrusiveListLink* e = nullptr; + IntrusiveListLink* tail = nullptr; + head_ = nullptr; + tail_ = nullptr; + // Repeatedly merge sublists. + int merge_count = 0; + do { + ++merge_count; + q = p; + // Determine the size of the first part and find the second. + int p_size = 0; + for (int i = 0; i < in_size; ++i) { + ++p_size; + q = q->next; + if (!q) { + break; + } + } + // Merge the two lists (if we have two). + int q_size = in_size; + while (p_size > 0 || (q_size > 0 && q)) { + if (p_size == 0) { + // p is empty; e must come from q. + e = q; + q = q->next; + --q_size; + } else if (q_size == 0 || !q) { + // q is empty; e must come from p. + e = p; + p = p->next; + --p_size; + } else if (compare_fn(impl::LinkToT<T, kOffset>(p), + impl::LinkToT<T, kOffset>(q))) { + // p <= q; e must come from p. + e = p; + p = p->next; + --p_size; + } else { + // q < p; e must come from q. + e = q; + q = q->next; + --q_size; + } + // Append e to the merged list. + if (tail) { + tail->next = e; + } else { + head_ = e; + } + e->prev = tail; + tail = e; + } + p = q; + } while (p); + tail->next = nullptr; + if (merge_count <= 1) { + // List is now sorted; stash and return. + tail_ = tail; + CheckCorrectness(); + return; + } + // Run merge again with larger lists. + in_size *= 2; + } +} + +} // namespace iree + +// Specializations: +#include "base/intrusive_list_ref_ptr.inc" +#include "base/intrusive_list_unique_ptr.inc" + +#endif // IREE_BASE_INTRUSIVE_LIST_H_
diff --git a/base/intrusive_list_ref_ptr.inc b/base/intrusive_list_ref_ptr.inc new file mode 100644 index 0000000..962a1fe --- /dev/null +++ b/base/intrusive_list_ref_ptr.inc
@@ -0,0 +1,174 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "base/intrusive_list.h" + +#ifndef IREE_BASE_INTRUSIVE_LIST_REF_PTR_H_ +#define IREE_BASE_INTRUSIVE_LIST_REF_PTR_H_ + +#include <cstddef> +#include <iterator> + +#include "base/intrusive_list.h" +#include "base/ref_ptr.h" + +namespace iree { + +// Iterator for an IntrusiveList specialized to ref_ptr. +template <typename T, size_t kOffset, bool kForward> +class IntrusiveListRefPtrIterator + : public std::iterator<std::input_iterator_tag, int> { + public: + using self_type = IntrusiveListRefPtrIterator<T, kOffset, kForward>; + + explicit IntrusiveListRefPtrIterator(IntrusiveListLink* current) + : current_(current) {} + self_type& operator++() { + if (current_) { + current_ = kForward ? current_->next : current_->prev; + } + return *this; + } + self_type operator++(int) { + self_type tmp(current_); + operator++(); + return tmp; + } + self_type& operator--() { + if (current_) { + current_ = kForward ? current_->prev : current_->next; + } + return *this; + } + self_type operator--(int) { + self_type tmp(current_); + operator--(); + return tmp; + } + bool operator==(const self_type& rhs) const { + return rhs.current_ == current_; + } + bool operator!=(const self_type& rhs) const { return !operator==(rhs); } + ref_ptr<T> operator*() const { + return add_ref(impl::LinkToT<T, kOffset>(current_)); + } + + protected: + IntrusiveListLink* current_; +}; + +// Specialized IntrusiveListBase for ref_ptr types. +// This makes the list methods accept/return ref_ptrs and iterate with +// a ref_ptr iterator. +template <typename T, size_t kOffset> +class IntrusiveListRefPtrBase + : private IntrusiveListBase< + T, IntrusiveListRefPtrIterator<T, kOffset, true>, + IntrusiveListRefPtrIterator<T, kOffset, false>, kOffset> { + public: + using IteratorT = IntrusiveListRefPtrIterator<T, kOffset, true>; + using ReverseIteratorT = IntrusiveListRefPtrIterator<T, kOffset, false>; + using base_list = IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>; + + IntrusiveListRefPtrBase() = default; + + using base_list::empty; + using base_list::size; + + using base_list::contains; + bool contains(const ref_ptr<T>& value) const { + return base_list::contains(value.get()); + } + + using base_list::clear; + + using base_list::begin; + using base_list::end; + using base_list::rbegin; + using base_list::rend; + + inline ref_ptr<T> next(const ref_ptr<T>& value) const { + return add_ref(base_list::next(value.get())); + } + inline ref_ptr<T> next(T* value) const { + return add_ref(base_list::next(value)); + } + + inline ref_ptr<T> previous(const ref_ptr<T>& value) const { + return add_ref(base_list::previous(value.get())); + } + inline ref_ptr<T> previous(T* value) const { + return add_ref(base_list::previous(value)); + } + + // Performance: O(1) + inline ref_ptr<T> front() const { + return add_ref(impl::LinkToT<T, kOffset>(head_)); + } + + void push_front(const ref_ptr<T>& value) { + base_list::push_front(value.get()); + } + + using base_list::pop_front; + + // Performance: O(1) + inline ref_ptr<T> back() const { + return add_ref(impl::LinkToT<T, kOffset>(tail_)); + } + + void push_back(const ref_ptr<T>& value) { base_list::push_back(value.get()); } + + using base_list::pop_back; + + void insert(const IteratorT& it, const ref_ptr<T>& value) { + base_list::insert(it, value.get()); + } + + using base_list::erase; + + ref_ptr<T> erase(const ref_ptr<T>& value) { + return add_ref(base_list::erase(value.get())); + } + + void replace(const ref_ptr<T>& old_value, const ref_ptr<T>& new_value) { + base_list::replace(old_value.get(), new_value.get()); + } + void replace(T* old_value, const ref_ptr<T>& new_value) { + base_list::replace(old_value, new_value.get()); + } + + using base_list::sort; + + private: + void OnAdd(T* value) override { value->AddReference(); } + void OnRemove(T* value) override { value->ReleaseReference(); } + void OnDeallocate(T* value) override { value->ReleaseReference(); } + + using base_list::count_; + using base_list::head_; + using base_list::tail_; +}; + +template <typename U, size_t kOffset> +class IntrusiveList<ref_ptr<U>, kOffset> + : public IntrusiveListRefPtrBase<U, kOffset> {}; + +template <typename U> +class IntrusiveList<ref_ptr<U>, kUseDefaultLinkOffset> + : public IntrusiveListRefPtrBase<U, offsetof(U, link)> {}; + +} // namespace iree + +#endif // IREE_BASE_INTRUSIVE_LIST_REF_PTR_H_
diff --git a/base/intrusive_list_ref_ptr_test.cc b/base/intrusive_list_ref_ptr_test.cc new file mode 100644 index 0000000..81d4296 --- /dev/null +++ b/base/intrusive_list_ref_ptr_test.cc
@@ -0,0 +1,100 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/intrusive_list.h" +#include "gtest/gtest.h" + +namespace iree { +namespace { + +static int alloc_count = 0; +struct RefCountedType : public RefObject<RefCountedType> { + IntrusiveListLink link; + RefCountedType() { ++alloc_count; } + ~RefCountedType() { --alloc_count; } + static void Deallocate(RefCountedType* value) { delete value; } + using RefObject<RefCountedType>::counter_; +}; + +TEST(IntrusiveListRefPtrTest, PushAndClear) { + alloc_count = 0; + IntrusiveList<ref_ptr<RefCountedType>> list; + EXPECT_EQ(0, alloc_count); + list.push_back(make_ref<RefCountedType>()); + EXPECT_EQ(1, alloc_count); + EXPECT_NE(nullptr, list.front()); + EXPECT_EQ(2, list.front()->counter_); + list.clear(); + EXPECT_EQ(0, alloc_count); +} + +TEST(IntrusiveListRefPtrTest, PushPop) { + alloc_count = 0; + IntrusiveList<ref_ptr<RefCountedType>> list; + list.push_back(make_ref<RefCountedType>()); + EXPECT_EQ(1, alloc_count); + list.push_back(make_ref<RefCountedType>()); + EXPECT_EQ(2, alloc_count); + EXPECT_NE(list.front(), list.back()); + list.pop_back(); + EXPECT_EQ(1, alloc_count); + list.pop_front(); + EXPECT_EQ(0, alloc_count); +} + +TEST(IntrusiveListRefPtrTest, PushErase) { + alloc_count = 0; + IntrusiveList<ref_ptr<RefCountedType>> list; + list.push_back(make_ref<RefCountedType>()); + EXPECT_EQ(1, alloc_count); + EXPECT_NE(nullptr, list.front()); + EXPECT_EQ(2, list.front()->counter_); + auto item = list.front(); + EXPECT_NE(nullptr, item.get()); + EXPECT_EQ(3, list.front()->counter_); + EXPECT_EQ(1, alloc_count); + list.erase(item); + EXPECT_EQ(1, alloc_count); + item.reset(); + EXPECT_EQ(0, alloc_count); +} + +TEST(IntrusiveListRefPtrTest, PushReplace) { + alloc_count = 0; + IntrusiveList<ref_ptr<RefCountedType>> list; + list.push_back(make_ref<RefCountedType>()); + EXPECT_EQ(1, alloc_count); + list.replace(list.front(), make_ref<RefCountedType>()); + EXPECT_EQ(1, alloc_count); + list.clear(); + EXPECT_EQ(0, alloc_count); +} + +TEST(IntrusiveListRefPtrTest, Iteration) { + alloc_count = 0; + IntrusiveList<ref_ptr<RefCountedType>> list; + list.push_back(make_ref<RefCountedType>()); + list.push_back(make_ref<RefCountedType>()); + list.push_back(make_ref<RefCountedType>()); + EXPECT_EQ(3, alloc_count); + for (auto item : list) { + const ref_ptr<RefCountedType>& item_ref = item; + EXPECT_NE(nullptr, item_ref.get()); + } + list.clear(); + EXPECT_EQ(0, alloc_count); +} + +} // namespace +} // namespace iree
diff --git a/base/intrusive_list_test.cc b/base/intrusive_list_test.cc new file mode 100644 index 0000000..e189105 --- /dev/null +++ b/base/intrusive_list_test.cc
@@ -0,0 +1,523 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/intrusive_list.h" + +#include <algorithm> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace iree { +namespace { + +using ::testing::ElementsAre; + +struct Item { + size_t some_data_0; + IntrusiveListLink list_a; + size_t some_data_1; + IntrusiveListLink list_b; + size_t some_data_2; + int value; + + static const size_t kToken = 0xDEADBEEF; + explicit Item(int value) + : some_data_0(kToken), + some_data_1(kToken), + some_data_2(kToken), + value(value) {} + bool is_valid() { + return some_data_0 == kToken && some_data_1 == kToken && + some_data_2 == kToken; + } +}; + +template <typename T, size_t V> +std::vector<T*> ExtractItems(const IntrusiveList<T, V>& list) { + std::vector<T*> items; + for (auto* item : list) { + items.push_back(item); + } + return items; +} + +template <typename T, size_t V> +std::vector<int> ExtractValues(const IntrusiveList<T, V>& list) { + std::vector<int> values; + for (auto* item : list) { + values.push_back(item->value); + } + return values; +} + +template <typename T, size_t V> +std::vector<int> ExtractValuesMutable(const IntrusiveList<T, V>& list) { + std::vector<int> values; + for (auto* item : list) { + values.push_back(item->value); + } + return values; +} + +TEST(IntrusiveListTest, PushPopItems) { + Item item1(1); + Item item2(2); + Item item3(3); + Item item4(4); + + IntrusiveList<Item, offsetof(Item, list_a)> items; + EXPECT_TRUE(items.empty()); + EXPECT_EQ(items.size(), 0u); + EXPECT_EQ(items.front(), nullptr); + EXPECT_EQ(items.back(), nullptr); + EXPECT_TRUE(items.begin() == items.end()); + items.push_front(&item1); + EXPECT_FALSE(items.empty()); + EXPECT_EQ(items.size(), 1u); + EXPECT_EQ(items.front(), &item1); + EXPECT_EQ(items.back(), &item1); + EXPECT_FALSE(items.begin() == items.end()); + items.push_front(&item2); + EXPECT_EQ(items.size(), 2u); + EXPECT_EQ(items.front(), &item2); + EXPECT_EQ(items.back(), &item1); + items.push_front(&item3); + EXPECT_EQ(items.size(), 3u); + EXPECT_EQ(items.front(), &item3); + EXPECT_EQ(items.back(), &item1); + EXPECT_THAT(ExtractValues(items), ElementsAre(3, 2, 1)); + + items.push_back(&item4); + EXPECT_EQ(items.size(), 4u); + EXPECT_EQ(items.front(), &item3); + EXPECT_EQ(items.back(), &item4); + EXPECT_THAT(ExtractValues(items), ElementsAre(3, 2, 1, 4)); + + items.pop_front(); + EXPECT_EQ(items.size(), 3u); + EXPECT_EQ(items.front(), &item2); + EXPECT_EQ(items.back(), &item4); + EXPECT_THAT(ExtractValues(items), ElementsAre(2, 1, 4)); + + items.pop_back(); + EXPECT_EQ(items.size(), 2u); + EXPECT_EQ(items.front(), &item2); + EXPECT_EQ(items.back(), &item1); + EXPECT_THAT(ExtractValues(items), ElementsAre(2, 1)); + + items.pop_back(); + items.pop_front(); + EXPECT_TRUE(items.empty()); + EXPECT_EQ(items.size(), 0u); + EXPECT_EQ(items.front(), nullptr); + EXPECT_EQ(items.back(), nullptr); + EXPECT_TRUE(items.begin() == items.end()); + + EXPECT_TRUE(item1.is_valid()); + EXPECT_TRUE(item2.is_valid()); + EXPECT_TRUE(item3.is_valid()); + EXPECT_TRUE(item4.is_valid()); +} + +TEST(IntrusiveListTest, Contains) { + Item item1(1); + Item item2(2); + Item item3(3); + Item item4(4); + + IntrusiveList<Item, offsetof(Item, list_a)> items; + items.push_back(&item1); + items.push_back(&item2); + items.push_back(&item3); + // item4 omitted. + + EXPECT_TRUE(items.contains(&item1)); + EXPECT_TRUE(items.contains(&item2)); + EXPECT_TRUE(items.contains(&item3)); + EXPECT_FALSE(items.contains(&item4)); + + EXPECT_FALSE(items.contains(nullptr)); +} + +TEST(IntrusiveListTest, MergeFrom) { + Item item1(1); + Item item2(2); + Item item3(3); + Item item4(4); + + IntrusiveList<Item, offsetof(Item, list_a)> items0; + items0.push_back(&item1); + items0.push_back(&item2); + items0.push_back(&item3); + + IntrusiveList<Item, offsetof(Item, list_a)> items1; + items1.push_back(&item4); + + items0.merge_from(&items1); + EXPECT_THAT(ExtractValues(items0), ElementsAre(1, 2, 3, 4)); + EXPECT_TRUE(items1.empty()); +} + +TEST(IntrusiveListTest, MergeFromEmpty) { + IntrusiveList<Item, offsetof(Item, list_a)> items0; + IntrusiveList<Item, offsetof(Item, list_a)> items1; + items0.merge_from(&items1); +} + +TEST(IntrusiveListTest, MergeFromAll) { + Item item1(1); + Item item2(2); + Item item3(3); + Item item4(4); + IntrusiveList<Item, offsetof(Item, list_a)> items0; + items0.push_back(&item1); + items0.push_back(&item2); + items0.push_back(&item3); + items0.push_back(&item4); + IntrusiveList<Item, offsetof(Item, list_a)> items1; + + // Merge all items from items1 into items0. Shouldn't change anything. + items0.merge_from(&items1); + EXPECT_THAT(ExtractValues(items0), ElementsAre(1, 2, 3, 4)); + EXPECT_TRUE(items1.empty()); + + // Merge all items from items0 into items1. Should move everything. + items1.merge_from(&items0); + EXPECT_TRUE(items0.empty()); + EXPECT_THAT(ExtractValues(items1), ElementsAre(1, 2, 3, 4)); +} + +TEST(IntrusiveListTest, Erase) { + Item item1(1); + Item item2(2); + Item item3(3); + Item item4(4); + + IntrusiveList<Item, offsetof(Item, list_a)> items; + items.push_back(&item1); + items.push_back(&item2); + items.push_back(&item3); + items.push_back(&item4); + + EXPECT_THAT(ExtractValues(items), ElementsAre(1, 2, 3, 4)); + items.erase(&item3); + EXPECT_THAT(ExtractValues(items), ElementsAre(1, 2, 4)); + items.erase(&item1); + EXPECT_THAT(ExtractValues(items), ElementsAre(2, 4)); + items.erase(&item4); + EXPECT_THAT(ExtractValues(items), ElementsAre(2)); + items.erase(&item2); + EXPECT_TRUE(items.empty()); + + items.push_back(&item1); + items.push_back(&item2); + items.push_back(&item3); + items.push_back(&item4); + + EXPECT_THAT(ExtractValues(items), ElementsAre(1, 2, 3, 4)); + auto it = items.begin(); + items.erase(it); + EXPECT_THAT(ExtractValues(items), ElementsAre(2, 3, 4)); + it = items.end(); + items.erase(it); + EXPECT_THAT(ExtractValues(items), ElementsAre(2, 3, 4)); + it = items.begin(); + ++it; + items.erase(it); + EXPECT_THAT(ExtractValues(items), ElementsAre(2, 4)); + + it = items.begin(); + it = items.erase(it); + EXPECT_EQ(4, (*it)->value); + EXPECT_THAT(ExtractValues(items), ElementsAre(4)); + it = items.erase(it); + EXPECT_TRUE(items.empty()); + EXPECT_EQ(items.end(), it); +} + +TEST(IntrusiveListTest, MultipleLists) { + Item item1(1); + Item item2(2); + Item item3(3); + Item item4(4); + + IntrusiveList<Item, offsetof(Item, list_a)> items_a; + IntrusiveList<Item, offsetof(Item, list_b)> items_b; + items_a.push_back(&item1); + items_a.push_back(&item2); + items_a.push_back(&item3); + items_a.push_back(&item4); + items_b.push_front(&item1); + items_b.push_front(&item2); + items_b.push_front(&item3); + items_b.push_front(&item4); + EXPECT_THAT(ExtractValues(items_a), ElementsAre(1, 2, 3, 4)); + EXPECT_THAT(ExtractValues(items_b), ElementsAre(4, 3, 2, 1)); + items_b.erase(&item3); + EXPECT_THAT(ExtractValues(items_a), ElementsAre(1, 2, 3, 4)); + EXPECT_THAT(ExtractValues(items_b), ElementsAre(4, 2, 1)); + items_a.pop_back(); + EXPECT_THAT(ExtractValues(items_a), ElementsAre(1, 2, 3)); + EXPECT_THAT(ExtractValues(items_b), ElementsAre(4, 2, 1)); +} + +TEST(IntrusiveListTest, MutableIterator) { + Item item1(1); + Item item2(2); + Item item3(3); + Item item4(4); + + IntrusiveList<Item, offsetof(Item, list_a)> items; + items.push_back(&item4); + items.push_front(&item1); + items.push_front(&item2); + items.push_front(&item3); + + EXPECT_THAT(ExtractValuesMutable(items), ElementsAre(3, 2, 1, 4)); +} + +struct BaseType { + explicit BaseType(int value) : value(value) {} + int value; + IntrusiveListLink base_link; +}; +struct SubType : public BaseType { + explicit SubType(int value) : BaseType(value) {} + IntrusiveListLink sub_link; +}; +TEST(IntrusiveListTest, SimpleType) { + SubType item1(1); + SubType item2(2); + SubType item3(3); + SubType item4(4); + + IntrusiveList<BaseType, offsetof(BaseType, base_link)> items_a; + items_a.push_front(&item1); + items_a.push_front(&item2); + items_a.push_front(&item3); + items_a.push_front(&item4); + EXPECT_THAT(ExtractValues(items_a), ElementsAre(4, 3, 2, 1)); + + IntrusiveList<SubType, offsetof(SubType, sub_link)> items_b; + items_b.push_back(&item1); + items_b.push_back(&item2); + items_b.push_back(&item3); + items_b.push_back(&item4); + EXPECT_THAT(ExtractValues(items_b), ElementsAre(1, 2, 3, 4)); +} + +struct AbstractType { + explicit AbstractType(int value) : value(value) {} + virtual ~AbstractType() = default; + virtual int DoSomething() = 0; + int value; + IntrusiveListLink base_link; +}; +struct ImplType : public AbstractType { + explicit ImplType(int value) : AbstractType(value) {} + int DoSomething() override { return value; } + IntrusiveListLink sub_link; +}; + +TEST(IntrusiveListTest, ComplexType) { + ImplType item1(1); + ImplType item2(2); + ImplType item3(3); + ImplType item4(4); + + IntrusiveList<AbstractType, offsetof(AbstractType, base_link)> items_a; + items_a.push_front(&item1); + items_a.push_front(&item2); + items_a.push_front(&item3); + items_a.push_front(&item4); + EXPECT_THAT(ExtractValues(items_a), ElementsAre(4, 3, 2, 1)); + + IntrusiveList<ImplType, offsetof(ImplType, sub_link)> items_b; + items_b.push_back(&item1); + items_b.push_back(&item2); + items_b.push_back(&item3); + items_b.push_back(&item4); + EXPECT_THAT(ExtractValues(items_b), ElementsAre(1, 2, 3, 4)); +} + +bool Comparison(Item* a, Item* b) { return a->value < b->value; } + +TEST(IntrusiveListTest, Inserting) { + Item item1(1); + Item item2(2); + Item item3(3); + Item item4(4); + + IntrusiveList<Item, offsetof(Item, list_a)> items; + items.insert(items.end(), &item3); + items.insert(items.begin(), &item1); + items.insert(items.end(), &item4); + + auto pos = std::upper_bound(items.begin(), items.end(), &item2, Comparison); + items.insert(pos, &item2); + + EXPECT_THAT(ExtractValues(items), ElementsAre(1, 2, 3, 4)); +} + +// TODO(benvanik): test reverse iteration. + +TEST(IntrusiveListTest, NextPrevious) { + Item item1(1); + Item item2(2); + + IntrusiveList<Item, offsetof(Item, list_a)> items; + EXPECT_EQ(nullptr, items.previous(nullptr)); + EXPECT_EQ(nullptr, items.next(nullptr)); + + items.push_back(&item1); + EXPECT_EQ(nullptr, items.previous(&item1)); + EXPECT_EQ(nullptr, items.next(&item1)); + + items.push_back(&item2); + EXPECT_EQ(nullptr, items.previous(&item1)); + EXPECT_EQ(&item2, items.next(&item1)); + EXPECT_EQ(&item1, items.previous(&item2)); + EXPECT_EQ(nullptr, items.next(&item2)); +} + +TEST(IntrusiveListTest, Clear) { + Item item1(1); + Item item2(2); + Item item3(3); + Item item4(4); + + IntrusiveList<Item, offsetof(Item, list_a)> items; + + // Empty clear. + items.clear(); + EXPECT_TRUE(items.empty()); + + // 1 item clear. + items.push_back(&item1); + items.clear(); + EXPECT_TRUE(items.empty()); + + // Multi-item clear. + items.push_back(&item1); + items.push_back(&item2); + items.push_back(&item3); + items.push_back(&item4); + items.clear(); + EXPECT_TRUE(items.empty()); +} + +TEST(IntrusiveListTest, ClearDeleter) { + Item item1(1); + Item item2(2); + + IntrusiveList<Item, offsetof(Item, list_a)> items; + + // No-op first. + int delete_count = 0; + items.clear([&](Item* item) { ++delete_count; }); + EXPECT_EQ(0, delete_count); + + // Now with items. + items.push_back(&item1); + items.push_back(&item2); + items.clear([&](Item* item) { ++delete_count; }); + EXPECT_EQ(2, delete_count); + EXPECT_TRUE(items.empty()); +} + +TEST(IntrusiveListTest, Replace) { + Item item1(1); + Item item2(2); + Item item3(3); + + IntrusiveList<Item, offsetof(Item, list_a)> items; + items.push_back(&item1); + items.push_back(&item2); + + items.replace(&item1, &item3); + EXPECT_THAT(ExtractValues(items), ElementsAre(3, 2)); + EXPECT_FALSE(items.contains(&item1)); + items.replace(&item2, &item1); + EXPECT_THAT(ExtractValues(items), ElementsAre(3, 1)); + EXPECT_FALSE(items.contains(&item2)); +} + +TEST(IntrusiveListTest, Sort) { + Item item1(1); + Item item2(2); + Item item3(3); + Item item4(4); + + IntrusiveList<Item, offsetof(Item, list_a)> items; + + // Empty sort. + items.sort([](Item* a, Item* b) { return a->value < b->value; }); + + // Single item sort. + items.clear(); + items.push_back(&item1); + items.sort([](Item* a, Item* b) { return a->value < b->value; }); + EXPECT_THAT(ExtractValues(items), ElementsAre(1)); + + // Already sorted. + items.clear(); + items.push_back(&item1); + items.push_back(&item2); + items.push_back(&item3); + items.push_back(&item4); + items.sort([](Item* a, Item* b) { return a->value < b->value; }); + EXPECT_THAT(ExtractValues(items), ElementsAre(1, 2, 3, 4)); + + // Reverse. + items.clear(); + items.push_back(&item4); + items.push_back(&item3); + items.push_back(&item2); + items.push_back(&item1); + items.sort([](Item* a, Item* b) { return a->value < b->value; }); + EXPECT_THAT(ExtractValues(items), ElementsAre(1, 2, 3, 4)); + + // Random. + items.clear(); + items.push_back(&item2); + items.push_back(&item4); + items.push_back(&item1); + items.push_back(&item3); + items.sort([](Item* a, Item* b) { return a->value < b->value; }); + EXPECT_THAT(ExtractValues(items), ElementsAre(1, 2, 3, 4)); + + // Stability. + Item item1a(1); + Item item2a(2); + items.clear(); + items.push_back(&item2); + items.push_back(&item4); + items.push_back(&item1); + items.push_back(&item3); + items.push_back(&item1a); + items.push_back(&item2a); + items.sort([](Item* a, Item* b) { return a->value <= b->value; }); + EXPECT_THAT(ExtractValues(items), ElementsAre(1, 1, 2, 2, 3, 4)); + auto items_vector = ExtractItems(items); + EXPECT_EQ(&item1, items_vector[0]); + EXPECT_EQ(&item1a, items_vector[1]); + EXPECT_EQ(&item2, items_vector[2]); + EXPECT_EQ(&item2a, items_vector[3]); + items.clear(); +} + +} // namespace +} // namespace iree
diff --git a/base/intrusive_list_unique_ptr.inc b/base/intrusive_list_unique_ptr.inc new file mode 100644 index 0000000..4397265 --- /dev/null +++ b/base/intrusive_list_unique_ptr.inc
@@ -0,0 +1,140 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "base/intrusive_list.h" + +#ifndef IREE_BASE_INTRUSIVE_LIST_UNIQUE_PTR_H_ +#define IREE_BASE_INTRUSIVE_LIST_UNIQUE_PTR_H_ + +#include <cstddef> +#include <memory> + +#include "base/intrusive_list.h" +#include "base/logging.h" + +namespace iree { + +// Specialized IntrusiveListBase for std::unique_ptr types. +// This makes the list methods accept std::unique_ptrs and contains a special +// take() method that takes ownership of a list item. +template <typename T, size_t kOffset> +class IntrusiveListUniquePtrBase + : private IntrusiveListBase<T, IntrusiveListIterator<T, kOffset, true>, + IntrusiveListIterator<T, kOffset, false>, + kOffset> { + public: + using IteratorT = IntrusiveListIterator<T, kOffset, true>; + using ReverseIteratorT = IntrusiveListIterator<T, kOffset, false>; + using base_list = IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>; + + IntrusiveListUniquePtrBase() = default; + + using base_list::empty; + using base_list::size; + + using base_list::contains; + + using base_list::clear; + + using base_list::begin; + using base_list::end; + using base_list::rbegin; + using base_list::rend; + + using base_list::next; + + using base_list::previous; + + using base_list::front; + + void push_front(std::unique_ptr<T> value) { + base_list::push_front(value.release()); + } + + using base_list::pop_front; + + using base_list::back; + + void push_back(std::unique_ptr<T> value) { + base_list::push_back(value.release()); + } + + using base_list::pop_back; + + void insert(const IteratorT& it, std::unique_ptr<T> value) { + base_list::insert(it, value.release()); + } + + using base_list::erase; + + // Removes an item from the list at the given iterator and transfers ownership + // to the caller. + // Performance: O(1) + std::unique_ptr<T> take(IteratorT& it) { // NOLINT(runtime/references) + return take(*it); + } + + // Removes an item from the list and transfers ownership to the caller. + // Performance: O(1) + std::unique_ptr<T> take(T* value) { + if (!value) { + return {nullptr}; + } + auto* link = impl::TToLink<T, kOffset>(value); + if (link->prev) { + DCHECK_NE(link, head_); + link->prev->next = link->next; + } else { + DCHECK_EQ(link, head_); + head_ = link->next; + } + if (link->next) { + DCHECK_NE(link, tail_); + link->next->prev = link->prev; + } else { + DCHECK_EQ(link, tail_); + tail_ = link->prev; + } + link->next = link->prev = nullptr; + --count_; + base_list::OnRemove(value); + base_list::CheckCorrectness(); + return std::unique_ptr<T>(value); + } + + void replace(T* old_value, std::unique_ptr<T> new_value) { + base_list::replace(old_value, new_value.release()); + } + + using base_list::sort; + + private: + void OnDeallocate(T* value) override { delete value; } + + using base_list::count_; + using base_list::head_; + using base_list::tail_; +}; + +template <typename U, size_t kOffset> +class IntrusiveList<std::unique_ptr<U>, kOffset> + : public IntrusiveListUniquePtrBase<U, kOffset> {}; + +template <typename U> +class IntrusiveList<std::unique_ptr<U>, kUseDefaultLinkOffset> + : public IntrusiveListUniquePtrBase<U, offsetof(U, link)> {}; + +} // namespace iree + +#endif // IREE_BASE_INTRUSIVE_LIST_UNIQUE_PTR_H_
diff --git a/base/intrusive_list_unique_ptr_test.cc b/base/intrusive_list_unique_ptr_test.cc new file mode 100644 index 0000000..c5e9421 --- /dev/null +++ b/base/intrusive_list_unique_ptr_test.cc
@@ -0,0 +1,84 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/memory/memory.h" +#include "base/intrusive_list.h" +#include "gtest/gtest.h" + +namespace iree { +namespace { + +struct AllocatedType : public IntrusiveLinkBase<void> { + AllocatedType() { ++alloc_count; } + ~AllocatedType() { --alloc_count; } + static int alloc_count; +}; +int AllocatedType::alloc_count = 0; + +TEST(IntrusiveListUniquePtrTest, UniquePtr) { + AllocatedType::alloc_count = 0; + + // Push/clear. + IntrusiveList<std::unique_ptr<AllocatedType>> list; + EXPECT_EQ(0, AllocatedType::alloc_count); + list.push_back(absl::make_unique<AllocatedType>()); + EXPECT_EQ(1, AllocatedType::alloc_count); + EXPECT_NE(nullptr, list.front()); + list.clear(); + EXPECT_EQ(0, AllocatedType::alloc_count); + + // Push/pop. + list.push_back(absl::make_unique<AllocatedType>()); + EXPECT_EQ(1, AllocatedType::alloc_count); + EXPECT_NE(nullptr, list.front()); + for (auto item : list) { + EXPECT_EQ(item, list.front()); + } + list.pop_back(); + EXPECT_EQ(0, AllocatedType::alloc_count); + + // Push/take. + list.push_back(absl::make_unique<AllocatedType>()); + EXPECT_EQ(1, AllocatedType::alloc_count); + EXPECT_NE(nullptr, list.front()); + auto item = list.take(list.front()); + EXPECT_TRUE(list.empty()); + EXPECT_NE(nullptr, item.get()); + EXPECT_EQ(1, AllocatedType::alloc_count); + item.reset(); + EXPECT_EQ(0, AllocatedType::alloc_count); + + // Push/replace. + list.push_back(absl::make_unique<AllocatedType>()); + EXPECT_EQ(1, AllocatedType::alloc_count); + list.replace(list.front(), absl::make_unique<AllocatedType>()); + EXPECT_EQ(1, AllocatedType::alloc_count); + list.clear(); + EXPECT_EQ(0, AllocatedType::alloc_count); + + // Iteration. + list.push_back(absl::make_unique<AllocatedType>()); + list.push_back(absl::make_unique<AllocatedType>()); + list.push_back(absl::make_unique<AllocatedType>()); + EXPECT_EQ(3, AllocatedType::alloc_count); + for (auto item : list) { + AllocatedType* item_ptr = item; + EXPECT_NE(nullptr, item_ptr); + } + list.clear(); + EXPECT_EQ(0, AllocatedType::alloc_count); +} + +} // namespace +} // namespace iree
diff --git a/base/logging.h b/base/logging.h new file mode 100644 index 0000000..b256cd6 --- /dev/null +++ b/base/logging.h
@@ -0,0 +1,63 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BASE_LOGGING_H_ +#define IREE_BASE_LOGGING_H_ + +// Logging macros live in their own file so that we can use external versions +// as required. +// +// LOG(severity) << ...; +// Logs a message at the given severity. +// Severity: +// INFO Logs information text. +// WARNING Logs a warning. +// ERROR Logs an error. +// FATAL Logs an error and exit(1). +// +// VLOG(level) << ...; +// Logs a verbose message at the given verbosity level. +// +// DVLOG(level) << ...; +// Behaves like `VLOG` in debug mode (i.e. `#ifndef NDEBUG`). +// Otherwise, it compiles away and does nothing. +// +// CHECK(condition) << ...; +// Runtime asserts that the given condition is true even in release builds. +// It's recommended that DCHECK is used instead as too many CHECKs +// can impact performance. +// +// CHECK_EQ|NE|LT|GT|LE|GE(val1, val2) << ...; +// Runtime assert the specified operation with the given values. +// +// DCHECK(condition) << ...; +// Runtime asserts that the given condition is true only in non-opt builds. +// +// DCHECK_EQ|NE|LT|GT|LE|GE(val1, val2) << ...; +// Runtime assert the specified operation with the given values in non-opt +// builds. +// +// QCHECK(condition) << ...; +// QCHECK_EQ|NE|LT|GT|LE|GE(val1, val2) << ...; +// These behave like `CHECK` but do not print a full stack trace. +// They are useful when problems are definitely unrelated to program flow, +// e.g. when validating user input. + +#ifdef IREE_CONFIG_GOOGLE_INTERNAL +#include "base/google/logging_google.h" +#else +#include "base/internal/logging.h" +#endif // IREE_CONFIG_GOOGLE_INTERNAL + +#endif // IREE_BASE_LOGGING_H_
diff --git a/iree/base/math.h b/base/math.h similarity index 100% rename from iree/base/math.h rename to base/math.h
diff --git a/iree/base/memory.h b/base/memory.h similarity index 100% rename from iree/base/memory.h rename to base/memory.h
diff --git a/base/ref_ptr.h b/base/ref_ptr.h new file mode 100644 index 0000000..996331d --- /dev/null +++ b/base/ref_ptr.h
@@ -0,0 +1,364 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BASE_REF_PTR_H_ +#define IREE_BASE_REF_PTR_H_ + +#include <atomic> +#include <cstdint> +#include <type_traits> +#include <utility> + +#include "absl/base/attributes.h" +#include "base/logging.h" + +namespace iree { + +// Use this to get really verbose refptr logging: +// #define IREE_VERBOSE_REF_PTR + +template <class T> +class ref_ptr; + +// Allocates a new ref_ptr type. +// Like make_unique, but for ref_ptr. +// +// Usage: +// ref_ptr<MyType> p = make_ref<MyType>(1, 2, 3); +template <typename T, typename... Args> +ref_ptr<T> make_ref(Args&&... args) { + return ref_ptr<T>(new T(std::forward<Args>(args)...)); +} + +// Assigns a raw pointer to a ref_ptr without adding a reference. +// +// Usage: +// ref_ptr<MyType> p = assign_ref(new MyType()); +template <typename T> +inline ref_ptr<T> assign_ref(T* value) { + return ref_ptr<T>(value); +} + +// Adds a reference to the given raw pointer. +// +// Usage: +// MyType* raw_ptr = AcquirePointerFromSomewhere(); +// ref_ptr<MyType> p = add_ref(raw_ptr); +template <typename T> +inline ref_ptr<T> add_ref(T* value) { + if (value) ref_ptr_add_ref(value); + return ref_ptr<T>(value); +} + +// Adds a reference to the given ref_ptr. +// +// Usage: +// ref_ptr<MyType> a = make_ref<MyType>(); +// ref_ptr<MyType> p = add_ref(a); +template <typename T> +inline ref_ptr<T> add_ref(const ref_ptr<T>& value) { + if (value.get()) ref_ptr_add_ref(value.get()); + return ref_ptr<T>(value.get()); +} + +// Reference counted pointer container. +// This is modeled on boost::instrusive_ptr in that it requires no +// extra storage over the pointer type and should compile to almost +// no additional code. It also allows us to round-trip object pointers +// through regular pointers, which is critical when having to round-trip +// them through JNI/etc where we can't use things like unique_ptr/shared_ptr. +// +// ref_ptr<Foo> p1(new Foo()); // ref count 1 +// ref_ptr<Foo> p2(p1); // ref count 2 +// p1.reset(); // ref count 1 +// p2.reset(); // ref count 0, deleted +// +// When round-tripping the pointer through external APIs, use release(): +// ref_ptr<Foo> p1(new Foo()); // ref count 1 +// Foo* raw_p = p1.release(); // ref count 1 +// // pass to API +// ref_ptr<Foo> p2(raw_p); // ref count 1 (don't add ref) +// p2.reset(); // ref count 0, deleted +// +// See the boost intrusive_ptr docs for details of behavior: +// http://www.boost.org/doc/libs/1_55_0/libs/smart_ptr/intrusive_ptr.html +// +// ref_ptr manages the target objects in a thread-safe way, though you'll want +// to take care with objects that may have pinned threads for deallocation. If +// you release the last reference to an object on a thread other than what it +// was expecting you're gonna have a bad time. +// +// Compatible only with types that subclass RefObject or implement the following +// methods: +// ref_ptr_add_ref +// ref_ptr_release_ref +template <class T> +class ref_ptr { + private: + typedef ref_ptr this_type; + typedef T* this_type::*unspecified_bool_type; + + public: + // Initializes with nullptr. + ABSL_ATTRIBUTE_ALWAYS_INLINE ref_ptr() noexcept = default; + + // Initializes with nullptr so that there is no way to create an + // uninitialized ref_ptr. + ABSL_ATTRIBUTE_ALWAYS_INLINE ref_ptr(std::nullptr_t) noexcept {} // NOLINT + + // Initializes the pointer to the given value. + // The value will not have its reference count incremented (as it is with + // unique_ptr). Use Retain to add to the reference count. + ABSL_ATTRIBUTE_ALWAYS_INLINE explicit ref_ptr(T* p) noexcept : px_(p) {} + + // Decrements the reference count of the owned pointer. + ABSL_ATTRIBUTE_ALWAYS_INLINE ~ref_ptr() noexcept { + if (px_) ref_ptr_release_ref(px_); + } + + // No implicit ref_ptr copying allowed; use add_ref instead. + ref_ptr(const ref_ptr&) noexcept = delete; + ref_ptr& operator=(const ref_ptr&) noexcept = delete; + + // Move support to transfer ownership from one ref_ptr to another. + ref_ptr(ref_ptr&& rhs) noexcept : px_(rhs.release()) {} + ref_ptr& operator=(ref_ptr&& rhs) noexcept { + if (px_ != rhs.px_) { + if (px_) ref_ptr_release_ref(px_); + px_ = rhs.release(); + } + return *this; + } + + // Move support from another compatible type. + template <typename U> + ref_ptr(ref_ptr<U>&& rhs) noexcept : px_(rhs.release()) {} // NOLINT + template <typename U> + ref_ptr& operator=(ref_ptr<U>&& rhs) noexcept { + if (px_ != rhs.get()) { + if (px_) ref_ptr_release_ref(px_); + px_ = rhs.release(); + } + return *this; + } + + // Resets the object to nullptr and decrements the reference count, possibly + // deleting it. + void reset() noexcept { + if (px_) { + ref_ptr_release_ref(px_); + px_ = nullptr; + } + } + + // Releases a pointer. + // Returns the current pointer held by this object without having + // its reference count decremented and resets the ref_ptr to empty. + // Returns nullptr if the ref_ptr holds no value. + // To re-wrap in a ref_ptr use either ref_ptr<T>(value) or assign(). + ABSL_ATTRIBUTE_ALWAYS_INLINE T* release() noexcept { + T* p = px_; + px_ = nullptr; + return p; + } + + // Assigns a pointer. + // The pointer will be accepted by the ref_ptr and its reference count will + // not be incremented. + ABSL_ATTRIBUTE_ALWAYS_INLINE void assign(T* value) noexcept { + reset(); + px_ = value; + } + + // Gets the pointer referenced by this instance. + // operator* and operator-> will assert() if there is no current object. + constexpr T* get() const noexcept { return px_; } + constexpr T& operator*() const noexcept { return *px_; } + constexpr T* operator->() const noexcept { return px_; } + + // Support boolean expression evaluation ala unique_ptr/shared_ptr: + // https://en.cppreference.com/w/cpp/memory/shared_ptr/operator_bool + constexpr operator unspecified_bool_type() const noexcept { + return px_ ? &this_type::px_ : nullptr; + } + // Supports unary expression evaluation. + constexpr bool operator!() const noexcept { return !px_; } + + // Swap support. + void swap(ref_ptr& rhs) { std::swap(px_, rhs.px_); } + + private: + T* px_ = nullptr; +}; + +// Base class for reference counted objects. +// Reference counted objects should be used with the ref_ptr pointer type. +// As reference counting can be tricky always prefer to use unique_ptr and +// avoid this type. Only use this when unique_ptr is not possible, such as +// when round-tripping objects through marshaling boundaries (v8/Java) or +// any objects that may have their lifetime tied to a garbage collected +// object. +// +// Subclasses should protect their dtor so that reference counting must +// be used. +// +// This is designed to avoid the need for extra vtable space or for adding +// methods to the vtable of subclasses. This differs from the boost Pointable +// version of this object. +// Inspiration for this comes from Peter Weinert's Dr. Dobb's article: +// http://www.drdobbs.com/cpp/a-base-class-for-intrusively-reference-c/229218807 +// +// RefObjects are thread safe and may be used with ref_ptrs from multiple +// threads. +// +// Subclasses may implement a custom Delete operator to handle their +// deallocation. It should be thread safe as it may be called from any thread. +// +// Usage: +// class MyRefObject : public RefObject<MyRefObject> { +// public: +// MyRefObject() = default; +// // Optional; can be used to return to pool/etc - must be public: +// static void Delete(MyRefObject* ptr) { +// ::operator delete(ptr); +// } +// }; +template <class T> +class RefObject { + static_assert(!std::is_array<T>::value, "T must not be an array"); + + // value is true if a static Delete(T*) function is present. + struct has_custom_deleter { + template <typename C> + static auto Test(C* p) -> decltype(C::Delete(nullptr), std::true_type()); + template <typename> + static std::false_type Test(...); + static constexpr bool value = + std::is_same<std::true_type, decltype(Test<T>(nullptr))>::value; + }; + + template <typename V, bool has_custom_deleter> + struct delete_thunk { + static void Delete(V* p) { + auto ref_obj = static_cast<RefObject<V>*>(p); + int previous_count = ref_obj->counter_.fetch_sub(1); +#ifdef IREE_VERBOSE_REF_PTR + LOG(INFO) << "ro-- " << typeid(V).name() << " " << p << " now " + << previous_count - 1 + << (previous_count == 1 ? " DEAD (CUSTOM)" : ""); +#endif // IREE_VERBOSE_REF_PTR + if (previous_count == 1) { + // We delete type T pointer here to avoid the need for a virtual dtor. + V::Delete(p); + } + } + }; + + template <typename V> + struct delete_thunk<V, false> { + static void Delete(V* p) { + auto ref_obj = static_cast<RefObject<V>*>(p); + int previous_count = ref_obj->counter_.fetch_sub(1); +#ifdef IREE_VERBOSE_REF_PTR + LOG(INFO) << "ro-- " << typeid(V).name() << " " << p << " now " + << previous_count - 1 << (previous_count == 1 ? " DEAD" : ""); +#endif // IREE_VERBOSE_REF_PTR + if (previous_count == 1) { + // We delete type T pointer here to avoid the need for a virtual dtor. + delete p; + } + } + }; + + public: + // Adds a reference; used by ref_ptr. + friend void ref_ptr_add_ref(T* p) { + auto ref_obj = static_cast<RefObject*>(p); + ++ref_obj->counter_; + +#ifdef IREE_VERBOSE_REF_PTR + LOG(INFO) << "ro++ " << typeid(T).name() << " " << p << " now " + << ref_obj->counter_; +#endif // IREE_VERBOSE_REF_PTR + } + + // Releases a reference, potentially deleting the object; used by ref_ptr. + friend void ref_ptr_release_ref(T* p) { + delete_thunk<T, has_custom_deleter::value>::Delete(p); + } + + // Adds a reference. + // ref_ptr should be used instead of this in most cases. This is required + // for when interoperating with marshaling APIs. + void AddReference() { ref_ptr_add_ref(static_cast<T*>(this)); } + + // Releases a reference, potentially deleting the object. + // ref_ptr should be used instead of this in most cases. This is required + // for when interoperating with marshaling APIs. + void ReleaseReference() { ref_ptr_release_ref(static_cast<T*>(this)); } + + protected: + RefObject() { ref_ptr_add_ref(static_cast<T*>(this)); } + RefObject(const RefObject&) = default; + RefObject& operator=(const RefObject&) { return *this; } + + std::atomic<intptr_t> counter_{0}; +}; + +// Various comparison operator overloads. + +template <class T, class U> +inline bool operator==(ref_ptr<T> const& a, ref_ptr<U> const& b) { + return a.get() == b.get(); +} + +template <class T, class U> +inline bool operator!=(ref_ptr<T> const& a, ref_ptr<U> const& b) { + return a.get() != b.get(); +} + +template <class T, class U> +inline bool operator==(ref_ptr<T> const& a, U* b) { + return a.get() == b; +} + +template <class T, class U> +inline bool operator!=(ref_ptr<T> const& a, U* b) { + return a.get() != b; +} + +template <class T, class U> +inline bool operator==(T* a, ref_ptr<U> const& b) { + return a == b.get(); +} + +template <class T, class U> +inline bool operator!=(T* a, ref_ptr<U> const& b) { + return a != b.get(); +} + +template <class T> +inline bool operator<(ref_ptr<T> const& a, ref_ptr<T> const& b) { + return a.get() < b.get(); +} + +// Swaps the pointers of two ref_ptrs. +template <class T> +void swap(ref_ptr<T>& lhs, ref_ptr<T>& rhs) { + lhs.swap(rhs); +} + +} // namespace iree + +#endif // IREE_BASE_REF_PTR_H_
diff --git a/base/ref_ptr_test.cc b/base/ref_ptr_test.cc new file mode 100644 index 0000000..5642eef --- /dev/null +++ b/base/ref_ptr_test.cc
@@ -0,0 +1,330 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/ref_ptr.h" + +#include "gtest/gtest.h" + +namespace iree { +namespace { + +class MyType : public RefObject<MyType> { + public: + int x = 5; + + using RefObject<MyType>::counter_; // Expose for testing. +}; + +TEST(RefPtrTest, Construction) { + // Empty. + ref_ptr<MyType> n1; + EXPECT_EQ(nullptr, n1.get()); + ref_ptr<MyType> n2(nullptr); + EXPECT_EQ(nullptr, n2.get()); + + // Assign a new ptr and add ref. + MyType* a_ptr = new MyType(); + EXPECT_EQ(1, a_ptr->counter_); + ref_ptr<MyType> a(a_ptr); + EXPECT_EQ(1, a->counter_); + + // Assign existing ptr without adding a ref. + ref_ptr<MyType> b(a_ptr); + EXPECT_EQ(1, b->counter_); + + // Add a new ref. + ref_ptr<MyType> c = add_ref(b); + EXPECT_EQ(2, c->counter_); + + b.release(); +} + +TEST(RefPtrTest, Assign) { + // Ok to assign nothing. + ref_ptr<MyType> n1 = assign_ref<MyType>(nullptr); + EXPECT_EQ(nullptr, n1.get()); + + ref_ptr<MyType> mt = make_ref<MyType>(); + EXPECT_EQ(1, mt->counter_); + ref_ptr<MyType> n2 = assign_ref(mt.get()); + EXPECT_EQ(1, mt->counter_); + mt.release(); // must release, as we assigned to n2. + EXPECT_EQ(1, n2->counter_); + n2.reset(); +} + +TEST(RefPtrTest, Retain) { + // Ok to retain nothing. + ref_ptr<MyType> n1 = add_ref<MyType>(nullptr); + EXPECT_EQ(nullptr, n1.get()); + + ref_ptr<MyType> mt = make_ref<MyType>(); + EXPECT_EQ(1, mt->counter_); + ref_ptr<MyType> n2 = add_ref(mt.get()); + EXPECT_EQ(2, mt->counter_); + mt.reset(); + EXPECT_EQ(1, n2->counter_); + n2.reset(); +} + +TEST(RefPtrTest, Reset) { + ref_ptr<MyType> a(new MyType()); + ref_ptr<MyType> b(new MyType()); + + // Reset to drop reference. + ref_ptr<MyType> a_copy = add_ref(a); + EXPECT_EQ(2, a_copy->counter_); + a.reset(); + EXPECT_EQ(1, a_copy->counter_); + + // Reset via = operator. + a = nullptr; + EXPECT_EQ(1, a_copy->counter_); + a = add_ref(a_copy); + EXPECT_EQ(2, a_copy->counter_); + + // No-op on empty ptrs. + ref_ptr<MyType> n; + n.reset(); + n.assign(nullptr); +} + +TEST(RefPtrTest, ReleaseAssign) { + ref_ptr<MyType> a(new MyType()); + + // Release a's pointer. + MyType* a_raw_ptr = a.get(); + MyType* a_ptr = a.release(); + EXPECT_EQ(a_raw_ptr, a_ptr); + EXPECT_EQ(nullptr, a.get()); + EXPECT_EQ(1, a_ptr->counter_); + + // Re-wrap in a ref_ptr. + a.assign(a_ptr); + EXPECT_EQ(1, a->counter_); + + // No-op on empty ptrs. + ref_ptr<MyType> n; + EXPECT_EQ(nullptr, n.release()); +} + +TEST(RefPtrTest, Accessors) { + ref_ptr<MyType> a(new MyType()); + EXPECT_EQ(5, a->x); + a->x = 100; + EXPECT_EQ(100, a->x); + + MyType& ra = *a; + ra.x = 200; + EXPECT_EQ(200, ra.x); + + const MyType& cra = *a; + EXPECT_EQ(200, cra.x); +} + +TEST(RefPtrTest, BooleanExpressions) { + ref_ptr<MyType> a(new MyType()); + ref_ptr<MyType> n; + + EXPECT_NE(nullptr, a.get()); + EXPECT_TRUE(a); + EXPECT_FALSE(!a); + EXPECT_EQ(true, static_cast<bool>(a)); + + EXPECT_EQ(nullptr, n.get()); + EXPECT_FALSE(n); + EXPECT_TRUE(!n); + EXPECT_EQ(false, static_cast<bool>(n)); +} + +TEST(RefPtrTest, Comparisons) { + ref_ptr<MyType> a(new MyType()); + ref_ptr<MyType> b(new MyType()); + ref_ptr<MyType> n; + + EXPECT_TRUE(a == a); + EXPECT_TRUE(a == a.get()); + EXPECT_TRUE(a.get() == a); + EXPECT_FALSE(a != a); + EXPECT_FALSE(a != a.get()); + EXPECT_FALSE(a.get() != a); + + EXPECT_FALSE(a == b); + EXPECT_FALSE(a == b.get()); + EXPECT_FALSE(a.get() == b); + EXPECT_TRUE(a != b); + EXPECT_TRUE(a != b.get()); + EXPECT_TRUE(a.get() != b); + + EXPECT_TRUE(n == n); + EXPECT_TRUE(n == n.get()); + EXPECT_TRUE(n.get() == n); + EXPECT_FALSE(n != n); + EXPECT_FALSE(n != n.get()); + EXPECT_FALSE(n.get() != n); + + EXPECT_FALSE(a < a); + EXPECT_TRUE(n < a); +} + +TEST(RefPtrTest, Swap) { + ref_ptr<MyType> a(new MyType()); + ref_ptr<MyType> b(new MyType()); + MyType* a_ptr = a.get(); + MyType* b_ptr = b.get(); + + swap(a, a); + EXPECT_EQ(a_ptr, a); + + swap(a, b); + EXPECT_EQ(a_ptr, b.get()); + EXPECT_EQ(b_ptr, a.get()); + + swap(a, b); + EXPECT_EQ(a_ptr, a.get()); + EXPECT_EQ(b_ptr, b.get()); + + ref_ptr<MyType> c; + swap(a, c); + EXPECT_EQ(a_ptr, c.get()); + EXPECT_EQ(nullptr, a.get()); +} + +TEST(RefPtrTest, Move) { + auto a = make_ref<MyType>(); + auto b = make_ref<MyType>(); + ref_ptr<MyType> c; + EXPECT_EQ(nullptr, c.get()); + + c = std::move(a); + EXPECT_NE(nullptr, c.get()); + + b = std::move(c); + EXPECT_NE(nullptr, b.get()); +} + +TEST(RefPtrTest, MoveCompatible) { + struct MyBaseType : public RefObject<MyBaseType> { + int x = 5; + using RefObject<MyBaseType>::counter_; // Expose for testing. + }; + struct MyTypeA : public MyBaseType { + int a = 6; + }; + struct MyTypeB : public MyBaseType { + int b = 7; + }; + + ref_ptr<MyTypeA> a = make_ref<MyTypeA>(); + EXPECT_EQ(1, a->counter_); + ref_ptr<MyBaseType> base = add_ref(a); + EXPECT_EQ(a.get(), base.get()); + EXPECT_EQ(2, a->counter_); + + base = make_ref<MyTypeB>(); + EXPECT_EQ(1, a->counter_); + EXPECT_EQ(1, base->counter_); +} + +TEST(RefPtrTest, StackAllocation) { + static int alloc_count = 0; + class StackAllocationType : public RefObject<StackAllocationType> { + public: + StackAllocationType() { ++alloc_count; } + ~StackAllocationType() { --alloc_count; } + }; + { + StackAllocationType a; + EXPECT_EQ(1, alloc_count); + } + EXPECT_EQ(0, alloc_count); +} + +TEST(RefPtrTest, DefaultDeleter) { + static int alloc_count = 0; + class DefaultDeleterType : public RefObject<DefaultDeleterType> { + public: + DefaultDeleterType() { ++alloc_count; } + ~DefaultDeleterType() { --alloc_count; } + }; + + // Empty is ok. + ref_ptr<DefaultDeleterType> n; + n.reset(); + + // Lifecycle. + EXPECT_EQ(0, alloc_count); + ref_ptr<DefaultDeleterType> a = make_ref<DefaultDeleterType>(); + EXPECT_EQ(1, alloc_count); + a.reset(); + EXPECT_EQ(0, alloc_count); +} + +TEST(RefPtrTest, InlineDeallocator) { + static int alloc_count = 0; + class CustomDeleterType : public RefObject<CustomDeleterType> { + public: + CustomDeleterType() { ++alloc_count; } + static void Delete(CustomDeleterType* ptr) { + --alloc_count; + ::operator delete(ptr); + } + }; + + // Empty is ok. + ref_ptr<CustomDeleterType> n; + n.reset(); + + // Lifecycle. + EXPECT_EQ(0, alloc_count); + auto a = make_ref<CustomDeleterType>(); + EXPECT_EQ(1, alloc_count); + a.reset(); + EXPECT_EQ(0, alloc_count); +} + +class VirtualDtorTypeA : public RefObject<VirtualDtorTypeA> { + public: + VirtualDtorTypeA() { ++alloc_count_a; } + virtual ~VirtualDtorTypeA() { --alloc_count_a; } + static int alloc_count_a; +}; +int VirtualDtorTypeA::alloc_count_a = 0; + +class VirtualDtorTypeB : public VirtualDtorTypeA { + public: + VirtualDtorTypeB() { ++alloc_count_b; } + ~VirtualDtorTypeB() override { --alloc_count_b; } + static int alloc_count_b; +}; +int VirtualDtorTypeB::alloc_count_b = 0; + +TEST(RefPtrTest, VirtualDestructor) { + // Empty is ok. + ref_ptr<VirtualDtorTypeB> n; + n.reset(); + + // Lifecycle. + EXPECT_EQ(0, VirtualDtorTypeA::alloc_count_a); + EXPECT_EQ(0, VirtualDtorTypeB::alloc_count_b); + ref_ptr<VirtualDtorTypeB> a = make_ref<VirtualDtorTypeB>(); + EXPECT_EQ(1, VirtualDtorTypeA::alloc_count_a); + EXPECT_EQ(1, VirtualDtorTypeB::alloc_count_b); + a.reset(); + EXPECT_EQ(0, VirtualDtorTypeA::alloc_count_a); + EXPECT_EQ(0, VirtualDtorTypeB::alloc_count_b); +} + +} // namespace +} // namespace iree
diff --git a/base/shape.cc b/base/shape.cc new file mode 100644 index 0000000..611d295 --- /dev/null +++ b/base/shape.cc
@@ -0,0 +1,100 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/shape.h" + +#include <cstddef> + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "base/source_location.h" +#include "base/status.h" + +namespace iree { + +Shape::Shape(const int* values, int size) : rank_(size) { + QCHECK_LE(size, kMaxRank) + << "Max rank of " << kMaxRank << ", shape has " << size; + std::memcpy(value_, values, size * sizeof(int)); +} + +std::string Shape::DebugString() const { + return absl::StrCat("[", absl::StrJoin(subspan(), ","), "]"); +} + +absl::Span<const int> Shape::subspan(size_type pos, size_type len) const { + if (len == npos) { + len = rank_ - pos; + } + return absl::MakeConstSpan(&value_[pos], len); +} + +void Shape::push_back(int dim) { + DCHECK_LE(rank_ + 1, kMaxRank); + value_[rank_++] = dim; +} + +void Shape::insert(iterator pos, int dim) { + int axis = static_cast<int>(pos - value_); + DCHECK_GE(axis, 0); + DCHECK_LE(axis, rank_); + DCHECK_LE(rank_ + 1, kMaxRank); + ++rank_; + for (int i = rank_ - 1; i > axis; --i) { + value_[i] = value_[i - 1]; + } + value_[axis] = dim; +} + +void Shape::erase(iterator pos) { + int axis = static_cast<int>(pos - value_); + DCHECK_GE(axis, 0); + DCHECK_LE(axis, rank_); + for (int i = axis; i < rank_ - 1; ++i) { + value_[i] = value_[i + 1]; + } + --rank_; +} + +int Shape::element_count() const { + size_t element_count = 1; + for (int i = 0; i < rank_; ++i) { + int dim = value_[i]; + if (dim == -1) { + return 0; + } + element_count *= dim; + } + return element_count; +} + +StatusOr<int> Shape::ResolveAxis(int axis) const { + if (rank_ == 0 && (axis == -1 || axis == 0)) { + // Scalar axes resolves to 0. + return 0; + } + + int new_axis = axis; + if (new_axis < 0) { + new_axis += rank_; + } + if (new_axis < 0 || new_axis >= rank_) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Axis " << new_axis << " (orig " << axis + << ") out of bounds of rank " << rank_; + } + return new_axis; +} + +} // namespace iree
diff --git a/base/shape.h b/base/shape.h new file mode 100644 index 0000000..b0e3ddd --- /dev/null +++ b/base/shape.h
@@ -0,0 +1,156 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BASE_SHAPE_H_ +#define IREE_BASE_SHAPE_H_ + +#include <array> +#include <cstring> +#include <initializer_list> +#include <iterator> +#include <string> +#include <type_traits> +#include <vector> + +#include "absl/meta/type_traits.h" +#include "absl/types/span.h" +#include "base/logging.h" +#include "base/status.h" + +namespace iree { + +// For simplicity we limit our shapes to a max of rank-N (shape.size() == N) as +// this prevents dynamic allocations and rarely are there greater ranks. +constexpr int kMaxRank = 5; + +// Represent indices and lengths of tensors. +using Index = std::array<int, kMaxRank>; +using Length = std::array<int, kMaxRank>; + +// Represents the number of elements in multiple dimensions. +// Can be rank-0 (scalar) to rank-kMaxRank. Tries to match the API of +// std::vector and can be converted to a Span via subspan(). +// +// https://www.tensorflow.org/guide/tensors#shape +class Shape { + public: + using size_type = int; + static constexpr size_type npos = ~(size_type(0)); // NOLINT + using iterator = int*; + using const_iterator = const int*; + + Shape() = default; + Shape(const int* values, int size); + Shape(std::initializer_list<int> values) + : Shape(values.begin(), values.size()) {} + explicit Shape(absl::Span<const int> values) + : Shape(values.data(), values.size()) {} + + template <typename Iterator> + using EnableIfForwardIterator = absl::enable_if_t<std::is_convertible< + typename std::iterator_traits<Iterator>::iterator_category, + std::forward_iterator_tag>::value>; + template <typename Iterator, EnableIfForwardIterator<Iterator>* = nullptr> + Shape(Iterator first, Iterator last) { + rank_ = std::distance(first, last); + QCHECK_LE(rank_, kMaxRank); + for (int i = 0; first != last; ++i, static_cast<void>(++first)) { + value_[i] = *first; + } + } + + // Returns a string representation of the given shape. + std::string DebugString() const; + + // Size (aka 'rank') of the shape, counting the number of dimensions. + constexpr size_type size() const noexcept { return rank_; } + + // Whether the shape is rank-0 (scalar). + constexpr bool empty() const noexcept { return rank_ == 0; } + + // Returns the total elements in the tensor shape. + // Returns 0 if the tensor shape is not complete and 1 if the shape is a + // scalar value. + int element_count() const; + + // Resolves an axis in [-R,R) to the real axis value and verifies the range. + StatusOr<int> ResolveAxis(int axis) const; + + // Compares two shapes for equality. + inline static bool Equal(const Shape& a, const Shape& b) { + return a.rank_ == b.rank_ && + std::memcmp(a.value_, b.value_, a.rank_ * sizeof(value_[0])) == 0; + } + + int& operator[](size_type i) noexcept { + DCHECK_GE(i, 0); + DCHECK_LT(i, rank_); + return value_[i]; + } + + const int& operator[](size_type i) const noexcept { + DCHECK_GE(i, 0); + DCHECK_LT(i, rank_); + return value_[i]; + } + + int front() const noexcept { + DCHECK_GE(rank_, 1); + return value_[0]; + } + + int back() const noexcept { + DCHECK_GE(rank_, 1); + return value_[rank_ - 1]; + } + + constexpr iterator begin() const noexcept { + return const_cast<iterator>(&value_[0]); + } + constexpr iterator end() const noexcept { + return const_cast<iterator>(&value_[rank_]); + } + constexpr const_iterator cbegin() const noexcept { return &value_[0]; } + constexpr const_iterator cend() const noexcept { return &value_[rank_]; } + + absl::Span<const int> subspan(size_type pos = 0, size_type len = npos) const; + absl::Span<const int> data() const { return subspan(); } + + void push_back(int dim); + + void insert(iterator pos, int dim); + + void erase(iterator pos); + + void clear() { rank_ = 0; } + + private: + size_type rank_ = 0; + int value_[kMaxRank]; +}; + +inline bool operator==(const Shape& a, const Shape& b) { + return Shape::Equal(a, b); +} + +inline bool operator!=(const Shape& a, const Shape& b) { return !(a == b); } + +inline std::ostream& operator<<(std::ostream& stream, const Shape& shape) { + stream << shape.DebugString(); + return stream; +} + +} // namespace iree + +#endif // IREE_BASE_SHAPE_H_
diff --git a/base/shape_test.cc b/base/shape_test.cc new file mode 100644 index 0000000..ac7b071 --- /dev/null +++ b/base/shape_test.cc
@@ -0,0 +1,222 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/shape.h" + +#include "base/status.h" +#include "base/status_matchers.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace iree { +namespace { + +using ::testing::ElementsAre; + +// Tests shapes that represent 0-D scalar values. +TEST(ShapeTest, Scalar) { + Shape shape; + EXPECT_EQ(0, shape.size()); + EXPECT_TRUE(shape.empty()); + EXPECT_EQ(1, shape.element_count()); + EXPECT_EQ(shape, shape); + EXPECT_EQ(0, shape.subspan().size()); + for (const int dim : shape) { + FAIL() << "Should have no dimensions, have: " << dim; + } + EXPECT_EQ(shape.begin(), shape.end()); + EXPECT_EQ(shape.cbegin(), shape.cend()); + shape.clear(); + EXPECT_EQ(0, shape.size()); +} + +// Tests the various ways of constructing a 1+D shape. +TEST(ShapeTest, NonScalarConstruction) { + EXPECT_EQ(0, Shape().size()); + EXPECT_EQ(0, Shape({}).size()); + EXPECT_EQ(1, Shape({10}).size()); + EXPECT_EQ(4, Shape({10, 20, 30, 40}).size()); + + std::vector<int> empty_data = {}; + EXPECT_EQ(0, Shape(empty_data.data(), empty_data.size()).size()); + EXPECT_EQ(0, Shape(empty_data.begin(), empty_data.end()).size()); + EXPECT_EQ(0, Shape(absl::MakeConstSpan(empty_data)).size()); + + EXPECT_THAT(Shape({}).subspan(), ElementsAre()); + EXPECT_THAT(Shape({10}).subspan(), ElementsAre(10)); + EXPECT_THAT(Shape({10, 20, 30, 40}).subspan(), ElementsAre(10, 20, 30, 40)); + + std::vector<int> valid_data = {10, 20, 30, 40}; + EXPECT_THAT(Shape(valid_data.begin(), valid_data.end()).subspan(), + ElementsAre(10, 20, 30, 40)); + EXPECT_THAT(Shape(absl::MakeConstSpan(valid_data)).subspan(), + ElementsAre(10, 20, 30, 40)); +} + +// Tests shapes that represent 1+D multidimensional values. +TEST(ShapeTest, NonScalarAccess) { + Shape shape = {1, 2, 3, 4}; + EXPECT_EQ(4, shape.size()); + EXPECT_FALSE(shape.empty()); + EXPECT_EQ(1 * 2 * 3 * 4, shape.element_count()); + EXPECT_EQ(shape, shape); + EXPECT_NE(shape, Shape({4, 3, 2, 1})); + EXPECT_THAT(shape.subspan(), ElementsAre(1, 2, 3, 4)); + std::vector<int> readout; + for (const int dim : shape) { + readout.push_back(dim); + } + EXPECT_THAT(readout, ElementsAre(1, 2, 3, 4)); + EXPECT_EQ(1, shape[0]); + EXPECT_EQ(2, shape[1]); + EXPECT_EQ(3, shape[2]); + EXPECT_EQ(4, shape[3]); + EXPECT_EQ(1, shape.front()); + EXPECT_EQ(4, shape.back()); +} + +TEST(ShapeTest, PushBack) { + Shape shape; + EXPECT_EQ(0, shape.size()); + + shape.push_back(10); + EXPECT_EQ(1, shape.size()); + EXPECT_EQ(10, shape.front()); + EXPECT_EQ(10, shape.back()); + EXPECT_EQ(10, shape[0]); + EXPECT_THAT(shape.subspan(), ElementsAre(10)); + + shape.push_back(20); + EXPECT_EQ(2, shape.size()); + EXPECT_EQ(10, shape.front()); + EXPECT_EQ(20, shape.back()); + EXPECT_EQ(10, shape[0]); + EXPECT_EQ(20, shape[1]); + EXPECT_THAT(shape.subspan(), ElementsAre(10, 20)); +} + +TEST(ShapeTest, Insert) { + Shape shape; + EXPECT_EQ(0, shape.size()); + + shape.insert(shape.begin(), 20); + EXPECT_THAT(shape.subspan(), ElementsAre(20)); + shape.insert(shape.begin(), 10); + EXPECT_THAT(shape.subspan(), ElementsAre(10, 20)); + shape.insert(shape.end(), 40); + EXPECT_THAT(shape.subspan(), ElementsAre(10, 20, 40)); + shape.insert(shape.begin() + 2, 30); + EXPECT_THAT(shape.subspan(), ElementsAre(10, 20, 30, 40)); + + Shape ex_shape{72, 4}; + ex_shape.insert(ex_shape.begin(), 144); + EXPECT_THAT(ex_shape.subspan(), ElementsAre(144, 72, 4)); +} + +TEST(ShapeTest, Erase) { + Shape shape = {1, 2, 3, 4}; + EXPECT_THAT(shape.subspan(), ElementsAre(1, 2, 3, 4)); + shape.erase(shape.begin()); + EXPECT_THAT(shape.subspan(), ElementsAre(2, 3, 4)); + shape.erase(shape.end()); + EXPECT_THAT(shape.subspan(), ElementsAre(2, 3)); + shape.erase(shape.begin() + 1); + EXPECT_THAT(shape.subspan(), ElementsAre(2)); + shape.erase(shape.end()); + EXPECT_THAT(shape.subspan(), ElementsAre()); +} + +TEST(ShapeTest, Clear) { + Shape shape; + EXPECT_EQ(0, shape.size()); + shape.clear(); + EXPECT_EQ(0, shape.size()); + + shape = Shape({1}); + shape.clear(); + EXPECT_EQ(0, shape.size()); + + shape = Shape({1, 2, 3, 4}); + shape.clear(); + EXPECT_EQ(0, shape.size()); +} + +TEST(ShapeTest, DebugString) { + EXPECT_EQ("[]", Shape({}).DebugString()); + EXPECT_EQ("[1]", Shape({1}).DebugString()); + EXPECT_EQ("[1,2]", Shape({1, 2}).DebugString()); +} + +TEST(ShapeTest, ElementCount) { + EXPECT_EQ(1, Shape({}).element_count()); + EXPECT_EQ(0, Shape({0}).element_count()); + EXPECT_EQ(1, Shape({1}).element_count()); + EXPECT_EQ(2, Shape({2, 1}).element_count()); + EXPECT_EQ(10, Shape({2, 5}).element_count()); + EXPECT_EQ(9216, Shape({72, 1, 128}).element_count()); + EXPECT_EQ(9216, Shape({1, 72, 128}).element_count()); + + // Partial shaping should yield no elements. + EXPECT_EQ(0, Shape({1, -1, 2, 3}).element_count()); +} + +TEST(ShapeTest, ResolveAxis) { + int axis; + ASSERT_OK_AND_ASSIGN(axis, Shape({0}).ResolveAxis(0)); + EXPECT_EQ(0, axis); + ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(1)); + EXPECT_EQ(1, axis); + ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(2)); + EXPECT_EQ(2, axis); + + EXPECT_TRUE(IsInvalidArgument(Shape({0, 1, 2}).ResolveAxis(3).status())); +} + +TEST(ShapeTest, ResolveAxisNegative) { + int axis; + ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(-3)); + EXPECT_EQ(0, axis); + ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(-2)); + EXPECT_EQ(1, axis); + ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(-1)); + EXPECT_EQ(2, axis); + + EXPECT_TRUE(IsInvalidArgument(Shape({0, 1, 2}).ResolveAxis(-4).status())); +} + +TEST(ShapeTest, ResolveAxisScalar) { + int axis; + ASSERT_OK_AND_ASSIGN(axis, Shape({}).ResolveAxis(0)); + EXPECT_EQ(0, axis); + ASSERT_OK_AND_ASSIGN(axis, Shape({}).ResolveAxis(-1)); + EXPECT_EQ(0, axis); + + EXPECT_TRUE(IsInvalidArgument(Shape({}).ResolveAxis(1).status())); +} + +TEST(ShapeTest, Equality) { + EXPECT_EQ(Shape({}), Shape({})); + EXPECT_EQ(Shape({0}), Shape({0})); + EXPECT_EQ(Shape({1}), Shape({1})); + EXPECT_EQ(Shape({1, 2}), Shape({1, 2})); + + EXPECT_NE(Shape({}), Shape({1})); + EXPECT_NE(Shape({-1}), Shape({1})); + EXPECT_NE(Shape({1}), Shape({})); + EXPECT_NE(Shape({1}), Shape({2})); + EXPECT_NE(Shape({1, 2}), Shape({3, 4})); +} + +} // namespace +} // namespace iree
diff --git a/base/source_location.h b/base/source_location.h new file mode 100644 index 0000000..c93db4a --- /dev/null +++ b/base/source_location.h
@@ -0,0 +1,24 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BASE_SOURCE_LOCATION_H_ +#define IREE_BASE_SOURCE_LOCATION_H_ + +#ifdef IREE_CONFIG_GOOGLE_INTERNAL +#include "base/google/source_location_google.h" +#else +#include "base/internal/source_location.h" +#endif // IREE_CONFIG_GOOGLE_INTERNAL + +#endif // IREE_BASE_SOURCE_LOCATION_H_
diff --git a/base/status.h b/base/status.h new file mode 100644 index 0000000..feadf75 --- /dev/null +++ b/base/status.h
@@ -0,0 +1,32 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BASE_STATUS_H_ +#define IREE_BASE_STATUS_H_ + +#ifdef IREE_CONFIG_GOOGLE_INTERNAL +#include "base/google/status_google.h" +#else +#include "base/internal/status.h" +#include "base/internal/status_builder.h" +#include "base/internal/status_errno.h" +#include "base/internal/status_errors.h" +#include "base/internal/status_macros.h" +#include "base/internal/status_win32_errors.h" +#include "base/internal/statusor.h" +#endif // IREE_CONFIG_GOOGLE_INTERNAL + +#include "base/source_location.h" // IWYU pragma: export + +#endif // IREE_BASE_STATUS_H_
diff --git a/base/status_matchers.h b/base/status_matchers.h new file mode 100644 index 0000000..533f4a4 --- /dev/null +++ b/base/status_matchers.h
@@ -0,0 +1,28 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BASE_STATUS_MATCHERS_H_ +#define IREE_BASE_STATUS_MATCHERS_H_ + +#ifdef IREE_CONFIG_GOOGLE_INTERNAL + +#include "base/google/status_matchers_google.h" // IWYU pragma: export + +#else + +#include "base/internal/status_matchers.h" // IWYU pragma: export + +#endif // IREE_CONFIG_GOOGLE_INTERNAL + +#endif // IREE_BASE_STATUS_MATCHERS_H_
diff --git a/iree/base/target_platform.h b/base/target_platform.h similarity index 100% rename from iree/base/target_platform.h rename to base/target_platform.h
diff --git a/iree/base/time.h b/base/time.h similarity index 100% rename from iree/base/time.h rename to base/time.h
diff --git a/base/tracing.cc b/base/tracing.cc new file mode 100644 index 0000000..0708532 --- /dev/null +++ b/base/tracing.cc
@@ -0,0 +1,188 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Force the header to detect WTF_ENABLE so that this library builds +// (for when building recursively). +#if !defined(WTF_ENABLE) +#define WTF_ENABLE +#endif + +#include "base/tracing.h" + +#include <thread> // NOLINT: Fiber doesn't work during startup on Android. + +#include "absl/base/attributes.h" +#include "absl/base/const_init.h" +#include "absl/base/thread_annotations.h" +#include "absl/flags/flag.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "base/file_io.h" +#include "base/file_path.h" +#include "base/init.h" +#include "base/logging.h" +#include "base/status.h" + +ABSL_FLAG(int32_t, iree_trace_file_period, 5, + "Seconds between automatic flushing of WTF trace files. 0 to " + "disable auto-flush."); +ABSL_FLAG(std::string, iree_trace_file, "/dev/null", + "wtf-trace file to save if --define=GLOBAL_WTF_ENABLE=1 was used " + "when building."); + +namespace iree { +namespace { + +// Guards global WTF state (like the flush fiber and IO). +ABSL_CONST_INIT absl::Mutex global_tracing_mutex(absl::kConstInit); + +// True when tracing has been enabled and initialized. +bool global_tracing_initialized ABSL_GUARDED_BY(global_tracing_mutex) = false; + +// If there is an existing file at the given path back it up by moving it aside. +// Only kMaxBackups will be kept to avoid unbounded growth. +void RollTraceFiles(const std::string& path) { + std::string path_stem = file_path::JoinPaths(file_path::DirectoryName(path), + file_path::Stem(path)); + const int kMaxBackups = 5; + for (int i = kMaxBackups; i >= 0; i--) { + std::string source_name; + if (i > 0) { + source_name = absl::StrCat(path_stem, ".", i, ".wtf-trace"); + } else { + source_name = path; + } + if (!file_io::FileExists(source_name).ok()) { + continue; + } + + Status status; + if (i == kMaxBackups) { + status = file_io::DeleteFile(source_name); + } else { + std::string backup_name = + absl::StrCat(path_stem, ".", (i + 1), ".wtf-trace"); + status = file_io::MoveFile(source_name, backup_name); + } + if (!status.ok()) { + LOG(WARNING) << "Could not remove backup trace file " << source_name + << ": " << status; + } + } +} + +// Flushes all recorded trace data since the last flush. +void FlushTraceFile() ABSL_EXCLUSIVE_LOCKS_REQUIRED(global_tracing_mutex) { + if (!global_tracing_initialized) return; + + const auto& trace_path = absl::GetFlag(FLAGS_iree_trace_file); + + static ::wtf::Runtime::SaveCheckpoint checkpoint; + static bool is_first_flush = true; + + if (is_first_flush && trace_path != "/dev/null") { + // Backup existing any existing trace files at the specified path. + RollTraceFiles(trace_path); + } + + auto save_options = + ::wtf::Runtime::SaveOptions::ForStreamingFile(&checkpoint); + if (is_first_flush) { + // On the first time, truncate the file. All subsequent flushes append. + save_options.open_mode = std::ios_base::trunc; + } + + is_first_flush = false; + + auto* runtime = ::wtf::Runtime::GetInstance(); + if (!runtime->SaveToFile(trace_path, save_options)) { + LOG(ERROR) << "Error saving WTF file: " << trace_path; + return; + } + + VLOG(1) << "Flushed WTF trace to: " << trace_path; +} + +} // namespace + +void InitializeTracing() { + if (!::wtf::kMasterEnable) { + if (!absl::GetFlag(FLAGS_iree_trace_file).empty()) { + LOG(WARNING) << "WTF trace save requested but WTF is not compiled in. " + << "Enable by building with --define=GLOBAL_WTF_ENABLE=1."; + } + return; + } + + absl::MutexLock lock(&global_tracing_mutex); + if (global_tracing_initialized) return; + global_tracing_initialized = true; + + LOG(INFO) << "Tracing enabled and streaming to: " + << absl::GetFlag(FLAGS_iree_trace_file); + + // Enable tracing on this thread, which we know is main. + IREE_TRACE_THREAD_ENABLE("main"); + + // Register atexit callback to stop tracking. + atexit(StopTracing); + + // Launch a thread to periodically flush the trace. + if (absl::GetFlag(FLAGS_iree_trace_file_period) > 0) { + auto flush_thread = std::thread(+[]() { + absl::Duration period = + absl::Seconds(absl::GetFlag(FLAGS_iree_trace_file_period)); + while (true) { + absl::SleepFor(period); + absl::MutexLock lock(&global_tracing_mutex); + if (!global_tracing_initialized) { + return; + } + FlushTraceFile(); + } + }); + flush_thread.detach(); + } +} + +// Stops tracing if currently initialized. +void StopTracing() { + if (!::wtf::kMasterEnable) return; + absl::MutexLock lock(&global_tracing_mutex); + if (!global_tracing_initialized) return; + + // Flush any pending trace data. + FlushTraceFile(); + + // Mark WTF as uninitialized to kill the flush thread. + global_tracing_initialized = false; + + LOG(INFO) << "Tracing stopped and flushed to file: " + << absl::GetFlag(FLAGS_iree_trace_file); +} + +void FlushTrace() { + if (!::wtf::kMasterEnable) return; + absl::MutexLock lock(&global_tracing_mutex); + if (!global_tracing_initialized) return; + FlushTraceFile(); +} + +} // namespace iree + +IREE_DECLARE_MODULE_INITIALIZER(iree_tracing); + +IREE_REGISTER_MODULE_INITIALIZER(iree_tracing, ::iree::InitializeTracing());
diff --git a/iree/base/tracing.h b/base/tracing.h similarity index 100% rename from iree/base/tracing.h rename to base/tracing.h
diff --git a/base/tracing_disabled.cc b/base/tracing_disabled.cc new file mode 100644 index 0000000..2360d54 --- /dev/null +++ b/base/tracing_disabled.cc
@@ -0,0 +1,29 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file is linked in only when WTF is not enabled. It allows us to keep the +// same flags and functions without needing to do a bunch of ifdef hackery or +// undefok mangling. + +#include <cstdint> +#include <string> + +#include "absl/flags/flag.h" +#include "base/tracing.h" + +// TODO(benvanik): remove this when disabled so that we don't dep on flags. +ABSL_FLAG(int32_t, iree_trace_file_period, 0, + "Flag for tracing. Use --define=GLOBAL_WTF_ENABLE=1 to enable WTF."); +ABSL_FLAG(std::string, iree_trace_file, "", + "Flag for tracing. Use --define=GLOBAL_WTF_ENABLE=1 to enable WTF.");
diff --git a/base/wait_handle.cc b/base/wait_handle.cc new file mode 100644 index 0000000..c8a99ce --- /dev/null +++ b/base/wait_handle.cc
@@ -0,0 +1,532 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/wait_handle.h" + +#include <errno.h> +#include <fcntl.h> +#include <poll.h> +#include <time.h> +#include <unistd.h> + +#include <type_traits> +#include <utility> + +#include "absl/container/fixed_array.h" +#include "absl/strings/str_cat.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "base/source_location.h" +#include "base/status.h" + +// TODO(benvanik): organize these macros - they are terrible. + +#if !defined(__ANDROID__) && !defined(OS_IOS) && !defined(__EMSCRIPTEN__) +#define IREE_HAS_PPOLL 1 +#endif // !__ANDROID__ && !__EMSCRIPTEN__ +#define IREE_HAS_POLL 1 + +#if !defined(OS_IOS) && !defined(OS_MACOSX) && !defined(__EMSCRIPTEN__) +#define IREE_HAS_EVENTFD 1 +#endif +#define IREE_HAS_PIPE 1 +// #define IREE_HAS_SYNC_FILE 1 + +#if defined(IREE_HAS_EVENTFD) +#include <sys/eventfd.h> +#endif // IREE_HAS_EVENTFD + +namespace iree { + +namespace { + +constexpr int kInvalidFd = WaitableObject::kInvalidFd; +constexpr int kSignaledFd = WaitableObject::kSignaledFd; + +// Retries a syscall until it succeeds or fails for a real reason. +template <typename SyscallT, typename... ParamsT> +StatusOr<typename std::result_of<SyscallT(ParamsT...)>::type> Syscall( + SyscallT syscall, ParamsT&&... params) { + while (true) { + const auto rv = syscall(std::forward<ParamsT>(params)...); + if (rv >= 0) return rv; + if (errno == EINTR) { + // Retry on EINTR. + continue; + } else { + return ErrnoToCanonicalStatus(errno, ""); + } + } +} + +#if defined(IREE_HAS_PPOLL) + +// ppoll(), present on Linux. +// ppoll is preferred as it has a much better timing mechanism; poll can have a +// large slop on the deadline. +// Documentation: https://linux.die.net/man/2/poll +StatusOr<int> SystemPoll(absl::Span<pollfd> poll_fds, absl::Time deadline) { + // Convert the deadline into a tmo_p struct for ppoll that controls whether + // the call is blocking or non-blocking. Note that we must do this every + // iteration of the loop as a previous ppoll may have taken some of the + // time. + // + // See the ppoll docs for more information as to what the expected value is: + // http://man7.org/linux/man-pages/man2/poll.2.html + timespec timeout_spec; + timespec* tmo_p; + if (deadline == absl::InfinitePast()) { + // 0 for non-blocking. + timeout_spec = {0}; + tmo_p = &timeout_spec; + } else if (deadline == absl::InfiniteFuture()) { + // nullptr to ppoll() to block forever. + tmo_p = nullptr; + } else { + // Wait only for as much time as we have before the deadline is exceeded. + absl::Duration remaining_time = deadline - absl::Now(); + if (remaining_time < absl::ZeroDuration()) { + // Note: we likely have already bailed before getting here with a negative + // duration. + return DeadlineExceededErrorBuilder(IREE_LOC); + } + timeout_spec = absl::ToTimespec(remaining_time); + tmo_p = &timeout_spec; + } + return Syscall(::ppoll, poll_fds.data(), poll_fds.size(), tmo_p, nullptr); +} + +#elif defined(IREE_HAS_POLL) + +// poll(), present pretty much everywhere. +// Documentation: https://linux.die.net/man/2/poll +StatusOr<int> SystemPoll(absl::Span<pollfd> poll_fds, absl::Time deadline) { + int timeout; + if (deadline == absl::InfinitePast()) { + // Don't block. + timeout = 0; + } else if (deadline == absl::InfiniteFuture()) { + // Block forever. + timeout = -1; + } else { + absl::Duration remaining_time = deadline - absl::Now(); + if (remaining_time < absl::ZeroDuration()) { + return DeadlineExceededErrorBuilder(IREE_LOC); + } + timeout = static_cast<int>(absl::ToInt64Milliseconds(remaining_time)); + } + return Syscall(::poll, poll_fds.data(), poll_fds.size(), timeout); +} + +#else +#error "No SystemPoll implementation" +#endif // IREE_HAS_PPOLL / IREE_HAS_POLL / etc + +// Builds the list of pollfds to for ppoll wait on and will perform any +// required wait handle callbacks. +// +// The provided deadline will be observed if any of the wait handles needs to +// block for acquiring an fd. +StatusOr<absl::FixedArray<pollfd>> AcquireWaitHandles( + WaitHandle::WaitHandleSpan wait_handles, absl::Time deadline) { + absl::FixedArray<pollfd> poll_fds{wait_handles.size()}; + for (int i = 0; i < wait_handles.size(); ++i) { + poll_fds[i].events = POLLIN | POLLPRI | POLLERR | POLLHUP | POLLNVAL; + poll_fds[i].revents = 0; + // NOTE: poll will ignore any negative fds and our kInvalidFd == -1 so we + // can still put them in the list and it'll just skip them. + if (!wait_handles[i] || !wait_handles[i]->object()) { + poll_fds[i].fd = kInvalidFd; + continue; + } + + // Acquire the file descriptor for waiting. + // This may block (if |deadline| allows it) if the fd is not yet available. + // This is like a pre-wait for the actual poll operation. It can be bad with + // WaitAny, though we could handle that better here. + ASSIGN_OR_RETURN(auto fd_info, + wait_handles[i]->object()->AcquireFdForWait(deadline)); + poll_fds[i].fd = fd_info.second; + + // Abort if deadline exceeded. + if (deadline != absl::InfinitePast() && deadline < absl::Now()) { + return DeadlineExceededErrorBuilder(IREE_LOC) + << "Deadline exceeded acquiring for fds"; + } + } + return poll_fds; +} + +Status ClearFd(WaitableObject::FdType fd_type, int fd) { + // Read in a loop until the read would block. + // Depending on how the users setup the fd the act of reading may reset the + // entire handle (such as with the default eventfd mode) or multiple reads + // may be required (such as with semaphores). + while (true) { +#if defined(IREE_HAS_EVENTFD) + eventfd_t val = 0; + int rv = ::eventfd_read(fd, &val); +#elif defined(IREE_HAS_PIPE) + char buf; + int rv = ::read(fd, &buf, 1); +#else + return UnimplementedErrorBuilder(IREE_LOC) << "fd_type cannot be cleared"; +#endif // IREE_HAS_EVENTFD + if (rv != -1) { + // Success! Keep going. + continue; + } else { + if (errno == EWOULDBLOCK) { + // The read would have blocked meaning that we've hit the end and + // successfully cleared the fd. + return OkStatus(); + } else if (errno == EINTR) { + // Retry. + continue; + } else { + return ErrnoToCanonicalStatus(errno, "ClearFd failed"); + } + } + } +} + +// Performs a single poll on multiple fds and returns information about the +// signaled fds, if any. +Status MultiPoll(WaitHandle::WaitHandleSpan wait_handles, + absl::Span<pollfd> poll_fds, absl::Time deadline, + int* out_any_signaled_index, int* out_unsignaled_count) { + *out_any_signaled_index = -1; + *out_unsignaled_count = 0; + + // poll has a nasty behavior where it allows -1 for fds... except for at [0]. + // To keep the rest of the code sane we correct for that here as epoll doesn't + // have that behavior and we may want to special case this later. + bool any_valid_fds = true; + int swapped_zero_index = -1; + if (poll_fds[0].fd < 0) { + // Find a valid handle. + for (int i = 1; i < poll_fds.size(); ++i) { + if (poll_fds[i].fd > 0) { + swapped_zero_index = i; + std::swap(poll_fds[0], poll_fds[i]); + break; + } + } + if (swapped_zero_index == -1) { + // No valid handles found, meaning that all handles are invalid. + // We'll skip the wait below so we can share the processing code for any + // fds that may be kSignaledFd. + any_valid_fds = false; + } + } + + // Pass handles to ppoll. + // http://man7.org/linux/man-pages/man2/poll.2.html + if (any_valid_fds) { + ASSIGN_OR_RETURN(int rv, SystemPoll(poll_fds, deadline)); + if (rv == 0) { + // Call timed out and no descriptors were ready. + // If this was just a poll then that's fine. + return DeadlineExceededErrorBuilder(IREE_LOC); + } + } + + // If we had swapped fds[0] above we need to correct for that now. + if (swapped_zero_index != -1) { + std::swap(poll_fds[0], poll_fds[swapped_zero_index]); + } + + // |rv| denotes the number of fds that were ready. Run through the list and + // find the ones that were ready and mark them as completed. + for (int i = 0; i < poll_fds.size(); ++i) { + if (poll_fds[i].fd == kSignaledFd || poll_fds[i].revents == POLLIN) { + // First attempt any resolve actions. If these fail we can't consider the + // fd as having been signaled. + ASSIGN_OR_RETURN( + bool resolved, + wait_handles[i]->object()->TryResolveWakeOnFd(poll_fds[i].fd)); + if (!resolved) { + ++(*out_unsignaled_count); + continue; + } + + // Successful wait. Kill the fd so it is ignored on the next poll. + poll_fds[i].fd = kInvalidFd; + *out_any_signaled_index = i; + } else if (poll_fds[i].revents) { + if (poll_fds[i].revents & POLLERR) { + return InternalErrorBuilder(IREE_LOC); + } else if (poll_fds[i].revents & POLLHUP) { + return CancelledErrorBuilder(IREE_LOC); + } else if (poll_fds[i].revents & POLLNVAL) { + return InvalidArgumentErrorBuilder(IREE_LOC); + } else { + return UnknownErrorBuilder(IREE_LOC); + } + } else if (poll_fds[i].fd != kInvalidFd) { + ++(*out_unsignaled_count); + } + } + + return OkStatus(); +} + +} // namespace + +// static +std::atomic<uint64_t> WaitHandle::next_unique_id_{1}; + +// static +WaitHandle WaitHandle::AlwaysSignaling() { + class AlwaysSignalingObject : public WaitableObject { + public: + std::string DebugString() const override { return "signal"; } + StatusOr<std::pair<FdType, int>> AcquireFdForWait( + absl::Time deadline) override { + return std::make_pair(FdType::kPermanent, kSignaledFd); + } + StatusOr<bool> TryResolveWakeOnFd(int fd) override { return true; } + }; + static auto* obj = new AlwaysSignalingObject(); + return WaitHandle(add_ref(obj)); +} + +// static +WaitHandle WaitHandle::AlwaysFailing() { + class AlwaysFailingObject : public WaitableObject { + public: + std::string DebugString() const override { return "fail"; } + StatusOr<std::pair<FdType, int>> AcquireFdForWait( + absl::Time deadline) override { + return InternalErrorBuilder(IREE_LOC) << "AlwaysFailingObject"; + } + StatusOr<bool> TryResolveWakeOnFd(int fd) override { + return InternalErrorBuilder(IREE_LOC) << "AlwaysFailingObject"; + } + }; + static auto* obj = new AlwaysFailingObject(); + return WaitHandle(add_ref(obj)); +} + +// static +Status WaitHandle::WaitAll(WaitHandleSpan wait_handles, absl::Time deadline) { + if (wait_handles.empty()) return OkStatus(); + + // Build the list of pollfds to wait on. + ASSIGN_OR_RETURN(auto poll_fds, AcquireWaitHandles(wait_handles, deadline)); + + // Loop until all handles have been signaled or the deadline is exceeded. + int unsignaled_count = 0; + do { + int any_signaled_index = 0; + RETURN_IF_ERROR(MultiPoll(wait_handles, absl::MakeSpan(poll_fds), deadline, + &any_signaled_index, &unsignaled_count)); + } while (unsignaled_count > 0 && absl::Now() < deadline); + + if (unsignaled_count == 0) { + // All waits resolved. + return OkStatus(); + } else { + // One or more were unsignaled. + return DeadlineExceededErrorBuilder(IREE_LOC); + } +} + +// static +StatusOr<bool> WaitHandle::TryWaitAll(WaitHandleSpan wait_handles) { + auto status = WaitAll(wait_handles, absl::InfinitePast()); + if (status.ok()) { + return true; + } else if (IsDeadlineExceeded(status)) { + return false; + } + return status; +} + +// static +StatusOr<int> WaitHandle::WaitAny(WaitHandleSpan wait_handles, + absl::Time deadline) { + if (wait_handles.empty()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "At least one wait handle is required for WaitAny"; + } + + // Build the list of pollfds to wait on. + ASSIGN_OR_RETURN(auto poll_fds, AcquireWaitHandles(wait_handles, deadline)); + + // Poll once; this makes a WaitAny just a WaitMulti that doesn't loop. + int any_signaled_index = -1; + int unsignaled_count = 0; + RETURN_IF_ERROR(MultiPoll(wait_handles, absl::MakeSpan(poll_fds), deadline, + &any_signaled_index, &unsignaled_count)); + if (any_signaled_index == -1) { + // No wait handles were valid. Pretend 0 was signaled. + return 0; + } + return any_signaled_index; +} + +// static +StatusOr<int> WaitHandle::TryWaitAny(WaitHandleSpan wait_handles) { + auto status_or = WaitAny(wait_handles, absl::InfinitePast()); + return IsDeadlineExceeded(status_or.status()) ? -1 : status_or; +} + +// Storage for static class variables; these won't be needed when we can use +// c++17 everywhere. +constexpr int WaitableObject::kInvalidFd; +constexpr int WaitableObject::kSignaledFd; + +WaitHandle::WaitHandle(ref_ptr<WaitableObject> object) + : unique_id_(++next_unique_id_), object_(std::move(object)) {} + +WaitHandle::~WaitHandle() { Dispose(); } + +void WaitHandle::Dispose() { object_.reset(); } + +WaitHandle::WaitHandle(WaitHandle&& other) + : unique_id_(other.unique_id_), object_(std::move(other.object_)) { + other.unique_id_ = 0; +} + +WaitHandle& WaitHandle::operator=(WaitHandle&& other) { + if (this != std::addressof(other)) { + // Close current handle. + Dispose(); + + // Take ownership of handle and resources. + object_ = std::move(other.object_); + + other.unique_id_ = ++next_unique_id_; + } + return *this; +} + +std::string WaitHandle::DebugString() const { + return object_ ? object_->DebugString() : absl::StrCat("wh_", unique_id_); +} + +StatusOr<bool> WaitHandle::TryWait() { + auto status = WaitAll({this}, absl::InfinitePast()); + if (status.ok()) { + return true; + } else if (IsDeadlineExceeded(status)) { + return false; + } + return status; +} + +ManualResetEvent::ManualResetEvent(const char* debug_name) + : debug_name_(debug_name) { + Initialize(); +} + +ManualResetEvent::~ManualResetEvent() { Dispose(); } + +void ManualResetEvent::Initialize() { +#if defined(IREE_HAS_EVENTFD) + // Create with an eventfd by default when we support it. + // eventfd has lower overhead than pipes (the syscalls are cheap). + // This usually will only fail if the system is completely out of handles. + // + // Docs: http://man7.org/linux/man-pages/man2/eventfd.2.html + fd_type_ = FdType::kEventFd; + fd_ = Syscall(::eventfd, 0, EFD_CLOEXEC | EFD_NONBLOCK).ValueOrDie(); +#elif defined(IREE_HAS_PIPE) + // Android/Linux/iOS-compatible POSIX pipe handle. + // Two handles are generated: one for transmitting and one for receiving. + // + // Docs: http://man7.org/linux/man-pages/man2/pipe.2.html + fd_type_ = FdType::kPipe; + int pipefd[2]; + Syscall(::pipe, pipefd).ValueOrDie(); + Syscall(::fcntl, pipefd[0], F_SETFL, O_NONBLOCK).ValueOrDie(); + fd_ = pipefd[0]; + write_fd_ = pipefd[1]; +#else +// NOTE: sync_file does not use Notifier as they come from the kernel. +#error "No fd-based sync primitive on this platform" +#endif // IREE_HAS_EVENTFD / IREE_HAS_PIPE / etc +} + +void ManualResetEvent::Dispose() { + if (fd_ != kInvalidFd) { + // Always signal, as we need to ensure waiters are woken. + CHECK_OK(Set()); + Syscall(::close, fd_).ValueOrDie(); + fd_ = kInvalidFd; + } + if (write_fd_ != kInvalidFd) { + Syscall(::close, write_fd_).ValueOrDie(); + write_fd_ = kInvalidFd; + } +} + +ManualResetEvent::ManualResetEvent(ManualResetEvent&& other) + : fd_type_(other.fd_type_), + fd_(other.fd_), + write_fd_(other.write_fd_), + debug_name_(other.debug_name_) { + other.fd_type_ = FdType::kPermanent; + other.fd_ = kInvalidFd; + other.write_fd_ = kInvalidFd; + other.debug_name_ = nullptr; +} + +ManualResetEvent& ManualResetEvent::operator=(ManualResetEvent&& other) { + if (this != std::addressof(other)) { + Dispose(); + fd_type_ = other.fd_type_; + fd_ = other.fd_; + write_fd_ = other.write_fd_; + debug_name_ = other.debug_name_; + other.fd_type_ = FdType::kPermanent; + other.fd_ = kInvalidFd; + other.write_fd_ = kInvalidFd; + other.debug_name_ = nullptr; + other.Initialize(); + } + return *this; +} + +std::string ManualResetEvent::DebugString() const { + if (debug_name_) { + return debug_name_; + } +#if defined(IREE_HAS_EVENTFD) + return absl::StrCat("eventfd_", fd_); +#elif defined(IREE_HAS_PIPE) + return absl::StrCat("pipe_", fd_, "_", write_fd_); +#else + return absl::StrCat("unknown_", fd_, "_", write_fd_); +#endif // IREE_HAS_EVENTFD / IREE_HAS_PIPE +} + +Status ManualResetEvent::Set() { +#if defined(IREE_HAS_EVENTFD) + return Syscall(::eventfd_write, fd_, 1ull).status(); +#elif defined(IREE_HAS_PIPE) + char buf = '\n'; + return Syscall(::write, write_fd_, &buf, 1).status(); +#else + return UnimplementedErrorBuilder(IREE_LOC) + << "No fd-based sync primitive on this platform"; +#endif // IREE_HAS_EVENTFD / IREE_HAS_PIPE +} + +Status ManualResetEvent::Reset() { return ClearFd(fd_type_, fd_); } + +WaitHandle ManualResetEvent::OnSet() { return WaitHandle(add_ref(this)); } + +} // namespace iree
diff --git a/base/wait_handle.h b/base/wait_handle.h new file mode 100644 index 0000000..8cf8946 --- /dev/null +++ b/base/wait_handle.h
@@ -0,0 +1,321 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BASE_WAIT_HANDLE_H_ +#define IREE_BASE_WAIT_HANDLE_H_ + +#include <atomic> +#include <cstdint> +#include <string> +#include <utility> + +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "base/ref_ptr.h" +#include "base/status.h" +#include "base/time.h" + +namespace iree { + +// Interfaces for waitable objects that can produce WaitHandles. +// WaitableObjects are much like ::thread::Selectable, only they support both +// the classic locking style as well as file descriptors for use with select(). +// +// Usage: +// class MyWaitableObject : public WaitableObject { +// public: +// std::string DebugString() const override { return "something useful"; } +// WaitHandle OnAsyncTask() { +// return WaitHandle(retain_ref(this)); +// } +// private: +// StatusOr<std::pair<FdType, int>> AcquireFdForWait( +// absl::Time deadline) override { +// // If blocking traditionally do so now and then return this: +// return std::make_pair(FdType::kPermanent, kSignaledFd); +// // Otherwise, see ManualResetEvent for an example using fds. +// } +// StatusOr<bool> TryResolveWakeOnFd(int fd) override { +// // Return true iff the object is really acquired, such as the semaphore +// // being decremented. +// return true; +// } +// }; +class WaitableObject : public RefObject<WaitableObject> { + public: + // Indicates that a file descriptor is invalid. It will not block when waited + // upon. + constexpr static int kInvalidFd = -1; + // Indicates that a file descriptor should be treated as signaled. + // Waiting on this fd should return as if it has already been signaled. + constexpr static int kSignaledFd = -2; + + // Defines the type of the native handle used for synchronization. + enum class FdType : uint16_t { + // Event has no handle and should be treated as permanently signaled. + kPermanent, + + // Android/Linux/iOS-compatible POSIX pipe handle. + // Two handles are generated: one for transmitting and one for receiving. + // + // More information: + // http://man7.org/linux/man-pages/man2/pipe.2.html + kPipe, + + // Android/Linux eventfd handle. + // These are akin to pipe() but require only a single handle and have + // significantly lower overhead (equivalent if not slightly better than + // pthreads condvars). + // + // eventfds support acting as both semaphores and auto reset events. + // + // More information: + // http://man7.org/linux/man-pages/man2/eventfd.2.html + kEventFd, + + // Android/Linux sync_file handle (aka 'sync fence'). + // The handle is allocated indirectly by the device driver via the + // <linux/sync_file.h> API. It may be waited upon with poll(), select(), or + // epoll() and must be closed with close() when no longer required. If + // waiting on multiple sync_files the caller should first merge them + // together. + // + // A sync_file must only be used as fences (one-shot manual reset events). + // + // More information: + // https://www.kernel.org/doc/Documentation/sync_file.txt + // https://lwn.net/Articles/702339/ + // https://source.android.com/devices/graphics/implement-vsync#explicit_synchronization + kSyncFile, + }; + + virtual ~WaitableObject() = default; + + // Returns a string representing the object, either specified as a debug_name + // or a unique ID. + virtual std::string DebugString() const = 0; + + // Attempts to acquire a file descriptor for the waitable objects by the given + // |deadline|. In many cases this will return immediately with a valid fd. + // + // In cases where the file descriptor may not be available the call may block + // until either it is available or the |deadline| has elapsed. Use + // absl::InfinitePast() to prevent blocking. + // + // Returns a valid file descriptor or kInvalidFd as an indication that the + // object should not be waited on (already signaled, etc). Can return + // kSignaledFd to indicate that it's already known that the handle has been + // signaled and the caller should resolve as if it caused a wake normally. + virtual StatusOr<std::pair<FdType, int>> AcquireFdForWait( + absl::Time deadline) = 0; + + // Tries to resolve the object with the given |fd|. + // In many cases this will no-op, however some types may require additional + // checks to ensure that the wait operation succeeded (such as semaphores + // that may need to query a count). If resolution fails the waitable object + // must not be considered signaled. This call will never block. + virtual StatusOr<bool> TryResolveWakeOnFd(int fd) = 0; +}; + +// Handle to waitable objects. +// WaitHandles are created by a particular synchronization primitive, such as +// Fence, as a way for one or more observers to poll or wait for notification. +// +// External synchronization primitives can be wrapped in WaitHandles to enable +// other libraries or languages to be waited on alongside WaitHandles created +// by the IREE primitives like Fence. See the notes on WaitHandleType for a list +// of handle types that are supported. +// +// Wait handles are thread-safe in that multiple threads may be waiting on them +// concurrently. +class WaitHandle { + public: + // Returns a WaitHandle that when waited on will never block. + static WaitHandle AlwaysSignaling(); + + // Returns a WaitHandle that when waited on will always fail. + static WaitHandle AlwaysFailing(); + + using WaitHandleSpan = absl::Span<WaitHandle* const>; + + // Blocks the caller until all passed |wait_handles| are signaled or the + // |deadline| elapses. + // + // Returns success if the wait is successful and all events have been + // signaled. + // + // Returns DEADLINE_EXCEEDED if the |deadline| elapses without all handles + // having been signaled. Note that a subset of the |wait_handles| may have + // been signaled and each can be queried to see which one. + static Status WaitAll(WaitHandleSpan wait_handles, absl::Time deadline); + static Status WaitAll(WaitHandleSpan wait_handles, absl::Duration timeout) { + return WaitAll(wait_handles, RelativeTimeoutToDeadline(timeout)); + } + static Status WaitAll(WaitHandleSpan wait_handles) { + return WaitAll(wait_handles, absl::InfiniteFuture()); + } + + // Tries waiting on the handles and returns immediately if it would have + // blocked. The caller will not be blocked even if a handle has not yet been + // signaled. + // + // Returns true if all handles have been signaled. + static StatusOr<bool> TryWaitAll(WaitHandleSpan wait_handles); + + // Blocks the caller until at least one of the |wait_handles| is signaled or + // the |deadline| elapses. + // + // Returns the index into |wait_handles| of a handle that was signaled. Note + // that more than one handle may have been signaled and all of the other + // |wait_handles| should be queried or waited on again until waits for them + // succeed. + // + // Returns DEADLINE_EXCEEDED if the |deadline| elapses without any handles + // having been signaled. + static StatusOr<int> WaitAny(WaitHandleSpan wait_handles, + absl::Time deadline); + static StatusOr<int> WaitAny(WaitHandleSpan wait_handles, + absl::Duration timeout) { + return WaitAny(wait_handles, RelativeTimeoutToDeadline(timeout)); + } + static StatusOr<int> WaitAny(WaitHandleSpan wait_handles) { + return WaitAny(wait_handles, absl::InfiniteFuture()); + } + + // Tries waiting for at least one handle to complete and returns immediately + // if none have been. The caller will not be blocked even if a handle has not + // yet been signaled. + // + // Returns the index into |wait_handles| of a handle that was signaled. Note + // that more than one handle may have been signaled and all of the other + // |wait_handles| should be queried or waited on again until waits for them + // succeed. + // + // Returns -1 if no handles were signaled. + static StatusOr<int> TryWaitAny(WaitHandleSpan wait_handles); + + // Default constructor creates a permanently signaled handle. + // Waiting on this handle will never block. + WaitHandle() = default; + + // Wraps an existing sync file descriptor. + // Ownership of the file descriptor is transferred to the WaitHandle and must + // be duplicated by the caller if they want to continue using it. + explicit WaitHandle(ref_ptr<WaitableObject> object); + + ~WaitHandle(); + + // Copying not supported. Create a new WaitHandle from the source. + WaitHandle(const WaitHandle&) = delete; + WaitHandle& operator=(const WaitHandle&) = delete; + + // Moving supported; sync primitive ownership is transferred. + WaitHandle(WaitHandle&& other); + WaitHandle& operator=(WaitHandle&& other); + + // Unique ID for the WaitHandle instance. + // Two wait handles, even if waiting on the same underlying primitive, will + // have differing unique_ids. This can be used for deduping the handles or + // storing handles in a map. + uint64_t unique_id() const { return unique_id_; } + + // Returns a unique string representing the handle. + std::string DebugString() const; + + // Blocks the caller until the handle is signaled or the |deadline| elapses. + // + // If waiting on multiple wait handles use WaitAll or WaitAny instead of + // multiple calls to Wait as they can significantly reduce overhead. + // + // Returns success if the wait is successful and the |wait_handle| was + // signaled. Returns DEADLINE_EXCEEDED if the timeout elapses without the + // handle having been signaled. + Status Wait(absl::Time deadline) { return WaitAll({this}, deadline); } + Status Wait(absl::Duration timeout) { + return WaitAll({this}, RelativeTimeoutToDeadline(timeout)); + } + Status Wait() { return WaitAll({this}, absl::InfiniteFuture()); } + + // Tries waiting on the handle and returns immediately if it would have + // waited. The caller will not be blocked even if the handle has not yet been + // signaled. + // + // Returns true if the handle has been signaled. + StatusOr<bool> TryWait(); + + // These accessors should generally be considered opaque but may be useful to + // code trying to interop with other runtimes. + const ref_ptr<WaitableObject>& object() const { return object_; } + + private: + // Disposes the handle by closing the fd and issuing callbacks. + void Dispose(); + + static std::atomic<uint64_t> next_unique_id_; + + uint64_t unique_id_ = 0; + ref_ptr<WaitableObject> object_; +}; + +// A manually-resettable event primitive. +// Effectively a binary semaphore with a maximum_count of 1 when running in +// auto-reset mode but also provides a sticky manual reset mode. +class ManualResetEvent : public WaitableObject { + public: + explicit ManualResetEvent(const char* debug_name = nullptr); + + ~ManualResetEvent() override; + + // Copying not supported. + ManualResetEvent(const ManualResetEvent&) = delete; + ManualResetEvent& operator=(const ManualResetEvent&) = delete; + + // Moving supported; sync primitive ownership is transferred. + ManualResetEvent(ManualResetEvent&& other); + ManualResetEvent& operator=(ManualResetEvent&& other); + + std::string DebugString() const override; + + // Sets the specified event object to the signaled state. + // The event stays signaled until Reset is called. Multiple waiters will be + // woken. + Status Set(); + + // Resets the specified event object to the nonsignaled state. + // Resetting an event that is already reset has no effect. + Status Reset(); + + // Returns a WaitHandle that will be signaled when the event is set. + WaitHandle OnSet(); + + protected: + void Initialize(); + void Dispose(); + + StatusOr<std::pair<FdType, int>> AcquireFdForWait( + absl::Time deadline) override { + return std::make_pair(fd_type_, fd_); + } + StatusOr<bool> TryResolveWakeOnFd(int fd) override { return true; } + + FdType fd_type_ = FdType::kPermanent; + int fd_ = kInvalidFd; + int write_fd_ = kInvalidFd; // Used only for fd_type_ == kPipe. + const char* debug_name_ = nullptr; +}; + +} // namespace iree + +#endif // IREE_BASE_WAIT_HANDLE_H_
diff --git a/base/wait_handle_test.cc b/base/wait_handle_test.cc new file mode 100644 index 0000000..ec1f28e --- /dev/null +++ b/base/wait_handle_test.cc
@@ -0,0 +1,555 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/wait_handle.h" + +#include <unistd.h> + +#include <string> +#include <thread> // NOLINT +#include <type_traits> + +#include "absl/time/time.h" +#include "base/status.h" +#include "base/status_matchers.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +// StatusOr<bool> will be true if the status is ok, which is bad. +#define ASSERT_STATUSOR_TRUE(x) ASSERT_TRUE(x.ValueOrDie()) +#define ASSERT_STATUSOR_FALSE(x) ASSERT_FALSE(x.ValueOrDie()) + +namespace iree { +namespace { + +using ::testing::_; +using ::testing::Return; + +// Tests the AlwaysSignaling helper. +TEST(WaitHandleTest, AlwaysSignaling) { + ASSERT_OK(WaitHandle::AlwaysSignaling().Wait()); + EXPECT_FALSE(WaitHandle::AlwaysSignaling().DebugString().empty()); +} + +// Tests the AlwaysFailing helper. +TEST(WaitHandleTest, AlwaysFailing) { + ASSERT_FALSE(WaitHandle::AlwaysFailing().Wait().ok()); + EXPECT_FALSE(WaitHandle::AlwaysFailing().DebugString().empty()); +} + +// Tests the basic lifecycle of a permanently signaled wait handle. +TEST(WaitHandleTest, LifecyclePermanentSignaled) { + // Just to be sure it's ok to safely no-op a WaitHandle value. + WaitHandle wh_never_used; + (void)wh_never_used; + + // Try waiting; should return immediately. + WaitHandle wh0; + ASSERT_OK(wh0.Wait()); + + // Waits on multiple permanent handles should be ok. + WaitHandle wh1; + ASSERT_OK(WaitHandle::WaitAll({&wh0, &wh1})); +} + +// Tests moving permanent WaitHandles around. +TEST(WaitHandleTest, MovePermanent) { + WaitHandle wh0; + WaitHandle wh1{std::move(wh0)}; + WaitHandle wh2 = std::move(wh1); + wh1 = std::move(wh2); +} + +// Tests moving around real handles (that may require closing). +TEST(WaitHandleTest, MoveRealHandle) { + ManualResetEvent fence0; + WaitHandle wh0 = fence0.OnSet(); + WaitHandle wh1{std::move(wh0)}; + WaitHandle wh2 = std::move(wh1); + wh1 = std::move(wh2); + + // Now overwrite the handle value to force a close. + ManualResetEvent fence1; + WaitHandle wh3 = fence1.OnSet(); + wh1 = std::move(wh3); + wh1 = WaitHandle(); // Ensure handle dies first. +} + +// Tests the various forms of waiting on a single WaitHandle. +// Since these just call WaitAll we leave the involved testing to those. +TEST(WaitHandleTest, SingleWait) { + WaitHandle wh; + ASSERT_OK(wh.Wait()); + ASSERT_OK(wh.Wait(absl::Now() + absl::Seconds(1))); + ASSERT_OK(wh.Wait(absl::Seconds(1))); + ASSERT_STATUSOR_TRUE(wh.TryWait()); +} + +// Tests using WaitAll with no valid handles. This should no-op. +TEST(WaitHandleTest, WaitAllNop) { + ASSERT_OK(WaitHandle::WaitAll({})); + ASSERT_OK(WaitHandle::WaitAll({nullptr})); + ASSERT_OK(WaitHandle::WaitAll({nullptr, nullptr})); +} + +// Tests polling with WaitAll with multiple wait handles. +TEST(WaitHandleTest, WaitAllPoll) { + ManualResetEvent fence0; + WaitHandle wh0 = fence0.OnSet(); + ManualResetEvent fence1; + WaitHandle wh1 = fence1.OnSet(); + + // Poll; should return immediately with timeout. + ASSERT_TRUE(IsDeadlineExceeded( + WaitHandle::WaitAll({&wh0, &wh1}, absl::InfinitePast()))); + + // Notify fence1. + ASSERT_OK(fence1.Set()); + + // Poll; should return immediately with timeout as fence1 is not signaled. + ASSERT_TRUE(IsDeadlineExceeded( + WaitHandle::WaitAll({&wh0, &wh1}, absl::InfinitePast()))); + + // Notify fence0. + ASSERT_OK(fence0.Set()); + + // Poll again; should return immediately with success. + ASSERT_OK(WaitHandle::WaitAll({&wh0, &wh1}, absl::InfinitePast())); +} + +// Tests waiting when the first file handle is invalid. This is to verify a +// workaround for bad poll() behavior with fds[0] == -1. +TEST(WaitHandleTest, WaitAllWithInvalid0) { + ManualResetEvent fence; + WaitHandle wh = fence.OnSet(); + + // Poll; should return immediately with timeout as fence is not signaled. + ASSERT_TRUE(IsDeadlineExceeded( + WaitHandle::WaitAll({nullptr, &wh}, absl::InfinitePast()))); + + // Notify fence. + ASSERT_OK(fence.Set()); + + // Poll again; should return immediately with success. + ASSERT_OK(WaitHandle::WaitAll({nullptr, &wh}, absl::InfinitePast())); +} + +// Tests exceeding the timeout deadline with WaitAll. +TEST(WaitHandleTest, WaitAllTimeout) { + ManualResetEvent fence; + WaitHandle wh = fence.OnSet(); + + // Wait with timeout on the unsignaled fence: + // Via polling (should never block): + ASSERT_TRUE( + IsDeadlineExceeded(WaitHandle::WaitAll({&wh}, absl::InfinitePast()))); + ASSERT_STATUSOR_FALSE(WaitHandle::TryWaitAll({&wh})); + // Via time in the near future (should block): + ASSERT_TRUE( + IsDeadlineExceeded(WaitHandle::WaitAll({&wh}, absl::Milliseconds(250)))); + // Via time in the past, should exceed deadline. + ASSERT_TRUE( + IsDeadlineExceeded(WaitHandle::WaitAll({&wh}, absl::Milliseconds(-250)))); + + // Notify and ensure no more timeouts. + ASSERT_OK(fence.Set()); + ASSERT_OK(WaitHandle::WaitAll({&wh}, absl::InfinitePast())); + ASSERT_STATUSOR_TRUE(WaitHandle::TryWaitAll({&wh})); + ASSERT_OK(WaitHandle::WaitAll({&wh}, absl::Milliseconds(250))); + + // Via time in the past, should exceed deadline even if signaled. + ASSERT_TRUE( + IsDeadlineExceeded(WaitHandle::WaitAll({&wh}, absl::Milliseconds(-250)))); +} + +// Tests using WaitAll to wait on other threads. +TEST(WaitHandleTest, WaitAllThreaded) { + // Spin up two threads. + ManualResetEvent fence0; + std::thread t0{[&]() { + ::usleep(absl::ToInt64Microseconds(absl::Milliseconds(250))); + ASSERT_OK(fence0.Set()); + }}; + ManualResetEvent fence1; + std::thread t1{[&]() { + ::usleep(absl::ToInt64Microseconds(absl::Milliseconds(250))); + ASSERT_OK(fence1.Set()); + }}; + + // Wait on both threads to complete. + WaitHandle wh0 = fence0.OnSet(); + WaitHandle wh1 = fence1.OnSet(); + ASSERT_OK(WaitHandle::WaitAll({&wh0, &wh1})); + + t0.join(); + t1.join(); +} + +// Tests using WaitAll with multiple wait handles from the same fence. +TEST(WaitHandleTest, WaitAllSameSource) { + ManualResetEvent fence; + WaitHandle wh0 = fence.OnSet(); + WaitHandle wh1 = fence.OnSet(); + ASSERT_TRUE(IsDeadlineExceeded( + WaitHandle::WaitAll({&wh0, &wh1}, absl::InfinitePast()))); + ASSERT_OK(fence.Set()); + ASSERT_OK(WaitHandle::WaitAll({&wh0, &wh1})); +} + +// Tests using WaitAll with literally the same wait handles. +TEST(WaitHandleTest, WaitAllSameHandle) { + ManualResetEvent fence; + WaitHandle wh = fence.OnSet(); + ASSERT_TRUE(IsDeadlineExceeded( + WaitHandle::WaitAll({&wh, &wh}, absl::InfinitePast()))); + ASSERT_OK(fence.Set()); + ASSERT_OK(WaitHandle::WaitAll({&wh, &wh})); +} + +// Tests WaitAll when a wait handle fails. +TEST(WaitHandleTest, WaitAllFailure) { + WaitHandle good_wh; + // Create a purposefully bad handle to induce an error. + WaitHandle bad_wh = WaitHandle::AlwaysFailing(); + // Should fail with some posixy error. + ASSERT_FALSE(WaitHandle::WaitAll({&good_wh, &bad_wh}).ok()); +} + +// Tests using WaitAny with no valid handles. This should no-op. +TEST(WaitHandleTest, WaitAnyNop) { + ASSERT_TRUE(IsInvalidArgument(WaitHandle::WaitAny({}).status())); + ASSERT_OK_AND_ASSIGN(int index, WaitHandle::WaitAny({nullptr})); + ASSERT_EQ(0, index); + ASSERT_OK_AND_ASSIGN(index, WaitHandle::WaitAny({nullptr, nullptr})); + ASSERT_EQ(0, index); +} + +// Tests polling with WaitAny with multiple wait handles. +TEST(WaitHandleTest, WaitAnyPoll) { + ManualResetEvent fence0; + WaitHandle wh0 = fence0.OnSet(); + ManualResetEvent fence1; + WaitHandle wh1 = fence1.OnSet(); + + // Poll; should return immediately with timeout. + ASSERT_TRUE(IsDeadlineExceeded( + WaitHandle::WaitAny({&wh0, &wh1}, absl::InfinitePast()).status())); + + // Notify fence1. + ASSERT_OK(fence1.Set()); + + // Poll; should return immediately with fence1 signaled. + ASSERT_OK_AND_ASSIGN(int index, + WaitHandle::WaitAny({&wh0, &wh1}, absl::InfinitePast())); + EXPECT_EQ(1, index); + + // Notify fence0. + ASSERT_OK(fence0.Set()); + + // Poll again; should return immediately; which one is signaled is undefined. + ASSERT_OK_AND_ASSIGN(index, + WaitHandle::WaitAny({&wh0, &wh1}, absl::InfinitePast())); + ASSERT_TRUE(index == 0 || index == 1); +} + +// Tests exceeding the timeout deadline with WaitAny. +TEST(WaitHandleTest, WaitAnyTimeout) { + ManualResetEvent fence0; + WaitHandle wh0 = fence0.OnSet(); + ManualResetEvent fence1; + WaitHandle wh1 = fence1.OnSet(); + + // Wait with timeout on the unsignaled fences: + // Via polling (should never block): + ASSERT_TRUE(IsDeadlineExceeded( + WaitHandle::WaitAny({&wh0, &wh1}, absl::InfinitePast()).status())); + ASSERT_OK_AND_ASSIGN(int index, WaitHandle::TryWaitAny({&wh0, &wh1})); + ASSERT_EQ(-1, index); + // Via time in the near future (should block): + ASSERT_TRUE(IsDeadlineExceeded( + WaitHandle::WaitAny({&wh0, &wh1}, absl::Milliseconds(250)).status())); + + // Notify one of the fences. Should return immediately. + ASSERT_OK(fence1.Set()); + ASSERT_OK_AND_ASSIGN(index, + WaitHandle::WaitAny({&wh0, &wh1}, absl::InfinitePast())); + ASSERT_EQ(1, index); + ASSERT_OK_AND_ASSIGN(index, WaitHandle::TryWaitAny({&wh0, &wh1})); + ASSERT_EQ(1, index); + ASSERT_OK_AND_ASSIGN( + index, WaitHandle::WaitAny({&wh0, &wh1}, absl::Milliseconds(250))); + ASSERT_EQ(1, index); + + // The unnotified fence should still timeout. + ASSERT_TRUE(IsDeadlineExceeded( + WaitHandle::WaitAny({&wh0}, absl::InfinitePast()).status())); + ASSERT_OK_AND_ASSIGN(index, WaitHandle::TryWaitAny({&wh0})); + ASSERT_EQ(-1, index); + ASSERT_TRUE(IsDeadlineExceeded( + WaitHandle::WaitAny({&wh0}, absl::Milliseconds(250)).status())); + + // Notify last fence and ensure complete. + ASSERT_OK(fence0.Set()); + ASSERT_OK_AND_ASSIGN(index, + WaitHandle::WaitAny({&wh0}, absl::InfinitePast())); + ASSERT_EQ(0, index); + ASSERT_OK_AND_ASSIGN(index, WaitHandle::TryWaitAny({&wh0})); + ASSERT_EQ(0, index); + ASSERT_OK_AND_ASSIGN(index, + WaitHandle::WaitAny({&wh0}, absl::Milliseconds(250))); + ASSERT_EQ(0, index); +} + +// Tests using WaitAny to wait on other threads. +TEST(WaitHandleTest, WaitAnyThreaded) { + // Spin up two threads. + // t1 will wait on t0 such that they will act in sequence. + ManualResetEvent fence0; + std::thread t0{[&]() { + ::usleep(absl::ToInt64Microseconds(absl::Milliseconds(250))); + ASSERT_OK(fence0.Set()); + }}; + ManualResetEvent fence1; + std::thread t1{[&]() { + ASSERT_OK(fence0.OnSet().Wait()); + ::usleep(absl::ToInt64Microseconds(absl::Milliseconds(250))); + ASSERT_OK(fence1.Set()); + }}; + + // Wait on both threads. We expect 0 to complete first. + WaitHandle wh0 = fence0.OnSet(); + WaitHandle wh1 = fence1.OnSet(); + ASSERT_OK_AND_ASSIGN(int index, WaitHandle::WaitAny({&wh0, &wh1})); + ASSERT_EQ(0, index); + + // Now wait for thread 1. + ASSERT_OK_AND_ASSIGN(index, WaitHandle::WaitAny({&wh1})); + ASSERT_EQ(0, index); + + t0.join(); + t1.join(); +} + +// Tests using WaitAny with multiple wait handles from the same fence. +TEST(WaitHandleTest, WaitAnySameSource) { + ManualResetEvent fence; + WaitHandle wh0 = fence.OnSet(); + WaitHandle wh1 = fence.OnSet(); + ASSERT_TRUE(IsDeadlineExceeded( + WaitHandle::WaitAny({&wh0, &wh1}, absl::InfinitePast()).status())); + ASSERT_OK(fence.Set()); + ASSERT_OK_AND_ASSIGN(int index, WaitHandle::WaitAny({&wh0, &wh1})); + ASSERT_TRUE(index == 0 || index == 1); +} + +// Tests using WaitAny with literally the same wait handles. +TEST(WaitHandleTest, WaitAnySameHandle) { + ManualResetEvent fence; + WaitHandle wh = fence.OnSet(); + ASSERT_TRUE(IsDeadlineExceeded( + WaitHandle::WaitAny({&wh, &wh}, absl::InfinitePast()).status())); + ASSERT_OK(fence.Set()); + ASSERT_OK_AND_ASSIGN(int index, WaitHandle::WaitAny({&wh, &wh})); + ASSERT_TRUE(index == 0 || index == 1); +} + +// Tests WaitAny when a wait handle fails. +TEST(WaitHandleTest, WaitAnyFailure) { + WaitHandle good_wh; + // Create a purposefully bad handle to induce an error. + WaitHandle bad_wh = WaitHandle::AlwaysFailing(); + // Should fail with some posixy error. + ASSERT_FALSE(WaitHandle::WaitAny({&good_wh, &bad_wh}).ok()); +} + +// ManualResetEvent with innards exposed. Meh. +class ExposedManualResetEvent : public ManualResetEvent { + public: + using ManualResetEvent::AcquireFdForWait; + using ManualResetEvent::TryResolveWakeOnFd; +}; + +// Mock type for the WaitableObject methods. +class MockWaitableObject : public ::testing::StrictMock<WaitableObject> { + public: + MockWaitableObject() : ::testing::StrictMock<WaitableObject>() {} + + MOCK_CONST_METHOD0(DebugString, std::string()); + MOCK_METHOD1(AcquireFdForWait, + StatusOr<std::pair<FdType, int>>(absl::Time deadline)); + MOCK_METHOD1(TryResolveWakeOnFd, StatusOr<bool>(int fd)); + + WaitHandle OnSomething() { return WaitHandle(add_ref(this)); } +}; + +// Tests normal AcquireFdForWait + TryResolveWakeOnFd use. +TEST(WaitableObjectTest, AcquireAndResolve) { + MockWaitableObject mwo; + WaitHandle wh = mwo.OnSomething(); + + // Use a MRE for testing, as we can just use its fd. + ExposedManualResetEvent mre; + + // Try waiting; we should see the AcquireFdForWait and then return because + // the fd has not been resolved. + EXPECT_CALL(mwo, AcquireFdForWait(_)).WillOnce([&](absl::Time deadline) { + // Return the valid FD from the MRE. + return mre.AcquireFdForWait(deadline); + }); + ASSERT_STATUSOR_FALSE(wh.TryWait()); + + // Signal the MRE. + ASSERT_OK(mre.Set()); + + // Try waiting again; we should get the AcquireFdForWait and then also get + // the TryResolveWakeOnFd. + EXPECT_CALL(mwo, AcquireFdForWait(_)).WillOnce([&](absl::Time deadline) { + // Return the valid (and now signaled) FD from the MRE. + return mre.AcquireFdForWait(deadline); + }); + EXPECT_CALL(mwo, TryResolveWakeOnFd(_)).WillOnce(Return(true)); + ASSERT_STATUSOR_TRUE(wh.TryWait()); +} + +// Tests timing out in AcquireFdForWait. +TEST(WaitableObjectTest, AcquireFdForWaitTimeout) { + ManualResetEvent mre; + WaitHandle always_wait = mre.OnSet(); + WaitHandle always_signal = WaitHandle::AlwaysSignaling(); + MockWaitableObject mwo; + WaitHandle wh = mwo.OnSomething(); + + // Make the AcquireFdForWait take longer than the timeout. We should hit + // deadline exceeded even though always_wait hasn't be signaled. + EXPECT_CALL(mwo, AcquireFdForWait(_)).WillOnce([](absl::Time deadline) { + ::usleep(absl::ToInt64Microseconds(absl::Milliseconds(10))); + return std::make_pair(WaitableObject::FdType::kPermanent, + WaitableObject::kInvalidFd); + }); + ASSERT_TRUE(IsDeadlineExceeded(WaitHandle::WaitAll( + {&wh, &always_signal}, absl::Now() - absl::Milliseconds(250)))); +} + +// Tests TryResolveWakeOnFd when a handle is a permanent kSignaledFd. +TEST(WaitableObjectTest, SignaledFd) { + MockWaitableObject mwo; + WaitHandle wh = mwo.OnSomething(); + + // Return the kSignaledFd handle and expect that we still get our notify call. + // We can do this multiple times. + for (int i = 0; i < 4; ++i) { + EXPECT_CALL(mwo, AcquireFdForWait(_)) + .WillOnce(Return(std::make_pair(WaitableObject::FdType::kPermanent, + WaitableObject::kSignaledFd))); + EXPECT_CALL(mwo, TryResolveWakeOnFd(WaitableObject::kSignaledFd)) + .WillOnce(Return(true)); + ASSERT_STATUSOR_TRUE(wh.TryWait()); + } +} + +// Tests that waiting will not resolve if TryResolveWakeOnFd returns false. +TEST(WaitableObjectTest, UnresolvedWake) { + MockWaitableObject mwo; + WaitHandle wh = mwo.OnSomething(); + + // Fail to resolve the first time. + // Since we are only trying to wait it should bail. + EXPECT_CALL(mwo, AcquireFdForWait(_)) + .WillOnce(Return(std::make_pair(WaitableObject::FdType::kPermanent, + WaitableObject::kSignaledFd))); + EXPECT_CALL(mwo, TryResolveWakeOnFd(WaitableObject::kSignaledFd)) + .WillOnce(Return(false)); + ASSERT_STATUSOR_FALSE(wh.TryWait()); + + // Resolve on the next try. + EXPECT_CALL(mwo, AcquireFdForWait(_)) + .WillOnce(Return(std::make_pair(WaitableObject::FdType::kPermanent, + WaitableObject::kSignaledFd))); + EXPECT_CALL(mwo, TryResolveWakeOnFd(WaitableObject::kSignaledFd)) + .WillOnce(Return(true)); + ASSERT_STATUSOR_TRUE(wh.TryWait()); +} + +// Tests the normal lifecycle of a ManualResetEvent. +TEST(ManualResetEventTest, Lifecycle) { + ManualResetEvent ev; + EXPECT_FALSE(ev.DebugString().empty()); + WaitHandle wh0 = ev.OnSet(); + EXPECT_EQ(ev.DebugString(), wh0.DebugString()); + WaitHandle wh1 = ev.OnSet(); + EXPECT_EQ(ev.DebugString(), wh1.DebugString()); + // Should not be set. + ASSERT_STATUSOR_FALSE(wh0.TryWait()); + ASSERT_STATUSOR_FALSE(wh1.TryWait()); + // Set should be sticky. + ASSERT_OK(ev.Set()); + ASSERT_STATUSOR_TRUE(wh0.TryWait()); + ASSERT_STATUSOR_TRUE(wh1.TryWait()); + // Reset should clear. + ASSERT_OK(ev.Reset()); + ASSERT_STATUSOR_FALSE(wh0.TryWait()); + ASSERT_STATUSOR_FALSE(wh1.TryWait()); + // Setting again should enable the previous WaitHandles to be signaled. + ASSERT_OK(ev.Set()); + ASSERT_STATUSOR_TRUE(wh0.TryWait()); + ASSERT_STATUSOR_TRUE(wh1.TryWait()); +} + +// Tests moving ManualResetEvents around. +TEST(ManualResetEventTest, Move) { + ManualResetEvent ev0; + WaitHandle wh = ev0.OnSet(); + ManualResetEvent ev1{std::move(ev0)}; + ManualResetEvent ev2 = std::move(ev1); + ev1 = std::move(ev2); + ASSERT_OK(ev1.Set()); + ASSERT_STATUSOR_TRUE(wh.TryWait()); +} + +// Tests redundantly setting and resetting ManualResetEvents. +TEST(ManualResetEventTest, RedundantUse) { + ManualResetEvent ev; + ASSERT_OK(ev.Reset()); + ASSERT_OK(ev.Reset()); + ASSERT_FALSE(ev.OnSet().TryWait().ValueOrDie()); + ASSERT_OK(ev.Set()); + ASSERT_OK(ev.Set()); + ASSERT_TRUE(ev.OnSet().TryWait().ValueOrDie()); + ASSERT_OK(ev.Reset()); + ASSERT_FALSE(ev.OnSet().TryWait().ValueOrDie()); +} + +// Tests waiting on an initially-set ManualResetEvent; +TEST(ManualResetEventTest, SetThenWait) { + ManualResetEvent ev; + ASSERT_OK(ev.Set()); + ASSERT_TRUE(ev.OnSet().TryWait().ValueOrDie()); +} + +// Tests that dangling an event will not wake waiters. +// This is intentional (for now); we could with a bit of wrangling make it so +// that WaitableObjects tracked their waiters and ensured they were all cleaned +// up, but that seems hard. Don't drop your objects. +TEST(ManualResetEventTest, NeverSet) { + ManualResetEvent ev; + WaitHandle wh = ev.OnSet(); + ASSERT_STATUSOR_FALSE(wh.TryWait()); + // Kill event to unblock waiters. + ev = ManualResetEvent(); + // Waiter should not have woken. + ASSERT_STATUSOR_FALSE(wh.TryWait()); +} + +} // namespace +} // namespace iree
diff --git a/bindings/python/BUILD b/bindings/python/BUILD new file mode 100644 index 0000000..793b93d --- /dev/null +++ b/bindings/python/BUILD
@@ -0,0 +1,25 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//:build_defs.google.bzl", "iree_py_library") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +iree_py_library( + name = "pathsetup", + imports = ["."], +)
diff --git a/iree/bindings/python/README.md b/bindings/python/README.md similarity index 100% rename from iree/bindings/python/README.md rename to bindings/python/README.md
diff --git a/bindings/python/pyiree/BUILD b/bindings/python/pyiree/BUILD new file mode 100644 index 0000000..151f5ee --- /dev/null +++ b/bindings/python/pyiree/BUILD
@@ -0,0 +1,103 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//:build_defs.google.bzl", "NUMPY_DEPS", "iree_py_extension") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +COMPILER_DEPS = [ + "///compiler/Translation/Sequencer", + "///compiler/Translation/Interpreter", + "///compiler/Translation/SPIRV", +] + +DRIVER_DEPS = [ + "///hal/interpreter:interpreter_driver_module", + "///hal/vulkan:vulkan_driver_module", +] + +iree_py_extension( + name = "binding", + srcs = [ + "binding.cc", + "binding.h", + "compiler.cc", + "compiler.h", + "hal.cc", + "hal.h", + "initialize.cc", + "initialize.h", + "rt.cc", + "rt.h", + "status_utils.cc", + "status_utils.h", + "vm.cc", + "vm.h", + ], + copts = [ + "-fexceptions", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "///base:api", + "///base:init", + "///base:status", + "///hal:api", + "///rt:api", + "///schemas", + "///vm:api", + "@llvm//:support", + "@local_config_mlir//:IR", + "@local_config_mlir//:Parser", + "@iree_pybind11//:pybind11", + ] + COMPILER_DEPS + DRIVER_DEPS, +) + +py_test( + name = "compiler_test", + srcs = ["compiler_test.py"], + python_version = "PY3", + deps = [ + ":binding", + "///bindings/python:pathsetup", + "@absl_py//absl/testing:absltest", + ], +) + +py_test( + name = "hal_test", + srcs = ["hal_test.py"], + python_version = "PY3", + deps = [ + ":binding", + "///bindings/python:pathsetup", + "@absl_py//absl/testing:absltest", + ], +) + +py_test( + name = "runtime_test", + srcs = ["runtime_test.py"], + python_version = "PY3", + deps = NUMPY_DEPS + [ + ":binding", + "@absl_py//absl/testing:absltest", + ], +)
diff --git a/bindings/python/pyiree/binding.cc b/bindings/python/pyiree/binding.cc new file mode 100644 index 0000000..1cbdef7 --- /dev/null +++ b/bindings/python/pyiree/binding.cc
@@ -0,0 +1,46 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "bindings/python/pyiree/binding.h" + +#include "bindings/python/pyiree/compiler.h" +#include "bindings/python/pyiree/hal.h" +#include "bindings/python/pyiree/initialize.h" +#include "bindings/python/pyiree/rt.h" +#include "bindings/python/pyiree/status_utils.h" +#include "bindings/python/pyiree/vm.h" + +namespace iree { +namespace python { + +PYBIND11_MODULE(binding, m) { + m.doc() = "IREE Binding Backend Helpers"; + py::class_<OpaqueBlob, std::shared_ptr<OpaqueBlob>>(m, "OpaqueBlob"); + m.def("initialize_extension", &InitializeExtension); + + auto compiler_m = m.def_submodule("compiler", "IREE compiler support"); + SetupCompilerBindings(compiler_m); + + auto hal_m = m.def_submodule("hal", "IREE HAL support"); + SetupHalBindings(hal_m); + + auto rt_m = m.def_submodule("rt", "IREE RT api"); + SetupRtBindings(rt_m); + + auto vm_m = m.def_submodule("vm", "IREE VM api"); + SetupVmBindings(vm_m); +} + +} // namespace python +} // namespace iree
diff --git a/bindings/python/pyiree/binding.h b/bindings/python/pyiree/binding.h new file mode 100644 index 0000000..e470ee6 --- /dev/null +++ b/bindings/python/pyiree/binding.h
@@ -0,0 +1,147 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BINDINGS_PYTHON_PYIREE_BINDING_H_ +#define IREE_BINDINGS_PYTHON_PYIREE_BINDING_H_ + +#include <vector> + +#include "absl/types/optional.h" +#include "base/api.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace pybind11 { +namespace detail { +#if !defined(ABSL_HAVE_STD_OPTIONAL) +// Make absl::optional act like the future C++17 optional for pybind11. +// If ABSL_HAVE_STD_OPTIONAL is defined then absl::optional == std::optional +// and the default type caster is sufficient. +template <typename T> +struct type_caster<absl::optional<T>> : optional_caster<absl::optional<T>> {}; +#endif +} // namespace detail +} // namespace pybind11 + +namespace iree { +namespace python { + +namespace py = pybind11; + +// Wrapper around a blob of memory. +// Used to transport blobs back and forth between C++ and Python. +class OpaqueBlob { + public: + OpaqueBlob() : data_(nullptr), size_(0) {} + OpaqueBlob(void* data, size_t size) : data_(data), size_(size) {} + virtual ~OpaqueBlob() = default; + + void* data() { return data_; } + const void* data() const { return data_; } + size_t size() const { return size_; } + + // Create a free function from the OpaqueBlob shared pointer. + using BufferFreeFn = void (*)(void* self, iree_byte_span_t); + static std::pair<BufferFreeFn, void*> CreateFreeFn( + std::shared_ptr<OpaqueBlob> blob) { + // Note that there are more efficient ways to write this which + // don't bounce through an extra heap alloc, but this is not + // intended to be a high impact code path. + struct Holder { + std::shared_ptr<OpaqueBlob> blob; + }; + Holder* holder = new Holder{std::move(blob)}; + auto free_fn = +([](void* self, iree_byte_span_t) { + Holder* self_holder = static_cast<Holder*>(self); + delete self_holder; + }); + return {free_fn, holder}; + } + + protected: + void* data_; + size_t size_; +}; + +// Opaque blob that owns a vector. +class OpaqueByteVectorBlob : public OpaqueBlob { + public: + OpaqueByteVectorBlob(std::vector<uint8_t> v) + : OpaqueBlob(), v_(std::move(v)) { + data_ = v_.data(); + size_ = v_.size(); + } + + private: + std::vector<uint8_t> v_; +}; + +template <typename T> +struct ApiPtrAdapter {}; + +template <typename Self, typename T> +class ApiRefCounted { + public: + ApiRefCounted() : instance_(nullptr) {} + ApiRefCounted(ApiRefCounted&& other) : instance_(other.instance_) { + other.instance_ = nullptr; + } + void operator=(const ApiRefCounted&) = delete; + + ~ApiRefCounted() { Release(); } + + // Creates an instance of the ref counted wrapper based on an instance + // that has already been retained. Ownership is transferred to the + // wrapper. + static Self CreateRetained(T* retained_inst) { + auto self = Self(); + self.instance_ = retained_inst; + return self; + } + + // Creates a new instance, retaining the underlying object. + static Self RetainAndCreate(T* non_retained_inst) { + auto self = Self(); + self.instance_ = non_retained_inst; + if (non_retained_inst) { + ApiPtrAdapter<T>::Retain(non_retained_inst); + } + return self; + } + + T* raw_ptr() { + if (!instance_) { + throw std::invalid_argument("API object is null"); + } + return instance_; + } + void Retain() { + if (instance_) { + ApiPtrAdapter<T>::Retain(instance_); + } + } + void Release() { + if (instance_) { + ApiPtrAdapter<T>::Release(instance_); + } + } + + private: + T* instance_; +}; + +} // namespace python +} // namespace iree + +#endif // IREE_BINDINGS_PYTHON_PYIREE_BINDING_H_
diff --git a/bindings/python/pyiree/compiler.cc b/bindings/python/pyiree/compiler.cc new file mode 100644 index 0000000..562f1c9 --- /dev/null +++ b/bindings/python/pyiree/compiler.cc
@@ -0,0 +1,93 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "bindings/python/pyiree/compiler.h" + +#include <stdexcept> + +#include "bindings/python/pyiree/binding.h" +#include "bindings/python/pyiree/initialize.h" +#include "bindings/python/pyiree/status_utils.h" +#include "compiler/Translation/Sequencer/SequencerModuleTranslation.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/Parser.h" +#include "schemas/module_def_generated.h" + +namespace py = pybind11; + +using namespace mlir; +using namespace mlir::iree_compiler; + +using llvm::MemoryBuffer; +using llvm::MemoryBufferRef; +using llvm::StringRef; + +namespace iree { +namespace python { + +namespace { + +OwningModuleRef parseMLIRModuleFromString(StringRef contents, + MLIRContext* context) { + std::unique_ptr<MemoryBuffer> contents_buffer; + if (contents.back() == 0) { + // If it has a nul terminator, just use as-is. + contents_buffer = MemoryBuffer::getMemBuffer(contents.drop_back()); + } else { + // Otherwise, make a copy. + contents_buffer = MemoryBuffer::getMemBufferCopy(contents, "EMBED"); + } + + llvm::SourceMgr source_mgr; + source_mgr.AddNewSourceBuffer(std::move(contents_buffer), llvm::SMLoc()); + OwningModuleRef mlir_module = parseSourceFile(source_mgr, context); + return mlir_module; +} + +} // namespace + +std::shared_ptr<OpaqueBlob> CompileModuleFromAsm(const std::string& moduleAsm) { + InitializeExtension({}); + + MLIRContext context; + + // Arrange to get a view that includes a terminating null to avoid additional + // copy. + const char* moduleAsmChars = moduleAsm.c_str(); + StringRef moduleAsmSr(moduleAsmChars, moduleAsm.size() + 1); + + // TODO(laurenzo): This error handling is super hoaky. Hook into the MLIR + // error reporter and plumb through properly. + OwningModuleRef mlirModule = parseMLIRModuleFromString(moduleAsmSr, &context); + if (!mlirModule) { + throw std::runtime_error("Failed to parse MLIR asm"); + } + + auto moduleBlob = + mlir::iree_compiler::translateMlirToIreeSequencerModule(mlirModule.get()); + if (moduleBlob.empty()) { + throw std::runtime_error("Failed to translate MLIR module"); + } + return std::make_shared<OpaqueByteVectorBlob>(std::move(moduleBlob)); +} + +void SetupCompilerBindings(pybind11::module m) { + m.def("compile_module_from_asm", CompileModuleFromAsm); +} + +} // namespace python +} // namespace iree
diff --git a/bindings/python/pyiree/compiler.h b/bindings/python/pyiree/compiler.h new file mode 100644 index 0000000..0bd6624 --- /dev/null +++ b/bindings/python/pyiree/compiler.h
@@ -0,0 +1,30 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BINDINGS_PYTHON_PYIREE_COMPILER_H_ +#define IREE_BINDINGS_PYTHON_PYIREE_COMPILER_H_ + +#include <string> + +#include "bindings/python/pyiree/binding.h" + +namespace iree { +namespace python { + +void SetupCompilerBindings(pybind11::module m); + +} // namespace python +} // namespace iree + +#endif // IREE_BINDINGS_PYTHON_PYIREE_COMPILER_H_
diff --git a/iree/bindings/python/pyiree/compiler_test.py b/bindings/python/pyiree/compiler_test.py similarity index 100% rename from iree/bindings/python/pyiree/compiler_test.py rename to bindings/python/pyiree/compiler_test.py
diff --git a/bindings/python/pyiree/hal.cc b/bindings/python/pyiree/hal.cc new file mode 100644 index 0000000..e7a59a2 --- /dev/null +++ b/bindings/python/pyiree/hal.cc
@@ -0,0 +1,135 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "bindings/python/pyiree/hal.h" + +#include "hal/api.h" + +namespace iree { +namespace python { + +namespace { + +class HalMappedMemory { + public: + HalMappedMemory(iree_hal_mapped_memory_t mapped_memory, + iree_hal_buffer_view_t* bv) + : mapped_memory_(mapped_memory), bv_(bv) { + iree_hal_buffer_view_retain(bv_); + } + ~HalMappedMemory() { + if (bv_) { + iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(bv_); + CHECK_EQ(iree_hal_buffer_unmap(buffer, &mapped_memory_), IREE_STATUS_OK); + iree_hal_buffer_view_release(bv_); + } + } + HalMappedMemory(HalMappedMemory&& other) + : mapped_memory_(other.mapped_memory_), bv_(other.bv_) { + other.bv_ = nullptr; + } + + static HalMappedMemory Create(HalBufferView& bv) { + iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(bv.raw_ptr()); + iree_device_size_t byte_length = iree_hal_buffer_byte_length(buffer); + iree_hal_mapped_memory_t mapped_memory; + CheckApiStatus(iree_hal_buffer_map(buffer, IREE_HAL_MEMORY_ACCESS_READ, + 0 /* element_offset */, byte_length, + &mapped_memory), + "Could not map memory"); + return HalMappedMemory(mapped_memory, bv.raw_ptr()); + } + + py::buffer_info ToBufferInfo() { + iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(bv_); + iree_shape_t shape; + CheckApiStatus(iree_hal_buffer_view_shape(bv_, &shape), + "Error getting buffer view shape"); + int8_t element_size = iree_hal_buffer_view_element_size(bv_); + iree_device_size_t byte_length = iree_hal_buffer_byte_length(buffer); + absl::InlinedVector<ssize_t, IREE_SHAPE_MAX_RANK> dims; + dims.resize(shape.rank); + for (int i = 0; i < shape.rank; ++i) { + dims[i] = shape.dims[i]; + } + absl::InlinedVector<ssize_t, IREE_SHAPE_MAX_RANK> strides; + strides.resize(shape.rank); + for (int i = 1; i < shape.rank; ++i) { + strides[i - 1] = shape.dims[i] * element_size; + } + if (!strides.empty()) { + strides.back() = 1 * element_size; + } + + // TODO(laurenzo): We need to figure out how to propagate dtype in the + // buffer view. + return py::buffer_info( + mapped_memory_.contents.data, element_size, + py::format_descriptor<float>::format(), // TODO(laurenzo): DTYPE! + shape.rank, dims, strides); + } + + private: + iree_hal_mapped_memory_t mapped_memory_; + iree_hal_buffer_view_t* bv_; +}; + +} // namespace + +void SetupHalBindings(pybind11::module m) { + // Enums. + py::enum_<iree_hal_memory_type_t>(m, "MemoryType") + .value("NONE", IREE_HAL_MEMORY_TYPE_NONE) + .value("TRANSIENT", IREE_HAL_MEMORY_TYPE_TRANSIENT) + .value("HOST_VISIBLE", IREE_HAL_MEMORY_TYPE_HOST_VISIBLE) + .value("HOST_COHERENT", IREE_HAL_MEMORY_TYPE_HOST_COHERENT) + .value("HOST_CACHED", IREE_HAL_MEMORY_TYPE_HOST_CACHED) + .value("HOST_LOCAL", IREE_HAL_MEMORY_TYPE_HOST_LOCAL) + .value("DEVICE_VISIBLE", IREE_HAL_MEMORY_TYPE_HOST_VISIBLE) + .value("DEVICE_LOCAL", IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL) + .export_values(); + py::enum_<iree_hal_buffer_usage_t>(m, "BufferUsage") + .value("NONE", IREE_HAL_BUFFER_USAGE_NONE) + .value("CONSTANT", IREE_HAL_BUFFER_USAGE_CONSTANT) + .value("TRANSFER", IREE_HAL_BUFFER_USAGE_TRANSFER) + .value("MAPPING", IREE_HAL_BUFFER_USAGE_MAPPING) + .value("DISPATCH", IREE_HAL_BUFFER_USAGE_DISPATCH) + .value("ALL", IREE_HAL_BUFFER_USAGE_ALL) + .export_values(); + py::enum_<iree_hal_memory_access_t>(m, "MemoryAccess") + .value("NONE", IREE_HAL_MEMORY_ACCESS_NONE) + .value("READ", IREE_HAL_MEMORY_ACCESS_READ) + .value("WRITE", IREE_HAL_MEMORY_ACCESS_WRITE) + .value("DISCARD", IREE_HAL_MEMORY_ACCESS_DISCARD) + .value("DISCARD_WRITE", IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE) + .value("ALL", IREE_HAL_MEMORY_ACCESS_ALL) + .export_values(); + + py::class_<HalShape>(m, "Shape").def(py::init(&HalShape::FromIntVector)); + py::class_<HalBufferView>(m, "BufferView") + .def("map", HalMappedMemory::Create); + py::class_<HalMappedMemory>(m, "MappedMemory", py::buffer_protocol()) + .def_buffer(&HalMappedMemory::ToBufferInfo); + py::class_<HalBuffer>(m, "Buffer") + .def_static("allocate_heap", &HalBuffer::AllocateHeapBuffer, + py::arg("memory_type"), py::arg("usage"), + py::arg("allocation_size")) + .def("fill_zero", &HalBuffer::FillZero, py::arg("byte_offset"), + py::arg("byte_length")) + .def("create_view", &HalBuffer::CreateView, py::arg("shape"), + py::arg("element_size")); +} + +} // namespace python +} // namespace iree
diff --git a/bindings/python/pyiree/hal.h b/bindings/python/pyiree/hal.h new file mode 100644 index 0000000..d26bcf0 --- /dev/null +++ b/bindings/python/pyiree/hal.h
@@ -0,0 +1,97 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BINDINGS_PYTHON_PYIREE_HAL_H_ +#define IREE_BINDINGS_PYTHON_PYIREE_HAL_H_ + +#include "bindings/python/pyiree/binding.h" +#include "bindings/python/pyiree/status_utils.h" +#include "hal/api.h" + +namespace iree { +namespace python { + +template <> +struct ApiPtrAdapter<iree_hal_buffer_t> { + static void Retain(iree_hal_buffer_t* b) { iree_hal_buffer_retain(b); } + static void Release(iree_hal_buffer_t* b) { iree_hal_buffer_release(b); } +}; + +template <> +struct ApiPtrAdapter<iree_hal_buffer_view_t> { + static void Retain(iree_hal_buffer_view_t* bv) { + iree_hal_buffer_view_retain(bv); + } + static void Release(iree_hal_buffer_view_t* bv) { + iree_hal_buffer_view_release(bv); + } +}; + +struct HalShape { + public: + static HalShape FromIntVector(std::vector<int32_t> indices) { + if (indices.size() > IREE_SHAPE_MAX_RANK) { + throw RaiseValueError("Shape exceeded maximum rank"); + } + HalShape s; + s.s.rank = indices.size(); + for (size_t i = 0, e = indices.size(); i < e; ++i) { + s.s.dims[i] = indices[i]; + } + return s; + } + + iree_shape_t s; +}; + +class HalBufferView + : public ApiRefCounted<HalBufferView, iree_hal_buffer_view_t> { + public: +}; + +class HalBuffer : public ApiRefCounted<HalBuffer, iree_hal_buffer_t> { + public: + static HalBuffer AllocateHeapBuffer(int32_t memory_type, int32_t usage, + iree_host_size_t allocation_size) { + iree_hal_buffer_t* buffer = nullptr; + CheckApiStatus( + iree_hal_heap_buffer_allocate( + static_cast<iree_hal_memory_type_t>(memory_type), + static_cast<iree_hal_buffer_usage_t>(usage), allocation_size, + IREE_ALLOCATOR_DEFAULT, IREE_ALLOCATOR_DEFAULT, &buffer), + "Error allocating heap buffer"); + return HalBuffer::CreateRetained(buffer); + } + + void FillZero(iree_device_size_t byte_offset, + iree_device_size_t byte_length) { + CheckApiStatus(iree_hal_buffer_zero(raw_ptr(), byte_offset, byte_length), + "Error zero filling buffer"); + } + + HalBufferView CreateView(HalShape& shape, size_t element_size) { + iree_hal_buffer_view_t* bv; + CheckApiStatus(iree_hal_buffer_view_create(raw_ptr(), shape.s, element_size, + IREE_ALLOCATOR_DEFAULT, &bv), + "Error creating buffer view"); + return HalBufferView::CreateRetained(bv); + } +}; + +void SetupHalBindings(pybind11::module m); + +} // namespace python +} // namespace iree + +#endif // IREE_BINDINGS_PYTHON_PYIREE_HAL_H_
diff --git a/iree/bindings/python/pyiree/hal_test.py b/bindings/python/pyiree/hal_test.py similarity index 100% rename from iree/bindings/python/pyiree/hal_test.py rename to bindings/python/pyiree/hal_test.py
diff --git a/bindings/python/pyiree/initialize.cc b/bindings/python/pyiree/initialize.cc new file mode 100644 index 0000000..acf5cf0 --- /dev/null +++ b/bindings/python/pyiree/initialize.cc
@@ -0,0 +1,53 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "bindings/python/pyiree/initialize.h" + +#include <string.h> + +#include <mutex> // NOLINT + +#include "base/init.h" + +namespace iree { +namespace python { + +namespace { + +void InternalInitialize(const std::vector<std::string>& arguments) { + int argc = arguments.size() + 1; // plus one for program name. + char** argv = static_cast<char**>( + malloc(sizeof(char*) * (argc + 1))); // plus one for null terminator. + char** orig_argv = argv; + argv[0] = strdup("<python_extension>"); + for (int i = 1; i < argc; ++i) { + argv[i] = strdup(arguments[i - 1].c_str()); + } + argv[argc] = nullptr; + InitializeEnvironment(&argc, &argv); + for (int i = 0; i < argc; ++i) { + free(argv[i]); + } + free(orig_argv); +} + +} // namespace + +void InitializeExtension(const std::vector<std::string>& arguments) { + static std::once_flag init_once; + std::call_once(init_once, InternalInitialize, arguments); +} + +} // namespace python +} // namespace iree
diff --git a/iree/bindings/python/pyiree/initialize.h b/bindings/python/pyiree/initialize.h similarity index 100% rename from iree/bindings/python/pyiree/initialize.h rename to bindings/python/pyiree/initialize.h
diff --git a/bindings/python/pyiree/rt.cc b/bindings/python/pyiree/rt.cc new file mode 100644 index 0000000..b683d02 --- /dev/null +++ b/bindings/python/pyiree/rt.cc
@@ -0,0 +1,150 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "bindings/python/pyiree/rt.h" + +#include "base/api.h" +#include "bindings/python/pyiree/status_utils.h" +#include "hal/api.h" + +namespace iree { +namespace python { + +HalBufferView RtContext::WrapPyBufferForInput(py::buffer py_buffer) { + auto py_buffer_info = py_buffer.request(false /* writable */); + if (py_buffer_info.ndim > IREE_SHAPE_MAX_RANK || py_buffer_info.ndim < 0) { + RaiseValueError("Unsupported buffer rank"); + } + if (py_buffer_info.size < 0) { + RaiseValueError("Illegal buffer size"); + } + + // For the moment, allocate a device visible buffer of equivalent size and + // copy into it. + // TODO(laurenzo): Once sequencer is in place, switch to HeapBuffer, wrap + // and retain the original buffer. + iree_host_size_t byte_size = py_buffer_info.size * py_buffer_info.itemsize; + HalBuffer buffer = + AllocateDeviceVisible(byte_size, IREE_HAL_BUFFER_USAGE_CONSTANT | + IREE_HAL_BUFFER_USAGE_TRANSFER | + IREE_HAL_BUFFER_USAGE_DISPATCH); + CheckApiStatus(iree_hal_buffer_write_data(buffer.raw_ptr(), 0, + py_buffer_info.ptr, byte_size), + "Error writing to input buffer"); + + // Create the buffer view. + // TODO(laurenzo): This does no validation on dtype and only cares if the + // elementsize matches. Figure out where to enforce actual dtype. + iree_shape_t shape; + shape.rank = py_buffer_info.ndim; + + // Verify strides are row-major. + // TODO(laurenzo): Test this with rank>1. + for (int i = 1; i < shape.rank; ++i) { + if ((py_buffer_info.strides[i - 1] * py_buffer_info.itemsize) != + py_buffer_info.shape[i]) { + RaiseValueError("Expected row-major layout"); + } + } + if (!py_buffer_info.strides.empty()) { + if (py_buffer_info.strides.back() != 1) { + RaiseValueError("Expected row-major layout"); + } + } + + // Populate shape. + for (int i = 0; i < shape.rank; ++i) { + ssize_t dim = py_buffer_info.shape[i]; + if (dim < 0) { + RaiseValueError("Unsupported negative dim"); + } + shape.dims[i] = dim; + } + + iree_hal_buffer_view_t* bv; + CheckApiStatus(iree_hal_buffer_view_create(buffer.raw_ptr(), shape, + py_buffer_info.itemsize, + IREE_ALLOCATOR_DEFAULT, &bv), + "Error allocating buffer view"); + + return HalBufferView::CreateRetained(bv); +} + +void SetupRtBindings(pybind11::module m) { + // BufferPlacement. + py::enum_<BufferPlacement>(m, "BufferPlacement") + .value("HEAP", BufferPlacement::kHeap) + .value("DEVICE_VISIBLE", BufferPlacement::kDeviceVisible) + .value("DEVICE_LOCAL", BufferPlacement::kDeviceLocal) + .export_values(); + + // RtModule. + py::class_<RtModule>(m, "Module") + .def_property_readonly("name", &RtModule::name) + .def("lookup_function_by_ordinal", &RtModule::lookup_function_by_ordinal) + .def("lookup_function_by_name", &RtModule::lookup_function_by_name); + // RtFunction. + py::class_<RtFunction>(m, "Function") + .def_property_readonly("name", &RtFunction::name) + .def_property_readonly("signature", &RtFunction::signature); + py::class_<iree_rt_function_signature_t>(m, "FunctionSignature") + .def_readonly("argument_count", + &iree_rt_function_signature_t::argument_count) + .def_readonly("result_count", + &iree_rt_function_signature_t::result_count); + + // RtPolicy. + py::class_<RtPolicy>(m, "Policy").def(py::init(&RtPolicy::Create)); + + // RtInstance. + py::class_<RtInstance>(m, "Instance") + .def(py::init(&RtInstance::Create), + py::arg_v("driver_name", absl::optional<std::string>())); + + // RtContext. + py::class_<RtContext>(m, "Context") + .def(py::init(&RtContext::Create), py::arg("instance"), py::arg("policy")) + .def_property_readonly("context_id", &RtContext::context_id) + .def("register_modules", &RtContext::RegisterModules, py::arg("modules")) + .def("register_module", &RtContext::RegisterModule, py::arg("module")) + .def("lookup_module_by_name", &RtContext::LookupModuleByName, + py::arg("name")) + .def("resolve_function", &RtContext::ResolveFunction, + py::arg("full_name")) + .def("allocate", &RtContext::Allocate, py::arg("allocation_size"), + py::arg("placement") = BufferPlacement::kHeap, + py::arg("usage") = IREE_HAL_BUFFER_USAGE_ALL) + .def("allocate_device_local", &RtContext::AllocateDeviceLocal, + py::arg("allocation_size"), + py::arg("usage") = IREE_HAL_BUFFER_USAGE_ALL) + .def("allocate_device_visible", &RtContext::AllocateDeviceVisible, + py::arg("allocation_size"), + py::arg("usage") = IREE_HAL_BUFFER_USAGE_ALL) + .def("wrap_for_input", &RtContext::WrapPyBufferForInput, py::arg("v")) + .def("invoke", &RtContext::Invoke, py::arg("f"), py::arg("policy"), + py::arg("arguments"), + py::arg("results") = absl::optional<std::vector<HalBufferView*>>()); + + // RtInvocation. + py::class_<RtInvocation>(m, "Invocation") + .def("query_status", &RtInvocation::QueryStatus) + .def("await", &RtInvocation::Await, + py::arg("deadline") = IREE_TIME_INFINITE_FUTURE) + .def("await_optional", &RtInvocation::AwaitOptional, + py::arg("deadline") = IREE_TIME_INFINITE_FUTURE) + .def_property_readonly("results", &RtInvocation::ConsumeResults); +} + +} // namespace python +} // namespace iree
diff --git a/bindings/python/pyiree/rt.h b/bindings/python/pyiree/rt.h new file mode 100644 index 0000000..85e85da --- /dev/null +++ b/bindings/python/pyiree/rt.h
@@ -0,0 +1,390 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BINDINGS_PYTHON_PYIREE_RT_H_ +#define IREE_BINDINGS_PYTHON_PYIREE_RT_H_ + +#include "absl/container/inlined_vector.h" +#include "base/api.h" +#include "bindings/python/pyiree/binding.h" +#include "bindings/python/pyiree/hal.h" +#include "bindings/python/pyiree/initialize.h" +#include "bindings/python/pyiree/status_utils.h" +#include "hal/api.h" +#include "rt/api.h" + +namespace iree { +namespace python { + +// When creating a buffer via the context, switch between the different +// allocation entry-points via an enum (these are separate functions in the +// C API). +enum class BufferPlacement { + kHeap, + kDeviceVisible, + kDeviceLocal, +}; + +// Adapts API pointer access to retain/release API calls. +template <> +struct ApiPtrAdapter<iree_rt_module_t> { + static void Retain(iree_rt_module_t* m) { iree_rt_module_retain(m); } + static void Release(iree_rt_module_t* m) { iree_rt_module_release(m); } +}; + +template <> +struct ApiPtrAdapter<iree_rt_instance_t> { + static void Retain(iree_rt_instance_t* inst) { + iree_rt_instance_retain(inst); + } + static void Release(iree_rt_instance_t* inst) { + iree_rt_instance_release(inst); + } +}; + +template <> +struct ApiPtrAdapter<iree_rt_policy_t> { + static void Retain(iree_rt_policy_t* p) { iree_rt_policy_retain(p); } + static void Release(iree_rt_policy_t* p) { iree_rt_policy_release(p); } +}; + +template <> +struct ApiPtrAdapter<iree_rt_context_t> { + static void Retain(iree_rt_context_t* c) { iree_rt_context_retain(c); } + static void Release(iree_rt_context_t* c) { iree_rt_context_release(c); } +}; + +template <> +struct ApiPtrAdapter<iree_rt_invocation_t> { + static void Retain(iree_rt_invocation_t* inv) { + iree_rt_invocation_retain(inv); + } + static void Release(iree_rt_invocation_t* inv) { + iree_rt_invocation_release(inv); + } +}; + +// Wrapper classes. These mirror the Python declarations. +class RtFunction { + public: + // Note that this will retain the module. + RtFunction(iree_rt_function_t function) : function_(function) { + iree_rt_module_retain(function_.module); + } + ~RtFunction() { + if (function_.module) iree_rt_module_release(function_.module); + } + RtFunction(RtFunction&& other) : function_(other.function_) { + other.function_.module = nullptr; + } + void operator=(const RtFunction&) = delete; + + std::string name() { + auto sv = iree_rt_function_name(&function_); + return std::string(sv.data, sv.size); + } + + iree_rt_function_signature_t signature() { + iree_rt_function_signature_t sig; + CheckApiStatus(iree_rt_function_signature(&function_, &sig), + "Error getting function signature"); + return sig; + } + + iree_rt_function_t& raw_function() { return function_; } + + private: + iree_rt_function_t function_; +}; + +class RtModule : public ApiRefCounted<RtModule, iree_rt_module_t> { + public: + std::string name() { + auto sv = iree_rt_module_name(raw_ptr()); + return std::string(sv.data, sv.size); + } + + absl::optional<RtFunction> lookup_function_by_ordinal(int32_t ordinal) { + iree_rt_function_t f; + // TODO(laurenzo): Support an optional linkage argument. + auto module = raw_ptr(); + auto status = iree_rt_module_lookup_function_by_ordinal( + module, IREE_RT_FUNCTION_LINKAGE_EXPORT, ordinal, &f); + if (status == IREE_STATUS_NOT_FOUND) { + return absl::optional<RtFunction>(); + } + CheckApiStatus(status, "Error looking up function"); + return RtFunction(f); + } + + absl::optional<RtFunction> lookup_function_by_name(const std::string& name) { + iree_rt_function_t f; + // TODO(laurenzo): Support an optional linkage argument. + auto module = raw_ptr(); + iree_string_view_t name_sv{name.data(), name.size()}; + auto status = iree_rt_module_lookup_function_by_name( + module, IREE_RT_FUNCTION_LINKAGE_EXPORT, name_sv, &f); + if (status == IREE_STATUS_NOT_FOUND) { + return absl::optional<RtFunction>(); + } + CheckApiStatus(status, "Error looking up function"); + return RtFunction(f); + } +}; + +class RtInstance : public ApiRefCounted<RtInstance, iree_rt_instance_t> { + public: + // TODO(laurenzo): Support optional allocator argument. + static RtInstance Create(absl::optional<std::string> driver_name) { + InitializeExtension({}); + iree_rt_instance_t* raw_inst; + CheckApiStatus(iree_rt_instance_create(IREE_ALLOCATOR_DEFAULT, &raw_inst), + "Error creating instance"); + RtInstance inst = RtInstance::CreateRetained(raw_inst); + + if (!driver_name) { + driver_name = "interpreter"; + } + CheckApiStatus(iree_rt_instance_register_driver_ex( + raw_inst, iree_string_view_t{driver_name->c_str(), + driver_name->size()}), + "Error registering drivers"); + + return inst; + } +}; + +class RtPolicy : public ApiRefCounted<RtPolicy, iree_rt_policy_t> { + public: + // TODO(laurenzo): Support optional allocator argument. + static RtPolicy Create() { + iree_rt_policy_t* policy; + CheckApiStatus(iree_rt_policy_create(IREE_ALLOCATOR_DEFAULT, &policy), + "Error creating policy"); + return RtPolicy::CreateRetained(policy); + } +}; + +class RtInvocation : public ApiRefCounted<RtInvocation, iree_rt_invocation_t> { + public: + // Returns whether ready. + // Raises exception on error. + bool QueryStatus() { + auto status = iree_rt_invocation_query_status(raw_ptr()); + if (status == IREE_STATUS_OK) { + return true; + } else if (status == IREE_STATUS_UNAVAILABLE) { + return false; + } else { + CheckApiStatus(status, "Error in function invocation"); + return false; + } + } + + // TODO(laurenzo): Convert to the pybind chrono support. + // Returns whether the invocation is ready. + bool AwaitOptional(iree_time_t epoch_nanos_deadline) { + auto status = iree_rt_invocation_await(raw_ptr(), epoch_nanos_deadline); + if (status == IREE_STATUS_OK) { + return true; + } else if (status == IREE_STATUS_DEADLINE_EXCEEDED) { + return false; + } else { + CheckApiStatus(status, "Error in invocation"); + return false; + } + } + + // Similar to AwaitOptional but will raise an error unless if the status + // is ready. + void Await(iree_time_t epoch_nanos_deadline) { + if (!AwaitOptional(epoch_nanos_deadline)) { + RaiseValueError("Deadline expired"); + } + } + + std::vector<HalBufferView> ConsumeResults() { + static constexpr size_t kInlineSize = 8; + iree_host_size_t result_count; + absl::InlinedVector<iree_hal_buffer_view_t*, kInlineSize> result_bvs; + result_bvs.resize(kInlineSize); + auto status = iree_rt_invocation_consume_results( + raw_ptr(), kInlineSize, IREE_ALLOCATOR_DEFAULT, &result_bvs[0], + &result_count); + if (status == IREE_STATUS_OUT_OF_RANGE) { + // Resize/retry. + result_bvs.resize(result_count); + status = iree_rt_invocation_consume_results( + raw_ptr(), result_count, IREE_ALLOCATOR_DEFAULT, &result_bvs[0], + &result_count); + } + CheckApiStatus(status, "Error consuming invocation results"); + result_bvs.resize(result_count); + std::vector<HalBufferView> results; + for (auto* raw_bv : result_bvs) { + results.push_back(HalBufferView::CreateRetained(raw_bv)); + } + return results; + } +}; + +class RtContext : public ApiRefCounted<RtContext, iree_rt_context_t> { + public: + static RtContext Create(RtInstance* instance, RtPolicy* policy) { + iree_rt_context_t* context; + // TODO(laurenzo): Support optional allocator argument. + CheckApiStatus( + iree_rt_context_create(instance->raw_ptr(), policy->raw_ptr(), + IREE_ALLOCATOR_DEFAULT, &context), + "Error creating instance"); + return RtContext::CreateRetained(context); + } + + int context_id() { return iree_rt_context_id(raw_ptr()); } + + void RegisterModules(std::vector<RtModule*> modules) { + std::vector<iree_rt_module_t*> module_raw_ptrs; + module_raw_ptrs.resize(modules.size()); + for (size_t i = 0, e = modules.size(); i < e; ++i) { + auto module_raw_ptr = modules[i]->raw_ptr(); + module_raw_ptrs[i] = module_raw_ptr; + } + CheckApiStatus( + iree_rt_context_register_modules(raw_ptr(), module_raw_ptrs.data(), + module_raw_ptrs.size()), + "Error registering modules"); + } + + void RegisterModule(RtModule* module) { + iree_rt_module_t* module_raw_ptr = module->raw_ptr(); + CheckApiStatus( + iree_rt_context_register_modules(raw_ptr(), &module_raw_ptr, 1), + "Error registering module"); + } + + absl::optional<RtModule> LookupModuleByName(const std::string& name) { + iree_rt_module_t* module = iree_rt_context_lookup_module_by_name( + raw_ptr(), {name.data(), name.size()}); + if (!module) { + return absl::optional<RtModule>(); + } + return RtModule::RetainAndCreate(module); + } + + absl::optional<RtFunction> ResolveFunction(const std::string& full_name) { + iree_rt_function_t f; + auto status = iree_rt_context_resolve_function( + raw_ptr(), {full_name.data(), full_name.size()}, &f); + if (status == IREE_STATUS_NOT_FOUND) { + return absl::optional<RtFunction>(); + } + CheckApiStatus(status, "Error resolving function"); + return RtFunction(f); + } + + // Convenience method to allocate host, device-visible or device-local + // buffers. + HalBuffer Allocate(iree_host_size_t allocation_size, + BufferPlacement placement, int32_t usage) { + iree_hal_buffer_t* raw_buffer = nullptr; + switch (placement) { + case BufferPlacement::kHeap: + // Even though allocating a heap buffer does not require the context, + // provide it here to make the API easier to navigate. + CheckApiStatus( + iree_hal_heap_buffer_allocate( + IREE_HAL_MEMORY_TYPE_HOST_LOCAL, + static_cast<iree_hal_buffer_usage_t>(usage), allocation_size, + IREE_ALLOCATOR_DEFAULT, IREE_ALLOCATOR_DEFAULT, &raw_buffer), + "Error allocating heap buffer"); + break; + case BufferPlacement::kDeviceLocal: + CheckApiStatus( + iree_rt_context_allocate_device_local_buffer( + raw_ptr(), static_cast<iree_hal_buffer_usage_t>(usage), + allocation_size, IREE_ALLOCATOR_DEFAULT, &raw_buffer), + "Error allocating device local buffer"); + break; + case BufferPlacement::kDeviceVisible: + CheckApiStatus( + iree_rt_context_allocate_device_visible_buffer( + raw_ptr(), static_cast<iree_hal_buffer_usage_t>(usage), + allocation_size, IREE_ALLOCATOR_DEFAULT, &raw_buffer), + "Error allocating device visible buffer"); + break; + default: + throw RaiseValueError("Unknown BufferPlacement"); + } + + return HalBuffer::CreateRetained(raw_buffer); + } + + HalBuffer AllocateHeap(iree_host_size_t allocation_size, int32_t usage) { + return Allocate(allocation_size, BufferPlacement::kHeap, usage); + } + + HalBuffer AllocateDeviceLocal(iree_host_size_t allocation_size, + int32_t usage) { + return Allocate(allocation_size, BufferPlacement::kDeviceLocal, usage); + } + + HalBuffer AllocateDeviceVisible(iree_host_size_t allocation_size, + int32_t usage) { + return Allocate(allocation_size, BufferPlacement::kDeviceVisible, usage); + } + + // One stop convenience method for wrapping a python buffer protocol buffer + // for input to a function. At the runtime's discretion, this may make a copy + // or do something smarter, meaning the data in the backing python buffer + // will either be accessed immediately or at some future point. + HalBufferView WrapPyBufferForInput(py::buffer py_buffer); + + RtInvocation Invoke(RtFunction& f, RtPolicy& policy, + std::vector<HalBufferView*> arguments, + absl::optional<std::vector<HalBufferView*>> results) { + absl::InlinedVector<iree_hal_buffer_view_t*, 8> raw_arguments; + raw_arguments.resize(arguments.size()); + for (size_t i = 0, e = arguments.size(); i < e; ++i) { + auto inst = arguments[i]; + CheckApiNotNull(inst, "Argument buffer view cannot be None"); + raw_arguments[i] = inst->raw_ptr(); + } + absl::InlinedVector<iree_hal_buffer_view_t*, 8> raw_results; + if (results) { + raw_results.resize(results->size()); + for (size_t i = 0, e = results->size(); i < e; ++i) { + auto inst = (*results)[i]; + CheckApiNotNull(inst, "Result buffer view cannot be None"); + raw_results[i] = inst->raw_ptr(); + } + } + + iree_rt_invocation_t* invocation; + CheckApiStatus(iree_rt_invocation_create( + raw_ptr(), &f.raw_function(), policy.raw_ptr(), + nullptr /* dependencies */, raw_arguments.data(), + raw_arguments.size(), raw_results.data(), + raw_results.size(), IREE_ALLOCATOR_DEFAULT, &invocation), + "Error invoking function"); + + return RtInvocation::CreateRetained(invocation); + } +}; + +void SetupRtBindings(pybind11::module m); + +} // namespace python +} // namespace iree + +#endif // IREE_BINDINGS_PYTHON_PYIREE_RT_H_
diff --git a/iree/bindings/python/pyiree/runtime_test.py b/bindings/python/pyiree/runtime_test.py similarity index 100% rename from iree/bindings/python/pyiree/runtime_test.py rename to bindings/python/pyiree/runtime_test.py
diff --git a/bindings/python/pyiree/status_utils.cc b/bindings/python/pyiree/status_utils.cc new file mode 100644 index 0000000..63f2131 --- /dev/null +++ b/bindings/python/pyiree/status_utils.cc
@@ -0,0 +1,72 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "bindings/python/pyiree/status_utils.h" + +#include "absl/strings/str_cat.h" + +namespace iree { +namespace python { + +namespace { + +PyObject* StatusToPyExcClass(const Status& status) { + switch (status.code()) { + case StatusCode::kInvalidArgument: + return PyExc_ValueError; + case StatusCode::kOutOfRange: + return PyExc_IndexError; + case StatusCode::kUnimplemented: + return PyExc_NotImplementedError; + default: + return PyExc_RuntimeError; + } +} + +PyObject* ApiStatusToPyExcClass(iree_status_t status) { + switch (status) { + case IREE_STATUS_INVALID_ARGUMENT: + return PyExc_ValueError; + case IREE_STATUS_OUT_OF_RANGE: + return PyExc_IndexError; + case IREE_STATUS_UNIMPLEMENTED: + return PyExc_NotImplementedError; + default: + return PyExc_RuntimeError; + } +} + +} // namespace + +pybind11::error_already_set StatusToPyExc(const Status& status) { + assert(!status.ok()); + PyErr_SetString(StatusToPyExcClass(status), status.error_message().c_str()); + return pybind11::error_already_set(); +} + +pybind11::error_already_set ApiStatusToPyExc(iree_status_t status, + const char* message) { + assert(status != IREE_STATUS_OK); + auto full_message = absl::StrCat(message, ": ", static_cast<int>(status)); + PyErr_SetString(ApiStatusToPyExcClass(status), full_message.c_str()); + return pybind11::error_already_set(); +} + +pybind11::error_already_set RaiseValueError(const char* message) { + PyErr_SetString(PyExc_ValueError, message); + return pybind11::error_already_set(); +} + +} // namespace python +} // namespace iree
diff --git a/bindings/python/pyiree/status_utils.h b/bindings/python/pyiree/status_utils.h new file mode 100644 index 0000000..ef89ba8 --- /dev/null +++ b/bindings/python/pyiree/status_utils.h
@@ -0,0 +1,67 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BINDINGS_PYTHON_PYIREE_STATUS_UTILS_H_ +#define IREE_BINDINGS_PYTHON_PYIREE_STATUS_UTILS_H_ + +#include "base/api.h" +#include "base/status.h" +#include "pybind11/pytypes.h" + +namespace iree { +namespace python { + +// Converts a failing status to a throwable exception, setting Python +// error information. +// Correct usage is something like: +// if (!status.ok()) { +// throw StatusToPyExc(status); +// } +pybind11::error_already_set StatusToPyExc(const Status& status); + +// Raises a value error with the given message. +// Correct usage: +// throw RaiseValueError("Foobar'd"); +pybind11::error_already_set RaiseValueError(const char* message); + +// Consumes a StatusOr<T>, returning an rvalue reference to the T if the +// status is ok(). Otherwise, throws an exception. +template <typename T> +T&& PyConsumeStatusOr(iree::StatusOr<T>&& sor) { + if (sor.ok()) { + return std::move(*sor); + } + throw StatusToPyExc(sor.status()); +} + +pybind11::error_already_set ApiStatusToPyExc(iree_status_t status, + const char* message); + +static void CheckApiStatus(iree_status_t status, const char* message) { + if (status == IREE_STATUS_OK) { + return; + } + throw ApiStatusToPyExc(status, message); +} + +static void CheckApiNotNull(const void* p, const char* message) { + if (!p) { + throw RaiseValueError(message); + } +} + +} // namespace python +} // namespace iree + +#endif // IREE_BINDINGS_PYTHON_PYIREE_STATUS_UTILS_H_
diff --git a/bindings/python/pyiree/vm.cc b/bindings/python/pyiree/vm.cc new file mode 100644 index 0000000..3d03d97 --- /dev/null +++ b/bindings/python/pyiree/vm.cc
@@ -0,0 +1,37 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "bindings/python/pyiree/vm.h" + +#include "bindings/python/pyiree/status_utils.h" + +namespace iree { +namespace python { + +RtModule CreateModuleFromBlob(std::shared_ptr<OpaqueBlob> blob) { + iree_rt_module_t* module; + auto free_fn = OpaqueBlob::CreateFreeFn(blob); + auto status = iree_vm_bytecode_module_create_from_buffer( + {static_cast<const uint8_t*>(blob->data()), blob->size()}, free_fn.first, + free_fn.second, IREE_ALLOCATOR_DEFAULT, &module); + CheckApiStatus(status, "Error creating vm module from blob"); + return RtModule::CreateRetained(module); +} + +void SetupVmBindings(pybind11::module m) { + m.def("create_module_from_blob", CreateModuleFromBlob); +} + +} // namespace python +} // namespace iree
diff --git a/bindings/python/pyiree/vm.h b/bindings/python/pyiree/vm.h new file mode 100644 index 0000000..e7338cc --- /dev/null +++ b/bindings/python/pyiree/vm.h
@@ -0,0 +1,30 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_BINDINGS_PYTHON_PYIREE_VM_H_ +#define IREE_BINDINGS_PYTHON_PYIREE_VM_H_ + +#include "bindings/python/pyiree/binding.h" +#include "bindings/python/pyiree/rt.h" +#include "vm/api.h" + +namespace iree { +namespace python { + +void SetupVmBindings(pybind11::module m); + +} // namespace python +} // namespace iree + +#endif // IREE_BINDINGS_PYTHON_PYIREE_VM_H_
diff --git a/build_defs.oss.bzl b/build_defs.oss.bzl new file mode 100644 index 0000000..92b002a --- /dev/null +++ b/build_defs.oss.bzl
@@ -0,0 +1,130 @@ +"""Common Bazel definitions for IREE.""" + +load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") +load("@iree_native_python//:build_defs.bzl", "py_extension") +load("@iree_core///build_tools/third_party/glslang:build_defs.bzl", "glsl_vulkan") +load("@rules_python//python:defs.bzl", "py_library") + +NUMPY_DEPS = [] + +def platform_trampoline_deps(basename): + """Produce a list of deps for the given `basename` platform target. + + Example: + "file_mapping" -> ["///base/internal/file_mapping_internal"] + + This is used for compatibility with various methods of including the + library in foreign source control systems. + + Args: + basename: Library name prefix for a library in base/internal. + Returns: + A list of dependencies for depending on the library in a platform + sensitive way. + """ + return [ + "///base/internal:%s_internal" % basename, + ] + +# A platform-sensitive list of copts for the Vulkan loader. +PLATFORM_VULKAN_LOADER_COPTS = select({ + "///hal/vulkan:native_vk": [], + "///hal/vulkan:swiftshader_vk": [], + "//conditions:default": [], +}) + +# A platform-sensitive list of dependencies for non-test targets using Vulkan. +PLATFORM_VULKAN_DEPS = select({ + "///hal/vulkan:native_vk": [], + "///hal/vulkan:swiftshader_vk": [], + "//conditions:default": [], +}) + +# A platform-sensitive list of dependencies for tests using Vulkan. +PLATFORM_VULKAN_TEST_DEPS = [ + "@com_google_googletest//:gtest_main", +] + +def iree_py_library(**kwargs): + """Compatibility py_library which has bazel compatible args.""" + + # This is used when args are needed that are incompatible with upstream. + # Presently, this includes: + # imports + py_library(**kwargs) + +def iree_py_extension(deps = [], **kwargs): + """Delegates to the real py_extension.""" + py_extension( + deps = ["@iree_native_python//:python_headers"] + deps, + **kwargs + ) + +def iree_build_test(name, targets): + """Dummy rule to ensure that targets build. + + This is currently undefined in bazel and is preserved for compatibility. + """ + pass + +def iree_setup_lit_package(data): + """Should be called once per test package that contains globbed lit tests. + + Args: + data: Additional, project specific data deps to add. + """ + + # Bundle together all of the test utilities that are used by tests. + native.filegroup( + name = "lit_test_utilities", + testonly = True, + data = data + [ + "@llvm//:FileCheck", + ], + ) + +def iree_glob_lit_tests( + data = [":lit_test_utilities"], + driver = "///tools:run_lit.sh", + test_file_exts = ["mlir"]): + """Globs lit test files into tests for a package. + + For most packages, the defaults suffice. Packages that include this must + also include a call to iree_setup_lit_package(). + + Args: + data: Data files to include/build. + driver: Test driver. + test_file_exts: File extensions to glob. + """ + for test_file_ext in test_file_exts: + test_files = native.glob([ + "*.%s" % (test_file_ext,), + "**/*.%s" % (test_file_ext,), + ]) + for test_file in test_files: + test_file_location = "$(location %s)" % (test_file,) + native.sh_test( + name = "%s.test" % (test_file,), + size = "small", + srcs = [driver], + data = data + [test_file], + args = [test_file_location], + ) + +# The OSS build currently has issues with generating flatbuffer reflections. +# It is hard-coded to disabled here (and in iree_flatbuffer_cc_library) until triaged/fixed. +FLATBUFFER_SUPPORTS_REFLECTIONS = False + +def iree_flatbuffer_cc_library(**kwargs): + """Wrapper for the flatbuffer_cc_library.""" + + # TODO(laurenzo): The bazel rule for reflections seems broken in OSS + # builds. Fix it and enable by default. + flatbuffer_cc_library( + gen_reflections = False, + **kwargs + ) + +def iree_glsl_vulkan(**kwargs): + glsl_vulkan(**kwargs)
diff --git a/build_tools/embed_data/build_defs.bzl b/build_tools/embed_data/build_defs.bzl index e48b22a..b4dc8ac 100644 --- a/build_tools/embed_data/build_defs.bzl +++ b/build_tools/embed_data/build_defs.bzl
@@ -53,7 +53,7 @@ identifier: The identifier to use in generated names (defaults to name). **kwargs: Args to pass to the cc_library. """ - generator = "//build_tools/embed_data:generate_cc_embed_data" + generator = "///build_tools/embed_data:generate_cc_embed_data" generator_location = "$(location %s)" % generator if identifier == None: identifier = name
diff --git a/build_tools/embed_data/testembed1_test.cc b/build_tools/embed_data/testembed1_test.cc index f34d19b..a2b9f01 100644 --- a/build_tools/embed_data/testembed1_test.cc +++ b/build_tools/embed_data/testembed1_test.cc
@@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "iree/build_tools/embed_data/testembed1.h" +#include "build_tools/embed_data/testembed1.h" #include "gmock/gmock.h" #include "gtest/gtest.h"
diff --git a/build_tools/python/configure.bzl b/build_tools/python/configure.bzl index 0e5e7b0..d215b4c 100644 --- a/build_tools/python/configure.bzl +++ b/build_tools/python/configure.bzl
@@ -56,11 +56,11 @@ ], attrs = { "_generate_script": attr.label( - default = Label("//build_tools/python:generate_build.py"), + default = Label("///build_tools/python:generate_build.py"), allow_single_file = True, ), "_build_defs": attr.label( - default = Label("//build_tools/python:build_defs.bzl"), + default = Label("///build_tools/python:build_defs.bzl"), allow_single_file = True, ), },
diff --git a/iree/compiler/BUILD b/compiler/BUILD similarity index 100% rename from iree/compiler/BUILD rename to compiler/BUILD
diff --git a/iree/compiler/CMakeLists.txt b/compiler/CMakeLists.txt similarity index 100% rename from iree/compiler/CMakeLists.txt rename to compiler/CMakeLists.txt
diff --git a/iree/compiler/IR/BUILD b/compiler/IR/BUILD similarity index 100% rename from iree/compiler/IR/BUILD rename to compiler/IR/BUILD
diff --git a/iree/compiler/IR/CMakeLists.txt b/compiler/IR/CMakeLists.txt similarity index 100% rename from iree/compiler/IR/CMakeLists.txt rename to compiler/IR/CMakeLists.txt
diff --git a/compiler/IR/ConfigOps.cpp b/compiler/IR/ConfigOps.cpp new file mode 100644 index 0000000..0837e69 --- /dev/null +++ b/compiler/IR/ConfigOps.cpp
@@ -0,0 +1,110 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/ConfigOps.h" + +#include "compiler/IR/StructureOps.h" +#include "compiler/IR/Types.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/STLExtras.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { + +//===----------------------------------------------------------------------===// +// Generic printers and parsers. +//===----------------------------------------------------------------------===// + +// Parses an op that has no inputs and no outputs. +static ParseResult parseNoIOOp(OpAsmParser &parser, OperationState &state) { + if (failed(parser.parseOptionalAttributeDict(state.attributes))) { + return failure(); + } + return success(); +} + +// Prints an op that has no inputs and no outputs. +static void printNoIOOp(Operation *op, OpAsmPrinter &printer) { + printer << op->getName(); + printer.printOptionalAttrDict(op->getAttrs()); +} + +//===----------------------------------------------------------------------===// +// iree.target_config +//===----------------------------------------------------------------------===// + +void ExecutableTargetConfigOp::build(Builder *builder, OperationState &state, + std::string backend) { + state.addAttribute("backend", builder->getStringAttr(backend)); + ensureTerminator(*state.addRegion(), *builder, state.location); +} + +static ParseResult parseExecutableTargetConfigOp(OpAsmParser &parser, + OperationState &state) { + llvm::SMLoc backendLoc; + StringAttr backendAttr; + if (failed(parser.parseLParen()) || + failed(parser.getCurrentLocation(&backendLoc)) || + failed(parser.parseAttribute(backendAttr, "backend", state.attributes))) { + return failure(); + } + + Region *body = state.addRegion(); + if (failed(parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))) { + return failure(); + } + if (succeeded(parser.parseOptionalKeyword("attributes"))) { + if (failed(parser.parseOptionalAttributeDict(state.attributes))) { + return failure(); + } + } + + ExecutableTargetConfigOp::ensureTerminator(*body, parser.getBuilder(), + state.location); + + return success(); +} + +static void printExecutableTargetConfigOp(OpAsmPrinter &printer, + ExecutableTargetConfigOp op) { + printer << op.getOperationName() << "(" << op.backend() << ")"; + + printer.printRegion(op.body(), /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/false); + + // Print out executable attributes, if present. + SmallVector<StringRef, 1> ignoredAttrs = { + "backend", + }; + if (op.getAttrs().size() > ignoredAttrs.size()) { + printer << "\n attributes "; + printer.printOptionalAttrDict(op.getAttrs(), ignoredAttrs); + } +} + +#define GET_OP_CLASSES +#include "compiler/IR/ConfigOps.cpp.inc" + +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/IR/ConfigOps.h b/compiler/IR/ConfigOps.h new file mode 100644 index 0000000..de806ae --- /dev/null +++ b/compiler/IR/ConfigOps.h
@@ -0,0 +1,38 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_IR_CONFIGOPS_H_ +#define IREE_COMPILER_IR_CONFIGOPS_H_ + +#include <cstdint> + +#include "compiler/IR/Types.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { + +#define GET_OP_CLASSES +#include "compiler/IR/ConfigOps.h.inc" + +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_IR_CONFIGOPS_H_
diff --git a/compiler/IR/ConfigOps.td b/compiler/IR/ConfigOps.td new file mode 100644 index 0000000..12dbce6 --- /dev/null +++ b/compiler/IR/ConfigOps.td
@@ -0,0 +1,44 @@ +// Ops used to declare configuration used by the IREE compiler. +// These allow inline config that follows along the IR they are associated with. +// Multiple config ops are allowed within a single scope to indicate that the +// parent IR node should be processed for multiple targets. + +#ifdef IREE_CONFIG_OPS +#else +#define IREE_CONFIG_OPS + +include "compiler/IR/OpBase.td" + +class IREE_ConfigOp<string mnemonic, list<OpTrait> traits = []> : + Op<IREE_Dialect, mnemonic, traits> { + let parser = [{ return parse$cppClass(parser, result); }]; + let printer = [{ print$cppClass(p, *this); }]; +} + +//===----------------------------------------------------------------------===// +// iree.executable configuration +//===----------------------------------------------------------------------===// + +def IREE_ExecutableTargetConfigOp : IREE_ConfigOp<"target_config", [ + IREE_ExecutableOnly, + SingleBlockImplicitTerminator<"ExecutableTargetConfigEndOp"> +]> { + let arguments = (ins + StrAttr:$backend + ); + + let regions = (region SizedRegion<1>:$body); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<"Builder *builder, OperationState &state, std::string backend">, + ]; +} + +def IREE_ExecutableTargetConfigEndOp : + IREE_ConfigOp<"_target_config_end", [Terminator, IREE_ExecutableTargetConfigOnly]> { + let parser = [{ return parseNoIOOp(parser, result); }]; + let printer = [{ printNoIOOp(getOperation(), p); }]; +} + +#endif // IREE_CONFIG_OPS
diff --git a/compiler/IR/Dialect.cpp b/compiler/IR/Dialect.cpp new file mode 100644 index 0000000..5968ade --- /dev/null +++ b/compiler/IR/Dialect.cpp
@@ -0,0 +1,90 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Dialect.h" + +#include "compiler/IR/ConfigOps.h" +#include "compiler/IR/Ops.h" +#include "compiler/IR/StructureOps.h" +#include "compiler/IR/Types.h" +#include "llvm/Support/SourceMgr.h" + +namespace mlir { +namespace iree_compiler { + +static DialectRegistration<IREEDialect> iree_dialect; + +IREEDialect::IREEDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context) { +#define IREE_ADD_TYPE(NAME, KIND, TYPE) addTypes<TYPE>(); + IREE_TYPE_TABLE(IREE_ADD_TYPE); + +#define GET_OP_LIST + addOperations< +#include "compiler/IR/Ops.cpp.inc" + >(); +#define GET_OP_LIST + addOperations< +#include "compiler/IR/ConfigOps.cpp.inc" + >(); +#define GET_OP_LIST + addOperations< +#include "compiler/IR/StructureOps.cpp.inc" + >(); +} + +//===----------------------------------------------------------------------===// +// Type Parsing +//===----------------------------------------------------------------------===// + +#define IREE_TYPE_PARSER(NAME, KIND, TYPE) \ + static Type parse##TYPE(IREEDialect const &dialect, StringRef spec, \ + Location loc) { \ + spec.consume_front(NAME); \ + return TYPE::get(dialect.getContext()); \ + } +IREE_TYPE_TABLE(IREE_TYPE_PARSER); + +#define IREE_PARSE_TYPE(NAME, KIND, TYPE) \ + if (spec.startswith(NAME)) { \ + return parse##TYPE(*this, spec, loc); \ + } +Type IREEDialect::parseType(StringRef spec, Location loc) const { + IREE_TYPE_TABLE(IREE_PARSE_TYPE); + emitError(loc, "unknown IREE type: ") << spec; + return Type(); +} + +//===----------------------------------------------------------------------===// +// Type Printing +//===----------------------------------------------------------------------===// + +#define IREE_TYPE_PRINTER(NAME, KIND, TYPE) \ + static void print##TYPE(TYPE type, llvm::raw_ostream &os) { os << NAME; } +IREE_TYPE_TABLE(IREE_TYPE_PRINTER); + +#define IREE_PRINT_TYPE(NAME, KIND, TYPE) \ + case KIND: \ + print##TYPE(type.cast<TYPE>(), os); \ + return; +void IREEDialect::printType(Type type, llvm::raw_ostream &os) const { + switch (type.getKind()) { + IREE_TYPE_TABLE(IREE_PRINT_TYPE); + default: + llvm_unreachable("unhandled IREE type"); + } +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/IR/Dialect.h b/compiler/IR/Dialect.h similarity index 100% rename from iree/compiler/IR/Dialect.h rename to compiler/IR/Dialect.h
diff --git a/compiler/IR/Interpreter/BUILD b/compiler/IR/Interpreter/BUILD new file mode 100644 index 0000000..8dc1cf3 --- /dev/null +++ b/compiler/IR/Interpreter/BUILD
@@ -0,0 +1,75 @@ +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +load("@local_config_mlir//:tblgen.bzl", "gentbl") + +filegroup( + name = "td_files", + srcs = glob(["*.td"]), +) + +cc_library( + name = "Interpreter", + srcs = [ + "HLDialect.cpp", + "HLOps.cpp", + "HLOps.cpp.inc", + "LLDialect.cpp", + "LLOps.cpp", + "LLOps.cpp.inc", + "OpWriters.cpp", + ], + hdrs = [ + "HLDialect.h", + "HLOps.h", + "HLOps.h.inc", + "LLDialect.h", + "LLOps.h", + "LLOps.h.inc", + "OpWriters.h", + ], + deps = [ + ":HLOpsGen", + ":LLOpsGen", + "///compiler/IR", + "///compiler/Serialization", + "///compiler/Utils", + "///schemas/bytecode:interpreter_bytecode_v0", + "@llvm//:support", + "@local_config_mlir//:IR", + "@local_config_mlir//:StandardOps", + ], + alwayslink = 1, +) + +gentbl( + name = "HLOpsGen", + tbl_outs = [ + ("-gen-op-decls", "HLOps.h.inc"), + ("-gen-op-defs", "HLOps.cpp.inc"), + ], + tblgen = "@local_config_mlir//:mlir-tblgen", + td_file = "HLOps.td", + td_srcs = [ + ":td_files", + "@local_config_mlir//:include/mlir/IR/OpBase.td", + "///compiler/IR:OpBase.td", + ], +) + +gentbl( + name = "LLOpsGen", + tbl_outs = [ + ("-gen-op-decls", "LLOps.h.inc"), + ("-gen-op-defs", "LLOps.cpp.inc"), + ], + tblgen = "@local_config_mlir//:mlir-tblgen", + td_file = "LLOps.td", + td_srcs = [ + ":td_files", + "@local_config_mlir//:include/mlir/IR/OpBase.td", + "///compiler/IR:OpBase.td", + ], +)
diff --git a/iree/compiler/IR/Interpreter/CMakeLists.txt b/compiler/IR/Interpreter/CMakeLists.txt similarity index 100% rename from iree/compiler/IR/Interpreter/CMakeLists.txt rename to compiler/IR/Interpreter/CMakeLists.txt
diff --git a/compiler/IR/Interpreter/HLDialect.cpp b/compiler/IR/Interpreter/HLDialect.cpp new file mode 100644 index 0000000..d34f44b --- /dev/null +++ b/compiler/IR/Interpreter/HLDialect.cpp
@@ -0,0 +1,34 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Interpreter/HLDialect.h" + +#include "compiler/IR/Interpreter/HLOps.h" +#include "llvm/Support/SourceMgr.h" + +namespace mlir { +namespace iree_compiler { + +IREEHLInterpreterDialect::IREEHLInterpreterDialect(MLIRContext* context) + : Dialect(getDialectNamespace(), context) { +#define GET_OP_LIST + addOperations< +#include "compiler/IR/Interpreter/HLOps.cpp.inc" + >(); +} + +static DialectRegistration<IREEHLInterpreterDialect> iree_hl_interp_dialect; + +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/IR/Interpreter/HLDialect.h b/compiler/IR/Interpreter/HLDialect.h similarity index 100% rename from iree/compiler/IR/Interpreter/HLDialect.h rename to compiler/IR/Interpreter/HLDialect.h
diff --git a/compiler/IR/Interpreter/HLOps.cpp b/compiler/IR/Interpreter/HLOps.cpp new file mode 100644 index 0000000..6a4b4ff --- /dev/null +++ b/compiler/IR/Interpreter/HLOps.cpp
@@ -0,0 +1,246 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Interpreter/HLOps.h" + +#include "compiler/IR/Ops.h" +#include "compiler/Utils/OpCreationUtils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace iree_compiler { +namespace IREEInterp { +namespace HL { + +//===----------------------------------------------------------------------===// +// iree_hl_interp.call +//===----------------------------------------------------------------------===// + +static ParseResult parseCallOp(OpAsmParser &parser, OperationState &state) { + SymbolRefAttr calleeAttr; + FunctionType calleeType; + SmallVector<OpAsmParser::OperandType, 4> operands; + auto calleeLoc = parser.getNameLoc(); + if (parser.parseAttribute(calleeAttr, "callee", state.attributes) || + parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || + parser.parseOptionalAttributeDict(state.attributes) || + parser.parseColonType(calleeType) || + parser.addTypesToList(calleeType.getResults(), state.types) || + parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc, + state.operands)) { + return failure(); + } + return success(); +} + +static void printCallOp(OpAsmPrinter &p, CallOp op) { + p << "iree_hl_interp.call " << op.getAttr("callee") << '('; + p.printOperands(op.getOperands()); + p << ')'; + p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); + p << " : "; + p.printType(op.getCalleeType()); +} + +FunctionType CallOp::getCalleeType() { + SmallVector<Type, 4> resultTypes(getResultTypes()); + SmallVector<Type, 8> argTypes(getOperandTypes()); + return FunctionType::get(argTypes, resultTypes, getContext()); +} + +//===----------------------------------------------------------------------===// +// iree_hl_interp.call_indirect +//===----------------------------------------------------------------------===// + +static ParseResult parseCallIndirectOp(OpAsmParser &parser, + OperationState &result) { + FunctionType calleeType; + OpAsmParser::OperandType callee; + llvm::SMLoc operandsLoc; + SmallVector<OpAsmParser::OperandType, 4> operands; + return failure( + parser.parseOperand(callee) || parser.getCurrentLocation(&operandsLoc) || + parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || + parser.parseOptionalAttributeDict(result.attributes) || + parser.parseColonType(calleeType) || + parser.resolveOperand(callee, calleeType, result.operands) || + parser.resolveOperands(operands, calleeType.getInputs(), operandsLoc, + result.operands) || + parser.addTypesToList(calleeType.getResults(), result.types)); +} + +static void printCallIndirectOp(OpAsmPrinter &p, CallIndirectOp op) { + p << "iree_hl_interp.call_indirect "; + p.printOperand(op.getCallee()); + p << '('; + auto operandRange = op.getOperands(); + p.printOperands(++operandRange.begin(), operandRange.end()); + p << ')'; + p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); + p << " : " << op.getCallee()->getType(); +} + +//===----------------------------------------------------------------------===// +// iree_hl_interp.return +//===----------------------------------------------------------------------===// + +static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &state) { + SmallVector<OpAsmParser::OperandType, 2> opInfo; + SmallVector<Type, 2> types; + llvm::SMLoc loc = parser.getCurrentLocation(); + return failure(parser.parseOperandList(opInfo) || + (!opInfo.empty() && parser.parseColonTypeList(types)) || + parser.resolveOperands(opInfo, types, loc, state.operands)); +} + +static void printReturnOp(OpAsmPrinter &p, ReturnOp op) { + p << "iree_hl_interp.return"; + if (op.getNumOperands() > 0) { + p << ' '; + p.printOperands(op.operand_begin(), op.operand_end()); + p << " : "; + interleaveComma(op.getOperandTypes(), p); + } +} + +//===----------------------------------------------------------------------===// +// iree_hl_interp.br +//===----------------------------------------------------------------------===// + +static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &result) { + Block *dest; + SmallVector<Value *, 4> destOperands; + if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure(); + result.addSuccessor(dest, destOperands); + return success(); +} + +static void printBranchOp(OpAsmPrinter &p, BranchOp op) { + p << "iree_hl_interp.br "; + p.printSuccessorAndUseList(op.getOperation(), 0); +} + +Block *BranchOp::getDest() { return getOperation()->getSuccessor(0); } + +void BranchOp::setDest(Block *block) { + return getOperation()->setSuccessor(block, 0); +} + +void BranchOp::eraseOperand(unsigned index) { + getOperation()->eraseSuccessorOperand(0, index); +} + +//===----------------------------------------------------------------------===// +// iree_hl_interp.cond_br +//===----------------------------------------------------------------------===// + +static ParseResult parseCondBranchOp(OpAsmParser &parser, + OperationState &result) { + SmallVector<Value *, 4> destOperands; + Block *dest; + OpAsmParser::OperandType condInfo; + + // Parse the condition. + Type int1Ty = parser.getBuilder().getI1Type(); + if (parser.parseOperand(condInfo) || parser.parseComma() || + parser.resolveOperand(condInfo, int1Ty, result.operands)) { + return parser.emitError(parser.getNameLoc(), + "expected condition type was boolean (i1)"); + } + + // Parse the true successor. + if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure(); + result.addSuccessor(dest, destOperands); + + // Parse the false successor. + destOperands.clear(); + if (parser.parseComma() || + parser.parseSuccessorAndUseList(dest, destOperands)) + return failure(); + result.addSuccessor(dest, destOperands); + + return success(); +} + +static void printCondBranchOp(OpAsmPrinter &p, CondBranchOp op) { + p << "iree_hl_interp.cond_br "; + p.printOperand(op.getCondition()); + p << ", "; + p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex); + p << ", "; + p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex); +} + +//===----------------------------------------------------------------------===// +// iree_hl_interp.clone +//===----------------------------------------------------------------------===// + +OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { + // If this is the only usage, we know the clone is unnecessary. + // TODO(b/135053584) More sophisticated analysis. + if (src()->hasOneUse()) return src(); + return {}; +} + +//===----------------------------------------------------------------------===// +// iree_hl_interp.concat +//===----------------------------------------------------------------------===// + +namespace { +struct ConcatToCopies : public OpRewritePattern<ConcatOp> { + using OpRewritePattern::OpRewritePattern; + PatternMatchResult matchAndRewrite(ConcatOp concatOp, + PatternRewriter &rewriter) const override { + auto finalType = concatOp.getResult()->getType().cast<ShapedType>(); + auto loc = concatOp.getLoc(); + std::vector<Value *> dimPieces; + auto dst = + rewriter.create<IREEInterp::HL::AllocHeapOp>(loc, finalType, dimPieces); + + llvm::SmallVector<int64_t, 4> zeroOffset(finalType.getRank(), 0); + auto srcIndices = createArrayConstant(rewriter, loc, zeroOffset); + + auto concatDimension = concatOp.dimension().getZExtValue(); + llvm::SmallVector<int64_t, 4> dstIndices(finalType.getRank(), 0); + for (auto *src : concatOp.srcs()) { + auto srcShape = src->getType().cast<ShapedType>().getShape(); + auto lengths = createArrayConstant(rewriter, loc, srcShape); + auto dstIndicesOp = createArrayConstant(rewriter, loc, dstIndices); + rewriter.create<IREEInterp::HL::CopyOp>(loc, src, srcIndices, dst, + dstIndicesOp, lengths); + dstIndices[concatDimension] += srcShape[concatDimension]; + } + + concatOp.replaceAllUsesWith(dst.getResult()); + + return matchSuccess(); + } +}; +} // namespace + +void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert<ConcatToCopies>(context); +} + +#define GET_OP_CLASSES +#include "compiler/IR/Interpreter/HLOps.cpp.inc" + +} // namespace HL +} // namespace IREEInterp +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/IR/Interpreter/HLOps.h b/compiler/IR/Interpreter/HLOps.h new file mode 100644 index 0000000..b02fbed --- /dev/null +++ b/compiler/IR/Interpreter/HLOps.h
@@ -0,0 +1,38 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_IR_INTERPRETER_HLOPS_H_ +#define IREE_COMPILER_IR_INTERPRETER_HLOPS_H_ + +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" + +namespace mlir { +namespace iree_compiler { +namespace IREEInterp { +namespace HL { + +#define GET_OP_CLASSES +#include "compiler/IR/Interpreter/HLOps.h.inc" + +} // namespace HL +} // namespace IREEInterp +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_IR_INTERPRETER_HLOPS_H_
diff --git a/compiler/IR/Interpreter/HLOps.td b/compiler/IR/Interpreter/HLOps.td new file mode 100644 index 0000000..63fa5bf --- /dev/null +++ b/compiler/IR/Interpreter/HLOps.td
@@ -0,0 +1,658 @@ +// IREE high-level interpreter op definitions. +// This op set contains pseudo ops, ops that accept non-MemRef types, and ops in +// normal SSA form. +// +// Through lowering these high-level ops are converted to low-level ops in the +// LLOps.td (iree_ll_interp.*). These map 1:1 with the bytecode, +// accept only MemRef types, and generally use output parameters instead of +// return types. +// +// The source of truth for bytecode opcodes is: +// https://github.com/google/iree/tree/master/iree/schemas/bytecode/interpreter_bytecode_v0.h + +#ifdef IREE_INTERPRETER_HL_OPS +#else +#define IREE_INTERPRETER_HL_OPS + +#ifdef IREE_OP_BASE +#else +include "compiler/IR/OpBase.td" +#endif // IREE_OP_BASE + +def IREEInterpHL_Dialect : Dialect { + let name = "iree_hl_interp"; + let cppNamespace = "IREEInterp::HL"; +} + +//===----------------------------------------------------------------------===// +// Base op classes +//===----------------------------------------------------------------------===// + +class IREEInterpHL_Op<string mnemonic, list<OpTrait> traits = []> : + Op<IREEInterpHL_Dialect, mnemonic, traits>; + +class IREEInterpHL_PureOp<string mnemonic, list<OpTrait> traits = []> : + IREEInterpHL_Op<mnemonic, !listconcat(traits, [NoSideEffect])>; + +//===----------------------------------------------------------------------===// +// High-level interpreter ops +//===----------------------------------------------------------------------===// + +def IREEInterpHL_CallOp : IREEInterpHL_Op<"call"> { + let arguments = (ins SymbolRefAttr:$callee, Variadic<IREEHL_MemRef>); + let results = (outs Variadic<IREEHL_MemRef>); + + let builders = [OpBuilder< + "Builder *builder, OperationState &result, FuncOp callee," + "ArrayRef<Value *> operands = {}", [{ + result.addOperands(operands); + result.addAttribute("callee", builder->getSymbolRefAttr(callee)); + result.addTypes(callee.getType().getResults()); + }]>, OpBuilder< + "Builder *builder, OperationState &result, StringRef callee," + "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{ + result.addOperands(operands); + result.addAttribute("callee", builder->getSymbolRefAttr(callee)); + result.addTypes(results); + }]>]; + + let extraClassDeclaration = [{ + // TODO(b/132296600): make tablegen follow the style guide. + StringRef getCallee() { return callee(); } + FunctionType getCalleeType(); + + // TODO(b/133879130): make tablegen support variadic operand accessors. + /// Get the argument operands to the called function. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + operand_iterator arg_operand_begin() { return operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + }]; + + let parser = [{ return parse$cppClass(parser, result); }]; + let printer = [{ return print$cppClass(p, *this); }]; +} + +def IREEInterpHL_CallIndirectOp : IREEInterpHL_Op<"call_indirect"> { + let arguments = (ins FunctionType:$callee, Variadic<IREEHL_MemRef>:$operands); + let results = (outs Variadic<IREEHL_MemRef>); + + let builders = [OpBuilder< + "Builder *, OperationState &result, Value *callee," + "ArrayRef<Value *> operands = {}", [{ + result.operands.push_back(callee); + result.addOperands(operands); + result.addTypes(callee->getType().cast<FunctionType>().getResults()); + }]>]; + + let extraClassDeclaration = [{ + // TODO(b/132296600): make tablegen follow the style guide. + Value *getCallee() { return getOperand(0); } + + // TODO(b/133879130): make tablegen support variadic operand accessors. + /// Get the argument operands to the called function. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + operand_iterator arg_operand_begin() { return ++operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + }]; + + let parser = [{ return parse$cppClass(parser, result); }]; + let printer = [{ return print$cppClass(p, *this); }]; +} + +def IREEInterpHL_ReturnOp : IREEInterpHL_Op<"return", [Terminator]> { + let arguments = (ins Variadic<IREEHL_MemRef>:$operands); + + let builders = [OpBuilder< + "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }] + >]; + + let parser = [{ return parse$cppClass(parser, result); }]; + let printer = [{ return print$cppClass(p, *this); }]; +} + +def IREEInterpHL_BranchOp : IREEInterpHL_Op<"br", [Terminator]> { + let arguments = (ins Variadic<IREEHL_MemRef>:$operands); + + let builders = [OpBuilder< + "Builder *, OperationState &result, Block *dest," + "ArrayRef<Value *> operands = {}", [{ + result.addSuccessor(dest, operands); + }]>]; + + let extraClassDeclaration = [{ + Block *getDest(); + void setDest(Block *block); + + /// Erase the operand at 'index' from the operand list. + void eraseOperand(unsigned index); + }]; + + let parser = [{ return parse$cppClass(parser, result); }]; + let printer = [{ return print$cppClass(p, *this); }]; +} + +def IREEInterpHL_CondBranchOp : IREEInterpHL_Op<"cond_br", [Terminator]> { + let arguments = (ins + IREEHL_BoolScalar:$condition, + Variadic<IREEHL_MemRef>:$branchOperands + ); + + let builders = [OpBuilder< + "Builder *, OperationState &result, Value *condition," + "Block *trueDest, ArrayRef<Value *> trueOperands," + "Block *falseDest, ArrayRef<Value *> falseOperands", [{ + result.addOperands(condition); + result.addSuccessor(trueDest, trueOperands); + result.addSuccessor(falseDest, falseOperands); + }]>]; + + let extraClassDeclaration = [{ + // These are the indices into the dests list. + enum { trueIndex = 0, falseIndex = 1 }; + + // The condition operand is the first operand in the list. + Value *getCondition() { return getOperand(0); } + + /// Return the destination if the condition is true. + Block *getTrueDest() { + return getOperation()->getSuccessor(trueIndex); + } + + /// Return the destination if the condition is false. + Block *getFalseDest() { + return getOperation()->getSuccessor(falseIndex); + } + + // Accessors for operands to the 'true' destination. + Value *getTrueOperand(unsigned idx) { + assert(idx < getNumTrueOperands()); + return getOperand(getTrueDestOperandIndex() + idx); + } + + void setTrueOperand(unsigned idx, Value *value) { + assert(idx < getNumTrueOperands()); + setOperand(getTrueDestOperandIndex() + idx, value); + } + + operand_iterator true_operand_begin() { + return operand_begin() + getTrueDestOperandIndex(); + } + operand_iterator true_operand_end() { + return true_operand_begin() + getNumTrueOperands(); + } + operand_range getTrueOperands() { + return {true_operand_begin(), true_operand_end()}; + } + + unsigned getNumTrueOperands() { + return getOperation()->getNumSuccessorOperands(trueIndex); + } + + /// Erase the operand at 'index' from the true operand list. + void eraseTrueOperand(unsigned index) { + getOperation()->eraseSuccessorOperand(trueIndex, index); + } + + // Accessors for operands to the 'false' destination. + Value *getFalseOperand(unsigned idx) { + assert(idx < getNumFalseOperands()); + return getOperand(getFalseDestOperandIndex() + idx); + } + void setFalseOperand(unsigned idx, Value *value) { + assert(idx < getNumFalseOperands()); + setOperand(getFalseDestOperandIndex() + idx, value); + } + + operand_iterator false_operand_begin() { return true_operand_end(); } + operand_iterator false_operand_end() { + return false_operand_begin() + getNumFalseOperands(); + } + operand_range getFalseOperands() { + return {false_operand_begin(), false_operand_end()}; + } + + unsigned getNumFalseOperands() { + return getOperation()->getNumSuccessorOperands(falseIndex); + } + + /// Erase the operand at 'index' from the false operand list. + void eraseFalseOperand(unsigned index) { + getOperation()->eraseSuccessorOperand(falseIndex, index); + } + + private: + /// Get the index of the first true destination operand. + unsigned getTrueDestOperandIndex() { return 1; } + + /// Get the index of the first false destination operand. + unsigned getFalseDestOperandIndex() { + return getTrueDestOperandIndex() + getNumTrueOperands(); + } + }]; + + let parser = [{ return parse$cppClass(parser, result); }]; + let printer = [{ return print$cppClass(p, *this); }]; +} + +def IREEInterpHL_CmpIOp : + IREEInterpHL_PureOp<"cmp_i", [SameOperandsAndResultShape, + AllTypesMatch<["lhs", "rhs"]>]> { + let arguments = (ins + I32Attr:$predicate, + IREEHL_IntMemRef:$lhs, + IREEHL_IntMemRef:$rhs + ); + let results = (outs IREEHL_BoolMemRef); +} + +def IREEInterpHL_CmpFOp : + IREEInterpHL_PureOp<"cmp_f", [SameOperandsAndResultShape, + AllTypesMatch<["lhs", "rhs"]>]> { + let arguments = (ins + I32Attr:$predicate, + IREEHL_FloatMemRef:$lhs, + IREEHL_FloatMemRef:$rhs + ); + let results = (outs IREEHL_BoolMemRef); +} + +// TODO(b/142012496): Add trait that enables DCE but not CSE. +def IREEInterpHL_AllocHeapOp : IREEInterpHL_Op<"alloc_heap"> { + // TODO(benvanik): attributes and args. + let arguments = (ins + Variadic<IREEHL_MemRef>:$dim_pieces + ); + let results = (outs + IREEHL_MemRef + ); +} + +def IREEInterpHL_DiscardOp : IREEInterpHL_Op<"discard"> { + let arguments = (ins IREEHL_MemRef); +} + +def IREEInterpHL_RankOp : IREEInterpHL_PureOp<"rank"> { + let arguments = (ins IREEHL_MemRef); + let results = (outs IREEHL_IntScalar); +} + +def IREEInterpHL_DimOp : IREEInterpHL_PureOp<"dim"> { + // TODO(benvanik) add dim attr (I32Attr:$dim) + let arguments = (ins IREEHL_MemRef); + let results = (outs IREEHL_IntScalar); +} + +def IREEInterpHL_ShapeOp : IREEInterpHL_PureOp<"shape"> { + let arguments = (ins IREEHL_MemRef); + let results = (outs IREEHL_1DIntMemRef); +} + +def IREEInterpHL_LengthOp : IREEInterpHL_PureOp<"length"> { + let arguments = (ins IREEHL_MemRef); + let results = (outs IREEHL_IndexScalar); +} + +def IREEInterpHL_SliceOp : + IREEInterpHL_PureOp<"slice", [AllElementTypesMatch<["src", "result"]>, + AllTypesMatch<["srcIndices", "lengths"]>]> { + let arguments = (ins + IREEHL_MemRef:$src, + IREEHL_1DIndexMemRef:$srcIndices, + IREEHL_1DIndexMemRef:$lengths + ); + let results = (outs IREEHL_MemRef:$result); +} + +def IREEInterpHL_CopyOp : IREEInterpHL_Op<"copy", [ + AllElementCountsMatch<["srcIndices", "dstIndices", "lengths"]>, + AllRanksMatch<["src", "dst"]>, + // The checks above are redundant with this one, but they give more specific + // error messages. + AllMatch<[ + Rank<"src">.result, + Rank<"dst">.result, + ElementCount<"srcIndices">.result, + ElementCount<"dstIndices">.result, + ElementCount<"lengths">.result + ], "src/dst rank is the same as srcIndices/dstIndices/lengths size">, + AllElementTypesMatch<["src", "dst"]> +]> { + let arguments = (ins + IREEHL_MemRef:$src, + IREEHL_1DIndexMemRef:$srcIndices, + IREEHL_MemRef:$dst, + IREEHL_1DIndexMemRef:$dstIndices, + IREEHL_1DIndexMemRef:$lengths + ); +} + +def IREEInterpHL_CloneOp : + IREEInterpHL_PureOp<"clone", [SameOperandsAndResultType]> { + let arguments = (ins IREEHL_MemRef:$src); + let results = (outs IREEHL_MemRef); + + let hasFolder = 1; +} + +// A pseudo op provided for convenience. This gets canonicalized to a series of +// copies. +def IREEInterpHL_ConcatOp : IREEInterpHL_PureOp<"concat"> { + let arguments = (ins + Variadic<IREEHL_MemRef>:$srcs, + I32Attr:$dimension + ); + let results = (outs IREEHL_MemRef); + + let hasCanonicalizer = 1; +} + +// TODO(benvanik): add split dim/size/etc. Maybe make multiple ops? +def IREEInterpHL_SplitOp : + IREEInterpHL_PureOp<"split", [SameOperandsAndResultElementType]> { + let arguments = (ins IREEHL_MemRef:$src); + let results = (outs Variadic<IREEHL_MemRef>); +} + +def IREEInterpHL_AssignOp : + IREEInterpHL_PureOp<"assign", [SameOperandsAndResultType]> { + let arguments = (ins IREEHL_MemRef:$src); + let results = (outs IREEHL_MemRef:$result); +} + +def IREEInterpHL_CondAssignOp : + IREEInterpHL_PureOp<"cond_assign", + [AllTypesMatch<["lhs", "rhs", "result"]>]> { + let arguments = (ins + IREEHL_BoolScalar:$cond, + IREEHL_MemRef:$lhs, + IREEHL_MemRef:$rhs + ); + let results = (outs IREEHL_MemRef:$result); +} + +def IREEInterpHL_ReshapeOp : IREEInterpHL_PureOp<"reshape"> { + let arguments = (ins IREEHL_MemRef:$src, IREEHL_MemRef:$shape); + let results = (outs IREEHL_MemRef); +} + +def IREEInterpHL_SelectOp : + IREEInterpHL_PureOp<"select", [AllTypesMatch<["lhs", "rhs", "result"]>]> { + let arguments = (ins + IREEHL_BoolMemRef:$cond, + IREEHL_MemRef:$lhs, + IREEHL_MemRef:$rhs + ); + let results = (outs IREEHL_MemRef:$result); +} + +def IREEInterpHL_BroadcastOp : + IREEInterpHL_PureOp<"broadcast", + [AllElementTypesMatch<["operand", "result"]>]> { + let arguments = (ins + IREE_ScalarMemRefOf<[AnyType]>:$operand, + IREEHL_1DIntMemRef:$shape + ); + let results = (outs IREEHL_MemRef:$result); +} + +def IREEInterpHL_PadOp : + IREEInterpHL_PureOp< + "pad", [AllElementTypesMatch<["src", "result", "padding_value"]>]> { + let arguments = (ins + IREEHL_MemRef:$src, + IREEHL_AnyScalar:$padding_value, + IREEHL_1DIndexMemRef:$edge_padding_low, + IREEHL_1DIndexMemRef:$edge_padding_high, + IREEHL_1DIndexMemRef:$interior_padding + ); + + let results = (outs IREEHL_MemRef:$result); +} + +def IREEInterpHL_TileOp : + IREEInterpHL_PureOp<"tile", [AllElementTypesMatch<["operand", "result"]>]> { + let arguments = (ins + IREEHL_MemRef:$operand, + IREEHL_1DIntMemRef:$shape + ); + let results = (outs IREEHL_MemRef:$result); +} + +def IREEInterpHL_TransposeOp : + IREEInterpHL_PureOp<"transpose", [ + AllElementTypesMatch<["operand", "result"]>, + AllRanksMatch<["operand", "result"]>, + AllElementCountsMatch<["operand", "result"]> + ]> { + let arguments = (ins + IREEHL_MemRef:$operand, + IREEHL_1DIntMemRef:$permutation + ); + let results = (outs IREEHL_MemRef:$result); +} + +def IREEInterpHL_ReverseOp : + IREEInterpHL_PureOp<"reverse", [AllTypesMatch<["operand", "result"]>]> { + let arguments = (ins + IREEHL_MemRef:$operand, + IREEHL_1DIntMemRef:$dims + ); + let results = (outs IREEHL_MemRef:$result); +} + +class IREEInterpHL_UnaryElementwiseOp<string mnemonic, Type type, + list<OpTrait> traits = []> : + IREEInterpHL_PureOp<mnemonic, + !listconcat(traits, [SameOperandsAndResultType])> { + let arguments = (ins type); + let results = (outs type); +} + +class IREEInterpHL_UnaryElementwiseFloatOp<string mnemonic, + list<OpTrait> traits = []> : + IREEInterpHL_UnaryElementwiseOp<mnemonic, IREEHL_FloatMemRef, traits>; + +class IREEInterpHL_UnaryElementwiseIntOp<string mnemonic, + list<OpTrait> traits = []> : + IREEInterpHL_UnaryElementwiseOp<mnemonic, IREEHL_IntMemRef, traits>; + +class IREEInterpHL_BinaryElementwiseOp<string mnemonic, Type type, + list<OpTrait> traits> : + IREEInterpHL_PureOp<mnemonic, + !listconcat(traits, [SameOperandsAndResultType])> { + let arguments = (ins type:$lhs, type:$rhs); + let results = (outs type); +} + +class IREEInterpHL_BinaryElementwiseFloatOp<string mnemonic, + list<OpTrait> traits = []> : + IREEInterpHL_BinaryElementwiseOp<mnemonic, IREEHL_FloatMemRef, + traits>; + +class IREEInterpHL_BinaryElementwiseIntOp<string mnemonic, + list<OpTrait> traits = []> : + IREEInterpHL_BinaryElementwiseOp<mnemonic, IREEHL_IntMemRef, + traits>; + +class IREEInterpHL_TernaryOp<string mnemonic, + Type type = IREEHL_MemRef, + list<OpTrait> traits = []> : + IREEInterpHL_PureOp<mnemonic, traits> { + let arguments = (ins type:$a, type:$b, type:$c); + let results = (outs type); +} + +// TODO(benvanik): add traits for broadcasting support. + +def IREEInterpHL_NotOp : IREEInterpHL_UnaryElementwiseIntOp<"not">; +def IREEInterpHL_AndOp : IREEInterpHL_BinaryElementwiseIntOp<"and">; +def IREEInterpHL_OrOp : IREEInterpHL_BinaryElementwiseIntOp<"or">; +def IREEInterpHL_XorOp : IREEInterpHL_BinaryElementwiseIntOp<"xor">; +def IREEInterpHL_ShiftLeftOp : IREEInterpHL_BinaryElementwiseIntOp<"sll">; +def IREEInterpHL_ShiftRightLogicalOp : IREEInterpHL_BinaryElementwiseIntOp<"srl">; +def IREEInterpHL_ShiftRightArithmeticOp : IREEInterpHL_BinaryElementwiseIntOp<"sra">; + +def IREEInterpHL_AddIOp : IREEInterpHL_BinaryElementwiseIntOp<"add_i">; +def IREEInterpHL_AddFOp : IREEInterpHL_BinaryElementwiseFloatOp<"add_f">; +def IREEInterpHL_SubIOp : IREEInterpHL_BinaryElementwiseIntOp<"sub_i">; +def IREEInterpHL_SubFOp : IREEInterpHL_BinaryElementwiseFloatOp<"sub_f">; +def IREEInterpHL_AbsIOp : IREEInterpHL_UnaryElementwiseIntOp<"abs_i">; +def IREEInterpHL_AbsFOp : IREEInterpHL_UnaryElementwiseFloatOp<"abs_f">; +def IREEInterpHL_MulIOp : IREEInterpHL_BinaryElementwiseIntOp<"mul_i">; +def IREEInterpHL_MulFOp : IREEInterpHL_BinaryElementwiseFloatOp<"mul_f">; +def IREEInterpHL_DivISOp : IREEInterpHL_BinaryElementwiseIntOp<"div_i_s">; +def IREEInterpHL_DivIUOp : IREEInterpHL_BinaryElementwiseIntOp<"div_i_u">; +def IREEInterpHL_DivFOp : IREEInterpHL_BinaryElementwiseFloatOp<"div_f">; +def IREEInterpHL_MulAddIOp : IREEInterpHL_TernaryOp<"madd_i", IREEHL_IntMemRef>; +def IREEInterpHL_MulAddFOp : IREEInterpHL_TernaryOp<"madd_f", IREEHL_FloatMemRef>; +def IREEInterpHL_ExpFOp : IREEInterpHL_UnaryElementwiseFloatOp<"exp_f">; +def IREEInterpHL_LogFOp : IREEInterpHL_UnaryElementwiseFloatOp<"log_f">; +def IREEInterpHL_RsqrtFOp : IREEInterpHL_UnaryElementwiseFloatOp<"rsqrt_f">; +def IREEInterpHL_CosFOp : IREEInterpHL_UnaryElementwiseFloatOp<"cos_f">; +def IREEInterpHL_SinFOp : IREEInterpHL_UnaryElementwiseFloatOp<"sin_f">; +def IREEInterpHL_TanhFOp : IREEInterpHL_UnaryElementwiseFloatOp<"tanh_f">; +def IREEInterpHL_Atan2FOp : IREEInterpHL_UnaryElementwiseFloatOp<"atan2_f">; + +def IREEInterpHL_MinISOp : IREEInterpHL_BinaryElementwiseIntOp<"min_i_s">; +def IREEInterpHL_MinIUOp : IREEInterpHL_BinaryElementwiseIntOp<"min_i_u">; +def IREEInterpHL_MinFOp : IREEInterpHL_BinaryElementwiseFloatOp<"min_f">; +def IREEInterpHL_MaxISOp : IREEInterpHL_BinaryElementwiseIntOp<"max_i_s">; +def IREEInterpHL_MaxIUOp : IREEInterpHL_BinaryElementwiseIntOp<"max_i_u">; +def IREEInterpHL_MaxFOp : IREEInterpHL_BinaryElementwiseFloatOp<"max_f">; +def IREEInterpHL_ClampFOp : IREEInterpHL_TernaryOp<"clamp_f", IREEHL_FloatMemRef>; +def IREEInterpHL_FloorFOp : IREEInterpHL_UnaryElementwiseFloatOp<"floor_f">; +def IREEInterpHL_CeilFOp : IREEInterpHL_UnaryElementwiseFloatOp<"ceil_f">; + +class IREEInterpHL_ConversionOp<string mnemonic, Type inputType, + Type outputType> : + IREEInterpHL_PureOp<mnemonic, [SameOperandsAndResultShape]> { + let arguments = (ins inputType); + let results = (outs outputType); +} + +def IREEInterpHL_ConvertSSOp : + IREEInterpHL_ConversionOp<"convert_s_s", IREEHL_IntMemRef, + IREEHL_IntMemRef>; +def IREEInterpHL_ConvertSUOp : + IREEInterpHL_ConversionOp<"convert_s_u", IREEHL_IntMemRef, + IREEHL_IntMemRef>; +def IREEInterpHL_ConvertSFOp : + IREEInterpHL_ConversionOp<"convert_s_f", IREEHL_IntMemRef, + IREEHL_FloatMemRef>; + +def IREEInterpHL_ConvertUSOp : + IREEInterpHL_ConversionOp<"convert_u_s", IREEHL_IntMemRef, + IREEHL_IntMemRef>; +def IREEInterpHL_ConvertUUOp : + IREEInterpHL_ConversionOp<"convert_u_u", IREEHL_IntMemRef, + IREEHL_IntMemRef>; +def IREEInterpHL_ConvertUFOp : + IREEInterpHL_ConversionOp<"convert_u_f", IREEHL_IntMemRef, + IREEHL_FloatMemRef>; + +def IREEInterpHL_ConvertFSOp : + IREEInterpHL_ConversionOp<"convert_f_s", IREEHL_FloatMemRef, + IREEHL_IntMemRef>; +def IREEInterpHL_ConvertFUOp : + IREEInterpHL_ConversionOp<"convert_f_u", IREEHL_FloatMemRef, + IREEHL_IntMemRef>; +def IREEInterpHL_ConvertFFOp : + IREEInterpHL_ConversionOp<"convert_f_f", IREEHL_FloatMemRef, + IREEHL_FloatMemRef>; + +def IREEInterpHL_MatMulIOp : + IREEInterpHL_PureOp<"matmul_i", + [AllElementTypesMatch<["lhs", "rhs", "result"]>]> { + let arguments = (ins + IREEHL_IntMemRef:$lhs, + IREEHL_IntMemRef:$rhs, + IREEHL_IntMemRef:$multiplier_mantissa, + IREEHL_IntMemRef:$multiplier_exponent + ); + let results = (outs IREEHL_IntMemRef:$result); +} +def IREEInterpHL_MatMulFOp : + IREEInterpHL_PureOp<"matmul_f", [SameOperandsAndResultElementType]> { + let arguments = (ins + IREEHL_FloatMemRef:$lhs, + IREEHL_FloatMemRef:$rhs + ); + let results = (outs IREEHL_FloatMemRef); +} + +def IREEInterpHL_ReduceSumIOp : + IREEInterpHL_PureOp<"reduce_sum_i", + [AllElementTypesMatch<["src", "result", "init"]>]> { + let arguments = (ins + IREEHL_IntMemRef:$src, + IREEHL_IntMemRef:$init, + I32Attr:$dimension + ); + let results = (outs IREEHL_IntMemRef:$result); +} +def IREEInterpHL_ReduceSumFOp : + IREEInterpHL_PureOp<"reduce_sum_f", + [AllElementTypesMatch<["src", "result", "init"]>]> { + let arguments = (ins + IREEHL_FloatMemRef:$src, + IREEHL_FloatMemRef:$init, + I32Attr:$dimension + ); + let results = (outs IREEHL_FloatMemRef:$result); +} +def IREEInterpHL_ReduceMinIOp : + IREEInterpHL_PureOp<"reduce_min_i", + [AllElementTypesMatch<["src", "result", "init"]>]> { + let arguments = (ins + IREEHL_IntMemRef:$src, + IREEHL_IntMemRef:$init, + I32Attr:$dimension + ); + let results = (outs IREEHL_IntMemRef:$result); +} +def IREEInterpHL_ReduceMinFOp : + IREEInterpHL_PureOp<"reduce_min_f", + [AllElementTypesMatch<["src", "result", "init"]>]> { + let arguments = (ins + IREEHL_FloatMemRef:$src, + IREEHL_FloatMemRef:$init, + I32Attr:$dimension + ); + let results = (outs IREEHL_FloatMemRef:$result); +} +def IREEInterpHL_ReduceMaxIOp : + IREEInterpHL_PureOp<"reduce_max_i", + [AllElementTypesMatch<["src", "result", "init"]>]> { + let arguments = (ins + IREEHL_IntMemRef:$src, + IREEHL_IntMemRef:$init, + I32Attr:$dimension + ); + let results = (outs IREEHL_IntMemRef:$result); +} +def IREEInterpHL_ReduceMaxFOp : + IREEInterpHL_PureOp<"reduce_max_f", + [AllElementTypesMatch<["src", "result", "init"]>]> { + let arguments = (ins + IREEHL_FloatMemRef:$src, + IREEHL_FloatMemRef:$init, + I32Attr:$dimension + ); + let results = (outs IREEHL_FloatMemRef:$result); +} + +def IREEInterpHL_TraceOp : IREEInterpHL_Op<"trace"> { + let arguments = (ins Variadic<IREEHL_MemRef>:$srcs); +} + +def IREEInterpHL_CondBreakOp : IREEInterpHL_Op<"cond_break"> { + let arguments = (ins IREEHL_BoolScalar:$cond); +} + +def IREEInterpHL_BreakOp : IREEInterpHL_Op<"break">; + +#endif // IREE_INTERPRETER_HL_OPS
diff --git a/compiler/IR/Interpreter/LLDialect.cpp b/compiler/IR/Interpreter/LLDialect.cpp new file mode 100644 index 0000000..0b76111 --- /dev/null +++ b/compiler/IR/Interpreter/LLDialect.cpp
@@ -0,0 +1,34 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Interpreter/LLDialect.h" + +#include "compiler/IR/Interpreter/LLOps.h" +#include "llvm/Support/SourceMgr.h" + +namespace mlir { +namespace iree_compiler { + +IREELLInterpreterDialect::IREELLInterpreterDialect(MLIRContext* context) + : Dialect(getDialectNamespace(), context) { +#define GET_OP_LIST + addOperations< +#include "compiler/IR/Interpreter/LLOps.cpp.inc" + >(); +} + +static DialectRegistration<IREELLInterpreterDialect> iree_ll_interp_dialect; + +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/IR/Interpreter/LLDialect.h b/compiler/IR/Interpreter/LLDialect.h similarity index 100% rename from iree/compiler/IR/Interpreter/LLDialect.h rename to compiler/IR/Interpreter/LLDialect.h
diff --git a/compiler/IR/Interpreter/LLOps.cpp b/compiler/IR/Interpreter/LLOps.cpp new file mode 100644 index 0000000..b5acace --- /dev/null +++ b/compiler/IR/Interpreter/LLOps.cpp
@@ -0,0 +1,228 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Interpreter/LLOps.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/OpImplementation.h" + +namespace mlir { +namespace iree_compiler { +namespace IREEInterp { +namespace LL { + +//===----------------------------------------------------------------------===// +// iree_ll_interp.call +//===----------------------------------------------------------------------===// + +static ParseResult parseCallOp(OpAsmParser &parser, OperationState &state) { + SymbolRefAttr calleeAttr; + FunctionType calleeType; + SmallVector<OpAsmParser::OperandType, 4> operands; + auto calleeLoc = parser.getNameLoc(); + if (parser.parseAttribute(calleeAttr, "callee", state.attributes) || + parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || + parser.parseOptionalAttributeDict(state.attributes) || + parser.parseColonType(calleeType) || + parser.addTypesToList(calleeType.getResults(), state.types) || + parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc, + state.operands)) { + return failure(); + } + return success(); +} + +static void printCallOp(OpAsmPrinter &p, CallOp op) { + p << "iree_ll_interp.call " << op.getAttr("callee") << '('; + p.printOperands(op.getOperands()); + p << ')'; + p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); + p << " : "; + p.printType(op.getCalleeType()); +} + +FunctionType CallOp::getCalleeType() { + SmallVector<Type, 4> resultTypes(getResultTypes()); + SmallVector<Type, 8> argTypes(getOperandTypes()); + return FunctionType::get(argTypes, resultTypes, getContext()); +} + +//===----------------------------------------------------------------------===// +// iree_ll_interp.call_import +//===----------------------------------------------------------------------===// + +static ParseResult parseCallImportOp(OpAsmParser &parser, + OperationState &state) { + SymbolRefAttr calleeAttr; + FunctionType calleeType; + SmallVector<OpAsmParser::OperandType, 4> operands; + auto calleeLoc = parser.getNameLoc(); + if (parser.parseAttribute(calleeAttr, "callee", state.attributes) || + parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || + parser.parseOptionalAttributeDict(state.attributes) || + parser.parseColonType(calleeType) || + parser.addTypesToList(calleeType.getResults(), state.types) || + parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc, + state.operands)) { + return failure(); + } + return success(); +} + +static void printCallImportOp(OpAsmPrinter &p, CallImportOp op) { + p << "iree_ll_interp.call_import " << op.getAttr("callee") << '('; + p.printOperands(op.getOperands()); + p << ')'; + p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); + p << " : "; + p.printType(op.getCalleeType()); +} + +FunctionType CallImportOp::getCalleeType() { + SmallVector<Type, 4> resultTypes(getResultTypes()); + SmallVector<Type, 8> argTypes(getOperandTypes()); + return FunctionType::get(argTypes, resultTypes, getContext()); +} + +//===----------------------------------------------------------------------===// +// iree_ll_interp.call_indirect +//===----------------------------------------------------------------------===// + +static ParseResult parseCallIndirectOp(OpAsmParser &parser, + OperationState &result) { + FunctionType calleeType; + OpAsmParser::OperandType callee; + llvm::SMLoc operandsLoc; + SmallVector<OpAsmParser::OperandType, 4> operands; + return failure( + parser.parseOperand(callee) || parser.getCurrentLocation(&operandsLoc) || + parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || + parser.parseOptionalAttributeDict(result.attributes) || + parser.parseColonType(calleeType) || + parser.resolveOperand(callee, calleeType, result.operands) || + parser.resolveOperands(operands, calleeType.getInputs(), operandsLoc, + result.operands) || + parser.addTypesToList(calleeType.getResults(), result.types)); +} + +static void printCallIndirectOp(OpAsmPrinter &p, CallIndirectOp op) { + p << "iree_ll_interp.call_indirect "; + p.printOperand(op.getCallee()); + p << '('; + auto operandRange = op.getOperands(); + p.printOperands(++operandRange.begin(), operandRange.end()); + p << ')'; + p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); + p << " : " << op.getCallee()->getType(); +} + +//===----------------------------------------------------------------------===// +// iree_ll_interp.return +//===----------------------------------------------------------------------===// + +static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &state) { + SmallVector<OpAsmParser::OperandType, 2> opInfo; + SmallVector<Type, 2> types; + llvm::SMLoc loc = parser.getCurrentLocation(); + return failure(parser.parseOperandList(opInfo) || + (!opInfo.empty() && parser.parseColonTypeList(types)) || + parser.resolveOperands(opInfo, types, loc, state.operands)); +} + +static void printReturnOp(OpAsmPrinter &p, ReturnOp op) { + p << "iree_ll_interp.return"; + if (op.getNumOperands() > 0) { + p << ' '; + p.printOperands(op.operand_begin(), op.operand_end()); + p << " : "; + interleaveComma(op.getOperandTypes(), p); + } +} + +//===----------------------------------------------------------------------===// +// iree_ll_interp.br +//===----------------------------------------------------------------------===// + +static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &result) { + Block *dest; + SmallVector<Value *, 4> destOperands; + if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure(); + result.addSuccessor(dest, destOperands); + return success(); +} + +static void printBranchOp(OpAsmPrinter &p, BranchOp op) { + p << "iree_ll_interp.br "; + p.printSuccessorAndUseList(op.getOperation(), 0); +} + +Block *BranchOp::getDest() { return getOperation()->getSuccessor(0); } + +void BranchOp::setDest(Block *block) { + return getOperation()->setSuccessor(block, 0); +} + +void BranchOp::eraseOperand(unsigned index) { + getOperation()->eraseSuccessorOperand(0, index); +} + +//===----------------------------------------------------------------------===// +// iree_ll_interp.cond_br +//===----------------------------------------------------------------------===// + +static ParseResult parseCondBranchOp(OpAsmParser &parser, + OperationState &result) { + SmallVector<Value *, 4> destOperands; + Block *dest; + OpAsmParser::OperandType condInfo; + + // Parse the condition. + Type int1Ty = parser.getBuilder().getI1Type(); + if (parser.parseOperand(condInfo) || parser.parseComma() || + parser.resolveOperand(condInfo, int1Ty, result.operands)) { + return parser.emitError(parser.getNameLoc(), + "expected condition type was boolean (i1)"); + } + + // Parse the true successor. + if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure(); + result.addSuccessor(dest, destOperands); + + // Parse the false successor. + destOperands.clear(); + if (parser.parseComma() || + parser.parseSuccessorAndUseList(dest, destOperands)) + return failure(); + result.addSuccessor(dest, destOperands); + + return success(); +} + +static void printCondBranchOp(OpAsmPrinter &p, CondBranchOp op) { + p << "iree_ll_interp.cond_br "; + p.printOperand(op.getCondition()); + p << ", "; + p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex); + p << ", "; + p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex); +} + +#define GET_OP_CLASSES +#include "compiler/IR/Interpreter/LLOps.cpp.inc" + +} // namespace LL +} // namespace IREEInterp +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/IR/Interpreter/LLOps.h b/compiler/IR/Interpreter/LLOps.h new file mode 100644 index 0000000..5c99ce8 --- /dev/null +++ b/compiler/IR/Interpreter/LLOps.h
@@ -0,0 +1,38 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_IR_INTERPRETER_LLOPS_H_ +#define IREE_COMPILER_IR_INTERPRETER_LLOPS_H_ + +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" + +namespace mlir { +namespace iree_compiler { +namespace IREEInterp { +namespace LL { + +#define GET_OP_CLASSES +#include "compiler/IR/Interpreter/LLOps.h.inc" + +} // namespace LL +} // namespace IREEInterp +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_IR_INTERPRETER_LLOPS_H_
diff --git a/compiler/IR/Interpreter/LLOps.td b/compiler/IR/Interpreter/LLOps.td new file mode 100644 index 0000000..7453835 --- /dev/null +++ b/compiler/IR/Interpreter/LLOps.td
@@ -0,0 +1,633 @@ +// IREE low-level interpreter op definitions. +// These map 1:1 with the bytecode, accept only MemRef types and generally use +// output parameters instead of return types. +// +// The source of truth for bytecode opcodes is: +// https://github.com/google/iree/tree/master/iree/schemas/bytecode/interpreter_bytecode_v0.h + +#ifdef IREE_INTERPRETER_LL_OPS +#else +#define IREE_INTERPRETER_LL_OPS + +#ifdef IREE_OP_BASE +#else +include "compiler/IR/OpBase.td" +#endif // IREE_OP_BASE + +def IREEInterpLL_Dialect : Dialect { + let name = "iree_ll_interp"; + let cppNamespace = "IREEInterp::LL"; +} + +//===----------------------------------------------------------------------===// +// Base op classes +//===----------------------------------------------------------------------===// + +class IREEInterpLL_Op<string mnemonic, list<OpTrait> traits = []> : + Op<IREEInterpLL_Dialect, mnemonic, traits>; + +class IREEInterpLL_PureOp<string mnemonic, list<OpTrait> traits = []> : + IREEInterpLL_Op<mnemonic, !listconcat(traits, [NoSideEffect])>; + +class IREEInterpLL_UnaryOp<string mnemonic, Type type = IREELL_MemRef, + list<OpTrait> traits = []> : IREEInterpLL_Op<mnemonic, traits> { + let arguments = (ins type:$input, type:$dst); +} + +class IREEInterpLL_BinaryOp<string mnemonic, Type type = IREELL_MemRef, + list<OpTrait> traits = []> : IREEInterpLL_Op<mnemonic, traits> { + let arguments = (ins type:$lhs, type:$rhs, type:$dst); +} + +class IREEInterpLL_TernaryOp<string mnemonic, Type type = IREELL_MemRef, + list<OpTrait> traits = []> + : IREEInterpLL_Op<mnemonic, traits> { + let arguments = (ins type : $a, type : $b, type : $c, type : $dst); +} + +//===----------------------------------------------------------------------===// +// Low-level interpreter ops +//===----------------------------------------------------------------------===// + +// TODO(benvanik): value attribute. +def IREEInterpLL_ConstantOp : IREEInterpLL_PureOp<"constant"> { + let results = (outs IREELL_MemRef); +} + +def IREEInterpLL_CallOp : IREEInterpLL_Op<"call"> { + let arguments = (ins SymbolRefAttr:$callee, Variadic<IREELL_MemRef>); + let results = (outs Variadic<IREELL_MemRef>); + + let builders = [OpBuilder< + "Builder *builder, OperationState &result, FuncOp callee," + "ArrayRef<Value *> operands = {}", [{ + result.addOperands(operands); + result.addAttribute("callee", builder->getSymbolRefAttr(callee)); + result.addTypes(callee.getType().getResults()); + }]>, OpBuilder< + "Builder *builder, OperationState &result, StringRef callee," + "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{ + result.addOperands(operands); + result.addAttribute("callee", builder->getSymbolRefAttr(callee)); + result.addTypes(results); + }]>]; + + let extraClassDeclaration = [{ + // TODO(b/132296600): make tablegen follow the style guide. + StringRef getCallee() { return callee(); } + FunctionType getCalleeType(); + + // TODO(b/133879130): make tablegen support variadic operand accessors. + /// Get the argument operands to the called function. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + + operand_iterator arg_operand_begin() { return operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + }]; + + let parser = [{ return parse$cppClass(parser, result); }]; + let printer = [{ return print$cppClass(p, *this); }]; +} + +// TODO(benvanik): add verifier that target isExternal. +def IREEInterpLL_CallImportOp : IREEInterpLL_Op<"call_import"> { + let arguments = (ins SymbolRefAttr:$callee, Variadic<IREELL_MemRef>); + let results = (outs Variadic<IREELL_MemRef>); + + let builders = [OpBuilder< + "Builder *builder, OperationState &result, FuncOp callee," + "ArrayRef<Value *> operands = {}", [{ + result.addOperands(operands); + result.addAttribute("callee", builder->getSymbolRefAttr(callee)); + result.addTypes(callee.getType().getResults()); + }]>, OpBuilder< + "Builder *builder, OperationState &result, StringRef callee," + "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{ + result.addOperands(operands); + result.addAttribute("callee", builder->getSymbolRefAttr(callee)); + result.addTypes(results); + }]>]; + + let extraClassDeclaration = [{ + // TODO(b/132296600): make tablegen follow the style guide. + StringRef getCallee() { return callee(); } + FunctionType getCalleeType(); + + // TODO(b/133879130): make tablegen support variadic operand accessors. + /// Get the argument operands to the called function. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + + operand_iterator arg_operand_begin() { return operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + }]; + + let parser = [{ return parse$cppClass(parser, result); }]; + let printer = [{ return print$cppClass(p, *this); }]; +} + +def IREEInterpLL_CallIndirectOp : IREEInterpLL_Op<"call_indirect"> { + let arguments = (ins FunctionType:$callee, Variadic<IREELL_MemRef>:$operands); + let results = (outs Variadic<IREELL_MemRef>); + + let builders = [OpBuilder< + "Builder *, OperationState &result, Value *callee," + "ArrayRef<Value *> operands = {}", [{ + result.operands.push_back(callee); + result.addOperands(operands); + result.addTypes(callee->getType().cast<FunctionType>().getResults()); + }]>]; + + let extraClassDeclaration = [{ + Value *getCallee() { return getOperand(0); } + + /// Get the argument operands to the called function. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + + operand_iterator arg_operand_begin() { return ++operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + }]; + + let parser = [{ return parse$cppClass(parser, result); }]; + let printer = [{ return print$cppClass(p, *this); }]; +} + +def IREEInterpLL_ReturnOp : IREEInterpLL_Op<"return", [Terminator]> { + let arguments = (ins Variadic<IREELL_MemRef>:$operands); + + let builders = [OpBuilder< + "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }] + >]; + + let parser = [{ return parse$cppClass(parser, result); }]; + let printer = [{ return print$cppClass(p, *this); }]; +} + +def IREEInterpLL_BranchOp : IREEInterpLL_Op<"br", [Terminator]> { + let arguments = (ins Variadic<IREELL_MemRef>:$operands); + + let builders = [OpBuilder< + "Builder *, OperationState &result, Block *dest," + "ArrayRef<Value *> operands = {}", [{ + result.addSuccessor(dest, operands); + }]>]; + + let extraClassDeclaration = [{ + Block *getDest(); + void setDest(Block *block); + + /// Erase the operand at 'index' from the operand list. + void eraseOperand(unsigned index); + }]; + + let parser = [{ return parse$cppClass(parser, result); }]; + let printer = [{ return print$cppClass(p, *this); }]; +} + +def IREEInterpLL_CondBranchOp : IREEInterpLL_Op<"cond_br", [Terminator]> { + let arguments = (ins + IREELL_BoolScalar:$condition, + Variadic<IREELL_MemRef>:$branchOperands + ); + + let builders = [OpBuilder< + "Builder *, OperationState &result, Value *condition," + "Block *trueDest, ArrayRef<Value *> trueOperands," + "Block *falseDest, ArrayRef<Value *> falseOperands", [{ + result.addOperands(condition); + result.addSuccessor(trueDest, trueOperands); + result.addSuccessor(falseDest, falseOperands); + }]>]; + + let extraClassDeclaration = [{ + // These are the indices into the dests list. + enum { trueIndex = 0, falseIndex = 1 }; + + // The condition operand is the first operand in the list. + Value *getCondition() { return getOperand(0); } + + /// Return the destination if the condition is true. + Block *getTrueDest() { + return getOperation()->getSuccessor(trueIndex); + } + + /// Return the destination if the condition is false. + Block *getFalseDest() { + return getOperation()->getSuccessor(falseIndex); + } + + // Accessors for operands to the 'true' destination. + Value *getTrueOperand(unsigned idx) { + assert(idx < getNumTrueOperands()); + return getOperand(getTrueDestOperandIndex() + idx); + } + + void setTrueOperand(unsigned idx, Value *value) { + assert(idx < getNumTrueOperands()); + setOperand(getTrueDestOperandIndex() + idx, value); + } + + operand_iterator true_operand_begin() { + return operand_begin() + getTrueDestOperandIndex(); + } + operand_iterator true_operand_end() { + return true_operand_begin() + getNumTrueOperands(); + } + operand_range getTrueOperands() { + return {true_operand_begin(), true_operand_end()}; + } + + unsigned getNumTrueOperands() { + return getOperation()->getNumSuccessorOperands(trueIndex); + } + + /// Erase the operand at 'index' from the true operand list. + void eraseTrueOperand(unsigned index) { + getOperation()->eraseSuccessorOperand(trueIndex, index); + } + + // Accessors for operands to the 'false' destination. + Value *getFalseOperand(unsigned idx) { + assert(idx < getNumFalseOperands()); + return getOperand(getFalseDestOperandIndex() + idx); + } + void setFalseOperand(unsigned idx, Value *value) { + assert(idx < getNumFalseOperands()); + setOperand(getFalseDestOperandIndex() + idx, value); + } + + operand_iterator false_operand_begin() { return true_operand_end(); } + operand_iterator false_operand_end() { + return false_operand_begin() + getNumFalseOperands(); + } + operand_range getFalseOperands() { + return {false_operand_begin(), false_operand_end()}; + } + + unsigned getNumFalseOperands() { + return getOperation()->getNumSuccessorOperands(falseIndex); + } + + /// Erase the operand at 'index' from the false operand list. + void eraseFalseOperand(unsigned index) { + getOperation()->eraseSuccessorOperand(falseIndex, index); + } + + private: + /// Get the index of the first true destination operand. + unsigned getTrueDestOperandIndex() { return 1; } + + /// Get the index of the first false destination operand. + unsigned getFalseDestOperandIndex() { + return getTrueDestOperandIndex() + getNumTrueOperands(); + } + }]; + + let parser = [{ return parse$cppClass(parser, result); }]; + let printer = [{ return print$cppClass(p, *this); }]; +} + +def IREEInterpLL_CmpIOp : IREEInterpLL_Op<"cmp_i"> { + let arguments = (ins + I32Attr:$predicate, + IREELL_IntMemRef:$lhs, + IREELL_IntMemRef:$rhs, + IREELL_BoolMemRef:$dst + ); +} + +def IREEInterpLL_CmpFOp : IREEInterpLL_Op<"cmp_f"> { + let arguments = (ins + I32Attr:$predicate, + IREELL_FloatMemRef:$lhs, + IREELL_FloatMemRef:$rhs, + IREELL_BoolMemRef:$dst + ); +} + +def IREEInterpLL_AllocStaticOp : IREEInterpLL_PureOp<"alloc_static"> { + // TODO(benvanik): attributes and args. + let results = (outs IREELL_MemRef); +} + +def IREEInterpLL_AllocStackOp : IREEInterpLL_PureOp<"alloc_stack"> { + // TODO(benvanik): atributes and args. + let arguments = (ins + Variadic<IREELL_MemRef>:$dim_pieces + ); + let results = (outs + IREELL_MemRef + ); +} + +def IREEInterpLL_AllocStackInitOp : IREEInterpLL_PureOp<"alloc_stack_init"> { + // TODO(benvanik): attributes and args. + let arguments = (ins + Variadic<IREELL_MemRef>:$dim_pieces + ); + let results = (outs + IREELL_MemRef + ); +} + +// TODO(b/142012496): Add trait that enables DCE but not CSE. +def IREEInterpLL_AllocHeapOp : IREEInterpLL_Op<"alloc_heap"> { + // TODO(benvanik): attributes and args. + let arguments = (ins + Variadic<IREELL_MemRef>:$dim_pieces + ); + let results = (outs + IREELL_MemRef + ); +} + +def IREEInterpLL_DiscardOp : IREEInterpLL_Op<"discard"> { + let arguments = (ins IREELL_MemRef); +} + +def IREEInterpLL_RankOp : IREEInterpLL_Op<"rank"> { + let arguments = (ins + IREELL_MemRef:$input, + IREELL_I32Scalar:$dst + ); +} + +def IREEInterpLL_DimOp : IREEInterpLL_Op<"dim"> { + // TODO(benvanik) add dim attr (I32Attr:$dim) + let arguments = (ins + IREELL_MemRef:$input, + IREELL_I32Scalar:$dst + ); +} + +def IREEInterpLL_ShapeOp : IREEInterpLL_Op<"shape"> { + let arguments = (ins + IREELL_MemRef:$input, + IREELL_I32MemRef:$dst + ); +} + +def IREEInterpLL_LengthOp : IREEInterpLL_Op<"length"> { + let arguments = (ins + IREELL_MemRef:$input, + IREELL_I32Scalar:$dst + ); +} + + +def IREEInterpLL_DynamicSliceOp : IREEInterpLL_PureOp<"dynamic_slice"> { + let arguments = (ins + IREELL_MemRef:$src, + IREELL_1DIndexMemRef:$srcIndices, + IREELL_1DIndexMemRef:$lengths + ); + let results = (outs + IREELL_MemRef + ); +} + +// TODO(benvanik): add attribute requirements/types. +def IREEInterpLL_StaticSliceOp : + IREEInterpLL_PureOp<"static_slice", [SameOperandsAndResultElementType]> { + let arguments = (ins IREELL_MemRef:$src); + let results = (outs IREELL_MemRef); +} + +def IREEInterpLL_DynamicCopyOp : IREEInterpLL_Op<"dynamic_copy", [ + AllElementCountsMatch<["srcIndices", "dstIndices", "lengths"]>, +]> { + let arguments = (ins + IREELL_MemRef:$src, + IREELL_1DIndexMemRef:$srcIndices, + IREELL_MemRef:$dst, + IREELL_1DIndexMemRef:$dstIndices, + IREELL_1DIndexMemRef:$lengths + ); +} + +def IREEInterpLL_StaticCopyOp : IREEInterpLL_Op<"static_copy", [ + AllElementCountsMatch<["srcIndices", "dstIndices", "lengths"]>, +]> { + let arguments = (ins + IREELL_MemRef:$src, + I32ElementsAttr:$srcIndices, + IREELL_MemRef:$dst, + I32ElementsAttr:$dstIndices, + I32ElementsAttr:$lengths + ); +} + +def IREEInterpLL_CloneOp : + IREEInterpLL_PureOp<"clone", [SameOperandsAndResultType]> { + let arguments = (ins IREELL_MemRef:$src); + let results = (outs IREELL_MemRef); +} + +// TODO(benvanik): add split dim/size/etc. Maybe make multiple ops? +def IREEInterpLL_SplitOp : IREEInterpLL_PureOp<"split"> { + let arguments = (ins + IREELL_MemRef:$src + ); + let results = (outs + Variadic<IREELL_MemRef> + ); +} + +def IREEInterpLL_AssignOp : + IREEInterpLL_Op<"assign", [SameOperandsAndResultType]> { + let arguments = (ins IREELL_MemRef:$src); + let results = (outs IREELL_MemRef); +} + +def IREEInterpLL_CondAssignOp : IREEInterpLL_Op<"cond_assign"> { + let arguments = (ins + IREELL_BoolScalar:$cond, + IREELL_MemRef:$lhs, + IREELL_MemRef:$rhs + ); + let results = (outs + IREELL_MemRef + ); +} + +def IREEInterpLL_ReshapeOp : IREEInterpLL_Op<"reshape"> { + let arguments = (ins + IREELL_MemRef:$input, + IREELL_1DIntMemRef:$shape + ); + let results = (outs + IREELL_MemRef + ); +} + +def IREEInterpLL_SelectOp : IREEInterpLL_Op<"select"> { + let arguments = (ins + IREELL_MemRef:$cond, + IREELL_MemRef:$lhs, + IREELL_MemRef:$rhs, + IREELL_MemRef:$dst + ); +} + +def IREEInterpLL_PadOp : + IREEInterpLL_Op< + "pad", [AllElementTypesMatch<["src", "dst", "padding_value"]>]> { + let arguments = (ins + IREELL_MemRef:$src, + IREELL_ElementScalar:$padding_value, + IREELL_1DIndexMemRef:$edge_padding_low, + IREELL_1DIndexMemRef:$edge_padding_high, + IREELL_1DIndexMemRef:$interior_padding, + IREELL_MemRef:$dst + ); +} + +def IREEInterpLL_TransposeOp : IREEInterpLL_BinaryOp<"transpose">; + +def IREEInterPLL_ReverseOp : IREEInterpLL_BinaryOp<"reverse">; + +def IREEInterpLL_BroadcastOp : IREEInterpLL_BinaryOp<"broadcast">; + +def IREEInterpLL_TileOp : IREEInterpLL_BinaryOp<"tile">; + +// TODO(benvanik): add traits for broadcasting support. + +def IREEInterpLL_NotOp : IREEInterpLL_UnaryOp<"not">; +def IREEInterpLL_AndOp : IREEInterpLL_BinaryOp<"and">; +def IREEInterpLL_OrOp : IREEInterpLL_BinaryOp<"or">; +def IREEInterpLL_XorOp : IREEInterpLL_BinaryOp<"xor">; +def IREEInterpLL_ShiftLeftOp : IREEInterpLL_BinaryOp<"sll">; +def IREEInterpLL_ShiftRightLogicalOp : IREEInterpLL_BinaryOp<"srl">; +def IREEInterpLL_ShiftRightArithmeticOp : IREEInterpLL_BinaryOp<"sra">; + +def IREEInterpLL_AddIOp : IREEInterpLL_BinaryOp<"add_i", IREELL_IntMemRef>; +def IREEInterpLL_AddFOp : IREEInterpLL_BinaryOp<"add_f", IREELL_FloatMemRef>; +def IREEInterpLL_SubIOp : IREEInterpLL_BinaryOp<"sub_i", IREELL_IntMemRef>; +def IREEInterpLL_SubFOp : IREEInterpLL_BinaryOp<"sub_f", IREELL_FloatMemRef>; +def IREEInterpLL_AbsIOp : IREEInterpLL_UnaryOp<"abs_i", IREELL_IntMemRef>; +def IREEInterpLL_AbsFOp : IREEInterpLL_UnaryOp<"abs_f", IREELL_FloatMemRef>; +def IREEInterpLL_MulIOp : IREEInterpLL_BinaryOp<"mul_i", IREELL_IntMemRef>; +def IREEInterpLL_MulFOp : IREEInterpLL_BinaryOp<"mul_f", IREELL_FloatMemRef>; +def IREEInterpLL_DivISOp : IREEInterpLL_BinaryOp<"div_i_s", IREELL_IntMemRef>; +def IREEInterpLL_DivIUOp : IREEInterpLL_BinaryOp<"div_i_u", IREELL_IntMemRef>; +def IREEInterpLL_DivFOp : IREEInterpLL_BinaryOp<"div_f", IREELL_FloatMemRef>; +def IREEInterpLL_MulAddIOp : IREEInterpLL_BinaryOp<"madd_i", IREELL_IntMemRef>; +def IREEInterpLL_MulAddFOp : IREEInterpLL_BinaryOp<"madd_f", IREELL_FloatMemRef>; +def IREEInterpLL_ExpFOp : IREEInterpLL_UnaryOp<"exp_f", IREELL_FloatMemRef>; +def IREEInterpLL_LogFOp : IREEInterpLL_UnaryOp<"log_f", IREELL_FloatMemRef>; +def IREEInterpLL_RsqrtFOp : IREEInterpLL_UnaryOp<"rsqrt_f", IREELL_FloatMemRef>; +def IREEInterpLL_CosFOp : IREEInterpLL_UnaryOp<"cos_f", IREELL_FloatMemRef>; +def IREEInterpLL_SinFOp : IREEInterpLL_UnaryOp<"sin_f", IREELL_FloatMemRef>; +def IREEInterpLL_TanhFOp : IREEInterpLL_UnaryOp<"tanh_f", IREELL_FloatMemRef>; +def IREEInterpLL_Atan2FOp : IREEInterpLL_UnaryOp<"atan2_f", IREELL_FloatMemRef>; + +def IREEInterpLL_MinISOp : IREEInterpLL_BinaryOp<"min_i_s", IREELL_IntMemRef>; +def IREEInterpLL_MinIUOp : IREEInterpLL_BinaryOp<"min_i_u", IREELL_IntMemRef>; +def IREEInterpLL_MinFOp : IREEInterpLL_BinaryOp<"min_f", IREELL_FloatMemRef>; +def IREEInterpLL_MaxISOp : IREEInterpLL_BinaryOp<"max_i_s", IREELL_IntMemRef>; +def IREEInterpLL_MaxIUOp : IREEInterpLL_BinaryOp<"max_i_u", IREELL_IntMemRef>; +def IREEInterpLL_MaxFOp : IREEInterpLL_BinaryOp<"max_f", IREELL_FloatMemRef>; +def IREEInterpLL_ClampFOp : IREEInterpLL_TernaryOp<"clamp_f", IREELL_FloatMemRef>; +def IREEInterpLL_FloorFOp : IREEInterpLL_UnaryOp<"floor_f", IREELL_FloatMemRef>; +def IREEInterpLL_CeilFOp : IREEInterpLL_UnaryOp<"ceil_f", IREELL_FloatMemRef>; + +def IREEInterpLL_ConvertSSOp : IREEInterpLL_UnaryOp<"convert_s_s", IREELL_MemRef>; +def IREEInterpLL_ConvertSUOp : IREEInterpLL_UnaryOp<"convert_s_u", IREELL_MemRef>; +def IREEInterpLL_ConvertSFOp : IREEInterpLL_UnaryOp<"convert_s_f", IREELL_MemRef>; + +def IREEInterpLL_ConvertUSOp : IREEInterpLL_UnaryOp<"convert_u_s", IREELL_MemRef>; +def IREEInterpLL_ConvertUUOp : IREEInterpLL_UnaryOp<"convert_u_u", IREELL_MemRef>; +def IREEInterpLL_ConvertUFOp : IREEInterpLL_UnaryOp<"convert_u_f", IREELL_MemRef>; + +def IREEInterpLL_ConvertFSOp : IREEInterpLL_UnaryOp<"convert_f_s", IREELL_MemRef>; +def IREEInterpLL_ConvertFUOp : IREEInterpLL_UnaryOp<"convert_f_u", IREELL_MemRef>; +def IREEInterpLL_ConvertFFOp : IREEInterpLL_UnaryOp<"convert_f_f", IREELL_MemRef>; + +def IREEInterpLL_MatMulIOp : IREEInterpLL_Op<"matmul_i"> { + let arguments = (ins + IREELL_IntMemRef:$lhs, + IREELL_IntMemRef:$rhs, + IREELL_IntMemRef:$multiplier_mantissa, + IREELL_IntMemRef:$multiplier_exponent, + IREELL_IntMemRef:$dst + ); +} +def IREEInterpLL_MatMulFOp : IREEInterpLL_Op<"matmul_f"> { + let arguments = (ins + IREELL_FloatMemRef:$lhs, + IREELL_FloatMemRef:$rhs, + IREELL_FloatMemRef:$dst + ); +} + +def IREEInterpLL_ReduceSumIOp : IREEInterpLL_Op<"reduce_sum_i"> { + let arguments = (ins + IREELL_IntMemRef:$src, + IREELL_IntMemRef:$init, + I32Attr:$dimension, + IREELL_IntMemRef:$dst + ); +} +def IREEInterpLL_ReduceSumFOp : IREEInterpLL_Op<"reduce_sum_f"> { + let arguments = (ins + IREELL_FloatMemRef:$src, + IREELL_FloatMemRef:$init, + I32Attr:$dimension, + IREELL_FloatMemRef:$dst + ); +} + +def IREEInterpLL_ReduceMinIOp : IREEInterpLL_Op<"reduce_min_i"> { + let arguments = (ins + IREELL_IntMemRef:$src, + IREELL_IntMemRef:$init, + I32Attr:$dimension, + IREELL_IntMemRef:$dst + ); +} +def IREEInterpLL_ReduceMinFOp : IREEInterpLL_Op<"reduce_min_f"> { + let arguments = (ins + IREELL_FloatMemRef:$src, + IREELL_FloatMemRef:$init, + I32Attr:$dimension, + IREELL_FloatMemRef:$dst + ); +} + +def IREEInterpLL_ReduceMaxIOp : IREEInterpLL_Op<"reduce_max_i"> { + let arguments = (ins + IREELL_IntMemRef:$src, + IREELL_IntMemRef:$init, + I32Attr:$dimension, + IREELL_IntMemRef:$dst + ); +} +def IREEInterpLL_ReduceMaxFOp : IREEInterpLL_Op<"reduce_max_f"> { + let arguments = (ins + IREELL_FloatMemRef:$src, + IREELL_FloatMemRef:$init, + I32Attr:$dimension, + IREELL_FloatMemRef:$dst + ); +} + +def IREEInterpLL_TraceOp : IREEInterpLL_Op<"trace"> { + let arguments = (ins + Variadic<IREELL_MemRef>:$srcs + ); +} + +def IREEInterpLL_CondBreakOp : IREEInterpLL_Op<"cond_break"> { + let arguments = (ins + IREELL_BoolScalar:$cond + ); +} + +def IREEInterpLL_BreakOp : IREEInterpLL_Op<"break">; + +#endif // IREE_INTERPRETER_LL_OPS
diff --git a/compiler/IR/Interpreter/OpWriters.cpp b/compiler/IR/Interpreter/OpWriters.cpp new file mode 100644 index 0000000..bc2a118 --- /dev/null +++ b/compiler/IR/Interpreter/OpWriters.cpp
@@ -0,0 +1,261 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Interpreter/OpWriters.h" + +#include "compiler/IR/Interpreter/LLOps.h" +#include "compiler/Serialization/BytecodeWriter.h" +#include "compiler/Utils/Macros.h" +#include "schemas/bytecode/interpreter_bytecode_v0.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/TypeUtilities.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +//===----------------------------------------------------------------------===// +// Sequencer ops +//===----------------------------------------------------------------------===// + +LogicalResult writeOp(IREEInterp::LL::ConstantOp op, BytecodeWriter *writer) { + RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kConstant)); + auto memrefType = op.getType().dyn_cast<MemRefType>(); + if (!memrefType) { + return op.emitOpError() + << "Constant has an unsupported type; must be a memref: " + << op.getType(); + } + RETURN_IF_FAILURE(writer->WriteConstant(memrefType, op.getAttr("value"))); + RETURN_IF_FAILURE(writer->WriteLocal(op.getResult())); + return success(); +} + +LogicalResult writeOp(IREEInterp::LL::CallOp op, BytecodeWriter *writer) { + auto module = op.getOperation()->getParentOfType<ModuleOp>(); + auto callee = module.lookupSymbol<FuncOp>(op.getCallee()); + // TODO(benvanik): transforms to convert Call->CallImport. + // TODO(benvanik): switch with kCallTail if attr exists. + if (callee.isExternal()) { + RETURN_IF_FAILURE( + writer->WriteOpcode(iree::InterpreterOpcode::kCallImport)); + } else { + RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kCall)); + } + RETURN_IF_FAILURE(writer->WriteFunctionOrdinal(callee)); + RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands())); + RETURN_IF_FAILURE(writer->WriteLocals(op.getResults())); + return success(); +} + +LogicalResult writeOp(IREEInterp::LL::CallImportOp op, BytecodeWriter *writer) { + auto module = op.getOperation()->getParentOfType<ModuleOp>(); + auto callee = module.lookupSymbol<FuncOp>(op.getCallee()); + // TODO(benvanik): switch with kCallTail if attr exists. + RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kCallImport)); + RETURN_IF_FAILURE(writer->WriteImportOrdinal(callee)); + RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands())); + RETURN_IF_FAILURE(writer->WriteLocals(op.getResults())); + return success(); +} + +LogicalResult writeOp(IREEInterp::LL::CallIndirectOp op, + BytecodeWriter *writer) { + RETURN_IF_FAILURE( + writer->WriteOpcode(iree::InterpreterOpcode::kCallIndirect)); + RETURN_IF_FAILURE(writer->WriteTypeIndex(op.getCallee()->getType())); + RETURN_IF_FAILURE(writer->WriteLocal(op.getCallee())); + RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands())); + RETURN_IF_FAILURE(writer->WriteLocals(op.getResults())); + return success(); +} + +LogicalResult WriteConvertOperands(Operation *op, BytecodeWriter *writer) { + auto *src = op->getOperand(0); + RETURN_IF_FAILURE( + writer->WriteTypeIndex(getElementTypeOrSelf(src->getType()))); + RETURN_IF_FAILURE(writer->WriteLocal(src)); + auto *dst = op->getOperand(1); + RETURN_IF_FAILURE( + writer->WriteTypeIndex(getElementTypeOrSelf(dst->getType()))); + RETURN_IF_FAILURE(writer->WriteLocal(dst)); + return success(); +} + +LogicalResult writeOp(IREEInterp::LL::ConvertSSOp op, BytecodeWriter *writer) { + RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kConvertSS)); + return WriteConvertOperands(op, writer); +} + +LogicalResult writeOp(IREEInterp::LL::ConvertUUOp op, BytecodeWriter *writer) { + RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kConvertUU)); + return WriteConvertOperands(op, writer); +} + +LogicalResult writeOp(IREEInterp::LL::ConvertSUOp op, BytecodeWriter *writer) { + RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kConvertSU)); + return WriteConvertOperands(op, writer); +} + +LogicalResult writeOp(IREEInterp::LL::ConvertUSOp op, BytecodeWriter *writer) { + RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kConvertUS)); + return WriteConvertOperands(op, writer); +} + +LogicalResult writeOp(IREEInterp::LL::BranchOp op, BytecodeWriter *writer) { + RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kBranch)); + RETURN_IF_FAILURE(writer->WriteBlockOffset(op.getDest())); + RETURN_IF_FAILURE(writer->WriteCount(op.getNumOperands())); + for (int i = 0; i < op.getNumOperands(); ++i) { + // Copy src->dst. + RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(i))); + RETURN_IF_FAILURE(writer->WriteLocal(op.getDest()->getArgument(i))); + } + return success(); +} + +LogicalResult writeOp(IREEInterp::LL::CondBranchOp op, BytecodeWriter *writer) { + RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kCondBranch)); + RETURN_IF_FAILURE(writer->WriteLocal(op.getCondition())); + RETURN_IF_FAILURE(writer->WriteBlockOffset(op.getTrueDest())); + RETURN_IF_FAILURE(writer->WriteCount(op.getNumTrueOperands())); + for (int i = 0; i < op.getNumTrueOperands(); ++i) { + // Copy src->dst. + RETURN_IF_FAILURE(writer->WriteLocal(op.getTrueOperand(i))); + RETURN_IF_FAILURE(writer->WriteLocal(op.getTrueDest()->getArgument(i))); + } + RETURN_IF_FAILURE(writer->WriteBlockOffset(op.getFalseDest())); + RETURN_IF_FAILURE(writer->WriteCount(op.getNumFalseOperands())); + for (int i = 0; i < op.getNumFalseOperands(); ++i) { + // Copy src->dst. + RETURN_IF_FAILURE(writer->WriteLocal(op.getFalseOperand(i))); + RETURN_IF_FAILURE(writer->WriteLocal(op.getFalseDest()->getArgument(i))); + } + return success(); +} + +LogicalResult writeOp(IREEInterp::LL::CmpIOp op, BytecodeWriter *writer) { + RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kCmpI)); + RETURN_IF_FAILURE( + writer->WriteUint8(static_cast<uint8_t>(op.predicate().getZExtValue()))); + RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(0))); + RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(1))); + RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(2))); + return success(); +} + +LogicalResult writeOp(IREEInterp::LL::CmpFOp op, BytecodeWriter *writer) { + RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kCmpF)); + RETURN_IF_FAILURE( + writer->WriteUint8(static_cast<uint8_t>(op.predicate().getZExtValue()))); + RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(0))); + RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(1))); + RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(2))); + return success(); +} + +LogicalResult writeOp(IREEInterp::LL::AllocHeapOp op, BytecodeWriter *writer) { + auto memrefType = op.getType().cast<MemRefType>(); + RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kAllocHeap)); + RETURN_IF_FAILURE(writer->WriteInt32(0)); + RETURN_IF_FAILURE(writer->WriteTypeIndex(memrefType.getElementType())); + RETURN_IF_FAILURE(writer->WriteShapePieces(memrefType)); + RETURN_IF_FAILURE(writer->WriteLocals(op.getOperands())); + RETURN_IF_FAILURE(writer->WriteLocal(op.getResult())); + return success(); +} + +LogicalResult writeOp(IREEInterp::LL::StaticCopyOp op, BytecodeWriter *writer) { + RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kStaticCopy)); + RETURN_IF_FAILURE(writer->WriteLocal(op.src())); + RETURN_IF_FAILURE(writer->WriteShapePieces(op.srcIndices())); + RETURN_IF_FAILURE(writer->WriteLocal(op.dst())); + RETURN_IF_FAILURE(writer->WriteShapePieces(op.dstIndices())); + RETURN_IF_FAILURE(writer->WriteShapePieces(op.lengths())); + return success(); +} + +LogicalResult writeReduceOperands(Operation *op, BytecodeWriter *writer, + APInt dimension) { + RETURN_IF_FAILURE(writer->WriteLocal(op->getOperand(0))); + RETURN_IF_FAILURE(writer->WriteLocal(op->getOperand(1))); + RETURN_IF_FAILURE(writer->WriteInt32(dimension.getZExtValue())); + RETURN_IF_FAILURE(writer->WriteLocal(op->getOperand(2))); + return success(); +} + +LogicalResult writeOp(IREEInterp::LL::ReduceSumIOp op, BytecodeWriter *writer) { + RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kReduceSumI)); + return writeReduceOperands(op, writer, op.dimension()); +} + +LogicalResult writeOp(IREEInterp::LL::ReduceSumFOp op, BytecodeWriter *writer) { + RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kReduceSumF)); + return writeReduceOperands(op, writer, op.dimension()); +} + +LogicalResult writeOp(IREEInterp::LL::ReduceMinIOp op, BytecodeWriter *writer) { + RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kReduceMinI)); + return writeReduceOperands(op, writer, op.dimension()); +} + +LogicalResult writeOp(IREEInterp::LL::ReduceMinFOp op, BytecodeWriter *writer) { + RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kReduceMinF)); + return writeReduceOperands(op, writer, op.dimension()); +} + +LogicalResult writeOp(IREEInterp::LL::ReduceMaxIOp op, BytecodeWriter *writer) { + RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kReduceMaxI)); + return writeReduceOperands(op, writer, op.dimension()); +} + +LogicalResult writeOp(IREEInterp::LL::ReduceMaxFOp op, BytecodeWriter *writer) { + RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kReduceMaxF)); + return writeReduceOperands(op, writer, op.dimension()); +} + +} // namespace + +void registerInterpreterCustomWriters(VMFunctionBuilder *builder) { +#define REGISTER_CUSTOM_WRITER_IMPL(op_type) \ + builder->RegisterCustomWriter( \ + op_type::getOperationName(), \ + +[](Operation *op, BytecodeWriter *writer) { \ + return writeOp(cast<op_type>(op), writer); \ + }); + REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ConstantOp); + REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::CallOp); + REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::CallImportOp); + REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::CallIndirectOp); + REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::BranchOp); + REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::CondBranchOp); + REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ConvertSSOp); + REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ConvertUUOp); + REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ConvertSUOp); + REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ConvertUSOp); + REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::CmpIOp); + REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::CmpFOp); + REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::AllocHeapOp); + REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::StaticCopyOp); + REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ReduceSumIOp); + REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ReduceSumFOp); + REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ReduceMinIOp); + REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ReduceMinFOp); + REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ReduceMaxIOp); + REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ReduceMaxFOp); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/IR/Interpreter/OpWriters.h b/compiler/IR/Interpreter/OpWriters.h new file mode 100644 index 0000000..a02c57f --- /dev/null +++ b/compiler/IR/Interpreter/OpWriters.h
@@ -0,0 +1,30 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_IR_INTERPRETER_OPWRITERS_H_ +#define IREE_COMPILER_IR_INTERPRETER_OPWRITERS_H_ + +#include "compiler/Serialization/VMFunctionBuilder.h" + +namespace mlir { +namespace iree_compiler { + +// Registers custom op writers with the builder. +// Ops not registered will use the generic writer. +void registerInterpreterCustomWriters(VMFunctionBuilder *builder); + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_IR_INTERPRETER_OPWRITERS_H_
diff --git a/compiler/IR/Interpreter/test/BUILD b/compiler/IR/Interpreter/test/BUILD new file mode 100644 index 0000000..44a5820 --- /dev/null +++ b/compiler/IR/Interpreter/test/BUILD
@@ -0,0 +1,15 @@ +load("//:build_defs.google.bzl", "iree_glob_lit_tests", "iree_setup_lit_package") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +iree_setup_lit_package( + data = [ + "///tools:iree-opt", + "///tools:iree-run-mlir", + ], +) + +iree_glob_lit_tests()
diff --git a/iree/compiler/IR/Interpreter/test/concat.mlir b/compiler/IR/Interpreter/test/concat.mlir similarity index 100% rename from iree/compiler/IR/Interpreter/test/concat.mlir rename to compiler/IR/Interpreter/test/concat.mlir
diff --git a/iree/compiler/IR/Interpreter/test/invalid_types_hl.mlir b/compiler/IR/Interpreter/test/invalid_types_hl.mlir similarity index 100% rename from iree/compiler/IR/Interpreter/test/invalid_types_hl.mlir rename to compiler/IR/Interpreter/test/invalid_types_hl.mlir
diff --git a/iree/compiler/IR/Interpreter/test/invalid_types_ll.mlir b/compiler/IR/Interpreter/test/invalid_types_ll.mlir similarity index 100% rename from iree/compiler/IR/Interpreter/test/invalid_types_ll.mlir rename to compiler/IR/Interpreter/test/invalid_types_ll.mlir
diff --git a/iree/compiler/IR/OpBase.td b/compiler/IR/OpBase.td similarity index 100% rename from iree/compiler/IR/OpBase.td rename to compiler/IR/OpBase.td
diff --git a/compiler/IR/Ops.cpp b/compiler/IR/Ops.cpp new file mode 100644 index 0000000..5179f41 --- /dev/null +++ b/compiler/IR/Ops.cpp
@@ -0,0 +1,639 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Ops.h" + +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/SMLoc.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Support/STLExtras.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { + +//===----------------------------------------------------------------------===// +// iree.constant +//===----------------------------------------------------------------------===// + +static ParseResult parseConstantOp(OpAsmParser &parser, + OperationState &result) { + Attribute valueAttr; + Type type; + if (parser.parseLSquare() || + parser.parseAttribute(valueAttr, "value", result.attributes) || + parser.parseRSquare() || + parser.parseOptionalAttributeDict(result.attributes) || + parser.parseColonType(type)) + return failure(); + + return parser.addTypeToList(type, result.types); +} + +static void printConstantOp(OpAsmPrinter &p, ConstantOp &op) { + p << "iree.constant["; + p.printAttribute(op.getValue()); + p << "] "; + p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"}); + + p << " : "; + p.printType(op.getType()); +} + +namespace { + +// TODO(gcmn) this is duplicated from MemRefUtils to avoid a circular +// dependency. Extract op-dependent parts of memref utils to allow reuse. +MemRefType convertTypeToMemRef(Type type) { + if (type.isIntOrIndexOrFloat()) { + return MemRefType::get({}, type, {}, 0); + } else if (auto tensorType = type.dyn_cast<RankedTensorType>()) { + return MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + } else if (auto memRefType = type.dyn_cast<MemRefType>()) { + return MemRefType::get(memRefType.getShape(), memRefType.getElementType()); + } else { + llvm_unreachable("Unconvertable type"); + } +} + +} // namespace + +void ConstantOp::build(Builder *builder, OperationState &state, + ElementsAttr value) { + auto type = convertTypeToMemRef(value.getType()); + return build(builder, state, type, value); +} + +// TODO(b/134575149): enable folder when we store the correct type. +// OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { +// assert(operands.empty() && "constant has no operands"); +// return getValue(); +// } + +//===----------------------------------------------------------------------===// +// iree.tensor_to_memref +//===----------------------------------------------------------------------===// + +static ParseResult parseTensorToMemRefOp(OpAsmParser &parser, + OperationState &state) { + OpAsmParser::OperandType operand; + Type operandType; + Type resultType; + if (failed(parser.parseLParen()) || failed(parser.parseOperand(operand)) || + failed(parser.parseColonType(operandType)) || + failed(parser.resolveOperand(operand, operandType, state.operands)) || + failed(parser.parseRParen()) || + failed(parser.parseColonType(resultType)) || + failed(parser.addTypeToList(resultType, state.types))) { + return failure(); + } + return success(); +} + +static void printTensorToMemRefOp(OpAsmPrinter &p, TensorToMemRefOp &op) { + p << "iree.tensor_to_memref("; + p.printOperand(op.getOperand()); + p << " : "; + p.printType(op.getOperand()->getType()); + p << ") : "; + p.printType(op.getType()); +} + +OpFoldResult TensorToMemRefOp::fold(ArrayRef<Attribute> operands) { + if (auto memrefToTensorOp = dyn_cast_or_null<IREE::MemRefToTensorOp>( + getOperand()->getDefiningOp())) { + return memrefToTensorOp.getOperand(); + } + + return {}; +} + +void TensorToMemRefOp::build(Builder *builder, OperationState &state, + Value *arg) { + build(builder, state, convertTypeToMemRef(arg->getType()), arg); +} + +//===----------------------------------------------------------------------===// +// iree.memref_to_tensor +//===----------------------------------------------------------------------===// + +static ParseResult parseMemRefToTensorOp(OpAsmParser &parser, + OperationState &state) { + OpAsmParser::OperandType operand; + Type operandType; + Type resultType; + if (failed(parser.parseLParen()) || failed(parser.parseOperand(operand)) || + failed(parser.parseColonType(operandType)) || + failed(parser.resolveOperand(operand, operandType, state.operands)) || + failed(parser.parseRParen()) || + failed(parser.parseColonType(resultType)) || + failed(parser.addTypeToList(resultType, state.types))) { + return failure(); + } + return success(); +} + +static void printMemRefToTensorOp(OpAsmPrinter &p, MemRefToTensorOp &op) { + p << "iree.memref_to_tensor("; + p.printOperand(op.getOperand()); + p << " : "; + p.printType(op.getOperand()->getType()); + p << ") : "; + p.printType(op.getType()); +} + +OpFoldResult MemRefToTensorOp::fold(ArrayRef<Attribute> operands) { + if (auto tensorToMemRefOp = dyn_cast_or_null<IREE::TensorToMemRefOp>( + getOperand()->getDefiningOp())) { + return tensorToMemRefOp.getOperand(); + } + + return {}; +} + +void MemRefToTensorOp::build(Builder *builder, OperationState &state, + Value *arg) { + // TODO(gcmn) Use getTensorType from MemRefUtils when circular dependency can + // be avoided. + auto memRefType = arg->getType().cast<MemRefType>(); + auto tensorType = + RankedTensorType::get(memRefType.getShape(), memRefType.getElementType()); + build(builder, state, tensorType, arg); +} + +//===----------------------------------------------------------------------===// +// iree.scalar_to_memref +//===----------------------------------------------------------------------===// + +static ParseResult parseScalarToMemRefOp(OpAsmParser &parser, + OperationState &state) { + OpAsmParser::OperandType operand; + Type operandType; + Type resultType; + if (failed(parser.parseLParen()) || failed(parser.parseOperand(operand)) || + failed(parser.parseColonType(operandType)) || + failed(parser.resolveOperand(operand, operandType, state.operands)) || + failed(parser.parseRParen()) || + failed(parser.parseColonType(resultType)) || + failed(parser.addTypeToList(resultType, state.types))) { + return failure(); + } + return success(); +} + +static void printScalarToMemRefOp(OpAsmPrinter &p, ScalarToMemRefOp &op) { + p << "iree.scalar_to_memref("; + p.printOperand(op.getOperand()); + p << " : "; + p.printType(op.getOperand()->getType()); + p << ") : "; + p.printType(op.getType()); +} + +OpFoldResult ScalarToMemRefOp::fold(ArrayRef<Attribute> operands) { + if (auto memrefToScalarOp = dyn_cast_or_null<IREE::MemRefToScalarOp>( + getOperand()->getDefiningOp())) { + return memrefToScalarOp.getOperand(); + } + + return {}; +} + +void ScalarToMemRefOp::build(Builder *builder, OperationState &state, + Value *arg) { + build(builder, state, convertTypeToMemRef(arg->getType()), arg); +} + +//===----------------------------------------------------------------------===// +// iree.memref_to_scalar +//===----------------------------------------------------------------------===// + +static ParseResult parseMemRefToScalarOp(OpAsmParser &parser, + OperationState &state) { + OpAsmParser::OperandType operand; + Type operandType; + Type resultType; + if (failed(parser.parseLParen()) || failed(parser.parseOperand(operand)) || + failed(parser.parseColonType(operandType)) || + failed(parser.resolveOperand(operand, operandType, state.operands)) || + failed(parser.parseRParen()) || + failed(parser.parseColonType(resultType)) || + failed(parser.addTypeToList(resultType, state.types))) { + return failure(); + } + return success(); +} + +static void printMemRefToScalarOp(OpAsmPrinter &p, MemRefToScalarOp &op) { + p << "iree.memref_to_scalar("; + p.printOperand(op.getOperand()); + p << " : "; + p.printType(op.getOperand()->getType()); + p << ") : "; + p.printType(op.getType()); +} + +OpFoldResult MemRefToScalarOp::fold(ArrayRef<Attribute> operands) { + if (auto scalarToMemRefOp = dyn_cast_or_null<IREE::ScalarToMemRefOp>( + getOperand()->getDefiningOp())) { + return scalarToMemRefOp.getOperand(); + } + + return {}; +} + +void MemRefToScalarOp::build(Builder *builder, OperationState &state, + Value *arg) { + build(builder, state, getElementTypeOrSelf(arg), arg); +} + +//===----------------------------------------------------------------------===// +// iree.dispatch_region +//===----------------------------------------------------------------------===// + +void DispatchRegionOp::build(Builder *builder, OperationState &state, + ArrayRef<Type> resultTypes, Value *workload, + ArrayRef<Value *> operands, + ArrayRef<NamedAttribute> attributes) { + state.addTypes(resultTypes); + state.addOperands({workload}); + state.addOperands(operands); + state.addAttributes(attributes); + state.addRegion(); + state.setOperandListToResizable(); +} + +ParseResult parseDispatchRegionOp(OpAsmParser &parser, OperationState &state) { + // Parse required workload. + OpAsmParser::OperandType workloadArg; + Type workloadArgType; + if (failed(parser.parseLSquare()) || + failed(parser.parseOperand(workloadArg)) || + failed(parser.parseColonType(workloadArgType)) || + failed(parser.parseRSquare()) || + failed(parser.resolveOperand(workloadArg, workloadArgType, + state.operands))) { + return failure(); + } + + // Parse (optional) args. + SmallVector<OpAsmParser::OperandType, 16> regionArgs; + SmallVector<Type, 16> regionArgTypes; + if (failed(parser.parseLParen())) { + return failure(); + } + if (failed(parser.parseOptionalRParen())) { + SmallVector<OpAsmParser::OperandType, 16> regionOperands; + auto argsLoc = parser.getCurrentLocation(); + do { + // Reserve entries in the lists. + regionArgs.emplace_back(); + regionOperands.emplace_back(); + regionArgTypes.emplace_back(); + if (failed(parser.parseRegionArgument(regionArgs.back())) || + failed(parser.parseEqual()) || + failed(parser.parseOperand(regionOperands.back())) || + failed(parser.parseColonType(regionArgTypes.back()))) { + return failure(); + } + } while (succeeded(parser.parseOptionalComma())); + if (failed(parser.parseRParen()) || + failed(parser.resolveOperands(regionOperands, regionArgTypes, argsLoc, + state.operands))) { + return failure(); + } + } + state.setOperandListToResizable(); + + // Parse (optional) results. + if (failed(parser.parseOptionalColonTypeList(state.types))) { + return failure(); + } + + // Parse region body. + Region *body = state.addRegion(); + if (failed(parser.parseRegion(*body, regionArgs, regionArgTypes)) || + failed(parser.parseOptionalAttributeDict(state.attributes))) { + return failure(); + } + return success(); +} + +void printDispatchRegionOp(OpAsmPrinter &p, DispatchRegionOp op) { + p << "iree.dispatch_region"; + + // Print the workload argument. + p << "["; + p.printOperand(op.getWorkload()); + p << " : "; + p.printType(op.getWorkload()->getType()); + p << "]"; + + // Print the data argument remapping. + p << "("; + interleaveComma( + llvm::zip(op.getBody().front().getArguments(), op.getArgOperands()), p, + [&](std::tuple<BlockArgument *, Value *> it) { + p << *std::get<0>(it) << " = " << *std::get<1>(it); + p << " : "; + p << std::get<1>(it)->getType(); + }); + p << ")"; + + // Print the result types, if any. + if (op.getNumResults() > 0) { + p << " : "; + interleaveComma(op.getResultTypes(), p); + } + + p.printRegion(op.getBody(), /*printEntryBlockArgs=*/false); + p.printOptionalAttrDict(op.getAttrs(), + /*elidedAttrs=*/{}); +} + +//===----------------------------------------------------------------------===// +// iree.reduction_region +//===----------------------------------------------------------------------===// + +void ReductionRegionOp::build(Builder *builder, OperationState &state, + ArrayRef<Type> resultTypes, Value *workload, + ArrayRef<Value *> operands, + ArrayRef<Value *> initialValues, + ArrayRef<int64_t> dimensions, + ArrayRef<NamedAttribute> attributes) { + state.addTypes(resultTypes); + state.addOperands({workload}); + state.addOperands(operands); + state.addOperands(initialValues); + state.addAttribute( + "dimensions", + DenseIntElementsAttr::get( + RankedTensorType::get({static_cast<int64_t>(dimensions.size())}, + builder->getIntegerType(64)), + dimensions)); + state.addAttributes(attributes); + state.addRegion(); + state.setOperandListToResizable(); +} + +void ReductionRegionOp::build( + Builder *builder, OperationState &state, ArrayRef<Type> resultTypes, + Value *workload, ArrayRef<Value *> operands, + ArrayRef<Value *> initialValues, ArrayRef<int64_t> windowDimensions, + ArrayRef<int64_t> windowStrides, ArrayRef<int64_t> baseDilations, + ArrayRef<int64_t> windowDilations, PaddingMode paddingMode, + ArrayRef<NamedAttribute> attributes) { + state.addTypes(resultTypes); + state.addOperands({workload}); + state.addOperands(operands); + state.addOperands(initialValues); + state.addAttribute( + "window_dimensions", + DenseIntElementsAttr::get( + RankedTensorType::get({static_cast<int64_t>(windowDimensions.size())}, + builder->getIntegerType(64)), + windowDimensions)); + state.addAttribute( + "window_strides", + DenseIntElementsAttr::get( + RankedTensorType::get({static_cast<int64_t>(windowStrides.size())}, + builder->getIntegerType(64)), + windowStrides)); + state.addAttribute( + "base_dilations", + DenseIntElementsAttr::get( + RankedTensorType::get({static_cast<int64_t>(baseDilations.size())}, + builder->getIntegerType(64)), + baseDilations)); + state.addAttribute( + "window_dilations", + DenseIntElementsAttr::get( + RankedTensorType::get({static_cast<int64_t>(windowDilations.size())}, + builder->getIntegerType(64)), + windowDilations)); + state.addAttribute("padding_mode", builder->getI32IntegerAttr( + static_cast<int32_t>(paddingMode))); + state.addAttributes(attributes); + state.addRegion(); + state.setOperandListToResizable(); +} + +ParseResult parseReductionRegionOp(OpAsmParser &parser, OperationState &state) { + OpAsmParser::OperandType workloadArg; + Type workloadArgType; + if (failed(parser.parseLSquare()) || + failed(parser.parseOperand(workloadArg)) || + failed(parser.parseColonType(workloadArgType)) || + failed(parser.parseRSquare()) || + failed(parser.resolveOperand(workloadArg, workloadArgType, + state.operands))) { + return failure(); + } + + SmallVector<OpAsmParser::OperandType, 8> reductionOperands; + Type reductionType; + auto operandsLoc = parser.getCurrentLocation(); + if (failed(parser.parseLParen()) || + failed(parser.parseOperandList(reductionOperands)) || + failed(parser.parseRParen()) || + failed(parser.parseColonType(reductionType)) || + failed(parser.resolveOperands( + reductionOperands, reductionType.cast<FunctionType>().getInputs(), + operandsLoc, state.operands))) { + return failure(); + } + for (auto type : reductionType.cast<FunctionType>().getResults()) { + state.types.push_back(type); + } + state.setOperandListToResizable(); + + SmallVector<OpAsmParser::OperandType, 8> regionArgs; + SmallVector<Type, 8> regionArgTypes; + if (failed(parser.parseKeyword("invocation")) || + failed(parser.parseLParen())) { + return failure(); + } + do { + Type argType; + SmallVector<OpAsmParser::OperandType, 2> reductionRegionArgs; + OpAsmParser::OperandType initialValue; + if (failed(parser.parseLParen()) || + failed(parser.parseOperandList(reductionRegionArgs, 2)) || + failed(parser.parseRParen()) || failed(parser.parseEqual()) || + failed(parser.parseOperand(initialValue)) || + failed(parser.parseColonType(argType)) || + failed(parser.resolveOperand(initialValue, argType, state.operands))) { + return failure(); + } + regionArgs.push_back(reductionRegionArgs[0]); + regionArgTypes.push_back(argType); + regionArgs.push_back(reductionRegionArgs[1]); + regionArgTypes.push_back(argType); + } while (succeeded(parser.parseOptionalComma())); + if (failed(parser.parseRParen())) { + return failure(); + } + + // Parse region body. + Region *body = state.addRegion(); + if (failed(parser.parseRegion(*body, regionArgs, regionArgTypes)) || + failed(parser.parseOptionalAttributeDict(state.attributes))) { + return failure(); + } + + return success(); +} + +void printReductionRegionOp(OpAsmPrinter &p, ReductionRegionOp op) { + p << "iree.reduction_region"; + + // Print the workload argument. + p << "["; + p.printOperand(op.getWorkload()); + p << " : "; + p.printType(op.getWorkload()->getType()); + p << "]"; + + p << "("; + p.printOperands(op.getODSOperands(1)); + p << ")"; + if (op.getNumResults() > 0) { + p << " : ("; + interleaveComma(op.getODSOperands(1), p, + [&](Value *operand) { p.printType(operand->getType()); }); + p << ")"; + p << " -> ("; + interleaveComma(op.getResultTypes(), p); + p << ")"; + } + p << "\n"; + + p << " invocation("; + auto &entryBlock = op.getBody().getBlocks().front(); + int regionArgIndex = 0; + interleaveComma(op.getODSOperands(2), p, [&](Value *operand) { + p << "("; + p.printOperand(entryBlock.getArgument(regionArgIndex++)); + p << ", "; + p.printOperand(entryBlock.getArgument(regionArgIndex++)); + p << ") = "; + p.printOperand(operand); + p << " : "; + p.printType(operand->getType()); + }); + p << ") "; + + p.printRegion(op.getBody(), /*printEntryBlockArgs=*/false); + p.printOptionalAttrDict(op.getAttrs(), + /*elidedAttrs=*/{}); +} + +//===----------------------------------------------------------------------===// +// iree.return +//===----------------------------------------------------------------------===// + +static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &state) { + SmallVector<OpAsmParser::OperandType, 2> opInfo; + SmallVector<Type, 2> types; + llvm::SMLoc loc = parser.getCurrentLocation(); + return failure(parser.parseOperandList(opInfo) || + (!opInfo.empty() && parser.parseColonTypeList(types)) || + parser.resolveOperands(opInfo, types, loc, state.operands)); +} + +static void printReturnOp(OpAsmPrinter &p, ReturnOp op) { + p << "iree.return"; + if (op.getNumOperands() > 0) { + p << ' '; + p.printOperands(op.operand_begin(), op.operand_end()); + p << " : "; + interleaveComma(op.getOperandTypes(), p); + } +} + +//===----------------------------------------------------------------------===// +// iree.load_input +//===----------------------------------------------------------------------===// + +ParseResult parseLoadInputOp(OpAsmParser &parser, OperationState &state) { + OpAsmParser::OperandType operand; + Type argType; + if (parser.parseLParen() || parser.parseOperand(operand) || + parser.parseColonType(argType) || parser.parseRParen() || + parser.resolveOperand(operand, argType, state.operands)) { + return failure(); + } + Type outputType; + if (parser.parseColonType(outputType) || + parser.addTypeToList(outputType, state.types)) { + return failure(); + } + return success(); +} + +void printLoadInputOp(OpAsmPrinter &printer, Operation *op) { + auto *inputValue = op->getOperand(0); + auto *outputValue = op->getResult(0); + printer << op->getName() << '('; + printer.printOperand(inputValue); + printer << " : "; + printer.printType(inputValue->getType()); + printer << ") : "; + printer.printType(outputValue->getType()); +} + +//===----------------------------------------------------------------------===// +// iree.store_output +//===----------------------------------------------------------------------===// + +ParseResult parseStoreOutputOp(OpAsmParser &parser, OperationState &state) { + OpAsmParser::OperandType op0, op1; + Type argType0, argType1; + if (parser.parseLParen() || parser.parseOperand(op0) || + parser.parseColonType(argType0) || parser.parseComma() || + parser.resolveOperand(op0, argType0, state.operands) || + parser.parseOperand(op1) || parser.parseColonType(argType1) || + parser.parseRParen() || + parser.resolveOperand(op1, argType1, state.operands)) { + return failure(); + } + return success(); +} + +void printStoreOutputOp(OpAsmPrinter &printer, Operation *op) { + auto *inputValue = op->getOperand(0); + auto *outputValue = op->getOperand(1); + printer << op->getName() << '('; + printer.printOperand(inputValue); + printer << " : "; + printer.printType(inputValue->getType()); + printer << ", "; + printer.printOperand(outputValue); + printer << " : "; + printer.printType(outputValue->getType()); + printer << ")"; +} + +#define GET_OP_CLASSES +#include "compiler/IR/Ops.cpp.inc" + +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/IR/Ops.h b/compiler/IR/Ops.h new file mode 100644 index 0000000..f8a1580 --- /dev/null +++ b/compiler/IR/Ops.h
@@ -0,0 +1,36 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_IR_OPS_H_ +#define IREE_COMPILER_IR_OPS_H_ + +#include "compiler/IR/Types.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { + +#define GET_OP_CLASSES +#include "compiler/IR/Ops.h.inc" + +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_IR_OPS_H_
diff --git a/compiler/IR/Ops.td b/compiler/IR/Ops.td new file mode 100644 index 0000000..a304532 --- /dev/null +++ b/compiler/IR/Ops.td
@@ -0,0 +1,200 @@ +// IREE ops for working with buffers and buffer views. +// These are used by common transforms between the sequencer and interpreter and +// allow us to share some of the common lowering passes from other dialects. + +#ifdef IREE_OPS +#else +#define IREE_OPS + +#ifdef IREE_OP_BASE +#else +include "compiler/IR/OpBase.td" +#endif // IREE_OP_BASE + +class IREE_Op<string mnemonic, list<OpTrait> traits = []> : + Op<IREE_Dialect, mnemonic, traits> { + let parser = [{ return parse$cppClass(parser, result); }]; + let printer = [{ print$cppClass(p, *this); }]; +} + +class IREE_PureOp<string mnemonic, list<OpTrait> traits = []> : + IREE_Op<mnemonic, !listconcat(traits, [NoSideEffect])>; + +// TODO(b/134575149): determine if we want multiple constant op types. +def IREE_ConstantOp : IREE_PureOp<"constant", [ + AllShapesMatch<["value", "result"]>, + AllElementTypesMatch<["value", "result"]> +]> { + let arguments = (ins ElementsAttr:$value); + let results = (outs IREEHL_MemRef:$result); + + // TODO(b/132296600): make tablegen follow the style guide. + let extraClassDeclaration = [{ + Attribute getValue() { return value(); } + }]; + + let builders = [OpBuilder<"Builder*, OperationState&, ElementsAttr">]; + + // TODO(b/134575149): enable folder when we store the correct type. + // let hasFolder = 1; +} + +// TODO(b/134671482): remove/move tensor_to_memref/memref_to_tensor. +def IREE_TensorToMemRefOp : IREE_PureOp<"tensor_to_memref", [ + SameOperandsAndResultShape, SameOperandsAndResultElementType +]> { + let arguments = (ins AnyTensor); + let results = (outs IREEHL_MemRef); + + let builders = [OpBuilder<"Builder*, OperationState&, Value*">]; + + let hasFolder = 1; +} + +// TODO(b/134671482): remove/move tensor_to_memref/memref_to_tensor. +def IREE_MemRefToTensorOp : IREE_PureOp<"memref_to_tensor", [ + SameOperandsAndResultShape, SameOperandsAndResultElementType +]> { + let arguments = (ins IREEHL_MemRef); + let results = (outs AnyTensor); + let builders = [OpBuilder<"Builder*, OperationState&, Value*">]; + + let hasFolder = 1; +} + +def IREE_ScalarToMemRefOp : IREE_PureOp<"scalar_to_memref", [ + SameOperandsAndResultElementType +]> { + let arguments = (ins IREEHL_Element); + let results = (outs IREEHL_AnyScalar); + + let builders = [OpBuilder<"Builder*, OperationState&, Value*">]; + + let hasFolder = 1; +} + +def IREE_MemRefToScalarOp : IREE_PureOp<"memref_to_scalar", [ + SameOperandsAndResultElementType +]> { + let arguments = (ins IREEHL_AnyScalar); + let results = (outs IREEHL_Element); + + let builders = [OpBuilder<"Builder*, OperationState&, Value*">]; + + let hasFolder = 1; +} + +def IREE_Workload : TensorOf<[AnyInteger]>; + +def IREE_DispatchRegionOp : IREE_PureOp<"dispatch_region"> { + let arguments = (ins + IREE_Workload:$workload, + Variadic<AnyType>:$args + ); + let results = (outs Variadic<AnyType>); + let regions = (region AnyRegion:$body); + + let extraClassDeclaration = [{ + // TODO(b/132296600): make tablegen follow the style guide. + Value *getWorkload() { return workload(); } + Region& getBody() { return body(); } + + // TODO(b/133879130): make tablegen support variadic operand accessors. + /// Get the argument operands to the called function. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + unsigned mapArgOperandToOpOperand(unsigned i) { return i + 1; } + unsigned getNumArgOperands() { return getNumOperands() - 1; } + Value *getArgOperand(unsigned i) { + return getOperand(mapArgOperandToOpOperand(i)); + } + void setArgOperand(unsigned i, Value *arg) { + setOperand(mapArgOperandToOpOperand(i), arg); + } + + operand_iterator arg_operand_begin() { + return operand_begin() + mapArgOperandToOpOperand(0); + } + operand_iterator arg_operand_end() { return operand_end(); } + }]; + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<"Builder *builder, OperationState &state," + "ArrayRef<Type> resultTypes, Value *workload," + "ArrayRef<Value *> args," + "ArrayRef<NamedAttribute> attributes = {}">, + ]; +} + +def IREE_ReductionRegionOp : IREE_PureOp<"reduction_region", [ + SameVariadicOperandSize, +]> { + let arguments = (ins + IREE_Workload:$workload, + Variadic<AnyType>:$operands, + Variadic<AnyType>:$initial_values, + OptionalAttr<I64ElementsAttr>:$dimensions, + OptionalAttr<I64ElementsAttr>:$window_dimensions, + OptionalAttr<I64ElementsAttr>:$window_strides, + OptionalAttr<I64ElementsAttr>:$base_dilations, + OptionalAttr<I64ElementsAttr>:$window_dilations, + OptionalAttr<IREE_PaddingModeAttr>:$padding_mode + ); + let results = (outs Variadic<AnyType>); + let regions = (region AnyRegion:$body); + + let extraClassDeclaration = [{ + // TODO(b/132296600): make tablegen follow the style guide. + Value *getWorkload() { return workload(); } + Region& getBody() { return body(); } + + bool isWindowed() { + return window_dimensions().hasValue(); + } + + PaddingMode getPaddingMode() { + return static_cast<PaddingMode>(padding_mode().getValue()); + } + + unsigned getNumReductionOperands() { return (getNumOperands() - 1) / 2; } + operand_range getReductionOperands() { return getODSOperands(1); } + operand_range getInitialValueOperands() { return getODSOperands(2); } + }]; + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<"Builder *builder, OperationState &state," + "ArrayRef<Type> resultTypes, Value *workload, ArrayRef<Value *> operands," + "ArrayRef<Value *> initialValues," + "ArrayRef<int64_t> dimensions," + "ArrayRef<NamedAttribute> attributes = {}">, + OpBuilder<"Builder *builder, OperationState &state," + "ArrayRef<Type> resultTypes, Value *workload, ArrayRef<Value *> operands," + "ArrayRef<Value *> initialValues," + "ArrayRef<int64_t> windowDimensions, ArrayRef<int64_t> windowStrides," + "ArrayRef<int64_t> baseDilations, ArrayRef<int64_t> windowDilations," + "PaddingMode paddingMode," + "ArrayRef<NamedAttribute> attributes = {}">, + ]; +} + +def IREE_ReturnOp : IREE_Op<"return", [Terminator]> { + let arguments = (ins Variadic<AnyType>:$operands); + + let builders = [OpBuilder< + "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }] + >]; +} + +def IREE_LoadInputOp : IREE_PureOp<"load_input"> { + let arguments = (ins IREEHL_MemRef:$src); + let results = (outs AnyType); +} + +def IREE_StoreOutputOp : IREE_Op<"store_output"> { + let arguments = (ins AnyType:$src, IREEHL_MemRef:$dst); +} + +#endif // IREE_OPS
diff --git a/compiler/IR/Sequencer/BUILD b/compiler/IR/Sequencer/BUILD new file mode 100644 index 0000000..4ca9ba2 --- /dev/null +++ b/compiler/IR/Sequencer/BUILD
@@ -0,0 +1,76 @@ +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +load("@local_config_mlir//:tblgen.bzl", "gentbl") + +filegroup( + name = "td_files", + srcs = glob(["*.td"]), +) + +cc_library( + name = "Sequencer", + srcs = [ + "HLDialect.cpp", + "HLOps.cpp", + "HLOps.cpp.inc", + "LLDialect.cpp", + "LLOps.cpp", + "LLOps.cpp.inc", + "OpWriters.cpp", + ], + hdrs = [ + "HLDialect.h", + "HLOps.h", + "HLOps.h.inc", + "LLDialect.h", + "LLOps.h", + "LLOps.h.inc", + "OpWriters.h", + ], + deps = [ + ":HLOpsGen", + ":LLOpsGen", + "///compiler/IR", + "///compiler/Serialization", + "///compiler/Utils", + "///schemas/bytecode:sequencer_bytecode_v0", + "@llvm//:support", + "@local_config_mlir//:IR", + "@local_config_mlir//:StandardOps", + "@local_config_mlir//:Support", + ], + alwayslink = 1, +) + +gentbl( + name = "HLOpsGen", + tbl_outs = [ + ("-gen-op-decls", "HLOps.h.inc"), + ("-gen-op-defs", "HLOps.cpp.inc"), + ], + tblgen = "@local_config_mlir//:mlir-tblgen", + td_file = "HLOps.td", + td_srcs = [ + ":td_files", + "@local_config_mlir//:include/mlir/IR/OpBase.td", + "///compiler/IR:OpBase.td", + ], +) + +gentbl( + name = "LLOpsGen", + tbl_outs = [ + ("-gen-op-decls", "LLOps.h.inc"), + ("-gen-op-defs", "LLOps.cpp.inc"), + ], + tblgen = "@local_config_mlir//:mlir-tblgen", + td_file = "LLOps.td", + td_srcs = [ + ":td_files", + "@local_config_mlir//:include/mlir/IR/OpBase.td", + "///compiler/IR:OpBase.td", + ], +)
diff --git a/iree/compiler/IR/Sequencer/CMakeLists.txt b/compiler/IR/Sequencer/CMakeLists.txt similarity index 100% rename from iree/compiler/IR/Sequencer/CMakeLists.txt rename to compiler/IR/Sequencer/CMakeLists.txt
diff --git a/compiler/IR/Sequencer/HLDialect.cpp b/compiler/IR/Sequencer/HLDialect.cpp new file mode 100644 index 0000000..2bc46bc --- /dev/null +++ b/compiler/IR/Sequencer/HLDialect.cpp
@@ -0,0 +1,34 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Sequencer/HLDialect.h" + +#include "compiler/IR/Sequencer/HLOps.h" +#include "llvm/Support/SourceMgr.h" + +namespace mlir { +namespace iree_compiler { + +IREEHLSequencerDialect::IREEHLSequencerDialect(MLIRContext* context) + : Dialect(getDialectNamespace(), context) { +#define GET_OP_LIST + addOperations< +#include "compiler/IR/Sequencer/HLOps.cpp.inc" + >(); +} + +static DialectRegistration<IREEHLSequencerDialect> iree_hl_seq_dialect; + +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/IR/Sequencer/HLDialect.h b/compiler/IR/Sequencer/HLDialect.h similarity index 100% rename from iree/compiler/IR/Sequencer/HLDialect.h rename to compiler/IR/Sequencer/HLDialect.h
diff --git a/compiler/IR/Sequencer/HLOps.cpp b/compiler/IR/Sequencer/HLOps.cpp new file mode 100644 index 0000000..6840804 --- /dev/null +++ b/compiler/IR/Sequencer/HLOps.cpp
@@ -0,0 +1,379 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Sequencer/HLOps.h" + +#include "compiler/IR/Ops.h" +#include "compiler/IR/Types.h" +#include "compiler/Utils/OpCreationUtils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" + +namespace mlir { +namespace iree_compiler { +namespace IREESeq { +namespace HL { + +namespace { + +static LogicalResult verifyWorkload(Operation *op, Value *workload) { + if (auto workloadType = workload->getType().dyn_cast<MemRefType>()) { + if (workloadType.getNumElements() != 3) { + return op->emitOpError("workload must be specified as (x,y,z) but has ") + << workloadType.getNumElements() + << " elements (type=" << workload->getType() << ")"; + } + return success(); + } + return op->emitOpError( + "workload must be specified as an (x,y,z) memref but has type ") + << workload->getType(); +} + +} // namespace + +//===----------------------------------------------------------------------===// +// iree_hl_seq.call +//===----------------------------------------------------------------------===// + +static ParseResult parseCallOp(OpAsmParser &parser, OperationState &state) { + SymbolRefAttr calleeAttr; + FunctionType calleeType; + SmallVector<OpAsmParser::OperandType, 4> operands; + auto calleeLoc = parser.getNameLoc(); + if (parser.parseAttribute(calleeAttr, "callee", state.attributes) || + parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || + parser.parseOptionalAttributeDict(state.attributes) || + parser.parseColonType(calleeType) || + parser.addTypesToList(calleeType.getResults(), state.types) || + parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc, + state.operands)) { + return failure(); + } + return success(); +} + +static void printCallOp(OpAsmPrinter &p, CallOp op) { + p << "iree_hl_seq.call " << op.getAttr("callee") << '('; + p.printOperands(op.getOperands()); + p << ')'; + p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); + p << " : "; + p.printType(op.getCalleeType()); +} + +FunctionType CallOp::getCalleeType() { + SmallVector<Type, 4> resultTypes(getResultTypes()); + SmallVector<Type, 8> argTypes(getOperandTypes()); + return FunctionType::get(argTypes, resultTypes, getContext()); +} + +//===----------------------------------------------------------------------===// +// iree_hl_seq.call_indirect +//===----------------------------------------------------------------------===// + +static ParseResult parseCallIndirectOp(OpAsmParser &parser, + OperationState &result) { + FunctionType calleeType; + OpAsmParser::OperandType callee; + llvm::SMLoc operandsLoc; + SmallVector<OpAsmParser::OperandType, 4> operands; + return failure( + parser.parseOperand(callee) || parser.getCurrentLocation(&operandsLoc) || + parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || + parser.parseOptionalAttributeDict(result.attributes) || + parser.parseColonType(calleeType) || + parser.resolveOperand(callee, calleeType, result.operands) || + parser.resolveOperands(operands, calleeType.getInputs(), operandsLoc, + result.operands) || + parser.addTypesToList(calleeType.getResults(), result.types)); +} + +static void printCallIndirectOp(OpAsmPrinter &p, CallIndirectOp op) { + p << "iree_hl_seq.call_indirect "; + p.printOperand(op.getCallee()); + p << '('; + auto operandRange = op.getOperands(); + p.printOperands(++operandRange.begin(), operandRange.end()); + p << ')'; + p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); + p << " : " << op.getCallee()->getType(); +} + +//===----------------------------------------------------------------------===// +// iree_hl_seq.return +//===----------------------------------------------------------------------===// + +static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &state) { + SmallVector<OpAsmParser::OperandType, 2> opInfo; + SmallVector<Type, 2> types; + llvm::SMLoc loc = parser.getCurrentLocation(); + return failure(parser.parseOperandList(opInfo) || + (!opInfo.empty() && parser.parseColonTypeList(types)) || + parser.resolveOperands(opInfo, types, loc, state.operands)); +} + +static void printReturnOp(OpAsmPrinter &p, ReturnOp op) { + p << "iree_hl_seq.return"; + if (op.getNumOperands() > 0) { + p << ' '; + p.printOperands(op.operand_begin(), op.operand_end()); + p << " : "; + interleaveComma(op.getOperandTypes(), p); + } +} + +//===----------------------------------------------------------------------===// +// iree_hl_seq.br +//===----------------------------------------------------------------------===// + +static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &result) { + Block *dest; + SmallVector<Value *, 4> destOperands; + if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure(); + result.addSuccessor(dest, destOperands); + return success(); +} + +static void printBranchOp(OpAsmPrinter &p, BranchOp op) { + p << "iree_hl_seq.br "; + p.printSuccessorAndUseList(op.getOperation(), 0); +} + +Block *BranchOp::getDest() { return getOperation()->getSuccessor(0); } + +void BranchOp::setDest(Block *block) { + return getOperation()->setSuccessor(block, 0); +} + +void BranchOp::eraseOperand(unsigned index) { + getOperation()->eraseSuccessorOperand(0, index); +} + +//===----------------------------------------------------------------------===// +// iree_hl_seq.cond_br +//===----------------------------------------------------------------------===// + +static ParseResult parseCondBranchOp(OpAsmParser &parser, + OperationState &result) { + SmallVector<Value *, 4> destOperands; + Block *dest; + OpAsmParser::OperandType condInfo; + + // Parse the condition. + Type int1Ty = parser.getBuilder().getI1Type(); + if (parser.parseOperand(condInfo) || parser.parseComma() || + parser.resolveOperand(condInfo, int1Ty, result.operands)) { + return parser.emitError(parser.getNameLoc(), + "expected condition type was boolean (i1)"); + } + + // Parse the true successor. + if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure(); + result.addSuccessor(dest, destOperands); + + // Parse the false successor. + destOperands.clear(); + if (parser.parseComma() || + parser.parseSuccessorAndUseList(dest, destOperands)) + return failure(); + result.addSuccessor(dest, destOperands); + + return success(); +} + +static void printCondBranchOp(OpAsmPrinter &p, CondBranchOp op) { + p << "iree_hl_seq.cond_br "; + p.printOperand(op.getCondition()); + p << ", "; + p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex); + p << ", "; + p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex); +} + +//===----------------------------------------------------------------------===// +// iree_hl_seq.dispatch +//===----------------------------------------------------------------------===// + +static ParseResult parseDispatchOp(OpAsmParser &parser, OperationState &state) { + auto executableLoc = parser.getNameLoc(); + + SymbolRefAttr executableAttr; + SymbolRefAttr entryPointAttr; + FunctionType entryPointType; + if (failed(parser.parseAttribute(executableAttr, "executable", + state.attributes)) || + failed(parser.parseColon()) || failed(parser.parseColon()) || + failed(parser.parseAttribute(entryPointAttr, "entry_point", + state.attributes))) { + return failure(); + } + + OpAsmParser::OperandType workloadArg; + Type workloadArgType; + if (failed(parser.parseLSquare()) || + failed(parser.parseOperand(workloadArg)) || + failed(parser.parseColonType(workloadArgType)) || + failed(parser.parseRSquare()) || + failed(parser.resolveOperand(workloadArg, workloadArgType, + state.operands))) { + return failure(); + } + + SmallVector<OpAsmParser::OperandType, 4> operands; + if (failed( + parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren)) || + failed(parser.parseOptionalAttributeDict(state.attributes)) || + failed(parser.parseColonType(entryPointType)) || + failed(parser.addTypesToList(entryPointType.getResults(), state.types)) || + failed(parser.resolveOperands(operands, entryPointType.getInputs(), + executableLoc, state.operands))) { + return failure(); + } + return success(); +} + +static void printDispatchOp(OpAsmPrinter &p, DispatchOp op) { + p << "iree_hl_seq.dispatch " << op.getExecutable() + << "::" << op.getEntryPoint(); + p << "["; + p.printOperand(op.getWorkload()); + p << " : "; + p.printType(op.getWorkload()->getType()); + p << "]("; + p.printOperands(op.getArgOperands()); + p << ')'; + p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{ + "executable", + "entry_point", + }); + p << " : "; + p.printType(op.getEntryPointType()); +} + +static LogicalResult verifyDispatchOp(DispatchOp op) { + if (failed(verifyWorkload(op, op.getWorkload()))) { + return failure(); + } + return success(); +} + +FunctionType DispatchOp::getEntryPointType() { + SmallVector<Type, 4> resultTypes(getResultTypes()); + SmallVector<Type, 8> argTypes(getArgOperandTypes()); + return FunctionType::get(argTypes, resultTypes, getContext()); +} + +//===----------------------------------------------------------------------===// +// iree_hl_seq.rank +//===----------------------------------------------------------------------===// + +OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) { + Builder builder(getContext()); + if (auto op0 = operands[0].dyn_cast_or_null<ElementsAttr>()) { + return builder.getIntegerAttr(builder.getIntegerType(32), + op0.getType().getRank()); + } + return {}; +} + +//===----------------------------------------------------------------------===// +// iree_hl_seq.shape +//===----------------------------------------------------------------------===// + +void ShapeOp::build(Builder *builder, OperationState &state, Value *operand) { + state.addOperands(operand); + int64_t rank = 0; + if (auto shapedType = operand->getType().dyn_cast<ShapedType>()) { + rank = shapedType.getRank(); + } + state.addTypes(MemRefType::get({rank}, builder->getIntegerType(32))); +} + +OpFoldResult ShapeOp::fold(ArrayRef<Attribute> operands) { + Builder builder(getContext()); + if (auto op0 = operands[0].dyn_cast_or_null<ElementsAttr>()) { + return DenseIntElementsAttr::get( + RankedTensorType::get({op0.getType().getRank()}, + builder.getIntegerType(32)), + op0.getType().getShape()); + } + return {}; +} + +//===----------------------------------------------------------------------===// +// iree_hl_seq.length +//===----------------------------------------------------------------------===// + +OpFoldResult LengthOp::fold(ArrayRef<Attribute> operands) { + Builder builder(getContext()); + if (auto op0 = operands[0].dyn_cast_or_null<ElementsAttr>()) { + return builder.getIntegerAttr(builder.getIntegerType(32), + op0.getNumElements()); + } + return {}; +} + +//===----------------------------------------------------------------------===// +// iree_hl_seq.concat +//===----------------------------------------------------------------------===// + +namespace { +struct ConcatToCopies : public OpRewritePattern<ConcatOp> { + using OpRewritePattern::OpRewritePattern; + PatternMatchResult matchAndRewrite(ConcatOp concatOp, + PatternRewriter &rewriter) const override { + auto finalType = concatOp.getResult()->getType().cast<ShapedType>(); + auto loc = concatOp.getLoc(); + std::vector<Value *> dimPieces; + auto dst = + rewriter.create<IREESeq::HL::AllocHeapOp>(loc, finalType, dimPieces); + + llvm::SmallVector<int64_t, 4> zeroOffset(finalType.getRank(), 0); + auto srcIndices = createArrayConstant(rewriter, loc, zeroOffset); + + auto concatDimension = concatOp.dimension().getZExtValue(); + llvm::SmallVector<int64_t, 4> dstIndices(finalType.getRank(), 0); + for (auto *src : concatOp.srcs()) { + auto srcShape = src->getType().cast<ShapedType>().getShape(); + auto lengths = createArrayConstant(rewriter, loc, srcShape); + auto dstIndicesOp = createArrayConstant(rewriter, loc, dstIndices); + rewriter.create<IREESeq::HL::CopyOp>(loc, src, srcIndices, dst, + dstIndicesOp, lengths); + dstIndices[concatDimension] += srcShape[concatDimension]; + } + + concatOp.replaceAllUsesWith(dst.getResult()); + + return matchSuccess(); + } +}; +} // namespace + +void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert<ConcatToCopies>(context); +} + +#define GET_OP_CLASSES +#include "compiler/IR/Sequencer/HLOps.cpp.inc" + +} // namespace HL +} // namespace IREESeq +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/IR/Sequencer/HLOps.h b/compiler/IR/Sequencer/HLOps.h new file mode 100644 index 0000000..9661021 --- /dev/null +++ b/compiler/IR/Sequencer/HLOps.h
@@ -0,0 +1,39 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_IR_SEQUENCER_HLOPS_H_ +#define IREE_COMPILER_IR_SEQUENCER_HLOPS_H_ + +#include "compiler/IR/Types.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" + +namespace mlir { +namespace iree_compiler { +namespace IREESeq { +namespace HL { + +#define GET_OP_CLASSES +#include "compiler/IR/Sequencer/HLOps.h.inc" + +} // namespace HL +} // namespace IREESeq +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_IR_SEQUENCER_HLOPS_H_
diff --git a/compiler/IR/Sequencer/HLOps.td b/compiler/IR/Sequencer/HLOps.td new file mode 100644 index 0000000..6ca114c --- /dev/null +++ b/compiler/IR/Sequencer/HLOps.td
@@ -0,0 +1,429 @@ +// IREE high-level sequencer op definitions. +// This op set contains pseudo ops, ops that accept non-MemRef types, and ops in +// normal SSA form. +// +// Through lowering these high-level ops are converted to low-level ops in the +// LLOps.td (iree_ll_seq.*). These map 1:1 with the bytecode, accept +// only MemRef types, and generally use output parameters instead of return +// types. +// +// The source of truth for bytecode opcodes is: +// https://github.com/google/iree/tree/master/iree/schemas/bytecode/sequencer_bytecode_v0.h + +#ifdef IREE_SEQUENCER_HL_OPS +#else +#define IREE_SEQUENCER_HL_OPS + +#ifdef IREE_OP_BASE +#else +include "compiler/IR/OpBase.td" +#endif // IREE_OP_BASE + +def IREESeqHL_Dialect : Dialect { + let name = "iree_hl_seq"; + let cppNamespace = "IREESeq::HL"; +} + +//===----------------------------------------------------------------------===// +// Base op classes +//===----------------------------------------------------------------------===// + +class IREESeqHL_Op<string mnemonic, list<OpTrait> traits = []> : + Op<IREESeqHL_Dialect, mnemonic, traits>; + +class IREESeqHL_PureOp<string mnemonic, list<OpTrait> traits = []> : + IREESeqHL_Op<mnemonic, !listconcat(traits, [NoSideEffect])>; + +//===----------------------------------------------------------------------===// +// High-level sequencer ops +//===----------------------------------------------------------------------===// + +def IREESeqHL_CallOp : IREESeqHL_PureOp<"call"> { + let arguments = (ins SymbolRefAttr:$callee, Variadic<IREEHL_MemRef>); + let results = (outs Variadic<IREEHL_MemRef>); + + let builders = [OpBuilder< + "Builder *builder, OperationState &result, FuncOp callee," + "ArrayRef<Value *> operands = {}", [{ + result.addOperands(operands); + result.addAttribute("callee", builder->getSymbolRefAttr(callee)); + result.addTypes(callee.getType().getResults()); + }]>, OpBuilder< + "Builder *builder, OperationState &result, StringRef callee," + "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{ + result.addOperands(operands); + result.addAttribute("callee", builder->getSymbolRefAttr(callee)); + result.addTypes(results); + }]>]; + + let extraClassDeclaration = [{ + // TODO(b/132296600): make tablegen follow the style guide. + StringRef getCallee() { return callee(); } + FunctionType getCalleeType(); + + // TODO(b/133879130): make tablegen support variadic operand accessors. + /// Get the argument operands to the called function. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + + operand_iterator arg_operand_begin() { return operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + }]; + + let parser = [{ return parse$cppClass(parser, result); }]; + let printer = [{ return print$cppClass(p, *this); }]; +} + +def IREESeqHL_CallIndirectOp : IREESeqHL_Op<"call_indirect"> { + let arguments = (ins FunctionType:$callee, Variadic<IREEHL_MemRef>:$operands); + let results = (outs Variadic<IREEHL_MemRef>); + + let builders = [OpBuilder< + "Builder *, OperationState &result, Value *callee," + "ArrayRef<Value *> operands = {}", [{ + result.operands.push_back(callee); + result.addOperands(operands); + result.addTypes(callee->getType().cast<FunctionType>().getResults()); + }]>]; + + let extraClassDeclaration = [{ + // TODO(b/132296600): make tablegen follow the style guide. + Value *getCallee() { return getOperand(0); } + + // TODO(b/133879130): make tablegen support variadic operand accessors. + /// Get the argument operands to the called function. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + operand_iterator arg_operand_begin() { return ++operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + }]; + + let parser = [{ return parse$cppClass(parser, result); }]; + let printer = [{ return print$cppClass(p, *this); }]; +} + +def IREESeqHL_ReturnOp : IREESeqHL_Op<"return", [Terminator]> { + let arguments = (ins Variadic<IREEHL_MemRef>:$operands); + + let builders = [OpBuilder< + "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }] + >]; + + let parser = [{ return parse$cppClass(parser, result); }]; + let printer = [{ return print$cppClass(p, *this); }]; +} + +def IREESeqHL_BranchOp : IREESeqHL_Op<"br", [Terminator]> { + let arguments = (ins Variadic<IREEHL_MemRef>:$operands); + + let skipDefaultBuilders = 1; + let builders = [OpBuilder< + "Builder *, OperationState &result, Block *dest," + "ArrayRef<Value *> operands = {}", [{ + result.addSuccessor(dest, operands); + }]>]; + + let extraClassDeclaration = [{ + Block *getDest(); + void setDest(Block *block); + + /// Erase the operand at 'index' from the operand list. + void eraseOperand(unsigned index); + }]; + + let parser = [{ return parse$cppClass(parser, result); }]; + let printer = [{ return print$cppClass(p, *this); }]; +} + +def IREESeqHL_CondBranchOp : IREESeqHL_Op<"cond_br", [Terminator]> { + let arguments = (ins + IREEHL_BoolScalar:$condition, + Variadic<IREEHL_MemRef>:$branchOperands + ); + + let skipDefaultBuilders = 1; + let builders = [OpBuilder< + "Builder *, OperationState &result, Value *condition, " + "Block *trueDest, ArrayRef<Value *> trueOperands, " + "Block *falseDest, ArrayRef<Value *> falseOperands", [{ + result.addOperands(condition); + result.addSuccessor(trueDest, trueOperands); + result.addSuccessor(falseDest, falseOperands); + }]>]; + + let extraClassDeclaration = [{ + // These are the indices into the dests list. + enum { trueIndex = 0, falseIndex = 1 }; + + // The condition operand is the first operand in the list. + Value *getCondition() { return getOperand(0); } + + /// Return the destination if the condition is true. + Block *getTrueDest() { + return getOperation()->getSuccessor(trueIndex); + } + + /// Return the destination if the condition is false. + Block *getFalseDest() { + return getOperation()->getSuccessor(falseIndex); + } + + // Accessors for operands to the 'true' destination. + Value *getTrueOperand(unsigned idx) { + assert(idx < getNumTrueOperands()); + return getOperand(getTrueDestOperandIndex() + idx); + } + + void setTrueOperand(unsigned idx, Value *value) { + assert(idx < getNumTrueOperands()); + setOperand(getTrueDestOperandIndex() + idx, value); + } + + operand_iterator true_operand_begin() { + return operand_begin() + getTrueDestOperandIndex(); + } + operand_iterator true_operand_end() { + return true_operand_begin() + getNumTrueOperands(); + } + operand_range getTrueOperands() { + return {true_operand_begin(), true_operand_end()}; + } + + unsigned getNumTrueOperands() { + return getOperation()->getNumSuccessorOperands(trueIndex); + } + + /// Erase the operand at 'index' from the true operand list. + void eraseTrueOperand(unsigned index) { + getOperation()->eraseSuccessorOperand(trueIndex, index); + } + + // Accessors for operands to the 'false' destination. + Value *getFalseOperand(unsigned idx) { + assert(idx < getNumFalseOperands()); + return getOperand(getFalseDestOperandIndex() + idx); + } + void setFalseOperand(unsigned idx, Value *value) { + assert(idx < getNumFalseOperands()); + setOperand(getFalseDestOperandIndex() + idx, value); + } + + operand_iterator false_operand_begin() { return true_operand_end(); } + operand_iterator false_operand_end() { + return false_operand_begin() + getNumFalseOperands(); + } + operand_range getFalseOperands() { + return {false_operand_begin(), false_operand_end()}; + } + + unsigned getNumFalseOperands() { + return getOperation()->getNumSuccessorOperands(falseIndex); + } + + /// Erase the operand at 'index' from the false operand list. + void eraseFalseOperand(unsigned index) { + getOperation()->eraseSuccessorOperand(falseIndex, index); + } + + private: + /// Get the index of the first true destination operand. + unsigned getTrueDestOperandIndex() { return 1; } + + /// Get the index of the first false destination operand. + unsigned getFalseDestOperandIndex() { + return getTrueDestOperandIndex() + getNumTrueOperands(); + } + }]; + + let parser = [{ return parse$cppClass(parser, result); }]; + let printer = [{ return print$cppClass(p, *this); }]; +} + +def IREESeqHL_DispatchOp : IREESeqHL_Op<"dispatch"> { + let arguments = (ins + SymbolRefAttr:$executable, + SymbolRefAttr:$entry_point, + IREEHL_IntMemRef:$workload, + Variadic<IREEHL_MemRef>:$operands + ); + let results = (outs Variadic<IREEHL_MemRef>); + + let skipDefaultBuilders = 1; + let builders = [OpBuilder< + "Builder *builder, OperationState &result, StringRef executable," + "StringRef entry_point, Value *workload," + "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{ + result.addOperands({workload}); + result.addOperands(operands); + result.addAttribute("executable", builder->getSymbolRefAttr(executable)); + result.addAttribute("entry_point", builder->getSymbolRefAttr(entry_point)); + result.addTypes(results); + }]>]; + + let extraClassDeclaration = [{ + // TODO(b/132296600): make tablegen follow the style guide. + StringRef getExecutable() { return executable(); } + StringRef getEntryPoint() { return entry_point(); } + FunctionType getEntryPointType(); + + Value *getWorkload() { return getOperand(0); } + + // TODO(b/133879130): make tablegen support variadic operand accessors. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + operand_iterator arg_operand_begin() { return operand_begin() + 1; } + operand_iterator arg_operand_end() { return operand_end(); } + + operand_type_range getArgOperandTypes() { + return {arg_operand_type_begin(), arg_operand_type_end()}; + } + operand_type_iterator arg_operand_type_begin() { + return operand_type_iterator(arg_operand_begin()); + } + operand_type_iterator arg_operand_type_end() { + return operand_type_iterator(arg_operand_end()); + } + }]; + + let parser = [{ return parse$cppClass(parser, result); }]; + let printer = [{ return print$cppClass(p, *this); }]; + let verifier = [{ return verify$cppClass(*this); }]; +} + +// TODO(b/142012496): Add trait that enables DCE but not CSE. +def IREESeqHL_AllocHeapOp : IREESeqHL_Op<"alloc_heap"> { + // TODO(benvanik): attributes and args. + let arguments = (ins Variadic<IREEHL_IntMemRef>:$dim_pieces); + let results = (outs IREEHL_MemRef); +} + +def IREESeqHL_DiscardOp : IREESeqHL_Op<"discard"> { + let arguments = (ins IREEHL_MemRef); +} + +def IREESeqHL_RankOp : IREESeqHL_PureOp<"rank"> { + let arguments = (ins IREEHL_MemRef); + let results = (outs IREEHL_IntScalar); + + let hasFolder = 1; +} + +def IREESeqHL_DimOp : IREESeqHL_PureOp<"dim"> { + // TODO(benvanik) add dim attr (I32Attr:$dim) + let arguments = (ins IREEHL_MemRef); + let results = (outs IREEHL_IntScalar); +} + +def IREESeqHL_ShapeOp : IREESeqHL_PureOp<"shape"> { + let arguments = (ins IREEHL_MemRef); + let results = (outs IREEHL_1DIntMemRef); + + let skipDefaultBuilders = 1; + let builders = [OpBuilder< + "Builder *builder, OperationState &result, Value *operand">]; + + let hasFolder = 1; +} + +def IREESeqHL_LengthOp : IREESeqHL_PureOp<"length"> { + let arguments = (ins IREEHL_MemRef); + let results = (outs IREEHL_IndexScalar); + + let hasFolder = 1; +} + +def IREESeqHL_SliceOp : + IREESeqHL_PureOp<"slice", [AllElementTypesMatch<["src", "result"]>, + AllTypesMatch<["indices", "lengths"]>]> { + let arguments = (ins + IREEHL_MemRef:$src, + IREEHL_1DIndexMemRef:$indices, + IREEHL_1DIndexMemRef:$lengths + ); + let results = (outs IREEHL_MemRef:$result); +} + +def IREESeqHL_CopyOp : IREESeqHL_Op<"copy", [ + AllElementCountsMatch<["srcIndices", "dstIndices", "lengths"]>, + AllRanksMatch<["src", "dst"]>, + // The checks above are redundant with this one, but they give more specific + // error messages. + AllMatch<[ + Rank<"src">.result, + Rank<"dst">.result, + ElementCount<"srcIndices">.result, + ElementCount<"dstIndices">.result, + ElementCount<"lengths">.result + ], "src/dst rank is the same as srcIndices/dstIndices/lengths size">, + AllElementTypesMatch<["src", "dst"]> +]> { + let arguments = (ins + IREEHL_MemRef:$src, + IREEHL_1DIndexMemRef:$srcIndices, + IREEHL_MemRef:$dst, + IREEHL_1DIndexMemRef:$dstIndices, + IREEHL_1DIndexMemRef:$lengths + ); +} + +def IREESeqHL_FillOp : IREESeqHL_Op<"fill"> { + let arguments = (ins + IREEHL_I32Scalar:$value, + IREEHL_MemRef:$dst, + IREEHL_1DIndexMemRef:$dstIndices, + IREEHL_1DIndexMemRef:$lengths + ); +} + +def IREESeqHL_CloneOp : IREESeqHL_PureOp<"clone", [SameOperandsAndResultType]> { + let arguments = (ins IREEHL_MemRef:$src); + let results = (outs IREEHL_MemRef); +} + +// A pseudo op provided for convenience. This gets canonicalized to a series of +// copies. +def IREESeqHL_ConcatOp : IREESeqHL_PureOp<"concat"> { + // TODO(b/135032064) Add type constraints when they support variadic + let arguments = (ins + Variadic<IREEHL_MemRef>:$srcs, + I32Attr:$dimension + ); + let results = (outs IREEHL_MemRef); + + let hasCanonicalizer = 1; +} + +def IREESeqHL_AssignOp : + IREESeqHL_PureOp<"assign", [SameOperandsAndResultType]> { + let arguments = (ins IREEHL_MemRef:$src); + let results = (outs IREEHL_MemRef); +} + +def IREESeqHL_CondAssignOp : IREESeqHL_PureOp<"cond_assign"> { + let arguments = (ins + IREEHL_BoolScalar:$cond, + IREEHL_MemRef:$lhs, + IREEHL_MemRef:$rhs + ); + let results = (outs IREEHL_MemRef); +} + +def IREESeqHL_ReshapeOp : IREESeqHL_PureOp<"reshape"> { + let arguments = (ins IREEHL_MemRef:$src, IREEHL_MemRef:$shape); + let results = (outs IREEHL_MemRef); +} + +def IREESeqHL_TraceOp : IREESeqHL_Op<"trace"> { + let arguments = (ins Variadic<IREEHL_MemRef>:$srcs); +} + +def IREESeqHL_CondBreakOp : IREESeqHL_Op<"cond_break"> { + let arguments = (ins IREEHL_BoolScalar:$cond); +} + +def IREESeqHL_BreakOp : IREESeqHL_Op<"break">; + +#endif // IREE_SEQUENCER_HL_OPS
diff --git a/compiler/IR/Sequencer/LLDialect.cpp b/compiler/IR/Sequencer/LLDialect.cpp new file mode 100644 index 0000000..7fb2368 --- /dev/null +++ b/compiler/IR/Sequencer/LLDialect.cpp
@@ -0,0 +1,34 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Sequencer/LLDialect.h" + +#include "compiler/IR/Sequencer/LLOps.h" +#include "llvm/Support/SourceMgr.h" + +namespace mlir { +namespace iree_compiler { + +IREELLSequencerDialect::IREELLSequencerDialect(MLIRContext* context) + : Dialect(getDialectNamespace(), context) { +#define GET_OP_LIST + addOperations< +#include "compiler/IR/Sequencer/LLOps.cpp.inc" + >(); +} + +static DialectRegistration<IREELLSequencerDialect> iree_ll_seq_dialect; + +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/IR/Sequencer/LLDialect.h b/compiler/IR/Sequencer/LLDialect.h similarity index 100% rename from iree/compiler/IR/Sequencer/LLDialect.h rename to compiler/IR/Sequencer/LLDialect.h
diff --git a/compiler/IR/Sequencer/LLOps.cpp b/compiler/IR/Sequencer/LLOps.cpp new file mode 100644 index 0000000..a5205f0 --- /dev/null +++ b/compiler/IR/Sequencer/LLOps.cpp
@@ -0,0 +1,671 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Sequencer/LLOps.h" + +#include "compiler/IR/Ops.h" +#include "compiler/Utils/OpUtils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/STLExtras.h" + +namespace mlir { +namespace iree_compiler { +namespace IREESeq { +namespace LL { + +namespace { + +static LogicalResult verifyWorkload(Operation *op, Value *workload) { + if (auto workloadType = workload->getType().dyn_cast<MemRefType>()) { + if (workloadType.getNumElements() != 3) { + return op->emitOpError("workload must be specified as (x,y,z) but has ") + << workloadType.getNumElements() + << " elements (type=" << workload->getType() << ")"; + } + return success(); + } + return op->emitOpError( + "workload must be specified as an (x,y,z) memref but has type ") + << workload->getType(); +} + +static LogicalResult verifyWorkload(Operation *op, ElementsAttr workload) { + if (workload.getNumElements() != 3) { + return op->emitOpError("workload must be specified as (x,y,z) but has ") + << workload.getNumElements() << " elements (value=" << workload + << ")"; + } + return success(); +} + +} // namespace + +//===----------------------------------------------------------------------===// +// iree_ll_seq.constant +//===----------------------------------------------------------------------===// + +OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { + return getValue(); +} + +//===----------------------------------------------------------------------===// +// iree_ll_seq.call +//===----------------------------------------------------------------------===// + +static ParseResult parseCallOp(OpAsmParser &parser, OperationState &state) { + SymbolRefAttr calleeAttr; + FunctionType calleeType; + SmallVector<OpAsmParser::OperandType, 4> operands; + auto calleeLoc = parser.getNameLoc(); + if (parser.parseAttribute(calleeAttr, "callee", state.attributes) || + parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || + parser.parseOptionalAttributeDict(state.attributes) || + parser.parseColonType(calleeType) || + parser.addTypesToList(calleeType.getResults(), state.types) || + parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc, + state.operands)) { + return failure(); + } + return success(); +} + +static void printCallOp(OpAsmPrinter &p, CallOp op) { + p << "iree_ll_seq.call " << op.getAttr("callee") << '('; + p.printOperands(op.getOperands()); + p << ')'; + p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); + p << " : "; + p.printType(op.getCalleeType()); +} + +FunctionType CallOp::getCalleeType() { + SmallVector<Type, 4> resultTypes(getResultTypes()); + SmallVector<Type, 8> argTypes(getOperandTypes()); + return FunctionType::get(argTypes, resultTypes, getContext()); +} + +//===----------------------------------------------------------------------===// +// iree_ll_seq.call_import +//===----------------------------------------------------------------------===// + +static ParseResult parseCallImportOp(OpAsmParser &parser, + OperationState &state) { + SymbolRefAttr calleeAttr; + FunctionType calleeType; + SmallVector<OpAsmParser::OperandType, 4> operands; + auto calleeLoc = parser.getNameLoc(); + if (parser.parseAttribute(calleeAttr, "callee", state.attributes) || + parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || + parser.parseOptionalAttributeDict(state.attributes) || + parser.parseColonType(calleeType) || + parser.addTypesToList(calleeType.getResults(), state.types) || + parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc, + state.operands)) { + return failure(); + } + return success(); +} + +static void printCallImportOp(OpAsmPrinter &p, CallImportOp op) { + p << "iree_ll_seq.call_import " << op.getAttr("callee") << '('; + p.printOperands(op.getOperands()); + p << ')'; + p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); + p << " : "; + p.printType(op.getCalleeType()); +} + +FunctionType CallImportOp::getCalleeType() { + SmallVector<Type, 4> resultTypes(getResultTypes()); + SmallVector<Type, 8> argTypes(getOperandTypes()); + return FunctionType::get(argTypes, resultTypes, getContext()); +} + +//===----------------------------------------------------------------------===// +// iree_ll_seq.call_indirect +//===----------------------------------------------------------------------===// + +static ParseResult parseCallIndirectOp(OpAsmParser &parser, + OperationState &result) { + FunctionType calleeType; + OpAsmParser::OperandType callee; + llvm::SMLoc operandsLoc; + SmallVector<OpAsmParser::OperandType, 4> operands; + return failure( + parser.parseOperand(callee) || parser.getCurrentLocation(&operandsLoc) || + parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || + parser.parseOptionalAttributeDict(result.attributes) || + parser.parseColonType(calleeType) || + parser.resolveOperand(callee, calleeType, result.operands) || + parser.resolveOperands(operands, calleeType.getInputs(), operandsLoc, + result.operands) || + parser.addTypesToList(calleeType.getResults(), result.types)); +} + +static void printCallIndirectOp(OpAsmPrinter &p, CallIndirectOp op) { + p << "iree_ll_seq.call_indirect "; + p.printOperand(op.getCallee()); + p << '('; + auto operandRange = op.getOperands(); + p.printOperands(++operandRange.begin(), operandRange.end()); + p << ')'; + p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); + p << " : " << op.getCallee()->getType(); +} + +//===----------------------------------------------------------------------===// +// iree_ll_seq.return +//===----------------------------------------------------------------------===// + +static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &state) { + SmallVector<OpAsmParser::OperandType, 2> opInfo; + SmallVector<Type, 2> types; + llvm::SMLoc loc = parser.getCurrentLocation(); + return failure(parser.parseOperandList(opInfo) || + (!opInfo.empty() && parser.parseColonTypeList(types)) || + parser.resolveOperands(opInfo, types, loc, state.operands)); +} + +static void printReturnOp(OpAsmPrinter &p, ReturnOp op) { + p << "iree_ll_seq.return"; + if (op.getNumOperands() > 0) { + p << ' '; + p.printOperands(op.operand_begin(), op.operand_end()); + p << " : "; + interleaveComma(op.getOperandTypes(), p); + } +} + +//===----------------------------------------------------------------------===// +// iree_ll_seq.br +//===----------------------------------------------------------------------===// + +static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &result) { + Block *dest; + SmallVector<Value *, 4> destOperands; + if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure(); + result.addSuccessor(dest, destOperands); + return success(); +} + +static void printBranchOp(OpAsmPrinter &p, BranchOp op) { + p << "iree_ll_seq.br "; + p.printSuccessorAndUseList(op.getOperation(), 0); +} + +Block *BranchOp::getDest() { return getOperation()->getSuccessor(0); } + +void BranchOp::setDest(Block *block) { + return getOperation()->setSuccessor(block, 0); +} + +void BranchOp::eraseOperand(unsigned index) { + getOperation()->eraseSuccessorOperand(0, index); +} + +//===----------------------------------------------------------------------===// +// iree_ll_seq.cond_br +//===----------------------------------------------------------------------===// + +static ParseResult parseCondBranchOp(OpAsmParser &parser, + OperationState &result) { + SmallVector<Value *, 4> destOperands; + Block *dest; + OpAsmParser::OperandType condInfo; + + // Parse the condition. + Type int1Ty = parser.getBuilder().getI1Type(); + if (parser.parseOperand(condInfo) || parser.parseComma() || + parser.resolveOperand(condInfo, int1Ty, result.operands)) { + return parser.emitError(parser.getNameLoc(), + "expected condition type was boolean (i1)"); + } + + // Parse the true successor. + if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure(); + result.addSuccessor(dest, destOperands); + + // Parse the false successor. + destOperands.clear(); + if (parser.parseComma() || + parser.parseSuccessorAndUseList(dest, destOperands)) + return failure(); + result.addSuccessor(dest, destOperands); + + return success(); +} + +static void printCondBranchOp(OpAsmPrinter &p, CondBranchOp op) { + p << "iree_ll_interp.cond_br "; + p.printOperand(op.getCondition()); + p << ", "; + p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex); + p << ", "; + p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex); +} + +//===----------------------------------------------------------------------===// +// iree_ll_seq.dynamic_dispatch +//===----------------------------------------------------------------------===// + +static ParseResult parseDynamicDispatchOp(OpAsmParser &parser, + OperationState &state) { + auto executableLoc = parser.getNameLoc(); + + SymbolRefAttr executableAttr; + SymbolRefAttr entryPointAttr; + FunctionType entryPointType; + if (failed(parser.parseAttribute(executableAttr, "executable", + state.attributes)) || + failed(parser.parseColon()) || failed(parser.parseColon()) || + failed(parser.parseAttribute(entryPointAttr, "entry_point", + state.attributes))) { + return failure(); + } + + OpAsmParser::OperandType workloadArg; + Type workloadArgType; + if (failed(parser.parseLSquare()) || + failed(parser.parseOperand(workloadArg)) || + failed(parser.parseColonType(workloadArgType)) || + failed(parser.parseRSquare()) || + failed(parser.resolveOperand(workloadArg, workloadArgType, + state.operands))) { + return failure(); + } + + SmallVector<OpAsmParser::OperandType, 4> operands; + if (failed( + parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren)) || + failed(parser.parseOptionalAttributeDict(state.attributes)) || + failed(parser.parseColonType(entryPointType)) || + failed(parser.addTypesToList(entryPointType.getResults(), state.types)) || + failed(parser.resolveOperands(operands, entryPointType.getInputs(), + executableLoc, state.operands))) { + return failure(); + } + return success(); +} + +static void printDynamicDispatchOp(OpAsmPrinter &p, DynamicDispatchOp op) { + p << "iree_ll_seq.dynamic_dispatch " << op.getExecutable() + << "::" << op.getEntryPoint(); + p << "["; + p.printOperand(op.getWorkload()); + p << " : "; + p.printType(op.getWorkload()->getType()); + p << "]("; + p.printOperands(op.getArgOperands()); + p << ')'; + p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{ + "executable", + "entry_point", + }); + p << " : "; + p.printType(op.getEntryPointType()); +} + +static LogicalResult verifyDynamicDispatchOp(DynamicDispatchOp op) { + if (failed(verifyWorkload(op, op.getWorkload()))) { + return failure(); + } + return success(); +} + +FunctionType DynamicDispatchOp::getEntryPointType() { + SmallVector<Type, 4> resultTypes(getResultTypes()); + SmallVector<Type, 8> argTypes(getArgOperandTypes()); + return FunctionType::get(argTypes, resultTypes, getContext()); +} + +namespace { +struct MakeDynamicDispatchOpStatic + : public OpRewritePattern<DynamicDispatchOp> { + using OpRewritePattern::OpRewritePattern; + PatternMatchResult matchAndRewrite(DynamicDispatchOp dynamicDispatchOp, + PatternRewriter &rewriter) const override { + ElementsAttr workloadAttr; + if (!matchPattern(dynamicDispatchOp.getWorkload(), + m_Constant(&workloadAttr))) { + return matchFailure(); + } + + SmallVector<Type, 8> resultTypes{dynamicDispatchOp.getResultTypes()}; + SmallVector<Value *, 8> operands{dynamicDispatchOp.getArgOperands()}; + rewriter.replaceOpWithNewOp<IREESeq::LL::StaticDispatchOp>( + dynamicDispatchOp, dynamicDispatchOp.getExecutable(), + dynamicDispatchOp.getEntryPoint(), workloadAttr, resultTypes, operands); + return matchSuccess(); + } +}; +} // namespace + +void DynamicDispatchOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert<MakeDynamicDispatchOpStatic>(context); +} + +//===----------------------------------------------------------------------===// +// iree_ll_seq.static_dispatch +//===----------------------------------------------------------------------===// + +static ParseResult parseStaticDispatchOp(OpAsmParser &parser, + OperationState &state) { + auto executableLoc = parser.getNameLoc(); + + SymbolRefAttr executableAttr; + SymbolRefAttr entryPointAttr; + FunctionType entryPointType; + if (failed(parser.parseAttribute(executableAttr, "executable", + state.attributes)) || + failed(parser.parseColon()) || failed(parser.parseColon()) || + failed(parser.parseAttribute(entryPointAttr, "entry_point", + state.attributes))) { + return failure(); + } + + ElementsAttr workloadAttr; + if (failed(parser.parseLSquare()) || + failed( + parser.parseAttribute(workloadAttr, "workload", state.attributes)) || + failed(parser.parseRSquare())) { + return failure(); + } + + SmallVector<OpAsmParser::OperandType, 4> operands; + if (failed( + parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren)) || + failed(parser.parseOptionalAttributeDict(state.attributes)) || + failed(parser.parseColonType(entryPointType)) || + failed(parser.addTypesToList(entryPointType.getResults(), state.types)) || + failed(parser.resolveOperands(operands, entryPointType.getInputs(), + executableLoc, state.operands))) { + return failure(); + } + return success(); +} + +static void printStaticDispatchOp(OpAsmPrinter &p, StaticDispatchOp op) { + p << "iree_ll_seq.static_dispatch " << op.getExecutable() + << "::" << op.getEntryPoint(); + p << "["; + p.printAttribute(op.getWorkload()); + p << "]("; + p.printOperands(op.getArgOperands()); + p << ')'; + p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{ + "executable", + "entry_point", + "workload", + }); + p << " : "; + p.printType(op.getEntryPointType()); +} + +static LogicalResult verifyStaticDispatchOp(StaticDispatchOp op) { + if (failed(verifyWorkload(op, op.getWorkload()))) { + return failure(); + } + return success(); +} + +FunctionType StaticDispatchOp::getEntryPointType() { + SmallVector<Type, 4> resultTypes(getResultTypes()); + SmallVector<Type, 8> argTypes(getArgOperandTypes()); + return FunctionType::get(argTypes, resultTypes, getContext()); +} + +//===----------------------------------------------------------------------===// +// iree_ll_seq.shape +//===----------------------------------------------------------------------===// + +namespace { +struct FoldShapeOp : public OpRewritePattern<ShapeOp> { + using OpRewritePattern::OpRewritePattern; + PatternMatchResult matchAndRewrite(ShapeOp shapeOp, + PatternRewriter &rewriter) const override { + auto memRefType = shapeOp.input()->getType().cast<MemRefType>(); + if (memRefType.hasStaticShape()) { + auto constantOp = rewriter.create<IREESeq::LL::ConstantOp>( + shapeOp.getLoc(), + MemRefType::get({memRefType.getRank()}, rewriter.getIntegerType(64)), + DenseIntElementsAttr::get( + RankedTensorType::get({memRefType.getRank()}, + rewriter.getIntegerType(64)), + memRefType.getShape())); + replaceSubsequentUses(shapeOp, shapeOp.dst(), constantOp.getResult()); + rewriter.eraseOp(shapeOp); + return matchSuccess(); + } + return matchFailure(); + } +}; +} // namespace + +void ShapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert<FoldShapeOp>(context); +} + +//===----------------------------------------------------------------------===// +// iree_ll_seq.length +//===----------------------------------------------------------------------===// + +namespace { +struct FoldLengthOp : public OpRewritePattern<LengthOp> { + using OpRewritePattern::OpRewritePattern; + PatternMatchResult matchAndRewrite(LengthOp lengthOp, + PatternRewriter &rewriter) const override { + auto memRefType = lengthOp.input()->getType().cast<MemRefType>(); + if (memRefType.hasStaticShape()) { + auto constantOp = rewriter.create<IREESeq::LL::ConstantOp>( + lengthOp.getLoc(), MemRefType::get({}, rewriter.getIntegerType(64)), + DenseIntElementsAttr::get( + RankedTensorType::get({}, rewriter.getIntegerType(64)), + {memRefType.getNumElements()})); + replaceSubsequentUses(lengthOp, lengthOp.dst(), constantOp.getResult()); + rewriter.eraseOp(lengthOp); + return matchSuccess(); + } + return matchFailure(); + } +}; +} // namespace + +void LengthOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert<FoldLengthOp>(context); +} + +//===----------------------------------------------------------------------===// +// iree_ll_seq.compute_offset +//===----------------------------------------------------------------------===// + +namespace { +struct FoldComputeOffsetOp : public OpRewritePattern<ComputeOffsetOp> { + using OpRewritePattern::OpRewritePattern; + PatternMatchResult matchAndRewrite(ComputeOffsetOp computeOffsetOp, + PatternRewriter &rewriter) const override { + ElementsAttr shapeAttr; + ElementsAttr indicesAttr; + if (!matchPattern(computeOffsetOp.shape(), m_Constant(&shapeAttr)) || + !matchPattern(computeOffsetOp.indices(), m_Constant(&indicesAttr))) { + return matchFailure(); + } + + int64_t offset = 0; + for (unsigned i = 0; i < indicesAttr.getNumElements(); ++i) { + int64_t axisOffset = + indicesAttr.getValue({i}).cast<IntegerAttr>().getInt(); + for (unsigned j = i + 1; j < shapeAttr.getNumElements(); ++j) { + axisOffset *= shapeAttr.getValue({j}).cast<IntegerAttr>().getInt(); + } + offset += axisOffset; + } + offset *= computeOffsetOp.elementSize().getZExtValue(); + + auto constantOp = rewriter.create<IREESeq::LL::ConstantOp>( + computeOffsetOp.getLoc(), + MemRefType::get({}, rewriter.getIntegerType(64)), + DenseIntElementsAttr::get( + RankedTensorType::get({}, rewriter.getIntegerType(64)), {offset})); + replaceSubsequentUses(computeOffsetOp, computeOffsetOp.dst(), + constantOp.getResult()); + rewriter.eraseOp(computeOffsetOp); + return matchSuccess(); + } +}; +} // namespace + +void ComputeOffsetOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert<FoldComputeOffsetOp>(context); +} + +//===----------------------------------------------------------------------===// +// iree_ll_seq.compute_range +//===----------------------------------------------------------------------===// + +namespace { +struct FoldComputeRangeOp : public OpRewritePattern<ComputeRangeOp> { + using OpRewritePattern::OpRewritePattern; + PatternMatchResult matchAndRewrite(ComputeRangeOp computeRangeOp, + PatternRewriter &rewriter) const override { + ElementsAttr shapeAttr; + ElementsAttr indicesAttr; + ElementsAttr lengthsAttr; + if (!matchPattern(computeRangeOp.shape(), m_Constant(&shapeAttr)) || + !matchPattern(computeRangeOp.indices(), m_Constant(&indicesAttr)) || + !matchPattern(computeRangeOp.lengths(), m_Constant(&lengthsAttr))) { + return matchFailure(); + } + + int64_t offset = 0; + int64_t length = computeRangeOp.elementSize().getZExtValue(); + for (unsigned i = 0; i < indicesAttr.getNumElements(); ++i) { + int64_t axisOffset = + indicesAttr.getValue({i}).cast<IntegerAttr>().getInt(); + for (unsigned j = i + 1; j < shapeAttr.getNumElements(); ++j) { + axisOffset *= shapeAttr.getValue({j}).cast<IntegerAttr>().getInt(); + } + offset += axisOffset; + length *= lengthsAttr.getValue({i}).cast<IntegerAttr>().getInt(); + } + offset *= computeRangeOp.elementSize().getZExtValue(); + + auto offsetConstantOp = rewriter.create<IREESeq::LL::ConstantOp>( + computeRangeOp.getLoc(), + MemRefType::get({}, rewriter.getIntegerType(64)), + DenseIntElementsAttr::get( + RankedTensorType::get({}, rewriter.getIntegerType(64)), {offset})); + replaceSubsequentUses(computeRangeOp, computeRangeOp.dstOffset(), + offsetConstantOp.getResult()); + auto lengthConstantOp = rewriter.create<IREESeq::LL::ConstantOp>( + computeRangeOp.getLoc(), + MemRefType::get({}, rewriter.getIntegerType(64)), + DenseIntElementsAttr::get( + RankedTensorType::get({}, rewriter.getIntegerType(64)), {length})); + replaceSubsequentUses(computeRangeOp, computeRangeOp.dstLength(), + lengthConstantOp.getResult()); + rewriter.eraseOp(computeRangeOp); + return matchSuccess(); + } +}; +} // namespace + +void ComputeRangeOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert<FoldComputeRangeOp>(context); +} + +//===----------------------------------------------------------------------===// +// iree_ll_seq.dynamic_copy +//===----------------------------------------------------------------------===// + +namespace { +struct MakeDynamicCopyOpStatic : public OpRewritePattern<DynamicCopyOp> { + using OpRewritePattern::OpRewritePattern; + PatternMatchResult matchAndRewrite(DynamicCopyOp dynamicCopyOp, + PatternRewriter &rewriter) const override { + ElementsAttr srcOffsetAttr; + ElementsAttr dstOffsetAttr; + ElementsAttr lengthAttr; + if (!matchPattern(dynamicCopyOp.srcOffset(), m_Constant(&srcOffsetAttr)) || + !matchPattern(dynamicCopyOp.dstOffset(), m_Constant(&dstOffsetAttr)) || + !matchPattern(dynamicCopyOp.length(), m_Constant(&lengthAttr))) { + return matchFailure(); + } + + rewriter.replaceOpWithNewOp<IREESeq::LL::StaticCopyOp>( + dynamicCopyOp, dynamicCopyOp.src(), + srcOffsetAttr.getValue({}).cast<IntegerAttr>(), dynamicCopyOp.dst(), + dstOffsetAttr.getValue({}).cast<IntegerAttr>(), + lengthAttr.getValue({}).cast<IntegerAttr>()); + return matchSuccess(); + } +}; +} // namespace + +void DynamicCopyOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert<MakeDynamicCopyOpStatic>(context); +} + +//===----------------------------------------------------------------------===// +// iree_ll_seq.dynamic_fill +//===----------------------------------------------------------------------===// + +namespace { +struct MakeDynamicFillOpStatic : public OpRewritePattern<DynamicFillOp> { + using OpRewritePattern::OpRewritePattern; + PatternMatchResult matchAndRewrite(DynamicFillOp dynamicFillOp, + PatternRewriter &rewriter) const override { + ElementsAttr valueAttr; + ElementsAttr dstOffsetAttr; + ElementsAttr lengthAttr; + if (!matchPattern(dynamicFillOp.value(), m_Constant(&valueAttr)) || + !matchPattern(dynamicFillOp.dstOffset(), m_Constant(&dstOffsetAttr)) || + !matchPattern(dynamicFillOp.length(), m_Constant(&lengthAttr))) { + return matchFailure(); + } + + rewriter.replaceOpWithNewOp<IREESeq::LL::StaticFillOp>( + dynamicFillOp, valueAttr.getValue({}).cast<IntegerAttr>(), + dynamicFillOp.dst(), dstOffsetAttr.getValue({}).cast<IntegerAttr>(), + lengthAttr.getValue({}).cast<IntegerAttr>()); + return matchSuccess(); + } +}; +} // namespace + +void DynamicFillOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert<MakeDynamicFillOpStatic>(context); +} + +#define GET_OP_CLASSES +#include "compiler/IR/Sequencer/LLOps.cpp.inc" + +} // namespace LL +} // namespace IREESeq +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/IR/Sequencer/LLOps.h b/compiler/IR/Sequencer/LLOps.h new file mode 100644 index 0000000..c5e83bd --- /dev/null +++ b/compiler/IR/Sequencer/LLOps.h
@@ -0,0 +1,38 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_IR_SEQUENCER_LLOPS_H_ +#define IREE_COMPILER_IR_SEQUENCER_LLOPS_H_ + +#include "compiler/IR/Types.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" + +namespace mlir { +namespace iree_compiler { +namespace IREESeq { +namespace LL { + +#define GET_OP_CLASSES +#include "compiler/IR/Sequencer/LLOps.h.inc" + +} // namespace LL +} // namespace IREESeq +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_IR_SEQUENCER_LLOPS_H_
diff --git a/compiler/IR/Sequencer/LLOps.td b/compiler/IR/Sequencer/LLOps.td new file mode 100644 index 0000000..1c3324d --- /dev/null +++ b/compiler/IR/Sequencer/LLOps.td
@@ -0,0 +1,579 @@ +// IREE low-level sequencer op definitions. +// These map 1:1 with the bytecode, accept only MemRef types and generally use +// output parameters instead of return types. +// +// The source of truth for bytecode opcodes is: +// https://github.com/google/iree/tree/master/iree/schemas/bytecode/sequencer_bytecode_v0.h +// +// Note that in this dialect we cannot use folders: they require that all +// operands are possible to make constants where we use output arguments that +// will never be constant. Instead we can use canonicalization patterns to +// match constant input operands and do the folding by replacing output operands +// with the new values. + +#ifdef IREE_SEQUENCER_LL_OPS +#else +#define IREE_SEQUENCER_LL_OPS + +#ifdef IREE_OP_BASE +#else +include "compiler/IR/OpBase.td" +#endif // IREE_OP_BASE + +def IREESeqLL_Dialect : Dialect { + let name = "iree_ll_seq"; + let cppNamespace = "IREESeq::LL"; +} + +//===----------------------------------------------------------------------===// +// Base op classes +//===----------------------------------------------------------------------===// + +class IREESeqLL_Op<string mnemonic, list<OpTrait> traits = []> : + Op<IREESeqLL_Dialect, mnemonic, traits> { + bit hasCustomSerializer = 0; +} + +class IREESeqLL_PureOp<string mnemonic, list<OpTrait> traits = []> : + IREESeqLL_Op<mnemonic, !listconcat(traits, [NoSideEffect])>; + +class IREESeqLL_UnaryOp<string mnemonic, Type type = IREELL_MemRef, + list<OpTrait> traits = []> : IREESeqLL_Op<mnemonic, traits> { + let arguments = (ins type:$input, type:$dst); +} + +class IREESeqLL_BinaryOp<string mnemonic, Type type = IREELL_MemRef, + list<OpTrait> traits = []> : IREESeqLL_Op<mnemonic, traits> { + let arguments = (ins type:$lhs, type:$rhs, type:$dst); +} + +class IREESeqLL_TernaryOp<string mnemonic, Type type = IREELL_MemRef, + list<OpTrait> traits = []> + : IREESeqLL_Op<mnemonic, traits> { + let arguments = (ins type : $a, type : $b, type : $c, type : $dst); +} + +//===----------------------------------------------------------------------===// +// Low-level sequencer ops +//===----------------------------------------------------------------------===// + +def IREESeqLL_ConstantOp : IREESeqLL_PureOp<"constant"> { + let arguments = (ins ElementsAttr:$value); + let results = (outs IREELL_MemRef); + + // TODO(b/132296600): make tablegen follow the style guide. + let extraClassDeclaration = [{ + Attribute getValue() { return value(); } + }]; + + let hasFolder = 1; +} + +def IREESeqLL_CallOp : IREESeqLL_Op<"call"> { + let arguments = (ins SymbolRefAttr:$callee, Variadic<IREELL_MemRef>); + let results = (outs Variadic<IREELL_MemRef>); + + let builders = [OpBuilder< + "Builder *builder, OperationState &result, FuncOp callee," + "ArrayRef<Value *> operands = {}", [{ + result.addOperands(operands); + result.addAttribute("callee", builder->getSymbolRefAttr(callee)); + result.addTypes(callee.getType().getResults()); + }]>, OpBuilder< + "Builder *builder, OperationState &result, StringRef callee," + "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{ + result.addOperands(operands); + result.addAttribute("callee", builder->getSymbolRefAttr(callee)); + result.addTypes(results); + }]>]; + + let extraClassDeclaration = [{ + // TODO(b/132296600): make tablegen follow the style guide. + StringRef getCallee() { return callee(); } + FunctionType getCalleeType(); + + // TODO(b/133879130): make tablegen support variadic operand accessors. + /// Get the argument operands to the called function. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + + operand_iterator arg_operand_begin() { return operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + }]; + + let parser = [{ return parse$cppClass(parser, result); }]; + let printer = [{ return print$cppClass(p, *this); }]; +} + +// TODO(benvanik): add verifier that target isExternal. +def IREESeqLL_CallImportOp : IREESeqLL_Op<"call_import"> { + let arguments = (ins SymbolRefAttr:$callee, Variadic<IREELL_MemRef>); + let results = (outs Variadic<IREELL_MemRef>); + + let builders = [OpBuilder< + "Builder *builder, OperationState &result, FuncOp callee," + "ArrayRef<Value *> operands = {}", [{ + result.addOperands(operands); + result.addAttribute("callee", builder->getSymbolRefAttr(callee)); + result.addTypes(callee.getType().getResults()); + }]>, OpBuilder< + "Builder *builder, OperationState &result, StringRef callee," + "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{ + result.addOperands(operands); + result.addAttribute("callee", builder->getSymbolRefAttr(callee)); + result.addTypes(results); + }]>]; + + let extraClassDeclaration = [{ + // TODO(b/132296600): make tablegen follow the style guide. + StringRef getCallee() { return callee(); } + FunctionType getCalleeType(); + + // TODO(b/133879130): make tablegen support variadic operand accessors. + /// Get the argument operands to the called function. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + + operand_iterator arg_operand_begin() { return operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + }]; + + let parser = [{ return parse$cppClass(parser, result); }]; + let printer = [{ return print$cppClass(p, *this); }]; +} + +def IREESeqLL_CallIndirectOp : IREESeqLL_Op<"call_indirect"> { + let arguments = (ins FunctionType:$callee, Variadic<IREELL_MemRef>:$operands); + let results = (outs Variadic<IREELL_MemRef>); + + let builders = [OpBuilder< + "Builder *, OperationState &result, Value *callee," + "ArrayRef<Value *> operands = {}", [{ + result.operands.push_back(callee); + result.addOperands(operands); + result.addTypes(callee->getType().cast<FunctionType>().getResults()); + }]>]; + + let extraClassDeclaration = [{ + Value *getCallee() { return getOperand(0); } + + /// Get the argument operands to the called function. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + + operand_iterator arg_operand_begin() { return ++operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + }]; + + let parser = [{ return parse$cppClass(parser, result); }]; + let printer = [{ return print$cppClass(p, *this); }]; +} + +def IREESeqLL_ReturnOp : IREESeqLL_Op<"return", [Terminator]> { + let arguments = (ins Variadic<IREELL_MemRef>:$operands); + + let builders = [OpBuilder< + "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }] + >]; + + let parser = [{ return parse$cppClass(parser, result); }]; + let printer = [{ return print$cppClass(p, *this); }]; +} + +def IREESeqLL_BranchOp : IREESeqLL_Op<"br", [Terminator]> { + let arguments = (ins Variadic<IREELL_MemRef>:$operands); + + let skipDefaultBuilders = 1; + let builders = [OpBuilder< + "Builder *, OperationState &result, Block *dest, " + "ArrayRef<Value *> operands = {}", [{ + result.addSuccessor(dest, operands); + }]>]; + + let extraClassDeclaration = [{ + Block *getDest(); + void setDest(Block *block); + + /// Erase the operand at 'index' from the operand list. + void eraseOperand(unsigned index); + }]; + + let parser = [{ return parse$cppClass(parser, result); }]; + let printer = [{ return print$cppClass(p, *this); }]; +} + +def IREESeqLL_CondBranchOp : IREESeqLL_Op<"cond_br", [Terminator]> { + let arguments = (ins + IREELL_BoolScalar:$condition, + Variadic<IREELL_MemRef>:$branchOperands + ); + + let skipDefaultBuilders = 1; + let builders = [OpBuilder< + "Builder *, OperationState &result, Value *condition," + "Block *trueDest, ArrayRef<Value *> trueOperands," + "Block *falseDest, ArrayRef<Value *> falseOperands", [{ + result.addOperands(condition); + result.addSuccessor(trueDest, trueOperands); + result.addSuccessor(falseDest, falseOperands); + }]>]; + + let extraClassDeclaration = [{ + // These are the indices into the dests list. + enum { trueIndex = 0, falseIndex = 1 }; + + // The condition operand is the first operand in the list. + Value *getCondition() { return getOperand(0); } + + /// Return the destination if the condition is true. + Block *getTrueDest() { + return getOperation()->getSuccessor(trueIndex); + } + + /// Return the destination if the condition is false. + Block *getFalseDest() { + return getOperation()->getSuccessor(falseIndex); + } + + // Accessors for operands to the 'true' destination. + Value *getTrueOperand(unsigned idx) { + assert(idx < getNumTrueOperands()); + return getOperand(getTrueDestOperandIndex() + idx); + } + + void setTrueOperand(unsigned idx, Value *value) { + assert(idx < getNumTrueOperands()); + setOperand(getTrueDestOperandIndex() + idx, value); + } + + operand_iterator true_operand_begin() { + return operand_begin() + getTrueDestOperandIndex(); + } + operand_iterator true_operand_end() { + return true_operand_begin() + getNumTrueOperands(); + } + operand_range getTrueOperands() { + return {true_operand_begin(), true_operand_end()}; + } + + unsigned getNumTrueOperands() { + return getOperation()->getNumSuccessorOperands(trueIndex); + } + + /// Erase the operand at 'index' from the true operand list. + void eraseTrueOperand(unsigned index) { + getOperation()->eraseSuccessorOperand(trueIndex, index); + } + + // Accessors for operands to the 'false' destination. + Value *getFalseOperand(unsigned idx) { + assert(idx < getNumFalseOperands()); + return getOperand(getFalseDestOperandIndex() + idx); + } + void setFalseOperand(unsigned idx, Value *value) { + assert(idx < getNumFalseOperands()); + setOperand(getFalseDestOperandIndex() + idx, value); + } + + operand_iterator false_operand_begin() { return true_operand_end(); } + operand_iterator false_operand_end() { + return false_operand_begin() + getNumFalseOperands(); + } + operand_range getFalseOperands() { + return {false_operand_begin(), false_operand_end()}; + } + + unsigned getNumFalseOperands() { + return getOperation()->getNumSuccessorOperands(falseIndex); + } + + /// Erase the operand at 'index' from the false operand list. + void eraseFalseOperand(unsigned index) { + getOperation()->eraseSuccessorOperand(falseIndex, index); + } + + private: + /// Get the index of the first true destination operand. + unsigned getTrueDestOperandIndex() { return 1; } + + /// Get the index of the first false destination operand. + unsigned getFalseDestOperandIndex() { + return getTrueDestOperandIndex() + getNumTrueOperands(); + } + }]; + + let parser = [{ return parse$cppClass(parser, result); }]; + let printer = [{ return print$cppClass(p, *this); }]; +} + +def IREESeqLL_DynamicDispatchOp : IREESeqLL_Op<"dynamic_dispatch"> { + let arguments = (ins + SymbolRefAttr:$executable, + SymbolRefAttr:$entry_point, + IREELL_IntMemRef:$workload, + Variadic<IREELL_MemRef>:$operands + ); + let results = (outs Variadic<IREELL_MemRef>); + + let builders = [OpBuilder< + "Builder *builder, OperationState &result, StringRef executable," + "StringRef entry_point, Value *workload," + "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{ + result.addOperands({workload}); + result.addOperands(operands); + result.addAttribute("executable", builder->getSymbolRefAttr(executable)); + result.addAttribute("entry_point", builder->getSymbolRefAttr(entry_point)); + result.addTypes(results); + }]>]; + + let extraClassDeclaration = [{ + // TODO(b/132296600): make tablegen follow the style guide. + StringRef getExecutable() { return executable(); } + StringRef getEntryPoint() { return entry_point(); } + FunctionType getEntryPointType(); + + Value *getWorkload() { return getOperand(0); } + + // TODO(b/133879130): make tablegen support variadic operand accessors. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + operand_iterator arg_operand_begin() { return operand_begin() + 1; } + operand_iterator arg_operand_end() { return operand_end(); } + + operand_type_range getArgOperandTypes() { + return {arg_operand_type_begin(), arg_operand_type_end()}; + } + operand_type_iterator arg_operand_type_begin() { + return operand_type_iterator(arg_operand_begin()); + } + operand_type_iterator arg_operand_type_end() { + return operand_type_iterator(arg_operand_end()); + } + }]; + + let parser = [{ return parse$cppClass(parser, result); }]; + let printer = [{ return print$cppClass(p, *this); }]; + let verifier = [{ return verify$cppClass(*this); }]; + let hasCanonicalizer = 1; +} + +def IREESeqLL_StaticDispatchOp : IREESeqLL_Op<"static_dispatch"> { + let arguments = (ins + SymbolRefAttr:$executable, + SymbolRefAttr:$entry_point, + I32ElementsAttr:$workload, + Variadic<IREELL_MemRef>:$operands + ); + let results = (outs Variadic<IREELL_MemRef>); + + let builders = [OpBuilder< + "Builder *builder, OperationState &result, StringRef executable," + "StringRef entry_point, ElementsAttr workload," + "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{ + result.addAttribute("workload", workload); + result.addOperands(operands); + result.addAttribute("executable", builder->getSymbolRefAttr(executable)); + result.addAttribute("entry_point", builder->getSymbolRefAttr(entry_point)); + result.addTypes(results); + }]>]; + + let extraClassDeclaration = [{ + // TODO(b/132296600): make tablegen follow the style guide. + StringRef getExecutable() { return executable(); } + StringRef getEntryPoint() { return entry_point(); } + FunctionType getEntryPointType(); + + ElementsAttr getWorkload() { return workload(); } + + // TODO(b/133879130): make tablegen support variadic operand accessors. + operand_range getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + operand_iterator arg_operand_begin() { return operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + + operand_type_range getArgOperandTypes() { + return {arg_operand_type_begin(), arg_operand_type_end()}; + } + operand_type_iterator arg_operand_type_begin() { + return operand_type_iterator(arg_operand_begin()); + } + operand_type_iterator arg_operand_type_end() { + return operand_type_iterator(arg_operand_end()); + } + }]; + + let parser = [{ return parse$cppClass(parser, result); }]; + let printer = [{ return print$cppClass(p, *this); }]; + let verifier = [{ return verify$cppClass(*this); }]; +} + +def IREESeqLL_AllocStaticOp : IREESeqLL_PureOp<"alloc_static"> { + // TODO(benvanik): attributes and args. + let results = (outs IREELL_MemRef); +} + +def IREESeqLL_AllocStackOp : IREESeqLL_PureOp<"alloc_stack"> { + // TODO(benvanik): attributes and args. + let arguments = (ins Variadic<IREELL_IntMemRef>:$dim_pieces); + let results = (outs IREELL_MemRef); +} + +def IREESeqLL_AllocStackInitOp : IREESeqLL_PureOp<"alloc_stack_init"> { + // TODO(benvanik): attributes and args. + let arguments = (ins Variadic<IREELL_IntMemRef>:$dim_pieces); + let results = (outs IREELL_MemRef); +} + +// TODO(b/142012496): Add trait that enables DCE but not CSE. +def IREESeqLL_AllocHeapOp : IREESeqLL_Op<"alloc_heap"> { + // TODO(benvanik): attributes and args. + let arguments = (ins Variadic<IREELL_IntMemRef>:$dim_pieces); + let results = (outs IREELL_MemRef); +} + +def IREESeqLL_DiscardOp : IREESeqLL_Op<"discard"> { + let arguments = (ins IREELL_MemRef); +} + +def IREESeqLL_ShapeOp : IREESeqLL_Op<"shape"> { + let arguments = (ins IREELL_MemRef:$input, IREELL_I32MemRef:$dst); + + let hasCanonicalizer = 1; +} + +def IREESeqLL_LengthOp : IREESeqLL_Op<"length"> { + let arguments = (ins IREELL_MemRef:$input, IREELL_I32Scalar:$dst); + + let hasCanonicalizer = 1; +} + +def IREESeqLL_ComputeOffsetOp : IREESeqLL_Op<"compute_offset"> { + let arguments = (ins + IREELL_1DIntMemRef:$shape, + I8Attr:$elementSize, + IREELL_1DIntMemRef:$indices, + IREELL_I32Scalar:$dst + ); + + let hasCanonicalizer = 1; +} + +def IREESeqLL_ComputeRangeOp : IREESeqLL_Op<"compute_range"> { + let arguments = (ins + IREELL_1DIntMemRef:$shape, + I8Attr:$elementSize, + IREELL_1DIntMemRef:$indices, + IREELL_1DIntMemRef:$lengths, + IREELL_I32Scalar:$dstOffset, + IREELL_I32Scalar:$dstLength + ); + + let hasCanonicalizer = 1; +} + +def IREESeqLL_DynamicSliceOp : IREESeqLL_PureOp<"dynamic_slice", [ + AllElementTypesMatch<["src", "result"]> +]> { + let arguments = (ins + IREELL_MemRef:$src, + IREELL_IntScalar:$offset, + IREELL_IntScalar:$length + ); + let results = (outs IREELL_MemRef:$result); +} + +def IREESeqLL_StaticSliceOp : IREESeqLL_PureOp<"static_slice", [ + AllElementTypesMatch<["src", "result"]> +]> { + let arguments = (ins + IREELL_MemRef:$src, + I64Attr:$offset, + I64Attr:$length + ); + let results = (outs IREELL_MemRef:$result); +} + +def IREESeqLL_DynamicCopyOp : IREESeqLL_Op<"dynamic_copy"> { + let arguments = (ins + IREELL_MemRef:$src, + IREELL_IndexScalar:$srcOffset, + IREELL_MemRef:$dst, + IREELL_IndexScalar:$dstOffset, + IREELL_IndexScalar:$length + ); + + let hasCanonicalizer = 1; +} + +def IREESeqLL_StaticCopyOp : IREESeqLL_Op<"static_copy"> { + let arguments = (ins + IREELL_MemRef:$src, + I64Attr:$srcOffset, + IREELL_MemRef:$dst, + I64Attr:$dstOffset, + I64Attr:$length + ); +} + +def IREESeqLL_DynamicFillOp : IREESeqLL_Op<"dynamic_fill"> { + let arguments = (ins + IREELL_I32Scalar:$value, + IREELL_MemRef:$dst, + IREELL_IndexScalar:$dstOffset, + IREELL_IndexScalar:$length + ); + + let hasCanonicalizer = 1; +} + +def IREESeqLL_StaticFillOp : IREESeqLL_Op<"static_fill"> { + let arguments = (ins + I32Attr:$value, + IREELL_MemRef:$dst, + I64Attr:$dstOffset, + I64Attr:$length + ); +} + +def IREESeqLL_CloneOp : + IREESeqLL_PureOp<"clone", [SameOperandsAndResultType]> { + let arguments = (ins IREELL_MemRef:$src); + let results = (outs IREELL_MemRef); +} + +def IREESeqLL_AssignOp : + IREESeqLL_Op<"assign", [SameOperandsAndResultType]> { + let arguments = (ins IREELL_MemRef:$src); + let results = (outs IREELL_MemRef); +} + +def IREESeqLL_CondAssignOp : IREESeqLL_Op<"cond_assign"> { + let arguments = (ins + IREELL_BoolScalar:$cond, + IREELL_MemRef:$lhs, + IREELL_MemRef:$rhs + ); + let results = (outs IREELL_MemRef); +} + +def IREESeqLL_ReshapeOp : IREESeqLL_Op<"reshape"> { + let arguments = (ins IREELL_MemRef:$input, IREELL_1DIntMemRef:$shape); + let results = (outs IREELL_MemRef); +} + +def IREESeqLL_TraceOp : IREESeqLL_Op<"trace"> { + let arguments = (ins Variadic<IREELL_MemRef>:$srcs); +} + +def IREESeqLL_CondBreakOp : IREESeqLL_Op<"cond_break"> { + let arguments = (ins IREELL_BoolScalar:$cond); +} + +def IREESeqLL_BreakOp : IREESeqLL_Op<"break">; + +#endif // IREE_SEQUENCER_LL_OPS
diff --git a/compiler/IR/Sequencer/OpWriters.cpp b/compiler/IR/Sequencer/OpWriters.cpp new file mode 100644 index 0000000..9d2be02 --- /dev/null +++ b/compiler/IR/Sequencer/OpWriters.cpp
@@ -0,0 +1,266 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Sequencer/OpWriters.h" + +#include "compiler/IR/Sequencer/LLOps.h" +#include "compiler/IR/StructureOps.h" +#include "compiler/Serialization/BytecodeWriter.h" +#include "compiler/Utils/Macros.h" +#include "schemas/bytecode/sequencer_bytecode_v0.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/OpImplementation.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +//===----------------------------------------------------------------------===// +// Sequencer ops +//===----------------------------------------------------------------------===// + +LogicalResult writeOp(IREESeq::LL::ConstantOp op, BytecodeWriter *writer) { + RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kConstant)); + auto memRefType = op.getType().dyn_cast<MemRefType>(); + if (!memRefType) { + return op.emitError() + << "Constant has an unsupported type; must be a memref: " + << op.getType(); + } + RETURN_IF_FAILURE(writer->WriteConstant(memRefType, op.getAttr("value"))); + RETURN_IF_FAILURE(writer->WriteLocal(op.getResult())); + return success(); +} + +LogicalResult writeOp(IREESeq::LL::CallOp op, BytecodeWriter *writer) { + auto module = op.getOperation()->getParentOfType<ModuleOp>(); + auto callee = module.lookupSymbol<FuncOp>(op.getCallee()); + // TODO(benvanik): switch with kCallTail if attr exists. + RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kCall)); + RETURN_IF_FAILURE(writer->WriteFunctionOrdinal(callee)); + RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands())); + RETURN_IF_FAILURE(writer->WriteLocals(op.getResults())); + return success(); +} + +LogicalResult writeOp(IREESeq::LL::CallImportOp op, BytecodeWriter *writer) { + auto module = op.getOperation()->getParentOfType<ModuleOp>(); + auto callee = module.lookupSymbol<FuncOp>(op.getCallee()); + // TODO(benvanik): transforms to convert Call->CallImport. + // TODO(benvanik): switch with kCallTail if attr exists. + if (callee.isExternal()) { + RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kCallImport)); + } else { + RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kCall)); + } + RETURN_IF_FAILURE(writer->WriteImportOrdinal(callee)); + RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands())); + RETURN_IF_FAILURE(writer->WriteLocals(op.getResults())); + return success(); +} + +LogicalResult writeOp(IREESeq::LL::CallIndirectOp op, BytecodeWriter *writer) { + RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kCallIndirect)); + RETURN_IF_FAILURE(writer->WriteTypeIndex(op.getCallee()->getType())); + RETURN_IF_FAILURE(writer->WriteLocal(op.getCallee())); + RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands())); + RETURN_IF_FAILURE(writer->WriteLocals(op.getResults())); + return success(); +} + +LogicalResult writeOp(IREESeq::LL::BranchOp op, BytecodeWriter *writer) { + RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kBranch)); + RETURN_IF_FAILURE(writer->WriteBlockOffset(op.getDest())); + RETURN_IF_FAILURE(writer->WriteCount(op.getNumOperands())); + for (int i = 0; i < op.getNumOperands(); ++i) { + // Copy src->dst. + RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(i))); + RETURN_IF_FAILURE(writer->WriteLocal(op.getDest()->getArgument(i))); + } + return success(); +} + +LogicalResult writeOp(IREESeq::LL::CondBranchOp op, BytecodeWriter *writer) { + RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kCondBranch)); + RETURN_IF_FAILURE(writer->WriteLocal(op.getCondition())); + RETURN_IF_FAILURE(writer->WriteBlockOffset(op.getTrueDest())); + RETURN_IF_FAILURE(writer->WriteCount(op.getNumTrueOperands())); + for (int i = 0; i < op.getNumTrueOperands(); ++i) { + // Copy src->dst. + RETURN_IF_FAILURE(writer->WriteLocal(op.getTrueOperand(i))); + RETURN_IF_FAILURE(writer->WriteLocal(op.getTrueDest()->getArgument(i))); + } + RETURN_IF_FAILURE(writer->WriteBlockOffset(op.getFalseDest())); + RETURN_IF_FAILURE(writer->WriteCount(op.getNumFalseOperands())); + for (int i = 0; i < op.getNumFalseOperands(); ++i) { + // Copy src->dst. + RETURN_IF_FAILURE(writer->WriteLocal(op.getFalseOperand(i))); + RETURN_IF_FAILURE(writer->WriteLocal(op.getFalseDest()->getArgument(i))); + } + return success(); +} + +LogicalResult writeDispatchOpExecutableRef(Operation *op, StringRef executable, + StringRef entryPoint, + BytecodeWriter *writer) { + auto module = op->getParentOfType<ModuleOp>(); + auto multiArchExecutableOp = + module.lookupSymbol<IREE::MultiArchExecutableOp>(executable); + if (!multiArchExecutableOp) { + return op->emitError() << "Executable @" << executable.str() + << " not found in module"; + } + + auto executableOrdinalAttr = multiArchExecutableOp.getAttr("iree.ordinal") + .dyn_cast_or_null<IntegerAttr>(); + if (!executableOrdinalAttr) { + return op->emitError() << "No ordinal assigned to executable"; + } + int executableOrdinal = executableOrdinalAttr.getInt(); + + // TODO(benvanik): move an export table to the MAE to make this cleaner. + auto executableOp = + cast<IREE::ExecutableOp>(multiArchExecutableOp.getBlock().front()); + auto entryPointOp = + executableOp.getInnerModule().lookupSymbol<FuncOp>(entryPoint); + if (!entryPointOp) { + return op->emitError() << "Entry point @" << entryPoint.str() + << " not found in executable @" << executable.str(); + } + if (!entryPointOp.getAttr("iree.ordinal")) { + return op->emitError() << "No ordinal assigned to entry point"; + } + int entryPointOrdinal = + entryPointOp.getAttr("iree.ordinal").cast<IntegerAttr>().getInt(); + + RETURN_IF_FAILURE(writer->WriteUint32(executableOrdinal)); + RETURN_IF_FAILURE(writer->WriteUint16(entryPointOrdinal)); + + return success(); +} + +LogicalResult writeOp(IREESeq::LL::DynamicDispatchOp op, + BytecodeWriter *writer) { + RETURN_IF_FAILURE( + writer->WriteOpcode(iree::SequencerOpcode::kDynamicDispatch)); + RETURN_IF_FAILURE(writeDispatchOpExecutableRef(op, op.getExecutable(), + op.getEntryPoint(), writer)); + RETURN_IF_FAILURE(writer->WriteLocal(op.getWorkload())); + RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands())); + // TODO(benvanik): support output arg group (or change to tags). + RETURN_IF_FAILURE(writer->WriteCount(/*output_arg_count*/ 0)); + RETURN_IF_FAILURE(writer->WriteLocals(op.getResults())); + return success(); +} + +LogicalResult writeOp(IREESeq::LL::StaticDispatchOp op, + BytecodeWriter *writer) { + RETURN_IF_FAILURE( + writer->WriteOpcode(iree::SequencerOpcode::kStaticDispatch)); + RETURN_IF_FAILURE(writeDispatchOpExecutableRef(op, op.getExecutable(), + op.getEntryPoint(), writer)); + auto workloadAttr = op.getWorkload(); + RETURN_IF_FAILURE( + writer->WriteInt32(workloadAttr.getValue<IntegerAttr>({0}).getInt())); + RETURN_IF_FAILURE( + writer->WriteInt32(workloadAttr.getValue<IntegerAttr>({1}).getInt())); + RETURN_IF_FAILURE( + writer->WriteInt32(workloadAttr.getValue<IntegerAttr>({2}).getInt())); + RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands())); + // TODO(benvanik): support output arg group (or change to tags). + RETURN_IF_FAILURE(writer->WriteCount(/*output_arg_count*/ 0)); + RETURN_IF_FAILURE(writer->WriteLocals(op.getResults())); + return success(); +} + +LogicalResult writeOp(IREESeq::LL::AllocHeapOp op, BytecodeWriter *writer) { + auto memRefType = op.getType().cast<MemRefType>(); + RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kAllocHeap)); + RETURN_IF_FAILURE(writer->WriteInt32(0)); + RETURN_IF_FAILURE(writer->WriteTypeIndex(memRefType.getElementType())); + RETURN_IF_FAILURE(writer->WriteShapePieces(memRefType)); + RETURN_IF_FAILURE(writer->WriteLocals(op.getOperands())); + RETURN_IF_FAILURE(writer->WriteLocal(op.getResult())); + return success(); +} + +LogicalResult writeOp(IREESeq::LL::ComputeRangeOp op, BytecodeWriter *writer) { + RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kComputeRange)); + RETURN_IF_FAILURE(writer->WriteLocal(op.shape())); + RETURN_IF_FAILURE(writer->WriteUint8(op.elementSize().getZExtValue())); + RETURN_IF_FAILURE(writer->WriteLocal(op.indices())); + RETURN_IF_FAILURE(writer->WriteLocal(op.lengths())); + RETURN_IF_FAILURE(writer->WriteLocal(op.dstOffset())); + RETURN_IF_FAILURE(writer->WriteLocal(op.dstLength())); + return success(); +} + +LogicalResult writeOp(IREESeq::LL::StaticSliceOp op, BytecodeWriter *writer) { + RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kStaticSlice)); + RETURN_IF_FAILURE(writer->WriteLocal(op.src())); + RETURN_IF_FAILURE(writer->WriteInt32(op.offset().getZExtValue())); + RETURN_IF_FAILURE(writer->WriteInt32(op.length().getZExtValue())); + RETURN_IF_FAILURE(writer->WriteTypeIndex(op.getResult()->getType())); + RETURN_IF_FAILURE( + writer->WriteShapePieces(op.getResult()->getType().cast<ShapedType>())); + RETURN_IF_FAILURE(writer->WriteLocal(op.getResult())); + return success(); +} + +LogicalResult writeOp(IREESeq::LL::StaticCopyOp op, BytecodeWriter *writer) { + RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kStaticCopy)); + RETURN_IF_FAILURE(writer->WriteLocal(op.src())); + RETURN_IF_FAILURE(writer->WriteInt32(op.srcOffset().getZExtValue())); + RETURN_IF_FAILURE(writer->WriteLocal(op.dst())); + RETURN_IF_FAILURE(writer->WriteInt32(op.dstOffset().getZExtValue())); + RETURN_IF_FAILURE(writer->WriteInt32(op.length().getZExtValue())); + return success(); +} + +LogicalResult writeOp(IREESeq::LL::StaticFillOp op, BytecodeWriter *writer) { + RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kStaticFill)); + RETURN_IF_FAILURE(writer->WriteInt32(op.value().getZExtValue())); + RETURN_IF_FAILURE(writer->WriteLocal(op.dst())); + RETURN_IF_FAILURE(writer->WriteInt32(op.dstOffset().getZExtValue())); + RETURN_IF_FAILURE(writer->WriteInt32(op.length().getZExtValue())); + return success(); +} + +} // namespace + +void registerSequencerCustomWriters(VMFunctionBuilder *builder) { +#define REGISTER_CUSTOM_WRITER_IMPL(op_type) \ + builder->RegisterCustomWriter( \ + op_type::getOperationName(), \ + +[](Operation *op, BytecodeWriter *writer) { \ + return writeOp(cast<op_type>(op), writer); \ + }); + REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::ConstantOp); + REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::CallOp); + REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::CallImportOp); + REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::CallIndirectOp); + REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::BranchOp); + REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::CondBranchOp); + REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::DynamicDispatchOp); + REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::StaticDispatchOp); + REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::AllocHeapOp); + REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::ComputeRangeOp); + REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::StaticSliceOp); + REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::StaticCopyOp); + REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::StaticFillOp); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/IR/Sequencer/OpWriters.h b/compiler/IR/Sequencer/OpWriters.h new file mode 100644 index 0000000..d9b638f --- /dev/null +++ b/compiler/IR/Sequencer/OpWriters.h
@@ -0,0 +1,30 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_IR_SEQUENCER_OPWRITERS_H_ +#define IREE_COMPILER_IR_SEQUENCER_OPWRITERS_H_ + +#include "compiler/Serialization/VMFunctionBuilder.h" + +namespace mlir { +namespace iree_compiler { + +// Registers custom op writers with the builder. +// Ops not registered will use the generic writer. +void registerSequencerCustomWriters(VMFunctionBuilder *builder); + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_IR_SEQUENCER_OPWRITERS_H_
diff --git a/compiler/IR/Sequencer/test/BUILD b/compiler/IR/Sequencer/test/BUILD new file mode 100644 index 0000000..44a5820 --- /dev/null +++ b/compiler/IR/Sequencer/test/BUILD
@@ -0,0 +1,15 @@ +load("//:build_defs.google.bzl", "iree_glob_lit_tests", "iree_setup_lit_package") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +iree_setup_lit_package( + data = [ + "///tools:iree-opt", + "///tools:iree-run-mlir", + ], +) + +iree_glob_lit_tests()
diff --git a/iree/compiler/IR/Sequencer/test/concat.mlir b/compiler/IR/Sequencer/test/concat.mlir similarity index 100% rename from iree/compiler/IR/Sequencer/test/concat.mlir rename to compiler/IR/Sequencer/test/concat.mlir
diff --git a/compiler/IR/StructureOps.cpp b/compiler/IR/StructureOps.cpp new file mode 100644 index 0000000..0778c76 --- /dev/null +++ b/compiler/IR/StructureOps.cpp
@@ -0,0 +1,250 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/StructureOps.h" + +#include "compiler/IR/Types.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/STLExtras.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { + +//===----------------------------------------------------------------------===// +// Generic printers and parsers. +//===----------------------------------------------------------------------===// + +// Parses an op that has no inputs and no outputs. +static ParseResult parseNoIOOp(OpAsmParser &parser, OperationState &state) { + if (failed(parser.parseOptionalAttributeDict(state.attributes))) { + return failure(); + } + return success(); +} + +// Prints an op that has no inputs and no outputs. +static void printNoIOOp(Operation *op, OpAsmPrinter &printer) { + printer << op->getName(); + printer.printOptionalAttrDict(op->getAttrs()); +} + +//===----------------------------------------------------------------------===// +// iree.module +//===----------------------------------------------------------------------===// + +void ModuleOp::build(Builder *builder, OperationState &state) { + ensureTerminator(*state.addRegion(), *builder, state.location); +} + +static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) { + Region *body = state.addRegion(); + if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) { + return failure(); + } + if (parser.parseOptionalAttributeDict(state.attributes)) { + return failure(); + } + ModuleOp::ensureTerminator(*body, parser.getBuilder(), state.location); + return success(); +} + +static void printModuleOp(OpAsmPrinter &printer, Operation *op) { + printer << op->getName(); + printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/false); + printer.printOptionalAttrDict(op->getAttrs()); +} + +//===----------------------------------------------------------------------===// +// iree.multi_arch_executable +//===----------------------------------------------------------------------===// + +void MultiArchExecutableOp::build(Builder *builder, OperationState &state, + StringRef name) { + state.addAttribute(SymbolTable::getSymbolAttrName(), + builder->getStringAttr(name)); + ensureTerminator(*state.addRegion(), *builder, state.location); +} + +static ParseResult parseMultiArchExecutableOp(OpAsmParser &parser, + OperationState &state) { + auto &builder = parser.getBuilder(); + + // Parse the name as a symbol reference attr and then convert to a string. + SymbolRefAttr nameAttr; + if (failed(parser.parseAttribute(nameAttr, SymbolTable::getSymbolAttrName(), + state.attributes))) { + return failure(); + } + state.attributes.back().second = builder.getStringAttr(nameAttr.getValue()); + + if (succeeded(parser.parseOptionalLSquare())) { + IntegerAttr ordinalAttr; + if (failed(parser.parseAttribute(ordinalAttr, builder.getIntegerType(32), + "iree.ordinal", state.attributes)) || + failed(parser.parseRSquare())) { + return failure(); + } + } + + if (failed(parser.parseLParen()) || failed(parser.parseRParen())) { + return failure(); + } + + Region *body = state.addRegion(); + if (failed(parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))) { + return failure(); + } + if (succeeded(parser.parseOptionalKeyword("attributes"))) { + if (failed(parser.parseOptionalAttributeDict(state.attributes))) { + return failure(); + } + } + + MultiArchExecutableOp::ensureTerminator(*body, builder, state.location); + + return success(); +} + +static void printMultiArchExecutableOp(OpAsmPrinter &printer, + MultiArchExecutableOp op) { + printer << op.getOperationName() << " @" << op.sym_name(); + if (auto ordinalAttr = + op.getAttr("iree.ordinal").dyn_cast_or_null<IntegerAttr>()) { + printer << "[" << ordinalAttr.getInt() << "]"; + } + printer << "()"; + + printer.printRegion(op.body(), /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/false); + + // Print out executable attributes, if present. + SmallVector<StringRef, 2> ignoredAttrs = { + SymbolTable::getSymbolAttrName(), + "iree.ordinal", + }; + SmallVector<NamedAttribute, 4> attrs( + llvm::make_filter_range(op.getAttrs(), [&](const NamedAttribute &attr) { + return llvm::count(ignoredAttrs, attr.first) == 0; + })); + if (!attrs.empty()) { + printer << "\n attributes "; + printer.printOptionalAttrDict(attrs); + } +} + +//===----------------------------------------------------------------------===// +// iree.executable +//===----------------------------------------------------------------------===// + +void ExecutableOp::build(Builder *builder, OperationState &state, + IREE::ExecutableFormat format) { + state.addAttribute("format", + builder->getI32IntegerAttr(static_cast<uint32_t>(format))); + ensureTerminator(*state.addRegion(), *builder, state.location); +} + +static ParseResult parseExecutableOp(OpAsmParser &parser, + OperationState &state) { + auto &builder = parser.getBuilder(); + + if (succeeded(parser.parseOptionalLSquare())) { + IntegerAttr ordinalAttr; + if (failed(parser.parseAttribute(ordinalAttr, builder.getIntegerType(32), + "iree.ordinal", state.attributes)) || + failed(parser.parseRSquare())) { + return failure(); + } + } + + IntegerAttr executableOrdinalAttr; + StringAttr formatAttr; + llvm::SMLoc formatLoc; + if (failed(parser.parseLParen()) || + failed(parser.getCurrentLocation(&formatLoc)) || + failed(parser.parseAttribute(formatAttr, "format", state.attributes))) { + return failure(); + } + auto format = symbolizeExecutableFormat(formatAttr.getValue()); + if (!format.hasValue()) { + return parser.emitError(formatLoc) + << "Unknown executable format " << formatAttr.getValue(); + } + state.attributes.back().second = + builder.getI32IntegerAttr(static_cast<int32_t>(format.getValue())); + + Region *body = state.addRegion(); + if (failed(parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))) { + return failure(); + } + if (succeeded(parser.parseOptionalKeyword("attributes"))) { + if (failed(parser.parseOptionalAttributeDict(state.attributes))) { + return failure(); + } + } + + ExecutableOp::ensureTerminator(*body, parser.getBuilder(), state.location); + + return success(); +} + +static void printExecutableOp(OpAsmPrinter &printer, ExecutableOp op) { + printer << op.getOperationName(); + if (auto ordinalAttr = + op.getAttr("iree.ordinal").dyn_cast_or_null<IntegerAttr>()) { + printer << "[" << ordinalAttr.getInt() << "]"; + } + printer << "("; + auto format = symbolizeExecutableFormat(op.format()); + if (format.hasValue()) { + printer << stringifyExecutableFormat(format.getValue()); + } else { + printer << "INVALID FORMAT"; + } + printer << ")"; + + printer.printRegion(op.body(), /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/false); + + // Print out executable attributes, if present. + SmallVector<StringRef, 2> ignoredAttrs = { + "iree.ordinal", + "format", + }; + SmallVector<NamedAttribute, 4> attrs( + llvm::make_filter_range(op.getAttrs(), [&](const NamedAttribute &attr) { + return llvm::count(ignoredAttrs, attr.first) == 0; + })); + if (!attrs.empty()) { + printer << "\n attributes "; + printer.printOptionalAttrDict(attrs); + } +} + +#define GET_OP_CLASSES +#include "compiler/IR/StructureOps.cpp.inc" + +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/IR/StructureOps.h b/compiler/IR/StructureOps.h new file mode 100644 index 0000000..f187b71 --- /dev/null +++ b/compiler/IR/StructureOps.h
@@ -0,0 +1,40 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_IR_STRUCTUREOPS_H_ +#define IREE_COMPILER_IR_STRUCTUREOPS_H_ + +#include <cstdint> + +#include "compiler/IR/Types.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/FunctionSupport.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" + +namespace mlir { +namespace iree_compiler { +namespace IREE { + +#define GET_OP_CLASSES +#include "compiler/IR/StructureOps.h.inc" + +} // namespace IREE +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_IR_STRUCTUREOPS_H_
diff --git a/compiler/IR/StructureOps.td b/compiler/IR/StructureOps.td new file mode 100644 index 0000000..afed1a0 --- /dev/null +++ b/compiler/IR/StructureOps.td
@@ -0,0 +1,122 @@ +// Structural ops such as 'module' and 'executable'. +// These are used to organize IREE IR into regions representing ops that act at +// the sequencer level (coarse control flow/scheduling) and ops that perform +// actual work (math/etc) on runtime execution backends. + +#ifdef IREE_STRUCTURE_OPS +#else +#define IREE_STRUCTURE_OPS + +#ifdef IREE_OP_BASE +#else +include "compiler/IR/OpBase.td" +#endif // IREE_OP_BASE + +class IREE_StructureOp<string mnemonic, list<OpTrait> traits = []> : + Op<IREE_Dialect, mnemonic, traits> { + let parser = [{ return parse$cppClass(parser, result); }]; + let printer = [{ print$cppClass(p, *this); }]; +} + +def IREE_ModuleOp : + IREE_StructureOp<"module", [ + SingleBlockImplicitTerminator<"ModuleEndOp">, + NativeOpTrait<"SymbolTable"> + ]> { + let regions = (region SizedRegion<1>:$body); + let extraClassDeclaration = [{ + Block& getBlock() { + return this->getOperation()->getRegion(0).front(); + } + }]; + + let skipDefaultBuilders = 1; + let builders = [OpBuilder<"Builder *, OperationState &state">]; +} + +def IREE_ModuleEndOp : + IREE_StructureOp<"_module_end", [ + IREE_ModuleOnly, + Terminator + ]> { + let parser = [{ return parseNoIOOp(parser, result); }]; + let printer = [{ printNoIOOp(getOperation(), p); }]; +} + +def IREE_MultiArchExecutableOp : + IREE_StructureOp<"multi_arch_executable", [ + // TODO(benvanik): make iree.module work and make this IREE_ModuleOnly. + SingleBlockImplicitTerminator<"MultiArchExecutableEndOp"> + ]> { + let arguments = (ins + StrAttr:$sym_name, + OptionalAttr<I32Attr>:$ordinal + ); + + let regions = (region SizedRegion<1>:$body); + let extraClassDeclaration = [{ + StringRef getName() { + return this->getOperation()->template getAttrOfType<StringAttr>( + ::mlir::SymbolTable::getSymbolAttrName()).getValue(); + } + + Region& getBody() { + return this->getOperation()->getRegion(0); + } + Block& getBlock() { + return this->getOperation()->getRegion(0).front(); + } + }]; + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<"Builder *builder, OperationState &state, StringRef name">, + ]; +} + +def IREE_MultiArchExecutableEndOp : + IREE_StructureOp<"_multi_arch_executable_end", [ + IREE_MultiArchExecutableOnly, + Terminator + ]> { + let parser = [{ return parseNoIOOp(parser, result); }]; + let printer = [{ printNoIOOp(getOperation(), p); }]; +} + +def IREE_ExecutableOp : + IREE_StructureOp<"executable", [ + SingleBlockImplicitTerminator<"ExecutableEndOp">, + NativeOpTrait<"SymbolTable"> + ]> { + let arguments = (ins + IREE_ExecutableFormatAttr:$format, + OptionalAttr<I32Attr>:$ordinal + ); + + let regions = (region SizedRegion<1>:$body); + let extraClassDeclaration = [{ + Region& getBody() { + return this->getOperation()->getRegion(0); + } + Block& getBlock() { + return this->getOperation()->getRegion(0).front(); + } + ::mlir::ModuleOp getInnerModule() { + return *getBlock().getOps<::mlir::ModuleOp>().begin(); + } + }]; + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<[{Builder *builder, OperationState &state, + ExecutableFormat executable_format}]>, + ]; +} + +def IREE_ExecutableEndOp : + IREE_StructureOp<"_executable_end", [Terminator, IREE_ExecutableOnly]> { + let parser = [{ return parseNoIOOp(parser, result); }]; + let printer = [{ printNoIOOp(getOperation(), p); }]; +} + +#endif // IREE_STRUCTURE_OPS
diff --git a/compiler/IR/Traits.cpp b/compiler/IR/Traits.cpp new file mode 100644 index 0000000..bc8a3fb --- /dev/null +++ b/compiler/IR/Traits.cpp
@@ -0,0 +1,23 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Traits.h" + +namespace mlir { +namespace iree_compiler { + +// TODO(benvanik): traits. + +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/IR/Traits.h b/compiler/IR/Traits.h similarity index 100% rename from iree/compiler/IR/Traits.h rename to compiler/IR/Traits.h
diff --git a/compiler/IR/Types.cpp b/compiler/IR/Types.cpp new file mode 100644 index 0000000..3e65bb5 --- /dev/null +++ b/compiler/IR/Types.cpp
@@ -0,0 +1,53 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Types.h" + +#include "compiler/IR/Enums.cpp.inc" + +namespace mlir { +namespace iree_compiler { + +// static +DeviceType DeviceType::get(MLIRContext *context) { + return Base::get(context, TypeKind::Device); +} + +// static +DeviceGroupType DeviceGroupType::get(MLIRContext *context) { + return Base::get(context, TypeKind::DeviceGroup); +} + +// static +CommandBufferType CommandBufferType::get(MLIRContext *context) { + return Base::get(context, TypeKind::CommandBuffer); +} + +// static +EventType EventType::get(MLIRContext *context) { + return Base::get(context, TypeKind::Event); +} + +// static +SemaphoreType SemaphoreType::get(MLIRContext *context) { + return Base::get(context, TypeKind::Semaphore); +} + +// static +FenceType FenceType::get(MLIRContext *context) { + return Base::get(context, TypeKind::Fence); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/IR/Types.h b/compiler/IR/Types.h new file mode 100644 index 0000000..9379877 --- /dev/null +++ b/compiler/IR/Types.h
@@ -0,0 +1,115 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_IR_TYPES_H_ +#define IREE_COMPILER_IR_TYPES_H_ + +#include <cstdint> + +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/StringSwitch.h" +#include "mlir/IR/Types.h" +#include "mlir/Support/LLVM.h" + +// Order matters. +#include "compiler/IR/Enums.h.inc" + +namespace mlir { +namespace iree_compiler { + +namespace TypeKind { +enum Kind { + Device = Type::FIRST_IREE_TYPE, + DeviceGroup, + CommandBuffer, + Event, + Semaphore, + Fence, +}; +} // namespace TypeKind + +// clang-format off +#define IREE_TYPE_TABLE(map) \ + map("device", TypeKind::Device, DeviceType) \ + map("device_group", TypeKind::DeviceGroup, DeviceGroupType) \ + map("command_buffer", TypeKind::CommandBuffer, CommandBufferType) \ + map("event", TypeKind::Event, EventType) \ + map("semaphore", TypeKind::Semaphore, SemaphoreType) \ + map("fence", TypeKind::Fence, FenceType) +// clang-format on + +// iree.device mapping to a runtime-resolved device type. +class DeviceType : public Type::TypeBase<DeviceType, Type> { + public: + using Base::Base; + + static bool kindof(unsigned kind) { return kind == TypeKind::Device; } + + static DeviceType get(MLIRContext *context); +}; + +// iree.device_group relating multiple iree.device requirements with each other. +class DeviceGroupType : public Type::TypeBase<DeviceGroupType, Type> { + public: + using Base::Base; + + static bool kindof(unsigned kind) { return kind == TypeKind::DeviceGroup; } + + static DeviceGroupType get(MLIRContext *context); +}; + +// iree.command_buffer mapping to an iree::hal::CommandBuffer. +class CommandBufferType : public Type::TypeBase<CommandBufferType, Type> { + public: + using Base::Base; + + static bool kindof(unsigned kind) { return kind == TypeKind::CommandBuffer; } + + static CommandBufferType get(MLIRContext *context); +}; + +// iree.event mapping to an iree::hal::Event. +class EventType : public Type::TypeBase<EventType, Type> { + public: + using Base::Base; + + static bool kindof(unsigned kind) { return kind == TypeKind::Event; } + + static EventType get(MLIRContext *context); +}; + +// iree.semaphore mapping to an iree::hal::Semaphore. +class SemaphoreType : public Type::TypeBase<SemaphoreType, Type> { + public: + using Base::Base; + + static bool kindof(unsigned kind) { return kind == TypeKind::Semaphore; } + + static SemaphoreType get(MLIRContext *context); +}; + +// iree.fence mapping to an iree::hal::Fence. +class FenceType : public Type::TypeBase<FenceType, Type> { + public: + using Base::Base; + + static bool kindof(unsigned kind) { return kind == TypeKind::Fence; } + + static FenceType get(MLIRContext *context); +}; + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_IR_TYPES_H_
diff --git a/compiler/IR/test/BUILD b/compiler/IR/test/BUILD new file mode 100644 index 0000000..44a5820 --- /dev/null +++ b/compiler/IR/test/BUILD
@@ -0,0 +1,15 @@ +load("//:build_defs.google.bzl", "iree_glob_lit_tests", "iree_setup_lit_package") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +iree_setup_lit_package( + data = [ + "///tools:iree-opt", + "///tools:iree-run-mlir", + ], +) + +iree_glob_lit_tests()
diff --git a/iree/compiler/IR/test/bindings.mlir b/compiler/IR/test/bindings.mlir similarity index 100% rename from iree/compiler/IR/test/bindings.mlir rename to compiler/IR/test/bindings.mlir
diff --git a/iree/compiler/IR/test/constant.mlir b/compiler/IR/test/constant.mlir similarity index 100% rename from iree/compiler/IR/test/constant.mlir rename to compiler/IR/test/constant.mlir
diff --git a/iree/compiler/IR/test/dispatch_regions.mlir b/compiler/IR/test/dispatch_regions.mlir similarity index 100% rename from iree/compiler/IR/test/dispatch_regions.mlir rename to compiler/IR/test/dispatch_regions.mlir
diff --git a/iree/compiler/IR/test/reduction_regions.mlir b/compiler/IR/test/reduction_regions.mlir similarity index 100% rename from iree/compiler/IR/test/reduction_regions.mlir rename to compiler/IR/test/reduction_regions.mlir
diff --git a/iree/compiler/IR/test/scalar_memref.mlir b/compiler/IR/test/scalar_memref.mlir similarity index 100% rename from iree/compiler/IR/test/scalar_memref.mlir rename to compiler/IR/test/scalar_memref.mlir
diff --git a/iree/compiler/IR/test/tensor_memref.mlir b/compiler/IR/test/tensor_memref.mlir similarity index 100% rename from iree/compiler/IR/test/tensor_memref.mlir rename to compiler/IR/test/tensor_memref.mlir
diff --git a/compiler/Serialization/BUILD b/compiler/Serialization/BUILD new file mode 100644 index 0000000..523684e --- /dev/null +++ b/compiler/Serialization/BUILD
@@ -0,0 +1,43 @@ +# Serialization for the VM bytecode. + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "Serialization", + srcs = [ + "BytecodeTables.cpp", + "BytecodeWriter.cpp", + "VMDeviceTableBuilder.cpp", + "VMExecutableTableBuilder.cpp", + "VMFunctionBuilder.cpp", + "VMFunctionTableBuilder.cpp", + "VMModuleBuilder.cpp", + "VMSourceMapBuilder.cpp", + ], + hdrs = [ + "BytecodeTables.h", + "BytecodeWriter.h", + "VMDeviceTableBuilder.h", + "VMExecutableTableBuilder.h", + "VMFunctionBuilder.h", + "VMFunctionTableBuilder.h", + "VMModuleBuilder.h", + "VMSourceMapBuilder.h", + ], + deps = [ + "///compiler/IR", + "///compiler/Utils", + "///schemas", + "///schemas/bytecode:bytecode_v0", + "///schemas/bytecode:interpreter_bytecode_v0", + "///schemas/bytecode:sequencer_bytecode_v0", + "@com_github_google_flatbuffers//:flatbuffers", + "@llvm//:support", + "@local_config_mlir//:IR", + "@local_config_mlir//:StandardOps", + "@local_config_mlir//:Support", + ], +)
diff --git a/compiler/Serialization/BytecodeTables.cpp b/compiler/Serialization/BytecodeTables.cpp new file mode 100644 index 0000000..80f63c9 --- /dev/null +++ b/compiler/Serialization/BytecodeTables.cpp
@@ -0,0 +1,73 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/Serialization/BytecodeTables.h" + +#include "llvm/ADT/STLExtras.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +// Info tables mapping 1:1 with bytecode ops. +// +// Note that we ensure the table is 256 elements long exactly to make sure +// that unused opcodes are handled gracefully. +#define DECLARE_INFO(ordinal, enum_value, name, flags, operand_encodings, ...) \ + { \ + name, \ + flags, \ + {operand_encodings}, \ + }, + +static const OpcodeInfo kInterpreterInfoTable[256] = { + IREE_INTERPRETER_OPCODE_LIST(DECLARE_INFO, DECLARE_INFO)}; + +static const OpcodeInfo kSequencerInfoTable[256] = { + IREE_SEQUENCER_OPCODE_LIST(DECLARE_INFO, DECLARE_INFO)}; + +#undef DECLARE_INFO + +} // namespace + +llvm::Optional<iree::InterpreterOpcode> GetInterpreterOpcodeByName( + StringRef name) { + for (int i = 0; i < llvm::array_lengthof(kInterpreterInfoTable); ++i) { + if (name == kInterpreterInfoTable[i].mnemonic) { + return static_cast<iree::InterpreterOpcode>(i); + } + } + return llvm::None; +} + +const OpcodeInfo& GetInterpreterOpcodeInfo(iree::InterpreterOpcode opcode) { + return kInterpreterInfoTable[static_cast<uint8_t>(opcode)]; +} + +llvm::Optional<iree::SequencerOpcode> GetSequencerOpcodeByName(StringRef name) { + for (int i = 0; i < llvm::array_lengthof(kSequencerInfoTable); ++i) { + if (name == kSequencerInfoTable[i].mnemonic) { + return static_cast<iree::SequencerOpcode>(i); + } + } + return llvm::None; +} + +const OpcodeInfo& GetSequencerOpcodeInfo(iree::SequencerOpcode opcode) { + return kSequencerInfoTable[static_cast<uint8_t>(opcode)]; +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Serialization/BytecodeTables.h b/compiler/Serialization/BytecodeTables.h new file mode 100644 index 0000000..a649c97 --- /dev/null +++ b/compiler/Serialization/BytecodeTables.h
@@ -0,0 +1,52 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_SERIALIZATION_BYTECODE_TABLES_H_ +#define IREE_COMPILER_SERIALIZATION_BYTECODE_TABLES_H_ + +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Support/LLVM.h" +#include "schemas/bytecode/interpreter_bytecode_v0.h" +#include "schemas/bytecode/sequencer_bytecode_v0.h" + +namespace mlir { +namespace iree_compiler { + +struct OpcodeInfo { + const char* mnemonic = nullptr; + iree::OpcodeFlagBitfield flags = iree::OpcodeFlagBitfield::kDefault; + union { + const char operands_value[8] = {0}; + const iree::OperandEncoding operands[8]; + }; +}; + +// Returns an opcode - if found - for the given interpreter op. +llvm::Optional<iree::InterpreterOpcode> GetInterpreterOpcodeByName( + StringRef name); + +// Returns the info for the given interpreter opcode. +const OpcodeInfo& GetInterpreterOpcodeInfo(iree::InterpreterOpcode opcode); + +// Returns an opcode - if found - for the given sequencer op. +llvm::Optional<iree::SequencerOpcode> GetSequencerOpcodeByName(StringRef name); + +// Returns the info for the given sequencer opcode. +const OpcodeInfo& GetSequencerOpcodeInfo(iree::SequencerOpcode opcode); + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_SERIALIZATION_BYTECODE_TABLES_H_
diff --git a/compiler/Serialization/BytecodeWriter.cpp b/compiler/Serialization/BytecodeWriter.cpp new file mode 100644 index 0000000..beb1e1e --- /dev/null +++ b/compiler/Serialization/BytecodeWriter.cpp
@@ -0,0 +1,334 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/Serialization/BytecodeWriter.h" + +#include <algorithm> + +#include "compiler/Utils/Macros.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { +namespace iree_compiler { + +LogicalResult BytecodeWriter::WriteCount(int count) { + if (count > UINT8_MAX) { + // TODO(benvanik): varints? + llvm::errs() << "Too many items: " << count + << "; only 0-UINT8_MAX are supported"; + return failure(); + } + return WriteUint8(static_cast<uint8_t>(count)); +} + +LogicalResult BytecodeWriter::WriteTypeIndex(Type type) { + iree::BuiltinType type_index; + if (type.isInteger(8)) { + type_index = iree::BuiltinType::kI8; + } else if (type.isInteger(16)) { + type_index = iree::BuiltinType::kI16; + } else if (type.isInteger(32)) { + type_index = iree::BuiltinType::kI32; + } else if (type.isInteger(64)) { + type_index = iree::BuiltinType::kI64; + } else if (type.isF16()) { + type_index = iree::BuiltinType::kF16; + } else if (type.isF32()) { + type_index = iree::BuiltinType::kF32; + } else if (type.isF64()) { + type_index = iree::BuiltinType::kF64; + } else { + // TODO(benvanik): support unknown types as BuiltinType::kOpaque? + return emitError(UnknownLoc::get(type.getContext())) + << "Type " << type << " cannot be represented by a builtin type"; + } + return WriteUint8(static_cast<uint8_t>(type_index)); +} + +LogicalResult BytecodeWriter::WriteFunctionOrdinal(FuncOp function) { + auto functionOrdinal = function.getAttrOfType<IntegerAttr>("iree.ordinal"); + if (!functionOrdinal) { + return function.emitError() << "Ordinal not assigned to function"; + } + RETURN_IF_FAILURE(WriteUint32(functionOrdinal.getInt())); + return success(); +} + +LogicalResult BytecodeWriter::WriteImportOrdinal(FuncOp function) { + // For now this is the same as internal function ordinals, though we could + // probably shrink it. + return WriteFunctionOrdinal(function); +} + +LogicalResult BytecodeWriter::WriteConstant(MemRefType memRefType, + Attribute baseAttr) { + // All types are memrefs, so we only need the element type. + RETURN_IF_FAILURE(WriteTypeIndex(memRefType.getElementType())); + + // Write shape (we could optimize this for cases of scalars and such). + RETURN_IF_FAILURE(WriteCount(memRefType.getRank())); + for (int i = 0; i < memRefType.getRank(); ++i) { + RETURN_IF_FAILURE(WriteInt32(memRefType.getDimSize(i))); + } + + if (auto attr = baseAttr.dyn_cast<SplatElementsAttr>()) { + RETURN_IF_FAILURE( + WriteUint8(static_cast<uint8_t>(iree::ConstantEncoding::kSplat))); + return WriteAttributeData(attr.getSplatValue()); + } + RETURN_IF_FAILURE( + WriteUint8(static_cast<uint8_t>(iree::ConstantEncoding::kDense))); + return WriteAttributeData(baseAttr); +} + +LogicalResult BytecodeWriter::WriteAttributeData(Attribute baseAttr) { + if (auto attr = baseAttr.dyn_cast<BoolAttr>()) { + return WriteUint8(attr.getValue() ? 1 : 0); + } else if (auto attr = baseAttr.dyn_cast<IntegerAttr>()) { + if (attr.getType().isIndex()) { + int32_t value = static_cast<int32_t>(attr.getInt()); + return WriteBytes(&value, 4); + } else { + int bitWidth = attr.getValue().getBitWidth(); + switch (bitWidth) { + case 8: + case 16: + case 32: + case 64: + return WriteBytes(attr.getValue().getRawData(), bitWidth / 8); + default: + return emitError(UnknownLoc::get(baseAttr.getContext())) + << "Bit width for integers must be one of 8,16,32,64; others " + "not implemented: " + << bitWidth; + } + } + } else if (auto attr = baseAttr.dyn_cast<FloatAttr>()) { + int bitWidth = attr.getType().getIntOrFloatBitWidth(); + auto bitcastValue = attr.getValue().bitcastToAPInt(); + switch (bitWidth) { + case 16: + case 32: + case 64: + return WriteBytes(bitcastValue.getRawData(), bitWidth / 8); + default: + return emitError(UnknownLoc::get(baseAttr.getContext())) + << "Bit width for floats must be one of 16,32,64; others " + "not implemented: " + << bitWidth; + } + } else if (auto attr = baseAttr.dyn_cast<StringAttr>()) { + // TODO(benvanik): other attribute encodings. + } else if (auto attr = baseAttr.dyn_cast<ArrayAttr>()) { + // TODO(benvanik): other attribute encodings. + } else if (auto attr = baseAttr.dyn_cast<AffineMapAttr>()) { + // TODO(benvanik): other attribute encodings. + } else if (auto attr = baseAttr.dyn_cast<IntegerSetAttr>()) { + // TODO(benvanik): other attribute encodings. + } else if (auto attr = baseAttr.dyn_cast<TypeAttr>()) { + // TODO(benvanik): other attribute encodings. + } else if (auto attr = baseAttr.dyn_cast<SymbolRefAttr>()) { + // TODO(benvanik): other attribute encodings. + } else if (auto attr = baseAttr.dyn_cast<SplatElementsAttr>()) { + return WriteAttributeData(attr.getSplatValue()); + } else if (auto attr = baseAttr.dyn_cast<DenseIntElementsAttr>()) { + int elementCount = attr.getType().getNumElements(); + if (elementCount == 0) { + return success(); + } + int bitWidth = attr.getType().getElementTypeBitWidth(); + int byteWidth = bitWidth / 8; + auto dst = ReserveBytes(elementCount * byteWidth); + if (dst.empty()) return failure(); + uint8_t *dstPtr = dst.data(); + for (auto element : attr) { + assert(element.getBitWidth() == bitWidth); + std::memcpy(dstPtr, element.getRawData(), byteWidth); + dstPtr += byteWidth; + } + return success(); + } else if (auto attr = baseAttr.dyn_cast<DenseFPElementsAttr>()) { + int elementCount = attr.getType().getNumElements(); + if (elementCount == 0) { + return success(); + } + int bitWidth = attr.getType().getElementTypeBitWidth(); + auto dst = ReserveBytes(elementCount * bitWidth / 8); + if (dst.empty()) return failure(); + uint8_t *dstPtr = dst.data(); + for (auto element : attr) { + auto bitcastValue = element.bitcastToAPInt(); + std::memcpy(dstPtr, bitcastValue.getRawData(), + bitcastValue.getBitWidth() / 8); + dstPtr += bitWidth / 8; + } + return success(); + } else if (auto attr = baseAttr.dyn_cast<DenseElementsAttr>()) { + // TODO(benvanik): other attribute encodings. + } else if (auto attr = baseAttr.dyn_cast<OpaqueElementsAttr>()) { + // TODO(benvanik): other attribute encodings. + } else if (auto attr = baseAttr.dyn_cast<SparseElementsAttr>()) { + // TODO(benvanik): other attribute encodings. + } + return emitError(UnknownLoc::get(baseAttr.getContext())) + << "Serializer for attribute kind " + << static_cast<int>(baseAttr.getKind()) << " not implemented"; +} + +Optional<int> BytecodeWriter::LookupLocalOrdinal(Value *value) { + int ordinal; + auto it = localMap_.find(value); + if (it != localMap_.end()) { + ordinal = it->second; + } else { + ordinal = localMap_.size(); + localMap_.insert({value, ordinal}); + } + if (ordinal > UINT16_MAX) { + // TODO(benvanik): varints? + emitError(UnknownLoc::get(value->getContext())) + << "Too many ordinals: " << ordinal + << "; only 0-UINT16_MAX are supported"; + return llvm::None; + } + return ordinal; +} + +LogicalResult BytecodeWriter::PrepareLocal(Value *value) { + if (!LookupLocalOrdinal(value).hasValue()) return failure(); + return success(); +} + +LogicalResult BytecodeWriter::WriteLocal(Value *value) { + auto ordinal = LookupLocalOrdinal(value); + if (!ordinal.hasValue()) { + return failure(); + } + if (ordinal.getValue() > UINT16_MAX) { + // TODO(benvanik): varints? + return emitError(UnknownLoc::get(value->getContext())) + << "Too many locals: " << ordinal.getValue() + << "; only 0-UINT16_MAX are supported"; + } + return WriteUint16(static_cast<uint16_t>(ordinal.getValue())); +} + +LogicalResult BytecodeWriter::WriteLocals( + llvm::iterator_range<Operation::operand_iterator> values) { + int count = std::distance(values.begin(), values.end()); + RETURN_IF_FAILURE(WriteCount(count)); + for (auto *value : values) { + RETURN_IF_FAILURE(WriteLocal(value)); + } + return success(); +} + +LogicalResult BytecodeWriter::WriteLocals( + llvm::iterator_range<Operation::result_iterator> values) { + int count = std::distance(values.begin(), values.end()); + RETURN_IF_FAILURE(WriteCount(count)); + for (auto *value : values) { + RETURN_IF_FAILURE(WriteLocal(value)); + } + return success(); +} + +MutableArrayRef<uint8_t> BytecodeWriter::ReserveBytes(size_t dataLength) { + int offset = bytecode_.size(); + bytecode_.resize(offset + dataLength); + return MutableArrayRef<uint8_t>( + reinterpret_cast<uint8_t *>(bytecode_.data()) + offset, dataLength); +} + +LogicalResult BytecodeWriter::WriteBytes(const void *data, size_t dataLength) { + auto dst = ReserveBytes(dataLength); + if (dataLength != dst.size()) { + return failure(); + } + std::memcpy(dst.data(), data, dst.size()); + return success(); +} + +LogicalResult BytecodeWriter::WriteUint8(uint8_t value) { + return WriteBytes(&value, sizeof(value)); +} + +LogicalResult BytecodeWriter::WriteUint16(uint16_t value) { + return WriteBytes(&value, sizeof(value)); +} + +LogicalResult BytecodeWriter::WriteInt32(int32_t value) { + return WriteBytes(&value, sizeof(value)); +} + +LogicalResult BytecodeWriter::WriteUint32(uint32_t value) { + return WriteBytes(&value, sizeof(value)); +} + +LogicalResult BytecodeWriter::WriteElementsAttrInt32(ElementsAttr attr) { + int elementCount = attr.getType().getNumElements(); + RETURN_IF_FAILURE(WriteCount(elementCount)); + for (auto value : attr.getValues<int32_t>()) { + RETURN_IF_FAILURE(WriteInt32(value)); + } + return success(); +} + +LogicalResult BytecodeWriter::WriteShapePieces(const ShapedType &type) { + RETURN_IF_FAILURE(WriteCount(type.getRank())); + for (int64_t dim : type.getShape()) { + RETURN_IF_FAILURE(WriteInt32(dim)); + } + return success(); +} + +LogicalResult BytecodeWriter::WriteShapePieces(ElementsAttr pieces) { + return WriteElementsAttrInt32(pieces); +} + +LogicalResult BytecodeWriter::MarkBlockOffset(Block *block) { + blockOffsets_[block] = bytecode_.size(); + return success(); +} + +LogicalResult BytecodeWriter::WriteBlockOffset(Block *targetBlock) { + // Reserve space for the offset and stash for later fixup. + blockOffsetFixups_.push_back({targetBlock, bytecode_.size()}); + bytecode_.resize(bytecode_.size() + sizeof(int32_t)); + return success(); +} + +LogicalResult BytecodeWriter::FixupOffsets() { + for (const auto &fixup : blockOffsetFixups_) { + auto it = blockOffsets_.find(fixup.first); + if (it == blockOffsets_.end()) { + llvm::errs() << "Block offset not found: " << fixup.first; + return failure(); + } + std::memcpy(bytecode_.data() + fixup.second, &it->second, sizeof(int32_t)); + } + blockOffsetFixups_.clear(); + return success(); +} + +std::vector<uint8_t> BytecodeWriter::Finish() { + localMap_.clear(); + return std::move(bytecode_); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Serialization/BytecodeWriter.h b/compiler/Serialization/BytecodeWriter.h new file mode 100644 index 0000000..c46e196 --- /dev/null +++ b/compiler/Serialization/BytecodeWriter.h
@@ -0,0 +1,96 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_SERIALIZATION_BYTECODE_WRITER_H_ +#define IREE_COMPILER_SERIALIZATION_BYTECODE_WRITER_H_ + +#include <cstddef> +#include <utility> +#include <vector> + +#include "compiler/IR/StructureOps.h" +#include "llvm/ADT/Optional.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "schemas/bytecode/bytecode_v0.h" + +namespace mlir { +namespace iree_compiler { + +class BytecodeWriter { + public: + int offset() const { return bytecode_.size(); } + + int local_count() const { return localMap_.size(); } + + template <typename T> + LogicalResult WriteOpcode(T value) { + static_assert(sizeof(T) == sizeof(uint8_t), "Opcode enum size mismatch"); + return WriteUint8(static_cast<uint8_t>(value)); + } + + LogicalResult WriteCount(int count); + + LogicalResult WriteTypeIndex(Type type); + + LogicalResult WriteFunctionOrdinal(FuncOp function); + LogicalResult WriteImportOrdinal(FuncOp function); + + LogicalResult WriteConstant(MemRefType memRefType, Attribute baseAttr); + LogicalResult WriteAttributeData(Attribute baseAttr); + + llvm::Optional<int> LookupLocalOrdinal(Value *value); + LogicalResult PrepareLocal(Value *value); + LogicalResult WriteLocal(Value *value); + LogicalResult WriteLocals( + llvm::iterator_range<Operation::operand_iterator> values); + LogicalResult WriteLocals( + llvm::iterator_range<Operation::result_iterator> values); + + LogicalResult WriteBytes(const void *data, size_t dataLength); + MutableArrayRef<uint8_t> ReserveBytes(size_t dataLength); + LogicalResult WriteUint8(uint8_t value); + LogicalResult WriteUint16(uint16_t value); + LogicalResult WriteInt32(int32_t value); + LogicalResult WriteUint32(uint32_t value); + + LogicalResult WriteElementsAttrInt32(ElementsAttr attr); + + LogicalResult WriteShapePieces(const ShapedType &type); + LogicalResult WriteShapePieces(ElementsAttr pieces); + + LogicalResult MarkBlockOffset(Block *block); + LogicalResult WriteBlockOffset(Block *targetBlock); + LogicalResult FixupOffsets(); + + std::vector<uint8_t> Finish(); + + private: + std::vector<uint8_t> bytecode_; + + llvm::DenseMap<Value *, int> localMap_; + + llvm::DenseMap<Block *, size_t> blockOffsets_; + std::vector<std::pair<Block *, size_t>> blockOffsetFixups_; +}; + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_SERIALIZATION_BYTECODE_WRITER_H_
diff --git a/iree/compiler/Serialization/CMakeLists.txt b/compiler/Serialization/CMakeLists.txt similarity index 100% rename from iree/compiler/Serialization/CMakeLists.txt rename to compiler/Serialization/CMakeLists.txt
diff --git a/compiler/Serialization/VMDeviceTableBuilder.cpp b/compiler/Serialization/VMDeviceTableBuilder.cpp new file mode 100644 index 0000000..4c95ab9 --- /dev/null +++ b/compiler/Serialization/VMDeviceTableBuilder.cpp
@@ -0,0 +1,46 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/Serialization/VMDeviceTableBuilder.h" + +namespace mlir { +namespace iree_compiler { + +VMDeviceTableBuilder::VMDeviceTableBuilder( + ::flatbuffers::FlatBufferBuilder *fbb) + : fbb_(fbb) {} + +LogicalResult VMDeviceTableBuilder::AddDevice( + ::flatbuffers::Offset<iree::DeviceDef> deviceDef) { + deviceDefs_.push_back(deviceDef); + return success(); +} + +LogicalResult VMDeviceTableBuilder::AddDeviceGroup( + ::flatbuffers::Offset<iree::DeviceGroupDef> deviceGroupDef) { + deviceGroupDefs_.push_back(deviceGroupDef); + return success(); +} + +::flatbuffers::Offset<iree::DeviceTableDef> VMDeviceTableBuilder::Finish() { + auto devicesOffset = fbb_->CreateVector(deviceDefs_); + auto deviceGroupsOffset = fbb_->CreateVector(deviceGroupDefs_); + iree::DeviceTableDefBuilder dtdb(*fbb_); + dtdb.add_devices(devicesOffset); + dtdb.add_device_groups(deviceGroupsOffset); + return dtdb.Finish(); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Serialization/VMDeviceTableBuilder.h b/compiler/Serialization/VMDeviceTableBuilder.h new file mode 100644 index 0000000..1cd9147 --- /dev/null +++ b/compiler/Serialization/VMDeviceTableBuilder.h
@@ -0,0 +1,45 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_SERIALIZATION_VMDEVICETABLEBUILDER_H_ +#define IREE_COMPILER_SERIALIZATION_VMDEVICETABLEBUILDER_H_ + +#include "flatbuffers/flatbuffers.h" +#include "mlir/Support/LogicalResult.h" +#include "schemas/device_table_def_generated.h" + +namespace mlir { +namespace iree_compiler { + +class VMDeviceTableBuilder { + public: + explicit VMDeviceTableBuilder(::flatbuffers::FlatBufferBuilder *fbb); + + LogicalResult AddDevice(::flatbuffers::Offset<iree::DeviceDef> deviceDef); + + LogicalResult AddDeviceGroup( + ::flatbuffers::Offset<iree::DeviceGroupDef> deviceGroupDef); + + ::flatbuffers::Offset<iree::DeviceTableDef> Finish(); + + private: + ::flatbuffers::FlatBufferBuilder *fbb_; + std::vector<::flatbuffers::Offset<iree::DeviceDef>> deviceDefs_; + std::vector<::flatbuffers::Offset<iree::DeviceGroupDef>> deviceGroupDefs_; +}; + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_SERIALIZATION_VMDEVICETABLEBUILDER_H_
diff --git a/compiler/Serialization/VMExecutableTableBuilder.cpp b/compiler/Serialization/VMExecutableTableBuilder.cpp new file mode 100644 index 0000000..ad71958 --- /dev/null +++ b/compiler/Serialization/VMExecutableTableBuilder.cpp
@@ -0,0 +1,41 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/Serialization/VMExecutableTableBuilder.h" + +namespace mlir { +namespace iree_compiler { + +VMExecutableTableBuilder::VMExecutableTableBuilder( + ::flatbuffers::FlatBufferBuilder *fbb) + : fbb_(fbb) {} + +LogicalResult VMExecutableTableBuilder::AddMultiArchExecutable( + ::flatbuffers::Offset<iree::MultiArchExecutableDef> + multiArchExecutableDef) { + multiArchExecutableDefs_.push_back(multiArchExecutableDef); + return success(); +} + +::flatbuffers::Offset<iree::ExecutableTableDef> +VMExecutableTableBuilder::Finish() { + auto multiArchExecutablesOffset = + fbb_->CreateVector(multiArchExecutableDefs_); + iree::ExecutableTableDefBuilder etdb(*fbb_); + etdb.add_multi_arch_executables(multiArchExecutablesOffset); + return etdb.Finish(); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Serialization/VMExecutableTableBuilder.h b/compiler/Serialization/VMExecutableTableBuilder.h new file mode 100644 index 0000000..0106c09 --- /dev/null +++ b/compiler/Serialization/VMExecutableTableBuilder.h
@@ -0,0 +1,44 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_SERIALIZATION_VM_EXECUTABLE_TABLE_BUILDER_H_ +#define IREE_COMPILER_SERIALIZATION_VM_EXECUTABLE_TABLE_BUILDER_H_ + +#include "flatbuffers/flatbuffers.h" +#include "mlir/Support/LogicalResult.h" +#include "schemas/executable_table_def_generated.h" + +namespace mlir { +namespace iree_compiler { + +class VMExecutableTableBuilder { + public: + explicit VMExecutableTableBuilder(::flatbuffers::FlatBufferBuilder *fbb); + + LogicalResult AddMultiArchExecutable( + ::flatbuffers::Offset<iree::MultiArchExecutableDef> + multiArchExecutableDef); + + ::flatbuffers::Offset<iree::ExecutableTableDef> Finish(); + + private: + ::flatbuffers::FlatBufferBuilder *fbb_; + std::vector<::flatbuffers::Offset<iree::MultiArchExecutableDef>> + multiArchExecutableDefs_; +}; + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_SERIALIZATION_VM_EXECUTABLE_TABLE_BUILDER_H_
diff --git a/compiler/Serialization/VMFunctionBuilder.cpp b/compiler/Serialization/VMFunctionBuilder.cpp new file mode 100644 index 0000000..23eda88 --- /dev/null +++ b/compiler/Serialization/VMFunctionBuilder.cpp
@@ -0,0 +1,359 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/Serialization/VMFunctionBuilder.h" + +#include "flatbuffers/flatbuffers.h" +#include "compiler/IR/Dialect.h" +#include "compiler/IR/Types.h" +#include "compiler/Serialization/BytecodeTables.h" +#include "compiler/Utils/Macros.h" +#include "schemas/bytecode/bytecode_v0.h" +#include "schemas/type_def_generated.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Module.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +LogicalResult WriteGenericIreeOp(Block *block, Operation *op, + BytecodeWriter *writer) { + // Strip the dialect name from the op name and lookup the opcode. + // TODO(benvanik): adjust for supporting sequencer opcodes. + + auto opName = op->getName().getStringRef(); + auto dialect = op->getDialect(); + if (!dialect) { + return op->emitOpError() << "Op does not belong to a registered dialect"; + } + + auto dialectNamespace = dialect->getNamespace(); + std::unique_ptr<OpcodeInfo> operandInfo; + auto strippedOpName = opName.substr(opName.find('.') + 1).str(); + if (dialectNamespace == "iree_ll_seq") { + auto opcode = GetSequencerOpcodeByName(strippedOpName); + if (!opcode.hasValue()) { + return op->emitOpError() + << "No sequencer opcode found for op; is it a pseudo op?"; + } + RETURN_IF_FAILURE(writer->WriteOpcode(opcode.getValue())); + operandInfo = + std::make_unique<OpcodeInfo>(GetSequencerOpcodeInfo(opcode.getValue())); + } else if (dialectNamespace == "iree_ll_interp" || + // TODO(gcmn) remove special case for IREE dialect? + dialectNamespace == IREEDialect::getDialectNamespace()) { + auto opcode = GetInterpreterOpcodeByName(strippedOpName); + if (!opcode.hasValue()) { + return op->emitOpError() + << "No interpreter opcode found for op; is it a pseudo op?"; + } + RETURN_IF_FAILURE(writer->WriteOpcode(opcode.getValue())); + operandInfo = std::make_unique<OpcodeInfo>( + GetInterpreterOpcodeInfo(opcode.getValue())); + } else { + return op->emitOpError() + << "Op belongs to unknown dialect " << dialectNamespace.str(); + } + // Write inputs and outputs based on the bytecode encoding. + int operandIndex = 0; + int resultIndex = 0; + for (int i = 0; i < llvm::array_lengthof(operandInfo->operands); ++i) { + auto op_encoding = operandInfo->operands[i]; + if (op_encoding == iree::OperandEncoding::kNone) break; + switch (op_encoding) { + case iree::OperandEncoding::kInputSlot: + case iree::OperandEncoding::kOutputSlot: { + auto *value = op->getOperand(operandIndex++); + RETURN_IF_FAILURE(writer->WriteLocal(value)); + break; + } + case iree::OperandEncoding::kVariadicInputSlots: + case iree::OperandEncoding::kVariadicOutputSlots: { + int count = op->getNumOperands() - operandIndex; + RETURN_IF_FAILURE(writer->WriteCount(count)); + for (; count; --count) { + auto *value = op->getOperand(operandIndex++); + RETURN_IF_FAILURE(writer->WriteLocal(value)); + } + break; + } + case iree::OperandEncoding::kResultSlot: { + auto *value = op->getResult(resultIndex++); + RETURN_IF_FAILURE(writer->WriteLocal(value)); + break; + } + case iree::OperandEncoding::kVariadicResultSlots: { + int count = op->getNumResults() - resultIndex; + RETURN_IF_FAILURE(writer->WriteCount(count)); + for (; count; --count) { + auto *value = op->getResult(resultIndex++); + RETURN_IF_FAILURE(writer->WriteLocal(value)); + } + break; + } + case iree::OperandEncoding::kConstant: + case iree::OperandEncoding::kFunctionOrdinal: + case iree::OperandEncoding::kBlockOffset: + case iree::OperandEncoding::kTypeIndex: + case iree::OperandEncoding::kIndex: + case iree::OperandEncoding::kIndexList: + case iree::OperandEncoding::kCmpIPredicate: + case iree::OperandEncoding::kCmpFPredicate: + return op->emitOpError() + << "Operand encoding " << static_cast<char>(op_encoding) + << " not supported by generic writer for " << opName.str(); + return failure(); + default: + return op->emitOpError() + << "Operand encoding " << static_cast<char>(op_encoding) << " (" + << static_cast<int>(op_encoding) << ") not recognized (typo?)"; + } + } + + return success(); +} + +} // namespace + +VMFunctionBuilder::VMFunctionBuilder(FuncOp function, + VMFunctionTableBuilder *functionTable, + ::flatbuffers::FlatBufferBuilder *fbb) + : context_(function.getContext()), + function_(function), + functionTable_(functionTable), + fbb_(fbb) {} + +void VMFunctionBuilder::RegisterCustomWriter(StringRef operationName, + CustomWriterFn writerFn) { + customWriters_.insert({operationName, writerFn}); +} + +LogicalResult VMFunctionBuilder::ConvertBytecode() { + BytecodeWriter writer; + sourceMap_ = {}; + + RETURN_IF_FAILURE(BeginFunction(function_, &writer)); + for (auto &block : function_.getBlocks()) { + RETURN_IF_FAILURE(BeginBlock(&block, &writer)); + for (auto &op : block.getOperations()) { + if (failed(WriteOperation(&block, &op, &writer))) { + op.emitError() << "Unable to serialize operation"; + return failure(); + } + } + RETURN_IF_FAILURE(EndBlock(&block, block.getTerminator(), &writer)); + } + RETURN_IF_FAILURE(EndFunction(function_, &writer)); + + int localCount = writer.local_count(); + auto bodyBytes = writer.Finish(); + auto bodyOffset = fbb_->CreateVector( + reinterpret_cast<const int8_t *>(bodyBytes.data()), bodyBytes.size()); + iree::BytecodeDefBuilder bdb(*fbb_); + bdb.add_local_count(localCount); + bdb.add_contents(bodyOffset); + bytecodeDef_ = bdb.Finish(); + + return success(); +} + +::flatbuffers::Offset<iree::FunctionDef> VMFunctionBuilder::Finish() { + using TypeDefVector = + ::flatbuffers::Vector<::flatbuffers::Offset<iree::TypeDef>>; + + const auto &functionType = function_.getType(); + std::vector<::flatbuffers::Offset<iree::TypeDef>> inputs; + for (const auto &type : functionType.getInputs()) { + auto typeOffset = SerializeType(type, fbb_); + if (typeOffset.IsNull()) return {}; + inputs.push_back(typeOffset); + } + ::flatbuffers::Offset<TypeDefVector> inputsOffset; + if (!inputs.empty()) { + inputsOffset = fbb_->CreateVector(inputs); + } + + std::vector<::flatbuffers::Offset<iree::TypeDef>> results; + for (const auto &type : functionType.getResults()) { + auto typeOffset = SerializeType(type, fbb_); + if (typeOffset.IsNull()) return {}; + results.push_back(typeOffset); + } + ::flatbuffers::Offset<TypeDefVector> resultsOffset; + if (!results.empty()) { + resultsOffset = fbb_->CreateVector(results); + } + iree::FunctionTypeDefBuilder ftb(*fbb_); + ftb.add_inputs(inputsOffset); + ftb.add_results(resultsOffset); + auto functionTypeOffset = ftb.Finish(); + + // TODO(benvanik): strip names of internal functions. + auto nameOffset = fbb_->CreateString(function_.getName().str()); + iree::FunctionDefBuilder fdb(*fbb_); + fdb.add_name(nameOffset); + fdb.add_type(functionTypeOffset); + fdb.add_bytecode(bytecodeDef_); + return fdb.Finish(); +} + +LogicalResult VMFunctionBuilder::BeginFunction(FuncOp function, + BytecodeWriter *writer) { + // Assign value slots for all arguments and results. + // Keeping them at the front will make it easier to find during debugging + // and makes spans easier to compute at runtime. + for (auto argument : function.getArguments()) { + RETURN_IF_FAILURE(writer->PrepareLocal(argument)); + } + return success(); +} + +LogicalResult VMFunctionBuilder::EndFunction(FuncOp function, + BytecodeWriter *writer) { + RETURN_IF_FAILURE(writer->FixupOffsets()); + return success(); +} + +LogicalResult VMFunctionBuilder::BeginBlock(Block *block, + BytecodeWriter *writer) { + RETURN_IF_FAILURE(writer->MarkBlockOffset(block)); + return success(); +} + +LogicalResult VMFunctionBuilder::EndBlock(Block *block, Operation *op, + BytecodeWriter *writer) { + return success(); +} + +LogicalResult VMFunctionBuilder::WriteOperation(Block *block, Operation *baseOp, + BytecodeWriter *writer) { + if (!baseOp->getLoc().isa<UnknownLoc>()) { + sourceMap_.locations.push_back({writer->offset(), baseOp->getLoc()}); + } + + // Check registered writers first to allow overrides. + auto writerIt = customWriters_.find(baseOp->getName().getStringRef()); + if (writerIt != customWriters_.end()) { + return writerIt->second(baseOp, writer); + } + + // Fallback to using the generic writer. + if (baseOp->getAbstractOperation()->dialect.getNamespace().startswith( + "iree")) { + RETURN_IF_FAILURE(WriteGenericIreeOp(block, baseOp, writer)); + } else { + return baseOp->emitError() + << "Unsupported op " << baseOp->getName().getStringRef().str() + << "; incorrectly outlined or not yet implemented"; + } + return success(); +} + +::flatbuffers::Offset<iree::TypeDef> VMFunctionBuilder::SerializeType( + Type type, ::flatbuffers::FlatBufferBuilder *fbb) { + ::flatbuffers::Offset<void> typeDefUnion; + iree::TypeDefUnion typeUnionType; + if (auto memRefType = type.dyn_cast<MemRefType>()) { + auto memRefTypeOffset = SerializeMemRefType(memRefType, fbb_); + if (memRefTypeOffset.IsNull()) return {}; + typeDefUnion = memRefTypeOffset.Union(); + typeUnionType = iree::TypeDefUnion::MemRefTypeDef; + } else if (auto deviceType = type.dyn_cast<DeviceType>()) { + typeDefUnion = iree::CreateDeviceTypeDef(*fbb).Union(); + typeUnionType = iree::TypeDefUnion::DeviceTypeDef; + } else if (auto commandBufferType = type.dyn_cast<CommandBufferType>()) { + typeDefUnion = iree::CreateCommandBufferTypeDef(*fbb).Union(); + typeUnionType = iree::TypeDefUnion::CommandBufferTypeDef; + } else if (auto eventType = type.dyn_cast<EventType>()) { + typeDefUnion = iree::CreateEventTypeDef(*fbb).Union(); + typeUnionType = iree::TypeDefUnion::EventTypeDef; + } else if (auto semaphoreType = type.dyn_cast<SemaphoreType>()) { + typeDefUnion = iree::CreateSemaphoreTypeDef(*fbb).Union(); + typeUnionType = iree::TypeDefUnion::SemaphoreTypeDef; + } else if (auto fenceType = type.dyn_cast<FenceType>()) { + typeDefUnion = iree::CreateFenceTypeDef(*fbb).Union(); + typeUnionType = iree::TypeDefUnion::FenceTypeDef; + } else { + function_.emitError() << "Function " << function_.getName().str() + << " has unsupported I/O with type " << type; + return {}; + } + + iree::TypeDefBuilder tdb(*fbb); + tdb.add_type_union_type(typeUnionType); + tdb.add_type_union(typeDefUnion); + return tdb.Finish(); +} + +::flatbuffers::Offset<iree::MemRefTypeDef> +VMFunctionBuilder::SerializeMemRefType(const MemRefType &type, + ::flatbuffers::FlatBufferBuilder *fbb) { + auto elementTypeOffset = SerializeElementType(type.getElementType(), fbb); + if (elementTypeOffset.IsNull()) return {}; + std::vector<int> shape; + for (int dim : type.getShape()) { + shape.push_back(dim); + } + auto shapeOffset = fbb->CreateVector(shape); + iree::MemRefTypeDefBuilder tb(*fbb); + tb.add_element_type(elementTypeOffset); + tb.add_shape(shapeOffset); + tb.add_memory_space(type.getMemorySpace()); + return tb.Finish(); +} + +::flatbuffers::Offset<iree::ElementTypeDef> +VMFunctionBuilder::SerializeElementType(const Type &genericType, + ::flatbuffers::FlatBufferBuilder *fbb) { + ::flatbuffers::Offset<void> typeDefUnion; + iree::ElementTypeDefUnion typeUnionType; + if (auto type = genericType.dyn_cast<FloatType>()) { + iree::FloatTypeDefBuilder tb(*fbb); + tb.add_width(type.getWidth()); + typeDefUnion = tb.Finish().Union(); + typeUnionType = iree::ElementTypeDefUnion::FloatTypeDef; + } else if (auto type = genericType.dyn_cast<IntegerType>()) { + iree::IntegerTypeDefBuilder tb(*fbb); + tb.add_width(type.getWidth()); + typeDefUnion = tb.Finish().Union(); + typeUnionType = iree::ElementTypeDefUnion::IntegerTypeDef; + } else if (auto type = genericType.dyn_cast<OpaqueType>()) { + auto dialectOffset = fbb->CreateString(type.getDialectNamespace().c_str()); + auto typeDataOffset = fbb->CreateString(type.getTypeData().data()); + iree::UnknownTypeDefBuilder tb(*fbb); + tb.add_dialect(dialectOffset); + tb.add_type_data(typeDataOffset); + typeDefUnion = tb.Finish().Union(); + typeUnionType = iree::ElementTypeDefUnion::UnknownTypeDef; + } else { + function_.emitError() + << "Unimplemented type encoding: " << genericType + << "; ensure IREE lowering passes are converting types to the IREE " + "set"; + return {}; + } + + iree::ElementTypeDefBuilder tdb(*fbb); + tdb.add_type_union_type(typeUnionType); + tdb.add_type_union(typeDefUnion); + return tdb.Finish(); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Serialization/VMFunctionBuilder.h b/compiler/Serialization/VMFunctionBuilder.h new file mode 100644 index 0000000..fcfefde --- /dev/null +++ b/compiler/Serialization/VMFunctionBuilder.h
@@ -0,0 +1,77 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_SERIALIZATION_VM_FUNCTION_BUILDER_H_ +#define IREE_COMPILER_SERIALIZATION_VM_FUNCTION_BUILDER_H_ + +#include "compiler/Serialization/BytecodeWriter.h" +#include "compiler/Serialization/VMFunctionTableBuilder.h" +#include "compiler/Serialization/VMSourceMapBuilder.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/StandardTypes.h" +#include "schemas/bytecode_def_generated.h" +#include "schemas/function_def_generated.h" + +namespace mlir { +namespace iree_compiler { + +class VMFunctionBuilder { + public: + using CustomWriterFn = + std::function<LogicalResult(Operation *, BytecodeWriter *writer)>; + + VMFunctionBuilder(FuncOp function, VMFunctionTableBuilder *functionTable, + ::flatbuffers::FlatBufferBuilder *fbb); + ~VMFunctionBuilder() = default; + + void RegisterCustomWriter(StringRef operationName, CustomWriterFn writerFn); + + const VMFunctionSourceMap &source_map() const { return sourceMap_; } + + LogicalResult ConvertBytecode(); + + ::flatbuffers::Offset<iree::FunctionDef> Finish(); + + ::flatbuffers::Offset<iree::TypeDef> SerializeType( + Type type, ::flatbuffers::FlatBufferBuilder *fbb); + ::flatbuffers::Offset<iree::MemRefTypeDef> SerializeMemRefType( + const MemRefType &genericType, ::flatbuffers::FlatBufferBuilder *fbb); + ::flatbuffers::Offset<iree::ElementTypeDef> SerializeElementType( + const Type &genericType, ::flatbuffers::FlatBufferBuilder *fbb); + + private: + LogicalResult BeginFunction(FuncOp function, BytecodeWriter *writer); + LogicalResult EndFunction(FuncOp function, BytecodeWriter *writer); + LogicalResult BeginBlock(Block *block, BytecodeWriter *writer); + LogicalResult EndBlock(Block *block, Operation *op, BytecodeWriter *writer); + + LogicalResult WriteOperation(Block *block, Operation *baseOp, + BytecodeWriter *writer); + + llvm::StringMap<CustomWriterFn> customWriters_; + + MLIRContext *context_; + FuncOp function_; + VMFunctionTableBuilder *functionTable_; + ::flatbuffers::FlatBufferBuilder *fbb_; + ::flatbuffers::Offset<iree::BytecodeDef> bytecodeDef_; + VMFunctionSourceMap sourceMap_; +}; + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_SERIALIZATION_VM_FUNCTION_BUILDER_H_
diff --git a/compiler/Serialization/VMFunctionTableBuilder.cpp b/compiler/Serialization/VMFunctionTableBuilder.cpp new file mode 100644 index 0000000..f013b07 --- /dev/null +++ b/compiler/Serialization/VMFunctionTableBuilder.cpp
@@ -0,0 +1,87 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/Serialization/VMFunctionTableBuilder.h" + +#include "compiler/Serialization/VMSourceMapBuilder.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace iree_compiler { + +VMFunctionTableBuilder::VMFunctionTableBuilder( + ::flatbuffers::FlatBufferBuilder *fbb) + : fbb_(fbb) {} + +bool VMFunctionTableBuilder::IsFunctionDeclared(FuncOp funcOp) { + return functionSet_.count(funcOp.getName()) != 0; +} + +LogicalResult VMFunctionTableBuilder::DeclareFunction(FuncOp funcOp, + LinkageType linkageType) { + if (functionSet_.count(funcOp.getName())) { + return funcOp.emitError() << "Function has already been declared/defined"; + } + auto functionOrdinal = funcOp.getAttrOfType<IntegerAttr>("iree.ordinal"); + if (!functionOrdinal) { + return funcOp.emitError() << "Ordinal not assigned to function"; + } + int ordinal = functionOrdinal.getInt(); + functionDefs_.resize( + std::max(functionDefs_.size(), static_cast<size_t>(ordinal) + 1u)); + functionSourceMaps_.resize( + std::max(functionDefs_.size(), static_cast<size_t>(ordinal) + 1u)); + functionSet_.insert({funcOp.getName()}); + switch (linkageType) { + case LinkageType::kInternal: + break; + case LinkageType::kImport: + importIndices_.push_back(ordinal); + break; + case LinkageType::kExport: + exportIndices_.push_back(ordinal); + break; + } + return success(); +} + +LogicalResult VMFunctionTableBuilder::DefineFunction( + FuncOp funcOp, ::flatbuffers::Offset<iree::FunctionDef> functionDef, + VMFunctionSourceMap functionSourceMap) { + auto functionOrdinal = funcOp.getAttrOfType<IntegerAttr>("iree.ordinal"); + if (!functionOrdinal) { + return funcOp.emitError() << "Ordinal not assigned to function"; + } + int ordinal = functionOrdinal.getInt(); + if (!functionDefs_[ordinal].IsNull()) { + return funcOp.emitOpError() << "Function has already been defined"; + } + functionDefs_[ordinal] = functionDef; + functionSourceMaps_[ordinal] = std::move(functionSourceMap); + return success(); +} + +::flatbuffers::Offset<iree::FunctionTableDef> VMFunctionTableBuilder::Finish() { + auto functionsOffset = fbb_->CreateVector(functionDefs_); + auto importsOffset = fbb_->CreateVector(importIndices_); + auto exportsOffset = fbb_->CreateVector(exportIndices_); + iree::FunctionTableDefBuilder ftdb(*fbb_); + ftdb.add_functions(functionsOffset); + ftdb.add_imports(importsOffset); + ftdb.add_exports(exportsOffset); + return ftdb.Finish(); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Serialization/VMFunctionTableBuilder.h b/compiler/Serialization/VMFunctionTableBuilder.h new file mode 100644 index 0000000..ee3427e --- /dev/null +++ b/compiler/Serialization/VMFunctionTableBuilder.h
@@ -0,0 +1,75 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_SERIALIZATION_VM_FUNCTION_TABLE_BUILDER_H_ +#define IREE_COMPILER_SERIALIZATION_VM_FUNCTION_TABLE_BUILDER_H_ + +#include <string> +#include <vector> + +#include "compiler/Serialization/VMSourceMapBuilder.h" +#include "flatbuffers/flatbuffers.h" +#include "llvm/ADT/StringSet.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OperationSupport.h" +#include "schemas/function_def_generated.h" +#include "schemas/function_table_def_generated.h" + +namespace mlir { +namespace iree_compiler { + +enum class LinkageType { + kInternal, + kImport, + kExport, +}; + +class VMFunctionTableBuilder { + public: + explicit VMFunctionTableBuilder(::flatbuffers::FlatBufferBuilder *fbb); + + int max_function_ordinal() const { return functionDefs_.size(); } + + ArrayRef<VMFunctionSourceMap> function_source_maps() { + return llvm::makeArrayRef(functionSourceMaps_); + } + + // Returns true if |funcOp| has already been declared in the table. + bool IsFunctionDeclared(FuncOp funcOp); + + // Declares |funcOp| with the given |linkageType|. + // Fails if the function has already been declared or defined. + LogicalResult DeclareFunction(FuncOp funcOp, LinkageType linkageType); + + // Defines |funcOp| using the given |functionDef|. + LogicalResult DefineFunction( + FuncOp funcOp, ::flatbuffers::Offset<iree::FunctionDef> functionDef, + VMFunctionSourceMap functionSourceMap); + + ::flatbuffers::Offset<iree::FunctionTableDef> Finish(); + + private: + ::flatbuffers::FlatBufferBuilder *fbb_; + llvm::StringSet<> functionSet_; + std::vector<::flatbuffers::Offset<iree::FunctionDef>> functionDefs_; + std::vector<VMFunctionSourceMap> functionSourceMaps_; + std::vector<int> importIndices_; + std::vector<int> exportIndices_; +}; + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_SERIALIZATION_VM_FUNCTION_TABLE_BUILDER_H_
diff --git a/compiler/Serialization/VMModuleBuilder.cpp b/compiler/Serialization/VMModuleBuilder.cpp new file mode 100644 index 0000000..0aa7d60 --- /dev/null +++ b/compiler/Serialization/VMModuleBuilder.cpp
@@ -0,0 +1,70 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/Serialization/VMModuleBuilder.h" + +#include "schemas/executable_table_def_generated.h" + +namespace mlir { +namespace iree_compiler { + +VMModuleBuilder::VMModuleBuilder(::flatbuffers::FlatBufferBuilder *fbb) + : fbb_(fbb), + deviceTable_(fbb), + functionTable_(fbb), + executableTable_(fbb), + sourceMap_(fbb) {} + +::flatbuffers::Offset<iree::ModuleDef> VMModuleBuilder::Finish() { + auto nameOffset = fbb_->CreateString("module"); + auto deviceTableOffset = deviceTable_.Finish(); + if (deviceTableOffset.IsNull()) return {}; + auto functionTableOffset = functionTable_.Finish(); + if (functionTableOffset.IsNull()) return {}; + auto executableTableOffset = executableTable_.Finish(); + if (executableTableOffset.IsNull()) return {}; + + for (int function_ordinal = 0; + function_ordinal < functionTable_.function_source_maps().size(); + ++function_ordinal) { + if (failed(sourceMap_.AddFunction( + function_ordinal, + functionTable_.function_source_maps()[function_ordinal]))) { + return {}; + } + } + auto sourceMapOffset = + sourceMap_.Finish(functionTable_.max_function_ordinal()); + if (sourceMapOffset.IsNull()) return {}; + + iree::ModuleDefBuilder mdb(*fbb_); + mdb.add_name(nameOffset); + mdb.add_device_table(deviceTableOffset); + mdb.add_function_table(functionTableOffset); + mdb.add_executable_table(executableTableOffset); + mdb.add_source_map(sourceMapOffset); + return mdb.Finish(); +} + +std::vector<uint8_t> VMModuleBuilder::Serialize( + ::flatbuffers::Offset<iree::ModuleDef> module_def) { + FinishModuleDefBuffer(*fbb_, module_def); + std::vector<uint8_t> bytes; + bytes.resize(fbb_->GetSize()); + std::memcpy(bytes.data(), fbb_->GetBufferPointer(), bytes.size()); + return bytes; +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Serialization/VMModuleBuilder.h b/compiler/Serialization/VMModuleBuilder.h new file mode 100644 index 0000000..e9a9516 --- /dev/null +++ b/compiler/Serialization/VMModuleBuilder.h
@@ -0,0 +1,57 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_SERIALIZATION_VM_MODULE_BUILDER_H_ +#define IREE_COMPILER_SERIALIZATION_VM_MODULE_BUILDER_H_ + +#include <vector> + +#include "compiler/Serialization/VMDeviceTableBuilder.h" +#include "compiler/Serialization/VMExecutableTableBuilder.h" +#include "compiler/Serialization/VMFunctionTableBuilder.h" +#include "compiler/Serialization/VMSourceMapBuilder.h" +#include "flatbuffers/flatbuffers.h" +#include "schemas/module_def_generated.h" + +namespace mlir { +namespace iree_compiler { + +class VMModuleBuilder { + public: + explicit VMModuleBuilder(::flatbuffers::FlatBufferBuilder *fbb); + + ::flatbuffers::FlatBufferBuilder *fbb() const { return fbb_; } + VMDeviceTableBuilder *device_table() { return &deviceTable_; } + VMFunctionTableBuilder *function_table() { return &functionTable_; } + VMExecutableTableBuilder *executable_table() { return &executableTable_; } + VMSourceMapBuilder *source_map() { return &sourceMap_; } + + ::flatbuffers::Offset<iree::ModuleDef> Finish(); + + std::vector<uint8_t> Serialize( + ::flatbuffers::Offset<iree::ModuleDef> module_def); + + private: + ::flatbuffers::FlatBufferBuilder *fbb_; + + VMDeviceTableBuilder deviceTable_; + VMFunctionTableBuilder functionTable_; + VMExecutableTableBuilder executableTable_; + VMSourceMapBuilder sourceMap_; +}; + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_SERIALIZATION_VM_MODULE_BUILDER_H_
diff --git a/compiler/Serialization/VMSourceMapBuilder.cpp b/compiler/Serialization/VMSourceMapBuilder.cpp new file mode 100644 index 0000000..a534540 --- /dev/null +++ b/compiler/Serialization/VMSourceMapBuilder.cpp
@@ -0,0 +1,164 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/Serialization/VMSourceMapBuilder.h" + +#include "flatbuffers/flatbuffers.h" +#include "schemas/source_map_def_generated.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Identifier.h" +#include "mlir/IR/Location.h" + +namespace mlir { +namespace iree_compiler { + +VMSourceMapBuilder::VMSourceMapBuilder(::flatbuffers::FlatBufferBuilder *fbb) + : fbb_(fbb) {} + +int VMSourceMapBuilder::GetUniqueString(std::string value) { + auto it = stringTableMap_.find(value); + if (it != stringTableMap_.end()) { + return it->second; + } + int stringIndex = stringTable_.size(); + stringTableMap_.insert({value, stringIndex}); + stringTable_.push_back(std::move(value)); + return stringIndex; +} + +LogicalResult VMSourceMapBuilder::AddFunction( + int functionOrdinal, VMFunctionSourceMap functionSourceMap) { + if (functionMaps_.size() <= functionOrdinal) { + functionMaps_.resize(functionOrdinal + 1); + } + functionMaps_[functionOrdinal] = std::move(functionSourceMap); + return success(); +} + +::flatbuffers::Offset<iree::SourceMapDef> VMSourceMapBuilder::Finish( + int maxFunctionOrdinal) { + // NOTE: we always ensure the source map table is the same size as the + // function table so that lookups at runtime can be validated once at load + // time (ensuring the tables match up) instead of on each lookup. + if (maxFunctionOrdinal < functionMaps_.size()) { + llvm::errs() << "Max function ordinal defined as " << maxFunctionOrdinal + << " but there are " << functionMaps_.size() + << " function source maps present"; + return {}; + } + functionMaps_.resize(maxFunctionOrdinal); + + std::vector<::flatbuffers::Offset<iree::FunctionSourceMapDef>> functionDefs; + functionDefs.resize(maxFunctionOrdinal); + for (int i = 0; i < functionMaps_.size(); ++i) { + const auto &functionMap = functionMaps_[i]; + functionDefs[i] = SerializeVMFunctionSourceMap(functionMap); + if (functionDefs[i].IsNull()) return {}; + } + + auto functionTableOffset = fbb_->CreateVector(functionDefs); + auto stringTableOffset = fbb_->CreateVectorOfStrings(stringTable_); + iree::SourceMapDefBuilder smdb(*fbb_); + smdb.add_function_table(functionTableOffset); + smdb.add_string_table(stringTableOffset); + return smdb.Finish(); +} + +::flatbuffers::Offset<iree::FunctionSourceMapDef> +VMSourceMapBuilder::SerializeVMFunctionSourceMap( + const VMFunctionSourceMap &functionMap) { + if (functionMap.locations.empty()) { + // Empty table. This ensures that we still have a non-null value in the + // function table but doesn't waste much space. + iree::FunctionSourceMapDefBuilder fsmdb(*fbb_); + return fsmdb.Finish(); + } + + LocationOffsetTable locationOffsetTable; + std::vector<iree::BytecodeSourceLocation> bytecodeMap; + for (const auto &offset_location : functionMap.locations) { + int locationIndex = + SerializeLocation(offset_location.second, &locationOffsetTable); + bytecodeMap.push_back({offset_location.first, locationIndex}); + } + auto locationTableOffset = + fbb_->CreateVector(locationOffsetTable.locationDefs); + auto bytecodeMapOffset = fbb_->CreateVectorOfStructs(bytecodeMap); + + iree::FunctionSourceMapDefBuilder fsmdb(*fbb_); + fsmdb.add_location_table(locationTableOffset); + fsmdb.add_bytecode_map(bytecodeMapOffset); + return fsmdb.Finish(); +} + +int VMSourceMapBuilder::SerializeLocation( + const Location &location, LocationOffsetTable *locationOffsetTable) { + auto existingIt = locationOffsetTable->locationMap.find(location); + if (existingIt != locationOffsetTable->locationMap.end()) { + return existingIt->getSecond(); + } + + iree::LocationDefUnion locationUnionType; + ::flatbuffers::Offset<void> locationUnionOffset; + if (auto fileLoc = location.dyn_cast<FileLineColLoc>()) { + locationUnionType = iree::LocationDefUnion::FileLocationDef; + int filenameIndex = GetUniqueString(fileLoc.getFilename().str()); + iree::FileLocationDefBuilder lb(*fbb_); + lb.add_filename(filenameIndex); + lb.add_line(fileLoc.getLine()); + lb.add_column(fileLoc.getColumn()); + locationUnionOffset = lb.Finish().Union(); + } else if (auto nameLoc = location.dyn_cast<NameLoc>()) { + locationUnionType = iree::LocationDefUnion::NameLocationDef; + int nameIndex = GetUniqueString(nameLoc.getName().str()); + iree::NameLocationDefBuilder lb(*fbb_); + lb.add_name(nameIndex); + locationUnionOffset = lb.Finish().Union(); + } else if (auto callSiteLoc = location.dyn_cast<CallSiteLoc>()) { + locationUnionType = iree::LocationDefUnion::CallSiteLocationDef; + int calleeIndex = + SerializeLocation(callSiteLoc.getCallee(), locationOffsetTable); + int callerIndex = + SerializeLocation(callSiteLoc.getCaller(), locationOffsetTable); + iree::CallSiteLocationDefBuilder lb(*fbb_); + lb.add_callee_location(calleeIndex); + lb.add_caller_location(callerIndex); + locationUnionOffset = lb.Finish().Union(); + } else if (auto fusedLoc = location.dyn_cast<FusedLoc>()) { + locationUnionType = iree::LocationDefUnion::FusedLocationDef; + std::vector<int> locationIndices; + locationIndices.reserve(fusedLoc.getLocations().size()); + for (const auto &child_loc : fusedLoc.getLocations()) { + int child_index = SerializeLocation(child_loc, locationOffsetTable); + locationIndices.push_back(child_index); + } + auto locationIndicesOffset = fbb_->CreateVector(locationIndices); + iree::FusedLocationDefBuilder lb(*fbb_); + lb.add_locations(locationIndicesOffset); + locationUnionOffset = lb.Finish().Union(); + } else { + llvm_unreachable("Unimplemented location kind"); + } + + iree::LocationDefBuilder ldb(*fbb_); + ldb.add_location_union_type(locationUnionType); + ldb.add_location_union(locationUnionOffset); + int locationIndex = locationOffsetTable->locationDefs.size(); + locationOffsetTable->locationDefs.push_back(ldb.Finish()); + locationOffsetTable->locationMap.insert({location, locationIndex}); + return locationIndex; +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Serialization/VMSourceMapBuilder.h b/compiler/Serialization/VMSourceMapBuilder.h new file mode 100644 index 0000000..13faf71 --- /dev/null +++ b/compiler/Serialization/VMSourceMapBuilder.h
@@ -0,0 +1,64 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_SERIALIZATION_VM_SOURCE_MAP_BUILDER_H_ +#define IREE_COMPILER_SERIALIZATION_VM_SOURCE_MAP_BUILDER_H_ + +#include <vector> + +#include "flatbuffers/flatbuffers.h" +#include "llvm/ADT/StringMap.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "schemas/source_map_def_generated.h" + +namespace mlir { +namespace iree_compiler { + +struct VMFunctionSourceMap { + std::vector<std::pair<int, Location>> locations; +}; + +class VMSourceMapBuilder { + public: + explicit VMSourceMapBuilder(::flatbuffers::FlatBufferBuilder *fbb); + + LogicalResult AddFunction(int functionOrdinal, + VMFunctionSourceMap functionSourceMap); + + ::flatbuffers::Offset<iree::SourceMapDef> Finish(int maxFunctionOrdinal); + + private: + struct LocationOffsetTable { + std::vector<::flatbuffers::Offset<iree::LocationDef>> locationDefs; + llvm::DenseMap<Location, int> locationMap; + }; + + int GetUniqueString(std::string value); + + ::flatbuffers::Offset<iree::FunctionSourceMapDef> + SerializeVMFunctionSourceMap(const VMFunctionSourceMap &functionMap); + int SerializeLocation(const Location &location, + LocationOffsetTable *locationOffsetTable); + + ::flatbuffers::FlatBufferBuilder *fbb_; + std::vector<std::string> stringTable_; + llvm::StringMap<int> stringTableMap_; + std::vector<VMFunctionSourceMap> functionMaps_; +}; + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_SERIALIZATION_VM_SOURCE_MAP_BUILDER_H_
diff --git a/compiler/Transforms/AggressiveOpElimination.cpp b/compiler/Transforms/AggressiveOpElimination.cpp new file mode 100644 index 0000000..9dd1781 --- /dev/null +++ b/compiler/Transforms/AggressiveOpElimination.cpp
@@ -0,0 +1,77 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <deque> +#include <memory> + +#include "compiler/IR/Interpreter/HLOps.h" +#include "compiler/IR/Interpreter/LLOps.h" +#include "compiler/IR/Sequencer/HLOps.h" +#include "compiler/IR/Sequencer/LLOps.h" +#include "compiler/IR/StructureOps.h" +#include "mlir/Analysis/Dominance.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Transforms/Utils.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +template <typename T> +struct EraseUnused : public OpRewritePattern<T> { + using OpRewritePattern<T>::OpRewritePattern; + PatternMatchResult matchAndRewrite(T op, + PatternRewriter &rewriter) const override { + if (op.use_empty()) { + rewriter.eraseOp(op); + return this->matchSuccess(); + } + return this->matchFailure(); + } +}; + +void populateAggressiveOpEliminationPatterns(OwningRewritePatternList &patterns, + MLIRContext *ctx) { + patterns.insert<EraseUnused<LoadOp>, EraseUnused<AllocOp>, + EraseUnused<IREESeq::HL::AllocHeapOp>, + EraseUnused<IREESeq::LL::AllocHeapOp>, + EraseUnused<IREEInterp::HL::AllocHeapOp>, + EraseUnused<IREEInterp::LL::AllocHeapOp>>(ctx); +} + +} // namespace + +// TODO(b/142012496) Make these be handled by normal DCE. +class AggressiveOpEliminationPass + : public FunctionPass<AggressiveOpEliminationPass> { + public: + void runOnFunction() override { + OwningRewritePatternList patterns; + populateAggressiveOpEliminationPatterns(patterns, &getContext()); + + applyPatternsGreedily(getFunction(), patterns); + } +}; + +std::unique_ptr<OpPassBase<FuncOp>> createAggressiveOpEliminationPass() { + return std::make_unique<AggressiveOpEliminationPass>(); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Transforms/AssignFunctionOrdinals.cpp b/compiler/Transforms/AssignFunctionOrdinals.cpp similarity index 100% rename from iree/compiler/Transforms/AssignFunctionOrdinals.cpp rename to compiler/Transforms/AssignFunctionOrdinals.cpp
diff --git a/compiler/Transforms/BUILD b/compiler/Transforms/BUILD new file mode 100644 index 0000000..9590416 --- /dev/null +++ b/compiler/Transforms/BUILD
@@ -0,0 +1,41 @@ +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "Transforms", + srcs = [ + "AggressiveOpElimination.cpp", + "AssignFunctionOrdinals.cpp", + "ConvertFromTupleCallingConvention.cpp", + "ConvertToMemRefCallingConvention.cpp", + "DropUnreachableFunctions.cpp", + "DropUnusedExecutables.cpp", + "LegalizeTypeStorage.cpp", + "LowerStdToIreeDialect.cpp", + "LowerXLAToIreeDialect.cpp", + ], + hdrs = [ + "ConversionUtils.h", + "Passes.h", + "Rewrites.h", + ], + deps = [ + "///compiler/IR", + "///compiler/IR/Interpreter", + "///compiler/IR/Sequencer", + "///compiler/Utils", + "@llvm//:support", + "@local_config_mlir//:Analysis", + "@local_config_mlir//:IR", + "@local_config_mlir//:Pass", + "@local_config_mlir//:StandardDialectRegistration", + "@local_config_mlir//:StandardOps", + "@local_config_mlir//:Support", + "@local_config_mlir//:TransformUtils", + "@local_config_mlir//:Transforms", + "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo", + ], + alwayslink = 1, +)
diff --git a/iree/compiler/Transforms/CMakeLists.txt b/compiler/Transforms/CMakeLists.txt similarity index 100% rename from iree/compiler/Transforms/CMakeLists.txt rename to compiler/Transforms/CMakeLists.txt
diff --git a/compiler/Transforms/ConversionUtils.h b/compiler/Transforms/ConversionUtils.h new file mode 100644 index 0000000..01f484c --- /dev/null +++ b/compiler/Transforms/ConversionUtils.h
@@ -0,0 +1,101 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_TRANSFORMS_CONVERSIONUTILS_H_ +#define IREE_COMPILER_TRANSFORMS_CONVERSIONUTILS_H_ + +#include "compiler/Utils/MemRefUtils.h" +#include "compiler/Utils/TypeConversionUtils.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace iree_compiler { + +template <typename SrcOp, typename DstOp> +struct UnaryOpLowering : public OpConversionPattern<SrcOp> { + using OpConversionPattern<SrcOp>::OpConversionPattern; + + PatternMatchResult matchAndRewrite( + SrcOp op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto *value = loadAccessValue(op.getLoc(), operands[0], rewriter); + value = wrapAsMemRef(value, op, rewriter); + + auto dstType = convertTypeToMemRef(op.getResult()); + auto dstOp = rewriter.create<DstOp>(op.getLoc(), dstType, value); + auto result = dstOp.getResult(); + result = wrapAsTensor(result, op, rewriter); + + rewriter.replaceOp( + op, {loadResultValue(op.getLoc(), op.getType(), result, rewriter)}); + return this->matchSuccess(); + } +}; + +template <typename SrcOp, typename DstOp> +struct BinaryOpLowering : public OpConversionPattern<SrcOp> { + using OpConversionPattern<SrcOp>::OpConversionPattern; + + PatternMatchResult matchAndRewrite( + SrcOp op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto *lhsValue = loadAccessValue(op.getLoc(), operands[0], rewriter); + auto *rhsValue = loadAccessValue(op.getLoc(), operands[1], rewriter); + auto dstType = convertTypeToMemRef(op.getResult()); + + lhsValue = wrapAsMemRef(lhsValue, op, rewriter); + rhsValue = wrapAsMemRef(rhsValue, op, rewriter); + + auto midOp = + rewriter.create<DstOp>(op.getLoc(), dstType, lhsValue, rhsValue); + auto result = midOp.getResult(); + result = wrapAsTensor(result, op, rewriter); + + rewriter.replaceOp( + op, {loadResultValue(op.getLoc(), op.getType(), result, rewriter)}); + return this->matchSuccess(); + } +}; + +template <typename SrcOp, typename DstOp> +struct TernaryOpLowering : public OpConversionPattern<SrcOp> { + using OpConversionPattern<SrcOp>::OpConversionPattern; + + PatternMatchResult matchAndRewrite( + SrcOp op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto *aValue = loadAccessValue(op.getLoc(), operands[0], rewriter); + auto *bValue = loadAccessValue(op.getLoc(), operands[1], rewriter); + auto *cValue = loadAccessValue(op.getLoc(), operands[2], rewriter); + + aValue = wrapAsMemRef(aValue, op, rewriter); + bValue = wrapAsMemRef(bValue, op, rewriter); + cValue = wrapAsMemRef(cValue, op, rewriter); + + auto dstType = convertTypeToMemRef(op.getResult()); + auto dstOp = + rewriter.create<DstOp>(op.getLoc(), dstType, aValue, bValue, cValue); + auto result = dstOp.getResult(); + result = wrapAsTensor(result, op, rewriter); + + rewriter.replaceOp( + op, {loadResultValue(op.getLoc(), op.getType(), result, rewriter)}); + return this->matchSuccess(); + } +}; + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_TRANSFORMS_CONVERSIONUTILS_H_
diff --git a/iree/compiler/Transforms/ConvertFromTupleCallingConvention.cpp b/compiler/Transforms/ConvertFromTupleCallingConvention.cpp similarity index 100% rename from iree/compiler/Transforms/ConvertFromTupleCallingConvention.cpp rename to compiler/Transforms/ConvertFromTupleCallingConvention.cpp
diff --git a/compiler/Transforms/ConvertToMemRefCallingConvention.cpp b/compiler/Transforms/ConvertToMemRefCallingConvention.cpp new file mode 100644 index 0000000..1b65ede --- /dev/null +++ b/compiler/Transforms/ConvertToMemRefCallingConvention.cpp
@@ -0,0 +1,398 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Ops.h" +#include "compiler/IR/StructureOps.h" +#include "compiler/Utils/MemRefUtils.h" +#include "compiler/Utils/TypeConversionUtils.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Transforms/Utils.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +void copyOperationAttrs(Operation *oldOp, Operation *newOp) { + for (const auto &oldAttr : oldOp->getAttrs()) { + newOp->setAttr(oldAttr.first, oldAttr.second); + } +} + +FunctionType getMemRefFunctionType(FunctionType type) { + Builder builder(type.getContext()); + llvm::SmallVector<Type, 8> replacementInputs; + for (auto type : type.getInputs()) { + auto memRefType = convertTypeToMemRef(type); + if (!memRefType) { + return nullptr; + } + replacementInputs.push_back(memRefType); + } + llvm::SmallVector<Type, 8> replacementResults; + for (auto type : type.getResults()) { + auto memRefType = convertTypeToMemRef(type); + if (!memRefType) { + return nullptr; + } + replacementResults.push_back(memRefType); + } + return builder.getFunctionType(replacementInputs, replacementResults); +} + +bool insertLoad(BlockArgument *oldArg, BlockArgument *newArg, + OpBuilder &builder, BlockAndValueMapping *mapping) { + auto loc = oldArg->getOwner()->getParent()->getLoc(); + + // If old arg was a memref we don't need to change anything. We still need + // to remap so that the use lists match through conversion, though. + if (oldArg->getType().isa<MemRefType>()) { + mapping->map(oldArg, newArg); + return false; + } else if (oldArg->getType().isa<TensorType>()) { + auto castOp = builder.create<IREE::MemRefToTensorOp>(loc, newArg); + mapping->map(oldArg, castOp.getResult()); + return false; + } + + // Insert the load we'll use to unbox the value. + auto loadedValue = builder.create<LoadOp>(loc, newArg, ArrayRef<Value *>{}); + mapping->map(oldArg, loadedValue); + + return false; +} + +bool insertLoad(Operation *oldOp, Value *oldValue, Value *newValue, + OpBuilder &builder, BlockAndValueMapping *mapping) { + // If old value was a memref we don't need to change anything. + if (oldValue->getType().isa<MemRefType>()) { + mapping->map(oldValue, newValue); + return false; + } else if (oldValue->getType().isa<TensorType>()) { + auto castOp = + builder.create<IREE::MemRefToTensorOp>(oldOp->getLoc(), newValue); + mapping->map(oldValue, castOp.getResult()); + return false; + } + + assert(newValue->getType().isa<MemRefType>()); + + // Insert the load we'll use to unbox the value. + auto loadedValue = + builder.create<LoadOp>(oldOp->getLoc(), newValue, ArrayRef<Value *>{}); + mapping->map(oldValue, loadedValue); + + return false; +} + +Value *insertStore(Operation *oldOp, Value *oldValue, OpBuilder &builder, + BlockAndValueMapping *mapping) { + auto *newValue = mapping->lookupOrNull(oldValue); + if (!newValue) { + return nullptr; + } + + // If the previous value was already a memref we don't need to change + // anything. + // TODO(benvanik): ensure indices make sense. + if (oldValue->getType().isa<MemRefType>()) { + return newValue; + } else if (oldValue->getType().isa<TensorType>()) { + auto castOp = + builder.create<IREE::TensorToMemRefOp>(oldOp->getLoc(), newValue); + return castOp.getResult(); + } + + // Look back up and see if we can find the memref the value was loaded from. + if (auto *sourceMemRef = resolveValueToSourceMemRef(oldValue, oldOp)) { + return mapping->lookupOrNull(sourceMemRef); + } + + // Allocate the memref to store the value. + auto newStorage = builder.create<AllocOp>( + oldOp->getLoc(), convertTypeToMemRef(oldValue->getType())); + + // Insert the store we'll use to box the value. + builder.create<StoreOp>(oldOp->getLoc(), newValue, newStorage, + ArrayRef<Value *>{}); + + return newStorage; +} + +bool convertCallOp(CallOp *oldOp, OpBuilder &builder, + BlockAndValueMapping *mapping) { + llvm::SmallVector<Value *, 4> newArgs; + for (auto *oldArg : oldOp->getOperands()) { + auto *newArg = insertStore(oldOp->getOperation(), oldArg, builder, mapping); + if (!newArg) { + return true; + } + newArgs.push_back(newArg); + } + + SmallVector<Type, 4> resultTypes; + for (auto oldType : oldOp->getOperation()->getResultTypes()) { + resultTypes.push_back(convertTypeToMemRef(oldType)); + } + auto newOp = builder.create<CallOp>(oldOp->getLoc(), oldOp->getCallee(), + resultTypes, newArgs); + copyOperationAttrs(oldOp->getOperation(), newOp.getOperation()); + + for (int i = 0; i < newOp.getNumResults(); ++i) { + auto *oldResult = oldOp->getResult(i); + auto *newResult = newOp.getResult(i); + if (insertLoad(oldOp->getOperation(), oldResult, newResult, builder, + mapping)) { + return true; + } + } + + return false; +} + +bool convertCallIndirectOp(CallIndirectOp *oldOp, OpBuilder &builder, + BlockAndValueMapping *mapping) { + // TODO(benvanik): support wrapping callee values. + oldOp->emitError("CallIndirectOp not yet supported"); + return true; +#if 0 + llvm::SmallVector<Value *, 4> newArgs; + for (auto *oldArg : oldOp->getArgOperands()) { + auto *newArg = insertStore(oldOp->getOperation(), oldArg, builder, mapping); + if (!newArg) { + return true; + } + newArgs.push_back(newArg); + } + + auto newOp = builder.create<CallIndirectOp>(oldOp->getLoc(), + oldOp->getCallee(), newArgs); + copyOperationAttrs(oldOp->getOperation(), newOp.getOperation()); + + for (int i = 0; i < newOp.getNumResults(); ++i) { + auto *oldResult = oldOp->getResult(i); + auto *newResult = newOp.getResult(i); + if (insertLoad(oldOp->getOperation(), oldResult, newResult, builder, + mapping)) { + return true; + } + } + + return false; +#endif // 0 +} + +bool convertReturnOp(Operation *oldOp, OpBuilder &builder, + BlockAndValueMapping *mapping) { + BlockAndValueMapping returnMapping; + for (auto *oldArg : oldOp->getOperands()) { + auto *newArg = insertStore(oldOp, oldArg, builder, mapping); + if (!newArg) { + return true; + } + returnMapping.map(oldArg, newArg); + } + + builder.clone(*oldOp, returnMapping); + return false; +} + +bool convertBranchOp(BranchOp *oldOp, OpBuilder &builder, + BlockAndValueMapping *mapping) { + llvm::SmallVector<Value *, 4> newArgs; + for (auto *oldArg : oldOp->getOperands()) { + auto *newArg = insertStore(oldOp->getOperation(), oldArg, builder, mapping); + if (!newArg) { + return true; + } + newArgs.push_back(newArg); + } + + auto *dest = mapping->lookupOrNull(oldOp->getDest()); + if (!dest) { + oldOp->emitError("Destination block mapping not found"); + return true; + } + + auto newOp = builder.create<BranchOp>(oldOp->getLoc(), dest, newArgs); + copyOperationAttrs(oldOp->getOperation(), newOp.getOperation()); + + return false; +} + +bool convertCondBranchOp(CondBranchOp *oldOp, OpBuilder &builder, + BlockAndValueMapping *mapping) { + llvm::SmallVector<Value *, 4> trueArgs; + for (auto *oldArg : oldOp->getTrueOperands()) { + auto *newArg = insertStore(oldOp->getOperation(), oldArg, builder, mapping); + if (!newArg) { + return true; + } + trueArgs.push_back(newArg); + } + llvm::SmallVector<Value *, 4> falseArgs; + for (auto *oldArg : oldOp->getFalseOperands()) { + auto *newArg = insertStore(oldOp->getOperation(), oldArg, builder, mapping); + if (!newArg) { + return true; + } + falseArgs.push_back(newArg); + } + + auto *trueDest = mapping->lookupOrNull(oldOp->getTrueDest()); + if (!trueDest) { + oldOp->emitError("True destination block mapping not found"); + return true; + } + auto *falseDest = mapping->lookupOrNull(oldOp->getFalseDest()); + if (!falseDest) { + oldOp->emitError("False destination block mapping not found"); + return true; + } + + // Lowering will take care of the condition store. + auto *newCondition = mapping->lookupOrNull(oldOp->getCondition()); + if (!newCondition) { + oldOp->emitError("Condition value mapping not found"); + return false; + } + + auto newOp = builder.create<CondBranchOp>( + oldOp->getLoc(), newCondition, trueDest, trueArgs, falseDest, falseArgs); + copyOperationAttrs(oldOp->getOperation(), newOp.getOperation()); + + return false; +} + +bool convertOperation(Operation *oldOp, OpBuilder &builder, + BlockAndValueMapping *mapping) { + if (isa<ConstantOp>(oldOp)) { + builder.clone(*oldOp, *mapping); + return false; + } else if (auto callOp = dyn_cast<CallOp>(oldOp)) { + return convertCallOp(&callOp, builder, mapping); + } else if (auto callIndirectOp = dyn_cast<CallIndirectOp>(oldOp)) { + return convertCallIndirectOp(&callIndirectOp, builder, mapping); + } else if (isa<ReturnOp>(oldOp) || isa<IREE::ReturnOp>(oldOp)) { + return convertReturnOp(oldOp, builder, mapping); + } else if (auto branchOp = dyn_cast<BranchOp>(oldOp)) { + return convertBranchOp(&branchOp, builder, mapping); + } else if (auto condBranchOp = dyn_cast<CondBranchOp>(oldOp)) { + return convertCondBranchOp(&condBranchOp, builder, mapping); + } else { + builder.clone(*oldOp, *mapping); + return false; + } +} + +bool convertFunction(FuncOp oldFunc, FuncOp newFunc) { + OpBuilder builder(newFunc.getBody()); + BlockAndValueMapping mapping; + + // Create new blocks matching the expected arguments of the old ones. + // This sets up the block mappings to enable us to reference blocks forward + // during conversion. + newFunc.getBlocks().clear(); + for (auto &oldBlock : oldFunc.getBlocks()) { + auto *newBlock = builder.createBlock(&newFunc.getBody()); + for (auto *oldArg : oldBlock.getArguments()) { + // Replace the block args with memrefs. + auto memRefType = convertTypeToMemRef(oldArg->getType()); + if (!memRefType) return true; + auto *newArg = newBlock->addArgument(memRefType); + + // Insert loads to preserve type, if needed. + // This will replace all uses of the oldArg with the loaded value from + // newArg so that the block contents are still using unwrapped values. + if (insertLoad(oldArg, newArg, builder, &mapping)) { + return true; + } + } + mapping.map(&oldBlock, newBlock); + } + + // Convert all ops in the blocks. + for (auto &oldBlock : oldFunc.getBlocks()) { + builder.setInsertionPointToEnd(mapping.lookupOrNull(&oldBlock)); + for (auto &oldOp : oldBlock.getOperations()) { + if (convertOperation(&oldOp, builder, &mapping)) { + return true; + } + } + } + + return false; +} + +} // namespace + +class ConvertToMemRefCallingConventionPass + : public ModulePass<ConvertToMemRefCallingConventionPass> { + public: + void runOnModule() override { + auto module = getModule(); + + // Build a list of (oldFunc, newFunc) for all functions we need to + // replace. This will ensure that when we go to convert function bodies we + // have only new functions defined. + std::vector<std::pair<FuncOp, FuncOp>> convertedFunctions; + + for (auto oldFunc : module.getOps<FuncOp>()) { + // Create the replacement function, ensuring that we copy attributes. + auto functionType = getMemRefFunctionType(oldFunc.getType()); + if (!functionType) { + return signalPassFailure(); + } + + auto newFunc = FuncOp::create(oldFunc.getLoc(), oldFunc.getName(), + functionType, oldFunc.getDialectAttrs()); + convertedFunctions.push_back({oldFunc, newFunc}); + + // Perform the actual body conversion now. + if (convertFunction(oldFunc, newFunc)) { + return signalPassFailure(); + } + } + + // Replace functions in the module. + for (auto &pair : convertedFunctions) { + pair.first.erase(); + module.push_back(pair.second); + } + } +}; + +std::unique_ptr<OpPassBase<ModuleOp>> +createConvertToMemRefCallingConventionPass() { + return std::make_unique<ConvertToMemRefCallingConventionPass>(); +} + +static PassRegistration<ConvertToMemRefCallingConventionPass> pass( + "convert-to-memref-calling-convention", + "Convert functions to use a memref-based calling convention."); + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Transforms/DropUnreachableFunctions.cpp b/compiler/Transforms/DropUnreachableFunctions.cpp new file mode 100644 index 0000000..f334dba --- /dev/null +++ b/compiler/Transforms/DropUnreachableFunctions.cpp
@@ -0,0 +1,67 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/Utils/ModuleUtils.h" +#include "llvm/ADT/SetVector.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/Utils.h" + +namespace mlir { +namespace iree_compiler { + +// Drops all functions in a module that are not reachable by functions with the +// "iree.module.export" attribute. +class DropUnreachableModuleFunctionsPass + : public ModulePass<DropUnreachableModuleFunctionsPass> { + public: + void runOnModule() override { + dropUnusedFunctions(getModule(), {"iree.module.export"}); + } +}; + +// Drops all functions in a module that are not reachable by functions with the +// "iree.executable.export" attribute. +class DropUnreachableExecutableFunctionsPass + : public ModulePass<DropUnreachableExecutableFunctionsPass> { + public: + void runOnModule() override { + dropUnusedFunctions(getModule(), {"iree.executable.export"}); + } +}; + +std::unique_ptr<OpPassBase<ModuleOp>> +createDropUnreachableModuleFunctionsPass() { + return std::make_unique<DropUnreachableModuleFunctionsPass>(); +} + +std::unique_ptr<OpPassBase<ModuleOp>> +createDropUnreachableExecutableFunctionsPass() { + return std::make_unique<DropUnreachableExecutableFunctionsPass>(); +} + +static PassRegistration<DropUnreachableModuleFunctionsPass> moduleFunctionsPass( + "iree-drop-unreachable-module-functions", + "Drop all functions not reachable from an exported function"); + +static PassRegistration<DropUnreachableExecutableFunctionsPass> + executableFunctionsPass( + "iree-drop-unreachable-executable-functions", + "Drop all functions not reachable from an exported function"); + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Transforms/DropUnusedExecutables.cpp b/compiler/Transforms/DropUnusedExecutables.cpp new file mode 100644 index 0000000..73fa365 --- /dev/null +++ b/compiler/Transforms/DropUnusedExecutables.cpp
@@ -0,0 +1,62 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Sequencer/HLOps.h" +#include "compiler/IR/StructureOps.h" +#include "llvm/ADT/SetVector.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/Utils.h" + +namespace mlir { +namespace iree_compiler { + +// Drops all executables in a module that are not used by any dispatch +// sequencer op. +class DropUnusedExecutablesPass : public ModulePass<DropUnusedExecutablesPass> { + public: + void runOnModule() override { + DenseSet<StringRef> usedExecutableNames; + for (auto funcOp : getModule().getOps<FuncOp>()) { + funcOp.walk([&](IREESeq::HL::DispatchOp op) { + usedExecutableNames.insert(op.getExecutable()); + }); + } + DenseSet<Operation *> deadExecutables; + for (auto executableOp : + getModule().getOps<IREE::MultiArchExecutableOp>()) { + if (usedExecutableNames.count(executableOp.getName()) == 0) { + deadExecutables.insert(executableOp); + } + } + for (auto executableOp : deadExecutables) { + executableOp->erase(); + } + } +}; + +std::unique_ptr<OpPassBase<ModuleOp>> createDropUnusedExecutablesPass() { + return std::make_unique<DropUnusedExecutablesPass>(); // NOLINT +} + +static PassRegistration<DropUnusedExecutablesPass> executableFunctionsPass( + "iree-drop-unused-executables", + "Drop all executables not reachable from a dispatch/reduce op."); + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Transforms/Interpreter/BUILD b/compiler/Transforms/Interpreter/BUILD new file mode 100644 index 0000000..9405f06 --- /dev/null +++ b/compiler/Transforms/Interpreter/BUILD
@@ -0,0 +1,41 @@ +# Transforms specific to the IREE interpreter. + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "Interpreter", + srcs = [ + "ExpandReductionsToOps.cpp", + "LowerInterpreterDialect.cpp", + "LowerStdToInterpreterDialect.cpp", + "LowerToInterpreterDialect.cpp", + "LowerXLAToInterpreterDialect.cpp", + "MakeExecutableABI.cpp", + ], + hdrs = [ + "Passes.h", + "Rewrites.h", + ], + deps = [ + "///compiler/IR", + "///compiler/IR/Interpreter", + "///compiler/Serialization", + "///compiler/Transforms", + "///compiler/Utils", + "///schemas/bytecode:interpreter_bytecode_v0", + "@llvm//:support", + "@local_config_mlir//:IR", + "@local_config_mlir//:Pass", + "@local_config_mlir//:StandardOps", + "@local_config_mlir//:Support", + "@local_config_mlir//:TransformUtils", + "@local_config_mlir//:Transforms", + "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo", + "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_legalize_to_standard", + "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_lower_general_dot", + ], + alwayslink = 1, +)
diff --git a/iree/compiler/Transforms/Interpreter/CMakeLists.txt b/compiler/Transforms/Interpreter/CMakeLists.txt similarity index 100% rename from iree/compiler/Transforms/Interpreter/CMakeLists.txt rename to compiler/Transforms/Interpreter/CMakeLists.txt
diff --git a/compiler/Transforms/Interpreter/ExpandReductionsToOps.cpp b/compiler/Transforms/Interpreter/ExpandReductionsToOps.cpp new file mode 100644 index 0000000..c8f80e6 --- /dev/null +++ b/compiler/Transforms/Interpreter/ExpandReductionsToOps.cpp
@@ -0,0 +1,216 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Dialect.h" +#include "compiler/IR/Interpreter/HLDialect.h" +#include "compiler/IR/Interpreter/HLOps.h" +#include "compiler/IR/Ops.h" +#include "compiler/Transforms/ConversionUtils.h" +#include "compiler/Utils/MemRefUtils.h" +#include "compiler/Utils/OpCreationUtils.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +LogicalResult convertReductionOp(FuncOp entryPoint, FuncOp applyFunc, + Operation *elementOp, OpBuilder &builder) { + // Ensure that this op is pass-through and does not interact with any other + // ops within the function. + // TODO(b/139313439): support fused reductions. + for (auto *operand : elementOp->getOperands()) { + if (operand->getDefiningOp() != nullptr) { + return elementOp->emitOpError() + << "Fused reductions are not supported (operand not sourced from " + "block args)"; + } + } + for (auto *result : elementOp->getResults()) { + for (auto *user : result->getUsers()) { + if (!user->isKnownTerminator()) { + return elementOp->emitOpError() << "Fused reductions are not supported " + "(result used by non-terminator)"; + } + } + } + + // Determine the index of the args we care about. We'll use these to match up + // the operands of the entry point with our application. + // Our arguments are expanded tuples like <lhs0, rhs0>, <lhs1, rhs1>, so this + // index gets the offset * 2. + auto &applyEntryBlock = applyFunc.getBlocks().front(); + int setIndex = std::distance(applyEntryBlock.args_begin(), + llvm::find(applyEntryBlock.getArguments(), + elementOp->getOperand(0))) / + 2; + + // Map to the args from the entry point. + auto &entryPointEntryBlock = entryPoint.getBlocks().front(); + Value *srcArg = entryPointEntryBlock.getArgument(setIndex); + Value *initArg = entryPointEntryBlock.getArgument( + applyFunc.getNumArguments() / 2 + setIndex); + Value *dstArg = + entryPointEntryBlock.getArgument(applyFunc.getNumArguments() + setIndex); + auto dstType = dstArg->getType().cast<ShapedType>(); + Type elementType = dstType.getElementType(); + auto loc = elementOp->getLoc(); + auto dimensionAttr = entryPoint.getAttrOfType<IntegerAttr>( + "iree.executable.reduction.dimension"); + + Operation *expandedOp = nullptr; + if (isa<IREEInterp::HL::AddFOp>(elementOp) || + isa<IREEInterp::HL::AddIOp>(elementOp)) { + if (elementType.isa<FloatType>()) { + expandedOp = builder.create<IREEInterp::HL::ReduceSumFOp>( + loc, dstType, srcArg, initArg, dimensionAttr); + } else { + expandedOp = builder.create<IREEInterp::HL::ReduceSumIOp>( + loc, dstType, srcArg, initArg, dimensionAttr); + } + } else if (isa<IREEInterp::HL::MinFOp>(elementOp) || + isa<IREEInterp::HL::MinISOp>(elementOp) || + isa<IREEInterp::HL::MinIUOp>(elementOp)) { + if (elementType.isa<FloatType>()) { + expandedOp = builder.create<IREEInterp::HL::ReduceMinFOp>( + loc, dstType, srcArg, initArg, dimensionAttr); + } else { + expandedOp = builder.create<IREEInterp::HL::ReduceMinIOp>( + loc, dstType, srcArg, initArg, dimensionAttr); + } + } else if (isa<IREEInterp::HL::MaxFOp>(elementOp) || + isa<IREEInterp::HL::MaxISOp>(elementOp) || + isa<IREEInterp::HL::MaxIUOp>(elementOp)) { + if (elementType.isa<FloatType>()) { + expandedOp = builder.create<IREEInterp::HL::ReduceMaxFOp>( + loc, dstType, srcArg, initArg, dimensionAttr); + } else { + expandedOp = builder.create<IREEInterp::HL::ReduceMaxIOp>( + loc, dstType, srcArg, initArg, dimensionAttr); + } + } + if (!expandedOp) { + return elementOp->emitOpError() + << "No matching expanded reduction op for elemental op"; + } + llvm::SmallVector<int64_t, 4> zeroOffset(dstType.getRank(), 0); + auto zeroIndices = createArrayConstant(builder, loc, zeroOffset); + auto lengths = createArrayConstant(builder, loc, dstType.getShape()); + builder.create<IREEInterp::HL::CopyOp>( + loc, expandedOp->getResult(0), zeroIndices, dstArg, zeroIndices, lengths); + + return success(); +} + +// Replaces the given elemental |funcOp| with a widened reduction. +LogicalResult expandReductionFunction(FuncOp entryFunc) { + if (!entryFunc.empty()) { + return entryFunc.emitError() + << "Function has already been expanded or has existing contents"; + } else if (!entryFunc.getAttr("iree.executable.reduction.dimension")) { + return entryFunc.emitError() << "Windowed reductions are not yet supported"; + } + auto applySym = + entryFunc.getAttrOfType<SymbolRefAttr>("iree.executable.reduction.apply"); + if (!applySym) { + return entryFunc.emitError() << "No reduction application function defined"; + } + auto applyFunc = entryFunc.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>( + applySym.getValue()); + if (!applyFunc) { + return entryFunc.emitError() + << "Unable to find apply function " << applySym; + } + + auto *entryBlock = entryFunc.addEntryBlock(); + OpBuilder builder(entryBlock); + + if (applyFunc.getBlocks() + .front() + .walk([&](Operation *op) { + if (!op->isKnownTerminator()) { + if (failed( + convertReductionOp(entryFunc, applyFunc, op, builder))) { + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }) + .wasInterrupted()) { + return applyFunc.emitError() << "Unable to convert apply func"; + } + + builder.create<IREE::ReturnOp>(builder.getUnknownLoc()); + + // Remove the apply function as we have inlined it. + applyFunc.erase(); + entryFunc.removeAttr("iree.executable.reduction.apply"); + entryFunc.removeAttr("iree.executable.reduction.dimension"); + + return success(); +} + +// Limited lowering of reductions to fat reduce_* ops. +// +// The specific subset this supports is: +// * 'min', 'max', and 'add' computations, with function names matching the +// computation +// * one op per reduction (no fusions yet). +// Note: computations and shapes are not validated. +// +// TODO(b/139410773): Implement more generally, supporting custom computations. +class ExpandReductionsToOpsPass : public ModulePass<ExpandReductionsToOpsPass> { + public: + void runOnModule() override { + auto module = getModule(); + SmallVector<FuncOp, 4> reductionFuncs; + for (auto funcOp : module.getOps<FuncOp>()) { + if (funcOp.getAttr("iree.executable.reduction.apply")) { + reductionFuncs.push_back(funcOp); + } + } + for (auto funcOp : reductionFuncs) { + if (failed(expandReductionFunction(funcOp))) { + return signalPassFailure(); + } + } + } +}; + +} // namespace + +std::unique_ptr<OpPassBase<ModuleOp>> createExpandReductionsToOpsPass() { + return std::make_unique<ExpandReductionsToOpsPass>(); +} + +static PassRegistration<ExpandReductionsToOpsPass> pass( + "iree-expand-reductions-to-ops", + "Expands IREE reduction functions to their interpreter ops"); + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Transforms/Interpreter/LowerInterpreterDialect.cpp b/compiler/Transforms/Interpreter/LowerInterpreterDialect.cpp new file mode 100644 index 0000000..8238b48 --- /dev/null +++ b/compiler/Transforms/Interpreter/LowerInterpreterDialect.cpp
@@ -0,0 +1,253 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Dialect.h" +#include "compiler/IR/Interpreter/HLDialect.h" +#include "compiler/IR/Interpreter/HLOps.h" +#include "compiler/IR/Interpreter/LLDialect.h" +#include "compiler/IR/Interpreter/LLOps.h" +#include "compiler/IR/Ops.h" +#include "compiler/Serialization/BytecodeTables.h" +#include "schemas/bytecode/interpreter_bytecode_v0.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/Utils.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +struct LowerBranchOpPattern + : public OpRewritePattern<IREEInterp::HL::BranchOp> { + using OpRewritePattern<IREEInterp::HL::BranchOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(IREEInterp::HL::BranchOp op, + PatternRewriter &rewriter) const { + SmallVector<Value *, 8> operands{op.getOperation()->getOperands()}; + + rewriter.replaceOpWithNewOp<IREEInterp::LL::BranchOp>(op, op.getDest(), + operands); + return matchSuccess(); + } +}; + +struct LowerCondCondBranchOpPattern + : public OpRewritePattern<IREEInterp::HL::CondBranchOp> { + using OpRewritePattern<IREEInterp::HL::CondBranchOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(IREEInterp::HL::CondBranchOp op, + PatternRewriter &rewriter) const { + SmallVector<Value *, 8> trueOperands{op.getTrueOperands()}; + SmallVector<Value *, 8> falseOperands{op.getFalseOperands()}; + + rewriter.replaceOpWithNewOp<IREEInterp::LL::CondBranchOp>( + op, op.getCondition(), op.getTrueDest(), trueOperands, + op.getFalseDest(), falseOperands); + return matchSuccess(); + } +}; + +// Returns true if the op defined by |opName| (like 'iree_ll_interp.reshape') +// uses output operands for results (like iree_ll_interp.add_i) or returns real +// results. +bool opTakesOutputOperands(llvm::StringRef opName) { + if (!opName.consume_front("iree_ll_interp.")) { + assert(false && "op not part of IREE LL Interpreter dialect"); + return false; + } + auto opcode = GetInterpreterOpcodeByName(opName.str()); + assert(opcode.hasValue() && "op has no corresponding opcode"); + const auto &info = GetInterpreterOpcodeInfo(opcode.getValue()); + for (auto &operand : info.operands) { + if (operand == iree::OperandEncoding::kOutputSlot || + operand == iree::OperandEncoding::kVariadicOutputSlots) { + return true; + } + } + return false; +} + +template <typename SrcOp, typename DstOp> +class SimpleOpLowering : public OpRewritePattern<SrcOp> { + using OpRewritePattern<SrcOp>::OpRewritePattern; + + PatternMatchResult matchAndRewrite(SrcOp op, + PatternRewriter &rewriter) const { + SmallVector<Value *, 8> operands{op.getOperation()->getOperands()}; + + // Most ops take results as output operands to populate during execution. + // Certain ops, like reshape, return references to existing memrefs and + // should still retain their results. + if (!opTakesOutputOperands(DstOp::getOperationName())) { + SmallVector<Type, 8> resultTypes{op.getOperation()->getResultTypes()}; + + rewriter.replaceOpWithNewOp<DstOp>(op, resultTypes, operands, + op.getAttrs()); + return this->matchSuccess(); + } + + SmallVector<Value *, 4> replacementValues; + for (Value *result : op.getOperation()->getResults()) { + auto memRefType = result->getType().cast<MemRefType>(); + if (!memRefType.hasStaticShape()) { + // TODO(benvanik): real thing here - dynamic shaping required. + // This should emit a shape calculation based on the operation. Most + // are likely simple and by running DCE after this we can clean up + // parts that are static or unused. + op.emitOpError() << "uses unsupported dynamic shapes"; + return this->matchFailure(); + } + ArrayRef<Value *> dim_pieces; + auto allocOp = rewriter.create<IREEInterp::LL::AllocHeapOp>( + op.getLoc(), memRefType, dim_pieces); + operands.push_back(allocOp); + replacementValues.push_back(allocOp); + } + ArrayRef<Type> resultTypes; + rewriter.create<DstOp>(op.getLoc(), resultTypes, operands, op.getAttrs()); + rewriter.replaceOp(op, replacementValues); + return this->matchSuccess(); + } +}; + +} // namespace + +void populateInterpreterLoweringPatterns(OwningRewritePatternList &patterns, + MLIRContext *ctx) { + patterns.insert<LowerBranchOpPattern, LowerCondCondBranchOpPattern>(ctx); + patterns.insert< + SimpleOpLowering<IREE::ConstantOp, IREEInterp::LL::ConstantOp>, + SimpleOpLowering<IREEInterp::HL::CopyOp, IREEInterp::LL::DynamicCopyOp>, + SimpleOpLowering<IREEInterp::HL::SliceOp, + IREEInterp::LL::DynamicSliceOp>>(ctx); +#define SAME_NAME_SIMPLE_PATTERN(op_name) \ + SimpleOpLowering<IREEInterp::HL::op_name, IREEInterp::LL::op_name> + // clang-format off + patterns.insert< + SAME_NAME_SIMPLE_PATTERN(AssignOp), + SAME_NAME_SIMPLE_PATTERN(AbsFOp), + SAME_NAME_SIMPLE_PATTERN(AbsIOp), + SAME_NAME_SIMPLE_PATTERN(AddFOp), + SAME_NAME_SIMPLE_PATTERN(AddIOp), + SAME_NAME_SIMPLE_PATTERN(AllocHeapOp), + SAME_NAME_SIMPLE_PATTERN(AndOp), + SAME_NAME_SIMPLE_PATTERN(Atan2FOp), + SAME_NAME_SIMPLE_PATTERN(BreakOp), + SAME_NAME_SIMPLE_PATTERN(BroadcastOp), + SAME_NAME_SIMPLE_PATTERN(CallOp), + SAME_NAME_SIMPLE_PATTERN(CallIndirectOp), + SAME_NAME_SIMPLE_PATTERN(CeilFOp), + SAME_NAME_SIMPLE_PATTERN(ClampFOp), + SAME_NAME_SIMPLE_PATTERN(CloneOp), + SAME_NAME_SIMPLE_PATTERN(CmpFOp), + SAME_NAME_SIMPLE_PATTERN(CmpIOp), + SAME_NAME_SIMPLE_PATTERN(CondAssignOp), + SAME_NAME_SIMPLE_PATTERN(ConvertSSOp), + SAME_NAME_SIMPLE_PATTERN(ConvertUUOp), + SAME_NAME_SIMPLE_PATTERN(ConvertSUOp), + SAME_NAME_SIMPLE_PATTERN(ConvertUSOp), + SAME_NAME_SIMPLE_PATTERN(CondBreakOp), + SAME_NAME_SIMPLE_PATTERN(CosFOp), + SAME_NAME_SIMPLE_PATTERN(DimOp), + SAME_NAME_SIMPLE_PATTERN(DivFOp), + SAME_NAME_SIMPLE_PATTERN(DivISOp), + SAME_NAME_SIMPLE_PATTERN(DivIUOp), + SAME_NAME_SIMPLE_PATTERN(ExpFOp), + SAME_NAME_SIMPLE_PATTERN(LogFOp), + SAME_NAME_SIMPLE_PATTERN(RsqrtFOp), + SAME_NAME_SIMPLE_PATTERN(FloorFOp), + SAME_NAME_SIMPLE_PATTERN(LengthOp), + SAME_NAME_SIMPLE_PATTERN(MatMulFOp), + SAME_NAME_SIMPLE_PATTERN(MatMulIOp), + SAME_NAME_SIMPLE_PATTERN(MaxFOp), + SAME_NAME_SIMPLE_PATTERN(MaxISOp), + SAME_NAME_SIMPLE_PATTERN(MaxIUOp), + SAME_NAME_SIMPLE_PATTERN(MinFOp), + SAME_NAME_SIMPLE_PATTERN(MinISOp), + SAME_NAME_SIMPLE_PATTERN(MinIUOp), + SAME_NAME_SIMPLE_PATTERN(MulAddFOp), + SAME_NAME_SIMPLE_PATTERN(MulAddIOp), + SAME_NAME_SIMPLE_PATTERN(MulFOp), + SAME_NAME_SIMPLE_PATTERN(MulIOp), + SAME_NAME_SIMPLE_PATTERN(NotOp), + SAME_NAME_SIMPLE_PATTERN(OrOp), + SAME_NAME_SIMPLE_PATTERN(PadOp), + SAME_NAME_SIMPLE_PATTERN(RankOp), + SAME_NAME_SIMPLE_PATTERN(ReduceSumIOp), + SAME_NAME_SIMPLE_PATTERN(ReduceSumFOp), + SAME_NAME_SIMPLE_PATTERN(ReduceMinIOp), + SAME_NAME_SIMPLE_PATTERN(ReduceMinFOp), + SAME_NAME_SIMPLE_PATTERN(ReduceMaxIOp), + SAME_NAME_SIMPLE_PATTERN(ReduceMaxFOp), + SAME_NAME_SIMPLE_PATTERN(ReshapeOp), + SAME_NAME_SIMPLE_PATTERN(ReturnOp), + SAME_NAME_SIMPLE_PATTERN(SelectOp), + SAME_NAME_SIMPLE_PATTERN(ShapeOp), + SAME_NAME_SIMPLE_PATTERN(ShiftLeftOp), + SAME_NAME_SIMPLE_PATTERN(ShiftRightArithmeticOp), + SAME_NAME_SIMPLE_PATTERN(ShiftRightLogicalOp), + SAME_NAME_SIMPLE_PATTERN(SinFOp), + SAME_NAME_SIMPLE_PATTERN(SplitOp), + SAME_NAME_SIMPLE_PATTERN(SubFOp), + SAME_NAME_SIMPLE_PATTERN(SubIOp), + SAME_NAME_SIMPLE_PATTERN(TanhFOp), + SAME_NAME_SIMPLE_PATTERN(TileOp), + SAME_NAME_SIMPLE_PATTERN(TraceOp), + SAME_NAME_SIMPLE_PATTERN(TransposeOp), + SAME_NAME_SIMPLE_PATTERN(ReverseOp), + SAME_NAME_SIMPLE_PATTERN(XorOp)>(ctx); + // clang-format on +#undef SAME_NAME_SIMPLE_PATTERN +} + +namespace { +class LowerInterpreterDialectPass + : public FunctionPass<LowerInterpreterDialectPass> { + public: + void runOnFunction() override { + OwningRewritePatternList patterns; + populateInterpreterLoweringPatterns(patterns, &getContext()); + + ConversionTarget target(getContext()); + target.addLegalDialect<IREELLInterpreterDialect>(); + target.addLegalOp<FuncOp, IREE::ReturnOp>(); + if (failed(applyFullConversion(getFunction(), target, patterns))) { + return signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr<OpPassBase<FuncOp>> createLowerInterpreterDialectPass() { + return std::make_unique<LowerInterpreterDialectPass>(); +} + +static PassRegistration<LowerInterpreterDialectPass> pass( + "lower-iree-interpreter-hl-to-ll", "Lowers IREE HL ops to IREE LL ops"); + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Transforms/Interpreter/LowerStdToInterpreterDialect.cpp b/compiler/Transforms/Interpreter/LowerStdToInterpreterDialect.cpp new file mode 100644 index 0000000..c174669 --- /dev/null +++ b/compiler/Transforms/Interpreter/LowerStdToInterpreterDialect.cpp
@@ -0,0 +1,303 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Dialect.h" +#include "compiler/IR/Interpreter/HLDialect.h" +#include "compiler/IR/Interpreter/HLOps.h" +#include "compiler/IR/Interpreter/LLDialect.h" +#include "compiler/IR/Ops.h" +#include "compiler/Transforms/ConversionUtils.h" +#include "compiler/Utils/MemRefUtils.h" +#include "compiler/Utils/OpCreationUtils.h" +#include "compiler/Utils/TypeConversionUtils.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/Support/Allocator.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +struct CallOpLowering : public OpConversionPattern<CallOp> { + using OpConversionPattern::OpConversionPattern; + + PatternMatchResult matchAndRewrite( + CallOp op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto callOp = cast<CallOp>(op); + auto calleeType = callOp.getCalleeType(); + rewriter.replaceOpWithNewOp<IREEInterp::HL::CallOp>( + op, callOp.getCallee(), calleeType.getResults(), operands); + return matchSuccess(); + } +}; + +struct CallIndirectOpLowering : public OpConversionPattern<CallIndirectOp> { + using OpConversionPattern::OpConversionPattern; + + PatternMatchResult matchAndRewrite( + CallIndirectOp op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto callOp = cast<CallIndirectOp>(op); + rewriter.replaceOpWithNewOp<IREEInterp::HL::CallIndirectOp>( + op, callOp.getCallee(), operands); + return matchSuccess(); + } +}; + +struct ReturnOpLowering : public OpConversionPattern<ReturnOp> { + using OpConversionPattern::OpConversionPattern; + + PatternMatchResult matchAndRewrite( + ReturnOp op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<IREEInterp::HL::ReturnOp>(op, operands); + return matchSuccess(); + } +}; + +struct BranchOpLowering : public OpConversionPattern<BranchOp> { + using OpConversionPattern::OpConversionPattern; + + PatternMatchResult matchAndRewrite( + BranchOp op, ArrayRef<Value *> properOperands, + ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<IREEInterp::HL::BranchOp>(op, destinations[0], + operands[0]); + return this->matchSuccess(); + } +}; + +struct CondBranchOpLowering : public OpConversionPattern<CondBranchOp> { + using OpConversionPattern::OpConversionPattern; + + PatternMatchResult matchAndRewrite( + CondBranchOp op, ArrayRef<Value *> properOperands, + ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands, + ConversionPatternRewriter &rewriter) const override { + auto *condValue = loadAccessValue(op.getLoc(), properOperands[0], rewriter); + rewriter.replaceOpWithNewOp<IREEInterp::HL::CondBranchOp>( + op, condValue, destinations[IREEInterp::HL::CondBranchOp::trueIndex], + operands[IREEInterp::HL::CondBranchOp::trueIndex], + destinations[IREEInterp::HL::CondBranchOp::falseIndex], + operands[IREEInterp::HL::CondBranchOp::falseIndex]); + return this->matchSuccess(); + } +}; + +template <typename SrcOp, typename DstOp> +struct CompareOpLowering : public OpConversionPattern<SrcOp> { + using OpConversionPattern<SrcOp>::OpConversionPattern; + + PatternMatchResult matchAndRewrite( + SrcOp op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto lhValue = loadAccessValue(op.getLoc(), operands[0], rewriter); + auto rhValue = loadAccessValue(op.getLoc(), operands[1], rewriter); + + lhValue = wrapAsMemRef(lhValue, op, rewriter); + rhValue = wrapAsMemRef(rhValue, op, rewriter); + + // TODO(benvanik): map predicate to stable value. + auto predicate = + rewriter.getI32IntegerAttr(static_cast<int32_t>(op.getPredicate())); + + auto dstType = convertTypeToMemRef(op.getResult()); + auto midOp = rewriter.create<DstOp>(op.getLoc(), dstType, predicate, + lhValue, rhValue); + + auto result = wrapAsTensor(midOp.getResult(), op, rewriter); + rewriter.replaceOp( + op, {loadResultValue(op.getLoc(), op.getType(), result, rewriter)}); + return this->matchSuccess(); + } +}; + +struct CmpIOpLowering + : public CompareOpLowering<CmpIOp, IREEInterp::HL::CmpIOp> { + using CompareOpLowering::CompareOpLowering; +}; + +struct CmpFOpLowering + : public CompareOpLowering<CmpFOp, IREEInterp::HL::CmpFOp> { + using CompareOpLowering::CompareOpLowering; +}; + +struct AllocOpLowering : public OpConversionPattern<AllocOp> { + using OpConversionPattern::OpConversionPattern; + + PatternMatchResult matchAndRewrite( + AllocOp op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + // TODO(benvanik): replace with length computation. + rewriter.replaceOpWithNewOp<IREEInterp::HL::AllocHeapOp>(op, op.getType(), + operands); + return matchSuccess(); + } +}; + +struct DeallocOpLowering : public OpConversionPattern<DeallocOp> { + using OpConversionPattern::OpConversionPattern; + + PatternMatchResult matchAndRewrite( + DeallocOp op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<IREEInterp::HL::DiscardOp>(op, operands[0]); + return matchSuccess(); + } +}; + +struct LoadOpLowering : public OpRewritePattern<LoadOp> { + using OpRewritePattern::OpRewritePattern; + PatternMatchResult matchAndRewrite(LoadOp loadOp, + PatternRewriter &rewriter) const override { + if (loadOp.getMemRefType().getRank() != 0) { + loadOp.emitError() << "Cannot lower load of non-scalar"; + return matchFailure(); + } + ArrayRef<Value *> dimPieces; + auto dst = + rewriter + .create<AllocOp>(loadOp.getLoc(), loadOp.getMemRefType(), dimPieces) + .getResult(); + auto emptyArrayMemref = createArrayConstant(rewriter, loadOp.getLoc(), {}); + rewriter.create<IREEInterp::HL::CopyOp>( + loadOp.getLoc(), loadOp.getMemRef(), + /*srcIndices=*/emptyArrayMemref, dst, + /*dstIndices=*/emptyArrayMemref, /*lengths=*/emptyArrayMemref); + + rewriter.replaceOpWithNewOp<IREE::MemRefToScalarOp>(loadOp, dst); + + return matchSuccess(); + } +}; + +struct StoreOpLowering : public OpRewritePattern<StoreOp> { + using OpRewritePattern::OpRewritePattern; + PatternMatchResult matchAndRewrite(StoreOp storeOp, + PatternRewriter &rewriter) const override { + if (storeOp.getMemRefType().getRank() != 0) { + storeOp.emitError() << "Cannot lower store of non-scalar"; + return matchFailure(); + } + + auto src = rewriter.create<IREE::ScalarToMemRefOp>( + storeOp.getLoc(), storeOp.getValueToStore()); + + auto emptyArrayMemref = createArrayConstant(rewriter, storeOp.getLoc(), {}); + rewriter.replaceOpWithNewOp<IREEInterp::HL::CopyOp>( + storeOp, src, /*srcIndices=*/emptyArrayMemref, storeOp.getMemRef(), + /*dstIndices=*/emptyArrayMemref, /*lengths=*/emptyArrayMemref); + + return matchSuccess(); + } +}; + +#define UNARY_OP_LOWERING(StdOpType, IREEOpType) \ + struct StdOpType##Lowering : public UnaryOpLowering<StdOpType, IREEOpType> { \ + using UnaryOpLowering::UnaryOpLowering; \ + }; + +#define BINARY_OP_LOWERING(StdOpType, IREEOpType) \ + struct StdOpType##Lowering \ + : public BinaryOpLowering<StdOpType, IREEOpType> { \ + using BinaryOpLowering::BinaryOpLowering; \ + }; + +#define TERNARY_OP_LOWERING(StdOpType, IREEOpType) \ + struct StdOpType##Lowering \ + : public TernaryOpLowering<StdOpType, IREEOpType> { \ + using TernaryOpLowering::TernaryOpLowering; \ + }; + +// UNARY_OP_LOWERING(RankOp, IREEInterp::HL::RankOp); +UNARY_OP_LOWERING(DimOp, IREEInterp::HL::DimOp); +// UNARY_OP_LOWERING(ShapeOp, IREEInterp::HL::ShapeOp); +// UNARY_OP_LOWERING(LengthOp, IREEInterp::HL::LengthOp); + +// UNARY_OP_LOWERING(NotOp, IREEInterp::HL::NotOp); +BINARY_OP_LOWERING(AndOp, IREEInterp::HL::AndOp); +BINARY_OP_LOWERING(OrOp, IREEInterp::HL::OrOp); +// BINARY_OP_LOWERING(XorOp, IREEInterp::HL::XorOp); +// BINARY_OP_LOWERING(ShiftLeftOp, IREEInterp::HL::ShiftLeftOp); +// BINARY_OP_LOWERING(ShiftRightLogicalOp, IREEInterp::HL::ShiftRightLogicalOp); +// BINARY_OP_LOWERING(ShiftRightArithmeticOp, +// IREEInterp::HL::ShiftRightArithmeticOp); + +BINARY_OP_LOWERING(AddIOp, IREEInterp::HL::AddIOp); +BINARY_OP_LOWERING(AddFOp, IREEInterp::HL::AddFOp); +BINARY_OP_LOWERING(SubIOp, IREEInterp::HL::SubIOp); +BINARY_OP_LOWERING(SubFOp, IREEInterp::HL::SubFOp); +// UNARY_OP_LOWERING(AbsIOp, IREEInterp::HL::AbsIOp); +// UNARY_OP_LOWERING(AbsFOp, IREEInterp::HL::AbsFOp); +BINARY_OP_LOWERING(MulIOp, IREEInterp::HL::MulIOp); +BINARY_OP_LOWERING(MulFOp, IREEInterp::HL::MulFOp); +BINARY_OP_LOWERING(DivISOp, IREEInterp::HL::DivISOp); +BINARY_OP_LOWERING(DivIUOp, IREEInterp::HL::DivIUOp); +BINARY_OP_LOWERING(DivFOp, IREEInterp::HL::DivFOp); +// BINARY_OP_LOWERING(MulAddIOp, IREEInterp::HL::MulAddIOp); +// BINARY_OP_LOWERING(MulAddFOp, IREEInterp::HL::MulAddFOp); +// UNARY_OP_LOWERING(ExpFOp, IREEInterp::HL::ExpFOp); +// UNARY_OP_LOWERING(LogFOp, IREEInterp::HL::LogFOp); +// UNARY_OP_LOWERING(RsqrtFOp, IREEInterp::HL::RsqrtFOp); +// UNARY_OP_LOWERING(CosFOp, IREEInterp::HL::CosFOp); +// UNARY_OP_LOWERING(SinFOp, IREEInterp::HL::SinFOp); +// UNARY_OP_LOWERING(TanhFOp, IREEInterp::HL::TanhFOp); +// UNARY_OP_LOWERING(Atan2FOp, IREEInterp::HL::Atan2FOp); + +// BINARY_OP_LOWERING(MinISOp, IREEInterp::HL::MinISOp); +// BINARY_OP_LOWERING(MinIUOp, IREEInterp::HL::MinIUOp); +// BINARY_OP_LOWERING(MinFOp, IREEInterp::HL::MinFOp); +// BINARY_OP_LOWERING(MaxISOp, IREEInterp::HL::MaxISOp); +// BINARY_OP_LOWERING(MaxIUOp, IREEInterp::HL::MaxIUOp); +// BINARY_OP_LOWERING(MaxFOp, IREEInterp::HL::MaxFOp); +// TERNARY_OP_LOWERING(ClampFOp, IREEInterp::HL::ClampFOp); +// UNARY_OP_LOWERING(FloorFOp, IREEInterp::HL::FloorFOp); +// UNARY_OP_LOWERING(CeilFOp, IREEInterp::HL::CeilFOp); + +} // namespace + +void populateLowerStdToInterpreterPatterns(OwningRewritePatternList &patterns, + MLIRContext *ctx) { + patterns.insert< + // Control flow. + CallOpLowering, CallIndirectOpLowering, ReturnOpLowering, + BranchOpLowering, CondBranchOpLowering, CmpIOpLowering, CmpFOpLowering, + // Memory management. + AllocOpLowering, DeallocOpLowering, LoadOpLowering, StoreOpLowering, + // Shape operations. + DimOpLowering, + // Logical ops. + AndOpLowering, OrOpLowering, + // Arithmetic ops. + AddIOpLowering, AddFOpLowering, SubIOpLowering, SubFOpLowering, + MulIOpLowering, MulFOpLowering, DivISOpLowering, DivIUOpLowering, + DivFOpLowering>(ctx); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Transforms/Interpreter/LowerToInterpreterDialect.cpp b/compiler/Transforms/Interpreter/LowerToInterpreterDialect.cpp new file mode 100644 index 0000000..a1334de --- /dev/null +++ b/compiler/Transforms/Interpreter/LowerToInterpreterDialect.cpp
@@ -0,0 +1,63 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Dialect.h" +#include "compiler/IR/Interpreter/HLDialect.h" +#include "compiler/IR/Interpreter/LLDialect.h" +#include "compiler/Transforms/Interpreter/Rewrites.h" +#include "compiler/Transforms/Rewrites.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" + +namespace mlir { +namespace iree_compiler { +namespace { + +class LowerToInterpreterDialectPass + : public FunctionPass<LowerToInterpreterDialectPass> { + public: + void runOnFunction() override { + OwningRewritePatternList patterns; + auto* ctx = &getContext(); + xla_hlo::PopulateGeneralDotOpLoweringPatterns(&patterns, ctx); + xla_hlo::PopulateXlaToStdPatterns(&patterns, ctx); + populateLowerStdToIreePatterns(patterns, ctx); + populateLowerStdToInterpreterPatterns(patterns, ctx); + populateLowerXlaToIreePatterns(patterns, ctx); + populateLowerXlaToInterpreterPatterns(patterns, ctx); + + ConversionTarget target(getContext()); + target.addLegalDialect<IREEHLInterpreterDialect, IREEDialect>(); + target.addLegalOp<FuncOp, ReturnOp>(); + if (failed(applyFullConversion(getFunction(), target, patterns))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr<OpPassBase<FuncOp>> createLowerToInterpreterDialectPass() { + return std::make_unique<LowerToInterpreterDialectPass>(); +} + +static PassRegistration<LowerToInterpreterDialectPass> pass( + "lower-to-iree-interpreter", + "Convert all ops to the IREE interpreter dialect"); + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Transforms/Interpreter/LowerXLAToInterpreterDialect.cpp b/compiler/Transforms/Interpreter/LowerXLAToInterpreterDialect.cpp new file mode 100644 index 0000000..00668ae --- /dev/null +++ b/compiler/Transforms/Interpreter/LowerXLAToInterpreterDialect.cpp
@@ -0,0 +1,565 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Dialect.h" +#include "compiler/IR/Interpreter/HLDialect.h" +#include "compiler/IR/Interpreter/HLOps.h" +#include "compiler/IR/Ops.h" +#include "compiler/Transforms/ConversionUtils.h" +#include "compiler/Utils/MemRefUtils.h" +#include "compiler/Utils/OpCreationUtils.h" +#include "compiler/Utils/TypeConversionUtils.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +// TODO(suderman): tablegen this? or something a bit more flexible. + +#define UNARY_OP_LOWERING(XlaOpType, IREEOpType) \ + struct XlaOpType##Lowering \ + : public UnaryOpLowering<xla_hlo::XlaOpType, IREEOpType> { \ + using UnaryOpLowering::UnaryOpLowering; \ + }; + +#define TERNARY_OP_LOWERING(XlaOpType, IREEOpType) \ + struct XlaOpType##Lowering \ + : public TernaryOpLowering<xla_hlo::XlaOpType, IREEOpType> { \ + using TernaryOpLowering::TernaryOpLowering; \ + }; + +UNARY_OP_LOWERING(CopyOp, IREEInterp::HL::CloneOp); +UNARY_OP_LOWERING(ExpOp, IREEInterp::HL::ExpFOp); +UNARY_OP_LOWERING(LogOp, IREEInterp::HL::LogFOp); +UNARY_OP_LOWERING(FloorOp, IREEInterp::HL::FloorFOp); +UNARY_OP_LOWERING(RsqrtOp, IREEInterp::HL::RsqrtFOp); +UNARY_OP_LOWERING(TanhOp, IREEInterp::HL::TanhFOp); +TERNARY_OP_LOWERING(SelectOp, IREEInterp::HL::SelectOp); + +#undef UNARY_OP_LOWERING +#undef TERNARY_OP_LOWERING + +template <typename T> +static Operation *createShapeTargetingOp(ConversionPatternRewriter &rewriter, + Location loc, Value *input, + MemRefType targetType) { + auto shapeOp = createArrayConstant(rewriter, loc, targetType.getShape()); + return rewriter.create<T>(loc, targetType, input, shapeOp); +} + +static Value *inputAsMemref(ConversionPatternRewriter &rewriter, Operation *op, + Value *tensor) { + return wrapAsMemRef(loadAccessValue(op->getLoc(), tensor, rewriter), op, + rewriter); +} + +template <typename SrcOp> +class XlaOpLowering : public OpConversionPattern<SrcOp> { + using OpConversionPattern<SrcOp>::OpConversionPattern; + + PatternMatchResult matchAndRewrite( + SrcOp srcOp, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + SmallVector<Value *, 4> memrefOperands; + for (auto operand : operands) { + memrefOperands.push_back(inputAsMemref(rewriter, srcOp, operand)); + } + + if (auto dstOp = rewriteInternal(&srcOp, memrefOperands, rewriter)) { + rewriter.replaceOp(srcOp, + wrapAsTensor(dstOp->getResult(0), srcOp, rewriter)); + return this->matchSuccess(); + } + return this->matchFailure(); + } + + protected: + virtual Operation *rewriteInternal( + SrcOp *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const { + llvm_unreachable("unimplemented rewrite, did you mean rewriteTerminator?"); + } +}; + +struct BroadcastInDimOpLowering + : public XlaOpLowering<xla_hlo::BroadcastInDimOp> { + using XlaOpLowering::XlaOpLowering; + + Operation *rewriteInternal( + xla_hlo::BroadcastInDimOp *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto *inputValue = operands[0]; + auto inputType = inputValue->getType().cast<MemRefType>(); + auto finalType = convertTypeToMemRef(*op); + + // Reshape to scalar and broadcast. + auto createFinal = createShapeTargetingOp<IREEInterp::HL::BroadcastOp>; + llvm::SmallVector<int64_t, 6> intermediateShape{}; + + // Or reshape to final rank and tile. + if (inputType.getNumElements() != 1) { + createFinal = createShapeTargetingOp<IREEInterp::HL::TileOp>; + + intermediateShape = llvm::SmallVector<int64_t, 6>(finalType.getRank(), 1); + auto inputShape = inputType.getShape(); + auto dimensions = op->broadcast_dimensions(); + for (size_t i = 0; i < inputType.getRank(); ++i) { + auto index = dimensions->getValue(i).cast<IntegerAttr>().getInt(); + intermediateShape[index] = inputShape[i]; + } + } + + auto intermediateType = + MemRefType::get(intermediateShape, inputType.getElementType()); + auto reshapeOp = createShapeTargetingOp<IREEInterp::HL::ReshapeOp>( + rewriter, op->getLoc(), inputValue, intermediateType); + return createFinal(rewriter, op->getLoc(), reshapeOp->getResult(0), + finalType); + } +}; + +struct ConcatOpLowering : public XlaOpLowering<xla_hlo::ConcatenateOp> { + using XlaOpLowering::XlaOpLowering; + + Operation *rewriteInternal( + xla_hlo::ConcatenateOp *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto finalType = convertTypeToMemRef(*op); + + return rewriter.create<IREEInterp::HL::ConcatOp>( + op->getLoc(), finalType, operands, + rewriter.getI32IntegerAttr(op->dimension().getZExtValue())); + } +}; + +struct DotOpLowering : public XlaOpLowering<xla_hlo::DotOp> { + using XlaOpLowering::XlaOpLowering; + + Operation *rewriteInternal( + xla_hlo::DotOp *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto *lhsValue = operands[0]; + auto *rhsValue = operands[1]; + + auto finalType = convertTypeToMemRef(*op); + auto elementType = finalType.getElementType(); + if (!elementType.isa<FloatType>()) { + op->emitOpError("xla_hlo.dot only supports floating point values"); + } + + Operation *matMulOp = rewriter + .create<IREEInterp::HL::MatMulFOp>( + op->getLoc(), finalType, lhsValue, rhsValue) + .getOperation(); + return matMulOp; + } +}; + +struct DynamicUpdateSliceOpLowering + : public XlaOpLowering<xla_hlo::DynamicUpdateSliceOp> { + using XlaOpLowering::XlaOpLowering; + + Operation *rewriteInternal( + xla_hlo::DynamicUpdateSliceOp *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto operand = operands[0]; + auto update = operands[1]; + + auto updateType = update->getType().cast<ShapedType>(); + Value *lengthConstant = + createArrayConstant(rewriter, op->getLoc(), updateType.getShape()); + + auto startIndices = makeArrayRef(operands).drop_front(2); + const int rank = startIndices.size(); + llvm::SmallVector<Value *, 4> valuesToConcat; + valuesToConcat.reserve(startIndices.size()); + auto type = getElementTypeOrSelf(startIndices.front()); + + // To generate the offset matrix we need to convert the variadic tensors + // into a reshaped and concated value. + for (auto index : startIndices) { + auto reshapedIndex = rewriter.create<IREEInterp::HL::ReshapeOp>( + op->getLoc(), MemRefType::get({1}, type), index, + createArrayConstant(rewriter, op->getLoc(), {1})); + valuesToConcat.push_back(reshapedIndex); + } + + auto dstOffset = rewriter + .create<IREEInterp::HL::ConcatOp>( + op->getLoc(), MemRefType::get({rank}, type), + valuesToConcat, rewriter.getI32IntegerAttr(0)) + .getResult(); + + llvm::SmallVector<int64_t, 4> zero_offset; + zero_offset.resize(updateType.getRank(), 0); + auto srcOffset = createArrayConstant(rewriter, op->getLoc(), zero_offset); + + auto copiedOperand = rewriter.create<IREEInterp::HL::CloneOp>( + op->getLoc(), operand->getType(), operand); + + rewriter + .create<IREEInterp::HL::CopyOp>(op->getLoc(), update, srcOffset, + copiedOperand, dstOffset, + lengthConstant) + .getOperation(); + + return copiedOperand; + } +}; + +template <typename XlaOpType, typename IreeFloatOpType, typename IreeIntOpType> +struct BinaryFloatIntOpLowering : public XlaOpLowering<XlaOpType> { + using XlaOpLowering<XlaOpType>::XlaOpLowering; + + Operation *rewriteInternal( + XlaOpType *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto *lhs = operands[0]; + auto *rhs = operands[1]; + auto inputType = lhs->getType().cast<MemRefType>(); + auto elementType = inputType.getElementType(); + + if (elementType.isa<FloatType>()) { + return rewriter.create<IreeFloatOpType>(op->getLoc(), inputType, lhs, + rhs); + } + + return rewriter.create<IreeIntOpType>(op->getLoc(), inputType, lhs, rhs); + } +}; + +struct MaxOpLowering + : public BinaryFloatIntOpLowering<xla_hlo::MaxOp, IREEInterp::HL::MaxFOp, + IREEInterp::HL::MaxISOp> { + using BinaryFloatIntOpLowering::BinaryFloatIntOpLowering; +}; + +struct MinOpLowering + : public BinaryFloatIntOpLowering<xla_hlo::MinOp, IREEInterp::HL::MinFOp, + IREEInterp::HL::MinISOp> { + using BinaryFloatIntOpLowering::BinaryFloatIntOpLowering; +}; + +struct ConvertLowering : public XlaOpLowering<xla_hlo::ConvertOp> { + using XlaOpLowering<xla_hlo::ConvertOp>::XlaOpLowering; + + Operation *rewriteInternal( + xla_hlo::ConvertOp *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto *operand = operands[0]; + auto *result = op->getResult(); + + auto operandType = operand->getType().cast<MemRefType>().getElementType(); + auto resultType = result->getType().cast<ShapedType>().getElementType(); + + auto newResultType = convertTypeToMemRef(result); + +#define ConvertCase(InType, OutType, NewOp) \ + { \ + if (operandType.isa<InType>() && resultType.isa<OutType>()) { \ + return rewriter.create<NewOp>(op->getLoc(), newResultType, operand); \ + } \ + } + ConvertCase(IntegerType, IntegerType, IREEInterp::HL::ConvertSSOp); + ConvertCase(IntegerType, FloatType, IREEInterp::HL::ConvertSFOp); + ConvertCase(FloatType, IntegerType, IREEInterp::HL::ConvertFSOp); + ConvertCase(FloatType, FloatType, IREEInterp::HL::ConvertFFOp); +#undef ConvertCase + + return nullptr; + } +}; + +// Lowers a subset of gathers along axis 0 that are really just a slice and +// reshape. +struct GatherOpLowering : public OpConversionPattern<xla_hlo::GatherOp> { + using OpConversionPattern::OpConversionPattern; + + // TODO(gcmn): This only handles a minimal number of cases. When XLA + // redefines gather to be simpler, lower it properly. + PatternMatchResult matchAndRewrite( + xla_hlo::GatherOp gatherOp, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + if (gatherOp.index_vector_dim() != 0) { + gatherOp.emitRemark() + << "Couldn't lower gather with index_vector_dim != 0"; + return matchFailure(); + } + if (gatherOp.start_index_map().getType().getRank() != 1 || + gatherOp.start_index_map().getValue(0).cast<IntegerAttr>().getValue() != + 0) { + gatherOp.emitRemark() + << "Couldn't lower gather with start_index_map != [0]"; + return matchFailure(); + } + if (gatherOp.collapsed_slice_dims().getType().getRank() != 1 || + gatherOp.collapsed_slice_dims() + .getValue(0) + .cast<IntegerAttr>() + .getValue() != 0) { + gatherOp.emitRemark() + << "Couldn't lower gather with collapsed_dims != [0]"; + return matchFailure(); + } + + auto resultType = gatherOp.getResult()->getType().cast<RankedTensorType>(); + if (gatherOp.offset_dims().getType().getNumElements() != + resultType.getRank()) { + gatherOp.emitRemark() << "Couldn't lower gather with offset_dims != " + "[0,...,rank of output]"; + return matchFailure(); + } + for (auto it : llvm::enumerate(gatherOp.offset_dims())) { + if (it.index() != it.value()) { + gatherOp.emitRemark() << "Couldn't lower gather with offset_dims != " + "[0,...,rank of output]"; + return matchFailure(); + } + } + + for (auto it : llvm::enumerate(resultType.getShape())) { + if (gatherOp.slice_sizes() + .getValue(it.index() + 1) + .cast<IntegerAttr>() + .getValue() != it.value()) { + gatherOp.emitRemark() + << "Couldn't lower gather with slice_sizes not [1] + final shape"; + return matchFailure(); + } + } + + auto inputType = gatherOp.operand()->getType().cast<RankedTensorType>(); + + auto startIndices = + inputAsMemref(rewriter, gatherOp, gatherOp.start_indices()); + auto startIndicesType = startIndices->getType().cast<MemRefType>(); + if (startIndicesType.getNumElements() != inputType.getRank()) { + auto extraDims = inputType.getRank() - startIndicesType.getNumElements(); + auto elementType = startIndicesType.getElementType(); + + if (startIndicesType.getRank() != 1) { + startIndices = createShapeTargetingOp<IREEInterp::HL::ReshapeOp>( + rewriter, gatherOp.getLoc(), startIndices, + MemRefType::get({1}, elementType)) + ->getResult(0); + } + + llvm::SmallVector<int64_t, 4> zeroes; + zeroes.resize(extraDims, 0); + + auto elementsAttr = DenseIntElementsAttr::get( + RankedTensorType::get(zeroes.size(), elementType), + llvm::makeArrayRef(zeroes)); + + auto extraStartIndices = + rewriter.create<IREE::ConstantOp>(gatherOp.getLoc(), elementsAttr); + + auto memrefOutputType = + MemRefType::get({inputType.getRank()}, elementType); + + SmallVector<Value *, 2> valuesToConcat = {startIndices, + extraStartIndices}; + startIndices = rewriter.create<IREEInterp::HL::ConcatOp>( + gatherOp.getLoc(), memrefOutputType, valuesToConcat, + rewriter.getI32IntegerAttr(0)); + } + + auto sliceSizeValues = gatherOp.slice_sizes().getValues<int64_t>(); + std::vector<int64_t> sliceSizes = {sliceSizeValues.begin(), + sliceSizeValues.end()}; + auto dstType = MemRefType::get(sliceSizes, inputType.getElementType()); + + auto src = inputAsMemref(rewriter, gatherOp, gatherOp.operand()); + std::vector<Value *> dim_pieces; + auto dst = rewriter.create<IREEInterp::HL::AllocHeapOp>( + gatherOp.getLoc(), dstType, dim_pieces); + auto lengths = rewriter.create<IREE::ConstantOp>(gatherOp.getLoc(), + gatherOp.slice_sizes()); + llvm::SmallVector<int64_t, 4> zero_offset; + zero_offset.resize(dstType.getRank(), 0); + auto dstIndices = + createArrayConstant(rewriter, gatherOp.getLoc(), zero_offset); + + rewriter.create<IREEInterp::HL::CopyOp>( + gatherOp.getLoc(), src, startIndices, dst, dstIndices, lengths); + + auto reshaped = createShapeTargetingOp<IREEInterp::HL::ReshapeOp>( + rewriter, gatherOp.getLoc(), dst, convertTypeToMemRef(gatherOp)); + rewriter.replaceOp( + gatherOp, wrapAsTensor(reshaped->getResult(0), gatherOp, rewriter)); + + return matchSuccess(); + } +}; + +struct SliceOpLowering : public XlaOpLowering<xla_hlo::SliceOp> { + using XlaOpLowering<xla_hlo::SliceOp>::XlaOpLowering; + + Operation *rewriteInternal( + xla_hlo::SliceOp *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + // XLA slice has value semantics, whereas the IREE slice creates a view. We + // lower it to a copy if all strides are one which may be transformed to a + // slice by later optimizations. + auto isNotOne = [](APInt stride) { return stride != 1; }; + if (llvm::any_of(op->strides(), isNotOne)) { + op->emitRemark() << "Could not lower slice op with non-singular strides"; + return nullptr; + } + + auto finalType = convertTypeToMemRef(*op); + auto src = operands[0]; + std::vector<Value *> dim_pieces; + auto dst = rewriter.create<IREEInterp::HL::AllocHeapOp>( + op->getLoc(), finalType, dim_pieces); + auto srcIndices = + rewriter.create<IREE::ConstantOp>(op->getLoc(), op->start_indices()); + auto lengths = + createArrayConstant(rewriter, op->getLoc(), finalType.getShape()); + + llvm::SmallVector<int64_t, 4> zero_offset; + zero_offset.resize(finalType.getRank(), 0); + auto dstIndices = createArrayConstant(rewriter, op->getLoc(), zero_offset); + + rewriter.create<IREEInterp::HL::CopyOp>(op->getLoc(), src, srcIndices, dst, + dstIndices, lengths); + return dst; + } +}; + +struct PadOpLowering : public XlaOpLowering<xla_hlo::PadOp> { + using XlaOpLowering::XlaOpLowering; + + Operation *rewriteInternal( + xla_hlo::PadOp *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto *src = operands[0]; + auto *paddingValue = operands[1]; + + // TODO(b/140836672) Support negative padding + for (int i = 0; i < op->edge_padding_high().getNumElements(); ++i) { + if (op->edge_padding_high().getValue<IntegerAttr>(i).getInt() < 0 || + op->edge_padding_low().getValue<IntegerAttr>(i).getInt() < 0) { + op->emitRemark() << "Could not lower pad op with negative padding"; + return nullptr; + } + } + + auto edgePaddingLowOp = + rewriter.create<IREE::ConstantOp>(op->getLoc(), op->edge_padding_low()); + auto edgePaddingHighOp = rewriter.create<IREE::ConstantOp>( + op->getLoc(), op->edge_padding_high()); + auto interiorPaddingOp = + rewriter.create<IREE::ConstantOp>(op->getLoc(), op->interior_padding()); + + return rewriter.create<IREEInterp::HL::PadOp>( + op->getLoc(), convertTypeToMemRef(*op), src, paddingValue, + edgePaddingLowOp, edgePaddingHighOp, interiorPaddingOp); + } +}; + +struct ReshapeOpLowering : public XlaOpLowering<xla_hlo::ReshapeOp> { + using XlaOpLowering::XlaOpLowering; + + Operation *rewriteInternal( + xla_hlo::ReshapeOp *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + return createShapeTargetingOp<IREEInterp::HL::ReshapeOp>( + rewriter, op->getLoc(), operands[0], convertTypeToMemRef(*op)); + } +}; + +struct TransposeOpLowering : public XlaOpLowering<xla_hlo::TransposeOp> { + using XlaOpLowering::XlaOpLowering; + + Operation *rewriteInternal( + xla_hlo::TransposeOp *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto permutationOp = + rewriter.create<IREE::ConstantOp>(op->getLoc(), op->permutation()); + + return rewriter.create<IREEInterp::HL::TransposeOp>( + op->getLoc(), convertTypeToMemRef(*op), operands[0], permutationOp); + } +}; + +struct ReverseOpLowering : public XlaOpLowering<xla_hlo::ReverseOp> { + using XlaOpLowering::XlaOpLowering; + + Operation *rewriteInternal( + xla_hlo::ReverseOp *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto reverseOp = + rewriter.create<IREE::ConstantOp>(op->getLoc(), op->dimensions()); + + return rewriter.create<IREEInterp::HL::ReverseOp>( + op->getLoc(), convertTypeToMemRef(*op), operands[0], reverseOp); + } +}; + +} // namespace + +void populateLowerXlaToInterpreterPatterns(OwningRewritePatternList &patterns, + MLIRContext *ctx) { + patterns + .insert<BroadcastInDimOpLowering, ConcatOpLowering, ConvertLowering, + CopyOpLowering, DotOpLowering, DynamicUpdateSliceOpLowering, + ExpOpLowering, FloorOpLowering, GatherOpLowering, LogOpLowering, + MaxOpLowering, MinOpLowering, PadOpLowering, ReshapeOpLowering, + ReverseOpLowering, RsqrtOpLowering, SelectOpLowering, + SliceOpLowering, TransposeOpLowering, TanhOpLowering>(ctx); +} + +namespace { +// Just for testing these passes. +// TODO(b/141337493) can we get rid of this pass entirely? +class LowerXLAToInterpreterDialectPass + : public FunctionPass<LowerXLAToInterpreterDialectPass> { + public: + void runOnFunction() override { + OwningRewritePatternList patterns; + populateLowerXlaToInterpreterPatterns(patterns, &getContext()); + + ConversionTarget target(getContext()); + target.addLegalDialect<IREEHLInterpreterDialect, IREEDialect>(); + if (failed(applyPartialConversion(getFunction(), target, patterns))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +static PassRegistration<LowerXLAToInterpreterDialectPass> pass( + "lower-xla-to-iree-interpreter", + "Convert all XLA functions to the IREE dialect"); + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Transforms/Interpreter/MakeExecutableABI.cpp b/compiler/Transforms/Interpreter/MakeExecutableABI.cpp new file mode 100644 index 0000000..793cb06 --- /dev/null +++ b/compiler/Transforms/Interpreter/MakeExecutableABI.cpp
@@ -0,0 +1,147 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Interpreter/HLOps.h" +#include "compiler/IR/Ops.h" +#include "compiler/Utils/OpCreationUtils.h" +#include "compiler/Utils/OpUtils.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +// Replaces a load_input op with valid IR that loads the input value. +LogicalResult replaceLoadInputOp(IREE::LoadInputOp bindOp) { + OpBuilder builder(bindOp); + + Value *newValue = nullptr; + auto dstType = bindOp.getResult()->getType(); + if (dstType.isa<TensorType>()) { + auto castOp = + builder.create<IREE::MemRefToTensorOp>(bindOp.getLoc(), bindOp.src()); + newValue = castOp.getResult(); + } else if (dstType.isIntOrIndexOrFloat()) { + auto loadOp = builder.create<LoadOp>(bindOp.getLoc(), dstType, bindOp.src(), + ArrayRef<Value *>{}); + newValue = loadOp.getResult(); + } else { + return bindOp.emitError() + << "Unsupported input destination type " << dstType; + } + + bindOp.replaceAllUsesWith(newValue); + bindOp.erase(); + + return success(); +} + +// Replaces a store_output op with valid IR that stores the output value. +LogicalResult replaceStoreOutputOp(IREE::StoreOutputOp bindOp) { + OpBuilder builder(bindOp); + + auto srcType = bindOp.src()->getType(); + if (srcType.isa<MemRefType>()) { + // Already stored into the output. + } else if (srcType.isa<TensorType>()) { + auto castOp = + builder.create<IREE::TensorToMemRefOp>(bindOp.getLoc(), bindOp.src()); + + // Insert a copy to our output parameter. + auto dst = bindOp.dst()->getType().cast<ShapedType>(); + if (!dst.hasStaticShape()) { + return bindOp.emitError() + << "Dynamic output args are not yet implemented"; + } + + auto zeroValues = llvm::SmallVector<int64_t, 4>(dst.getRank()); + auto zeros = createArrayConstant(builder, bindOp.getLoc(), zeroValues); + auto lengths = + createArrayConstant(builder, bindOp.getLoc(), dst.getShape()); + builder.create<IREEInterp::HL::CopyOp>(bindOp.getLoc(), castOp.getResult(), + zeros, bindOp.dst(), zeros, lengths); + } else if (srcType.isIntOrIndexOrFloat()) { + builder.create<StoreOp>(bindOp.getLoc(), bindOp.src(), bindOp.dst(), + ArrayRef<Value *>{}); + } else { + return bindOp.emitError() << "Unsupported output src type " << srcType; + } + + bindOp.erase(); + + return success(); +} + +// Strips iree.bind_* ops from |func|. +LogicalResult stripBindingOps(FuncOp func) { + // Find iree.load_input ops to replace with memref_to_tensor if needed. + SmallVector<IREE::LoadInputOp, 8> bindInputOps; + func.walk([&](IREE::LoadInputOp bindOp) { bindInputOps.push_back(bindOp); }); + for (auto &bindOp : bindInputOps) { + if (failed(replaceLoadInputOp(bindOp))) { + return failure(); + } + } + + // Find iree.store_output ops and replace with tensor_to_memref if needed. + SmallVector<IREE::StoreOutputOp, 8> bindOutputOps; + func.walk( + [&](IREE::StoreOutputOp bindOp) { bindOutputOps.push_back(bindOp); }); + for (auto &bindOp : bindOutputOps) { + if (failed(replaceStoreOutputOp(bindOp))) { + return failure(); + } + } + + return success(); +} + +} // namespace + +// Finds iree.executable.export functions and fixes up bindings. +// For the interpreter this really just means stripping the bind ops entirely. +class MakeExecutableABIPass : public ModulePass<MakeExecutableABIPass> { + public: + void runOnModule() override { + auto module = getModule(); + for (auto func : module.getOps<FuncOp>()) { + if (func.getAttr("iree.executable.export")) { + if (failed(stripBindingOps(func))) { + return signalPassFailure(); + } + } + } + } +}; + +std::unique_ptr<OpPassBase<ModuleOp>> createMakeExecutableABIPass() { + return std::make_unique<MakeExecutableABIPass>(); +} + +static PassRegistration<MakeExecutableABIPass> pass( + "iree-make-executable-abi", + "Makes functions match the IREE dispatch executable ABI."); + +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Transforms/Interpreter/Passes.h b/compiler/Transforms/Interpreter/Passes.h similarity index 100% rename from iree/compiler/Transforms/Interpreter/Passes.h rename to compiler/Transforms/Interpreter/Passes.h
diff --git a/iree/compiler/Transforms/Interpreter/Rewrites.h b/compiler/Transforms/Interpreter/Rewrites.h similarity index 100% rename from iree/compiler/Transforms/Interpreter/Rewrites.h rename to compiler/Transforms/Interpreter/Rewrites.h
diff --git a/compiler/Transforms/Interpreter/test/BUILD b/compiler/Transforms/Interpreter/test/BUILD new file mode 100644 index 0000000..fb38390 --- /dev/null +++ b/compiler/Transforms/Interpreter/test/BUILD
@@ -0,0 +1,16 @@ +# Tests for lowering MLIR in various dialects to IREE interpreter bytecode. + +load("//:build_defs.google.bzl", "iree_glob_lit_tests", "iree_setup_lit_package") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +iree_setup_lit_package( + data = [ + "///tools:iree-opt", + ], +) + +iree_glob_lit_tests()
diff --git a/iree/compiler/Transforms/Interpreter/test/clone.mlir b/compiler/Transforms/Interpreter/test/clone.mlir similarity index 100% rename from iree/compiler/Transforms/Interpreter/test/clone.mlir rename to compiler/Transforms/Interpreter/test/clone.mlir
diff --git a/iree/compiler/Transforms/Interpreter/test/make_executable_abi.mlir b/compiler/Transforms/Interpreter/test/make_executable_abi.mlir similarity index 100% rename from iree/compiler/Transforms/Interpreter/test/make_executable_abi.mlir rename to compiler/Transforms/Interpreter/test/make_executable_abi.mlir
diff --git a/compiler/Transforms/Interpreter/test/xla/BUILD b/compiler/Transforms/Interpreter/test/xla/BUILD new file mode 100644 index 0000000..32c7031 --- /dev/null +++ b/compiler/Transforms/Interpreter/test/xla/BUILD
@@ -0,0 +1,17 @@ +# Tests specific to lowering XLA to IREE. + +load("//:build_defs.google.bzl", "iree_glob_lit_tests", "iree_setup_lit_package") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +iree_setup_lit_package( + data = [ + "///tools:iree-opt", + "///tools:iree-run-mlir", + ], +) + +iree_glob_lit_tests()
diff --git a/iree/compiler/Transforms/Interpreter/test/xla/concat.mlir b/compiler/Transforms/Interpreter/test/xla/concat.mlir similarity index 100% rename from iree/compiler/Transforms/Interpreter/test/xla/concat.mlir rename to compiler/Transforms/Interpreter/test/xla/concat.mlir
diff --git a/iree/compiler/Transforms/Interpreter/test/xla/dynamic_update_slice.mlir b/compiler/Transforms/Interpreter/test/xla/dynamic_update_slice.mlir similarity index 100% rename from iree/compiler/Transforms/Interpreter/test/xla/dynamic_update_slice.mlir rename to compiler/Transforms/Interpreter/test/xla/dynamic_update_slice.mlir
diff --git a/iree/compiler/Transforms/Interpreter/test/xla/gather.mlir b/compiler/Transforms/Interpreter/test/xla/gather.mlir similarity index 100% rename from iree/compiler/Transforms/Interpreter/test/xla/gather.mlir rename to compiler/Transforms/Interpreter/test/xla/gather.mlir
diff --git a/iree/compiler/Transforms/Interpreter/test/xla/slice.mlir b/compiler/Transforms/Interpreter/test/xla/slice.mlir similarity index 100% rename from iree/compiler/Transforms/Interpreter/test/xla/slice.mlir rename to compiler/Transforms/Interpreter/test/xla/slice.mlir
diff --git a/compiler/Transforms/LegalizeTypeStorage.cpp b/compiler/Transforms/LegalizeTypeStorage.cpp new file mode 100644 index 0000000..780fc97 --- /dev/null +++ b/compiler/Transforms/LegalizeTypeStorage.cpp
@@ -0,0 +1,145 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/Utils/TypeConversionUtils.h" +#include "llvm/ADT/DenseSet.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Transforms/Utils.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +bool convertOperation(Operation *oldOp, OpBuilder &builder, + BlockAndValueMapping *mapping) { + OperationState state(oldOp->getLoc(), oldOp->getName()); + if (oldOp->getNumSuccessors() == 0) { + // Non-branching operations can just add all the operands. + for (auto *oldOperand : oldOp->getOperands()) { + state.operands.push_back(mapping->lookupOrDefault(oldOperand)); + } + } else { + // We add the operands separated by nullptr's for each successor. + unsigned firstSuccOperand = oldOp->getNumSuccessors() + ? oldOp->getSuccessorOperandIndex(0) + : oldOp->getNumOperands(); + auto opOperands = oldOp->getOpOperands(); + unsigned i = 0; + for (; i != firstSuccOperand; ++i) { + state.operands.push_back(mapping->lookupOrDefault(opOperands[i].get())); + } + for (unsigned succ = 0, e = oldOp->getNumSuccessors(); succ != e; ++succ) { + state.successors.push_back( + mapping->lookupOrDefault(oldOp->getSuccessor(succ))); + // Add sentinel to delineate successor operands. + state.operands.push_back(nullptr); + // Remap the successors operands. + for (auto *operand : oldOp->getSuccessorOperands(succ)) { + state.operands.push_back(mapping->lookupOrDefault(operand)); + } + } + } + for (const auto &oldType : oldOp->getResultTypes()) { + state.types.push_back(legalizeType(oldType)); + } + state.attributes = {oldOp->getAttrs().begin(), oldOp->getAttrs().end()}; + auto newOp = builder.createOperation(state); + for (int i = 0; i < newOp->getNumResults(); ++i) { + mapping->map(oldOp->getResult(i), newOp->getResult(i)); + } + return false; +} + +bool convertFunction(FuncOp oldFunction, FuncOp newFunction) { + OpBuilder builder(newFunction.getBody()); + BlockAndValueMapping mapping; + + // Create new blocks matching the expected arguments of the old ones. + // This sets up the block mappings to enable us to reference blocks forward + // during conversion. + newFunction.getBlocks().clear(); + for (auto &oldBlock : oldFunction.getBlocks()) { + auto *newBlock = builder.createBlock(&newFunction.getBody()); + mapping.map(&oldBlock, newBlock); + for (auto *oldArg : oldBlock.getArguments()) { + auto *newArg = newBlock->addArgument(legalizeType(oldArg->getType())); + mapping.map(oldArg, newArg); + } + } + + // Convert all ops in the blocks. + for (auto &oldBlock : oldFunction.getBlocks()) { + builder.setInsertionPointToEnd(mapping.lookupOrNull(&oldBlock)); + for (auto &oldOp : oldBlock.getOperations()) { + if (convertOperation(&oldOp, builder, &mapping)) { + return true; + } + } + } + + return false; +} + +} // namespace + +class LegalizeTypeStoragePass : public ModulePass<LegalizeTypeStoragePass> { + public: + void runOnModule() override { + auto module = getModule(); + + // Build a list of (oldFunction, newFunction) for all functions we need to + // replace. This will ensure that when we go to convert function bodies we + // have only new functions defined. + std::vector<std::pair<FuncOp, FuncOp>> convertedFunctions; + + for (auto oldFunction : module.getOps<FuncOp>()) { + // Create the replacement function, ensuring that we copy attributes. + auto newFunction = FuncOp::create( + oldFunction.getLoc(), oldFunction.getName(), + legalizeType(oldFunction.getType()).cast<FunctionType>(), + oldFunction.getDialectAttrs()); + convertedFunctions.push_back({oldFunction, newFunction}); + + // Perform the actual body conversion now that we have proper signatures. + if (convertFunction(oldFunction, newFunction)) { + return signalPassFailure(); + } + } + + // Replace functions in the module. + for (auto &pair : convertedFunctions) { + pair.first.erase(); + module.push_back(pair.second); + } + } +}; + +std::unique_ptr<OpPassBase<ModuleOp>> createLegalizeTypeStoragePass() { + return std::make_unique<LegalizeTypeStoragePass>(); +} + +static PassRegistration<LegalizeTypeStoragePass> pass( + "iree-legalize-type-storage", + "Legalizes types to ones supported by the IREE VM."); + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Transforms/LowerStdToIreeDialect.cpp b/compiler/Transforms/LowerStdToIreeDialect.cpp new file mode 100644 index 0000000..5fb5c2c --- /dev/null +++ b/compiler/Transforms/LowerStdToIreeDialect.cpp
@@ -0,0 +1,77 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Dialect.h" +#include "compiler/IR/Ops.h" +#include "compiler/Utils/MemRefUtils.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +struct ConstantOpLowering : public OpRewritePattern<ConstantOp> { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(ConstantOp op, + PatternRewriter &rewriter) const override { + if (auto elementsValue = op.getValue().dyn_cast<ElementsAttr>()) { + auto ireeConst = + rewriter.create<IREE::ConstantOp>(op.getLoc(), elementsValue); + + auto result = wrapAsTensor(ireeConst.getResult(), op, rewriter); + rewriter.replaceOp(op, result); + return matchSuccess(); + } + + auto type = op.getValue().getType(); + if (!type.isIntOrFloat()) { + return matchFailure(); + } + auto elementsValue = + DenseElementsAttr::get(RankedTensorType::get({}, type), op.getValue()); + auto ireeConst = + rewriter.create<IREE::ConstantOp>(op.getLoc(), elementsValue); + rewriter.replaceOpWithNewOp<IREE::MemRefToScalarOp>(op, ireeConst); + return matchSuccess(); + } +}; + +struct ExtractElementOpLowering : public OpRewritePattern<ExtractElementOp> { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(ExtractElementOp op, + PatternRewriter &rewriter) const override { + Value *memRefInput = + wrapAsMemRef(loadAccessValue(op.getLoc(), op.getAggregate(), rewriter), + op, rewriter); + + SmallVector<Value *, 4> indices = {op.indices().begin(), + op.indices().end()}; + rewriter.replaceOpWithNewOp<LoadOp>(op, memRefInput, indices); + return matchSuccess(); + } +}; + +} // namespace + +void populateLowerStdToIreePatterns(OwningRewritePatternList &patterns, + MLIRContext *ctx) { + patterns.insert<ConstantOpLowering, ExtractElementOpLowering>(ctx); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Transforms/LowerXLAToIreeDialect.cpp b/compiler/Transforms/LowerXLAToIreeDialect.cpp new file mode 100644 index 0000000..e719c42 --- /dev/null +++ b/compiler/Transforms/LowerXLAToIreeDialect.cpp
@@ -0,0 +1,44 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Ops.h" +#include "compiler/Utils/MemRefUtils.h" +#include "mlir/IR/PatternMatch.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +struct ConstOpLowering : public OpRewritePattern<xla_hlo::ConstOp> { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(xla_hlo::ConstOp op, + PatternRewriter &rewriter) const override { + auto ireeConst = rewriter.create<IREE::ConstantOp>(op.getLoc(), op.value()); + rewriter.replaceOp(op, wrapAsTensor(ireeConst, op, rewriter)); + return matchSuccess(); + } +}; + +} // namespace + +void populateLowerXlaToIreePatterns(OwningRewritePatternList &patterns, + MLIRContext *ctx) { + patterns.insert<ConstOpLowering>(ctx); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Transforms/Passes.h b/compiler/Transforms/Passes.h similarity index 100% rename from iree/compiler/Transforms/Passes.h rename to compiler/Transforms/Passes.h
diff --git a/iree/compiler/Transforms/Rewrites.h b/compiler/Transforms/Rewrites.h similarity index 100% rename from iree/compiler/Transforms/Rewrites.h rename to compiler/Transforms/Rewrites.h
diff --git a/compiler/Transforms/Sequencer/AssignExecutableOrdinals.cpp b/compiler/Transforms/Sequencer/AssignExecutableOrdinals.cpp new file mode 100644 index 0000000..f64f313 --- /dev/null +++ b/compiler/Transforms/Sequencer/AssignExecutableOrdinals.cpp
@@ -0,0 +1,75 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/StructureOps.h" +#include "compiler/Utils/OpUtils.h" +#include "llvm/ADT/DenseMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/Utils.h" + +namespace mlir { +namespace iree_compiler { + +class AssignExecutableOrdinalsPass + : public ModulePass<AssignExecutableOrdinalsPass> { + public: + void runOnModule() override { + Builder builder(getModule()); + int nextExecutableOrdinal = 0; + for (auto multiArchExecutableOp : + getModule().getOps<IREE::MultiArchExecutableOp>()) { + multiArchExecutableOp.setAttr( + "iree.ordinal", builder.getI32IntegerAttr(nextExecutableOrdinal++)); + + // We'll scan for all entry points in the first executable. Then on all + // other executables we can reuse the ordinals (ensuring that iteration + // order does not matter). + llvm::DenseMap<StringRef, FuncOp> entryPointMap; + for (auto executableOp : + multiArchExecutableOp.getBlock().getOps<IREE::ExecutableOp>()) { + executableOp.setAttr("iree.ordinal", + multiArchExecutableOp.getAttr("iree.ordinal")); + int nextEntryPointOrdinal = 0; + for (auto funcOp : executableOp.getInnerModule().getOps<FuncOp>()) { + if (!funcOp.getAttr("iree.executable.export")) continue; + auto it = entryPointMap.find(funcOp.getName()); + if (it == entryPointMap.end()) { + funcOp.setAttr("iree.ordinal", + builder.getI32IntegerAttr(nextEntryPointOrdinal++)); + entryPointMap.insert({funcOp.getName(), funcOp}); + } else { + funcOp.setAttr("iree.ordinal", it->second.getAttr("iree.ordinal")); + } + } + } + } + } +}; + +std::unique_ptr<OpPassBase<ModuleOp>> createAssignExecutableOrdinalsPass() { + return std::make_unique<AssignExecutableOrdinalsPass>(); +} + +static PassRegistration<AssignExecutableOrdinalsPass> pass( + "iree-assign-executable-ordinals", + "Assigns executable and entry point ordinals"); + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Transforms/Sequencer/AssignExecutableWorkloadAttrs.cpp b/compiler/Transforms/Sequencer/AssignExecutableWorkloadAttrs.cpp new file mode 100644 index 0000000..20d51a2 --- /dev/null +++ b/compiler/Transforms/Sequencer/AssignExecutableWorkloadAttrs.cpp
@@ -0,0 +1,125 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Sequencer/LLOps.h" +#include "compiler/IR/StructureOps.h" +#include "compiler/Utils/OpUtils.h" +#include "llvm/ADT/StringMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/Utils.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +struct WorkloadInfo { + SmallVector<ElementsAttr, 4> staticWorkloads; + SmallVector<Value *, 4> dynamicWorkloads; +}; + +// Finds all dispatches and records their workload attributes mapped by +// (executable ordinal, entry point ordinal). +llvm::StringMap<llvm::StringMap<WorkloadInfo>> gatherExecutableWorkloadInfos( + ModuleOp moduleOp) { + llvm::StringMap<llvm::StringMap<WorkloadInfo>> workloadInfos; + for (auto funcOp : moduleOp.getOps<FuncOp>()) { + funcOp.walk([&](IREESeq::LL::DynamicDispatchOp op) { + auto &workloadInfo = + workloadInfos[op.getExecutable()][op.getEntryPoint()]; + workloadInfo.dynamicWorkloads.push_back(op.getWorkload()); + }); + funcOp.walk([&](IREESeq::LL::StaticDispatchOp op) { + auto &workloadInfo = + workloadInfos[op.getExecutable()][op.getEntryPoint()]; + for (auto existingWorkloadAttr : workloadInfo.staticWorkloads) { + if (existingWorkloadAttr == op.getWorkload()) { + return; // Already present, ignore. + } + } + workloadInfo.staticWorkloads.push_back(op.getWorkload()); + }); + } + return workloadInfos; +} + +// Adds attributes to the given executable entry point describing the workload +// info to the backends that will be processing them. +LogicalResult attributeExecutableEntryPointWorkload( + FuncOp entryPointOp, const WorkloadInfo &workloadInfo) { + if (!workloadInfo.dynamicWorkloads.empty()) { + return entryPointOp.emitError() << "Dynamic workloads not yet supported"; + } + if (workloadInfo.staticWorkloads.size() != 1) { + return entryPointOp.emitError() << "Static workload sizes differ in shape"; + } + + // Easy because we just support static workloads now. + // When this code is adapted to support dynamic workloads we'll want to put + // a pair of attrs describing which dimensions may be static and which args + // have the dynamic values to reference. + entryPointOp.setAttr("iree.executable.workload", + workloadInfo.staticWorkloads.front()); + + return success(); +} + +} // namespace + +class AssignExecutableWorkloadAttrsPass + : public ModulePass<AssignExecutableWorkloadAttrsPass> { + public: + void runOnModule() override { + Builder builder(getModule()); + + // Find all dispatches and capture their workload information. + // We store this information by executable and then entry point ordinal. + auto executableWorkloadInfos = gatherExecutableWorkloadInfos(getModule()); + + // Process each executable with the workload information. + for (auto &executableIt : executableWorkloadInfos) { + auto multiArchExecutableOp = cast<IREE::MultiArchExecutableOp>( + getModule().lookupSymbol(executableIt.first())); + for (auto executableOp : + multiArchExecutableOp.getBlock().getOps<IREE::ExecutableOp>()) { + for (auto &entryPointIt : executableIt.second) { + auto funcOp = cast<FuncOp>( + executableOp.getInnerModule().lookupSymbol(entryPointIt.first())); + if (failed(attributeExecutableEntryPointWorkload( + funcOp, entryPointIt.second))) { + return signalPassFailure(); + } + } + } + } + } +}; + +std::unique_ptr<OpPassBase<ModuleOp>> +createAssignExecutableWorkloadAttrsPass() { + return std::make_unique<AssignExecutableWorkloadAttrsPass>(); +} + +static PassRegistration<AssignExecutableWorkloadAttrsPass> pass( + "iree-assign-executable-workload-attrs", + "Assigns executable entrypoint workload attributes"); + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Transforms/Sequencer/BUILD b/compiler/Transforms/Sequencer/BUILD new file mode 100644 index 0000000..a86088b --- /dev/null +++ b/compiler/Transforms/Sequencer/BUILD
@@ -0,0 +1,46 @@ +# Transforms specific to the IREE sequencer. + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "Sequencer", + srcs = [ + "AssignExecutableOrdinals.cpp", + "AssignExecutableWorkloadAttrs.cpp", + "FoldCompatibleDispatchRegions.cpp", + "IdentifyDispatchRegions.cpp", + "IdentifyReductionRegions.cpp", + "LegalizeInputs.cpp", + "LowerSequencerDialect.cpp", + "LowerStdToSequencerDialect.cpp", + "LowerToSequencerDialect.cpp", + "LowerXLAToSequencerDialect.cpp", + "OutlineDispatchRegions.cpp", + "OutlineReductionRegions.cpp", + "RematerializeDispatchConstants.cpp", + ], + hdrs = [ + "Passes.h", + "Rewrites.h", + ], + deps = [ + "///compiler/IR", + "///compiler/IR/Sequencer", + "///compiler/Transforms", + "///compiler/Utils", + "@llvm//:support", + "@local_config_mlir//:IR", + "@local_config_mlir//:Pass", + "@local_config_mlir//:StandardDialectRegistration", + "@local_config_mlir//:StandardOps", + "@local_config_mlir//:Support", + "@local_config_mlir//:TransformUtils", + "@local_config_mlir//:Transforms", + "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo", + "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_legalize_to_standard", + "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_lower_general_dot", + ], +)
diff --git a/iree/compiler/Transforms/Sequencer/CMakeLists.txt b/compiler/Transforms/Sequencer/CMakeLists.txt similarity index 100% rename from iree/compiler/Transforms/Sequencer/CMakeLists.txt rename to compiler/Transforms/Sequencer/CMakeLists.txt
diff --git a/compiler/Transforms/Sequencer/FoldCompatibleDispatchRegions.cpp b/compiler/Transforms/Sequencer/FoldCompatibleDispatchRegions.cpp new file mode 100644 index 0000000..a48f735 --- /dev/null +++ b/compiler/Transforms/Sequencer/FoldCompatibleDispatchRegions.cpp
@@ -0,0 +1,63 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Ops.h" +#include "compiler/Utils/DispatchUtils.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/Utils.h" + +namespace mlir { +namespace iree_compiler { + +// Identifies dispatch regions that have compatible workloads and folds them. +// This relies on CSE having deduped workloads to simplify the logic to simply +// looking for dispatch regions using the same values. +class FoldCompatibleDispatchRegionsPass + : public FunctionPass<FoldCompatibleDispatchRegionsPass> { + public: + void runOnFunction() override { + auto func = getFunction(); + for (auto &block : func) { + if (failed(mergeBlockDispatchRegions(func, &block))) { + return signalPassFailure(); + } + } + } +}; + +std::unique_ptr<OpPassBase<FuncOp>> createFoldCompatibleDispatchRegionsPass() { + return std::make_unique<FoldCompatibleDispatchRegionsPass>(); +} + +static PassRegistration<FoldCompatibleDispatchRegionsPass> pass( + "iree-fold-compatible-dispatch-regions", + "Folds dispatch regions that have compatible workloads."); + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Transforms/Sequencer/IdentifyDispatchRegions.cpp b/compiler/Transforms/Sequencer/IdentifyDispatchRegions.cpp new file mode 100644 index 0000000..54ce60e --- /dev/null +++ b/compiler/Transforms/Sequencer/IdentifyDispatchRegions.cpp
@@ -0,0 +1,259 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <algorithm> + +#include "compiler/IR/Ops.h" +#include "compiler/Utils/DispatchUtils.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/Utils.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +// Returns true if the given |op| can be dispatched in all cases. +// Other passes may handle special cases of these ops but this initial +// identification is conservative. +bool isDispatchableOp(Operation *op) { + if (op->getDialect() && op->getDialect()->getNamespace().startswith("iree")) { + // Ignore things we've already produced as they should only relate to + // sequencer operations. + return false; + } else if (op->isKnownTerminator()) { + // Currently we skip all terminators as we want to leave them in the block + // to keep it valid. Future folding passes may take care of them if they are + // worth bringing into the dispatch region. + return false; + } else if (isa<CallOp>(op)) { + // This may be handled by a control-flow folding pass later once we have + // done our initial analysis and know what functions are compatible. + return false; + } else if (isa<CallIndirectOp>(op)) { + // Indirect calls are not supported in dispatch code. + return false; + } else if (isa<AllocOp>(op)) { + // Allocations are sequencer ops. + // Note that we could support static allocations (convert to stack/etc). + return false; + } else if (isa<ConstantOp>(op)) { + // Constants are handled in the RematerializeDispatchConstants pass. + // We do that independently so that we can more easily see the use of + // constants across all dispatches instead of just on an individual basis + // as we do here. + return false; + } else if (isa<xla_hlo::DynamicUpdateSliceOp>(op)) { + // TODO(benvanik): lower these to the sequencer dialect prior to ID'ing. + return false; + } + return true; +} + +// Returns true if the given |op| can have other ops fused into it. +// This is sketchy and it'd be nice to define this as an op property instead. +// +// What we are looking for in foldable ops is whether the execution of the op +// when fused has some possible benefit (or at least, a non-negative cost). +// Eventually we want to allow backends to vote on this and allow multiple +// folding strategies within the same executable. For now we just hardcode what +// we know for the ops we have. +// +// Preconditions: isDispatchableOp(op) == true. +bool isFusionRootOp(Operation *op) { + if (isa<xla_hlo::DotOp>(op) || isa<xla_hlo::ConvOp>(op)) { + // We have hand-written kernels for these right now we want to stand alone. + // When we do a bit more magic we should allow these ops to fold. + return false; + } + return true; +} + +// Returns true if the given |op| can be fused into other ops. +// +// Ops that perform narrowing on shapes (such as reduction ops) should not +// generally be fused with other downstream ops (probably...). This avoids +// potential oversampling and indexing issues and allows backends to perform +// more efficient rooted cascading reduction dispatches. +// +// Preconditions: isDispatchableOp(op) == true. +bool isFusableOp(Operation *op) { + if (isa<xla_hlo::DotOp>(op) || isa<xla_hlo::ConvOp>(op)) { + return false; + } else if (isa<xla_hlo::ReduceOp>(op)) { + // Reduction is usually a dedicated root operation - we can shove things in + // the front of it but not behind. + return false; + } + return true; +} + +// Puts all of the |unsortedOps| into |sortedOps| in an arbitrary topological +// order. +// https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search +// +// Preconditions: |unsortedOps| has no cycles within the set of ops. +std::vector<Operation *> sortOpsTopologically( + const llvm::SetVector<Operation *> &unsortedOps) { + llvm::SetVector<Operation *> unmarkedOps; + unmarkedOps.insert(unsortedOps.begin(), unsortedOps.end()); + llvm::SetVector<Operation *> markedOps; + + using VisitFn = std::function<void(Operation * op)>; + VisitFn visit = [&](Operation *op) { + if (markedOps.count(op) > 0) return; + for (auto *result : op->getResults()) { + for (auto *user : result->getUsers()) { + // Don't visit ops not in our set. + if (unsortedOps.count(user) == 0) continue; + visit(user); + } + } + markedOps.insert(op); + }; + + while (!unmarkedOps.empty()) { + auto *op = unmarkedOps.pop_back_val(); + visit(op); + } + + auto sortedOps = markedOps.takeVector(); + std::reverse(sortedOps.begin(), sortedOps.end()); + return sortedOps; +} + +// Recursively traverses the IR DAG along the operand edges to find ops we are +// able to fuse and appends them to |subgraph|. +void gatherFusionOps(Operation *op, llvm::SetVector<Operation *> *subgraph) { + // Skip ops that are used outside of the subgraph we are building. + for (auto *result : op->getResults()) { + if (result->use_empty() || result->hasOneUse()) continue; + for (auto *user : result->getUsers()) { + if (subgraph->count(user) == 0) { + // Op that consumes the result is not (yet) in the subgraph. + // For now we'll ignore these as it may represent a fork that we don't + // want to join too early. + return; + } + } + } + + // Walk backward up to ops providing our input operands. + for (auto *operand : op->getOperands()) { + auto *sourceOp = operand->getDefiningOp(); + if (!sourceOp) continue; + if (subgraph->count(sourceOp) == 0) { + if (isDispatchableOp(sourceOp) && isFusableOp(sourceOp)) { + gatherFusionOps(sourceOp, subgraph); + } + } + } + + subgraph->insert(op); +} + +// Finds all ops that can be fused together with the given |rootOp| by searching +// backwards in the op order through input edges. +// Returns a topologically sorted list of all fused ops with |rootOp| at the +// end. +std::vector<Operation *> findFusionSubgraphFromRoot(Operation *rootOp) { + if (!isFusionRootOp(rootOp)) { + return {rootOp}; + } + llvm::SetVector<Operation *> subgraph; + subgraph.insert(rootOp); + gatherFusionOps(rootOp, &subgraph); + return sortOpsTopologically(subgraph); +} + +// Identifies ranges of dispatchable ops and moves them into dispatch regions. +LogicalResult identifyBlockDispatchRegions(FuncOp func, Block *block) { + // Fixed point iteration until we can no longer fuse anything. + bool didFindAnyNewRegions; + do { + // Iterate in reverse so we root further along in the op list. + didFindAnyNewRegions = false; + for (auto &rootOp : llvm::reverse(*block)) { + if (!isDispatchableOp(&rootOp)) { + // Op should remain at the sequencer level. + continue; + } + + // Attempt to find all operations, including rootOp, that can be fused. + // The ops will be sorted in topological order with rootOp as the last op. + // Worst case we may end up with a subgraph of only the rootOp. + auto fusedSubgraph = findFusionSubgraphFromRoot(&rootOp); + + // Compute the workload based on the output shape. + // When variadic all output shapes match so we can just take the first. + auto *workload = calculateWorkload(&rootOp, rootOp.getResult(0)); + + // Try to build a dispatch region from this root. + if (failed(buildDispatchRegion(func, block, workload, fusedSubgraph))) { + return failure(); + } + + // Successfully created a dispatch region from the ops and we must now + // start over again as we've likely trashed the whole block structure. + didFindAnyNewRegions = true; + break; + } + } while (didFindAnyNewRegions); + return success(); +} + +} // namespace + +// Identifies dispatchable ops and moves them into iree.dispatch_regions. +// Some ops, such as call, will be deferred until following passes. +class IdentifyDispatchRegionsPass + : public FunctionPass<IdentifyDispatchRegionsPass> { + public: + void runOnFunction() override { + auto func = getFunction(); + for (auto &block : func) { + if (failed(identifyBlockDispatchRegions(func, &block))) { + return signalPassFailure(); + } + } + } +}; + +std::unique_ptr<OpPassBase<FuncOp>> createIdentifyDispatchRegionsPass() { + return std::make_unique<IdentifyDispatchRegionsPass>(); +} + +static PassRegistration<IdentifyDispatchRegionsPass> pass( + "iree-identify-dispatch-regions", + "Conservatively identifies dispatch regions in functions."); + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Transforms/Sequencer/IdentifyReductionRegions.cpp b/compiler/Transforms/Sequencer/IdentifyReductionRegions.cpp new file mode 100644 index 0000000..6c9a48a --- /dev/null +++ b/compiler/Transforms/Sequencer/IdentifyReductionRegions.cpp
@@ -0,0 +1,163 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <algorithm> + +#include "compiler/IR/Ops.h" +#include "compiler/IR/Types.h" +#include "compiler/Utils/DispatchUtils.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/Utils.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +// Builds a new iree.reduction_region with the given |invocationRegion|. +// The new region will be inserted after |originalOp|. +// +// All |invocationRegion| ops must be compatible with the |workload| specified +// as they will all be dispatched with the same workgroup structure. The +// |invocationRegion| will not be modified. +LogicalResult buildReductionRegion(Operation *originalOp, + ArrayRef<Value *> operands, + ArrayRef<Value *> initialValues, + ArrayRef<int64_t> dimensions, + Region &invocationRegion) { + OpBuilder parentBuilder(originalOp); + + // Compute the workload based on the output shape. + // When variadic all output shapes match so we can just take the first. + auto *workload = calculateWorkload(originalOp, originalOp->getResult(0)); + + // Build the region op and add it to the parent block. + SmallVector<Type, 4> resultTypes{originalOp->getResultTypes()}; + auto reductionRegionOp = parentBuilder.create<IREE::ReductionRegionOp>( + originalOp->getLoc(), resultTypes, workload, operands, initialValues, + dimensions); + + // Create the block and setup the arg mapping for captured values. + BlockAndValueMapping mapping; + invocationRegion.cloneInto(&reductionRegionOp.getBody(), mapping); + + // Replace xla_hlo.return -> iree.return. + OpBuilder regionBuilder(reductionRegionOp.getBody()); + reductionRegionOp.walk([&](xla_hlo::ReturnOp returnOp) { + regionBuilder.setInsertionPoint(returnOp); + SmallVector<Value *, 4> returnValues(returnOp.getOperands()); + regionBuilder.create<IREE::ReturnOp>(returnOp.getLoc(), returnValues); + returnOp.erase(); + }); + + // Replace usage of values with the results of the region. + for (int i = 0; i < originalOp->getNumResults(); ++i) { + originalOp->getResult(i)->replaceAllUsesWith( + reductionRegionOp.getResult(i)); + } + + return success(); +} + +// Converts an xla_hlo::ReduceOp to a reduction region and inlines the target +// computation into the region body. +LogicalResult buildReductionRegionFromXLAReduceOp(xla_hlo::ReduceOp reduceOp) { + SmallVector<Value *, 4> operands(reduceOp.getOperands()); + OperandAdaptor<xla_hlo::ReduceOp> adaptor(operands); + + SmallVector<int64_t, 4> dimensions; + for (auto dim : reduceOp.dimensions().getIntValues()) { + dimensions.push_back(dim.getSExtValue()); + } + + // Create the iree.reduction_region. + if (failed(buildReductionRegion(reduceOp, adaptor.operands(), + adaptor.init_values(), dimensions, + reduceOp.body()))) { + return failure(); + } + + // Remove original XLA reduction op. + reduceOp.erase(); + + return success(); +} + +// Identifies reduction ops and moves them into reduction regions. +LogicalResult identifyBlockReductionRegions(FuncOp funcOp, Block *block) { + // Fixed point iteration until we can no longer fuse anything. + bool didFindAnyNewRegions; + do { + // Iterate in reverse so we root further along in the op list. + didFindAnyNewRegions = false; + for (auto &rootOp : llvm::reverse(*block)) { + if (auto reduceOp = dyn_cast<xla_hlo::ReduceOp>(rootOp)) { + if (failed(buildReductionRegionFromXLAReduceOp(reduceOp))) { + return failure(); + } + + // Successfully created a dispatch region from the ops and we must now + // start over again as we've likely trashed the whole block structure. + didFindAnyNewRegions = true; + break; + } + } + } while (didFindAnyNewRegions); + return success(); +} + +} // namespace + +// Identifies reduction ops and moves their targets into iree.reduction_regions. +class IdentifyReductionRegionsPass + : public ModulePass<IdentifyReductionRegionsPass> { + public: + void runOnModule() override { + for (auto funcOp : getModule().getOps<FuncOp>()) { + for (auto &block : funcOp) { + if (failed(identifyBlockReductionRegions(funcOp, &block))) { + return signalPassFailure(); + } + } + } + } +}; + +std::unique_ptr<OpPassBase<ModuleOp>> createIdentifyReductionRegionsPass() { + return std::make_unique<IdentifyReductionRegionsPass>(); // NOLINT +} + +static PassRegistration<IdentifyReductionRegionsPass> pass( + "iree-identify-reduction-regions", + "Identifies reduction regions based on input reduction ops."); + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Transforms/Sequencer/LegalizeInputs.cpp b/compiler/Transforms/Sequencer/LegalizeInputs.cpp new file mode 100644 index 0000000..b8c860c --- /dev/null +++ b/compiler/Transforms/Sequencer/LegalizeInputs.cpp
@@ -0,0 +1,53 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Dialect.h" +#include "compiler/Transforms/Rewrites.h" +#include "compiler/Transforms/Sequencer/Rewrites.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" +#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" + +namespace mlir { +namespace iree_compiler { +namespace { + +class LegalizeInputOpsPass + : public FunctionPass<LegalizeInputOpsPass> { + public: + void runOnFunction() override { + OwningRewritePatternList patterns; + xla_hlo::PopulateGeneralDotOpLoweringPatterns(&patterns, &getContext()); + + ConversionTarget target(getContext()); + target.addLegalDialect<xla_hlo::XlaHloDialect, StandardOpsDialect>(); + target.addLegalOp<FuncOp, ReturnOp>(); + target.addIllegalOp<xla_hlo::DotGeneralOp, xla_hlo::WhileOp>(); + if (failed(applyFullConversion(getFunction(), target, patterns))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr<OpPassBase<FuncOp>> createLegalizeInputOpsPass() { + return std::make_unique<LegalizeInputOpsPass>(); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Transforms/Sequencer/LowerSequencerDialect.cpp b/compiler/Transforms/Sequencer/LowerSequencerDialect.cpp new file mode 100644 index 0000000..63408b7 --- /dev/null +++ b/compiler/Transforms/Sequencer/LowerSequencerDialect.cpp
@@ -0,0 +1,305 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Dialect.h" +#include "compiler/IR/Ops.h" +#include "compiler/IR/Sequencer/HLDialect.h" +#include "compiler/IR/Sequencer/HLOps.h" +#include "compiler/IR/Sequencer/LLDialect.h" +#include "compiler/IR/Sequencer/LLOps.h" +#include "compiler/IR/StructureOps.h" +#include "compiler/Utils/TypeConversionUtils.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/Utils.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +template <typename SrcOp> +class SequencerLoweringPattern : public OpConversionPattern<SrcOp> { + public: + SequencerLoweringPattern(MLIRContext *context, TypeConverter &typeConverter) + : OpConversionPattern<SrcOp>(context), typeConverter_(typeConverter) {} + + protected: + TypeConverter &typeConverter_; +}; + +// Returns an integer scalar memref containing the offset specified by |indices| +// within |type|. +Value *computeOffset(Location loc, Value *reference, Value *indices, + OpBuilder &builder) { + auto referenceType = reference->getType().cast<ShapedType>(); + auto *shapeMemRef = builder + .create<IREESeq::LL::AllocHeapOp>( + loc, + MemRefType::get({referenceType.getRank()}, + builder.getIntegerType(32)), + ArrayRef<Value *>{}) + .getResult(); + builder.create<IREESeq::LL::ShapeOp>(loc, reference, shapeMemRef); + auto *resultMemRef = + builder + .create<IREESeq::LL::AllocHeapOp>( + loc, MemRefType::get({}, builder.getIntegerType(32)), + ArrayRef<Value *>{}) + .getResult(); + auto elementSizeAttr = builder.getIntegerAttr( + builder.getIntegerType(8), referenceType.getElementTypeBitWidth() / 8); + builder.create<IREESeq::LL::ComputeOffsetOp>( + loc, shapeMemRef, elementSizeAttr, indices, resultMemRef); + return resultMemRef; +} + +// Returns a tuple of (offset, length) integer scalar memrefs with the range +// specified by |indices| and |lengths| within |type|. +std::pair<Value *, Value *> computeRange(Location loc, Value *reference, + Value *indices, Value *lengths, + OpBuilder &builder) { + auto referenceType = reference->getType().cast<ShapedType>(); + auto *shapeMemRef = builder + .create<IREESeq::LL::AllocHeapOp>( + loc, + MemRefType::get({referenceType.getRank()}, + builder.getIntegerType(32)), + ArrayRef<Value *>{}) + .getResult(); + builder.create<IREESeq::LL::ShapeOp>(loc, reference, shapeMemRef); + auto *offsetMemRef = + builder + .create<IREESeq::LL::AllocHeapOp>( + loc, MemRefType::get({}, builder.getIntegerType(32)), + ArrayRef<Value *>{}) + .getResult(); + auto *lengthMemRef = + builder + .create<IREESeq::LL::AllocHeapOp>( + loc, MemRefType::get({}, builder.getIntegerType(32)), + ArrayRef<Value *>{}) + .getResult(); + auto elementSizeAttr = builder.getIntegerAttr( + builder.getIntegerType(8), referenceType.getElementTypeBitWidth() / 8); + builder.create<IREESeq::LL::ComputeRangeOp>(loc, shapeMemRef, elementSizeAttr, + indices, lengths, offsetMemRef, + lengthMemRef); + return {offsetMemRef, lengthMemRef}; +} + +struct LowerSliceOpPattern + : public SequencerLoweringPattern<IREESeq::HL::SliceOp> { + using SequencerLoweringPattern::SequencerLoweringPattern; + + PatternMatchResult matchAndRewrite( + IREESeq::HL::SliceOp op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + OperandAdaptor<IREESeq::HL::SliceOp> operandAdaptor(operands); + auto range = computeRange(op.getLoc(), operandAdaptor.src(), + operandAdaptor.indices(), + operandAdaptor.lengths(), rewriter); + rewriter.replaceOpWithNewOp<IREESeq::LL::DynamicSliceOp>( + op, typeConverter_.convertType(op.getType()), + ArrayRef<Value *>{operandAdaptor.src(), range.first, range.second}, + op.getAttrs()); + return matchSuccess(); + } +}; + +struct LowerShapeOpPattern + : public SequencerLoweringPattern<IREESeq::HL::ShapeOp> { + using SequencerLoweringPattern::SequencerLoweringPattern; + + PatternMatchResult matchAndRewrite( + IREESeq::HL::ShapeOp op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto *shapeMemRef = + rewriter + .create<IREESeq::LL::AllocHeapOp>( + op.getLoc(), + MemRefType::get({op.getType().cast<ShapedType>().getRank()}, + rewriter.getIntegerType(64)), + ArrayRef<Value *>{}) + .getResult(); + op.replaceAllUsesWith(shapeMemRef); + rewriter.replaceOpWithNewOp<IREESeq::LL::ShapeOp>(op, operands[0], + shapeMemRef); + return matchSuccess(); + } +}; + +struct LowerCopyOpPattern + : public SequencerLoweringPattern<IREESeq::HL::CopyOp> { + using SequencerLoweringPattern::SequencerLoweringPattern; + + PatternMatchResult matchAndRewrite( + IREESeq::HL::CopyOp op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + OperandAdaptor<IREESeq::HL::CopyOp> operandAdaptor(operands); + auto *srcOffsetMemRef = + computeOffset(op.getLoc(), operandAdaptor.src(), + operandAdaptor.srcIndices(), rewriter); + auto dstRange = computeRange(op.getLoc(), operandAdaptor.dst(), + operandAdaptor.dstIndices(), + operandAdaptor.lengths(), rewriter); + rewriter.replaceOpWithNewOp<IREESeq::LL::DynamicCopyOp>( + op, operandAdaptor.src(), srcOffsetMemRef, operandAdaptor.dst(), + dstRange.first, dstRange.second); + return matchSuccess(); + } +}; + +struct LowerFillOpPattern + : public SequencerLoweringPattern<IREESeq::HL::FillOp> { + using SequencerLoweringPattern::SequencerLoweringPattern; + + PatternMatchResult matchAndRewrite( + IREESeq::HL::FillOp op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + OperandAdaptor<IREESeq::HL::FillOp> operandAdaptor(operands); + auto dstRange = computeRange(op.getLoc(), operandAdaptor.dst(), + operandAdaptor.dstIndices(), + operandAdaptor.lengths(), rewriter); + rewriter.replaceOpWithNewOp<IREESeq::LL::DynamicFillOp>( + op, operandAdaptor.value(), operandAdaptor.dst(), dstRange.first, + dstRange.second); + return matchSuccess(); + } +}; + +struct LowerBranchOpPattern + : public SequencerLoweringPattern<IREESeq::HL::BranchOp> { + using SequencerLoweringPattern< + IREESeq::HL::BranchOp>::SequencerLoweringPattern; + + PatternMatchResult matchAndRewrite( + IREESeq::HL::BranchOp op, ArrayRef<Value *> properOperands, + ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<IREESeq::LL::BranchOp>(op, destinations[0], + operands[0]); + return matchSuccess(); + } +}; + +struct LowerCondCondBranchOpPattern + : public SequencerLoweringPattern<IREESeq::HL::CondBranchOp> { + using SequencerLoweringPattern< + IREESeq::HL::CondBranchOp>::SequencerLoweringPattern; + + PatternMatchResult matchAndRewrite( + IREESeq::HL::CondBranchOp op, ArrayRef<Value *> properOperands, + ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<IREESeq::LL::CondBranchOp>( + op, properOperands[0], + destinations[IREESeq::HL::CondBranchOp::trueIndex], + operands[IREESeq::HL::CondBranchOp::trueIndex], + destinations[IREESeq::HL::CondBranchOp::falseIndex], + operands[IREESeq::HL::CondBranchOp::falseIndex]); + return matchSuccess(); + } +}; + +// Rewrites an op into one with all the same operands, results, and attributes. +// Operands and results in the ops must have the same order and attributes must +// have the same name. They must also be constructed properly by the default +// builders. +template <typename SRC, typename DST> +struct LowerIdenticalOpPattern : public SequencerLoweringPattern<SRC> { + using SequencerLoweringPattern<SRC>::SequencerLoweringPattern; + + PatternMatchResult matchAndRewrite( + SRC op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + SmallVector<Type, 8> originalResultTypes{ + op.getOperation()->getResultTypes()}; + SmallVector<Type, 8> resultTypes; + if (failed(this->typeConverter_.convertTypes(originalResultTypes, + resultTypes))) { + op.emitOpError() << "Failed to convert result types"; + return this->matchFailure(); + } + rewriter.replaceOpWithNewOp<DST>(op, resultTypes, operands, op.getAttrs()); + return this->matchSuccess(); + } +}; + +} // namespace + +class LowerSequencerDialectPass : public ModulePass<LowerSequencerDialectPass> { + public: + void runOnModule() override { + auto *ctx = &getContext(); + LLTypeConverter typeConverter(ctx); + OwningRewritePatternList patterns; + patterns.insert< + LowerIdenticalOpPattern<IREE::ConstantOp, IREESeq::LL::ConstantOp>, + LowerIdenticalOpPattern<IREESeq::HL::DispatchOp, + IREESeq::LL::DynamicDispatchOp>, + LowerShapeOpPattern, LowerCopyOpPattern, LowerSliceOpPattern, + LowerBranchOpPattern, LowerCondCondBranchOpPattern>(ctx, typeConverter); +#define IDENTICAL_OP_LOWERING(op_name) \ + LowerIdenticalOpPattern<IREESeq::HL::op_name, IREESeq::LL::op_name> + patterns.insert< + IDENTICAL_OP_LOWERING(AllocHeapOp), IDENTICAL_OP_LOWERING(CloneOp), + IDENTICAL_OP_LOWERING(ReshapeOp), IDENTICAL_OP_LOWERING(CallOp), + IDENTICAL_OP_LOWERING(ReturnOp)>(ctx, typeConverter); +#undef IDENTICAL_OP_LOWERING + + mlir::populateFuncOpTypeConversionPattern(patterns, ctx, typeConverter); + ConversionTarget target(*ctx); + target.addLegalDialect<IREELLSequencerDialect>(); + target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { + return typeConverter.isSignatureLegal(op.getType()); + }); + + // TODO(b/142791494): The conversion framework will recurse into the + // executable if we just call it on the top-level module. This can't be a + // function pass because type conversion replaces the original functions. + auto funcsIt = getModule().getOps<FuncOp>(); + SmallVector<Operation *, 4> funcs(funcsIt.begin(), funcsIt.end()); + + if (failed(applyFullConversion(funcs, target, patterns, &typeConverter))) { + return signalPassFailure(); + } + } +}; + +std::unique_ptr<OpPassBase<ModuleOp>> createLowerSequencerDialectPass() { + return std::make_unique<LowerSequencerDialectPass>(); +} + +static PassRegistration<LowerSequencerDialectPass> pass( + "iree-lower-sequencer-dialect", + "Lowers the IREE HL sequencer dialect to the LL sequencer dialect."); + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Transforms/Sequencer/LowerStdToSequencerDialect.cpp b/compiler/Transforms/Sequencer/LowerStdToSequencerDialect.cpp new file mode 100644 index 0000000..e3f760d --- /dev/null +++ b/compiler/Transforms/Sequencer/LowerStdToSequencerDialect.cpp
@@ -0,0 +1,204 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Dialect.h" +#include "compiler/IR/Ops.h" +#include "compiler/IR/Sequencer/HLDialect.h" +#include "compiler/IR/Sequencer/HLOps.h" +#include "compiler/IR/StructureOps.h" +#include "compiler/Utils/MemRefUtils.h" +#include "compiler/Utils/OpCreationUtils.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/Utils.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +template <typename T> +class SequencerConversionPattern : public OpConversionPattern<T> { + using OpConversionPattern<T>::OpConversionPattern; +}; + +struct CallOpLowering : public SequencerConversionPattern<CallOp> { + using SequencerConversionPattern::SequencerConversionPattern; + + PatternMatchResult matchAndRewrite( + CallOp callOp, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + SmallVector<Type, 4> resultTypes(callOp.getResultTypes()); + rewriter.replaceOpWithNewOp<IREESeq::HL::CallOp>(callOp, callOp.getCallee(), + resultTypes, operands); + + return matchSuccess(); + } +}; + +struct CallIndirectOpLowering + : public SequencerConversionPattern<CallIndirectOp> { + using SequencerConversionPattern::SequencerConversionPattern; + + PatternMatchResult matchAndRewrite( + CallIndirectOp callOp, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<IREESeq::HL::CallIndirectOp>( + callOp, callOp.getCallee(), operands); + return matchSuccess(); + } +}; + +struct ReturnOpLowering : public SequencerConversionPattern<ReturnOp> { + using SequencerConversionPattern::SequencerConversionPattern; + + PatternMatchResult matchAndRewrite( + ReturnOp returnOp, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + SmallVector<Value *, 4> newOperands; + newOperands.reserve(operands.size()); + for (auto *operand : operands) { + newOperands.push_back(wrapAsMemRef(operand, returnOp, rewriter)); + } + rewriter.replaceOpWithNewOp<IREESeq::HL::ReturnOp>(returnOp, newOperands); + return matchSuccess(); + } +}; + +struct BranchOpLowering : public SequencerConversionPattern<BranchOp> { + using SequencerConversionPattern::SequencerConversionPattern; + PatternMatchResult matchAndRewrite( + BranchOp branchOp, ArrayRef<Value *> properOperands, + ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<IREESeq::HL::BranchOp>( + branchOp, destinations[0], operands[0]); + return this->matchSuccess(); + } +}; + +struct CondBranchOpLowering : public SequencerConversionPattern<CondBranchOp> { + using SequencerConversionPattern::SequencerConversionPattern; + PatternMatchResult matchAndRewrite( + CondBranchOp condBranchOp, ArrayRef<Value *> properOperands, + ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands, + ConversionPatternRewriter &rewriter) const override { + auto *condValue = + loadAccessValue(condBranchOp.getLoc(), properOperands[0], rewriter); + rewriter.replaceOpWithNewOp<IREESeq::HL::CondBranchOp>( + condBranchOp, condValue, + destinations[IREESeq::HL::CondBranchOp::trueIndex], + operands[IREESeq::HL::CondBranchOp::trueIndex], + destinations[IREESeq::HL::CondBranchOp::falseIndex], + operands[IREESeq::HL::CondBranchOp::falseIndex]); + return this->matchSuccess(); + } +}; + +struct AllocOpLowering : public SequencerConversionPattern<AllocOp> { + using SequencerConversionPattern::SequencerConversionPattern; + PatternMatchResult matchAndRewrite( + AllocOp allocOp, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + // TODO(benvanik): replace with length computation. + rewriter.replaceOpWithNewOp<IREESeq::HL::AllocHeapOp>( + allocOp, allocOp.getType(), operands); + return matchSuccess(); + } +}; + +struct DeallocOpLowering : public SequencerConversionPattern<DeallocOp> { + using SequencerConversionPattern::SequencerConversionPattern; + PatternMatchResult matchAndRewrite( + DeallocOp deallocOp, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp<IREESeq::HL::DiscardOp>(deallocOp, operands[0]); + return matchSuccess(); + } +}; + +struct LoadOpLowering : public SequencerConversionPattern<LoadOp> { + using SequencerConversionPattern::SequencerConversionPattern; + PatternMatchResult matchAndRewrite( + LoadOp loadOp, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + if (loadOp.getMemRefType().getRank() != 0) { + loadOp.emitError() << "Cannot lower load of non-scalar"; + return matchFailure(); + } + ArrayRef<Value *> dimPieces; + auto dst = rewriter.create<AllocOp>(loadOp.getLoc(), loadOp.getMemRefType(), + dimPieces); + auto emptyArrayMemref = createArrayConstant(rewriter, loadOp.getLoc(), {}); + rewriter.create<IREESeq::HL::CopyOp>(loadOp.getLoc(), loadOp.getMemRef(), + /*srcIndices=*/emptyArrayMemref, dst, + /*dstIndices=*/emptyArrayMemref, + /*lengths=*/emptyArrayMemref); + + rewriter.replaceOpWithNewOp<IREE::MemRefToScalarOp>(loadOp, dst); + + return matchSuccess(); + } +}; + +struct StoreOpLowering : public SequencerConversionPattern<StoreOp> { + using SequencerConversionPattern::SequencerConversionPattern; + PatternMatchResult matchAndRewrite( + StoreOp storeOp, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + if (storeOp.getMemRefType().getRank() != 0) { + storeOp.emitError() << "Cannot lower store of non-scalar"; + return matchFailure(); + } + + auto src = rewriter.create<IREE::ScalarToMemRefOp>( + storeOp.getLoc(), storeOp.getValueToStore()); + + auto emptyArrayMemref = createArrayConstant(rewriter, storeOp.getLoc(), {}); + rewriter.replaceOpWithNewOp<IREESeq::HL::CopyOp>( + storeOp, src, /*srcIndices=*/emptyArrayMemref, storeOp.getMemRef(), + /*dstIndices=*/emptyArrayMemref, /*lengths=*/emptyArrayMemref); + + return matchSuccess(); + } +}; + +} // namespace + +void populateLowerStdToSequencerPatterns(OwningRewritePatternList &patterns, + MLIRContext *context) { + patterns.insert< + // Control flow. + CallOpLowering, CallIndirectOpLowering, ReturnOpLowering, + BranchOpLowering, CondBranchOpLowering, + // Memory management. + AllocOpLowering, DeallocOpLowering, LoadOpLowering, StoreOpLowering>( + context); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Transforms/Sequencer/LowerToSequencerDialect.cpp b/compiler/Transforms/Sequencer/LowerToSequencerDialect.cpp new file mode 100644 index 0000000..720d259 --- /dev/null +++ b/compiler/Transforms/Sequencer/LowerToSequencerDialect.cpp
@@ -0,0 +1,62 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Dialect.h" +#include "compiler/IR/Sequencer/HLDialect.h" +#include "compiler/IR/Sequencer/LLDialect.h" +#include "compiler/Transforms/Rewrites.h" +#include "compiler/Transforms/Sequencer/Rewrites.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" + +namespace mlir { +namespace iree_compiler { +namespace { + +class LowerToSequencerDialectPass + : public FunctionPass<LowerToSequencerDialectPass> { + public: + void runOnFunction() override { + OwningRewritePatternList patterns; + auto* ctx = &getContext(); + xla_hlo::PopulateGeneralDotOpLoweringPatterns(&patterns, ctx); + xla_hlo::PopulateXlaToStdPatterns(&patterns, ctx); + populateLowerStdToIreePatterns(patterns, ctx); + populateLowerStdToSequencerPatterns(patterns, ctx); + populateLowerXlaToIreePatterns(patterns, ctx); + populateLowerXlaToSequencerPatterns(patterns, ctx); + + ConversionTarget target(getContext()); + target.addLegalDialect<IREEHLSequencerDialect, IREEDialect>(); + target.addLegalOp<FuncOp>(); + if (failed(applyFullConversion(getFunction(), target, patterns))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr<OpPassBase<FuncOp>> createLowerToSequencerDialectPass() { + return std::make_unique<LowerToSequencerDialectPass>(); +} + +static PassRegistration<LowerToSequencerDialectPass> pass( + "lower-to-iree-sequencer", "Convert all ops to the IREE sequencer dialect"); + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Transforms/Sequencer/LowerXLAToSequencerDialect.cpp b/compiler/Transforms/Sequencer/LowerXLAToSequencerDialect.cpp new file mode 100644 index 0000000..5626d6c --- /dev/null +++ b/compiler/Transforms/Sequencer/LowerXLAToSequencerDialect.cpp
@@ -0,0 +1,224 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/IR/Dialect.h" +#include "compiler/IR/Ops.h" +#include "compiler/IR/Sequencer/HLDialect.h" +#include "compiler/IR/Sequencer/HLOps.h" +#include "compiler/IR/StructureOps.h" +#include "compiler/Transforms/ConversionUtils.h" +#include "compiler/Utils/MemRefUtils.h" +#include "compiler/Utils/OpCreationUtils.h" +#include "compiler/Utils/TypeConversionUtils.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +// TODO(suderman): tablegen this? or something a bit more flexible. + +#define UNARY_OP_LOWERING(XlaOpType, IREEOpType) \ + struct XlaOpType##Lowering \ + : public UnaryOpLowering<xla_hlo::XlaOpType, IREEOpType> { \ + using UnaryOpLowering::UnaryOpLowering; \ + }; + +#define TERNARY_OP_LOWERING(XlaOpType, IREEOpType) \ + struct XlaOpType##Lowering \ + : public TernaryOpLowering<xla_hlo::XlaOpType, IREEOpType> { \ + using TernaryOpLowering::TernaryOpLowering; \ + }; + +UNARY_OP_LOWERING(CopyOp, IREESeq::HL::CloneOp); + +#undef UNARY_OP_LOWERING +#undef TERNARY_OP_LOWERING + +template <typename T> +static Operation *createShapeTargetingOp(ConversionPatternRewriter &rewriter, + Location loc, Value *input, + MemRefType targetType) { + auto shapeOp = createArrayConstant(rewriter, loc, targetType.getShape()); + return rewriter.create<T>(loc, targetType, input, shapeOp); +} + +static Value *inputAsMemref(ConversionPatternRewriter &rewriter, Operation *op, + Value *tensor) { + return wrapAsMemRef(loadAccessValue(op->getLoc(), tensor, rewriter), op, + rewriter); +} + +template <typename SrcOp> +class XlaOpLowering : public OpConversionPattern<SrcOp> { + public: + using OpConversionPattern<SrcOp>::OpConversionPattern; + + PatternMatchResult matchAndRewrite( + SrcOp op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto srcOp = cast<SrcOp>(op); + + SmallVector<Value *, 4> memrefOperands; + for (auto operand : operands) { + memrefOperands.push_back(inputAsMemref(rewriter, op, operand)); + } + + auto dstOp = rewriteInternal(&srcOp, memrefOperands, rewriter); + rewriter.replaceOp(op, wrapAsTensor(dstOp->getResult(0), srcOp, rewriter)); + return this->matchSuccess(); + } + + protected: + virtual Operation *rewriteInternal( + SrcOp *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const { + llvm_unreachable("unimplemented rewrite, did you mean rewriteTerminator?"); + } +}; + +struct ConcatOpLowering : public XlaOpLowering<xla_hlo::ConcatenateOp> { + using XlaOpLowering::XlaOpLowering; + + Operation *rewriteInternal( + xla_hlo::ConcatenateOp *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto finalType = convertTypeToMemRef(*op); + + return rewriter.create<IREESeq::HL::ConcatOp>( + op->getLoc(), finalType, operands, + rewriter.getI32IntegerAttr(op->dimension().getZExtValue())); + } +}; + +struct DynamicUpdateSliceLowering + : public XlaOpLowering<xla_hlo::DynamicUpdateSliceOp> { + using XlaOpLowering::XlaOpLowering; + + Operation *rewriteInternal( + xla_hlo::DynamicUpdateSliceOp *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + auto operand = operands[0]; + auto update = operands[1]; + + auto updateType = update->getType().cast<ShapedType>(); + Value *lengthConstant = + createArrayConstant(rewriter, op->getLoc(), updateType.getShape()); + + auto startIndices = makeArrayRef(operands).drop_front(2); + const int rank = startIndices.size(); + llvm::SmallVector<Value *, 4> valuesToConcat; + valuesToConcat.reserve(startIndices.size()); + auto type = getElementTypeOrSelf(startIndices.front()); + + // To generate the offset matrix we need to convert the variadic tensors + // into a reshaped and concated value. + for (auto index : startIndices) { + auto reshapedIndex = rewriter.create<IREESeq::HL::ReshapeOp>( + op->getLoc(), MemRefType::get({1}, type), index, + createArrayConstant(rewriter, op->getLoc(), {1})); + valuesToConcat.push_back(reshapedIndex); + } + + auto dstOffset = rewriter + .create<IREESeq::HL::ConcatOp>( + op->getLoc(), MemRefType::get({rank}, type), + valuesToConcat, rewriter.getI32IntegerAttr(0)) + .getResult(); + + llvm::SmallVector<int64_t, 4> zero_offset; + zero_offset.resize(updateType.getRank(), 0); + auto srcOffset = createArrayConstant(rewriter, op->getLoc(), zero_offset); + + auto copiedOperand = rewriter.create<IREESeq::HL::CloneOp>( + op->getLoc(), operand->getType(), operand); + + rewriter + .create<IREESeq::HL::CopyOp>(op->getLoc(), update, srcOffset, + copiedOperand, dstOffset, lengthConstant) + .getOperation(); + + return copiedOperand; + } +}; + +struct SliceLowering : public XlaOpLowering<xla_hlo::SliceOp> { + using XlaOpLowering::XlaOpLowering; + Operation *rewriteInternal( + xla_hlo::SliceOp *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + // XLA slice has value semantics, whereas the IREE slice creates a view. We + // lower it to a copy if all strides are one which may be transformed to a + // slice by later optimizations. + auto isNotOne = [](APInt stride) { return stride != 1; }; + if (llvm::any_of(op->strides(), isNotOne)) { + op->emitRemark() << "Could not lower slice op with non-singular strides"; + return nullptr; + } + + auto finalType = convertTypeToMemRef(*op); + + auto src = operands[0]; + std::vector<Value *> dim_pieces; + auto dst = rewriter.create<IREESeq::HL::AllocHeapOp>(op->getLoc(), + finalType, dim_pieces); + auto srcIndices = + rewriter.create<IREE::ConstantOp>(op->getLoc(), op->start_indices()); + auto lengths = + createArrayConstant(rewriter, op->getLoc(), finalType.getShape()); + + llvm::SmallVector<int64_t, 4> zero_offset; + zero_offset.resize(finalType.getRank(), 0); + auto dstIndices = createArrayConstant(rewriter, op->getLoc(), zero_offset); + + rewriter.create<IREESeq::HL::CopyOp>(op->getLoc(), src, srcIndices, dst, + dstIndices, lengths); + return dst; + } +}; + +struct ReshapeOpLowering : public XlaOpLowering<xla_hlo::ReshapeOp> { + using XlaOpLowering::XlaOpLowering; + + Operation *rewriteInternal( + xla_hlo::ReshapeOp *op, ArrayRef<Value *> operands, + ConversionPatternRewriter &rewriter) const override { + return createShapeTargetingOp<IREESeq::HL::ReshapeOp>( + rewriter, op->getLoc(), operands[0], convertTypeToMemRef(*op)); + } +}; + +} // namespace + +void populateLowerXlaToSequencerPatterns(OwningRewritePatternList &patterns, + MLIRContext *ctx) { + patterns.insert<ConcatOpLowering, CopyOpLowering, DynamicUpdateSliceLowering, + ReshapeOpLowering, SliceLowering>(ctx); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Transforms/Sequencer/OutlineDispatchRegions.cpp b/compiler/Transforms/Sequencer/OutlineDispatchRegions.cpp new file mode 100644 index 0000000..b914d46 --- /dev/null +++ b/compiler/Transforms/Sequencer/OutlineDispatchRegions.cpp
@@ -0,0 +1,242 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <utility> + +#include "compiler/IR/Ops.h" +#include "compiler/IR/Sequencer/HLOps.h" +#include "compiler/IR/StructureOps.h" +#include "compiler/IR/Types.h" +#include "compiler/Utils/DispatchUtils.h" +#include "compiler/Utils/MemRefUtils.h" +#include "compiler/Utils/TypeConversionUtils.h" +#include "llvm/ADT/SetVector.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/Utils.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +// Inserts a load from a wrapped memref (as inserted via insertDispatcherStore). +// Returns the value in the original type. +Value *insertDispatcheeLoad(Operation *op, Type originalType, Value *value, + OpBuilder &builder) { + // If old value was a memref we don't need to change anything. + if (originalType.isa<MemRefType>()) { + return value; + } + + auto loadInputOp = + builder.create<IREE::LoadInputOp>(op->getLoc(), originalType, value); + value->replaceAllUsesWith(loadInputOp.getResult()); + loadInputOp.setOperand(value); + return loadInputOp.getResult(); +} + +// Marshals args and results as buffers for the given region. +// Beyond inserting the appropriate tensor-to-memref ops we avoid mutating the +// interior of the dispatch region as much as possible. +LogicalResult marshalDispatchSite(IREE::DispatchRegionOp regionOp) { + auto &entryBlock = regionOp.getBody().getBlocks().front(); + OpBuilder dispatcherBuilder(regionOp); + OpBuilder dispatcheeBuilder(&entryBlock, entryBlock.begin()); + + // Wrap input operands and unwrap in the entry block. + SmallVector<Value *, 8> newArgs; + for (int i = 0; i < regionOp.getNumArgOperands(); ++i) { + // Wrap the input outside of the region. + auto *blockArg = entryBlock.getArgument(i); + Type originalType = blockArg->getType(); + auto *originalArg = regionOp.getArgOperand(i); + auto *wrappedArg = + insertDispatcherStore(regionOp, originalArg, dispatcherBuilder); + newArgs.push_back(wrappedArg); + blockArg->setType(wrappedArg->getType()); + + // Unwrap the block arg value and replace all of the uses with the newly + // unwrapped value. + insertDispatcheeLoad(regionOp, originalType, blockArg, dispatcheeBuilder); + } + + // Allocate output arguments and replace the return values with those. + SmallVector<Type, 8> newResults; + SmallVector<std::pair<int, Value *>, 8> resultIndicesToOutputArgs; + SmallVector<int, 8> deadResultIndices; + SmallVector<std::pair<Value *, Value *>, 8> replacedResults; + for (int i = 0; i < regionOp.getNumResults(); ++i) { + auto *result = regionOp.getResult(i); + auto convertedType = convertTypeToMemRef(result->getType()); + + // Allocate output buffer in the dispatcher to pass in to the region. + Value *allocatedValue = allocateDispatchOutputBuffer( + regionOp.getLoc(), convertedType, dispatcherBuilder); + if (!allocatedValue) { + regionOp.emitError("unable to allocate result value"); + return failure(); + } + newArgs.push_back(allocatedValue); + + auto *newBlockArg = entryBlock.addArgument(allocatedValue->getType()); + resultIndicesToOutputArgs.push_back({i, newBlockArg}); + + // NOTE: right now we always replace results. If we want to allow return + // values we can avoid killing them here. + deadResultIndices.push_back(i); + replacedResults.push_back({result, allocatedValue}); + } + + // Remove dead results from return statements. + regionOp.walk([&](IREE::ReturnOp returnOp) { + // Replace the results we were returning with stores to output arguments. + OpBuilder builder(returnOp); + for (auto resultToArg : resultIndicesToOutputArgs) { + auto *value = returnOp.getOperand(resultToArg.first); + auto *outputArg = resultToArg.second; + builder.create<IREE::StoreOutputOp>(returnOp.getLoc(), value, outputArg); + } + + // Filter out the results that are now dead. + SmallVector<Value *, 8> newOperands(returnOp.getOperands()); + for (int i = deadResultIndices.size() - 1; i >= 0; --i) { + newOperands.erase(newOperands.begin() + deadResultIndices[i]); + } + returnOp.getOperation()->setOperands(newOperands); + }); + + // Clone the region op with the new args/results. + auto newRegionOp = dispatcherBuilder.create<IREE::DispatchRegionOp>( + regionOp.getLoc(), newResults, regionOp.getWorkload(), newArgs); + newRegionOp.getBody().takeBody(regionOp.getBody()); + + // Marshal back the results by replacing uses of the original with loads from + // the new output arg. + for (auto &it : replacedResults) { + insertDispatcherLoad(regionOp, it.first, it.second, dispatcherBuilder); + } + + // Remove original region. + regionOp.erase(); + + return success(); +} + +// Converts a dispatch_region into a dispatch to the outlined region function. +LogicalResult convertToDispatchOp(IREE::DispatchRegionOp regionOp, + IREE::MultiArchExecutableOp executable, + FuncOp entryPoint) { + // Insert at the same place as the original region. + OpBuilder dispatcherBuilder(regionOp); + + // Ensure workload is a memref. + auto *workload = + wrapAsMemRef(regionOp.getWorkload(), regionOp, dispatcherBuilder); + + // Create the dispatch op to the executable function. + SmallVector<Value *, 8> operandValues(regionOp.getArgOperands()); + auto dispatchOp = dispatcherBuilder.create<IREESeq::HL::DispatchOp>( + regionOp.getLoc(), executable.getName(), entryPoint.getName(), workload, + entryPoint.getType().getResults(), operandValues); + + // Replace uses of the existing results with the new results. + for (int i = 0; i < regionOp.getNumResults(); ++i) { + regionOp.getResult(i)->replaceAllUsesWith(dispatchOp.getResult(i)); + } + + // Erase original region. + regionOp.erase(); + + return success(); +} + +// Outlines a dispatch region into an iree.multi_arch_executable. +LogicalResult outlineDispatchRegion(IREE::DispatchRegionOp regionOp, + int outlinedRegionOrdinal) { + // Build function type matching 1:1 with the region signature. + SmallVector<Type, 8> operandTypes; + for (auto *arg : regionOp.getArgOperands()) { + operandTypes.push_back(arg->getType()); + } + SmallVector<Type, 8> resultTypes(regionOp.getResultTypes()); + auto functionType = + FunctionType::get(operandTypes, resultTypes, regionOp.getContext()); + + // Create the executable with the region cloned into it. + IREE::MultiArchExecutableOp multiArchExecutable; + FuncOp outlinedFunc; + std::tie(multiArchExecutable, outlinedFunc) = createRegionExecutable( + regionOp, functionType, + "_dispatch_" + std::to_string(outlinedRegionOrdinal)); + outlinedFunc.setAttr("iree.executable.export", + UnitAttr::get(regionOp.getContext())); + + // Finally convert the dispatch region into a dispatch to the outlined func. + return convertToDispatchOp(regionOp, multiArchExecutable, outlinedFunc); +} + +} // namespace + +class OutlineDispatchRegionsPass + : public ModulePass<OutlineDispatchRegionsPass> { + public: + void runOnModule() override { + auto module = getModule(); + + ModuleManager moduleManager(module); + auto funcs = module.getOps<FuncOp>(); + SmallVector<FuncOp, 4> funcOps(funcs.begin(), funcs.end()); + for (auto func : funcOps) { + // Perform marshaling of the dispatcher and dispatchee I/O. + // This inserts the required stores and loads to make everything memrefs + // and adds the iree.load_input/iree.store_output ops to the dispatchee. + if (func.walk([&](IREE::DispatchRegionOp op) { + if (failed(marshalDispatchSite(op))) { + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }) + .wasInterrupted()) { + return signalPassFailure(); + } + + // Outline all of the iree.dispatch_region ops in this function. + SmallVector<IREE::DispatchRegionOp, 8> dispatchRegionOps; + func.walk( + [&](IREE::DispatchRegionOp op) { dispatchRegionOps.push_back(op); }); + for (int i = 0; i < dispatchRegionOps.size(); ++i) { + if (failed(outlineDispatchRegion(dispatchRegionOps[i], i))) { + return signalPassFailure(); + } + } + } + } +}; + +std::unique_ptr<OpPassBase<ModuleOp>> createOutlineDispatchRegionsPass() { + return std::make_unique<OutlineDispatchRegionsPass>(); +} + +static PassRegistration<OutlineDispatchRegionsPass> pass( + "iree-outline-dispatch-regions", + "Outlines dispatch regions into standalone functions"); + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Transforms/Sequencer/OutlineReductionRegions.cpp b/compiler/Transforms/Sequencer/OutlineReductionRegions.cpp new file mode 100644 index 0000000..407f014 --- /dev/null +++ b/compiler/Transforms/Sequencer/OutlineReductionRegions.cpp
@@ -0,0 +1,307 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <utility> + +#include "compiler/IR/Ops.h" +#include "compiler/IR/Sequencer/HLOps.h" +#include "compiler/IR/StructureOps.h" +#include "compiler/IR/Types.h" +#include "compiler/Utils/DispatchUtils.h" +#include "compiler/Utils/MemRefUtils.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/Utils.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +// Determines the shapes involved with reducing this dimension. +SmallVector<int64_t, 4> calculateResultShape(Value *input, + int windowDimension) { + SmallVector<int64_t, 4> resultShape; + for (auto it : + llvm::enumerate(input->getType().cast<ShapedType>().getShape())) { + if (it.index() != windowDimension) { + resultShape.push_back(it.value()); + } + } + return resultShape; +} + +// Creates an executable that holds the given elemental reduction region. +// The executable will have an entry point taking the specified reduction values +// and writing the results to output arguments. +std::pair<IREE::MultiArchExecutableOp, FuncOp> createReductionExecutable( + IREE::ReductionRegionOp regionOp, int outlinedRegionOrdinal, + int separatedReductionIndex, int reductionDimension, + SmallVector<Value *, 4> initialValues, SmallVector<Value *, 4> inputs) { + Builder builder(regionOp.getContext()); + + // Build function type matching 1:1 with the region signature. + SmallVector<Type, 8> elementalOperandTypes; + SmallVector<Type, 8> elementalResultTypes; + for (auto *arg : regionOp.getInitialValueOperands()) { + // (in0, in1) -> out0 + elementalOperandTypes.push_back(arg->getType()); + elementalOperandTypes.push_back(arg->getType()); + elementalResultTypes.push_back(arg->getType()); + } + auto elementalFunctionType = FunctionType::get( + elementalOperandTypes, elementalResultTypes, regionOp.getContext()); + + // Create the executable with the region cloned into it. + IREE::MultiArchExecutableOp multiArchExecutable; + FuncOp elementalFunc; + std::tie(multiArchExecutable, elementalFunc) = createRegionExecutable( + regionOp, elementalFunctionType, + "_reduce_" + std::to_string(outlinedRegionOrdinal) + "_dim_" + + std::to_string(separatedReductionIndex)); + + // Create a new entry point that we can use with the signature for this + // dimension. + SmallVector<Type, 8> allOperandTypes; + auto inputTypes = + llvm::map_range(inputs, [](Value *value) { return value->getType(); }); + allOperandTypes.append(inputTypes.begin(), inputTypes.end()); + auto initialValueTypes = llvm::map_range( + initialValues, [](Value *value) { return value->getType(); }); + allOperandTypes.append(initialValueTypes.begin(), initialValueTypes.end()); + for (auto resultType : llvm::enumerate(regionOp.getResultTypes())) { + auto shapedType = resultType.value().cast<ShapedType>(); + allOperandTypes.push_back(MemRefType::get( + calculateResultShape(inputs[resultType.index()], reductionDimension), + shapedType.getElementType())); + } + auto entryFuncType = FunctionType::get(allOperandTypes, ArrayRef<Type>{}, + regionOp.getContext()); + auto entryFunc = + FuncOp::create(regionOp.getLoc(), + (elementalFunc.getName() + "_entry").str(), entryFuncType); + entryFunc.setAttr("iree.executable.export", + UnitAttr::get(regionOp.getContext())); + elementalFunc.getOperation()->getBlock()->push_back(entryFunc); + entryFunc.getOperation()->moveBefore(elementalFunc); + entryFunc.setAttr("iree.executable.reduction", + UnitAttr::get(regionOp.getContext())); + entryFunc.setAttr("iree.executable.reduction.apply", + builder.getSymbolRefAttr(elementalFunc)); + + return {multiArchExecutable, entryFunc}; +} + +// Converts a reduction_region into a dispatch to the outlined region function +// for a single reduction dimension. +// Returns the results of the reduction or empty if the construction fails. +SmallVector<Value *, 4> convertToDispatchOp( + IREE::ReductionRegionOp regionOp, IREE::MultiArchExecutableOp executable, + FuncOp entryFunc, int reductionDimension, + SmallVector<Value *, 4> initialValues, SmallVector<Value *, 4> inputs, + OpBuilder &dispatcherBuilder) { + // Allocate output args and replace the return values with those. + SmallVector<Value *, 4> resultValues; + for (auto resultType : llvm::enumerate(regionOp.getResultTypes())) { + // Allocate output buffer in the dispatcher to pass in to the region. + auto shapedType = resultType.value().cast<ShapedType>(); + Value *allocatedValue = allocateDispatchOutputBuffer( + regionOp.getLoc(), + MemRefType::get(calculateResultShape(inputs[resultType.index()], + reductionDimension), + shapedType.getElementType()), + dispatcherBuilder); + if (!allocatedValue) { + regionOp.emitError("unable to allocate result value"); + return {}; + } + resultValues.push_back(allocatedValue); + } + + // Calculate workload from the result shape. + auto *workload = + wrapAsMemRef(calculateWorkload(regionOp, resultValues.front()), regionOp, + dispatcherBuilder); + + // Create the reduce op to the executable function. + std::vector<Value *> allOperands; + allOperands.insert(allOperands.end(), inputs.begin(), inputs.end()); + allOperands.insert(allOperands.end(), initialValues.begin(), + initialValues.end()); + allOperands.insert(allOperands.end(), resultValues.begin(), + resultValues.end()); + dispatcherBuilder.create<IREESeq::HL::DispatchOp>( + regionOp.getLoc(), executable.getName(), entryFunc.getName(), workload, + ArrayRef<Type>{}, allOperands); + + return resultValues; +} + +// Outlines a reduction region into one or more iree.multi_arch_executables. +// This separates the reduction into multiple dispatches, one for each reduction +// dimension (thankfully XLA's operation semantics state this is ok). We then +// special case the first dispatch such that it takes the constant initial +// values so that we don't have to materialize a buffer for them. +LogicalResult outlineReductionRegion(IREE::ReductionRegionOp regionOp, + int outlinedRegionOrdinal) { + // Insert at the same place as the original region. + OpBuilder dispatcherBuilder(regionOp); + + // Wrap input operands in memrefs. + SmallVector<Value *, 4> initialValues{llvm::map_range( + regionOp.getInitialValueOperands(), [&](Value *originalArg) { + return insertDispatcherStore(regionOp, originalArg, dispatcherBuilder); + })}; + SmallVector<Value *, 4> temps{ + llvm::map_range(regionOp.getReductionOperands(), [&](Value *originalArg) { + return insertDispatcherStore(regionOp, originalArg, dispatcherBuilder); + })}; + + // Create one dispatch per dimension being reduced. + // We'll do this by chaining the original input through with the temporary + // reduction results. The results we end up with will be the originally + // requested shape and we can just substitute them. + if (regionOp.isWindowed()) { + auto windowDimensions = regionOp.window_dimensions().getValue(); + auto windowStrides = regionOp.window_strides().getValue(); + auto baseDilations = regionOp.base_dilations().getValue(); + auto windowDilations = regionOp.window_dilations().getValue(); + SmallVector<std::tuple<int64_t, int64_t, int64_t, int64_t>, 4> + sortedWindowAttrs; + for (uint64_t i = 0; i < windowDimensions.getNumElements(); ++i) { + int64_t windowDimension = + windowDimensions.getValue<IntegerAttr>({i}).getInt(); + int64_t windowStride = windowStrides.getValue<IntegerAttr>({i}).getInt(); + int64_t baseDilation = baseDilations.getValue<IntegerAttr>({i}).getInt(); + int64_t windowDilation = + windowDilations.getValue<IntegerAttr>({i}).getInt(); + sortedWindowAttrs.push_back( + {windowDimension, windowStride, baseDilation, windowDilation}); + } + llvm::sort(sortedWindowAttrs, + [](std::tuple<int64_t, int64_t, int64_t, int64_t> a, + std::tuple<int64_t, int64_t, int64_t, int64_t> b) { + return std::get<0>(a) - std::get<0>(b); + }); + for (auto windowAttrs : llvm::enumerate(sortedWindowAttrs)) { + int64_t windowDimension = std::get<0>(windowAttrs.value()); + int64_t windowStride = std::get<1>(windowAttrs.value()); + int64_t baseDilation = std::get<2>(windowAttrs.value()); + int64_t windowDilation = std::get<3>(windowAttrs.value()); + IREE::MultiArchExecutableOp multiArchExecutable; + FuncOp entryFunc; + std::tie(multiArchExecutable, entryFunc) = createReductionExecutable( + regionOp, outlinedRegionOrdinal, windowAttrs.index(), windowDimension, + initialValues, temps); + entryFunc.setAttr("iree.executable.reduction.padding_mode", + dispatcherBuilder.getI32IntegerAttr( + regionOp.padding_mode().getValue())); + entryFunc.setAttr("iree.executable.reduction.window_dimension", + dispatcherBuilder.getI32IntegerAttr(windowDimension)); + entryFunc.setAttr("iree.executable.reduction.window_stride", + dispatcherBuilder.getI32IntegerAttr(windowStride)); + entryFunc.setAttr("iree.executable.reduction.base_dilation", + dispatcherBuilder.getI32IntegerAttr(baseDilation)); + entryFunc.setAttr("iree.executable.reduction.window_dilation", + dispatcherBuilder.getI32IntegerAttr(windowDilation)); + temps = convertToDispatchOp(regionOp, multiArchExecutable, entryFunc, + windowDimension, initialValues, + std::move(temps), dispatcherBuilder); + if (temps.empty()) { + return regionOp.emitOpError() + << "Failed to construct reduction for windowed dimension " + << windowDimension; + } + } + } else { + auto dimensions = regionOp.dimensions().getValue(); + SmallVector<int64_t, 4> sortedDimensions; + for (uint64_t i = 0; i < dimensions.getNumElements(); ++i) { + sortedDimensions.push_back( + dimensions.getValue<IntegerAttr>({i}).getInt()); + } + llvm::sort(sortedDimensions, [](int64_t a, int64_t b) { return a - b; }); + for (auto dimension : llvm::enumerate(sortedDimensions)) { + IREE::MultiArchExecutableOp multiArchExecutable; + FuncOp entryFunc; + std::tie(multiArchExecutable, entryFunc) = createReductionExecutable( + regionOp, outlinedRegionOrdinal, dimension.index(), dimension.value(), + initialValues, temps); + entryFunc.setAttr("iree.executable.reduction.dimension", + dispatcherBuilder.getI32IntegerAttr(dimension.value())); + temps = convertToDispatchOp(regionOp, multiArchExecutable, entryFunc, + dimension.value(), initialValues, + std::move(temps), dispatcherBuilder); + if (temps.empty()) { + return regionOp.emitOpError() + << "Failed to construct reduction for dimension " + << dimension.value(); + } + } + } + for (auto it : llvm::enumerate(regionOp.getResults())) { + insertDispatcherLoad(regionOp, it.value(), temps[it.index()], + dispatcherBuilder); + } + + // Erase original region. + regionOp.erase(); + + return success(); +} + +} // namespace + +class OutlineReductionRegionsPass + : public ModulePass<OutlineReductionRegionsPass> { + public: + void runOnModule() override { + auto module = getModule(); + + ModuleManager moduleManager(module); + auto funcs = module.getOps<FuncOp>(); + SmallVector<FuncOp, 4> funcOps(funcs.begin(), funcs.end()); + for (auto func : funcOps) { + // Outline all of the iree.reduction_region ops in this function. + std::vector<IREE::ReductionRegionOp> reductionRegionOps; + func.walk([&](IREE::ReductionRegionOp op) { + reductionRegionOps.push_back(op); + }); + for (int i = 0; i < reductionRegionOps.size(); ++i) { + if (failed(outlineReductionRegion(reductionRegionOps[i], i))) { + return signalPassFailure(); + } + } + } + } +}; + +std::unique_ptr<OpPassBase<ModuleOp>> createOutlineReductionRegionsPass() { + return std::make_unique<OutlineReductionRegionsPass>(); // NOLINT +} + +static PassRegistration<OutlineReductionRegionsPass> pass( + "iree-outline-reduction-regions", + "Outlines reduction regions into standalone functions"); + +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/Passes.h b/compiler/Transforms/Sequencer/Passes.h similarity index 100% rename from iree/compiler/Transforms/Sequencer/Passes.h rename to compiler/Transforms/Sequencer/Passes.h
diff --git a/compiler/Transforms/Sequencer/RematerializeDispatchConstants.cpp b/compiler/Transforms/Sequencer/RematerializeDispatchConstants.cpp new file mode 100644 index 0000000..a5479e5 --- /dev/null +++ b/compiler/Transforms/Sequencer/RematerializeDispatchConstants.cpp
@@ -0,0 +1,149 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <algorithm> + +#include "compiler/IR/Ops.h" +#include "compiler/Utils/DispatchUtils.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/Utils.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +// Chosen randomly for now. We can measure and see what makes sense. +constexpr int64_t kMaxRematerializedConstantSizeInBytes = 1 * 1024; + +// Returns true if the constant value is under a certain threshold. +// This threshold is fixed for all backends as a value that is assumed small +// enough to be worth inlining possibly several times (at the cost of binary +// bloat). +bool isConstantSmall(ConstantOp constantOp) { + if (auto shapedType = constantOp.getType().dyn_cast<ShapedType>()) { + return shapedType.getSizeInBits() / 8 <= + kMaxRematerializedConstantSizeInBytes; + } + + // Assume anything unshaped is small. This may not always be true in custom + // dialects but is in std for now. + return true; +} + +// Returns true if the dispatch region is allowed to have constants inside. +// Certain regions that may get replaced or turned into kernel imports shouldn't +// have the constants moved into them as they'll just get lost. +bool canDispatchRegionContainConstants( + IREE::DispatchRegionOp dispatchRegionOp) { + for (auto &block : dispatchRegionOp.getBody()) { + for (auto &op : block) { + if (isa<xla_hlo::DotOp>(&op)) { + return false; + } + } + } + return true; +} + +// Rematerializes a constant inside of all dispatch regions that use it. +// Afterward the constant is only removed if there are no other uses within the +// non-dispatch block (such as by sequencer ops). +LogicalResult rematerializeConstantInDispatchRegions(ConstantOp constantOp) { + Value *constantValue = constantOp.getResult(); + SmallVector<IREE::DispatchRegionOp, 4> usingRegionOps; + for (auto *user : constantValue->getUsers()) { + if (auto dispatchRegionOp = dyn_cast<IREE::DispatchRegionOp>(user)) { + // Ensure this isn't just the workload and is used as an arg. + if (std::find(dispatchRegionOp.arg_operand_begin(), + dispatchRegionOp.arg_operand_end(), + constantValue) != dispatchRegionOp.arg_operand_end()) { + if (canDispatchRegionContainConstants(dispatchRegionOp)) { + usingRegionOps.push_back(dispatchRegionOp); + } + } + } + } + for (auto &dispatchRegionOp : usingRegionOps) { + if (failed(inlineDispatchRegionOperandsUsingValue(dispatchRegionOp, + constantValue))) { + return failure(); + } + } + + // Remove if there are no other uses within the block. + if (constantOp.use_empty()) { + constantOp.erase(); + } + + return success(); +} + +} // namespace + +// Finds constant arguments to dispatch regions that are too small to be worth +// putting into constant pools. This prevents things like a CSE'd scalar +// constant of 0.0 being passed by reference to a bunch of regions. Later +// backend-specific passes running on the dispatch regions may also be able to +// improve their constant propagation chances by having the full constant value +// available. +// +// Note that this currently only operates at the block level. Constants that are +// pushed across branches are assumed to have been rematerialized within blocks +// already, but if that isn't the case then this pass can be extended to do +// that. +class RematerializeDispatchConstantsPass + : public FunctionPass<RematerializeDispatchConstantsPass> { + public: + void runOnFunction() override { + for (auto &block : getFunction()) { + SmallVector<ConstantOp, 8> smallConstantOps; + for (auto constantOp : block.getOps<ConstantOp>()) { + if (isConstantSmall(constantOp)) { + smallConstantOps.push_back(constantOp); + } + } + // Note: we iterate in reverse so that the rematerialized constants appear + // in the same order they did originally (as insertion is at the top). + for (auto constantOp : llvm::reverse(smallConstantOps)) { + if (failed(rematerializeConstantInDispatchRegions(constantOp))) { + return signalPassFailure(); + } + } + } + } +}; + +std::unique_ptr<OpPassBase<FuncOp>> createRematerializeDispatchConstantsPass() { + return std::make_unique<RematerializeDispatchConstantsPass>(); +} + +static PassRegistration<RematerializeDispatchConstantsPass> pass( + "iree-rematerialize-dispatch-constants", + "Rematerializes small previously-CSE'd constants into dispatch regions."); + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Transforms/Sequencer/Rewrites.h b/compiler/Transforms/Sequencer/Rewrites.h new file mode 100644 index 0000000..c78a131 --- /dev/null +++ b/compiler/Transforms/Sequencer/Rewrites.h
@@ -0,0 +1,40 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_TRANSFORMS_SEQUENCER_REWRITES_H_ +#define IREE_COMPILER_TRANSFORMS_SEQUENCER_REWRITES_H_ + +#include "compiler/Utils/TypeConversionUtils.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace iree_compiler { + +// Adds rewrite patterns for lowering IREE Sequencer HL ops (iree_hl_seq.*) +// to LL ops (iree_ll_seq.*). +void populateSequencerLoweringPatterns(OwningRewritePatternList &patterns, + MLIRContext *ctx); + +// Adds rewrite patterns for lowering xla_hlo ops to Sequencer HL ops. +void populateLowerXlaToSequencerPatterns(OwningRewritePatternList &patterns, + MLIRContext *ctx); + +// Adds rewrite patterns for lowering standard ops to Sequencer HL ops. +void populateLowerStdToSequencerPatterns(OwningRewritePatternList &patterns, + MLIRContext *ctx); +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_TRANSFORMS_SEQUENCER_REWRITES_H_
diff --git a/compiler/Transforms/Sequencer/test/BUILD b/compiler/Transforms/Sequencer/test/BUILD new file mode 100644 index 0000000..8a8ee90 --- /dev/null +++ b/compiler/Transforms/Sequencer/test/BUILD
@@ -0,0 +1,16 @@ +# Tests specific to the sequencer. + +load("//:build_defs.google.bzl", "iree_glob_lit_tests", "iree_setup_lit_package") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +iree_setup_lit_package( + data = [ + "///tools:iree-opt", + ], +) + +iree_glob_lit_tests()
diff --git a/compiler/Transforms/test/BUILD b/compiler/Transforms/test/BUILD new file mode 100644 index 0000000..c7b98c5 --- /dev/null +++ b/compiler/Transforms/test/BUILD
@@ -0,0 +1,16 @@ +# Tests for common transforms. + +load("//:build_defs.google.bzl", "iree_glob_lit_tests", "iree_setup_lit_package") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +iree_setup_lit_package( + data = [ + "///tools:iree-opt", + ], +) + +iree_glob_lit_tests()
diff --git a/iree/compiler/Transforms/test/drop_unreachable_module_functions.mlir b/compiler/Transforms/test/drop_unreachable_module_functions.mlir similarity index 100% rename from iree/compiler/Transforms/test/drop_unreachable_module_functions.mlir rename to compiler/Transforms/test/drop_unreachable_module_functions.mlir
diff --git a/iree/compiler/Translation/BUILD b/compiler/Translation/BUILD similarity index 100% rename from iree/compiler/Translation/BUILD rename to compiler/Translation/BUILD
diff --git a/iree/compiler/Translation/CMakeLists.txt b/compiler/Translation/CMakeLists.txt similarity index 100% rename from iree/compiler/Translation/CMakeLists.txt rename to compiler/Translation/CMakeLists.txt
diff --git a/compiler/Translation/Interpreter/BUILD b/compiler/Translation/Interpreter/BUILD new file mode 100644 index 0000000..17858bd --- /dev/null +++ b/compiler/Translation/Interpreter/BUILD
@@ -0,0 +1,31 @@ +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "Interpreter", + srcs = ["InterpreterExecutableTranslation.cpp"], + hdrs = ["InterpreterExecutableTranslation.h"], + deps = [ + "///compiler/IR", + "///compiler/IR/Interpreter", + "///compiler/Serialization", + "///compiler/Transforms", + "///compiler/Transforms/Interpreter", + "///compiler/Utils", + "///schemas", + "@com_github_google_flatbuffers//:flatbuffers", + "@llvm//:support", + "@local_config_mlir//:IR", + "@local_config_mlir//:Pass", + "@local_config_mlir//:StandardDialectRegistration", + "@local_config_mlir//:Support", + "@local_config_mlir//:Transforms", + "@local_config_mlir//:Translation", + "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo", + "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_dialect_registration", + "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_legalize_to_standard", + ], + alwayslink = 1, +)
diff --git a/iree/compiler/Translation/Interpreter/CMakeLists.txt b/compiler/Translation/Interpreter/CMakeLists.txt similarity index 100% rename from iree/compiler/Translation/Interpreter/CMakeLists.txt rename to compiler/Translation/Interpreter/CMakeLists.txt
diff --git a/compiler/Translation/Interpreter/InterpreterExecutableTranslation.cpp b/compiler/Translation/Interpreter/InterpreterExecutableTranslation.cpp new file mode 100644 index 0000000..ae0fb52 --- /dev/null +++ b/compiler/Translation/Interpreter/InterpreterExecutableTranslation.cpp
@@ -0,0 +1,289 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/Translation/Interpreter/InterpreterExecutableTranslation.h" + +#include <cstdint> +#include <iostream> +#include <vector> + +#include "flatbuffers/flatbuffers.h" +#include "flatbuffers/minireflect.h" +#include "compiler/IR/ConfigOps.h" +#include "compiler/IR/Interpreter/OpWriters.h" +#include "compiler/IR/Types.h" +#include "compiler/Serialization/VMFunctionBuilder.h" +#include "compiler/Serialization/VMFunctionTableBuilder.h" +#include "compiler/Serialization/VMModuleBuilder.h" +#include "compiler/Transforms/Interpreter/Passes.h" +#include "compiler/Transforms/Passes.h" +#include "compiler/Utils/Macros.h" +#include "compiler/Utils/OpUtils.h" +#include "compiler/Utils/TranslationUtils.h" +#include "schemas/executable_def_generated.h" +#include "schemas/module_def_generated.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Debug.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Module.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Translation.h" +#include "tensorflow/compiler/mlir/xla/transforms/passes.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +// Builds a pass pipeline that optimizes and legalizes the module to the form +// expected by translation. +void buildLegalizeInputPassPipeline(PassManager *passManager) { + // Standard passes that shake out a lot of garbage. + // Some may have been run prior to translation but this ensures we are always + // in a known state. + passManager->addPass(createCanonicalizerPass()); + passManager->addPass(createLoopFusionPass()); + passManager->addPass(createLoopInvariantCodeMotionPass()); + passManager->addPass(createMemRefDataFlowOptPass()); + passManager->addPass(createCanonicalizerPass()); + passManager->addPass(createSimplifyAffineStructuresPass()); + passManager->addPass(createCSEPass()); + passManager->addPass(createCanonicalizerPass()); + + // Eliminate ops we don't care about based on a lack of side-effects. + // IREE does not guarantee exception/error behavior of dead ops. + passManager->addPass(createAggressiveOpEliminationPass()); + + // Expand uses of tuples into independent args/results. + passManager->addPass(createConvertFromTupleCallingConventionPass()); + passManager->addPass(createCanonicalizerPass()); +} + +// Builds a pass pipeline that converts functions to the iree_hl_interp dialect. +void buildInterpreterConversionPassPipeline(PassManager *passManager) { + // We don't need the IREE binding ops anymore, as we match the calling + // convention exactly (we're the same VM). + passManager->addPass(createMakeExecutableABIPass()); + + // Convert to the memref calling convention and optimize away as many + // loads and stores as we can prior to progressing. + passManager->addPass(createConvertToMemRefCallingConventionPass()); + passManager->addPass(createCanonicalizerPass()); + passManager->addPass(createMemRefDataFlowOptPass()); + + // Convert various dialects to IREE opcodes and cleanup leftover conversions. + passManager->addPass(createLowerToInterpreterDialectPass()); + passManager->addPass(createCanonicalizerPass()); + passManager->addPass(createAggressiveOpEliminationPass()); + + // Widen reduction functions (that have iree.executable.reduction attrs) to + // use their primitive IREE ops. + passManager->addPass(createExpandReductionsToOpsPass()); + + // Convert any uses of index to int32_t (as we explicitly don't want to + // support dynamic index width). + // This also looks for other weird types (i1, etc). + passManager->addPass(createLegalizeTypeStoragePass()); + + // Perform any last-minute optimizations to trim down the IR. + passManager->addPass(createAggressiveOpEliminationPass()); + passManager->addPass(createCanonicalizerPass()); + passManager->addPass(createLoopFusionPass()); + passManager->addPass(createLoopInvariantCodeMotionPass()); + passManager->addPass(createMemRefDataFlowOptPass()); + passManager->addPass(createCanonicalizerPass()); + passManager->addPass(createCSEPass()); + passManager->addPass(createCanonicalizerPass()); + + // Drop all functions that are not reachable. + passManager->addPass(createDropUnreachableExecutableFunctionsPass()); +} + +// Builds a pass pipeline that lowers the iree_hl_interp dialect to the +// iree_ll_interp dialect and prepares for serialization. +void buildInterpreterLoweringPassPipeline(PassManager *passManager) { + // Lower iree_hl_interp -> iree_ll_interp. + passManager->addPass(createLowerInterpreterDialectPass()); + + // Assign ordinals used by the bytecode to reference executables and + // functions. + passManager->addPass(createAssignFunctionOrdinalsPass()); +} + +class InterpreterTranslator { + public: + explicit InterpreterTranslator(ExecutableTranslationOptions options) + : options_(options) {} + + const ExecutableTranslationOptions &options() const { return options_; } + + std::unique_ptr<iree::ExecutableDefT> translateExecutable( + IREE::ExecutableOp executableOp); + + private: + LogicalResult translateExecutableModule(IREE::ExecutableOp executableOp, + ModuleOp moduleOp, + VMModuleBuilder *moduleBuilder); + LogicalResult declareFunction(FuncOp function, + VMModuleBuilder *moduleBuilder); + LogicalResult defineFunction(FuncOp function, VMModuleBuilder *moduleBuilder); + + ExecutableTranslationOptions options_; +}; + +std::unique_ptr<iree::ExecutableDefT> +InterpreterTranslator::translateExecutable(IREE::ExecutableOp executableOp) { + auto moduleOp = executableOp.getInnerModule(); + + // Run all passes to go from input to the iree_ll_interp dialect. + auto executableConversionPasses = + createPassManager(moduleOp.getContext(), options()); + buildLegalizeInputPassPipeline(executableConversionPasses.get()); + buildInterpreterConversionPassPipeline(executableConversionPasses.get()); + buildInterpreterLoweringPassPipeline(executableConversionPasses.get()); + if (failed(runPassPipeline(options(), executableConversionPasses.get(), + moduleOp))) { + executableOp.emitError() << "Failed to run conversion passes"; + return {}; + } + + // Build the module bytecode. + ::flatbuffers::FlatBufferBuilder fbb; + VMModuleBuilder moduleBuilder(&fbb); + if (failed( + translateExecutableModule(executableOp, moduleOp, &moduleBuilder))) { + executableOp.emitError() << "Failed to translate executable module"; + return {}; + } + auto moduleDef = moduleBuilder.Finish(); + if (moduleDef.IsNull()) { + moduleOp.emitError() << "Failed to verify completed module def"; + return {}; + } + auto bytes = moduleBuilder.Serialize(moduleDef); + if (bytes.empty()) { + moduleOp.emitError() << "Failed to serialize final module def"; + return {}; + } + + OpBuilder builder(executableOp); + executableOp.setAttr("format", builder.getI32IntegerAttr(static_cast<int32_t>( + IREE::ExecutableFormat::IreeBytecode))); + + auto executableDef = std::make_unique<iree::ExecutableDefT>(); + executableDef->format = + static_cast<uint32_t>(IREE::ExecutableFormat::IreeBytecode); + executableDef->supported_features = iree::ExecutableFeature::kDebugging; + executableDef->contents = std::move(bytes); + return executableDef; +} + +LogicalResult InterpreterTranslator::translateExecutableModule( + IREE::ExecutableOp executableOp, ModuleOp moduleOp, + VMModuleBuilder *moduleBuilder) { + // Declare functions first so that we get stable indices during declaration + // (as call ops need to use the function table). + for (auto function : moduleOp.getOps<FuncOp>()) { + RETURN_IF_FAILURE(declareFunction(function, moduleBuilder)); + } + + // Define functions now that all functions have been declared. + for (auto function : moduleOp.getOps<FuncOp>()) { + RETURN_IF_FAILURE(defineFunction(function, moduleBuilder)); + } + + return success(); +} + +LogicalResult InterpreterTranslator::declareFunction( + FuncOp function, VMModuleBuilder *moduleBuilder) { + auto *functionTable = moduleBuilder->function_table(); + if (functionTable->IsFunctionDeclared(function)) { + // Already declared. + return success(); + } + + LinkageType linkageType; + if (function.isExternal()) { + linkageType = LinkageType::kImport; + } else if (function.getAttr("iree.executable.export")) { + linkageType = LinkageType::kExport; + } else { + linkageType = LinkageType::kInternal; + } + if (failed(functionTable->DeclareFunction(function, linkageType))) { + return function.emitError() << "Unable to declare function"; + } + + // Import functions must have their definition defined here so we get their + // type. Internal and export functions will be defined during conversion. + if (linkageType == LinkageType::kImport) { + VMFunctionBuilder functionBuilder(function, moduleBuilder->function_table(), + moduleBuilder->fbb()); + auto functionOffset = functionBuilder.Finish(); + if (functionOffset.IsNull()) { + return function.emitError() + << "Failed to create import function bytecode"; + } + RETURN_IF_FAILURE( + functionTable->DefineFunction(function, functionOffset, {})); + } + + return success(); +} + +LogicalResult InterpreterTranslator::defineFunction( + FuncOp function, VMModuleBuilder *moduleBuilder) { + VMFunctionBuilder functionBuilder(function, moduleBuilder->function_table(), + moduleBuilder->fbb()); + registerInterpreterCustomWriters(&functionBuilder); + RETURN_IF_FAILURE(functionBuilder.ConvertBytecode()); + auto functionOffset = functionBuilder.Finish(); + if (functionOffset.IsNull()) { + return function.emitError() << "Failed to serialize function"; + } + RETURN_IF_FAILURE(moduleBuilder->function_table()->DefineFunction( + function, functionOffset, functionBuilder.source_map())); + return success(); +} + +} // namespace + +llvm::Optional<ExecutableTranslationResult> +translateExecutableToInterpreterExecutable( + ArrayRef<IREE::ExecutableOp> executableOps, + ExecutableTranslationOptions options) { + InterpreterTranslator translator(options); + ExecutableTranslationResult translationResult; + for (auto executableOp : llvm::make_early_inc_range(executableOps)) { + auto executableDef = translator.translateExecutable(executableOp); + if (!executableDef) { + executableOp.emitError() << "Failed to translate one or more executables"; + return llvm::None; + } + translationResult.executable_defs.push_back(std::move(executableDef)); + } + return translationResult; +} + +static ExecutableTranslationRegistration + InterpreterExecutableTranslationRegistration( + "interpreter-bytecode", translateExecutableToInterpreterExecutable); + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Translation/Interpreter/InterpreterExecutableTranslation.h b/compiler/Translation/Interpreter/InterpreterExecutableTranslation.h new file mode 100644 index 0000000..0e90e61 --- /dev/null +++ b/compiler/Translation/Interpreter/InterpreterExecutableTranslation.h
@@ -0,0 +1,39 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_TRANSLATION_INTERPRETER_INTERPRETEREXECUTABLETRANSLATION_H_ +#define IREE_COMPILER_TRANSLATION_INTERPRETER_INTERPRETEREXECUTABLETRANSLATION_H_ + +#include <vector> + +#include "compiler/IR/StructureOps.h" +#include "compiler/Utils/TranslationUtils.h" +#include "mlir/IR/Module.h" + +namespace mlir { +namespace iree_compiler { + +// Translates an MLIR module into a bytecode interpreter executable. +// These executables are stored as IREE modules as defined in the +// https://github.com/google/iree/tree/master/iree/schemas/module_def.fbs +// schema. +llvm::Optional<ExecutableTranslationResult> +translateExecutableToInterpreterExecutable( + ArrayRef<IREE::ExecutableOp> executableOps, + ExecutableTranslationOptions options = {}); + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_TRANSLATION_INTERPRETER_INTERPRETEREXECUTABLETRANSLATION_H_
diff --git a/compiler/Translation/SPIRV/AffineExprCodegen.h b/compiler/Translation/SPIRV/AffineExprCodegen.h new file mode 100644 index 0000000..ff12b37 --- /dev/null +++ b/compiler/Translation/SPIRV/AffineExprCodegen.h
@@ -0,0 +1,143 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//===- AffineExprCodegen.h -------------------------------------*- C++//-*-===// +// +// Code-generation for Affine Expression. +// +//===----------------------------------------------------------------------===// +#ifndef IREE_COMPILER_TRANSLATION_SPIRV_AFFINEEXPRCODGEN_H +#define IREE_COMPILER_TRANSLATION_SPIRV_AFFINEEXPRCODGEN_H + +#include "compiler/Translation/SPIRV/XLAIndexPropagation.h" +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/IR/AffineExprVisitor.h" + +namespace mlir { +namespace iree_compiler { + +/// Codegenerator for affine expressions. +class AffineExprCodegen : public AffineExprVisitor<AffineExprCodegen, Value *> { + public: + explicit AffineExprCodegen(spirv::ModuleOp module, + IndexComputationCache &tensorIndices) + : builder(module.getContext()), + location(module.getLoc()), + tensorIndices(tensorIndices) {} + + Value *visitAddExpr(AffineBinaryOpExpr expr) { + auto operand1 = getValueInternal(expr.getLHS()); + auto operand2 = getValueInternal(expr.getRHS()); + return builder.create<spirv::IAddOp>(location, operand1, operand2); + } + Value *visitMulExpr(AffineBinaryOpExpr expr) { + auto operand1 = getValueInternal(expr.getLHS()); + auto operand2 = getValueInternal(expr.getRHS()); + return builder.create<spirv::IMulOp>(location, operand1, operand2); + } + Value *visitModExpr(AffineBinaryOpExpr expr) { + auto operand1 = getValueInternal(expr.getLHS()); + auto operand2 = getValueInternal(expr.getRHS()); + return builder.create<spirv::SModOp>(location, operand1, operand2); + } + Value *visitFloorDivExpr(AffineBinaryOpExpr expr) { + auto operand1 = getValueInternal(expr.getLHS()); + auto operand2 = getValueInternal(expr.getRHS()); + return builder.create<spirv::SDivOp>(location, operand1, operand2); + } + Value *visitCeilDivExpr(AffineBinaryOpExpr expr) { + // TODO(ravishankarm): Implement ceil div expr codegen. + llvm_unreachable("Unimplemented affine AffineCeilDivExpr codegen"); + return nullptr; + } + Value *visitConstantExpr(AffineConstantExpr expr) { + return builder.create<spirv::ConstantOp>( + location, builder.getIntegerType(32), + builder.getI32IntegerAttr(expr.getValue())); + } + Value *visitDimExpr(AffineDimExpr expr) { + return threadDimToDstValue.lookup(expr.getPosition()); + } + Value *visitSymbolExpr(AffineSymbolExpr expr) { + // TODO(ravishankarm): Implement symbol expr codegen. + llvm_unreachable("Unimplemented affine AffineSymbolExpr codegen"); + return nullptr; + } + + /// Set the value that contains the workitem ID along a particular + /// dimension. 0 -> x-dimension, 1 -> y-dimension, etc. + void setDimDstValue(unsigned dimID, Value *value) { + threadDimToDstValue[dimID] = value; + } + + /// Generates the scalar value for a affine expression. + Value *getValue(AffineExpr expr, OpBuilder::InsertPoint ip, Location loc) { + auto &val = exprToDstValue[expr]; + if (!val) { + location = loc; + builder.restoreInsertionPoint(ip); + val = visit(expr); + } + return val; + } + + /// Returns a list of indices of a particular tensor in the source dialect + /// needed within the dispatch function (obtained from the + /// IndexComputationCache) + SmallVector<AffineMap, 4> getIndices(Value *value) { + SmallVector<AffineMap, 4> indices; + for (auto &index : tensorIndices[value]) { + indices.push_back(index.first); + } + return indices; + } + + /// For a given tensor in the source dialect and index, return the index of + /// all operands needed to compute the result. + ArrayRef<AffineMap> getOperandIndices(Value *value, AffineMap index) { + return tensorIndices[value][index]; + } + + private: + /// Returns the Value corresponding to the AffineExpr `expr` by either + /// previously generated value for the same index, or by generating the value. + /// This version assumes the insertion point/Location has already been set. + Value *getValueInternal(AffineExpr expr) { + auto &val = exprToDstValue[expr]; + if (!val) { + val = visit(expr); + } + return val; + } + + OpBuilder builder; + + Location location; + + /// Map from launch dimension to scalar value. + DenseMap<unsigned, Value *> threadDimToDstValue; + + /// Cache of affine expression to scalar value. TODO(ravishankarm) : Might + /// need to be changed if we are handling control flow within the dispatch + /// function. + DenseMap<AffineExpr, Value *> exprToDstValue; + + /// Map from tensor value in source dialect to list of indices of the tensor + /// needed within a workitem to compute the results of the dispatch function. + IndexComputationCache &tensorIndices; +}; +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_TRANSLATION_SPIRV_AFFINEEXPRCODGEN_H
diff --git a/compiler/Translation/SPIRV/BUILD b/compiler/Translation/SPIRV/BUILD new file mode 100644 index 0000000..287af86 --- /dev/null +++ b/compiler/Translation/SPIRV/BUILD
@@ -0,0 +1,50 @@ +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "SPIRV", + srcs = [ + "AffineExprCodegen.h", + "EmbeddedKernels.cpp", + "IREEIndexComputation.cpp", + "IREEToSPIRV.cpp", + "IREEToSPIRVPass.cpp", + "IndexComputation.cpp", + "SPIRVExecutableTranslation.cpp", + "SPIRVLowering.cpp", + "SPIRVLowering.h", + "XLAIndexPropagation.cpp", + ], + hdrs = [ + "EmbeddedKernels.h", + "IREEIndexComputation.h", + "IREEToSPIRV.h", + "IREEToSPIRVPass.h", + "IndexComputation.h", + "SPIRVExecutableTranslation.h", + "XLAIndexPropagation.h", + ], + deps = [ + "///compiler/IR", + "///compiler/Translation/SPIRV/Kernels", + "///compiler/Utils", + "///schemas", + "///schemas:spirv_executable_def_cc_fbs", + "@com_github_google_flatbuffers//:flatbuffers", + "@llvm//:support", + "@local_config_mlir//:IR", + "@local_config_mlir//:Pass", + "@local_config_mlir//:SPIRVDialect", + "@local_config_mlir//:SPIRVDialectRegistration", + "@local_config_mlir//:SPIRVSerialization", + "@local_config_mlir//:StandardDialectRegistration", + "@local_config_mlir//:StandardOps", + "@local_config_mlir//:Support", + "@local_config_mlir//:Transforms", + "@local_config_mlir//:Translation", + "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo", + ], + alwayslink = 1, +)
diff --git a/iree/compiler/Translation/SPIRV/CMakeLists.txt b/compiler/Translation/SPIRV/CMakeLists.txt similarity index 100% rename from iree/compiler/Translation/SPIRV/CMakeLists.txt rename to compiler/Translation/SPIRV/CMakeLists.txt
diff --git a/compiler/Translation/SPIRV/EmbeddedKernels.cpp b/compiler/Translation/SPIRV/EmbeddedKernels.cpp new file mode 100644 index 0000000..1fd5b1c --- /dev/null +++ b/compiler/Translation/SPIRV/EmbeddedKernels.cpp
@@ -0,0 +1,219 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/Translation/SPIRV/EmbeddedKernels.h" + +#include "compiler/Translation/SPIRV/Kernels/Kernels.h" +#include "schemas/spirv_executable_def_generated.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Module.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +// Reads the SPIR-V code for the embedded kernel with the given file name. +// If the kernel under Kernels/ is 'matmul.comp' then |kernelName| would be +// 'matmul.spv' (because it's been compiled). +std::vector<uint32_t> readEmbeddedKernelCode(std::string kernelName) { + auto *fileToc = spirv_kernels::Kernels_create(); + for (int i = 0; i < spirv_kernels::Kernels_size(); ++i) { + if (std::strcmp(fileToc[i].name, kernelName.c_str()) == 0) { + std::vector<uint32_t> code; + code.resize(fileToc[i].size / 4); + std::memcpy(code.data(), fileToc[i].data, fileToc[i].size); + return code; + } + } + return {}; +} + +// Adds a storage buffer binding to the descriptor set layout. +void addDescriptorSetLayoutBinding(uint32_t binding, + iree::VkDescriptorSetLayoutDefT *dsl) { + auto bindingDef = std::make_unique<iree::VkDescriptorSetLayoutBindingDefT>(); + bindingDef->binding = binding; + bindingDef->descriptor_count = 1; + bindingDef->descriptor_type = 7; // VK_DESCRIPTOR_TYPE_STORAGE_BUFFER + bindingDef->stage_flags = 0x00000020; // VK_SHADER_STAGE_COMPUTE_BIT + dsl->bindings.push_back(std::move(bindingDef)); +} + +// Adds a specialization map entry for |constant_id| set to a 4-byte int value. +void addSpecializationMapEntry( + uint32_t constant_id, uint32_t value, + iree::VkSpecializationInfoDefT *specializationInfoDef) { + auto specValue = std::make_unique<iree::VkSpecializationMapEntryDefT>(); + specValue->constant_id = constant_id; + specValue->uint32_value = value; + specializationInfoDef->map_entries.push_back(std::move(specValue)); +} + +LogicalResult buildReductionExecutable(IREE::ExecutableOp executableOp, + FuncOp entryFuncOp, + iree::SpirVExecutableDefT *out_def) { + auto funcType = entryFuncOp.getType(); + auto arg0 = funcType.getInput(0).cast<ShapedType>(); + if (!arg0.getElementType().isF32()) { + // When we do other types we'll need other shaders. + return entryFuncOp.emitOpError() + << "Only floating point reduction is implemented"; + } + + auto module = executableOp.getInnerModule(); + auto applyFuncAttr = entryFuncOp.getAttrOfType<SymbolRefAttr>( + "iree.executable.reduction.apply"); + auto applyFuncOp = module.lookupSymbol(applyFuncAttr.getValue()); + + // TODO(benvanik): specialize (template on shapes/types/etc). + std::string kernelName = "reduce_untiled.spv"; + llvm::Optional<uint32_t> operationId; + applyFuncOp->walk([&](Operation *op) { + if (isa<xla_hlo::AddOp>(op)) { + operationId = 0; + } else if (isa<xla_hlo::MaxOp>(op)) { + operationId = 1; + } else if (isa<xla_hlo::MinOp>(op)) { + operationId = 2; + } + }); + if (!operationId.hasValue()) { + applyFuncOp->dump(); + return applyFuncOp->emitOpError() << "Unsupported reduction operator"; + } + + out_def->tag = "__reduce__"; + out_def->entry_points = {"main"}; + + out_def->code = readEmbeddedKernelCode(kernelName); + + // arg0, arg1, ret0 + auto pipelineLayoutDef = std::make_unique<iree::VkPipelineLayoutDefT>(); + pipelineLayoutDef->buffer_binding_set = 0; + auto dsl = std::make_unique<iree::VkDescriptorSetLayoutDefT>(); + addDescriptorSetLayoutBinding(0, dsl.get()); + addDescriptorSetLayoutBinding(1, dsl.get()); + addDescriptorSetLayoutBinding(2, dsl.get()); + pipelineLayoutDef->descriptor_set_layouts.push_back(std::move(dsl)); + out_def->pipeline_layout = std::move(pipelineLayoutDef); + + // See the shader source for documentation on the values of A/B/C/R. + int64_t reductionDimension = + entryFuncOp + .getAttrOfType<IntegerAttr>("iree.executable.reduction.dimension") + .getInt(); + uint32_t r = arg0.getDimSize(reductionDimension); + uint32_t a = 1; + for (int i = 0; i < reductionDimension; ++i) { + a *= arg0.getDimSize(i); + } + uint32_t b = 1; + for (int i = reductionDimension + 1; i < arg0.getRank(); ++i) { + b *= arg0.getDimSize(i); + } + uint32_t c = b; + + auto specializationInfoDef = + std::make_unique<iree::VkSpecializationInfoDefT>(); + addSpecializationMapEntry(/*kOperationId*/ 100, operationId.getValue(), + specializationInfoDef.get()); + addSpecializationMapEntry(/*kA*/ 101, a, specializationInfoDef.get()); + addSpecializationMapEntry(/*kB*/ 102, b, specializationInfoDef.get()); + addSpecializationMapEntry(/*kC*/ 103, c, specializationInfoDef.get()); + addSpecializationMapEntry(/*kR*/ 104, r, specializationInfoDef.get()); + out_def->specialization_info = std::move(specializationInfoDef); + + return success(); +} + +// Builds a SPIR-V executable from a well-known matmul executable. +// |out_def| will be populated with all required information for serialization. +LogicalResult buildMatMulExecutable(IREE::ExecutableOp executableOp, + FuncOp entryFuncOp, xla_hlo::DotOp dotOp, + iree::SpirVExecutableDefT *out_def) { + auto arg0 = dotOp.getOperand(0)->getType().cast<ShapedType>(); + auto arg1 = dotOp.getOperand(1)->getType().cast<ShapedType>(); + + out_def->tag = "__matmul__"; + out_def->entry_points = {"main"}; + + // TODO(benvanik): specialize (template on shapes/types/etc). + out_def->code = readEmbeddedKernelCode("matmul.spv"); + + // arg0, arg1, ret0 + auto pipelineLayoutDef = std::make_unique<iree::VkPipelineLayoutDefT>(); + pipelineLayoutDef->buffer_binding_set = 0; + auto dsl = std::make_unique<iree::VkDescriptorSetLayoutDefT>(); + addDescriptorSetLayoutBinding(0, dsl.get()); + addDescriptorSetLayoutBinding(1, dsl.get()); + addDescriptorSetLayoutBinding(2, dsl.get()); + pipelineLayoutDef->descriptor_set_layouts.push_back(std::move(dsl)); + out_def->pipeline_layout = std::move(pipelineLayoutDef); + + // Shapes of [arg0, arg1, ret0]. + // arg0 = [b0, m, k] + // arg1 = [b0, k, n] + // ret0 = [b0, m, n] + // Note that we handle both batched (rank 3) and unbatched (rank 2). + uint32_t m = arg0.getRank() == 3 ? arg0.getDimSize(1) : arg0.getDimSize(0); + uint32_t k = arg0.getRank() == 3 ? arg0.getDimSize(2) : arg0.getDimSize(1); + uint32_t n = arg1.getRank() == 3 ? arg1.getDimSize(2) : arg1.getDimSize(1); + auto specializationInfoDef = + std::make_unique<iree::VkSpecializationInfoDefT>(); + addSpecializationMapEntry(/*kMatrixM*/ 100, m, specializationInfoDef.get()); + addSpecializationMapEntry(/*kMatrixK*/ 101, k, specializationInfoDef.get()); + addSpecializationMapEntry(/*kMatrixN*/ 102, n, specializationInfoDef.get()); + out_def->specialization_info = std::move(specializationInfoDef); + + return success(); +} + +} // namespace + +bool tryEmbeddedKernelRewrite(IREE::ExecutableOp executableOp, + iree::SpirVExecutableDefT *out_def) { + auto module = executableOp.getInnerModule(); + for (auto funcOp : module.getOps<FuncOp>()) { + if (funcOp.getAttr("iree.executable.reduction")) { + if (failed(buildReductionExecutable(executableOp, funcOp, out_def))) { + executableOp.emitOpError() << "Failed to splat in the reduction kernel"; + return false; + } + return true; + } + + for (auto &block : funcOp) { + for (auto &op : block) { + if (isa<xla_hlo::ConvOp>(&op)) { + executableOp.emitOpError() << "Conv not yet implemented"; + return false; + } else if (auto dotOp = dyn_cast_or_null<xla_hlo::DotOp>(&op)) { + if (failed(buildMatMulExecutable(executableOp, funcOp, dotOp, + out_def))) { + executableOp.emitOpError() + << "Failed to splat in the matmul kernel"; + return false; + } + return true; + } + } + } + } + return false; +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Translation/SPIRV/EmbeddedKernels.h b/compiler/Translation/SPIRV/EmbeddedKernels.h new file mode 100644 index 0000000..3535944 --- /dev/null +++ b/compiler/Translation/SPIRV/EmbeddedKernels.h
@@ -0,0 +1,35 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_TRANSLATION_SPIRV_EMBEDDEDKERNELS_H_ +#define IREE_COMPILER_TRANSLATION_SPIRV_EMBEDDEDKERNELS_H_ + +#include "compiler/IR/StructureOps.h" +#include "flatbuffers/flatbuffers.h" +#include "mlir/Support/LogicalResult.h" +#include "schemas/spirv_executable_def_generated.h" + +namespace mlir { +namespace iree_compiler { + +// Tries to match the |executableOp| against an embedded kernel and if matched +// will populate |out_def| with the kernel. +// Returns true if the kernel matched and was populated. +bool tryEmbeddedKernelRewrite(IREE::ExecutableOp executableOp, + iree::SpirVExecutableDefT* out_def); + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_TRANSLATION_SPIRV_EMBEDDEDKERNELS_H_
diff --git a/compiler/Translation/SPIRV/IREEIndexComputation.cpp b/compiler/Translation/SPIRV/IREEIndexComputation.cpp new file mode 100644 index 0000000..b05de91 --- /dev/null +++ b/compiler/Translation/SPIRV/IREEIndexComputation.cpp
@@ -0,0 +1,107 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//===- IREEIndexComputation.cpp --------------------------------*- C++//-*-===// +// +// Implementaiton of Index Propagation for IREE statements that are used in +// dispatch functions. +// +//===----------------------------------------------------------------------===// +#include "compiler/Translation/SPIRV/IREEIndexComputation.h" + +namespace mlir { +namespace iree_compiler { + +//===----------------------------------------------------------------------===// +// IREELoadInputOp +//===----------------------------------------------------------------------===// + +LogicalResult IREELoadIndexPropagation::propagateIndexMap( + Operation *operation, IndexComputationCache &indexMap) const { + auto loadOp = cast<IREE::LoadInputOp>(operation); + auto result = operation->getResult(0); + auto src = loadOp.src(); + auto resultType = result->getType().dyn_cast<RankedTensorType>(); + auto srcType = src->getType().dyn_cast<MemRefType>(); + if (!resultType || !srcType || resultType.getShape() != srcType.getShape()) { + return loadOp.emitError( + "mismatch in shape of the result tensor and source memref"); + } + // Initialize the storage for the src. + indexMap[src]; + for (auto &resultIndexMap : indexMap[operation->getResult(0)]) { + indexMap[src][resultIndexMap.first]; + resultIndexMap.second.push_back(resultIndexMap.first); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// IREEStoreOutputOp +//===----------------------------------------------------------------------===// + +LogicalResult IREEStoreIndexPropagation::propagateIndexMap( + Operation *operation, IndexComputationCache &indexMap) const { + auto storeOp = cast<IREE::StoreOutputOp>(operation); + auto src = storeOp.src(); + auto srcType = src->getType().dyn_cast<ShapedType>(); + if (!srcType || !srcType.hasStaticShape()) { + return storeOp.emitError( + "can only handle store with src being tensor of static shape"); + } + + SmallVector<int64_t, 3> launchSize; + if (failed(getLaunchSize(operation, launchSize))) { + return failure(); + } + + // The launch dimensions are [x, y, z] co-ordinates. The reverse of this is + // used to determine the location of the tensor element computed by a + // workitem. The choice is failry arbitrary but is done to enable the common + // case where consecutive workitems compute "logically" adjacent tensor + // elements. + Builder builder(storeOp.getContext()); + SmallVector<AffineExpr, 4> affineExprs; + int64_t numElements = 1; + for (size_t i = launchSize.size(); i > 0; --i) { + // If launchSize along any dimension is 1, just use 0 for the index. This is + // not just an optimization. If you have an output of type memref<f32> which + // is lowered to !spv.ptr<!spv.struct<f32>, StorageBuffer> with launchSize + // <1>, then spv.AccessChain requires the indices to be a constant. + if (launchSize[i - 1] == 1) { + affineExprs.push_back(builder.getAffineConstantExpr(0)); + } else { + affineExprs.push_back(builder.getAffineDimExpr(i - 1)); + } + numElements *= launchSize[i - 1]; + } + auto launchMap = AffineMap::get(launchSize.size(), 0, affineExprs); + + // The stored tensor can be a reshape of the launch dimension. It still + // retains the requirement that each workitem is computing a single element + // of the stored tensor. + AffineMap srcMap; + SmallVector<int64_t, 3> revLaunchSize(reverse(launchSize)); + if (numElements != srcType.getNumElements() || + failed(getReshapeOperandMap(builder, launchMap, revLaunchSize, + srcType.getShape(), srcMap))) { + return storeOp.emitError( + "unable to map from launch id to element to compute within a " + "workitem"); + } + indexMap[src][srcMap]; + return success(); +} +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Translation/SPIRV/IREEIndexComputation.h b/compiler/Translation/SPIRV/IREEIndexComputation.h new file mode 100644 index 0000000..55def25 --- /dev/null +++ b/compiler/Translation/SPIRV/IREEIndexComputation.h
@@ -0,0 +1,92 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//===- IREEIndexComputation.h ----------------------------------*- C++//-*-===// +// +// Index Propagation for IREE statements that are used in dispatch functions. +// +//===----------------------------------------------------------------------===// +#ifndef IREE_COMPILER_TRANSLATION_SPIRV_H +#define IREE_COMPILER_TRANSLATION_SPIRV_H + +#include "compiler/IR/Ops.h" +#include "compiler/IR/StructureOps.h" +#include "compiler/Translation/SPIRV/XLAIndexPropagation.h" +#include "mlir/IR/Function.h" + +namespace mlir { +namespace iree_compiler { + +/// Gets the launch size associated with the dispatch function that this op is +/// part of. +inline LogicalResult getLaunchSize(Operation *op, + SmallVectorImpl<int64_t> &launchSize) { + auto funcOp = op->getParentOfType<FuncOp>(); + if (!funcOp || !funcOp.getAttr("iree.executable.export")) { + return op->emitError( + "expected operation to be in dispatch function to get launch size"); + } + auto workloadAttr = + funcOp.getAttrOfType<DenseElementsAttr>("iree.executable.workload"); + if (!workloadAttr) { + op->emitError( + "unable to find workload size, missing attribute " + "iree.executable.workload in dispatch function"); + } + launchSize.clear(); + for (auto value : workloadAttr.getValues<APInt>()) { + launchSize.push_back(value.getSExtValue()); + } + // Drop trailing ones. + auto dropFrom = launchSize.size() - 1; + while (dropFrom > 0 && launchSize[dropFrom] == 1) { + --dropFrom; + } + if (dropFrom > 0) { + launchSize.erase(std::next(launchSize.begin(), dropFrom + 1), + launchSize.end()); + } + return success(); +} + +/// Index propagation for iree.load_input operation. This operation is +/// essentially a copy from a memref to a tensor. So just copy the index map to +/// the memref operand from the result tensor. +class IREELoadIndexPropagation final + : public IndexPropagationOp<IREE::LoadInputOp> { + public: + using IndexPropagationOp<IREE::LoadInputOp>::IndexPropagationOp; + + LogicalResult propagateIndexMap( + Operation *operation, IndexComputationCache &indexMap) const override; +}; + +/// Index propagation for iree.store_output operation. The launch size is +/// assumed to match the shape of the tensor that is being stored. This +/// operation acts as a seed for the index propogation. Each workitem is assumed +/// to compute a single element of this tensor. The range of the index map is +/// the reverse of the launch dimension. +class IREEStoreIndexPropagation final + : public IndexPropagationOp<IREE::StoreOutputOp> { + public: + using IndexPropagationOp<IREE::StoreOutputOp>::IndexPropagationOp; + + LogicalResult propagateIndexMap( + Operation *operation, IndexComputationCache &indexMap) const override; +}; + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_TRANSLATION_SPIRV_H
diff --git a/compiler/Translation/SPIRV/IREEToSPIRV.cpp b/compiler/Translation/SPIRV/IREEToSPIRV.cpp new file mode 100644 index 0000000..97721ae --- /dev/null +++ b/compiler/Translation/SPIRV/IREEToSPIRV.cpp
@@ -0,0 +1,77 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//===- IREEToSPIRV.cpp -----------------------------------------*- C++//-*-===// +// +// Translation of IREE statements in dispatch functions to SPIR-V. +// +//===----------------------------------------------------------------------===// +#include "compiler/Translation/SPIRV/IREEToSPIRV.h" + +namespace mlir { +namespace iree_compiler { + +/// IREE::LoadInputOp is essentially a memcpy. Just update the `valueCache` with +/// the value of the operand. +LogicalResult IREELoadOpSPIRVLowering::lowerOperation( + Operation *op, OpBuilder &builder, AffineMap index, + ArrayRef<Value *> operands, ValueCache &valueCache) const { + auto loadOp = cast<IREE::LoadInputOp>(op); + auto result = loadOp.getResult(); + valueCache.setOperandDstValue(result, index, operands[0]); + return success(); +} + +/// IREE::StoreOp needs to write to the spv.globalVariable created for the +/// memref that holds the result of the dispatch function. +LogicalResult IREEStoreOpSPIRVLowering::lowerOperation( + Operation *op, OpBuilder &builder, AffineExprCodegen &affineExprCodegen, + ValueCache &valueCache, + DenseMap<Value *, spirv::GlobalVariableOp> &inputBuffers, + ArrayRef<spirv::GlobalVariableOp> outputBuffers) const { + auto storeOp = cast<IREE::StoreOutputOp>(op); + auto src = storeOp.src(); + auto indices = affineExprCodegen.getIndices(src); + if (indices.size() != 1) { + return storeOp.emitError( + "expected to compute a single element of the tensor that is stored " + "into the output memref"); + } + auto var = inputBuffers.lookup(storeOp.dst()); + if (!var) { + return storeOp.emitError( + "unable to find spv.globalVariable that corresponds to the dst memref"); + } + auto ptr = genPointerOffset(builder, storeOp.getLoc(), affineExprCodegen, + indices[0], var); + auto scalarValue = valueCache.getOperandDstValue(src, indices[0]); + builder.create<spirv::StoreOp>(storeOp.getLoc(), ptr, scalarValue, + /*memory_access = */ nullptr, + /*alignment = */ nullptr); + return success(); +} + +/// IREE::ReturnOp in dispatch functions lowered to SPIR-V should have no +/// operands. +LogicalResult IREEReturnOpSPIRVLowering::lowerOperation( + Operation *op, OpBuilder &builder, AffineExprCodegen &affineExprCodegen, + ValueCache &valueCache, + DenseMap<Value *, spirv::GlobalVariableOp> &inputBuffers, + ArrayRef<spirv::GlobalVariableOp> outputBuffers) const { + builder.create<spirv::ReturnOp>(op->getLoc()); + return success(); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Translation/SPIRV/IREEToSPIRV.h b/compiler/Translation/SPIRV/IREEToSPIRV.h new file mode 100644 index 0000000..ada74b7 --- /dev/null +++ b/compiler/Translation/SPIRV/IREEToSPIRV.h
@@ -0,0 +1,69 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//===- IREEToSPIRV.h -------------------------------------------*- C++//-*-===// +// +// Translation of IREE statements in dispatch functions to SPIR-V. +// +//===----------------------------------------------------------------------===// +#ifndef IREE_COMPILER_TRANSLATION_SPIRV_IREETOSPIRV_H +#define IREE_COMPILER_TRANSLATION_SPIRV_IREETOSPIRV_H + +#include "compiler/IR/Ops.h" +#include "compiler/IR/StructureOps.h" +#include "compiler/Translation/SPIRV/SPIRVLowering.h" + +namespace mlir { +namespace iree_compiler { + +/// Translation of iree.load_input operation. +class IREELoadOpSPIRVLowering final + : public SPIRVOpLowering<IREE::LoadInputOp> { + public: + using SPIRVOpLowering<IREE::LoadInputOp>::SPIRVOpLowering; + + LogicalResult lowerOperation(Operation *op, OpBuilder &builder, + AffineMap index, ArrayRef<Value *> operands, + ValueCache &valueCache) const override; +}; + +/// Translation of iree.return operation. +class IREEReturnOpSPIRVLowering final : public SPIRVOpLowering<IREE::ReturnOp> { + public: + using SPIRVOpLowering<IREE::ReturnOp>::SPIRVOpLowering; + + LogicalResult lowerOperation( + Operation *op, OpBuilder &builder, AffineExprCodegen &codegen, + ValueCache &valueCache, + DenseMap<Value *, spirv::GlobalVariableOp> &inputBuffers, + ArrayRef<spirv::GlobalVariableOp> outputBuffers) const override; +}; + +/// Translation of iree.store_output operation. +class IREEStoreOpSPIRVLowering final + : public SPIRVOpLowering<IREE::StoreOutputOp> { + public: + using SPIRVOpLowering<IREE::StoreOutputOp>::SPIRVOpLowering; + + LogicalResult lowerOperation( + Operation *op, OpBuilder &builder, AffineExprCodegen &codegen, + ValueCache &valueCache, + DenseMap<Value *, spirv::GlobalVariableOp> &inputBuffers, + ArrayRef<spirv::GlobalVariableOp> outputBuffers) const override; +}; + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_TRANSLATION_SPIRV_IREETOSPIRV_H
diff --git a/compiler/Translation/SPIRV/IREEToSPIRVPass.cpp b/compiler/Translation/SPIRV/IREEToSPIRVPass.cpp new file mode 100644 index 0000000..5672738 --- /dev/null +++ b/compiler/Translation/SPIRV/IREEToSPIRVPass.cpp
@@ -0,0 +1,196 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//===- IREEToSPIRVPass.cpp -------------------------------------*- C++//-*-===// +// +// Pass to translate iree executables for vulkan-spirv. +// +//===----------------------------------------------------------------------===// +#include "compiler/Translation/SPIRV/IREEToSPIRVPass.h" + +#include "compiler/Translation/SPIRV/IREEIndexComputation.h" +#include "compiler/Translation/SPIRV/IREEToSPIRV.h" +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/SPIRVTypes.h" +#include "mlir/Dialect/StandardOps/Ops.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +class IREEToSPIRVPass : public ModulePass<IREEToSPIRVPass> { + void runOnModule() override; +}; + +} // namespace + +void IREEToSPIRVPass::runOnModule() { + auto module = getModule(); + OpBuilder builder(module.getBodyRegion()); + + // Initialize the index computation. + IndexPropagationList<IndexPropagationOp<ConstantOp>, + // IREE-specific ops: + IndexPropagationOp<IREE::ReturnOp>, + IREELoadIndexPropagation, IREEStoreIndexPropagation, + // Standard dialect unary elementwise ops: + NoBroadcastPwOpIndexPropagation<SIToFPOp>, + NoBroadcastPwOpIndexPropagation<SignExtendIOp>, + // Standard dialect binary elementwise ops: + NoBroadcastPwOpIndexPropagation<AddFOp>, + NoBroadcastPwOpIndexPropagation<AddIOp>, + NoBroadcastPwOpIndexPropagation<AndOp>, + NoBroadcastPwOpIndexPropagation<CmpFOp>, + NoBroadcastPwOpIndexPropagation<CmpIOp>, + NoBroadcastPwOpIndexPropagation<DivFOp>, + NoBroadcastPwOpIndexPropagation<DivISOp>, + NoBroadcastPwOpIndexPropagation<DivIUOp>, + NoBroadcastPwOpIndexPropagation<MulFOp>, + NoBroadcastPwOpIndexPropagation<MulIOp>, + NoBroadcastPwOpIndexPropagation<OrOp>, + NoBroadcastPwOpIndexPropagation<RemFOp>, + NoBroadcastPwOpIndexPropagation<RemISOp>, + NoBroadcastPwOpIndexPropagation<RemIUOp>, + NoBroadcastPwOpIndexPropagation<SubFOp>, + NoBroadcastPwOpIndexPropagation<SubFOp>, + NoBroadcastPwOpIndexPropagation<SubIOp>, + NoBroadcastPwOpIndexPropagation<TruncateIOp>, + NoBroadcastPwOpIndexPropagation<XOrOp>, + NoBroadcastPwOpIndexPropagation<ZeroExtendIOp>, + // XLA unary elementwise ops: + NoBroadcastPwOpIndexPropagation<xla_hlo::AbsOp>, + NoBroadcastPwOpIndexPropagation<xla_hlo::CeilOp>, + NoBroadcastPwOpIndexPropagation<xla_hlo::ConvertOp>, + NoBroadcastPwOpIndexPropagation<xla_hlo::CosOp>, + NoBroadcastPwOpIndexPropagation<xla_hlo::ExpOp>, + NoBroadcastPwOpIndexPropagation<xla_hlo::FloorOp>, + NoBroadcastPwOpIndexPropagation<xla_hlo::LogOp>, + NoBroadcastPwOpIndexPropagation<xla_hlo::NegOp>, + NoBroadcastPwOpIndexPropagation<xla_hlo::RsqrtOp>, + NoBroadcastPwOpIndexPropagation<xla_hlo::SignOp>, + NoBroadcastPwOpIndexPropagation<xla_hlo::TanhOp>, + // XLA binary elementwise ops: + NoBroadcastPwOpIndexPropagation<xla_hlo::AddOp>, + NoBroadcastPwOpIndexPropagation<xla_hlo::AndOp>, + NoBroadcastPwOpIndexPropagation<xla_hlo::DivOp>, + NoBroadcastPwOpIndexPropagation<xla_hlo::MaxOp>, + NoBroadcastPwOpIndexPropagation<xla_hlo::MinOp>, + NoBroadcastPwOpIndexPropagation<xla_hlo::MulOp>, + NoBroadcastPwOpIndexPropagation<xla_hlo::SubOp>, + // XLA other ops: + // TODO(ravishankarm): conv, dot. + // TODO(ravishankarm): gather. + // TODO(ravishankarm): pad. + // TODO(ravishankarm): slice. + NoBroadcastPwOpIndexPropagation<xla_hlo::CopyOp>, + ReshapeOpIndexPropagation<xla_hlo::ReshapeOp>, + NoBroadcastPwOpIndexPropagation<xla_hlo::SelectOp>, + XLABroadcastOpIndexPropagation, + XLABroadcastInDimOpIndexPropagation, + XLAReverseOpIndexPropagation, + XLATransposeOpIndexPropagation> + indexPropagation; + + // Initialize the spir-v codegenerator. + SPIRVCodegen< + ConstantOpSPIRVLowering, + // IREE-specific ops: + IREELoadOpSPIRVLowering, IREEReturnOpSPIRVLowering, + IREEStoreOpSPIRVLowering, + // Standard dialect unary elementwise ops: + // Standard dialect binary elementwise ops: + SPIRVPwOpLowering<AddFOp, spirv::FAddOp>, + SPIRVPwOpLowering<DivFOp, spirv::FDivOp>, + SPIRVPwOpLowering<MulFOp, spirv::FMulOp>, + SPIRVPwOpLowering<SubFOp, spirv::FSubOp>, + SPIRVPwOpLowering<AddIOp, spirv::IAddOp>, + SPIRVPwOpLowering<DivISOp, spirv::SDivOp>, + SPIRVPwOpLowering<MulIOp, spirv::IMulOp>, + SPIRVPwOpLowering<SubIOp, spirv::ISubOp>, + // XLA unary elementwise ops: + SPIRVPwOpLowering<xla_hlo::AbsOp, spirv::GLSLSAbsOp, spirv::GLSLFAbsOp>, + SPIRVPwOpLowering<xla_hlo::CeilOp, spirv::GLSLCeilOp>, + // TODO(ravishankarm): xla_hlo::ConvertOp + SPIRVPwOpLowering<xla_hlo::CosOp, spirv::GLSLCosOp>, + SPIRVPwOpLowering<xla_hlo::ExpOp, spirv::GLSLExpOp>, + SPIRVPwOpLowering<xla_hlo::FloorOp, spirv::GLSLFloorOp>, + SPIRVPwOpLowering<xla_hlo::LogOp, spirv::GLSLLogOp>, + SPIRVPwOpLowering<xla_hlo::NegOp, spirv::FNegateOp>, + SPIRVPwOpLowering<xla_hlo::RsqrtOp, spirv::GLSLInverseSqrtOp>, + SPIRVPwOpLowering<xla_hlo::SignOp, spirv::GLSLSSignOp, + spirv::GLSLFSignOp>, + SPIRVPwOpLowering<xla_hlo::TanhOp, spirv::GLSLTanhOp>, + // XLA binary elementwise ops: + SPIRVPwOpLowering<xla_hlo::AddOp, spirv::IAddOp, spirv::FAddOp>, + SPIRVPwOpLowering<xla_hlo::AndOp, spirv::LogicalAndOp>, + SPIRVPwOpLowering<xla_hlo::DivOp, spirv::FDivOp>, + SPIRVPwOpLowering<xla_hlo::MaxOp, spirv::GLSLSMaxOp, spirv::GLSLFMaxOp>, + SPIRVPwOpLowering<xla_hlo::MinOp, spirv::GLSLSMinOp, spirv::GLSLFMinOp>, + SPIRVPwOpLowering<xla_hlo::MulOp, spirv::IMulOp, spirv::FMulOp>, + SPIRVPwOpLowering<xla_hlo::SubOp, spirv::ISubOp, spirv::FSubOp>, + // XLA other ops: + CmpFOpSPIRVLowering, + SPIRVPwOpLowering<xla_hlo::SelectOp, spirv::SelectOp>, + SPIRVIndexOpLowering<xla_hlo::BroadcastOp>, + SPIRVIndexOpLowering<xla_hlo::BroadcastInDimOp>, + SPIRVIndexOpLowering<xla_hlo::CopyOp>, + SPIRVIndexOpLowering<xla_hlo::ReshapeOp>, + SPIRVIndexOpLowering<xla_hlo::ReverseOp>, + SPIRVIndexOpLowering<xla_hlo::TransposeOp>> + spirvCodegen; + + // Create a spirv.module Op. + auto spvModule = builder.create<spirv::ModuleOp>( + module.getLoc(), + builder.getI32IntegerAttr( + static_cast<int32_t>(spirv::AddressingModel::Logical)), + builder.getI32IntegerAttr( + static_cast<int32_t>(spirv::MemoryModel::GLSL450))); + SmallVector<StringRef, 2> caps; + caps.push_back(spirv::stringifyCapability(spirv::Capability::Shader)); + spvModule.setAttr("capabilities", builder.getStrArrayAttr(caps)); + SmallVector<StringRef, 2> exts; + exts.push_back("SPV_KHR_storage_buffer_storage_class"); + spvModule.setAttr("extensions", builder.getStrArrayAttr(exts)); + + for (auto funcOp : module.getOps<FuncOp>()) { + // TODO(ravishankarm): FuncOps in executable that are not dispatch functions + // are not lowered to SPIR-V. Fix this limitation. + if (!funcOp.getAttr("iree.executable.export")) continue; + + IndexComputationCache indexMap; + if (failed(indexPropagation.propagate(funcOp.getBody(), indexMap))) { + return signalPassFailure(); + } + // dumpIndexCache(indexMap); + + ValueCache valueCache; + AffineExprCodegen affineExprCodegen(spvModule, indexMap); + if (failed(spirvCodegen.codegen(spvModule, funcOp, affineExprCodegen, + valueCache))) { + return signalPassFailure(); + } + } +} + +std::unique_ptr<OpPassBase<ModuleOp>> createIREEToSPIRVPass() { + return std::make_unique<IREEToSPIRVPass>(); +} +static PassRegistration<IREEToSPIRVPass> pass( + "convert-iree-to-spirv", + "Convert IREE dispatch functions to SPIR-V dialect"); + +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/IREEToSPIRVPass.h b/compiler/Translation/SPIRV/IREEToSPIRVPass.h similarity index 100% rename from iree/compiler/Translation/SPIRV/IREEToSPIRVPass.h rename to compiler/Translation/SPIRV/IREEToSPIRVPass.h
diff --git a/compiler/Translation/SPIRV/IndexComputation.cpp b/compiler/Translation/SPIRV/IndexComputation.cpp new file mode 100644 index 0000000..877fab0 --- /dev/null +++ b/compiler/Translation/SPIRV/IndexComputation.cpp
@@ -0,0 +1,269 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//===- IndexComputation.cpp ------------------------------------*- C++//-*-===// +// +// For an IREE dispatch function, compute the map from workitem ID to index of +// tensor computed within that workitem. +// +//===----------------------------------------------------------------------===// +#include "compiler/Translation/SPIRV/IndexComputation.h" + +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/raw_ostream.h" + +static llvm::cl::opt<bool> doAffineExprSimplify( + "simplify-spirv-affine-exprs", + llvm::cl::desc("Simplify affine expressions during code-generation."), + llvm::cl::init(true)); + +namespace mlir { +namespace iree_compiler { + +//===----------------------------------------------------------------------===// +// Reshape Utility Functions +//===----------------------------------------------------------------------===// + +namespace { +/// Handles shapes for scalars. Shape of scalars are represented as empty vetor, +/// i.e. {}. Its easier to do index propogation to handle the scalar as vector +/// of size 1. +inline SmallVector<int64_t, 4> handleIfScalar(ArrayRef<int64_t> shape) { + SmallVector<int64_t, 4> resultShape; + if (shape.empty()) { + return {1}; + } + return SmallVector<int64_t, 4>(shape.begin(), shape.end()); +} + +/// Reshapes are often used to either add a dimension of size 1 or remove a +/// dimension of size 1. Recognizing such cases can make the code-generation +/// easier. The AffineMap needs to either add a constant 0 in the range for such +/// added dimensions or drop those dimensions. +inline LogicalResult getAffineExprForAddOrRemoveDimension( + Builder &builder, ArrayRef<AffineExpr> resultExprs, + ArrayRef<int64_t> resultShape, ArrayRef<int64_t> operandShape, + SmallVectorImpl<AffineExpr> &operandExprs) { + auto resultIndex = resultShape.size(); + auto operandIndex = operandShape.size(); + operandExprs.resize(operandShape.size()); + // Try to match up the dimensions of the operand and result by ignoring any + // dimensions of size of 1 that are introduced. + while (resultIndex > 0 && operandIndex > 0) { + if (resultShape[resultIndex - 1] == -1 || + operandShape[operandIndex - 1] == -1) { + return failure(); + } + if (resultShape[resultIndex - 1] == operandShape[operandIndex - 1]) { + operandExprs[operandIndex - 1] = resultExprs[resultIndex - 1]; + resultIndex--; + operandIndex--; + continue; + } + if (resultShape[resultIndex - 1] == 1) { + // This is a dimension that is added on the operand. This affine + // expression corresponding to this dimension is dropped. + resultIndex--; + continue; + } + if (operandShape[operandIndex - 1] == 1) { + // This is a dimension of size 1 of the operand that is dropped. Add a + // constant expr 0. + operandExprs[operandIndex - 1] = builder.getAffineConstantExpr(0); + operandIndex--; + continue; + } + return failure(); + } + // Any remaining dimensions should be 1. + while (resultIndex > 0) { + if (resultShape[resultIndex - 1] != 1) { + return failure(); + } + resultIndex--; + } + while (operandIndex > 0) { + if (operandShape[operandIndex - 1] != 1) { + return failure(); + } + // This is a dimension of size 1 that is dropped. Add a constant expression + // 0. + operandExprs[operandIndex - 1] = builder.getAffineConstantExpr(0); + operandIndex--; + } + return success(); +} + +/// Constructs the strides of an array assuming a row-major packed layout. +// TODO(ravishankarm): This assumes the shape are static. When using dynamic +// shapes, parameters of each dimension can be used to construct AffineExpr for +// strides along each dimension. Note that multiplying two symbolic constants is +// technically not affine, but you could use another symbol to represent the +// product, so it should be still representable as affine exprs. +inline LogicalResult getRowMajorPackedStrides( + Builder &builder, ArrayRef<int64_t> shape, + SmallVectorImpl<AffineExpr> &strides) { + strides.resize(shape.size()); + int64_t stride = 1; + for (auto dim : enumerate(reverse(shape))) { + if (dim.value() < 0) { + // TODO(ravishankarm) : Better error message. + return failure(); + } + strides[shape.size() - 1 - dim.index()] = + builder.getAffineConstantExpr(stride); + stride *= dim.value(); + } + return success(); +} + +/// Linearizes the index of the result position accessed using the shape of the +/// result tensor and delinearizes it to get the position of the operand. +inline LogicalResult getAffineExprForReshape( + Builder &builder, unsigned numDims, unsigned numSymbols, + ArrayRef<AffineExpr> resultExprs, ArrayRef<int64_t> resultShape, + ArrayRef<int64_t> operandShape, SmallVectorImpl<AffineExpr> &operandExprs) { + // To linearize the index, assume that the memory is laid out in + // packed-row-major layout based on the shape. + // TODO(ravishankarm) : When there is stride information, use that to map from + // index to memory location. + SmallVector<AffineExpr, 4> resultStrides; + if (failed(getRowMajorPackedStrides(builder, resultShape, resultStrides))) { + return failure(); + } + AffineExpr linearizedExpr; + for (auto index : enumerate(resultExprs)) { + auto val = getAffineBinaryOpExpr(AffineExprKind::Mul, index.value(), + resultStrides[index.index()]); + if (doAffineExprSimplify) { + val = simplifyAffineExpr(val, numDims, numSymbols); + } + linearizedExpr = (index.index() ? getAffineBinaryOpExpr(AffineExprKind::Add, + linearizedExpr, val) + : val); + if (doAffineExprSimplify) { + linearizedExpr = simplifyAffineExpr(val, numDims, numSymbols); + } + } + + // Unlinearize the index, assuming row-major-packed layout. + // TODO(ravishankarm) : When there is stride information, use that to map from + // memory location to index. + SmallVector<AffineExpr, 4> operandStrides; + if (failed(getRowMajorPackedStrides(builder, operandShape, operandStrides))) { + return failure(); + } + operandExprs.resize(operandStrides.size()); + for (auto stride : enumerate(operandStrides)) { + if (stride.index() == operandStrides.size() - 1) { + operandExprs[stride.index()] = linearizedExpr; + break; + } + auto expr = getAffineBinaryOpExpr(AffineExprKind::FloorDiv, linearizedExpr, + stride.value()); + operandExprs[stride.index()] = + (doAffineExprSimplify ? simplifyAffineExpr(expr, numDims, numSymbols) + : expr); + + linearizedExpr = getAffineBinaryOpExpr(AffineExprKind::Mod, linearizedExpr, + stride.value()); + if (doAffineExprSimplify) { + linearizedExpr = simplifyAffineExpr(linearizedExpr, numDims, numSymbols); + } + } + return success(); +} +} // namespace + +LogicalResult getReshapeOperandMap(Builder &builder, AffineMap resultIndexMap, + ArrayRef<int64_t> resultShapeRef, + ArrayRef<int64_t> operandShapeRef, + AffineMap &operandIndexMap) { + auto resultShape = handleIfScalar(resultShapeRef); + auto operandShape = handleIfScalar(operandShapeRef); + auto resultExprs = resultIndexMap.getResults(); + assert(resultShape.size() == resultExprs.size() && + "Ranks of the Domain of index map and result must be the same"); + SmallVector<AffineExpr, 4> operandExprs; + if (failed(getAffineExprForAddOrRemoveDimension( + builder, resultExprs, resultShape, operandShape, operandExprs)) && + failed(getAffineExprForReshape( + builder, resultIndexMap.getNumDims(), resultIndexMap.getNumSymbols(), + resultExprs, resultShape, operandShape, operandExprs))) { + return failure(); + } + assert(operandExprs.size() == operandShape.size() && + "expected as many exprs for the operand as the rank of the operand"); + operandIndexMap = + AffineMap::get(resultIndexMap.getNumDims(), + resultIndexMap.getNumSymbols(), operandExprs); + + return success(); +} + +LogicalResult IndexPropagation::propagateIndexMap( + Operation *op, IndexComputationCache &indexMap) const { + if (op->getNumResults() == 0) { + // Nothing to do for this op. + return success(); + } + if (op->getNumResults() != 1) { + return op->emitError( + "default index propagation handles case with a single-return value"); + } + // Initialize the storage for all the operands. + for (auto arg : op->getOperands()) { + indexMap[arg]; + } + for (auto &resultIndexMap : indexMap[op->getResult(0)]) { + SmallVector<AffineMap, 4> operandIndices; + if (failed(this->propagateIndexMap(op, resultIndexMap.first, + operandIndices))) { + return failure(); + } + assert(operandIndices.size() == op->getNumOperands() && + "Expected as many indices as operands"); + for (auto arg : enumerate(op->getOperands())) { + indexMap[arg.value()][operandIndices[arg.index()]]; + resultIndexMap.second.push_back(operandIndices[arg.index()]); + } + } + return success(); +} + +void dumpIndexCache(IndexComputationCache &indexMap) { + for (auto &el : indexMap) { + // llvm::errs() << "Value : " << *(el.first); + // llvm::errs().flush(); + if (isa<OpResult>(el.first)) { + llvm::errs() << "Operation : " << el.first->getDefiningOp()->getName(); + } else if (isa<BlockArgument>(el.first)) { + llvm::errs() << "BlockArgument"; + } + for (auto &used : el.second) { + llvm::errs() << "\n\t" << used.first << " : ["; + std::string sep = ""; + for (auto &operand : used.second) { + llvm::errs() << sep << operand; + sep = ", "; + } + llvm::errs() << "]"; + } + llvm::errs() << "\n"; + } + llvm::errs() << "\n"; +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/IndexComputation.h b/compiler/Translation/SPIRV/IndexComputation.h similarity index 100% rename from iree/compiler/Translation/SPIRV/IndexComputation.h rename to compiler/Translation/SPIRV/IndexComputation.h
diff --git a/iree/compiler/Translation/SPIRV/Kernels/BUILD b/compiler/Translation/SPIRV/Kernels/BUILD similarity index 100% rename from iree/compiler/Translation/SPIRV/Kernels/BUILD rename to compiler/Translation/SPIRV/Kernels/BUILD
diff --git a/iree/compiler/Translation/SPIRV/Kernels/CMakeLists.txt b/compiler/Translation/SPIRV/Kernels/CMakeLists.txt similarity index 100% rename from iree/compiler/Translation/SPIRV/Kernels/CMakeLists.txt rename to compiler/Translation/SPIRV/Kernels/CMakeLists.txt
diff --git a/iree/compiler/Translation/SPIRV/Kernels/matmul.comp b/compiler/Translation/SPIRV/Kernels/matmul.comp similarity index 100% rename from iree/compiler/Translation/SPIRV/Kernels/matmul.comp rename to compiler/Translation/SPIRV/Kernels/matmul.comp
diff --git a/iree/compiler/Translation/SPIRV/Kernels/reduce_untiled.comp b/compiler/Translation/SPIRV/Kernels/reduce_untiled.comp similarity index 100% rename from iree/compiler/Translation/SPIRV/Kernels/reduce_untiled.comp rename to compiler/Translation/SPIRV/Kernels/reduce_untiled.comp
diff --git a/compiler/Translation/SPIRV/Kernels/spirv_utils.bzl b/compiler/Translation/SPIRV/Kernels/spirv_utils.bzl new file mode 100644 index 0000000..f6e224d --- /dev/null +++ b/compiler/Translation/SPIRV/Kernels/spirv_utils.bzl
@@ -0,0 +1,32 @@ +"""Utilities for handling hand-written SPIR-V files.""" + +load("//:build_defs.google.bzl", "iree_glsl_vulkan") +load("///build_tools/embed_data:build_defs.bzl", "cc_embed_data") + +def spirv_kernel_cc_library(name, srcs): + """Compiles GLSL files into SPIR-V binaries and embeds them in a cc_library. + + Args: + name: cc_library name to depend on. + srcs: a list of GLSL source files. + """ + spv_files = [] + for src in srcs: + spv_name = src.split(".")[-2] + iree_glsl_vulkan( + name = spv_name, + srcs = [src], + ) + spv_files.append(spv_name + ".spv") + native.filegroup( + name = name + "_files", + srcs = spv_files, + ) + cc_embed_data( + name = name, + srcs = spv_files, + cc_file_output = name + ".cc", + h_file_output = name + ".h", + cpp_namespace = "mlir::iree_compiler::spirv_kernels", + flatten = True, + )
diff --git a/compiler/Translation/SPIRV/SPIRVExecutableTranslation.cpp b/compiler/Translation/SPIRV/SPIRVExecutableTranslation.cpp new file mode 100644 index 0000000..5ab78c7 --- /dev/null +++ b/compiler/Translation/SPIRV/SPIRVExecutableTranslation.cpp
@@ -0,0 +1,314 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/Translation/SPIRV/SPIRVExecutableTranslation.h" + +#include <cstdint> +#include <iostream> +#include <map> +#include <vector> + +#include "flatbuffers/flatbuffers.h" +#include "compiler/Translation/SPIRV/EmbeddedKernels.h" +#include "compiler/Translation/SPIRV/IREEToSPIRVPass.h" +#include "compiler/Utils/OpUtils.h" +#include "compiler/Utils/TranslationUtils.h" +#include "schemas/executable_def_generated.h" +#include "schemas/spirv_executable_def_generated.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/Serialization.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Module.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Translation.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" +#include "tensorflow/compiler/mlir/xla/transforms/passes.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +class SPIRVTranslator { + public: + explicit SPIRVTranslator(ExecutableTranslationOptions options) + : options_(options) {} + + const ExecutableTranslationOptions &options() const { return options_; } + + // Returns a populated ExecutableDef or nullptr if translation is + // unsuccessful. + std::unique_ptr<iree::ExecutableDefT> translateExecutable( + IREE::ExecutableOp executableOp); + + private: + // Returns a list of entry point names matching the expected export ordinals. + std::vector<std::string> populateEntryPointNames( + IREE::ExecutableOp executableOp); + + // Translates the input module into the SPIR-V dialect and returns the + // serialized code words or empty if translation failed. + std::vector<uint32_t> translateAndSerializeShaderModule( + IREE::ExecutableOp executableOp); + + // Returns a pipeline layout definition based on the bindings required. + std::unique_ptr<iree::VkPipelineLayoutDefT> populatePipelineLayout( + spirv::ModuleOp spirvModuleOp); + + ExecutableTranslationOptions options_; +}; + +std::unique_ptr<iree::ExecutableDefT> SPIRVTranslator::translateExecutable( + IREE::ExecutableOp executableOp) { + // Try first to match against an embedded kernel (such as matmul) and + // otherwise fall back to generating the kernel. + iree::SpirVExecutableDefT spirvExecutableDef; + if (!tryEmbeddedKernelRewrite(executableOp, &spirvExecutableDef)) { + // The sequencer and runtime use ordinals instead of names. We provide the + // list of entry point names here that are then passed in + // VkShaderModuleCreateInfo. + spirvExecutableDef.entry_points = populateEntryPointNames(executableOp); + + // Translate the module and generate the SPIR-V code. + // The module is expected to be modified and must contain the metadata + // required to enable the following information needed for the + // SpirVExecutableDef to be extracted. + spirvExecutableDef.code = translateAndSerializeShaderModule(executableOp); + if (spirvExecutableDef.code.empty()) { + executableOp.emitError() + << "Failed to translate and serialize SPIR-V executable"; + return {}; + } + + // Reflect against the entry thunk to identify the required pipeline + // layout based on binding information. This is used by the runtime to + // create the VkPipelineLayout. + for (auto spirvModuleOp : + executableOp.getBlock().getOps<spirv::ModuleOp>()) { + spirvExecutableDef.pipeline_layout = + populatePipelineLayout(spirvModuleOp); + if (!spirvExecutableDef.pipeline_layout) { + spirvModuleOp.emitError() + << "Failed to generate pipeline for SPIR-V module"; + return {}; + } + break; + } + } + + // Pack the executable definition and get the bytes with the proper header. + // The header is used to verify the contents at runtime. + ::flatbuffers::FlatBufferBuilder fbb; + auto executableOffset = + iree::SpirVExecutableDef::Pack(fbb, &spirvExecutableDef); + iree::FinishSpirVExecutableDefBuffer(fbb, executableOffset); + std::vector<uint8_t> bytes; + bytes.resize(fbb.GetSize()); + std::memcpy(bytes.data(), fbb.GetBufferPointer(), bytes.size()); + + OpBuilder builder(executableOp); + executableOp.setAttr("format", builder.getI32IntegerAttr(static_cast<int32_t>( + IREE::ExecutableFormat::SpirV))); + + auto executableDef = std::make_unique<iree::ExecutableDefT>(); + executableDef->format = static_cast<uint32_t>(IREE::ExecutableFormat::SpirV); + executableDef->contents = std::move(bytes); + return executableDef; +} + +std::vector<std::string> SPIRVTranslator::populateEntryPointNames( + IREE::ExecutableOp executableOp) { + auto module = executableOp.getInnerModule(); + DenseMap<unsigned, StringRef> entryPoints; + for (auto funcOp : module.getOps<FuncOp>()) { + if (!funcOp.getAttr("iree.executable.export")) continue; + auto ordinalAttr = funcOp.getAttrOfType<IntegerAttr>("iree.ordinal"); + entryPoints[ordinalAttr.getInt()] = funcOp.getName(); + } + std::vector<std::string> entryPointNames(entryPoints.size()); + for (auto &entry : entryPoints) { + entryPointNames[entry.first] = entry.second.str(); + } + return entryPointNames; +} + +std::vector<uint32_t> SPIRVTranslator::translateAndSerializeShaderModule( + IREE::ExecutableOp executableOp) { + auto module = executableOp.getInnerModule(); + + // We can use the workload hint to know what the expected dispatch workload + // is. If we want to remap this to make more sense for the operations we are + // performing we can do that here. + // + // Note that workloads are computed per entry point. There may be some + // dimensions of the workload that are static (in which case workloadAttr will + // have non-dynamic dims) and others that need to be taken from an argument + // shape (in which case workloadRef is the argument ordinal to take dynamic + // dimensions from). + // TODO(benvanik): make it just an arg instead? iree.workload special op? + // TODO(benvanik): instead of FuncOp have an iree.entry_point op with these. + for (auto funcOp : module.getOps<FuncOp>()) { + // TODO(ravishankarm): FuncOps in executable that are not dispatch functions + // are not lowered to SPIR-V. Fix this limitation. + if (!funcOp.getAttr("iree.executable.export")) continue; + auto workloadAttr = + funcOp.getAttrOfType<ElementsAttr>("iree.executable.workload"); + auto workloadRefAttr = + funcOp.getAttrOfType<IntegerAttr>("iree.executable.workload_ref"); + std::array<int32_t, 3> staticWorkloadDims = {-1, -1, -1}; + if (workloadAttr) { + for (unsigned i = 0; i < 3; ++i) { + if (auto dimAttr = + workloadAttr.getValue({i}).dyn_cast_or_null<IntegerAttr>()) { + staticWorkloadDims[i] = dimAttr.getInt(); + } + } + } + std::array<BlockArgument *, 3> dynamicWorkloadDimRefs; + if (workloadRefAttr) { + for (unsigned i = 0; i < 3; ++i) { + if (staticWorkloadDims[i] == -1) { + dynamicWorkloadDimRefs[i] = + funcOp.getArgument(workloadRefAttr.getInt()); + } + } + } + + // Now staticWorkloadDims will have non-negative values for known dimensions + // and any dim with -1 will need to be pulled from the corresponding shape + // dimension of dynamicWorkloadDimRefs. + + // TODO(b/137868263): use this information to map from workgroup to + // invocation and perform indexing. + } + + // Lower module to spirv::ModuleOp. + auto spirvGenPasses = createPassManager(module.getContext(), options()); + spirvGenPasses->addPass(xla_hlo::createLegalizeToStdPass()); + spirvGenPasses->addPass(createIREEToSPIRVPass()); + if (failed(runPassPipeline(options(), spirvGenPasses.get(), module))) { + executableOp.emitError() << "Failed to generate spv.module"; + return {}; + } + + auto spvModules = module.getOps<spirv::ModuleOp>(); + if (std::distance(spvModules.begin(), spvModules.end()) != 1) { + executableOp.emitError() + << "Expected a single spv.module for an IREE executable op"; + return {}; + } + + // Serialize the spirv::ModuleOp into the binary that we will embed in the + // final flatbuffer. + std::vector<uint32_t> spvBinaries; + for (auto spvModule : spvModules) { + SmallVector<uint32_t, 256> spvBinary; + if (failed(spirv::serialize(spvModule, spvBinary))) { + executableOp.emitError() << "Failed to serialize spv.module"; + return {}; + } + spvBinaries.insert(spvBinaries.end(), spvBinary.begin(), spvBinary.end()); + + // Clone the module into executableOp directly. + auto clonedModule = spvModule.clone(); + executableOp.getBlock().getOperations().insert( + std::prev(executableOp.getBlock().getOperations().end()), clonedModule); + } + // Remove the original code. + module.erase(); + + return spvBinaries; +} + +std::unique_ptr<iree::VkPipelineLayoutDefT> +SPIRVTranslator::populatePipelineLayout(spirv::ModuleOp spirvModuleOp) { + // NOTE: we currently make some assumptions about this based on the expected + // ABI of the runtime. If we wanted to support more general shaders with more + // complex I/O we'd need to find a better way to communicate this through the + // VkPipelineLayoutDef. + auto pipelineLayoutDef = std::make_unique<iree::VkPipelineLayoutDefT>(); + pipelineLayoutDef->buffer_binding_set = 0; + + // Build a set of descriptor_set -> binding -> variable. + // This makes it easier to write out the descriptor in a logical order, even + // though this is not strictly required. + int64_t maxDescriptorSetOrdinal = -1; + std::map<int32_t, std::map<int32_t, spirv::GlobalVariableOp>> descriptorSets; + for (auto globalVar : + spirvModuleOp.getBlock().getOps<spirv::GlobalVariableOp>()) { + auto descriptorSetAttr = + globalVar.getAttrOfType<IntegerAttr>("descriptor_set"); + auto bindingAttr = globalVar.getAttrOfType<IntegerAttr>("binding"); + if (!descriptorSetAttr || !bindingAttr) { + // Not something the runtime cares about. + continue; + } + maxDescriptorSetOrdinal = + std::max(descriptorSetAttr.getInt(), maxDescriptorSetOrdinal); + auto &descriptorSet = descriptorSets[descriptorSetAttr.getInt()]; + descriptorSet[bindingAttr.getInt()] = globalVar; + } + + // Create the individual layout and binding defs. + pipelineLayoutDef->descriptor_set_layouts.resize(maxDescriptorSetOrdinal + 1); + for (auto &descriptorSetBindings : descriptorSets) { + int32_t descriptorSet = descriptorSetBindings.first; + auto dsl = std::make_unique<iree::VkDescriptorSetLayoutDefT>(); + + for (auto &globalVarBinding : descriptorSetBindings.second) { + auto binding = std::make_unique<iree::VkDescriptorSetLayoutBindingDefT>(); + binding->binding = globalVarBinding.first; + binding->descriptor_count = 1; + // TODO(benvanik): pull from type info. + binding->descriptor_type = 7; // VK_DESCRIPTOR_TYPE_STORAGE_BUFFER + binding->stage_flags = 0x00000020; // VK_SHADER_STAGE_COMPUTE_BIT + dsl->bindings.push_back(std::move(binding)); + } + + pipelineLayoutDef->descriptor_set_layouts[descriptorSet] = std::move(dsl); + } + + return pipelineLayoutDef; +} + +} // namespace + +llvm::Optional<ExecutableTranslationResult> +translateExecutableToSPIRVExecutable(ArrayRef<IREE::ExecutableOp> executableOps, + ExecutableTranslationOptions options) { + SPIRVTranslator translator(options); + ExecutableTranslationResult translationResult; + for (auto executableOp : llvm::make_early_inc_range(executableOps)) { + auto executableDef = translator.translateExecutable(executableOp); + if (!executableDef) { + executableOp.emitError() << "Failed to translate one or more executables"; + return llvm::None; + } + translationResult.executable_defs.push_back(std::move(executableDef)); + } + return translationResult; +} + +static ExecutableTranslationRegistration SPIRVExecutableTranslationRegistration( + "vulkan-spirv", translateExecutableToSPIRVExecutable); + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Translation/SPIRV/SPIRVExecutableTranslation.h b/compiler/Translation/SPIRV/SPIRVExecutableTranslation.h new file mode 100644 index 0000000..89fe483 --- /dev/null +++ b/compiler/Translation/SPIRV/SPIRVExecutableTranslation.h
@@ -0,0 +1,38 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_TRANSLATION_SPIRV_SPIRVEXECUTABLETRANSLATION_H_ +#define IREE_COMPILER_TRANSLATION_SPIRV_SPIRVEXECUTABLETRANSLATION_H_ + +#include <vector> + +#include "compiler/IR/StructureOps.h" +#include "compiler/Utils/TranslationUtils.h" +#include "mlir/IR/Module.h" + +namespace mlir { +namespace iree_compiler { + +// Translates an MLIR module into a SPIR-V executable. +// These executables are stored as FlatBuffers in the +// https://github.com/google/iree/tree/master/iree/schemas/spirv_executable_def.fbs +// schema. +llvm::Optional<ExecutableTranslationResult> +translateExecutableToSPIRVExecutable(ArrayRef<IREE::ExecutableOp> executableOps, + ExecutableTranslationOptions options = {}); + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_TRANSLATION_SPIRV_SPIRVEXECUTABLETRANSLATION_H_
diff --git a/compiler/Translation/SPIRV/SPIRVLowering.cpp b/compiler/Translation/SPIRV/SPIRVLowering.cpp new file mode 100644 index 0000000..23f87a7 --- /dev/null +++ b/compiler/Translation/SPIRV/SPIRVLowering.cpp
@@ -0,0 +1,131 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//===- SPIRVLowering.cpp ---------------------------------------*- C++//-*-===// +// +// SPIR-V Code-generation for XLA-HLO Ops within IREE Dispatch functions +// +//===----------------------------------------------------------------------===// +#include "compiler/Translation/SPIRV/SPIRVLowering.h" + +namespace mlir { +namespace iree_compiler { +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// +LogicalResult ConstantOpSPIRVLowering::lowerOperation( + Operation *op, OpBuilder &builder, AffineMap index, ArrayRef<Value *>, + ValueCache &valueCache) const { + auto constOp = cast<ConstantOp>(op); + auto attr = constOp.value().dyn_cast<DenseElementsAttr>(); + if (!attr || !attr.isSplat()) { + return op->emitError( + "unhandled constant lowering unless value is a splat dense element " + "attribute"); + } + auto resultType = constOp.getResult()->getType(); + Type resultElemType; + if (resultType.isIntOrFloat()) { + resultElemType = resultType; + } else if (auto shapedType = resultType.dyn_cast<ShapedType>()) { + resultElemType = shapedType.getElementType(); + } else { + return op->emitError("unhandled result type of constant : ") << resultType; + } + Attribute constVal = attr.getSplatValue(); + auto spirvConstOp = + builder.create<spirv::ConstantOp>(op->getLoc(), resultElemType, constVal); + valueCache.setOperandDstValue(constOp.getResult(), index, + spirvConstOp.getResult()); + return success(); +} + +//===----------------------------------------------------------------------===// +// CmpFOp +//===----------------------------------------------------------------------===// +LogicalResult CmpFOpSPIRVLowering::lowerOperation( + Operation *op, OpBuilder &builder, AffineMap index, + ArrayRef<Value *> operands, ValueCache &valueCache) const { + if (operands.size() != 2) { + return op->emitError("expected two operands in spir-v lowering of CmpFOp"); + } + Operation *spirvOp = nullptr; + auto opInfo = op->getAttrOfType<IntegerAttr>(CmpFOp::getPredicateAttrName()); + if (!opInfo) { + return op->emitError("expected CmpFOp to contain ") + << CmpFOp::getPredicateAttrName() << " attribute"; + } + auto boolType = builder.getI1Type(); + auto predicateVal = static_cast<CmpFPredicate>(opInfo.getInt()); + switch (predicateVal) { +#define DISPATCH(caseLabel, opName) \ + case caseLabel: \ + spirvOp = builder.create<opName>(op->getLoc(), boolType, operands[0], \ + operands[1]); \ + break; + + DISPATCH(CmpFPredicate::OEQ, spirv::FOrdEqualOp); + DISPATCH(CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); + DISPATCH(CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); + DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); + DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp); + DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp); + DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp); + DISPATCH(CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); + DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); + DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); + DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp); + DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp); + +#undef DISPATCH + + default: + return op->emitError("unhandled predicate attribute for SPIR-V lowering"); + } + valueCache.setOperandDstValue(op->getResult(0), index, spirvOp->getResult(0)); + return success(); +} + +//===----------------------------------------------------------------------===// +// ReturnOp +//===----------------------------------------------------------------------===// +LogicalResult ReturnOpSPIRVLowering::lowerOperation( + Operation *op, OpBuilder &builder, AffineExprCodegen &affineExprCodegen, + ValueCache &valueCache, + DenseMap<Value *, spirv::GlobalVariableOp> &inputBuffers, + ArrayRef<spirv::GlobalVariableOp> outputBuffers) const { + auto returnOp = cast<ReturnOp>(op); + if (returnOp.getNumOperands() != 1) { + return returnOp.emitError( + "unhandled lowering of return statement with multiple returns"); + } + auto returnTensor = returnOp.getOperand(0); + auto indices = affineExprCodegen.getIndices(returnTensor); + if (indices.size() != 1) { + return returnOp.emitError( + "expected to compute a single element of the return tensor"); + } + assert(outputBuffers.size() == 1 && "Expected a single output buffer"); + auto var = outputBuffers[0]; + auto ptr = genPointerOffset(builder, returnOp.getLoc(), affineExprCodegen, + indices[0], var); + auto scalarVal = valueCache.getOperandDstValue(returnTensor, indices[0]); + builder.create<spirv::StoreOp>(returnOp.getLoc(), ptr, scalarVal, + /*memory_access = */ nullptr, + /*alignment = */ nullptr); + builder.create<spirv::ReturnOp>(returnOp.getLoc()); + return success(); +} +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Translation/SPIRV/SPIRVLowering.h b/compiler/Translation/SPIRV/SPIRVLowering.h new file mode 100644 index 0000000..eacb551 --- /dev/null +++ b/compiler/Translation/SPIRV/SPIRVLowering.h
@@ -0,0 +1,591 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//===- SPIRVLowering.h -----------------------------------------*- C++//-*-===// +// +// SPIR-V Code-generation for tensor operations within IREE Dispatch functions +// +//===----------------------------------------------------------------------===// +#ifndef IREE_COMPILER_TRANSLATION_SPIRV_SPIRVLOWERING_H +#define IREE_COMPILER_TRANSLATION_SPIRV_SPIRVLOWERING_H + +#include "compiler/Translation/SPIRV/AffineExprCodegen.h" +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/Support/StringExtras.h" + +namespace mlir { +namespace iree_compiler { + +class ValueCache { + public: + Value *getOperandDstValue(Value *value, AffineMap index) { + return convertedValueMap.lookup(value).lookup(index); + } + + void setOperandDstValue(Value *value, AffineMap index, Value *scalar) { + convertedValueMap[value][index] = scalar; + } + + private: + DenseMap<Value *, DenseMap<AffineMap, Value *>> convertedValueMap; +}; + +/// Base class for lowering tensor operations in the dispatch function to SPIR-V +/// op. +class SPIRVLowering { + public: + virtual ~SPIRVLowering() = default; + virtual StringRef getOpName() = 0; + /// This method (in the derived class) should generate the scalar operation + /// corresponding the the tensor operation `op` to generate the value of the + /// result tensor at a particular `index`. The scalar value of the operands + /// needed to compute this value is passed in within `operands`. The methods + /// have to insert the scalar result value of the generated operation into the + /// `valueCache`. + virtual LogicalResult lowerOperation(Operation *op, OpBuilder &builder, + AffineMap index, + ArrayRef<Value *> operands, + ValueCache &valueCache) const { + return failure(); + } + + /// This method (in the derived class) should generate the scalar operations + /// corresponding to the tensor operation `op`. This should be implemented + /// when the `op` has no result value, typically store operations and return + /// operations. + virtual LogicalResult lowerOperation( + Operation *op, OpBuilder &builder, AffineExprCodegen &affineExprCodegen, + ValueCache &valueCache, + DenseMap<Value *, spirv::GlobalVariableOp> &inputBuffers, + ArrayRef<spirv::GlobalVariableOp> outputBuffers) const { + return failure(); + } +}; + +/// Base class that gets the opName for the operation. +template <typename OpTy> +class SPIRVOpLowering : public SPIRVLowering { + public: + using SPIRVLowering::SPIRVLowering; + virtual ~SPIRVOpLowering<OpTy>() {} + StringRef getOpName() override { return OpTy::getOperationName(); } +}; + +/// SPIR-V lowering for ConstantOp. +class ConstantOpSPIRVLowering final : public SPIRVOpLowering<ConstantOp> { + public: + using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering; + LogicalResult lowerOperation(Operation *op, OpBuilder &builder, + AffineMap index, ArrayRef<Value *> operands, + ValueCache &valueCache) const override; +}; + +/// SPIR-V lowering for CmpFOp. +class CmpFOpSPIRVLowering final : public SPIRVOpLowering<CmpFOp> { + public: + using SPIRVOpLowering<CmpFOp>::SPIRVOpLowering; + + LogicalResult lowerOperation(Operation *op, OpBuilder &builder, + AffineMap index, ArrayRef<Value *> operands, + ValueCache &valueCache) const override; +}; + +/// SPIR-V lowering for Min/Max operations. +template <typename OpTy, typename CmpOpTy, typename CmpFOpTy> +class CmpSelectOpSPIRVLowering final : public SPIRVOpLowering<OpTy> { + public: + using SPIRVOpLowering<OpTy>::SPIRVOpLowering; + LogicalResult lowerOperation(Operation *op, OpBuilder &builder, + AffineMap index, ArrayRef<Value *> operands, + ValueCache &valueCache) const override { + if (op->getNumOperands() != 2) { + return op->emitError( + "unhandled SPIR-V lowering for more than 2 operands"); + } + assert(operands.size() == op->getNumOperands() && + "expected as many operands for the replacement as the original " + "instruction"); + auto cmpSelectOp = cast<OpTy>(op); + auto result = cmpSelectOp.getResult(); + auto resultTy = result->getType().template dyn_cast<ShapedType>(); + if (!resultTy) { + return op->emitError( + "unhandled lowering of operations that don't return a " + "ShapedType"); + } + auto elementTy = resultTy.getElementType(); + auto boolTy = builder.getI1Type(); + Operation *cmpOp = nullptr; + if (elementTy.template isa<FloatType>()) { + cmpOp = builder.create<CmpFOpTy>(op->getLoc(), boolTy, operands, + ArrayRef<NamedAttribute>()); + } else { + cmpOp = builder.create<CmpOpTy>(op->getLoc(), boolTy, operands, + ArrayRef<NamedAttribute>()); + } + auto selectOp = builder.create<spirv::SelectOp>( + op->getLoc(), operands[0]->getType(), cmpOp->getResult(0), operands[0], + operands[1]); + valueCache.setOperandDstValue(op->getResult(0), index, + selectOp.getResult()); + return success(); + } +}; + +/// This class is the general template used to emit scalar instruction +/// corresponding for point-wise operations. Assumes that the original +/// instruction has a single result value of type ShapedType. +/// TODO(ravishankarm) : In XLA-HLO, the same operations is used for +/// integer/float tensor operations. So allow this op to take an additional op +/// type as a template parameter to handle such cases. Find a better way to do +/// this. +template <typename OpTy, typename ReplacementOpTy, + typename FloatOpTy = ReplacementOpTy> +class SPIRVPwOpLowering final : public SPIRVOpLowering<OpTy> { + public: + using SPIRVOpLowering<OpTy>::SPIRVOpLowering; + + LogicalResult lowerOperation(Operation *op, OpBuilder &builder, + AffineMap index, + ArrayRef<Value *> scalarOperands, + ValueCache &valueCache) const override { + // TODO(ravishankarm) : This check should really be a static_assert. See if + // that can be changed. + if (op->getNumOperands() == 0) { + return op->emitError("expected op to have at least one operand"); + } + auto pwOp = cast<OpTy>(op); + auto result = pwOp.getResult(); + auto resultType = result->getType().template dyn_cast<ShapedType>(); + if (!resultType) { + return op->emitError( + "unhandled lowering of operations that don't return a " + "ShapedType"); + } + auto elementType = resultType.getElementType(); + Operation *scalarOp = nullptr; + if (elementType.template isa<IntegerType>()) { + scalarOp = builder + .create<ReplacementOpTy>(op->getLoc(), elementType, + scalarOperands, + ArrayRef<NamedAttribute>()) + .getOperation(); + } else { + scalarOp = + builder + .create<FloatOpTy>(op->getLoc(), elementType, scalarOperands, + ArrayRef<NamedAttribute>()) + .getOperation(); + } + if (!scalarOp) { + return op->emitError("unable to lower operation"); + } + valueCache.setOperandDstValue(pwOp.getResult(), index, + scalarOp->getResult(0)); + return success(); + } +}; + +/// This class is the general template used to emit scalar instruction for index +/// transformation instructions like transpose. Assumes a single result value +/// and a single operand +template <typename OpTy> +class SPIRVIndexOpLowering final : public SPIRVOpLowering<OpTy> { + public: + using SPIRVOpLowering<OpTy>::SPIRVOpLowering; + + LogicalResult lowerOperation(Operation *op, OpBuilder &builder, + AffineMap index, + ArrayRef<Value *> scalarOperands, + ValueCache &valueCache) const override { + if (op->getNumOperands() != 1) { + return op->emitError( + "unhandled lowering of index transformation operation with multiple " + "operands"); + } + auto indexOp = cast<OpTy>(op); + valueCache.setOperandDstValue(indexOp.getResult(), index, + scalarOperands[0]); + return success(); + } +}; + +/// Ggenerates spv.AccessChain instruction to get the pointer value at a given +/// location of a spv.globalVariable. +inline Value *genPointerOffset(OpBuilder &builder, Location loc, + AffineExprCodegen &affineExprCodegen, + AffineMap indexMap, + spirv::GlobalVariableOp &var) { + auto basePtr = builder.create<spirv::AddressOfOp>( + loc, var.type(), builder.getSymbolRefAttr(var.sym_name())); + auto varPtrType = var.type().cast<spirv::PointerType>().getPointeeType(); + // The variable has to be a struct type with a single element. + assert(varPtrType.isa<spirv::StructType>() && + "expected variable type to be a spv.ptr<spv.struct<...>>"); + auto varStructType = varPtrType.cast<spirv::StructType>(); + assert(varStructType.getNumElements() == 1 && + "expected variable type to be a spv.ptr of spv.struct with a single " + "element"); + auto varType = varStructType.getElementType(0); + + SmallVector<Value *, 2> accessIndex; + /// For scalar values, the index-map computed with already map to the 0-th + /// element. For arrays, they map to the position accessed. So just for arrays + /// we need to add an extra 0 to index into the struct. + if (varType.isa<spirv::ArrayType>() || + varType.isa<spirv::RuntimeArrayType>()) { + auto i32Type = builder.getIntegerType(32); + auto zero = builder.create<spirv::ConstantOp>(loc, i32Type, + builder.getI32IntegerAttr(0)); + accessIndex.push_back(zero); + } + for (auto indexExpr : indexMap.getResults()) { + accessIndex.push_back(affineExprCodegen.getValue( + indexExpr, builder.saveInsertionPoint(), loc)); + } + return builder.create<spirv::AccessChainOp>(loc, basePtr, accessIndex); +} + +/// Lower return statements during SPIR-V codegeneration. +class ReturnOpSPIRVLowering : public SPIRVOpLowering<ReturnOp> { + public: + using SPIRVOpLowering<ReturnOp>::SPIRVOpLowering; + + LogicalResult lowerOperation( + Operation *op, OpBuilder &builder, AffineExprCodegen &affineExprCodegen, + ValueCache &valueCache, + DenseMap<Value *, spirv::GlobalVariableOp> &inputBuffers, + ArrayRef<spirv::GlobalVariableOp> outputBuffers) const override; +}; + +/// Class to drive the SPIRV code-generation. +template <typename... Ts> +class SPIRVCodegen { + using OpCodegenListT = llvm::StringMap<std::unique_ptr<SPIRVLowering>>; + + public: + explicit SPIRVCodegen() { insert(); } + + LogicalResult codegen(spirv::ModuleOp &spirvModule, FuncOp &fn, + AffineExprCodegen &affineExprCodegen, + ValueCache &valueCache) { + if (fn.getBlocks().size() != 1) { + return emitError( + fn.getLoc(), + "unimplemeneted handling multiple blocks within a function"); + } + + OpBuilder builder(spirvModule.body()); + // Create the entry function and generate global invocation ID. Creates a + // global variable for all inputs and output tensors. + return createEntryFn(builder, fn, affineExprCodegen, valueCache); + } + + private: + /// Helper method to create the entry function. Creates global variables for + /// all inputs and outputs. Inserts the spv.EntryPoint operations as well. + LogicalResult createEntryFn(OpBuilder &builder, FuncOp &fn, + AffineExprCodegen &affineExprCodegen, + ValueCache &valueCache) { + auto loc = fn.getLoc(); + // TODO(ravishankarm) : This should actually be part of the SPIR-V + // conversion framework in MLIR core. Move it there. + auto convertType = [&loc](Type t, + spirv::PointerType &varType) -> LogicalResult { + auto shapedType = t.dyn_cast<ShapedType>(); + if (!shapedType) { + return emitError(loc, "expected ShapedType argument"); + } + auto elementType = shapedType.getElementType(); + if (!elementType.isIntOrFloat()) { + return emitError(loc, "unhandled element type ") + << elementType << " while lowering to SPIR-V"; + } + int64_t stride = elementType.getIntOrFloatBitWidth() / 8; + for (auto dim : reverse(shapedType.getShape())) { + if (dim <= 0) { + return emitError(loc, "expected tensor dimensions to be non-zero"); + } + elementType = spirv::ArrayType::get( + elementType, dim, + static_cast<spirv::ArrayType::LayoutInfo>(stride)); + stride *= dim; + } + // TODO(ravishankarm): Verify that the type of the variable passes + // spirv-val. + varType = spirv::PointerType::get( + spirv::StructType::get(elementType, + static_cast<spirv::StructType::LayoutInfo>(0)), + spirv::StorageClass::StorageBuffer); + return success(); + }; + + // Convert functions arguments and return values to + // spirv::GlobalVariables. All global variables are given a descriptor set + // of 0 and binding is the argument number. + auto fnType = fn.getType(); + auto descriptorSetAttrName = convertToSnakeCase( + stringifyDecoration(spirv::Decoration::DescriptorSet)); + auto bindingAttrName = + convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding)); + for (auto argType : enumerate(fnType.getInputs())) { + spirv::PointerType varType; + if (failed(convertType(argType.value(), varType))) { + return failure(); + } + auto varName = + fn.getName().str() + "_arg_" + std::to_string(argType.index()); + auto var = builder.create<spirv::GlobalVariableOp>( + loc, TypeAttr::get(varType), builder.getStringAttr(varName), nullptr); + // Set descriptor_set to 0. + var.setAttr(descriptorSetAttrName, builder.getI32IntegerAttr(0)); + // Set binding to argument number. + var.setAttr(bindingAttrName, builder.getI32IntegerAttr(argType.index())); + + inputArgToVariable[fn.getArgument(argType.index())] = var; + } + for (auto resType : enumerate(fnType.getResults())) { + spirv::PointerType varType; + if (failed(convertType(resType.value(), varType))) { + return failure(); + } + auto varName = + fn.getName().str() + "_res_" + std::to_string(resType.index()); + auto var = builder.create<spirv::GlobalVariableOp>( + loc, TypeAttr::get(varType), builder.getStringAttr(varName), nullptr); + // Set descriptor_set to 0. + var.setAttr(descriptorSetAttrName, builder.getI32IntegerAttr(0)); + // Set binding to (result number + num arguments) + var.setAttr( + bindingAttrName, + builder.getI32IntegerAttr(fnType.getNumInputs() + resType.index())); + + resultIndexToVariable.push_back(var); + } + + auto entryFnType = + builder.getFunctionType(ArrayRef<Type>(), ArrayRef<Type>()); + auto entryFn = builder.create<FuncOp>(loc, fn.getName(), entryFnType, + ArrayRef<NamedAttribute>()); + + // Start a scope to create an insertion guard to reset the builder once the + // function is lowered. + { + OpBuilder::InsertionGuard funcInsertGuard(builder); + builder.setInsertionPointToStart(entryFn.addEntryBlock()); + + // Create the Global invocation ID. + if (failed(createGlobalInvocationID(builder, fn.getLoc(), + affineExprCodegen))) { + return failure(); + } + + if (failed(lowerFunction(builder, fn, entryFn, affineExprCodegen, + valueCache))) { + return failure(); + } + } + + // Create the entry point instructions for the entry function. + if (failed(createEntryPoint(builder, loc, entryFn))) { + return failure(); + } + return success(); + } + + /// Creates the global variable for GlobalInvocationID, and gets the ID at x, + /// y and z dimensions. + LogicalResult createGlobalInvocationID(OpBuilder &builder, Location loc, + AffineExprCodegen &affineExprCodegen) { + auto moduleOp = builder.getInsertionBlock() + ->getParentOp() + ->getParentOfType<spirv::ModuleOp>(); + OpBuilder moduleBuilder(moduleOp.body()); + auto i32Type = builder.getIntegerType(32); + auto idType = VectorType::get(3, i32Type); + auto ptrIdType = + spirv::PointerType::get(idType, spirv::StorageClass::Input); + auto globalInvocationID = moduleBuilder.create<spirv::GlobalVariableOp>( + loc, TypeAttr::get(ptrIdType), + builder.getStringAttr("globalInvocationID"), nullptr); + globalInvocationID.setAttr( + convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn)), + builder.getStringAttr( + spirv::stringifyBuiltIn(spirv::BuiltIn::GlobalInvocationId))); + interface.push_back( + builder.getSymbolRefAttr(globalInvocationID.sym_name())); + + auto globalInvocationIDPtr = builder.create<spirv::AddressOfOp>( + loc, ptrIdType, + builder.getSymbolRefAttr(globalInvocationID.getOperation())); + auto id = builder.create<spirv::LoadOp>(loc, idType, globalInvocationIDPtr, + nullptr, nullptr); + auto id_x = builder.create<spirv::CompositeExtractOp>( + loc, i32Type, id, builder.getArrayAttr(builder.getI32IntegerAttr(0))); + auto id_y = builder.create<spirv::CompositeExtractOp>( + loc, i32Type, id, builder.getArrayAttr(builder.getI32IntegerAttr(1))); + auto id_z = builder.create<spirv::CompositeExtractOp>( + loc, i32Type, id, builder.getArrayAttr(builder.getI32IntegerAttr(2))); + affineExprCodegen.setDimDstValue(0, id_x); + affineExprCodegen.setDimDstValue(1, id_y); + affineExprCodegen.setDimDstValue(2, id_z); + return success(); + } + + /// Method to load the values of globalVariables corresponding to the + /// arguments of the dispatch function at all indices needed within the + /// dispatch function. + LogicalResult initArgValues(OpBuilder &builder, Location loc, + AffineExprCodegen &affineExprCodegen, + ValueCache &valueCache, Value *origArg) { + for (auto indexMap : affineExprCodegen.getIndices(origArg)) { + auto var = inputArgToVariable.lookup(origArg); + if (!var) { + return emitError( + loc, "undefined SPIR-V global variable for tensor argument"); + } + auto ptr = + genPointerOffset(builder, loc, affineExprCodegen, indexMap, var); + auto elementType = + ptr->getType().template cast<spirv::PointerType>().getPointeeType(); + auto val = builder.create<spirv::LoadOp>(loc, elementType, ptr, + /*memory_access =*/nullptr, + /*alignment = */ nullptr); + valueCache.setOperandDstValue(origArg, indexMap, val); + } + return success(); + } + + /// Adds the spv.EntryPointOp and records all the interface variables used in + /// the entryFn. + LogicalResult createEntryPoint(OpBuilder &builder, Location loc, + FuncOp entryFn) { + builder.create<spirv::EntryPointOp>( + loc, + builder.getI32IntegerAttr( + static_cast<int32_t>(spirv::ExecutionModel::GLCompute)), + builder.getSymbolRefAttr(entryFn), builder.getArrayAttr(interface)); + builder.create<spirv::ExecutionModeOp>( + loc, builder.getSymbolRefAttr(entryFn), + builder.getI32IntegerAttr( + static_cast<int32_t>(spirv::ExecutionMode::LocalSize)), + builder.getI32ArrayAttr({1, 1, 1})); + interface.clear(); + return success(); + } + + /// Lowers the body of the function in the original dialect to SPIR-V dialect. + LogicalResult lowerFunction(OpBuilder &builder, FuncOp fn, FuncOp entryFn, + AffineExprCodegen &affineExprCodegen, + ValueCache &valueCache) { + for (auto arg : fn.getArguments()) { + // Load values of the argument at all indices needed for computation + // within the dispatch function. + if (failed(initArgValues(builder, fn.getLoc(), affineExprCodegen, + valueCache, arg))) { + return failure(); + } + } + + for (auto &block : fn) { + for (auto &op : block) { + // Lower individual operations. + if (failed( + lowerOperation(builder, affineExprCodegen, valueCache, &op))) { + return failure(); + } + } + } + return success(); + } + + /// Dispatches the lowering of tensor operation to SPIR-V scalar + /// operation. + LogicalResult lowerOperation(OpBuilder &builder, + AffineExprCodegen &affineExprCodegen, + ValueCache &valueCache, Operation *op) { + auto opName = op->getName().getStringRef(); + if (!opCodegenList.count(opName)) { + return op->emitError("unhandled codegen"); + } + if (op->getNumResults() > 1) { + return op->emitError("unhandled codegen for multiple result values"); + } + + // Zero return case. + if (!op->getNumResults()) { + return opCodegenList[opName]->lowerOperation( + op, builder, affineExprCodegen, valueCache, inputArgToVariable, + resultIndexToVariable); + } + + // Single return case. + auto resultTensor = op->getResult(0); + auto indices = affineExprCodegen.getIndices(resultTensor); + for (auto &index : indices) { + auto operandIndices = + affineExprCodegen.getOperandIndices(resultTensor, index); + SmallVector<Value *, 2> scalarOperands; + for (auto arg : llvm::enumerate(op->getOperands())) { + auto scalarArg = valueCache.getOperandDstValue( + arg.value(), operandIndices[arg.index()]); + if (!scalarArg) { + return op->emitError("argument ") + << arg.index() << " has no scalar value"; + } + scalarOperands.push_back(scalarArg); + } + if (failed(opCodegenList[opName]->lowerOperation( + op, builder, index, scalarOperands, valueCache))) { + return failure(); + } + } + return success(); + } + + void insert() { + std::vector<std::unique_ptr<SPIRVLowering>> objs; + using dummy = int[]; + (void)dummy{0, (objs.emplace_back(std::make_unique<Ts>()), 0)...}; + for (auto &elem : objs) { + StringRef opName = elem->getOpName(); + opCodegenList.try_emplace(opName, std::move(elem)); + } + } + + /// List of classes that implement the operation lowering from tensor + /// operations to SPIR-V. + OpCodegenListT opCodegenList; + + /// I/O interface for the entry function containing global variables that are + /// used by the entire function call tree. + SmallVector<Attribute, 4> interface; + + /// Mapping from argument of the dispatch function in tensor dialect to the + /// corresponding spv.globalVariable. + DenseMap<Value *, spirv::GlobalVariableOp> inputArgToVariable; + + /// List of spv.globalVariables created for tensors returned by the dispatch + /// function in tensor dialects. + SmallVector<spirv::GlobalVariableOp, 1> resultIndexToVariable; + + /// GlobalInvocationID variable. + spirv::GlobalVariableOp globalInvocationID; +}; + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_TRANSLATION_SPIRV_SPIRVLOWERING_H
diff --git a/compiler/Translation/SPIRV/XLAIndexPropagation.cpp b/compiler/Translation/SPIRV/XLAIndexPropagation.cpp new file mode 100644 index 0000000..7f30579 --- /dev/null +++ b/compiler/Translation/SPIRV/XLAIndexPropagation.cpp
@@ -0,0 +1,126 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//===- XLAIndexPropagation.cpp ---------------------------------*- C++//-*-===// +// +// For an IREE dispatch function in XLA-HLO dialect, compute the indices of all +// tensors needed to produce the value of the result tensors at a particlar +// index. +// +//===----------------------------------------------------------------------===// + +#include "compiler/Translation/SPIRV/XLAIndexPropagation.h" + +namespace mlir { +namespace iree_compiler { + +//===----------------------------------------------------------------------===// +// BroadcastInDimOp +//===----------------------------------------------------------------------===// + +LogicalResult XLABroadcastInDimOpIndexPropagation::propagateIndexMap( + Operation *operation, AffineMap resultIndex, + SmallVectorImpl<AffineMap> &indexMap) const { + auto broadcastOp = cast<xla_hlo::BroadcastInDimOp>(operation); + auto broadcastDim = broadcastOp.broadcast_dimensions(); + + Builder builder(operation->getContext()); + if (!broadcastDim) { + // This is a scalar. So all indices map to the same element. + AffineMap scalarMap = + AffineMap::get(resultIndex.getNumDims(), resultIndex.getNumSymbols(), + builder.getAffineConstantExpr(0)); + indexMap.push_back(scalarMap); + return success(); + } + + // Handle non-scalar cases. + auto dimensions = broadcastDim->getValues<int64_t>(); + SmallVector<AffineExpr, 4> exprs; + for (auto resultExpr : enumerate(resultIndex.getResults())) { + if (llvm::any_of(dimensions, [&resultExpr](int64_t dim) { + return dim == resultExpr.index(); + })) { + exprs.push_back(resultExpr.value()); + } + } + auto operandMap = AffineMap::get(resultIndex.getNumDims(), + resultIndex.getNumSymbols(), exprs); + indexMap.push_back(operandMap); + return success(); +} + +//===----------------------------------------------------------------------===// +// BroadcastOp +//===----------------------------------------------------------------------===// + +// For broadcast op, just drop the first N expressions of the resultIndex, where +// N is the number of elements in broadcast_sizes attribute. +LogicalResult XLABroadcastOpIndexPropagation::propagateIndexMap( + Operation *operation, AffineMap resultIndex, + SmallVectorImpl<AffineMap> &indexMap) const { + auto broadcastOp = cast<xla_hlo::BroadcastOp>(operation); + auto broadcastDim = broadcastOp.broadcast_sizes(); + + SmallVector<AffineExpr, 4> exprs; + for (auto i : llvm::seq<size_t>( + broadcastDim.getType().getShape()[0], + operation->getResult(0)->getType().cast<ShapedType>().getRank())) { + exprs.push_back(resultIndex.getResult(i)); + } + + Builder builder(operation->getContext()); + if (exprs.empty()) { + // The result is a scalar. Just add a constant expr 0. + exprs.push_back(builder.getAffineConstantExpr(0)); + } + auto operandMap = AffineMap::get(resultIndex.getNumDims(), + resultIndex.getNumSymbols(), exprs); + indexMap.push_back(operandMap); + return success(); +} + +//===----------------------------------------------------------------------===// +// ReverseOp +//===----------------------------------------------------------------------===// + +LogicalResult XLAReverseOpIndexPropagation::propagateIndexMap( + Operation *op, AffineMap resultIndex, + SmallVectorImpl<AffineMap> &indexMap) const { + auto reverseOp = cast<xla_hlo::ReverseOp>(op); + DenseSet<unsigned> dimensions; + for (auto index : reverseOp.dimensions()) { + dimensions.insert(index.getZExtValue()); + } + return propagateIndexMapImpl(op, dimensions, resultIndex, indexMap); +} + +//===----------------------------------------------------------------------===// +// TransposeOp +//===----------------------------------------------------------------------===// + +LogicalResult XLATransposeOpIndexPropagation::propagateIndexMap( + Operation *op, AffineMap resultIndex, + SmallVectorImpl<AffineMap> &indexMap) const { + auto transposeOp = cast<xla_hlo::TransposeOp>(op); + // Compute the affine map that represents the permutation. + SmallVector<unsigned, 4> permutation; + for (auto index : transposeOp.permutation()) { + permutation.push_back(index.getZExtValue()); + } + return propagateIndexMapImpl(op, permutation, resultIndex, indexMap); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Translation/SPIRV/XLAIndexPropagation.h b/compiler/Translation/SPIRV/XLAIndexPropagation.h new file mode 100644 index 0000000..e3baa21 --- /dev/null +++ b/compiler/Translation/SPIRV/XLAIndexPropagation.h
@@ -0,0 +1,112 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//===- XLAIndexPropagation.h -----------------------------------*- C++//-*-===// +// +// For an IREE dispatch function in XLA-HLO dialect, compute the indices of all +// tensors needed to produce the value of the result tensors at a particlar +// index. +// +//===----------------------------------------------------------------------===// +#ifndef IREE_COMPILER_TRANSLATION_SPIRV_XLAINDEXPROPOGATION_H +#define IREE_COMPILER_TRANSLATION_SPIRV_XLAINDEXPROPOGATION_H + +#include "compiler/Translation/SPIRV/IndexComputation.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Function.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" + +namespace mlir { +namespace iree_compiler { + +class XLABroadcastInDimOpIndexPropagation final + : public IndexPropagationOp<xla_hlo::BroadcastInDimOp> { + public: + using IndexPropagationOp<xla_hlo::BroadcastInDimOp>::IndexPropagationOp; + + LogicalResult propagateIndexMap( + Operation *operation, AffineMap resultIndex, + SmallVectorImpl<AffineMap> &operandIndices) const override; +}; + +// For broadcast op, just drop the first N expressions of the resultIndex, where +// N is the number of elements in broadcast_sizes attribute. +class XLABroadcastOpIndexPropagation final + : public IndexPropagationOp<xla_hlo::BroadcastOp> { + public: + using IndexPropagationOp<xla_hlo::BroadcastOp>::IndexPropagationOp; + + LogicalResult propagateIndexMap( + Operation *operation, AffineMap resultIndex, + SmallVectorImpl<AffineMap> &operandIndices) const override; +}; + +/// For return ops, it is assumed that each thread is computing the value of one +/// element of the returned tensor. +template <typename OpTy> +class ReturnOpIndexPropagation : public IndexPropagationOp<OpTy> { + public: + using IndexPropagationOp<OpTy>::IndexPropagationOp; + + LogicalResult propagateIndexMap( + Operation *operation, IndexComputationCache &indexMap) const override { + if (operation->getNumOperands() != 1) { + return operation->emitError("unhandled multiple return values"); + } + auto returnValue = operation->getOperand(0); + auto returnType = returnValue->getType().cast<RankedTensorType>(); + auto returnRank = returnType.getRank(); + if (returnRank > 3) { + return operation->emitError("unhandled return tensor of dimension ") + << returnType.getShape().size(); + } + // Have as many symbols as the rank of the input tensor. These symbols map + // to GlobalInvocationID along the three dimensions. + Builder builder(operation->getContext()); + SmallVector<AffineExpr, 4> affineExprs; + for (size_t i = returnRank; i > 0; --i) { + affineExprs.push_back(builder.getAffineDimExpr(i - 1)); + } + indexMap[operation->getOperand(0)] + [AffineMap::get(returnRank, 0, affineExprs)]; + return success(); + } +}; + +/// Index propogation for XLA Reverse. +class XLAReverseOpIndexPropagation final + : public ReverseOpIndexPropagation<xla_hlo::ReverseOp> { + public: + using ReverseOpIndexPropagation< + xla_hlo::ReverseOp>::ReverseOpIndexPropagation; + LogicalResult propagateIndexMap( + Operation *op, AffineMap resultIndex, + SmallVectorImpl<AffineMap> &indexMap) const override; +}; + +/// Index propogation for XLA Transpose. +class XLATransposeOpIndexPropagation final + : public TransposeOpIndexPropagation<xla_hlo::TransposeOp> { + public: + using TransposeOpIndexPropagation< + xla_hlo::TransposeOp>::TransposeOpIndexPropagation; + LogicalResult propagateIndexMap( + Operation *op, AffineMap resultIndex, + SmallVectorImpl<AffineMap> &indexMap) const override; +}; + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_TRANSLATION_SPIRV_XLAINDEXPROPOGATION_H
diff --git a/compiler/Translation/SPIRV/test/BUILD b/compiler/Translation/SPIRV/test/BUILD new file mode 100644 index 0000000..c7b98c5 --- /dev/null +++ b/compiler/Translation/SPIRV/test/BUILD
@@ -0,0 +1,16 @@ +# Tests for common transforms. + +load("//:build_defs.google.bzl", "iree_glob_lit_tests", "iree_setup_lit_package") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +iree_setup_lit_package( + data = [ + "///tools:iree-opt", + ], +) + +iree_glob_lit_tests()
diff --git a/iree/compiler/Translation/SPIRV/test/exp_test.mlir b/compiler/Translation/SPIRV/test/exp_test.mlir similarity index 100% rename from iree/compiler/Translation/SPIRV/test/exp_test.mlir rename to compiler/Translation/SPIRV/test/exp_test.mlir
diff --git a/iree/compiler/Translation/SPIRV/test/simple_test.mlir b/compiler/Translation/SPIRV/test/simple_test.mlir similarity index 100% rename from iree/compiler/Translation/SPIRV/test/simple_test.mlir rename to compiler/Translation/SPIRV/test/simple_test.mlir
diff --git a/compiler/Translation/Sequencer/BUILD b/compiler/Translation/Sequencer/BUILD new file mode 100644 index 0000000..cea0b27 --- /dev/null +++ b/compiler/Translation/Sequencer/BUILD
@@ -0,0 +1,33 @@ +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "Sequencer", + srcs = ["SequencerModuleTranslation.cpp"], + hdrs = ["SequencerModuleTranslation.h"], + deps = [ + "///base:status", + "///compiler/IR", + "///compiler/IR/Sequencer", + "///compiler/Serialization", + "///compiler/Transforms", + "///compiler/Transforms/Sequencer", + "///compiler/Utils", + "///hal:executable_format", + "///schemas", + "@com_github_google_flatbuffers//:flatbuffers", + "@llvm//:support", + "@local_config_mlir//:IR", + "@local_config_mlir//:Pass", + "@local_config_mlir//:StandardDialectRegistration", + "@local_config_mlir//:Support", + "@local_config_mlir//:Transforms", + "@local_config_mlir//:Translation", + "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo", + "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_dialect_registration", + "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_legalize_control_flow", + ], + alwayslink = 1, +)
diff --git a/iree/compiler/Translation/Sequencer/CMakeLists.txt b/compiler/Translation/Sequencer/CMakeLists.txt similarity index 100% rename from iree/compiler/Translation/Sequencer/CMakeLists.txt rename to compiler/Translation/Sequencer/CMakeLists.txt
diff --git a/compiler/Translation/Sequencer/SequencerModuleTranslation.cpp b/compiler/Translation/Sequencer/SequencerModuleTranslation.cpp new file mode 100644 index 0000000..a3d6834 --- /dev/null +++ b/compiler/Translation/Sequencer/SequencerModuleTranslation.cpp
@@ -0,0 +1,505 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/Translation/Sequencer/SequencerModuleTranslation.h" + +#include <cstdint> +#include <iostream> +#include <memory> +#include <unordered_map> +#include <vector> + +#include "flatbuffers/flatbuffers.h" +#include "flatbuffers/minireflect.h" +#include "base/status.h" +#include "compiler/IR/ConfigOps.h" +#include "compiler/IR/Sequencer/OpWriters.h" +#include "compiler/IR/StructureOps.h" +#include "compiler/IR/Types.h" +#include "compiler/Serialization/VMFunctionBuilder.h" +#include "compiler/Serialization/VMFunctionTableBuilder.h" +#include "compiler/Serialization/VMModuleBuilder.h" +#include "compiler/Transforms/Passes.h" +#include "compiler/Transforms/Sequencer/Passes.h" +#include "compiler/Utils/Macros.h" +#include "compiler/Utils/OpUtils.h" +#include "compiler/Utils/TranslationUtils.h" +#include "hal/executable_format.h" +#include "schemas/executable_def_generated.h" +#include "schemas/executable_table_def_generated.h" +#include "schemas/module_def_generated.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ToolOutputFile.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Module.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Translation.h" +#include "tensorflow/compiler/mlir/xla/transforms/passes.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +// Builds a pass pipeline that optimizes and legalizes the module to the form +// expected by partitioning. +void buildLegalizeInputPassPipeline(PassManager *passManager) { + // Convert to the subset of XLA HLO and Standard dialects supported as IREE + // input. In particular, move from XLA HLO to standard control flow. + passManager->addPass(xla_hlo::createLegalizeControlFlowPass()); + passManager->addPass(createLegalizeInputOpsPass()); + + // Standard passes that shake out a lot of garbage. + // Some may have been run prior to translation but this ensures we are always + // in a known state. + passManager->addPass(createCanonicalizerPass()); + passManager->addPass(createLoopFusionPass()); + passManager->addPass(createLoopInvariantCodeMotionPass()); + passManager->addPass(createMemRefDataFlowOptPass()); + passManager->addPass(createCanonicalizerPass()); + passManager->addPass(createSimplifyAffineStructuresPass()); + passManager->addPass(createCSEPass()); + passManager->addPass(createCanonicalizerPass()); + + // Expand uses of tuples into independent args/results. + passManager->addPass(createConvertFromTupleCallingConventionPass()); + passManager->addPass(createCanonicalizerPass()); +} + +// Builds a pass pipeline that partitions the module into sequencer functions +// and executables ready to be translated. +void buildPartitioningPassPipeline(PassManager *passManager) { + // Find reduction ops and create iree.reduction_regions. We do this prior to + // performing dispatch region identification so that we can build as big of + // fused reduction regions as possible. The remaining ops will be put into + // dispatch regions. + passManager->addPass(createIdentifyReductionRegionsPass()); + passManager->addPass(createCSEPass()); + + // Create all of the dispatch regions, CSE their workloads, and fold. + passManager->addPass(createIdentifyDispatchRegionsPass()); + passManager->addPass(createCSEPass()); + passManager->addPass(createFoldCompatibleDispatchRegionsPass()); + + // Note that as we are rematerializing things here it's critical we do not run + // the canonicalizer/CSE between now and when we outline - otherwise it'll + // undo all of our work! + passManager->addPass(createRematerializeDispatchConstantsPass()); + + // Outline the dispatch regions into their own functions. This separates the + // sequencer functions performing dispatches from the dispatchees. + passManager->addPass(createOutlineDispatchRegionsPass()); + passManager->addPass(createOutlineReductionRegionsPass()); + + // Cleanup identity sequencer tensor-to-memref ops that clutter up the IR. + passManager->addPass(createCanonicalizerPass()); + + // Drop all functions that are no longer reachable. + // This is important as many of the functions remaining are probably + // dispatchable and unused now that we've outlined them executables. + passManager->addPass(createDropUnreachableModuleFunctionsPass()); + + // Drop all unused executables. + // Note that we need to have dropped unreachable functions first otherwise + // references could keep executables that are unreachable from exported + // functions alive. + passManager->addPass(createDropUnusedExecutablesPass()); +} + +// Builds a pass pipeline that converts sequencer functions to the iree_seq.hl +// dialect. +void buildSequencerConversionPassPipeline(PassManager *passManager) { + passManager->addPass(createConvertToMemRefCallingConventionPass()); + + // Convert ops that are supported by the sequencer directly to the sequencer + // dialect. The ops that remain should be only those that can be moved into + // dispatch regions. + passManager->addPass(createLowerToSequencerDialectPass()); + + // Cleanup identity sequencer tensor-to-memref ops and other memory accesses + // that clutter up the IR. + passManager->addPass(createCanonicalizerPass()); + passManager->addPass(createMemRefDataFlowOptPass()); + + // Eliminate ops we don't care about based on a lack of side-effects. + // IREE does not guarantee exception/error behavior of dead ops. + passManager->addPass(createAggressiveOpEliminationPass()); + + // Perform any last-minute optimizations to trim down the IR. + passManager->addPass(createCanonicalizerPass()); + passManager->addPass(createMemRefDataFlowOptPass()); + passManager->addPass(createCSEPass()); +} + +// Builds a pass pipeline that lowers the iree_seq.hl dialect to the iree_seq.ll +// dialect and prepares for serialization. +void buildSequencerLoweringPassPipeline(PassManager *passManager) { + // Lower iree_hl_seq -> iree_ll_seq. + passManager->addPass(createLowerSequencerDialectPass()); + passManager->addPass(createCanonicalizerPass()); + passManager->addPass(createMemRefDataFlowOptPass()); + passManager->addPass(createAggressiveOpEliminationPass()); + + // Assign ordinals used by the bytecode to reference executables and + // functions. + passManager->addPass(createAssignFunctionOrdinalsPass()); + passManager->addPass(createAssignExecutableOrdinalsPass()); + + // Plumb workload information down into executable entry points. This allows + // the backends to calculate their workgroup sizes, indexing, etc. + passManager->addPass(createAssignExecutableWorkloadAttrsPass()); +} + +// Inserts one or more iree.executable_target_config ops based on the +// translation options. +void insertTargetConfigOps(const ModuleTranslationOptions &options, + OpBuilder &builder) { + llvm::StringSet<> targetBackends; + if (options.target_backends.empty()) { + // Add all backends when none are explicitly provided. + targetBackends.insert(getExecutableTranslationRegistry().keys().begin(), + getExecutableTranslationRegistry().keys().end()); + } else { + for (auto &targetBackend : options.target_backends) { + for (auto &matchedBackend : + matchExecutableTranslationBackendNames(targetBackend)) { + targetBackends.insert(matchedBackend); + } + } + } + for (auto &targetBackend : targetBackends) { + builder.create<IREE::ExecutableTargetConfigOp>(builder.getUnknownLoc(), + targetBackend.getKey()); + } +} + +class SequencerTranslator { + public: + explicit SequencerTranslator(ModuleTranslationOptions options) + : options_(options) {} + + const ModuleTranslationOptions &options() const { return options_; } + + std::vector<uint8_t> translateModule(ModuleOp module); + + private: + LogicalResult translateMultiArchExecutable( + IREE::MultiArchExecutableOp executableOp, VMModuleBuilder *moduleBuilder); + + LogicalResult translateSequencerModule(ModuleOp module, + VMModuleBuilder *moduleBuilder); + LogicalResult declareFunction(FuncOp function, + VMModuleBuilder *moduleBuilder); + LogicalResult defineFunction(FuncOp function, VMModuleBuilder *moduleBuilder); + + ModuleTranslationOptions options_; +}; + +std::vector<uint8_t> SequencerTranslator::translateModule(ModuleOp module) { + // Run one large set of passes to get to a partitioned module. + auto partitioningPasses = createPassManager(module.getContext(), options()); + buildLegalizeInputPassPipeline(partitioningPasses.get()); + buildPartitioningPassPipeline(partitioningPasses.get()); + if (failed(runPassPipeline(options(), partitioningPasses.get(), module))) { + module.emitError() << "Failed to run partitioning passes"; + return {}; + } + + // Run the sequencer-specific conversion passes on the module. + auto sequencerConversionPasses = + createPassManager(module.getContext(), options()); + buildSequencerConversionPassPipeline(sequencerConversionPasses.get()); + if (failed(runPassPipeline(options(), sequencerConversionPasses.get(), + module))) { + module.emitError() << "Failed to run sequencer conversion passes"; + return {}; + } + + // Lower sequencer functions to their final form. + auto sequencerLoweringPasses = + createPassManager(module.getContext(), options()); + buildSequencerLoweringPassPipeline(sequencerLoweringPasses.get()); + if (failed( + runPassPipeline(options(), sequencerLoweringPasses.get(), module))) { + module.emitError() << "Failed to run sequencer lowering passes"; + return {}; + } + + // Perform translation on all executables. + // We then know exactly what executable formats we have and can query them to + // see if we need to do any additional processing (such as to support better + // types/etc). + ::flatbuffers::FlatBufferBuilder fbb; + VMModuleBuilder moduleBuilder(&fbb); + for (auto multiArchExecutableOp : + module.getOps<IREE::MultiArchExecutableOp>()) { + if (failed(translateMultiArchExecutable(multiArchExecutableOp, + &moduleBuilder))) { + module.emitError() << "Failed to translate multi-arch-executable"; + return {}; + } + } + + // Build the module bytecode. + if (failed(translateSequencerModule(module, &moduleBuilder))) { + module.emitError() << "Unable to translate sequencer module"; + return {}; + } + auto moduleDef = moduleBuilder.Finish(); + if (moduleDef.IsNull()) { + module.emitError() << "Failed to verify completed module def"; + return {}; + } + auto bytes = moduleBuilder.Serialize(moduleDef); + if (bytes.empty()) { + module.emitError() << "Failed to serialize final module def"; + return {}; + } + return bytes; +} + +LogicalResult SequencerTranslator::translateMultiArchExecutable( + IREE::MultiArchExecutableOp multiArchExecutableOp, + VMModuleBuilder *moduleBuilder) { + auto &fbb = *moduleBuilder->fbb(); + + // Find the unspecified executable. This is the template from which we will + // translate to other targets. + IREE::ExecutableOp templateExecutableOp; + for (auto executableOp : + multiArchExecutableOp.getBlock().getOps<IREE::ExecutableOp>()) { + if (executableOp.format() == + static_cast<uint32_t>(IREE::ExecutableFormat::Unspecified)) { + templateExecutableOp = executableOp; + break; + } + } + if (!templateExecutableOp) { + // Fine for there to be no unspecified executable - just ignore. + return success(); + } + int entryPointCount = 0; + for (auto func : templateExecutableOp.getInnerModule().getOps<FuncOp>()) { + if (func.getAttr("iree.executable.export")) { + ++entryPointCount; + } + } + + // For now we just add target config ops based on options. In the future we + // could do this earlier via an analysis pass determining which targets should + // be used for each executable. + OpBuilder configBuilder(templateExecutableOp); + configBuilder.setInsertionPointToStart(&templateExecutableOp.getBlock()); + insertTargetConfigOps(options(), configBuilder); + + // Find all target configs and bucket them into the backends that will + // translate them. This way we can batch the translations and possibly enable + // backends to dedupe some things. + DenseMap<StringRef, std::vector<IREE::ExecutableTargetConfigOp>> + backendTargetConfigOps; + for (auto targetConfigOp : templateExecutableOp.getBlock() + .getOps<IREE::ExecutableTargetConfigOp>()) { + auto &targetConfigOps = backendTargetConfigOps[targetConfigOp.backend()]; + targetConfigOps.push_back(targetConfigOp); + } + if (backendTargetConfigOps.empty()) { + // There are no target configs - which likely means we've already translated + // this in a previous pass. + return success(); + } + + ExecutableTranslationOptions translationOptions; + translationOptions.CopyFrom(options()); + + // Invoke each backend translator on the template executables to produce new + // executables. The backends may produce any number of executables that we + // then merge back in to the iree.multi_arch_executable and the module + // flatbuffer. + std::vector<std::unique_ptr<iree::ExecutableDefT>> translatedExecutableDefs; + for (auto it : backendTargetConfigOps) { + const auto &backendKey = it.first; + const auto &targetConfigOps = it.second; + + // Find the translator to use in the registry. It must have been linked in + // and the name must match what is used in the registration macro. + auto translateExecutableFn = + getExecutableTranslationRegistry().lookup(backendKey); + if (!translateExecutableFn) { + return multiArchExecutableOp.emitError() + << "No registered backend found for target '" << backendKey.str() + << "'; ensure it is linked in to your binary (have: " + << llvm::join(getExecutableTranslationRegistry().keys(), ", ") + << ")"; + } + + // Clone the executable for each config so that the translator is allowed to + // modify it in-place. + // We also need to strip all of the other configs so that the translator + // backend only sees the one for each of its configs. + OpBuilder builder(&multiArchExecutableOp.getBlock()); + builder.setInsertionPoint(multiArchExecutableOp.getBlock().getTerminator()); + SmallVector<IREE::ExecutableOp, 4> clonedExecutableOps; + for (auto targetConfigOp : targetConfigOps) { + auto executableCloneOp = cast<IREE::ExecutableOp>( + builder.clone(*templateExecutableOp.getOperation())); + for (auto existingTargetConfigOp : llvm::make_early_inc_range( + executableCloneOp.getBlock() + .getOps<IREE::ExecutableTargetConfigOp>())) { + existingTargetConfigOp.erase(); + } + OpBuilder configBuilder(executableCloneOp); + configBuilder.setInsertionPointToStart(&executableCloneOp.getBlock()); + configBuilder.clone(*targetConfigOp.getOperation()); + clonedExecutableOps.push_back(executableCloneOp); + } + + // Perform translation on all of the backend-specific targets. + // Note that the results here may not have the same number of executables we + // started with if the backend either couldn't satisfy some of the requests + // or decided to dedupe or expand certain ones. + auto translationResults = + translateExecutableFn(clonedExecutableOps, translationOptions); + if (!translationResults.hasValue()) { + return multiArchExecutableOp.emitError() + << "Failed to translate executable with backend " << backendKey; + } + for (auto &executableDef : translationResults.getValue().executable_defs) { + translatedExecutableDefs.push_back(std::move(executableDef)); + } + } + + // Remove configs from the template executable so that if we are called again + // we don't re-translate. + for (auto targetConfigOp : llvm::make_early_inc_range( + templateExecutableOp.getBlock() + .getOps<IREE::ExecutableTargetConfigOp>())) { + targetConfigOp.erase(); + } + + // Create multi-arch executable with all of the target-specific executables. + iree::MultiArchExecutableDefT maedf; + maedf.name = multiArchExecutableOp.getName(); + maedf.entry_point_count = entryPointCount; + maedf.executables = std::move(translatedExecutableDefs); + auto maedfOffset = iree::MultiArchExecutableDef::Pack(fbb, &maedf); + RETURN_IF_FAILURE( + moduleBuilder->executable_table()->AddMultiArchExecutable(maedfOffset)); + + return success(); +} + +LogicalResult SequencerTranslator::translateSequencerModule( + ModuleOp module, VMModuleBuilder *moduleBuilder) { + // Declare functions. This must happen first so that we get stable indices + // during declaration (as call ops need to use the function table). + for (auto function : module.getOps<FuncOp>()) { + RETURN_IF_FAILURE(declareFunction(function, moduleBuilder)); + } + + // Define functions and convert their bodies to bytecode. + for (auto function : module.getOps<FuncOp>()) { + RETURN_IF_FAILURE(defineFunction(function, moduleBuilder)); + } + + return success(); +} + +LogicalResult SequencerTranslator::declareFunction( + FuncOp function, VMModuleBuilder *moduleBuilder) { + auto *functionTable = moduleBuilder->function_table(); + if (functionTable->IsFunctionDeclared(function)) { + // Already declared. + return success(); + } + + LinkageType linkageType; + if (function.isExternal()) { + linkageType = LinkageType::kImport; + } else if (function.getAttr("iree.module.export")) { + linkageType = LinkageType::kExport; + } else { + linkageType = LinkageType::kInternal; + } + if (failed(functionTable->DeclareFunction(function, linkageType))) { + return function.emitError() + << "Unable to declare function " << function.getName(); + } + + // Import functions must have their definition defined here so we get their + // type. Internal and export functions will be defined during conversion. + if (linkageType == LinkageType::kImport) { + VMFunctionBuilder functionBuilder(function, moduleBuilder->function_table(), + moduleBuilder->fbb()); + auto functionOffset = functionBuilder.Finish(); + if (functionOffset.IsNull()) { + return function.emitError() + << "Failed to create import function bytecode"; + } + RETURN_IF_FAILURE( + functionTable->DefineFunction(function, functionOffset, {})); + } + + return success(); +} + +LogicalResult SequencerTranslator::defineFunction( + FuncOp function, VMModuleBuilder *moduleBuilder) { + VMFunctionBuilder functionBuilder(function, moduleBuilder->function_table(), + moduleBuilder->fbb()); + registerSequencerCustomWriters(&functionBuilder); + RETURN_IF_FAILURE(functionBuilder.ConvertBytecode()); + auto functionOffset = functionBuilder.Finish(); + if (functionOffset.IsNull()) { + return function.emitError() << "Failed to convert function to bytecode"; + } + RETURN_IF_FAILURE(moduleBuilder->function_table()->DefineFunction( + function, functionOffset, functionBuilder.source_map())); + return success(); +} + +} // namespace + +std::vector<uint8_t> translateMlirToIreeSequencerModule( + ModuleOp module, ModuleTranslationOptions options) { + SequencerTranslator translator(options); + return translator.translateModule(module); +} + +LogicalResult translateMlirToIreeSequencerModuleFile( + ModuleOp module, llvm::raw_ostream &output) { + ModuleTranslationOptions options; + SequencerTranslator translator(options); + auto bytecodeModule = translator.translateModule(module); + if (bytecodeModule.empty()) { + return emitError(UnknownLoc::get(module.getContext()), + "failed to translate module"); + } + + output.write(reinterpret_cast<const char *>(bytecodeModule.data()), + bytecodeModule.size()); + return success(); +} + +static TranslateFromMLIRRegistration MlirToIreeSequencerModuleTranslate( + "mlir-to-iree-module", translateMlirToIreeSequencerModuleFile); + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Translation/Sequencer/SequencerModuleTranslation.h b/compiler/Translation/Sequencer/SequencerModuleTranslation.h new file mode 100644 index 0000000..7bd6c37 --- /dev/null +++ b/compiler/Translation/Sequencer/SequencerModuleTranslation.h
@@ -0,0 +1,36 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_TRANSLATION_SEQUENCER_SEQUENCERMODULETRANSLATION_H_ +#define IREE_COMPILER_TRANSLATION_SEQUENCER_SEQUENCERMODULETRANSLATION_H_ + +#include <vector> + +#include "compiler/Utils/TranslationUtils.h" +#include "mlir/IR/Module.h" + +namespace mlir { +namespace iree_compiler { + +// Translates an MLIR module in a compatible IREE input dialect (such as XLA HLO +// and/or Std) into an IREE Module. Executables will be lowered based on the +// provided configuration. +// Returns an empty vector on translation failure. +std::vector<uint8_t> translateMlirToIreeSequencerModule( + ModuleOp module, ModuleTranslationOptions options = {}); + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_TRANSLATION_SEQUENCER_SEQUENCERMODULETRANSLATION_H_
diff --git a/compiler/Utils/BUILD b/compiler/Utils/BUILD new file mode 100644 index 0000000..bb09cb6 --- /dev/null +++ b/compiler/Utils/BUILD
@@ -0,0 +1,41 @@ +# Utilities for working with IREE MLIR types. + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "Utils", + srcs = [ + "DispatchUtils.cpp", + "MemRefUtils.cpp", + "ModuleUtils.cpp", + "OpCreationUtils.cpp", + "OpUtils.cpp", + "TranslationUtils.cpp", + "TypeConversionUtils.cpp", + ], + hdrs = [ + "DispatchUtils.h", + "Macros.h", + "MemRefUtils.h", + "ModuleUtils.h", + "OpCreationUtils.h", + "OpUtils.h", + "TranslationUtils.h", + "TypeConversionUtils.h", + ], + deps = [ + "///compiler/IR", + "///schemas", + "@llvm//:support", + "@local_config_mlir//:IR", + "@local_config_mlir//:Pass", + "@local_config_mlir//:StandardOps", + "@local_config_mlir//:Support", + "@local_config_mlir//:TransformUtils", + "@local_config_mlir//:Transforms", + "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo", + ], +)
diff --git a/iree/compiler/Utils/CMakeLists.txt b/compiler/Utils/CMakeLists.txt similarity index 100% rename from iree/compiler/Utils/CMakeLists.txt rename to compiler/Utils/CMakeLists.txt
diff --git a/compiler/Utils/DispatchUtils.cpp b/compiler/Utils/DispatchUtils.cpp new file mode 100644 index 0000000..461d12a --- /dev/null +++ b/compiler/Utils/DispatchUtils.cpp
@@ -0,0 +1,726 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/Utils/DispatchUtils.h" + +#include "compiler/IR/Ops.h" +#include "compiler/Utils/TypeConversionUtils.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/Utils.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" + +namespace mlir { +namespace iree_compiler { + +Value *calculateWorkload(Operation *op, Value *baseOperand) { + OpBuilder builder(op); + + std::array<int32_t, 3> workload = {1, 1, 1}; + + // TODO(b/139353314): lookup/calculate based on type/etc. + auto resultType = baseOperand->getType(); + if (auto shapedType = resultType.dyn_cast<ShapedType>()) { + if (!shapedType.hasStaticShape()) { + op->emitOpError() << "Dynamic shapes not yet supported"; + return nullptr; + } + auto shape = shapedType.getShape(); + // Drop the trailing ones from the shape. + while (shape.size() > 1 && shape.back() == 1) { + shape = shape.drop_back(); + } + if (shape.size() <= 3) { + // Maps to XYZ (possibly with 1's for unused dimensions). + for (auto dim : enumerate(shape)) { + workload[shape.size() - 1 - dim.index()] = dim.value(); + } + } else { + // Need to flatten the shape to fit XYZ. For now we just squash from LHS. + workload[2] = 1; + for (int i = 0; i < shape.size(); ++i) { + workload[2] *= shape[i]; + } + workload[1] = shape[shape.size() - 2]; + workload[0] = shape.back(); + } + } + + // TODO(b/139353314): optimize workload layout. + + auto constantType = RankedTensorType::get({3}, builder.getIntegerType(32)); + return builder.create<ConstantOp>( + op->getLoc(), constantType, + DenseIntElementsAttr::get<int32_t>(constantType, workload)); +} + +bool isTriviallyDispatchable(FuncOp func) { + if (func.empty()) return false; + auto &block = func.front(); + if (block.getOperations().size() != 2) return false; + auto &op0 = block.front(); + auto &op1 = block.back(); + auto regionOp = dyn_cast<IREE::DispatchRegionOp>(op0); + auto returnOp = dyn_cast<ReturnOp>(op1); + if (!regionOp || !returnOp || + regionOp.getNumResults() != returnOp.getNumOperands()) { + return false; + } + for (int i = 0; i < regionOp.getNumResults(); ++i) { + if (regionOp.getResult(i) != returnOp.getOperand(i)) return false; + } + return true; +} + +namespace { + +// Returns the set of values that must be captured for use by |ops| and the +// set of values defined by |ops| that are used outside of the set. +LogicalResult analyzeOpRangeValues( + const llvm::SmallDenseSet<Operation *> &opSet, + llvm::SetVector<Value *> *capturedValues, + llvm::SetVector<Value *> *escapingValues) { + for (auto *op : opSet) { + for (auto *value : op->getOperands()) { + if (!llvm::is_contained(opSet, value->getDefiningOp())) { + // Op is using a value not in the ops set, ensure we capture it. + capturedValues->insert(value); + } + } + for (auto *value : op->getResults()) { + for (auto &use : value->getUses()) { + if (!llvm::is_contained(opSet, use.getOwner())) { + // An op outside of the ops set is using the value, needs to escape. + escapingValues->insert(value); + } + } + } + } + return success(); +} + +} // namespace + +LogicalResult buildDispatchRegion(FuncOp func, Block *parentBlock, + Value *workload, ArrayRef<Operation *> ops) { + // Fused location with all ops. + SmallVector<Location, 16> opLocs; + for (auto *op : ops) { + opLocs.push_back(op->getLoc()); + } + auto regionLoc = FusedLoc::get(opLocs, func.getContext()); + + // Get a list of values that we need to capture and values that escape the + // region and need to be returned. + llvm::SmallDenseSet<Operation *> opSet; + opSet.reserve(ops.size()); + opSet.insert(ops.begin(), ops.end()); + llvm::SetVector<Value *> capturedValues; + llvm::SetVector<Value *> escapingValues; + if (failed(analyzeOpRangeValues(opSet, &capturedValues, &escapingValues))) { + return failure(); + } + SmallVector<Type, 8> escapingTypes; + for (auto *value : escapingValues) escapingTypes.push_back(value->getType()); + + // Build the region op and add it to the parent block. + OpBuilder parentBuilder(parentBlock); + parentBuilder.setInsertionPoint(ops.back()); + auto dispatchRegionOp = parentBuilder.create<IREE::DispatchRegionOp>( + regionLoc, escapingTypes, workload, capturedValues.getArrayRef()); + + // Create the block and setup the arg mapping for captured values. + auto *regionBlock = new Block(); + dispatchRegionOp.getBody().push_back(regionBlock); + OpBuilder regionBuilder(regionBlock); + BlockAndValueMapping mapping; + for (auto *capturedValue : capturedValues) { + auto *blockArg = regionBlock->addArgument(capturedValue->getType()); + mapping.map(capturedValue, blockArg); + } + + // Clone ops into the new region block. + for (auto *op : ops) { + // Note that this updates the mapping with the new values (so at the end + // we have those new values). + regionBuilder.clone(*op, mapping); + } + + // Return results (as we need a terminator in our block). + // These are all of the values that escape our region. + SmallVector<Value *, 8> resultValues; + for (auto *oldValue : escapingValues) { + resultValues.push_back(mapping.lookupOrDefault(oldValue)); + } + regionBuilder.create<IREE::ReturnOp>(opLocs.back(), resultValues); + + // Replace usage of values with the results of the region. + for (int i = 0; i < escapingValues.size(); ++i) { + escapingValues[i]->replaceAllUsesWith(dispatchRegionOp.getResult(i)); + } + + // Remove original ops from the parent region. + for (auto it = ops.rbegin(); it != ops.rend(); ++it) { + (*it)->erase(); + } + + return success(); +} + +namespace { + +// Replaces |returnOp| with a clone including |newOperands| appended. +LogicalResult appendReturnOperands(IREE::ReturnOp returnOp, + ArrayRef<Value *> newOperands) { + // Insert prior to the original return. + OpBuilder builder(returnOp); + + // Clone with new args. + SmallVector<Value *, 8> operands; + operands.reserve(returnOp.getNumOperands() + newOperands.size()); + operands.append(returnOp.operand_begin(), returnOp.operand_end()); + operands.append(newOperands.begin(), newOperands.end()); + builder.create<IREE::ReturnOp>(returnOp.getLoc(), operands); + + // Remove original. + returnOp.erase(); + + return success(); +} + +// Replaces |regionOp| with a clone including |newArgs| and |newResults|. +IREE::DispatchRegionOp appendRegionArgsAndResults( + IREE::DispatchRegionOp ®ionOp, ArrayRef<Value *> newArgs, + ArrayRef<Value *> newResults, Location otherLoc) { + // Insert prior to the original region. + OpBuilder builder(regionOp); + + // Location is original region + new region location (both probably fused). + SmallVector<Location, 2> fusedLocs = {regionOp.getLoc(), otherLoc}; + auto fusedLoc = FusedLoc::get(fusedLocs, regionOp.getContext()); + + // Clone with new results. + SmallVector<Value *, 8> operands; + operands.append(regionOp.getArgOperands().begin(), + regionOp.getArgOperands().end()); + operands.append(newArgs.begin(), newArgs.end()); + SmallVector<Type, 8> resultTypes; + resultTypes.append(regionOp.result_type_begin(), regionOp.result_type_end()); + for (auto *newResult : newResults) { + resultTypes.push_back(newResult->getType()); + } + auto newRegionOp = builder.create<IREE::DispatchRegionOp>( + fusedLoc, resultTypes, regionOp.getWorkload(), operands, + regionOp.getAttrs()); + newRegionOp.getBody().takeBody(regionOp.getBody()); + + // Replace uses of original values with the new values. + for (int i = 0; i < regionOp.getNumResults(); ++i) { + regionOp.getResult(i)->replaceAllUsesWith(newRegionOp.getResult(i)); + } + + // Erase the original region. + regionOp.erase(); + + return newRegionOp; +} + +// Removes results that are not used from the dispatch region. +// Returns the new operation. There may be unused ops in the region but DCE +// should take care of that later. +IREE::DispatchRegionOp removeUnusedResults(IREE::DispatchRegionOp regionOp) { + // Find return value within the region. + auto ®ionBlock = regionOp.getBody().getBlocks().front(); + auto returnOp = dyn_cast<IREE::ReturnOp>(regionBlock.getTerminator()); + if (!returnOp) { + regionBlock.getParent()->getParentOfType<FuncOp>().emitError() + << "Block does not contain an iree.return op"; + } + + // Calculate new return values. + SmallVector<Type, 8> newReturnTypes; + SmallVector<Value *, 8> newReturnValues; + SmallVector<Value *, 8> newRegionResults; + for (int i = 0; i < returnOp.getNumOperands(); ++i) { + auto *resultValue = regionOp.getResult(i); + if (!resultValue->use_empty()) { + // Still has uses so we will preserve it. + newReturnTypes.push_back(resultValue->getType()); + newReturnValues.push_back(returnOp.getOperand(i)); + newRegionResults.push_back(resultValue); + } + } + + // Update return op operands. We can do this in-place as we are only shrinking + // the list. + returnOp.getOperation()->setOperands(newReturnValues); + + // Insert prior to the original region. + OpBuilder builder(regionOp); + + // Clone with new results. + SmallVector<Value *, 8> operands(regionOp.getArgOperands()); + auto newRegionOp = builder.create<IREE::DispatchRegionOp>( + regionOp.getLoc(), newReturnTypes, regionOp.getWorkload(), operands, + regionOp.getAttrs()); + newRegionOp.getBody().takeBody(regionOp.getBody()); + + // Replace uses of original values with the new values. + for (int i = 0; i < newRegionResults.size(); ++i) { + newRegionResults[i]->replaceAllUsesWith(newRegionOp.getResult(i)); + } + + // Erase the original region. + regionOp.erase(); + + return newRegionOp; +} + +// Returns true if |lhs| and |rhs| have either an identical workload or one that +// is compatible. +bool areDispatchRegionWorkloadsCompatible(IREE::DispatchRegionOp &lhs, + IREE::DispatchRegionOp &rhs) { + // TODO(benvanik): more sophisticated checking; right now it's just identical. + return lhs.getWorkload() == rhs.getWorkload(); +} + +// Returns true if |value| depends in any way on |op| through any path. +// Only works if the operations are within the same block. +bool doesValueDependOnOperation(Value *value, Operation *op) { + if (!value->getDefiningOp()) { + return false; + } else if (value->getDefiningOp() == op) { + return true; + } else if (value->getDefiningOp()->isBeforeInBlock(op)) { + // Can't depend on |op| as it is defined prior to it. + return false; + } + for (auto *operand : value->getDefiningOp()->getOperands()) { + if (doesValueDependOnOperation(operand, op)) { + return true; + } + } + return true; +} + +// Returns true if |rhs| transitively depends on any out of |lhs|. +// |rhs| may depend directly on the results of |lhs| but no other ops in the +// parent block will use the results prior to |rhs|. +bool areDispatchRegionsTransitivelyDependent(IREE::DispatchRegionOp &lhs, + IREE::DispatchRegionOp &rhs) { + for (auto *arg : rhs.getArgOperands()) { + if (arg->getDefiningOp() != lhs && doesValueDependOnOperation(arg, lhs)) { + // Transitively dependent - boo - can't merge yet. + return true; + } + } + return false; +} + +// Returns true if the dispatch region contains only a single block. +// This is because our merge isn't very smart and will not preserve the CFG +// right now. We can fix this when needed. +bool isDispatchRegionMergable(IREE::DispatchRegionOp ®ionOp) { + // Disallow merging of dispatch regions containing matmuls and other big ops. + // We do this to allow backends to lower the big op as entirely isolated such + // that substituting library calls is easier. + for (auto &block : regionOp.getBody().getBlocks()) { + for (auto &op : block) { + if (isa<xla_hlo::DotOp>(op) || isa<xla_hlo::ConvOp>(op)) { + return false; + } + } + } + return regionOp.getBody().getBlocks().size() == 1; +} + +// Merges |rhs| into |lhs| and returns the new |lhs| op. +// Precondition: !areDispatchRegionsTransitivelyDependent +IREE::DispatchRegionOp mergeDispatchRegions(IREE::DispatchRegionOp &lhs, + IREE::DispatchRegionOp &rhs) { + auto &lhsBlock = lhs.getBody().front(); + auto &rhsBlock = rhs.getBody().front(); + + // Find the values used as return values in the lhs. + // We'll need to replace the uses in rhs with these. + auto lhsReturnOp = cast<IREE::ReturnOp>(lhsBlock.getTerminator()); + SmallVector<Value *, 8> lhsReturnValues; + lhsReturnValues.reserve(lhsReturnOp.getNumOperands()); + lhsReturnValues.append(lhsReturnOp.operand_begin(), + lhsReturnOp.operand_end()); + + // Find the values used as return values in the rhs. + // We'll add these to the results of the lhs region. + auto rhsReturnOp = cast<IREE::ReturnOp>(rhsBlock.getTerminator()); + SmallVector<Value *, 8> rhsReturnValues; + rhsReturnValues.reserve(rhsReturnOp.getNumOperands()); + rhsReturnValues.append(rhsReturnOp.operand_begin(), + rhsReturnOp.operand_end()); + + // Compute new args. + BlockAndValueMapping mapping; + SmallVector<Value *, 8> newArgs; + for (int rhsOpIdx = 0; rhsOpIdx < rhs.getNumArgOperands(); ++rhsOpIdx) { + bool didElide = false; + // Find if the rhs arg already exists on the lhs and dedupe. + for (int lhsOpIdx = 0; lhsOpIdx < lhs.getNumArgOperands(); ++lhsOpIdx) { + if (rhs.getArgOperand(rhsOpIdx) == lhs.getArgOperand(lhsOpIdx)) { + mapping.map(rhsBlock.getArgument(rhsOpIdx), + lhsBlock.getArgument(lhsOpIdx)); + didElide = true; + break; + } + } + // Find if the arg has a direct dependency on the results of the lhs. + for (int lhsResultIdx = 0; lhsResultIdx < lhs.getNumResults(); + ++lhsResultIdx) { + if (rhs.getArgOperand(rhsOpIdx) == lhs.getResult(lhsResultIdx)) { + // Direct dependency; can elide. We'll skip adding it to the new region + // args and instead just remap it later. + mapping.map(rhsBlock.getArgument(rhsOpIdx), + lhsReturnValues[lhsResultIdx]); + didElide = true; + break; + } + } + if (!didElide) { + // Add to the lhs block. + auto *oldArg = rhs.getOperand(rhsOpIdx + 1); + auto *newArg = lhsBlock.addArgument(oldArg->getType()); + mapping.map(rhsBlock.getArgument(rhsOpIdx), newArg); + newArgs.push_back(oldArg); + } + } + + OpBuilder regionBuilder(&lhsBlock); + + // Copy ops (replacing any args as needed). + // Note that we need to insert prior to the terminator. + regionBuilder.setInsertionPoint(lhsReturnOp); + for (auto &op : rhsBlock) { + // Note that this updates the mapping with the new values (so at the end + // we have those new values). + // + // We avoid the return op here as we have already merged it above. + if (!op.isKnownTerminator()) { + regionBuilder.clone(op, mapping); + } + } + + // Compute new results and add to both region and return op. + SmallVector<Value *, 8> newResults; + for (auto *rhsResult : rhsReturnValues) { + newResults.push_back(mapping.lookupOrDefault(rhsResult)); + } + if (failed(appendReturnOperands(lhsReturnOp, newResults))) { + return nullptr; + } + auto newRegionOp = + appendRegionArgsAndResults(lhs, newArgs, newResults, rhs.getLoc()); + + // Replace uses of original values with the new values. + for (int i = 0; i < rhs.getNumResults(); ++i) { + rhs.getResult(i)->replaceAllUsesWith( + newRegionOp.getResult(lhsReturnValues.size() + i)); + } + + // Remove rhs region. + rhs.erase(); + + // Remove results from the lhs that aren't used anymore as they may have been + // elided when we merged as only the rhs was using them. + newRegionOp = removeUnusedResults(newRegionOp); + + return newRegionOp; +} + +} // namespace + +LogicalResult mergeBlockDispatchRegions(FuncOp func, Block *parentBlock) { + SmallVector<IREE::DispatchRegionOp, 8> mergableRegions; + for (auto &op : *parentBlock) { + if (auto regionOp = dyn_cast<IREE::DispatchRegionOp>(op)) { + if (isDispatchRegionMergable(regionOp)) { + mergableRegions.push_back(regionOp); + } else { + regionOp.emitRemark( + "Unable to merge into following iree.dispatch_regions; " + "contains non-trivial control flow"); + } + } + } + for (int i = 0; i < mergableRegions.size(); ++i) { + if (!mergableRegions[i]) continue; + auto &lhs = mergableRegions[i]; + for (int j = i + 1; j < mergableRegions.size(); ++j) { + if (!mergableRegions[j]) continue; + auto &rhs = mergableRegions[j]; + if (!areDispatchRegionWorkloadsCompatible(lhs, rhs) || + areDispatchRegionsTransitivelyDependent(lhs, rhs)) { + continue; + } + if (!isDispatchRegionMergable(rhs)) { + // TODO(b/134675461): support non-trivial control flow. + rhs.emitRemark( + "Unable to merge into previous iree.dispatch_region; " + "contains non-trivial control flow"); + } + mergableRegions[i] = mergeDispatchRegions(lhs, rhs); + if (!mergableRegions[i]) { + return failure(); + } + mergableRegions[j] = nullptr; + --i; // Try again to see if there are subsequent regions to merge. + break; + } + } + + return success(); +} + +namespace { + +// Recursively clones the given |sourceOp| and returns the newly cloned op. +Operation *recursivelyCloneOp(Operation *sourceOp, OpBuilder &builder, + BlockAndValueMapping *mapping) { + // Note that we dedupe required operands in the case of multiple arguments + // coming from the same source operation. + SmallPtrSet<Operation *, 4> operandOps; + for (auto *operand : sourceOp->getOperands()) { + operandOps.insert(operand->getDefiningOp()); + } + for (auto *operandOp : operandOps) { + recursivelyCloneOp(operandOp, builder, mapping); + } + return builder.clone(*sourceOp, *mapping); +} + +// Clones the |sourceValue| op tree into |targetBlock|. +// |mapping| is used to lookup existing values that may be present in the block +// such as block arguments or already cloned ancestor ops. |mapping| will be +// updated as the tree is cloned. +Value *cloneOpTreeIntoBlock(Value *sourceValue, Block *targetBlock, + BlockAndValueMapping *mapping) { + // If the op has already been cloned we can just reuse that. + // This happens if multiple arguments reference the same trees. + if (auto *existingValue = mapping->lookupOrNull(sourceValue)) { + return existingValue; + } + + OpBuilder builder(targetBlock); + builder.setInsertionPointToStart(targetBlock); + auto *sourceOp = sourceValue->getDefiningOp(); + auto *clonedOp = recursivelyCloneOp(sourceOp, builder, mapping); + + // Return only the result matching our source value (in the case of multiple + // results). + int resultIndex = std::distance( + sourceOp->result_begin(), + std::find(sourceOp->result_begin(), sourceOp->result_end(), sourceValue)); + return clonedOp->getResult(resultIndex); +} + +} // namespace + +LogicalResult inlineDispatchRegionOperandsUsingValue( + IREE::DispatchRegionOp dispatchRegionOp, Value *value) { + // Find all args that are using this value. + SmallVector<unsigned, 4> argIndices; + for (auto arg : llvm::enumerate(dispatchRegionOp.getArgOperands())) { + if (arg.value() == value) { + argIndices.push_back(arg.index()); + } + } + if (argIndices.empty()) { + // Not used? Wasteful call! + return success(); + } + + // Clone the value (and the ops required to create it) into the entry block. + auto &entryBlock = dispatchRegionOp.getBody().getBlocks().front(); + BlockAndValueMapping mapping; + auto *clonedValue = cloneOpTreeIntoBlock(value, &entryBlock, &mapping); + + // Replace all uses of the inner operand with the new value. + for (unsigned argIndex : argIndices) { + entryBlock.getArgument(argIndex)->replaceAllUsesWith(clonedValue); + } + + // Remove the dispatch region args and the block args that have been + // replaced. + for (unsigned argIndex : llvm::reverse(argIndices)) { + dispatchRegionOp.getOperation()->eraseOperand( + dispatchRegionOp.mapArgOperandToOpOperand(argIndex)); + entryBlock.eraseArgument(argIndex); + } + + return success(); +} + +namespace { + +// Recursively finds all reachable functions from the given |rootFunc| and adds +// them to the |reachableFuncs| set. +// +// Note that indirect calls are not supported, however we don't allow those in +// dispatch regions anyway so they should not be present here. +LogicalResult findReachableFunctions(Operation *rootFunc, + llvm::SetVector<FuncOp> &reachableFuncs) { + bool allCallsValid = true; + rootFunc->walk([&](CallOp op) { + auto callee = rootFunc->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>( + op.getCallee()); + if (!callee.getAttr("iree.dispatchable")) { + allCallsValid = false; + rootFunc->emitError() << callee.getName().str() << " is not dispatchable"; + return; + } + if (reachableFuncs.insert(callee)) { + findReachableFunctions(callee, reachableFuncs); + } + }); + return success(allCallsValid); +} + +} // namespace + +std::pair<IREE::MultiArchExecutableOp, FuncOp> createRegionExecutable( + Operation *op, FunctionType functionType, StringRef symbolSuffix) { + // Create the function and take the region body directly. + // NOTE: this will get uniquified if we have multiple in the same block. + auto parentFunc = op->getParentOfType<FuncOp>(); + std::string functionName = + (parentFunc.getName().str() + "_rgn" + symbolSuffix).str(); + auto outlinedFunc = FuncOp::create(op->getLoc(), functionName, functionType); + BlockAndValueMapping mapping; + op->getRegion(0).cloneInto(&outlinedFunc.getBody(), mapping); + + // Gather all reachable functions. + llvm::SetVector<FuncOp> reachableFuncs; + findReachableFunctions(outlinedFunc, reachableFuncs); + + // Create the multi-arch executable that will contain the outlined region. + // NOTE: this will get uniquified if we have multiple in the same block. + auto parentModule = parentFunc.getParentOfType<ModuleOp>(); + OpBuilder parentModuleBuilder(parentModule); + parentModuleBuilder.setInsertionPoint(parentFunc); + std::string executableName = + (parentFunc.getName().str() + "_ex" + symbolSuffix).str(); + auto multiArchExecutable = + parentModuleBuilder.create<IREE::MultiArchExecutableOp>( + outlinedFunc.getLoc(), executableName); + + // Create the executable op initially unspecified so that later + // transformations can compile it to various formats. + OpBuilder multiArchExecutableBuilder(multiArchExecutable); + multiArchExecutableBuilder.setInsertionPointToStart( + &multiArchExecutable.getBlock()); + auto executable = multiArchExecutableBuilder.create<IREE::ExecutableOp>( + outlinedFunc.getLoc(), IREE::ExecutableFormat::Unspecified); + + // Create the inner ModuleOp that contains the original functions. We need + // to provide this shim as some ops (like std.call) look for the + // containing module to provide symbol resolution. + OpBuilder executableBuilder(executable); + executableBuilder.setInsertionPointToStart(&executable.getBlock()); + auto innerModule = executableBuilder.create<ModuleOp>(outlinedFunc.getLoc()); + + // TODO(b/137674142): make an ExecutableEntryPointOp and convert the + // entry thunk into that format. + innerModule.push_back(outlinedFunc); + + // Copy all reachable functions into the executable. + // Linker passes may dedupe these later on. + for (auto reachableFunc : reachableFuncs) { + auto clonedFunc = reachableFunc.clone(); + clonedFunc.removeAttr("iree.dispatchable"); + innerModule.push_back(clonedFunc); + } + + return std::make_pair(multiArchExecutable, outlinedFunc); +} + +Value *insertDispatcherStore(Operation *op, Value *value, OpBuilder &builder) { + if (!value) { + return nullptr; + } + + // If the previous value was already a memref we don't need to change + // anything. + // TODO(benvanik): ensure indices make sense. + if (value->getType().isa<MemRefType>()) { + return value; + } else if (value->getType().isa<TensorType>()) { + auto castOp = builder.create<IREE::TensorToMemRefOp>(op->getLoc(), value); + return castOp.getResult(); + } + + // Allocate the memref to store the value. + auto newStorage = builder.create<AllocOp>( + op->getLoc(), convertTypeToMemRef(value->getType())); + + // Insert the store we'll use to box the value. + builder.create<StoreOp>(op->getLoc(), value, newStorage, ArrayRef<Value *>{}); + + return newStorage; +} + +Value *insertDispatcherLoad(Operation *op, Value *originalValue, + Value *allocatedValue, OpBuilder &builder) { + // If old value was a memref we don't need to change anything. + if (originalValue->getType().isa<MemRefType>()) { + return allocatedValue; + } else if (originalValue->getType().isa<TensorType>()) { + auto castOp = + builder.create<IREE::MemRefToTensorOp>(op->getLoc(), allocatedValue); + originalValue->replaceAllUsesWith(castOp.getResult()); + return castOp.getResult(); + } + + // Insert the load we'll use to unbox the value. + auto loadOp = + builder.create<LoadOp>(op->getLoc(), allocatedValue, ArrayRef<Value *>{}); + originalValue->replaceAllUsesWith(loadOp); + return loadOp; +} + +// TODO(benvanik): enough information to walk into dispatch region and compute +// shape when not static. +Value *allocateDispatchOutputBuffer(Location loc, MemRefType type, + OpBuilder &builder) { + // TODO(benvanik): allocation algorithm: + // - synthesize shape logic (magic) [[ for now assume fixed shapes ]] + // - insert shape logic above region + // - rely on folding to merge multiple calculations together + // - unranked = death, need to be able to alloc shape outputs + // - insert alloc + SmallVector<Value *, 4> dimPieces; + return builder.create<AllocOp>(loc, type, dimPieces); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Utils/DispatchUtils.h b/compiler/Utils/DispatchUtils.h new file mode 100644 index 0000000..02a8d50 --- /dev/null +++ b/compiler/Utils/DispatchUtils.h
@@ -0,0 +1,92 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Utilities for dispatch region and function manipulation. +// These are shared between all dispatchable types such as the standard +// iree.dispatch_region as well as dispatch-related types like +// iree.reduction_region. + +#ifndef IREE_COMPILER_UTILS_DISPATCHUTILS_H_ +#define IREE_COMPILER_UTILS_DISPATCHUTILS_H_ + +#include <utility> + +#include "compiler/IR/Ops.h" +#include "compiler/IR/StructureOps.h" +#include "llvm/ADT/ArrayRef.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Value.h" + +namespace mlir { +namespace iree_compiler { + +// Calculates the workload for |op| based on the op type. +Value *calculateWorkload(Operation *op, Value *baseOperand); + +// Returns true if the func is trivially dispatchable, meaning that: +// - it contains a single block +// - it contains a single dispatch region +// - it contains a return op directly returning the dispatch region results +bool isTriviallyDispatchable(FuncOp func); + +// Builds a new iree.dispatch_region with the given |ops|. +// The region will capture all required values and return all values used +// outside of the |ops| provided. The region will be inserted at the location of +// the last operation in the set. +// +// All |ops| must be compatible with the |workload| specified as they will all +// be dispatched with the same workgroup structure. +// TODO(benvanik): ensure we want to insert at end. Maybe front? +LogicalResult buildDispatchRegion(FuncOp func, Block *parentBlock, + Value *workload, ArrayRef<Operation *> ops); + +// Merges multiple dispatch regions within a block into the same region, +// if possible. Operations may be reordered if it's possible to merge more while +// still obeying data dependencies. +LogicalResult mergeBlockDispatchRegions(FuncOp func, Block *parentBlock); + +// Inlines use of the given |value| from outside of a dispatch region to inside +// of it and removes the argument. Supports multiple arguments that reference +// |value| and will clone the entire value tree. +LogicalResult inlineDispatchRegionOperandsUsingValue( + IREE::DispatchRegionOp dispatchRegionOp, Value *value); + +// Creates an iree.multi_arch_executable containing an iree.executable with an +// exported function containing the body region of |op|. Created executables +// will be named for their original function concatenated with |symbolSuffix|. +std::pair<IREE::MultiArchExecutableOp, FuncOp> createRegionExecutable( + Operation *op, FunctionType functionType, StringRef symbolSuffix); + +// Inserts a conversion of an arbitrary |value| to a memref, possibly by way of +// wrapping in an allocation. +// Returns a new memref containing the value or an alias to |value|. +Value *insertDispatcherStore(Operation *op, Value *value, OpBuilder &builder); + +// Inserts a load from a wrapped memref. +// Returns the value in the original type or an alias to the |value| memref. +Value *insertDispatcherLoad(Operation *op, Value *originalValue, + Value *allocatedValue, OpBuilder &builder); + +// TODO(benvanik): enough information to walk into dispatch region and compute +// shape when not static. +Value *allocateDispatchOutputBuffer(Location loc, MemRefType type, + OpBuilder &builder); + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_UTILS_DISPATCHUTILS_H_
diff --git a/iree/compiler/Utils/Macros.h b/compiler/Utils/Macros.h similarity index 100% rename from iree/compiler/Utils/Macros.h rename to compiler/Utils/Macros.h
diff --git a/compiler/Utils/MemRefUtils.cpp b/compiler/Utils/MemRefUtils.cpp new file mode 100644 index 0000000..c97b610 --- /dev/null +++ b/compiler/Utils/MemRefUtils.cpp
@@ -0,0 +1,94 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/Utils/MemRefUtils.h" + +#include <cassert> + +#include "compiler/IR/Ops.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/ErrorHandling.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/StandardTypes.h" + +namespace mlir { +namespace iree_compiler { +Value *resolveValueToSourceMemRef(Value *value, Operation *useOp) { + // TODO(benvanik): implement this for real; this is naive but enough for our + // simple load patterns. + auto *defInstr = value->getDefiningOp(); + if (auto loadOp = dyn_cast_or_null<LoadOp>(defInstr)) { + // TODO(benvanik): support views. + return loadOp.getMemRef(); + } + return nullptr; +} + +Value *wrapAsTensor(Value *value, Operation *srcOp, OpBuilder &builder) { + if (srcOp->getResult(0)->getType().isa<TensorType>()) { + if (isa_and_nonnull<IREE::TensorToMemRefOp>(value->getDefiningOp())) { + return value->getDefiningOp()->getOperand(0); + } + auto newOp = builder.create<IREE::MemRefToTensorOp>(srcOp->getLoc(), value); + value = newOp.getResult(); + } + return value; +} + +Value *wrapAsMemRef(Value *value, Operation *srcOp, OpBuilder &builder) { + if (value->getType().isa<TensorType>()) { + if (isa_and_nonnull<IREE::MemRefToTensorOp>(value->getDefiningOp())) { + return value->getDefiningOp()->getOperand(0); + } + auto newOp = builder.create<IREE::TensorToMemRefOp>(srcOp->getLoc(), value); + value = newOp.getResult(); + } + return value; +} + +Value *loadAccessValue(Location location, Value *operand, OpBuilder &builder) { + if (operand->getType().isa<MemRefType>() || + operand->getType().isa<TensorType>()) { + return operand; + } + + auto memRefType = MemRefType::get({}, operand->getType()); + if (auto loadOp = dyn_cast_or_null<LoadOp>(operand->getDefiningOp())) { + // TODO(benvanik): handle creating views. + if (loadOp.getMemRefType() == memRefType) { + return loadOp.getMemRef(); + } + } + + auto allocOp = builder.create<AllocOp>(location, memRefType); + builder.create<StoreOp>(location, operand, allocOp.getResult(), + ArrayRef<Value *>{}); + return allocOp.getResult(); +} + +Value *loadResultValue(Location location, const Type &originalType, + Value *result, OpBuilder &builder) { + if (originalType.isa<MemRefType>()) { + return result; + } else if (auto tensorType = originalType.dyn_cast<TensorType>()) { + return result; + } + + auto loadOp = builder.create<LoadOp>(location, result, ArrayRef<Value *>{}); + return loadOp.getResult(); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Utils/MemRefUtils.h b/compiler/Utils/MemRefUtils.h similarity index 100% rename from iree/compiler/Utils/MemRefUtils.h rename to compiler/Utils/MemRefUtils.h
diff --git a/compiler/Utils/ModuleUtils.cpp b/compiler/Utils/ModuleUtils.cpp new file mode 100644 index 0000000..439b3fc --- /dev/null +++ b/compiler/Utils/ModuleUtils.cpp
@@ -0,0 +1,98 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/Utils/ModuleUtils.h" + +#include "llvm/ADT/SetVector.h" +#include "mlir/IR/Function.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +// Finds a list of functions with the given |attrName| and adds them to |funcs|. +void findFunctionsWithAttr(ModuleOp module, const char *attrName, + llvm::SetVector<FuncOp> &funcs) { + for (auto func : module.getOps<FuncOp>()) { + if (func.getAttr(attrName)) { + funcs.insert(func); + } + } +} + +// Inserts functions reachable directly from |func| to |usedFuncs|. +void insertUsedFunctions(ModuleOp module, FuncOp func, + DenseSet<FuncOp> *usedFuncs, + std::vector<FuncOp> *toSearch) { + auto onCalledFunction = [&](StringRef calleeName) { + auto calleeFunc = module.lookupSymbol<FuncOp>(calleeName); + if (usedFuncs->insert(calleeFunc).second) { + // New function found! Add to queue for searching. + toSearch->push_back(calleeFunc); + } + }; + for (auto &block : func) { + for (auto &op : block) { + // TODO(benvanik): replace with iree_hl.call check. + if (auto calleeAttr = op.getAttr("callee")) { + onCalledFunction(calleeAttr.cast<SymbolRefAttr>().getValue()); + } + } + } +} + +// Returns a set containing the names of all functions used by the given +// |rootFuncs| list. +DenseSet<FuncOp> findUsedFunctions(ModuleOp module, + ArrayRef<FuncOp> rootFuncs) { + // Breadth-first search. + DenseSet<FuncOp> usedFuncs; + usedFuncs.insert(rootFuncs.begin(), rootFuncs.end()); + std::vector<FuncOp> toSearch = {rootFuncs.begin(), rootFuncs.end()}; + while (!toSearch.empty()) { + auto func = toSearch.back(); + toSearch.pop_back(); + insertUsedFunctions(module, func, &usedFuncs, &toSearch); + } + return usedFuncs; +} + +} // namespace + +void dropUnusedFunctions(ModuleOp module, ArrayRef<const char *> keepAttrs) { + // Find all of the exported functions we'll treat as roots. + llvm::SetVector<FuncOp> rootFuncs; + for (auto keepAttr : keepAttrs) { + findFunctionsWithAttr(module, keepAttr, rootFuncs); + } + + // Find the full set of all used functions reachable from the given rootFuncs. + // This set will contain the rootFuncs. + auto usedFuncs = findUsedFunctions(module, rootFuncs.getArrayRef()); + + // Drop all unused functions. + std::vector<FuncOp> deadFuncs; + for (auto func : module.getOps<FuncOp>()) { + if (!llvm::is_contained(usedFuncs, func)) { + deadFuncs.push_back(func); + } + } + for (auto func : deadFuncs) { + func.erase(); + } +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Utils/ModuleUtils.h b/compiler/Utils/ModuleUtils.h similarity index 100% rename from iree/compiler/Utils/ModuleUtils.h rename to compiler/Utils/ModuleUtils.h
diff --git a/compiler/Utils/OpCreationUtils.cpp b/compiler/Utils/OpCreationUtils.cpp new file mode 100644 index 0000000..7d5bcca --- /dev/null +++ b/compiler/Utils/OpCreationUtils.cpp
@@ -0,0 +1,45 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/Utils/OpCreationUtils.h" + +#include <cstdint> + +#include "compiler/IR/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +ElementsAttr elementsAttrFromArray(OpBuilder &builder, + ArrayRef<int64_t> elements) { + return DenseIntElementsAttr::get( + RankedTensorType::get(elements.size(), builder.getIntegerType(64)), + elements); +} + +} // namespace + +IREE::ConstantOp createArrayConstant(OpBuilder &builder, Location loc, + llvm::ArrayRef<int64_t> elements) { + auto elementsAttr = elementsAttrFromArray(builder, elements); + return builder.create<IREE::ConstantOp>(loc, elementsAttr); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Utils/OpCreationUtils.h b/compiler/Utils/OpCreationUtils.h new file mode 100644 index 0000000..1ad01bc --- /dev/null +++ b/compiler/Utils/OpCreationUtils.h
@@ -0,0 +1,39 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Utility functions related to the creation of new operations. Where possible, +// use custom builders. These helpers are for situations where a custom builder +// is not appropriate. + +#ifndef IREE_COMPILER_UTILS_OPCREATIONUTILS_H_ +#define IREE_COMPILER_UTILS_OPCREATIONUTILS_H_ + +#include <cstdint> + +#include "compiler/IR/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { +namespace iree_compiler { + +IREE::ConstantOp createArrayConstant(OpBuilder &builder, Location loc, + llvm::ArrayRef<int64_t> elements); + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_UTILS_OPCREATIONUTILS_H_
diff --git a/compiler/Utils/OpUtils.cpp b/compiler/Utils/OpUtils.cpp new file mode 100644 index 0000000..df2c8f2 --- /dev/null +++ b/compiler/Utils/OpUtils.cpp
@@ -0,0 +1,44 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/Utils/OpUtils.h" + +namespace mlir { +namespace iree_compiler { + +void removeDeadOperations(llvm::SetVector<Operation *> &deadOperations) { + while (!deadOperations.empty()) { + auto *op = deadOperations.front(); + deadOperations.erase(deadOperations.begin()); + for (auto *operand : op->getOperands()) { + // TODO(benvanik): add check for op side effects. + if (operand->hasOneUse()) { + deadOperations.insert(operand->getDefiningOp()); + } + } + op->erase(); + } +} + +void replaceSubsequentUses(Operation *userOp, Value *oldValue, + Value *newValue) { + for (auto &use : oldValue->getUses()) { + if (userOp->isBeforeInBlock(use.getOwner())) { + use.set(newValue); + } + } +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Utils/OpUtils.h b/compiler/Utils/OpUtils.h similarity index 100% rename from iree/compiler/Utils/OpUtils.h rename to compiler/Utils/OpUtils.h
diff --git a/compiler/Utils/TranslationUtils.cpp b/compiler/Utils/TranslationUtils.cpp new file mode 100644 index 0000000..a9ec1ca --- /dev/null +++ b/compiler/Utils/TranslationUtils.cpp
@@ -0,0 +1,139 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/Utils/TranslationUtils.h" + +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace iree_compiler { + +namespace { + +// Returns the static registry of translator names to translation functions. +llvm::StringMap<TranslateExecutableFn> + &getMutableExecutableTranslationRegistry() { + static llvm::StringMap<TranslateExecutableFn> registry; + return registry; +} + +// Returns true if the given |value| matches |pattern| (normal * and ? rules). +bool matchPattern(StringRef value, StringRef pattern) { + size_t nextCharIndex = pattern.find_first_of("*?"); + if (nextCharIndex == std::string::npos) { + return value == pattern; + } else if (nextCharIndex > 0) { + if (value.substr(0, nextCharIndex) != pattern.substr(0, nextCharIndex)) { + return false; + } + value = value.substr(nextCharIndex); + pattern = pattern.substr(nextCharIndex); + } + char patternChar = pattern[0]; + if (value.empty() && pattern.empty()) { + return true; + } else if (patternChar == '*' && pattern.size() > 1 && value.empty()) { + return false; + } else if (patternChar == '*' && pattern.size() == 1) { + return true; + } else if (patternChar == '?' || value[0] == patternChar) { + return matchPattern(value.substr(1), pattern.substr(1)); + } else if (patternChar == '*') { + return matchPattern(value, pattern.substr(1)) || + matchPattern(value.substr(1), pattern); + } + return false; +} + +// Force enables IR printing on the |passManager|. +void enableIRPrinting(PassManager *passManager) { + auto notVerifier = [](Pass *pass) { + return pass->getName() != "FunctionVerifier" && + pass->getName() != "ModuleVerifier"; + }; + bool printModuleScope = false; + passManager->enableIRPrinting(/*shouldPrintBeforePass=*/{}, + /*shouldPrintAfterPass=*/notVerifier, + printModuleScope, llvm::dbgs()); + passManager->disableMultithreading(); +} + +} // namespace + +ExecutableTranslationRegistration::ExecutableTranslationRegistration( + llvm::StringRef name, const TranslateExecutableFn &fn) { + auto ®istry = getMutableExecutableTranslationRegistry(); + if (registry.find(name) != registry.end()) { + llvm::report_fatal_error( + "Attempting to overwrite an existing translation function"); + } + assert(fn && "Attempting to register an empty translation function"); + registry[name] = fn; +} + +const llvm::StringMap<TranslateExecutableFn> + &getExecutableTranslationRegistry() { + return getMutableExecutableTranslationRegistry(); +} + +std::vector<std::string> matchExecutableTranslationBackendNames( + llvm::StringRef pattern) { + std::vector<std::string> matches; + for (auto &entry : getExecutableTranslationRegistry()) { + if (matchPattern(entry.getKey(), pattern)) { + matches.push_back(entry.getKey().str()); + } + } + return matches; +} + +std::unique_ptr<PassManager> createPassManager( + MLIRContext *ctx, const TranslationOptions &translationOptions) { + std::unique_ptr<PassManager> passManager(new PassManager(ctx)); + + // Enable IR printing/timing/etc from command line options. + registerPassManagerCLOptions(); + applyPassManagerCLOptions(*passManager); + + // Override with programmatic options. + if (translationOptions.print_mlir) { + enableIRPrinting(passManager.get()); + } + + return passManager; +} + +LogicalResult runPassPipeline(const TranslationOptions &translationOptions, + PassManager *passManager, ModuleOp module) { + if (translationOptions.print_mlir) { + module.dump(); + } + + // Run on the module. + if (failed(passManager->run(module))) { + return failure(); + } + + if (translationOptions.print_mlir) { + module.dump(); + } + + return success(); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/compiler/Utils/TranslationUtils.h b/compiler/Utils/TranslationUtils.h new file mode 100644 index 0000000..a96a428 --- /dev/null +++ b/compiler/Utils/TranslationUtils.h
@@ -0,0 +1,109 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_COMPILER_UTILS_TRANSLATIONUTILS_H_ +#define IREE_COMPILER_UTILS_TRANSLATIONUTILS_H_ + +#include <functional> +#include <memory> + +#include "compiler/IR/StructureOps.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/PassManager.h" +#include "schemas/executable_def_generated.h" + +namespace mlir { +namespace iree_compiler { + +// Common translation options for diagnostics and debugging. +struct TranslationOptions { + // Enables MLIR IR printing during translation. + // This can be specified via the -print-ir-before-all and -print-ir-after-all + // command line flags or overridden programmatically via this flag. + bool print_mlir = false; + + void CopyFrom(const TranslationOptions &other) { + print_mlir = other.print_mlir; + } +}; + +// Options for iree.module translation for diagnostics and debugging. +struct ModuleTranslationOptions : public TranslationOptions { + // Defines which backend translators will be used to translate executables. + // If empty then all linked in translators will be used. + // TODO(benvanik): extend to allow specifying entire config blobs via mlir. + std::vector<std::string> target_backends; +}; + +// Options for iree.executable translation for diagnostics and debugging. +// Target configuration is sourced from the iree.target_config op within the +// iree.executable. +struct ExecutableTranslationOptions : public TranslationOptions {}; + +// Results of a translation operation. +// May contain zero or more executable defs depending on translation options, +// defined target configs, and support. +struct ExecutableTranslationResult { + std::vector<std::unique_ptr<iree::ExecutableDefT>> executable_defs; +}; + +// Registered function that given a set of |executableOps| containing one +// or more iree.executables will produce zero or more serialized executables. +// +// Each iree.executable provided contains one iree.executable_target_config with +// backend-specific translation information. The translator can decide whether +// to translate each independently, group them together, etc. +// +// The provided |executableOps| can be mutated by the callee and will be +// preserved for debugging after translation. If any executable in +// |executableOps| is not used by the translator then it should be erased. +using TranslateExecutableFn = + std::function<llvm::Optional<ExecutableTranslationResult>( + ArrayRef<IREE::ExecutableOp> executableOps, + ExecutableTranslationOptions options)>; + +// Registers an executable translation function. +struct ExecutableTranslationRegistration { + ExecutableTranslationRegistration(llvm::StringRef name, + const TranslateExecutableFn &fn); +}; + +// Returns a read-only reference to the translator registry. +const llvm::StringMap<TranslateExecutableFn> + &getExecutableTranslationRegistry(); + +// Returns executable translation backend names matching the given pattern. +// This accepts wildcards for any delimited value. For example, 'foo-*-bar' will +// match 'foo-123-bar' and 'foo-456-bar' and 'foo-10?' will match 'foo-101' and +// 'foo-102'. +std::vector<std::string> matchExecutableTranslationBackendNames( + llvm::StringRef pattern); + +// Creates a new pass manager initialized with the given options. +std::unique_ptr<PassManager> createPassManager( + MLIRContext *ctx, const TranslationOptions &translationOptions); + +// Runs an initialized set of passes on the given module. +LogicalResult runPassPipeline(const TranslationOptions &translationOptions, + PassManager *passManager, ModuleOp module); + +} // namespace iree_compiler +} // namespace mlir + +#endif // IREE_COMPILER_UTILS_TRANSLATIONUTILS_H_
diff --git a/compiler/Utils/TypeConversionUtils.cpp b/compiler/Utils/TypeConversionUtils.cpp new file mode 100644 index 0000000..cf416d0 --- /dev/null +++ b/compiler/Utils/TypeConversionUtils.cpp
@@ -0,0 +1,74 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "compiler/Utils/TypeConversionUtils.h" + +#include <cassert> + +#include "compiler/IR/Ops.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/ErrorHandling.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/StandardTypes.h" + +namespace mlir { +namespace iree_compiler { + +Type legalizeType(Type type) { + if (type.isIndex()) { + return IntegerType::get(kIndexBitWidth, type.getContext()); + } else if (type.isInteger(1)) { + return IntegerType::get(kBoolBitWidth, type.getContext()); + } else if (auto memRefType = type.dyn_cast<MemRefType>()) { + return MemRefType::get(memRefType.getShape(), + legalizeType(memRefType.getElementType())); + } else if (auto functionType = type.dyn_cast<FunctionType>()) { + llvm::SmallVector<Type, 4> inputs; + for (const auto &oldType : functionType.getInputs()) { + inputs.push_back(legalizeType(oldType)); + } + llvm::SmallVector<Type, 4> results; + for (const auto &oldType : functionType.getResults()) { + results.push_back(legalizeType(oldType)); + } + return FunctionType::get(inputs, results, type.getContext()); + } + return type; +} + +Type LLTypeConverter::convertType(Type type) { return legalizeType(type); } + +MemRefType convertTypeToMemRef(Type type) { + if (type.isIntOrIndexOrFloat()) { + return MemRefType::get({}, type, {}, 0); + } else if (auto tensorType = type.dyn_cast<RankedTensorType>()) { + return MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + } else if (auto memRefType = type.dyn_cast<MemRefType>()) { + return memRefType; + } else { + llvm_unreachable("Unconvertable type"); + } +} + +MemRefType convertTypeToMemRef(Value *value) { + return convertTypeToMemRef(value->getType()); +} + +Type MemRefTypeConverter::convertType(Type type) { + return convertTypeToMemRef(type); +} + +} // namespace iree_compiler +} // namespace mlir
diff --git a/iree/compiler/Utils/TypeConversionUtils.h b/compiler/Utils/TypeConversionUtils.h similarity index 100% rename from iree/compiler/Utils/TypeConversionUtils.h rename to compiler/Utils/TypeConversionUtils.h
diff --git a/hal/BUILD b/hal/BUILD new file mode 100644 index 0000000..aeaf38d --- /dev/null +++ b/hal/BUILD
@@ -0,0 +1,377 @@ +# HAL (Hardware Abstraction Layer). +# Subdirectories contain implementations for different hardware and +# software backends. + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "allocator", + srcs = ["allocator.cc"], + hdrs = ["allocator.h"], + deps = [ + ":buffer", + "///base:source_location", + "///base:status", + "///base:tracing", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "api", + srcs = ["api.cc"], + hdrs = [ + "api.h", + "api_detail.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":api_hdrs", + ":buffer", + ":buffer_view", + ":fence", + ":heap_buffer", + ":semaphore", + "///base:api", + "///base:api_util", + "///base:shape", + "///base:tracing", + "@com_google_absl//absl/base:core_headers", + ], +) + +cc_library( + name = "api_hdrs", + hdrs = ["api.h"], + deps = [ + "///base:api_hdrs", + ], +) + +cc_library( + name = "buffer", + srcs = ["buffer.cc"], + hdrs = ["buffer.h"], + deps = [ + ":resource", + "///base:bitfield", + "///base:logging", + "///base:source_location", + "///base:status", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", + ], +) + +cc_test( + name = "buffer_test", + srcs = [ + "buffer_mapping_test.cc", + "buffer_test.cc", + ], + deps = [ + ":buffer", + ":heap_buffer", + "///base:status", + "///base:status_matchers", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "buffer_view", + srcs = ["buffer_view.cc"], + hdrs = ["buffer_view.h"], + deps = [ + ":buffer", + "///base:shape", + "///base:source_location", + "///base:status", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "buffer_view_test", + srcs = [ + "buffer_view_test.cc", + ], + deps = [ + ":buffer", + ":buffer_view", + ":heap_buffer", + "///base:status", + "///base:status_matchers", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "buffer_view_string_util", + srcs = ["buffer_view_string_util.cc"], + hdrs = ["buffer_view_string_util.h"], + deps = [ + ":allocator", + ":buffer_view", + ":heap_buffer", + "///base:source_location", + "///base:status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + +cc_test( + name = "buffer_view_string_util_test", + srcs = ["buffer_view_string_util_test.cc"], + deps = [ + ":buffer_view_string_util", + "///base:status", + "///base:status_matchers", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "command_buffer", + srcs = ["command_buffer.cc"], + hdrs = ["command_buffer.h"], + deps = [ + ":allocator", + ":buffer", + ":buffer_view", + ":event", + ":executable", + ":resource", + "///base:bitfield", + "///base:shape", + "///base:status", + "@com_google_absl//absl/base:core_headers", + ], +) + +cc_library( + name = "command_buffer_validation", + srcs = ["command_buffer_validation.cc"], + hdrs = ["command_buffer_validation.h"], + deps = [ + ":command_buffer", + "///base:logging", + "///base:status", + ], +) + +cc_library( + name = "command_queue", + hdrs = ["command_queue.h"], + deps = [ + ":command_buffer", + ":fence", + ":semaphore", + "///base:bitfield", + "///base:status", + "///base:time", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "deferred_buffer", + srcs = ["deferred_buffer.cc"], + hdrs = ["deferred_buffer.h"], + deps = [ + ":allocator", + ":buffer", + "///base:status", + ], +) + +cc_test( + name = "deferred_buffer_test", + srcs = ["deferred_buffer_test.cc"], + deps = [ + ":deferred_buffer", + ":heap_buffer", + "///base:status_matchers", + "///hal/testing:mock_allocator", + "@com_google_absl//absl/memory", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "device", + hdrs = ["device.h"], + deps = [ + ":allocator", + ":buffer", + ":command_queue", + ":device_info", + ":event", + ":executable_cache", + ":semaphore", + "///base:status", + "///base:time", + "@com_google_absl//absl/time", + ], +) + +cc_library( + name = "device_info", + hdrs = ["device_info.h"], + deps = [ + "///base:bitfield", + "@com_google_absl//absl/base:core_headers", + ], +) + +cc_library( + name = "device_manager", + srcs = ["device_manager.cc"], + hdrs = ["device_manager.h"], + deps = [ + ":allocator", + ":buffer", + ":command_queue", + ":device", + ":device_placement", + ":executable_format", + ":fence", + ":heap_buffer", + "///base:source_location", + "///base:status", + "///base:time", + "///base:tracing", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "device_placement", + hdrs = ["device_placement.h"], +) + +cc_library( + name = "driver", + hdrs = ["driver.h"], + deps = [ + ":device", + ":device_info", + "///base:status", + ], +) + +cc_library( + name = "driver_registry", + srcs = ["driver_registry.cc"], + hdrs = ["driver_registry.h"], + deps = [ + ":driver", + "///base:init", + "///base:status", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "event", + hdrs = ["event.h"], + deps = [ + ":resource", + ], +) + +cc_library( + name = "executable", + hdrs = ["executable.h"], + deps = [":resource"], +) + +cc_library( + name = "executable_cache", + srcs = ["executable_cache.cc"], + hdrs = ["executable_cache.h"], + deps = [ + ":executable", + ":executable_format", + ":executable_spec", + "///base:bitfield", + "///base:ref_ptr", + "///base:status", + ], +) + +cc_library( + name = "executable_format", + hdrs = ["executable_format.h"], + deps = [ + "@com_google_absl//absl/base:core_headers", + ], +) + +cc_library( + name = "executable_spec", + hdrs = ["executable_spec.h"], + deps = [ + ":executable_format", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "fence", + hdrs = ["fence.h"], + deps = [ + ":resource", + "///base:status", + ], +) + +cc_library( + name = "heap_buffer", + srcs = ["heap_buffer.cc"], + hdrs = ["heap_buffer.h"], + deps = [ + ":allocator", + ":buffer", + "///base:source_location", + "///base:status", + "///base:tracing", + "///hal/host:host_buffer", + "@com_google_absl//absl/base:core_headers", + ], +) + +cc_library( + name = "resource", + hdrs = ["resource.h"], + deps = [ + "///base:ref_ptr", + ], +) + +cc_library( + name = "semaphore", + hdrs = ["semaphore.h"], + deps = [ + ":resource", + "@com_google_absl//absl/types:variant", + ], +) + +cc_library( + name = "stack_trace", + hdrs = ["stack_trace.h"], +)
diff --git a/iree/hal/CMakeLists.txt b/hal/CMakeLists.txt similarity index 100% rename from iree/hal/CMakeLists.txt rename to hal/CMakeLists.txt
diff --git a/hal/allocator.cc b/hal/allocator.cc new file mode 100644 index 0000000..57c18c2 --- /dev/null +++ b/hal/allocator.cc
@@ -0,0 +1,77 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/allocator.h" + +#include <cstdint> +#include <cstdlib> +#include <string> +#include <utility> + +#include "base/source_location.h" +#include "base/status.h" +#include "base/tracing.h" + +namespace iree { +namespace hal { + +bool Allocator::CanUseBuffer(Buffer* buffer, + BufferUsageBitfield intended_usage) const { + return CanUseBufferLike(buffer->allocator(), buffer->memory_type(), + buffer->usage(), intended_usage); +} + +StatusOr<ref_ptr<Buffer>> Allocator::AllocateConstant( + BufferUsageBitfield buffer_usage, ref_ptr<Buffer> source_buffer) { + if (AnyBitSet(source_buffer->usage() & BufferUsage::kConstant) && + CanUseBuffer(source_buffer.get(), buffer_usage)) { + // Buffer can be used directly by the device. + return source_buffer; + } + + IREE_TRACE_SCOPE0("Allocator::AllocateConstant"); + + // We need to map so we can copy into it. + buffer_usage |= BufferUsage::kMapping; + // It will be constant after we write it. + buffer_usage |= BufferUsage::kConstant; + + MemoryTypeBitfield memory_type = + MemoryType::kDeviceLocal | MemoryType::kHostVisible; + ASSIGN_OR_RETURN(auto device_buffer, Allocate(memory_type, buffer_usage, + source_buffer->byte_length())); + ASSIGN_OR_RETURN(auto source_mapping, + source_buffer->MapMemory<uint8_t>(MemoryAccess::kRead)); + RETURN_IF_ERROR(device_buffer->WriteData(0, source_mapping.data(), + source_mapping.byte_length())); + return device_buffer; +} + +StatusOr<ref_ptr<Buffer>> Allocator::Wrap(MemoryTypeBitfield memory_type, + BufferUsageBitfield buffer_usage, + const void* data, + size_t data_length) { + return WrapMutable(memory_type, MemoryAccess::kRead, buffer_usage, + const_cast<void*>(data), data_length); +} + +StatusOr<ref_ptr<Buffer>> Allocator::WrapMutable( + MemoryTypeBitfield memory_type, MemoryAccessBitfield allowed_access, + BufferUsageBitfield buffer_usage, void* data, size_t data_length) { + return UnimplementedErrorBuilder(IREE_LOC) + << "Allocator does not support wrapping host memory"; +} + +} // namespace hal +} // namespace iree
diff --git a/hal/allocator.h b/hal/allocator.h new file mode 100644 index 0000000..b3dba23 --- /dev/null +++ b/hal/allocator.h
@@ -0,0 +1,138 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_ALLOCATOR_H_ +#define IREE_HAL_ALLOCATOR_H_ + +#include <cstddef> +#include <memory> + +#include "absl/types/span.h" +#include "base/status.h" +#include "hal/buffer.h" + +namespace iree { +namespace hal { + +// Allocates buffers for a particular device memory space. +// +// Buffers allocated are only guaranteed to work with the driver that the +// allocator services. Any attempt to use buffers on drivers they were not +// allocated from must first be checked with CanUseBuffer. +// +// Thread-safe. +class Allocator { + public: + virtual ~Allocator() = default; + + // Returns true if the device can use the given buffer for the provided usage. + // For buffers allocated from this allocator it's expected that the result + // will always be true. For buffers that originate from another allocator + // there may be limited support for cross-device usage. + // + // Returning false indicates that the buffer must be transferred externally + // into a buffer compatible with the device this allocator services. + bool CanUseBuffer(Buffer* buffer, BufferUsageBitfield intended_usage) const; + virtual bool CanUseBufferLike(Allocator* source_allocator, + MemoryTypeBitfield memory_type, + BufferUsageBitfield buffer_usage, + BufferUsageBitfield intended_usage) const = 0; + + // Returns true if the allocator can allocate a buffer with the given + // attributes. + virtual bool CanAllocate(MemoryTypeBitfield memory_type, + BufferUsageBitfield buffer_usage, + size_t allocation_size) const = 0; + + // Adjusts allocation parameters to be compatible with the allocator. + // Certain allocators may require particular memory types to function. By + // adjusting the parameters prior to allocation callers can be sure they are + // able to successfully Allocate a buffer later on with the same parameters. + virtual Status MakeCompatible(MemoryTypeBitfield* memory_type, + BufferUsageBitfield* buffer_usage) const { + return OkStatus(); + } + + // Allocates a buffer from the allocator. + // Fails if the memory type requested for the given usage cannot be serviced. + // Callers can use CanAllocate to decide their memory use strategy. + // + // The memory type of the buffer returned may differ from the requested value + // if the device can provide more functionality; for example, if requesting + // MemoryType::kHostVisible but the memory is really host cached you may get + // a buffer back with MemoryType::kHostVisible | MemoryType::kHostCached. The + // only requirement is that the buffer satisfy the required bits. + virtual StatusOr<ref_ptr<Buffer>> Allocate(MemoryTypeBitfield memory_type, + BufferUsageBitfield buffer_usage, + size_t allocation_size) = 0; + + // Allocates a buffer from the allocator for use as a constant value. + // The provided |source_buffer| may be returned if the device can use it + // directly and otherwise will be copied. + virtual StatusOr<ref_ptr<Buffer>> AllocateConstant( + BufferUsageBitfield buffer_usage, ref_ptr<Buffer> source_buffer); + + // Wraps an existing host heap allocation in a buffer. + // Ownership of the host allocation remains with the caller and the memory + // must remain valid for so long as the Buffer may be in use. + // Will have MemoryType::kHostLocal in most cases and may not be usable + // by the device. + // + // The inference optimizer makes assumptions about buffer aliasing based on + // Buffer instances and because of this wrapping the same host buffer in + // multiple Buffers will create potential memory aliasing issues that can be + // difficult to track down. There's no checking as to whether a host buffer + // has already been wrapped so it's best for callers to ensure this is never + // possible (the simplest way being to never use Wrap and always just allocate + // new Buffers). + // + // Fails if the allocator cannot access host memory in this way. + StatusOr<ref_ptr<Buffer>> Wrap(MemoryTypeBitfield memory_type, + BufferUsageBitfield buffer_usage, + const void* data, size_t data_length); + virtual StatusOr<ref_ptr<Buffer>> WrapMutable( + MemoryTypeBitfield memory_type, MemoryAccessBitfield allowed_access, + BufferUsageBitfield buffer_usage, void* data, size_t data_length); + template <typename T> + StatusOr<ref_ptr<Buffer>> Wrap(MemoryTypeBitfield memory_type, + BufferUsageBitfield buffer_usage, + absl::Span<const T> data); + template <typename T> + StatusOr<ref_ptr<Buffer>> WrapMutable(MemoryTypeBitfield memory_type, + MemoryAccessBitfield allowed_access, + BufferUsageBitfield buffer_usage, + absl::Span<T> data); +}; + +// Inline functions and template definitions follow: + +template <typename T> +StatusOr<ref_ptr<Buffer>> Allocator::Wrap(MemoryTypeBitfield memory_type, + BufferUsageBitfield buffer_usage, + absl::Span<const T> data) { + return Wrap(memory_type, buffer_usage, data.data(), data.size() * sizeof(T)); +} + +template <typename T> +StatusOr<ref_ptr<Buffer>> Allocator::WrapMutable( + MemoryTypeBitfield memory_type, MemoryAccessBitfield allowed_access, + BufferUsageBitfield buffer_usage, absl::Span<T> data) { + return WrapMutable(memory_type, allowed_access, buffer_usage, data.data(), + data.size() * sizeof(T)); +} + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_ALLOCATOR_H_
diff --git a/hal/api.cc b/hal/api.cc new file mode 100644 index 0000000..b103406 --- /dev/null +++ b/hal/api.cc
@@ -0,0 +1,439 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/api.h" + +#include "base/api.h" +#include "base/api_util.h" +#include "base/shape.h" +#include "base/tracing.h" +#include "hal/api_detail.h" +#include "hal/buffer.h" +#include "hal/buffer_view.h" +#include "hal/fence.h" +#include "hal/heap_buffer.h" +#include "hal/semaphore.h" + +namespace iree { +namespace hal { + +//===----------------------------------------------------------------------===// +// iree::hal::Buffer +//===----------------------------------------------------------------------===// + +IREE_API_EXPORT iree_status_t iree_hal_buffer_subspan( + iree_hal_buffer_t* buffer, iree_device_size_t byte_offset, + iree_device_size_t byte_length, iree_allocator_t allocator, + iree_hal_buffer_t** out_buffer) { + IREE_TRACE_SCOPE0("iree_hal_buffer_subspan"); + + if (!out_buffer) { + return IREE_STATUS_INVALID_ARGUMENT; + } + *out_buffer = nullptr; + + if (!buffer) { + return IREE_STATUS_INVALID_ARGUMENT; + } + auto handle = add_ref(reinterpret_cast<Buffer*>(buffer)); + + IREE_API_ASSIGN_OR_RETURN(auto new_handle, + Buffer::Subspan(handle, byte_offset, byte_length)); + + *out_buffer = reinterpret_cast<iree_hal_buffer_t*>(new_handle.release()); + + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t +iree_hal_buffer_retain(iree_hal_buffer_t* buffer) { + IREE_TRACE_SCOPE0("iree_hal_buffer_retain"); + auto* handle = reinterpret_cast<Buffer*>(buffer); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + handle->AddReference(); + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t +iree_hal_buffer_release(iree_hal_buffer_t* buffer) { + IREE_TRACE_SCOPE0("iree_hal_buffer_release"); + auto* handle = reinterpret_cast<Buffer*>(buffer); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + handle->ReleaseReference(); + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_device_size_t +iree_hal_buffer_byte_length(const iree_hal_buffer_t* buffer) { + IREE_TRACE_SCOPE0("iree_hal_buffer_byte_length"); + const auto* handle = reinterpret_cast<const Buffer*>(buffer); + CHECK(handle) << "NULL buffer handle"; + return handle->byte_length(); +} + +IREE_API_EXPORT iree_status_t +iree_hal_buffer_zero(iree_hal_buffer_t* buffer, iree_device_size_t byte_offset, + iree_device_size_t byte_length) { + IREE_TRACE_SCOPE0("iree_hal_buffer_zero"); + auto* handle = reinterpret_cast<Buffer*>(buffer); + if (!buffer) { + return IREE_STATUS_INVALID_ARGUMENT; + } + IREE_API_RETURN_IF_ERROR(handle->Fill8(byte_offset, byte_length, 0)); + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t +iree_hal_buffer_fill(iree_hal_buffer_t* buffer, iree_device_size_t byte_offset, + iree_device_size_t byte_length, const void* pattern, + iree_host_size_t pattern_length) { + IREE_TRACE_SCOPE0("iree_hal_buffer_fill"); + auto* handle = reinterpret_cast<Buffer*>(buffer); + if (!buffer) { + return IREE_STATUS_INVALID_ARGUMENT; + } + IREE_API_RETURN_IF_ERROR( + handle->Fill(byte_offset, byte_length, pattern, pattern_length)); + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t iree_hal_buffer_read_data( + iree_hal_buffer_t* buffer, iree_device_size_t source_offset, + void* target_buffer, iree_device_size_t data_length) { + IREE_TRACE_SCOPE0("iree_hal_buffer_read_data"); + auto* handle = reinterpret_cast<Buffer*>(buffer); + if (!buffer) { + return IREE_STATUS_INVALID_ARGUMENT; + } + IREE_API_RETURN_IF_ERROR( + handle->ReadData(source_offset, target_buffer, data_length)); + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t iree_hal_buffer_write_data( + iree_hal_buffer_t* buffer, iree_device_size_t target_offset, + const void* source_buffer, iree_device_size_t data_length) { + IREE_TRACE_SCOPE0("iree_hal_buffer_write_data"); + auto* handle = reinterpret_cast<Buffer*>(buffer); + if (!buffer) { + return IREE_STATUS_INVALID_ARGUMENT; + } + IREE_API_RETURN_IF_ERROR( + handle->WriteData(target_offset, source_buffer, data_length)); + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t iree_hal_buffer_map( + iree_hal_buffer_t* buffer, iree_hal_memory_access_t memory_access, + iree_device_size_t element_offset, iree_device_size_t element_length, + iree_hal_mapped_memory_t* out_mapped_memory) { + IREE_TRACE_SCOPE0("iree_hal_buffer_map"); + + if (!out_mapped_memory) { + return IREE_STATUS_INVALID_ARGUMENT; + } + std::memset(out_mapped_memory, 0, sizeof(*out_mapped_memory)); + + auto* buffer_handle = reinterpret_cast<Buffer*>(buffer); + if (!buffer_handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + + IREE_API_ASSIGN_OR_RETURN( + auto mapping, buffer_handle->MapMemory<uint8_t>( + static_cast<MemoryAccessBitfield>(memory_access), + element_offset, element_length)); + + static_assert(sizeof(iree_hal_mapped_memory_t::reserved) >= + sizeof(MappedMemory<uint8_t>), + "C mapped memory struct must have large enough storage for the " + "matching C++ struct"); + auto* mapping_storage = + reinterpret_cast<MappedMemory<uint8_t>*>(out_mapped_memory->reserved); + *mapping_storage = std::move(mapping); + + out_mapped_memory->contents = {const_cast<uint8_t*>(mapping_storage->data()), + mapping_storage->size()}; + + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t iree_hal_buffer_unmap( + iree_hal_buffer_t* buffer, iree_hal_mapped_memory_t* mapped_memory) { + IREE_TRACE_SCOPE0("iree_hal_buffer_map"); + auto* buffer_handle = reinterpret_cast<Buffer*>(buffer); + if (!buffer_handle || !mapped_memory) { + return IREE_STATUS_INVALID_ARGUMENT; + } + + auto* mapping = + reinterpret_cast<MappedMemory<uint8_t>*>(mapped_memory->reserved); + mapping->reset(); + + std::memset(mapped_memory, 0, sizeof(*mapped_memory)); + return IREE_STATUS_OK; +} + +//===----------------------------------------------------------------------===// +// iree::hal::HeapBuffer +//===----------------------------------------------------------------------===// + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_allocate( + iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t usage, + iree_host_size_t allocation_size, iree_allocator_t contents_allocator, + iree_allocator_t allocator, iree_hal_buffer_t** out_buffer) { + IREE_TRACE_SCOPE0("iree_hal_heap_buffer_allocate"); + + if (!out_buffer) { + return IREE_STATUS_INVALID_ARGUMENT; + } + *out_buffer = nullptr; + + if (!allocation_size) { + return IREE_STATUS_INVALID_ARGUMENT; + } + + auto handle = HeapBuffer::Allocate( + static_cast<MemoryTypeBitfield>(memory_type), + static_cast<BufferUsageBitfield>(usage), allocation_size); + + *out_buffer = reinterpret_cast<iree_hal_buffer_t*>( + static_cast<Buffer*>(handle.release())); + + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_allocate_copy( + iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t usage, + iree_hal_memory_access_t allowed_access, iree_byte_span_t contents, + iree_allocator_t contents_allocator, iree_allocator_t allocator, + iree_hal_buffer_t** out_buffer) { + IREE_TRACE_SCOPE0("iree_hal_heap_buffer_allocate_copy"); + + if (!out_buffer) { + return IREE_STATUS_INVALID_ARGUMENT; + } + *out_buffer = nullptr; + + if (!contents.data || !contents.data_length) { + return IREE_STATUS_INVALID_ARGUMENT; + } + + auto handle = HeapBuffer::AllocateCopy( + static_cast<BufferUsageBitfield>(usage), + static_cast<MemoryAccessBitfield>(allowed_access), contents.data, + contents.data_length); + + *out_buffer = reinterpret_cast<iree_hal_buffer_t*>(handle.release()); + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_wrap( + iree_hal_memory_type_t memory_type, iree_hal_memory_access_t allowed_access, + iree_hal_buffer_usage_t usage, iree_byte_span_t contents, + iree_allocator_t allocator, iree_hal_buffer_t** out_buffer) { + IREE_TRACE_SCOPE0("iree_hal_heap_buffer_wrap"); + + if (!out_buffer) { + return IREE_STATUS_INVALID_ARGUMENT; + } + *out_buffer = nullptr; + + if (!contents.data || !contents.data_length) { + return IREE_STATUS_INVALID_ARGUMENT; + } + + auto handle = + HeapBuffer::WrapMutable(static_cast<MemoryTypeBitfield>(memory_type), + static_cast<MemoryAccessBitfield>(allowed_access), + static_cast<BufferUsageBitfield>(usage), + contents.data, contents.data_length); + + *out_buffer = reinterpret_cast<iree_hal_buffer_t*>(handle.release()); + return IREE_STATUS_OK; +} + +//===----------------------------------------------------------------------===// +// iree::hal::BufferView +//===----------------------------------------------------------------------===// + +IREE_API_EXPORT iree_status_t iree_hal_buffer_view_create( + iree_hal_buffer_t* buffer, iree_shape_t shape, int8_t element_size, + iree_allocator_t allocator, iree_hal_buffer_view_t** out_buffer_view) { + IREE_TRACE_SCOPE0("iree_hal_buffer_view_create"); + + if (!out_buffer_view) { + return IREE_STATUS_INVALID_ARGUMENT; + } + *out_buffer_view = nullptr; + + if (!buffer) { + return IREE_STATUS_INVALID_ARGUMENT; + } else if (shape.rank > kMaxRank || element_size <= 0) { + return IREE_STATUS_OUT_OF_RANGE; + } + + // Allocate and initialize the iree_hal_buffer_view struct. + iree_hal_buffer_view* handle = nullptr; + IREE_API_RETURN_IF_API_ERROR(allocator.alloc( + allocator.self, sizeof(*handle), reinterpret_cast<void**>(&handle))); + new (handle) iree_hal_buffer_view(); + handle->allocator = allocator; + + handle->impl.buffer = add_ref(reinterpret_cast<Buffer*>(buffer)); + handle->impl.shape = {shape.dims, shape.rank}; + handle->impl.element_size = element_size; + + *out_buffer_view = reinterpret_cast<iree_hal_buffer_view_t*>(handle); + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t +iree_hal_buffer_view_retain(iree_hal_buffer_view_t* buffer_view) { + IREE_TRACE_SCOPE0("iree_hal_buffer_view_retain"); + auto* handle = reinterpret_cast<iree_hal_buffer_view*>(buffer_view); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + handle->AddReference(); + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t +iree_hal_buffer_view_release(iree_hal_buffer_view_t* buffer_view) { + IREE_TRACE_SCOPE0("iree_hal_buffer_view_release"); + auto* handle = reinterpret_cast<iree_hal_buffer_view*>(buffer_view); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + handle->ReleaseReference(); + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t iree_hal_buffer_view_assign( + iree_hal_buffer_view_t* buffer_view, iree_hal_buffer_t* buffer, + iree_shape_t shape, int8_t element_size) { + IREE_TRACE_SCOPE0("iree_hal_buffer_view_assign"); + auto* handle = reinterpret_cast<iree_hal_buffer_view*>(buffer_view); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + handle->impl.buffer.reset(); + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t +iree_hal_buffer_view_reset(iree_hal_buffer_view_t* buffer_view) { + IREE_TRACE_SCOPE0("iree_hal_buffer_view_reset"); + auto* handle = reinterpret_cast<iree_hal_buffer_view*>(buffer_view); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + handle->impl.buffer.reset(); + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_hal_buffer_t* iree_hal_buffer_view_buffer( + const iree_hal_buffer_view_t* buffer_view) { + IREE_TRACE_SCOPE0("iree_hal_buffer_view_buffer"); + const auto* handle = + reinterpret_cast<const iree_hal_buffer_view*>(buffer_view); + CHECK(handle) << "NULL buffer_view handle"; + return reinterpret_cast<iree_hal_buffer_t*>(handle->impl.buffer.get()); +} + +IREE_API_EXPORT iree_status_t iree_hal_buffer_view_shape( + const iree_hal_buffer_view_t* buffer_view, iree_shape_t* out_shape) { + IREE_TRACE_SCOPE0("iree_hal_buffer_view_shape"); + + if (!out_shape) { + return IREE_STATUS_INVALID_ARGUMENT; + } + out_shape->rank = 0; + + const auto* handle = + reinterpret_cast<const iree_hal_buffer_view*>(buffer_view); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + + const auto& shape = handle->impl.shape; + return ToApiShape(shape, out_shape); +} + +IREE_API_EXPORT int8_t +iree_hal_buffer_view_element_size(const iree_hal_buffer_view_t* buffer_view) { + IREE_TRACE_SCOPE0("iree_hal_buffer_view_element_size"); + const auto* handle = + reinterpret_cast<const iree_hal_buffer_view*>(buffer_view); + CHECK(handle) << "NULL buffer_view handle"; + return handle->impl.element_size; +} + +//===----------------------------------------------------------------------===// +// iree::hal::Semaphore +//===----------------------------------------------------------------------===// + +IREE_API_EXPORT iree_status_t +iree_hal_semaphore_retain(iree_hal_semaphore_t* semaphore) { + IREE_TRACE_SCOPE0("iree_hal_semaphore_retain"); + auto* handle = reinterpret_cast<Semaphore*>(semaphore); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + handle->AddReference(); + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t +iree_hal_semaphore_release(iree_hal_semaphore_t* semaphore) { + IREE_TRACE_SCOPE0("iree_hal_semaphore_release"); + auto* handle = reinterpret_cast<Semaphore*>(semaphore); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + handle->ReleaseReference(); + return IREE_STATUS_OK; +} + +//===----------------------------------------------------------------------===// +// iree::hal::Fence +//===----------------------------------------------------------------------===// + +IREE_API_EXPORT iree_status_t iree_hal_fence_retain(iree_hal_fence_t* fence) { + IREE_TRACE_SCOPE0("iree_hal_fence_retain"); + auto* handle = reinterpret_cast<Fence*>(fence); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + handle->AddReference(); + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t iree_hal_fence_release(iree_hal_fence_t* fence) { + IREE_TRACE_SCOPE0("iree_hal_fence_release"); + auto* handle = reinterpret_cast<Fence*>(fence); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + handle->ReleaseReference(); + return IREE_STATUS_OK; +} + +} // namespace hal +} // namespace iree
diff --git a/hal/api.h b/hal/api.h new file mode 100644 index 0000000..4d305f9 --- /dev/null +++ b/hal/api.h
@@ -0,0 +1,366 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// See iree/base/api.h for documentation on the API conventions used. + +#ifndef IREE_HAL_API_H_ +#define IREE_HAL_API_H_ + +#include <stdint.h> + +#include "base/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +//===----------------------------------------------------------------------===// +// Types and Enums +//===----------------------------------------------------------------------===// + +typedef struct iree_hal_buffer iree_hal_buffer_t; +typedef struct iree_hal_buffer_view iree_hal_buffer_view_t; +typedef struct iree_hal_semaphore iree_hal_semaphore_t; +typedef struct iree_hal_fence iree_hal_fence_t; + +// Reference to a buffer's mapped memory. +typedef struct { + // Contents of the buffer. Behavior is undefined if an access is performed + // whose type was not specified during mapping. + iree_byte_span_t contents; + + // Used internally - do not modify. + uint64_t reserved[8]; +} iree_hal_mapped_memory_t; + +// A bitfield specifying properties for a memory type. +typedef enum { + IREE_HAL_MEMORY_TYPE_NONE = 0, + + // Memory is lazily allocated by the device and only exists transiently. + // This is the optimal mode for memory used only within a single command + // buffer. Transient buffers, even if they have + // IREE_HAL_MEMORY_TYPE_HOST_VISIBLE set, should be treated as device-local + // and opaque as they may have no memory attached to them outside of the time + // they are being evaluated on devices. + // + // This flag can be treated as a hint in most cases; allocating a buffer with + // it set _may_ return the same as if it had not be set. Certain allocation + // routines may use the hint to more tightly control reuse or defer wiring the + // memory. + IREE_HAL_MEMORY_TYPE_TRANSIENT = 1 << 0, + + // Memory allocated with this type can be mapped for host access using + // iree_hal_buffer_map. + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE = 1 << 1, + + // The host cache management commands MappedMemory::Flush and + // MappedMemory::Invalidate are not needed to flush host writes + // to the device or make device writes visible to the host, respectively. + IREE_HAL_MEMORY_TYPE_HOST_COHERENT = 1 << 2, + + // Memory allocated with this type is cached on the host. Host memory + // accesses to uncached memory are slower than to cached memory, however + // uncached memory is always host coherent. MappedMemory::Flush must be used + // to ensure the device has visibility into any changes made on the host and + // Invalidate must be used to ensure the host has visibility into any changes + // made on the device. + IREE_HAL_MEMORY_TYPE_HOST_CACHED = 1 << 3, + + // Memory is accessible as normal host allocated memory. + IREE_HAL_MEMORY_TYPE_HOST_LOCAL = + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE | IREE_HAL_MEMORY_TYPE_HOST_COHERENT, + + // Memory allocated with this type is visible to the device for execution. + // Being device visible does not mean the same thing as + // IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL. Though an allocation may be visible to + // the device and therefore useable for execution it may require expensive + // mapping or implicit transfers. + IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE = 1 << 4, + + // Memory allocated with this type is the most efficient for device access. + // Devices may support using memory that is not device local via + // IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE but doing so can incur non-trivial + // performance penalties. Device local memory, on the other hand, is + // guaranteed to be fast for all operations. + IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL = + IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE | (1 << 5), +} iree_hal_memory_type_t; + +// A bitfield specifying how memory will be accessed in a mapped memory region. +typedef enum { + // Memory is not mapped. + IREE_HAL_MEMORY_ACCESS_NONE = 0, + // Memory will be read. + // If a buffer is only mapped for reading it may still be possible to write to + // it but the results will be undefined (as it may present coherency issues). + IREE_HAL_MEMORY_ACCESS_READ = 1 << 0, + // Memory will be written. + // If a buffer is only mapped for writing it may still be possible to read + // from it but the results will be undefined or incredibly slow (as it may + // be mapped by the driver as uncached). + IREE_HAL_MEMORY_ACCESS_WRITE = 1 << 1, + // Memory will be discarded prior to mapping. + // The existing contents will be undefined after mapping and must be written + // to ensure validity. + IREE_HAL_MEMORY_ACCESS_DISCARD = 1 << 2, + // Memory will be discarded and completely overwritten in a single operation. + IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE = + IREE_HAL_MEMORY_ACCESS_WRITE | IREE_HAL_MEMORY_ACCESS_DISCARD, + // Memory may have any operation performed on it. + IREE_HAL_MEMORY_ACCESS_ALL = IREE_HAL_MEMORY_ACCESS_READ | + IREE_HAL_MEMORY_ACCESS_WRITE | + IREE_HAL_MEMORY_ACCESS_DISCARD, +} iree_hal_memory_access_t; + +// Bitfield that defines how a buffer is intended to be used. +// Usage allows the driver to appropriately place the buffer for more +// efficient operations of the specified types. +typedef enum { + IREE_HAL_BUFFER_USAGE_NONE = 0, + + // The buffer, once defined, will not be mapped or updated again. + // This should be used for uniform parameter values such as runtime + // constants for executables. Doing so may allow drivers to inline values or + // represent them in command buffers more efficiently (avoiding memory reads + // or swapping, etc). + IREE_HAL_BUFFER_USAGE_CONSTANT = 1 << 0, + + // The buffer can be used as the source or target of a transfer command + // (CopyBuffer, UpdateBuffer, etc). + // + // If |IREE_HAL_BUFFER_USAGE_MAPPING| is not specified drivers may safely + // assume that the host may never need visibility of this buffer as all + // accesses will happen via command buffers. + IREE_HAL_BUFFER_USAGE_TRANSFER = 1 << 1, + + // The buffer can be mapped by the host application for reading and writing. + // + // As mapping may require placement in special address ranges or system + // calls to enable visibility the driver can use the presence (or lack of) + // this flag to perform allocation-type setup and avoid initial mapping + // overhead. + IREE_HAL_BUFFER_USAGE_MAPPING = 1 << 2, + + // The buffer can be provided as an input or output to an executable. + // Buffers of this type may be directly used by drivers during dispatch. + IREE_HAL_BUFFER_USAGE_DISPATCH = 1 << 3, + + // Buffer may be used for any operation. + IREE_HAL_BUFFER_USAGE_ALL = IREE_HAL_BUFFER_USAGE_TRANSFER | + IREE_HAL_BUFFER_USAGE_MAPPING | + IREE_HAL_BUFFER_USAGE_DISPATCH, +} iree_hal_buffer_usage_t; + +//===----------------------------------------------------------------------===// +// iree::hal::Buffer +//===----------------------------------------------------------------------===// + +#ifndef IREE_API_NO_PROTOTYPES + +// Returns a reference to a subspan of the |buffer|. +// If |byte_length| is IREE_WHOLE_BUFFER the remaining bytes in the buffer after +// |byte_offset| (possibly 0) will be selected. +// +// The parent buffer will remain alive for the lifetime of the subspan +// returned. If the subspan is a small portion this may cause additional +// memory to remain allocated longer than required. +// +// Returns the given |buffer| if the requested span covers the entire range. +// |out_buffer| must be released by the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_subspan( + iree_hal_buffer_t* buffer, iree_device_size_t byte_offset, + iree_device_size_t byte_length, iree_allocator_t allocator, + iree_hal_buffer_t** out_buffer); + +// Retains the given |buffer| for the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_buffer_retain(iree_hal_buffer_t* buffer); + +// Releases the given |buffer| from the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_buffer_release(iree_hal_buffer_t* buffer); + +// Returns the size in bytes of the buffer. +IREE_API_EXPORT iree_device_size_t IREE_API_CALL +iree_hal_buffer_byte_length(const iree_hal_buffer_t* buffer); + +// Sets a range of the buffer to binary zero. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_buffer_zero(iree_hal_buffer_t* buffer, iree_device_size_t byte_offset, + iree_device_size_t byte_length); + +// Sets a range of the buffer to the given value. +// Only |pattern_length| values with 1, 2, or 4 bytes are supported. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_buffer_fill(iree_hal_buffer_t* buffer, iree_device_size_t byte_offset, + iree_device_size_t byte_length, const void* pattern, + iree_host_size_t pattern_length); + +// Reads a block of data from the buffer at the given offset. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_read_data( + iree_hal_buffer_t* buffer, iree_device_size_t source_offset, + void* target_buffer, iree_device_size_t data_length); + +// Writes a block of byte data into the buffer at the given offset. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_write_data( + iree_hal_buffer_t* buffer, iree_device_size_t target_offset, + const void* source_buffer, iree_device_size_t data_length); + +// Maps the buffer to be accessed as a host pointer into |out_mapped_memory|. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_map( + iree_hal_buffer_t* buffer, iree_hal_memory_access_t memory_access, + iree_device_size_t element_offset, iree_device_size_t element_length, + iree_hal_mapped_memory_t* out_mapped_memory); + +// Unmaps the buffer as was previously mapped to |mapped_memory|. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_unmap( + iree_hal_buffer_t* buffer, iree_hal_mapped_memory_t* mapped_memory); + +#endif // IREE_API_NO_PROTOTYPES + +//===----------------------------------------------------------------------===// +// iree::hal::HeapBuffer +//===----------------------------------------------------------------------===// + +#ifndef IREE_API_NO_PROTOTYPES + +// Allocates a zeroed host heap buffer of the given size. +// The buffer contents will be allocated with |contents_allocator| while +// |allocator| is used for the iree_hal_buffer_t. +// +// Returns a buffer allocated with malloc that may not be usable by devices +// without copies. |memory_type| should be set to +// IREE_HAL_MEMORY_TYPE_HOST_LOCAL in most cases. +// |out_buffer| must be released by the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_allocate( + iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t usage, + iree_host_size_t allocation_size, iree_allocator_t contents_allocator, + iree_allocator_t allocator, iree_hal_buffer_t** out_buffer); + +// Allocates a host heap buffer with a copy of the given data. +// The buffer contents will be allocated with |contents_allocator| while +// |allocator| is used for the iree_hal_buffer_t. +// +// Returns a buffer allocated with malloc that may not be usable by devices +// without copies. |memory_type| should be set to +// IREE_HAL_MEMORY_TYPE_HOST_LOCAL in most cases. +// |out_buffer| must be released by the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_allocate_copy( + iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t usage, + iree_hal_memory_access_t allowed_access, iree_byte_span_t contents, + iree_allocator_t contents_allocator, iree_allocator_t allocator, + iree_hal_buffer_t** out_buffer); + +// Wraps an existing host heap allocation in a buffer. +// Ownership of the host allocation remains with the caller and the memory +// must remain valid for so long as the iree_hal_buffer_t may be in use. +// +// Returns a buffer allocated with malloc that may not be usable by devices +// without copies. |memory_type| should be set to +// IREE_HAL_MEMORY_TYPE_HOST_LOCAL in most cases. +// |out_buffer| must be released by the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_wrap( + iree_hal_memory_type_t memory_type, iree_hal_memory_access_t allowed_access, + iree_hal_buffer_usage_t usage, iree_byte_span_t contents, + iree_allocator_t allocator, iree_hal_buffer_t** out_buffer); + +// TODO(benvanik): add a wrap that takes an allocator just for the buffer. + +#endif // IREE_API_NO_PROTOTYPES + +//===----------------------------------------------------------------------===// +// iree::hal::BufferView +//===----------------------------------------------------------------------===// + +#ifndef IREE_API_NO_PROTOTYPES + +// Creates a buffer view with the given |buffer|, which may be nullptr. +// |out_buffer_view| must be released by the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_create( + iree_hal_buffer_t* buffer, iree_shape_t shape, int8_t element_size, + iree_allocator_t allocator, iree_hal_buffer_view_t** out_buffer_view); + +// Retains the given |buffer_view| for the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_buffer_view_retain(iree_hal_buffer_view_t* buffer_view); + +// Releases the given |buffer_view| from the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_buffer_view_release(iree_hal_buffer_view_t* buffer_view); + +// Sets the buffer view to point at the new |buffer| with the given metadata. +// To clear a buffer_view to empty use iree_hal_buffer_view_reset. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_assign( + iree_hal_buffer_view_t* buffer_view, iree_hal_buffer_t* buffer, + iree_shape_t shape, int8_t element_size); + +// Resets the buffer view to have an empty buffer and shape. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_buffer_view_reset(iree_hal_buffer_view_t* buffer_view); + +// Returns the buffer underlying the buffer view. +// The caller must retain the returned buffer if they want to continue using it. +IREE_API_EXPORT iree_hal_buffer_t* IREE_API_CALL +iree_hal_buffer_view_buffer(const iree_hal_buffer_view_t* buffer_view); + +// Returns the shape of the buffer view in |out_shape|. +// If there is not enough space in |out_shape| to store all dimensions then +// IREE_STATUS_OUT_OF_RANGE is returned and |out_shape|.rank is set to the rank. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_shape( + const iree_hal_buffer_view_t* buffer_view, iree_shape_t* out_shape); + +// Returns the size of each element in the buffer view in bytes. +IREE_API_EXPORT int8_t IREE_API_CALL +iree_hal_buffer_view_element_size(const iree_hal_buffer_view_t* buffer_view); + +#endif // IREE_API_NO_PROTOTYPES + +//===----------------------------------------------------------------------===// +// iree::hal::Semaphore +//===----------------------------------------------------------------------===// + +#ifndef IREE_API_NO_PROTOTYPES + +// Retains the given |semaphore| for the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_semaphore_retain(iree_hal_semaphore_t* semaphore); + +// Releases the given |semaphore| from the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_semaphore_release(iree_hal_semaphore_t* semaphore); + +#endif // IREE_API_NO_PROTOTYPES + +//===----------------------------------------------------------------------===// +// iree::hal::Fence +//===----------------------------------------------------------------------===// + +#ifndef IREE_API_NO_PROTOTYPES + +// Retains the given |fence| for the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_fence_retain(iree_hal_fence_t* fence); + +// Releases the given |fence| from the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_fence_release(iree_hal_fence_t* fence); + +#endif // IREE_API_NO_PROTOTYPES + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_API_H_
diff --git a/hal/api_detail.h b/hal/api_detail.h new file mode 100644 index 0000000..09b6230 --- /dev/null +++ b/hal/api_detail.h
@@ -0,0 +1,42 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Additional definitions for internal users of the api. This should only +// be included from internal implementation files. + +#ifndef IREE_HAL_API_DETAIL_H_ +#define IREE_HAL_API_DETAIL_H_ + +#include "hal/api.h" +#include "hal/buffer_view.h" + +namespace iree { +namespace hal { + +// In the API, buffer views are ref objects, and this allows parts of the +// API outside of the HAL to work with them. +struct iree_hal_buffer_view : public RefObject<iree_hal_buffer_view> { + BufferView impl; + iree_allocator_t allocator; + + static void Delete(iree_hal_buffer_view* ptr) { + ptr->impl.buffer.reset(); + ptr->allocator.free(ptr->allocator.self, ptr); + } +}; + +} // namespace hal +} // namespace iree + +#endif
diff --git a/hal/buffer.cc b/hal/buffer.cc new file mode 100644 index 0000000..bfac15d --- /dev/null +++ b/hal/buffer.cc
@@ -0,0 +1,549 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/buffer.h" + +#include <algorithm> +#include <atomic> +#include <cstdint> +#include <cstring> +#include <sstream> + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/types/variant.h" +#include "base/status.h" + +namespace iree { +namespace hal { + +#if HAS_IREE_BUFFER_DEBUG_NAME +namespace { +// Used for diagnostic purposes only as a default buffer name. +std::atomic<int> next_buffer_id_{0}; +} // namespace +#endif // HAS_IREE_BUFFER_DEBUG_NAME + +std::string MemoryTypeString(MemoryTypeBitfield memory_type) { + return FormatBitfieldValue(memory_type, + { + // Combined: + {MemoryType::kHostLocal, "kHostLocal"}, + {MemoryType::kDeviceLocal, "kDeviceLocal"}, + // Separate: + {MemoryType::kTransient, "kTransient"}, + {MemoryType::kHostVisible, "kHostVisible"}, + {MemoryType::kHostCoherent, "kHostCoherent"}, + {MemoryType::kHostCached, "kHostCached"}, + {MemoryType::kDeviceVisible, "kDeviceVisible"}, + }); +} + +std::string MemoryAccessString(MemoryAccessBitfield memory_access) { + return FormatBitfieldValue(memory_access, + { + // Combined: + {MemoryAccess::kAll, "kAll"}, + {MemoryAccess::kDiscardWrite, "kDiscardWrite"}, + // Separate: + {MemoryAccess::kRead, "kRead"}, + {MemoryAccess::kWrite, "kWrite"}, + {MemoryAccess::kDiscard, "kDiscard"}, + }); +} + +std::string BufferUsageString(BufferUsageBitfield buffer_usage) { + return FormatBitfieldValue(buffer_usage, + { + // Combined: + {BufferUsage::kAll, "kAll"}, + // Separate: + {BufferUsage::kConstant, "kConstant"}, + {BufferUsage::kTransfer, "kTransfer"}, + {BufferUsage::kMapping, "kMapping"}, + {BufferUsage::kDispatch, "kDispatch"}, + }); +} + +// Special router for buffers that just reference other buffers. +// We keep this out of the base Buffer so that it's a bit easier to track +// delegation. +class SubspanBuffer : public Buffer { + public: + SubspanBuffer(ref_ptr<Buffer> parent_buffer, device_size_t byte_offset, + device_size_t byte_length) + : Buffer(parent_buffer->allocator(), parent_buffer->memory_type(), + parent_buffer->allowed_access(), parent_buffer->usage(), + parent_buffer->allocation_size(), byte_offset, byte_length) { + allocated_buffer_ = parent_buffer.get(); + parent_buffer_ = std::move(parent_buffer); + } + + protected: + Status FillImpl(device_size_t byte_offset, device_size_t byte_length, + const void* pattern, device_size_t pattern_length) override { + return parent_buffer_->FillImpl(byte_offset, byte_length, pattern, + pattern_length); + } + + Status ReadDataImpl(device_size_t source_offset, void* data, + device_size_t data_length) override { + return parent_buffer_->ReadDataImpl(source_offset, data, data_length); + } + + Status WriteDataImpl(device_size_t target_offset, const void* data, + device_size_t data_length) override { + return parent_buffer_->WriteDataImpl(target_offset, data, data_length); + } + + Status CopyDataImpl(device_size_t target_offset, Buffer* source_buffer, + device_size_t source_offset, + device_size_t data_length) override { + return parent_buffer_->CopyDataImpl(target_offset, source_buffer, + source_offset, data_length); + } + + Status MapMemoryImpl(MappingMode mapping_mode, + MemoryAccessBitfield memory_access, + device_size_t local_byte_offset, + device_size_t local_byte_length, + void** out_data) override { + return parent_buffer_->MapMemoryImpl(mapping_mode, memory_access, + local_byte_offset, local_byte_length, + out_data); + } + + Status UnmapMemoryImpl(device_size_t local_byte_offset, + device_size_t local_byte_length, void* data) override { + return parent_buffer_->UnmapMemoryImpl(local_byte_offset, local_byte_length, + data); + } + + Status InvalidateMappedMemoryImpl(device_size_t local_byte_offset, + device_size_t local_byte_length) override { + return parent_buffer_->InvalidateMappedMemoryImpl(local_byte_offset, + local_byte_length); + } + + Status FlushMappedMemoryImpl(device_size_t local_byte_offset, + device_size_t local_byte_length) override { + return parent_buffer_->FlushMappedMemoryImpl(local_byte_offset, + local_byte_length); + } +}; + +// static +StatusOr<ref_ptr<Buffer>> Buffer::Subspan(const ref_ptr<Buffer>& buffer, + device_size_t byte_offset, + device_size_t byte_length) { + RETURN_IF_ERROR(buffer->CalculateRange(byte_offset, byte_length, &byte_offset, + &byte_length)); + if (byte_offset == 0 && byte_length == buffer->byte_length()) { + // Asking for the same buffer. + return add_ref(buffer); + } + + // To avoid heavy nesting of subspans that just add indirection we go to the + // parent buffer directly. If we wanted better accounting (to track where + // buffers came from) we'd want to avoid this but I'm not sure that's worth + // the super deep indirection that could arise. + if (buffer->allocated_buffer() != buffer.get()) { + CHECK(buffer->parent_buffer_); + return Buffer::Subspan(buffer->parent_buffer_, byte_offset, byte_length); + } else { + return {make_ref<SubspanBuffer>(add_ref(buffer), byte_offset, byte_length)}; + } +} + +// static +Buffer::Overlap Buffer::TestOverlap( + Buffer* lhs_buffer, device_size_t lhs_offset, device_size_t lhs_length, + Buffer* rhs_buffer, device_size_t rhs_offset, device_size_t rhs_length) { + if (lhs_buffer->allocated_buffer() != rhs_buffer->allocated_buffer()) { + // Not even the same buffers. + return Overlap::kDisjoint; + } + // Resolve offsets into the underlying allocation. + device_size_t lhs_alloc_offset = lhs_buffer->byte_offset() + lhs_offset; + device_size_t rhs_alloc_offset = rhs_buffer->byte_offset() + rhs_offset; + device_size_t lhs_alloc_length = lhs_length == kWholeBuffer + ? lhs_buffer->byte_length() - lhs_offset + : lhs_length; + device_size_t rhs_alloc_length = rhs_length == kWholeBuffer + ? rhs_buffer->byte_length() - rhs_offset + : rhs_length; + if (!lhs_alloc_length || !rhs_alloc_length) { + return Overlap::kDisjoint; + } + if (lhs_alloc_offset == rhs_alloc_offset && + lhs_alloc_length == rhs_alloc_length) { + return Overlap::kComplete; + } + return lhs_alloc_offset + lhs_alloc_length > rhs_alloc_offset && + rhs_alloc_offset + rhs_alloc_length > lhs_alloc_offset + ? Overlap::kPartial + : Overlap::kDisjoint; +} + +// static +bool Buffer::DoesOverlap(Buffer* lhs_buffer, device_size_t lhs_offset, + device_size_t lhs_length, Buffer* rhs_buffer, + device_size_t rhs_offset, device_size_t rhs_length) { + return TestOverlap(lhs_buffer, lhs_offset, lhs_length, rhs_buffer, rhs_offset, + rhs_length) != Overlap::kDisjoint; +} + +Buffer::Buffer(Allocator* allocator, MemoryTypeBitfield memory_type, + MemoryAccessBitfield allowed_access, BufferUsageBitfield usage, + device_size_t allocation_size, device_size_t byte_offset, + device_size_t byte_length) + : allocated_buffer_(const_cast<Buffer*>(this)), + allocator_(allocator), + memory_type_(memory_type), + allowed_access_(allowed_access), + usage_(usage), + allocation_size_(allocation_size), + byte_offset_(byte_offset), + byte_length_(byte_length) { +#if HAS_IREE_BUFFER_DEBUG_NAME + // Default name for logging. + // It'd be nice to defer this until it's required but that would require + // synchronization or something. + const char* debug_name_prefix = ""; + if ((memory_type_ & MemoryType::kHostLocal) == MemoryType::kHostLocal) { + debug_name_prefix = "host_buffer_"; + } else if ((memory_type_ & MemoryType::kDeviceLocal) == + MemoryType::kDeviceLocal) { + // TODO(benvanik): include allocator ID to differentiate devices. + debug_name_prefix = "device_buffer_"; + } + debug_name_ = absl::StrCat(debug_name_prefix, next_buffer_id_++); +#endif // HAS_IREE_BUFFER_DEBUG_NAME +} + +Buffer* Buffer::allocated_buffer() const noexcept { + Buffer* allocated_buffer = allocated_buffer_; + while (allocated_buffer != this && + allocated_buffer != allocated_buffer->allocated_buffer()) { + allocated_buffer = allocated_buffer->allocated_buffer(); + } + return allocated_buffer; +} + +std::string Buffer::DebugString() const { + std::ostringstream stream; + stream << allocated_buffer()->debug_name() << "[" + << (allocation_size() == kWholeBuffer + ? "?" + : std::to_string(allocation_size())) + << "]."; + if (AnyBitSet(memory_type() & MemoryType::kTransient)) stream << "Z"; + if ((memory_type() & MemoryType::kHostLocal) == MemoryType::kHostLocal) { + stream << "h"; + } else { + if (AnyBitSet(memory_type() & MemoryType::kHostVisible)) stream << "v"; + if (AnyBitSet(memory_type() & MemoryType::kHostCoherent)) stream << "x"; + if (AnyBitSet(memory_type() & MemoryType::kHostCached)) stream << "c"; + } + if ((memory_type() & MemoryType::kDeviceLocal) == MemoryType::kDeviceLocal) { + stream << "D"; + } else { + if (AnyBitSet(memory_type() & MemoryType::kDeviceVisible)) stream << "V"; + } + stream << "."; + if (AnyBitSet(usage() & BufferUsage::kConstant)) stream << "c"; + if (AnyBitSet(usage() & BufferUsage::kTransfer)) stream << "t"; + if (AnyBitSet(usage() & BufferUsage::kMapping)) stream << "m"; + if (AnyBitSet(usage() & BufferUsage::kDispatch)) stream << "d"; + if (byte_offset_ || byte_length_ != allocation_size_) { + stream << "(" << byte_offset_ << "-" << (byte_offset_ + byte_length_ - 1) + << ")"; + } + return stream.str(); +} + +std::string Buffer::DebugStringShort() const { + // TODO(benvanik): figure out what's most useful here. Maybe a long variant? + std::ostringstream stream; + stream << allocated_buffer()->debug_name() << "[" + << (allocation_size() == kWholeBuffer + ? "?" + : std::to_string(allocation_size())) + << "]"; + if (byte_offset_ || byte_length_ != allocation_size_) { + stream << "(" << byte_offset_ << "-" << (byte_offset_ + byte_length_ - 1) + << ")"; + } + return stream.str(); +} + +Status Buffer::ValidateCompatibleMemoryType( + MemoryTypeBitfield memory_type) const { + if ((memory_type_ & memory_type) != memory_type) { + // Missing one or more bits. + return PermissionDeniedErrorBuilder(IREE_LOC) + << "Buffer memory type is not compatible with the requested " + "operation; buffer has " + << MemoryTypeString(memory_type_) << ", operation requires " + << MemoryTypeString(memory_type); + } + return OkStatus(); +} + +Status Buffer::ValidateAccess(MemoryAccessBitfield memory_access) const { + if (!AnyBitSet(memory_access & + (MemoryAccess::kRead | MemoryAccess::kWrite))) { + // No actual access bits defined. + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Memory access must specify one or more of kRead or kWrite"; + } else if ((allowed_access_ & memory_access) != memory_access) { + // Bits must match exactly. + return PermissionDeniedErrorBuilder(IREE_LOC) + << "The buffer does not support the requested access type; buffer " + "allows " + << MemoryAccessString(allowed_access_) << ", operation requires " + << MemoryAccessString(memory_access); + } + return OkStatus(); +} + +Status Buffer::ValidateUsage(BufferUsageBitfield usage) const { + if ((usage_ & usage) != usage) { + // Missing one or more bits. + return PermissionDeniedErrorBuilder(IREE_LOC) + << "Requested usage was not specified when the buffer was " + "allocated; buffer allows " + << BufferUsageString(usage_) << ", operation requires " + << BufferUsageString(usage); + } + return OkStatus(); +} + +Status Buffer::CalculateRange(device_size_t base_offset, + device_size_t max_length, device_size_t offset, + device_size_t length, + device_size_t* out_adjusted_offset, + device_size_t* out_adjusted_length) { + // Check if the start of the range runs off the end of the buffer. + if (offset > max_length) { + *out_adjusted_offset = 0; + if (out_adjusted_length) *out_adjusted_length = 0; + return OutOfRangeErrorBuilder(IREE_LOC) + << "Attempted to access an address off the end of the valid buffer " + "range (offset=" + << offset << ", length=" << length + << ", buffer byte_length=" << max_length << ")"; + } + + // Handle length as kWholeBuffer by adjusting it (if allowed). + if (length == kWholeBuffer && !out_adjusted_length) { + *out_adjusted_offset = 0; + return InvalidArgumentErrorBuilder(IREE_LOC) + << "kWholeBuffer may only be used with buffer ranges, not external " + "pointer ranges"; + } + + // Calculate the real ranges adjusted for our region within the allocation. + device_size_t adjusted_offset = base_offset + offset; + device_size_t adjusted_length = + length == kWholeBuffer ? max_length - offset : length; + if (adjusted_length == 0) { + // Fine to have a zero length. + *out_adjusted_offset = adjusted_offset; + if (out_adjusted_length) *out_adjusted_length = adjusted_length; + return OkStatus(); + } + + // Check if the end runs over the allocation. + device_size_t end = offset + adjusted_length - 1; + if (end >= max_length) { + *out_adjusted_offset = 0; + if (out_adjusted_length) *out_adjusted_length = 0; + return OutOfRangeErrorBuilder(IREE_LOC) + << "Attempted to access an address outside of the valid buffer " + "range (offset=" + << offset << ", adjusted_length=" << adjusted_length + << ", end=" << end << ", buffer byte_length=" << max_length << ")"; + } + + *out_adjusted_offset = adjusted_offset; + if (out_adjusted_length) *out_adjusted_length = adjusted_length; + return OkStatus(); +} + +Status Buffer::CalculateRange(device_size_t offset, device_size_t length, + device_size_t* out_adjusted_offset, + device_size_t* out_adjusted_length) const { + return CalculateRange(byte_offset_, byte_length_, offset, length, + out_adjusted_offset, out_adjusted_length); +} + +Status Buffer::CalculateLocalRange(device_size_t max_length, + device_size_t offset, device_size_t length, + device_size_t* out_adjusted_offset, + device_size_t* out_adjusted_length) { + return CalculateRange(0, max_length, offset, length, out_adjusted_offset, + out_adjusted_length); +} + +Status Buffer::Fill(device_size_t byte_offset, device_size_t byte_length, + const void* pattern, device_size_t pattern_length) { + // If not host visible we'll need to issue command buffers. + RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible)); + RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kWrite)); + RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping)); + RETURN_IF_ERROR( + CalculateRange(byte_offset, byte_length, &byte_offset, &byte_length)); + if (pattern_length != 1 && pattern_length != 2 && pattern_length != 4) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Fill patterns must be 1, 2, or 4 bytes"; + } + if ((byte_offset % pattern_length) != 0 || + (byte_length % pattern_length) != 0) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Attempting to fill a range with " << pattern_length + << " byte values that is not " + "aligned (offset=" + << byte_offset << ", length=" << byte_length << ")"; + } + if (byte_length == 0) { + return OkStatus(); // No-op. + } + const uint32_t kZero = 0; + if (std::memcmp(pattern, &kZero, pattern_length) == 0) { + // We can turn all-zero values into single-byte fills as that can be much + // faster on devices (doing a fill8 vs fill32). + pattern_length = 1; + } + return FillImpl(byte_offset, byte_length, pattern, pattern_length); +} + +Status Buffer::ReadData(device_size_t source_offset, void* data, + device_size_t data_length) { + // If not host visible we'll need to issue command buffers. + RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible)); + RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kRead)); + RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping)); + RETURN_IF_ERROR(CalculateRange(source_offset, data_length, &source_offset)); + if (data_length == 0) { + return OkStatus(); // No-op. + } + return ReadDataImpl(source_offset, data, data_length); +} + +Status Buffer::WriteData(device_size_t target_offset, const void* data, + device_size_t data_length) { + // If not host visible we'll need to issue command buffers. + RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible)); + RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kWrite)); + RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping)); + RETURN_IF_ERROR(CalculateRange(target_offset, data_length, &target_offset)); + if (data_length == 0) { + return OkStatus(); // No-op. + } + return WriteDataImpl(target_offset, data, data_length); +} + +Status Buffer::CopyData(device_size_t target_offset, Buffer* source_buffer, + device_size_t source_offset, + device_size_t data_length) { + RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible)); + RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kWrite)); + RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping)); + RETURN_IF_ERROR( + source_buffer->ValidateCompatibleMemoryType(MemoryType::kHostVisible)); + RETURN_IF_ERROR(source_buffer->ValidateAccess(MemoryAccess::kRead)); + RETURN_IF_ERROR(source_buffer->ValidateUsage(BufferUsage::kMapping)); + + // We need to validate both buffers. + device_size_t source_data_length = data_length; + device_size_t target_data_length = data_length; + device_size_t adjusted_source_offset; + RETURN_IF_ERROR(source_buffer->CalculateRange( + source_offset, source_data_length, &adjusted_source_offset, + &source_data_length)); + RETURN_IF_ERROR(CalculateRange(target_offset, target_data_length, + &target_offset, &target_data_length)); + device_size_t adjusted_data_length; + if (data_length == kWholeBuffer) { + // Whole buffer copy requested - that could mean either, so take the min. + adjusted_data_length = std::min(source_data_length, target_data_length); + } else { + // Specific length requested - validate that we have matching lengths. + CHECK_EQ(source_data_length, target_data_length); + adjusted_data_length = source_data_length; + } + + // Elide zero length copies. + if (adjusted_data_length == 0) { + return OkStatus(); + } + + // Check for overlap. + if (this == source_buffer && + adjusted_source_offset <= target_offset + adjusted_data_length && + target_offset <= adjusted_source_offset + adjusted_data_length) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Source and target ranges overlap within the same buffer"; + } + + return CopyDataImpl(target_offset, source_buffer, source_offset, + adjusted_data_length); +} + +Status Buffer::MapMemory(MappingMode mapping_mode, + MemoryAccessBitfield memory_access, + device_size_t* byte_offset, device_size_t* byte_length, + void** out_data) { + RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible)); + RETURN_IF_ERROR(ValidateAccess(memory_access)); + RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping)); + RETURN_IF_ERROR( + CalculateRange(*byte_offset, *byte_length, byte_offset, byte_length)); + *out_data = nullptr; + return MapMemoryImpl(mapping_mode, memory_access, *byte_offset, *byte_length, + out_data); +} + +Status Buffer::UnmapMemory(device_size_t local_byte_offset, + device_size_t local_byte_length, void* data) { + RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible)); + RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping)); + // NOTE: local_byte_offset/local_byte_length are already adjusted. + return UnmapMemoryImpl(local_byte_offset, local_byte_length, data); +} + +Status Buffer::InvalidateMappedMemory(device_size_t local_byte_offset, + device_size_t local_byte_length) { + RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible)); + if (AnyBitSet(memory_type_ & MemoryType::kHostCoherent)) { + return PermissionDeniedErrorBuilder(IREE_LOC) + << "Buffer memory type is coherent and invalidation is not required"; + } + RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping)); + // NOTE: local_byte_offset/local_byte_length are already adjusted. + return InvalidateMappedMemoryImpl(local_byte_offset, local_byte_length); +} + +Status Buffer::FlushMappedMemory(device_size_t local_byte_offset, + device_size_t local_byte_length) { + RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible | + MemoryType::kHostCached)); + RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping)); + // NOTE: local_byte_offset/local_byte_length are already adjusted. + return FlushMappedMemoryImpl(local_byte_offset, local_byte_length); +} + +} // namespace hal +} // namespace iree
diff --git a/hal/buffer.h b/hal/buffer.h new file mode 100644 index 0000000..5de808d --- /dev/null +++ b/hal/buffer.h
@@ -0,0 +1,903 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Allocated memory buffer wrapper type and utilities. +// +// Buffers are the basic unit of memory used by the inference system. They may +// be allocated such that they are accessible from the host (normal C++ code +// running on the main CPU), a particular device (such as an accelerator) or +// family of devices, or from some mix of all of those. +// +// The type of memory a buffer is allocated within has implications on it's +// performance and lifetime. For example if an application attempts to use a +// host-allocated buffer (MemoryType::kHostLocal) on an accelerator with +// discrete memory the accelerator may either be unable to access the memory or +// take a non-trivial performance hit when attempting to do so (involving +// setting up kernel mappings, doing DMA transfers, etc). Likewise, trying to +// access a device-allocated buffer (MemoryType::kDeviceLocal) may incur similar +// overhead or not be possible at all. This may be due to restrictions in the +// memory visibility, address spaces, mixed endianness or pointer widths, +// and other weirdness. +// +// The memory types (defined by a bitfield of MemoryType values) that a +// particular context (host or device) may use vary from device to device and +// must be queried by the application when allocating buffers. It's strongly +// recommended that the most specific memory type be set as possible. For +// example allocating a buffer with MemoryType::kHostCoherent even when it will +// never be used in a way that requires coherency may occupy address space +// reservations or memory mapping that would otherwise not be needed. +// +// As buffers may sometimes not be accessible from the host the base Buffer type +// does not allow for direct void* access and instead buffers must be either +// manipulated using utility functions (such as ReadData or WriteData) or by +// mapping them into a host-accessible address space via MapMemory. Buffer must +// be unmapped before any command may use it. +// +// Buffers may map (roughly) 1:1 with an allocation either from the host heap or +// a device. Buffer::Subspan can be used to reference subspans of buffers like +// absl::Span - though unlike absl::Span the returned Buffer holds a reference +// to the parent buffer. + +#ifndef IREE_HAL_BUFFER_H_ +#define IREE_HAL_BUFFER_H_ + +#include <cstddef> +#include <cstdint> +#include <memory> +#include <string> +#include <utility> + +#include "absl/types/span.h" +#include "absl/types/variant.h" +#include "base/bitfield.h" +#include "base/logging.h" +#include "base/source_location.h" +#include "base/status.h" +#include "hal/resource.h" + +// Only enable debug names in non-opt modes (unless the user forces it on). +#if !defined(NDEBUG) && !defined(HAS_IREE_BUFFER_DEBUG_NAME) +#define HAS_IREE_BUFFER_DEBUG_NAME 1 +#endif // !NDEBUG + +namespace iree { + +// std::size_t equivalent that is the size as used on device. +// As the device may have a larger memory address space than the host we treat +// all byte offsets as this type instead of the host-specified size_t. +using device_size_t = uint64_t; + +// When used as a length value in functions causes the length to be the entire +// remaining buffer from the specified offset. +constexpr device_size_t kWholeBuffer = ~0ull; + +} // namespace iree + +namespace iree { +namespace hal { + +class Allocator; +template <typename T> +class MappedMemory; + +// A bitfield specifying properties for a memory type. +enum class MemoryType : uint32_t { + kNone = 0, + + // Memory is lazily allocated by the device and only exists transiently. + // This is the optimal mode for memory used only within a single command + // buffer. Transient buffers, even if they have kHostVisible set, should be + // treated as device-local and opaque as they may have no memory attached to + // them outside of the time they are being evaluated on devices. + // + // This flag can be treated as a hint in most cases; allocating a buffer with + // it set _may_ return the same as if it had not be set. Certain allocation + // routines may use the hint to more tightly control reuse or defer wiring the + // memory. + kTransient = 1 << 0, + + // Memory allocated with this type can be mapped for host access using + // Buffer::MapMemory. + kHostVisible = 1 << 1, + + // The host cache management commands MappedMemory::Flush and + // MappedMemory::Invalidate are not needed to flush host writes + // to the device or make device writes visible to the host, respectively. + kHostCoherent = 1 << 2, + + // Memory allocated with this type is cached on the host. Host memory + // accesses to uncached memory are slower than to cached memory, however + // uncached memory is always host coherent. MappedMemory::Flush must be used + // to ensure the device has visibility into any changes made on the host and + // Invalidate must be used to ensure the host has visibility into any changes + // made on the device. + kHostCached = 1 << 3, + + // Memory is accessible as normal host allocated memory. + kHostLocal = kHostVisible | kHostCoherent, + + // Memory allocated with this type is visible to the device for execution. + // Being device visible does not mean the same thing as kDeviceLocal. Though + // an allocation may be visible to the device and therefore useable for + // execution it may require expensive mapping or implicit transfers. + kDeviceVisible = 1 << 4, + + // Memory allocated with this type is the most efficient for device access. + // Devices may support using memory that is not device local via + // kDeviceVisible but doing so can incur non-trivial performance penalties. + // Device local memory, on the other hand, is guaranteed to be fast for all + // operations. + kDeviceLocal = kDeviceVisible | (1 << 5), +}; +IREE_BITFIELD(MemoryType); +using MemoryTypeBitfield = MemoryType; +std::string MemoryTypeString(MemoryTypeBitfield memory_type); + +// A bitfield specifying how memory will be accessed in a mapped memory region. +enum class MemoryAccess : uint32_t { + // Memory is not mapped. + kNone = 0, + + // Memory will be read. + // If a buffer is only mapped for reading it may still be possible to write to + // it but the results will be undefined (as it may present coherency issues). + kRead = 1 << 0, + + // Memory will be written. + // If a buffer is only mapped for writing it may still be possible to read + // from it but the results will be undefined or incredibly slow (as it may + // be mapped by the driver as uncached). + kWrite = 1 << 1, + + // Memory will be discarded prior to mapping. + // The existing contents will be undefined after mapping and must be written + // to ensure validity. + kDiscard = 1 << 2, + + // Memory will be discarded and completely overwritten in a single operation. + kDiscardWrite = kWrite | kDiscard, + + // Memory may have any operation performed on it. + kAll = kRead | kWrite | kDiscard, +}; +IREE_BITFIELD(MemoryAccess); +using MemoryAccessBitfield = MemoryAccess; +std::string MemoryAccessString(MemoryAccessBitfield memory_access); + +// Bitfield that defines how a buffer is intended to be used. +// Usage allows the driver to appropriately place the buffer for more +// efficient operations of the specified types. +enum class BufferUsage { + kNone = 0, + + // The buffer, once defined, will not be mapped or updated again. + // This should be used for uniform parameter values such as runtime + // constants for executables. Doing so may allow drivers to inline values or + // represent them in command buffers more efficiently (avoiding memory reads + // or swapping, etc). + kConstant = 1 << 0, + + // The buffer can be used as the source or target of a transfer command + // (CopyBuffer, UpdateBuffer, etc). + // + // If |kMapping| is not specified drivers may safely assume that the host + // may never need visibility of this buffer as all accesses will happen via + // command buffers. + kTransfer = 1 << 1, + + // The buffer can be mapped by the host application for reading and writing. + // + // As mapping may require placement in special address ranges or system + // calls to enable visibility the driver can use the presence (or lack of) + // this flag to perform allocation-type setup and avoid initial mapping + // overhead. + kMapping = 1 << 2, + + // The buffer can be provided as an input or output to an executable. + // Buffers of this type may be directly used by drivers during dispatch. + kDispatch = 1 << 3, + + // Buffer may be used for any operation. + kAll = kTransfer | kMapping | kDispatch, +}; +IREE_BITFIELD(BufferUsage); +using BufferUsageBitfield = BufferUsage; +std::string BufferUsageString(BufferUsageBitfield buffer_usage); + +// A memory buffer. +// Buffers have a specific memory_type that is used to describe the capabilities +// and behavior of the backing memory of the buffer. Buffers may be any mix of +// host-accessible, host-coherent, or device-accessible for various usages. +// Depending on these memory types the buffers may be mapped for access on the +// host as memory though certain restrictions may be imposed. +// +// See MemoryType for more information about the types and what operations they +// support. +class Buffer : public Resource { + public: + // Returns a reference to a subspan of the buffer. + // If |byte_length| is kWholeBuffer the remaining bytes in the buffer after + // |byte_offset| (possibly 0) will be selected. + // + // The parent buffer will remain alive for the lifetime of the subspan + // returned. If the subspan is a small portion this may cause additional + // memory to remain allocated longer than required. + // + // Returns the given |buffer| if the requested span covers the entire range. + static StatusOr<ref_ptr<Buffer>> Subspan(const ref_ptr<Buffer>& buffer, + device_size_t byte_offset, + device_size_t byte_length); + + // Overlap test results. + enum class Overlap { + // No overlap between the two buffers. + kDisjoint, + // Partial overlap between the two buffers. + kPartial, + // Complete overlap between the two buffers (they are the same). + kComplete, + }; + + // Tests whether the given buffers overlap, including support for subspans. + // kWholeBuffer may be used for |lhs_length| and/or |rhs_length| to use the + // lengths of those buffers, respectively. + static Overlap TestOverlap(Buffer* lhs_buffer, device_size_t lhs_offset, + device_size_t lhs_length, Buffer* rhs_buffer, + device_size_t rhs_offset, + device_size_t rhs_length); + + // Returns true if the two buffer ranges overlap at all. + static bool DoesOverlap(Buffer* lhs_buffer, device_size_t lhs_offset, + device_size_t lhs_length, Buffer* rhs_buffer, + device_size_t rhs_offset, device_size_t rhs_length); + + // Disallow copies (as copying requires real work). + Buffer(const Buffer&) = delete; + Buffer& operator=(const Buffer&) = delete; + + ~Buffer() override = default; + +#if HAS_IREE_BUFFER_DEBUG_NAME + // Optionally populated name useful for logging a persistent name for the + // buffer. + absl::string_view debug_name() const { return debug_name_; } + void set_debug_name(std::string debug_name) { + debug_name_ = std::move(debug_name); + } +#else + absl::string_view debug_name() const { return ""; } + void set_debug_name(std::string debug_name) {} +#endif // HAS_IREE_BUFFER_DEBUG_NAME + + // Memory allocator this buffer was allocated from. + // May be nullptr if the buffer has no particular allocator and should be + // assumed to be allocated from the host heap. + constexpr Allocator* allocator() const { + return allocated_buffer_ == this ? allocator_ + : allocated_buffer_->allocator(); + } + + // Memory type this buffer is allocated from. + MemoryTypeBitfield memory_type() const { return memory_type_; } + + // Memory access operations allowed on the buffer. + MemoryAccessBitfield allowed_access() const { return allowed_access_; } + + // Bitfield describing how the buffer is to be used. + BufferUsageBitfield usage() const { return usage_; } + + // Returns the underlying buffer that represents the allocated memory for the + // Buffer. In most cases this is the buffer itself but for buffer subspan + // references it will point to the parent buffer. + Buffer* allocated_buffer() const noexcept; + + // Size of the resource memory allocation in bytes. + // This may be rounded up from the originally requested size or the ideal + // size for the resource based on device restrictions. + constexpr device_size_t allocation_size() const { + return allocated_buffer_ == this ? allocation_size_ + : allocated_buffer_->allocation_size(); + } + + // Range within the underlying allocation this buffer occupies. + // For buffers that map 1:1 with an allocation this should be + // [0, allocation_size()), however may still differ if the allocation needed + // to be aligned. + // + // The offset is most often manipulated by Subspan, however it's important to + // note that the offset may not be what was passed to Subspan as it refers to + // the offset in the original ancestor buffer, not the buffer from which the + // subspan was taken. + constexpr device_size_t byte_offset() const noexcept { return byte_offset_; } + constexpr device_size_t byte_length() const noexcept { return byte_length_; } + + // TODO(benvanik): add debug_name. + + // Returns a longer debug string describing the buffer and its attributes. + std::string DebugString() const; + // Returns a short debug string describing the buffer. + std::string DebugStringShort() const; + + // Sets a range of the buffer to the given value. + // This requires that the resource was allocated with + // MemoryType::kHostVisible and BufferUsage::kMapping. + // If |byte_length| is kWholeBuffer the remaining bytes in the buffer after + // |byte_offset| (possibly 0) will be filled. + // + // The |byte_offset| and |byte_length| must be aligned to the size of the fill + // value. Multi-byte values will be written in host order for host buffers and + // device order for device buffers. + // + // Only |pattern_length| values with 1, 2, or 4 bytes are supported. + // + // Fails if the write could not be performed; either the bounds are out of + // range or the memory type does not support writing in this way. + Status Fill(device_size_t byte_offset, device_size_t byte_length, + const void* pattern, device_size_t pattern_length); + template <typename T> + Status Fill8(device_size_t byte_offset, device_size_t byte_length, T value); + template <typename T> + Status Fill16(device_size_t byte_offset, device_size_t byte_length, T value); + template <typename T> + Status Fill32(device_size_t byte_offset, device_size_t byte_length, T value); + template <typename T> + Status Fill8(T value); + template <typename T> + Status Fill16(T value); + template <typename T> + Status Fill32(T value); + + // Reads a block of byte data from the resource at the given offset. + // This requires that the resource was allocated with + // MemoryType::kHostVisible and BufferUsage::kMapping. + // + // Fails if the read could not be performed; either the bounds are out of + // range or the memory type does not support reading in this way. + Status ReadData(device_size_t source_offset, void* data, + device_size_t data_length); + + // Writes a block of byte data into the resource at the given offset. + // This requires that the resource was allocated with + // MemoryType::kHostVisible and BufferUsage::kMapping. + // + // Fails if the write could not be performed; either the bounds are out of + // range or the memory type does not support writing in this way. + Status WriteData(device_size_t target_offset, const void* data, + device_size_t data_length); + + // Copies data from the provided source_buffer into the buffer. + // This requires that the resource was allocated with + // MemoryType::kHostVisible and BufferUsage::kMapping. + // The source and destination may be the same buffer but the ranges must not + // overlap (a la memcpy). + // + // Fails if the write could not be performed; either the bounds are out of + // range or the memory type does not support writing in this way. + Status CopyData(device_size_t target_offset, Buffer* source_buffer, + device_size_t source_offset, device_size_t data_length); + Status CopyData(device_size_t target_offset, Buffer* source_buffer) { + return CopyData(target_offset, source_buffer, 0, kWholeBuffer); + } + + // Maps the resource memory for direct access from the host. + // This requires that the resource was allocated with + // MemoryType::kHostVisible and BufferUsage::kMapping. + // + // If MemoryType::kHostCoherent was not specified then explicit + // Invalidate and Flush calls must be used to control visibility of the data + // on the device. If MemoryType::kHostCached is not set callers must not + // attempt to read from the mapped memory as doing so may produce undefined + // results and/or ultra slow reads. + // + // If the MemoryAccess::kDiscard bit is set when mapping for writes the caller + // guarantees that they will be overwriting all data in the mapped range. This + // is used as a hint to the device that the prior contents are no longer + // required and can enable optimizations that save on synchronization and + // readback. Note however that it is strictly a hint and the contents are not + // guaranteed to be zeroed during mapping. + // + // This allows mapping the memory as a C++ type. Care must be taken to ensure + // the data layout in C++ matches the expected data layout in the executables + // that consume this data. For simple primitives like uint8_t or float this is + // usually not a problem however struct packing may have many restrictions. + // + // The returned mapping should be unmapped when it is no longer required. + // Unmapping does not implicitly flush. + // + // Fails if the memory could not be mapped due to mapping exhaustion, invalid + // arguments, or unsupported memory types. + // + // Example: + // ASSIGN_OR_RETURN(auto mapping, buffer->MapForRead<MyStruct>()); + // mapping[5].foo = 3; + // std::memcpy(mapping.data(), source_data, mapping.size()); + // mapping.reset(); + template <typename T> + StatusOr<MappedMemory<T>> MapMemory( + MemoryAccessBitfield memory_access, device_size_t element_offset = 0, + device_size_t element_length = kWholeBuffer); + + protected: + template <typename T> + friend class MappedMemory; + + // Defines the mode of a MapMemory operation. + enum class MappingMode { + // The call to MapMemory will always be matched with UnmapMemory. + kScoped, + }; + + Buffer(Allocator* allocator, MemoryTypeBitfield memory_type, + MemoryAccessBitfield allowed_access, BufferUsageBitfield usage, + device_size_t allocation_size, device_size_t byte_offset, + device_size_t byte_length); + + // Allows subclasses to override the allowed access bits. + // This should only be done when known safe by the allocation scheme. + void set_allowed_access(MemoryAccessBitfield allowed_access) { + allowed_access_ = allowed_access; + } + + // Sets a range of the buffer to the given value. + // State and parameters have already been validated. For the >8bit variants + // the offset and length have already been validated to be aligned to the + // natural alignment of the type. + virtual Status FillImpl(device_size_t byte_offset, device_size_t byte_length, + const void* pattern, + device_size_t pattern_length) = 0; + + // Reads a block of byte data from the resource at the given offset. + // State and parameters have already been validated. + virtual Status ReadDataImpl(device_size_t source_offset, void* data, + device_size_t data_length) = 0; + + // Writes a block of byte data into the resource at the given offset. + // State and parameters have already been validated. + virtual Status WriteDataImpl(device_size_t target_offset, const void* data, + device_size_t data_length) = 0; + + // Copies a block of byte data into the resource at the given offset. + // State and parameters have already been validated. + virtual Status CopyDataImpl(device_size_t target_offset, + Buffer* source_buffer, + device_size_t source_offset, + device_size_t data_length) = 0; + + // Maps memory directly. + // The output data pointer will be properly aligned to the start of the data. + // |local_byte_offset| and |local_byte_length| are the adjusted values that + // should map into the local space of the buffer. + // + // Fails if the memory could not be mapped (invalid access type, invalid + // range, or unsupported memory type). + // State and parameters have already been validated. + virtual Status MapMemoryImpl(MappingMode mapping_mode, + MemoryAccessBitfield memory_access, + device_size_t local_byte_offset, + device_size_t local_byte_length, + void** out_data) = 0; + + // Unmaps previously mapped memory. + // No-op if the memory is not mapped. As this is often used in destructors + // we can't rely on failures here propagating with anything but CHECK/DCHECK. + // State and parameters have already been validated. + virtual Status UnmapMemoryImpl(device_size_t local_byte_offset, + device_size_t local_byte_length, + void* data) = 0; + + // Invalidates ranges of non-coherent memory from the host caches. + // Use this before reading from non-coherent memory. + // This guarantees that device writes to the memory ranges provided are + // visible on the host. + // This is only required for memory types without kHostCoherent set. + // State and parameters have already been validated. + virtual Status InvalidateMappedMemoryImpl( + device_size_t local_byte_offset, device_size_t local_byte_length) = 0; + + // Flushes ranges of non-coherent memory from the host caches. + // Use this after writing to non-coherent memory. + // This guarantees that host writes to the memory ranges provided are made + // available for device access. + // This is only required for memory types without kHostCoherent set. + // State and parameters have already been validated. + virtual Status FlushMappedMemoryImpl(device_size_t local_byte_offset, + device_size_t local_byte_length) = 0; + + // Validates the given buffer range and adjusts the offset and length if the + // provided length is kWholeBuffer or the buffer is offset within its + // allocation. This calculates the range in the given domain without adjusting + // to any particular buffer base offsets. + static Status CalculateLocalRange(device_size_t max_length, + device_size_t offset, device_size_t length, + device_size_t* out_adjusted_offset, + device_size_t* out_adjusted_length); + + private: + friend class Allocator; + + // This is not great and deserves cleanup. + friend class DeferredBuffer; + friend class SubspanBuffer; + friend class HeapBuffer; + + // Maps memory directly. + // The byte offset and byte length may be adjusted for device alignment. + // The output data pointer will be properly aligned to the start of the data. + // Fails if the memory could not be mapped (invalid access type, invalid + // range, or unsupported memory type). + Status MapMemory(MappingMode mapping_mode, MemoryAccessBitfield memory_access, + device_size_t* byte_offset, device_size_t* byte_length, + void** out_data); + + // Unmaps previously mapped memory. + // No-op if the memory is not mapped. As this is often used in destructors + // we can't rely on failures here propagating with anything but CHECK/DCHECK. + Status UnmapMemory(device_size_t local_byte_offset, + device_size_t local_byte_length, void* data); + + // Invalidates ranges of non-coherent memory from the host caches. + // Use this before reading from non-coherent memory. + // This guarantees that device writes to the memory ranges provided are + // visible on the host. + // This is only required for memory types without kHostCoherent set. + Status InvalidateMappedMemory(device_size_t local_byte_offset, + device_size_t local_byte_length); + + // Flushes ranges of non-coherent memory from the host caches. + // Use this after writing to non-coherent memory. + // This guarantees that host writes to the memory ranges provided are made + // available for device access. + // This is only required for memory types without kHostCoherent set. + Status FlushMappedMemory(device_size_t local_byte_offset, + device_size_t local_byte_length); + + // Returns a failure if the memory type the buffer was allocated from is not + // compatible with the given type. + Status ValidateCompatibleMemoryType(MemoryTypeBitfield memory_type) const; + // Returns a failure if the buffer memory type or usage disallows the given + // access type. + Status ValidateAccess(MemoryAccessBitfield memory_access) const; + // Returns a failure if the buffer was not allocated for the given usage. + Status ValidateUsage(BufferUsageBitfield usage) const; + // Validates the given buffer range and optionally adjusts the offset and + // length if the provided length is kWholeBuffer or the buffer is offset + // within its allocation. + static Status CalculateRange(device_size_t base_offset, + device_size_t max_length, device_size_t offset, + device_size_t length, + device_size_t* out_adjusted_offset, + device_size_t* out_adjusted_length = nullptr); + Status CalculateRange(device_size_t offset, device_size_t length, + device_size_t* out_adjusted_offset, + device_size_t* out_adjusted_length = nullptr) const; + + // Points to either this or parent_buffer_.get(). + Buffer* allocated_buffer_ = nullptr; + + Allocator* allocator_ = nullptr; + MemoryTypeBitfield memory_type_ = MemoryType::kNone; + MemoryAccessBitfield allowed_access_ = MemoryAccess::kNone; + BufferUsageBitfield usage_ = BufferUsage::kNone; + + device_size_t allocation_size_ = 0; + device_size_t byte_offset_ = 0; + device_size_t byte_length_ = 0; + +#if HAS_IREE_BUFFER_DEBUG_NAME + // Friendly name for the buffer used in DebugString. May be set by the app or + // auto generated. + std::string debug_name_; +#endif // HAS_IREE_BUFFER_DEBUG_NAME + + // Defined when this buffer is a subspan of another buffer. + ref_ptr<Buffer> parent_buffer_; +}; + +// A memory mapping RAII object. +// The mapping will stay active until it is reset and will retain the buffer. +template <typename T> +class MappedMemory { + public: + using unspecified_bool_type = const T* MappedMemory<T>::*; + + MappedMemory() = default; + MappedMemory(MemoryAccessBitfield access, ref_ptr<Buffer> buffer, + device_size_t byte_offset, device_size_t byte_length, + device_size_t element_size, T* data); + + // Allow moving but disallow copying as the mapping is stateful. + MappedMemory(MappedMemory&& rhs) noexcept; + MappedMemory& operator=(MappedMemory&& rhs) noexcept; + MappedMemory(const MappedMemory&) = delete; + MappedMemory& operator=(const MappedMemory&) = delete; + + ~MappedMemory(); + + // The buffer resource that this mapping references. + const ref_ptr<Buffer>& buffer() const noexcept { return buffer_; } + // Offset, in bytes, into the resource allocation. + // This value is *informative only*, as it may vary from device to device. + device_size_t byte_offset() const noexcept { return byte_offset_; } + // Length, in bytes, of the resource mapping. + // This may be larger than the originally requested length due to alignment. + // This value is *informative only*, as it may vary from device to device. + device_size_t byte_length() const noexcept { return byte_length_; } + + // True if the mapping is empty. + bool empty() const noexcept { return element_size_ == 0; } + // The size of the mapping as requested in elements. + size_t size() const noexcept { return static_cast<size_t>(element_size_); } + + // Returns a read-only pointer to the mapped memory. + // This will be nullptr if the mapping failed or the mapping is not readable. + const T* data() const noexcept; + absl::Span<const T> contents() const noexcept { return {data(), size()}; } + + // Returns a mutable pointer to the mapped memory. + // This will be nullptr if the mapping failed or the mapping is not writable. + // If the mapping was not made with read access it may still be possible to + // read from this memory but behavior is undefined. + T* mutable_data() noexcept; + absl::Span<T> mutable_contents() noexcept { return {mutable_data(), size()}; } + + // Equivalent to absl::Span::subspan(). + // May return a 0-length span. + // Fails if the buffer is not mapped or not mapped for the requested access. + StatusOr<absl::Span<const T>> Subspan( + device_size_t element_offset = 0, + device_size_t element_length = kWholeBuffer) const noexcept; + StatusOr<absl::Span<T>> MutableSubspan( + device_size_t element_offset = 0, + device_size_t element_length = kWholeBuffer) noexcept; + + // Accesses an element in the mapped memory. + // Must be called with a valid index in [0, size()). + const T& operator[](device_size_t i) const noexcept { return data_[i]; } + + // Invalidates a range of non-coherent elements from the host caches. + Status Invalidate(device_size_t element_offset = 0, + device_size_t element_length = kWholeBuffer) const; + + // Flushes a range of non-coherent elements from the host caches. + Status Flush(device_size_t element_offset = 0, + device_size_t element_length = kWholeBuffer); + + // Unmaps the mapped memory. + // The memory will not be implicitly flushed when unmapping. + void reset(); + + private: + Status ValidateAccess(MemoryAccessBitfield memory_access) const; + Status CalculateDataRange(device_size_t element_offset, + device_size_t element_length, + device_size_t* out_adjusted_element_offset, + device_size_t* out_adjusted_element_length) const; + + MemoryAccessBitfield access_ = MemoryAccess::kNone; + ref_ptr<Buffer> buffer_; + device_size_t byte_offset_ = 0; + device_size_t byte_length_ = 0; + device_size_t element_size_ = 0; + T* data_ = nullptr; +}; + +// Inline functions and template definitions follow: + +template <typename T> +Status Buffer::Fill8(device_size_t byte_offset, device_size_t byte_length, + T value) { + auto sized_value = reinterpret_cast<uint8_t*>(&value); + return Fill(byte_offset, byte_length, sized_value, sizeof(*sized_value)); +} + +template <typename T> +Status Buffer::Fill16(device_size_t byte_offset, device_size_t byte_length, + T value) { + auto sized_value = reinterpret_cast<uint16_t*>(&value); + return Fill(byte_offset, byte_length, sized_value, sizeof(*sized_value)); +} + +template <typename T> +Status Buffer::Fill32(device_size_t byte_offset, device_size_t byte_length, + T value) { + auto sized_value = reinterpret_cast<uint32_t*>(&value); + return Fill(byte_offset, byte_length, sized_value, sizeof(*sized_value)); +} + +template <typename T> +Status Buffer::Fill8(T value) { + return Fill8(0, kWholeBuffer, value); +} + +template <typename T> +Status Buffer::Fill16(T value) { + return Fill16(0, kWholeBuffer, value); +} + +template <typename T> +Status Buffer::Fill32(T value) { + return Fill32(0, kWholeBuffer, value); +} + +template <typename T> +StatusOr<MappedMemory<T>> Buffer::MapMemory(MemoryAccessBitfield memory_access, + device_size_t element_offset, + device_size_t element_length) { + device_size_t byte_offset = element_offset * sizeof(T); + device_size_t byte_length = element_length == kWholeBuffer + ? kWholeBuffer + : element_length * sizeof(T); + void* data = nullptr; + RETURN_IF_ERROR(MapMemory(MappingMode::kScoped, memory_access, &byte_offset, + &byte_length, &data)); + return MappedMemory<T>{ + memory_access, add_ref(this), byte_offset, + byte_length, byte_length / sizeof(T), static_cast<T*>(data)}; +} + +template <typename T> +MappedMemory<T>::MappedMemory(MemoryAccessBitfield access, + ref_ptr<Buffer> buffer, device_size_t byte_offset, + device_size_t byte_length, + device_size_t element_size, T* data) + : access_(access), + buffer_(std::move(buffer)), + byte_offset_(byte_offset), + byte_length_(byte_length), + element_size_(element_size), + data_(data) {} + +template <typename T> +MappedMemory<T>::MappedMemory(MappedMemory<T>&& rhs) noexcept + : access_(rhs.access_), + buffer_(std::move(rhs.buffer_)), + byte_offset_(rhs.byte_offset_), + byte_length_(rhs.byte_length_), + element_size_(rhs.element_size_), + data_(rhs.data_) { + rhs.access_ = MemoryAccess::kNone; + rhs.buffer_.reset(); + rhs.byte_offset_ = 0; + rhs.byte_length_ = 0; + rhs.element_size_ = 0; + rhs.data_ = nullptr; +} + +template <typename T> +MappedMemory<T>& MappedMemory<T>::operator=(MappedMemory<T>&& rhs) noexcept { + if (this != &rhs) { + reset(); + access_ = rhs.access_; + buffer_ = std::move(rhs.buffer_); + byte_offset_ = rhs.byte_offset_; + byte_length_ = rhs.byte_length_; + element_size_ = rhs.element_size_; + data_ = rhs.data_; + + rhs.access_ = MemoryAccess::kNone; + rhs.buffer_.reset(); + rhs.byte_offset_ = 0; + rhs.byte_length_ = 0; + rhs.element_size_ = 0; + rhs.data_ = nullptr; + } + return *this; +} + +template <typename T> +MappedMemory<T>::~MappedMemory() { + // Unmap (if needed) - note that we can't fail gracefully here :( + reset(); +} + +template <typename T> +const T* MappedMemory<T>::data() const noexcept { + if (!data_ || !AnyBitSet(access_ & MemoryAccess::kRead)) { + return nullptr; + } + return data_; +} + +template <typename T> +T* MappedMemory<T>::mutable_data() noexcept { + if (!data_ || !AnyBitSet(access_ & MemoryAccess::kWrite)) { + return nullptr; + } + return data_; +} + +template <typename T> +Status MappedMemory<T>::ValidateAccess( + MemoryAccessBitfield memory_access) const { + if (!data_) { + return FailedPreconditionErrorBuilder(IREE_LOC) << "Buffer is not mapped"; + } else if (!AnyBitSet(access_ & memory_access)) { + return PermissionDeniedErrorBuilder(IREE_LOC) + << "Buffer is not mapped for the desired access"; + } + return OkStatus(); +} + +template <typename T> +Status MappedMemory<T>::CalculateDataRange( + device_size_t element_offset, device_size_t element_length, + device_size_t* out_adjusted_element_offset, + device_size_t* out_adjusted_element_length) const { + RETURN_IF_ERROR(Buffer::CalculateLocalRange( + element_size_ * sizeof(T), element_offset * sizeof(T), + element_length == kWholeBuffer ? kWholeBuffer + : element_length * sizeof(T), + out_adjusted_element_offset, out_adjusted_element_length)); + *out_adjusted_element_offset /= sizeof(T); + *out_adjusted_element_length /= sizeof(T); + return OkStatus(); +} + +template <typename T> +inline StatusOr<absl::Span<const T>> MappedMemory<T>::Subspan( + device_size_t element_offset, device_size_t element_length) const noexcept { + RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kRead)); + RETURN_IF_ERROR(CalculateDataRange(element_offset, element_length, + &element_offset, &element_length)); + return absl::Span<const T>(data_ + element_offset, element_length); +} + +template <typename T> +inline StatusOr<absl::Span<T>> MappedMemory<T>::MutableSubspan( + device_size_t element_offset, device_size_t element_length) noexcept { + RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kWrite)); + RETURN_IF_ERROR(CalculateDataRange(element_offset, element_length, + &element_offset, &element_length)); + return absl::Span<T>(data_ + element_offset, element_length); +} + +template <typename T> +Status MappedMemory<T>::Invalidate(device_size_t element_offset, + device_size_t element_length) const { + RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kRead)); + RETURN_IF_ERROR(CalculateDataRange(element_offset, element_length, + &element_offset, &element_length)); + if (!element_length) return OkStatus(); + return buffer_->InvalidateMappedMemory( + byte_offset_ + element_offset * sizeof(T), element_length * sizeof(T)); +} + +template <typename T> +Status MappedMemory<T>::Flush(device_size_t element_offset, + device_size_t element_length) { + RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kWrite)); + RETURN_IF_ERROR(CalculateDataRange(element_offset, element_length, + &element_offset, &element_length)); + if (!element_length) return OkStatus(); + return buffer_->FlushMappedMemory(byte_offset_ + element_offset * sizeof(T), + element_length * sizeof(T)); +} + +template <typename T> +void MappedMemory<T>::reset() { + if (!buffer_) return; + // TODO(benvanik): better handling of errors? may be fine to always warn. + buffer_->UnmapMemory(byte_offset_, byte_length_, data_).IgnoreError(); + buffer_.reset(); + access_ = MemoryAccess::kNone; + byte_offset_ = 0; + byte_length_ = 0; + element_size_ = 0; + data_ = nullptr; +} + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_BUFFER_H_
diff --git a/hal/buffer_mapping_test.cc b/hal/buffer_mapping_test.cc new file mode 100644 index 0000000..1bffbaa --- /dev/null +++ b/hal/buffer_mapping_test.cc
@@ -0,0 +1,539 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tests for the MemoryMapping RAII wrapper. +// This uses a mock buffer implementation such that it is only testing +// MemoryMapping and not any real underlying memory mapping behavior. + +#include <cstdint> +#include <memory> +#include <utility> + +#include "absl/types/span.h" +#include "base/status.h" +#include "base/status_matchers.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "hal/buffer.h" + +namespace iree { +namespace hal { +class Allocator; + +namespace { + +using ::testing::_; +using ::testing::DoAll; +using ::testing::Return; +using ::testing::SetArgPointee; + +static void* const kValidPtr = reinterpret_cast<void*>(0xBEEFCAFEF00D1234ull); + +class MockBuffer : public Buffer { + public: + using Buffer::MappingMode; + + MockBuffer(Allocator* allocator, MemoryTypeBitfield memory_type, + MemoryAccessBitfield allowed_access, BufferUsageBitfield usage, + device_size_t allocation_size) + : Buffer(allocator, memory_type, allowed_access, usage, allocation_size, + 0, allocation_size) {} + + MOCK_METHOD4(FillImpl, + Status(device_size_t byte_offset, device_size_t byte_length, + const void* pattern, device_size_t pattern_length)); + + MOCK_METHOD3(ReadDataImpl, Status(device_size_t source_offset, void* data, + device_size_t data_length)); + MOCK_METHOD3(WriteDataImpl, + Status(device_size_t target_offset, const void* data, + device_size_t data_length)); + MOCK_METHOD4(CopyDataImpl, + Status(device_size_t target_offset, Buffer* source_buffer, + device_size_t source_offset, device_size_t data_length)); + + MOCK_METHOD5(MapMemoryImpl, + Status(MappingMode mapping_mode, + MemoryAccessBitfield memory_access, + device_size_t local_byte_offset, + device_size_t local_byte_length, void** out_data)); + MOCK_METHOD3(UnmapMemoryImpl, + Status(device_size_t local_byte_offset, + device_size_t local_byte_length, void* data)); + MOCK_METHOD2(InvalidateMappedMemoryImpl, + Status(device_size_t local_byte_offset, + device_size_t local_byte_length)); + MOCK_METHOD2(FlushMappedMemoryImpl, Status(device_size_t local_byte_offset, + device_size_t local_byte_length)); +}; + +TEST(MemoryMappingTest, MapWholeBuffer) { + auto buffer = + std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal, + MemoryAccess::kAll, BufferUsage::kAll, 128); + EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, + MemoryAccess::kRead, 0, 128, _)) + .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); + ASSERT_OK_AND_ASSIGN(auto mapping, + buffer->MapMemory<uint8_t>(MemoryAccess::kRead)); + EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) + .WillOnce(Return(OkStatus())); + mapping.reset(); +} + +TEST(MemoryMappingTest, MapPartialBuffer) { + auto buffer = + std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal, + MemoryAccess::kAll, BufferUsage::kAll, 128); + EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, + MemoryAccess::kRead, 4, 12, _)) + .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); + ASSERT_OK_AND_ASSIGN(auto mapping, + buffer->MapMemory<uint8_t>(MemoryAccess::kRead, 4, 12)); + EXPECT_CALL(*buffer, UnmapMemoryImpl(4, 12, kValidPtr)) + .WillOnce(Return(OkStatus())); + mapping.reset(); +} + +TEST(MemoryMappingTest, EmptyHandle) { + MappedMemory<uint8_t> mm_a; + MappedMemory<uint8_t> mm_b; + mm_a = std::move(mm_b); + EXPECT_EQ(nullptr, mm_a.buffer()); + EXPECT_EQ(0, mm_a.byte_offset()); + EXPECT_EQ(0, mm_a.byte_length()); + EXPECT_TRUE(mm_a.empty()); + EXPECT_EQ(0, mm_a.size()); + EXPECT_EQ(nullptr, mm_a.data()); + EXPECT_EQ(nullptr, mm_a.mutable_data()); + EXPECT_TRUE(IsFailedPrecondition(mm_a.Subspan().status())); + EXPECT_TRUE(IsFailedPrecondition(mm_a.MutableSubspan().status())); + EXPECT_TRUE(IsFailedPrecondition(mm_a.Invalidate())); + EXPECT_TRUE(IsFailedPrecondition(mm_a.Flush())); + mm_a.reset(); +} + +TEST(MemoryMappingTest, MoveHandle) { + auto buffer = + std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal, + MemoryAccess::kAll, BufferUsage::kAll, 128); + + EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, + MemoryAccess::kRead, 0, 128, _)) + .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); + ASSERT_OK_AND_ASSIGN(auto mm_a, + buffer->MapMemory<uint8_t>(MemoryAccess::kRead)); + + // Should be able to move the handle around without having any calls. + auto mm_b = std::move(mm_a); + mm_a = std::move(mm_b); + mm_b = std::move(mm_a); + + EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) + .WillOnce(Return(OkStatus())); + mm_b.reset(); +} + +TEST(MemoryMappingTest, ReadOnlyAccess) { + auto buffer = + std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal, + MemoryAccess::kRead, BufferUsage::kAll, 128); + + // Should succeed to map for reading. + EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, + MemoryAccess::kRead, 0, 128, _)) + .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); + ASSERT_OK_AND_ASSIGN(auto mm_r, + buffer->MapMemory<uint8_t>(MemoryAccess::kRead)); + + // Non-mutable access is fine. + EXPECT_EQ(kValidPtr, mm_r.data()); + ASSERT_OK_AND_ASSIGN(auto span, mm_r.Subspan()); + (void)span; + + // Read-only mappings should not be able to get mutable access. + EXPECT_EQ(nullptr, mm_r.mutable_data()); + EXPECT_TRUE(IsPermissionDenied(mm_r.MutableSubspan().status())); + + // Read-only mappings should not be able to call Flush. + EXPECT_TRUE(IsPermissionDenied(mm_r.Flush())); + + EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) + .WillOnce(Return(OkStatus())); + mm_r.reset(); + + // Should fail to map for writing. + EXPECT_TRUE(IsPermissionDenied( + buffer->MapMemory<uint8_t>(MemoryAccess::kWrite).status())); +} + +TEST(MemoryMappingTest, ReadWriteAccess) { + auto buffer = std::make_shared<MockBuffer>( + nullptr, MemoryType::kHostLocal, + MemoryAccess::kRead | MemoryAccess::kWrite, BufferUsage::kAll, 128); + + // Should succeed to map for reading and/or writing. + EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, + MemoryAccess::kRead | MemoryAccess::kWrite, + 0, 128, _)) + .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); + ASSERT_OK_AND_ASSIGN( + auto mm_rw, + buffer->MapMemory<uint8_t>(MemoryAccess::kRead | MemoryAccess::kWrite)); + + // Everything valid. + EXPECT_EQ(kValidPtr, mm_rw.data()); + ASSERT_OK_AND_ASSIGN(auto span, mm_rw.Subspan()); + EXPECT_EQ(kValidPtr, mm_rw.mutable_data()); + ASSERT_OK_AND_ASSIGN(span, mm_rw.MutableSubspan()); + + EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) + .WillOnce(Return(OkStatus())); + mm_rw.reset(); + + // Should fail to map for discard. + EXPECT_TRUE(IsPermissionDenied( + buffer->MapMemory<uint8_t>(MemoryAccess::kDiscardWrite).status())); +} + +TEST(MemoryMappingTest, WriteOnlyAccess) { + auto buffer = std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal, + MemoryAccess::kWrite, + BufferUsage::kAll, 128); + + // Should succeed to map for writing. + EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, + MemoryAccess::kWrite, 0, 128, _)) + .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); + ASSERT_OK_AND_ASSIGN(auto mm_w, + buffer->MapMemory<uint8_t>(MemoryAccess::kWrite)); + + // Mutable access is valid. + EXPECT_EQ(kValidPtr, mm_w.mutable_data()); + ASSERT_OK_AND_ASSIGN(auto span, mm_w.MutableSubspan()); + (void)span; + + // Write-only mappings should not be able to get non-mutable access. + EXPECT_EQ(nullptr, mm_w.data()); + EXPECT_TRUE(IsPermissionDenied(mm_w.Subspan().status())); + + // Write-only mappings should not be able to call Invalidate. + EXPECT_TRUE(IsPermissionDenied(mm_w.Invalidate())); + + EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) + .WillOnce(Return(OkStatus())); + mm_w.reset(); + + // Should fail to map for reading. + EXPECT_TRUE(IsPermissionDenied( + buffer->MapMemory<uint8_t>(MemoryAccess::kRead).status())); + + // Should fail to map for discard. + EXPECT_TRUE(IsPermissionDenied( + buffer->MapMemory<uint8_t>(MemoryAccess::kDiscardWrite).status())); +} + +TEST(MemoryMappingTest, WriteDiscardAccess) { + auto buffer = std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal, + MemoryAccess::kDiscardWrite, + BufferUsage::kAll, 128); + + // Should succeed to map for writing with discard. + EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, + MemoryAccess::kDiscardWrite, 0, 128, _)) + .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); + ASSERT_OK_AND_ASSIGN(auto mm_dw, + buffer->MapMemory<uint8_t>(MemoryAccess::kDiscardWrite)); + EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) + .WillOnce(Return(OkStatus())); + mm_dw.reset(); + + // Should also be ok to map for just writing. + EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, + MemoryAccess::kWrite, 0, 128, _)) + .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); + ASSERT_OK_AND_ASSIGN(auto mm_w, + buffer->MapMemory<uint8_t>(MemoryAccess::kWrite)); + EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) + .WillOnce(Return(OkStatus())); + mm_w.reset(); + + // Should fail to map for reading. + EXPECT_TRUE(IsPermissionDenied( + buffer->MapMemory<uint8_t>(MemoryAccess::kRead).status())); +} + +TEST(MemoryMappingTest, Subspan) { + auto buffer = + std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal, + MemoryAccess::kAll, BufferUsage::kAll, 128); + EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, + MemoryAccess::kRead, 0, 128, _)) + .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); + ASSERT_OK_AND_ASSIGN(auto mm_r, + buffer->MapMemory<uint8_t>(MemoryAccess::kRead)); + + // Request some valid ranges and ensure the byte offsets are correct. + ASSERT_OK_AND_ASSIGN(auto ss, mm_r.Subspan()); + EXPECT_EQ(kValidPtr, ss.data()); + EXPECT_EQ(128, ss.size()); + ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(100, 2)); + EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 100, ss.data()); + EXPECT_EQ(2, ss.size()); + ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(100, kWholeBuffer)); + EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 100, ss.data()); + EXPECT_EQ(28, ss.size()); + + // Zero length ranges are fine. + ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(0, 0)); + EXPECT_EQ(kValidPtr, ss.data()); + EXPECT_TRUE(ss.empty()); + ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(128, 0)); + EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 128, ss.data()); + EXPECT_TRUE(ss.empty()); + ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(128, kWholeBuffer)); + EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 128, ss.data()); + EXPECT_TRUE(ss.empty()); + + EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) + .WillOnce(Return(OkStatus())); + mm_r.reset(); +} + +TEST(MemoryMappingTest, SubspanOutOfRange) { + auto buffer = + std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal, + MemoryAccess::kAll, BufferUsage::kAll, 128); + EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, + MemoryAccess::kRead, 0, 128, _)) + .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); + ASSERT_OK_AND_ASSIGN(auto mm_r, + buffer->MapMemory<uint8_t>(MemoryAccess::kRead)); + + // Try some invalid ranges that would overrun the span. + EXPECT_TRUE(IsOutOfRange(mm_r.Subspan(1234, 0).status())); + EXPECT_TRUE(IsOutOfRange(mm_r.Subspan(1234, 2).status())); + EXPECT_TRUE(IsOutOfRange(mm_r.Subspan(1234, kWholeBuffer).status())); + EXPECT_TRUE(IsOutOfRange(mm_r.Subspan(100, 1234).status())); + + EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) + .WillOnce(Return(OkStatus())); + mm_r.reset(); +} + +TEST(MemoryMappingTest, MutableSubspan) { + auto buffer = + std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal, + MemoryAccess::kAll, BufferUsage::kAll, 128); + EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, + MemoryAccess::kWrite, 0, 128, _)) + .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); + ASSERT_OK_AND_ASSIGN(auto mm_w, + buffer->MapMemory<uint8_t>(MemoryAccess::kWrite)); + + // Request some valid ranges and ensure the byte offsets are correct. + ASSERT_OK_AND_ASSIGN(auto ss, mm_w.MutableSubspan()); + EXPECT_EQ(kValidPtr, ss.data()); + EXPECT_EQ(128, ss.size()); + ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(100, 2)); + EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 100, ss.data()); + EXPECT_EQ(2, ss.size()); + ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(100, kWholeBuffer)); + EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 100, ss.data()); + EXPECT_EQ(28, ss.size()); + + // Zero length ranges are fine. + ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(0, 0)); + EXPECT_EQ(kValidPtr, ss.data()); + EXPECT_TRUE(ss.empty()); + ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(128, 0)); + EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 128, ss.data()); + EXPECT_TRUE(ss.empty()); + ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(128, kWholeBuffer)); + EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 128, ss.data()); + EXPECT_TRUE(ss.empty()); + + EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) + .WillOnce(Return(OkStatus())); + mm_w.reset(); +} + +TEST(MemoryMappingTest, MutableSubspanOutOfRange) { + auto buffer = + std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal, + MemoryAccess::kAll, BufferUsage::kAll, 128); + EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, + MemoryAccess::kWrite, 0, 128, _)) + .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); + ASSERT_OK_AND_ASSIGN(auto mm_w, + buffer->MapMemory<uint8_t>(MemoryAccess::kWrite)); + + // Try some invalid ranges that would overrun the span. + EXPECT_TRUE(IsOutOfRange(mm_w.MutableSubspan(1234, 0).status())); + EXPECT_TRUE(IsOutOfRange(mm_w.MutableSubspan(1234, 2).status())); + EXPECT_TRUE(IsOutOfRange(mm_w.MutableSubspan(1234, kWholeBuffer).status())); + EXPECT_TRUE(IsOutOfRange(mm_w.MutableSubspan(100, 1234).status())); + + EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) + .WillOnce(Return(OkStatus())); + mm_w.reset(); +} + +TEST(MemoryMappingTest, ElementOperator) { + auto buffer = + std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal, + MemoryAccess::kAll, BufferUsage::kAll, 128); + EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, + MemoryAccess::kRead, 0, 128, _)) + .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); + ASSERT_OK_AND_ASSIGN(auto mm_r, + buffer->MapMemory<uint8_t>(MemoryAccess::kRead)); + + // Just verify we are getting the expected pointer back. + EXPECT_EQ(kValidPtr, &mm_r[0]); + + EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) + .WillOnce(Return(OkStatus())); + mm_r.reset(); +} + +TEST(MemoryMappingTest, Invalidate) { + auto buffer = + std::make_shared<MockBuffer>(nullptr, MemoryType::kHostVisible, + MemoryAccess::kAll, BufferUsage::kAll, 128); + EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, + MemoryAccess::kRead, 0, 128, _)) + .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); + ASSERT_OK_AND_ASSIGN(auto mm_r, + buffer->MapMemory<uint8_t>(MemoryAccess::kRead)); + + // Invalidate a few ways. + EXPECT_CALL(*buffer, InvalidateMappedMemoryImpl(0, 128)) + .WillOnce(Return(OkStatus())); + EXPECT_OK(mm_r.Invalidate()); + EXPECT_CALL(*buffer, InvalidateMappedMemoryImpl(100, 2)) + .WillOnce(Return(OkStatus())); + EXPECT_OK(mm_r.Invalidate(100, 2)); + EXPECT_CALL(*buffer, InvalidateMappedMemoryImpl(100, 28)) + .WillOnce(Return(OkStatus())); + EXPECT_OK(mm_r.Invalidate(100, kWholeBuffer)); + + EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) + .WillOnce(Return(OkStatus())); + mm_r.reset(); +} + +TEST(MemoryMappingTest, InvalidateOutOfRange) { + auto buffer = + std::make_shared<MockBuffer>(nullptr, MemoryType::kHostVisible, + MemoryAccess::kAll, BufferUsage::kAll, 128); + EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, + MemoryAccess::kRead, 0, 128, _)) + .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); + ASSERT_OK_AND_ASSIGN(auto mm_r, + buffer->MapMemory<uint8_t>(MemoryAccess::kRead)); + + // Try to invalidate invalid ranges. + EXPECT_TRUE(IsOutOfRange(mm_r.Invalidate(1234, 0))); + EXPECT_TRUE(IsOutOfRange(mm_r.Invalidate(1234, 12345))); + EXPECT_TRUE(IsOutOfRange(mm_r.Invalidate(1234, kWholeBuffer))); + EXPECT_TRUE(IsOutOfRange(mm_r.Invalidate(1, 1234))); + + EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) + .WillOnce(Return(OkStatus())); + mm_r.reset(); +} + +TEST(MemoryMappingTest, InvalidateBadMode) { + // Invalidate is not required on coherent memory. + auto coherent_buffer = + std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal, + MemoryAccess::kAll, BufferUsage::kAll, 128); + EXPECT_CALL(*coherent_buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, + MemoryAccess::kRead, 0, 128, _)) + .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); + ASSERT_OK_AND_ASSIGN( + auto mm_r, coherent_buffer->MapMemory<uint8_t>(MemoryAccess::kRead)); + EXPECT_TRUE(IsPermissionDenied(mm_r.Invalidate())); + EXPECT_CALL(*coherent_buffer, UnmapMemoryImpl(0, 128, kValidPtr)) + .WillOnce(Return(OkStatus())); + mm_r.reset(); +} + +TEST(MemoryMappingTest, Flush) { + auto buffer = std::make_shared<MockBuffer>( + nullptr, MemoryType::kHostVisible | MemoryType::kHostCached, + MemoryAccess::kAll, BufferUsage::kAll, 128); + EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, + MemoryAccess::kWrite, 0, 128, _)) + .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); + ASSERT_OK_AND_ASSIGN(auto mm_w, + buffer->MapMemory<uint8_t>(MemoryAccess::kWrite)); + + // Flush a few ways. + EXPECT_CALL(*buffer, FlushMappedMemoryImpl(0, 128)) + .WillOnce(Return(OkStatus())); + EXPECT_OK(mm_w.Flush()); + EXPECT_CALL(*buffer, FlushMappedMemoryImpl(100, 2)) + .WillOnce(Return(OkStatus())); + EXPECT_OK(mm_w.Flush(100, 2)); + EXPECT_CALL(*buffer, FlushMappedMemoryImpl(100, 28)) + .WillOnce(Return(OkStatus())); + EXPECT_OK(mm_w.Flush(100, kWholeBuffer)); + + EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) + .WillOnce(Return(OkStatus())); + mm_w.reset(); +} + +TEST(MemoryMappingTest, FlushOutOfRange) { + auto buffer = std::make_shared<MockBuffer>( + nullptr, MemoryType::kHostVisible | MemoryType::kHostCached, + MemoryAccess::kAll, BufferUsage::kAll, 128); + EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, + MemoryAccess::kWrite, 0, 128, _)) + .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); + ASSERT_OK_AND_ASSIGN(auto mm_w, + buffer->MapMemory<uint8_t>(MemoryAccess::kWrite)); + + // Try to flush invalid ranges. + EXPECT_TRUE(IsOutOfRange(mm_w.Flush(1234, 0))); + EXPECT_TRUE(IsOutOfRange(mm_w.Flush(1234, 12345))); + EXPECT_TRUE(IsOutOfRange(mm_w.Flush(1234, kWholeBuffer))); + EXPECT_TRUE(IsOutOfRange(mm_w.Flush(1, 1234))); + + EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) + .WillOnce(Return(OkStatus())); + mm_w.reset(); +} + +TEST(MemoryMappingTest, FlushBadMode) { + // Flush is not required on uncached memory. + auto uncached_buffer = + std::make_shared<MockBuffer>(nullptr, MemoryType::kHostVisible, + MemoryAccess::kAll, BufferUsage::kAll, 128); + EXPECT_CALL(*uncached_buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, + MemoryAccess::kWrite, 0, 128, _)) + .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); + ASSERT_OK_AND_ASSIGN( + auto mm_w, uncached_buffer->MapMemory<uint8_t>(MemoryAccess::kWrite)); + EXPECT_TRUE(IsPermissionDenied(mm_w.Flush())); + EXPECT_CALL(*uncached_buffer, UnmapMemoryImpl(0, 128, kValidPtr)) + .WillOnce(Return(OkStatus())); + mm_w.reset(); +} + +} // namespace +} // namespace hal +} // namespace iree
diff --git a/hal/buffer_test.cc b/hal/buffer_test.cc new file mode 100644 index 0000000..31f0448 --- /dev/null +++ b/hal/buffer_test.cc
@@ -0,0 +1,1000 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tests for the shared buffer functionality and host heap buffers. +// This does not test device-specific buffer implementations; see the device +// code for associated tests. + +#include "hal/buffer.h" + +#include <vector> + +#include "base/status_matchers.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "hal/heap_buffer.h" + +namespace iree { +namespace hal { +namespace { + +using ::testing::_; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::Not; + +TEST(BufferTest, Allocate) { + auto buffer = + HeapBuffer::Allocate(BufferUsage::kTransfer | BufferUsage::kMapping, 14); + EXPECT_NE(nullptr, buffer->allocator()); + EXPECT_EQ(MemoryAccess::kAll, buffer->allowed_access()); + EXPECT_EQ(MemoryType::kHostLocal, buffer->memory_type()); + EXPECT_EQ(BufferUsage::kTransfer | BufferUsage::kMapping, buffer->usage()); + + // We don't currently do any padding on the host. + // Other implementations may differ. + EXPECT_GE(14, buffer->allocation_size()); + EXPECT_EQ(0, buffer->byte_offset()); + EXPECT_EQ(14, buffer->byte_length()); + + // Data should be zeroed by default. + std::vector<uint8_t> zero_data(buffer->allocation_size()); + std::vector<uint8_t> actual_data(buffer->allocation_size()); + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, Eq(zero_data)); +} + +TEST(BufferTest, AllocateZeroLength) { + auto buffer = + HeapBuffer::Allocate(BufferUsage::kTransfer | BufferUsage::kMapping, 0); + EXPECT_NE(nullptr, buffer->allocator()); + EXPECT_EQ(MemoryType::kHostLocal, buffer->memory_type()); + EXPECT_EQ(BufferUsage::kTransfer | BufferUsage::kMapping, buffer->usage()); + EXPECT_EQ(0, buffer->allocation_size()); +} + +TEST(BufferTest, AllocateCopy) { + std::vector<uint8_t> src_data = {0, 1, 2, 3}; + auto buffer = + HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, + src_data.data(), src_data.size()); + EXPECT_NE(nullptr, buffer->allocator()); + EXPECT_GE(src_data.size(), buffer->allocation_size()); + + // Data should have been copied. + std::vector<uint8_t> actual_data(src_data.size()); + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, Eq(src_data)); + + // Modify the source data and ensure it is not reflected in the buffer. + src_data[0] = 0x88; + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, Not(Eq(src_data))); +} + +TEST(BufferTest, AllocateCopyZeroLength) { + std::vector<uint8_t> src_data; + auto buffer = + HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, + src_data.data(), src_data.size()); + EXPECT_NE(nullptr, buffer->allocator()); + EXPECT_EQ(0, buffer->allocation_size()); +} + +TEST(BufferTest, AllocateCopyTyped) { + std::vector<int32_t> src_data = {0, 1, 2, 3}; + auto buffer = + HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, + absl::MakeConstSpan(src_data)); + EXPECT_NE(nullptr, buffer->allocator()); + EXPECT_EQ(MemoryType::kHostLocal, buffer->memory_type()); + EXPECT_EQ(BufferUsage::kTransfer | BufferUsage::kMapping, buffer->usage()); + EXPECT_GE(src_data.size() * sizeof(int32_t), buffer->allocation_size()); + + // Data should have been copied. + std::vector<int32_t> actual_data(src_data.size()); + EXPECT_OK(buffer->ReadData(0, actual_data.data(), + actual_data.size() * sizeof(int32_t))); + EXPECT_THAT(actual_data, Eq(src_data)); +} + +TEST(BufferTest, WrapConstant) { + std::vector<uint8_t> src_data = {0, 1, 2, 3}; + auto buffer = HeapBuffer::Wrap(MemoryType::kHostLocal, + BufferUsage::kTransfer | BufferUsage::kMapping, + absl::MakeConstSpan(src_data)); + EXPECT_EQ(MemoryType::kHostLocal, buffer->memory_type()); + EXPECT_EQ(BufferUsage::kTransfer | BufferUsage::kMapping, buffer->usage()); + EXPECT_EQ(src_data.size(), buffer->allocation_size()); + + // src_data and buffer should match after the wrapping. + std::vector<uint8_t> actual_data(src_data.size()); + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, Eq(src_data)); + + // Modify the source data directly. + src_data[0] = 123; + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, Eq(src_data)); + + // Attempts to modify the buffer should fail. + std::vector<uint8_t> new_data = {3, 2, 1, 0}; + EXPECT_TRUE(IsPermissionDenied( + buffer->WriteData(0, new_data.data(), new_data.size()))); +} + +TEST(BufferTest, WrapMutable) { + std::vector<uint8_t> src_data = {0, 1, 2, 3}; + auto buffer = HeapBuffer::WrapMutable( + MemoryType::kHostLocal, MemoryAccess::kAll, + BufferUsage::kTransfer | BufferUsage::kMapping, absl::MakeSpan(src_data)); + EXPECT_EQ(MemoryType::kHostLocal, buffer->memory_type()); + EXPECT_EQ(BufferUsage::kTransfer | BufferUsage::kMapping, buffer->usage()); + EXPECT_EQ(src_data.size(), buffer->allocation_size()); + + // src_data and buffer should match after the wrapping. + std::vector<uint8_t> actual_data(src_data.size()); + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, Eq(src_data)); + + // Modify the source data directly. + src_data[0] = 123; + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, Eq(src_data)); + + // Modify the source data via the Buffer and ensure reflected in src_data. + std::vector<uint8_t> new_data = {3, 2, 1, 0}; + EXPECT_OK(buffer->WriteData(0, new_data.data(), new_data.size())); + EXPECT_THAT(src_data, Eq(new_data)); +} + +TEST(BufferTest, WrapExternal) { + // This is not fully supported yet, but does let us verify that the validation + // of memory types is working. + std::vector<uint8_t> src_data = {0, 1, 2, 3}; + auto buffer = HeapBuffer::Wrap(MemoryType::kDeviceLocal, BufferUsage::kAll, + absl::MakeConstSpan(src_data)); + EXPECT_EQ(MemoryType::kDeviceLocal, buffer->memory_type()); + + // Should fail (for now) as the buffer is not host visible. + EXPECT_TRUE(IsPermissionDenied(buffer->Fill8(0, kWholeBuffer, 0x99u))); +} + +TEST(BufferTest, DoesOverlap) { + std::vector<uint8_t> src_data = {0, 1, 2, 3}; + auto parent_buffer = + HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, + src_data.data(), src_data.size()); + + // A buffer should overlap with itself. + EXPECT_FALSE(Buffer::DoesOverlap(parent_buffer.get(), 0, 1, + parent_buffer.get(), 1, 1)); + EXPECT_TRUE(Buffer::DoesOverlap(parent_buffer.get(), 0, 1, + parent_buffer.get(), 0, 1)); + + // Zero length buffers never overlap. + EXPECT_FALSE(Buffer::DoesOverlap(parent_buffer.get(), 1, 1, + parent_buffer.get(), 1, 0)); + + // Subspans should offset within their allocation. + ASSERT_OK_AND_ASSIGN(auto subspan_buffer_0, + Buffer::Subspan(parent_buffer, 1, 2)); + ASSERT_OK_AND_ASSIGN(auto subspan_buffer_1, + Buffer::Subspan(parent_buffer, 2, 2)); + EXPECT_FALSE(Buffer::DoesOverlap(subspan_buffer_0.get(), 0, 1, + subspan_buffer_1.get(), 0, 1)); + EXPECT_TRUE(Buffer::DoesOverlap(subspan_buffer_0.get(), 1, 1, + subspan_buffer_1.get(), 0, 1)); + + // Mixing subspans and normal buffers. + EXPECT_FALSE(Buffer::DoesOverlap(parent_buffer.get(), 0, 1, + subspan_buffer_0.get(), 0, 1)); + EXPECT_TRUE(Buffer::DoesOverlap(parent_buffer.get(), 1, 2, + subspan_buffer_0.get(), 1, 1)); + + // Independent buffers should not be able to overlap. + auto other_buffer = HeapBuffer::Allocate(BufferUsage::kAll, 128); + EXPECT_FALSE(Buffer::DoesOverlap(parent_buffer.get(), 0, kWholeBuffer, + other_buffer.get(), 0, kWholeBuffer)); +} + +TEST(BufferTest, Subspan) { + std::vector<uint8_t> src_data = {0, 1, 2, 3}; + auto parent_buffer = + HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, + src_data.data(), src_data.size()); + ASSERT_TRUE(parent_buffer); + + // Create a subspan of the buffer. + ASSERT_OK_AND_ASSIGN(auto subspan_buffer, + Buffer::Subspan(parent_buffer, 1, 2)); + ASSERT_TRUE(subspan_buffer); + EXPECT_EQ(1, subspan_buffer->byte_offset()); + EXPECT_EQ(2, subspan_buffer->byte_length()); + + // Modifications to either buffer should appear in the other. + EXPECT_OK(subspan_buffer->Fill8(1, kWholeBuffer, 0xFFu)); + std::vector<uint8_t> actual_data(src_data.size()); + EXPECT_OK(parent_buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, ElementsAre(0, 1, 0xFF, 3)); + + // Subspans should be able to create subspans. + // NOTE: offset is from the original buffer. + ASSERT_OK_AND_ASSIGN(auto subsubspan_buffer, + Buffer::Subspan(subspan_buffer, 1, 1)); + ASSERT_TRUE(subsubspan_buffer); + EXPECT_EQ(2, subsubspan_buffer->byte_offset()); + EXPECT_EQ(1, subsubspan_buffer->byte_length()); + + // Zero length subspans are fine. + ASSERT_OK_AND_ASSIGN(auto zero_subspan_buffer, + Buffer::Subspan(parent_buffer, 0, 0)); + ASSERT_TRUE(zero_subspan_buffer); + EXPECT_EQ(0, zero_subspan_buffer->byte_offset()); + EXPECT_EQ(0, zero_subspan_buffer->byte_length()); + + // Subspan with kWholeBuffer should get the remaining size (or zero). + ASSERT_OK_AND_ASSIGN(auto whole_subspan_buffer, + Buffer::Subspan(parent_buffer, 1, kWholeBuffer)); + ASSERT_TRUE(whole_subspan_buffer); + EXPECT_EQ(1, whole_subspan_buffer->byte_offset()); + EXPECT_EQ(3, whole_subspan_buffer->byte_length()); + + // Zero length subspans are fine. + ASSERT_OK(Buffer::Subspan(subspan_buffer, 2, 0)); + ASSERT_OK(Buffer::Subspan(subspan_buffer, 2, kWholeBuffer)); +} + +TEST(BufferTest, SubspanIdentity) { + std::vector<uint8_t> src_data = {0, 1, 2, 3}; + auto parent_buffer = + HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, + src_data.data(), src_data.size()); + + // Asking for a subspan of the entire buffer should return the same buffer. + // Mostly an optimization. + EXPECT_EQ(parent_buffer.get(), + Buffer::Subspan(parent_buffer, 0, kWholeBuffer).ValueOrDie().get()); + EXPECT_EQ(parent_buffer.get(), + Buffer::Subspan(parent_buffer, 0, 4).ValueOrDie().get()); +} + +TEST(BufferTest, SubspanOutOfRange) { + std::vector<uint8_t> src_data = {0, 1, 2, 3}; + auto parent_buffer = + HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, + src_data.data(), src_data.size()); + ASSERT_TRUE(parent_buffer); + + // Create a subspan of the buffer. + ASSERT_OK_AND_ASSIGN(auto subspan_buffer, + Buffer::Subspan(parent_buffer, 1, 2)); + ASSERT_TRUE(subspan_buffer); + EXPECT_EQ(1, subspan_buffer->byte_offset()); + EXPECT_EQ(2, subspan_buffer->byte_length()); + + // Try to make subspans from invalid ranges. + EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(parent_buffer, 5, 0).status())); + EXPECT_TRUE( + IsOutOfRange(Buffer::Subspan(parent_buffer, 5, kWholeBuffer).status())); + EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(parent_buffer, 4, 1).status())); + EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(parent_buffer, 0, 123).status())); + EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(subspan_buffer, 1, 2).status())); + EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(subspan_buffer, 0, 44).status())); +} + +TEST(BufferTest, Fill8) { + auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 5); + ASSERT_TRUE(buffer); + + // Data should be zeroed by default. + std::vector<uint8_t> actual_data(buffer->allocation_size()); + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, ElementsAre(0, 0, 0, 0, 0)); + + // Fill with a sentinel. + EXPECT_OK(buffer->Fill8(0, buffer->allocation_size(), 0x33u)); + + // Verify data. + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, ElementsAre(0x33, 0x33, 0x33, 0x33, 0x33)); + + // Zero fills are fine. + EXPECT_OK(buffer->Fill8(0, 0, 0x44u)); + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, ElementsAre(0x33, 0x33, 0x33, 0x33, 0x33)); + + // Fill the remaining parts of the buffer by using kWholeBuffer. + EXPECT_OK(buffer->Fill8(2, kWholeBuffer, 0x55u)); + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, ElementsAre(0x33, 0x33, 0x55, 0x55, 0x55)); + + // Fill a small region of the buffer. + EXPECT_OK(buffer->Fill8(1, 1, 0x66u)); + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, ElementsAre(0x33, 0x66, 0x55, 0x55, 0x55)); + + // Whole buffer helper. + EXPECT_OK(buffer->Fill8(0x99u)); + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, ElementsAre(0x99, 0x99, 0x99, 0x99, 0x99)); +} + +TEST(BufferTest, Fill8OutOfRange) { + auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 5); + ASSERT_TRUE(buffer); + + // Fill with a sentinel. + EXPECT_OK(buffer->Fill8(0, buffer->allocation_size(), 0x33u)); + + // Try to fill with invalid ranges. + EXPECT_TRUE(IsOutOfRange(buffer->Fill8(1, 444, 0x44u))); + EXPECT_TRUE(IsOutOfRange(buffer->Fill8(123, 444, 0x44u))); + EXPECT_TRUE(IsOutOfRange(buffer->Fill8(123, 1, 0x44u))); + EXPECT_TRUE(IsOutOfRange(buffer->Fill8(1, 444, 0x44u))); + + // Ensure nothing happened with the bad ranges. + std::vector<uint8_t> actual_data(buffer->allocation_size()); + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, ElementsAre(0x33, 0x33, 0x33, 0x33, 0x33)); +} + +TEST(BufferTest, Fill8BadMode) { + // Fail to fill buffers not supporting mapping. + auto nonmapping_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4); + EXPECT_TRUE( + IsPermissionDenied(nonmapping_buffer->Fill8(0, kWholeBuffer, 0x99u))); + + // Fail to fill constant buffers. + std::vector<uint8_t> const_data = {1, 2, 3}; + auto constant_buffer = + HeapBuffer::Wrap(MemoryType::kHostLocal, BufferUsage::kMapping, + absl::MakeConstSpan(const_data)); + EXPECT_TRUE( + IsPermissionDenied(constant_buffer->Fill8(0, kWholeBuffer, 0x99u))); +} + +TEST(BufferTest, Fill8Subspan) { + auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 5); + ASSERT_TRUE(buffer); + + // Test on subspan. + std::vector<uint8_t> actual_data(buffer->allocation_size()); + ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 1, 3)); + EXPECT_OK(subspan_buffer->Fill8(2, kWholeBuffer, 0xDDu)); + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, ElementsAre(0, 0, 0, 0xDD, 0)); +} + +TEST(BufferTest, Fill16) { + auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9); + ASSERT_TRUE(buffer); + + // Data should be zeroed by default. + std::vector<uint8_t> actual_data(buffer->allocation_size()); + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, ElementsAre(0, 0, 0, 0, 0, 0, 0, 0, 0)); + + // Fill with a sentinel. + EXPECT_OK(buffer->Fill16(0, 4, 0x1122u)); + + // Verify data. + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, ElementsAre(0x22, 0x11, 0x22, 0x11, 0, 0, 0, 0, 0)); + + // Zero fills are fine. + EXPECT_OK(buffer->Fill16(0, 0, 0x5566u)); + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, ElementsAre(0x22, 0x11, 0x22, 0x11, 0, 0, 0, 0, 0)); + + // Fill the remaining parts of the buffer by using kWholeBuffer. + auto aligned_buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 8); + EXPECT_OK(aligned_buffer->Fill16(4, kWholeBuffer, 0x5566u)); + std::vector<uint8_t> aligned_actual_data(aligned_buffer->allocation_size()); + EXPECT_OK(aligned_buffer->ReadData(0, aligned_actual_data.data(), + aligned_actual_data.size())); + EXPECT_THAT(aligned_actual_data, + ElementsAre(0, 0, 0, 0, 0x66, 0x55, 0x66, 0x55)); + + // Whole buffer helper. + EXPECT_OK(aligned_buffer->Fill16(0x5566u)); + EXPECT_OK(aligned_buffer->ReadData(0, aligned_actual_data.data(), + aligned_actual_data.size())); + EXPECT_THAT(aligned_actual_data, + ElementsAre(0x66, 0x55, 0x66, 0x55, 0x66, 0x55, 0x66, 0x55)); +} + +TEST(BufferTest, Fill16OutOfRange) { + auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9); + ASSERT_TRUE(buffer); + + // Try to fill with invalid ranges. + EXPECT_TRUE(IsOutOfRange(buffer->Fill16(4, 444, 0x5566u))); + EXPECT_TRUE(IsOutOfRange(buffer->Fill16(128, 444, 0x5566u))); + EXPECT_TRUE(IsOutOfRange(buffer->Fill16(128, 4, 0x5566u))); + EXPECT_TRUE(IsOutOfRange(buffer->Fill16(4, 444, 0x5566u))); +} + +TEST(BufferTest, Fill16Unaligned) { + auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9); + ASSERT_TRUE(buffer); + + // Try to fill with unaligned ranges. + EXPECT_TRUE(IsInvalidArgument(buffer->Fill16(1, 4, 0x5566u))); + EXPECT_TRUE(IsInvalidArgument(buffer->Fill16(0, 5, 0x5566u))); +} + +TEST(BufferTest, Fill16BadMode) { + // Fail to fill buffers not supporting mapping. + auto nonmapping_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4); + EXPECT_TRUE( + IsPermissionDenied(nonmapping_buffer->Fill16(0, kWholeBuffer, 0x99AAu))); + + // Fail to fill constant buffers. + std::vector<uint8_t> const_data = {1, 2, 3}; + auto constant_buffer = + HeapBuffer::Wrap(MemoryType::kHostLocal, BufferUsage::kMapping, + absl::MakeConstSpan(const_data)); + EXPECT_TRUE( + IsPermissionDenied(constant_buffer->Fill16(0, kWholeBuffer, 0x99AAu))); +} + +TEST(BufferTest, Fill16Subspan) { + auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9); + ASSERT_TRUE(buffer); + + // Fill with a sentinel. + EXPECT_OK(buffer->Fill16(0, 4, 0x1122u)); + + // Test on subspan. + std::vector<uint8_t> actual_data(buffer->allocation_size()); + ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 2, 4)); + EXPECT_OK(subspan_buffer->Fill16(2, kWholeBuffer, 0xAABBu)); + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, + ElementsAre(0x22, 0x11, 0x22, 0x11, 0xBB, 0xAA, 0, 0, 0)); +} + +TEST(BufferTest, Fill32) { + auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9); + ASSERT_TRUE(buffer); + + // Data should be zeroed by default. + std::vector<uint8_t> actual_data(buffer->allocation_size()); + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, ElementsAre(0, 0, 0, 0, 0, 0, 0, 0, 0)); + + // Fill with a sentinel. + EXPECT_OK(buffer->Fill32(0, 8, 0x11223344u)); + + // Verify data. + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, + ElementsAre(0x44, 0x33, 0x22, 0x11, 0x44, 0x33, 0x22, 0x11, 0)); + + // Zero fills are fine. + EXPECT_OK(buffer->Fill32(0, 0, 0x55667788u)); + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, + ElementsAre(0x44, 0x33, 0x22, 0x11, 0x44, 0x33, 0x22, 0x11, 0)); + + // Fill the remaining parts of the buffer by using kWholeBuffer. + auto aligned_buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 8); + EXPECT_OK(aligned_buffer->Fill32(4, kWholeBuffer, 0x55667788u)); + std::vector<uint8_t> aligned_actual_data(aligned_buffer->allocation_size()); + EXPECT_OK(aligned_buffer->ReadData(0, aligned_actual_data.data(), + aligned_actual_data.size())); + EXPECT_THAT(aligned_actual_data, + ElementsAre(0, 0, 0, 0, 0x88, 0x77, 0x66, 0x55)); + + // Whole buffer helper. + EXPECT_OK(aligned_buffer->Fill32(0x55667788u)); + EXPECT_OK(aligned_buffer->ReadData(0, aligned_actual_data.data(), + aligned_actual_data.size())); + EXPECT_THAT(aligned_actual_data, + ElementsAre(0x88, 0x77, 0x66, 0x55, 0x88, 0x77, 0x66, 0x55)); +} + +TEST(BufferTest, Fill32OutOfRange) { + auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9); + ASSERT_TRUE(buffer); + + // Try to fill with invalid ranges. + EXPECT_TRUE(IsOutOfRange(buffer->Fill32(4, 444, 0x55667788u))); + EXPECT_TRUE(IsOutOfRange(buffer->Fill32(128, 444, 0x55667788u))); + EXPECT_TRUE(IsOutOfRange(buffer->Fill32(128, 4, 0x55667788u))); + EXPECT_TRUE(IsOutOfRange(buffer->Fill32(4, 444, 0x55667788u))); +} + +TEST(BufferTest, Fill32Unaligned) { + auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9); + ASSERT_TRUE(buffer); + + // Try to fill with unaligned ranges. + EXPECT_TRUE(IsInvalidArgument(buffer->Fill32(1, 4, 0x55667788u))); + EXPECT_TRUE(IsInvalidArgument(buffer->Fill32(0, 5, 0x55667788u))); +} + +TEST(BufferTest, Fill32BadMode) { + // Fail to fill buffers not supporting mapping. + auto nonmapping_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4); + EXPECT_TRUE(IsPermissionDenied( + nonmapping_buffer->Fill32(0, kWholeBuffer, 0x99AABBCCu))); + + // Fail to fill constant buffers. + std::vector<uint8_t> const_data = {1, 2, 3}; + auto constant_buffer = + HeapBuffer::Wrap(MemoryType::kHostLocal, BufferUsage::kMapping, + absl::MakeConstSpan(const_data)); + EXPECT_TRUE(IsPermissionDenied( + constant_buffer->Fill32(0, kWholeBuffer, 0x99AABBCCu))); +} + +TEST(BufferTest, Fill32Subspan) { + auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9); + ASSERT_TRUE(buffer); + + // Fill with a sentinel. + EXPECT_OK(buffer->Fill32(0, 8, 0x11223344u)); + + // Test on subspan. + std::vector<uint8_t> actual_data(buffer->allocation_size()); + ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 4, 4)); + EXPECT_OK(subspan_buffer->Fill32(0, kWholeBuffer, 0xAABBCCDDu)); + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, + ElementsAre(0x44, 0x33, 0x22, 0x11, 0xDD, 0xCC, 0xBB, 0xAA, 0)); +} + +TEST(BufferTest, ReadData) { + std::vector<uint8_t> src_data = {0, 1, 2, 3}; + auto buffer = + HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, + src_data.data(), src_data.size()); + ASSERT_TRUE(buffer); + + // Read the data back. + std::vector<uint8_t> actual_data(src_data.size()); + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, Eq(src_data)); + + // Reading zero bytes is valid. + std::vector<uint8_t> zero_data(0); + EXPECT_OK(buffer->ReadData(1, zero_data.data(), 0)); + + // Read a portion of the data. + std::vector<uint8_t> partial_data(2); + EXPECT_OK(buffer->ReadData(1, partial_data.data(), 2)); + EXPECT_THAT(partial_data, ElementsAre(1, 2)); +} + +TEST(BufferTest, ReadDataOutOfRange) { + std::vector<uint8_t> src_data = {0, 1, 2, 3}; + auto buffer = + HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, + src_data.data(), src_data.size()); + ASSERT_TRUE(buffer); + + // Try to read out of range. + std::vector<uint8_t> partial_data(2); + EXPECT_TRUE(IsOutOfRange(buffer->ReadData(0, partial_data.data(), 444))); + EXPECT_TRUE(IsOutOfRange(buffer->ReadData(1230, partial_data.data(), 444))); + EXPECT_TRUE(IsOutOfRange(buffer->ReadData(1230, partial_data.data(), 1))); + EXPECT_TRUE(IsInvalidArgument( + buffer->ReadData(0, partial_data.data(), kWholeBuffer))); +} + +TEST(BufferTest, ReadDataBadMode) { + // Fail to read buffers not supporting mapping. + std::vector<uint8_t> actual_data(1); + auto nonmapping_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4); + EXPECT_TRUE(IsPermissionDenied( + nonmapping_buffer->ReadData(0, actual_data.data(), 1))); +} + +TEST(BufferTest, ReadDataSubspan) { + std::vector<uint8_t> src_data = {0, 1, 2, 3}; + auto buffer = + HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, + src_data.data(), src_data.size()); + ASSERT_TRUE(buffer); + + // Test on subspan. + std::vector<uint8_t> subspan_data(1); + ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 1, 2)); + EXPECT_OK(subspan_buffer->ReadData(1, subspan_data.data(), 1)); + EXPECT_THAT(subspan_data, ElementsAre(2)); +} + +TEST(BufferTest, WriteData) { + std::vector<uint8_t> src_data = {0, 1, 2, 3}; + auto buffer = + HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, + src_data.data(), src_data.size()); + ASSERT_TRUE(buffer); + + // Read the data back - should still match. + std::vector<uint8_t> actual_data(src_data.size()); + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, Eq(src_data)); + + // Write over the entire buffer. + std::vector<uint8_t> new_data = {10, 20, 30, 40}; + EXPECT_OK(buffer->WriteData(0, new_data.data(), new_data.size())); + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, Eq(new_data)); + + // Writing zero bytes is valid. + std::vector<uint8_t> zero_data; + EXPECT_OK(buffer->WriteData(0, zero_data.data(), 0)); + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, Eq(new_data)); + + // Write over a portion of the buffer. + std::vector<uint8_t> partial_data = {99}; + EXPECT_OK(buffer->WriteData(1, partial_data.data(), partial_data.size())); + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, ElementsAre(10, 99, 30, 40)); +} + +TEST(BufferTest, WriteDataOutOfRange) { + std::vector<uint8_t> src_data = {0, 1, 2, 3}; + auto buffer = + HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, + src_data.data(), src_data.size()); + ASSERT_TRUE(buffer); + + // Try to write out of range. + std::vector<uint8_t> partial_data = {99}; + EXPECT_TRUE(IsOutOfRange(buffer->WriteData(0, partial_data.data(), 444))); + EXPECT_TRUE(IsOutOfRange(buffer->WriteData(1230, partial_data.data(), 444))); + EXPECT_TRUE(IsOutOfRange(buffer->WriteData(1230, partial_data.data(), 1))); + EXPECT_TRUE(IsInvalidArgument( + buffer->WriteData(0, partial_data.data(), kWholeBuffer))); +} + +TEST(BufferTest, WriteDataBadMode) { + std::vector<uint8_t> actual_data(4); + + // Fail to write buffers not supporting mapping. + auto nonmapping_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4); + EXPECT_TRUE(IsPermissionDenied( + nonmapping_buffer->WriteData(0, actual_data.data(), 1))); + + // Fail to write to constant buffers. + std::vector<uint8_t> const_data = {1, 2, 3}; + auto constant_buffer = + HeapBuffer::Wrap(MemoryType::kHostLocal, BufferUsage::kTransfer, + absl::MakeConstSpan(const_data)); + EXPECT_TRUE( + IsPermissionDenied(constant_buffer->WriteData(0, actual_data.data(), 2))); +} + +TEST(BufferTest, WriteDataSubspan) { + std::vector<uint8_t> src_data = {0, 1, 2, 3}; + auto buffer = + HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, + src_data.data(), src_data.size()); + ASSERT_TRUE(buffer); + + // Test on subspan. + std::vector<uint8_t> subspan_data = {0xAA}; + ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 1, 2)); + EXPECT_OK(subspan_buffer->WriteData(1, subspan_data.data(), 1)); + std::vector<uint8_t> actual_data(src_data.size()); + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, ElementsAre(0, 1, 0xAA, 3)); +} + +TEST(BufferTest, CopyData) { + std::vector<uint8_t> src_data = {0, 1, 2, 3}; + auto src_buffer = + HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer, + src_data.data(), src_data.size()); + ASSERT_TRUE(src_buffer); + std::vector<uint8_t> dst_data = {0, 1, 2, 3, 4}; + auto dst_buffer = + HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer, + dst_data.data(), dst_data.size()); + ASSERT_TRUE(dst_buffer); + + // Copy of length 0 should not change the dest buffer. + EXPECT_OK(dst_buffer->CopyData(0, src_buffer.get(), 0, 0)); + std::vector<uint8_t> actual_data(dst_data.size()); + EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, Eq(dst_data)); + + // Copy a subrange of the buffer. + EXPECT_OK(dst_buffer->CopyData(1, src_buffer.get(), 2, 2)); + EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, ElementsAre(0, 2, 3, 3, 4)); + + // Copy the entire buffer using kWholeBuffer. This will adjust sizes + // to ensure that the min buffer is taken. We test both src and dst buffer + // offset/length calculations (note that some may end up as 0 copies). + EXPECT_OK(dst_buffer->CopyData(3, src_buffer.get(), 0, kWholeBuffer)); + EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, ElementsAre(0, 2, 3, 0, 1)); + EXPECT_OK(dst_buffer->CopyData(0, src_buffer.get(), 2, kWholeBuffer)); + EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, ElementsAre(2, 3, 3, 0, 1)); + EXPECT_OK(dst_buffer->CopyData(0, src_buffer.get(), 3, kWholeBuffer)); + EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, ElementsAre(3, 3, 3, 0, 1)); + EXPECT_OK(dst_buffer->CopyData(4, src_buffer.get(), 0, kWholeBuffer)); + EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, ElementsAre(3, 3, 3, 0, 0)); +} + +TEST(BufferTest, CopyDataOutOfRange) { + std::vector<uint8_t> src_data = {0, 1, 2, 3}; + auto src_buffer = + HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer, + src_data.data(), src_data.size()); + ASSERT_TRUE(src_buffer); + std::vector<uint8_t> dst_data = {0, 1, 2, 3, 4}; + auto dst_buffer = + HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer, + dst_data.data(), dst_data.size()); + ASSERT_TRUE(dst_buffer); + + // Try to copy out of range of source and dest. + EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(123, src_buffer.get(), 0, 1))); + EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(4, src_buffer.get(), 0, 4))); + EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(0, src_buffer.get(), 123, 1))); + EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(0, src_buffer.get(), 0, 123))); + EXPECT_TRUE( + IsOutOfRange(dst_buffer->CopyData(123, src_buffer.get(), 123, 123))); + EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(0, src_buffer.get(), 123, 0))); +} + +TEST(BufferTest, CopyDataOverlapping) { + std::vector<uint8_t> src_data = {0, 1, 2, 3}; + auto src_buffer = + HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer, + src_data.data(), src_data.size()); + ASSERT_TRUE(src_buffer); + std::vector<uint8_t> dst_data = {0, 1, 2, 3, 4}; + auto dst_buffer = + HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer, + dst_data.data(), dst_data.size()); + ASSERT_TRUE(dst_buffer); + + // Test overlap. Non-overlapping regions should be fine, otherwise fail. + std::vector<uint8_t> actual_data(dst_data.size()); + EXPECT_OK(dst_buffer->CopyData(0, dst_buffer.get(), 4, 1)); + EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, ElementsAre(4, 1, 2, 3, 4)); + EXPECT_TRUE( + IsInvalidArgument(dst_buffer->CopyData(2, dst_buffer.get(), 0, 3))); + EXPECT_TRUE( + IsInvalidArgument(dst_buffer->CopyData(0, dst_buffer.get(), 0, 3))); +} + +TEST(BufferTest, CopyDataBadMode) { + // Both source and target buffers must support mapping. + auto nonmapping_src_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4); + auto nonmapping_dst_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4); + EXPECT_TRUE(IsPermissionDenied(nonmapping_dst_buffer->CopyData( + 0, nonmapping_src_buffer.get(), 0, kWholeBuffer))); + EXPECT_TRUE(IsPermissionDenied(nonmapping_src_buffer->CopyData( + 0, nonmapping_dst_buffer.get(), 0, kWholeBuffer))); + + // Fail to copy into to constant buffers. + std::vector<uint8_t> const_data = {1, 2, 3}; + auto constant_buffer = + HeapBuffer::Wrap(MemoryType::kHostLocal, BufferUsage::kTransfer, + absl::MakeConstSpan(const_data)); + std::vector<uint8_t> src_data = {0, 1, 2, 3}; + auto src_buffer = + HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer, + src_data.data(), src_data.size()); + EXPECT_TRUE(IsPermissionDenied( + constant_buffer->CopyData(0, src_buffer.get(), 0, kWholeBuffer))); +} + +TEST(BufferTest, CopyDataSubspan) { + std::vector<uint8_t> src_data = {0, 1, 2, 3}; + auto src_buffer = + HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer, + src_data.data(), src_data.size()); + ASSERT_TRUE(src_buffer); + std::vector<uint8_t> dst_data = {0, 1, 2, 3, 4}; + auto dst_buffer = + HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer, + dst_data.data(), dst_data.size()); + ASSERT_TRUE(dst_buffer); + + // Test on subspan. + std::vector<uint8_t> actual_data(dst_data.size()); + ASSERT_OK_AND_ASSIGN(auto subspan_src_buffer, + Buffer::Subspan(src_buffer, 1, 3)); + ASSERT_OK_AND_ASSIGN(auto subspan_dst_buffer, + Buffer::Subspan(dst_buffer, 2, 3)); + EXPECT_OK(subspan_dst_buffer->CopyData(1, subspan_src_buffer.get(), 1, 2)); + EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, ElementsAre(0, 1, 2, 2, 3)); +} + +// NOTE: more tests related specifically to MappedMemory are in +// buffer_mapping_test.cc. This tests the MapMemory operation and enough to +// ensure the memory was mapped to the correct range and the HostBuffer and +// SubspanBuffer work as intended for basic usage. +TEST(BufferTest, MapMemory) { + std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6}; + auto buffer = HeapBuffer::AllocateCopy( + BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kRead, + src_data.data(), src_data.size()); + ASSERT_TRUE(buffer); + + // 0-length mappings are valid. + ASSERT_OK_AND_ASSIGN(auto mapping, + buffer->MapMemory<uint8_t>(MemoryAccess::kRead, 0, 0)); + EXPECT_TRUE(mapping.empty()); + EXPECT_EQ(0, mapping.size()); + EXPECT_EQ(0, mapping.byte_length()); + EXPECT_NE(nullptr, mapping.data()); + ASSERT_OK_AND_ASSIGN(auto span, mapping.Subspan()); + EXPECT_TRUE(span.empty()); + mapping.reset(); + + // Map the whole buffer for reading. + ASSERT_OK_AND_ASSIGN(mapping, buffer->MapMemory<uint8_t>(MemoryAccess::kRead, + 0, kWholeBuffer)); + EXPECT_EQ(src_data.size(), mapping.size()); + ASSERT_OK_AND_ASSIGN(span, mapping.Subspan()); + EXPECT_THAT(span, ElementsAre(0, 1, 2, 3, 4, 5, 6)); + mapping.reset(); + + // Map a portion of the buffer for reading. + ASSERT_OK_AND_ASSIGN(mapping, + buffer->MapMemory<uint8_t>(MemoryAccess::kRead, 1, 2)); + EXPECT_EQ(2, mapping.size()); + ASSERT_OK_AND_ASSIGN(span, mapping.Subspan()); + EXPECT_THAT(span, ElementsAre(1, 2)); + mapping.reset(); +} + +TEST(BufferTest, MapMemoryNonByte) { + std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6}; + auto buffer = HeapBuffer::AllocateCopy( + BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kRead, + src_data.data(), src_data.size()); + ASSERT_TRUE(buffer); + + // Map the buffer as non-byte values. + // Note that we'll round down to the number of valid elements at the + // alignment. + ASSERT_OK_AND_ASSIGN(auto mapping16, + buffer->MapMemory<uint16_t>(MemoryAccess::kRead)); + EXPECT_EQ(3, mapping16.size()); + EXPECT_LE(6, mapping16.byte_length()); + ASSERT_OK_AND_ASSIGN(auto span16, mapping16.Subspan()); + EXPECT_THAT(span16, ElementsAre(0x0100, 0x0302, 0x0504)); + mapping16.reset(); +} + +TEST(BufferTest, MapMemoryOutOfRange) { + std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6}; + auto buffer = HeapBuffer::AllocateCopy( + BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kRead, + src_data.data(), src_data.size()); + ASSERT_TRUE(buffer); + + // Test invalid mapping ranges. + EXPECT_TRUE(IsOutOfRange( + buffer->MapMemory<uint16_t>(MemoryAccess::kRead, 0, 123).status())); + EXPECT_TRUE(IsOutOfRange( + buffer->MapMemory<uint16_t>(MemoryAccess::kRead, 5, 1231).status())); + EXPECT_TRUE(IsOutOfRange( + buffer->MapMemory<uint16_t>(MemoryAccess::kRead, 6, kWholeBuffer) + .status())); + EXPECT_TRUE(IsOutOfRange( + buffer->MapMemory<uint16_t>(MemoryAccess::kRead, 1236, 1).status())); +} + +TEST(BufferTest, MapMemoryBadMode) { + std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6}; + auto read_buffer = HeapBuffer::AllocateCopy( + BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kRead, + src_data.data(), src_data.size()); + ASSERT_TRUE(read_buffer); + + // Test mapping the read-only buffer for writing. + EXPECT_TRUE(IsPermissionDenied( + read_buffer->MapMemory<uint8_t>(MemoryAccess::kWrite).status())); + EXPECT_TRUE(IsPermissionDenied( + read_buffer->MapMemory<uint8_t>(MemoryAccess::kDiscardWrite).status())); + EXPECT_TRUE(IsPermissionDenied( + read_buffer + ->MapMemory<uint8_t>(MemoryAccess::kRead | MemoryAccess::kDiscard) + .status())); + EXPECT_TRUE(IsInvalidArgument( + read_buffer->MapMemory<uint8_t>(MemoryAccess::kNone).status())); +} + +TEST(BufferTest, MapMemoryWrite) { + std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6}; + auto buffer = HeapBuffer::AllocateCopy( + BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kAll, + src_data.data(), src_data.size()); + ASSERT_TRUE(buffer); + + // Map and modify the data. We should see it when we read back. + ASSERT_OK_AND_ASSIGN(auto mapping, + buffer->MapMemory<uint8_t>(MemoryAccess::kWrite, 1, 2)); + auto mutable_data = mapping.mutable_data(); + mutable_data[0] = 0xAA; + mutable_data[1] = 0xBB; + mapping.reset(); + std::vector<uint8_t> actual_data(src_data.size()); + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, ElementsAre(0, 0xAA, 0xBB, 3, 4, 5, 6)); +} + +TEST(BufferTest, MapMemoryDiscard) { + std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6}; + auto buffer = HeapBuffer::AllocateCopy( + BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kAll, + src_data.data(), src_data.size()); + ASSERT_TRUE(buffer); + + // Map for discard. Note that we can't really rely on the value of the data + // so we just trust that it's been discarded. It's a hint, anyway. We can be + // sure that the data we didn't want to discard is the same though. + std::vector<uint8_t> actual_data(src_data.size()); + ASSERT_OK_AND_ASSIGN(auto mapping, buffer->MapMemory<uint8_t>( + MemoryAccess::kDiscardWrite, 1, 2)); + EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, ElementsAre(0, _, _, 3, 4, 5, 6)); + mapping.reset(); +} + +TEST(BufferTest, MapMemorySubspan) { + std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6}; + auto parent_buffer = HeapBuffer::AllocateCopy( + BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kAll, + src_data.data(), src_data.size()); + ASSERT_TRUE(parent_buffer); + ASSERT_OK_AND_ASSIGN(auto subspan_buffer, + Buffer::Subspan(parent_buffer, 1, 3)); + ASSERT_OK_AND_ASSIGN(auto mapping, subspan_buffer->MapMemory<uint8_t>( + MemoryAccess::kDiscardWrite, 1, 2)); + auto* mutable_data = mapping.mutable_data(); + mutable_data[0] = 0xCC; + mutable_data[1] = 0xDD; + mapping.reset(); + + std::vector<uint8_t> actual_data(src_data.size()); + EXPECT_OK(parent_buffer->ReadData(0, actual_data.data(), actual_data.size())); + EXPECT_THAT(actual_data, ElementsAre(0, 1, 0xCC, 0xDD, 4, 5, 6)); + + // Just here to make coverage happy; they are currently no-ops on the host. + // buffer_mapping_test.cc contains tests that ensure they are called + // correctly. + std::vector<uint8_t> external_data = {0, 1, 2, 3, 4}; + auto external_buffer = HeapBuffer::WrapMutable( + MemoryType::kHostVisible | MemoryType::kHostCached, MemoryAccess::kAll, + BufferUsage::kAll, absl::MakeSpan(external_data)); + ASSERT_OK_AND_ASSIGN(auto external_subspan_buffer, + Buffer::Subspan(external_buffer, 0, 1)); + ASSERT_OK_AND_ASSIGN( + mapping, external_subspan_buffer->MapMemory<uint8_t>(MemoryAccess::kAll)); + EXPECT_OK(mapping.Invalidate()); + EXPECT_OK(mapping.Flush()); +} + +} // namespace +} // namespace hal +} // namespace iree
diff --git a/hal/buffer_view.cc b/hal/buffer_view.cc new file mode 100644 index 0000000..ff02041 --- /dev/null +++ b/hal/buffer_view.cc
@@ -0,0 +1,180 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/buffer_view.h" + +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "base/source_location.h" +#include "base/status.h" +#include "hal/buffer.h" + +namespace iree { +namespace hal { + +namespace { +// Pretty prints an array, e.g. [1, 2, 3, 4] +inline std::string PrettyPrint(absl::Span<const int32_t> arr) { + return "[" + absl::StrJoin(arr, ",") + "]"; +} +} // namespace + +// static +bool BufferView::Equal(const BufferView& lhs, const BufferView& rhs) { + return lhs.buffer.get() == rhs.buffer.get() && + lhs.element_size == rhs.element_size && lhs.shape == rhs.shape; +} + +std::string BufferView::DebugStringShort() const { + if (element_size == 0) { + return "Ø"; + } + return shape.empty() ? std::to_string(element_size) + : absl::StrCat(absl::StrJoin(shape.subspan(), "x"), "x", + element_size); +} + +StatusOr<device_size_t> BufferView::CalculateOffset( + absl::Span<const int32_t> indices) const { + if (indices.empty()) { + return 0; + } else if (shape.empty() || indices.size() > shape.size()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Indices " << PrettyPrint(indices) + << " out of bounds of the rank of buffer_view " + << DebugStringShort(); + } + device_size_t offset = 0; + for (int i = 0; i < indices.size(); ++i) { + if (indices[i] >= shape[i]) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Indices[" << i << "]=" << indices[i] + << " out of bounds of buffer_view " << DebugStringShort(); + } + device_size_t axis_offset = indices[i]; + for (int j = i + 1; j < shape.size(); ++j) { + axis_offset *= shape[j]; + } + offset += axis_offset; + } + offset *= element_size; + return offset; +} + +StatusOr<BufferView> BufferView::Slice( + absl::Span<const int32_t> start_indices, + absl::Span<const int32_t> lengths) const { + if (start_indices.size() != shape.size()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Slice start_indices " << PrettyPrint(start_indices) + << " do not match rank of buffer_view " << DebugStringShort(); + } + if (start_indices.size() != lengths.size()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Slice start_indices " << PrettyPrint(start_indices) + << " and lengths " << PrettyPrint(lengths) + << " are not the same size"; + } + + // Buffer::Subspan only support contiguous memory. To ensure that this slice + // only requests such, we validate that the offset in the buffer between the + // start and end indices is the same as the requested size of the slice. + absl::InlinedVector<int32_t, 6> end_indices(lengths.size()); + device_size_t subspan_length = element_size; + for (int i = 0; i < lengths.size(); ++i) { + subspan_length *= lengths[i]; + end_indices[i] = start_indices[i] + lengths[i] - 1; + } + + ASSIGN_OR_RETURN(auto start_byte_offset, CalculateOffset(start_indices)); + // Also validates the ends are in bounds. + ASSIGN_OR_RETURN(auto end_byte_offset, CalculateOffset(end_indices)); + + auto offset_length = end_byte_offset - start_byte_offset + element_size; + if (subspan_length != offset_length) { + return UnimplementedErrorBuilder(IREE_LOC) + << "Slice for non-contiguous region of memory unimplemented. " + "start_indices: " + << PrettyPrint(start_indices) << " lengths: " << PrettyPrint(lengths) + << " " << subspan_length << " " << offset_length << " " + << PrettyPrint(end_indices); + } + + ASSIGN_OR_RETURN(auto new_buffer, + Buffer::Subspan(buffer, start_byte_offset, subspan_length)); + return BufferView(std::move(new_buffer), Shape(lengths), element_size); +} + +// static +Status BufferView::Copy(BufferView* src, + absl::Span<const int32_t> src_start_indices, + BufferView* dst, + absl::Span<const int32_t> dst_start_indices, + absl::Span<const int32_t> lengths) { + if (src_start_indices.size() != src->shape.size() || + dst_start_indices.size() != dst->shape.size() || + src_start_indices.size() != lengths.size() || + dst_start_indices.size() != lengths.size()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Src/dst shape/size mismatch: src=" << src->DebugStringShort() + << ", dst=" << dst->DebugStringShort() + << ", src_indices=" << PrettyPrint(src_start_indices) + << ", dst_indices=" << PrettyPrint(dst_start_indices) + << ", lengths=" << PrettyPrint(lengths); + } + + // Copies only support contiguous memory. To ensure that this copy + // only requests such, we validate that the offset in the buffer between the + // start and end indices is the same as the requested size of the copy. + absl::InlinedVector<int32_t, 4> src_end_indices(lengths.size()); + absl::InlinedVector<int32_t, 4> dst_end_indices(lengths.size()); + device_size_t total_length = src->element_size; + for (int i = 0; i < lengths.size(); ++i) { + total_length *= lengths[i]; + src_end_indices[i] = src_start_indices[i] + lengths[i] - 1; + dst_end_indices[i] = dst_start_indices[i] + lengths[i] - 1; + } + + ASSIGN_OR_RETURN(auto src_start_byte_offset, + src->CalculateOffset(src_start_indices)); + ASSIGN_OR_RETURN(auto src_end_byte_offset, + src->CalculateOffset(src_end_indices)); + ASSIGN_OR_RETURN(auto dst_start_byte_offset, + dst->CalculateOffset(dst_start_indices)); + ASSIGN_OR_RETURN(auto dst_end_byte_offset, + dst->CalculateOffset(dst_end_indices)); + + auto src_length = + src_end_byte_offset - src_start_byte_offset + src->element_size; + auto dst_length = + dst_end_byte_offset - dst_start_byte_offset + dst->element_size; + if (src_length != dst_length || src_length != total_length) { + return UnimplementedErrorBuilder(IREE_LOC) + << "Copy for non-contiguous region of memory unimplemented: " + << src->DebugStringShort() << ", dst=" << dst->DebugStringShort() + << ", src_indices=" << PrettyPrint(src_start_indices) + << ", dst_indices=" << PrettyPrint(dst_start_indices) + << ", lengths=" << PrettyPrint(lengths); + } + + RETURN_IF_ERROR(dst->buffer->CopyData(dst_start_byte_offset, + src->buffer.get(), + src_start_byte_offset, total_length)); + + return OkStatus(); +} + +} // namespace hal +} // namespace iree
diff --git a/hal/buffer_view.h b/hal/buffer_view.h new file mode 100644 index 0000000..ef5bc22 --- /dev/null +++ b/hal/buffer_view.h
@@ -0,0 +1,108 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_BUFFER_VIEW_H_ +#define IREE_HAL_BUFFER_VIEW_H_ + +#include <memory> +#include <ostream> + +#include "base/shape.h" +#include "hal/buffer.h" + +namespace iree { +namespace hal { + +struct BufferView { + // Returns true if the given buffer_views are exactly equal. + static bool Equal(const BufferView& lhs, const BufferView& rhs); + + BufferView() = default; + BufferView(ref_ptr<Buffer> buffer, Shape shape, int8_t element_size) noexcept + : buffer(std::move(buffer)), shape(shape), element_size(element_size) {} + + BufferView(const BufferView& other) noexcept + : buffer(add_ref(other.buffer)), + shape(other.shape), + element_size(other.element_size) {} + BufferView& operator=(const BufferView& other) noexcept { + buffer = add_ref(other.buffer); + shape = other.shape; + element_size = other.element_size; + return *this; + } + BufferView(BufferView&& other) noexcept + : buffer(std::move(other.buffer)), + shape(other.shape), + element_size(other.element_size) {} + BufferView& operator=(BufferView&& other) noexcept { + buffer = std::move(other.buffer); + shape = other.shape; + element_size = other.element_size; + return *this; + } + + // Returns a string useful for printing debug messages. + std::string DebugStringShort() const; + + // Total length of the valid view range in bytes. + device_size_t byte_length() const { + return shape.element_count() * element_size; + } + + // TODO(b/134586626): remove this when byte ranges are encoded in IR. + // Calculates a byte offset into the buffer_view at the given dimension + // indices. + StatusOr<device_size_t> CalculateOffset( + absl::Span<const int32_t> indices) const; + + // TODO(b/134586626): remove this when byte ranges are encoded in IR. + // Returns a view onto the given range of the buffer underlying this view. The + // returned view starts at the offset indicated by |start_indices| and has a + // shape of |lengths|. + // Only contiguous regions of memory are supported at the moment. + StatusOr<BufferView> Slice(absl::Span<const int32_t> start_indices, + absl::Span<const int32_t> lengths) const; + + // TODO(b/134586626): remove this when byte ranges are encoded in IR. + static Status Copy(BufferView* src, + absl::Span<const int32_t> src_start_indices, + BufferView* dst, + absl::Span<const int32_t> dst_start_indices, + absl::Span<const int32_t> lengths); + + ref_ptr<Buffer> buffer; + Shape shape; + int8_t element_size; + // TODO(benvanik): strides. +}; + +inline bool operator==(const BufferView& a, const BufferView& b) { + return BufferView::Equal(a, b); +} + +inline bool operator!=(const BufferView& a, const BufferView& b) { + return !(a == b); +} + +inline std::ostream& operator<<(std::ostream& stream, + const BufferView& buffer_view) { + stream << buffer_view.DebugStringShort(); + return stream; +} + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_BUFFER_VIEW_H_
diff --git a/hal/buffer_view_string_util.cc b/hal/buffer_view_string_util.cc new file mode 100644 index 0000000..130dabf --- /dev/null +++ b/hal/buffer_view_string_util.cc
@@ -0,0 +1,542 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/buffer_view_string_util.h" + +#include <functional> +#include <sstream> +#include <type_traits> + +#include "absl/strings/ascii.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/strip.h" +#include "absl/types/optional.h" +#include "base/source_location.h" +#include "base/status.h" +#include "hal/heap_buffer.h" + +namespace iree { +namespace hal { + +namespace { + +/* clang-format off */ +constexpr char kHexValue[256] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, 0, 0, 0, 0, // '0'..'9' + 0, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 'A'..'F' + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 'a'..'f' + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 +}; +/* clang-format on */ + +template <typename T> +void HexStringToBytes(const char* from, T to, ptrdiff_t num) { + for (int i = 0; i < num; i++) { + to[i] = (kHexValue[from[i * 2] & 0xFF] << 4) + + (kHexValue[from[i * 2 + 1] & 0xFF]); + } +} + +constexpr char kHexTable[513] = + "000102030405060708090a0b0c0d0e0f" + "101112131415161718191a1b1c1d1e1f" + "202122232425262728292a2b2c2d2e2f" + "303132333435363738393a3b3c3d3e3f" + "404142434445464748494a4b4c4d4e4f" + "505152535455565758595a5b5c5d5e5f" + "606162636465666768696a6b6c6d6e6f" + "707172737475767778797a7b7c7d7e7f" + "808182838485868788898a8b8c8d8e8f" + "909192939495969798999a9b9c9d9e9f" + "a0a1a2a3a4a5a6a7a8a9aaabacadaeaf" + "b0b1b2b3b4b5b6b7b8b9babbbcbdbebf" + "c0c1c2c3c4c5c6c7c8c9cacbcccdcecf" + "d0d1d2d3d4d5d6d7d8d9dadbdcdddedf" + "e0e1e2e3e4e5e6e7e8e9eaebecedeeef" + "f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff"; + +template <typename T> +void BytesToHexString(const unsigned char* src, T dest, ptrdiff_t num) { + auto dest_ptr = &dest[0]; + for (auto src_ptr = src; src_ptr != (src + num); ++src_ptr, dest_ptr += 2) { + const char* hex_p = &kHexTable[*src_ptr * 2]; + std::copy(hex_p, hex_p + 2, dest_ptr); + } +} + +// Returns true if the given type is represented as binary hex data. +bool IsBinaryType(absl::string_view type_str) { + return !type_str.empty() && absl::ascii_isdigit(type_str[0]); +} + +// Parses binary hex data. +Status ParseBinaryData(absl::string_view data_str, Buffer* buffer) { + data_str = absl::StripAsciiWhitespace(data_str); + ASSIGN_OR_RETURN(auto mapping, + buffer->MapMemory<uint8_t>(MemoryAccess::kDiscardWrite)); + auto contents = mapping.mutable_contents(); + size_t dst_i = 0; + size_t src_i = 0; + while (src_i < data_str.size() && dst_i < contents.size()) { + char c = data_str[src_i]; + if (absl::ascii_isspace(c) || c == ',') { + ++src_i; + continue; + } + if (src_i + 1 >= data_str.size()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Invalid input hex data (offset=" << src_i << ")"; + } + HexStringToBytes(data_str.data() + src_i, contents.data() + dst_i, 1); + src_i += 2; + ++dst_i; + } + if (dst_i < contents.size()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Too few elements to fill type; expected " << contents.size() + << " but only read " << dst_i; + } else if (data_str.size() - src_i > 0) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Input data string contains more elements than the underlying " + "buffer (" + << contents.size() << ")"; + } + return OkStatus(); +} + +// Prints binary hex data. +Status PrintBinaryData(int element_size, Buffer* buffer, size_t max_entries, + std::ostream* stream) { + max_entries *= element_size; // Counting bytes, but treat them as elements. + ASSIGN_OR_RETURN(auto mapping, + buffer->MapMemory<uint8_t>(MemoryAccess::kRead)); + auto contents = mapping.contents(); + char hex_buffer[8 * 2]; + for (size_t i = 0; i < std::min(max_entries, mapping.size()); + i += element_size) { + if (i > 0) *stream << " "; + BytesToHexString(contents.data() + i, hex_buffer, element_size); + *stream << hex_buffer; + } + if (mapping.size() > max_entries) *stream << "..."; + return OkStatus(); +} + +template <typename ElementType, typename Enabled = void> +struct SimpleStrToValue { + absl::optional<ElementType> operator()(absl::string_view text) const = delete; +}; + +template <typename IntegerType> +struct SimpleStrToValue< + IntegerType, + typename std::enable_if<(sizeof(IntegerType) < 4), void>::type> { + absl::optional<IntegerType> operator()(absl::string_view text) const { + int32_t value; + return absl::SimpleAtoi(text, &value) ? absl::optional<IntegerType>{value} + : absl::nullopt; + } +}; + +template <typename IntegerType> +struct SimpleStrToValue< + IntegerType, + typename std::enable_if<(sizeof(IntegerType) >= 4), void>::type> { + absl::optional<IntegerType> operator()(absl::string_view text) const { + IntegerType value; + return absl::SimpleAtoi(text, &value) ? absl::optional<IntegerType>{value} + : absl::nullopt; + } +}; + +template <> +struct SimpleStrToValue<float, void> { + absl::optional<float> operator()(absl::string_view text) const { + float value; + return absl::SimpleAtof(text, &value) ? absl::optional<float>{value} + : absl::nullopt; + } +}; + +template <> +struct SimpleStrToValue<double, void> { + absl::optional<double> operator()(absl::string_view text) const { + double value; + return absl::SimpleAtod(text, &value) ? absl::optional<double>{value} + : absl::nullopt; + } +}; + +template <typename T> +Status ParseNumericalDataElement(absl::string_view data_str, size_t token_start, + size_t token_end, absl::Span<T> contents, + int dst_i) { + if (dst_i >= contents.size()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Input data string contains more elements than the underlying " + "buffer (" + << contents.size() << ")"; + } + auto element_str = data_str.substr(token_start, token_end - token_start + 1); + auto element = SimpleStrToValue<T>()(element_str); + if (!element.has_value()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Unable to parse element " << dst_i << " = '" << element_str + << "'"; + } + contents[dst_i] = element.value(); + return OkStatus(); +} + +template <typename T> +Status ParseNumericalDataAsType(absl::string_view data_str, Buffer* buffer) { + ASSIGN_OR_RETURN(auto mapping, + buffer->MapMemory<T>(MemoryAccess::kDiscardWrite)); + auto contents = mapping.mutable_contents(); + size_t src_i = 0; + size_t dst_i = 0; + size_t token_start = std::string::npos; + while (src_i < data_str.size()) { + char c = data_str[src_i++]; + bool is_separator = + absl::ascii_isspace(c) || c == ',' || c == '[' || c == ']'; + if (token_start == std::string::npos) { + if (!is_separator) { + token_start = src_i - 1; + } + continue; + } else if (token_start != std::string::npos && !is_separator) { + continue; + } + RETURN_IF_ERROR(ParseNumericalDataElement<T>(data_str, token_start, + src_i - 2, contents, dst_i++)); + token_start = std::string::npos; + } + if (token_start != std::string::npos) { + RETURN_IF_ERROR(ParseNumericalDataElement<T>( + data_str, token_start, data_str.size() - 1, contents, dst_i++)); + } + if (dst_i < contents.size()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Input data string contains fewer elements than the underlying " + "buffer (expected " + << contents.size() << ")"; + } + return OkStatus(); +} + +// Parses numerical data (ints, floats, etc) in some typed form. +Status ParseNumericalData(absl::string_view type_str, + absl::string_view data_str, Buffer* buffer) { + if (type_str == "i8") { + return ParseNumericalDataAsType<int8_t>(data_str, buffer); + } else if (type_str == "u8") { + return ParseNumericalDataAsType<uint8_t>(data_str, buffer); + } else if (type_str == "i16") { + return ParseNumericalDataAsType<int16_t>(data_str, buffer); + } else if (type_str == "u16") { + return ParseNumericalDataAsType<uint16_t>(data_str, buffer); + } else if (type_str == "i32") { + return ParseNumericalDataAsType<int32_t>(data_str, buffer); + } else if (type_str == "u32") { + return ParseNumericalDataAsType<uint32_t>(data_str, buffer); + } else if (type_str == "i64") { + return ParseNumericalDataAsType<int64_t>(data_str, buffer); + } else if (type_str == "u64") { + return ParseNumericalDataAsType<uint64_t>(data_str, buffer); + } else if (type_str == "f32") { + return ParseNumericalDataAsType<float>(data_str, buffer); + } else if (type_str == "f64") { + return ParseNumericalDataAsType<double>(data_str, buffer); + } else { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Unsupported type: " << type_str; + } +} + +template <typename T> +void PrintElementList(const Shape& shape, absl::Span<const T> data, + size_t* max_entries, std::ostream* stream) { + if (shape.empty()) { + // Scalar value. + PrintElementList({1}, data, max_entries, stream); + return; + } else if (shape.size() == 1) { + // Leaf dimension; output data. + size_t max_count = std::min(*max_entries, static_cast<size_t>(shape[0])); + *stream << absl::StrJoin(data.subspan(0, max_count), " "); + if (max_count < shape[0]) { + *stream << "..."; + } + *max_entries -= max_count; + } else { + // Nested; recurse into next dimension. + Shape nested_shape = Shape(shape.subspan(1)); + size_t length = nested_shape.element_count(); + size_t offset = 0; + for (int i = 0; i < shape[0]; ++i) { + *stream << "["; + PrintElementList<T>(nested_shape, data.subspan(offset, length), + max_entries, stream); + offset += length; + *stream << "]"; + } + } +} + +template <typename T> +Status PrintNumericalDataAsType(const Shape& shape, Buffer* buffer, + size_t max_entries, std::ostream* stream) { + ASSIGN_OR_RETURN(auto mapping, buffer->MapMemory<T>(MemoryAccess::kRead)); + PrintElementList(shape, mapping.contents(), &max_entries, stream); + return OkStatus(); +} + +// Prints numerical data (ints, floats, etc) from some typed form. +Status PrintNumericalData(const Shape& shape, absl::string_view type_str, + Buffer* buffer, size_t max_entries, + std::ostream* stream) { + if (type_str == "i8") { + return PrintNumericalDataAsType<int8_t>(shape, buffer, max_entries, stream); + } else if (type_str == "u8") { + return PrintNumericalDataAsType<uint8_t>(shape, buffer, max_entries, + stream); + } else if (type_str == "i16") { + return PrintNumericalDataAsType<int16_t>(shape, buffer, max_entries, + stream); + } else if (type_str == "u16") { + return PrintNumericalDataAsType<uint16_t>(shape, buffer, max_entries, + stream); + } else if (type_str == "i32") { + return PrintNumericalDataAsType<int32_t>(shape, buffer, max_entries, + stream); + } else if (type_str == "u32") { + return PrintNumericalDataAsType<uint32_t>(shape, buffer, max_entries, + stream); + } else if (type_str == "i64") { + return PrintNumericalDataAsType<int64_t>(shape, buffer, max_entries, + stream); + } else if (type_str == "u64") { + return PrintNumericalDataAsType<uint64_t>(shape, buffer, max_entries, + stream); + } else if (type_str == "f32") { + return PrintNumericalDataAsType<float>(shape, buffer, max_entries, stream); + } else if (type_str == "f64") { + return PrintNumericalDataAsType<double>(shape, buffer, max_entries, stream); + } else { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Unsupported type: " << type_str; + } +} + +} // namespace + +StatusOr<int> GetTypeElementSize(absl::string_view type_str) { + if (type_str.empty()) { + return InvalidArgumentErrorBuilder(IREE_LOC) << "Type is empty"; + } else if (IsBinaryType(type_str)) { + // If the first character is a digit then we are dealign with binary data. + // The type is just the number of bytes per element. + int element_size = 0; + if (!absl::SimpleAtoi(type_str, &element_size)) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Unable to parse element size type '" << type_str << "'"; + } + return element_size; + } + // We know that our types are single characters followed by bit counts. + // If we start to support other types we may need to do something more clever. + int bit_count = 0; + if (!absl::SimpleAtoi(type_str.substr(1), &bit_count)) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Unable to parse type bit count from '" << type_str + << "'; expecting something like 'i32'"; + } + return bit_count / 8; +} + +StatusOr<Shape> ParseShape(absl::string_view shape_str) { + std::vector<int> dims; + for (auto dim_str : absl::StrSplit(shape_str, 'x', absl::SkipWhitespace())) { + int dim_value = 0; + if (!absl::SimpleAtoi(dim_str, &dim_value)) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Invalid shape dimension '" << dim_str + << "' while parsing shape '" << shape_str << "'"; + } + dims.push_back(dim_value); + } + return Shape{dims}; +} + +StatusOr<BufferView> ParseBufferViewFromString( + absl::string_view buffer_view_str, hal::Allocator* allocator) { + // Strip whitespace that may come along (linefeeds/etc). + buffer_view_str = absl::StripAsciiWhitespace(buffer_view_str); + if (buffer_view_str.empty()) { + // Empty lines denote empty buffer_views. + return BufferView{}; + } + + // Split into the components we can work with: shape, type, and data. + absl::string_view shape_and_type_str; + absl::string_view data_str; + auto equal_index = buffer_view_str.find('='); + if (equal_index == std::string::npos) { + // Treat a lack of = as defaulting the data to zeros. + shape_and_type_str = buffer_view_str; + } else { + shape_and_type_str = buffer_view_str.substr(0, equal_index); + data_str = buffer_view_str.substr(equal_index + 1); + } + absl::string_view shape_str; + absl::string_view type_str; + auto last_x_index = shape_and_type_str.rfind('x'); + if (last_x_index == std::string::npos) { + // Scalar. + type_str = shape_and_type_str; + } else { + // Has a shape. + shape_str = shape_and_type_str.substr(0, last_x_index); + type_str = shape_and_type_str.substr(last_x_index + 1); + } + + // Populate BufferView metadata required for allocation. + BufferView result; + ASSIGN_OR_RETURN(result.element_size, GetTypeElementSize(type_str)); + ASSIGN_OR_RETURN(result.shape, ParseShape(shape_str)); + + // Allocate the host buffer. + size_t allocation_size = result.shape.element_count() * result.element_size; + if (allocator) { + ASSIGN_OR_RETURN( + result.buffer, + allocator->Allocate(MemoryType::kHostLocal | MemoryType::kDeviceVisible, + BufferUsage::kAll | BufferUsage::kConstant, + allocation_size)); + } else { + result.buffer = HeapBuffer::Allocate( + MemoryType::kHostLocal, BufferUsage::kAll | BufferUsage::kConstant, + allocation_size); + } + + if (!data_str.empty()) { + // Parse the data from the string right into the buffer. + if (IsBinaryType(type_str)) { + // Parse as binary hex. + RETURN_IF_ERROR(ParseBinaryData(data_str, result.buffer.get())); + } else { + // Parse as some nicely formatted type. + RETURN_IF_ERROR( + ParseNumericalData(type_str, data_str, result.buffer.get())); + } + } + + return result; +} + +StatusOr<BufferViewPrintMode> ParseBufferViewPrintMode(absl::string_view str) { + char str_char = str.empty() ? '?' : str[0]; + switch (str_char) { + case 'b': + return BufferViewPrintMode::kBinary; + case 'i': + return BufferViewPrintMode::kSignedInteger; + case 'u': + return BufferViewPrintMode::kUnsignedInteger; + case 'f': + return BufferViewPrintMode::kFloatingPoint; + default: + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Unsupported output type '" << str << "'"; + } +} + +StatusOr<std::string> PrintBufferViewToString(const BufferView& buffer_view, + BufferViewPrintMode print_mode, + size_t max_entries) { + std::string result; + RETURN_IF_ERROR( + PrintBufferViewToString(buffer_view, print_mode, max_entries, &result)); + return result; +} + +Status PrintBufferViewToString(const BufferView& buffer_view, + BufferViewPrintMode print_mode, + size_t max_entries, std::string* out_result) { + std::ostringstream stream; + RETURN_IF_ERROR( + PrintBufferViewToStream(buffer_view, print_mode, max_entries, &stream)); + *out_result = stream.str(); + return OkStatus(); +} + +Status PrintBufferViewToStream(const BufferView& buffer_view, + BufferViewPrintMode print_mode, + size_t max_entries, std::ostream* stream) { + if (!buffer_view.buffer) { + // No buffer means the buffer_view is empty. We use the empty string to + // denote this (as we have no useful information). + return OkStatus(); + } + + // Pick a type based on the element size and the printing mode. + std::string type_str; + switch (print_mode) { + case BufferViewPrintMode::kBinary: + type_str = std::to_string(buffer_view.element_size); + break; + case BufferViewPrintMode::kSignedInteger: + absl::StrAppend(&type_str, "i", buffer_view.element_size * 8); + break; + case BufferViewPrintMode::kUnsignedInteger: + absl::StrAppend(&type_str, "u", buffer_view.element_size * 8); + break; + case BufferViewPrintMode::kFloatingPoint: + absl::StrAppend(&type_str, "f", buffer_view.element_size * 8); + break; + } + + // [shape]x[type]= prefix (taking into account scalar values). + *stream << absl::StrJoin(buffer_view.shape.begin(), buffer_view.shape.end(), + "x"); + if (!buffer_view.shape.empty()) *stream << "x"; + *stream << type_str; + *stream << "="; + + if (IsBinaryType(type_str)) { + return PrintBinaryData(buffer_view.element_size, buffer_view.buffer.get(), + max_entries, stream); + } else { + return PrintNumericalData(buffer_view.shape, type_str, + buffer_view.buffer.get(), max_entries, stream); + } +} + +} // namespace hal +} // namespace iree
diff --git a/hal/buffer_view_string_util.h b/hal/buffer_view_string_util.h new file mode 100644 index 0000000..0864ddb --- /dev/null +++ b/hal/buffer_view_string_util.h
@@ -0,0 +1,95 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Utilities for working with BufferView data, mostly useful for testing. +// These functions allow for conversion between types, parsing and printing, and +// basic comparisons. +// +// The canonical BufferView string format is: +// [shape]x[type]=value,value,... +// For example: +// 2x2xi32=0,1,2,3 +// Characters like [] are optional and will be ignored during parsing: +// 2x2xi32=[[0 1][2 3]] +// +// The type may be one of the following: +// * 1/2/4/8 = 1/2/4/8 byte elements in binary hex format. +// * i8/u8 = signed/unsigned 8-bit integers. +// * i16/u16 = signed/unsigned 16-bit integers. +// * i32/u32 = signed/unsigned 32-bit integers. +// * i64/u64 = signed/unsigned 64-bit integers. +// * f32 = 32-bit floating-point number. +// * f64 = 64-bit floating-point number. + +#ifndef IREE_HAL_BUFFER_VIEW_STRING_UTIL_H_ +#define IREE_HAL_BUFFER_VIEW_STRING_UTIL_H_ + +#include <ostream> +#include <string> + +#include "absl/strings/string_view.h" +#include "base/status.h" +#include "hal/allocator.h" +#include "hal/buffer_view.h" + +namespace iree { +namespace hal { + +// Returns the size, in bytes, of the given type. +StatusOr<int> GetTypeElementSize(absl::string_view type_str); + +// Returns a Shape parsed from the given NxMx... string. +StatusOr<Shape> ParseShape(absl::string_view shape_str); + +// Parses a BufferView encoded in a string. +// If an |allocator| is provided the buffer will be allocated as host-local and +// device-visible. Otherwise, buffers will be host-local. +// The format accepted matches that produced by PrintBufferViewToString. +StatusOr<BufferView> ParseBufferViewFromString( + absl::string_view buffer_view_str, hal::Allocator* allocator = nullptr); + +// Defines how the elements within a BufferView are interpreted during printing. +enum class BufferViewPrintMode { + // Interpret the data as if it were serialized bytes. + // In this mode no conversion is performed and the bytes in memory are printed + // as hex in groupings based on the element size. Shortened to 'b'. + kBinary, + // Interpret elements as signed integers; shortened to 'i'. + kSignedInteger, + // Interpret elements as unsigned integers; shortened to 'u'. + kUnsignedInteger, + // Interpret elements as floating-point values; shortened to 'f'. + kFloatingPoint, +}; + +// Returns the BufferViewPrintMode based on the shortened char in |str|. +StatusOr<BufferViewPrintMode> ParseBufferViewPrintMode(absl::string_view str); + +// Prints a BufferView to a string encoded in the canonical format. +StatusOr<std::string> PrintBufferViewToString(const BufferView& buffer_view, + BufferViewPrintMode print_mode, + size_t max_entries); +Status PrintBufferViewToString(const BufferView& buffer_view, + BufferViewPrintMode print_mode, + size_t max_entries, std::string* out_result); + +// Prints a BufferView to a string stream encoded in the canonical format. +Status PrintBufferViewToStream(const BufferView& buffer_view, + BufferViewPrintMode print_mode, + size_t max_entries, std::ostream* stream); + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_BUFFER_VIEW_STRING_UTIL_H_
diff --git a/hal/buffer_view_string_util_test.cc b/hal/buffer_view_string_util_test.cc new file mode 100644 index 0000000..0cf39ba --- /dev/null +++ b/hal/buffer_view_string_util_test.cc
@@ -0,0 +1,186 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/buffer_view_string_util.h" + +#include "base/status.h" +#include "base/status_matchers.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace iree { +namespace hal { +namespace { + +using ::testing::ElementsAre; + +template <typename T> +StatusOr<std::vector<T>> ReadBuffer(const ref_ptr<Buffer>& buffer) { + std::vector<T> result; + result.resize(buffer->byte_length() / sizeof(T)); + RETURN_IF_ERROR( + buffer->ReadData(0, result.data(), result.size() * sizeof(T))); + return result; +} + +TEST(BufferViewUtilTest, GetTypeElementSize) { + EXPECT_EQ(1, GetTypeElementSize("1").ValueOrDie()); + EXPECT_EQ(7, GetTypeElementSize("7").ValueOrDie()); + EXPECT_EQ(4, GetTypeElementSize("i32").ValueOrDie()); + EXPECT_EQ(8, GetTypeElementSize("f64").ValueOrDie()); + + EXPECT_FALSE(GetTypeElementSize("").ok()); + EXPECT_FALSE(GetTypeElementSize(" ").ok()); + EXPECT_FALSE(GetTypeElementSize("a").ok()); + EXPECT_FALSE(GetTypeElementSize("ib").ok()); + EXPECT_FALSE(GetTypeElementSize("i").ok()); + EXPECT_FALSE(GetTypeElementSize("i543ff").ok()); +} + +TEST(BufferViewUtilTest, ParseShape) { + EXPECT_EQ((Shape{}), ParseShape("").ValueOrDie()); + EXPECT_EQ((Shape{1}), ParseShape("1").ValueOrDie()); + EXPECT_EQ((Shape{1, 2}), ParseShape("1x2").ValueOrDie()); + EXPECT_EQ((Shape{1, 2}), ParseShape(" 1 x 2 ").ValueOrDie()); + + EXPECT_FALSE(ParseShape("abc").ok()); + EXPECT_FALSE(ParseShape("1xf").ok()); + EXPECT_FALSE(ParseShape("1xff23").ok()); +} + +TEST(BufferViewUtilTest, ParseBufferViewFromStringEmpty) { + // Empty string = empty buffer_view. + ASSERT_OK_AND_ASSIGN(auto m0, ParseBufferViewFromString("")); + EXPECT_EQ(nullptr, m0.buffer.get()); + EXPECT_EQ(Shape{}, m0.shape); + EXPECT_EQ(0, m0.element_size); + + // No = means no data. + ASSERT_OK_AND_ASSIGN(auto m1, ParseBufferViewFromString("4x2xf32")); + EXPECT_EQ(4 * 2 * 4, m1.buffer->allocation_size()); + EXPECT_EQ(Shape({4, 2}), m1.shape); + EXPECT_EQ(4, m1.element_size); + EXPECT_THAT(ReadBuffer<float>(m1.buffer).ValueOrDie(), + ElementsAre(0, 0, 0, 0, 0, 0, 0, 0)); + + // No data after = means no data. + ASSERT_OK_AND_ASSIGN(auto m2, ParseBufferViewFromString("4x2xf32=")); + EXPECT_EQ(4 * 2 * 4, m2.buffer->allocation_size()); + EXPECT_EQ(Shape({4, 2}), m2.shape); + EXPECT_EQ(4, m2.element_size); + EXPECT_THAT(ReadBuffer<float>(m2.buffer).ValueOrDie(), + ElementsAre(0, 0, 0, 0, 0, 0, 0, 0)); +} + +TEST(BufferViewUtilTest, ParseBufferViewFromStringBinary) { + ASSERT_OK_AND_ASSIGN(auto m0, ParseBufferViewFromString("4x1=00 01 02 03")); + EXPECT_EQ(Shape({4}), m0.shape); + EXPECT_EQ(1, m0.element_size); + EXPECT_THAT(ReadBuffer<uint8_t>(m0.buffer).ValueOrDie(), + ElementsAre(0, 1, 2, 3)); + + // Whitespace shouldn't matter. + ASSERT_OK_AND_ASSIGN(auto m1, ParseBufferViewFromString("4x1=00,010203")); + EXPECT_EQ(Shape({4}), m1.shape); + EXPECT_EQ(1, m1.element_size); + EXPECT_THAT(ReadBuffer<uint8_t>(m1.buffer).ValueOrDie(), + ElementsAre(0, 1, 2, 3)); + + // Should fail on malformed hex bytes. + EXPECT_FALSE(ParseBufferViewFromString("4x1=1").ok()); + EXPECT_FALSE(ParseBufferViewFromString("4x1=00003").ok()); + EXPECT_FALSE(ParseBufferViewFromString("4x1=%0123%\1").ok()); + EXPECT_FALSE(ParseBufferViewFromString("4x1=00010203040506").ok()); +} + +TEST(BufferViewUtilTest, ParseBufferViewFromStringAllowBrackets) { + ASSERT_OK_AND_ASSIGN(auto m0, + ParseBufferViewFromString("4xi16=[[0][ 1 ][2]][3]")); + EXPECT_EQ(Shape({4}), m0.shape); + EXPECT_EQ(2, m0.element_size); + EXPECT_THAT(ReadBuffer<int16_t>(m0.buffer).ValueOrDie(), + ElementsAre(0, 1, 2, 3)); +} + +TEST(BufferViewUtilTest, ParseBufferViewFromStringInteger) { + // Signed int16. + ASSERT_OK_AND_ASSIGN(auto m0, + ParseBufferViewFromString("4xi16=0 12345 65535 -2")); + EXPECT_EQ(Shape({4}), m0.shape); + EXPECT_EQ(2, m0.element_size); + EXPECT_THAT(ReadBuffer<int16_t>(m0.buffer).ValueOrDie(), + ElementsAre(0, 12345, -1, -2)); + + // Unsigned int16. + ASSERT_OK_AND_ASSIGN(auto m1, + ParseBufferViewFromString("4xu16=0 12345 65535 -2")); + EXPECT_EQ(Shape({4}), m1.shape); + EXPECT_EQ(2, m1.element_size); + EXPECT_THAT(ReadBuffer<uint16_t>(m1.buffer).ValueOrDie(), + ElementsAre(0, 12345, 65535, 65534)); + + // Mixing separator types is ok. + ASSERT_OK_AND_ASSIGN(auto m2, + ParseBufferViewFromString("4xu16=0, 12345, 65535, -2")); + EXPECT_EQ(Shape({4}), m2.shape); + EXPECT_EQ(2, m2.element_size); + EXPECT_THAT(ReadBuffer<uint16_t>(m2.buffer).ValueOrDie(), + ElementsAre(0, 12345, 65535, 65534)); + + // Should fail on malformed integers bytes and out of bounds values. + EXPECT_FALSE(ParseBufferViewFromString("4xi32=asodfj").ok()); + EXPECT_FALSE(ParseBufferViewFromString("4xi32=0 1 2 3 4").ok()); +} + +TEST(BufferViewUtilTest, ParseBufferViewFromStringFloat) { + // Float. + ASSERT_OK_AND_ASSIGN(auto m0, + ParseBufferViewFromString("4xf32=0 1.0 1234 -2.0e-5")); + EXPECT_EQ(Shape({4}), m0.shape); + EXPECT_EQ(4, m0.element_size); + EXPECT_THAT(ReadBuffer<float>(m0.buffer).ValueOrDie(), + ElementsAre(0.0f, 1.0f, 1234.0f, -2.0e-5f)); + + // Double. + ASSERT_OK_AND_ASSIGN(auto m1, ParseBufferViewFromString( + "4xf64=0 1.0 123456789012345 -2.0e-5")); + EXPECT_EQ(Shape({4}), m1.shape); + EXPECT_EQ(8, m1.element_size); + EXPECT_THAT(ReadBuffer<double>(m1.buffer).ValueOrDie(), + ElementsAre(0.0, 1.0, 123456789012345.0, -2.0e-5)); + + // Should fail on malformed floats and out of bounds values. + EXPECT_FALSE(ParseBufferViewFromString("4xf32=asodfj").ok()); + EXPECT_FALSE(ParseBufferViewFromString("4xf32=0").ok()); + EXPECT_FALSE(ParseBufferViewFromString("4xf32=0 1 2 3 4").ok()); +} + +TEST(BufferViewUtilTest, ParseBufferViewPrintMode) { + EXPECT_EQ(BufferViewPrintMode::kBinary, + ParseBufferViewPrintMode("b").ValueOrDie()); + EXPECT_EQ(BufferViewPrintMode::kSignedInteger, + ParseBufferViewPrintMode("i").ValueOrDie()); + EXPECT_EQ(BufferViewPrintMode::kUnsignedInteger, + ParseBufferViewPrintMode("u").ValueOrDie()); + EXPECT_EQ(BufferViewPrintMode::kFloatingPoint, + ParseBufferViewPrintMode("f").ValueOrDie()); + + EXPECT_FALSE(ParseBufferViewPrintMode("").ok()); + EXPECT_FALSE(ParseBufferViewPrintMode("s").ok()); + EXPECT_FALSE(ParseBufferViewPrintMode("asdfasdf").ok()); +} + +} // namespace +} // namespace hal +} // namespace iree
diff --git a/hal/buffer_view_test.cc b/hal/buffer_view_test.cc new file mode 100644 index 0000000..efc9699 --- /dev/null +++ b/hal/buffer_view_test.cc
@@ -0,0 +1,285 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/buffer_view.h" + +#include <numeric> +#include <vector> + +#include "base/status.h" +#include "base/status_matchers.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "hal/buffer.h" +#include "hal/heap_buffer.h" + +namespace iree { +namespace hal { +namespace { + +template <typename T> +BufferView MakeView(const std::vector<T> src_data, Shape shape) { + auto parent_buffer = HeapBuffer::AllocateCopy( + BufferUsage::kTransfer | BufferUsage::kMapping, absl::MakeSpan(src_data)); + + return BufferView(std::move(parent_buffer), shape, sizeof(T)); +} + +template <typename T> +std::vector<T> ReadData(BufferView view) { + std::vector<T> data(view.shape.element_count()); + EXPECT_OK(view.buffer->ReadData(0, data.data(), data.size() * sizeof(T))); + return data; +} + +TEST(BufferViewTest, SliceWholeBuffer) { + std::vector<uint8_t> src_data = {0, 1, 2, 3}; + Shape shape = {2, 2}; + auto parent_view = MakeView(src_data, shape); + + std::vector<int32_t> start_indices = {0, 0}; + std::vector<int32_t> lengths = {2, 2}; + ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths)); + + EXPECT_TRUE(BufferView::Equal(parent_view, slice)) + << "original parent_view " << parent_view.DebugStringShort() + << " and whole slice " << slice.DebugStringShort() << " are not equal"; +} + +TEST(BufferViewTest, SliceSingleRow) { + std::vector<uint8_t> src_data = {0, 1, 2, 3}; + Shape shape = {2, 2}; + auto parent_view = MakeView(src_data, shape); + + std::vector<int32_t> start_indices = {1, 0}; + std::vector<int32_t> lengths = {1, 2}; + ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths)); + + EXPECT_EQ(ReadData<uint8_t>(slice), std::vector<uint8_t>({2, 3})); +} + +TEST(BufferViewTest, SliceRowStart) { + std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6, 7}; + Shape shape = {2, 4}; + auto parent_view = MakeView(src_data, shape); + + std::vector<int32_t> start_indices = {1, 0}; + std::vector<int32_t> lengths = {1, 3}; + ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths)); + + EXPECT_EQ(ReadData<uint8_t>(slice), std::vector<uint8_t>({4, 5, 6})); +} + +TEST(BufferViewTest, SliceRowEnd) { + std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6, 7}; + Shape shape = {2, 4}; + auto parent_view = MakeView(src_data, shape); + + std::vector<int32_t> start_indices = {1, 1}; + std::vector<int32_t> lengths = {1, 3}; + ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths)); + + EXPECT_EQ(ReadData<uint8_t>(slice), std::vector<uint8_t>({5, 6, 7})); +} + +TEST(BufferViewTest, SliceRowMiddle) { + std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6, 7}; + Shape shape = {2, 4}; + auto parent_view = MakeView(src_data, shape); + + std::vector<int32_t> start_indices = {1, 1}; + std::vector<int32_t> lengths = {1, 2}; + ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths)); + + EXPECT_EQ(ReadData<uint8_t>(slice), std::vector<uint8_t>({5, 6})); +} + +TEST(BufferViewTest, SliceMultiRow) { + std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6, 7, 8}; + Shape shape = {3, 3}; + auto parent_view = MakeView(src_data, shape); + + std::vector<int32_t> start_indices = {1, 0}; + std::vector<int32_t> lengths = {2, 3}; + ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths)); + + EXPECT_EQ(ReadData<uint8_t>(slice), std::vector<uint8_t>({3, 4, 5, 6, 7, 8})); +} + +TEST(BufferViewTest, SliceHighRank) { + std::vector<uint8_t> src_data(81); + std::iota(src_data.begin(), src_data.end(), 0); + Shape shape = {3, 3, 3, 3}; + + auto parent_view = MakeView(src_data, shape); + + std::vector<int32_t> start_indices = {1, 2, 2, 1}; + std::vector<int32_t> lengths = {1, 1, 1, 2}; + ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths)); + + EXPECT_EQ(ReadData<uint8_t>(slice), std::vector<uint8_t>({52, 53})); +} + +TEST(BufferViewTest, SliceModifySlice) { + std::vector<uint8_t> src_data = {0, 1, 2, 3}; + Shape shape = {2, 2}; + auto parent_view = MakeView(src_data, shape); + + std::vector<int32_t> start_indices = {1, 0}; + std::vector<int32_t> lengths = {1, 2}; + ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths)); + + EXPECT_OK(slice.buffer->Fill8(0, kWholeBuffer, 0xFFu)); + + auto parent_data = ReadData<uint8_t>(parent_view); + EXPECT_EQ(parent_data, std::vector<uint8_t>({0, 1, 0xFFu, 0xFFu})); +} + +TEST(BufferViewTest, SliceModifyParent) { + std::vector<uint8_t> src_data = {0, 1, 2, 3}; + Shape shape = {2, 2}; + auto parent_view = MakeView(src_data, shape); + + std::vector<int32_t> start_indices = {1, 0}; + std::vector<int32_t> lengths = {1, 2}; + ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths)); + + EXPECT_OK(parent_view.buffer->Fill8(0, kWholeBuffer, 0xFFu)); + + EXPECT_EQ(ReadData<uint8_t>(slice), std::vector<uint8_t>({0xFFu, 0xFFu})); +} + +TEST(BufferViewTest, SliceMultiByteElementWholeBuffer) { + const std::vector<int32_t> src_data = {INT32_MAX, 1, 2, 3}; + + Shape shape = {2, 2}; + auto parent_view = MakeView(src_data, shape); + + std::vector<int32_t> start_indices = {0, 0}; + std::vector<int32_t> lengths = {2, 2}; + ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths)); + + EXPECT_TRUE(BufferView::Equal(parent_view, slice)) + << "original parent_view " << parent_view.DebugStringShort() + << " and whole slice " << slice.DebugStringShort() << " are not equal"; +} + +TEST(BufferViewTest, SliceShapeAndElementSize) { + std::vector<int32_t> src_data = {INT32_MAX, 1, 2, 3}; + Shape shape = {2, 2}; + auto parent_view = MakeView(src_data, shape); + + std::vector<int32_t> start_indices = {1, 0}; + std::vector<int32_t> lengths = {1, 2}; + ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths)); + EXPECT_EQ(slice.shape, Shape(lengths)); + EXPECT_EQ(slice.element_size, 4); +} + +TEST(BufferViewTest, SliceMultiByteElement) { + std::vector<int32_t> src_data = {INT32_MAX, 1, 2, 3}; + Shape shape = {2, 2}; + auto parent_view = MakeView(src_data, shape); + + std::vector<int32_t> start_indices = {1, 0}; + std::vector<int32_t> lengths = {1, 2}; + ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths)); + + EXPECT_EQ(ReadData<int32_t>(slice), std::vector<int32_t>({2, 3})); +} + +TEST(BufferViewTest, SliceIndexBadRank) { + std::vector<uint8_t> src_data = {0, 1, 2, 3}; + Shape shape = {2, 2}; + auto parent_view = MakeView(src_data, shape); + + std::vector<int32_t> start_indices = {0}; + std::vector<int32_t> lengths = {2}; + EXPECT_TRUE( + IsInvalidArgument(parent_view.Slice(start_indices, lengths).status())); +} + +TEST(BufferViewTest, SliceIndexLengthMismatch) { + std::vector<uint8_t> src_data = {0, 1, 2, 3}; + Shape shape = {2, 2}; + auto parent_view = MakeView(src_data, shape); + + std::vector<int32_t> start_indices = {0, 0}; + std::vector<int32_t> lengths = {2}; + EXPECT_TRUE( + IsInvalidArgument(parent_view.Slice(start_indices, lengths).status())); +} + +TEST(BufferViewTest, SliceIndicesOutOfBounds) { + std::vector<uint8_t> src_data = {0, 1, 2, 3}; + Shape shape = {2, 2}; + + auto parent_view = MakeView(src_data, shape); + + std::vector<int32_t> start_indices = {0, 3}; + std::vector<int32_t> lengths = {1, 1}; + EXPECT_TRUE( + IsInvalidArgument(parent_view.Slice(start_indices, lengths).status())); +} + +TEST(BufferViewTest, SliceLengthsOutOfBounds) { + std::vector<uint8_t> src_data = {0, 1, 2, 3}; + Shape shape = {2, 2}; + + auto parent_view = MakeView(src_data, shape); + + std::vector<int32_t> start_indices = {0, 0}; + std::vector<int32_t> lengths = {1, 3}; + EXPECT_TRUE( + IsInvalidArgument(parent_view.Slice(start_indices, lengths).status())); +} + +TEST(BufferViewTest, SliceNonContiguous) { + std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6, 7, 8}; + Shape shape = {3, 3}; + auto parent_view = MakeView(src_data, shape); + + std::vector<int32_t> start_indices = {1, 1}; + std::vector<int32_t> lengths = {2, 2}; + EXPECT_TRUE( + IsUnimplemented(parent_view.Slice(start_indices, lengths).status())); +} + +TEST(BufferViewTest, SliceNonContiguousMultiRowLeft) { + std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6, 7, 8}; + Shape shape = {3, 3}; + auto parent_view = MakeView(src_data, shape); + + std::vector<int32_t> start_indices = {1, 0}; + std::vector<int32_t> lengths = {2, 1}; + EXPECT_TRUE( + IsUnimplemented(parent_view.Slice(start_indices, lengths).status())); +} + +TEST(BufferViewTest, SliceHighRankNonContiguous) { + std::vector<uint8_t> src_data(81); + std::iota(src_data.begin(), src_data.end(), 0); + Shape shape = {3, 3, 3, 3}; + + auto parent_view = MakeView(src_data, shape); + + std::vector<int32_t> start_indices = {1, 0, 2, 1}; + std::vector<int32_t> lengths = {1, 2, 1, 2}; + EXPECT_TRUE( + IsUnimplemented(parent_view.Slice(start_indices, lengths).status())); +} + +} // namespace +} // namespace hal +} // namespace iree
diff --git a/hal/command_buffer.cc b/hal/command_buffer.cc new file mode 100644 index 0000000..d634c75 --- /dev/null +++ b/hal/command_buffer.cc
@@ -0,0 +1,29 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/command_buffer.h" + +namespace iree { +namespace hal { + +std::string CommandCategoryString(CommandCategoryBitfield categories) { + return FormatBitfieldValue(categories, + { + {CommandCategory::kTransfer, "kTransfer"}, + {CommandCategory::kDispatch, "kDispatch"}, + }); +} + +} // namespace hal +} // namespace iree
diff --git a/hal/command_buffer.h b/hal/command_buffer.h new file mode 100644 index 0000000..df9b40f --- /dev/null +++ b/hal/command_buffer.h
@@ -0,0 +1,383 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_COMMAND_BUFFER_H_ +#define IREE_HAL_COMMAND_BUFFER_H_ + +#include <cstdint> + +#include "base/bitfield.h" +#include "base/shape.h" +#include "base/status.h" +#include "hal/allocator.h" +#include "hal/buffer.h" +#include "hal/buffer_view.h" +#include "hal/event.h" +#include "hal/executable.h" +#include "hal/resource.h" + +namespace iree { +namespace hal { + +// A bitfield specifying the mode of operation for a command buffer. +enum class CommandBufferMode : uint32_t { + // Command buffer will be submitted once and never used again. + // This may enable in-place patching of command buffers that reduce overhead + // when it's known that command buffers will not be reused. + kOneShot = 1 << 0, +}; +IREE_BITFIELD(CommandBufferMode); +using CommandBufferModeBitfield = CommandBufferMode; +std::string CommandBufferModeString(CommandBufferModeBitfield mode); + +// A bitfield specifying the category of commands in a command queue. +enum class CommandCategory : uint32_t { + // Command is considered a transfer operation (memcpy, etc). + kTransfer = 1 << 0, + // Command is considered a dispatch operation (dispatch/execute). + kDispatch = 1 << 1, +}; +IREE_BITFIELD(CommandCategory); +using CommandCategoryBitfield = CommandCategory; +std::string CommandCategoryString(CommandCategoryBitfield categories); + +// Bitfield specifying which execution stage a brarrier should start/end at. +// +// Maps to VkPipelineStageFlagBits. +enum class ExecutionStage : uint32_t { + // Top of the pipeline when commands are initially issued by the device. + kCommandIssue = 1 << 0, + // Stage of the pipeline when dispatch parameter data is consumed. + kCommandProcess = 1 << 1, + // Stage where dispatch commands execute. + kDispatch = 1 << 2, + // Stage where transfer (copy/clear/fill/etc) commands execute. + kTransfer = 1 << 3, + // Final stage in the pipeline when commands are retired on the device. + kCommandRetire = 1 << 4, + // Pseudo-stage for read/writes by the host. Not executed on device. + kHost = 1 << 5, +}; +IREE_BITFIELD(ExecutionStage); +using ExecutionStageBitfield = ExecutionStage; + +// Bitfield specifying which scopes will access memory and how. +// +// Maps to VkAccessFlagBits. +enum class AccessScope : uint32_t { + // Read access to indirect command data as part of an indirect dispatch. + kIndirectCommandRead = 1 << 0, + // Constant uniform buffer reads by the device. + kConstantRead = 1 << 1, + // Storage buffer reads by dispatch commands. + kDispatchRead = 1 << 2, + // Storage buffer writes by dispatch commands. + kDispatchWrite = 1 << 3, + // Source of a transfer operation. + kTransferRead = 1 << 4, + // Target of a transfer operation. + kTransferWrite = 1 << 5, + // Read operation by the host through mapped memory. + kHostRead = 1 << 6, + // Write operation by the host through mapped memory. + kHostWrite = 1 << 7, + // External/non-specific read. + kMemoryRead = 1 << 8, + // External/non-specific write. + kMemoryWrite = 1 << 9, +}; +IREE_BITFIELD(AccessScope); +using AccessScopeBitfield = AccessScope; + +// Defines a global memory barrier. +// These are cheaper to encode than buffer-specific barriers but may cause +// stalls and bubbles in device pipelines if applied too broadly. Prefer them +// over equivalently large sets of buffer-specific barriers (such as when +// completely changing execution contexts). +// +// Maps to VkMemoryBarrier. +struct MemoryBarrier { + // All access scopes prior-to the barrier (inclusive). + AccessScopeBitfield source_scope; + // All access scopes following the barrier (inclusive). + AccessScopeBitfield target_scope; +}; + +// Defines a memory barrier that applies to a range of a specific buffer. +// Use of these (vs. global memory barriers) provides fine-grained execution +// ordering to device command processors and allows for more aggressive +// reordering. +// +// Maps to VkBufferMemoryBarrier. +struct BufferBarrier { + // All access scopes prior-to the barrier (inclusive). + AccessScopeBitfield source_scope; + // All access scopes following the barrier (inclusive). + AccessScopeBitfield target_scope; + // Buffer the barrier is restricted to. + // The barrier will apply to the entire physical device allocation. + Buffer* buffer = nullptr; + // Relative offset/length within |buffer| (which may itself be mapped into the + // device allocation at an offset). + device_size_t offset = 0; + device_size_t length = kWholeBuffer; +}; + +// Represents a binding to a buffer with a set of attributes. +// This may be used by drivers to validate alignment. +struct BufferBinding { + // Access rights of the buffer contents by the executable. + MemoryAccessBitfield access = MemoryAccess::kAll; + + // The buffer this binding references. + // The buffer is not retained by the binding and must be kept alive externally + // for the duration it is in use by the queue. + Buffer* buffer = nullptr; + + // Shape of the buffer contents. + Shape shape; + + // Size of each element within the buffer, in bytes. + int8_t element_size = 0; + + BufferBinding() = default; + BufferBinding(MemoryAccessBitfield access, Buffer* buffer) + : access(access), buffer(buffer) {} + BufferBinding(MemoryAccessBitfield access, Buffer* buffer, Shape shape, + int8_t element_size) + : access(access), + buffer(buffer), + shape(shape), + element_size(element_size) {} + BufferBinding(MemoryAccessBitfield access, const BufferView& buffer_view) + : access(access), + buffer(buffer_view.buffer.get()), + shape(buffer_view.shape), + element_size(buffer_view.element_size) {} +}; + +// Wraps parameters for a Dispatch request. +struct DispatchRequest { + // Executable prepared for use on the device. + // The executable must remain alive until all in-flight dispatch requests + // that use it have completed. + Executable* executable = nullptr; + + // Executable entry point ordinal. + int entry_point = 0; + + // TODO(benvanik): predication. + + // Static workload parameters defining the X, Y, and Z workgroup counts. + std::array<int32_t, 3> workload; + + // An optional buffer containing the dynamic workload to dispatch. + // The contents need not be available at the time of recording but must be + // made visible prior to execution of the dispatch command. + // + // Buffer contents are expected to be 3 int32 values defining the X, Y, and Z + // workgroup counts. + // + // The buffer must have been allocated with BufferUsage::kDispatch and be + // of MemoryType::kDeviceVisible. + Buffer* workload_buffer = nullptr; + + // A list of buffers that contain the execution inputs/outputs. + // Order is dependent on executable arg layout. + // + // Buffers must have been allocated with BufferUsage::kDispatch and be + // of MemoryType::kDeviceVisible. + absl::Span<const BufferBinding> bindings; + + // TODO(benvanik): push-constant equivalent (uniforms, etc). +}; + +// Asynchronous command buffer recording interface. +// Commands are recorded by the implementation for later submission to command +// queues. +// +// Buffers and synchronization objects referenced must remain valid and not be +// modified or read while there are commands in-flight. The usual flow is to +// populate input buffers, Dispatch using those buffers, wait on a Fence until +// the buffers are guaranteed to no longer be in use, and then reuse or release +// the buffers. +// +// Errors that can be recognized when operations are enqueued will be returned +// immediately, such as invalid argument errors. Errors that can only be +// determined at execution time will be returned on fences. Once a failure +// occurs the device queue will enter an error state that invalidates all +// operations on the device queue (as ordering is not strict and any may still +// be in-flight). In this case the user of the device queue should treat all +// in-flight operations as cancelled and fully reset themselves. Other device +// queues that may be waiting on events from the device queue will also enter +// error states. Only once a user has acknowledged and cleared the error state +// with a Reset the queue will become usable, and otherwise all operations will +// return errors. +// +// Command buffers are thread-compatible. Use multiple command buffers if trying +// to record commands from multiple threads. Command buffers must not be mutated +// between when they have are submitted for execution on a queue and when the +// fence fires indicating the completion of their execution. +class CommandBuffer : public Resource { + public: + virtual CommandBuffer* impl() { return this; } + + // Device allocator that commands encoded into the buffer share compatibility + // with. + Allocator* allocator() const { return allocator_; } + + // Command buffer operation mode. + CommandBufferModeBitfield mode() const { return mode_; } + + // Command categories that may be recorded into the buffer. + CommandCategoryBitfield command_categories() const { + return command_categories_; + } + + // True if the command buffer is between a Begin/End recording block. + virtual bool is_recording() const = 0; + + // Resets and begins recording into the command buffer, clearing all + // previously recorded contents. + // The command buffer must not be in-flight. + virtual Status Begin() = 0; + + // Ends recording into the command buffer. + // This must be called prior to submitting the command buffer for execution. + virtual Status End() = 0; + + // TODO(benvanik): annotations for debugging and tracing: + // enter/exit + // stack frame manipulation + // explicit timers? or profiling buffer? + + // TODO(b/138719910): cross-queue and external acquire/release. + // virtual Status AcquireBuffer() = 0; + // virtual Status ReleaseBuffer() = 0; + + // Defines a memory dependency between commands recorded before and after the + // barrier. One or more memory or buffer barriers can be specified to indicate + // between which stages or buffers the dependencies exist. + virtual Status ExecutionBarrier( + ExecutionStageBitfield source_stage_mask, + ExecutionStageBitfield target_stage_mask, + absl::Span<const MemoryBarrier> memory_barriers, + absl::Span<const BufferBarrier> buffer_barriers) = 0; + + // Sets an event to the signaled state. + // |source_stage_mask| specifies when the event is signaled. + // + // Events are only valid within a single command buffer. Events can only be + // used on non-transfer queues. + virtual Status SignalEvent(Event* event, + ExecutionStageBitfield source_stage_mask) = 0; + + // Resets an event to the non-signaled state. + // |source_stage_mask| specifies when the event is unsignaled. + // + // Events are only valid within a single command buffer. Events can only be + // used on non-transfer queues. + virtual Status ResetEvent(Event* event, + ExecutionStageBitfield source_stage_mask) = 0; + + // Waits for one or more events to be signaled and defines a memory dependency + // between the synchronization scope of the signal operations and the commands + // following the wait. + // + // |source_stage_mask| must include ExecutionStage::kHost for Event::Signal to + // be visibile. + // + // Events are only valid within a single command buffer. Events remain + // signaled even after waiting and must be reset to be reused. Events can only + // be used on non-transfer queues. + virtual Status WaitEvents( + absl::Span<Event*> events, ExecutionStageBitfield source_stage_mask, + ExecutionStageBitfield target_stage_mask, + absl::Span<const MemoryBarrier> memory_barriers, + absl::Span<const BufferBarrier> buffer_barriers) = 0; + + // Fills the target buffer with the given repeating value. + // Expects that value_length is one of 1, 2, or 4 and that the offset and + // length are aligned to the natural alignment of the value. + // The target buffer must be compatible with the devices owned by this + // device queue and be allocated with BufferUsage::kTransfer. + virtual Status FillBuffer(Buffer* target_buffer, device_size_t target_offset, + device_size_t length, const void* pattern, + size_t pattern_length) = 0; + + // Hints to the device queue that the given buffer will not be used again. + // After encoding a discard the buffer contents will be considered undefined. + // This is because the discard may be used to elide write backs to host memory + // or aggressively reuse the allocation for other purposes. + // + // For buffers allocated with MemoryType::kTransient this may allow + // the device queue to reclaim the memory used by the buffer earlier than + // otherwise possible. + virtual Status DiscardBuffer(Buffer* buffer) = 0; + + // Updates a range of the given target buffer from the source host memory. + // The source host memory is copied immediately into the command buffer and + // occupies command buffer space. It is strongly recommended that large buffer + // updates are performed via CopyBuffer where there is the possibility of a + // zero-copy path. + // The |source_buffer| may be releaed by the caller immediately after this + // call returns. + // The |target_buffer| must be compatible with the devices owned by this + // device queue and be allocated with BufferUsage::kTransfer. + virtual Status UpdateBuffer(const void* source_buffer, + device_size_t source_offset, + Buffer* target_buffer, + device_size_t target_offset, + device_size_t length) = 0; + + // Copies a range of one buffer to another. + // Both buffers must be compatible with the devices owned by this device + // queue and be allocated with BufferUsage::kTransfer. Though the source and + // target buffer may be the same the ranges must not overlap (as with memcpy). + // + // This can be used to perform device->host, host->device, and device->device + // copies. + virtual Status CopyBuffer(Buffer* source_buffer, device_size_t source_offset, + Buffer* target_buffer, device_size_t target_offset, + device_size_t length) = 0; + + // Dispatches an execution request. + // The request may execute overlapped with any other transfer operation or + // dispatch made within the same barrier-defined sequence. + // + // The executable specified must be registered for use with the device driver + // owning this queue. It must not be unregistered until all requests that use + // it have completed. + // + // Fails if the queue does not support dispatch operations (as indicated by + // can_dispatch). + virtual Status Dispatch(const DispatchRequest& dispatch_request) = 0; + + protected: + CommandBuffer(Allocator* allocator, CommandBufferModeBitfield mode, + CommandCategoryBitfield command_categories) + : allocator_(allocator), + mode_(mode), + command_categories_(command_categories) {} + + private: + Allocator* const allocator_; + const CommandBufferModeBitfield mode_; + const CommandCategoryBitfield command_categories_; +}; + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_COMMAND_BUFFER_H_
diff --git a/hal/command_buffer_validation.cc b/hal/command_buffer_validation.cc new file mode 100644 index 0000000..9d86b82 --- /dev/null +++ b/hal/command_buffer_validation.cc
@@ -0,0 +1,403 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/command_buffer_validation.h" + +#include "base/logging.h" +#include "base/status.h" + +namespace iree { +namespace hal { + +namespace { + +// Command buffer validation shim. +// Wraps an existing command buffer to provide in-depth validation during +// recording. This should be enabled whenever the command buffer is being driven +// by unsafe code or when early and readable diagnostics are needed. +class ValidatingCommandBuffer : public CommandBuffer { + public: + explicit ValidatingCommandBuffer(ref_ptr<CommandBuffer> impl); + ~ValidatingCommandBuffer() override; + + CommandBuffer* impl() override { return impl_.get(); } + + bool is_recording() const override; + + Status Begin() override; + Status End() override; + + Status ExecutionBarrier( + ExecutionStageBitfield source_stage_mask, + ExecutionStageBitfield target_stage_mask, + absl::Span<const MemoryBarrier> memory_barriers, + absl::Span<const BufferBarrier> buffer_barriers) override; + Status SignalEvent(Event* event, + ExecutionStageBitfield source_stage_mask) override; + Status ResetEvent(Event* event, + ExecutionStageBitfield source_stage_mask) override; + Status WaitEvents(absl::Span<Event*> events, + ExecutionStageBitfield source_stage_mask, + ExecutionStageBitfield target_stage_mask, + absl::Span<const MemoryBarrier> memory_barriers, + absl::Span<const BufferBarrier> buffer_barriers) override; + Status FillBuffer(Buffer* target_buffer, device_size_t target_offset, + device_size_t length, const void* pattern, + size_t pattern_length) override; + Status DiscardBuffer(Buffer* buffer) override; + Status UpdateBuffer(const void* source_buffer, device_size_t source_offset, + Buffer* target_buffer, device_size_t target_offset, + device_size_t length) override; + Status CopyBuffer(Buffer* source_buffer, device_size_t source_offset, + Buffer* target_buffer, device_size_t target_offset, + device_size_t length) override; + Status Dispatch(const DispatchRequest& dispatch_request) override; + + private: + // Returns a failure if the queue does not support the given caps. + Status ValidateCategories(CommandCategoryBitfield required_categories) const; + // Returns a failure if the memory type the buffer was allocated from is not + // compatible with the given type. + Status ValidateCompatibleMemoryType(Buffer* buffer, + MemoryTypeBitfield memory_type) const; + // Returns a failure if the buffer memory type or usage disallows the given + // access type. + Status ValidateAccess(Buffer* buffer, + MemoryAccessBitfield memory_access) const; + // Returns a failure if the buffer was not allocated for the given usage. + Status ValidateUsage(Buffer* buffer, BufferUsageBitfield usage) const; + // Validates that the range provided is within the given buffer. + Status ValidateRange(Buffer* buffer, device_size_t byte_offset, + device_size_t byte_length) const; + + ref_ptr<CommandBuffer> impl_; +}; + +ValidatingCommandBuffer::ValidatingCommandBuffer(ref_ptr<CommandBuffer> impl) + : CommandBuffer(impl->allocator(), impl->mode(), + impl->command_categories()), + impl_(std::move(impl)) {} + +ValidatingCommandBuffer::~ValidatingCommandBuffer() = default; + +bool ValidatingCommandBuffer::is_recording() const { + return impl_->is_recording(); +} + +Status ValidatingCommandBuffer::Begin() { + DVLOG(3) << "CommandBuffer::Begin()"; + if (impl_->is_recording()) { + return FailedPreconditionErrorBuilder(IREE_LOC) + << "Command buffer is already recording"; + } + return impl_->Begin(); +} + +Status ValidatingCommandBuffer::End() { + DVLOG(3) << "CommandBuffer::End()"; + if (!impl_->is_recording()) { + return FailedPreconditionErrorBuilder(IREE_LOC) + << "Command buffer is not recording"; + } + return impl_->End(); +} + +Status ValidatingCommandBuffer::ValidateCategories( + CommandCategoryBitfield required_categories) const { + if (!AllBitsSet(command_categories(), required_categories)) { + return FailedPreconditionErrorBuilder(IREE_LOC) + << "Operation requires categories " + << CommandCategoryString(required_categories) + << " but buffer only supports " + << CommandCategoryString(command_categories()); + } + return OkStatus(); +} + +Status ValidatingCommandBuffer::ValidateCompatibleMemoryType( + Buffer* buffer, MemoryTypeBitfield memory_type) const { + if ((buffer->memory_type() & memory_type) != memory_type) { + // Missing one or more bits. + return PermissionDeniedErrorBuilder(IREE_LOC) + << "Buffer memory type is not compatible with the requested " + "operation; buffer has " + << MemoryTypeString(buffer->memory_type()) << ", operation requires " + << MemoryTypeString(memory_type); + } + return OkStatus(); +} + +Status ValidatingCommandBuffer::ValidateAccess( + Buffer* buffer, MemoryAccessBitfield memory_access) const { + if ((buffer->allowed_access() & memory_access) != memory_access) { + // Bits must match exactly. + return PermissionDeniedErrorBuilder(IREE_LOC) + << "The buffer does not support the requested access type; buffer " + "allows " + << MemoryAccessString(buffer->allowed_access()) + << ", operation requires " << MemoryAccessString(memory_access); + } + return OkStatus(); +} + +// Returns a failure if the buffer was not allocated for the given usage. +Status ValidatingCommandBuffer::ValidateUsage(Buffer* buffer, + BufferUsageBitfield usage) const { + if (!allocator()->CanUseBuffer(buffer, usage)) { + // Buffer cannot be used on the queue for the given usage. + return PermissionDeniedErrorBuilder(IREE_LOC) + << "Requested usage of " << buffer->DebugString() + << " is not supported for the buffer on this queue; " + "buffer allows " + << BufferUsageString(buffer->usage()) << ", queue requires " + << BufferUsageString(usage); + } + + if ((buffer->usage() & usage) != usage) { + // Missing one or more bits. + return PermissionDeniedErrorBuilder(IREE_LOC) + << "Requested usage was not specified when the buffer was " + "allocated; buffer allows " + << BufferUsageString(buffer->usage()) << ", operation requires " + << BufferUsageString(usage); + } + + return OkStatus(); +} + +// Validates that the range provided is within the given buffer. +Status ValidatingCommandBuffer::ValidateRange(Buffer* buffer, + device_size_t byte_offset, + device_size_t byte_length) const { + // Check if the start of the range runs off the end of the buffer. + if (byte_offset > buffer->byte_length()) { + return OutOfRangeErrorBuilder(IREE_LOC) + << "Attempted to access an address off the end of the valid buffer " + "range (offset=" + << byte_offset << ", length=" << byte_length + << ", buffer byte_length=" << buffer->byte_length() << ")"; + } + + if (byte_length == 0) { + // Fine to have a zero length. + return OkStatus(); + } + + // Check if the end runs over the allocation. + device_size_t end = byte_offset + byte_length; + if (end > buffer->byte_length()) { + return OutOfRangeErrorBuilder(IREE_LOC) + << "Attempted to access an address outside of the valid buffer " + "range (offset=" + << byte_offset << ", length=" << byte_length + << ", end(inc)=" << (end - 1) + << ", buffer byte_length=" << buffer->byte_length() << ")"; + } + + return OkStatus(); +} + +Status ValidatingCommandBuffer::ExecutionBarrier( + ExecutionStageBitfield source_stage_mask, + ExecutionStageBitfield target_stage_mask, + absl::Span<const MemoryBarrier> memory_barriers, + absl::Span<const BufferBarrier> buffer_barriers) { + DVLOG(3) << "CommandBuffer::ExecutionBarrier(...)"; + + // TODO(benvanik): additional synchronization validation. + RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer | + CommandCategory::kDispatch)); + + return impl_->ExecutionBarrier(source_stage_mask, target_stage_mask, + memory_barriers, buffer_barriers); +} + +Status ValidatingCommandBuffer::SignalEvent( + Event* event, ExecutionStageBitfield source_stage_mask) { + DVLOG(3) << "CommandBuffer::SignalEvent(...)"; + + // TODO(benvanik): additional synchronization validation. + RETURN_IF_ERROR(ValidateCategories(CommandCategory::kDispatch)); + + return impl_->SignalEvent(event, source_stage_mask); +} + +Status ValidatingCommandBuffer::ResetEvent( + Event* event, ExecutionStageBitfield source_stage_mask) { + DVLOG(3) << "CommandBuffer::ResetEvent(...)"; + + // TODO(benvanik): additional synchronization validation. + RETURN_IF_ERROR(ValidateCategories(CommandCategory::kDispatch)); + + return impl_->ResetEvent(event, source_stage_mask); +} + +Status ValidatingCommandBuffer::WaitEvents( + absl::Span<Event*> events, ExecutionStageBitfield source_stage_mask, + ExecutionStageBitfield target_stage_mask, + absl::Span<const MemoryBarrier> memory_barriers, + absl::Span<const BufferBarrier> buffer_barriers) { + DVLOG(3) << "CommandBuffer::WaitEvents(...)"; + + // TODO(benvanik): additional synchronization validation. + RETURN_IF_ERROR(ValidateCategories(CommandCategory::kDispatch)); + + return impl_->WaitEvents(events, source_stage_mask, target_stage_mask, + memory_barriers, buffer_barriers); +} + +Status ValidatingCommandBuffer::FillBuffer(Buffer* target_buffer, + device_size_t target_offset, + device_size_t length, + const void* pattern, + size_t pattern_length) { + DVLOG(3) << "CommandBuffer::FillBuffer(" << target_buffer->DebugString() + << ", " << target_offset << ", " << length << ", ??, " + << pattern_length << ")"; + + RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer)); + RETURN_IF_ERROR( + ValidateCompatibleMemoryType(target_buffer, MemoryType::kDeviceVisible)); + RETURN_IF_ERROR(ValidateAccess(target_buffer, MemoryAccess::kWrite)); + RETURN_IF_ERROR(ValidateUsage(target_buffer, BufferUsage::kTransfer)); + RETURN_IF_ERROR(ValidateRange(target_buffer, target_offset, length)); + + // Ensure the value length is supported. + if (pattern_length != 1 && pattern_length != 2 && pattern_length != 4) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Fill value length is not one of the supported values " + "(pattern_length=" + << pattern_length << ")"; + } + + // Ensure the offset and length have an alignment matching the value length. + if ((target_offset % pattern_length) != 0 || (length % pattern_length) != 0) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Fill offset and/or length do not match the natural alignment of " + "the fill value (target_offset=" + << target_offset << ", length=" << length + << ", pattern_length=" << pattern_length << ")"; + } + + return impl_->FillBuffer(target_buffer, target_offset, length, pattern, + pattern_length); +} + +Status ValidatingCommandBuffer::DiscardBuffer(Buffer* buffer) { + DVLOG(3) << "CommandBuffer::DiscardBuffer(" << buffer->DebugString() << ")"; + + RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer)); + RETURN_IF_ERROR( + ValidateCompatibleMemoryType(buffer, MemoryType::kDeviceVisible)); + RETURN_IF_ERROR(ValidateUsage(buffer, BufferUsage::kNone)); + + return impl_->DiscardBuffer(buffer); +} + +Status ValidatingCommandBuffer::UpdateBuffer(const void* source_buffer, + device_size_t source_offset, + Buffer* target_buffer, + device_size_t target_offset, + device_size_t length) { + DVLOG(3) << "CommandBuffer::UpdateBuffer(" << source_buffer << ", " + << source_offset << ", " << target_buffer->DebugString() << ", " + << target_offset << ", " << length << ")"; + + RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer)); + RETURN_IF_ERROR( + ValidateCompatibleMemoryType(target_buffer, MemoryType::kDeviceVisible)); + RETURN_IF_ERROR(ValidateAccess(target_buffer, MemoryAccess::kWrite)); + RETURN_IF_ERROR(ValidateUsage(target_buffer, BufferUsage::kTransfer)); + RETURN_IF_ERROR(ValidateRange(target_buffer, target_offset, length)); + + return impl_->UpdateBuffer(source_buffer, source_offset, target_buffer, + target_offset, length); +} + +Status ValidatingCommandBuffer::CopyBuffer(Buffer* source_buffer, + device_size_t source_offset, + Buffer* target_buffer, + device_size_t target_offset, + device_size_t length) { + DVLOG(3) << "CommandBuffer::CopyBuffer(" << source_buffer->DebugString() + << ", " << source_offset << ", " << target_buffer->DebugString() + << ", " << target_offset << ", " << length << ")"; + + RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer)); + + // At least source or destination must be device-visible to enable + // host->device, device->host, and device->device. + // TODO(b/117338171): host->host copies. + if (!AnyBitSet(source_buffer->memory_type() & MemoryType::kDeviceVisible) && + !AnyBitSet(target_buffer->memory_type() & MemoryType::kDeviceVisible)) { + return PermissionDeniedErrorBuilder(IREE_LOC) + << "At least one buffer must be device-visible for a copy; " + "source_buffer=" + << MemoryTypeString(source_buffer->memory_type()) + << ", target_buffer=" + << MemoryTypeString(target_buffer->memory_type()); + } + + RETURN_IF_ERROR(ValidateAccess(source_buffer, MemoryAccess::kRead)); + RETURN_IF_ERROR(ValidateAccess(target_buffer, MemoryAccess::kWrite)); + RETURN_IF_ERROR(ValidateUsage(source_buffer, BufferUsage::kTransfer)); + RETURN_IF_ERROR(ValidateUsage(target_buffer, BufferUsage::kTransfer)); + RETURN_IF_ERROR(ValidateRange(source_buffer, source_offset, length)); + RETURN_IF_ERROR(ValidateRange(target_buffer, target_offset, length)); + + // Check for overlap - just like memcpy we don't handle that. + if (Buffer::TestOverlap(source_buffer, source_offset, length, target_buffer, + target_offset, + length) != Buffer::Overlap::kDisjoint) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Source and target ranges overlap within the same buffer"; + } + + return impl_->CopyBuffer(source_buffer, source_offset, target_buffer, + target_offset, length); +} + +Status ValidatingCommandBuffer::Dispatch( + const DispatchRequest& dispatch_request) { + DVLOG(3) << "CommandBuffer::Dispatch(?)"; + + RETURN_IF_ERROR(ValidateCategories(CommandCategory::kDispatch)); + + // Validate all buffers referenced have compatible memory types, access + // rights, and usage. + for (const auto& binding : dispatch_request.bindings) { + RETURN_IF_ERROR(ValidateCompatibleMemoryType(binding.buffer, + MemoryType::kDeviceVisible)) + << "input buffer: " << MemoryAccessString(binding.access) << " " + << binding.buffer->DebugStringShort(); + RETURN_IF_ERROR(ValidateAccess(binding.buffer, binding.access)); + RETURN_IF_ERROR(ValidateUsage(binding.buffer, BufferUsage::kDispatch)); + // TODO(benvanik): validate it matches the executable expectations. + // TODO(benvanik): validate buffer contains enough data for shape+size. + } + + // TODO(benvanik): validate no aliasing? + + return impl_->Dispatch(dispatch_request); +} + +} // namespace + +ref_ptr<CommandBuffer> WrapCommandBufferWithValidation( + ref_ptr<CommandBuffer> impl) { + return make_ref<ValidatingCommandBuffer>(std::move(impl)); +} + +} // namespace hal +} // namespace iree
diff --git a/hal/command_buffer_validation.h b/hal/command_buffer_validation.h new file mode 100644 index 0000000..f60d465 --- /dev/null +++ b/hal/command_buffer_validation.h
@@ -0,0 +1,32 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_COMMAND_BUFFER_VALIDATION_H_ +#define IREE_HAL_COMMAND_BUFFER_VALIDATION_H_ + +#include "hal/command_buffer.h" + +namespace iree { +namespace hal { + +// Wraps an existing command buffer to provide in-depth validation during +// recording. This should be enabled whenever the command buffer is being driven +// by unsafe code or when early and readable diagnostics are needed. +ref_ptr<CommandBuffer> WrapCommandBufferWithValidation( + ref_ptr<CommandBuffer> impl); + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_COMMAND_BUFFER_VALIDATION_H_
diff --git a/hal/command_queue.h b/hal/command_queue.h new file mode 100644 index 0000000..e4f5239 --- /dev/null +++ b/hal/command_queue.h
@@ -0,0 +1,119 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_COMMAND_QUEUE_H_ +#define IREE_HAL_COMMAND_QUEUE_H_ + +#include <cstdint> +#include <string> + +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "base/bitfield.h" +#include "base/status.h" +#include "base/time.h" +#include "hal/command_buffer.h" +#include "hal/fence.h" +#include "hal/semaphore.h" + +namespace iree { +namespace hal { + +// A batch of command buffers with synchronization information for submission. +struct SubmissionBatch { + // Semaphores that must be signaled prior to the execution of any command + // buffer in this submission. For TimelineSemaphores the specified payload + // must be reached or exceeded. + absl::Span<const SemaphoreValue> wait_semaphores; + + // Command buffers that will execute in this batch. + // The command buffers will begin execution in order but may complete out of + // order. + absl::Span<CommandBuffer* const> command_buffers; + + // Semaphores to signal after execution of all command buffers complete. + // TimelineSemaphores will be set to the maximum of the specified payload or + // their current payload. + absl::Span<const SemaphoreValue> signal_semaphores; +}; + +// Asynchronous command execution queue. +// +// CommandQueues may capture device status at Fence barriers, including +// information about device state such as thermal throttling. This information +// is a snapshot of the state at the time the fence was signaled and not +// necessarily live at the time of the application query. +// +// Command queues are thread-safe and submissions may occur from multiple +// threads. +class CommandQueue { + public: + virtual ~CommandQueue() = default; + + // Name of the queue used for logging purposes. + // Try to keep at 4 characters total for prettier logging. + const std::string& name() const { return name_; } + + // Capabilities of the command queue. + CommandCategoryBitfield supported_categories() const { + return supported_categories_; + } + + // Whether this queue may be used for transfer commands. + bool can_transfer() const { + return AllBitsSet(supported_categories_, CommandCategory::kTransfer); + } + + // Whether this queue may be used for dispatch commands. + bool can_dispatch() const { + return AllBitsSet(supported_categories_, CommandCategory::kDispatch); + } + + // Submits one or more command batches for execution on the queue. + // Dependencies between |batches| on BinarySemaphores must be sorted in order + // such that all semaphores are signaled prior to any waits on them. + // Dependencies between TimelineSemaphores may occur in any order. + // + // The provided |fence| will be signaled when all |batches| have retired. + virtual Status Submit(absl::Span<const SubmissionBatch> batches, + FenceValue fence) = 0; + inline Status Submit(const SubmissionBatch& batch, FenceValue fence) { + return Submit(absl::MakeConstSpan(&batch, 1), std::move(fence)); + } + + // Blocks until all outstanding requests have been completed. + // This is equivalent to having waited on all outstanding fences. + // Implicitly calls Flush to ensure delayed requests are scheduled. + // + // If the command queue has encountered an error during submission at any + // point it will be returned here (repeatedly). + virtual Status WaitIdle(absl::Time deadline) = 0; + inline Status WaitIdle(absl::Duration timeout) { + return WaitIdle(RelativeTimeoutToDeadline(timeout)); + } + inline Status WaitIdle() { return WaitIdle(absl::InfiniteFuture()); } + + protected: + CommandQueue(std::string name, CommandCategoryBitfield supported_categories) + : name_(std::move(name)), supported_categories_(supported_categories) {} + + const std::string name_; + const CommandCategoryBitfield supported_categories_; +}; + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_COMMAND_QUEUE_H_
diff --git a/hal/dawn/BUILD b/hal/dawn/BUILD new file mode 100644 index 0000000..0f5c39a --- /dev/null +++ b/hal/dawn/BUILD
@@ -0,0 +1,72 @@ +# HAL implementation using Dawn and SPIR-V executables. +# https://dawn.googlesource.com/dawn + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "dawn_device", + srcs = ["dawn_device.cc"], + hdrs = ["dawn_device.h"], + deps = [ + "///base:memory", + "///base:status", + "///base:tracing", + "///hal:command_queue", + "///hal:device", + "///hal:executable_cache", + "///hal:fence", + "///hal/host:host_local_allocator", + "//third_party/dawn:dawn_headers", + "//third_party/dawn:dawn_native", + "//third_party/dawn:dawn_static_proc", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "dawn_driver", + srcs = ["dawn_driver.cc"], + hdrs = ["dawn_driver.h"], + deps = [ + ":dawn_device", + "///base:status", + "///base:tracing", + "///hal:device_info", + "///hal:driver", + "//third_party/dawn:dawn_headers", + "//third_party/dawn:dawn_native", + "//third_party/dawn:dawn_static_proc", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + +# TODO(scotttodd): Use SwiftShader to test Vulkan backend +cc_test( + name = "dawn_driver_test", + srcs = ["dawn_driver_test.cc"], + deps = [ + ":dawn_driver", + "///base:status_matchers", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "dawn_driver_module", + srcs = ["dawn_driver_module.cc"], + deps = [ + ":dawn_driver", + "///base:init", + "///base:status", + "///base:tracing", + "///hal:driver_registry", + "@com_google_absl//absl/flags:flag", + ], + alwayslink = 1, +)
diff --git a/hal/dawn/dawn_device.cc b/hal/dawn/dawn_device.cc new file mode 100644 index 0000000..c8a3618 --- /dev/null +++ b/hal/dawn/dawn_device.cc
@@ -0,0 +1,139 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/dawn/dawn_device.h" + +#include "absl/memory/memory.h" +#include "base/status.h" +#include "base/tracing.h" +#include "hal/command_queue.h" +#include "hal/executable_cache.h" +#include "hal/fence.h" + +namespace iree { +namespace hal { +namespace dawn { + +namespace { + +// ExecutableCache implementation that compiles but does nothing. +// This will be replaced with something functional soon. +class NoopExecutableCache final : public ExecutableCache { + public: + explicit NoopExecutableCache() {} + ~NoopExecutableCache() override = default; + + bool CanPrepareFormat(ExecutableFormat format) const override { + return false; + } + + StatusOr<ref_ptr<Executable>> PrepareExecutable( + ExecutableCachingModeBitfield mode, const ExecutableSpec& spec) override { + return UnimplementedErrorBuilder(IREE_LOC) << "PrepareExecutable NYI"; + } +}; + +} // namespace + +DawnDevice::DawnDevice(const DeviceInfo& device_info, + ::dawn::Device backend_device) + : Device(device_info), backend_device_(backend_device) { + IREE_TRACE_SCOPE0("DawnDevice::ctor"); + + // TODO(scotttodd): construct command queues, perform other initialization + + // Log some basic device info. + std::string backend_type_str; + auto* adapter = + static_cast<dawn_native::Adapter*>(device_info.driver_handle()); + switch (adapter->GetBackendType()) { + case dawn_native::BackendType::D3D12: + backend_type_str = "D3D12"; + break; + case dawn_native::BackendType::Metal: + backend_type_str = "Metal"; + break; + case dawn_native::BackendType::Null: + backend_type_str = "Null"; + break; + case dawn_native::BackendType::OpenGL: + backend_type_str = "OpenGL"; + break; + case dawn_native::BackendType::Vulkan: + backend_type_str = "Vulkan"; + break; + } + LOG(INFO) << "Created DawnDevice '" << device_info.name() << "' (" + << backend_type_str << ")"; +} + +DawnDevice::~DawnDevice() = default; + +std::shared_ptr<ExecutableCache> DawnDevice::CreateExecutableCache() { + return std::make_shared<NoopExecutableCache>(); +} + +StatusOr<ref_ptr<CommandBuffer>> DawnDevice::CreateCommandBuffer( + CommandBufferModeBitfield mode, + CommandCategoryBitfield command_categories) { + return UnimplementedErrorBuilder(IREE_LOC) << "CreateCommandBuffer NYI"; +} + +StatusOr<ref_ptr<Event>> DawnDevice::CreateEvent() { + return UnimplementedErrorBuilder(IREE_LOC) << "CreateEvent NYI"; +} + +StatusOr<ref_ptr<BinarySemaphore>> DawnDevice::CreateBinarySemaphore( + bool initial_value) { + IREE_TRACE_SCOPE0("DawnDevice::CreateBinarySemaphore"); + + return UnimplementedErrorBuilder(IREE_LOC) + << "Binary semaphores not yet implemented"; +} + +StatusOr<ref_ptr<TimelineSemaphore>> DawnDevice::CreateTimelineSemaphore( + uint64_t initial_value) { + IREE_TRACE_SCOPE0("DawnDevice::CreateTimelineSemaphore"); + + return UnimplementedErrorBuilder(IREE_LOC) + << "Timeline semaphores not yet implemented"; +} + +StatusOr<ref_ptr<Fence>> DawnDevice::CreateFence(uint64_t initial_value) { + IREE_TRACE_SCOPE0("DawnDevice::CreateFence"); + + return UnimplementedErrorBuilder(IREE_LOC) << "CreateFence NYI"; +} + +Status DawnDevice::WaitAllFences(absl::Span<const FenceValue> fences, + absl::Time deadline) { + IREE_TRACE_SCOPE0("DawnDevice::WaitAllFences"); + + return UnimplementedErrorBuilder(IREE_LOC) << "WaitAllFences NYI"; +} + +StatusOr<int> DawnDevice::WaitAnyFence(absl::Span<const FenceValue> fences, + absl::Time deadline) { + IREE_TRACE_SCOPE0("DawnDevice::WaitAnyFence"); + + return UnimplementedErrorBuilder(IREE_LOC) << "WaitAnyFence NYI"; +} + +Status DawnDevice::WaitIdle(absl::Time deadline) { + return UnimplementedErrorBuilder(IREE_LOC) << "WaitIdle"; +} + +} // namespace dawn +} // namespace hal +} // namespace iree
diff --git a/hal/dawn/dawn_device.h b/hal/dawn/dawn_device.h new file mode 100644 index 0000000..2c6b4d6 --- /dev/null +++ b/hal/dawn/dawn_device.h
@@ -0,0 +1,78 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_DAWN_DAWN_DEVICE_H_ +#define IREE_HAL_DAWN_DAWN_DEVICE_H_ + +#include "absl/container/inlined_vector.h" +#include "absl/types/span.h" +#include "base/memory.h" +#include "hal/device.h" +#include "hal/host/host_local_allocator.h" +#include "third_party/dawn/src/include/dawn/dawncpp.h" +#include "third_party/dawn/src/include/dawn_native/DawnNative.h" + +namespace iree { +namespace hal { +namespace dawn { + +class DawnDevice final : public Device { + public: + explicit DawnDevice(const DeviceInfo& device_info, + ::dawn::Device backend_device); + ~DawnDevice() override; + + Allocator* allocator() const override { return &allocator_; } + + absl::Span<CommandQueue*> dispatch_queues() const override { + return RawPtrSpan(absl::MakeSpan(command_queues_)); + } + + absl::Span<CommandQueue*> transfer_queues() const override { + return RawPtrSpan(absl::MakeSpan(command_queues_)); + } + + std::shared_ptr<ExecutableCache> CreateExecutableCache() override; + + StatusOr<ref_ptr<CommandBuffer>> CreateCommandBuffer( + CommandBufferModeBitfield mode, + CommandCategoryBitfield command_categories) override; + + StatusOr<ref_ptr<Event>> CreateEvent() override; + + StatusOr<ref_ptr<BinarySemaphore>> CreateBinarySemaphore( + bool initial_value) override; + StatusOr<ref_ptr<TimelineSemaphore>> CreateTimelineSemaphore( + uint64_t initial_value) override; + + StatusOr<ref_ptr<Fence>> CreateFence(uint64_t initial_value) override; + Status WaitAllFences(absl::Span<const FenceValue> fences, + absl::Time deadline) override; + StatusOr<int> WaitAnyFence(absl::Span<const FenceValue> fences, + absl::Time deadline) override; + + Status WaitIdle(absl::Time deadline) override; + + private: + mutable HostLocalAllocator allocator_; + mutable absl::InlinedVector<std::unique_ptr<CommandQueue>, 1> command_queues_; + + ::dawn::Device backend_device_; +}; + +} // namespace dawn +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_DAWN_DAWN_DEVICE_H_
diff --git a/hal/dawn/dawn_driver.cc b/hal/dawn/dawn_driver.cc new file mode 100644 index 0000000..565e98a --- /dev/null +++ b/hal/dawn/dawn_driver.cc
@@ -0,0 +1,120 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/dawn/dawn_driver.h" + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "base/status.h" +#include "base/tracing.h" +#include "hal/dawn/dawn_device.h" +#include "hal/device_info.h" + +namespace iree { +namespace hal { +namespace dawn { + +namespace { + +// Populates device information from the given dawn_native::Adapter. +StatusOr<DeviceInfo> PopulateDeviceInfo(dawn_native::Adapter* adapter) { + // TODO(scotttodd): Query these for each backend or implement? + DeviceFeatureBitfield supported_features = DeviceFeature::kNone; + // supported_features |= DeviceFeature::kDebugging; + // supported_features |= DeviceFeature::kCoverage; + // supported_features |= DeviceFeature::kProfiling; + + // TODO(scotttodd): more clever/sanitized device naming. + std::string device_name = absl::StrCat("dawn-", adapter->GetPCIInfo().name); + + return DeviceInfo(device_name, supported_features, + reinterpret_cast<void*>(adapter)); +} + +} // namespace + +DawnDriver::DawnDriver() : Driver("dawn") { + dawn_instance_ = absl::make_unique<dawn_native::Instance>(); +} + +DawnDriver::~DawnDriver() = default; + +StatusOr<std::vector<DeviceInfo>> DawnDriver::EnumerateAvailableDevices() { + IREE_TRACE_SCOPE0("DawnDriver::EnumerateAvailableDevices"); + + if (dawn_backend_adapters_.empty()) { + // Discover adapters (i.e. devices and their associated backend APIs). + // Retain the list of adapters so pointers are valid for the lifetime of + // this object. + dawn_instance_->DiscoverDefaultAdapters(); + dawn_backend_adapters_ = dawn_instance_->GetAdapters(); + } else { + // Assume that the list of adapters does not change. This is not guaranteed + // to be true, but we also don't want to invalidate pointers by requesting + // a new list each time. If the list of available devices would change, + // tearing down and creating a new DawnDriver may be your best option. + } + + // Convert to our HAL structure. + std::vector<DeviceInfo> device_infos; + device_infos.reserve(dawn_backend_adapters_.size()); + for (auto& adapter : dawn_backend_adapters_) { + // TODO(scotttodd): if we fail should we just ignore the device in the list? + ASSIGN_OR_RETURN(auto device_info, PopulateDeviceInfo(&adapter)); + device_infos.push_back(std::move(device_info)); + } + return device_infos; +} + +StatusOr<std::shared_ptr<Device>> DawnDriver::CreateDefaultDevice() { + IREE_TRACE_SCOPE0("DawnDriver::CreateDefaultDevice"); + + // Query available devices. + ASSIGN_OR_RETURN(auto available_devices, EnumerateAvailableDevices()); + if (available_devices.empty()) { + return NotFoundErrorBuilder(IREE_LOC) << "No devices are available"; + } + + // Create the first non-null device, if any. + for (const auto& device : available_devices) { + auto* adapter = static_cast<dawn_native::Adapter*>(device.driver_handle()); + if (adapter->GetBackendType() != dawn_native::BackendType::Null) { + return CreateDevice(device); + } + } + + // Otherwise create the first null device. + return CreateDevice(available_devices.front()); +} + +StatusOr<std::shared_ptr<Device>> DawnDriver::CreateDevice( + const DeviceInfo& device_info) { + IREE_TRACE_SCOPE0("DawnDriver::CreateDevice"); + + auto* adapter = + static_cast<dawn_native::Adapter*>(device_info.driver_handle()); + ::DawnDevice c_backend_device = adapter->CreateDevice(); + if (!c_backend_device) { + return InternalErrorBuilder(IREE_LOC) << "Failed to create a Dawn device"; + } + DawnProcTable backend_procs = dawn_native::GetProcs(); + dawnSetProcs(&backend_procs); + ::dawn::Device backend_device = ::dawn::Device::Acquire(c_backend_device); + + return std::make_shared<DawnDevice>(device_info, backend_device); +} + +} // namespace dawn +} // namespace hal +} // namespace iree
diff --git a/hal/dawn/dawn_driver.h b/hal/dawn/dawn_driver.h new file mode 100644 index 0000000..03a2cdb --- /dev/null +++ b/hal/dawn/dawn_driver.h
@@ -0,0 +1,50 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_DAWN_DAWN_DRIVER_H_ +#define IREE_HAL_DAWN_DAWN_DRIVER_H_ + +#include <memory> +#include <vector> + +#include "hal/driver.h" +#include "third_party/dawn/src/include/dawn/dawncpp.h" +#include "third_party/dawn/src/include/dawn_native/DawnNative.h" + +namespace iree { +namespace hal { +namespace dawn { + +class DawnDriver final : public Driver { + public: + DawnDriver(); + ~DawnDriver() override; + + StatusOr<std::vector<DeviceInfo>> EnumerateAvailableDevices() override; + + StatusOr<std::shared_ptr<Device>> CreateDefaultDevice() override; + + StatusOr<std::shared_ptr<Device>> CreateDevice( + const DeviceInfo& device_info) override; + + private: + std::unique_ptr<dawn_native::Instance> dawn_instance_; + std::vector<dawn_native::Adapter> dawn_backend_adapters_; +}; + +} // namespace dawn +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_DAWN_DAWN_DRIVER_H_
diff --git a/hal/dawn/dawn_driver_module.cc b/hal/dawn/dawn_driver_module.cc new file mode 100644 index 0000000..317d9fe --- /dev/null +++ b/hal/dawn/dawn_driver_module.cc
@@ -0,0 +1,41 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <memory> + +#include "base/init.h" +#include "base/status.h" +#include "base/tracing.h" +#include "hal/dawn/dawn_driver.h" +#include "hal/driver_registry.h" + +namespace iree { +namespace hal { +namespace dawn { +namespace { + +StatusOr<std::shared_ptr<Driver>> CreateDawnDriver() { + return std::make_shared<DawnDriver>(); +} + +} // namespace +} // namespace dawn +} // namespace hal +} // namespace iree + +IREE_REGISTER_MODULE_INITIALIZER(iree_hal_dawn_driver, { + QCHECK_OK(::iree::hal::DriverRegistry::shared_registry()->Register( + "dawn", ::iree::hal::dawn::CreateDawnDriver)); +}); +IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(iree_hal, iree_hal_dawn_driver);
diff --git a/hal/dawn/dawn_driver_test.cc b/hal/dawn/dawn_driver_test.cc new file mode 100644 index 0000000..7e196f1 --- /dev/null +++ b/hal/dawn/dawn_driver_test.cc
@@ -0,0 +1,45 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/dawn/dawn_driver.h" + +#include "base/status_matchers.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace iree { +namespace hal { +namespace dawn { +namespace { + +TEST(DawnDriverTest, CreateDefaultDevice) { + DawnDriver dawn_driver; + ASSERT_OK_AND_ASSIGN(auto default_device, dawn_driver.CreateDefaultDevice()); +} + +TEST(DawnDriverTest, EnumerateDevicesAndCreate) { + DawnDriver dawn_driver; + + ASSERT_OK_AND_ASSIGN(auto available_devices, + dawn_driver.EnumerateAvailableDevices()); + ASSERT_GT(available_devices.size(), 0); + + ASSERT_OK_AND_ASSIGN(auto first_device, + dawn_driver.CreateDevice(available_devices[0])); +} + +} // namespace +} // namespace dawn +} // namespace hal +} // namespace iree
diff --git a/hal/deferred_buffer.cc b/hal/deferred_buffer.cc new file mode 100644 index 0000000..5b3e82d --- /dev/null +++ b/hal/deferred_buffer.cc
@@ -0,0 +1,162 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/deferred_buffer.h" + +#include "base/status.h" + +namespace iree { +namespace hal { + +DeferredBuffer::DeferredBuffer(Allocator* allocator, + MemoryTypeBitfield memory_type, + MemoryAccessBitfield allowed_access, + BufferUsageBitfield usage, + device_size_t byte_length) + : Buffer(allocator, memory_type, allowed_access, usage, 0, 0, byte_length) { +} + +DeferredBuffer::~DeferredBuffer() = default; + +Status DeferredBuffer::GrowByteLength(device_size_t new_byte_length) { + if (parent_buffer_) { + return FailedPreconditionErrorBuilder(IREE_LOC) + << "Attempting to set min allocation size while bound to an " + "allocation"; + } + if (byte_length_ != kWholeBuffer && new_byte_length < byte_length_) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Attempting to shrink a buffer to " << new_byte_length + << " when it has a minimum size of " << byte_length_; + } + byte_length_ = new_byte_length; + return OkStatus(); +} + +Status DeferredBuffer::BindAllocation(ref_ptr<Buffer> allocated_buffer, + device_size_t byte_offset, + device_size_t byte_length) { + // We can only be bound to allocations that are compatible with our specified + // allocator and usage. + if (!allocator_->CanUseBuffer(allocated_buffer.get(), usage())) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Allocation is not compatible with the allocator specified for " + "the deferred buffer"; + } + + // Calculate the range in the allocated_buffer that we are interested in. + RETURN_IF_ERROR(Buffer::CalculateRange(0, allocated_buffer->byte_length(), + byte_offset, byte_length, &byte_offset, + &byte_length)); + + // Verify that we have enough bytes for what we've promised. + if (byte_length < byte_length_) { + return OutOfRangeErrorBuilder(IREE_LOC) + << "Allocation range is too small; min_allocation_size=" + << byte_length_ << " but the range of " << byte_offset << "-" + << (byte_offset + byte_length - 1) << " (" << byte_length + << "b) is too small"; + } + + allocated_buffer_ = allocated_buffer.get(); + parent_buffer_ = std::move(allocated_buffer); + byte_offset_ = byte_offset; + return OkStatus(); +} + +void DeferredBuffer::ResetAllocation() { + allocated_buffer_ = this; + parent_buffer_.reset(); + byte_offset_ = 0; +} + +StatusOr<Buffer*> DeferredBuffer::ResolveAllocation() const { + // If you get errors here then someone allocated the buffer with + // MemoryType::kTransient and you are trying to use it outside of the time + // it is actually allocated (such as during CommandBuffer evaluation). If + // you need to use the buffer in non-transient ways then allocate the buffer + // without the MemoryType::kTransient flag. + if (!parent_buffer_) { + return FailedPreconditionErrorBuilder(IREE_LOC) + << "Attempting to use a transient buffer prior to allocation: " + << DebugString(); + } + return parent_buffer_.get(); +} + +Status DeferredBuffer::FillImpl(device_size_t byte_offset, + device_size_t byte_length, const void* pattern, + device_size_t pattern_length) { + ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation()); + return allocated_buffer->FillImpl(byte_offset, byte_length, pattern, + pattern_length); +} + +Status DeferredBuffer::ReadDataImpl(device_size_t source_offset, void* data, + device_size_t data_length) { + ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation()); + return allocated_buffer->ReadDataImpl(source_offset, data, data_length); +} + +Status DeferredBuffer::WriteDataImpl(device_size_t target_offset, + const void* data, + device_size_t data_length) { + ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation()); + return allocated_buffer->WriteDataImpl(target_offset, data, data_length); +} + +Status DeferredBuffer::CopyDataImpl(device_size_t target_offset, + Buffer* source_buffer, + device_size_t source_offset, + device_size_t data_length) { + ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation()); + return allocated_buffer->CopyDataImpl(target_offset, source_buffer, + source_offset, data_length); +} + +Status DeferredBuffer::MapMemoryImpl(MappingMode mapping_mode, + MemoryAccessBitfield memory_access, + device_size_t local_byte_offset, + device_size_t local_byte_length, + void** out_data) { + ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation()); + return allocated_buffer->MapMemoryImpl(mapping_mode, memory_access, + local_byte_offset, local_byte_length, + out_data); +} + +Status DeferredBuffer::UnmapMemoryImpl(device_size_t local_byte_offset, + device_size_t local_byte_length, + void* data) { + ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation()); + return allocated_buffer->UnmapMemoryImpl(local_byte_offset, local_byte_length, + data); +} + +Status DeferredBuffer::InvalidateMappedMemoryImpl( + device_size_t local_byte_offset, device_size_t local_byte_length) { + ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation()); + return allocated_buffer->InvalidateMappedMemoryImpl(local_byte_offset, + local_byte_length); +} + +Status DeferredBuffer::FlushMappedMemoryImpl(device_size_t local_byte_offset, + device_size_t local_byte_length) { + ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation()); + return allocated_buffer->FlushMappedMemoryImpl(local_byte_offset, + local_byte_length); +} + +} // namespace hal +} // namespace iree
diff --git a/hal/deferred_buffer.h b/hal/deferred_buffer.h new file mode 100644 index 0000000..ab03c0d --- /dev/null +++ b/hal/deferred_buffer.h
@@ -0,0 +1,106 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_DEFERRED_BUFFER_H_ +#define IREE_HAL_DEFERRED_BUFFER_H_ + +#include <cstddef> +#include <memory> +#include <utility> + +#include "base/status.h" +#include "hal/allocator.h" +#include "hal/buffer.h" + +namespace iree { +namespace hal { + +// A Buffer that can have its underlying allocation changed at runtime. +// Unbound buffers act as a way to logically group dependent ranges of memory +// without needing to have allocated that memory yet. +// +// Usage: +// // Setup two spans referencing ranges of a deferred buffer. +// auto deferred_buffer = std::make_shared<DeferredBuffer>(..., 200); +// ASSIGN_OR_RETURN(auto span0, Buffer::Subspan(deferred_buffer, 0, 100)); +// ASSIGN_OR_RETURN(auto span1, Buffer::Subspan(deferred_buffer, 100, 100)); +// +// // Attempting to access |deferred_buffer| or |span0| or |span1| will fail. +// // ERROR: span0->Fill(false); +// +// // Now allocate a real buffer to serve as storage for the data. +// ASSIGN_OR_RETURN(auto allocated_buffer, Buffer::Allocate(..., 200)); +// RETURN_IF_ERROR(deferred_buffer->BindAllocation( +// allocated_buffer, 0, kWholeBuffer)); +// +// // And now we can use the spans. +// RETURN_IF_ERROR(span0->Fill(false)); +// +// // If at some point we want to detach the buffer from the allocation (so we +// // can use a different allocation, reuse the memory, etc). +// deferred_buffer->ResetAllocation(); +// +// Thread-compatible. Attempting to rebind the allocation while other threads +// are using the buffer will lead to undefined behavior. +class DeferredBuffer : public Buffer { + public: + DeferredBuffer(Allocator* allocator, MemoryTypeBitfield memory_type, + MemoryAccessBitfield allowed_access, BufferUsageBitfield usage, + device_size_t byte_length); + ~DeferredBuffer() override; + + // Grows the minimum allocation size of the buffer to |new_byte_length|. + // Attempting to bind an allocation less than this size will fail. This must + // only be called when the buffer is not bound to an allocation. + Status GrowByteLength(device_size_t new_byte_length); + + // Binds or rebinds the deferred buffer to an allocated buffer. + Status BindAllocation(ref_ptr<Buffer> allocated_buffer, + device_size_t byte_offset, device_size_t byte_length); + + // Resets the deferred buffer to have no binding. + void ResetAllocation(); + + private: + // Resolves the allocated buffer that this subspan references into. + // This will fail if the buffer has not yet been bound to an allocation or + // the allocated buffer has not been committed. + StatusOr<Buffer*> ResolveAllocation() const; + + Status FillImpl(device_size_t byte_offset, device_size_t byte_length, + const void* pattern, device_size_t pattern_length) override; + Status ReadDataImpl(device_size_t source_offset, void* data, + device_size_t data_length) override; + Status WriteDataImpl(device_size_t target_offset, const void* data, + device_size_t data_length) override; + Status CopyDataImpl(device_size_t target_offset, Buffer* source_buffer, + device_size_t source_offset, + device_size_t data_length) override; + Status MapMemoryImpl(MappingMode mapping_mode, + MemoryAccessBitfield memory_access, + device_size_t local_byte_offset, + device_size_t local_byte_length, + void** out_data) override; + Status UnmapMemoryImpl(device_size_t local_byte_offset, + device_size_t local_byte_length, void* data) override; + Status InvalidateMappedMemoryImpl(device_size_t local_byte_offset, + device_size_t local_byte_length) override; + Status FlushMappedMemoryImpl(device_size_t local_byte_offset, + device_size_t local_byte_length) override; +}; + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_DEFERRED_BUFFER_H_
diff --git a/hal/deferred_buffer_test.cc b/hal/deferred_buffer_test.cc new file mode 100644 index 0000000..6a8f051 --- /dev/null +++ b/hal/deferred_buffer_test.cc
@@ -0,0 +1,174 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/deferred_buffer.h" + +#include "absl/memory/memory.h" +#include "base/status_matchers.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "hal/heap_buffer.h" +#include "hal/testing/mock_allocator.h" + +namespace iree { +namespace hal { +namespace { + +using ::iree::hal::testing::MockAllocator; +using ::testing::_; +using ::testing::Return; + +// Tests properties of unbound buffers. +TEST(DeferredBufferTest, Unbound) { + MockAllocator allocator; + auto deferred_buffer = absl::make_unique<DeferredBuffer>( + &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll, + 100); + EXPECT_EQ(&allocator, deferred_buffer->allocator()); + EXPECT_EQ(deferred_buffer.get(), deferred_buffer->allocated_buffer()); + EXPECT_EQ(0, deferred_buffer->allocation_size()); + EXPECT_EQ(0, deferred_buffer->byte_offset()); + EXPECT_EQ(100, deferred_buffer->byte_length()); +} + +// Tests that binding verifies allocators are compatible. +TEST(DeferredBufferTest, AllocatorCheck) { + MockAllocator allocator; + auto deferred_buffer = absl::make_unique<DeferredBuffer>( + &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll, + 100); + auto real_buffer = + HeapBuffer::Allocate(MemoryType::kHostLocal, BufferUsage::kAll, 256); + EXPECT_CALL( + allocator, + CanUseBufferLike(real_buffer->allocator(), real_buffer->memory_type(), + real_buffer->usage(), BufferUsage::kAll)) + .WillOnce(Return(false)); + EXPECT_TRUE(IsInvalidArgument( + deferred_buffer->BindAllocation(std::move(real_buffer), 0, 100))); +} + +// Tests that binding verifies allocation sizes. +TEST(DeferredBufferTest, SizeCheck) { + MockAllocator allocator; + auto deferred_buffer = absl::make_unique<DeferredBuffer>( + &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll, + 100); + auto real_buffer = + HeapBuffer::Allocate(MemoryType::kHostLocal, BufferUsage::kAll, 256); + EXPECT_CALL(allocator, CanUseBufferLike(_, _, _, _)) + .WillRepeatedly(Return(true)); + + EXPECT_OK(deferred_buffer->BindAllocation(add_ref(real_buffer), 10, 100)); + EXPECT_EQ(256, deferred_buffer->allocation_size()); + EXPECT_EQ(10, deferred_buffer->byte_offset()); + EXPECT_EQ(100, deferred_buffer->byte_length()); + EXPECT_OK( + deferred_buffer->BindAllocation(add_ref(real_buffer), 10, kWholeBuffer)); + EXPECT_EQ(256, deferred_buffer->allocation_size()); + EXPECT_EQ(10, deferred_buffer->byte_offset()); + EXPECT_EQ(100, deferred_buffer->byte_length()); + + EXPECT_TRUE(IsOutOfRange( + deferred_buffer->BindAllocation(add_ref(real_buffer), 200, 100))); + EXPECT_TRUE(IsOutOfRange(deferred_buffer->BindAllocation(add_ref(real_buffer), + 200, kWholeBuffer))); + EXPECT_TRUE(IsOutOfRange( + deferred_buffer->BindAllocation(add_ref(real_buffer), 10, 10))); +} + +// Tests resizing buffers after they have been allocated. +TEST(DeferredBufferTest, Resizing) { + MockAllocator allocator; + auto deferred_buffer = absl::make_unique<DeferredBuffer>( + &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll, + 100); + auto real_buffer = + HeapBuffer::Allocate(MemoryType::kHostLocal, BufferUsage::kAll, 256); + EXPECT_CALL(allocator, CanUseBufferLike(_, _, _, _)) + .WillRepeatedly(Return(true)); + + // Grow. + EXPECT_EQ(100, deferred_buffer->byte_length()); + EXPECT_OK(deferred_buffer->GrowByteLength(150)); + EXPECT_EQ(150, deferred_buffer->byte_length()); + + // Shrinking should fail. + EXPECT_TRUE(IsInvalidArgument(deferred_buffer->GrowByteLength(5))); + + // Growing should fail if bound. + EXPECT_OK(deferred_buffer->BindAllocation(std::move(real_buffer), 0, 150)); + EXPECT_TRUE(IsFailedPrecondition(deferred_buffer->GrowByteLength(100))); +} + +// Tests binding and rebinding behavior. +TEST(DeferredBufferTest, Rebinding) { + MockAllocator allocator; + auto deferred_buffer = absl::make_unique<DeferredBuffer>( + &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll, + 100); + auto real_buffer = + HeapBuffer::Allocate(MemoryType::kHostLocal, BufferUsage::kAll, 256); + EXPECT_CALL(allocator, CanUseBufferLike(_, _, _, _)) + .WillRepeatedly(Return(true)); + + // Safe to reset when not bound. + deferred_buffer->ResetAllocation(); + EXPECT_EQ(deferred_buffer.get(), deferred_buffer->allocated_buffer()); + EXPECT_EQ(0, deferred_buffer->allocation_size()); + + EXPECT_OK(deferred_buffer->BindAllocation(add_ref(real_buffer), 0, 100)); + EXPECT_EQ(real_buffer.get(), deferred_buffer->allocated_buffer()); + EXPECT_EQ(256, deferred_buffer->allocation_size()); + deferred_buffer->ResetAllocation(); + EXPECT_EQ(deferred_buffer.get(), deferred_buffer->allocated_buffer()); + EXPECT_EQ(0, deferred_buffer->allocation_size()); + EXPECT_OK(deferred_buffer->BindAllocation(add_ref(real_buffer), 0, 100)); + EXPECT_EQ(real_buffer.get(), deferred_buffer->allocated_buffer()); + EXPECT_EQ(256, deferred_buffer->allocation_size()); +} + +// Tests normal usage of bound buffers. +TEST(DeferredBufferTest, BoundUsage) { + MockAllocator allocator; + auto deferred_buffer = absl::make_unique<DeferredBuffer>( + &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll, + 100); + auto real_buffer = + HeapBuffer::Allocate(MemoryType::kHostLocal, BufferUsage::kAll, 256); + EXPECT_CALL(allocator, CanUseBufferLike(_, _, _, _)) + .WillRepeatedly(Return(true)); + EXPECT_OK(deferred_buffer->BindAllocation(std::move(real_buffer), 0, 100)); + + EXPECT_FALSE(deferred_buffer->DebugString().empty()); + EXPECT_FALSE(deferred_buffer->DebugStringShort().empty()); + + EXPECT_OK(deferred_buffer->Fill8(0, 10, 0xFF)); +} + +// Tests that unbound buffers fail to perform any buffer actions. +TEST(DeferredBufferTest, UnboundUsage) { + MockAllocator allocator; + auto deferred_buffer = absl::make_unique<DeferredBuffer>( + &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll, + 100); + EXPECT_FALSE(deferred_buffer->DebugString().empty()); + EXPECT_FALSE(deferred_buffer->DebugStringShort().empty()); + + EXPECT_TRUE(IsFailedPrecondition(deferred_buffer->Fill8(0, 10, 0xFF))); +} + +} // namespace +} // namespace hal +} // namespace iree
diff --git a/hal/device.h b/hal/device.h new file mode 100644 index 0000000..69586ca --- /dev/null +++ b/hal/device.h
@@ -0,0 +1,165 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_DEVICE_H_ +#define IREE_HAL_DEVICE_H_ + +#include <memory> + +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "base/status.h" +#include "base/time.h" +#include "hal/allocator.h" +#include "hal/buffer.h" +#include "hal/command_queue.h" +#include "hal/device_info.h" +#include "hal/event.h" +#include "hal/executable_cache.h" +#include "hal/semaphore.h" + +namespace iree { +namespace hal { + +class Device { + public: + virtual ~Device() = default; + + // Information about device capabilities. + const DeviceInfo& info() const { return device_info_; } + + // TODO(benvanik): status (thermal, power mode, etc). + + // TODO(benvanik): throttling adjustment/power profile. + + // TODO(benvanik): control (suspend/resume, delay, etc). + + // An allocator providing buffers usable by the device. + // This allocator may be shared with other devices in the same family. + virtual Allocator* allocator() const = 0; + + // Returns a list of all general-purpose dispatch queues provided by the + // device. In general these map 1:1 with independent execution contexts, + // though some devices may hide that and expose only a single queue that is + // scheduled internally. + virtual absl::Span<CommandQueue*> dispatch_queues() const = 0; + + // Returns a list of transfer queues provided by the device. These queues may + // perform transfer operations asynchronously with respect to execution on the + // dispatch queues. For large sequences of transfer operations always prefer + // using one of these queues. + // Note that if the device does not support a dedicated transfer queue this + // list may be the same as (or a subset of) dispatch_queues. + virtual absl::Span<CommandQueue*> transfer_queues() const = 0; + + // TODO(b/137153339): accept initial cache data. + // Creates a device-specific cache for executables prepared for dispatch. + // The cache manages executable compilation, caching (on disk or in memory), + // and lifetime. Users can decide to use one or more caches to allow differing + // lifetimes (such as unloading modules), persistent on disk caching of only + // specific hot executables, etc. + // + // Returns a thread-safe cache that must remain alive until all executables + // using the cache are no longer in-flight. + virtual std::shared_ptr<ExecutableCache> CreateExecutableCache() = 0; + + // Creates a command buffer for recording commands to submit to queues owned + // by this device. The command buffer may come from a pool but will be reset + // prior to being returned to the caller. + virtual StatusOr<ref_ptr<CommandBuffer>> CreateCommandBuffer( + CommandBufferModeBitfield mode, + CommandCategoryBitfield command_categories) = 0; + + // Creates an event for recording into command buffers. + // The returned event object is only usable with this device and events must + // only be used to synchronize within the same queue. + virtual StatusOr<ref_ptr<Event>> CreateEvent() = 0; + + // Creates a binary semaphore that can be used with command queues owned by + // this device. To use the semaphores with other devices or instances they + // must first be exported. + virtual StatusOr<ref_ptr<BinarySemaphore>> CreateBinarySemaphore( + bool initial_value) = 0; + + // Creates a timeline semaphore that can be used with command queues owned by + // this device. To use the semaphores with other devices or instances they + // must first be exported. + virtual StatusOr<ref_ptr<TimelineSemaphore>> CreateTimelineSemaphore( + uint64_t initial_value) = 0; + + // Creates a fence that can be used with command queues owned by this device. + // To use the fences with other devices or instances they must first be + // exported. + virtual StatusOr<ref_ptr<Fence>> CreateFence(uint64_t initial_value) = 0; + + // TODO(benvanik): import/export semaphore utilities. + // TODO(benvanik): import/export fence utilities. + // TODO(benvanik): fences to wait handles. + + // Blocks the caller until all passed |fences| reach or exceed the specified + // payload values or the |deadline| elapses. All |fences| must be created from + // this device (or be imported into it). + // + // Returns success if the wait is successful and all fences have been + // signaled. + // + // Returns DEADLINE_EXCEEDED if the |deadline| elapses without all fences + // having been signaled. Note that a subset of the |fences| may have been + // signaled and each can be queried to see which ones. + virtual Status WaitAllFences(absl::Span<const FenceValue> fences, + absl::Time deadline) = 0; + inline Status WaitAllFences(absl::Span<const FenceValue> fences, + absl::Duration timeout) { + return WaitAllFences(fences, RelativeTimeoutToDeadline(timeout)); + } + + // Blocks the caller until at least one of the |fences| reaches or exceeds the + // specified payload value or the |deadline| elapses. All |fences| must be + // created from this device (or be imported into it). + // + // Returns an arbitrary index into |fences| of a fence that was signaled. Note + // that more than one fence may have been signaled and all of the other + // |fences| should be queried or waited on again until waits for them + // succeed. + // + // Returns DEADLINE_EXCEEDED if the |deadline| elapses without any fences + // having been signaled. + virtual StatusOr<int> WaitAnyFence(absl::Span<const FenceValue> fences, + absl::Time deadline) = 0; + inline StatusOr<int> WaitAnyFence(absl::Span<const FenceValue> fences, + absl::Duration timeout) { + return WaitAnyFence(fences, RelativeTimeoutToDeadline(timeout)); + } + + // Blocks until all outstanding requests on all queues have been + // completed. This is equivalent to having waited on all outstanding + // fences. + virtual Status WaitIdle(absl::Time deadline) = 0; + inline Status WaitIdle(absl::Duration timeout) { + return WaitIdle(RelativeTimeoutToDeadline(timeout)); + } + inline Status WaitIdle() { return WaitIdle(absl::InfiniteFuture()); } + + protected: + explicit Device(DeviceInfo device_info) + : device_info_(std::move(device_info)) {} + + private: + const DeviceInfo device_info_; +}; + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_DEVICE_H_
diff --git a/hal/device_info.h b/hal/device_info.h new file mode 100644 index 0000000..c7bbe67 --- /dev/null +++ b/hal/device_info.h
@@ -0,0 +1,90 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_DEVICE_INFO_H_ +#define IREE_HAL_DEVICE_INFO_H_ + +#include <cstdint> +#include <string> +#include <utility> + +#include "base/bitfield.h" + +namespace iree { +namespace hal { + +// Describes features supported by the device. +// These flags indicate the availability of features that may be enabled at the +// request of the calling application. Note that certain features may disable +// runtime optimizations or require compilation flags to ensure the required +// metadata is present in executables. +enum class DeviceFeature : uint32_t { + kNone = 0, + + // Device supports executable debugging. + // When present executables *may* be compiled with + // ExecutableCachingMode::kEnableDebugging and will have usable debugging + // related methods. Note that if the input executables do not have embedded + // debugging information they still may not be able to perform disassembly or + // fine-grained breakpoint insertion. + kDebugging = 1 << 0, + + // Device supports executable coverage information. + // When present executables *may* be compiled with + // ExecutableCachingMode::kEnableCoverage and will produce coverage buffers + // during dispatch. Note that input executables must have partial embedded + // debug information to allow mapping back to source offsets. + kCoverage = 1 << 1, + + // Device supports executable and command queue profiling. + // When present executables *may* be compiled with + // ExecutableCachingMode::kEnableProfiling and will produce profiling buffers + // during dispatch. Note that input executables must have partial embedded + // debug information to allow mapping back to source offsets. + kProfiling = 1 << 2, +}; +IREE_BITFIELD(DeviceFeature); +using DeviceFeatureBitfield = DeviceFeature; + +// TODO(benvanik): device info (caps, physical mappings, etc). +class DeviceInfo { + public: + DeviceInfo(std::string name, DeviceFeatureBitfield supported_features, + void* driver_handle = nullptr) + : name_(std::move(name)), + supported_features_(supported_features), + driver_handle_(driver_handle) {} + + const std::string& name() const { return name_; } + + // Features supported by the device. + DeviceFeatureBitfield supported_features() const { + return supported_features_; + } + + // Opaque handle used by drivers to correlate this device with their internal + // listing. This handle will not be valid across driver instances or outside + // of the current process. + void* driver_handle() const { return driver_handle_; } + + private: + const std::string name_; + const DeviceFeatureBitfield supported_features_; + void* driver_handle_; +}; + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_DEVICE_INFO_H_
diff --git a/hal/device_manager.cc b/hal/device_manager.cc new file mode 100644 index 0000000..4d1feeb --- /dev/null +++ b/hal/device_manager.cc
@@ -0,0 +1,201 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/device_manager.h" + +#include <algorithm> + +#include "base/source_location.h" +#include "base/status.h" +#include "base/tracing.h" +#include "hal/heap_buffer.h" + +namespace iree { +namespace hal { + +DeviceManager::DeviceManager() = default; + +DeviceManager::~DeviceManager() { + IREE_TRACE_SCOPE0("DeviceManager::dtor"); + WaitIdle().IgnoreError(); +} + +Status DeviceManager::RegisterDevice(std::shared_ptr<Device> device) { + IREE_TRACE_SCOPE0("DeviceManager::RegisterDevice"); + absl::MutexLock lock(&device_mutex_); + if (std::find(devices_.begin(), devices_.end(), device) != devices_.end()) { + return FailedPreconditionErrorBuilder(IREE_LOC) + << "Device already registered"; + } + devices_.push_back(std::move(device)); + return OkStatus(); +} + +Status DeviceManager::UnregisterDevice(Device* device) { + IREE_TRACE_SCOPE0("DeviceManager::UnregisterDevice"); + absl::MutexLock lock(&device_mutex_); + auto it = std::find_if(devices_.begin(), devices_.end(), + [device](const std::shared_ptr<Device>& other_device) { + return device == other_device.get(); + }); + if (it == devices_.end()) { + return NotFoundErrorBuilder(IREE_LOC) << "Device not registered"; + } + devices_.erase(it); + return OkStatus(); +} + +StatusOr<DevicePlacement> DeviceManager::ResolvePlacement( + const PlacementSpec& placement_spec) const { + IREE_TRACE_SCOPE0("DeviceManager::ResolvePlacement"); + absl::MutexLock lock(&device_mutex_); + if (devices_.empty()) { + return NotFoundErrorBuilder(IREE_LOC) << "No devices registered"; + } + + // TODO(benvanik): multiple devices and placement. + QCHECK_EQ(devices_.size(), 1) + << "Multiple devices not yet supported (need placement)"; + DevicePlacement device_placement; + device_placement.device = devices_.front(); + + return device_placement; +} + +StatusOr<Allocator*> DeviceManager::FindCompatibleAllocator( + MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, + absl::Span<const DevicePlacement> device_placements) const { + IREE_TRACE_SCOPE0("DeviceManager::FindCompatibleAllocator"); + if (device_placements.empty()) { + return InvalidArgumentErrorBuilder(IREE_LOC) << "No placements provided"; + } + + // Find the first allocator. As we only return an allocator if all placements + // are compatible we'll compare allocator[0] against allocator[1,N]. + Allocator* some_allocator = nullptr; + for (const auto& device_placement : device_placements) { + auto* allocator = device_placement.device->allocator(); + if (!some_allocator) { + some_allocator = allocator; + continue; + } + // NOTE: as there can be asymmetry between usage restrictions (A can use B + // but B cannot use A) we have to compare both directions. + if (!some_allocator->CanUseBufferLike(allocator, memory_type, buffer_usage, + buffer_usage) || + !allocator->CanUseBufferLike(some_allocator, memory_type, buffer_usage, + buffer_usage)) { + // Allocators are not compatible. + return NotFoundErrorBuilder(IREE_LOC) + << "No single allocator found that is compatible with all " + "placements"; + } + } + return some_allocator; +} + +StatusOr<ref_ptr<Buffer>> DeviceManager::TryAllocateDeviceVisibleBuffer( + MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, + device_size_t allocation_size, + absl::Span<const DevicePlacement> device_placements) { + IREE_TRACE_SCOPE("DeviceManager::TryAllocateDeviceVisibleBuffer:size", int) + (static_cast<int>(allocation_size)); + if (!AnyBitSet(memory_type & MemoryType::kHostLocal)) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Host-local buffers require the kHostLocal bit: " + << MemoryTypeString(memory_type); + } + + // Strip kDeviceVisible as we conditionally add it based on support. + memory_type &= ~MemoryType::kDeviceVisible; + + // Find an allocator that works for device-visible buffers. + // If this fails we'll fall back to allocation a non-device-visible buffer. + auto allocator_or = + FindCompatibleAllocator(memory_type | MemoryType::kDeviceVisible, + buffer_usage, device_placements); + if (allocator_or.ok()) { + return allocator_or.ValueOrDie()->Allocate( + memory_type | MemoryType::kDeviceVisible, buffer_usage, + allocation_size); + } + + // Fallback to allocating a host-local buffer. + return HeapBuffer::Allocate(memory_type, buffer_usage, allocation_size); +} + +StatusOr<ref_ptr<Buffer>> DeviceManager::AllocateDeviceVisibleBuffer( + MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, + device_size_t allocation_size, + absl::Span<const DevicePlacement> device_placements) { + IREE_TRACE_SCOPE("DeviceManager::AllocateDeviceVisibleBuffer:size", int) + (static_cast<int>(allocation_size)); + if (!AnyBitSet(memory_type & MemoryType::kHostLocal)) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Host-local buffers require the kHostLocal bit: " + << MemoryTypeString(memory_type); + } + + // Always use device-visible. + memory_type |= MemoryType::kDeviceVisible; + + // Find an allocator that works for device-visible buffers. + ASSIGN_OR_RETURN( + auto* allocator, + FindCompatibleAllocator(memory_type, buffer_usage, device_placements)); + return allocator->Allocate(memory_type, buffer_usage, allocation_size); +} + +StatusOr<ref_ptr<Buffer>> DeviceManager::AllocateDeviceLocalBuffer( + MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, + device_size_t allocation_size, + absl::Span<const DevicePlacement> device_placements) { + IREE_TRACE_SCOPE("DeviceManager::AllocateDeviceLocalBuffer:size", int) + (static_cast<int>(allocation_size)); + if (!AnyBitSet(memory_type & MemoryType::kDeviceLocal)) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Device-local buffers require the kDeviceLocal bit: " + << MemoryTypeString(memory_type); + } + + // Find an allocator that works for device-local buffers. + ASSIGN_OR_RETURN( + auto* allocator, + FindCompatibleAllocator(memory_type, buffer_usage, device_placements)); + return allocator->Allocate(memory_type, buffer_usage, allocation_size); +} + +Status DeviceManager::Submit(Device* device, CommandQueue* command_queue, + absl::Span<const SubmissionBatch> batches, + absl::Time deadline, FenceValue fence) { + IREE_TRACE_SCOPE0("DeviceManager::Submit"); + return command_queue->Submit(batches, fence); +} + +Status DeviceManager::Flush() { + IREE_TRACE_SCOPE0("DeviceManager::Flush"); + return OkStatus(); +} + +Status DeviceManager::WaitIdle(absl::Time deadline) { + IREE_TRACE_SCOPE0("DeviceManager::WaitIdle"); + absl::MutexLock lock(&device_mutex_); + for (const auto& device : devices_) { + RETURN_IF_ERROR(device->WaitIdle(deadline)); + } + return OkStatus(); +} + +} // namespace hal +} // namespace iree
diff --git a/hal/device_manager.h b/hal/device_manager.h new file mode 100644 index 0000000..07c22e0 --- /dev/null +++ b/hal/device_manager.h
@@ -0,0 +1,209 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_DEVICE_MANAGER_H_ +#define IREE_HAL_DEVICE_MANAGER_H_ + +#include <vector> + +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "base/status.h" +#include "base/time.h" +#include "hal/allocator.h" +#include "hal/buffer.h" +#include "hal/command_queue.h" +#include "hal/device.h" +#include "hal/device_placement.h" +#include "hal/executable_format.h" +#include "hal/fence.h" + +namespace iree { +namespace hal { + +// Specifies how devices should be resolved to DevicePlacements. +// Most fields are optional and when not included will be ignored. +struct PlacementSpec { + // TODO(benvanik): other requirements (features/caps, power, etc). + + // A list of executable formats that the placement should support. + // If more than one format is provided any device satisfying at least one + // will be considered for placement. The formats can be sorted in descending + // priority order to prefer the first available format in the case of ties. + absl::Span<const ExecutableFormat> available_formats; +}; + +// Manages device lifetime and placement resolution. +// Optionally the DeviceManager may be used for automatic device selection for +// allocations or batched submissions, however this is not required if specific +// devices and scheduling behavior are known to the caller. +// +// Thread-safe. Note that callers must ensure that unregistered devices are kept +// alive for as long as any commands are in-flight that may be using them. +class DeviceManager final { + public: + DeviceManager(); + ~DeviceManager(); + + // Registers a device with the manager. + // The device will be used to resolve placements. Any placements resolved + // prior to the addition of the device will need to be refreshed by the caller + // if they want to make use of the new device. + Status RegisterDevice(std::shared_ptr<Device> device); + + // Unregisters a device with the manager. + // Placements that resolved to the device prior to unregistering will remain + // valid for that device. Callers will need to refresh the placements to + // ensure the device stops being used. + Status UnregisterDevice(Device* device); + + // TODO(benvanik): dispatch info + requirements + etc -> DevicePlacement. + + // Resolves a placement spec to a device placement based on the registered + // devices. + // If the placement is not fully specified the device and queue may be chosen + // at random. See PlacementSpec for more information about resolution and + // ranking. + StatusOr<DevicePlacement> ResolvePlacement( + const PlacementSpec& placement_spec) const; + + // Finds an allocator that can allocate buffers of the given |memory_type| and + // |buffer_usage| such that the buffers can be used interchangebly. + // Fails if there is no Allocator that can satisfy that requirement. + StatusOr<Allocator*> FindCompatibleAllocator( + MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, + absl::Span<const DevicePlacement> device_placements) const; + + // Tries to allocate a host-local buffer that _may_ be optimal for use with + // the given |device_placements| and _may_ be device-visible. The buffer can + // be used for staging uploads to device-local buffers and is useful for times + // when the buffer will be used more on the host than the device. If a buffer + // never needs to be used with a device prefer instead + // Allocator::host_local()::Allocate. + // + // Returns a buffer even if it's not possible to satisfy the requested + // |buffer_usage| for the |device_placements| at the cost of a run-time + // performance hit. + StatusOr<ref_ptr<Buffer>> TryAllocateDeviceVisibleBuffer( + MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, + device_size_t allocation_size, + absl::Span<const DevicePlacement> device_placements); + StatusOr<ref_ptr<Buffer>> TryAllocateDeviceVisibleBuffer( + BufferUsageBitfield buffer_usage, device_size_t allocation_size, + absl::Span<const DevicePlacement> device_placements) { + return TryAllocateDeviceVisibleBuffer( + MemoryType::kHostLocal | MemoryType::kDeviceVisible, buffer_usage, + allocation_size, device_placements); + } + + // Allocates a host-local buffer that is optimal for use on the host but is + // usable by the given |device_placements| (at a possible performance + // penalty). The buffer can be used for staging uploads to device-local + // buffers and is useful for times when the buffer will be used more on the + // host than the device. If a buffer never needs to be used with a device + // prefer instead HeapBuffer::Allocate. + // + // Fails if it is not possible to allocate and satisfy all |device_placements| + // for the requested |buffer_usage|. + StatusOr<ref_ptr<Buffer>> AllocateDeviceVisibleBuffer( + MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, + device_size_t allocation_size, + absl::Span<const DevicePlacement> device_placements); + StatusOr<ref_ptr<Buffer>> AllocateDeviceVisibleBuffer( + BufferUsageBitfield buffer_usage, device_size_t allocation_size, + absl::Span<const DevicePlacement> device_placements) { + return AllocateDeviceVisibleBuffer( + MemoryType::kHostLocal | MemoryType::kDeviceVisible, buffer_usage, + allocation_size, device_placements); + } + + // Allocates a device-local buffer that is optimal for use with the given + // |device_placements|. The buffer will not be host-visible and can only be + // used from compatible device queues. + // + // Fails if it is not possible to allocate and satisfy all |device_placements| + // for the requested |buffer_usage|. + StatusOr<ref_ptr<Buffer>> AllocateDeviceLocalBuffer( + MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, + device_size_t allocation_size, + absl::Span<const DevicePlacement> device_placements); + StatusOr<ref_ptr<Buffer>> AllocateDeviceLocalBuffer( + BufferUsageBitfield buffer_usage, device_size_t allocation_size, + absl::Span<const DevicePlacement> device_placements) { + return AllocateDeviceLocalBuffer(MemoryType::kDeviceLocal, buffer_usage, + allocation_size, device_placements); + } + + // Enqueues a submission against the given target |device| |command_queue|. + // The provided |deadline| is used to determine how long the submission can + // stay waiting in the queue prior to flushing, with absl::InfinitePast + // indicating immediate submission and absl::InfiniteFuture indicating that + // Flush must be called. + // + // If a |fence| is provided it will be signaled when the submission has + // completed and otherwise the caller must use WaitIdle to ensure completion. + // If a sequence of submissions are performed then the semaphore relationships + // can be used to elide waits. Submit(A)+Submit(B, fence) where there is a + // dependency from A->B is safe. + // + // All provided resources must remain alive until the provided |fence| + // resolves or Scheduler::WaitIdle succeeds. + // + // Submissions may be made from any thread. Behavior is undefined + // if a thread is performing a WaitIdle while another thread submits work. + Status Submit(Device* device, CommandQueue* command_queue, + absl::Span<const SubmissionBatch> batches, absl::Time deadline, + FenceValue fence = {}); + Status Submit(Device* device, CommandQueue* command_queue, + absl::Span<const SubmissionBatch> batches, + absl::Duration timeout, FenceValue fence = {}) { + return Submit(device, command_queue, batches, + RelativeTimeoutToDeadline(timeout), fence); + } + Status Submit(Device* device, CommandQueue* command_queue, + absl::Span<const SubmissionBatch> batches, + FenceValue fence = {}) { + return Submit(device, command_queue, batches, absl::InfinitePast(), fence); + } + + // Flushes any requests that are pending in the scheduler and ensures they + // begin executing ASAP regardless of policy. + // + // If any used device has encountered an error during submission at any + // point it will be returned here (repeatedly). + Status Flush(); + + // Blocks until all outstanding requests have been completed. + // This is equivalent to having waited on all outstanding fences. + // Implicitly calls Flush to ensure delayed requests are scheduled. + // Work submitted from other threads during a wait may not be included in the + // wait set. + // + // If any used device has encountered an error during submission at any + // point it will be returned here (repeatedly). + Status WaitIdle(absl::Time deadline); + inline Status WaitIdle(absl::Duration timeout) { + return WaitIdle(RelativeTimeoutToDeadline(timeout)); + } + inline Status WaitIdle() { return WaitIdle(absl::InfiniteFuture()); } + + private: + mutable absl::Mutex device_mutex_; + std::vector<std::shared_ptr<Device>> devices_ ABSL_GUARDED_BY(device_mutex_); +}; + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_DEVICE_MANAGER_H_
diff --git a/iree/hal/device_placement.h b/hal/device_placement.h similarity index 100% rename from iree/hal/device_placement.h rename to hal/device_placement.h
diff --git a/hal/driver.h b/hal/driver.h new file mode 100644 index 0000000..bc0dbc1 --- /dev/null +++ b/hal/driver.h
@@ -0,0 +1,61 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_DRIVER_H_ +#define IREE_HAL_DRIVER_H_ + +#include <memory> +#include <string> +#include <vector> + +#include "base/status.h" +#include "hal/device.h" +#include "hal/device_info.h" + +namespace iree { +namespace hal { + +class Driver { + public: + virtual ~Driver() = default; + + // Driver name used during registration. + const std::string& name() const { return name_; } + + // TODO(benvanik): info/query (version number, etc). + + // Enumerates devices available for creation from the driver. + // This may fail if the driver is in an invalid state but otherwise will + // return an empty list if no devices are available. + virtual StatusOr<std::vector<DeviceInfo>> EnumerateAvailableDevices() = 0; + + // Creates the driver-defined 'default' device. + // This may simply be the first device enumerated. + virtual StatusOr<std::shared_ptr<Device>> CreateDefaultDevice() = 0; + + // Creates a device as queried with the given |device_info|. + virtual StatusOr<std::shared_ptr<Device>> CreateDevice( + const DeviceInfo& device_info) = 0; + + protected: + explicit Driver(std::string name) : name_(std::move(name)) {} + + private: + const std::string name_; +}; + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_DRIVER_H_
diff --git a/hal/driver_registry.cc b/hal/driver_registry.cc new file mode 100644 index 0000000..cb49a0a --- /dev/null +++ b/hal/driver_registry.cc
@@ -0,0 +1,87 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/driver_registry.h" + +#include "base/status.h" + +namespace iree { +namespace hal { + +// static +DriverRegistry* DriverRegistry::shared_registry() { + static auto* singleton = new DriverRegistry(); + return singleton; +} + +DriverRegistry::DriverRegistry() = default; + +DriverRegistry::~DriverRegistry() = default; + +Status DriverRegistry::Register(std::string driver_name, FactoryFn factory_fn) { + absl::MutexLock lock(&mutex_); + for (const auto& pair : driver_factory_fns_) { + if (pair.first == driver_name) { + return AlreadyExistsErrorBuilder(IREE_LOC) + << "Driver already registered: " << driver_name; + } + } + driver_factory_fns_.emplace_back(driver_name, std::move(factory_fn)); + return OkStatus(); +} + +bool DriverRegistry::HasDriver(absl::string_view driver_name) const { + absl::MutexLock lock(&mutex_); + for (const auto& pair : driver_factory_fns_) { + if (pair.first == driver_name) { + return true; + } + } + return false; +} + +std::vector<std::string> DriverRegistry::EnumerateAvailableDrivers() const { + absl::MutexLock lock(&mutex_); + std::vector<std::string> driver_names; + driver_names.reserve(driver_factory_fns_.size()); + for (const auto& pair : driver_factory_fns_) { + driver_names.push_back(pair.first); + } + return driver_names; +} + +StatusOr<std::shared_ptr<Driver>> DriverRegistry::Create( + absl::string_view driver_name) const { + FactoryFn factory_fn; + { + absl::MutexLock lock(&mutex_); + for (const auto& pair : driver_factory_fns_) { + if (pair.first == driver_name) { + factory_fn = pair.second; + break; + } + } + if (!factory_fn) { + return NotFoundErrorBuilder(IREE_LOC) + << "Driver " << driver_name << " not found"; + } + } + return factory_fn(); +} + +} // namespace hal +} // namespace iree + +IREE_REGISTER_MODULE_INITIALIZER( + iree_hal, ::iree::hal::DriverRegistry::shared_registry());
diff --git a/hal/driver_registry.h b/hal/driver_registry.h new file mode 100644 index 0000000..6cca616 --- /dev/null +++ b/hal/driver_registry.h
@@ -0,0 +1,83 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_DRIVER_REGISTRY_H_ +#define IREE_HAL_DRIVER_REGISTRY_H_ + +#include <memory> +#include <vector> + +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" +#include "base/init.h" +#include "base/status.h" +#include "hal/driver.h" + +namespace iree { +namespace hal { + +// Driver registry and factory. +// Factory functions for available drivers are registered with a given name and +// can be invoked with a call to Create. The configuration of the drivers is +// generally contained within the factory function and consumers of the drivers +// don't need to fiddle with things. +// +// This is used for dynamic *safe* link-time driver module registration. +// Roughly: driver_registry provides the shared registry and a way to create +// drivers and *_driver_module.cc files register drivers when linked in. +// Remember to alwayslink=1 on cc_libraries providing modules. +// +// If link-time driver registration is not desired (or possible) it's also +// possible to explicitly register drivers via this registry. This is useful +// when programmatically enabling drivers. +// +// Thread-safe. +class DriverRegistry final { + public: + using FactoryFn = std::function<StatusOr<std::shared_ptr<Driver>>()>; + + // The shared driver registry singleton that modules use when linked in. + static DriverRegistry* shared_registry(); + + DriverRegistry(); + ~DriverRegistry(); + + // Registers a driver and its factory function. + // The function will be called to create a new driver whenever it is requested + // via Create. + Status Register(std::string driver_name, FactoryFn factory_fn); + + // Returns true if there is a driver registered with the given name. + bool HasDriver(absl::string_view driver_name) const; + + // Returns a list of registered drivers. + std::vector<std::string> EnumerateAvailableDrivers() const; + + // TODO(benvanik): flags for enabling debug validation/control/etc. + // Creates a driver by name. + StatusOr<std::shared_ptr<Driver>> Create(absl::string_view driver_name) const; + + private: + mutable absl::Mutex mutex_; + std::vector<std::pair<std::string, FactoryFn>> driver_factory_fns_ + ABSL_GUARDED_BY(mutex_); +}; + +} // namespace hal +} // namespace iree + +IREE_DECLARE_MODULE_INITIALIZER(iree_hal); +IREE_REQUIRE_MODULE_LINKED(iree_hal); + +#endif // IREE_HAL_DRIVER_REGISTRY_H_
diff --git a/hal/event.h b/hal/event.h new file mode 100644 index 0000000..53b981d --- /dev/null +++ b/hal/event.h
@@ -0,0 +1,35 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_EVENT_H_ +#define IREE_HAL_EVENT_H_ + +#include "hal/resource.h" + +namespace iree { +namespace hal { + +// Events are used for defining synchronization scopes within CommandBuffers. +// An event only exists within a single CommandBuffer and must not be used +// across CommandBuffers from the same device or others. +// +// See CommandBuffer::SignalEvent and CommandBuffer::WaitEvents for more info. +class Event : public Resource { + public: +}; + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_EVENT_H_
diff --git a/hal/executable.h b/hal/executable.h new file mode 100644 index 0000000..7a39b10 --- /dev/null +++ b/hal/executable.h
@@ -0,0 +1,57 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_EXECUTABLE_H_ +#define IREE_HAL_EXECUTABLE_H_ + +#include "hal/resource.h" + +namespace iree { +namespace hal { + +class Executable : public Resource { + public: + ~Executable() override = default; + + // True if the executable was prepared with debugging enabled and the device + // and input data support debugging (symbols present, etc). + virtual bool supports_debugging() const = 0; + + // TODO(benvanik): disassembly methods. + + // TODO(benvanik): relative offset calculation: + // - step once + // - step over + // - step out + + // TODO(benvanik): create executable split on breakpoint. + // Executable should return when the breakpoint is hit without any future + // modifications to output buffers. If the breakpoint is not hit the + // executable should run to completion as normal. + + // TODO(benvanik): retrieve coverage info. + // Returns a buffer containing offset -> coverage metrics. Note that depending + // on the device this may only contain a single coverage metric for the entire + // executable or some subset of the available offsets. + + // TODO(benvanik): retrieve profiling info. + + protected: + Executable() = default; +}; + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_EXECUTABLE_H_
diff --git a/hal/executable_cache.cc b/hal/executable_cache.cc new file mode 100644 index 0000000..05617cd --- /dev/null +++ b/hal/executable_cache.cc
@@ -0,0 +1,25 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/executable_cache.h" + +namespace iree { +namespace hal { + +ExecutableCache::ExecutableCache() = default; + +ExecutableCache::~ExecutableCache() = default; + +} // namespace hal +} // namespace iree
diff --git a/hal/executable_cache.h b/hal/executable_cache.h new file mode 100644 index 0000000..cd3eaa3 --- /dev/null +++ b/hal/executable_cache.h
@@ -0,0 +1,126 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_EXECUTABLE_CACHE_H_ +#define IREE_HAL_EXECUTABLE_CACHE_H_ + +#include "base/bitfield.h" +#include "base/ref_ptr.h" +#include "base/status.h" +#include "hal/executable.h" +#include "hal/executable_format.h" +#include "hal/executable_spec.h" + +namespace iree { +namespace hal { + +// Defines how the executable cache performs preparation. +enum class ExecutableCachingMode : uint32_t { + // Allows the cache to reference the provided executable_data after it has + // prepared the executable. Callers must ensure the data remains valid for the + // lifetime of the cache. If memory mapping constant executable data from + // disk this can be used to avoid copies. + kAliasProvidedData = 1 << 0, + + // Allows the prepared executable to be cached persistently (on disk/etc). + // Enable for any executable that is likely to be used in future runs. + // Note that not all caches support persistent serialization and this is just + // a hint. + kAllowPersistentCaching = 1 << 1, + + // Allows the cache to optimize the executable as much as it can. + // This may cause preparation to take significantly longer while (hopefully) + // improving runtime performance. Avoid for one-shot executables. + kAllowOptimization = 1 << 2, + + // Enables Executable debugging methods if supported by the device and + // executable. This may disable certain optimizations or retain additional + // data to allow disassembly, stepping, etc. + // + // Device must support the DeviceFeature::kDebugging feature and executables + // must support the ExecutableFeature::kDebugging feature. + kEnableDebugging = 1 << 3, + + // Enables Executable coverage if supported by the device and executable. + // Depending on the optimization mode this may produce partial coverage + // results (for example, when certain source operations were optimized away). + // + // Device must support the DeviceFeature::kCoverage feature and executables + // must support the ExecutableFeature::kCoverage feature. + kEnableCoverage = 1 << 4, + + // Enables Executable profiling if supported by the device and executable. + // Depending on the optimization mode this may produce partial profiling + // results. Profiling attribution (whether to the entire executable or + // specific operations) depends on the implementation. + // + // Device must support the DeviceFeature::kProfiling feature and executables + // must support the ExecutableFeature::kProfiling feature. + kEnableProfiling = 1 << 5, + + // Default caching mode. + kDefault = kAllowPersistentCaching | kAllowOptimization, +}; +IREE_BITFIELD(ExecutableCachingMode); +using ExecutableCachingModeBitfield = ExecutableCachingMode; + +// A cache of prepared executables for a particular device. +// Caches may be shared across multiple devices from the same driver or specific +// to individual devices. Caches may persist prepared executables across process +// launches or reprepare them each run. Callers should assume that the cache is +// a no-op and the returned Executables only live for as long as the cache does. +// +// The term 'cache' here is rather optimistic - it's perfectly acceptable for +// implementations to not cache at all and return new Executables for each +// PrepareExecutable called (even for the same executable). Callers should +// expect such behavior and try to retain the results of the PrepareExecutable +// calls to reduce overhead in re-preparing executables. +// +// Thread-safe - multiple threads may prepare executables (including the *same* +// executable) simultaneously. +class ExecutableCache { + public: + virtual ~ExecutableCache(); + + // TODO(benvanik): status/queries (size, etc). + + // TODO(b/137153339): serialization/deserialization. + + // Returns true if the executable cache can prepare the given executable input + // format. Perparation may still fail if the particular version or features + // required by the executable are not supported. + virtual bool CanPrepareFormat(ExecutableFormat format) const = 0; + + // Prepares an executable for use. + // The provided |spec| and |executable_data| will be used to either lookup a + // previously prepared executable in the cache or prepare a new one. + // + // Depending on the driver preparation may take a non-trivial amount of time + // (such as when JITing/etc). As the cache is internally synchronized callers + // can issue preparation requests from multiple threads - even for the same + // executables - and calls will block until preparation completes. + // + // When preparing a large number of executables it's recommended to use the + // PrepareExecutables method to batch and wait on the results. + virtual StatusOr<ref_ptr<Executable>> PrepareExecutable( + ExecutableCachingModeBitfield mode, const ExecutableSpec& spec) = 0; + + protected: + ExecutableCache(); +}; + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_EXECUTABLE_CACHE_H_
diff --git a/iree/hal/executable_format.h b/hal/executable_format.h similarity index 100% rename from iree/hal/executable_format.h rename to hal/executable_format.h
diff --git a/hal/executable_spec.h b/hal/executable_spec.h new file mode 100644 index 0000000..f871819 --- /dev/null +++ b/hal/executable_spec.h
@@ -0,0 +1,44 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_EXECUTABLE_SPEC_H_ +#define IREE_HAL_EXECUTABLE_SPEC_H_ + +#include "absl/types/span.h" +#include "hal/executable_format.h" + +namespace iree { +namespace hal { + +// Defines an executable specification used by a cache to prepare an executable. +struct ExecutableSpec { + // TODO(benvanik): pre-populated hash_code/key to avoid calculation. + + // Format of the executable input data. + ExecutableFormat format = kExecutableFormatUnspecified; + + // A reference to the executable data as input to the cache. + // If ExecutableCachingMode::kAliasProvidedData is set then this reference + // may be retained by the cache and the backing buffer must be kept valid for + // the lifetime of the cache. + absl::Span<const uint8_t> executable_data; + + // TODO(benvanik): add specialization info (constants/defines). + // TODO(benvanik): add compiler flags? could treat as opaque. +}; + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_EXECUTABLE_SPEC_H_
diff --git a/hal/fence.h b/hal/fence.h new file mode 100644 index 0000000..40071c5 --- /dev/null +++ b/hal/fence.h
@@ -0,0 +1,72 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_FENCE_H_ +#define IREE_HAL_FENCE_H_ + +#include <cstdint> + +#include "base/status.h" +#include "hal/resource.h" + +namespace iree { +namespace hal { + +// Synchronization mechanism for device->host notification. +// Fences behave like timeline semaphores and contain a monotonically increasing +// uint64_t payload. They may be waited on any number of times - even if they +// have already been signaled. +// +// A fence is updated to its new value after all prior commands have completed +// but the delay between completion and the host being woken varies. Some +// implementations may coalesce fences to avoid spurious waking while others +// will immediately synchronize with the host. +// +// The primary use of fences is for resource lifetime management: all resources +// used by a set of submission batches must be considered live until the fence +// attached to the submission has signaled. +// +// Fences may be set to a permanently failed state by implementations when +// errors occur during asynchronous execution. Users are expected to propagate +// the failures and possibly reset the entire device that produced the error. +// +// For more information on fences see the following docs describing how +// timelines are generally used (specifically in the device->host case): +// https://www.youtube.com/watch?v=SpE--Rf516Y +// https://www.khronos.org/assets/uploads/developers/library/2018-xdc/Vulkan-Timeline-Semaphores-Part-1_Sep18.pdf +// https://docs.microsoft.com/en-us/windows/win32/direct3d12/user-mode-heap-synchronization +class Fence : public Resource { + public: + // Returns a permanent failure status if the fence is indicating an + // asynchronous failure. + // + // Returns the status at the time the method is called without blocking and as + // such is only valid after a fence has been signaled. The same failure status + // will be returned regardless of when in the timeline the error occurred. + virtual Status status() const = 0; + + // Queries the current payload of the fence. As the payload is monotonically + // increasing it is guaranteed that the value is at least equal to the + // previous result of a QueryValue call and coherent with any waits for a + // specified value via Device::WaitAllFences. + virtual StatusOr<uint64_t> QueryValue() = 0; +}; + +// A reference to a fence and associated payload value. +using FenceValue = std::pair<Fence*, uint64_t>; + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_FENCE_H_
diff --git a/hal/heap_buffer.cc b/hal/heap_buffer.cc new file mode 100644 index 0000000..5ea4326 --- /dev/null +++ b/hal/heap_buffer.cc
@@ -0,0 +1,190 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/heap_buffer.h" + +#include <cstdint> +#include <cstdlib> +#include <string> +#include <utility> + +#include "base/source_location.h" +#include "base/status.h" +#include "base/tracing.h" +#include "hal/allocator.h" +#include "hal/host/host_buffer.h" + +namespace iree { +namespace hal { + +namespace { + +// An allocator that allocates or wraps host-only buffers. +// The resulting buffers are not usable by most devices without a copy and +// using a device allocator is strongly preferred. +class HeapAllocator : public Allocator { + public: + // Returns a singleton heap allocator that can provide buffers that have + // MemoryType::kHostLocal and are allocated with malloc/free. + // These buffers will not be usable by devices directly and may incur + // additional copies. + static Allocator* std_heap(); + + // TODO(benvanik): specify custom allocator (not malloc/free). + HeapAllocator(); + ~HeapAllocator() override; + + bool CanUseBufferLike(Allocator* source_allocator, + MemoryTypeBitfield memory_type, + BufferUsageBitfield buffer_usage, + BufferUsageBitfield intended_usage) const override; + + bool CanAllocate(MemoryTypeBitfield memory_type, + BufferUsageBitfield buffer_usage, + size_t allocation_size) const override; + + StatusOr<ref_ptr<Buffer>> Allocate(MemoryTypeBitfield memory_type, + BufferUsageBitfield buffer_usage, + size_t allocation_size) override; + + StatusOr<ref_ptr<Buffer>> WrapMutable(MemoryTypeBitfield memory_type, + MemoryAccessBitfield allowed_access, + BufferUsageBitfield buffer_usage, + void* data, + size_t data_length) override; +}; + +// static +Allocator* HeapAllocator::std_heap() { + static Allocator* std_heap_allocator = new HeapAllocator(); + return std_heap_allocator; +} + +HeapAllocator::HeapAllocator() = default; + +HeapAllocator::~HeapAllocator() = default; + +bool HeapAllocator::CanUseBufferLike(Allocator* source_allocator, + MemoryTypeBitfield memory_type, + BufferUsageBitfield buffer_usage, + BufferUsageBitfield intended_usage) const { + // The host can use anything with kHostVisible. + if (!AnyBitSet(memory_type & MemoryType::kHostVisible)) { + return false; + } + + // Host currently uses mapping to copy buffers, which is done a lot. + if (!AnyBitSet(buffer_usage & BufferUsage::kMapping)) { + return false; + } + + return true; +} + +bool HeapAllocator::CanAllocate(MemoryTypeBitfield memory_type, + BufferUsageBitfield buffer_usage, + size_t allocation_size) const { + // This host only allocator cannot serve device visible allocation as we + // can't know which devices these buffers will be used with. + return (memory_type & MemoryType::kHostLocal) == MemoryType::kHostLocal && + !AnyBitSet(memory_type & MemoryType::kDeviceLocal) && + !AnyBitSet(memory_type & MemoryType::kDeviceVisible); +} + +StatusOr<ref_ptr<Buffer>> HeapAllocator::Allocate( + MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, + size_t allocation_size) { + IREE_TRACE_SCOPE0("HeapAllocator::Allocate"); + + if (!CanAllocate(memory_type, buffer_usage, allocation_size)) { + return FailedPreconditionErrorBuilder(IREE_LOC) + << "Allocation not supported; memory_type=" + << MemoryTypeString(memory_type) + << ", buffer_usage=" << BufferUsageString(buffer_usage) + << ", allocation_size=" << allocation_size; + } + + void* malloced_data = std::calloc(1, allocation_size); + if (!malloced_data) { + return ResourceExhaustedErrorBuilder(IREE_LOC) + << "Failed to malloc " << allocation_size << " bytes"; + } + + auto buffer = + make_ref<HostBuffer>(this, memory_type, MemoryAccess::kAll, buffer_usage, + allocation_size, malloced_data, true); + return buffer; +} + +StatusOr<ref_ptr<Buffer>> HeapAllocator::WrapMutable( + MemoryTypeBitfield memory_type, MemoryAccessBitfield allowed_access, + BufferUsageBitfield buffer_usage, void* data, size_t data_length) { + auto buffer = make_ref<HostBuffer>(this, memory_type, allowed_access, + buffer_usage, data_length, data, false); + return buffer; +} + +} // namespace + +// static +ref_ptr<Buffer> HeapBuffer::Allocate(MemoryTypeBitfield memory_type, + BufferUsageBitfield usage, + size_t allocation_size) { + auto buffer_or = + HeapAllocator::std_heap()->Allocate(memory_type, usage, allocation_size); + return std::move(buffer_or.ValueOrDie()); +} + +// static +ref_ptr<Buffer> HeapBuffer::AllocateCopy(BufferUsageBitfield usage, + const void* data, size_t data_length) { + return AllocateCopy(usage, MemoryAccess::kAll, data, data_length); +} + +// static +ref_ptr<Buffer> HeapBuffer::AllocateCopy(BufferUsageBitfield usage, + MemoryAccessBitfield allowed_access, + const void* data, size_t data_length) { + IREE_TRACE_SCOPE0("HeapBuffer::AllocateCopy"); + // Ensure we can map so that we can copy into it. + usage |= BufferUsage::kMapping; + auto buffer_or = HeapAllocator::std_heap()->Allocate(MemoryType::kHostLocal, + usage, data_length); + auto buffer = std::move(buffer_or.ValueOrDie()); + buffer->WriteData(0, data, data_length).IgnoreError(); + buffer->set_allowed_access(allowed_access); + return buffer; +} + +// static +ref_ptr<Buffer> HeapBuffer::Wrap(MemoryTypeBitfield memory_type, + BufferUsageBitfield usage, const void* data, + size_t data_length) { + auto buffer_or = + HeapAllocator::std_heap()->Wrap(memory_type, usage, data, data_length); + return std::move(buffer_or.ValueOrDie()); +} + +// static +ref_ptr<Buffer> HeapBuffer::WrapMutable(MemoryTypeBitfield memory_type, + MemoryAccessBitfield allowed_access, + BufferUsageBitfield usage, void* data, + size_t data_length) { + auto buffer_or = HeapAllocator::std_heap()->WrapMutable( + memory_type, allowed_access, usage, data, data_length); + return std::move(buffer_or.ValueOrDie()); +} + +} // namespace hal +} // namespace iree
diff --git a/hal/heap_buffer.h b/hal/heap_buffer.h new file mode 100644 index 0000000..ceaa10f --- /dev/null +++ b/hal/heap_buffer.h
@@ -0,0 +1,117 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_HEAP_BUFFER_H_ +#define IREE_HAL_HEAP_BUFFER_H_ + +#include <memory> + +#include "base/status.h" +#include "hal/buffer.h" + +namespace iree { +namespace hal { + +// Factory for buffers that are allocated from the host heap (malloc/free). +// These buffers cannot be used by devices and will incur copies/transfers when +// used. Prefer device-specific allocators instead. +class HeapBuffer { + public: + // Allocates a zeroed host heap buffer of the given size. + // Returns a buffer allocated with malloc and have MemoryType::kHostLocal + // and will not be usable by devices without copies. + static ref_ptr<Buffer> Allocate(MemoryTypeBitfield memory_type, + BufferUsageBitfield usage, + size_t allocation_size); + static ref_ptr<Buffer> Allocate(BufferUsageBitfield usage, + size_t allocation_size) { + return Allocate(MemoryType::kHostLocal, usage, allocation_size); + } + + // Allocates a host heap buffer with a copy of the given data. + // Returns a buffer allocated with malloc and have MemoryType::kHostLocal + // and will not be usable by devices without copies. + static ref_ptr<Buffer> AllocateCopy(BufferUsageBitfield usage, + const void* data, size_t data_length); + static ref_ptr<Buffer> AllocateCopy(BufferUsageBitfield usage, + MemoryAccessBitfield allowed_access, + const void* data, size_t data_length); + template <typename T> + static ref_ptr<Buffer> AllocateCopy(BufferUsageBitfield usage, + absl::Span<const T> data); + template <typename T> + static ref_ptr<Buffer> AllocateCopy(BufferUsageBitfield usage, + MemoryAccessBitfield allowed_access, + absl::Span<const T> data); + + // Wraps an existing host heap allocation in a buffer. + // Ownership of the host allocation remains with the caller and the memory + // must remain valid for so long as the Buffer may be in use. + // Will have MemoryType::kHostLocal in most cases and may not be usable + // by the device. + static ref_ptr<Buffer> Wrap(MemoryTypeBitfield memory_type, + BufferUsageBitfield usage, const void* data, + size_t data_length); + static ref_ptr<Buffer> WrapMutable(MemoryTypeBitfield memory_type, + MemoryAccessBitfield allowed_access, + BufferUsageBitfield usage, void* data, + size_t data_length); + template <typename T> + static ref_ptr<Buffer> Wrap(MemoryTypeBitfield memory_type, + BufferUsageBitfield usage, + absl::Span<const T> data); + template <typename T> + static ref_ptr<Buffer> WrapMutable(MemoryTypeBitfield memory_type, + MemoryAccessBitfield allowed_access, + BufferUsageBitfield usage, + absl::Span<T> data); +}; + +// Inline functions and template definitions follow: + +template <typename T> +ref_ptr<Buffer> HeapBuffer::AllocateCopy(BufferUsageBitfield usage, + absl::Span<const T> data) { + return HeapBuffer::AllocateCopy(usage, MemoryAccess::kAll, data); +} + +template <typename T> +ref_ptr<Buffer> HeapBuffer::AllocateCopy(BufferUsageBitfield usage, + MemoryAccessBitfield allowed_access, + absl::Span<const T> data) { + return HeapBuffer::AllocateCopy(usage, allowed_access, data.data(), + data.size() * sizeof(T)); +} + +template <typename T> +ref_ptr<Buffer> HeapBuffer::Wrap(MemoryTypeBitfield memory_type, + BufferUsageBitfield usage, + absl::Span<const T> data) { + return HeapBuffer::Wrap(memory_type, usage, data.data(), + data.size() * sizeof(T)); +} + +template <typename T> +ref_ptr<Buffer> HeapBuffer::WrapMutable(MemoryTypeBitfield memory_type, + MemoryAccessBitfield allowed_access, + BufferUsageBitfield usage, + absl::Span<T> data) { + return HeapBuffer::WrapMutable(memory_type, allowed_access, usage, + data.data(), data.size() * sizeof(T)); +} + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_HEAP_BUFFER_H_
diff --git a/hal/host/BUILD b/hal/host/BUILD new file mode 100644 index 0000000..76f055c --- /dev/null +++ b/hal/host/BUILD
@@ -0,0 +1,155 @@ +# Default implementations for HAL types that use the host resources. +# These are generally just wrappers around host heap memory and host threads. + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "async_command_queue", + srcs = ["async_command_queue.cc"], + hdrs = ["async_command_queue.h"], + deps = [ + ":host_submission_queue", + "///base:status", + "///base:tracing", + "///hal:command_queue", + "///hal:fence", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/synchronization", + ], +) + +cc_test( + name = "async_command_queue_test", + srcs = ["async_command_queue_test.cc"], + deps = [ + ":async_command_queue", + ":host_submission_queue", + "///base:status", + "///base:status_matchers", + "///base:time", + "///hal:command_queue", + "///hal/testing:mock_command_buffer", + "///hal/testing:mock_command_queue", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/time", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "host_buffer", + srcs = ["host_buffer.cc"], + hdrs = ["host_buffer.h"], + deps = [ + "///base:logging", + "///base:source_location", + "///base:status", + "///hal:buffer", + "@com_google_absl//absl/base:core_headers", + ], +) + +cc_library( + name = "host_event", + srcs = ["host_event.cc"], + hdrs = ["host_event.h"], + deps = [ + "///hal:event", + ], +) + +cc_library( + name = "host_fence", + srcs = ["host_fence.cc"], + hdrs = ["host_fence.h"], + deps = [ + "///base:status", + "///base:tracing", + "///hal:fence", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "host_fence_test", + srcs = ["host_fence_test.cc"], + deps = [ + ":host_fence", + "///base:status", + "///base:status_matchers", + "@com_google_absl//absl/time", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "host_local_allocator", + srcs = ["host_local_allocator.cc"], + hdrs = ["host_local_allocator.h"], + deps = [ + ":host_buffer", + "///base:source_location", + "///base:status", + "///base:tracing", + "///hal:allocator", + "///hal:buffer", + ], +) + +cc_library( + name = "host_local_command_processor", + srcs = ["host_local_command_processor.cc"], + hdrs = ["host_local_command_processor.h"], + deps = [ + "///base:source_location", + "///base:status", + "///base:tracing", + "///hal:command_buffer", + ], +) + +cc_library( + name = "host_submission_queue", + srcs = ["host_submission_queue.cc"], + hdrs = ["host_submission_queue.h"], + deps = [ + ":host_fence", + "///base:intrusive_list", + "///base:status", + "///base:tracing", + "///hal:command_queue", + "///hal:fence", + "///hal:semaphore", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/synchronization", + ], +) + +cc_test( + name = "host_submission_queue_test", + srcs = ["host_submission_queue_test.cc"], + deps = [ + ":host_submission_queue", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "inproc_command_buffer", + srcs = ["inproc_command_buffer.cc"], + hdrs = ["inproc_command_buffer.h"], + deps = [ + "///base:arena", + "///base:intrusive_list", + "///base:status", + "///base:tracing", + "///hal:command_buffer", + ], +)
diff --git a/iree/hal/host/CMakeLists.txt b/hal/host/CMakeLists.txt similarity index 100% rename from iree/hal/host/CMakeLists.txt rename to hal/host/CMakeLists.txt
diff --git a/hal/host/async_command_queue.cc b/hal/host/async_command_queue.cc new file mode 100644 index 0000000..0d3bc2c --- /dev/null +++ b/hal/host/async_command_queue.cc
@@ -0,0 +1,127 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/host/async_command_queue.h" + +#include "absl/base/thread_annotations.h" +#include "base/status.h" +#include "base/tracing.h" + +namespace iree { +namespace hal { + +AsyncCommandQueue::AsyncCommandQueue(std::unique_ptr<CommandQueue> target_queue) + : CommandQueue(target_queue->name(), target_queue->supported_categories()), + target_queue_(std::move(target_queue)) { + IREE_TRACE_SCOPE0("AsyncCommandQueue::ctor"); + thread_ = std::thread([this]() { ThreadMain(); }); +} + +AsyncCommandQueue::~AsyncCommandQueue() { + IREE_TRACE_SCOPE0("AsyncCommandQueue::dtor"); + { + // Signal to thread that we want to stop. Note that the thread may have + // already been stopped and that's ok (as we'll Join right away). + // The thread will finish processing any queued submissions. + absl::MutexLock lock(&submission_mutex_); + submission_queue_.SignalShutdown(); + } + thread_.join(); + + // Ensure we shut down OK. + { + absl::MutexLock lock(&submission_mutex_); + CHECK(submission_queue_.empty()) + << "Dirty shutdown of async queue (unexpected thread exit?)"; + } +} + +void AsyncCommandQueue::ThreadMain() { + // TODO(benvanik): make this safer (may die if trace is flushed late). + IREE_TRACE_THREAD_ENABLE(target_queue_->name().c_str()); + + bool is_exiting = false; + while (!is_exiting) { + // Block until we are either requested to exit or there are pending + // submissions. + submission_mutex_.Lock(); + submission_mutex_.Await(absl::Condition( + +[](HostSubmissionQueue* queue) { + return queue->has_shutdown() || !queue->empty(); + }, + &submission_queue_)); + if (!submission_queue_.empty()) { + // Run all ready submissions (this may be called many times). + submission_mutex_.AssertHeld(); + submission_queue_ + .ProcessBatches( + [this](absl::Span<CommandBuffer* const> command_buffers) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(submission_mutex_) { + // Release the lock while we perform the processing so that + // other threads can submit more work. + submission_mutex_.AssertHeld(); + submission_mutex_.Unlock(); + + // Relay the command buffers to the target queue. + // Since we are taking care of all synchronization they + // don't need any waiters or fences. + auto status = target_queue_->Submit( + {{}, command_buffers, {}}, {nullptr, 0u}); + + // Take back the lock so we can manipulate the queue safely. + submission_mutex_.Lock(); + submission_mutex_.AssertHeld(); + + return status; + }) + .IgnoreError(); + submission_mutex_.AssertHeld(); + } + if (submission_queue_.has_shutdown()) { + // Exit when there are no more submissions to process and an exit was + // requested (or we errored out). + is_exiting = true; + } + submission_mutex_.Unlock(); + } +} + +Status AsyncCommandQueue::Submit(absl::Span<const SubmissionBatch> batches, + FenceValue fence) { + IREE_TRACE_SCOPE0("AsyncCommandQueue::Submit"); + absl::MutexLock lock(&submission_mutex_); + return submission_queue_.Enqueue(batches, fence); +} + +Status AsyncCommandQueue::WaitIdle(absl::Time deadline) { + IREE_TRACE_SCOPE0("AsyncCommandQueue::WaitIdle"); + + // Wait until the deadline, the thread exits, or there are no more pending + // submissions. + absl::MutexLock lock(&submission_mutex_); + if (!submission_mutex_.AwaitWithDeadline( + absl::Condition( + +[](HostSubmissionQueue* queue) { + return queue->empty() || !queue->permanent_error().ok(); + }, + &submission_queue_), + deadline)) { + return DeadlineExceededErrorBuilder(IREE_LOC) + << "Deadline exceeded waiting for submission thread to go idle"; + } + return submission_queue_.permanent_error(); +} + +} // namespace hal +} // namespace iree
diff --git a/hal/host/async_command_queue.h b/hal/host/async_command_queue.h new file mode 100644 index 0000000..54224db --- /dev/null +++ b/hal/host/async_command_queue.h
@@ -0,0 +1,71 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_HOST_ASYNC_COMMAND_QUEUE_H_ +#define IREE_HAL_HOST_ASYNC_COMMAND_QUEUE_H_ + +#include <memory> +#include <thread> // NOLINT + +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" +#include "hal/command_queue.h" +#include "hal/fence.h" +#include "hal/host/host_submission_queue.h" + +namespace iree { +namespace hal { + +// Asynchronous command queue wrapper. +// This creates a single thread to perform all CommandQueue operations. Any +// submitted CommandBuffer is dispatched in FIFO order on the queue thread +// against the provided |target_queue|. +// +// Target queues will receive submissions containing only command buffers as +// all semaphore synchronization is handled by the wrapper. Fences will also be +// omitted and code should safely handle nullptr. +// +// AsyncCommandQueue (as with CommandQueue) is thread-safe. Multiple threads +// may submit command buffers concurrently, though the order of execution in +// such a case depends entirely on the synchronization primitives provided. +class AsyncCommandQueue final : public CommandQueue { + public: + explicit AsyncCommandQueue(std::unique_ptr<CommandQueue> target_queue); + ~AsyncCommandQueue() override; + + Status Submit(absl::Span<const SubmissionBatch> batches, + FenceValue fence) override; + + Status WaitIdle(absl::Time deadline) override; + + private: + // Thread entry point for the async worker thread. + // Waits for submissions to be queued up and processes them eagerly. + void ThreadMain(); + + // CommandQueue that the async queue relays submissions into. + std::unique_ptr<CommandQueue> target_queue_; + + // Thread that runs the ThreadMain() function and processes submissions. + std::thread thread_; + + // Queue that manages submission ordering. + mutable absl::Mutex submission_mutex_; + HostSubmissionQueue submission_queue_ ABSL_GUARDED_BY(submission_mutex_); +}; + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_HOST_ASYNC_COMMAND_QUEUE_H_
diff --git a/hal/host/async_command_queue_test.cc b/hal/host/async_command_queue_test.cc new file mode 100644 index 0000000..995ef42 --- /dev/null +++ b/hal/host/async_command_queue_test.cc
@@ -0,0 +1,232 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/host/async_command_queue.h" + +#include <cstdint> +#include <memory> +#include <utility> + +#include "absl/memory/memory.h" +#include "absl/time/time.h" +#include "base/status.h" +#include "base/status_matchers.h" +#include "base/time.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "hal/command_queue.h" +#include "hal/host/host_submission_queue.h" +#include "hal/testing/mock_command_buffer.h" +#include "hal/testing/mock_command_queue.h" + +namespace iree { +namespace hal { +namespace { + +using ::testing::_; + +using testing::MockCommandBuffer; +using testing::MockCommandQueue; + +struct AsyncCommandQueueTest : public ::testing::Test { + MockCommandQueue* mock_target_queue; + std::unique_ptr<CommandQueue> command_queue; + + void SetUp() override { + auto mock_queue = absl::make_unique<MockCommandQueue>( + "mock", CommandCategory::kTransfer | CommandCategory::kDispatch); + mock_target_queue = mock_queue.get(); + command_queue = absl::make_unique<AsyncCommandQueue>(std::move(mock_queue)); + } + + void TearDown() override { + command_queue.reset(); + mock_target_queue = nullptr; + } +}; + +// Tests that submitting a command buffer and immediately waiting will not +// deadlock. +TEST_F(AsyncCommandQueueTest, BlockingSubmit) { + ::testing::InSequence sequence; + + auto cmd_buffer = make_ref<MockCommandBuffer>( + nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer); + + EXPECT_CALL(*mock_target_queue, Submit(_, _)) + .WillOnce( + [&](absl::Span<const SubmissionBatch> batches, FenceValue fence) { + CHECK_EQ(1, batches.size()); + CHECK_EQ(1, batches[0].command_buffers.size()); + CHECK_EQ(cmd_buffer.get(), batches[0].command_buffers[0]); + CHECK_EQ(nullptr, fence.first); + return OkStatus(); + }); + HostFence fence(0u); + ASSERT_OK(command_queue->Submit({{}, {cmd_buffer.get()}, {}}, {&fence, 1u})); + ASSERT_OK(HostFence::WaitForFences({{&fence, 1u}}, /*wait_all=*/true, + absl::InfiniteFuture())); +} + +// Tests that failure is propagated along the fence from the target queue. +TEST_F(AsyncCommandQueueTest, PropagateSubmitFailure) { + ::testing::InSequence sequence; + + auto cmd_buffer = make_ref<MockCommandBuffer>( + nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer); + + EXPECT_CALL(*mock_target_queue, Submit(_, _)) + .WillOnce( + [](absl::Span<const SubmissionBatch> batches, FenceValue fence) { + return DataLossErrorBuilder(IREE_LOC); + }); + HostFence fence(0u); + ASSERT_OK(command_queue->Submit({{}, {cmd_buffer.get()}, {}}, {&fence, 1u})); + EXPECT_TRUE(IsDataLoss(HostFence::WaitForFences( + {{&fence, 1u}}, /*wait_all=*/true, absl::InfiniteFuture()))); +} + +// Tests that waiting for idle is a no-op when nothing is queued. +TEST_F(AsyncCommandQueueTest, WaitIdleWhileIdle) { + ASSERT_OK(command_queue->WaitIdle()); +} + +// Tests that waiting for idle will block when work is pending/in-flight. +TEST_F(AsyncCommandQueueTest, WaitIdleWithPending) { + ::testing::InSequence sequence; + + auto cmd_buffer = make_ref<MockCommandBuffer>( + nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer); + + EXPECT_CALL(*mock_target_queue, Submit(_, _)) + .WillOnce( + [](absl::Span<const SubmissionBatch> batches, FenceValue fence) { + Sleep(absl::Milliseconds(100)); + return OkStatus(); + }); + HostFence fence(0u); + ASSERT_OK(command_queue->Submit({{}, {cmd_buffer.get()}, {}}, {&fence, 1u})); + + // This should block for a sec or two. + ASSERT_OK(command_queue->WaitIdle()); + + // Should have already expired. + ASSERT_OK_AND_ASSIGN(uint64_t value, fence.QueryValue()); + ASSERT_EQ(1u, value); +} + +// Tests that waiting for idle with multiple pending submissions will wait until +// all of them complete while still allowing incremental progress. +TEST_F(AsyncCommandQueueTest, WaitIdleAndProgress) { + ::testing::InSequence sequence; + + EXPECT_CALL(*mock_target_queue, Submit(_, _)) + .WillRepeatedly( + [](absl::Span<const SubmissionBatch> batches, FenceValue fence) { + Sleep(absl::Milliseconds(100)); + return OkStatus(); + }); + + auto cmd_buffer_0 = make_ref<MockCommandBuffer>( + nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer); + auto cmd_buffer_1 = make_ref<MockCommandBuffer>( + nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer); + + HostFence fence_0(0u); + ASSERT_OK( + command_queue->Submit({{}, {cmd_buffer_0.get()}, {}}, {&fence_0, 1u})); + HostFence fence_1(0u); + ASSERT_OK( + command_queue->Submit({{}, {cmd_buffer_1.get()}, {}}, {&fence_1, 1u})); + + // This should block for a sec or two. + ASSERT_OK(command_queue->WaitIdle()); + + // Both should have already expired. + ASSERT_OK_AND_ASSIGN(uint64_t value_0, fence_0.QueryValue()); + ASSERT_EQ(1u, value_0); + ASSERT_OK_AND_ASSIGN(uint64_t value_1, fence_1.QueryValue()); + ASSERT_EQ(1u, value_1); +} + +// Tests that failures are sticky. +TEST_F(AsyncCommandQueueTest, StickyFailures) { + ::testing::InSequence sequence; + + // Fail. + EXPECT_CALL(*mock_target_queue, Submit(_, _)) + .WillOnce( + [](absl::Span<const SubmissionBatch> batches, FenceValue fence) { + Sleep(absl::Milliseconds(100)); + return DataLossErrorBuilder(IREE_LOC); + }); + auto cmd_buffer_0 = make_ref<MockCommandBuffer>( + nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer); + HostFence fence_0(0u); + ASSERT_OK( + command_queue->Submit({{}, {cmd_buffer_0.get()}, {}}, {&fence_0, 1u})); + EXPECT_TRUE(IsDataLoss(HostFence::WaitForFences( + {{&fence_0, 1u}}, /*wait_all=*/true, absl::InfiniteFuture()))); + + // Future flushes/waits/etc should also fail. + EXPECT_TRUE(IsDataLoss(command_queue->WaitIdle())); + + // Future submits should fail asynchronously. + auto cmd_buffer_1 = make_ref<MockCommandBuffer>( + nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer); + HostFence fence_1(0u); + EXPECT_TRUE(IsDataLoss( + command_queue->Submit({{}, {cmd_buffer_1.get()}, {}}, {&fence_1, 1u}))); +} + +// Tests that a failure with two submissions pending causes the second to +// bail as well. +TEST_F(AsyncCommandQueueTest, FailuresCascadeAcrossSubmits) { + ::testing::InSequence sequence; + + // Fail. + EXPECT_CALL(*mock_target_queue, Submit(_, _)) + .WillOnce( + [](absl::Span<const SubmissionBatch> batches, FenceValue fence) { + Sleep(absl::Milliseconds(100)); + return DataLossErrorBuilder(IREE_LOC); + }); + + auto cmd_buffer_0 = make_ref<MockCommandBuffer>( + nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer); + auto cmd_buffer_1 = make_ref<MockCommandBuffer>( + nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer); + + HostBinarySemaphore semaphore_0_1(false); + HostFence fence_0(0u); + ASSERT_OK(command_queue->Submit({{}, {cmd_buffer_0.get()}, {&semaphore_0_1}}, + {&fence_0, 1u})); + HostFence fence_1(0u); + ASSERT_OK(command_queue->Submit({{&semaphore_0_1}, {cmd_buffer_1.get()}, {}}, + {&fence_1, 1u})); + + EXPECT_TRUE(IsDataLoss(command_queue->WaitIdle())); + + EXPECT_TRUE(IsDataLoss(HostFence::WaitForFences( + {{&fence_0, 1u}}, /*wait_all=*/true, absl::InfiniteFuture()))); + EXPECT_TRUE(IsDataLoss(HostFence::WaitForFences( + {{&fence_1, 1u}}, /*wait_all=*/true, absl::InfiniteFuture()))); + + // Future flushes/waits/etc should also fail. + EXPECT_TRUE(IsDataLoss(command_queue->WaitIdle())); +} + +} // namespace +} // namespace hal +} // namespace iree
diff --git a/hal/host/host_buffer.cc b/hal/host/host_buffer.cc new file mode 100644 index 0000000..fe4bce4 --- /dev/null +++ b/hal/host/host_buffer.cc
@@ -0,0 +1,148 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/host/host_buffer.h" + +#include <cstdint> +#include <cstdlib> +#include <cstring> + +#include "base/logging.h" +#include "base/source_location.h" +#include "base/status.h" + +namespace iree { +namespace hal { + +class Allocator; + +HostBuffer::HostBuffer(Allocator* allocator, MemoryTypeBitfield memory_type, + MemoryAccessBitfield allowed_access, + BufferUsageBitfield usage, device_size_t allocation_size, + void* data, bool owns_data) + : Buffer(allocator, memory_type, allowed_access, usage, allocation_size, 0, + allocation_size), + data_(data), + owns_data_(owns_data) {} + +HostBuffer::~HostBuffer() { + if (owns_data_ && data_) { + std::free(data_); + data_ = nullptr; + } +} + +Status HostBuffer::FillImpl(device_size_t byte_offset, + device_size_t byte_length, const void* pattern, + device_size_t pattern_length) { + auto data_ptr = data_; + switch (pattern_length) { + case 1: { + uint8_t* data = static_cast<uint8_t*>(data_ptr); + uint8_t value_bits = *static_cast<const uint8_t*>(pattern); + std::fill_n(data + byte_offset, byte_length, value_bits); + break; + } + case 2: { + uint16_t* data = static_cast<uint16_t*>(data_ptr); + uint16_t value_bits = *static_cast<const uint16_t*>(pattern); + std::fill_n(data + byte_offset / sizeof(uint16_t), + byte_length / sizeof(uint16_t), value_bits); + break; + } + case 4: { + uint32_t* data = static_cast<uint32_t*>(data_ptr); + uint32_t value_bits = *static_cast<const uint32_t*>(pattern); + std::fill_n(data + byte_offset / sizeof(uint32_t), + byte_length / sizeof(uint32_t), value_bits); + break; + } + default: + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Unsupported scalar data size: " << pattern_length; + } + return OkStatus(); +} + +Status HostBuffer::ReadDataImpl(device_size_t source_offset, void* data, + device_size_t data_length) { + auto data_ptr = static_cast<uint8_t*>(data_); + std::memcpy(data, data_ptr + source_offset, data_length); + return OkStatus(); +} + +Status HostBuffer::WriteDataImpl(device_size_t target_offset, const void* data, + device_size_t data_length) { + auto data_ptr = static_cast<uint8_t*>(data_); + std::memcpy(data_ptr + target_offset, data, data_length); + return OkStatus(); +} + +Status HostBuffer::CopyDataImpl(device_size_t target_offset, + Buffer* source_buffer, + device_size_t source_offset, + device_size_t data_length) { + // This is pretty terrible. Let's not do this. + // TODO(benvanik): a way for allocators to indicate transfer compat. + ASSIGN_OR_RETURN(auto source_data, + source_buffer->MapMemory<uint8_t>( + MemoryAccess::kRead, source_offset, data_length)); + CHECK_EQ(data_length, source_data.size()); + auto data_ptr = static_cast<uint8_t*>(data_); + std::memcpy(data_ptr + target_offset, source_data.data(), data_length); + return OkStatus(); +} + +Status HostBuffer::MapMemoryImpl(MappingMode mapping_mode, + MemoryAccessBitfield memory_access, + device_size_t local_byte_offset, + device_size_t local_byte_length, + void** out_data) { + auto data_ptr = static_cast<uint8_t*>(data_); + *out_data = data_ptr + local_byte_offset; + + // If we mapped for discard scribble over the bytes. This is not a mandated + // behavior but it will make debugging issues easier. Alternatively for + // heap buffers we could reallocate them such that ASAN yells, but that + // would only work if the entire buffer was discarded. +#ifndef NDEBUG + if (AnyBitSet(memory_access & MemoryAccess::kDiscard)) { + std::memset(data_ptr + local_byte_offset, 0xCD, local_byte_length); + } +#endif // !NDEBUG + + return OkStatus(); +} + +Status HostBuffer::UnmapMemoryImpl(device_size_t local_byte_offset, + device_size_t local_byte_length, + void* data) { + // No-op? We still want error checking to make finding misuse easier. + return OkStatus(); +} + +Status HostBuffer::InvalidateMappedMemoryImpl(device_size_t local_byte_offset, + device_size_t local_byte_length) { + // No-op? We still want error checking to make finding misuse easier. + return OkStatus(); +} + +Status HostBuffer::FlushMappedMemoryImpl(device_size_t local_byte_offset, + device_size_t local_byte_length) { + // No-op? We still want error checking to make finding misuse easier. + return OkStatus(); +} + +} // namespace hal +} // namespace iree
diff --git a/hal/host/host_buffer.h b/hal/host/host_buffer.h new file mode 100644 index 0000000..2d20ea7 --- /dev/null +++ b/hal/host/host_buffer.h
@@ -0,0 +1,67 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_HOST_BUFFER_H_ +#define IREE_HAL_HOST_BUFFER_H_ + +#include <cstdint> + +#include "base/status.h" +#include "hal/buffer.h" + +namespace iree { +namespace hal { + +// A buffer type that operates on host pointers. +// This can be used by Allocator implementations when they support operating +// on host memory (or mapping their memory to host memory). +class HostBuffer : public Buffer { + public: + HostBuffer(Allocator* allocator, MemoryTypeBitfield memory_type, + MemoryAccessBitfield allowed_access, BufferUsageBitfield usage, + device_size_t allocation_size, void* data, bool owns_data); + + ~HostBuffer() override; + + protected: + Status FillImpl(device_size_t byte_offset, device_size_t byte_length, + const void* pattern, device_size_t pattern_length) override; + Status ReadDataImpl(device_size_t source_offset, void* data, + device_size_t data_length) override; + Status WriteDataImpl(device_size_t target_offset, const void* data, + device_size_t data_length) override; + Status CopyDataImpl(device_size_t target_offset, Buffer* source_buffer, + device_size_t source_offset, + device_size_t data_length) override; + Status MapMemoryImpl(MappingMode mapping_mode, + MemoryAccessBitfield memory_access, + device_size_t local_byte_offset, + device_size_t local_byte_length, + void** out_data) override; + Status UnmapMemoryImpl(device_size_t local_byte_offset, + device_size_t local_byte_length, void* data) override; + Status InvalidateMappedMemoryImpl(device_size_t local_byte_offset, + device_size_t local_byte_length) override; + Status FlushMappedMemoryImpl(device_size_t local_byte_offset, + device_size_t local_byte_length) override; + + private: + void* data_ = nullptr; + bool owns_data_ = false; +}; + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_HOST_BUFFER_H_
diff --git a/hal/host/host_event.cc b/hal/host/host_event.cc new file mode 100644 index 0000000..1b7abac --- /dev/null +++ b/hal/host/host_event.cc
@@ -0,0 +1,25 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/host/host_event.h" + +namespace iree { +namespace hal { + +HostEvent::HostEvent() = default; + +HostEvent::~HostEvent() = default; + +} // namespace hal +} // namespace iree
diff --git a/hal/host/host_event.h b/hal/host/host_event.h new file mode 100644 index 0000000..cfdbe09 --- /dev/null +++ b/hal/host/host_event.h
@@ -0,0 +1,32 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_HOST_HOST_EVENT_H_ +#define IREE_HAL_HOST_HOST_EVENT_H_ + +#include "hal/event.h" + +namespace iree { +namespace hal { + +class HostEvent final : public Event { + public: + HostEvent(); + ~HostEvent() override; +}; + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_HOST_HOST_EVENT_H_
diff --git a/hal/host/host_fence.cc b/hal/host/host_fence.cc new file mode 100644 index 0000000..a983af9 --- /dev/null +++ b/hal/host/host_fence.cc
@@ -0,0 +1,110 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/host/host_fence.h" + +#include <atomic> +#include <cstdint> + +#include "absl/container/inlined_vector.h" +#include "absl/synchronization/mutex.h" +#include "base/status.h" +#include "base/tracing.h" + +namespace iree { +namespace hal { + +HostFence::HostFence(uint64_t initial_value) : value_(initial_value) {} + +HostFence::~HostFence() = default; + +Status HostFence::status() const { + absl::MutexLock lock(&mutex_); + return status_; +} + +StatusOr<uint64_t> HostFence::QueryValue() { + return value_.load(std::memory_order_acquire); +} + +Status HostFence::Signal(uint64_t value) { + absl::MutexLock lock(&mutex_); + if (!status_.ok()) { + return status_; + } + if (value_.exchange(value) >= value) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Fence values must be monotonically increasing"; + } + return OkStatus(); +} + +Status HostFence::Fail(Status status) { + absl::MutexLock lock(&mutex_); + status_ = status; + value_.store(UINT64_MAX, std::memory_order_release); + return OkStatus(); +} + +// static +Status HostFence::WaitForFences(absl::Span<const FenceValue> fences, + bool wait_all, absl::Time deadline) { + IREE_TRACE_SCOPE0("HostFence::WaitForFences"); + + // Some of the fences may already be signaled; we only need to wait for those + // that are not yet at the expected value. + using HostFenceValue = std::pair<HostFence*, uint64_t>; + absl::InlinedVector<HostFenceValue, 4> waitable_fences; + waitable_fences.reserve(fences.size()); + for (auto& fence_value : fences) { + auto* fence = reinterpret_cast<HostFence*>(fence_value.first); + ASSIGN_OR_RETURN(uint64_t current_value, fence->QueryValue()); + if (current_value == UINT64_MAX) { + // Fence has failed. Return the error. + return fence->status(); + } else if (current_value < fence_value.second) { + // Fence has not yet hit the required value; wait for it. + waitable_fences.push_back({fence, fence_value.second}); + } + } + + // TODO(benvanik): maybe sort fences by value in case we are waiting on + // multiple values from the same fence. + + // Loop over the fences and wait for them to complete. + // TODO(b/140026716): add WaitHandle support for !wait_all (wait any). + for (auto& fence_value : waitable_fences) { + auto* fence = fence_value.first; + absl::MutexLock lock(&fence->mutex_); + if (!fence->mutex_.AwaitWithDeadline( + absl::Condition( + +[](HostFenceValue* fence_value) { + return fence_value->first->value_.load( + std::memory_order_acquire) >= fence_value->second; + }, + &fence_value), + deadline)) { + return DeadlineExceededErrorBuilder(IREE_LOC) + << "Deadline exceeded waiting for fences"; + } + if (!fence->status_.ok()) { + return fence->status_; + } + } + + return OkStatus(); +} + +} // namespace hal +} // namespace iree
diff --git a/hal/host/host_fence.h b/hal/host/host_fence.h new file mode 100644 index 0000000..16464db --- /dev/null +++ b/hal/host/host_fence.h
@@ -0,0 +1,64 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_HOST_HOST_FENCE_H_ +#define IREE_HAL_HOST_HOST_FENCE_H_ + +#include <atomic> +#include <cstdint> + +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "base/status.h" +#include "hal/fence.h" + +namespace iree { +namespace hal { + +// TODO(b/140026716): add WaitHandle support for better multi-wait. +// Simple host-only fence semaphore implemented with a mutex. +// +// Thread-safe (as instances may be imported and used by others). +class HostFence final : public Fence { + public: + // Waits for one or more (or all) fences to reach or exceed the given values. + static Status WaitForFences(absl::Span<const FenceValue> fences, + bool wait_all, absl::Time deadline); + + explicit HostFence(uint64_t initial_value); + ~HostFence() override; + + Status status() const override; + StatusOr<uint64_t> QueryValue() override; + + Status Signal(uint64_t value); + Status Fail(Status status); + + private: + // The mutex is not required to query the value; this lets us quickly check if + // a required value has been exceeded. The mutex is only used to update and + // notify waiters. + std::atomic<uint64_t> value_{0}; + + // We have a full mutex here so that we can perform condvar waits on value + // changes. + mutable absl::Mutex mutex_; + Status status_ ABSL_GUARDED_BY(mutex_); +}; + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_HOST_HOST_FENCE_H_
diff --git a/hal/host/host_fence_test.cc b/hal/host/host_fence_test.cc new file mode 100644 index 0000000..c0d12a5 --- /dev/null +++ b/hal/host/host_fence_test.cc
@@ -0,0 +1,148 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/host/host_fence.h" + +#include <cstdint> +#include <thread> // NOLINT + +#include "absl/time/time.h" +#include "base/status.h" +#include "base/status_matchers.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace iree { +namespace hal { +namespace { + +// Tests that a fence that is unused properly cleans itself up. +TEST(HostFenceTest, NoOp) { + HostFence fence(123u); + EXPECT_TRUE(fence.status().ok()); + ASSERT_OK_AND_ASSIGN(uint64_t value, fence.QueryValue()); + EXPECT_EQ(123u, value); +} + +// Tests that a fence will accept new values as it is signaled. +TEST(HostFenceTest, NormalSignaling) { + HostFence fence(2u); + EXPECT_EQ(2u, fence.QueryValue().ValueOrDie()); + EXPECT_OK(fence.Signal(3u)); + EXPECT_EQ(3u, fence.QueryValue().ValueOrDie()); + EXPECT_OK(fence.Signal(40u)); + EXPECT_EQ(40u, fence.QueryValue().ValueOrDie()); +} + +// Tests that a fence will fail to set non-increasing values. +TEST(HostFenceTest, RequireIncreasingValues) { + HostFence fence(2u); + EXPECT_EQ(2u, fence.QueryValue().ValueOrDie()); + // Same value. + EXPECT_TRUE(IsInvalidArgument(fence.Signal(2u))); + // Decreasing. + EXPECT_TRUE(IsInvalidArgument(fence.Signal(1u))); +} + +// Tests that a fence that has failed will remain in a failed state. +TEST(HostFenceTest, StickyFailure) { + HostFence fence(2u); + // Signal to 3. + EXPECT_OK(fence.Signal(3u)); + EXPECT_TRUE(fence.status().ok()); + EXPECT_EQ(3u, fence.QueryValue().ValueOrDie()); + + // Fail now. + EXPECT_OK(fence.Fail(UnknownErrorBuilder(IREE_LOC))); + EXPECT_TRUE(IsUnknown(fence.status())); + EXPECT_EQ(UINT64_MAX, fence.QueryValue().ValueOrDie()); + + // Unable to signal again (it'll return the sticky failure). + EXPECT_TRUE(IsUnknown(fence.Signal(4u))); + EXPECT_TRUE(IsUnknown(fence.status())); + EXPECT_EQ(UINT64_MAX, fence.QueryValue().ValueOrDie()); +} + +// Tests waiting on no fences. +TEST(HostFenceTest, EmptyWait) { + EXPECT_OK( + HostFence::WaitForFences({}, /*wait_all=*/true, absl::InfiniteFuture())); +} + +// Tests waiting on a fence that has already been signaled. +TEST(HostFenceTest, WaitAlreadySignaled) { + HostFence fence(2u); + // Test both previous and current values. + EXPECT_OK(HostFence::WaitForFences({{&fence, 1u}}, /*wait_all=*/true, + absl::InfiniteFuture())); + EXPECT_OK(HostFence::WaitForFences({{&fence, 2u}}, /*wait_all=*/true, + absl::InfiniteFuture())); +} + +// Tests waiting on a fence that has not been signaled. +TEST(HostFenceTest, WaitUnsignaled) { + HostFence fence(2u); + // NOTE: we don't actually block here because otherwise we'd lock up. + EXPECT_TRUE(IsDeadlineExceeded(HostFence::WaitForFences( + {{&fence, 3u}}, /*wait_all=*/true, absl::InfinitePast()))); +} + +// Tests waiting on a failed fence (it should return the error on the fence). +TEST(HostFenceTest, WaitAlreadyFailed) { + HostFence fence(2u); + EXPECT_OK(fence.Fail(UnknownErrorBuilder(IREE_LOC))); + EXPECT_TRUE(IsUnknown(HostFence::WaitForFences( + {{&fence, 2u}}, /*wait_all=*/true, absl::InfinitePast()))); +} + +// Tests threading behavior by ping-ponging between the test main thread and +// a little thread. +TEST(HostFenceTest, PingPong) { + HostFence a2b(0u); + HostFence b2a(0u); + std::thread thread([&]() { + // Should advance right past this because the value is already set. + ASSERT_OK(HostFence::WaitForFences({{&a2b, 0u}}, /*wait_all=*/true, + absl::InfiniteFuture())); + ASSERT_OK(b2a.Signal(1u)); + // Jump ahead. + ASSERT_OK(HostFence::WaitForFences({{&a2b, 4u}}, /*wait_all=*/true, + absl::InfiniteFuture())); + }); + ASSERT_OK(HostFence::WaitForFences({{&b2a, 1u}}, /*wait_all=*/true, + absl::InfiniteFuture())); + ASSERT_OK(a2b.Signal(4u)); + thread.join(); +} + +// Tests that failure still wakes waiters and propagates the error. +TEST(HostFenceTest, FailNotifies) { + HostFence a2b(0u); + HostFence b2a(0u); + bool got_failure = false; + std::thread thread([&]() { + ASSERT_OK(b2a.Signal(1u)); + got_failure = IsUnknown(HostFence::WaitForFences( + {{&a2b, 1u}}, /*wait_all=*/true, absl::InfiniteFuture())); + }); + ASSERT_OK(HostFence::WaitForFences({{&b2a, 1u}}, /*wait_all=*/true, + absl::InfiniteFuture())); + ASSERT_OK(a2b.Fail(UnknownErrorBuilder(IREE_LOC))); + thread.join(); + ASSERT_TRUE(got_failure); +} + +} // namespace +} // namespace hal +} // namespace iree
diff --git a/hal/host/host_local_allocator.cc b/hal/host/host_local_allocator.cc new file mode 100644 index 0000000..6b60769 --- /dev/null +++ b/hal/host/host_local_allocator.cc
@@ -0,0 +1,111 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/host/host_local_allocator.h" + +#include <cstdlib> +#include <string> +#include <utility> + +#include "base/source_location.h" +#include "base/status.h" +#include "base/tracing.h" +#include "hal/host/host_buffer.h" + +namespace iree { +namespace hal { + +HostLocalAllocator::HostLocalAllocator() = default; + +HostLocalAllocator::~HostLocalAllocator() = default; + +bool HostLocalAllocator::CanUseBufferLike( + Allocator* source_allocator, MemoryTypeBitfield memory_type, + BufferUsageBitfield buffer_usage, + BufferUsageBitfield intended_usage) const { + // Must always have visibility to the device, which ensures we can test + // against the host but have things work on devices with separate address + // spaces. + if (!AnyBitSet(memory_type & MemoryType::kDeviceVisible)) { + return false; + } + + // kHostVisible is required for mapping. + if (AnyBitSet(intended_usage & BufferUsage::kMapping) && + !AnyBitSet(memory_type & MemoryType::kHostVisible)) { + return false; + } + + // Dispatch needs to be specified if we intend to dispatch. + if (AnyBitSet(intended_usage & BufferUsage::kDispatch) && + !AnyBitSet(buffer_usage & BufferUsage::kDispatch)) { + return false; + } + + return true; +} + +bool HostLocalAllocator::CanAllocate(MemoryTypeBitfield memory_type, + BufferUsageBitfield buffer_usage, + size_t allocation_size) const { + // Host allows everything, pretty much, so long as it is device-visible (as + // the host is the device here). + return AnyBitSet(memory_type & MemoryType::kDeviceVisible); +} + +Status HostLocalAllocator::MakeCompatible( + MemoryTypeBitfield* memory_type, BufferUsageBitfield* buffer_usage) const { + // Always ensure we are host-visible. + *memory_type |= MemoryType::kHostVisible; + + // Host currently uses mapping to copy buffers, which is done a lot. + // We could probably remove this restriction somehow. + *buffer_usage |= BufferUsage::kMapping; + + // TODO(b/111372612): tensorflow needs transfer too, but shouldn't. + *buffer_usage |= BufferUsage::kTransfer; + + return OkStatus(); +} + +StatusOr<ref_ptr<Buffer>> HostLocalAllocator::Allocate( + MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, + size_t allocation_size) { + IREE_TRACE_SCOPE0("HostLocalAllocator::Allocate"); + + if (!CanAllocate(memory_type, buffer_usage, allocation_size)) { + return FailedPreconditionErrorBuilder(IREE_LOC) + << "Allocation not supported; memory_type=" + << MemoryTypeString(memory_type) + << ", buffer_usage=" << BufferUsageString(buffer_usage) + << ", allocation_size=" << allocation_size; + } + + // Make compatible with our requirements. + RETURN_IF_ERROR(MakeCompatible(&memory_type, &buffer_usage)); + + void* malloced_data = std::calloc(1, allocation_size); + if (!malloced_data) { + return ResourceExhaustedErrorBuilder(IREE_LOC) + << "Failed to malloc " << allocation_size << " bytes"; + } + + auto buffer = + make_ref<HostBuffer>(this, memory_type, MemoryAccess::kAll, buffer_usage, + allocation_size, malloced_data, true); + return buffer; +} + +} // namespace hal +} // namespace iree
diff --git a/hal/host/host_local_allocator.h b/hal/host/host_local_allocator.h new file mode 100644 index 0000000..fd38910 --- /dev/null +++ b/hal/host/host_local_allocator.h
@@ -0,0 +1,60 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_HOST_LOCAL_ALLOCATOR_H_ +#define IREE_HAL_HOST_LOCAL_ALLOCATOR_H_ + +#include <cstddef> +#include <memory> + +#include "base/status.h" +#include "hal/allocator.h" +#include "hal/buffer.h" + +namespace iree { +namespace hal { + +// An allocator implementation that allocates buffers from host memory. +// This can be used for drivers that do not have a memory space of their own. +// +// Buffers allocated will have be MemoryType::kHostLocal | kDeviceVisible as +// the 'device' in the case of a host-local queue *is* the host. To keep code +// written initially for a host-local queue working when other queues are used +// the allocator only works with buffers that are kDeviceVisible. +class HostLocalAllocator : public Allocator { + public: + HostLocalAllocator(); + ~HostLocalAllocator() override; + + bool CanUseBufferLike(Allocator* source_allocator, + MemoryTypeBitfield memory_type, + BufferUsageBitfield buffer_usage, + BufferUsageBitfield intended_usage) const override; + + bool CanAllocate(MemoryTypeBitfield memory_type, + BufferUsageBitfield buffer_usage, + size_t allocation_size) const override; + + Status MakeCompatible(MemoryTypeBitfield* memory_type, + BufferUsageBitfield* buffer_usage) const override; + + StatusOr<ref_ptr<Buffer>> Allocate(MemoryTypeBitfield memory_type, + BufferUsageBitfield buffer_usage, + size_t allocation_size) override; +}; + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_HOST_LOCAL_ALLOCATOR_H_
diff --git a/hal/host/host_local_command_processor.cc b/hal/host/host_local_command_processor.cc new file mode 100644 index 0000000..72a3c88 --- /dev/null +++ b/hal/host/host_local_command_processor.cc
@@ -0,0 +1,120 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/host/host_local_command_processor.h" + +#include "base/source_location.h" +#include "base/status.h" +#include "base/tracing.h" + +namespace iree { +namespace hal { + +HostLocalCommandProcessor::HostLocalCommandProcessor( + Allocator* allocator, CommandBufferModeBitfield mode, + CommandCategoryBitfield command_categories) + : CommandBuffer(allocator, mode, command_categories) {} + +HostLocalCommandProcessor::~HostLocalCommandProcessor() = default; + +Status HostLocalCommandProcessor::Begin() { + IREE_TRACE_SCOPE0("HostLocalCommandProcessor::Begin"); + is_recording_ = true; + return OkStatus(); +} + +Status HostLocalCommandProcessor::End() { + IREE_TRACE_SCOPE0("HostLocalCommandProcessor::End"); + is_recording_ = false; + return OkStatus(); +} + +Status HostLocalCommandProcessor::ExecutionBarrier( + ExecutionStageBitfield source_stage_mask, + ExecutionStageBitfield target_stage_mask, + absl::Span<const MemoryBarrier> memory_barriers, + absl::Span<const BufferBarrier> buffer_barriers) { + IREE_TRACE_SCOPE0("HostLocalCommandProcessor::ExecutionBarrier"); + // No-op. + return OkStatus(); +} + +Status HostLocalCommandProcessor::SignalEvent( + Event* event, ExecutionStageBitfield source_stage_mask) { + IREE_TRACE_SCOPE0("HostLocalCommandProcessor::SignalEvent"); + // No-op. + return OkStatus(); +} + +Status HostLocalCommandProcessor::ResetEvent( + Event* event, ExecutionStageBitfield source_stage_mask) { + IREE_TRACE_SCOPE0("HostLocalCommandProcessor::ResetEvent"); + // No-op. + return OkStatus(); +} + +Status HostLocalCommandProcessor::WaitEvents( + absl::Span<Event*> events, ExecutionStageBitfield source_stage_mask, + ExecutionStageBitfield target_stage_mask, + absl::Span<const MemoryBarrier> memory_barriers, + absl::Span<const BufferBarrier> buffer_barriers) { + IREE_TRACE_SCOPE0("HostLocalCommandProcessor::WaitEvents"); + // No-op. + return OkStatus(); +} + +Status HostLocalCommandProcessor::FillBuffer(Buffer* target_buffer, + device_size_t target_offset, + device_size_t length, + const void* pattern, + size_t pattern_length) { + IREE_TRACE_SCOPE0("HostLocalCommandProcessor::FillBuffer"); + return target_buffer->Fill(target_offset, length, pattern, pattern_length); +} + +Status HostLocalCommandProcessor::DiscardBuffer(Buffer* buffer) { + IREE_TRACE_SCOPE0("HostLocalCommandProcessor::DiscardBuffer"); + // No-op as we don't support lazily allocated buffers. + return OkStatus(); +} + +Status HostLocalCommandProcessor::UpdateBuffer(const void* source_buffer, + device_size_t source_offset, + Buffer* target_buffer, + device_size_t target_offset, + device_size_t length) { + IREE_TRACE_SCOPE0("HostLocalCommandProcessor::UpdateBuffer"); + return target_buffer->WriteData( + target_offset, static_cast<const uint8_t*>(source_buffer) + source_offset, + length); +} + +Status HostLocalCommandProcessor::CopyBuffer(Buffer* source_buffer, + device_size_t source_offset, + Buffer* target_buffer, + device_size_t target_offset, + device_size_t length) { + IREE_TRACE_SCOPE0("HostLocalCommandProcessor::CopyBuffer"); + return target_buffer->CopyData(target_offset, source_buffer, source_offset, + length); +} + +Status HostLocalCommandProcessor::Dispatch( + const DispatchRequest& dispatch_request) { + return FailedPreconditionErrorBuilder(IREE_LOC) + << "Command processor does not support dispatch operations"; +} + +} // namespace hal +} // namespace iree
diff --git a/hal/host/host_local_command_processor.h b/hal/host/host_local_command_processor.h new file mode 100644 index 0000000..e18d2a8 --- /dev/null +++ b/hal/host/host_local_command_processor.h
@@ -0,0 +1,85 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_HOST_HOST_LOCAL_COMMAND_PROCESSOR_H_ +#define IREE_HAL_HOST_HOST_LOCAL_COMMAND_PROCESSOR_H_ + +#include "hal/command_buffer.h" + +namespace iree { +namespace hal { + +// Host-local command processor for dispatching transfer operations against +// buffers allocated from the HostLocalAllocator. +// This assumes that all buffers are host-visible (if not local) and that all +// buffers can be mapped for access. +// +// Subclasses may implement Dispatch, otherwise the default implementation just +// returns failure. +// +// Thread-compatible (as with CommandBuffer itself). +class HostLocalCommandProcessor : public CommandBuffer { + public: + HostLocalCommandProcessor(Allocator* allocator, + CommandBufferModeBitfield mode, + CommandCategoryBitfield command_categories); + ~HostLocalCommandProcessor() override; + + bool is_recording() const override { return is_recording_; } + + Status Begin() override; + Status End() override; + + Status ExecutionBarrier( + ExecutionStageBitfield source_stage_mask, + ExecutionStageBitfield target_stage_mask, + absl::Span<const MemoryBarrier> memory_barriers, + absl::Span<const BufferBarrier> buffer_barriers) override; + + Status SignalEvent(Event* event, + ExecutionStageBitfield source_stage_mask) override; + + Status ResetEvent(Event* event, + ExecutionStageBitfield source_stage_mask) override; + + Status WaitEvents(absl::Span<Event*> events, + ExecutionStageBitfield source_stage_mask, + ExecutionStageBitfield target_stage_mask, + absl::Span<const MemoryBarrier> memory_barriers, + absl::Span<const BufferBarrier> buffer_barriers) override; + + Status FillBuffer(Buffer* target_buffer, device_size_t target_offset, + device_size_t length, const void* pattern, + size_t pattern_length) override; + + Status DiscardBuffer(Buffer* buffer) override; + + Status UpdateBuffer(const void* source_buffer, device_size_t source_offset, + Buffer* target_buffer, device_size_t target_offset, + device_size_t length) override; + + Status CopyBuffer(Buffer* source_buffer, device_size_t source_offset, + Buffer* target_buffer, device_size_t target_offset, + device_size_t length) override; + + Status Dispatch(const DispatchRequest& dispatch_request) override; + + private: + bool is_recording_ = false; +}; + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_HOST_HOST_LOCAL_COMMAND_PROCESSOR_H_
diff --git a/hal/host/host_submission_queue.cc b/hal/host/host_submission_queue.cc new file mode 100644 index 0000000..aa672b0 --- /dev/null +++ b/hal/host/host_submission_queue.cc
@@ -0,0 +1,295 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/host/host_submission_queue.h" + +#include <atomic> +#include <cstdint> + +#include "absl/synchronization/mutex.h" +#include "base/status.h" +#include "base/tracing.h" + +namespace iree { +namespace hal { + +HostBinarySemaphore::HostBinarySemaphore(bool initial_value) { + State state = {0}; + state.signaled = initial_value ? 1 : 0; + state_ = state; +} + +bool HostBinarySemaphore::is_signaled() const { + return state_.load(std::memory_order_acquire).signaled == 1; +} + +Status HostBinarySemaphore::BeginSignaling() { + State old_state = state_.load(std::memory_order_acquire); + if (old_state.signal_pending != 0) { + return FailedPreconditionErrorBuilder(IREE_LOC) + << "A signal operation on a binary semaphore is already pending"; + } + State new_state = old_state; + new_state.signal_pending = 1; + state_.compare_exchange_strong(old_state, new_state); + return OkStatus(); +} + +Status HostBinarySemaphore::EndSignaling() { + State old_state = state_.load(std::memory_order_acquire); + DCHECK_EQ(old_state.signal_pending, 1) + << "A signal operation on a binary semaphore was not pending"; + if (old_state.signaled != 0) { + return FailedPreconditionErrorBuilder(IREE_LOC) + << "A binary semaphore cannot be signaled multiple times"; + } + State new_state = old_state; + new_state.signal_pending = 0; + new_state.signaled = 1; + state_.compare_exchange_strong(old_state, new_state); + return OkStatus(); +} + +Status HostBinarySemaphore::BeginWaiting() { + State old_state = state_.load(std::memory_order_acquire); + if (old_state.wait_pending != 0) { + return FailedPreconditionErrorBuilder(IREE_LOC) + << "A wait operation on a binary semaphore is already pending"; + } + State new_state = old_state; + new_state.wait_pending = 1; + state_.compare_exchange_strong(old_state, new_state); + return OkStatus(); +} + +Status HostBinarySemaphore::EndWaiting() { + State old_state = state_.load(std::memory_order_acquire); + DCHECK_EQ(old_state.wait_pending, 1) + << "A wait operation on a binary semaphore was not pending"; + if (old_state.signaled != 1) { + return FailedPreconditionErrorBuilder(IREE_LOC) + << "A binary semaphore cannot be reset multiple times"; + } + State new_state = old_state; + new_state.wait_pending = 0; + new_state.signaled = 0; + state_.compare_exchange_strong(old_state, new_state); + return OkStatus(); +} + +HostSubmissionQueue::HostSubmissionQueue() = default; + +HostSubmissionQueue::~HostSubmissionQueue() = default; + +bool HostSubmissionQueue::IsBatchReady(const PendingBatch& batch) const { + for (auto& wait_point : batch.wait_semaphores) { + if (wait_point.index() == 0) { + auto* binary_semaphore = + reinterpret_cast<HostBinarySemaphore*>(absl::get<0>(wait_point)); + if (!binary_semaphore->is_signaled()) { + return false; + } + } else { + // TODO(b/140141417): implement timeline semaphores. + return false; + } + } + return true; +} + +Status HostSubmissionQueue::Enqueue(absl::Span<const SubmissionBatch> batches, + FenceValue fence) { + IREE_TRACE_SCOPE0("HostSubmissionQueue::Enqueue"); + + if (has_shutdown_) { + return FailedPreconditionErrorBuilder(IREE_LOC) + << "Cannot enqueue new submissions; queue is exiting"; + } else if (!permanent_error_.ok()) { + return permanent_error_; + } + + // Verify waiting/signaling behavior on semaphores and prepare them all. + // We need to track this to ensure that we are modeling the Vulkan behavior + // and are consistent across HAL implementations. + for (auto& batch : batches) { + for (auto& semaphore_value : batch.wait_semaphores) { + if (semaphore_value.index() == 0) { + auto* binary_semaphore = reinterpret_cast<HostBinarySemaphore*>( + absl::get<0>(semaphore_value)); + RETURN_IF_ERROR(binary_semaphore->BeginWaiting()); + } else { + // TODO(b/140141417): implement timeline semaphores. + return UnimplementedErrorBuilder(IREE_LOC) << "Timeline semaphores NYI"; + } + } + for (auto& semaphore_value : batch.signal_semaphores) { + if (semaphore_value.index() == 0) { + auto* binary_semaphore = reinterpret_cast<HostBinarySemaphore*>( + absl::get<0>(semaphore_value)); + RETURN_IF_ERROR(binary_semaphore->BeginSignaling()); + } else { + // TODO(b/140141417): implement timeline semaphores. + return UnimplementedErrorBuilder(IREE_LOC) << "Timeline semaphores NYI"; + } + } + } + + // Add to list - order does not matter as Process evaluates semaphores. + auto submission = absl::make_unique<Submission>(); + submission->fence = std::move(fence); + submission->pending_batches.resize(batches.size()); + for (int i = 0; i < batches.size(); ++i) { + submission->pending_batches[i] = PendingBatch{ + {batches[i].wait_semaphores.begin(), batches[i].wait_semaphores.end()}, + {batches[i].command_buffers.begin(), batches[i].command_buffers.end()}, + {batches[i].signal_semaphores.begin(), + batches[i].signal_semaphores.end()}, + }; + } + list_.push_back(std::move(submission)); + + return OkStatus(); +} + +Status HostSubmissionQueue::ProcessBatches(ExecuteFn execute_fn) { + IREE_TRACE_SCOPE0("HostSubmissionQueue::ProcessBatches"); + + if (!permanent_error_.ok()) { + // Sticky failure state. + return permanent_error_; + } + + // Repeated try to run things until we quiesce or are blocked. + while (permanent_error_.ok() && !list_.empty()) { + // NOTE: to support re-entrancy where |execute_fn| may modify the submission + // list we need to always start from the beginning. If we wanted we could + // track a list of ready submissions however that's a lot of bookkeeping and + // the list is usually short. + bool restart_iteration = false; + for (auto* submission : list_) { + for (int i = 0; i < submission->pending_batches.size(); ++i) { + auto& batch = submission->pending_batches[i]; + if (!IsBatchReady(batch)) { + // Try the next batch in the submission until we find one that is + // ready. If none are ready we'll return to the caller. + continue; + } + + // Batch can run! Process now and remove it from the list so we don't + // try to run it again. + auto batch_status = ProcessBatch(batch, execute_fn); + submission->pending_batches.erase(submission->pending_batches.begin() + + i); + if (batch_status.ok()) { + // Batch succeeded. Since we want to preserve submission order we'll + // break out of the loop and try from the first submission again. + if (submission->pending_batches.empty()) { + // All work for this submission completed successfully. Signal the + // fence and remove the submission from the list. + RETURN_IF_ERROR(CompleteSubmission(submission, OkStatus())); + list_.take(submission).reset(); + } + } else { + // Batch failed; set the permanent error flag and abort so we don't + // try to process anything else. + permanent_error_ = batch_status; + RETURN_IF_ERROR(CompleteSubmission(submission, batch_status)); + list_.take(submission).reset(); + } + restart_iteration = true; + break; + } + if (restart_iteration) break; + } + } + + if (!permanent_error_.ok()) { + // If the sticky error got set while processing we need to abort all + // remaining submissions (simulating a device loss). + FailAllPending(permanent_error_); + return permanent_error_; + } + + return OkStatus(); +} + +Status HostSubmissionQueue::ProcessBatch(const PendingBatch& batch, + const ExecuteFn& execute_fn) { + IREE_TRACE_SCOPE0("HostSubmissionQueue::ProcessBatch"); + + // Complete the waits on all semaphores and reset them. + for (auto& semaphore_value : batch.wait_semaphores) { + if (semaphore_value.index() == 0) { + auto* binary_semaphore = + reinterpret_cast<HostBinarySemaphore*>(absl::get<0>(semaphore_value)); + RETURN_IF_ERROR(binary_semaphore->EndWaiting()); + } else { + // TODO(b/140141417): implement timeline semaphores. + return UnimplementedErrorBuilder(IREE_LOC) << "Timeline semaphores NYI"; + } + } + + // Let the caller handle execution of the command buffers. + RETURN_IF_ERROR(execute_fn(batch.command_buffers)); + + // Signal all semaphores to allow them to unblock waiters. + for (auto& semaphore_value : batch.signal_semaphores) { + if (semaphore_value.index() == 0) { + auto* binary_semaphore = + reinterpret_cast<HostBinarySemaphore*>(absl::get<0>(semaphore_value)); + RETURN_IF_ERROR(binary_semaphore->EndSignaling()); + } else { + // TODO(b/140141417): implement timeline semaphores. + return UnimplementedErrorBuilder(IREE_LOC) << "Timeline semaphores NYI"; + } + } + + return OkStatus(); +} + +Status HostSubmissionQueue::CompleteSubmission(Submission* submission, + Status status) { + IREE_TRACE_SCOPE0("HostSubmissionQueue::CompleteSubmission"); + + // It's safe to drop any remaining batches - their semaphores will never be + // signaled but that's fine as we should be the only thing relying on them. + submission->pending_batches.clear(); + + // Signal the fence. + auto* fence = static_cast<HostFence*>(submission->fence.first); + if (status.ok()) { + RETURN_IF_ERROR(fence->Signal(submission->fence.second)); + } else { + RETURN_IF_ERROR(fence->Fail(std::move(status))); + } + + return OkStatus(); +} + +void HostSubmissionQueue::FailAllPending(Status status) { + IREE_TRACE_SCOPE0("HostSubmissionQueue::FailAllPending"); + while (!list_.empty()) { + auto submission = list_.take(list_.front()); + CompleteSubmission(submission.get(), status).IgnoreError(); + submission.reset(); + } +} + +void HostSubmissionQueue::SignalShutdown() { + IREE_TRACE_SCOPE0("HostSubmissionQueue::SignalShutdown"); + has_shutdown_ = true; +} + +} // namespace hal +} // namespace iree
diff --git a/hal/host/host_submission_queue.h b/hal/host/host_submission_queue.h new file mode 100644 index 0000000..1ec141f --- /dev/null +++ b/hal/host/host_submission_queue.h
@@ -0,0 +1,163 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_HOST_HOST_SUBMISSION_QUEUE_H_ +#define IREE_HAL_HOST_HOST_SUBMISSION_QUEUE_H_ + +#include "absl/base/thread_annotations.h" +#include "absl/container/inlined_vector.h" +#include "absl/synchronization/mutex.h" +#include "base/intrusive_list.h" +#include "base/status.h" +#include "hal/command_queue.h" +#include "hal/host/host_fence.h" +#include "hal/semaphore.h" + +namespace iree { +namespace hal { + +class HostSubmissionQueue; + +// Simple host-only binary semaphore implemented with a mutex. +// To match the expected HAL behavior (mostly dictated by Vulkan) we can only +// have a single waiter and waits can only occur once a signal has been +// enqueued. +// +// Thread-safe (as instances may be imported and used by others). +class HostBinarySemaphore final : public BinarySemaphore { + public: + explicit HostBinarySemaphore(bool initial_value); + + // Returns true if the semaphore has been signaled. + bool is_signaled() const; + + private: + friend class HostSubmissionQueue; + + // Begins a signal operation and ensures no other signal operation is pending. + Status BeginSignaling(); + // Ends a signal operation by setting the semaphore to the signaled state. + Status EndSignaling(); + + // Begins a wait operation and ensures no other wait operation is pending. + Status BeginWaiting(); + // Ends a wait operation by resetting the semaphore to the unsignaled state. + Status EndWaiting(); + + // A single 32-bit int for lock-free semaphore behavior. We need to do this + // extra tracking so that we get consistent behavior across HAL + // implementations that have strict semaphore semantics. + struct State { + uint32_t signal_pending : 1; + uint32_t wait_pending : 1; + uint32_t signaled : 1; + }; + std::atomic<State> state_{{0, 0, 0}}; +}; + +// Simple host-only timeline semaphore implemented with a mutex. +// +// Thread-safe (as instances may be imported and used by others). +class HostTimelineSemaphore final : public TimelineSemaphore { + public: + // TODO(b/140141417): implement timeline semaphores. +}; + +// A queue managing CommandQueue submissions that uses host-local +// synchronization primitives. Evaluates submission order by respecting the +// wait and signal semaphores defined per batch and notifies fences upon +// submission completion. +// +// Note that it's possible for HAL users to deadlock themselves; we don't try to +// avoid that as in device backends it may not be possible and we want to have +// some kind of warning in the host implementation that TSAN can catch. +// +// Thread-compatible. Const methods may be called from any thread. +class HostSubmissionQueue { + public: + using ExecuteFn = + std::function<Status(absl::Span<CommandBuffer* const> command_buffers)>; + + HostSubmissionQueue(); + ~HostSubmissionQueue(); + + // Returns true if the queue is currently empty. + bool empty() const { return list_.empty(); } + // Returns true if SignalShutdown has been called. + bool has_shutdown() const { return has_shutdown_; } + // The sticky error status, if an error has occurred. + Status permanent_error() const { return permanent_error_; } + + // Enqueues a new submission. + // No work will be performed until Process is called. + Status Enqueue(absl::Span<const SubmissionBatch> batches, FenceValue fence); + + // Processes all ready batches using the provided |execute_fn|. + // The function may be called several times if new batches become ready due to + // prior batches in the sequence completing during processing. + // + // Returns any errors returned by |execute_fn| (which will be the same as + // permanent_error()). When an error occurs all in-flight submissions are + // aborted, the permanent_error() is set, and the queue is shutdown. + Status ProcessBatches(ExecuteFn execute_fn); + + // Marks the queue as having shutdown. All pending submissions will be allowed + // to complete but future enqueues will fail. + void SignalShutdown(); + + private: + // A submitted command buffer batch and its synchronization information. + struct PendingBatch { + absl::InlinedVector<SemaphoreValue, 4> wait_semaphores; + absl::InlinedVector<CommandBuffer*, 4> command_buffers; + absl::InlinedVector<SemaphoreValue, 4> signal_semaphores; + }; + struct Submission : public IntrusiveLinkBase<void> { + absl::InlinedVector<PendingBatch, 4> pending_batches; + FenceValue fence; + }; + + // Returns true if all wait semaphores in the |batch| are signaled. + bool IsBatchReady(const PendingBatch& batch) const; + + // Processes a batch by resetting semaphores, dispatching the command buffers + // to the specified |execute_fn|, and signaling semaphores. + // + // Preconditions: IsBatchReady(batch) == true + Status ProcessBatch(const PendingBatch& batch, const ExecuteFn& execute_fn); + + // Completes a submission by signaling the fence with the given |status|. + Status CompleteSubmission(Submission* submission, Status status); + + // Fails all pending submissions with the given status. + // Errors that occur during this process are silently ignored. + void FailAllPending(Status status); + + // True to exit the thread after all submissions complete. + bool has_shutdown_ = false; + + // A sticky error that is set on the first failed submit. All future + // submissions will be skipped except for fences, which will receive this + // error. + Status permanent_error_; + + // Pending submissions in submission order. + // Note that we may evaluate batches within the list out of order. + IntrusiveList<std::unique_ptr<Submission>> list_; +}; + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_HOST_HOST_SUBMISSION_QUEUE_H_
diff --git a/hal/host/host_submission_queue_test.cc b/hal/host/host_submission_queue_test.cc new file mode 100644 index 0000000..ed71de2 --- /dev/null +++ b/hal/host/host_submission_queue_test.cc
@@ -0,0 +1,30 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/host/host_submission_queue.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace iree { +namespace hal { +namespace { + +TEST(HostSubmissionQueueTest, TBD) { + // TODO(benvanik): test! +} + +} // namespace +} // namespace hal +} // namespace iree
diff --git a/hal/host/inproc_command_buffer.cc b/hal/host/inproc_command_buffer.cc new file mode 100644 index 0000000..4115825 --- /dev/null +++ b/hal/host/inproc_command_buffer.cc
@@ -0,0 +1,264 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/host/inproc_command_buffer.h" + +#include "base/tracing.h" + +namespace iree { +namespace hal { + +InProcCommandBuffer::InProcCommandBuffer( + Allocator* allocator, CommandBufferModeBitfield mode, + CommandCategoryBitfield command_categories) + : CommandBuffer(allocator, mode, command_categories) {} + +InProcCommandBuffer::~InProcCommandBuffer() { Reset(); } + +Status InProcCommandBuffer::Begin() { + IREE_TRACE_SCOPE0("InProcCommandBuffer::Begin"); + is_recording_ = true; + Reset(); + return OkStatus(); +} + +Status InProcCommandBuffer::End() { + IREE_TRACE_SCOPE0("InProcCommandBuffer::End"); + is_recording_ = false; + return OkStatus(); +} + +Status InProcCommandBuffer::ExecutionBarrier( + ExecutionStageBitfield source_stage_mask, + ExecutionStageBitfield target_stage_mask, + absl::Span<const MemoryBarrier> memory_barriers, + absl::Span<const BufferBarrier> buffer_barriers) { + IREE_TRACE_SCOPE0("InProcCommandBuffer::ExecutionBarrier"); + ASSIGN_OR_RETURN(auto* cmd, AppendCmd<ExecutionBarrierCmd>()); + cmd->source_stage_mask = source_stage_mask; + cmd->target_stage_mask = target_stage_mask; + cmd->memory_barriers = AppendStructSpan(memory_barriers); + cmd->buffer_barriers = AppendStructSpan(buffer_barriers); + return OkStatus(); +} + +Status InProcCommandBuffer::SignalEvent( + Event* event, ExecutionStageBitfield source_stage_mask) { + IREE_TRACE_SCOPE0("InProcCommandBuffer::SignalEvent"); + ASSIGN_OR_RETURN(auto* cmd, AppendCmd<SignalEventCmd>()); + cmd->event = event; + cmd->source_stage_mask = source_stage_mask; + return OkStatus(); +} + +Status InProcCommandBuffer::ResetEvent( + Event* event, ExecutionStageBitfield source_stage_mask) { + IREE_TRACE_SCOPE0("InProcCommandBuffer::ResetEvent"); + ASSIGN_OR_RETURN(auto* cmd, AppendCmd<ResetEventCmd>()); + cmd->event = event; + cmd->source_stage_mask = source_stage_mask; + return OkStatus(); +} + +Status InProcCommandBuffer::WaitEvents( + absl::Span<Event*> events, ExecutionStageBitfield source_stage_mask, + ExecutionStageBitfield target_stage_mask, + absl::Span<const MemoryBarrier> memory_barriers, + absl::Span<const BufferBarrier> buffer_barriers) { + IREE_TRACE_SCOPE0("InProcCommandBuffer::WaitEvents"); + ASSIGN_OR_RETURN(auto* cmd, AppendCmd<WaitEventsCmd>()); + cmd->events = AppendStructSpan(events); + cmd->source_stage_mask = source_stage_mask; + cmd->target_stage_mask = target_stage_mask; + cmd->memory_barriers = AppendStructSpan(memory_barriers); + cmd->buffer_barriers = AppendStructSpan(buffer_barriers); + return OkStatus(); +} + +Status InProcCommandBuffer::FillBuffer(Buffer* target_buffer, + device_size_t target_offset, + device_size_t length, + const void* pattern, + size_t pattern_length) { + IREE_TRACE_SCOPE0("InProcCommandBuffer::FillBuffer"); + ASSIGN_OR_RETURN(auto* cmd, AppendCmd<FillBufferCmd>()); + cmd->target_buffer = target_buffer; + cmd->target_offset = target_offset; + cmd->length = length; + std::memcpy(cmd->pattern, pattern, pattern_length); + cmd->pattern_length = pattern_length; + return OkStatus(); +} + +Status InProcCommandBuffer::DiscardBuffer(Buffer* buffer) { + IREE_TRACE_SCOPE0("InProcCommandBuffer::DiscardBuffer"); + ASSIGN_OR_RETURN(auto* cmd, AppendCmd<DiscardBufferCmd>()); + cmd->buffer = buffer; + return OkStatus(); +} + +Status InProcCommandBuffer::UpdateBuffer(const void* source_buffer, + device_size_t source_offset, + Buffer* target_buffer, + device_size_t target_offset, + device_size_t length) { + IREE_TRACE_SCOPE0("InProcCommandBuffer::UpdateBuffer"); + ASSIGN_OR_RETURN(auto* cmd, AppendCmd<UpdateBufferCmd>()); + cmd->source_buffer = AppendCmdData(source_buffer, source_offset, length); + cmd->target_buffer = target_buffer; + cmd->target_offset = target_offset; + cmd->length = length; + return OkStatus(); +} + +Status InProcCommandBuffer::CopyBuffer(Buffer* source_buffer, + device_size_t source_offset, + Buffer* target_buffer, + device_size_t target_offset, + device_size_t length) { + IREE_TRACE_SCOPE0("InProcCommandBuffer::CopyBuffer"); + ASSIGN_OR_RETURN(auto* cmd, AppendCmd<CopyBufferCmd>()); + cmd->source_buffer = source_buffer; + cmd->source_offset = source_offset; + cmd->target_buffer = target_buffer; + cmd->target_offset = target_offset; + cmd->length = length; + return OkStatus(); +} + +Status InProcCommandBuffer::Dispatch(const DispatchRequest& dispatch_request) { + IREE_TRACE_SCOPE0("InProcCommandBuffer::Dispatch"); + ASSIGN_OR_RETURN(auto* cmd, AppendCmd<DispatchCmd>()); + cmd->request.executable = dispatch_request.executable; + cmd->request.entry_point = dispatch_request.entry_point; + cmd->request.workload = dispatch_request.workload; + cmd->request.workload_buffer = dispatch_request.workload_buffer; + cmd->request.bindings = AppendStructSpan(dispatch_request.bindings); + return OkStatus(); +} + +void InProcCommandBuffer::Reset() { + auto* cmd_list = ¤t_cmd_list_; + cmd_list->head = cmd_list->tail = nullptr; + cmd_list->arena.Reset(); +} + +InProcCommandBuffer::CmdHeader* InProcCommandBuffer::AppendCmdHeader( + CmdType type, size_t cmd_size) { + auto* cmd_list = ¤t_cmd_list_; + auto* cmd_header = reinterpret_cast<CmdHeader*>( + cmd_list->arena.AllocateBytes(sizeof(CmdHeader) + cmd_size)); + cmd_header->next = nullptr; + cmd_header->type = type; + if (!cmd_list->head) { + cmd_list->head = cmd_header; + } else if (cmd_list->tail) { + cmd_list->tail->next = cmd_header; + } + cmd_list->tail = cmd_header; + return cmd_header; +} + +void* InProcCommandBuffer::AppendCmdData(const void* source_buffer, + device_size_t source_offset, + device_size_t source_length) { + auto* cmd_list = ¤t_cmd_list_; + + uint8_t* allocated_bytes = cmd_list->arena.AllocateBytes(source_length); + std::memcpy(allocated_bytes, + static_cast<const uint8_t*>(source_buffer) + source_offset, + source_length); + return allocated_bytes; +} + +Status InProcCommandBuffer::Process(CommandBuffer* command_processor) const { + IREE_TRACE_SCOPE0("InProcCommandBuffer::Process"); + + RETURN_IF_ERROR(command_processor->Begin()); + + // Process each command in the order they were recorded. + auto* cmd_list = ¤t_cmd_list_; + for (CmdHeader* cmd_header = cmd_list->head; cmd_header != nullptr; + cmd_header = cmd_header->next) { + auto command_status = ProcessCmd(cmd_header, command_processor); + if (!command_status.ok()) { + LOG(ERROR) << "DeviceQueue failure while executing command; permanently " + "failing all future commands: " + << command_status; + } + } + + RETURN_IF_ERROR(command_processor->End()); + + return OkStatus(); +} + +Status InProcCommandBuffer::ProcessCmd(CmdHeader* cmd_header, + CommandBuffer* command_processor) const { + switch (cmd_header->type) { + case CmdType::kExecutionBarrier: { + auto* cmd = reinterpret_cast<ExecutionBarrierCmd*>(cmd_header + 1); + return command_processor->ExecutionBarrier( + cmd->source_stage_mask, cmd->target_stage_mask, cmd->memory_barriers, + cmd->buffer_barriers); + } + case CmdType::kSignalEvent: { + auto* cmd = reinterpret_cast<SignalEventCmd*>(cmd_header + 1); + return command_processor->SignalEvent(cmd->event, cmd->source_stage_mask); + } + case CmdType::kResetEvent: { + auto* cmd = reinterpret_cast<ResetEventCmd*>(cmd_header + 1); + return command_processor->ResetEvent(cmd->event, cmd->source_stage_mask); + } + case CmdType::kWaitEvents: { + auto* cmd = reinterpret_cast<WaitEventsCmd*>(cmd_header + 1); + return command_processor->WaitEvents( + cmd->events, cmd->source_stage_mask, cmd->target_stage_mask, + cmd->memory_barriers, cmd->buffer_barriers); + } + case CmdType::kFillBuffer: { + auto* cmd = reinterpret_cast<FillBufferCmd*>(cmd_header + 1); + return command_processor->FillBuffer(cmd->target_buffer, + cmd->target_offset, cmd->length, + cmd->pattern, cmd->pattern_length); + } + case CmdType::kDiscardBuffer: { + auto* cmd = reinterpret_cast<DiscardBufferCmd*>(cmd_header + 1); + return command_processor->DiscardBuffer(cmd->buffer); + } + case CmdType::kUpdateBuffer: { + auto* cmd = reinterpret_cast<UpdateBufferCmd*>(cmd_header + 1); + return command_processor->UpdateBuffer(cmd->source_buffer, 0, + cmd->target_buffer, + cmd->target_offset, cmd->length); + } + case CmdType::kCopyBuffer: { + auto* cmd = reinterpret_cast<CopyBufferCmd*>(cmd_header + 1); + return command_processor->CopyBuffer( + cmd->source_buffer, cmd->source_offset, cmd->target_buffer, + cmd->target_offset, cmd->length); + } + case CmdType::kDispatch: { + auto* cmd = reinterpret_cast<DispatchCmd*>(cmd_header + 1); + return command_processor->Dispatch(cmd->request); + } + default: + return DataLossErrorBuilder(IREE_LOC) + << "Unrecognized command type " + << static_cast<int>(cmd_header->type) << "; corrupt buffer?"; + } +} + +} // namespace hal +} // namespace iree
diff --git a/hal/host/inproc_command_buffer.h b/hal/host/inproc_command_buffer.h new file mode 100644 index 0000000..a3bca41 --- /dev/null +++ b/hal/host/inproc_command_buffer.h
@@ -0,0 +1,241 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_HOST_INPROC_COMMAND_BUFFER_H_ +#define IREE_HAL_HOST_INPROC_COMMAND_BUFFER_H_ + +#include "base/arena.h" +#include "base/intrusive_list.h" +#include "base/status.h" +#include "hal/command_buffer.h" + +namespace iree { +namespace hal { + +// In-process command buffer with support for recording and playback. +// Commands are recorded into heap-allocated arenas with pointers to used +// resources (Buffer*, etc). To replay a command buffer against a real +// implementation use Process to call each command method as it was originally +// recorded. +// +// Thread-compatible (as with CommandBuffer itself). +class InProcCommandBuffer final : public CommandBuffer { + public: + InProcCommandBuffer(Allocator* allocator, CommandBufferModeBitfield mode, + CommandCategoryBitfield command_categories); + ~InProcCommandBuffer() override; + + bool is_recording() const override { return is_recording_; } + + Status Begin() override; + Status End() override; + + Status ExecutionBarrier( + ExecutionStageBitfield source_stage_mask, + ExecutionStageBitfield target_stage_mask, + absl::Span<const MemoryBarrier> memory_barriers, + absl::Span<const BufferBarrier> buffer_barriers) override; + + Status SignalEvent(Event* event, + ExecutionStageBitfield source_stage_mask) override; + + Status ResetEvent(Event* event, + ExecutionStageBitfield source_stage_mask) override; + + Status WaitEvents(absl::Span<Event*> events, + ExecutionStageBitfield source_stage_mask, + ExecutionStageBitfield target_stage_mask, + absl::Span<const MemoryBarrier> memory_barriers, + absl::Span<const BufferBarrier> buffer_barriers) override; + + Status FillBuffer(Buffer* target_buffer, device_size_t target_offset, + device_size_t length, const void* pattern, + size_t pattern_length) override; + + Status DiscardBuffer(Buffer* buffer) override; + + Status UpdateBuffer(const void* source_buffer, device_size_t source_offset, + Buffer* target_buffer, device_size_t target_offset, + device_size_t length) override; + + Status CopyBuffer(Buffer* source_buffer, device_size_t source_offset, + Buffer* target_buffer, device_size_t target_offset, + device_size_t length) override; + + Status Dispatch(const DispatchRequest& dispatch_request) override; + + // Processes all commands in the buffer using the given |command_processor|. + // The commands are issued in the order they were recorded. + Status Process(CommandBuffer* command_processor) const; + + private: + // Type of Cmd, used by CmdHeader to identify the command payload. + enum class CmdType { + kExecutionBarrier, + kSignalEvent, + kResetEvent, + kWaitEvents, + kFillBuffer, + kDiscardBuffer, + kUpdateBuffer, + kCopyBuffer, + kDispatch, + }; + + // Prefix for commands encoded into the CmdList. + // This is used to identify the type of a command as well as connect commands + // in the list sequence. Command data immediately follows the header in + // memory. + struct CmdHeader { + // Optional next command in the list. + CmdHeader* next; + // Type of the command. + CmdType type; + }; + + // A lightweight linked list of commands and an arena that stores them. + // CmdLists are designed to be reused so that the arena allocations are + // amortized across multiple uses. + // + // Note that this and the CmdHeader/Cmd types include raw pointers and as + // such are *not* portable across processes. It'd be possible, though, to + // extend this for cross-process use if a shared-memory Buffer was also + // implemented. For YAGNI we avoid that here. + struct CmdList : public IntrusiveLinkBase<void> { + static constexpr size_t kArenaBlockSize = 64 * 1024; + + Arena arena{kArenaBlockSize}; + CmdHeader* head = nullptr; + CmdHeader* tail = nullptr; + }; + + // Defines an execution barrier. + struct ExecutionBarrierCmd { + static constexpr CmdType kType = CmdType::kExecutionBarrier; + ExecutionStageBitfield source_stage_mask; + ExecutionStageBitfield target_stage_mask; + absl::Span<const MemoryBarrier> memory_barriers; + absl::Span<const BufferBarrier> buffer_barriers; + }; + + // Signals an event. + struct SignalEventCmd { + static constexpr CmdType kType = CmdType::kSignalEvent; + Event* event; + ExecutionStageBitfield source_stage_mask; + }; + + // Resets an event. + struct ResetEventCmd { + static constexpr CmdType kType = CmdType::kResetEvent; + Event* event; + ExecutionStageBitfield source_stage_mask; + }; + + // Waits for one or more events. + struct WaitEventsCmd { + static constexpr CmdType kType = CmdType::kWaitEvents; + absl::Span<Event*> events; + ExecutionStageBitfield source_stage_mask; + ExecutionStageBitfield target_stage_mask; + absl::Span<const MemoryBarrier> memory_barriers; + absl::Span<const BufferBarrier> buffer_barriers; + }; + + // Fills the target buffer with the given repeating value. + struct FillBufferCmd { + static constexpr CmdType kType = CmdType::kFillBuffer; + Buffer* target_buffer; + device_size_t target_offset; + device_size_t length; + uint8_t pattern[4]; + size_t pattern_length; + }; + + // Hints to the device queue that the given buffer will not be used again. + struct DiscardBufferCmd { + static constexpr CmdType kType = CmdType::kDiscardBuffer; + Buffer* buffer; + }; + + // Writes a range of the given target buffer from the embedded memory. + // The source buffer contents immediately follow the command in the arena. + struct UpdateBufferCmd { + static constexpr CmdType kType = CmdType::kUpdateBuffer; + const void* source_buffer; + Buffer* target_buffer; + device_size_t target_offset; + device_size_t length; + }; + + // Copies a range of one buffer to another. + struct CopyBufferCmd { + static constexpr CmdType kType = CmdType::kCopyBuffer; + Buffer* source_buffer; + device_size_t source_offset; + Buffer* target_buffer; + device_size_t target_offset; + device_size_t length; + }; + + // Dispatches an execution request. + struct DispatchCmd { + static constexpr CmdType kType = CmdType::kDispatch; + DispatchRequest request; + }; + + // Resets the command list. + void Reset(); + + // Allocates a command and appends it to the current command list. + // The caller must populate the fields in the returned pointer. + template <typename T> + StatusOr<T*> AppendCmd() { + return reinterpret_cast<T*>(AppendCmdHeader(T::kType, sizeof(T)) + 1); + } + + // Appends a command with the given |type| and payload |cmd_size| prefixed + // with a CmdHeader. Returns a pointer to the CmdHeader that is followed + // immediately by |cmd_size| zero bytes. + CmdHeader* AppendCmdHeader(CmdType type, size_t cmd_size); + + // Appends a byte buffer to the command buffer and returns a pointer to the + // copied data within the command buffer arena. + void* AppendCmdData(const void* source_buffer, device_size_t source_offset, + device_size_t source_length); + + // Appends a span of POD structs to the current CmdList and returns a span + // pointing into the CmdList arena. + template <typename T> + absl::Span<T> AppendStructSpan(absl::Span<T> value) { + static_assert(std::is_standard_layout<T>::value, + "Struct must be a POD type"); + void* data_ptr = AppendCmdData(value.data(), 0, value.size() * sizeof(T)); + return absl::MakeSpan(static_cast<T*>(data_ptr), value.size()); + } + + // Processes a single command. + Status ProcessCmd(CmdHeader* cmd_header, + CommandBuffer* command_processor) const; + + bool is_recording_ = false; + + // NOTE: not synchronized. Expected to be used from a single thread. + CmdList current_cmd_list_; +}; + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_HOST_INPROC_COMMAND_BUFFER_H_
diff --git a/hal/interpreter/BUILD b/hal/interpreter/BUILD new file mode 100644 index 0000000..5984600 --- /dev/null +++ b/hal/interpreter/BUILD
@@ -0,0 +1,190 @@ +# HAL implementation running on the CPU using the IREE bytecode. + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "bytecode_cache", + srcs = ["bytecode_cache.cc"], + hdrs = ["bytecode_cache.h"], + deps = [ + ":bytecode_executable", + "///base:source_location", + "///base:status", + "///base:tracing", + "///hal:allocator", + "///hal:executable", + "///hal:executable_cache", + "///hal:executable_format", + "///rt", + ], +) + +cc_library( + name = "bytecode_dispatch", + srcs = [ + "bytecode_dispatch.cc", + "bytecode_dispatch_conversion.h", + "bytecode_dispatch_util.cc", + "bytecode_dispatch_util.h", + ], + hdrs = ["bytecode_dispatch.h"], + deps = [ + ":bytecode_kernels", + "///base:logging", + "///base:memory", + "///base:status", + "///hal:allocator", + "///hal:buffer_view", + "///hal:heap_buffer", + "///rt", + "///schemas/bytecode:interpreter_bytecode_v0", + "///vm:bytecode_module", + "///vm:bytecode_reader", + "///vm:bytecode_tables_interpreter", + "///vm:bytecode_util", + "///vm:opcode_info", + "///vm:type", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "bytecode_executable", + srcs = ["bytecode_executable.cc"], + hdrs = ["bytecode_executable.h"], + deps = [ + ":interpreter_module", + "///base:status", + "///hal:allocator", + "///hal:executable", + "///hal:executable_spec", + "///rt", + "///vm:bytecode_tables_interpreter", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "bytecode_kernels", + hdrs = ["bytecode_kernels.h"], + textual_hdrs = [ + # TODO(benvanik): SIMD variants. + "bytecode_kernels_generic.h", + "bytecode_kernels_ruy.h", + ], + deps = [ + "///base:shape", + "///base:status", + "///hal:buffer_view", + "@com_google_absl//absl/algorithm", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", + "@org_tensorflow//tensorflow/lite/experimental/ruy", + "@org_tensorflow//tensorflow/lite/experimental/ruy:context", + ], +) + +cc_test( + name = "bytecode_kernels_test", + srcs = ["bytecode_kernels_test.cc"], + deps = [ + ":bytecode_kernels", + "///base:memory", + "///base:status_matchers", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "interpreter_command_processor", + srcs = ["interpreter_command_processor.cc"], + hdrs = ["interpreter_command_processor.h"], + deps = [ + ":bytecode_executable", + "///base:source_location", + "///base:status", + "///base:tracing", + "///hal:buffer_view", + "///hal/host:host_local_command_processor", + "///rt", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "interpreter_device", + srcs = ["interpreter_device.cc"], + hdrs = ["interpreter_device.h"], + deps = [ + ":bytecode_cache", + ":bytecode_kernels", + ":interpreter_command_processor", + "///base:memory", + "///base:status", + "///base:tracing", + "///hal:command_buffer_validation", + "///hal:command_queue", + "///hal:device", + "///hal:fence", + "///hal/host:async_command_queue", + "///hal/host:host_event", + "///hal/host:host_local_allocator", + "///hal/host:host_submission_queue", + "///hal/host:inproc_command_buffer", + "///rt", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "interpreter_driver", + srcs = ["interpreter_driver.cc"], + hdrs = ["interpreter_driver.h"], + deps = [ + ":interpreter_device", + "///hal:device_info", + "///hal:driver", + ], +) + +cc_library( + name = "interpreter_driver_module", + srcs = ["interpreter_driver_module.cc"], + deps = [ + ":interpreter_driver", + "///base:init", + "///base:status", + "///hal:driver_registry", + ], + alwayslink = 1, +) + +cc_library( + name = "interpreter_module", + srcs = ["interpreter_module.cc"], + hdrs = ["interpreter_module.h"], + deps = [ + ":bytecode_dispatch", + ":bytecode_kernels", + "///base:flatbuffer_util", + "///base:status", + "///base:tracing", + "///hal:allocator", + "///hal:buffer_view", + "///rt", + "///vm:bytecode_module", + "///vm:bytecode_tables_interpreter", + "@com_google_absl//absl/types:span", + ], +)
diff --git a/iree/hal/interpreter/CMakeLists.txt b/hal/interpreter/CMakeLists.txt similarity index 100% rename from iree/hal/interpreter/CMakeLists.txt rename to hal/interpreter/CMakeLists.txt
diff --git a/hal/interpreter/bytecode_cache.cc b/hal/interpreter/bytecode_cache.cc new file mode 100644 index 0000000..5eee0d8 --- /dev/null +++ b/hal/interpreter/bytecode_cache.cc
@@ -0,0 +1,55 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/interpreter/bytecode_cache.h" + +#include "base/source_location.h" +#include "base/status.h" +#include "base/tracing.h" +#include "hal/executable_format.h" +#include "hal/interpreter/bytecode_executable.h" + +namespace iree { +namespace hal { + +BytecodeCache::BytecodeCache(ref_ptr<rt::Instance> instance, + hal::Allocator* allocator) + : instance_(std::move(instance)), allocator_(allocator) {} + +BytecodeCache::~BytecodeCache() = default; + +bool BytecodeCache::CanPrepareFormat(ExecutableFormat format) const { + return format == kExecutableFormatIreeBytecode; +} + +StatusOr<ref_ptr<Executable>> BytecodeCache::PrepareExecutable( + ExecutableCachingModeBitfield mode, const ExecutableSpec& spec) { + IREE_TRACE_SCOPE0("BytecodeCache::PrepareExecutable"); + if (!CanPrepareFormat(spec.format)) { + return UnimplementedErrorBuilder(IREE_LOC) + << "Unsupported format: " << spec.format; + } + + // Wrap the data (or copy it). + bool allow_aliasing_data = + AllBitsSet(mode, ExecutableCachingMode::kAliasProvidedData); + ASSIGN_OR_RETURN(auto executable, + BytecodeExecutable::Load(add_ref(instance_), allocator_, + spec, !allow_aliasing_data)); + + return executable; +} + +} // namespace hal +} // namespace iree
diff --git a/hal/interpreter/bytecode_cache.h b/hal/interpreter/bytecode_cache.h new file mode 100644 index 0000000..7f56e3e --- /dev/null +++ b/hal/interpreter/bytecode_cache.h
@@ -0,0 +1,44 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_INTERPRETER_BYTECODE_CACHE_H_ +#define IREE_HAL_INTERPRETER_BYTECODE_CACHE_H_ + +#include "hal/allocator.h" +#include "hal/executable.h" +#include "hal/executable_cache.h" +#include "rt/instance.h" + +namespace iree { +namespace hal { + +class BytecodeCache final : public ExecutableCache { + public: + BytecodeCache(ref_ptr<rt::Instance> instance, hal::Allocator* allocator); + ~BytecodeCache() override; + + bool CanPrepareFormat(ExecutableFormat format) const override; + + StatusOr<ref_ptr<Executable>> PrepareExecutable( + ExecutableCachingModeBitfield mode, const ExecutableSpec& spec) override; + + private: + ref_ptr<rt::Instance> instance_; + hal::Allocator* allocator_; +}; + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_INTERPRETER_BYTECODE_CACHE_H_
diff --git a/hal/interpreter/bytecode_dispatch.cc b/hal/interpreter/bytecode_dispatch.cc new file mode 100644 index 0000000..8dd1892 --- /dev/null +++ b/hal/interpreter/bytecode_dispatch.cc
@@ -0,0 +1,850 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Implements a full bytecode dispatch system. +// Currently this is verbose and object oriented, but future revisions +// (once we have interesting benchmarks) will likely simplify and inline +// a lot of the checks to make things faster. Consider this to be as +// experimental an implementation as the entire rest of the project :) + +#include "hal/interpreter/bytecode_dispatch.h" + +#include <algorithm> + +#include "absl/base/attributes.h" +#include "absl/container/inlined_vector.h" +#include "absl/types/span.h" +#include "base/logging.h" +#include "base/memory.h" +#include "base/status.h" +#include "hal/buffer_view.h" +#include "hal/heap_buffer.h" +#include "hal/interpreter/bytecode_dispatch_conversion.h" +#include "hal/interpreter/bytecode_dispatch_util.h" +#include "hal/interpreter/bytecode_kernels.h" +#include "rt/function.h" +#include "schemas/bytecode/interpreter_bytecode_v0.h" +#include "vm/bytecode_module.h" +#include "vm/bytecode_reader.h" +#include "vm/bytecode_tables_interpreter.h" +#include "vm/bytecode_util.h" +#include "vm/opcode_info.h" + +namespace iree { +namespace hal { + +namespace { + +using ::iree::rt::Stack; +using ::iree::rt::StackFrame; +using ::iree::vm::BytecodeReader; + +} // namespace + +Status Dispatch(hal::Allocator* allocator, + kernels::RuntimeState* kernel_runtime_state, Stack* stack, + StackFrame* entry_stack_frame, + absl::Span<BufferView> entry_results) { + // Dispatch table mapping 1:1 with bytecode ops. + // Each entry is a label within this function that can be used for computed + // goto. You can find more information on computed goto here: + // https://eli.thegreenplace.net/2012/07/12/computed-goto-for-efficient-dispatch-tables + // + // Note that we ensure the table is 256 elements long exactly to make sure + // that unused opcodes are handled gracefully. + static const void* kDispatchTable[256] = { +#define DECLARE_DISPATCH(ordinal, name, ...) &&_dispatch_##name, +#define DECLARE_DISPATCH_RESERVED(ordinal, name, ...) &&_dispatch_unhandled, + IREE_INTERPRETER_OPCODE_LIST(DECLARE_DISPATCH, DECLARE_DISPATCH_RESERVED) +#undef DECLARE_DISPATCH +#undef DECLARE_DISPATCH_RESERVED + }; + + // Primary dispatch state. This is our 'native stack frame' and really just + // enough to make dereferencing common addresses (like the current offset) + // faster. You can think of this like CPU state (like PC). + // + // We hope that LLVM decides to keep these in registers (as they are touched + // for every instruction executed). The stack_frame will change as we call + // into different functions. + BytecodeReader reader(stack); + RETURN_IF_ERROR(reader.SwitchStackFrame(entry_stack_frame)); + +#define DISPATCH_NEXT() \ + { \ + uint8_t opcode = *reader.AdvanceOffset().ValueOrDie(); \ + DVLOG(1) \ + << "Interpreter dispatching op code: " \ + << GetOpcodeInfo(vm::interpreter_opcode_table(), opcode).mnemonic; \ + goto* kDispatchTable[opcode]; \ + } + +#define DISPATCH_CORE_OPCODE(opcode, body) \ + _dispatch_##opcode : {body} DISPATCH_NEXT() +#if defined(IREE_SUPPORT_F32) || defined(IREE_SUPPORT_F64) +#define DISPATCH_FLOAT_OPCODE(opcode, body) \ + _dispatch_##opcode : {body} DISPATCH_NEXT() +#else +#define DISPATCH_FLOAT_OPCODE(...) +#endif // IREE_SUPPORT_F32 || IREE_SUPPORT_F64 + + DISPATCH_NEXT(); + + DISPATCH_CORE_OPCODE(kConstant, { + ASSIGN_OR_RETURN(auto value, reader.ReadConstant()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + *dst_local = std::move(value); + }); + + DISPATCH_CORE_OPCODE(kCall, { + auto* old_stack_frame = stack->current_frame(); + ASSIGN_OR_RETURN(const auto& target_function, reader.ReadFunction()); + // TODO(benvanik): rework register storage interface. + ASSIGN_OR_RETURN( + const auto* function_def, + static_cast<const vm::BytecodeModule*>(target_function.module()) + ->GetFunctionDef(target_function.linkage(), + target_function.ordinal())); + ASSIGN_OR_RETURN(auto* new_stack_frame, stack->PushFrame(target_function)); + new_stack_frame->mutable_registers()->buffer_views.resize( + function_def->bytecode()->local_count()); + RETURN_IF_ERROR( + reader.CopyInputsAndSwitchStackFrame(old_stack_frame, new_stack_frame)); + DVLOG(1) << "Call; stack now: " << stack->DebugString(); + }); + + DISPATCH_CORE_OPCODE(kCallImport, { + return UnimplementedErrorBuilder(IREE_LOC) + << "Non-module imports not supported"; + }); + + DISPATCH_CORE_OPCODE(kCallIndirect, { + return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented call_indirect"; + }); + + DISPATCH_CORE_OPCODE(kReturn, { + auto* old_stack_frame = stack->current_frame(); + auto* new_stack_frame = stack->caller_frame(); + if (old_stack_frame == entry_stack_frame) { + // Returning from entry function. Marshal results from the return stmt. + ASSIGN_OR_RETURN(int32_t src_count, reader.ReadCount()); + for (int i = 0; i < src_count; ++i) { + ASSIGN_OR_RETURN( + auto* src_local, + reader.ReadLocal(old_stack_frame->mutable_registers())); + entry_results[i] = std::move(*src_local); + } + DVLOG(1) << "Returning to entry"; + return OkStatus(); + } else if (!new_stack_frame) { + return FailedPreconditionErrorBuilder(IREE_LOC) << "Stack underflow"; + } + RETURN_IF_ERROR(reader.CopyResultsAndSwitchStackFrame(old_stack_frame, + new_stack_frame)); + RETURN_IF_ERROR(stack->PopFrame()); + DVLOG(1) << "Return; stack now: " << stack->DebugString(); + }); + + DISPATCH_CORE_OPCODE(kBranch, { + ASSIGN_OR_RETURN(int32_t offset, reader.ReadBlockOffset()); + RETURN_IF_ERROR(reader.CopySlots()); + RETURN_IF_ERROR(reader.BranchToOffset(offset)); + }); + + DISPATCH_CORE_OPCODE(kCondBranch, { + // Evaluate condition first so we can do the copies as we read them for + // which side of the branch we take. + ASSIGN_OR_RETURN(auto* cond_local, reader.ReadLocal()); + bool cond_value = BufferViewIsTrue(*cond_local); + ASSIGN_OR_RETURN(int32_t true_offset, reader.ReadBlockOffset()); + if (cond_value) { + RETURN_IF_ERROR(reader.CopySlots()); + RETURN_IF_ERROR(reader.BranchToOffset(true_offset)); + } else { + ASSIGN_OR_RETURN(int32_t true_op_count, reader.ReadCount()); + RETURN_IF_ERROR(reader.SkipLocals(2 * true_op_count)); + ASSIGN_OR_RETURN(int32_t false_offset, reader.ReadBlockOffset()); + RETURN_IF_ERROR(reader.CopySlots()); + RETURN_IF_ERROR(reader.BranchToOffset(false_offset)); + } + }); + + DISPATCH_CORE_OPCODE(kCmpI, { + ASSIGN_OR_RETURN(uint8_t predicate, reader.ReadUint8_t()); + ASSIGN_OR_RETURN(auto* lhs_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* rhs_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + + switch (static_cast<CmpIPredicate>(predicate)) { + case CmpIPredicate::kEq: + RETURN_IF_ERROR(ApplyComparisonOpIS<kernels::CompareEQ>( + lhs_local, rhs_local, dst_local)); + break; + case CmpIPredicate::kNe: + RETURN_IF_ERROR(ApplyComparisonOpIS<kernels::CompareNE>( + lhs_local, rhs_local, dst_local)); + break; + case CmpIPredicate::kSlt: + RETURN_IF_ERROR(ApplyComparisonOpIS<kernels::CompareLT>( + lhs_local, rhs_local, dst_local)); + break; + case CmpIPredicate::kSle: + RETURN_IF_ERROR(ApplyComparisonOpIS<kernels::CompareLE>( + lhs_local, rhs_local, dst_local)); + break; + case CmpIPredicate::kSgt: + RETURN_IF_ERROR(ApplyComparisonOpIS<kernels::CompareGT>( + lhs_local, rhs_local, dst_local)); + break; + case CmpIPredicate::kSge: + RETURN_IF_ERROR(ApplyComparisonOpIS<kernels::CompareGE>( + lhs_local, rhs_local, dst_local)); + break; + case CmpIPredicate::kUlt: + RETURN_IF_ERROR(ApplyComparisonOpIU<kernels::CompareLT>( + lhs_local, rhs_local, dst_local)); + break; + case CmpIPredicate::kUle: + RETURN_IF_ERROR(ApplyComparisonOpIU<kernels::CompareLE>( + lhs_local, rhs_local, dst_local)); + break; + case CmpIPredicate::kUgt: + RETURN_IF_ERROR(ApplyComparisonOpIU<kernels::CompareGT>( + lhs_local, rhs_local, dst_local)); + break; + case CmpIPredicate::kUge: + RETURN_IF_ERROR(ApplyComparisonOpIU<kernels::CompareGE>( + lhs_local, rhs_local, dst_local)); + break; + } + }); + + DISPATCH_FLOAT_OPCODE(kCmpF, { + ASSIGN_OR_RETURN(uint8_t p, reader.ReadUint8_t()); + ASSIGN_OR_RETURN(auto* lhs_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* rhs_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + + auto predicate = static_cast<CmpFPredicate>(p); + switch (predicate) { + case CmpFPredicate::kOeq: + RETURN_IF_ERROR(ApplyComparisonOpF<kernels::CompareEQ>( + lhs_local, rhs_local, dst_local)); + break; + case CmpFPredicate::kUne: + RETURN_IF_ERROR(ApplyComparisonOpF<kernels::CompareNE>( + lhs_local, rhs_local, dst_local)); + break; + case CmpFPredicate::kOlt: + RETURN_IF_ERROR(ApplyComparisonOpF<kernels::CompareLT>( + lhs_local, rhs_local, dst_local)); + break; + case CmpFPredicate::kOle: + RETURN_IF_ERROR(ApplyComparisonOpF<kernels::CompareLE>( + lhs_local, rhs_local, dst_local)); + break; + case CmpFPredicate::kOgt: + RETURN_IF_ERROR(ApplyComparisonOpF<kernels::CompareGT>( + lhs_local, rhs_local, dst_local)); + break; + case CmpFPredicate::kOge: + RETURN_IF_ERROR(ApplyComparisonOpF<kernels::CompareGE>( + lhs_local, rhs_local, dst_local)); + break; + case CmpFPredicate::kFalse: + case CmpFPredicate::kOne: + case CmpFPredicate::kOrd: + case CmpFPredicate::kUeq: + case CmpFPredicate::kUgt: + case CmpFPredicate::kUge: + case CmpFPredicate::kUlt: + case CmpFPredicate::kUle: + case CmpFPredicate::kUno: + case CmpFPredicate::kTrue: + // TODO(b/132183250) support these if we ever need them. + return UnimplementedErrorBuilder(IREE_LOC) + << "Unsupported comparison predicate value " + << static_cast<int>(p) << " (" + << vm::PredicateToString(predicate) << ")"; + } + }); + + DISPATCH_CORE_OPCODE(kAllocStatic, { + return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented alloc_static"; + }); + + DISPATCH_CORE_OPCODE(kAllocStack, { + return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented alloc_stack"; + }); + + DISPATCH_CORE_OPCODE(kAllocStackInit, { + return UnimplementedErrorBuilder(IREE_LOC) + << "Unimplemented alloc_stack_init"; + }); + + DISPATCH_CORE_OPCODE(kAllocHeap, { + ASSIGN_OR_RETURN(auto heap_type, reader.ReadInt32()); + ASSIGN_OR_RETURN(auto type, reader.ReadType()); + size_t element_size = type.element_size(); + + // TODO(benvanik): more efficient reading and storage. + size_t element_count = 0; + ASSIGN_OR_RETURN(auto shape, reader.ReadShapePieces(&element_count)); + size_t allocation_size = element_size * element_count; + + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + dst_local->element_size = element_size; + dst_local->shape = shape; + + // TODO(benvanik): properly allocate with attributes from op. + CHECK_EQ(heap_type, 0); + ASSIGN_OR_RETURN( + dst_local->buffer, + allocator->Allocate(MemoryType::kHostLocal | MemoryType::kDeviceVisible, + BufferUsage::kAll, allocation_size)); + }); + + DISPATCH_CORE_OPCODE(kDiscard, { + // NOTE: if we were an encoder we would actually discard the buffer. + ASSIGN_OR_RETURN(auto* local, reader.ReadLocal()); + *local = {}; + }); + + DISPATCH_CORE_OPCODE(kRank, { + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + int32_t rank = src_local->shape.size(); + RETURN_IF_ERROR(dst_local->buffer->WriteData(0, &rank, sizeof(int32_t))); + }); + + DISPATCH_CORE_OPCODE(kDim, { + ASSIGN_OR_RETURN(int32_t axis, reader.ReadUint8_t()); + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(int32_t dim, src_local->shape.ResolveAxis(axis)); + RETURN_IF_ERROR(dst_local->buffer->WriteData(0, &dim, sizeof(int32_t))); + }); + + DISPATCH_CORE_OPCODE(kShape, { + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + RETURN_IF_ERROR(dst_local->buffer->WriteData( + 0, src_local->shape.subspan().data(), + src_local->shape.subspan().size() * sizeof(int32_t))); + }); + + DISPATCH_CORE_OPCODE(kLength, { + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + int32_t length = src_local->shape.element_count(); + RETURN_IF_ERROR(dst_local->buffer->WriteData(0, &length, sizeof(int32_t))); + }); + + DISPATCH_CORE_OPCODE(kDynamicSlice, { + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto indices, reader.ReadSlotElements<int32_t>()); + ASSIGN_OR_RETURN(auto lengths, reader.ReadSlotElements<int32_t>()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(*dst_local, src_local->Slice(indices, lengths)); + }); + + DISPATCH_CORE_OPCODE(kStaticSlice, { + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto indices, reader.ReadIndexList()); + ASSIGN_OR_RETURN(auto lengths, reader.ReadIndexList()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(*dst_local, src_local->Slice(indices, lengths)); + }); + + DISPATCH_CORE_OPCODE(kDynamicCopy, { + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto src_indices, reader.ReadSlotElements<int32_t>()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto dst_indices, reader.ReadSlotElements<int32_t>()); + ASSIGN_OR_RETURN(auto lengths, reader.ReadSlotElements<int32_t>()); + RETURN_IF_ERROR( + ApplyCopy(src_local, src_indices, dst_local, dst_indices, lengths)); + }); + + DISPATCH_CORE_OPCODE(kStaticCopy, { + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto src_indices, reader.ReadIndexList()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto dst_indices, reader.ReadIndexList()); + ASSIGN_OR_RETURN(auto lengths, reader.ReadIndexList()); + RETURN_IF_ERROR( + ApplyCopy(src_local, src_indices, dst_local, dst_indices, lengths)); + }); + + DISPATCH_CORE_OPCODE(kClone, { + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + dst_local->element_size = src_local->element_size; + dst_local->shape = src_local->shape; + dst_local->buffer = HeapBuffer::Allocate(src_local->buffer->usage(), + src_local->buffer->byte_length()); + RETURN_IF_ERROR(dst_local->buffer->CopyData(0, src_local->buffer.get())); + }); + + DISPATCH_CORE_OPCODE(kSplit, { + return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented split"; + }); + + DISPATCH_CORE_OPCODE(kAssign, { + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + *dst_local = *src_local; + }); + + DISPATCH_CORE_OPCODE(kCondAssign, { + ASSIGN_OR_RETURN(auto* cond_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* lhs_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* rhs_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + *dst_local = BufferViewIsTrue(*cond_local) ? *lhs_local : *rhs_local; + }); + + DISPATCH_CORE_OPCODE(kReshape, { + // TODO(benvanik): more logic required if strides differ. + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto shape_data, reader.ReadSlotElements<int32_t>()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + Shape new_shape = Shape{shape_data}; + if (src_local->shape.element_count() != new_shape.element_count()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "New element count " << new_shape.element_count() + << " != source element count " << src_local->shape.element_count(); + } + dst_local->shape = new_shape; + dst_local->buffer = add_ref(src_local->buffer); + dst_local->element_size = src_local->element_size; + }); + + DISPATCH_CORE_OPCODE(kSelect, { + ASSIGN_OR_RETURN(auto* cond_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* lhs_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* rhs_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto cond_buffer, cond_local->buffer->MapMemory<uint8_t>( + MemoryAccess::kRead)); + ASSIGN_OR_RETURN(auto lhs_buffer, lhs_local->buffer->MapMemory<uint8_t>( + MemoryAccess::kRead)); + ASSIGN_OR_RETURN(auto rhs_buffer, rhs_local->buffer->MapMemory<uint8_t>( + MemoryAccess::kRead)); + ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<uint8_t>( + MemoryAccess::kDiscardWrite)); + if (cond_local->element_size != 1) { + return InvalidArgumentErrorBuilder(IREE_LOC) << "Select cond must be i8"; + } else if (lhs_buffer.size() != rhs_buffer.size()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "LHS " << lhs_buffer.size() << "b != RHS " << rhs_buffer.size() + << "b; both arguments must match"; + } else if (lhs_buffer.size() != dst_buffer.size()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Dest " << dst_buffer.size() << "b != LHS/RHS " + << lhs_buffer.size() << "b; dest must match inputs"; + } + switch (lhs_local->element_size) { + case 1: + RETURN_IF_ERROR(kernels::Select::Execute<uint8_t>( + cond_buffer.contents(), lhs_buffer.contents(), + rhs_buffer.contents(), dst_buffer.mutable_contents())); + break; + case 2: + RETURN_IF_ERROR(kernels::Select::Execute<uint16_t>( + cond_buffer.contents(), + ReinterpretSpan<uint16_t>(lhs_buffer.contents()), + ReinterpretSpan<uint16_t>(rhs_buffer.contents()), + ReinterpretSpan<uint16_t>(dst_buffer.mutable_contents()))); + break; + case 4: + RETURN_IF_ERROR(kernels::Select::Execute<uint32_t>( + cond_buffer.contents(), + ReinterpretSpan<uint32_t>(lhs_buffer.contents()), + ReinterpretSpan<uint32_t>(rhs_buffer.contents()), + ReinterpretSpan<uint32_t>(dst_buffer.mutable_contents()))); + break; + case 8: + RETURN_IF_ERROR(kernels::Select::Execute<uint64_t>( + cond_buffer.contents(), + ReinterpretSpan<uint64_t>(lhs_buffer.contents()), + ReinterpretSpan<uint64_t>(rhs_buffer.contents()), + ReinterpretSpan<uint64_t>(dst_buffer.mutable_contents()))); + break; + default: + return UnimplementedErrorBuilder(IREE_LOC) + << "Unimplemented element size: " << lhs_local->element_size; + } + }); + + DISPATCH_CORE_OPCODE(kTranspose, { + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto perm_data, reader.ReadSlotElements<int32_t>()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + RETURN_IF_ERROR(ApplyUnaryOpIU<kernels::Transpose>( + src_local, dst_local, src_local->shape, + absl::MakeConstSpan(perm_data))); + }); + + DISPATCH_CORE_OPCODE(kReverse, { + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto perm_data, reader.ReadSlotElements<int32_t>()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + RETURN_IF_ERROR( + ApplyUnaryOpIU<kernels::Reverse>(src_local, dst_local, src_local->shape, + absl::MakeConstSpan(perm_data))); + }); + + DISPATCH_CORE_OPCODE(kPad, { + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* padding_value, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto edge_padding_low, reader.ReadSlotElements<int32_t>()); + ASSIGN_OR_RETURN(auto edge_padding_high, + reader.ReadSlotElements<int32_t>()); + ASSIGN_OR_RETURN(auto interior_padding, reader.ReadSlotElements<int32_t>()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + + RETURN_IF_ERROR(ApplyBinaryOpIU<kernels::Pad>( + src_local, padding_value, dst_local, src_local->shape, dst_local->shape, + absl::MakeConstSpan(edge_padding_low), + absl::MakeConstSpan(edge_padding_high), + absl::MakeConstSpan(interior_padding))); + }); + + DISPATCH_CORE_OPCODE(kBroadcast, { + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto shape_data, reader.ReadSlotElements<int32_t>()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + dst_local->shape = Shape{shape_data}; + RETURN_IF_ERROR(ApplyUnaryOpIU<kernels::Broadcast>(src_local, dst_local)); + }); + + DISPATCH_CORE_OPCODE(kTile, { + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto shape_data, reader.ReadSlotElements<int32_t>()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + dst_local->shape = Shape{shape_data}; + RETURN_IF_ERROR(ApplyUnaryOpIU<kernels::Tile>( + src_local, dst_local, src_local->shape, dst_local->shape)); + }); + + DISPATCH_CORE_OPCODE(kNot, { + RETURN_IF_ERROR(DispatchElementwiseUnaryOpIU<kernels::Not>(&reader)); + }); + DISPATCH_CORE_OPCODE(kAnd, { + RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::And>(&reader)); + }); + DISPATCH_CORE_OPCODE(kOr, { + RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Or>(&reader)); + }); + DISPATCH_CORE_OPCODE(kXor, { + RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Xor>(&reader)); + }); + DISPATCH_CORE_OPCODE(kShiftLeft, { + RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::ShiftLeft>(&reader)); + }); + DISPATCH_CORE_OPCODE(kShiftRightLogical, { + RETURN_IF_ERROR( + DispatchElementwiseBinaryOpIU<kernels::ShiftRight>(&reader)); + }); + DISPATCH_CORE_OPCODE(kShiftRightArithmetic, { + RETURN_IF_ERROR( + DispatchElementwiseBinaryOpIS<kernels::ShiftRight>(&reader)); + }); + + DISPATCH_CORE_OPCODE(kAddI, { + RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Add>(&reader)); + }); + DISPATCH_FLOAT_OPCODE(kAddF, { + RETURN_IF_ERROR(DispatchElementwiseBinaryOpF<kernels::Add>(&reader)); + }); + + DISPATCH_CORE_OPCODE(kSubI, { + RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Sub>(&reader)); + }); + DISPATCH_FLOAT_OPCODE(kSubF, { + RETURN_IF_ERROR(DispatchElementwiseBinaryOpF<kernels::Sub>(&reader)); + }); + + DISPATCH_CORE_OPCODE(kAbsI, { + RETURN_IF_ERROR(DispatchElementwiseUnaryOpIS<kernels::Abs>(&reader)); + }); + DISPATCH_FLOAT_OPCODE(kAbsF, { + RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Abs>(&reader)); + }); + + DISPATCH_CORE_OPCODE(kMulI, { + RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Mul>(&reader)); + }); + DISPATCH_FLOAT_OPCODE(kMulF, { + RETURN_IF_ERROR(DispatchElementwiseBinaryOpF<kernels::Mul>(&reader)); + }); + + DISPATCH_CORE_OPCODE(kDivIS, { + RETURN_IF_ERROR(DispatchElementwiseBinaryOpIS<kernels::Div>(&reader)); + }); + DISPATCH_CORE_OPCODE(kDivIU, { + RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Div>(&reader)); + }); + DISPATCH_FLOAT_OPCODE(kDivF, { + RETURN_IF_ERROR(DispatchElementwiseBinaryOpF<kernels::Div>(&reader)); + }); + + DISPATCH_CORE_OPCODE(kMulAddI, { + RETURN_IF_ERROR(DispatchElementwiseTernaryOpIU<kernels::MulAdd>(&reader)); + }); + DISPATCH_FLOAT_OPCODE(kMulAddF, { + RETURN_IF_ERROR(DispatchElementwiseTernaryOpF<kernels::MulAdd>(&reader)); + }); + DISPATCH_FLOAT_OPCODE(kExpF, { + RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Exp>(&reader)); + }); + DISPATCH_FLOAT_OPCODE(kLogF, { + RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Log>(&reader)); + }); + DISPATCH_FLOAT_OPCODE(kRsqrtF, { + RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Rsqrt>(&reader)); + }); + DISPATCH_FLOAT_OPCODE(kCosF, { + RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Cos>(&reader)); + }); + DISPATCH_FLOAT_OPCODE(kSinF, { + RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Sin>(&reader)); + }); + DISPATCH_FLOAT_OPCODE(kTanhF, { + RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Tanh>(&reader)); + }); + DISPATCH_FLOAT_OPCODE(kAtan2F, { + RETURN_IF_ERROR(DispatchElementwiseBinaryOpF<kernels::Atan2>(&reader)); + }); + + DISPATCH_CORE_OPCODE(kMinIS, { + RETURN_IF_ERROR(DispatchElementwiseBinaryOpIS<kernels::Min>(&reader)); + }); + DISPATCH_CORE_OPCODE(kMinIU, { + RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Min>(&reader)); + }); + DISPATCH_FLOAT_OPCODE(kMinF, { + RETURN_IF_ERROR(DispatchElementwiseBinaryOpF<kernels::Min>(&reader)); + }); + + DISPATCH_CORE_OPCODE(kMaxIS, { + RETURN_IF_ERROR(DispatchElementwiseBinaryOpIS<kernels::Max>(&reader)); + }); + DISPATCH_CORE_OPCODE(kMaxIU, { + RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Max>(&reader)); + }); + DISPATCH_FLOAT_OPCODE(kMaxF, { + RETURN_IF_ERROR(DispatchElementwiseBinaryOpF<kernels::Max>(&reader)); + }); + + DISPATCH_CORE_OPCODE(kClampIS, { + RETURN_IF_ERROR(DispatchElementwiseTernaryOpIS<kernels::Clamp>(&reader)); + }); + DISPATCH_CORE_OPCODE(kClampIU, { + RETURN_IF_ERROR(DispatchElementwiseTernaryOpIS<kernels::Clamp>(&reader)); + }); + DISPATCH_FLOAT_OPCODE(kClampF, { + RETURN_IF_ERROR(DispatchElementwiseTernaryOpF<kernels::Clamp>(&reader)); + }); + + DISPATCH_FLOAT_OPCODE(kFloorF, { + RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Floor>(&reader)); + }); + DISPATCH_FLOAT_OPCODE(kCeilF, { + RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Ceil>(&reader)); + }); + + DISPATCH_CORE_OPCODE(kConvertSS, { + ASSIGN_OR_RETURN(auto src_type, reader.ReadType()); + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto dst_type, reader.ReadType()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + RETURN_IF_ERROR( + ApplyConvertSS::Apply(src_type, src_local, dst_type, dst_local)); + }); + DISPATCH_CORE_OPCODE(kConvertUU, { + ASSIGN_OR_RETURN(auto src_type, reader.ReadType()); + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto dst_type, reader.ReadType()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + RETURN_IF_ERROR( + ApplyConvertUU::Apply(src_type, src_local, dst_type, dst_local)); + }); + DISPATCH_CORE_OPCODE(kConvertSU, { + ASSIGN_OR_RETURN(auto src_type, reader.ReadType()); + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto dst_type, reader.ReadType()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + RETURN_IF_ERROR( + ApplyConvertSU::Apply(src_type, src_local, dst_type, dst_local)); + }); + DISPATCH_CORE_OPCODE(kConvertUS, { + ASSIGN_OR_RETURN(auto src_type, reader.ReadType()); + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto dst_type, reader.ReadType()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + RETURN_IF_ERROR( + ApplyConvertUS::Apply(src_type, src_local, dst_type, dst_local)); + }); + + DISPATCH_CORE_OPCODE(kMatMulI, { + ASSIGN_OR_RETURN(auto* lhs_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* rhs_local, reader.ReadLocal()); + // TODO(benvanik): add fused matmul-with-bias op in MLIR and lower to this. + BufferView* bias_local = nullptr; + ASSIGN_OR_RETURN(auto* multiplier_mantissa_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* multiplier_exponent_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + RETURN_IF_ERROR(ValidateMatMulOpI(lhs_local, rhs_local, bias_local, + multiplier_mantissa_local, + multiplier_exponent_local, dst_local)); + auto* mat_mul_state = kernel_runtime_state->mat_mul_state.get(); + // TODO(benvanik): define as a matrix of supported types to enable 8*8=16, + // accumulator options, and other precision modes. + switch (lhs_local->element_size) { + case 1: + RETURN_IF_ERROR(ApplyMatMulOpI<int8_t>( + mat_mul_state, lhs_local, rhs_local, bias_local, + multiplier_mantissa_local, multiplier_exponent_local, dst_local)); + break; + case 2: + RETURN_IF_ERROR(ApplyMatMulOpI<int16_t>( + mat_mul_state, lhs_local, rhs_local, bias_local, + multiplier_mantissa_local, multiplier_exponent_local, dst_local)); + break; + case 4: + RETURN_IF_ERROR(ApplyMatMulOpI<int32_t>( + mat_mul_state, lhs_local, rhs_local, bias_local, + multiplier_mantissa_local, multiplier_exponent_local, dst_local)); + break; + case 8: + RETURN_IF_ERROR(ApplyMatMulOpI<int64_t>( + mat_mul_state, lhs_local, rhs_local, bias_local, + multiplier_mantissa_local, multiplier_exponent_local, dst_local)); + break; + default: + return UnimplementedErrorBuilder(IREE_LOC) + << "Unimplemented element size: " << lhs_local->element_size; + } + }); + + DISPATCH_FLOAT_OPCODE(kMatMulF, { + ASSIGN_OR_RETURN(auto* lhs_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* rhs_local, reader.ReadLocal()); + BufferView* bias_local = nullptr; + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + RETURN_IF_ERROR( + ValidateMatMulOpF(lhs_local, rhs_local, bias_local, dst_local)); + auto* mat_mul_state = kernel_runtime_state->mat_mul_state.get(); + switch (lhs_local->element_size) { + case 4: + RETURN_IF_ERROR(ApplyMatMulOpF<float>( + mat_mul_state, lhs_local, rhs_local, bias_local, dst_local)); + break; + case 8: + RETURN_IF_ERROR(ApplyMatMulOpF<double>( + mat_mul_state, lhs_local, rhs_local, bias_local, dst_local)); + break; + default: + return UnimplementedErrorBuilder(IREE_LOC) + << "Unimplemented element size: " << lhs_local->element_size; + } + }); + + DISPATCH_CORE_OPCODE(kReduceSumI, { + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* init_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto dimension, reader.ReadInt32()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + // TODO(scotttodd): validate + RETURN_IF_ERROR(ApplyBinaryOpIS<kernels::ReduceSum>( + src_local, init_local, dst_local, dimension, src_local->shape, + dst_local->shape)); + }); + + DISPATCH_FLOAT_OPCODE(kReduceSumF, { + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* init_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto dimension, reader.ReadInt32()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + // TODO(scotttodd): validate + RETURN_IF_ERROR(ApplyBinaryOpF<kernels::ReduceSum>( + src_local, init_local, dst_local, dimension, src_local->shape, + dst_local->shape)); + }); + + DISPATCH_CORE_OPCODE(kReduceMinI, { + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* init_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto dimension, reader.ReadInt32()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + // TODO(scotttodd): validate + RETURN_IF_ERROR(ApplyBinaryOpIS<kernels::ReduceMin>( + src_local, init_local, dst_local, dimension, src_local->shape, + dst_local->shape)); + }); + + DISPATCH_FLOAT_OPCODE(kReduceMinF, { + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* init_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto dimension, reader.ReadInt32()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + // TODO(scotttodd): validate + RETURN_IF_ERROR(ApplyBinaryOpF<kernels::ReduceMin>( + src_local, init_local, dst_local, dimension, src_local->shape, + dst_local->shape)); + }); + + DISPATCH_CORE_OPCODE(kReduceMaxI, { + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* init_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto dimension, reader.ReadInt32()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + // TODO(scotttodd): validate + RETURN_IF_ERROR(ApplyBinaryOpIS<kernels::ReduceMax>( + src_local, init_local, dst_local, dimension, src_local->shape, + dst_local->shape)); + }); + + DISPATCH_FLOAT_OPCODE(kReduceMaxF, { + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* init_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto dimension, reader.ReadInt32()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + // TODO(scotttodd): validate + RETURN_IF_ERROR(ApplyBinaryOpF<kernels::ReduceMax>( + src_local, init_local, dst_local, dimension, src_local->shape, + dst_local->shape)); + }); + + DISPATCH_CORE_OPCODE(kTrace, { + return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented trace"; + }); + + DISPATCH_CORE_OPCODE(kBreak, { + return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented break"; + }); + + DISPATCH_CORE_OPCODE(kCondBreak, { + return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented cond_break"; + }); + +_dispatch_unhandled: + // TODO(benvanik): better tracing. + return UnimplementedErrorBuilder(IREE_LOC) << "Unknown dispatch opcode"; +} // NOLINT(readability/fn_size) + +} // namespace hal +} // namespace iree
diff --git a/hal/interpreter/bytecode_dispatch.h b/hal/interpreter/bytecode_dispatch.h new file mode 100644 index 0000000..4a9db68 --- /dev/null +++ b/hal/interpreter/bytecode_dispatch.h
@@ -0,0 +1,35 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_H_ +#define IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_H_ + +#include "base/status.h" +#include "hal/allocator.h" +#include "hal/interpreter/bytecode_kernels.h" +#include "rt/stack.h" +#include "rt/stack_frame.h" + +namespace iree { +namespace hal { + +Status Dispatch(hal::Allocator* allocator, + kernels::RuntimeState* kernel_runtime_state, rt::Stack* stack, + rt::StackFrame* entry_stack_frame, + absl::Span<BufferView> entry_results); + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_H_
diff --git a/hal/interpreter/bytecode_dispatch_conversion.h b/hal/interpreter/bytecode_dispatch_conversion.h new file mode 100644 index 0000000..8cedd75 --- /dev/null +++ b/hal/interpreter/bytecode_dispatch_conversion.h
@@ -0,0 +1,395 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Conversion helper tables. + +#ifndef IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_CONVERSION_H_ +#define IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_CONVERSION_H_ + +#include "base/status.h" +#include "hal/buffer_view.h" +#include "hal/interpreter/bytecode_dispatch_util.h" +#include "schemas/bytecode/interpreter_bytecode_v0.h" +#include "vm/type.h" + +namespace iree { +namespace hal { + +template <typename KERNEL, bool src_signed, bool dst_signed, typename... ARGS> +struct ApplyConversionOp { + static Status Apply(const vm::Type& src_type, BufferView* src_local, + const vm::Type& dst_type, BufferView* dst_local, + ARGS... args) { + // Validate ranges so that we cannot go out of bounds on thunk table. + int src_type_index = src_type.type_index(); + int dst_type_index = dst_type.type_index(); + if (src_type_index < 0 || src_type_index >= kBuiltinTypeCount) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Conversion from invalid source builtin type " + << src_type_index; + } else if (dst_type_index < 0 || dst_type_index >= kBuiltinTypeCount) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Conversion to invalid dest builtin type " << dst_type_index; + } + + // All possible combinations of conversions. + using KernelFn = Status (*)(BufferView * src_local, BufferView * dst_local, + ARGS... args); + KernelFn fn = nullptr; + if (src_signed && dst_signed) { + // Signed -> signed. + static const KernelFn + kConversionTable[kBuiltinTypeCount * kBuiltinTypeCount] = { + // src_type = kI8: + /* kI8 */ Thunk<int8_t, int8_t>::Apply, + /* kI16 */ Thunk<int8_t, int16_t>::Apply, + /* kI32 */ Thunk<int8_t, int32_t>::Apply, + /* kI64 */ Thunk<int8_t, int64_t>::Apply, + /* kF16 */ nullptr, + /* kF32 */ Thunk<int8_t, float>::Apply, + /* kF64 */ Thunk<int8_t, double>::Apply, + + // src_type = kI16: + /* kI8 */ Thunk<int16_t, int8_t>::Apply, + /* kI16 */ Thunk<int16_t, int16_t>::Apply, + /* kI32 */ Thunk<int16_t, int32_t>::Apply, + /* kI64 */ Thunk<int16_t, int64_t>::Apply, + /* kF16 */ nullptr, + /* kF32 */ Thunk<int16_t, float>::Apply, + /* kF64 */ Thunk<int16_t, double>::Apply, + + // src_type = kI32: + /* kI8 */ Thunk<int32_t, int8_t>::Apply, + /* kI16 */ Thunk<int32_t, int16_t>::Apply, + /* kI32 */ Thunk<int32_t, int32_t>::Apply, + /* kI64 */ Thunk<int32_t, int64_t>::Apply, + /* kF16 */ nullptr, + /* kF32 */ Thunk<int32_t, float>::Apply, + /* kF64 */ Thunk<int32_t, double>::Apply, + + // src_type = kI64: + /* kI8 */ Thunk<int64_t, int8_t>::Apply, + /* kI16 */ Thunk<int64_t, int16_t>::Apply, + /* kI32 */ Thunk<int64_t, int32_t>::Apply, + /* kI64 */ Thunk<int64_t, int64_t>::Apply, + /* kF16 */ nullptr, + /* kF32 */ Thunk<int64_t, float>::Apply, + /* kF64 */ Thunk<int64_t, double>::Apply, + + // src_type = kF16: + /* kI8 */ nullptr, + /* kI16 */ nullptr, + /* kI32 */ nullptr, + /* kI64 */ nullptr, + /* kF16 */ Thunk<uint16_t, uint16_t>::Apply, + /* kF32 */ nullptr, + /* kF64 */ nullptr, + + // src_type = kF32: + /* kI8 */ Thunk<float, int8_t>::Apply, + /* kI16 */ Thunk<float, int16_t>::Apply, + /* kI32 */ Thunk<float, int32_t>::Apply, + /* kI64 */ Thunk<float, int64_t>::Apply, + /* kF16 */ nullptr, + /* kF32 */ Thunk<float, float>::Apply, + /* kF64 */ Thunk<float, double>::Apply, + + // src_type = kF64: + /* kI8 */ Thunk<double, int8_t>::Apply, + /* kI16 */ Thunk<double, int16_t>::Apply, + /* kI32 */ Thunk<double, int32_t>::Apply, + /* kI64 */ Thunk<double, int64_t>::Apply, + /* kF16 */ nullptr, + /* kF32 */ Thunk<double, float>::Apply, + /* kF64 */ Thunk<double, double>::Apply, + }; + fn = + kConversionTable[src_type_index * kBuiltinTypeCount + dst_type_index]; + } else if (src_signed && !dst_signed) { + // Signed -> unsigned. + static const KernelFn + kConversionTable[kBuiltinTypeCount * kBuiltinTypeCount] = { + // src_type = kI8: + /* kI8 */ Thunk<int8_t, uint8_t>::Apply, + /* kI16 */ Thunk<int8_t, uint16_t>::Apply, + /* kI32 */ Thunk<int8_t, uint32_t>::Apply, + /* kI64 */ Thunk<int8_t, uint64_t>::Apply, + /* kF16 */ nullptr, + /* kF32 */ nullptr, + /* kF64 */ nullptr, + + // src_type = kI16: + /* kI8 */ Thunk<int16_t, uint8_t>::Apply, + /* kI16 */ Thunk<int16_t, uint16_t>::Apply, + /* kI32 */ Thunk<int16_t, uint32_t>::Apply, + /* kI64 */ Thunk<int16_t, uint64_t>::Apply, + /* kF16 */ nullptr, + /* kF32 */ nullptr, + /* kF64 */ nullptr, + + // src_type = kI32: + /* kI8 */ Thunk<int32_t, uint8_t>::Apply, + /* kI16 */ Thunk<int32_t, uint16_t>::Apply, + /* kI32 */ Thunk<int32_t, uint32_t>::Apply, + /* kI64 */ Thunk<int32_t, uint64_t>::Apply, + /* kF16 */ nullptr, + /* kF32 */ nullptr, + /* kF64 */ nullptr, + + // src_type = kI64: + /* kI8 */ Thunk<int64_t, uint8_t>::Apply, + /* kI16 */ Thunk<int64_t, uint16_t>::Apply, + /* kI32 */ Thunk<int64_t, uint32_t>::Apply, + /* kI64 */ Thunk<int64_t, uint64_t>::Apply, + /* kF16 */ nullptr, + /* kF32 */ nullptr, + /* kF64 */ nullptr, + + // src_type = kF16: + /* kI8 */ nullptr, + /* kI16 */ nullptr, + /* kI32 */ nullptr, + /* kI64 */ nullptr, + /* kF16 */ nullptr, + /* kF32 */ nullptr, + /* kF64 */ nullptr, + + // src_type = kF32: + /* kI8 */ Thunk<float, uint8_t>::Apply, + /* kI16 */ Thunk<float, uint16_t>::Apply, + /* kI32 */ Thunk<float, uint32_t>::Apply, + /* kI64 */ Thunk<float, uint64_t>::Apply, + /* kF16 */ nullptr, + /* kF32 */ nullptr, + /* kF64 */ nullptr, + + // src_type = kF64: + /* kI8 */ Thunk<double, uint8_t>::Apply, + /* kI16 */ Thunk<double, uint16_t>::Apply, + /* kI32 */ Thunk<double, uint32_t>::Apply, + /* kI64 */ Thunk<double, uint64_t>::Apply, + /* kF16 */ nullptr, + /* kF32 */ nullptr, + /* kF64 */ nullptr, + }; + fn = + kConversionTable[src_type_index * kBuiltinTypeCount + dst_type_index]; + } else if (!src_signed && dst_signed) { + // Unsigned -> signed. + static const KernelFn + kConversionTable[kBuiltinTypeCount * kBuiltinTypeCount] = { + // src_type = kI8: + /* kI8 */ Thunk<uint8_t, int8_t>::Apply, + /* kI16 */ Thunk<uint8_t, int16_t>::Apply, + /* kI32 */ Thunk<uint8_t, int32_t>::Apply, + /* kI64 */ Thunk<uint8_t, int64_t>::Apply, + /* kF16 */ nullptr, + /* kF32 */ Thunk<uint8_t, float>::Apply, + /* kF64 */ Thunk<uint8_t, double>::Apply, + + // src_type = kI16: + /* kI8 */ Thunk<uint16_t, int8_t>::Apply, + /* kI16 */ Thunk<uint16_t, int16_t>::Apply, + /* kI32 */ Thunk<uint16_t, int32_t>::Apply, + /* kI64 */ Thunk<uint16_t, int64_t>::Apply, + /* kF16 */ nullptr, + /* kF32 */ Thunk<uint16_t, float>::Apply, + /* kF64 */ Thunk<uint16_t, double>::Apply, + + // src_type = kI32: + /* kI8 */ Thunk<uint32_t, int8_t>::Apply, + /* kI16 */ Thunk<uint32_t, int16_t>::Apply, + /* kI32 */ Thunk<uint32_t, int32_t>::Apply, + /* kI64 */ Thunk<uint32_t, int64_t>::Apply, + /* kF16 */ nullptr, + /* kF32 */ Thunk<uint32_t, float>::Apply, + /* kF64 */ Thunk<uint32_t, double>::Apply, + + // src_type = kI64: + /* kI8 */ Thunk<uint64_t, int8_t>::Apply, + /* kI16 */ Thunk<uint64_t, int16_t>::Apply, + /* kI32 */ Thunk<uint64_t, int32_t>::Apply, + /* kI64 */ Thunk<uint64_t, int64_t>::Apply, + /* kF16 */ nullptr, + /* kF32 */ Thunk<uint64_t, float>::Apply, + /* kF64 */ Thunk<uint64_t, double>::Apply, + + // src_type = kF16: + /* kI8 */ nullptr, + /* kI16 */ nullptr, + /* kI32 */ nullptr, + /* kI64 */ nullptr, + /* kF16 */ nullptr, + /* kF32 */ nullptr, + /* kF64 */ nullptr, + + // src_type = kF32: + /* kI8 */ nullptr, + /* kI16 */ nullptr, + /* kI32 */ nullptr, + /* kI64 */ nullptr, + /* kF16 */ nullptr, + /* kF32 */ nullptr, + /* kF64 */ nullptr, + + // src_type = kF64: + /* kI8 */ nullptr, + /* kI16 */ nullptr, + /* kI32 */ nullptr, + /* kI64 */ nullptr, + /* kF16 */ nullptr, + /* kF32 */ nullptr, + /* kF64 */ nullptr, + }; + fn = + kConversionTable[src_type_index * kBuiltinTypeCount + dst_type_index]; + } else if (!src_signed && !dst_signed) { + // Unsigned -> unsigned. + static const KernelFn + kConversionTable[kBuiltinTypeCount * kBuiltinTypeCount] = { + // src_type = kI8: + /* kI8 */ Thunk<uint8_t, uint8_t>::Apply, + /* kI16 */ Thunk<uint8_t, uint16_t>::Apply, + /* kI32 */ Thunk<uint8_t, uint32_t>::Apply, + /* kI64 */ Thunk<uint8_t, uint64_t>::Apply, + /* kF16 */ nullptr, + /* kF32 */ nullptr, + /* kF64 */ nullptr, + + // src_type = kI16: + /* kI8 */ Thunk<uint16_t, uint8_t>::Apply, + /* kI16 */ Thunk<uint16_t, uint16_t>::Apply, + /* kI32 */ Thunk<uint16_t, uint32_t>::Apply, + /* kI64 */ Thunk<uint16_t, uint64_t>::Apply, + /* kF16 */ nullptr, + /* kF32 */ nullptr, + /* kF64 */ nullptr, + + // src_type = kI32: + /* kI8 */ Thunk<uint32_t, uint8_t>::Apply, + /* kI16 */ Thunk<uint32_t, uint16_t>::Apply, + /* kI32 */ Thunk<uint32_t, uint32_t>::Apply, + /* kI64 */ Thunk<uint32_t, uint64_t>::Apply, + /* kF16 */ nullptr, + /* kF32 */ nullptr, + /* kF64 */ nullptr, + + // src_type = kI64: + /* kI8 */ Thunk<uint64_t, uint8_t>::Apply, + /* kI16 */ Thunk<uint64_t, uint16_t>::Apply, + /* kI32 */ Thunk<uint64_t, uint32_t>::Apply, + /* kI64 */ Thunk<uint64_t, uint64_t>::Apply, + /* kF16 */ nullptr, + /* kF32 */ nullptr, + /* kF64 */ nullptr, + + // src_type = kF16: + /* kI8 */ nullptr, + /* kI16 */ nullptr, + /* kI32 */ nullptr, + /* kI64 */ nullptr, + /* kF16 */ nullptr, + /* kF32 */ nullptr, + /* kF64 */ nullptr, + + // src_type = kF32: + /* kI8 */ nullptr, + /* kI16 */ nullptr, + /* kI32 */ nullptr, + /* kI64 */ nullptr, + /* kF16 */ nullptr, + /* kF32 */ nullptr, + /* kF64 */ nullptr, + + // src_type = kF64: + /* kI8 */ nullptr, + /* kI16 */ nullptr, + /* kI32 */ nullptr, + /* kI64 */ nullptr, + /* kF16 */ nullptr, + /* kF32 */ nullptr, + /* kF64 */ nullptr, + }; + fn = + kConversionTable[src_type_index * kBuiltinTypeCount + dst_type_index]; + } + if (!fn) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Unsupported conversion from " << src_type_index << " to " + << dst_type_index; + } + return fn(src_local, dst_local, args...); + } + + template <typename SRC, typename DST> + struct Thunk { + static Status Apply(BufferView* src_local, BufferView* dst_local, + ARGS... args) { + ASSIGN_OR_RETURN(auto src_buffer, + src_local->buffer->MapMemory<SRC>(MemoryAccess::kRead)); + ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<DST>( + MemoryAccess::kDiscardWrite)); + return KERNEL::Execute(src_buffer.contents(), + dst_buffer.mutable_contents(), args...); + } + }; + +// Disable F32/F64 conversions if they are not supported. +#if !defined(IREE_SUPPORT_F32) + template <typename DST> + struct Thunk<float, DST> { + static Status Apply(BufferView* src_local, BufferView* dst_local, + ARGS... args) { + return UnimplementedErrorBuilder(IREE_LOC) << "F32 not supported"; + } + }; + template <typename SRC> + struct Thunk<SRC, float> { + static Status Apply(BufferView* src_local, BufferView* dst_local, + ARGS... args) { + return UnimplementedErrorBuilder(IREE_LOC) << "F32 not supported"; + } + }; +#endif // !IREE_SUPPORT_F32 +#if !defined(IREE_SUPPORT_F64) + template <typename DST> + struct Thunk<double, DST> { + static Status Apply(BufferView* src_local, BufferView* dst_local, + ARGS... args) { + return UnimplementedErrorBuilder(IREE_LOC) << "F64 not supported"; + } + }; + template <typename SRC> + struct Thunk<SRC, double> { + static Status Apply(BufferView* src_local, BufferView* dst_local, + ARGS... args) { + return UnimplementedErrorBuilder(IREE_LOC) << "F64 not supported"; + } + }; +#endif // !IREE_SUPPORT_F64 +}; + +using ApplyConvertSS = ApplyConversionOp<kernels::Convert, /*src_signed=*/true, + /*dst_signed=*/true>; +using ApplyConvertUU = ApplyConversionOp<kernels::Convert, /*src_signed=*/false, + /*dst_signed=*/false>; +using ApplyConvertSU = ApplyConversionOp<kernels::Convert, /*src_signed=*/true, + /*dst_signed=*/false>; +using ApplyConvertUS = ApplyConversionOp<kernels::Convert, /*src_signed=*/false, + /*dst_signed=*/true>; + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_CONVERSION_H_
diff --git a/hal/interpreter/bytecode_dispatch_util.cc b/hal/interpreter/bytecode_dispatch_util.cc new file mode 100644 index 0000000..988b525 --- /dev/null +++ b/hal/interpreter/bytecode_dispatch_util.cc
@@ -0,0 +1,107 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/interpreter/bytecode_dispatch_util.h" + +namespace iree { +namespace hal { + +bool BufferViewIsTrue(const BufferView& buffer_view) { + if (buffer_view.element_size == 0 || !buffer_view.buffer || + buffer_view.byte_length() == 0) { + return false; + } + // TODO(benvanik): map more efficiently (based on element size?). + auto mapping = + buffer_view.buffer->MapMemory<uint8_t>(hal::MemoryAccess::kRead); + if (!mapping.ok()) { + return false; + } + for (uint8_t value : mapping.ValueOrDie().contents()) { + if (value) return true; + } + return false; +} + +Status ValidateElementwiseUnaryOp(BufferView* src_local, + BufferView* dst_local) { + // TODO(benvanik): validate shapes. + return OkStatus(); +} + +Status ValidateElementwiseBinaryOp(BufferView* lhs_local, BufferView* rhs_local, + BufferView* dst_local) { + // TODO(benvanik): validate shapes. + return OkStatus(); +} + +Status ValidateElementwiseTernaryOp(BufferView* a_local, BufferView* b_local, + BufferView* c_local, + BufferView* dst_local) { + // TODO(benvanik): validate shapes. + return OkStatus(); +} + +Status ValidateMatMulOpI(BufferView* lhs_local, BufferView* rhs_local, + BufferView* bias_local, + BufferView* multiplier_mantissa_local, + BufferView* multiplier_exponent_local, + BufferView* dst_local) { + // TODO(benvanik): validate shapes. + return OkStatus(); +} + +Status ValidateMatMulOpF(BufferView* lhs_local, BufferView* rhs_local, + BufferView* bias_local, BufferView* dst_local) { + // TODO(benvanik): validate shapes. + return OkStatus(); +} + +Status ApplyCopy(BufferView* src_local, absl::Span<const int32_t> src_indices, + BufferView* dst_local, absl::Span<const int32_t> dst_indices, + absl::Span<const int32_t> lengths) { + ASSIGN_OR_RETURN(auto src_buffer, + src_local->buffer->MapMemory<uint8_t>(MemoryAccess::kRead)); + // TODO(benvanik): discard if overwriting the entire buffer. + ASSIGN_OR_RETURN(auto dst_buffer, + dst_local->buffer->MapMemory<uint8_t>(MemoryAccess::kWrite)); + switch (src_local->element_size) { + case 1: + return kernels::Copy::Execute<1>(src_buffer.contents(), src_local->shape, + src_indices, + dst_buffer.mutable_contents(), + dst_local->shape, dst_indices, lengths); + case 2: + return kernels::Copy::Execute<2>(src_buffer.contents(), src_local->shape, + src_indices, + dst_buffer.mutable_contents(), + dst_local->shape, dst_indices, lengths); + case 4: + return kernels::Copy::Execute<4>(src_buffer.contents(), src_local->shape, + src_indices, + dst_buffer.mutable_contents(), + dst_local->shape, dst_indices, lengths); + case 8: + return kernels::Copy::Execute<8>(src_buffer.contents(), src_local->shape, + src_indices, + dst_buffer.mutable_contents(), + dst_local->shape, dst_indices, lengths); + default: + return UnimplementedErrorBuilder(IREE_LOC) + << "Unimplemented element size: " << src_local->element_size; + } +} + +} // namespace hal +} // namespace iree
diff --git a/hal/interpreter/bytecode_dispatch_util.h b/hal/interpreter/bytecode_dispatch_util.h new file mode 100644 index 0000000..65855c1 --- /dev/null +++ b/hal/interpreter/bytecode_dispatch_util.h
@@ -0,0 +1,513 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Utilities used by the bytecode_dispatch routines to aid in working with the +// bytecode stream and kernel dispatch. + +#ifndef IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_UTIL_H_ +#define IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_UTIL_H_ + +#include "absl/base/attributes.h" +#include "absl/container/inlined_vector.h" +#include "base/status.h" +#include "hal/buffer_view.h" +#include "hal/heap_buffer.h" +#include "hal/interpreter/bytecode_kernels.h" +#include "rt/function.h" +#include "rt/stack.h" +#include "schemas/bytecode/interpreter_bytecode_v0.h" +#include "vm/bytecode_reader.h" +#include "vm/type.h" + +// TODO(benvanik): move to dedicated config file/build flags. +#define IREE_SUPPORT_F32 1 +#define IREE_SUPPORT_F64 1 + +namespace iree { +namespace hal { + +// Returns true if the contents of the BufferView are bitwise non-zero. +// Returns false if there is no buffer, the buffer is empty, or the contents are +// bitwise zero. +bool BufferViewIsTrue(const BufferView& buffer_view); + +Status ValidateElementwiseUnaryOp(BufferView* src_local, BufferView* dst_local); +Status ValidateElementwiseBinaryOp(BufferView* lhs_local, BufferView* rhs_local, + BufferView* dst_local); +Status ValidateElementwiseTernaryOp(BufferView* a_local, BufferView* b_local, + BufferView* c_local, BufferView* dst_local); +Status ValidateMatMulOpI(BufferView* lhs_local, BufferView* rhs_local, + BufferView* bias_local, + BufferView* multiplier_mantissa_local, + BufferView* multiplier_exponent_local, + BufferView* dst_local); +Status ValidateMatMulOpF(BufferView* lhs_local, BufferView* rhs_local, + BufferView* bias_local, BufferView* dst_local); + +template <typename KERNEL, typename T, typename... ARGS> +Status ApplyUnaryOp(BufferView* src_local, BufferView* dst_local, + ARGS... args) { + // TODO(benvanik): avoid mapping by changing buffer type? + ASSIGN_OR_RETURN(auto src_buffer, + src_local->buffer->MapMemory<T>(MemoryAccess::kRead)); + ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<T>( + MemoryAccess::kDiscardWrite)); + return KERNEL::Execute(src_buffer.contents(), dst_buffer.mutable_contents(), + args...); +} + +template <typename KERNEL, typename T, typename... ARGS> +Status ApplyBinaryOp(BufferView* lhs_local, BufferView* rhs_local, + BufferView* dst_local, ARGS... args) { + ASSIGN_OR_RETURN(auto lhs_buffer, + lhs_local->buffer->MapMemory<T>(MemoryAccess::kRead)); + ASSIGN_OR_RETURN(auto rhs_buffer, + rhs_local->buffer->MapMemory<T>(MemoryAccess::kRead)); + ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<T>( + MemoryAccess::kDiscardWrite)); + return KERNEL::Execute(lhs_buffer.contents(), rhs_buffer.contents(), + dst_buffer.mutable_contents(), args...); +} + +template <typename KERNEL, typename T, typename... ARGS> +Status ApplyTernaryOp(BufferView* a_local, BufferView* b_local, + BufferView* c_local, BufferView* dst_local, + ARGS... args) { + ASSIGN_OR_RETURN(auto a_buffer, + a_local->buffer->MapMemory<T>(MemoryAccess::kRead)); + ASSIGN_OR_RETURN(auto b_buffer, + b_local->buffer->MapMemory<T>(MemoryAccess::kRead)); + ASSIGN_OR_RETURN(auto c_buffer, + c_local->buffer->MapMemory<T>(MemoryAccess::kRead)); + ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<T>( + MemoryAccess::kDiscardWrite)); + return KERNEL::Execute(a_buffer.contents(), b_buffer.contents(), + c_buffer.contents(), dst_buffer.mutable_contents(), + args...); +} + +template <typename KERNEL, typename T> +Status ApplyComparisonOp(BufferView* lhs_local, BufferView* rhs_local, + BufferView* dst_local) { + ASSIGN_OR_RETURN(auto lhs_buffer, + lhs_local->buffer->MapMemory<T>(MemoryAccess::kRead)); + ASSIGN_OR_RETURN(auto rhs_buffer, + rhs_local->buffer->MapMemory<T>(MemoryAccess::kRead)); + ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<uint8_t>( + MemoryAccess::kDiscardWrite)); + return KERNEL::Execute(lhs_buffer.contents(), rhs_buffer.contents(), + dst_buffer.mutable_contents()); +} + +template <typename KERNEL, typename... ARGS> +Status ApplyUnaryOpIS(BufferView* src_local, BufferView* dst_local, + ARGS... args) { + switch (src_local->element_size) { + case 1: + return ApplyUnaryOp<KERNEL, int8_t>(src_local, dst_local, args...); + case 2: + return ApplyUnaryOp<KERNEL, int16_t>(src_local, dst_local, args...); + case 4: + return ApplyUnaryOp<KERNEL, int32_t>(src_local, dst_local, args...); + case 8: + return ApplyUnaryOp<KERNEL, int64_t>(src_local, dst_local, args...); + default: + return UnimplementedErrorBuilder(IREE_LOC) + << "Unimplemented element size: " << src_local->element_size; + } +} + +template <typename KERNEL, typename... ARGS> +Status ApplyUnaryOpIU(BufferView* src_local, BufferView* dst_local, + ARGS... args) { + switch (src_local->element_size) { + case 1: + return ApplyUnaryOp<KERNEL, uint8_t>(src_local, dst_local, args...); + case 2: + return ApplyUnaryOp<KERNEL, uint16_t>(src_local, dst_local, args...); + case 4: + return ApplyUnaryOp<KERNEL, uint32_t>(src_local, dst_local, args...); + case 8: + return ApplyUnaryOp<KERNEL, uint64_t>(src_local, dst_local, args...); + default: + return UnimplementedErrorBuilder(IREE_LOC) + << "Unimplemented element size: " << src_local->element_size; + } +} + +template <typename KERNEL, typename... ARGS> +Status ApplyUnaryOpF(BufferView* src_local, BufferView* dst_local, + ARGS... args) { + switch (src_local->element_size) { +#if defined(IREE_SUPPORT_F32) + case 4: + return ApplyUnaryOp<KERNEL, float>(src_local, dst_local, args...); +#endif // IREE_SUPPORT_F32 +#if defined(IREE_SUPPORT_F64) + case 8: + return ApplyUnaryOp<KERNEL, double>(src_local, dst_local, args...); +#endif // IREE_SUPPORT_F64 + default: + return UnimplementedErrorBuilder(IREE_LOC) + << "Unimplemented element size: " << src_local->element_size; + } +} + +template <typename KERNEL, typename... ARGS> +Status ApplyBinaryOpIS(BufferView* lhs_local, BufferView* rhs_local, + BufferView* dst_local, ARGS... args) { + switch (lhs_local->element_size) { + case 1: + return ApplyBinaryOp<KERNEL, int8_t>(lhs_local, rhs_local, dst_local, + args...); + case 2: + return ApplyBinaryOp<KERNEL, int16_t>(lhs_local, rhs_local, dst_local, + args...); + case 4: + return ApplyBinaryOp<KERNEL, int32_t>(lhs_local, rhs_local, dst_local, + args...); + case 8: + return ApplyBinaryOp<KERNEL, int64_t>(lhs_local, rhs_local, dst_local, + args...); + default: + return UnimplementedErrorBuilder(IREE_LOC) + << "Unimplemented element size: " << lhs_local->element_size; + } +} + +template <typename KERNEL, typename... ARGS> +Status ApplyBinaryOpIU(BufferView* lhs_local, BufferView* rhs_local, + BufferView* dst_local, ARGS... args) { + switch (lhs_local->element_size) { + case 1: + return ApplyBinaryOp<KERNEL, uint8_t>(lhs_local, rhs_local, dst_local, + args...); + case 2: + return ApplyBinaryOp<KERNEL, uint16_t>(lhs_local, rhs_local, dst_local, + args...); + case 4: + return ApplyBinaryOp<KERNEL, uint32_t>(lhs_local, rhs_local, dst_local, + args...); + case 8: + return ApplyBinaryOp<KERNEL, uint64_t>(lhs_local, rhs_local, dst_local, + args...); + default: + return UnimplementedErrorBuilder(IREE_LOC) + << "Unimplemented element size: " << lhs_local->element_size; + } +} + +template <typename KERNEL, typename... ARGS> +Status ApplyBinaryOpF(BufferView* lhs_local, BufferView* rhs_local, + BufferView* dst_local, ARGS... args) { + switch (lhs_local->element_size) { +#if defined(IREE_SUPPORT_F32) + case 4: + return ApplyBinaryOp<KERNEL, float>(lhs_local, rhs_local, dst_local, + args...); +#endif // IREE_SUPPORT_F32 +#if defined(IREE_SUPPORT_F64) + case 8: + return ApplyBinaryOp<KERNEL, double>(lhs_local, rhs_local, dst_local, + args...); +#endif // IREE_SUPPORT_F64 + default: + return UnimplementedErrorBuilder(IREE_LOC) + << "Unimplemented element size: " << lhs_local->element_size; + } +} + +template <typename KERNEL, typename... ARGS> +Status ApplyTernaryOpIS(BufferView* a_local, BufferView* b_local, + BufferView* c_local, BufferView* dst_local, + ARGS... args) { + switch (a_local->element_size) { + case 1: + return ApplyTernaryOp<KERNEL, int8_t>(a_local, b_local, c_local, + dst_local, args...); + case 2: + return ApplyTernaryOp<KERNEL, int16_t>(a_local, b_local, c_local, + dst_local, args...); + case 4: + return ApplyTernaryOp<KERNEL, int32_t>(a_local, b_local, c_local, + dst_local, args...); + case 8: + return ApplyTernaryOp<KERNEL, int64_t>(a_local, b_local, c_local, + dst_local, args...); + default: + return UnimplementedErrorBuilder(IREE_LOC) + << "Unimplemented element size: " << a_local->element_size; + } +} + +template <typename KERNEL, typename... ARGS> +Status ApplyTernaryOpIU(BufferView* a_local, BufferView* b_local, + BufferView* c_local, BufferView* dst_local, + ARGS... args) { + switch (a_local->element_size) { + case 1: + return ApplyTernaryOp<KERNEL, uint8_t>(a_local, b_local, c_local, + dst_local, args...); + case 2: + return ApplyTernaryOp<KERNEL, uint16_t>(a_local, b_local, c_local, + dst_local, args...); + case 4: + return ApplyTernaryOp<KERNEL, uint32_t>(a_local, b_local, c_local, + dst_local, args...); + case 8: + return ApplyTernaryOp<KERNEL, uint64_t>(a_local, b_local, c_local, + dst_local, args...); + default: + return UnimplementedErrorBuilder(IREE_LOC) + << "Unimplemented element size: " << a_local->element_size; + } +} + +template <typename KERNEL, typename... ARGS> +Status ApplyTernaryOpF(BufferView* a_local, BufferView* b_local, + BufferView* c_local, BufferView* dst_local, + ARGS... args) { + switch (a_local->element_size) { +#if defined(IREE_SUPPORT_F32) + case 4: + return ApplyTernaryOp<KERNEL, float>(a_local, b_local, c_local, dst_local, + args...); +#endif // IREE_SUPPORT_F32 +#if defined(IREE_SUPPORT_F64) + case 8: + return ApplyTernaryOp<KERNEL, double>(a_local, b_local, c_local, + dst_local, args...); +#endif // IREE_SUPPORT_F64 + default: + return UnimplementedErrorBuilder(IREE_LOC) + << "Unimplemented element size: " << a_local->element_size; + } +} + +template <typename KERNEL> +Status ApplyComparisonOpIS(BufferView* lhs_local, BufferView* rhs_local, + BufferView* dst_local) { + switch (lhs_local->element_size) { + case 1: + return ApplyComparisonOp<KERNEL, int8_t>(lhs_local, rhs_local, dst_local); + case 2: + return ApplyComparisonOp<KERNEL, int16_t>(lhs_local, rhs_local, + dst_local); + case 4: + return ApplyComparisonOp<KERNEL, int32_t>(lhs_local, rhs_local, + dst_local); + case 8: + return ApplyComparisonOp<KERNEL, int64_t>(lhs_local, rhs_local, + dst_local); + default: + return UnimplementedErrorBuilder(IREE_LOC) + << "Unimplemented element size: " << lhs_local->element_size; + } +} + +template <typename KERNEL> +Status ApplyComparisonOpIU(BufferView* lhs_local, BufferView* rhs_local, + BufferView* dst_local) { + switch (lhs_local->element_size) { + case 1: + return ApplyComparisonOp<KERNEL, uint8_t>(lhs_local, rhs_local, + dst_local); + case 2: + return ApplyComparisonOp<KERNEL, uint16_t>(lhs_local, rhs_local, + dst_local); + case 4: + return ApplyComparisonOp<KERNEL, uint32_t>(lhs_local, rhs_local, + dst_local); + case 8: + return ApplyComparisonOp<KERNEL, uint64_t>(lhs_local, rhs_local, + dst_local); + default: + return UnimplementedErrorBuilder(IREE_LOC) + << "Unimplemented element size: " << lhs_local->element_size; + } +} + +template <typename KERNEL> +Status ApplyComparisonOpF(BufferView* lhs_local, BufferView* rhs_local, + BufferView* dst_local) { + switch (lhs_local->element_size) { + case 4: + return ApplyComparisonOp<KERNEL, float>(lhs_local, rhs_local, dst_local); + case 8: + return ApplyComparisonOp<KERNEL, double>(lhs_local, rhs_local, dst_local); + default: + return UnimplementedErrorBuilder(IREE_LOC) + << "Unimplemented element size: " << lhs_local->element_size; + } +} + +template <typename T, typename ACC = int32_t> +Status ApplyMatMulOpI(kernels::MatMul::RuntimeState* runtime_state, + BufferView* lhs_local, BufferView* rhs_local, + BufferView* bias_local, + BufferView* multiplier_mantissa_local, + BufferView* multiplier_exponent_local, + BufferView* dst_local) { + kernels::MatMul::Buffers<T, ACC> buffers; + ASSIGN_OR_RETURN(auto lhs_buffer, + lhs_local->buffer->MapMemory<T>(MemoryAccess::kRead)); + buffers.lhs_buffer = lhs_buffer.contents(); + buffers.lhs_shape = lhs_local->shape; + ASSIGN_OR_RETURN(auto rhs_buffer, + rhs_local->buffer->MapMemory<T>(MemoryAccess::kRead)); + buffers.rhs_buffer = rhs_buffer.contents(); + buffers.rhs_shape = rhs_local->shape; + MappedMemory<ACC> bias_buffer; + if (bias_local && bias_local->buffer && !bias_local->shape.empty()) { + if (bias_local->element_size != sizeof(ACC)) { + return UnimplementedErrorBuilder(IREE_LOC) + << "Only " << sizeof(ACC) << "b biases are supported right now"; + } + ASSIGN_OR_RETURN(bias_buffer, + bias_local->buffer->MapMemory<ACC>(MemoryAccess::kRead)); + buffers.bias_buffer = bias_buffer.contents(); + } + ASSIGN_OR_RETURN( + auto multiplier_mantissa_buffer, + multiplier_mantissa_local->buffer->MapMemory<ACC>(MemoryAccess::kRead)); + buffers.multiplier_mantissa_buffer = multiplier_mantissa_buffer.contents(); + ASSIGN_OR_RETURN(auto multiplier_exponent_buffer, + multiplier_exponent_local->buffer->MapMemory<int32_t>( + MemoryAccess::kRead)); + buffers.multiplier_exponent_buffer = multiplier_exponent_buffer.contents(); + ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<T>( + MemoryAccess::kDiscardWrite)); + buffers.dst_buffer = dst_buffer.mutable_contents(); + buffers.dst_shape = dst_local->shape; + return kernels::MatMul::Execute(runtime_state, buffers); +} + +template <typename T> +Status ApplyMatMulOpF(kernels::MatMul::RuntimeState* runtime_state, + BufferView* lhs_local, BufferView* rhs_local, + BufferView* bias_local, BufferView* dst_local) { + kernels::MatMul::Buffers<T, T> buffers; + ASSIGN_OR_RETURN(auto lhs_buffer, + lhs_local->buffer->MapMemory<T>(MemoryAccess::kRead)); + buffers.lhs_buffer = lhs_buffer.contents(); + buffers.lhs_shape = lhs_local->shape; + ASSIGN_OR_RETURN(auto rhs_buffer, + rhs_local->buffer->MapMemory<T>(MemoryAccess::kRead)); + buffers.rhs_buffer = rhs_buffer.contents(); + buffers.rhs_shape = rhs_local->shape; + MappedMemory<T> bias_buffer; + if (bias_local && bias_local->buffer && !bias_local->shape.empty()) { + ASSIGN_OR_RETURN(bias_buffer, + bias_local->buffer->MapMemory<T>(MemoryAccess::kRead)); + buffers.bias_buffer = bias_buffer.contents(); + } + ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<T>( + MemoryAccess::kDiscardWrite)); + buffers.dst_buffer = dst_buffer.mutable_contents(); + buffers.dst_shape = dst_local->shape; + return kernels::MatMul::Execute(runtime_state, buffers); +} + +template <typename KERNEL> +Status DispatchElementwiseUnaryOpIS(vm::BytecodeReader* reader) { + ASSIGN_OR_RETURN(auto* src_local, reader->ReadLocal()); + ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal()); + RETURN_IF_ERROR(ValidateElementwiseUnaryOp(src_local, dst_local)); + return ApplyUnaryOpIS<KERNEL>(src_local, dst_local); +} + +template <typename KERNEL> +Status DispatchElementwiseUnaryOpIU(vm::BytecodeReader* reader) { + ASSIGN_OR_RETURN(auto* src_local, reader->ReadLocal()); + ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal()); + RETURN_IF_ERROR(ValidateElementwiseUnaryOp(src_local, dst_local)); + return ApplyUnaryOpIU<KERNEL>(src_local, dst_local); +} + +template <typename KERNEL> +Status DispatchElementwiseUnaryOpF(vm::BytecodeReader* reader) { + ASSIGN_OR_RETURN(auto* src_local, reader->ReadLocal()); + ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal()); + RETURN_IF_ERROR(ValidateElementwiseUnaryOp(src_local, dst_local)); + return ApplyUnaryOpF<KERNEL>(src_local, dst_local); +} + +template <typename KERNEL> +Status DispatchElementwiseBinaryOpIS(vm::BytecodeReader* reader) { + ASSIGN_OR_RETURN(auto* lhs_local, reader->ReadLocal()); + ASSIGN_OR_RETURN(auto* rhs_local, reader->ReadLocal()); + ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal()); + RETURN_IF_ERROR(ValidateElementwiseBinaryOp(lhs_local, rhs_local, dst_local)); + return ApplyBinaryOpIS<KERNEL>(lhs_local, rhs_local, dst_local); +} + +template <typename KERNEL> +Status DispatchElementwiseBinaryOpIU(vm::BytecodeReader* reader) { + ASSIGN_OR_RETURN(auto* lhs_local, reader->ReadLocal()); + ASSIGN_OR_RETURN(auto* rhs_local, reader->ReadLocal()); + ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal()); + RETURN_IF_ERROR(ValidateElementwiseBinaryOp(lhs_local, rhs_local, dst_local)); + return ApplyBinaryOpIU<KERNEL>(lhs_local, rhs_local, dst_local); +} + +template <typename KERNEL> +Status DispatchElementwiseBinaryOpF(vm::BytecodeReader* reader) { + ASSIGN_OR_RETURN(auto* lhs_local, reader->ReadLocal()); + ASSIGN_OR_RETURN(auto* rhs_local, reader->ReadLocal()); + ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal()); + RETURN_IF_ERROR(ValidateElementwiseBinaryOp(lhs_local, rhs_local, dst_local)); + return ApplyBinaryOpF<KERNEL>(lhs_local, rhs_local, dst_local); +} + +template <typename KERNEL> +Status DispatchElementwiseTernaryOpIS(vm::BytecodeReader* reader) { + ASSIGN_OR_RETURN(auto* a_local, reader->ReadLocal()); + ASSIGN_OR_RETURN(auto* b_local, reader->ReadLocal()); + ASSIGN_OR_RETURN(auto* c_local, reader->ReadLocal()); + ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal()); + RETURN_IF_ERROR( + ValidateElementwiseTernaryOp(a_local, b_local, c_local, dst_local)); + return ApplyTernaryOpIS<KERNEL>(a_local, b_local, c_local, dst_local); +} + +template <typename KERNEL> +Status DispatchElementwiseTernaryOpIU(vm::BytecodeReader* reader) { + ASSIGN_OR_RETURN(auto* a_local, reader->ReadLocal()); + ASSIGN_OR_RETURN(auto* b_local, reader->ReadLocal()); + ASSIGN_OR_RETURN(auto* c_local, reader->ReadLocal()); + ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal()); + RETURN_IF_ERROR( + ValidateElementwiseTernaryOp(a_local, b_local, c_local, dst_local)); + return ApplyTernaryOpIU<KERNEL>(a_local, b_local, c_local, dst_local); +} + +template <typename KERNEL> +Status DispatchElementwiseTernaryOpF(vm::BytecodeReader* reader) { + ASSIGN_OR_RETURN(auto* a_local, reader->ReadLocal()); + ASSIGN_OR_RETURN(auto* b_local, reader->ReadLocal()); + ASSIGN_OR_RETURN(auto* c_local, reader->ReadLocal()); + ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal()); + RETURN_IF_ERROR( + ValidateElementwiseTernaryOp(a_local, b_local, c_local, dst_local)); + return ApplyTernaryOpF<KERNEL>(a_local, b_local, c_local, dst_local); +} + +Status ApplyCopy(BufferView* src_local, absl::Span<const int32_t> src_indices, + BufferView* dst_local, absl::Span<const int32_t> dst_indices, + absl::Span<const int32_t> lengths); + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_UTIL_H_
diff --git a/hal/interpreter/bytecode_executable.cc b/hal/interpreter/bytecode_executable.cc new file mode 100644 index 0000000..ff3d6d3 --- /dev/null +++ b/hal/interpreter/bytecode_executable.cc
@@ -0,0 +1,64 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/interpreter/bytecode_executable.h" + +#include <iostream> + +#include "hal/interpreter/interpreter_module.h" +#include "rt/policy.h" + +namespace iree { +namespace hal { + +// static +StatusOr<ref_ptr<BytecodeExecutable>> BytecodeExecutable::Load( + ref_ptr<rt::Instance> instance, hal::Allocator* allocator, + ExecutableSpec spec, bool allow_aliasing_data) { + // Allocate the executable now. + // We do this here so that if we need to clone the data we are passing that + // to the VM loader instead of the data we may not have access to later. + auto executable = make_ref<BytecodeExecutable>(std::move(instance), allocator, + spec, allow_aliasing_data); + + // Create the executable module. + auto module_def = + ::flatbuffers::GetRoot<ModuleDef>(executable->executable_data().data()); + ASSIGN_OR_RETURN(auto module, + InterpreterModule::FromDef(allocator, *module_def)); + executable->module_ = add_ref(module); + RETURN_IF_ERROR(executable->context()->RegisterModule(std::move(module))); + + return executable; +} + +BytecodeExecutable::BytecodeExecutable(ref_ptr<rt::Instance> instance, + hal::Allocator* allocator, + ExecutableSpec spec, + bool allow_aliasing_data) + : spec_(spec), + context_( + make_ref<rt::Context>(std::move(instance), make_ref<rt::Policy>())) { + if (!allow_aliasing_data) { + // Clone data. + cloned_executable_data_ = {spec.executable_data.begin(), + spec.executable_data.end()}; + spec_.executable_data = absl::MakeConstSpan(cloned_executable_data_); + } +} + +BytecodeExecutable::~BytecodeExecutable() = default; + +} // namespace hal +} // namespace iree
diff --git a/hal/interpreter/bytecode_executable.h b/hal/interpreter/bytecode_executable.h new file mode 100644 index 0000000..f90e144 --- /dev/null +++ b/hal/interpreter/bytecode_executable.h
@@ -0,0 +1,68 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_INTERPRETER_BYTECODE_EXECUTABLE_H_ +#define IREE_HAL_INTERPRETER_BYTECODE_EXECUTABLE_H_ + +#include <vector> + +#include "absl/types/span.h" +#include "base/status.h" +#include "hal/allocator.h" +#include "hal/executable.h" +#include "hal/executable_spec.h" +#include "rt/context.h" +#include "rt/instance.h" +#include "rt/module.h" + +namespace iree { +namespace hal { + +class BytecodeExecutable final : public Executable { + public: + static StatusOr<ref_ptr<BytecodeExecutable>> Load( + ref_ptr<rt::Instance> instance, hal::Allocator* allocator, + ExecutableSpec spec, bool allow_aliasing_data); + + BytecodeExecutable(ref_ptr<rt::Instance> instance, hal::Allocator* allocator, + ExecutableSpec spec, bool allow_aliasing_data); + ~BytecodeExecutable() override; + + bool supports_debugging() const override { return false; } + + // Reference to the bytecode blob contents. + absl::Span<const uint8_t> executable_data() const { + return spec_.executable_data; + } + + // VM context with the executable registered. + const ref_ptr<rt::Context>& context() const { return context_; } + + // VM module representing the executable. + // Note that there may be more than one module in the Context and only this + // module can be used to lookup executable exports. + const ref_ptr<rt::Module>& module() const { return module_; } + + private: + ExecutableSpec spec_; + std::vector<uint8_t> cloned_executable_data_; + + ref_ptr<rt::Context> context_; + ref_ptr<rt::Module> module_; +}; + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_INTERPRETER_BYTECODE_EXECUTABLE_H_
diff --git a/hal/interpreter/bytecode_kernels.h b/hal/interpreter/bytecode_kernels.h new file mode 100644 index 0000000..2e7e90c --- /dev/null +++ b/hal/interpreter/bytecode_kernels.h
@@ -0,0 +1,371 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Defines kernel functions and provides their implementation via one (or more) +// included files. +// +// Kernels should do the simplest possible operation. Buffer validation is +// handled by the dispatch logic and need not be checked. Kernels may optionally +// accept arguments beyond just the buffers, depending on the required state +// and attributes. +// +// Kernels may optionally have runtime state. This is state that is allocated +// once for the entire Runtime (and stored on RuntimeState) and shared across +// all fibers. This enables kernels that may require thread pools or device +// handles to be shared while kernels that require transient storage to be safe +// to use from multiple fibers concurrently. +// +// All kernels are templated to enable specialization of particular types or +// type combinations. By default the bytecode_kernels_generic.h will provide C++ +// semantics as reference and platform-specific versions can be implemented +// as needed. + +#ifndef IREE_HAL_INTERPRETER_BYTECODE_KERNELS_H_ +#define IREE_HAL_INTERPRETER_BYTECODE_KERNELS_H_ + +#include <cstdint> + +#include "absl/types/span.h" +#include "base/shape.h" +#include "base/status.h" + +namespace iree { +namespace hal { +namespace kernels { + +struct CompareEQ { + template <typename T> + static Status Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, + absl::Span<uint8_t> dst_buffer); +}; +struct CompareNE { + template <typename T> + static Status Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, + absl::Span<uint8_t> dst_buffer); +}; +struct CompareLT { + template <typename T> + static Status Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, + absl::Span<uint8_t> dst_buffer); +}; +struct CompareLE { + template <typename T> + static Status Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, + absl::Span<uint8_t> dst_buffer); +}; +struct CompareGT { + template <typename T> + static Status Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, + absl::Span<uint8_t> dst_buffer); +}; +struct CompareGE { + template <typename T> + static Status Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, + absl::Span<uint8_t> dst_buffer); +}; + +struct Copy { + template <int element_size> + static Status Execute(absl::Span<const uint8_t> src_buffer, + const Shape& src_shape, + absl::Span<const int32_t> src_indices, + absl::Span<uint8_t> dst_buffer, const Shape& dst_shape, + absl::Span<const int32_t> dst_indices, + absl::Span<const int32_t> lengths); +}; + +struct Select { + template <typename T> + static Status Execute(absl::Span<const uint8_t> cond_buffer, + absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, + absl::Span<T> dst_buffer); +}; + +struct Transpose { + template <typename T> + static Status Execute(absl::Span<const T> src_buffer, + absl::Span<T> dst_buffer, const Shape& src_shape, + absl::Span<const int32_t> perm); +}; + +struct Pad { + template <typename T> + static Status Execute(absl::Span<const T> src_buffer, + absl::Span<const T> padding_value, + absl::Span<T> dst_buffer, const Shape& src_shape, + const Shape& dst_shape, + absl::Span<const int32_t> edge_padding_low, + absl::Span<const int32_t> edge_padding_high, + absl::Span<const int32_t> interior_padding); +}; + +struct Reverse { + template <typename T> + static Status Execute(absl::Span<const T> src_buffer, + absl::Span<T> dst_buffer, const Shape& src_shape, + absl::Span<const int32_t> dimensions); +}; + +struct Broadcast { + template <typename T> + static Status Execute(absl::Span<const T> src_buffer, + absl::Span<T> dst_buffer); +}; + +struct Tile { + template <typename T> + static Status Execute(absl::Span<const T> src_buffer, + absl::Span<T> dst_buffer, const Shape& src_shape, + const Shape& dst_shape); +}; + +struct Not { + template <typename T> + static Status Execute(absl::Span<const T> src_buffer, + absl::Span<T> dst_buffer); +}; + +struct And { + template <typename T> + static Status Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, + absl::Span<T> dst_buffer); +}; + +struct Or { + template <typename T> + static Status Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, + absl::Span<T> dst_buffer); +}; + +struct Xor { + template <typename T> + static Status Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, + absl::Span<T> dst_buffer); +}; + +struct ShiftLeft { + template <typename T> + static Status Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, + absl::Span<T> dst_buffer); +}; + +struct ShiftRight { + template <typename T> + static Status Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, + absl::Span<T> dst_buffer); +}; + +struct Add { + template <typename T> + static Status Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, + absl::Span<T> dst_buffer); +}; + +struct Sub { + template <typename T> + static Status Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, + absl::Span<T> dst_buffer); +}; + +struct Abs { + template <typename T> + static Status Execute(absl::Span<const T> src_buffer, + absl::Span<T> dst_buffer); +}; + +struct Mul { + template <typename T> + static Status Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, + absl::Span<T> dst_buffer); +}; + +struct Div { + template <typename T> + static Status Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, + absl::Span<T> dst_buffer); +}; + +// a + (b * c) +struct MulAdd { + template <typename T> + static Status Execute(absl::Span<const T> a_buffer, + absl::Span<const T> b_buffer, + absl::Span<const T> c_buffer, absl::Span<T> dst_buffer); +}; + +struct Exp { + template <typename T> + static Status Execute(absl::Span<const T> src_buffer, + absl::Span<T> dst_buffer); +}; + +struct Log { + template <typename T> + static Status Execute(absl::Span<const T> src_buffer, + absl::Span<T> dst_buffer); +}; + +struct Rsqrt { + template <typename T> + static Status Execute(absl::Span<const T> src_buffer, + absl::Span<T> dst_buffer); +}; + +struct Cos { + template <typename T> + static Status Execute(absl::Span<const T> src_buffer, + absl::Span<T> dst_buffer); +}; + +struct Sin { + template <typename T> + static Status Execute(absl::Span<const T> src_buffer, + absl::Span<T> dst_buffer); +}; + +struct Tanh { + template <typename T> + static Status Execute(absl::Span<const T> src_buffer, + absl::Span<T> dst_buffer); +}; + +struct Atan2 { + template <typename T> + static Status Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, + absl::Span<T> dst_buffer); +}; + +struct Min { + template <typename T> + static Status Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, + absl::Span<T> dst_buffer); +}; + +struct Max { + template <typename T> + static Status Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, + absl::Span<T> dst_buffer); +}; + +struct Clamp { + template <typename T> + static Status Execute(absl::Span<const T> src_buffer, + absl::Span<const T> min_buffer, + absl::Span<const T> max_buffer, + absl::Span<T> dst_buffer); +}; + +struct Floor { + template <typename T> + static Status Execute(absl::Span<const T> src_buffer, + absl::Span<T> dst_buffer); +}; + +struct Ceil { + template <typename T> + static Status Execute(absl::Span<const T> src_buffer, + absl::Span<T> dst_buffer); +}; + +struct Convert { + template <typename SRC, typename DST> + static Status Execute(absl::Span<const SRC> src_buffer, + absl::Span<DST> dst_buffer); +}; + +struct MatMul { + struct RuntimeState; + + static std::unique_ptr<RuntimeState> CreateRuntimeState(); + + template <typename T, typename ACC> + struct Buffers { + Shape lhs_shape; + absl::Span<const T> lhs_buffer; + Shape rhs_shape; + absl::Span<const T> rhs_buffer; + Shape dst_shape; + absl::Span<T> dst_buffer; + + // Optional bias buffer. + absl::Span<const ACC> bias_buffer; + + // Fixed-point multiplier mantissa/exponent. May be a single value (for + // uniform quantization) or one element per row of the destination matrix + // for per-channel. + absl::Span<const ACC> multiplier_mantissa_buffer; + absl::Span<const int32_t> multiplier_exponent_buffer; + }; + + template <typename T, typename ACC> + static Status Execute(RuntimeState* runtime_state, + const Buffers<T, ACC>& buffers); +}; + +struct RuntimeState { + std::unique_ptr<MatMul::RuntimeState> mat_mul_state = + MatMul::CreateRuntimeState(); +}; + +struct ReduceSum { + template <typename T> + static Status Execute(absl::Span<const T> src_buffer, + absl::Span<const T> init_buffer, + absl::Span<T> dst_buffer, int32_t dimension, + const Shape& src_shape, const Shape& dst_shape); +}; + +struct ReduceMin { + template <typename T> + static Status Execute(absl::Span<const T> src_buffer, + absl::Span<const T> init_buffer, + absl::Span<T> dst_buffer, int32_t dimension, + const Shape& src_shape, const Shape& dst_shape); +}; + +struct ReduceMax { + template <typename T> + static Status Execute(absl::Span<const T> src_buffer, + absl::Span<const T> init_buffer, + absl::Span<T> dst_buffer, int32_t dimension, + const Shape& src_shape, const Shape& dst_shape); +}; + +} // namespace kernels +} // namespace hal +} // namespace iree + +#include "hal/interpreter/bytecode_kernels_generic.h" // IWYU pragma: export +#include "hal/interpreter/bytecode_kernels_ruy.h" // IWYU pragma: export + +#endif // IREE_HAL_INTERPRETER_BYTECODE_KERNELS_H_
diff --git a/hal/interpreter/bytecode_kernels_generic.h b/hal/interpreter/bytecode_kernels_generic.h new file mode 100644 index 0000000..95d12dc --- /dev/null +++ b/hal/interpreter/bytecode_kernels_generic.h
@@ -0,0 +1,708 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_INTERPRETER_BYTECODE_KERNELS_GENERIC_H_ +#define IREE_HAL_INTERPRETER_BYTECODE_KERNELS_GENERIC_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/types/span.h" +#include "base/status.h" + +namespace iree { +namespace hal { +namespace kernels { + +template <typename T> +Status CompareEQ::Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, + absl::Span<uint8_t> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = lhs_buffer[i] == rhs_buffer[i]; + } + return OkStatus(); +} + +template <typename T> +Status CompareNE::Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, + absl::Span<uint8_t> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = lhs_buffer[i] != rhs_buffer[i]; + } + return OkStatus(); +} + +template <typename T> +Status CompareLT::Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, + absl::Span<uint8_t> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = lhs_buffer[i] < rhs_buffer[i]; + } + return OkStatus(); +} + +template <typename T> +Status CompareLE::Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, + absl::Span<uint8_t> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = lhs_buffer[i] <= rhs_buffer[i]; + } + return OkStatus(); +} + +template <typename T> +Status CompareGT::Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, + absl::Span<uint8_t> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = lhs_buffer[i] > rhs_buffer[i]; + } + return OkStatus(); +} + +template <typename T> +Status CompareGE::Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, + absl::Span<uint8_t> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = lhs_buffer[i] >= rhs_buffer[i]; + } + return OkStatus(); +} + +namespace impl { +inline absl::InlinedVector<size_t, 6> ComputeCopyStrides(const Shape& shape, + size_t element_size) { + absl::InlinedVector<size_t, 6> strides(shape.size()); + strides.back() = element_size; + for (int i = shape.size() - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * shape[i + 1]; + } + return strides; +} + +inline void CopyRegion(absl::Span<const uint8_t> src_buffer, + absl::Span<const size_t> src_strides, + absl::Span<const int32_t> src_indices, + absl::Span<uint8_t> dst_buffer, + absl::Span<const size_t> dst_strides, + absl::Span<const int32_t> dst_indices, + absl::Span<const int32_t> lengths) { + if (lengths.size() > 1) { + for (int i = 0; i < lengths[0]; ++i) { + size_t src_offset = src_strides[0] * (src_indices[0] + i); + size_t dst_offset = dst_strides[0] * (dst_indices[0] + i); + CopyRegion(src_buffer.subspan(src_offset), src_strides.subspan(1), + src_indices.subspan(1), dst_buffer.subspan(dst_offset), + dst_strides.subspan(1), dst_indices.subspan(1), + lengths.subspan(1)); + } + } else { + DCHECK_EQ(dst_strides.size(), 1); + DCHECK_EQ(src_strides.size(), 1); + DCHECK_EQ(src_indices.size(), 1); + DCHECK_EQ(dst_indices.size(), 1); + DCHECK_EQ(lengths.size(), 1); + auto src_offset = src_indices[0] * src_strides[0]; + auto dst_offset = dst_indices[0] * dst_strides[0]; + auto length = dst_strides[0] * lengths[0]; + std::memcpy(dst_buffer.data() + dst_offset, src_buffer.data() + src_offset, + length); + } +} +} // namespace impl + +// TODO(benvanik): replace with a real implementation once copy is defined. +// TODO(gcmn): More consistent/principled handling for scalars. +template <int element_size> +Status Copy::Execute(absl::Span<const uint8_t> src_buffer, + const Shape& src_shape, + absl::Span<const int32_t> src_indices, + absl::Span<uint8_t> dst_buffer, const Shape& dst_shape, + absl::Span<const int32_t> dst_indices, + absl::Span<const int32_t> lengths) { + DCHECK_EQ(src_indices.size(), lengths.size()); + DCHECK_EQ(dst_indices.size(), lengths.size()); + DCHECK_EQ(src_shape.size(), lengths.size()); + DCHECK_EQ(dst_shape.size(), lengths.size()); + if (lengths.empty()) { + std::memcpy(dst_buffer.data(), src_buffer.data(), element_size); + return OkStatus(); + } + + // TODO(gcmn) Maybe we can fast-path earlier if we detect contiguous memory + // across multiple rows. + auto src_strides = impl::ComputeCopyStrides(src_shape, element_size); + auto dst_strides = impl::ComputeCopyStrides(dst_shape, element_size); + DCHECK_EQ(src_strides.size(), lengths.size()); + DCHECK_EQ(dst_strides.size(), lengths.size()); + impl::CopyRegion(src_buffer, src_strides, src_indices, dst_buffer, + dst_strides, dst_indices, lengths); + return OkStatus(); +} + +template <typename T> +Status Select::Execute(absl::Span<const uint8_t> cond_buffer, + absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, + absl::Span<T> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = cond_buffer[i] ? lhs_buffer[i] : rhs_buffer[i]; + } + return OkStatus(); +} + +template <typename T> +Status Transpose::Execute(absl::Span<const T> src_buffer, + absl::Span<T> dst_buffer, const Shape& src_shape, + absl::Span<const int32_t> perm) { + // This implementation is .... not fast. + int rank = src_shape.size(); + absl::InlinedVector<int, 8> src_strides(rank); + absl::InlinedVector<int, 8> dst_strides(rank); + size_t src_stride = 1; + size_t dst_stride = 1; + for (int dim_i = rank - 1; dim_i >= 0; --dim_i) { + src_strides[dim_i] = src_stride; + dst_strides[dim_i] = dst_stride; + src_stride *= src_shape[dim_i]; + dst_stride *= src_shape[perm[dim_i]]; + } + for (size_t dst_i = 0; dst_i < dst_buffer.size(); ++dst_i) { + size_t src_i = 0; + size_t t = dst_i; + for (int dim_i = 0; dim_i < rank; ++dim_i) { + size_t ratio = t / dst_strides[dim_i]; + t -= ratio * dst_strides[dim_i]; + src_i += ratio * src_strides[perm[dim_i]]; + } + dst_buffer[dst_i] = src_buffer[src_i]; + } + return OkStatus(); +} + +namespace impl { +inline void IncrementShapeIndex(absl::Span<int32_t> indices, + const Shape& shape) { + for (int i = indices.size() - 1; i >= 0; --i) { + if (++indices[i] < shape[i]) return; + indices[i] = 0; + } +} + +inline bool IsPadding(absl::Span<const int32_t> indices, const Shape& shape, + absl::Span<const int32_t> edge_padding_low, + absl::Span<const int32_t> edge_padding_high, + absl::Span<const int32_t> interior_padding) { + for (int i = 0; i < indices.size(); ++i) { + auto index = indices[i]; + if (index < edge_padding_low[i] || + index >= shape[i] - edge_padding_high[i] || + (index - edge_padding_low[i]) % (interior_padding[i] + 1) != 0) { + return true; + } + } + + return false; +} +} // namespace impl + +template <typename T> +Status Pad::Execute(absl::Span<const T> src_buffer, + absl::Span<const T> padding_value_buffer, + absl::Span<T> dst_buffer, const Shape& src_shape, + const Shape& dst_shape, + absl::Span<const int32_t> edge_padding_low, + absl::Span<const int32_t> edge_padding_high, + absl::Span<const int32_t> interior_padding) { + // This implementation is not at all fast, as it iterates every index in the + // destination buffer individually. Potential improvements: + // 1. Fill the dst buffer with padded value initially. Only need to iterate + // through source buffer and can exit early. + // 2. Use striding to advance through larger swaths of the buffer with a + // memcpy from src and filling (or skipping) padded incides. Especially + // useful when e.g. entire rows are padded. + + // TODO(b/140836672) support negative padding + + if (padding_value_buffer.size() != 1) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Padding value buffer is larger than one element."; + } + auto padding_value = padding_value_buffer.front(); + + absl::InlinedVector<int, 8> dst_indices(src_shape.size(), 0); + + const T* src_ptr = src_buffer.begin(); + T* dst_ptr = dst_buffer.begin(); + while (dst_ptr != dst_buffer.end()) { + if (impl::IsPadding(dst_indices, dst_shape, edge_padding_low, + edge_padding_high, interior_padding)) { + *dst_ptr++ = padding_value; + } else { + DCHECK(src_ptr != src_buffer.end()); + *dst_ptr++ = *src_ptr++; + } + impl::IncrementShapeIndex(absl::MakeSpan(dst_indices), dst_shape); + } + + return OkStatus(); +} + +template <typename T> +Status Reverse::Execute(absl::Span<const T> src_buffer, + absl::Span<T> dst_buffer, const Shape& src_shape, + absl::Span<const int32_t> dimensions) { + // This implementation is not fast either + int rank = src_shape.size(); + absl::InlinedVector<int, 8> strides(rank); + size_t stride = 1; + for (int dim_i = rank - 1; dim_i >= 0; --dim_i) { + strides[dim_i] = stride; + stride *= src_shape[dim_i]; + } + absl::flat_hash_set<int32_t> dims_set(dimensions.begin(), dimensions.end()); + for (size_t dst_i = 0; dst_i < dst_buffer.size(); ++dst_i) { + size_t src_i = 0; + size_t t = dst_i; + for (int dim_i = 0; dim_i < rank; ++dim_i) { + size_t ratio = t / strides[dim_i]; + t -= ratio * strides[dim_i]; + bool do_reverse = dims_set.contains(dim_i); + src_i += (do_reverse ? (src_shape[dim_i] - 1 - ratio) : ratio) * + strides[dim_i]; + } + dst_buffer[dst_i] = src_buffer[src_i]; + } + return OkStatus(); +} + +template <typename T> +Status Broadcast::Execute(absl::Span<const T> src_buffer, + absl::Span<T> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = src_buffer[0]; + } + return OkStatus(); +} + +template <typename T> +Status Tile::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer, + const Shape& src_shape, const Shape& dst_shape) { + // This implementation is .... not fast. + int rank = dst_shape.size(); + absl::InlinedVector<int, 8> src_strides(rank); + absl::InlinedVector<int, 8> dst_strides(rank); + size_t src_stride = 1; + size_t dst_stride = 1; + for (int dim_i = rank - 1; dim_i >= 0; --dim_i) { + src_strides[dim_i] = src_stride; + dst_strides[dim_i] = dst_stride; + src_stride *= src_shape[dim_i]; + dst_stride *= dst_shape[dim_i]; + } + for (size_t dst_i = 0; dst_i < dst_buffer.size(); ++dst_i) { + size_t src_i = 0; + size_t t = dst_i; + for (int dim_i = 0; dim_i < rank; ++dim_i) { + src_i += t / dst_strides[dim_i] % src_shape[dim_i] * src_strides[dim_i]; + t %= dst_strides[dim_i]; + } + dst_buffer[dst_i] = src_buffer[src_i]; + } + return OkStatus(); +} + +template <typename T> +Status Not::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = ~src_buffer[i]; + } + return OkStatus(); +} + +template <typename T> +Status And::Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = lhs_buffer[i] & rhs_buffer[i]; + } + return OkStatus(); +} + +template <typename T> +Status Or::Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = lhs_buffer[i] | rhs_buffer[i]; + } + return OkStatus(); +} + +template <typename T> +Status Xor::Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = lhs_buffer[i] ^ rhs_buffer[i]; + } + return OkStatus(); +} + +template <typename T> +Status ShiftLeft::Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, + absl::Span<T> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = lhs_buffer[i] << rhs_buffer[i]; + } + return OkStatus(); +} + +template <typename T> +Status ShiftRight::Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, + absl::Span<T> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = lhs_buffer[i] >> rhs_buffer[i]; + } + return OkStatus(); +} + +template <typename T> +Status Add::Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = lhs_buffer[i] + rhs_buffer[i]; + } + return OkStatus(); +} + +template <typename T> +Status Sub::Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = lhs_buffer[i] - rhs_buffer[i]; + } + return OkStatus(); +} + +template <typename T> +Status Abs::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = std::abs(src_buffer[i]); + } + return OkStatus(); +} + +template <typename T> +Status Mul::Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = lhs_buffer[i] * rhs_buffer[i]; + } + return OkStatus(); +} + +template <typename T> +Status Div::Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = lhs_buffer[i] / rhs_buffer[i]; + } + return OkStatus(); +} + +template <typename T> +Status MulAdd::Execute(absl::Span<const T> a_buffer, + absl::Span<const T> b_buffer, + absl::Span<const T> c_buffer, absl::Span<T> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = a_buffer[i] + (b_buffer[i] * c_buffer[i]); + } + return OkStatus(); +} + +template <typename T> +Status Exp::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = std::exp(src_buffer[i]); + } + return OkStatus(); +} + +template <typename T> +Status Rsqrt::Execute(absl::Span<const T> src_buffer, + absl::Span<T> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = 1.0 / std::sqrt(src_buffer[i]); + } + return OkStatus(); +} + +template <typename T> +Status Log::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = std::log(src_buffer[i]); + } + return OkStatus(); +} + +template <typename T> +Status Cos::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = std::cos(src_buffer[i]); + } + return OkStatus(); +} + +template <typename T> +Status Sin::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = std::sin(src_buffer[i]); + } + return OkStatus(); +} + +template <typename T> +Status Tanh::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = std::tanh(src_buffer[i]); + } + return OkStatus(); +} + +template <typename T> +Status Atan2::Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, + absl::Span<T> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = std::atan2(lhs_buffer[i], rhs_buffer[i]); + } + return OkStatus(); +} + +template <typename T> +Status Min::Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = std::min(lhs_buffer[i], rhs_buffer[i]); + } + return OkStatus(); +} + +template <typename T> +Status Max::Execute(absl::Span<const T> lhs_buffer, + absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = std::max(lhs_buffer[i], rhs_buffer[i]); + } + return OkStatus(); +} + +template <typename T> +Status Clamp::Execute(absl::Span<const T> src_buffer, + absl::Span<const T> min_buffer, + absl::Span<const T> max_buffer, + absl::Span<T> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + T src = src_buffer[i]; + T min = min_buffer[i]; + T max = max_buffer[i]; + dst_buffer[i] = src <= min ? min : src >= max ? max : src; + } + return OkStatus(); +} + +template <typename T> +Status Floor::Execute(absl::Span<const T> src_buffer, + absl::Span<T> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = std::floor(src_buffer[i]); + } + return OkStatus(); +} + +template <typename T> +Status Ceil::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) { + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = std::ceil(src_buffer[i]); + } + return OkStatus(); +} + +template <typename SRC, typename DST> +Status Convert::Execute(absl::Span<const SRC> src_buffer, + absl::Span<DST> dst_buffer) { + DCHECK_EQ(src_buffer.size(), dst_buffer.size()); + for (size_t i = 0; i < dst_buffer.size(); ++i) { + dst_buffer[i] = static_cast<DST>(src_buffer[i]); + } + return OkStatus(); +} + +namespace impl { + +struct SumKernel { + template <typename T> + inline void operator()(T* value0, const T value1) { + *value0 += value1; + } +}; + +struct MinKernel { + template <typename T> + inline void operator()(T* value0, const T value1) { + *value0 = std::min(*value0, value1); + } +}; + +struct MaxKernel { + template <typename T> + inline void operator()(T* value0, const T value1) { + *value0 = std::max(*value0, value1); + } +}; + +template <typename T, typename KernelImpl> +inline void ReduceDimension(absl::Span<const T> src_buffer, + absl::Span<T> dst_buffer, const Shape& src_shape, + absl::Span<const int32_t> reduce_dims, + absl::Span<const int> dst_strides, int dim, + absl::Span<int> src_indices, size_t flat_src_i, + size_t src_stride) { + if (dim < 0) { + // Base case of the recursion - figure out which elements should be acted + // upon and apply the reduction kernel to them. + + // Derive destination indices from source indices. + // For example, + // reduce_dims: [1, 2] + // src_indices: [2, 1, 3, 0] + // ^ ^ + // | | + // |----- remove these dimensions + // dst_indices: [2, 0] + // + // TODO(scotttodd): Clean this up somehow, share across recursion levels? + size_t dst_size = src_shape.size() - reduce_dims.size(); + absl::InlinedVector<int, 8> dst_indices; + for (size_t i = 0; i < src_indices.size(); ++i) { + if (std::find(std::begin(reduce_dims), std::end(reduce_dims), i) == + std::end(reduce_dims)) { + dst_indices.push_back(src_indices[i]); + } + } + // Compute the flattened index into dst_buffer at [dst_indices]. + size_t dst_i = 0; + for (size_t i = 0; i < dst_indices.size(); ++i) { + dst_i += dst_indices[i] * dst_strides[dst_size - 1 - i]; + } + + // Flattened src and dst indices have been computed, invoke the kernel. + KernelImpl()(&dst_buffer[dst_i], src_buffer[flat_src_i]); + return; + } + + // Iterate through the current dimension in the source shape, recursing + // down one dimension at a time. + // + // This touches each element in the source buffer once, tracking complete + // dimensions within the shaped source buffer and using them to compute + // the corresponding indices (shaped and flattened) within the destination + // buffer. Each element in the destination buffer will be touched multiple + // times. + // + // Note that cache coherency isn't considered here, and some computations + // are redundant, so this could be optimized substantially. + for (size_t dim_i = 0; dim_i < src_shape[dim]; ++dim_i) { + src_indices[dim] = dim_i; + + // Recurse down to the next dimension (e.g. 2 -> 1 -> 0 -> base case) + // * Add the current stride to flat_src_i + // * Multiply src_stride by this dimension's shape + ReduceDimension<T, KernelImpl>(src_buffer, dst_buffer, src_shape, + reduce_dims, dst_strides, dim - 1, + src_indices, flat_src_i + dim_i * src_stride, + src_stride * src_shape[dim]); + } +} + +template <typename T, typename KernelImpl> +Status GenericReduce(absl::Span<const T> src_buffer, + absl::Span<const T> init_buffer, absl::Span<T> dst_buffer, + int32_t dimension, const Shape& src_shape, + const Shape& dst_shape) { + // Initialize using init_buffer, which is expected to be a scalar. + std::fill_n(dst_buffer.data(), dst_buffer.size(), init_buffer[0]); + + // Precompute destination strides. + int dst_rank = dst_shape.size(); + absl::InlinedVector<int, 8> dst_strides; + size_t dst_stride = 1; + for (int dim_i = dst_rank - 1; dim_i >= 0; --dim_i) { + dst_strides.push_back(dst_stride); + dst_stride *= dst_shape[dim_i]; + } + + // Call the helper (recursive) function, starting with: + // * source index [0, 0, ..., 0] + // * the innermost dimension (last in the shape) + // * flat_src_i of 0 (corresponds to [0, 0, ..., 0] above) + // * source stride 1 + absl::InlinedVector<int, 8> src_indices(src_shape.size(), 0); + ReduceDimension<T, KernelImpl>(src_buffer, dst_buffer, src_shape, {dimension}, + absl::MakeSpan(dst_strides), + src_shape.size() - 1, + absl::MakeSpan(src_indices), 0, 1); + + return OkStatus(); +} + +} // namespace impl + +template <typename T> +Status ReduceSum::Execute(absl::Span<const T> src_buffer, + absl::Span<const T> init_buffer, + absl::Span<T> dst_buffer, int32_t dimension, + const Shape& src_shape, const Shape& dst_shape) { + return impl::GenericReduce<T, impl::SumKernel>( + src_buffer, init_buffer, dst_buffer, dimension, src_shape, dst_shape); +} + +template <typename T> +Status ReduceMin::Execute(absl::Span<const T> src_buffer, + absl::Span<const T> init_buffer, + absl::Span<T> dst_buffer, int32_t dimension, + const Shape& src_shape, const Shape& dst_shape) { + return impl::GenericReduce<T, impl::MinKernel>( + src_buffer, init_buffer, dst_buffer, dimension, src_shape, dst_shape); +} + +template <typename T> +Status ReduceMax::Execute(absl::Span<const T> src_buffer, + absl::Span<const T> init_buffer, + absl::Span<T> dst_buffer, int32_t dimension, + const Shape& src_shape, const Shape& dst_shape) { + return impl::GenericReduce<T, impl::MaxKernel>( + src_buffer, init_buffer, dst_buffer, dimension, src_shape, dst_shape); +} + +} // namespace kernels +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_INTERPRETER_BYTECODE_KERNELS_GENERIC_H_
diff --git a/hal/interpreter/bytecode_kernels_ruy.h b/hal/interpreter/bytecode_kernels_ruy.h new file mode 100644 index 0000000..ae36844 --- /dev/null +++ b/hal/interpreter/bytecode_kernels_ruy.h
@@ -0,0 +1,81 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_INTERPRETER_BYTECODE_KERNELS_RUY_H_ +#define IREE_HAL_INTERPRETER_BYTECODE_KERNELS_RUY_H_ + +#include "absl/base/thread_annotations.h" +#include "absl/memory/memory.h" +#include "base/status.h" +#include "hal/buffer_view.h" +#include "tensorflow/lite/experimental/ruy/context.h" +#include "tensorflow/lite/experimental/ruy/ruy.h" + +namespace iree { +namespace hal { +namespace kernels { + +// TODO(benvanik): something more clever for making this shareable. +// Maybe a factory fn based on the impl selected? +struct MatMul::RuntimeState { + // TODO(benvanik): share the thread pool but keep context per-fiber? + ruy::Context context; +}; + +inline std::unique_ptr<MatMul::RuntimeState> MatMul::CreateRuntimeState() { + return absl::make_unique<RuntimeState>(); +} + +template <typename T, typename ACC> +Status MatMul::Execute(RuntimeState* runtime_state, + const Buffers<T, ACC>& buffers) { + ruy::Matrix<T> lhs_matrix; + ruy::MakeSimpleLayout(buffers.lhs_shape[0], buffers.lhs_shape[1], + ruy::Order::kRowMajor, &lhs_matrix.layout); + lhs_matrix.data.set(buffers.lhs_buffer.data()); + + ruy::Matrix<T> rhs_matrix; + ruy::MakeSimpleLayout(buffers.rhs_shape[0], buffers.rhs_shape[1], + ruy::Order::kRowMajor, &rhs_matrix.layout); + rhs_matrix.data.set(buffers.rhs_buffer.data()); + + ruy::Matrix<T> dst_matrix; + ruy::MakeSimpleLayout(buffers.dst_shape[0], buffers.dst_shape[1], + ruy::Order::kRowMajor, &dst_matrix.layout); + dst_matrix.data.set(buffers.dst_buffer.data()); + + ruy::BasicSpec<ACC, T> spec; + spec.bias = buffers.bias_buffer.data(); + + if (buffers.multiplier_mantissa_buffer.size() == 1) { + spec.multiplier_fixedpoint = buffers.multiplier_mantissa_buffer[0]; + spec.multiplier_exponent = buffers.multiplier_exponent_buffer[0]; + } else { + spec.multiplier_fixedpoint_perchannel = + buffers.multiplier_mantissa_buffer.data(); + spec.multiplier_exponent_perchannel = + buffers.multiplier_exponent_buffer.data(); + } + + ruy::Mul<ruy::kAllPaths>(lhs_matrix, rhs_matrix, spec, + &runtime_state->context, &dst_matrix); + + return OkStatus(); +} + +} // namespace kernels +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_INTERPRETER_BYTECODE_KERNELS_RUY_H_
diff --git a/hal/interpreter/bytecode_kernels_test.cc b/hal/interpreter/bytecode_kernels_test.cc new file mode 100644 index 0000000..db99add --- /dev/null +++ b/hal/interpreter/bytecode_kernels_test.cc
@@ -0,0 +1,407 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/interpreter/bytecode_kernels.h" + +#include "base/memory.h" +#include "base/status_matchers.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace iree { +namespace hal { +namespace kernels { + +namespace { + +constexpr float kEpsilon = 0.0001f; + +template <typename T> +std::vector<T> MakeIota(int size) { + std::vector<T> v(size); + std::iota(v.begin(), v.end(), static_cast<T>(1)); + return v; +} + +TEST(Copy, WholeBuffer) { + Shape src_shape = {2, 2}; + auto src_buffer = MakeIota<uint8_t>(4); + std::vector<int32_t> src_indices = {0, 0}; + Shape dst_shape = src_shape; + std::vector<uint8_t> dst_buffer(dst_shape.element_count()); + std::vector<int32_t> dst_indices = {0, 0}; + std::vector<int32_t> lengths = {2, 2}; + auto expected_dst = src_buffer; + + EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices, + absl::MakeSpan(dst_buffer), dst_shape, dst_indices, + lengths)); + EXPECT_EQ(dst_buffer, expected_dst); +} + +TEST(Copy, FirstRow) { + Shape src_shape = {3, 4}; + auto src_buffer = MakeIota<uint8_t>(12); + std::vector<int32_t> src_indices = {0, 0}; + Shape dst_shape = {1, 4}; + std::vector<uint8_t> dst_buffer(dst_shape.element_count()); + std::vector<int32_t> dst_indices = {0, 0}; + std::vector<int32_t> lengths = {1, 4}; + std::vector<uint8_t> expected_dst = {1, 2, 3, 4}; + + EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices, + absl::MakeSpan(dst_buffer), dst_shape, dst_indices, + lengths)); + EXPECT_EQ(dst_buffer, expected_dst); +} + +TEST(Copy, RowPart) { + Shape src_shape = {3, 4}; + auto src_buffer = MakeIota<uint8_t>(12); + std::vector<int32_t> src_indices = {1, 1}; + Shape dst_shape = {1, 2}; + std::vector<uint8_t> dst_buffer(dst_shape.element_count()); + std::vector<int32_t> dst_indices = {0, 0}; + std::vector<int32_t> lengths = {1, 2}; + std::vector<uint8_t> expected_dst = {6, 7}; + + EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices, + absl::MakeSpan(dst_buffer), dst_shape, dst_indices, + lengths)); + EXPECT_EQ(dst_buffer, expected_dst); +} + +TEST(Copy, MultiRow) { + Shape src_shape = {3, 4}; + auto src_buffer = MakeIota<uint8_t>(12); + std::vector<int32_t> src_indices = {1, 0}; + Shape dst_shape = {2, 4}; + std::vector<uint8_t> dst_buffer(dst_shape.element_count()); + std::vector<int32_t> dst_indices = {0, 0}; + std::vector<int32_t> lengths = {2, 4}; + std::vector<uint8_t> expected_dst = {5, 6, 7, 8, 9, 10, 11, 12}; + + EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices, + absl::MakeSpan(dst_buffer), dst_shape, dst_indices, + lengths)); + EXPECT_EQ(dst_buffer, expected_dst); +} + +TEST(Copy, NonContiguous) { + Shape src_shape = {3, 4}; + auto src_buffer = MakeIota<uint8_t>(12); + std::vector<int32_t> src_indices = {1, 1}; + Shape dst_shape = {2, 2}; + std::vector<uint8_t> dst_buffer(dst_shape.element_count()); + std::vector<int32_t> dst_indices = {0, 0}; + std::vector<int32_t> lengths = {2, 2}; + std::vector<uint8_t> expected_dst = {6, 7, 10, 11}; + + EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices, + absl::MakeSpan(dst_buffer), dst_shape, dst_indices, + lengths)); + EXPECT_EQ(dst_buffer, expected_dst); +} + +TEST(Copy, MultiByte) { + Shape src_shape = {3, 4}; + auto src_vals = MakeIota<int32_t>(12); + auto src_buffer = ReinterpretSpan<uint8_t>(absl::MakeSpan(src_vals)); + std::vector<int32_t> src_indices = {1, 1}; + Shape dst_shape = {2, 2}; + std::vector<uint8_t> dst_buffer(dst_shape.element_count() * sizeof(int32_t)); + std::vector<int32_t> dst_indices = {0, 0}; + std::vector<int32_t> lengths = {2, 2}; + std::vector<int32_t> expected_dst = {6, 7, 10, 11}; + + EXPECT_OK(Copy::Execute<4>(src_buffer, src_shape, src_indices, + absl::MakeSpan(dst_buffer), dst_shape, dst_indices, + lengths)); + + absl::Span<int32_t> dst_buffer_int32_t = + ReinterpretSpan<int32_t>(absl::MakeSpan(dst_buffer)); + + EXPECT_EQ(dst_buffer_int32_t, expected_dst); +} + +TEST(Copy, NotFullDst) { + Shape src_shape = {3, 4}; + auto src_buffer = MakeIota<uint8_t>(12); + std::vector<int32_t> src_indices = {0, 0}; + Shape dst_shape = {4, 3}; + std::vector<uint8_t> dst_buffer(12, 42); + std::vector<int32_t> dst_indices = {1, 1}; + std::vector<int32_t> lengths = {2, 2}; + // clang-format off + std::vector<uint8_t> expected_dst = {42, 42, 42, + 42, 1, 2, + 42, 5, 6, + 42, 42, 42}; + // clang-format on + + EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices, + absl::MakeSpan(dst_buffer), dst_shape, dst_indices, + lengths)); + EXPECT_EQ(dst_buffer, expected_dst); +} + +TEST(Copy, HighRank) { + Shape src_shape = {3, 3, 3, 3}; + auto src_buffer = MakeIota<uint8_t>(81); + std::vector<int32_t> src_indices = {1, 1, 1, 1}; + Shape dst_shape = {2, 2, 2, 2}; + std::vector<uint8_t> dst_buffer(dst_shape.element_count()); + std::vector<int32_t> dst_indices = {0, 0, 0, 0}; + std::vector<int32_t> lengths = {2, 2, 2, 2}; + std::vector<uint8_t> expected_dst = {41, 42, 44, 45, 50, 51, 53, 54, + 68, 69, 71, 72, 77, 78, 80, 81}; + + EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices, + absl::MakeSpan(dst_buffer), dst_shape, dst_indices, + lengths)); + EXPECT_EQ(dst_buffer, expected_dst); +} + +TEST(Copy, Scalar) { + Shape src_shape = {}; + std::vector<uint8_t> src_buffer = {42}; + std::vector<int32_t> src_indices = {}; + Shape dst_shape = {}; + std::vector<uint8_t> dst_buffer(dst_shape.element_count()); + std::vector<int32_t> dst_indices = {}; + std::vector<int32_t> lengths = {}; + std::vector<uint8_t> expected_dst = {42}; + + EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices, + absl::MakeSpan(dst_buffer), dst_shape, dst_indices, + lengths)); + EXPECT_EQ(dst_buffer, expected_dst); +} + +TEST(Copy, ScalarMultiByte) { + Shape src_shape = {}; + std::vector<int32_t> src_vals = {INT32_MAX}; + auto src_buffer = ReinterpretSpan<uint8_t>(absl::MakeSpan(src_vals)); + std::vector<int32_t> src_indices = {}; + Shape dst_shape = {}; + std::vector<uint8_t> dst_buffer(sizeof(int32_t)); + std::vector<int32_t> dst_indices = {}; + std::vector<int32_t> lengths = {}; + std::vector<int32_t> expected_dst = {INT32_MAX}; + + EXPECT_OK(Copy::Execute<4>(src_buffer, src_shape, src_indices, + absl::MakeSpan(dst_buffer), dst_shape, dst_indices, + lengths)); + + absl::Span<int32_t> dst_buffer_int32_t = + ReinterpretSpan<int32_t>(absl::MakeSpan(dst_buffer)); + + EXPECT_EQ(dst_buffer_int32_t, expected_dst); +} + +TEST(Pad, NoPadding) { + Shape src_shape = {2, 3}; + auto src_buffer = MakeIota<uint16_t>(src_shape.element_count()); + std::vector<uint16_t> pad_value_buffer = {0}; + std::vector<int32_t> edge_padding_low = {0, 0}; + std::vector<int32_t> edge_padding_high = {0, 0}; + std::vector<int32_t> interior_padding = {0, 0}; + Shape dst_shape = src_shape; + std::vector<uint16_t> dst_buffer(dst_shape.element_count(), UINT16_MAX); + auto expected_dst = src_buffer; + + EXPECT_OK(Pad::Execute<uint16_t>( + src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape, + dst_shape, edge_padding_low, edge_padding_high, interior_padding)); + EXPECT_EQ(dst_buffer, expected_dst); +} + +TEST(Pad, LowHighPadding) { + Shape src_shape = {2, 3}; + auto src_buffer = MakeIota<uint16_t>(src_shape.element_count()); + std::vector<uint16_t> pad_value_buffer = {0}; + std::vector<int32_t> edge_padding_low = {0, 1}; + std::vector<int32_t> edge_padding_high = {1, 2}; + std::vector<int32_t> interior_padding = {0, 0}; + Shape dst_shape = {3, 6}; + std::vector<uint16_t> dst_buffer(dst_shape.element_count(), UINT16_MAX); + // clang-format off + std::vector<uint16_t> expected_dst = {0, 1, 2, 3, 0, 0, + 0, 4, 5, 6, 0, 0, + 0, 0, 0, 0, 0, 0}; + // clang-format on + + EXPECT_OK(Pad::Execute<uint16_t>( + src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape, + dst_shape, edge_padding_low, edge_padding_high, interior_padding)); + EXPECT_EQ(dst_buffer, expected_dst); +} + +TEST(Pad, OnlyHighPadding) { + Shape src_shape = {2, 3}; + auto src_buffer = MakeIota<uint16_t>(src_shape.element_count()); + std::vector<uint16_t> pad_value_buffer = {0}; + std::vector<int32_t> edge_padding_low = {0, 0}; + std::vector<int32_t> edge_padding_high = {1, 3}; + std::vector<int32_t> interior_padding = {0, 0}; + Shape dst_shape = {3, 6}; + std::vector<uint16_t> dst_buffer(dst_shape.element_count(), UINT16_MAX); + // clang-format off + std::vector<uint16_t> expected_dst = {1, 2, 3, 0, 0, 0, + 4, 5, 6, 0, 0, 0, + 0, 0, 0, 0, 0, 0}; + // clang-format on + + EXPECT_OK(Pad::Execute<uint16_t>( + src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape, + dst_shape, edge_padding_low, edge_padding_high, interior_padding)); + EXPECT_EQ(dst_buffer, expected_dst); +} + +TEST(Pad, OnlyLowPadding) { + Shape src_shape = {2, 3}; + auto src_buffer = MakeIota<uint16_t>(src_shape.element_count()); + std::vector<uint16_t> pad_value_buffer = {0}; + std::vector<int32_t> edge_padding_low = {1, 3}; + std::vector<int32_t> edge_padding_high = {0, 0}; + std::vector<int32_t> interior_padding = {0, 0}; + Shape dst_shape = {3, 6}; + std::vector<uint16_t> dst_buffer(dst_shape.element_count(), UINT16_MAX); + // clang-format off + std::vector<uint16_t> expected_dst = {0, 0, 0, 0, 0, 0, + 0, 0, 0, 1, 2, 3, + 0, 0, 0, 4, 5, 6}; + // clang-format on + + EXPECT_OK(Pad::Execute<uint16_t>( + src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape, + dst_shape, edge_padding_low, edge_padding_high, interior_padding)); + EXPECT_EQ(dst_buffer, expected_dst); +} + +TEST(Pad, OnlyInteriorPadding) { + Shape src_shape = {2, 3}; + auto src_buffer = MakeIota<uint16_t>(src_shape.element_count()); + std::vector<uint16_t> pad_value_buffer = {0}; + std::vector<int32_t> edge_padding_low = {0, 0}; + std::vector<int32_t> edge_padding_high = {0, 0}; + std::vector<int32_t> interior_padding = {1, 1}; + Shape dst_shape = {3, 5}; + std::vector<uint16_t> dst_buffer(dst_shape.element_count(), UINT16_MAX); + // clang-format off + std::vector<uint16_t> expected_dst = {1, 0, 2, 0, 3, + 0, 0, 0, 0, 0, + 4, 0, 5, 0, 6}; + // clang-format on + + EXPECT_OK(Pad::Execute<uint16_t>( + src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape, + dst_shape, edge_padding_low, edge_padding_high, interior_padding)); + EXPECT_EQ(dst_buffer, expected_dst); +} + +TEST(Pad, AllPaddingTypes) { + Shape src_shape = {2, 3}; + auto src_buffer = MakeIota<uint16_t>(src_shape.element_count()); + std::vector<uint16_t> pad_value_buffer = {0}; + std::vector<int32_t> edge_padding_low = {1, 1}; + std::vector<int32_t> edge_padding_high = {1, 2}; + std::vector<int32_t> interior_padding = {1, 1}; + Shape dst_shape = {5, 8}; + std::vector<uint16_t> dst_buffer(dst_shape.element_count(), UINT16_MAX); + // clang-format off + std::vector<uint16_t> expected_dst = {0, 0, 0, 0, 0, 0, 0, 0, + 0, 1, 0, 2, 0, 3, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 4, 0, 5, 0, 6, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0}; + // clang-format on + + EXPECT_OK(Pad::Execute<uint16_t>( + src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape, + dst_shape, edge_padding_low, edge_padding_high, interior_padding)); + EXPECT_EQ(dst_buffer, expected_dst); +} + +TEST(Pad, HighRank) { + Shape src_shape = {2, 2, 2, 2}; + auto src_buffer = MakeIota<uint16_t>(src_shape.element_count()); + std::vector<uint16_t> pad_value_buffer = {0}; + std::vector<int32_t> edge_padding_low = {1, 0, 0, 0}; + std::vector<int32_t> edge_padding_high = {0, 1, 0, 0}; + std::vector<int32_t> interior_padding = {0, 0, 1, 0}; + Shape dst_shape = {3, 3, 3, 2}; + std::vector<uint16_t> dst_buffer(dst_shape.element_count(), UINT16_MAX); + // clang-format off + std::vector<uint16_t> expected_dst = { 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + + 1, 2, 0, 0, 3, 4, + 5, 6, 0, 0, 7, 8, + 0, 0, 0, 0, 0, 0, + + 9, 10, 0, 0, 11, 12, + 13, 14, 0, 0, 15, 16, + 0, 0, 0, 0, 0, 0}; + // clang-format on + + ASSERT_EQ(dst_buffer.size(), expected_dst.size()); + + EXPECT_OK(Pad::Execute<uint16_t>( + src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape, + dst_shape, edge_padding_low, edge_padding_high, interior_padding)); + EXPECT_EQ(dst_buffer, expected_dst); +} + +TEST(ReduceSum, Scalar) { + Shape src_shape = {5}; + int32_t dimension = 0; + Shape dst_shape = {1}; + std::vector<float> src_buffer = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + std::vector<float> init_buffer = {0.0f}; + std::vector<float> dst_buffer(dst_shape.element_count(), 0.0f); + std::vector<float> expected_dst = {5.0f}; + + EXPECT_OK(ReduceSum::Execute<float>(src_buffer, init_buffer, + absl::MakeSpan(dst_buffer), dimension, + src_shape, dst_shape)); + + for (int i = 0; i < dst_buffer.size(); ++i) { + EXPECT_NEAR(expected_dst[i], dst_buffer[i], kEpsilon); + } +} + +TEST(ReduceMin, TwoDimensionsToOne) { + Shape src_shape = {3, 3}; + int32_t dimension = 0; + Shape dst_shape = {3}; + std::vector<float> src_buffer = MakeIota<float>(src_shape.element_count()); + std::vector<float> init_buffer = {std::numeric_limits<float>::max()}; + std::vector<float> dst_buffer(dst_shape.element_count(), 0.0f); + std::vector<float> expected_dst = {1.0f, 2.0f, 3.0f}; + + EXPECT_OK(ReduceMin::Execute<float>(src_buffer, init_buffer, + absl::MakeSpan(dst_buffer), dimension, + src_shape, dst_shape)); + + for (int i = 0; i < dst_buffer.size(); ++i) { + EXPECT_NEAR(expected_dst[i], dst_buffer[i], kEpsilon); + } +} + +} // namespace +} // namespace kernels +} // namespace hal +} // namespace iree
diff --git a/hal/interpreter/interpreter_command_processor.cc b/hal/interpreter/interpreter_command_processor.cc new file mode 100644 index 0000000..ac8b0b1 --- /dev/null +++ b/hal/interpreter/interpreter_command_processor.cc
@@ -0,0 +1,66 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/interpreter/interpreter_command_processor.h" + +#include "absl/container/inlined_vector.h" +#include "absl/types/span.h" +#include "base/source_location.h" +#include "base/status.h" +#include "base/tracing.h" +#include "hal/buffer_view.h" +#include "hal/interpreter/bytecode_executable.h" +#include "rt/stack.h" + +namespace iree { +namespace hal { + +InterpreterCommandProcessor::InterpreterCommandProcessor( + Allocator* allocator, CommandBufferModeBitfield mode, + CommandCategoryBitfield command_categories) + : HostLocalCommandProcessor(allocator, mode, command_categories) {} + +InterpreterCommandProcessor::~InterpreterCommandProcessor() = default; + +Status InterpreterCommandProcessor::Dispatch( + const DispatchRequest& dispatch_request) { + IREE_TRACE_SCOPE0("InterpreterCommandProcessor::Dispatch"); + + // Lookup the exported function. + auto* executable = + static_cast<BytecodeExecutable*>(dispatch_request.executable); + const auto& module = executable->module(); + ASSIGN_OR_RETURN(auto entry_function, module->LookupFunctionByOrdinal( + rt::Function::Linkage::kExport, + dispatch_request.entry_point)); + + rt::Stack stack(executable->context().get()); + + // TODO(benvanik): avoid this by directly referencing the bindings. + absl::InlinedVector<BufferView, 8> arguments; + arguments.reserve(dispatch_request.bindings.size()); + for (auto& binding : dispatch_request.bindings) { + arguments.push_back(BufferView{add_ref(binding.buffer), binding.shape, + binding.element_size}); + } + absl::InlinedVector<BufferView, 8> results; + + RETURN_IF_ERROR(executable->module()->Execute( + &stack, entry_function, std::move(arguments), &results)); + + return OkStatus(); +} + +} // namespace hal +} // namespace iree
diff --git a/hal/interpreter/interpreter_command_processor.h b/hal/interpreter/interpreter_command_processor.h new file mode 100644 index 0000000..0913517 --- /dev/null +++ b/hal/interpreter/interpreter_command_processor.h
@@ -0,0 +1,36 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_INTERPRETER_INTERPRETER_COMMAND_PROCESSOR_H_ +#define IREE_HAL_INTERPRETER_INTERPRETER_COMMAND_PROCESSOR_H_ + +#include "hal/host/host_local_command_processor.h" + +namespace iree { +namespace hal { + +class InterpreterCommandProcessor final : public HostLocalCommandProcessor { + public: + InterpreterCommandProcessor(Allocator* allocator, + CommandBufferModeBitfield mode, + CommandCategoryBitfield command_categories); + ~InterpreterCommandProcessor() override; + + Status Dispatch(const DispatchRequest& dispatch_request) override; +}; + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_INTERPRETER_INTERPRETER_COMMAND_PROCESSOR_H_
diff --git a/hal/interpreter/interpreter_device.cc b/hal/interpreter/interpreter_device.cc new file mode 100644 index 0000000..faab7af --- /dev/null +++ b/hal/interpreter/interpreter_device.cc
@@ -0,0 +1,173 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/interpreter/interpreter_device.h" + +#include <utility> + +#include "absl/memory/memory.h" +#include "base/status.h" +#include "base/tracing.h" +#include "hal/command_buffer_validation.h" +#include "hal/command_queue.h" +#include "hal/fence.h" +#include "hal/host/async_command_queue.h" +#include "hal/host/host_event.h" +#include "hal/host/host_submission_queue.h" +#include "hal/host/inproc_command_buffer.h" +#include "hal/interpreter/bytecode_cache.h" +#include "hal/interpreter/interpreter_command_processor.h" + +namespace iree { +namespace hal { + +namespace { + +// A CommandQueue that performs no synchronization (semaphores/fences) and just +// directly executes command buffers inline. +// +// This is meant to be wrapped by SyncCommandQueue or AsyncCommandQueue that +// themselves perform the synchronization/threading/etc. As such we ignore +// all semaphores in the provided batches under the assumption that if Submit is +// being called then all dependencies are valid. The wrapping queue is also +// responsible for signaling the fence as well as propagating errors in a way +// that is dependent on how it is performing its synchronization. +class UnsynchronizedCommandQueue final : public CommandQueue { + public: + UnsynchronizedCommandQueue(Allocator* allocator, std::string name, + CommandCategoryBitfield supported_categories) + : CommandQueue(std::move(name), supported_categories), + allocator_(allocator) {} + ~UnsynchronizedCommandQueue() override = default; + + Status Submit(absl::Span<const SubmissionBatch> batches, + FenceValue fence) override { + IREE_TRACE_SCOPE0("UnsynchronizedCommandQueue::Submit"); + DCHECK_EQ(nullptr, fence.first) + << "Fences must be handled by the wrapping queue"; + + // Process command buffers and propagate errors asynchronously through the + // fence. This ensures that even if we are running synchronously we still + // get consistent failure behavior with drivers that are purely async. + for (auto& batch : batches) { + DCHECK(batch.wait_semaphores.empty() && batch.signal_semaphores.empty()) + << "Semaphores must be handled by the wrapping queue"; + RETURN_IF_ERROR(ProcessCommandBuffers(batch.command_buffers)); + } + + // NOTE: fence is ignored here. + return OkStatus(); + } + + Status WaitIdle(absl::Time deadline) override { + // No-op. + return OkStatus(); + } + + private: + // Processes each command buffer in-turn with a fresh processor. + // This ensures we don't have any state that can carry across buffers. + Status ProcessCommandBuffers( + absl::Span<CommandBuffer* const> command_buffers) { + IREE_TRACE_SCOPE0("UnsynchronizedCommandQueue::ProcessCommandBuffers"); + for (auto* command_buffer : command_buffers) { + auto* inproc_command_buffer = + static_cast<InProcCommandBuffer*>(command_buffer->impl()); + InterpreterCommandProcessor command_processor( + allocator_, command_buffer->mode(), supported_categories()); + RETURN_IF_ERROR(inproc_command_buffer->Process(&command_processor)); + } + return OkStatus(); + } + + Allocator* const allocator_; +}; + +} // namespace + +InterpreterDevice::InterpreterDevice(DeviceInfo device_info) + : Device(std::move(device_info)), instance_(make_ref<rt::Instance>()) { + // We currently only expose a single command queue. + auto command_queue = absl::make_unique<UnsynchronizedCommandQueue>( + &allocator_, "cpu0", + CommandCategory::kTransfer | CommandCategory::kDispatch); + + // TODO(benvanik): allow injection of the wrapper type to support + // SyncCommandQueue without always linking in both. + auto async_command_queue = + absl::make_unique<AsyncCommandQueue>(std::move(command_queue)); + command_queues_.push_back(std::move(async_command_queue)); +} + +InterpreterDevice::~InterpreterDevice() = default; + +std::shared_ptr<ExecutableCache> InterpreterDevice::CreateExecutableCache() { + return std::make_shared<BytecodeCache>(add_ref(instance_), &allocator_); +} + +StatusOr<ref_ptr<CommandBuffer>> InterpreterDevice::CreateCommandBuffer( + CommandBufferModeBitfield mode, + CommandCategoryBitfield command_categories) { + // TODO(b/140026716): conditionally enable validation. + auto impl = + make_ref<InProcCommandBuffer>(&allocator_, mode, command_categories); + return WrapCommandBufferWithValidation(std::move(impl)); +} + +StatusOr<ref_ptr<Event>> InterpreterDevice::CreateEvent() { + return make_ref<HostEvent>(); +} + +StatusOr<ref_ptr<BinarySemaphore>> InterpreterDevice::CreateBinarySemaphore( + bool initial_value) { + IREE_TRACE_SCOPE0("InterpreterDevice::CreateBinarySemaphore"); + return make_ref<HostBinarySemaphore>(initial_value); +} + +StatusOr<ref_ptr<TimelineSemaphore>> InterpreterDevice::CreateTimelineSemaphore( + uint64_t initial_value) { + IREE_TRACE_SCOPE0("InterpreterDevice::CreateTimelineSemaphore"); + + // TODO(b/140141417): implement timeline semaphores. + return UnimplementedErrorBuilder(IREE_LOC) + << "Timeline semaphores not yet implemented"; +} + +StatusOr<ref_ptr<Fence>> InterpreterDevice::CreateFence( + uint64_t initial_value) { + IREE_TRACE_SCOPE0("InterpreterDevice::CreateFence"); + return make_ref<HostFence>(initial_value); +} + +Status InterpreterDevice::WaitAllFences(absl::Span<const FenceValue> fences, + absl::Time deadline) { + IREE_TRACE_SCOPE0("InterpreterDevice::WaitAllFences"); + return HostFence::WaitForFences(fences, /*wait_all=*/true, deadline); +} + +StatusOr<int> InterpreterDevice::WaitAnyFence( + absl::Span<const FenceValue> fences, absl::Time deadline) { + IREE_TRACE_SCOPE0("InterpreterDevice::WaitAnyFence"); + return HostFence::WaitForFences(fences, /*wait_all=*/false, deadline); +} + +Status InterpreterDevice::WaitIdle(absl::Time deadline) { + for (auto& command_queue : command_queues_) { + RETURN_IF_ERROR(command_queue->WaitIdle(deadline)); + } + return OkStatus(); +} + +} // namespace hal +} // namespace iree
diff --git a/hal/interpreter/interpreter_device.h b/hal/interpreter/interpreter_device.h new file mode 100644 index 0000000..c5bdbf2 --- /dev/null +++ b/hal/interpreter/interpreter_device.h
@@ -0,0 +1,79 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_INTERPRETER_INTERPRETER_DEVICE_H_ +#define IREE_HAL_INTERPRETER_INTERPRETER_DEVICE_H_ + +#include "absl/container/inlined_vector.h" +#include "absl/types/span.h" +#include "base/memory.h" +#include "hal/device.h" +#include "hal/host/host_local_allocator.h" +#include "hal/interpreter/bytecode_kernels.h" +#include "rt/instance.h" + +namespace iree { +namespace hal { + +class InterpreterDevice final : public Device { + public: + explicit InterpreterDevice(DeviceInfo device_info); + ~InterpreterDevice() override; + + kernels::RuntimeState* kernel_runtime_state() { + return &kernel_runtime_state_; + } + + Allocator* allocator() const override { return &allocator_; } + + absl::Span<CommandQueue*> dispatch_queues() const override { + return RawPtrSpan(absl::MakeSpan(command_queues_)); + } + + absl::Span<CommandQueue*> transfer_queues() const override { + return RawPtrSpan(absl::MakeSpan(command_queues_)); + } + + std::shared_ptr<ExecutableCache> CreateExecutableCache() override; + + StatusOr<ref_ptr<CommandBuffer>> CreateCommandBuffer( + CommandBufferModeBitfield mode, + CommandCategoryBitfield command_categories) override; + + StatusOr<ref_ptr<Event>> CreateEvent() override; + + StatusOr<ref_ptr<BinarySemaphore>> CreateBinarySemaphore( + bool initial_value) override; + StatusOr<ref_ptr<TimelineSemaphore>> CreateTimelineSemaphore( + uint64_t initial_value) override; + + StatusOr<ref_ptr<Fence>> CreateFence(uint64_t initial_value) override; + Status WaitAllFences(absl::Span<const FenceValue> fences, + absl::Time deadline) override; + StatusOr<int> WaitAnyFence(absl::Span<const FenceValue> fences, + absl::Time deadline) override; + + Status WaitIdle(absl::Time deadline) override; + + private: + ref_ptr<rt::Instance> instance_; + kernels::RuntimeState kernel_runtime_state_; + mutable HostLocalAllocator allocator_; + mutable absl::InlinedVector<std::unique_ptr<CommandQueue>, 1> command_queues_; +}; + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_INTERPRETER_INTERPRETER_DEVICE_H_
diff --git a/hal/interpreter/interpreter_driver.cc b/hal/interpreter/interpreter_driver.cc new file mode 100644 index 0000000..f0364af --- /dev/null +++ b/hal/interpreter/interpreter_driver.cc
@@ -0,0 +1,62 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/interpreter/interpreter_driver.h" + +#include <memory> + +#include "hal/device_info.h" +#include "hal/interpreter/interpreter_device.h" + +namespace iree { +namespace hal { + +namespace { + +DeviceInfo GetDefaultDeviceInfo() { + DeviceFeatureBitfield supported_features = DeviceFeature::kNone; + // TODO(benvanik): implement debugging/profiling features. + // supported_features |= DeviceFeature::kDebugging; + // supported_features |= DeviceFeature::kCoverage; + // supported_features |= DeviceFeature::kProfiling; + DeviceInfo device_info("interpreter", supported_features); + // TODO(benvanik): device info. + return device_info; +} + +} // namespace + +InterpreterDriver::InterpreterDriver() : Driver("interpreter") {} + +InterpreterDriver::~InterpreterDriver() = default; + +StatusOr<std::vector<DeviceInfo>> +InterpreterDriver::EnumerateAvailableDevices() { + std::vector<DeviceInfo> device_infos; + device_infos.push_back(GetDefaultDeviceInfo()); + return device_infos; +} + +StatusOr<std::shared_ptr<Device>> InterpreterDriver::CreateDefaultDevice() { + return CreateDevice(GetDefaultDeviceInfo()); +} + +StatusOr<std::shared_ptr<Device>> InterpreterDriver::CreateDevice( + const DeviceInfo& device_info) { + auto device = std::make_shared<InterpreterDevice>(device_info); + return device; +} + +} // namespace hal +} // namespace iree
diff --git a/hal/interpreter/interpreter_driver.h b/hal/interpreter/interpreter_driver.h new file mode 100644 index 0000000..5389c5a --- /dev/null +++ b/hal/interpreter/interpreter_driver.h
@@ -0,0 +1,39 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_INTERPRETER_INTERPRETER_DRIVER_H_ +#define IREE_HAL_INTERPRETER_INTERPRETER_DRIVER_H_ + +#include "hal/driver.h" + +namespace iree { +namespace hal { + +class InterpreterDriver final : public Driver { + public: + InterpreterDriver(); + ~InterpreterDriver() override; + + StatusOr<std::vector<DeviceInfo>> EnumerateAvailableDevices() override; + + StatusOr<std::shared_ptr<Device>> CreateDefaultDevice() override; + + StatusOr<std::shared_ptr<Device>> CreateDevice( + const DeviceInfo& device_info) override; +}; + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_INTERPRETER_INTERPRETER_DRIVER_H_
diff --git a/hal/interpreter/interpreter_driver_module.cc b/hal/interpreter/interpreter_driver_module.cc new file mode 100644 index 0000000..fc40ae6 --- /dev/null +++ b/hal/interpreter/interpreter_driver_module.cc
@@ -0,0 +1,39 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <memory> + +#include "base/init.h" +#include "base/status.h" +#include "hal/driver_registry.h" +#include "hal/interpreter/interpreter_driver.h" + +namespace iree { +namespace hal { +namespace { + +StatusOr<std::shared_ptr<Driver>> CreateInterpreterDriver() { + return std::make_shared<InterpreterDriver>(); +} + +} // namespace +} // namespace hal +} // namespace iree + +IREE_REGISTER_MODULE_INITIALIZER(iree_hal_interpreter_driver, { + QCHECK_OK(::iree::hal::DriverRegistry::shared_registry()->Register( + "interpreter", ::iree::hal::CreateInterpreterDriver)); +}); +IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(iree_hal, + iree_hal_interpreter_driver);
diff --git a/hal/interpreter/interpreter_module.cc b/hal/interpreter/interpreter_module.cc new file mode 100644 index 0000000..178d1f6 --- /dev/null +++ b/hal/interpreter/interpreter_module.cc
@@ -0,0 +1,80 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/interpreter/interpreter_module.h" + +#include "base/flatbuffer_util.h" +#include "base/status.h" +#include "base/tracing.h" +#include "hal/interpreter/bytecode_dispatch.h" +#include "vm/bytecode_tables_interpreter.h" + +namespace iree { +namespace hal { + +// static +StatusOr<ref_ptr<rt::Module>> InterpreterModule::FromDef( + hal::Allocator* allocator, const ModuleDef& module_def) { + ASSIGN_OR_RETURN(auto module_file, + vm::ModuleFile::Create(&module_def, []() {})); + if (module_file->root() == nullptr) { + return InvalidArgumentErrorBuilder(IREE_LOC) << "No root ModuleDef present"; + } + + auto module = + assign_ref(new InterpreterModule(allocator, std::move(module_file))); + + // TODO(benvanik): validate internals here? or make explicit? + + return {std::move(module)}; +} + +InterpreterModule::InterpreterModule( + hal::Allocator* allocator, std::unique_ptr<vm::ModuleFile> module_file) + : vm::BytecodeModule(std::move(module_file), + vm::interpreter_opcode_table()), + allocator_(allocator) {} + +Status InterpreterModule::Execute( + rt::Stack* stack, const rt::Function function, + absl::InlinedVector<hal::BufferView, 8> arguments, + absl::InlinedVector<hal::BufferView, 8>* results) const { + IREE_TRACE_SCOPE0("InterperterModule::Execute"); + + // Push stack frame for the function we are calling. + ASSIGN_OR_RETURN(auto* callee_stack_frame, stack->PushFrame(function)); + + // TODO(benvanik): rework register storage interface. + ASSIGN_OR_RETURN(const auto* function_def, + GetFunctionDef(function.linkage(), function.ordinal())); + auto* registers = callee_stack_frame->mutable_registers(); + registers->buffer_views.resize(function_def->bytecode()->local_count()); + + // Marshal input arguments. + for (int i = 0; i < arguments.size(); ++i) { + registers->buffer_views[i] = std::move(arguments[i]); + } + + // Run main dispatch loop until it exits (or errors). + RETURN_IF_ERROR(Dispatch(allocator_, &kernel_runtime_state_, stack, + callee_stack_frame, absl::MakeSpan(*results))); + + // Pop the callee frame to balance out the stack. + RETURN_IF_ERROR(stack->PopFrame()); + + return OkStatus(); +} + +} // namespace hal +} // namespace iree
diff --git a/hal/interpreter/interpreter_module.h b/hal/interpreter/interpreter_module.h new file mode 100644 index 0000000..1664589 --- /dev/null +++ b/hal/interpreter/interpreter_module.h
@@ -0,0 +1,55 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_INTERPRETER_INTERPRETER_MODULE_H_ +#define IREE_HAL_INTERPRETER_INTERPRETER_MODULE_H_ + +#include <memory> + +#include "absl/types/span.h" +#include "base/status.h" +#include "hal/allocator.h" +#include "hal/buffer_view.h" +#include "hal/interpreter/bytecode_kernels.h" +#include "rt/function.h" +#include "rt/module.h" +#include "rt/stack.h" +#include "vm/bytecode_module.h" +#include "vm/bytecode_tables_interpreter.h" + +namespace iree { +namespace hal { + +class InterpreterModule final : public vm::BytecodeModule { + public: + static StatusOr<ref_ptr<rt::Module>> FromDef(hal::Allocator* allocator, + const ModuleDef& module_def); + + Status Execute( + rt::Stack* stack, const rt::Function function, + absl::InlinedVector<hal::BufferView, 8> arguments, + absl::InlinedVector<hal::BufferView, 8>* results) const override; + + private: + InterpreterModule(hal::Allocator* allocator, + std::unique_ptr<vm::ModuleFile> module_file); + + hal::Allocator* allocator_; + mutable kernels::RuntimeState kernel_runtime_state_; +}; + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_INTERPRETER_INTERPRETER_MODULE_H_
diff --git a/hal/resource.h b/hal/resource.h new file mode 100644 index 0000000..8396f64 --- /dev/null +++ b/hal/resource.h
@@ -0,0 +1,33 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_RESOURCE_H_ +#define IREE_HAL_RESOURCE_H_ + +#include "base/ref_ptr.h" + +namespace iree { +namespace hal { + +// Abstract resource type whose lifetime is managed by a ResourceSet. +// Used mostly just to get a virtual dtor, though we could add nicer logging. +class Resource : public RefObject<Resource> { + public: + virtual ~Resource() = default; +}; + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_RESOURCE_H_
diff --git a/hal/semaphore.h b/hal/semaphore.h new file mode 100644 index 0000000..5ce0c25 --- /dev/null +++ b/hal/semaphore.h
@@ -0,0 +1,61 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_SEMAPHORE_H_ +#define IREE_HAL_SEMAPHORE_H_ + +#include "absl/types/variant.h" +#include "hal/resource.h" + +namespace iree { +namespace hal { + +// A synchronization primitive used to indicate submission dependencies. +// Semaphores are either of type binary (signaled or unsignaled) or timeline +// (uint64 payload with >= semantics). +class Semaphore : public Resource { + public: +}; + +// Binary semaphores have strict ordering requirements and must be carefully +// balanced. Each binary semaphore must only be waited on after a signal +// operation has been issued and each wait requires exactly one signal. They +// are commonly used only when interacting with external handles that may +// cross device or process boundaries. +class BinarySemaphore : public Semaphore { + public: +}; + +// Timeline semaphores act as a fence along a per-semaphore timeline where +// signaling is done by setting the payload to a monotonically increasing +// 64-bit integer and waiting is done by blocking until the payload is set +// greater-than or equal-to the specified value. Timeline semaphores may be +// waited on or signaled in any order and can be significantly more +// efficient due to system-level coalescing. +class TimelineSemaphore : public Semaphore { + public: + // TODO(benvanik): add value query support. + // TODO(benvanik): add host-side signal/wait. +}; + +// A reference to a strongly-typed semaphore and associated information. +// For TimelineSemaphores the provided payload is used to specify either the +// payload to wait for or new payload value. +using SemaphoreValue = + absl::variant<BinarySemaphore*, std::pair<TimelineSemaphore*, uint64_t>>; + +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_SEMAPHORE_H_
diff --git a/iree/hal/stack_trace.h b/hal/stack_trace.h similarity index 100% rename from iree/hal/stack_trace.h rename to hal/stack_trace.h
diff --git a/hal/testing/BUILD b/hal/testing/BUILD new file mode 100644 index 0000000..4f99b78 --- /dev/null +++ b/hal/testing/BUILD
@@ -0,0 +1,36 @@ +# Test utilities for HAL-specific code. + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "mock_allocator", + testonly = True, + hdrs = ["mock_allocator.h"], + deps = [ + "///hal:allocator", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "mock_command_buffer", + testonly = True, + hdrs = ["mock_command_buffer.h"], + deps = [ + "///hal:command_buffer", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "mock_command_queue", + testonly = True, + hdrs = ["mock_command_queue.h"], + deps = [ + "///hal:command_queue", + "@com_google_googletest//:gtest", + ], +)
diff --git a/iree/hal/testing/CMakeLists.txt b/hal/testing/CMakeLists.txt similarity index 100% rename from iree/hal/testing/CMakeLists.txt rename to hal/testing/CMakeLists.txt
diff --git a/hal/testing/mock_allocator.h b/hal/testing/mock_allocator.h new file mode 100644 index 0000000..fa50bac --- /dev/null +++ b/hal/testing/mock_allocator.h
@@ -0,0 +1,55 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_TESTING_MOCK_ALLOCATOR_H_ +#define IREE_HAL_TESTING_MOCK_ALLOCATOR_H_ + +#include "gmock/gmock.h" +#include "hal/allocator.h" + +namespace iree { +namespace hal { +namespace testing { + +class MockAllocator : public ::testing::StrictMock<Allocator> { + public: + MockAllocator() : ::testing::StrictMock<Allocator>() {} + + MOCK_CONST_METHOD4(CanUseBufferLike, + bool(Allocator* source_allocator, + MemoryTypeBitfield memory_type, + BufferUsageBitfield buffer_usage, + BufferUsageBitfield intended_usage)); + + MOCK_CONST_METHOD3(CanAllocate, bool(MemoryTypeBitfield memory_type, + BufferUsageBitfield buffer_usage, + size_t allocation_size)); + + MOCK_METHOD3(Allocate, + StatusOr<ref_ptr<Buffer>>(MemoryTypeBitfield memory_type, + BufferUsageBitfield buffer_usage, + size_t allocation_size)); + + MOCK_METHOD5(WrapMutable, + StatusOr<ref_ptr<Buffer>>(MemoryTypeBitfield memory_type, + MemoryAccessBitfield allowed_access, + BufferUsageBitfield buffer_usage, + void* data, size_t data_length)); +}; + +} // namespace testing +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_TESTING_MOCK_ALLOCATOR_H_
diff --git a/hal/testing/mock_command_buffer.h b/hal/testing/mock_command_buffer.h new file mode 100644 index 0000000..dace511 --- /dev/null +++ b/hal/testing/mock_command_buffer.h
@@ -0,0 +1,80 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_TESTING_MOCK_COMMAND_BUFFER_H_ +#define IREE_HAL_TESTING_MOCK_COMMAND_BUFFER_H_ + +#include "gmock/gmock.h" +#include "hal/command_buffer.h" + +namespace iree { +namespace hal { +namespace testing { + +class MockCommandBuffer : public ::testing::StrictMock<CommandBuffer> { + public: + MockCommandBuffer(Allocator* allocator, CommandBufferModeBitfield mode, + CommandCategoryBitfield command_categories) + : ::testing::StrictMock<CommandBuffer>(allocator, mode, + command_categories) {} + + bool is_recording() const override { return false; } + + MOCK_METHOD0(Begin, Status()); + MOCK_METHOD0(End, Status()); + + MOCK_METHOD4(ExecutionBarrier, + Status(ExecutionStageBitfield source_stage_mask, + ExecutionStageBitfield target_stage_mask, + absl::Span<const MemoryBarrier> memory_barriers, + absl::Span<const BufferBarrier> buffer_barriers)); + + MOCK_METHOD2(SignalEvent, + Status(Event* event, ExecutionStageBitfield source_stage_mask)); + + MOCK_METHOD2(ResetEvent, + Status(Event* event, ExecutionStageBitfield source_stage_mask)); + + MOCK_METHOD5(WaitEvents, + Status(absl::Span<Event*> events, + ExecutionStageBitfield source_stage_mask, + ExecutionStageBitfield target_stage_mask, + absl::Span<const MemoryBarrier> memory_barriers, + absl::Span<const BufferBarrier> buffer_barriers)); + + MOCK_METHOD5(FillBuffer, + Status(Buffer* target_buffer, device_size_t target_offset, + device_size_t length, const void* pattern, + size_t pattern_length)); + + MOCK_METHOD1(DiscardBuffer, Status(Buffer* buffer)); + + MOCK_METHOD5(UpdateBuffer, + Status(const void* source_buffer, device_size_t source_offset, + Buffer* target_buffer, device_size_t target_offset, + device_size_t length)); + + MOCK_METHOD5(CopyBuffer, + Status(Buffer* source_buffer, device_size_t source_offset, + Buffer* target_buffer, device_size_t target_offset, + device_size_t length)); + + MOCK_METHOD1(Dispatch, Status(const DispatchRequest& dispatch_request)); +}; + +} // namespace testing +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_TESTING_MOCK_COMMAND_BUFFER_H_
diff --git a/hal/testing/mock_command_queue.h b/hal/testing/mock_command_queue.h new file mode 100644 index 0000000..9195fb1 --- /dev/null +++ b/hal/testing/mock_command_queue.h
@@ -0,0 +1,42 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_TESTING_MOCK_COMMAND_QUEUE_H_ +#define IREE_HAL_TESTING_MOCK_COMMAND_QUEUE_H_ + +#include "gmock/gmock.h" +#include "hal/command_queue.h" + +namespace iree { +namespace hal { +namespace testing { + +class MockCommandQueue : public ::testing::StrictMock<CommandQueue> { + public: + MockCommandQueue(std::string name, + CommandCategoryBitfield supported_categories) + : ::testing::StrictMock<CommandQueue>(std::move(name), + supported_categories) {} + + MOCK_METHOD2(Submit, Status(absl::Span<const SubmissionBatch> batches, + FenceValue fence)); + + MOCK_METHOD1(WaitIdle, Status(absl::Time deadline)); +}; + +} // namespace testing +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_TESTING_MOCK_COMMAND_QUEUE_H_
diff --git a/hal/vulkan/BUILD b/hal/vulkan/BUILD new file mode 100644 index 0000000..97880b7 --- /dev/null +++ b/hal/vulkan/BUILD
@@ -0,0 +1,380 @@ +# HAL implementation using Vulkan and (likely) SPIR-V executables. + +load("//:build_defs.google.bzl", "PLATFORM_VULKAN_LOADER_COPTS", "PLATFORM_VULKAN_TEST_DEPS") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +# --define=IREE_VK=native to use the native Vulkan drivers (and real hardware). +config_setting( + name = "native_vk", + values = { + "define": "IREE_VK=native", + }, +) + +# --define=IREE_VK=swiftshader to use SwiftShader. +config_setting( + name = "swiftshader_vk", + values = { + "define": "IREE_VK=swiftshader", + }, +) + +cc_library( + name = "debug_reporter", + srcs = ["debug_reporter.cc"], + hdrs = ["debug_reporter.h"], + deps = [ + ":dynamic_symbols", + ":status_util", + "///base:status", + "///base:tracing", + "@vulkan_headers//:vulkan_headers_no_prototypes", + ], +) + +cc_library( + name = "descriptor_pool_cache", + srcs = ["descriptor_pool_cache.cc"], + hdrs = ["descriptor_pool_cache.h"], + deps = [ + ":dynamic_symbols", + ":handle_util", + ":status_util", + "///base:ref_ptr", + "///base:status", + "///base:tracing", + "@com_google_absl//absl/container:inlined_vector", + ], +) + +cc_library( + name = "descriptor_set_arena", + srcs = ["descriptor_set_arena.cc"], + hdrs = ["descriptor_set_arena.h"], + deps = [ + ":descriptor_pool_cache", + ":pipeline_executable", + ":status_util", + ":vma_allocator", + "///base:arena", + "///base:math", + "///base:status", + "///base:tracing", + "///hal:command_buffer", + ], +) + +cc_library( + name = "direct_command_buffer", + srcs = ["direct_command_buffer.cc"], + hdrs = ["direct_command_buffer.h"], + deps = [ + ":descriptor_pool_cache", + ":descriptor_set_arena", + ":dynamic_symbols", + ":handle_util", + ":native_event", + ":pipeline_executable", + ":status_util", + ":vma_allocator", + "///base:math", + "///base:source_location", + "///base:status", + "///base:tracing", + "///hal:command_buffer", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/synchronization", + "@vulkan_headers//:vulkan_headers_no_prototypes", + ], +) + +cc_library( + name = "direct_command_queue", + srcs = ["direct_command_queue.cc"], + hdrs = ["direct_command_queue.h"], + deps = [ + ":direct_command_buffer", + ":dynamic_symbols", + ":handle_util", + ":legacy_fence", + ":native_binary_semaphore", + ":status_util", + "///base:arena", + "///base:memory", + "///base:source_location", + "///base:status", + "///base:tracing", + "///hal:command_queue", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@vulkan_headers//:vulkan_headers_no_prototypes", + ], +) + +cc_library( + name = "dynamic_symbols", + srcs = ["dynamic_symbols.cc"], + hdrs = [ + "dynamic_symbol_tables.h", + "dynamic_symbols.h", + ], + copts = PLATFORM_VULKAN_LOADER_COPTS, + linkopts = [ + "-ldl", + ], + deps = [ + "///base:file_path", + "///base:ref_ptr", + "///base:source_location", + "///base:status", + "///base:target_platform", + "///base:tracing", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/memory", + "@vulkan_headers//:vulkan_headers_no_prototypes", + ], +) + +cc_test( + name = "dynamic_symbols_test", + srcs = ["dynamic_symbols_test.cc"], + deps = [ + ":status_util", + ":dynamic_symbols", + "///base:status_matchers", + ] + PLATFORM_VULKAN_TEST_DEPS, +) + +cc_library( + name = "extensibility_util", + srcs = ["extensibility_util.cc"], + hdrs = ["extensibility_util.h"], + deps = [ + ":dynamic_symbols", + ":status_util", + "///base:memory", + "///base:status", + "///base:tracing", + "@com_google_absl//absl/types:span", + "@vulkan_headers//:vulkan_headers_no_prototypes", + ], +) + +cc_library( + name = "handle_util", + hdrs = ["handle_util.h"], + deps = [ + ":dynamic_symbols", + ":extensibility_util", + "///base:ref_ptr", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/utility", + "@vulkan_headers//:vulkan_headers_no_prototypes", + ], +) + +cc_library( + name = "legacy_fence", + srcs = ["legacy_fence.cc"], + hdrs = ["legacy_fence.h"], + deps = [ + ":handle_util", + ":status_util", + "///base:intrusive_list", + "///base:ref_ptr", + "///base:source_location", + "///base:status", + "///base:tracing", + "///hal:fence", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@vulkan_headers//:vulkan_headers_no_prototypes", + ], +) + +cc_library( + name = "native_binary_semaphore", + srcs = ["native_binary_semaphore.cc"], + hdrs = ["native_binary_semaphore.h"], + deps = [ + ":handle_util", + "///hal:semaphore", + "@vulkan_headers//:vulkan_headers_no_prototypes", + ], +) + +cc_library( + name = "native_event", + srcs = ["native_event.cc"], + hdrs = ["native_event.h"], + deps = [ + ":handle_util", + "///hal:event", + "@vulkan_headers//:vulkan_headers_no_prototypes", + ], +) + +cc_library( + name = "pipeline_cache", + srcs = ["pipeline_cache.cc"], + hdrs = ["pipeline_cache.h"], + deps = [ + ":handle_util", + ":pipeline_executable", + ":status_util", + "///base:source_location", + "///base:status", + "///base:tracing", + "///hal:executable", + "///hal:executable_cache", + "///hal:executable_format", + "///schemas:spirv_executable_def_cc_fbs", + "@com_github_google_flatbuffers//:flatbuffers", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/synchronization", + "@vulkan_headers//:vulkan_headers_no_prototypes", + ], +) + +cc_library( + name = "pipeline_executable", + srcs = ["pipeline_executable.cc"], + hdrs = ["pipeline_executable.h"], + deps = [ + ":handle_util", + ":status_util", + "///base:memory", + "///base:source_location", + "///base:status", + "///base:tracing", + "///hal:executable", + "///hal:executable_cache", + "///hal:executable_spec", + "///schemas:spirv_executable_def_cc_fbs", + "@com_google_absl//absl/container:inlined_vector", + "@vulkan_headers//:vulkan_headers_no_prototypes", + ], +) + +cc_library( + name = "status_util", + srcs = ["status_util.cc"], + hdrs = ["status_util.h"], + deps = [ + "///base:status", + "@vulkan_headers//:vulkan_headers_no_prototypes", + ], +) + +cc_library( + name = "vma_allocator", + srcs = [ + "internal_vk_mem_alloc.cc", + "internal_vk_mem_alloc.h", + "vma_allocator.cc", + "vma_buffer.cc", + ], + hdrs = [ + "vma_allocator.h", + "vma_buffer.h", + ], + copts = [ + # Only needed in the implementation cc and not by external users. + "-DVMA_STATIC_VULKAN_FUNCTIONS=0", + ], + deps = [ + ":dynamic_symbols", + ":handle_util", + ":status_util", + "///base:logging", + "///base:source_location", + "///base:status", + "///base:tracing", + "///hal:allocator", + "///hal:buffer", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/synchronization", + "@vulkan_headers//:vulkan_headers_no_prototypes", + "@vulkan_memory_allocator//:impl_header_only", + ], +) + +cc_library( + name = "vulkan_device", + srcs = ["vulkan_device.cc"], + hdrs = ["vulkan_device.h"], + deps = [ + ":descriptor_pool_cache", + ":direct_command_buffer", + ":direct_command_queue", + ":dynamic_symbols", + ":extensibility_util", + ":handle_util", + ":legacy_fence", + ":native_binary_semaphore", + ":native_event", + ":pipeline_cache", + ":status_util", + ":vma_allocator", + "///base:memory", + "///base:status", + "///base:tracing", + "///hal:allocator", + "///hal:command_buffer_validation", + "///hal:command_queue", + "///hal:device", + "///hal:fence", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@vulkan_headers//:vulkan_headers_no_prototypes", + ], +) + +cc_library( + name = "vulkan_driver", + srcs = ["vulkan_driver.cc"], + hdrs = ["vulkan_driver.h"], + deps = [ + ":debug_reporter", + ":dynamic_symbols", + ":extensibility_util", + ":status_util", + ":vulkan_device", + "///base:memory", + "///base:status", + "///base:tracing", + "///hal:device_info", + "///hal:driver", + "@com_google_absl//absl/container:inlined_vector", + "@vulkan_headers//:vulkan_headers_no_prototypes", + ], +) + +cc_library( + name = "vulkan_driver_module", + srcs = ["vulkan_driver_module.cc"], + deps = [ + ":dynamic_symbols", + ":vulkan_driver", + "///base:init", + "///base:status", + "///base:tracing", + "///hal:driver_registry", + "@com_google_absl//absl/flags:flag", + ], + alwayslink = 1, +)
diff --git a/iree/hal/vulkan/CMakeLists.txt b/hal/vulkan/CMakeLists.txt similarity index 100% rename from iree/hal/vulkan/CMakeLists.txt rename to hal/vulkan/CMakeLists.txt
diff --git a/hal/vulkan/debug_reporter.cc b/hal/vulkan/debug_reporter.cc new file mode 100644 index 0000000..4558d27 --- /dev/null +++ b/hal/vulkan/debug_reporter.cc
@@ -0,0 +1,160 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/vulkan/debug_reporter.h" + +#include "base/tracing.h" +#include "hal/vulkan/status_util.h" + +namespace iree { +namespace hal { +namespace vulkan { + +namespace { + +// NOTE: |user_data| may be nullptr if we are being called during instance +// creation. Otherwise it is a pointer to the DebugReporter instance. + +// NOTE: this callback must be thread safe and must be careful not to reach too +// far outside of the call - it is called in-context from arbitrary threads with +// some amount of Vulkan state on the stack. Assume that creating or deleting +// Vulkan objects, issuing most Vulkan commands, etc are off-limits. + +VKAPI_ATTR VkBool32 VKAPI_CALL DebugUtilsMessageCallback( + VkDebugUtilsMessageSeverityFlagBitsEXT message_severity, + VkDebugUtilsMessageTypeFlagsEXT message_type, + const VkDebugUtilsMessengerCallbackDataEXT* callback_data, + void* user_data) { + // TODO(benvanik): better logging once we have switched logging APIs. + LOG(ERROR) << callback_data->pMessage; + + return VK_FALSE; // VK_TRUE is reserved for future use. +} + +VKAPI_ATTR VkBool32 VKAPI_CALL DebugReportCallback( + VkDebugReportFlagsEXT flags, VkDebugReportObjectTypeEXT object_type, + uint64_t object, size_t location, int32_t message_code, + const char* layer_prefix, const char* message, void* user_data) { + // TODO(benvanik): better logging once we have switched logging APIs. + LOG(ERROR) << message; + + return VK_FALSE; // VK_TRUE is reserved for future use. +} + +} // namespace + +// static +void DebugReporter::PopulateStaticCreateInfo( + VkDebugUtilsMessengerCreateInfoEXT* create_info) { + create_info->sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_MESSENGER_CREATE_INFO_EXT; + create_info->pNext = nullptr; + create_info->flags = 0; + + // TODO(benvanik): only enable the severities that logging has enabled. + create_info->messageSeverity = + VK_DEBUG_UTILS_MESSAGE_SEVERITY_VERBOSE_BIT_EXT | + VK_DEBUG_UTILS_MESSAGE_SEVERITY_INFO_BIT_EXT | + VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT | + VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT; + + // TODO(benvanik): allow filtering by category as a flag. + create_info->messageType = VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT | + VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT | + VK_DEBUG_UTILS_MESSAGE_TYPE_PERFORMANCE_BIT_EXT; + + create_info->pfnUserCallback = DebugUtilsMessageCallback; + create_info->pUserData = nullptr; +} + +// static +void DebugReporter::PopulateStaticCreateInfo( + VkDebugReportCallbackCreateInfoEXT* create_info) { + create_info->sType = VK_STRUCTURE_TYPE_DEBUG_REPORT_CALLBACK_CREATE_INFO_EXT; + create_info->pNext = nullptr; + create_info->flags = 0; + + // TODO(benvanik): only enable the severities that logging has enabled. + create_info->flags |= + VK_DEBUG_REPORT_INFORMATION_BIT_EXT | VK_DEBUG_REPORT_WARNING_BIT_EXT | + VK_DEBUG_REPORT_PERFORMANCE_WARNING_BIT_EXT | + VK_DEBUG_REPORT_ERROR_BIT_EXT | VK_DEBUG_REPORT_DEBUG_BIT_EXT; + + create_info->pfnCallback = DebugReportCallback; + create_info->pUserData = nullptr; +} + +// static +StatusOr<std::unique_ptr<DebugReporter>> +DebugReporter::CreateDebugUtilsMessenger( + VkInstance instance, const ref_ptr<DynamicSymbols>& syms, + const VkAllocationCallbacks* allocation_callbacks) { + IREE_TRACE_SCOPE0("DebugReporter::CreateDebugUtilsMessenger"); + + auto debug_reporter = + absl::WrapUnique(new DebugReporter(instance, syms, allocation_callbacks)); + + VkDebugUtilsMessengerCreateInfoEXT create_info; + PopulateStaticCreateInfo(&create_info); + create_info.pUserData = debug_reporter.get(); + + VK_RETURN_IF_ERROR(syms->vkCreateDebugUtilsMessengerEXT( + instance, &create_info, allocation_callbacks, + &debug_reporter->messenger_)); + + return debug_reporter; +} + +// static +StatusOr<std::unique_ptr<DebugReporter>> +DebugReporter::CreateDebugReportCallback( + VkInstance instance, const ref_ptr<DynamicSymbols>& syms, + const VkAllocationCallbacks* allocation_callbacks) { + IREE_TRACE_SCOPE0("DebugReporter::CreateDebugReportCallback"); + + auto debug_reporter = + absl::WrapUnique(new DebugReporter(instance, syms, allocation_callbacks)); + + VkDebugReportCallbackCreateInfoEXT create_info; + PopulateStaticCreateInfo(&create_info); + create_info.pUserData = debug_reporter.get(); + + VK_RETURN_IF_ERROR(syms->vkCreateDebugReportCallbackEXT( + instance, &create_info, allocation_callbacks, + &debug_reporter->callback_)); + + return debug_reporter; +} + +DebugReporter::DebugReporter(VkInstance instance, + const ref_ptr<DynamicSymbols>& syms, + const VkAllocationCallbacks* allocation_callbacks) + : instance_(instance), + syms_(add_ref(syms)), + allocation_callbacks_(allocation_callbacks) {} + +DebugReporter::~DebugReporter() { + IREE_TRACE_SCOPE0("DebugReporter::dtor"); + if (messenger_ != VK_NULL_HANDLE) { + syms_->vkDestroyDebugUtilsMessengerEXT(instance_, messenger_, + allocation_callbacks_); + } + if (callback_ != VK_NULL_HANDLE) { + syms_->vkDestroyDebugReportCallbackEXT(instance_, callback_, + allocation_callbacks_); + } +} + +} // namespace vulkan +} // namespace hal +} // namespace iree
diff --git a/hal/vulkan/debug_reporter.h b/hal/vulkan/debug_reporter.h new file mode 100644 index 0000000..2eff99d --- /dev/null +++ b/hal/vulkan/debug_reporter.h
@@ -0,0 +1,87 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_VULKAN_DEBUG_REPORTER_H_ +#define IREE_HAL_VULKAN_DEBUG_REPORTER_H_ + +#include <vulkan/vulkan.h> + +#include "base/status.h" +#include "hal/vulkan/dynamic_symbols.h" + +namespace iree { +namespace hal { +namespace vulkan { + +// A debug reporter that works with the VK_EXT_debug_utils extension. +// One reporter should be created per VkInstance to receive callbacks from the +// API and route them to our logging systems. In general VK_EXT_debug_utils +// should be preferred if available as it provides a much cleaner interface and +// more plug-points than VK_EXT_debug_report. +// +// Since creating a reporter requires a VkInstance it's not possible to report +// on messages during instance creation. To work around this it's possible to +// pass a *CreateInfo struct to vkCreateInstance as part of the +// VkInstanceCreateInfo::pNext chain. The callback will only be used this way +// during the creation call after which users can create the real +// instance-specific reporter. +class DebugReporter final { + public: + // Populates |create_info| with an instance-agnostic callback. + // This can be used during instance creation by chaining the |create_info| to + // VkInstanceCreateInfo::pNext. + // + // Only use if VK_EXT_debug_utils is present. + static void PopulateStaticCreateInfo( + VkDebugUtilsMessengerCreateInfoEXT* create_info); + + // Populates |create_info| with an instance-agnostic callback. + // This can be used during instance creation by chaining the |create_info| to + // VkInstanceCreateInfo::pNext. + // + // Only use if VK_EXT_debug_report is present. + static void PopulateStaticCreateInfo( + VkDebugReportCallbackCreateInfoEXT* create_info); + + // Creates a debug messenger for the given Vulkan |instance| with + // VK_EXT_debug_utils enabled. + static StatusOr<std::unique_ptr<DebugReporter>> CreateDebugUtilsMessenger( + VkInstance instance, const ref_ptr<DynamicSymbols>& syms, + const VkAllocationCallbacks* allocation_callbacks); + + // Creates a debug report callback for the given Vulkan |instance| with + // VK_EXT_debug_report enabled. + static StatusOr<std::unique_ptr<DebugReporter>> CreateDebugReportCallback( + VkInstance instance, const ref_ptr<DynamicSymbols>& syms, + const VkAllocationCallbacks* allocation_callbacks); + + ~DebugReporter(); + + private: + DebugReporter(VkInstance instance, const ref_ptr<DynamicSymbols>& syms, + const VkAllocationCallbacks* allocation_callbacks); + + VkInstance instance_ = VK_NULL_HANDLE; + ref_ptr<DynamicSymbols> syms_; + const VkAllocationCallbacks* allocation_callbacks_ = nullptr; + + VkDebugUtilsMessengerEXT messenger_ = VK_NULL_HANDLE; + VkDebugReportCallbackEXT callback_ = VK_NULL_HANDLE; +}; + +} // namespace vulkan +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_VULKAN_DEBUG_REPORTER_H_
diff --git a/hal/vulkan/descriptor_pool_cache.cc b/hal/vulkan/descriptor_pool_cache.cc new file mode 100644 index 0000000..453fe78 --- /dev/null +++ b/hal/vulkan/descriptor_pool_cache.cc
@@ -0,0 +1,102 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/vulkan/descriptor_pool_cache.h" + +#include <array> + +#include "base/tracing.h" +#include "hal/vulkan/status_util.h" + +namespace iree { +namespace hal { +namespace vulkan { + +namespace { + +// TODO(benvanik): be more conservative with descriptor set count or allow +// chaining in the command buffer when pools run out. +static constexpr int kMaxDescriptorSets = 4096; + +} // namespace + +DescriptorSetGroup::~DescriptorSetGroup() { + CHECK(descriptor_pools_.empty()) + << "DescriptorSetGroup must be reset explicitly"; +} + +Status DescriptorSetGroup::Reset() { + IREE_TRACE_SCOPE0("DescriptorSetGroup::Reset"); + + RETURN_IF_ERROR(descriptor_pool_cache_->ReleaseDescriptorPools( + absl::MakeSpan(descriptor_pools_))); + descriptor_pools_.clear(); + + return OkStatus(); +} + +DescriptorPoolCache::DescriptorPoolCache(ref_ptr<VkDeviceHandle> logical_device) + : logical_device_(std::move(logical_device)) {} + +StatusOr<DescriptorPool> DescriptorPoolCache::AcquireDescriptorPool( + VkDescriptorType descriptor_type, int max_descriptor_count) { + IREE_TRACE_SCOPE0("DescriptorPoolCache::AcquireDescriptorPool"); + + // TODO(benvanik): lookup in cache. + + VkDescriptorPoolCreateInfo create_info; + create_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO; + create_info.pNext = nullptr; + create_info.flags = 0; + create_info.maxSets = kMaxDescriptorSets; + std::array<VkDescriptorPoolSize, 1> pool_sizes; + pool_sizes[0].type = descriptor_type; + pool_sizes[0].descriptorCount = max_descriptor_count; + create_info.poolSizeCount = pool_sizes.size(); + create_info.pPoolSizes = pool_sizes.data(); + + DescriptorPool descriptor_pool; + descriptor_pool.descriptor_type = descriptor_type; + descriptor_pool.max_descriptor_count = max_descriptor_count; + descriptor_pool.handle = VK_NULL_HANDLE; + + VK_RETURN_IF_ERROR(syms().vkCreateDescriptorPool( + *logical_device_, &create_info, logical_device_->allocator(), + &descriptor_pool.handle)); + + return descriptor_pool; +} + +Status DescriptorPoolCache::ReleaseDescriptorPools( + absl::Span<DescriptorPool> descriptor_pools) { + IREE_TRACE_SCOPE0("DescriptorPoolCache::ReleaseDescriptorPools"); + + for (const auto& descriptor_pool : descriptor_pools) { + // Always reset immediately. We could do this on allocation instead however + // this leads to better errors when using the validation layers as we'll + // throw if there are in-flight command buffers using the sets in the pool. + VK_RETURN_IF_ERROR(syms().vkResetDescriptorPool(*logical_device_, + descriptor_pool.handle, 0)); + + // TODO(benvanik): release to cache. + syms().vkDestroyDescriptorPool(*logical_device_, descriptor_pool.handle, + logical_device_->allocator()); + } + + return OkStatus(); +} + +} // namespace vulkan +} // namespace hal +} // namespace iree
diff --git a/hal/vulkan/descriptor_pool_cache.h b/hal/vulkan/descriptor_pool_cache.h new file mode 100644 index 0000000..8e54112 --- /dev/null +++ b/hal/vulkan/descriptor_pool_cache.h
@@ -0,0 +1,103 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_VULKAN_DESCRIPTOR_POOL_CACHE_H_ +#define IREE_HAL_VULKAN_DESCRIPTOR_POOL_CACHE_H_ + +#include "absl/container/inlined_vector.h" +#include "base/ref_ptr.h" +#include "hal/vulkan/dynamic_symbols.h" +#include "hal/vulkan/handle_util.h" + +namespace iree { +namespace hal { +namespace vulkan { + +class DescriptorPoolCache; + +// A descriptor pool with a single descriptor type of some number. +// We only support a single descriptor type for now as we only generate SPIR-V +// that uses a single type. +struct DescriptorPool { + // Type of the descriptor in the set. + VkDescriptorType descriptor_type = VK_DESCRIPTOR_TYPE_MAX_ENUM; + // Maximum number of descriptors of the given type per allocation. + int max_descriptor_count = 0; + // Pool handle. + VkDescriptorPool handle = VK_NULL_HANDLE; +}; + +// A group of descriptor sets allocated and released together. +// The group must be explicitly reset with Reset() prior to disposing. +class DescriptorSetGroup final { + public: + DescriptorSetGroup() = default; + DescriptorSetGroup(ref_ptr<DescriptorPoolCache> descriptor_pool_cache, + absl::InlinedVector<DescriptorPool, 8> descriptor_pools) + : descriptor_pool_cache_(std::move(descriptor_pool_cache)), + descriptor_pools_(std::move(descriptor_pools)) {} + DescriptorSetGroup(const DescriptorSetGroup&) = delete; + DescriptorSetGroup& operator=(const DescriptorSetGroup&) = delete; + DescriptorSetGroup(DescriptorSetGroup&& other) noexcept + : descriptor_pool_cache_(std::move(other.descriptor_pool_cache_)), + descriptor_pools_(std::move(other.descriptor_pools_)) {} + DescriptorSetGroup& operator=(DescriptorSetGroup&& other) { + std::swap(descriptor_pool_cache_, other.descriptor_pool_cache_); + std::swap(descriptor_pools_, other.descriptor_pools_); + return *this; + } + ~DescriptorSetGroup(); + + Status Reset(); + + private: + ref_ptr<DescriptorPoolCache> descriptor_pool_cache_; + absl::InlinedVector<DescriptorPool, 8> descriptor_pools_; +}; + +// A "cache" (or really, pool) of descriptor pools. These pools are allocated +// as needed to satisfy different descriptor size requirements and are given +// to command buffers during recording to write descriptor updates and bind +// resources. After the descriptors in the pool are no longer used (all +// command buffers using descriptor sets allocated from the pool have retired) +// the pool is returned here to be reused in the future. +class DescriptorPoolCache final : public RefObject<DescriptorPoolCache> { + public: + explicit DescriptorPoolCache(ref_ptr<VkDeviceHandle> logical_device); + + const ref_ptr<VkDeviceHandle>& logical_device() const { + return logical_device_; + } + const DynamicSymbols& syms() const { return *logical_device_->syms(); } + + // Acquires a new descriptor pool for use by the caller. + // The pool will have been reset and have all descriptor sets available. + // When all sets allocated from the pool are no longer in use it must be + // returned to the cache with ReleaseDescriptorPool. + StatusOr<DescriptorPool> AcquireDescriptorPool( + VkDescriptorType descriptor_type, int max_descriptor_count); + + // Releases descriptor pools back to the cache. The pools will be reset + // immediately and must no longer be in use by any in-flight command. + Status ReleaseDescriptorPools(absl::Span<DescriptorPool> descriptor_pools); + + private: + ref_ptr<VkDeviceHandle> logical_device_; +}; + +} // namespace vulkan +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_VULKAN_DESCRIPTOR_POOL_CACHE_H_
diff --git a/hal/vulkan/descriptor_set_arena.cc b/hal/vulkan/descriptor_set_arena.cc new file mode 100644 index 0000000..b31b583 --- /dev/null +++ b/hal/vulkan/descriptor_set_arena.cc
@@ -0,0 +1,204 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/vulkan/descriptor_set_arena.h" + +#include "base/math.h" +#include "base/tracing.h" +#include "hal/vulkan/status_util.h" +#include "hal/vulkan/vma_buffer.h" + +namespace iree { +namespace hal { +namespace vulkan { + +namespace { + +StatusOr<VmaBuffer*> CastBuffer(Buffer* buffer) { + // TODO(benvanik): assert that the buffer is from the right allocator and + // that it is compatible with our target queue family. + return static_cast<VmaBuffer*>(buffer->allocated_buffer()); +} + +StatusOr<absl::Span<VkWriteDescriptorSet>> PopulateDescriptorSetWriteInfos( + const PipelineDescriptorSets& pipeline_descriptor_sets, + absl::Span<const BufferBinding> bindings, VkDescriptorSet dst_set, + Arena* arena) { + int required_descriptor_count = + pipeline_descriptor_sets.buffer_binding_set_map.size(); + + arena->Reset(); + auto buffer_infos = + arena->AllocateSpan<VkDescriptorBufferInfo>(required_descriptor_count); + auto write_infos = arena->AllocateSpan<VkWriteDescriptorSet>(bindings.size()); + + for (int i = 0; i < bindings.size(); ++i) { + const auto& binding = bindings[i]; + + auto& buffer_info = buffer_infos[i]; + ASSIGN_OR_RETURN(auto buffer, CastBuffer(binding.buffer)); + buffer_info.buffer = buffer->handle(); + // TODO(benvanik): properly subrange (add to BufferBinding). + buffer_info.offset = binding.buffer->byte_offset(); + buffer_info.range = binding.buffer->byte_length(); + + auto& write_info = write_infos[i]; + write_info.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET; + write_info.pNext = nullptr; + write_info.dstSet = dst_set; + write_info.dstBinding = pipeline_descriptor_sets.buffer_binding_set_map[i]; + write_info.dstArrayElement = 0; + write_info.descriptorCount = 1; + write_info.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; + write_info.pImageInfo = nullptr; + write_info.pBufferInfo = &buffer_info; + write_info.pTexelBufferView = nullptr; + } + + return write_infos; +} + +} // namespace + +DescriptorSetArena::DescriptorSetArena( + ref_ptr<DescriptorPoolCache> descriptor_pool_cache) + : logical_device_(add_ref(descriptor_pool_cache->logical_device())), + descriptor_pool_cache_(std::move(descriptor_pool_cache)) {} + +DescriptorSetArena::~DescriptorSetArena() { + if (!used_descriptor_pools_.empty()) { + descriptor_pool_cache_ + ->ReleaseDescriptorPools(absl::MakeSpan(used_descriptor_pools_)) + .IgnoreError(); + used_descriptor_pools_.clear(); + } +} + +Status DescriptorSetArena::BindDescriptorSet( + VkCommandBuffer command_buffer, PipelineExecutable* executable, + absl::Span<const BufferBinding> bindings) { + // Always prefer using push descriptors when available as we can avoid the + // additional API overhead of updating/resetting pools. + if (logical_device_->enabled_extensions().push_descriptors) { + return PushDescriptorSet(command_buffer, executable, bindings); + } + + IREE_TRACE_SCOPE0("DescriptorSetArena::BindDescriptorSet"); + + // Pick a bucket based on the number of descriptors required. + // NOTE: right now we are 1:1 with bindings. + int required_descriptor_count = bindings.size() * 1; + int max_descriptor_count = + std::max(8, RoundUpToNearestPow2(required_descriptor_count)); + int bucket = TrailingZeros(max_descriptor_count >> 3); + if (bucket >= descriptor_pool_buckets_.size()) { + return OutOfRangeErrorBuilder(IREE_LOC) + << "Too many descriptors required: " << required_descriptor_count + << " (max=" << (1 << (descriptor_pool_buckets_.size() + 3)) << ")"; + } + if (descriptor_pool_buckets_[bucket].handle == VK_NULL_HANDLE) { + // Acquire a pool for this max_descriptor_count bucket. + ASSIGN_OR_RETURN( + descriptor_pool_buckets_[bucket], + descriptor_pool_cache_->AcquireDescriptorPool( + VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, max_descriptor_count)); + used_descriptor_pools_.push_back(descriptor_pool_buckets_[bucket]); + } + auto& descriptor_pool = descriptor_pool_buckets_[bucket]; + + const auto& pipeline_descriptor_sets = executable->descriptor_sets(); + + VkDescriptorSetAllocateInfo allocate_info; + allocate_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO; + allocate_info.pNext = nullptr; + allocate_info.descriptorPool = descriptor_pool.handle; + allocate_info.descriptorSetCount = 1; + allocate_info.pSetLayouts = + &pipeline_descriptor_sets.buffer_binding_set_layout; + VkDescriptorSet descriptor_set = VK_NULL_HANDLE; + VkResult result = syms().vkAllocateDescriptorSets( + *logical_device_, &allocate_info, &descriptor_set); + if (result == VK_ERROR_OUT_OF_POOL_MEMORY) { + // Allocation failed because the pool is either out of descriptors or too + // fragmented. We'll just allocate another pool. + ASSIGN_OR_RETURN( + descriptor_pool_buckets_[bucket], + descriptor_pool_cache_->AcquireDescriptorPool( + VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, max_descriptor_count)); + used_descriptor_pools_.push_back(descriptor_pool_buckets_[bucket]); + } + + // Get a list of VkWriteDescriptorSet structs with all bound buffers. + ASSIGN_OR_RETURN(auto write_infos, PopulateDescriptorSetWriteInfos( + pipeline_descriptor_sets, bindings, + descriptor_set, &scratch_arena_)); + + // This is the reason why push descriptor sets are good. + // We can't batch these effectively as we don't know prior to recording what + // descriptor sets we will need and what buffers they will point to (without + // doing just as much work as actually recording the buffer to try to find + // out). + syms().vkUpdateDescriptorSets(*logical_device_, write_infos.size(), + write_infos.data(), 0, nullptr); + + // Bind the descriptor set. + syms().vkCmdBindDescriptorSets(command_buffer, VK_PIPELINE_BIND_POINT_COMPUTE, + executable->pipeline_layout(), + pipeline_descriptor_sets.buffer_binding_set, 1, + &descriptor_set, 0, nullptr); + + return OkStatus(); +} + +Status DescriptorSetArena::PushDescriptorSet( + VkCommandBuffer command_buffer, PipelineExecutable* executable, + absl::Span<const BufferBinding> bindings) { + IREE_TRACE_SCOPE0("DescriptorSetArena::PushDescriptorSet"); + + const auto& pipeline_descriptor_sets = executable->descriptor_sets(); + + // Get a list of VkWriteDescriptorSet structs with all bound buffers. + ASSIGN_OR_RETURN(auto write_infos, PopulateDescriptorSetWriteInfos( + pipeline_descriptor_sets, bindings, + VK_NULL_HANDLE, &scratch_arena_)); + + // Fast path using push descriptors. These are pooled internally by the + // command buffer and prevent the need for our own pooling mechanisms. + syms().vkCmdPushDescriptorSetKHR(command_buffer, + VK_PIPELINE_BIND_POINT_COMPUTE, + executable->pipeline_layout(), + pipeline_descriptor_sets.buffer_binding_set, + write_infos.size(), write_infos.data()); + + return OkStatus(); +} + +StatusOr<DescriptorSetGroup> DescriptorSetArena::Flush() { + IREE_TRACE_SCOPE0("DescriptorSetArena::Flush"); + + if (used_descriptor_pools_.empty()) { + // No resources to free. + return DescriptorSetGroup{}; + } + + for (auto& bucket : descriptor_pool_buckets_) { + bucket = {}; + } + return DescriptorSetGroup(add_ref(descriptor_pool_cache_), + std::move(used_descriptor_pools_)); +} + +} // namespace vulkan +} // namespace hal +} // namespace iree
diff --git a/hal/vulkan/descriptor_set_arena.h b/hal/vulkan/descriptor_set_arena.h new file mode 100644 index 0000000..5e07268 --- /dev/null +++ b/hal/vulkan/descriptor_set_arena.h
@@ -0,0 +1,76 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_VULKAN_DESCRIPTOR_SET_ARENA_H_ +#define IREE_HAL_VULKAN_DESCRIPTOR_SET_ARENA_H_ + +#include <array> +#include <vector> + +#include "base/arena.h" +#include "base/status.h" +#include "hal/command_buffer.h" +#include "hal/vulkan/descriptor_pool_cache.h" +#include "hal/vulkan/pipeline_executable.h" + +namespace iree { +namespace hal { +namespace vulkan { + +// A reusable arena for allocating descriptor sets and batching updates. +class DescriptorSetArena final { + public: + explicit DescriptorSetArena( + ref_ptr<DescriptorPoolCache> descriptor_pool_cache); + ~DescriptorSetArena(); + + // Allocates and binds a descriptor set from the arena. + // The command buffer will have the descriptor set containing |bindings| bound + // to it. + Status BindDescriptorSet(VkCommandBuffer command_buffer, + PipelineExecutable* executable, + absl::Span<const BufferBinding> bindings); + + // Flushes all pending writes to descriptor sets allocated from the arena and + // returns a group that - when dropped - will release the descriptor sets + // back to the pools they were allocated from. + StatusOr<DescriptorSetGroup> Flush(); + + private: + const DynamicSymbols& syms() const { return *logical_device_->syms(); } + + // Pushes the descriptor set to the command buffer, if supported. + Status PushDescriptorSet(VkCommandBuffer command_buffer, + PipelineExecutable* executable, + absl::Span<const BufferBinding> bindings); + + ref_ptr<VkDeviceHandle> logical_device_; + ref_ptr<DescriptorPoolCache> descriptor_pool_cache_; + + // Arena used for temporary binding information used during allocation. + Arena scratch_arena_; + + // A list of pools acquired on demand as different descriptor counts are + // needed. Allocation granularity is max_descriptor_count=[8, 16, 32, 64]. + std::array<DescriptorPool, 4> descriptor_pool_buckets_; + + // All pools that have been used during allocation. + absl::InlinedVector<DescriptorPool, 8> used_descriptor_pools_; +}; + +} // namespace vulkan +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_VULKAN_DESCRIPTOR_SET_ARENA_H_
diff --git a/hal/vulkan/direct_command_buffer.cc b/hal/vulkan/direct_command_buffer.cc new file mode 100644 index 0000000..13398f7 --- /dev/null +++ b/hal/vulkan/direct_command_buffer.cc
@@ -0,0 +1,403 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/vulkan/direct_command_buffer.h" + +#include "absl/base/attributes.h" +#include "absl/container/inlined_vector.h" +#include "absl/synchronization/mutex.h" +#include "base/math.h" +#include "base/source_location.h" +#include "base/status.h" +#include "base/tracing.h" +#include "hal/vulkan/status_util.h" + +namespace iree { +namespace hal { +namespace vulkan { + +namespace { + +VkPipelineStageFlags ConvertPipelineStageFlags( + ExecutionStageBitfield stage_mask) { + VkPipelineStageFlags flags = 0; + flags |= AnyBitSet(stage_mask & ExecutionStage::kCommandIssue) + ? VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT + : 0; + flags |= AnyBitSet(stage_mask & ExecutionStage::kCommandProcess) + ? VK_PIPELINE_STAGE_DRAW_INDIRECT_BIT + : 0; + flags |= AnyBitSet(stage_mask & ExecutionStage::kDispatch) + ? VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT + : 0; + flags |= AnyBitSet(stage_mask & ExecutionStage::kTransfer) + ? VK_PIPELINE_STAGE_TRANSFER_BIT + : 0; + flags |= AnyBitSet(stage_mask & ExecutionStage::kCommandRetire) + ? VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT + : 0; + flags |= AnyBitSet(stage_mask & ExecutionStage::kHost) + ? VK_PIPELINE_STAGE_HOST_BIT + : 0; + return flags; +} + +VkAccessFlags ConvertAccessMask(AccessScopeBitfield access_mask) { + VkAccessFlags flags = 0; + flags |= AnyBitSet(access_mask & AccessScope::kIndirectCommandRead) + ? VK_ACCESS_INDIRECT_COMMAND_READ_BIT + : 0; + flags |= AnyBitSet(access_mask & AccessScope::kConstantRead) + ? VK_ACCESS_UNIFORM_READ_BIT + : 0; + flags |= AnyBitSet(access_mask & AccessScope::kDispatchRead) + ? VK_ACCESS_SHADER_READ_BIT + : 0; + flags |= AnyBitSet(access_mask & AccessScope::kDispatchWrite) + ? VK_ACCESS_SHADER_WRITE_BIT + : 0; + flags |= AnyBitSet(access_mask & AccessScope::kTransferRead) + ? VK_ACCESS_TRANSFER_READ_BIT + : 0; + flags |= AnyBitSet(access_mask & AccessScope::kTransferWrite) + ? VK_ACCESS_TRANSFER_WRITE_BIT + : 0; + flags |= AnyBitSet(access_mask & AccessScope::kHostRead) + ? VK_ACCESS_HOST_READ_BIT + : 0; + flags |= AnyBitSet(access_mask & AccessScope::kHostWrite) + ? VK_ACCESS_HOST_WRITE_BIT + : 0; + flags |= AnyBitSet(access_mask & AccessScope::kMemoryRead) + ? VK_ACCESS_MEMORY_READ_BIT + : 0; + flags |= AnyBitSet(access_mask & AccessScope::kMemoryWrite) + ? VK_ACCESS_MEMORY_WRITE_BIT + : 0; + return flags; +} + +// Splats a pattern value of 1, 2, or 4 bytes out to a 4 byte value. +uint32_t SplatPattern(const void* pattern, size_t pattern_length) { + switch (pattern_length) { + case 1: { + uint32_t pattern_value = *static_cast<const uint8_t*>(pattern); + return (pattern_value << 24) | (pattern_value << 16) | + (pattern_value << 8) | pattern_value; + } + case 2: { + uint32_t pattern_value = *static_cast<const uint16_t*>(pattern); + return (pattern_value << 16) | pattern_value; + } + case 4: { + uint32_t pattern_value = *static_cast<const uint32_t*>(pattern); + return pattern_value; + } + default: + return 0; // Already verified that this should not be possible. + } +} + +} // namespace + +DirectCommandBuffer::DirectCommandBuffer( + Allocator* allocator, CommandBufferModeBitfield mode, + CommandCategoryBitfield command_categories, + ref_ptr<DescriptorPoolCache> descriptor_pool_cache, + ref_ptr<VkCommandPoolHandle> command_pool, VkCommandBuffer command_buffer) + : CommandBuffer(allocator, mode, command_categories), + command_pool_(std::move(command_pool)), + command_buffer_(command_buffer), + descriptor_set_arena_(std::move(descriptor_pool_cache)) {} + +DirectCommandBuffer::~DirectCommandBuffer() { + IREE_TRACE_SCOPE0("DirectCommandBuffer::dtor"); + descriptor_set_group_.Reset().IgnoreError(); + absl::MutexLock lock(command_pool_->mutex()); + syms()->vkFreeCommandBuffers(*command_pool_->logical_device(), *command_pool_, + 1, &command_buffer_); +} + +StatusOr<NativeEvent*> DirectCommandBuffer::CastEvent(Event* event) const { + // TODO(benvanik): assert the event is valid. + return static_cast<NativeEvent*>(event); +} + +StatusOr<VmaBuffer*> DirectCommandBuffer::CastBuffer(Buffer* buffer) const { + // TODO(benvanik): assert that the buffer is from the right allocator and + // that it is compatible with our target queue family. + return static_cast<VmaBuffer*>(buffer->allocated_buffer()); +} + +Status DirectCommandBuffer::Begin() { + IREE_TRACE_SCOPE0("DirectCommandBuffer::Begin"); + + is_recording_ = true; + + // NOTE: we require that command buffers not be recorded while they are + // in-flight so this is safe. + RETURN_IF_ERROR(descriptor_set_group_.Reset()); + + VkCommandBufferBeginInfo begin_info; + begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; + begin_info.pNext = nullptr; + begin_info.flags = AllBitsSet(mode(), CommandBufferMode::kOneShot) + ? VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT + : 0; + begin_info.pInheritanceInfo = nullptr; + VK_RETURN_IF_ERROR( + syms()->vkBeginCommandBuffer(command_buffer_, &begin_info)); + + return OkStatus(); +} + +Status DirectCommandBuffer::End() { + IREE_TRACE_SCOPE0("DirectCommandBuffer::End"); + + VK_RETURN_IF_ERROR(syms()->vkEndCommandBuffer(command_buffer_)); + + // Flush all pending descriptor set writes (if any). + ASSIGN_OR_RETURN(descriptor_set_group_, descriptor_set_arena_.Flush()); + + is_recording_ = false; + + return OkStatus(); +} + +Status DirectCommandBuffer::ExecutionBarrier( + ExecutionStageBitfield source_stage_mask, + ExecutionStageBitfield target_stage_mask, + absl::Span<const MemoryBarrier> memory_barriers, + absl::Span<const BufferBarrier> buffer_barriers) { + IREE_TRACE_SCOPE0("DirectCommandBuffer::ExecutionBarrier"); + + absl::InlinedVector<VkMemoryBarrier, 8> memory_barrier_infos( + memory_barriers.size()); + for (int i = 0; i < memory_barriers.size(); ++i) { + const auto& memory_barrier = memory_barriers[i]; + auto& info = memory_barrier_infos[i]; + info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; + info.pNext = nullptr; + info.srcAccessMask = ConvertAccessMask(memory_barrier.source_scope); + info.dstAccessMask = ConvertAccessMask(memory_barrier.target_scope); + } + + absl::InlinedVector<VkBufferMemoryBarrier, 8> buffer_barrier_infos( + buffer_barriers.size()); + for (int i = 0; i < buffer_barriers.size(); ++i) { + const auto& buffer_barrier = buffer_barriers[i]; + auto& info = buffer_barrier_infos[i]; + info.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER; + info.pNext = nullptr; + info.srcAccessMask = ConvertAccessMask(buffer_barrier.source_scope); + info.dstAccessMask = ConvertAccessMask(buffer_barrier.target_scope); + info.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; + info.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; + ASSIGN_OR_RETURN(auto* device_buffer, CastBuffer(buffer_barrier.buffer)); + info.buffer = device_buffer->handle(); + info.offset = buffer_barrier.offset; + info.size = buffer_barrier.length; + } + + syms()->vkCmdPipelineBarrier( + command_buffer_, ConvertPipelineStageFlags(source_stage_mask), + ConvertPipelineStageFlags(target_stage_mask), /*dependencyFlags=*/0, + memory_barrier_infos.size(), memory_barrier_infos.data(), + buffer_barrier_infos.size(), buffer_barrier_infos.data(), 0, nullptr); + + return OkStatus(); +} + +Status DirectCommandBuffer::SignalEvent( + Event* event, ExecutionStageBitfield source_stage_mask) { + IREE_TRACE_SCOPE0("DirectCommandBuffer::SignalEvent"); + ASSIGN_OR_RETURN(auto* device_event, CastEvent(event)); + syms()->vkCmdSetEvent(command_buffer_, device_event->handle(), + ConvertPipelineStageFlags(source_stage_mask)); + return OkStatus(); +} + +Status DirectCommandBuffer::ResetEvent( + Event* event, ExecutionStageBitfield source_stage_mask) { + IREE_TRACE_SCOPE0("DirectCommandBuffer::ResetEvent"); + ASSIGN_OR_RETURN(auto* device_event, CastEvent(event)); + syms()->vkCmdResetEvent(command_buffer_, device_event->handle(), + ConvertPipelineStageFlags(source_stage_mask)); + return OkStatus(); +} + +Status DirectCommandBuffer::WaitEvents( + absl::Span<Event*> events, ExecutionStageBitfield source_stage_mask, + ExecutionStageBitfield target_stage_mask, + absl::Span<const MemoryBarrier> memory_barriers, + absl::Span<const BufferBarrier> buffer_barriers) { + IREE_TRACE_SCOPE0("DirectCommandBuffer::WaitEvents"); + + absl::InlinedVector<VkEvent, 4> event_handles(events.size()); + for (int i = 0; i < events.size(); ++i) { + ASSIGN_OR_RETURN(auto* device_event, CastEvent(events[i])); + event_handles[i] = device_event->handle(); + } + + absl::InlinedVector<VkMemoryBarrier, 8> memory_barrier_infos( + memory_barriers.size()); + for (int i = 0; i < memory_barriers.size(); ++i) { + const auto& memory_barrier = memory_barriers[i]; + auto& info = memory_barrier_infos[i]; + info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; + info.pNext = nullptr; + info.srcAccessMask = ConvertAccessMask(memory_barrier.source_scope); + info.dstAccessMask = ConvertAccessMask(memory_barrier.target_scope); + } + + absl::InlinedVector<VkBufferMemoryBarrier, 8> buffer_barrier_infos( + buffer_barriers.size()); + for (int i = 0; i < buffer_barriers.size(); ++i) { + const auto& buffer_barrier = buffer_barriers[i]; + auto& info = buffer_barrier_infos[i]; + info.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER; + info.pNext = nullptr; + info.srcAccessMask = ConvertAccessMask(buffer_barrier.source_scope); + info.dstAccessMask = ConvertAccessMask(buffer_barrier.target_scope); + info.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; + info.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; + ASSIGN_OR_RETURN(auto* device_buffer, CastBuffer(buffer_barrier.buffer)); + info.buffer = device_buffer->handle(); + info.offset = buffer_barrier.offset; + info.size = buffer_barrier.length; + } + + syms()->vkCmdWaitEvents( + command_buffer_, event_handles.size(), event_handles.data(), + ConvertPipelineStageFlags(source_stage_mask), + ConvertPipelineStageFlags(target_stage_mask), memory_barrier_infos.size(), + memory_barrier_infos.data(), buffer_barrier_infos.size(), + buffer_barrier_infos.data(), 0, nullptr); + return OkStatus(); +} + +Status DirectCommandBuffer::FillBuffer(Buffer* target_buffer, + device_size_t target_offset, + device_size_t length, + const void* pattern, + size_t pattern_length) { + IREE_TRACE_SCOPE0("DirectCommandBuffer::FillBuffer"); + ASSIGN_OR_RETURN(auto* target_device_buffer, CastBuffer(target_buffer)); + + // Note that fill only accepts 4-byte aligned values so we need to splat out + // our variable-length pattern. + target_offset += target_buffer->byte_offset(); + uint32_t dword_pattern = SplatPattern(pattern, pattern_length); + syms()->vkCmdFillBuffer(command_buffer_, target_device_buffer->handle(), + target_offset, length, dword_pattern); + + return OkStatus(); +} + +Status DirectCommandBuffer::DiscardBuffer(Buffer* buffer) { + IREE_TRACE_SCOPE0("DirectCommandBuffer::DiscardBuffer"); + // NOTE: we could use this to prevent queue family transitions. + return OkStatus(); +} + +Status DirectCommandBuffer::UpdateBuffer(const void* source_buffer, + device_size_t source_offset, + Buffer* target_buffer, + device_size_t target_offset, + device_size_t length) { + IREE_TRACE_SCOPE0("DirectCommandBuffer::UpdateBuffer"); + ASSIGN_OR_RETURN(auto* target_device_buffer, CastBuffer(target_buffer)); + + // Vulkan only allows updates of <= 65536 because you really, really, really + // shouldn't do large updates like this (as it wastes command buffer space and + // may be slower than just using write-through mapped memory). The + // recommendation in the spec for larger updates is to split the single update + // into multiple updates over the entire desired range. + const auto* source_buffer_ptr = static_cast<const uint8_t*>(source_buffer); + target_offset += target_buffer->byte_offset(); + while (length > 0) { + device_size_t chunk_length = + std::min(static_cast<device_size_t>(65536u), length); + syms()->vkCmdUpdateBuffer(command_buffer_, target_device_buffer->handle(), + target_offset, chunk_length, source_buffer_ptr); + source_buffer_ptr += chunk_length; + target_offset += chunk_length; + length -= chunk_length; + } + + return OkStatus(); +} + +Status DirectCommandBuffer::CopyBuffer(Buffer* source_buffer, + device_size_t source_offset, + Buffer* target_buffer, + device_size_t target_offset, + device_size_t length) { + IREE_TRACE_SCOPE0("DirectCommandBuffer::CopyBuffer"); + ASSIGN_OR_RETURN(auto* source_device_buffer, CastBuffer(source_buffer)); + ASSIGN_OR_RETURN(auto* target_device_buffer, CastBuffer(target_buffer)); + + VkBufferCopy region; + region.srcOffset = source_buffer->byte_offset() + source_offset; + region.dstOffset = target_buffer->byte_offset() + target_offset; + region.size = length; + syms()->vkCmdCopyBuffer(command_buffer_, source_device_buffer->handle(), + target_device_buffer->handle(), 1, ®ion); + + return OkStatus(); +} + +Status DirectCommandBuffer::Dispatch(const DispatchRequest& dispatch_request) { + IREE_TRACE_SCOPE0("DirectCommandBuffer::Dispatch"); + + // Get the compiled and linked pipeline for the specified entry point and + // bind it to the command buffer. + auto* executable = + static_cast<PipelineExecutable*>(dispatch_request.executable); + ASSIGN_OR_RETURN(VkPipeline pipeline, executable->GetPipelineForEntryPoint( + dispatch_request.entry_point)); + syms()->vkCmdBindPipeline(command_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, + pipeline); + + // Either allocate, update, and bind a descriptor set or use push descriptor + // sets to use the command buffer pool when supported. + RETURN_IF_ERROR(descriptor_set_arena_.BindDescriptorSet( + command_buffer_, executable, dispatch_request.bindings)); + + // TODO(benvanik): divide workload by caps and issue multiple dispatches. + // TODO(benvanik): track local workgroup/subgroup size and divide into groups. + if (dispatch_request.workload_buffer) { + return UnimplementedErrorBuilder(IREE_LOC) + << "Dynamic dispatches not yet implemented"; + } + uint32_t group_count_x = dispatch_request.workload[0]; + uint32_t group_count_y = dispatch_request.workload[1]; + uint32_t group_count_z = dispatch_request.workload[2]; + + // TODO(GH-67): pre-divide workload by tile size. + if (executable->is_matmul()) { + group_count_x = (group_count_x + 16 - 1) / 16; + group_count_y = (group_count_y + 16 - 1) / 16; + group_count_z = 1; + } + + syms()->vkCmdDispatch(command_buffer_, group_count_x, group_count_y, + group_count_z); + + return OkStatus(); +} + +} // namespace vulkan +} // namespace hal +} // namespace iree
diff --git a/hal/vulkan/direct_command_buffer.h b/hal/vulkan/direct_command_buffer.h new file mode 100644 index 0000000..9433dd3 --- /dev/null +++ b/hal/vulkan/direct_command_buffer.h
@@ -0,0 +1,103 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_VULKAN_DIRECT_COMMAND_BUFFER_H_ +#define IREE_HAL_VULKAN_DIRECT_COMMAND_BUFFER_H_ + +#include <vulkan/vulkan.h> + +#include "hal/command_buffer.h" +#include "hal/vulkan/descriptor_pool_cache.h" +#include "hal/vulkan/descriptor_set_arena.h" +#include "hal/vulkan/dynamic_symbols.h" +#include "hal/vulkan/handle_util.h" +#include "hal/vulkan/native_event.h" +#include "hal/vulkan/pipeline_executable.h" +#include "hal/vulkan/vma_buffer.h" + +namespace iree { +namespace hal { +namespace vulkan { + +// Command buffer implementation that directly maps to VkCommandBuffer. +// This records the commands on the calling thread without additional threading +// indirection. +class DirectCommandBuffer final : public CommandBuffer { + public: + DirectCommandBuffer(Allocator* allocator, CommandBufferModeBitfield mode, + CommandCategoryBitfield command_categories, + ref_ptr<DescriptorPoolCache> descriptor_pool_cache, + ref_ptr<VkCommandPoolHandle> command_pool, + VkCommandBuffer command_buffer); + ~DirectCommandBuffer() override; + + VkCommandBuffer handle() const { return command_buffer_; } + + bool is_recording() const override { return is_recording_; } + + Status Begin() override; + Status End() override; + + Status ExecutionBarrier( + ExecutionStageBitfield source_stage_mask, + ExecutionStageBitfield target_stage_mask, + absl::Span<const MemoryBarrier> memory_barriers, + absl::Span<const BufferBarrier> buffer_barriers) override; + Status SignalEvent(Event* event, + ExecutionStageBitfield source_stage_mask) override; + Status ResetEvent(Event* event, + ExecutionStageBitfield source_stage_mask) override; + Status WaitEvents(absl::Span<Event*> events, + ExecutionStageBitfield source_stage_mask, + ExecutionStageBitfield target_stage_mask, + absl::Span<const MemoryBarrier> memory_barriers, + absl::Span<const BufferBarrier> buffer_barriers) override; + + Status FillBuffer(Buffer* target_buffer, device_size_t target_offset, + device_size_t length, const void* pattern, + size_t pattern_length) override; + Status DiscardBuffer(Buffer* buffer) override; + Status UpdateBuffer(const void* source_buffer, device_size_t source_offset, + Buffer* target_buffer, device_size_t target_offset, + device_size_t length) override; + Status CopyBuffer(Buffer* source_buffer, device_size_t source_offset, + Buffer* target_buffer, device_size_t target_offset, + device_size_t length) override; + + Status Dispatch(const DispatchRequest& dispatch_request) override; + + private: + const ref_ptr<DynamicSymbols>& syms() const { return command_pool_->syms(); } + + StatusOr<NativeEvent*> CastEvent(Event* event) const; + StatusOr<VmaBuffer*> CastBuffer(Buffer* buffer) const; + + bool is_recording_ = false; + ref_ptr<VkCommandPoolHandle> command_pool_; + VkCommandBuffer command_buffer_; + + // TODO(b/140026716): may grow large - should try to reclaim or reuse. + DescriptorSetArena descriptor_set_arena_; + + // The current descriptor set group in use by the command buffer, if any. + // This must remain valid until all in-flight submissions of the command + // buffer complete. + DescriptorSetGroup descriptor_set_group_; +}; + +} // namespace vulkan +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_VULKAN_DIRECT_COMMAND_BUFFER_H_
diff --git a/hal/vulkan/direct_command_queue.cc b/hal/vulkan/direct_command_queue.cc new file mode 100644 index 0000000..c2c325e --- /dev/null +++ b/hal/vulkan/direct_command_queue.cc
@@ -0,0 +1,201 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/vulkan/direct_command_queue.h" + +#include <cstdint> + +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "base/memory.h" +#include "base/source_location.h" +#include "base/status.h" +#include "base/tracing.h" +#include "hal/vulkan/direct_command_buffer.h" +#include "hal/vulkan/legacy_fence.h" +#include "hal/vulkan/native_binary_semaphore.h" +#include "hal/vulkan/status_util.h" + +namespace iree { +namespace hal { +namespace vulkan { + +DirectCommandQueue::DirectCommandQueue( + std::string name, CommandCategoryBitfield supported_categories, + const ref_ptr<VkDeviceHandle>& logical_device, VkQueue queue) + : CommandQueue(std::move(name), supported_categories), + logical_device_(add_ref(logical_device)), + queue_(queue) {} + +DirectCommandQueue::~DirectCommandQueue() { + IREE_TRACE_SCOPE0("DirectCommandQueue::dtor"); + absl::MutexLock lock(&queue_mutex_); + syms()->vkQueueWaitIdle(queue_); +} + +Status DirectCommandQueue::TranslateBatchInfo(const SubmissionBatch& batch, + VkSubmitInfo* submit_info, + Arena* arena) { + // TODO(benvanik): see if we can go to finer-grained stages. + // For example, if this was just queue ownership transfers then we can use + // the pseudo-stage of VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT. + VkPipelineStageFlags dst_stage_mask = + VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT; + + auto wait_semaphore_handles = + arena->AllocateSpan<VkSemaphore>(batch.wait_semaphores.size()); + auto wait_dst_stage_masks = + arena->AllocateSpan<VkPipelineStageFlags>(batch.wait_semaphores.size()); + for (int i = 0; i < batch.wait_semaphores.size(); ++i) { + const auto& semaphore_value = batch.wait_semaphores[i]; + if (semaphore_value.index() == 0) { + const auto& binary_semaphore = + static_cast<NativeBinarySemaphore*>(absl::get<0>(semaphore_value)); + wait_semaphore_handles[i] = binary_semaphore->handle(); + } else { + // TODO(b/140141417): implement timeline semaphores. + return UnimplementedErrorBuilder(IREE_LOC) + << "Timeline semaphores not yet implemented"; + } + wait_dst_stage_masks[i] = dst_stage_mask; + } + + auto signal_semaphore_handles = + arena->AllocateSpan<VkSemaphore>(batch.signal_semaphores.size()); + for (int i = 0; i < batch.signal_semaphores.size(); ++i) { + const auto& semaphore_value = batch.signal_semaphores[i]; + if (semaphore_value.index() == 0) { + const auto& binary_semaphore = + static_cast<NativeBinarySemaphore*>(absl::get<0>(semaphore_value)); + signal_semaphore_handles[i] = binary_semaphore->handle(); + } else { + // TODO(b/140141417): implement timeline semaphores. + return UnimplementedErrorBuilder(IREE_LOC) + << "Timeline semaphores not yet implemented"; + } + } + + auto command_buffer_handles = + arena->AllocateSpan<VkCommandBuffer>(batch.command_buffers.size()); + for (int i = 0; i < batch.command_buffers.size(); ++i) { + const auto& command_buffer = batch.command_buffers[i]; + auto* direct_command_buffer = + static_cast<DirectCommandBuffer*>(command_buffer->impl()); + command_buffer_handles[i] = direct_command_buffer->handle(); + } + + submit_info->sType = VK_STRUCTURE_TYPE_SUBMIT_INFO; + submit_info->pNext = nullptr; + submit_info->waitSemaphoreCount = wait_semaphore_handles.size(); + submit_info->pWaitSemaphores = wait_semaphore_handles.data(); + submit_info->pWaitDstStageMask = wait_dst_stage_masks.data(); + submit_info->commandBufferCount = command_buffer_handles.size(); + submit_info->pCommandBuffers = command_buffer_handles.data(); + submit_info->signalSemaphoreCount = signal_semaphore_handles.size(); + submit_info->pSignalSemaphores = signal_semaphore_handles.data(); + + return OkStatus(); +} + +Status DirectCommandQueue::Submit(absl::Span<const SubmissionBatch> batches, + FenceValue fence) { + IREE_TRACE_SCOPE0("DirectCommandQueue::Submit"); + + // Map the submission batches to VkSubmitInfos. + // Note that we must keep all arrays referenced alive until submission + // completes and since there are a bunch of them we use an arena. + Arena arena(4 * 1024); + auto submit_infos = arena.AllocateSpan<VkSubmitInfo>(batches.size()); + for (int i = 0; i < batches.size(); ++i) { + RETURN_IF_ERROR(TranslateBatchInfo(batches[i], &submit_infos[i], &arena)); + } + + // TODO(b/140141417): implement timeline semaphore fences and switch here. + auto legacy_fence = reinterpret_cast<LegacyFence*>(fence.first); + ASSIGN_OR_RETURN(VkFence fence_handle, + legacy_fence->AcquireSignalFence(fence.second)); + + { + absl::MutexLock lock(&queue_mutex_); + VK_RETURN_IF_ERROR(syms()->vkQueueSubmit( + queue_, submit_infos.size(), submit_infos.data(), fence_handle)); + } + + return OkStatus(); +} + +Status DirectCommandQueue::WaitIdle(absl::Time deadline) { + if (deadline == absl::InfiniteFuture()) { + // Fast path for using vkQueueWaitIdle, which is usually cheaper (as it + // requires fewer calls into the driver). + IREE_TRACE_SCOPE0("DirectCommandQueue::WaitIdle#vkQueueWaitIdle"); + absl::MutexLock lock(&queue_mutex_); + VK_RETURN_IF_ERROR(syms()->vkQueueWaitIdle(queue_)); + return OkStatus(); + } + + IREE_TRACE_SCOPE0("DirectCommandQueue::WaitIdle#Fence"); + + // Create a new fence just for this wait. This keeps us thread-safe as the + // behavior of wait+reset is racey. + VkFenceCreateInfo create_info; + create_info.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO; + create_info.pNext = nullptr; + create_info.flags = 0; + VkFence fence = VK_NULL_HANDLE; + VK_RETURN_IF_ERROR(syms()->vkCreateFence( + *logical_device_, &create_info, logical_device_->allocator(), &fence)); + auto fence_cleanup = MakeCleanup([this, fence]() { + syms()->vkDestroyFence(*logical_device_, fence, + logical_device_->allocator()); + }); + + uint64_t timeout; + if (deadline == absl::InfinitePast()) { + // Do not wait. + timeout = 0; + } else if (deadline == absl::InfiniteFuture()) { + // Wait forever. + timeout = UINT64_MAX; + } else { + // Convert to relative time in nanoseconds. + // The implementation may not wait with this granularity (like, by 10000x). + absl::Time now = absl::Now(); + if (deadline < now) { + return DeadlineExceededErrorBuilder(IREE_LOC) << "Deadline in the past"; + } + timeout = static_cast<uint64_t>(absl::ToInt64Nanoseconds(deadline - now)); + } + + { + absl::MutexLock lock(&queue_mutex_); + VK_RETURN_IF_ERROR(syms()->vkQueueSubmit(queue_, 0, nullptr, fence)); + } + + VkResult result = + syms()->vkWaitForFences(*logical_device_, 1, &fence, VK_TRUE, timeout); + switch (result) { + case VK_SUCCESS: + return OkStatus(); + case VK_TIMEOUT: + return DeadlineExceededErrorBuilder(IREE_LOC) + << "Deadline exceeded waiting for idle"; + default: + return VkResultToStatus(result); + } +} + +} // namespace vulkan +} // namespace hal +} // namespace iree
diff --git a/hal/vulkan/direct_command_queue.h b/hal/vulkan/direct_command_queue.h new file mode 100644 index 0000000..85b6f60 --- /dev/null +++ b/hal/vulkan/direct_command_queue.h
@@ -0,0 +1,69 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_VULKAN_DIRECT_COMMAND_QUEUE_H_ +#define IREE_HAL_VULKAN_DIRECT_COMMAND_QUEUE_H_ + +#include <vulkan/vulkan.h> + +#include <cstdint> +#include <string> + +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "base/arena.h" +#include "base/status.h" +#include "hal/command_queue.h" +#include "hal/vulkan/dynamic_symbols.h" +#include "hal/vulkan/handle_util.h" + +namespace iree { +namespace hal { +namespace vulkan { + +// Command queue implementation directly maps to VkQueue. +class DirectCommandQueue final : public CommandQueue { + public: + DirectCommandQueue(std::string name, + CommandCategoryBitfield supported_categories, + const ref_ptr<VkDeviceHandle>& logical_device, + VkQueue queue); + ~DirectCommandQueue() override; + + const ref_ptr<DynamicSymbols>& syms() const { + return logical_device_->syms(); + } + + Status Submit(absl::Span<const SubmissionBatch> batches, + FenceValue fence) override; + + Status WaitIdle(absl::Time deadline) override; + + private: + Status TranslateBatchInfo(const SubmissionBatch& batch, + VkSubmitInfo* submit_info, Arena* arena); + + ref_ptr<VkDeviceHandle> logical_device_; + + // VkQueue needs to be externally synchronized. + mutable absl::Mutex queue_mutex_; + VkQueue queue_ ABSL_GUARDED_BY(queue_mutex_); +}; + +} // namespace vulkan +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_VULKAN_DIRECT_COMMAND_QUEUE_H_
diff --git a/iree/hal/vulkan/dynamic_symbol_tables.h b/hal/vulkan/dynamic_symbol_tables.h similarity index 100% rename from iree/hal/vulkan/dynamic_symbol_tables.h rename to hal/vulkan/dynamic_symbol_tables.h
diff --git a/hal/vulkan/dynamic_symbols.cc b/hal/vulkan/dynamic_symbols.cc new file mode 100644 index 0000000..b4823f6 --- /dev/null +++ b/hal/vulkan/dynamic_symbols.cc
@@ -0,0 +1,238 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/vulkan/dynamic_symbols.h" + +#include <cstddef> +#include <cstdlib> + +#include "absl/base/attributes.h" +#include "absl/base/macros.h" +#include "absl/memory/memory.h" +#include "base/file_path.h" +#include "base/source_location.h" +#include "base/status.h" +#include "base/target_platform.h" +#include "base/tracing.h" +#include "hal/vulkan/dynamic_symbol_tables.h" + +#if defined(IREE_PLATFORM_WINDOWS) +#include <windows.h> +#else +#include <dlfcn.h> +#endif // IREE_PLATFORM_WINDOWS + +namespace iree { +namespace hal { +namespace vulkan { + +// Read-only table of function pointer information designed to be in .rdata. +// To reduce binary size this structure is packed (knowing that we won't have +// gigabytes of function pointers :). +struct FunctionPtrInfo { + // Name of the function (like 'vkSomeFunction'). + const char* function_name; + // 1 if the function pointer can be resolved via vkGetDeviceProcAddr. + uint32_t is_device : 1; + // 1 if the function is required and the loader should bail if not found. + uint32_t is_required : 1; + // TODO(benvanik): remove from table by manually walking sizeof(uintptr_t). + // An offset in bytes from the base of &syms to where the PFN_vkSomeFunction + // member is located. + uint32_t member_offset : 30; +} ABSL_ATTRIBUTE_PACKED; + +namespace { + +#define REQUIRED_PFN_FUNCTION_PTR(function_name, is_device) \ + {#function_name, is_device, 1, offsetof(DynamicSymbols, function_name)}, +#define OPTIONAL_PFN_FUNCTION_PTR(function_name, is_device) \ + {#function_name, is_device, 0, offsetof(DynamicSymbols, function_name)}, +#define EXCLUDED_PFN_FUNCTION_PTR(function_name, is_device) +#define INS_PFN_FUNCTION_PTR(requirement, function_name) \ + requirement##_PFN_FUNCTION_PTR(function_name, 0) +#define DEV_PFN_FUNCTION_PTR(requirement, function_name) \ + requirement##_PFN_FUNCTION_PTR(function_name, 1) + +// Defines the table of mandatory FunctionPtrInfos resolved prior to instance +// creation. These are safe to call with no instance parameter and should be +// exported by all loaders/ICDs. +static constexpr const FunctionPtrInfo kInstancelessFunctionPtrInfos[] = { + REQUIRED_PFN_FUNCTION_PTR(vkCreateInstance, false) // + REQUIRED_PFN_FUNCTION_PTR(vkEnumerateInstanceLayerProperties, false) // + REQUIRED_PFN_FUNCTION_PTR(vkEnumerateInstanceExtensionProperties, false) // +}; + +// Defines the table of FunctionPtrInfos for dynamic loading that must wait +// until an instance has been created to be resolved. +static constexpr const FunctionPtrInfo kDynamicFunctionPtrInfos[] = { + IREE_VULKAN_DYNAMIC_SYMBOL_TABLES(INS_PFN_FUNCTION_PTR, + DEV_PFN_FUNCTION_PTR)}; + +} // namespace + +// static +StatusOr<ref_ptr<DynamicSymbols>> DynamicSymbols::Create( + const GetProcAddrFn& get_proc_addr) { + IREE_TRACE_SCOPE0("DynamicSymbols::Create"); + + auto syms = make_ref<DynamicSymbols>(); + + // Resolve the method the shared object uses to resolve other functions. + // Some libraries will export all symbols while others will only export this + // single function. + syms->vkGetInstanceProcAddr = reinterpret_cast<PFN_vkGetInstanceProcAddr>( + get_proc_addr("vkGetInstanceProcAddr")); + if (!syms->vkGetInstanceProcAddr) { + return UnavailableErrorBuilder(IREE_LOC) + << "Required method vkGetInstanceProcAddr not " + "found in provided Vulkan library (did you pick the wrong file?)"; + } + + // Resolve the mandatory functions that we need to create instances. + // If the provided |get_proc_addr| cannot resolve these then it's not a loader + // or ICD we want to use, anyway. + for (int i = 0; i < ABSL_ARRAYSIZE(kInstancelessFunctionPtrInfos); ++i) { + const auto& function_ptr = kInstancelessFunctionPtrInfos[i]; + auto* member_ptr = reinterpret_cast<PFN_vkVoidFunction*>( + reinterpret_cast<uint8_t*>(syms.get()) + function_ptr.member_offset); + *member_ptr = + syms->vkGetInstanceProcAddr(VK_NULL_HANDLE, function_ptr.function_name); + if (*member_ptr == nullptr) { + return UnavailableErrorBuilder(IREE_LOC) + << "Mandatory Vulkan function " << function_ptr.function_name + << " not available; invalid loader/ICD?"; + } + } + + return syms; +} + +// static +StatusOr<ref_ptr<DynamicSymbols>> DynamicSymbols::CreateFromSystemLoader() { + IREE_TRACE_SCOPE0("DynamicSymbols::CreateFromSystemLoader"); + +#if defined(IREE_VK_ICD_FILENAMES) +#define IREE_STRINGIFY_(x) #x +#define IREE_STRING_(x) IREE_STRINGIFY_(x) + std::string vk_icd_filenames = IREE_STRING_(IREE_VK_ICD_FILENAMES); +#undef IREE_STRINGIFY_ +#undef IREE_STRING_ +#if defined(IREE_PLATFORM_WINDOWS) + // TODO(b/138220713): Set VK_ICD_FILENAMES on Windows +#else + ::setenv("VK_ICD_FILENAMES", vk_icd_filenames.c_str(), 0); +#endif // IREE_PLATFORM_WINDOWS +#else + // Leave VK_ICD_FILENAMES unchanged and rely on the system Vulkan loader to + // discover ICDs. +#endif // IREE_VK_ICD_FILENAMES + +// NOTE: we could factor this out into base, but this is the only place we use +// it right now so it's fine. +#if defined(IREE_PLATFORM_WINDOWS) + HMODULE library = ::LoadLibraryA("vulkan-1.dll"); + if (!library) { + return UnavailableErrorBuilder(IREE_LOC) + << "Unable to open vulkan-1.dll; driver not installed/on PATH"; + } + ASSIGN_OR_RETURN(auto syms, Create([library](const char* function_name) { + return reinterpret_cast<PFN_vkVoidFunction>( + ::GetProcAddress(library, function_name)); + })); + syms->close_fn_ = [library]() { + // TODO(benvanik): disable if we want to get profiling results. Sometimes + // closing the library can prevent proper symbolization on crashes or + // in sampling profilers. + ::FreeLibrary(library); + }; + return syms; +#else + void* library = ::dlopen("libvulkan.so.1", RTLD_LAZY | RTLD_LOCAL); + if (!library) { + return UnavailableErrorBuilder(IREE_LOC) + << "Unable to open libvulkan.so; driver not installed/on " + "LD_LIBRARY_PATH"; + } + ASSIGN_OR_RETURN(auto syms, Create([library](const char* function_name) { + return reinterpret_cast<PFN_vkVoidFunction>( + ::dlsym(library, function_name)); + })); + syms->close_fn_ = [library]() { + // TODO(benvanik): disable if we want to get profiling results. Sometimes + // closing the library can prevent proper symbolization on crashes or + // in sampling profilers. + ::dlclose(library); + }; + return syms; +#endif // IREE_PLATFORM_WINDOWS +} + +Status DynamicSymbols::LoadFromInstance(VkInstance instance) { + IREE_TRACE_SCOPE0("DynamicSymbols::LoadFromInstance"); + return LoadFromDevice(instance, VK_NULL_HANDLE); +} + +Status DynamicSymbols::LoadFromDevice(VkInstance instance, VkDevice device) { + IREE_TRACE_SCOPE0("DynamicSymbols::LoadFromDevice"); + + if (!instance) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Instance must have been created and a default instance proc " + "lookup function is required"; + } + + // Setup the lookup methods first. The rest of the syms uses these to + // resolve function pointers. + this->vkGetDeviceProcAddr = reinterpret_cast<PFN_vkGetDeviceProcAddr>( + this->vkGetInstanceProcAddr(instance, "vkGetDeviceProcAddr")); + if (!this->vkGetDeviceProcAddr) { + return UnavailableErrorBuilder(IREE_LOC) + << "Required Vulkan function vkGetDeviceProcAddr not available; " + "invalid driver handle?"; + } + + // Load the rest of the functions. + for (int i = 0; i < ABSL_ARRAYSIZE(kDynamicFunctionPtrInfos); ++i) { + const auto& function_ptr = kDynamicFunctionPtrInfos[i]; + auto* member_ptr = reinterpret_cast<PFN_vkVoidFunction*>( + reinterpret_cast<uint8_t*>(this) + function_ptr.member_offset); + if (function_ptr.is_device && device) { + *member_ptr = + this->vkGetDeviceProcAddr(device, function_ptr.function_name); + } else { + *member_ptr = + this->vkGetInstanceProcAddr(instance, function_ptr.function_name); + } + if (*member_ptr == nullptr && function_ptr.is_required) { + return UnavailableErrorBuilder(IREE_LOC) + << "Required Vulkan function " << function_ptr.function_name + << " not available"; + } + } + + return OkStatus(); +} + +DynamicSymbols::DynamicSymbols() = default; + +DynamicSymbols::~DynamicSymbols() { + if (close_fn_) { + close_fn_(); + } +} + +} // namespace vulkan +} // namespace hal +} // namespace iree
diff --git a/hal/vulkan/dynamic_symbols.h b/hal/vulkan/dynamic_symbols.h new file mode 100644 index 0000000..adc95c8 --- /dev/null +++ b/hal/vulkan/dynamic_symbols.h
@@ -0,0 +1,129 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_VULKAN_DYNAMIC_SYMBOLS_H_ +#define IREE_HAL_VULKAN_DYNAMIC_SYMBOLS_H_ + +#include <vulkan/vulkan.h> + +#include <cstdint> +#include <functional> +#include <memory> + +#include "base/ref_ptr.h" +#include "base/status.h" +#include "hal/vulkan/dynamic_symbol_tables.h" + +namespace iree { +namespace hal { +namespace vulkan { + +struct FunctionPtrInfo; + +// Dynamic Vulkan function loader for use with vulkan.hpp. +// This loader is a subset of the DispatchLoaderDynamic implementation that only +// loads functions we are interested in (a compute-specific subset) and avoids +// extensions we will never use. +// +// This exposes all Vulkan methods as function pointer members. Optional +// methods will be nullptr if not present. Excluded methods will be omitted. +// +// DynamicSymbols instances are designed to be passed to vulkan.hpp methods as +// the last argument, though they may also be called directly. +// **Always make sure to pass the loader to vulkan.hpp methods!** +// +// Loading is performed by walking a table of required and optional functions +// (defined in dynamic_symbol_tables.h) and populating the member function +// pointers exposed on this struct when available. For example, if the +// vkSomeFunction method is marked in the table as OPTIONAL the loader will +// attempt to lookup the function and if successful set the +// DynamicSymbols::vkSomeFunction pointer to the resolved address. If the +// function is not found then it will be set to nullptr so users can check for +// function availability. +// +// Documentation: +// https://github.com/KhronosGroup/Vulkan-Hpp#extensions--per-device-function-pointers +// +// Usage: +// ASSIGN_OR_RETURN(auto syms, DynamicSymbols::CreateFromSystemLoader()); +// VkInstance instance = VK_NULL_HANDLE; +// syms->vkCreateInstance(..., &instance); +// RETURN_IF_ERROR(syms->LoadFromInstance(instance)); +struct DynamicSymbols : public RefObject<DynamicSymbols> { + using GetProcAddrFn = + std::function<PFN_vkVoidFunction(const char* function_name)>; + + DynamicSymbols(); + ~DynamicSymbols(); + + // Creates the dynamic symbol table using the given |get_proc_addr| to resolve + // the vkCreateInstance function. + // + // After the instance is created the caller must use LoadFromInstance (or + // LoadFromDevice) to load the remaining symbols. + static StatusOr<ref_ptr<DynamicSymbols>> Create( + const GetProcAddrFn& get_proc_addr); + + // Loads all required and optional Vulkan functions from the Vulkan loader. + // This will look for a Vulkan loader on the system (like libvulkan.so) and + // dlsym the functions from that. + // + // The loaded function pointers will point to thunks in the ICD. This may + // enable additional debug checking and more readable stack traces (as + // errors come from within the ICD, where we have symbols). + static StatusOr<ref_ptr<DynamicSymbols>> CreateFromSystemLoader(); + + // Loads all required and optional Vulkan functions from the given instance. + // + // The loaded function pointers will point to thunks in the ICD. This may + // enable additional debug checking and more readable stack traces (as + // errors come from within the ICD, where we have symbols). + Status LoadFromInstance(VkInstance instance); + + // Loads all required and optional Vulkan functions from the given device, + // falling back to the instance when required. + // + // This attempts to directly query the methods from the device, bypassing any + // ICD or shim layers. These methods will generally have less overhead at + // runtime as they need not jump through the various trampolines. + Status LoadFromDevice(VkInstance instance, VkDevice device); + + // Define members for each function pointer. + // See dynamic_symbol_tables.h for the full list of methods. + // + // Each required and optional function in the loader tables will expand to + // the following member, such as for example 'vkSomeFunction': + // PFN_vkSomeFunction vkSomeFunction; +#define REQUIRED_PFN(function_name) PFN_##function_name function_name +#define OPTIONAL_PFN(function_name) PFN_##function_name function_name +#define EXCLUDED_PFN(function_name) +#define PFN_MEMBER(requirement, function_name) requirement##_PFN(function_name); + REQUIRED_PFN(vkGetInstanceProcAddr); + REQUIRED_PFN(vkGetDeviceProcAddr); + IREE_VULKAN_DYNAMIC_SYMBOL_TABLES(PFN_MEMBER, PFN_MEMBER); +#undef REQUIRED_PFN +#undef OPTIONAL_PFN +#undef EXCLUDED_PFN +#undef PFN_MEMBER + + private: + // Optional callback on loader destruction. + std::function<void()> close_fn_; +}; + +} // namespace vulkan +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_VULKAN_DYNAMIC_SYMBOLS_H_
diff --git a/hal/vulkan/dynamic_symbols_test.cc b/hal/vulkan/dynamic_symbols_test.cc new file mode 100644 index 0000000..3bf0638 --- /dev/null +++ b/hal/vulkan/dynamic_symbols_test.cc
@@ -0,0 +1,73 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/vulkan/dynamic_symbols.h" + +#include "base/status_matchers.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "hal/vulkan/status_util.h" + +namespace iree { +namespace hal { +namespace vulkan { +namespace { + +VkApplicationInfo GetApplicationInfo() { + VkApplicationInfo app_info; + app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; + app_info.pNext = nullptr; + app_info.pApplicationName = "IREE-ML-TEST"; + app_info.applicationVersion = 0; + app_info.pEngineName = "IREE"; + app_info.engineVersion = 0; + app_info.apiVersion = VK_API_VERSION_1_0; + return app_info; +} + +VkInstanceCreateInfo GetInstanceCreateInfo(VkApplicationInfo* app_info) { + VkInstanceCreateInfo create_info; + create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; + create_info.pNext = nullptr; + create_info.flags = 0; + create_info.pApplicationInfo = app_info; + create_info.enabledLayerCount = 0; + create_info.ppEnabledLayerNames = nullptr; + create_info.enabledExtensionCount = 0; + create_info.ppEnabledExtensionNames = nullptr; + return create_info; +} + +TEST(DynamicSymbolsTest, CreateFromSystemLoader) { + auto status_or_syms = DynamicSymbols::CreateFromSystemLoader(); + ASSERT_OK(status_or_syms); + ref_ptr<DynamicSymbols> syms = std::move(status_or_syms.ValueOrDie()); + + // Create and destroy a VkInstance using the symbols. This is mainly testing + // that the symbols were loaded successfully and are actually able to be used. + VkApplicationInfo app_info = GetApplicationInfo(); + VkInstanceCreateInfo create_info = GetInstanceCreateInfo(&app_info); + VkInstance instance = VK_NULL_HANDLE; + VK_CHECK_OK( + syms->vkCreateInstance(&create_info, /*pAllocator=*/nullptr, &instance)); + + ASSERT_OK(syms->LoadFromInstance(instance)); + + syms->vkDestroyInstance(instance, /*pAllocator=*/nullptr); +} + +} // namespace +} // namespace vulkan +} // namespace hal +} // namespace iree
diff --git a/hal/vulkan/extensibility_util.cc b/hal/vulkan/extensibility_util.cc new file mode 100644 index 0000000..09d9f54 --- /dev/null +++ b/hal/vulkan/extensibility_util.cc
@@ -0,0 +1,221 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/vulkan/extensibility_util.h" + +#include "base/memory.h" +#include "base/status.h" +#include "base/tracing.h" +#include "hal/vulkan/status_util.h" + +namespace iree { +namespace hal { +namespace vulkan { + +namespace { + +StatusOr<std::vector<const char*>> MatchAvailableLayers( + absl::Span<const char* const> required_layers, + absl::Span<const char* const> optional_layers, + absl::Span<const VkLayerProperties> properties) { + IREE_TRACE_SCOPE0("MatchAvailableLayers"); + + std::vector<const char*> enabled_layers; + enabled_layers.reserve(required_layers.size() + optional_layers.size()); + + for (const char* layer_name : required_layers) { + bool found = false; + for (const auto& layer_properties : properties) { + if (std::strcmp(layer_name, layer_properties.layerName) == 0) { + VLOG(1) << "Enabling required layer: " << layer_name; + found = true; + enabled_layers.push_back(layer_name); + break; + } + } + if (!found) { + return UnavailableErrorBuilder(IREE_LOC) + << "Required layer " << layer_name << " not available"; + } + } + + for (const char* layer_name : optional_layers) { + bool found = false; + for (const auto& layer_properties : properties) { + if (std::strcmp(layer_name, layer_properties.layerName) == 0) { + VLOG(1) << "Enabling optional layer: " << layer_name; + found = true; + enabled_layers.push_back(layer_name); + break; + } + } + if (!found) { + VLOG(1) << "Optional layer " << layer_name << " not available"; + } + } + + return enabled_layers; +} + +StatusOr<std::vector<const char*>> MatchAvailableExtensions( + absl::Span<const char* const> required_extensions, + absl::Span<const char* const> optional_extensions, + absl::Span<const VkExtensionProperties> properties) { + IREE_TRACE_SCOPE0("MatchAvailableExtensions"); + + std::vector<const char*> enabled_extensions; + enabled_extensions.reserve(required_extensions.size() + + optional_extensions.size()); + + for (const char* extension_name : required_extensions) { + bool found = false; + for (const auto& extension_properties : properties) { + if (std::strcmp(extension_name, extension_properties.extensionName) == + 0) { + VLOG(1) << "Enabling required extension: " << extension_name; + found = true; + enabled_extensions.push_back(extension_name); + break; + } + } + if (!found) { + return UnavailableErrorBuilder(IREE_LOC) + << "Required extension " << extension_name << " not available"; + } + } + + for (const char* extension_name : optional_extensions) { + bool found = false; + for (const auto& extension_properties : properties) { + if (std::strcmp(extension_name, extension_properties.extensionName) == + 0) { + VLOG(1) << "Enabling optional extension: " << extension_name; + found = true; + enabled_extensions.push_back(extension_name); + break; + } + } + if (!found) { + VLOG(1) << "Optional extension " << extension_name << " not available"; + } + } + + return enabled_extensions; +} + +} // namespace + +StatusOr<std::vector<const char*>> MatchAvailableInstanceLayers( + const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms) { + uint32_t layer_property_count = 0; + VK_RETURN_IF_ERROR( + syms.vkEnumerateInstanceLayerProperties(&layer_property_count, nullptr)); + std::vector<VkLayerProperties> layer_properties(layer_property_count); + VK_RETURN_IF_ERROR(syms.vkEnumerateInstanceLayerProperties( + &layer_property_count, layer_properties.data())); + ASSIGN_OR_RETURN(auto enabled_layers, + MatchAvailableLayers(extensibility_spec.required_layers, + extensibility_spec.optional_layers, + layer_properties), + _ << "Unable to find all required instance layers"); + return enabled_layers; +} + +StatusOr<std::vector<const char*>> MatchAvailableInstanceExtensions( + const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms) { + uint32_t extension_property_count = 0; + // Warning: leak checks remain disabled if an error is returned. + IREE_DISABLE_LEAK_CHECKS(); + VK_RETURN_IF_ERROR(syms.vkEnumerateInstanceExtensionProperties( + nullptr, &extension_property_count, nullptr)); + std::vector<VkExtensionProperties> extension_properties( + extension_property_count); + VK_RETURN_IF_ERROR(syms.vkEnumerateInstanceExtensionProperties( + nullptr, &extension_property_count, extension_properties.data())); + ASSIGN_OR_RETURN( + auto enabled_extensions, + MatchAvailableExtensions(extensibility_spec.required_extensions, + extensibility_spec.optional_extensions, + extension_properties), + _ << "Unable to find all required instance extensions"); + IREE_ENABLE_LEAK_CHECKS(); + return enabled_extensions; +} + +StatusOr<std::vector<const char*>> MatchAvailableDeviceLayers( + VkPhysicalDevice physical_device, + const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms) { + uint32_t layer_property_count = 0; + VK_RETURN_IF_ERROR(syms.vkEnumerateDeviceLayerProperties( + physical_device, &layer_property_count, nullptr)); + std::vector<VkLayerProperties> layer_properties(layer_property_count); + VK_RETURN_IF_ERROR(syms.vkEnumerateDeviceLayerProperties( + physical_device, &layer_property_count, layer_properties.data())); + ASSIGN_OR_RETURN(auto enabled_layers, + MatchAvailableLayers(extensibility_spec.required_layers, + extensibility_spec.optional_layers, + layer_properties), + _ << "Unable to find all required device layers"); + return enabled_layers; +} + +StatusOr<std::vector<const char*>> MatchAvailableDeviceExtensions( + VkPhysicalDevice physical_device, + const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms) { + uint32_t extension_property_count = 0; + VK_RETURN_IF_ERROR(syms.vkEnumerateDeviceExtensionProperties( + physical_device, nullptr, &extension_property_count, nullptr)); + std::vector<VkExtensionProperties> extension_properties( + extension_property_count); + VK_RETURN_IF_ERROR(syms.vkEnumerateDeviceExtensionProperties( + physical_device, nullptr, &extension_property_count, + extension_properties.data())); + ASSIGN_OR_RETURN( + auto enabled_extensions, + MatchAvailableExtensions(extensibility_spec.required_extensions, + extensibility_spec.optional_extensions, + extension_properties), + _ << "Unable to find all required device extensions"); + return enabled_extensions; +} + +InstanceExtensions PopulateEnabledInstanceExtensions( + absl::Span<const char* const> extension_names) { + InstanceExtensions extensions = {0}; + for (const char* extension_name : extension_names) { + if (std::strcmp(extension_name, VK_EXT_DEBUG_REPORT_EXTENSION_NAME) == 0) { + extensions.debug_report = true; + } else if (std::strcmp(extension_name, VK_EXT_DEBUG_UTILS_EXTENSION_NAME) == + 0) { + extensions.debug_utils = true; + } + } + return extensions; +} + +DeviceExtensions PopulateEnabledDeviceExtensions( + absl::Span<const char* const> extension_names) { + DeviceExtensions extensions = {0}; + for (const char* extension_name : extension_names) { + if (std::strcmp(extension_name, VK_KHR_PUSH_DESCRIPTOR_EXTENSION_NAME) == + 0) { + extensions.push_descriptors = true; + } + } + return extensions; +} + +} // namespace vulkan +} // namespace hal +} // namespace iree
diff --git a/hal/vulkan/extensibility_util.h b/hal/vulkan/extensibility_util.h new file mode 100644 index 0000000..3e4d725 --- /dev/null +++ b/hal/vulkan/extensibility_util.h
@@ -0,0 +1,100 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Utilities for working with layers and extensions. + +#ifndef IREE_HAL_VULKAN_EXTENSIBILITY_UTIL_H_ +#define IREE_HAL_VULKAN_EXTENSIBILITY_UTIL_H_ + +#include <vulkan/vulkan.h> + +#include "absl/types/span.h" +#include "base/status.h" +#include "hal/vulkan/dynamic_symbols.h" + +namespace iree { +namespace hal { +namespace vulkan { + +// Describes required and optional extensibility points. +struct ExtensibilitySpec { + // A list of required and optional layers. + std::vector<const char*> required_layers; + std::vector<const char*> optional_layers; + + // A list of required and optional extensions. + // Prefer using the _EXTENSION_NAME macros to make tracking easier (such as + // 'VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME'). + std::vector<const char*> required_extensions; + std::vector<const char*> optional_extensions; +}; + +// Returns a list of layer names available for instances. +// Fails if any required_layers are unavailable. +StatusOr<std::vector<const char*>> MatchAvailableInstanceLayers( + const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms); + +// Returns a list of extension names available for instances. +// Fails if any required_extensions are unavailable. +StatusOr<std::vector<const char*>> MatchAvailableInstanceExtensions( + const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms); + +// Returns a list of layer names available for the given |physical_device|. +// Fails if any required_layers are unavailable. +StatusOr<std::vector<const char*>> MatchAvailableDeviceLayers( + VkPhysicalDevice physical_device, + const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms); + +// Returns a list of extension names available for the given |physical_device|. +// Fails if any required_extensions are unavailable. +StatusOr<std::vector<const char*>> MatchAvailableDeviceExtensions( + VkPhysicalDevice physical_device, + const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms); + +// Bits for enabled instance extensions. +// We must use this to query support instead of just detecting symbol names as +// ICDs will resolve the functions sometimes even if they don't support the +// extension (or we didn't ask for it to be enabled). +struct InstanceExtensions { + // VK_EXT_debug_report is enabled and a callback is regsitered. + // https://www.khronos.org/registry/vulkan/specs/1.1-extensions/html/chap44.html#VK_EXT_debug_report + bool debug_report : 1; + + // VK_EXT_debug_utils is enabled and a debug messenger is registered. + // https://www.khronos.org/registry/vulkan/specs/1.1-extensions/html/chap44.html#VK_EXT_debug_utils + bool debug_utils : 1; +}; + +// Returns a bitfield with all of the provided extension names. +InstanceExtensions PopulateEnabledInstanceExtensions( + absl::Span<const char* const> extension_names); + +// Bits for enabled device extensions. +// We must use this to query support instead of just detecting symbol names as +// ICDs will resolve the functions sometimes even if they don't support the +// extension (or we didn't ask for it to be enabled). +struct DeviceExtensions { + // VK_KHR_push_descriptor is enabled and vkCmdPushDescriptorSetKHR is valid. + bool push_descriptors : 1; +}; + +// Returns a bitfield with all of the provided extension names. +DeviceExtensions PopulateEnabledDeviceExtensions( + absl::Span<const char* const> extension_names); + +} // namespace vulkan +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_VULKAN_EXTENSIBILITY_UTIL_H_
diff --git a/hal/vulkan/handle_util.h b/hal/vulkan/handle_util.h new file mode 100644 index 0000000..c12e4d1 --- /dev/null +++ b/hal/vulkan/handle_util.h
@@ -0,0 +1,136 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Helpers for wrapping Vulkan handles that don't require us to wrap every type. +// This keeps our compilation time reasonable (as the vulkancpp library is +// insane) while giving us nice safety around cleanup and ensuring we use +// dynamic symbols and consistent allocators. +// +// Do not add functionality beyond handle management to these types. Keep our +// Vulkan usage mostly functional and C-like to ensure minimal code size and +// readability. + +#ifndef IREE_HAL_VULKAN_HANDLE_UTIL_H_ +#define IREE_HAL_VULKAN_HANDLE_UTIL_H_ + +#include <vulkan/vulkan.h> + +#include "absl/synchronization/mutex.h" +#include "absl/utility/utility.h" +#include "base/ref_ptr.h" +#include "hal/vulkan/dynamic_symbols.h" +#include "hal/vulkan/extensibility_util.h" + +namespace iree { +namespace hal { +namespace vulkan { + +class VkDeviceHandle : public RefObject<VkDeviceHandle> { + public: + VkDeviceHandle(const ref_ptr<DynamicSymbols>& syms, + DeviceExtensions enabled_extensions, + const VkAllocationCallbacks* allocator = nullptr) + : syms_(add_ref(syms)), + enabled_extensions_(enabled_extensions), + allocator_(allocator) {} + ~VkDeviceHandle() { reset(); } + + VkDeviceHandle(const VkDeviceHandle&) = delete; + VkDeviceHandle& operator=(const VkDeviceHandle&) = delete; + VkDeviceHandle(VkDeviceHandle&& other) noexcept + : value_(absl::exchange(other.value_, + static_cast<VkDevice>(VK_NULL_HANDLE))), + syms_(std::move(other.syms_)), + enabled_extensions_(other.enabled_extensions_), + allocator_(other.allocator_) {} + + void reset() { + if (value_ == VK_NULL_HANDLE) return; + syms_->vkDestroyDevice(value_, allocator_); + value_ = VK_NULL_HANDLE; + } + + VkDevice value() const noexcept { return value_; } + VkDevice* mutable_value() noexcept { return &value_; } + operator VkDevice() const noexcept { return value_; } + + const ref_ptr<DynamicSymbols>& syms() const noexcept { return syms_; } + const VkAllocationCallbacks* allocator() const noexcept { return allocator_; } + + const DeviceExtensions& enabled_extensions() const { + return enabled_extensions_; + } + + private: + VkDevice value_ = VK_NULL_HANDLE; + ref_ptr<DynamicSymbols> syms_; + DeviceExtensions enabled_extensions_; + const VkAllocationCallbacks* allocator_ = nullptr; +}; + +class VkCommandPoolHandle : public RefObject<VkCommandPoolHandle> { + public: + explicit VkCommandPoolHandle(const ref_ptr<VkDeviceHandle>& logical_device) + : logical_device_(add_ref(logical_device)) {} + ~VkCommandPoolHandle() { reset(); } + + VkCommandPoolHandle(const VkCommandPoolHandle&) = delete; + VkCommandPoolHandle& operator=(const VkCommandPoolHandle&) = delete; + VkCommandPoolHandle(VkCommandPoolHandle&& other) noexcept + : logical_device_(std::move(other.logical_device_)), + value_(absl::exchange(other.value_, + static_cast<VkCommandPool>(VK_NULL_HANDLE))) {} + VkCommandPoolHandle& operator=(VkCommandPoolHandle&& other) { + std::swap(logical_device_, other.logical_device_); + std::swap(value_, other.value_); + return *this; + } + + void reset() { + if (value_ == VK_NULL_HANDLE) return; + syms()->vkDestroyCommandPool(*logical_device_, value_, allocator()); + value_ = VK_NULL_HANDLE; + } + + VkCommandPool value() const noexcept { return value_; } + VkCommandPool* mutable_value() noexcept { return &value_; } + operator VkCommandPool() const noexcept { return value_; } + + const ref_ptr<VkDeviceHandle>& logical_device() const noexcept { + return logical_device_; + } + const ref_ptr<DynamicSymbols>& syms() const noexcept { + return logical_device_->syms(); + } + const VkAllocationCallbacks* allocator() const noexcept { + return logical_device_->allocator(); + } + + absl::Mutex* mutex() const { return &mutex_; } + + private: + ref_ptr<VkDeviceHandle> logical_device_; + VkCommandPool value_ = VK_NULL_HANDLE; + + // Vulkan command pools are not thread safe and require external + // synchronization. Since we allow arbitrary threads to allocate and + // deallocate the HAL command buffers we need to externally synchronize. + mutable absl::Mutex mutex_; +}; + +} // namespace vulkan +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_VULKAN_HANDLE_UTIL_H_
diff --git a/hal/vulkan/internal_vk_mem_alloc.cc b/hal/vulkan/internal_vk_mem_alloc.cc new file mode 100644 index 0000000..f3d317f --- /dev/null +++ b/hal/vulkan/internal_vk_mem_alloc.cc
@@ -0,0 +1,68 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file configures VMA to use common Google/Abseil types in an effort to +// better integrate with applications compiled using other Google code. By using +// the same types that dependers are likely using we can often reduce binary +// size and ease debugging (such as by using absl::Mutex to get better tsan +// warnings). + +// Only compile if an external implementation has not been otherwise linked. +#if !defined(VULKAN_MEMORY_ALLOCATOR_EXTERNAL_IMPL) + +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" +#include "base/logging.h" + +// Use std::vector instead of the VMA version. +#define VMA_USE_STL_VECTOR 1 + +// TODO(benvanik): figure out why std::list cannot be used. +// #define VMA_USE_STL_LIST 1 + +// Use absl::flat_hash_map instead of std::unordered_map. +#define VmaPair std::pair +#define VMA_MAP_TYPE(KeyT, ValueT) \ + absl::flat_hash_map<KeyT, ValueT, std::hash<KeyT>, std::equal_to<KeyT>, \ + VmaStlAllocator<std::pair<KeyT, ValueT> > > + +// Use CHECK for assertions. +#define VMA_ASSERT CHECK +#define VMA_HEAVY_ASSERT DCHECK + +// Use LOG for logging. +#ifndef NDEBUG +#define VMA_DEBUG_LOG(...) ABSL_RAW_LOG(INFO, __VA_ARGS__) +#else +#define VMA_DEBUG_LOG(...) +#endif // !NDEBUG + +// Use absl::Mutex for VMA_MUTEX. +#define VMA_MUTEX absl::Mutex +class AbslVmaRWMutex { + public: + void LockRead() ABSL_SHARED_LOCK_FUNCTION() { mutex_.ReaderLock(); } + void UnlockRead() ABSL_UNLOCK_FUNCTION() { mutex_.ReaderUnlock(); } + void LockWrite() ABSL_EXCLUSIVE_LOCK_FUNCTION() { mutex_.WriterLock(); } + void UnlockWrite() ABSL_UNLOCK_FUNCTION() { mutex_.WriterUnlock(); } + + private: + absl::Mutex mutex_; +}; +#define VMA_RW_MUTEX AbslVmaRWMutex + +#define VMA_IMPLEMENTATION +#include "vk_mem_alloc.h" + +#endif
diff --git a/iree/hal/vulkan/internal_vk_mem_alloc.h b/hal/vulkan/internal_vk_mem_alloc.h similarity index 100% rename from iree/hal/vulkan/internal_vk_mem_alloc.h rename to hal/vulkan/internal_vk_mem_alloc.h
diff --git a/hal/vulkan/legacy_fence.cc b/hal/vulkan/legacy_fence.cc new file mode 100644 index 0000000..7639817 --- /dev/null +++ b/hal/vulkan/legacy_fence.cc
@@ -0,0 +1,396 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/vulkan/legacy_fence.h" + +#include <cstdint> + +#include "absl/container/inlined_vector.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "base/intrusive_list.h" +#include "base/source_location.h" +#include "base/status.h" +#include "base/tracing.h" +#include "hal/vulkan/status_util.h" + +namespace iree { +namespace hal { +namespace vulkan { + +namespace { + +// Inserts the given |fence_signal| into |list| in ascending order. +void InsertOutstandingFenceSignal(OutstandingFenceSignal* fence_signal, + IntrusiveList<OutstandingFenceSignal>* list) { + for (auto existing_signal : *list) { + if (existing_signal->value > fence_signal->value) { + list->insert(existing_signal, fence_signal); + return; + } + } + list->push_back(fence_signal); +} + +} // namespace + +// static +StatusOr<ref_ptr<LegacyFencePool>> LegacyFencePool::Create( + ref_ptr<VkDeviceHandle> logical_device) { + IREE_TRACE_SCOPE0("LegacyFencePool::Create"); + ref_ptr<LegacyFencePool> fence_pool( + new LegacyFencePool(std::move(logical_device))); + RETURN_IF_ERROR(fence_pool->PreallocateFences()); + return fence_pool; +} + +LegacyFencePool::LegacyFencePool(ref_ptr<VkDeviceHandle> logical_device) + : logical_device_(std::move(logical_device)) {} + +LegacyFencePool::~LegacyFencePool() { + IREE_TRACE_SCOPE0("LegacyFencePool::dtor"); + + absl::MutexLock lock(&mutex_); + for (auto& fence_signal : storage_) { + syms()->vkDestroyFence(*logical_device_, fence_signal.fence, + logical_device_->allocator()); + } + unused_fences_.clear(); + unresolved_fences_.clear(); +} + +Status LegacyFencePool::PreallocateFences() { + IREE_TRACE_SCOPE0("LegacyFencePool::PreallocateFences"); + + VkFenceCreateInfo create_info; + create_info.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO; + create_info.pNext = nullptr; + create_info.flags = 0; + + absl::MutexLock lock(&mutex_); + for (int i = 0; i < kMaxInFlightFenceCount; ++i) { + auto* fence_signal = &storage_[i]; + VK_RETURN_IF_ERROR(syms()->vkCreateFence(*logical_device_, &create_info, + logical_device_->allocator(), + &fence_signal->fence)); + unused_fences_.push_back(fence_signal); + } + + return OkStatus(); +} + +StatusOr<OutstandingFenceSignal*> LegacyFencePool::Acquire() { + IREE_TRACE_SCOPE0("LegacyFencePool::Acquire"); + + absl::MutexLock lock(&mutex_); + if (unused_fences_.empty()) { + return ResourceExhaustedErrorBuilder(IREE_LOC) + << "Fence pool out of unused fences"; + } + + auto* fence_signal = unused_fences_.front(); + unused_fences_.pop_front(); + return fence_signal; +} + +void LegacyFencePool::ReleaseResolved( + IntrusiveList<OutstandingFenceSignal>* fence_signals) { + IREE_TRACE_SCOPE0("LegacyFencePool::ReleaseResolved"); + + // Get a list of fences we need to reset. Note that not all fences may have + // been signaled and we can avoid resetting them. + absl::InlinedVector<VkFence, 8> handles; + handles.reserve(fence_signals->size()); + for (auto* fence_signal : *fence_signals) { + if (fence_signal->is_pending) { + handles.push_back(fence_signal->fence); + } + } + if (!handles.empty()) { + syms()->vkResetFences(*logical_device_, handles.size(), handles.data()); + } + + absl::MutexLock lock(&mutex_); + unused_fences_.merge_from(fence_signals); +} + +void LegacyFencePool::ReleaseUnresolved( + IntrusiveList<OutstandingFenceSignal>* fence_signals) { + IREE_TRACE_SCOPE0("LegacyFencePool::ReleaseUnresolved"); + + absl::MutexLock lock(&mutex_); + while (!fence_signals->empty()) { + auto* fence_signal = fence_signals->front(); + fence_signals->pop_front(); + if (fence_signal->is_pending) { + // Fence was submitted and may still have a pending signal on it. We can't + // reuse it until it has resolved. + // TODO(benvanik): fix these fences by reallocating? We aren't leaking + // here (technically) but we will exhaust the pool pretty quickly. + unresolved_fences_.push_back(fence_signal); + } else { + // Fence was never actually submitted so we can reuse it no problem. + unused_fences_.push_back(fence_signal); + } + } +} + +// static +Status LegacyFence::WaitForFences(VkDeviceHandle* logical_device, + absl::Span<const FenceValue> fences, + bool wait_all, absl::Time deadline) { + IREE_TRACE_SCOPE0("LegacyFence::WaitForFences"); + + // NOTE: we could pool this state too (probably right on the LegacyFencePool) + // or be smarter about using stack-allocated storage. The best idea is to use + // real timeline semaphores, though, so not much effort has been spent on + // optimizing this. + absl::InlinedVector<VkFence, 4> handles; + handles.reserve(fences.size()); + + // Loop over the fences and wait for any/all to signal. In wait_all mode we + // perform the bookkeeping to remove fences that have already been signaled so + // that we only wait on ones we need to (and possibly avoid making the vk call + // entirely!). + while (true) { + // Grab handles and acquire fences for all fences not yet at the requested + // timeline value. + for (const auto& fence_value : fences) { + auto* fence = reinterpret_cast<LegacyFence*>(fence_value.first); + // NOTE: this will return the sticky fence error if the fence has failed. + ASSIGN_OR_RETURN(VkFence handle, + fence->AcquireWaitFence(fence_value.second)); + if (handle != VK_NULL_HANDLE) { + // Fence is unresolved and we need to really wait for it. + handles.push_back(handle); + } + } + if (handles.empty()) { + // All fences resolved. + return OkStatus(); + } + + uint64_t timeout_nanos; + if (deadline == absl::InfiniteFuture()) { + timeout_nanos = UINT64_MAX; + } else if (deadline == absl::InfinitePast()) { + timeout_nanos = 0; + } else { + auto relative_nanos = absl::ToInt64Nanoseconds(deadline - absl::Now()); + timeout_nanos = relative_nanos < 0 ? 0 : relative_nanos; + } + + // Wait on the fences we still need. + // Note that waking does not actually indicate all fences were hit! We need + // to do another pass above on the next iteration to make sure that we don't + // need to wait again on another fence. + VK_RETURN_IF_ERROR(logical_device->syms()->vkWaitForFences( + *logical_device, handles.size(), handles.data(), wait_all, + timeout_nanos)); + handles.clear(); + } + + return OkStatus(); +} + +LegacyFence::LegacyFence(ref_ptr<LegacyFencePool> fence_pool, + uint64_t initial_value) + : fence_pool_(std::move(fence_pool)), value_(initial_value) {} + +LegacyFence::~LegacyFence() { + IREE_TRACE_SCOPE0("LegacyFence::dtor"); + CHECK_OK(TryResolveOutstandingFences(UINT64_MAX)); + absl::MutexLock lock(&mutex_); + CHECK(outstanding_signals_.empty()) + << "Destroying a fence without first waiting on outstanding signals"; +} + +Status LegacyFence::status() const { + if (value_.load() != UINT64_MAX) { + return OkStatus(); + } + absl::MutexLock lock(&mutex_); + return status_; +} + +StatusOr<uint64_t> LegacyFence::QueryValue() { + RETURN_IF_ERROR(TryResolveOutstandingFences(UINT64_MAX)); + return value_.load(); +} + +StatusOr<VkFence> LegacyFence::AcquireSignalFence(uint64_t value) { + absl::MutexLock lock(&mutex_); + + // It's an error to signal out of order (as that requires a lot more + // tracking and magic to get right). + if (value_.load() >= value) { + return FailedPreconditionErrorBuilder(IREE_LOC) + << "Attempting to signal a timeline fence out of order; value=" + << value_ << ", new_value=" << value; + } + + // Scan to see if there's waiters for this value (or values before it). + // We may be able to reuse a previously allocated fence in the case that a + // user is waiting prior to actually submitting the signal operation. + OutstandingFenceSignal* signal_state = nullptr; + for (auto* fence_signal : outstanding_signals_) { + if (fence_signal->value == value) { + // Fence is going to be signaled at exactly the required value. + if (fence_signal->is_pending) { + // Already have signaled to this value - that's a paddlin'. + return FailedPreconditionErrorBuilder(IREE_LOC) + << "Duplicate signal of timeline fence for value=" << value; + } + signal_state = fence_signal; + break; + } + } + if (!signal_state) { + // Allocate a signal state entry and a VkFence to submit with. + // TODO(benvanik): check for RESOURCE_EXHAUSTED and force a flush. + ASSIGN_OR_RETURN(signal_state, fence_pool_->Acquire()); + signal_state->value = value; + InsertOutstandingFenceSignal(signal_state, &outstanding_signals_); + } + + signal_state->is_pending = true; + return signal_state->fence; +} + +StatusOr<VkFence> LegacyFence::AcquireWaitFence(uint64_t value) { + // If we've already resolved then we want to avoid doing any kind of wait. + // Since the value is monotonically increasing we can do a lock-free peek + // here to see if we need to bother taking a full lock. + if (value_.load() >= value) { + return VK_NULL_HANDLE; + } + + absl::MutexLock lock(&mutex_); + + // Try to resolve any outstanding fence signals. + RETURN_IF_ERROR(TryResolveOutstandingFencesLocked(value)); + if (value_.load() >= value) { + return VK_NULL_HANDLE; + } + + // Try to find an existing fence we can reuse based on the required value. + OutstandingFenceSignal* signal_state = nullptr; + for (auto* fence_signal : outstanding_signals_) { + if (fence_signal->value >= value) { + // Fence is going to be signaled at or above the required value. + signal_state = fence_signal; + break; // |outstanding_signals_| is in sorted order. + } + } + if (!signal_state) { + // Allocate a signal state entry and a VkFence that we will need to signal + // in the future. We can't yet insert it into the queue but it will go in + // when the user tries to signal a value >= the required value. + // TODO(benvanik): check for RESOURCE_EXHAUSTED and force a flush. + ASSIGN_OR_RETURN(signal_state, fence_pool_->Acquire()); + signal_state->value = value; + InsertOutstandingFenceSignal(signal_state, &outstanding_signals_); + } + + return signal_state->fence; +} + +Status LegacyFence::TryResolveOutstandingFences(uint64_t upper_value) { + absl::MutexLock lock(&mutex_); + return TryResolveOutstandingFencesLocked(upper_value); +} + +Status LegacyFence::TryResolveOutstandingFencesLocked(uint64_t upper_value) { + // Fast-path for when we have no outstanding fences. + // NOTE: we hold the lock during the entire resolve process so that any waiter + // will only be woken once we have resolved to the furthest possible value. + if (outstanding_signals_.empty() || value_ > upper_value) { + return OkStatus(); + } + + IREE_TRACE_SCOPE0("LegacyFence::TryResolveOutstandingFences"); + + IntrusiveList<OutstandingFenceSignal> resolved_fences; + IntrusiveList<OutstandingFenceSignal> unresolved_fences; + VkDevice device = *fence_pool_->logical_device(); + const auto& syms = fence_pool_->syms(); + bool keep_resolving = true; + while (keep_resolving && !outstanding_signals_.empty()) { + auto* fence_signal = outstanding_signals_.front(); + if (fence_signal->value > upper_value) { + // Signal is for a value beyond our upper limit - early exit so that we + // don't spend time dealing with signals we don't yet care about. This can + // prevent live lock where one thread is signaling fences as fast/faster + // than another thread can consume them. + keep_resolving = false; + break; + } + VkResult fence_status = syms->vkGetFenceStatus(device, fence_signal->fence); + switch (fence_status) { + case VK_SUCCESS: { + // Fence has signaled meaning that we have reached this point in the + // timeline and can advance the value. + value_.store(fence_signal->value); + outstanding_signals_.erase(fence_signal); + resolved_fences.push_back(fence_signal); + + // Run backwards and resolve any non-pending fences as they will never + // be used. + for (auto* it = fence_signal; it != nullptr;) { + auto* prev_fence_signal = it; + it = outstanding_signals_.previous(it); + if (!prev_fence_signal->is_pending) { + outstanding_signals_.erase(prev_fence_signal); + unresolved_fences.push_back(prev_fence_signal); + } + } + break; + } + case VK_NOT_READY: + if (fence_signal->is_pending) { + // Fence has not yet been signaled. We stop here and wait for future + // attempts at resolution. + keep_resolving = false; + } + // Fence is not even pending yet - we may have skipped it. Keep + // resolving to see if there's a higher value we can use. + break; + default: + // Fence indicates an error (device lost, out of memory, etc). + // Propagate this back to our status (and thus any waiters). + // Since we only take the first error we find we skip all remaining + // fences. + status_ = VkResultToStatus(fence_status); + value_.store(UINT64_MAX); + outstanding_signals_.erase(fence_signal); + resolved_fences.push_back(fence_signal); + break; + } + } + + // Release resolved fences back to the pool. Note that we can only do this + // to fences we know have actually completed: unresolved fences after an error + // may still be in-flight and we don't want to reuse them. + fence_pool_->ReleaseResolved(&resolved_fences); + fence_pool_->ReleaseUnresolved(&unresolved_fences); + if (!status_.ok()) { + fence_pool_->ReleaseUnresolved(&outstanding_signals_); + } + + return status_; +} + +} // namespace vulkan +} // namespace hal +} // namespace iree
diff --git a/hal/vulkan/legacy_fence.h b/hal/vulkan/legacy_fence.h new file mode 100644 index 0000000..1ab4768 --- /dev/null +++ b/hal/vulkan/legacy_fence.h
@@ -0,0 +1,200 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// TODO(b/140141417): share the pool (and possibly most of the fence impl) with +// the timeline semaphores fallback. + +#ifndef IREE_HAL_VULKAN_LEGACY_FENCE_H_ +#define IREE_HAL_VULKAN_LEGACY_FENCE_H_ + +#include <vulkan/vulkan.h> + +#include <array> +#include <atomic> + +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" +#include "base/intrusive_list.h" +#include "base/ref_ptr.h" +#include "base/status.h" +#include "hal/fence.h" +#include "hal/vulkan/handle_util.h" + +namespace iree { +namespace hal { +namespace vulkan { + +// An outstanding legacy fence signal for a particular timeline value. +// Each signal to a new value gets a new VkFence and these are stored in a +// LegacyFence to quickly scan and process signaled fences. +// +// Must be externally synchronized via the LegacyFence mutex. +struct OutstandingFenceSignal : public IntrusiveLinkBase<void> { + // Allocated fence that is passed to vkQueueSubmit/vkWaitForFences. + // Represents a point in the timeline of value. + VkFence fence = VK_NULL_HANDLE; + + // Value that the fence payload should be when the fence is signaled. + // Note that since fences may resolve out of order we still need to check that + // we are only ever advancing the timeline and not just setting this value. + uint64_t value = UINT64_MAX; + + // True when the fence has been submitted and is pending on the device. + bool is_pending = false; +}; + +// A pool of VkFences that can be used by LegacyFence to simulate individual +// payload value signaling. Note that we prefer a pool instead of a ringbuffer +// as we want to allow out-of-order completion. +class LegacyFencePool final : public RefObject<LegacyFencePool> { + public: + static constexpr int kMaxInFlightFenceCount = 64; + + // Allocates a new fence pool and all fences. + static StatusOr<ref_ptr<LegacyFencePool>> Create( + ref_ptr<VkDeviceHandle> logical_device); + + ~LegacyFencePool(); + + const ref_ptr<VkDeviceHandle>& logical_device() const { + return logical_device_; + } + const ref_ptr<DynamicSymbols>& syms() const { + return logical_device_->syms(); + } + + // Acquires a fence from the pool for use by the caller. + // The fence is guaranteed to not be in-flight and will have been reset to an + // unsignaled state. + // + // Returns RESOURCE_EXHAUSTED if the pool has no more available fences. + // Callers are expected to handle this by waiting on previous fences or for + // complete device idle. Yes, that's as bad as it sounds, and if we start + // seeing that we should bump up the max count. + StatusOr<OutstandingFenceSignal*> Acquire(); + + // Releases one or more fences back to the pool. + // The fences must either be signaled or not be in-flight. + void ReleaseResolved(IntrusiveList<OutstandingFenceSignal>* fence_signals); + + // Releases one or more unresolved fences back to the pool. + // These may be in any state and will be assumed as untouchable. + void ReleaseUnresolved(IntrusiveList<OutstandingFenceSignal>* fence_signals); + + private: + explicit LegacyFencePool(ref_ptr<VkDeviceHandle> logical_device); + + Status PreallocateFences() ABSL_LOCKS_EXCLUDED(mutex_); + + ref_ptr<VkDeviceHandle> logical_device_; + + absl::Mutex mutex_; + std::array<OutstandingFenceSignal, kMaxInFlightFenceCount> storage_ + ABSL_GUARDED_BY(mutex_); + IntrusiveList<OutstandingFenceSignal> unused_fences_ ABSL_GUARDED_BY(mutex_); + IntrusiveList<OutstandingFenceSignal> unresolved_fences_ + ABSL_GUARDED_BY(mutex_); +}; + +// A fence implemented using a pool of native VkFences. +// This is supported unconditionally on all versions of Vulkan. When timeline +// semaphores are available we prefer using those instead and this is only +// present as a fallback. We keep this implementation separate so that it can be +// compiled out when the target is known to have the extension. +// +// Simulation of timeline semaphore-based fences is done via a pool of native +// VkFences that each represent a single signaled value. This means that worst +// case we are using one fence per submit however that's no different than if +// we did anything else. Though we can't cancel previously-queued fences when +// increasing values are signaled we can be clever when querying and releasing +// by always walking in reverse relying on the monotonically increasing values. +// +// Valid usage patterns we need to handle: +// 1. fence signaled and waited on (common case) +// 2. fence waited on before beginning signaling +// 3. fence signaled and never waited on +// +// Case 1 is fairly straightforward: we acquire a VkFence, pass that to the +// queue submit, and then vkWaitForFences/query it for completion. +// +// Case 2 requires that we reserve a fence during the wait so that we can pass +// it to vkWaitForFences and track it such that we can reuse it during a future +// signal operation. Since we don't know during signaling if the specific value +// we waited on will ever have its own dedicated signal operation we need to be +// conservative and try to coalesce for correctness. This means that if a wait +// for a value of 1 is performed and we get a signal for a value of 2 we need to +// combine the two. If a signal for a value of 1 is later performed it then +// becomes a no-op. This could lead to some additional latency however that's a +// risk (or benefit!) of using timelines. Rule of thumb: don't do out of order +// signaling. +// +// Case 3 is like case 2 where we need to reserve a fence to wait on, however +// since we don't know if it will ever be signaled we need to take care to +// properly release the VkFence back to the pool for reuse: we don't want to +// return it while there are still waiters for its original event. For this +// reason we track the waiters on a given fence during their wait operation and +// if a fence is released with waiters active we put them in a special +// unresolved until the waiters continue on. +class LegacyFence final : public Fence { + public: + // Waits for one or more (or all) fences to reach or exceed the given values. + static Status WaitForFences(VkDeviceHandle* logical_device, + absl::Span<const FenceValue> fences, + bool wait_all, absl::Time deadline); + + LegacyFence(ref_ptr<LegacyFencePool> fence_pool, uint64_t initial_value); + ~LegacyFence() override; + + Status status() const override; + + StatusOr<uint64_t> QueryValue() override; + + // Acquires a new fence for signaling a specific value. + StatusOr<VkFence> AcquireSignalFence(uint64_t value); + + private: + // Acquires a new fence for waiting on a specific value. + // Returns VK_NULL_HANDLE if the fence already resolved and the sticky error + // if the fence is in an error state. + StatusOr<VkFence> AcquireWaitFence(uint64_t value); + + // Runs down the outstanding fences list and resolves to the latest signaled + // value. Will early exit if the value moves beyond |upper_value|. + Status TryResolveOutstandingFences(uint64_t upper_value) + ABSL_LOCKS_EXCLUDED(mutex_); + Status TryResolveOutstandingFencesLocked(uint64_t upper_value) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + ref_ptr<LegacyFencePool> fence_pool_; + + // The current highest value of the fence as verified during a wait or query. + // Kept outside of |mutex_| so that queries do not require a lock. + std::atomic<uint64_t> value_; + + mutable absl::Mutex mutex_; + + // Sticky status failure value set on first failure. + Status status_ ABSL_GUARDED_BY(mutex_); + + // Outstanding VkFences representing signal values. + // Expected to be sorted in ascending order by value. + IntrusiveList<OutstandingFenceSignal> outstanding_signals_ + ABSL_GUARDED_BY(mutex_); +}; + +} // namespace vulkan +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_VULKAN_LEGACY_FENCE_H_
diff --git a/hal/vulkan/native_binary_semaphore.cc b/hal/vulkan/native_binary_semaphore.cc new file mode 100644 index 0000000..bb96897 --- /dev/null +++ b/hal/vulkan/native_binary_semaphore.cc
@@ -0,0 +1,32 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/vulkan/native_binary_semaphore.h" + +namespace iree { +namespace hal { +namespace vulkan { + +NativeBinarySemaphore::NativeBinarySemaphore( + ref_ptr<VkDeviceHandle> logical_device, VkSemaphore handle) + : logical_device_(std::move(logical_device)), handle_(handle) {} + +NativeBinarySemaphore::~NativeBinarySemaphore() { + logical_device_->syms()->vkDestroySemaphore(*logical_device_, handle_, + logical_device_->allocator()); +} + +} // namespace vulkan +} // namespace hal +} // namespace iree
diff --git a/hal/vulkan/native_binary_semaphore.h b/hal/vulkan/native_binary_semaphore.h new file mode 100644 index 0000000..19ac293 --- /dev/null +++ b/hal/vulkan/native_binary_semaphore.h
@@ -0,0 +1,46 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_VULKAN_NATIVE_BINARY_SEMAPHORE_H_ +#define IREE_HAL_VULKAN_NATIVE_BINARY_SEMAPHORE_H_ + +#include <vulkan/vulkan.h> + +#include "hal/semaphore.h" +#include "hal/vulkan/handle_util.h" + +namespace iree { +namespace hal { +namespace vulkan { + +// A binary semaphore implemented using the native VkSemaphore type. +// This is supported unconditionally on all versions of Vulkan. +class NativeBinarySemaphore final : public BinarySemaphore { + public: + NativeBinarySemaphore(ref_ptr<VkDeviceHandle> logical_device, + VkSemaphore handle); + ~NativeBinarySemaphore() override; + + VkSemaphore handle() const { return handle_; } + + private: + ref_ptr<VkDeviceHandle> logical_device_; + VkSemaphore handle_; +}; + +} // namespace vulkan +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_VULKAN_NATIVE_BINARY_SEMAPHORE_H_
diff --git a/hal/vulkan/native_event.cc b/hal/vulkan/native_event.cc new file mode 100644 index 0000000..4cdb8c6 --- /dev/null +++ b/hal/vulkan/native_event.cc
@@ -0,0 +1,31 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/vulkan/native_event.h" + +namespace iree { +namespace hal { +namespace vulkan { + +NativeEvent::NativeEvent(ref_ptr<VkDeviceHandle> logical_device, VkEvent handle) + : logical_device_(std::move(logical_device)), handle_(handle) {} + +NativeEvent::~NativeEvent() { + logical_device_->syms()->vkDestroyEvent(*logical_device_, handle_, + logical_device_->allocator()); +} + +} // namespace vulkan +} // namespace hal +} // namespace iree
diff --git a/hal/vulkan/native_event.h b/hal/vulkan/native_event.h new file mode 100644 index 0000000..2cc4c4c --- /dev/null +++ b/hal/vulkan/native_event.h
@@ -0,0 +1,44 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_VULKAN_NATIVE_EVENT_H_ +#define IREE_HAL_VULKAN_NATIVE_EVENT_H_ + +#include <vulkan/vulkan.h> + +#include "hal/event.h" +#include "hal/vulkan/handle_util.h" + +namespace iree { +namespace hal { +namespace vulkan { + +// An event implemented with the native VkEvent type. +class NativeEvent final : public Event { + public: + NativeEvent(ref_ptr<VkDeviceHandle> logical_device, VkEvent handle); + ~NativeEvent() override; + + VkEvent handle() const { return handle_; } + + private: + ref_ptr<VkDeviceHandle> logical_device_; + VkEvent handle_; +}; + +} // namespace vulkan +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_VULKAN_NATIVE_EVENT_H_
diff --git a/hal/vulkan/pipeline_cache.cc b/hal/vulkan/pipeline_cache.cc new file mode 100644 index 0000000..9ed48c3 --- /dev/null +++ b/hal/vulkan/pipeline_cache.cc
@@ -0,0 +1,235 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/vulkan/pipeline_cache.h" + +#include "absl/synchronization/mutex.h" +#include "base/source_location.h" +#include "base/status.h" +#include "base/tracing.h" +#include "flatbuffers/flatbuffers.h" +#include "hal/executable_format.h" +#include "hal/vulkan/status_util.h" +#include "schemas/spirv_executable_def_generated.h" + +namespace iree { +namespace hal { +namespace vulkan { + +PipelineCache::PipelineCache(const ref_ptr<VkDeviceHandle>& logical_device) + : logical_device_(add_ref(logical_device)) {} + +PipelineCache::~PipelineCache() { + IREE_TRACE_SCOPE0("PipelineCache::dtor"); + ClearLayoutCaches(); +} + +bool PipelineCache::CanPrepareFormat(ExecutableFormat format) const { + return format == kExecutableFormatSpirV; +} + +StatusOr<ref_ptr<Executable>> PipelineCache::PrepareExecutable( + ExecutableCachingModeBitfield mode, const ExecutableSpec& spec) { + IREE_TRACE_SCOPE0("PipelineCache::PrepareExecutable"); + if (!CanPrepareFormat(spec.format)) { + return UnimplementedErrorBuilder(IREE_LOC) + << "Unsupported 4CC format: 0x" << std::hex << spec.format; + } + if (spec.executable_data.size() <= 4 || + !SpirVExecutableDefBufferHasIdentifier(spec.executable_data.data())) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Supplied executable data does not contain a SpirVExecutableDef"; + } + + // Get the SPIR-V executable def flatbuffer. + const auto& spirv_executable_def = + *::flatbuffers::GetRoot<SpirVExecutableDef>(spec.executable_data.data()); + + // Create (or reuse) a pipeline layout. + if (!spirv_executable_def.pipeline_layout()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Missing pipeline layout def"; + } + ASSIGN_OR_RETURN( + auto pipeline_layout_entry, + LookupOrInsertPipelineLayout(*spirv_executable_def.pipeline_layout())); + + // Create the executable (which may itself own many pipelines). + ASSIGN_OR_RETURN(auto executable, PipelineExecutable::Create( + logical_device_, + /*pipeline_cache=*/VK_NULL_HANDLE, + pipeline_layout_entry->pipeline_layout, + pipeline_layout_entry->descriptor_sets, + mode, spirv_executable_def)); + return executable; +} + +StatusOr<const PipelineCache::CachedPipelineLayout*> +PipelineCache::LookupOrInsertPipelineLayout( + const VkPipelineLayoutDef& pipeline_layout_def) { + IREE_TRACE_SCOPE0("PipelineCache::LookupOrInsertPipelineLayout"); + absl::MutexLock lock(&mutex_); + + // Build a list of the required descriptor set layouts and push constants. + // If we were being fast about this we would just hash the def and directly + // look up the pipeline layout. + PipelineDescriptorSets descriptor_sets; + descriptor_sets.buffer_binding_set = pipeline_layout_def.buffer_binding_set(); + descriptor_sets.buffer_binding_set_layout = VK_NULL_HANDLE; + absl::InlinedVector<VkDescriptorSetLayout, 4> descriptor_set_layouts; + if (pipeline_layout_def.descriptor_set_layouts()) { + const auto& layout_defs = *pipeline_layout_def.descriptor_set_layouts(); + descriptor_set_layouts.resize(layout_defs.size()); + for (int i = 0; i < descriptor_set_layouts.size(); ++i) { + if (!layout_defs[i]) { + return InvalidArgumentErrorBuilder(IREE_LOC) << "Missing layout def"; + } + ASSIGN_OR_RETURN(descriptor_set_layouts[i], + LookupOrInsertDescriptorSetLayout(*layout_defs[i])); + if (i == pipeline_layout_def.buffer_binding_set()) { + descriptor_sets.buffer_binding_set_layout = descriptor_set_layouts[i]; + descriptor_sets.buffer_binding_set_map.resize( + layout_defs[i]->bindings()->size()); + for (int j = 0; j < layout_defs[i]->bindings()->size(); ++j) { + descriptor_sets.buffer_binding_set_map[j] = + layout_defs[i]->bindings()->Get(j)->binding(); + } + } + } + } + + absl::InlinedVector<VkPushConstantRange, 1> push_constant_ranges; + if (pipeline_layout_def.push_constant_ranges()) { + const auto& range_defs = *pipeline_layout_def.push_constant_ranges(); + push_constant_ranges.resize(range_defs.size()); + for (int i = 0; i < push_constant_ranges.size(); ++i) { + if (!range_defs[i]) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Missing push constant range def"; + } + push_constant_ranges[i].stageFlags = range_defs[i]->stage_flags(); + push_constant_ranges[i].offset = range_defs[i]->offset(); + push_constant_ranges[i].size = range_defs[i]->size(); + } + } + + // Scan for an existing pipeline layout that matches the descriptor sets. + for (auto& entry : pipeline_layout_cache_) { + if (entry.descriptor_set_layouts.size() != descriptor_set_layouts.size() || + entry.push_constant_ranges.size() != push_constant_ranges.size()) { + continue; + } + if (std::memcmp( + descriptor_set_layouts.data(), entry.descriptor_set_layouts.data(), + descriptor_set_layouts.size() * sizeof(VkDescriptorSetLayout)) == + 0 && + std::memcmp( + push_constant_ranges.data(), entry.push_constant_ranges.data(), + push_constant_ranges.size() * sizeof(VkPushConstantRange)) == 0) { + return &entry; + } + } + + VkPipelineLayoutCreateInfo create_info; + create_info.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO; + create_info.pNext = nullptr; + create_info.flags = 0; + create_info.setLayoutCount = descriptor_set_layouts.size(); + create_info.pSetLayouts = descriptor_set_layouts.data(); + create_info.pushConstantRangeCount = push_constant_ranges.size(); + create_info.pPushConstantRanges = push_constant_ranges.data(); + + // Create and insert into the cache. + VkPipelineLayout pipeline_layout = VK_NULL_HANDLE; + VK_RETURN_IF_ERROR(syms()->vkCreatePipelineLayout( + *logical_device_, &create_info, logical_device_->allocator(), + &pipeline_layout)); + pipeline_layout_cache_.push_back({std::move(descriptor_set_layouts), + std::move(push_constant_ranges), + pipeline_layout, descriptor_sets}); + return &pipeline_layout_cache_.back(); +} + +StatusOr<VkDescriptorSetLayout> +PipelineCache::LookupOrInsertDescriptorSetLayout( + const VkDescriptorSetLayoutDef& descriptor_set_layout_def) { + // Build a list of bindings in the set. + // If we were being fast we would hash the bindings and directly lookup + // without doing this allocation. + absl::InlinedVector<VkDescriptorSetLayoutBinding, 4> bindings; + if (descriptor_set_layout_def.bindings()) { + const auto& binding_defs = *descriptor_set_layout_def.bindings(); + bindings.resize(binding_defs.size()); + for (int i = 0; i < binding_defs.size(); ++i) { + bindings[i].binding = binding_defs[i]->binding(); + bindings[i].descriptorType = + static_cast<VkDescriptorType>(binding_defs[i]->descriptor_type()); + bindings[i].descriptorCount = binding_defs[i]->descriptor_count(); + bindings[i].stageFlags = binding_defs[i]->stage_flags(); + bindings[i].pImmutableSamplers = nullptr; + } + } + + // Scan for an existing descriptor set layout that matches the bindings. + for (auto& entry : descriptor_set_layout_cache_) { + if (entry.bindings.size() != bindings.size()) continue; + if (std::memcmp(bindings.data(), entry.bindings.data(), + bindings.size() * sizeof(VkDescriptorSetLayoutBinding)) == + 0) { + return entry.descriptor_set_layout; + } + } + + VkDescriptorSetLayoutCreateInfo create_info; + create_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO; + create_info.pNext = nullptr; + create_info.flags = 0; + if (logical_device_->enabled_extensions().push_descriptors) { + // Note that we can *only* use push descriptor sets if we set this create + // flag. That's fine, though, as the command buffer recording logic always + // prefers the extension if available. + create_info.flags |= + VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT_KHR; + } + create_info.bindingCount = bindings.size(); + create_info.pBindings = bindings.data(); + + // Create and insert into the cache. + VkDescriptorSetLayout descriptor_set_layout = VK_NULL_HANDLE; + VK_RETURN_IF_ERROR(syms()->vkCreateDescriptorSetLayout( + *logical_device_, &create_info, logical_device_->allocator(), + &descriptor_set_layout)); + descriptor_set_layout_cache_.push_back( + {std::move(bindings), descriptor_set_layout}); + return descriptor_set_layout; +} + +void PipelineCache::ClearLayoutCaches() { + absl::MutexLock lock(&mutex_); + for (auto& entry : pipeline_layout_cache_) { + syms()->vkDestroyPipelineLayout(*logical_device_, entry.pipeline_layout, + logical_device_->allocator()); + } + pipeline_layout_cache_.clear(); + for (auto& entry : descriptor_set_layout_cache_) { + syms()->vkDestroyDescriptorSetLayout(*logical_device_, + entry.descriptor_set_layout, + logical_device_->allocator()); + } + descriptor_set_layout_cache_.clear(); +} + +} // namespace vulkan +} // namespace hal +} // namespace iree
diff --git a/hal/vulkan/pipeline_cache.h b/hal/vulkan/pipeline_cache.h new file mode 100644 index 0000000..5941ab8 --- /dev/null +++ b/hal/vulkan/pipeline_cache.h
@@ -0,0 +1,85 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_VULKAN_PIPELINE_CACHE_H_ +#define IREE_HAL_VULKAN_PIPELINE_CACHE_H_ + +#include <vulkan/vulkan.h> + +#include "absl/base/thread_annotations.h" +#include "absl/container/inlined_vector.h" +#include "absl/synchronization/mutex.h" +#include "hal/executable.h" +#include "hal/executable_cache.h" +#include "hal/vulkan/handle_util.h" +#include "hal/vulkan/pipeline_executable.h" +#include "schemas/spirv_executable_def_generated.h" + +namespace iree { +namespace hal { +namespace vulkan { + +class PipelineCache final : public ExecutableCache { + public: + explicit PipelineCache(const ref_ptr<VkDeviceHandle>& logical_device); + ~PipelineCache() override; + + const ref_ptr<DynamicSymbols>& syms() const { + return logical_device_->syms(); + } + + bool CanPrepareFormat(ExecutableFormat format) const override; + + StatusOr<ref_ptr<Executable>> PrepareExecutable( + ExecutableCachingModeBitfield mode, const ExecutableSpec& spec) override; + + private: + struct CachedDescriptorSetLayout { + absl::InlinedVector<VkDescriptorSetLayoutBinding, 4> bindings; + VkDescriptorSetLayout descriptor_set_layout; + }; + struct CachedPipelineLayout { + absl::InlinedVector<VkDescriptorSetLayout, 4> descriptor_set_layouts; + absl::InlinedVector<VkPushConstantRange, 1> push_constant_ranges; + VkPipelineLayout pipeline_layout; + PipelineDescriptorSets descriptor_sets; + }; + + StatusOr<const CachedPipelineLayout*> LookupOrInsertPipelineLayout( + const VkPipelineLayoutDef& pipeline_layout_def) + ABSL_LOCKS_EXCLUDED(mutex_); + StatusOr<VkDescriptorSetLayout> LookupOrInsertDescriptorSetLayout( + const VkDescriptorSetLayoutDef& descriptor_set_layout_def) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + void ClearLayoutCaches() ABSL_LOCKS_EXCLUDED(mutex_); + + ref_ptr<VkDeviceHandle> logical_device_; + + // A "cache" of descriptor set and pipeline layouts for various values. + // We never evict and just do a simple linear scan on lookup. This is fine for + // now as we only support a single descriptor type and really we only need to + // check for binding count. As we go toward more general usage of descriptors + // (images/etc) we will likely want to change this to a real cache. + absl::Mutex mutex_; + std::vector<CachedDescriptorSetLayout> descriptor_set_layout_cache_ + ABSL_GUARDED_BY(mutex_); + std::vector<CachedPipelineLayout> pipeline_layout_cache_ + ABSL_GUARDED_BY(mutex_); +}; + +} // namespace vulkan +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_VULKAN_PIPELINE_CACHE_H_
diff --git a/hal/vulkan/pipeline_executable.cc b/hal/vulkan/pipeline_executable.cc new file mode 100644 index 0000000..f4f7a61 --- /dev/null +++ b/hal/vulkan/pipeline_executable.cc
@@ -0,0 +1,191 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/vulkan/pipeline_executable.h" + +#include "absl/container/inlined_vector.h" +#include "base/memory.h" +#include "base/source_location.h" +#include "base/status.h" +#include "base/tracing.h" +#include "hal/vulkan/status_util.h" + +namespace iree { +namespace hal { +namespace vulkan { + +namespace { + +// Generates the baked specialization constant data based on the flatbuffer. +// We only support uint32_t right now so this is easy. +// Note that the returned vectors are referenced by pointers in |out_info| and +// must remain valid until the info is no longer in use. +std::pair<std::vector<VkSpecializationMapEntry>, std::vector<uint8_t>> +PopulateSpecializationInfo(const VkSpecializationInfoDef* info_def) { + int entry_count = + info_def && info_def->map_entries() ? info_def->map_entries()->size() : 0; + if (!entry_count) { + return {}; + } + + std::vector<VkSpecializationMapEntry> entries; + entries.reserve(entry_count); + std::vector<uint8_t> data; + data.resize(entry_count * sizeof(uint32_t)); + + uint32_t offset = 0; + for (const auto* entry_def : *info_def->map_entries()) { + if (!entry_def) continue; + entries.push_back({}); + auto& entry = entries.back(); + entry.constantID = entry_def->constant_id(); + entry.offset = offset; + entry.size = sizeof(uint32_t); + uint32_t value = entry_def->uint32_value(); + std::memcpy(data.data() + offset, &value, sizeof(value)); + offset += entry.size; + } + + return {std::move(entries), std::move(data)}; +} + +} // namespace + +// static +StatusOr<ref_ptr<PipelineExecutable>> PipelineExecutable::Create( + const ref_ptr<VkDeviceHandle>& logical_device, + VkPipelineCache pipeline_cache, VkPipelineLayout pipeline_layout, + PipelineDescriptorSets descriptor_sets, ExecutableCachingModeBitfield mode, + const SpirVExecutableDef& spirv_executable_def) { + IREE_TRACE_SCOPE0("PipelineExecutable::Create"); + const auto& syms = logical_device->syms(); + if (!spirv_executable_def.entry_points() || + spirv_executable_def.entry_points()->size() == 0) { + return InvalidArgumentErrorBuilder(IREE_LOC) << "No entry points defined"; + } + if (!spirv_executable_def.code()) { + return InvalidArgumentErrorBuilder(IREE_LOC) << "No SPIR-V code present"; + } + const auto& code = *spirv_executable_def.code(); + + // Create the shader module. + VkShaderModuleCreateInfo shader_module_create_info; + shader_module_create_info.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; + shader_module_create_info.pNext = nullptr; + shader_module_create_info.flags = 0; + shader_module_create_info.codeSize = code.size() * sizeof(uint32_t); + shader_module_create_info.pCode = code.data(); + VkShaderModule shader_module = VK_NULL_HANDLE; + VK_RETURN_IF_ERROR( + syms->vkCreateShaderModule(*logical_device, &shader_module_create_info, + logical_device->allocator(), &shader_module)); + + // We only need to keep this around during pipeline creation so ensure we + // always clean it up when we exit this function. + auto shader_module_cleanup = MakeCleanup([&logical_device, shader_module]() { + logical_device->syms()->vkDestroyShaderModule( + *logical_device, shader_module, logical_device->allocator()); + }); + + // Specialization info is currently constant against all entry points. + std::vector<VkSpecializationMapEntry> spec_entries; + std::vector<uint8_t> spec_data; + std::tie(spec_entries, spec_data) = + PopulateSpecializationInfo(spirv_executable_def.specialization_info()); + VkSpecializationInfo specialization_info; + specialization_info.mapEntryCount = spec_entries.size(); + specialization_info.pMapEntries = spec_entries.data(); + specialization_info.dataSize = spec_data.size(); + specialization_info.pData = spec_data.data(); + + // Create pipelines for each entry point. + const auto& entry_points = *spirv_executable_def.entry_points(); + absl::InlinedVector<VkComputePipelineCreateInfo, 1> pipeline_create_infos; + pipeline_create_infos.resize(entry_points.size()); + for (int entry_ordinal = 0; entry_ordinal < entry_points.size(); + ++entry_ordinal) { + auto& pipeline_create_info = pipeline_create_infos[entry_ordinal]; + pipeline_create_info.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; + pipeline_create_info.pNext = nullptr; + pipeline_create_info.flags = 0; + if (!AllBitsSet(mode, ExecutableCachingMode::kAllowOptimization)) { + pipeline_create_info.flags |= VK_PIPELINE_CREATE_DISABLE_OPTIMIZATION_BIT; + } + if (entry_ordinal == 0) { + pipeline_create_info.flags |= VK_PIPELINE_CREATE_ALLOW_DERIVATIVES_BIT; + } else { + pipeline_create_info.flags |= VK_PIPELINE_CREATE_DERIVATIVE_BIT; + } + pipeline_create_info.layout = pipeline_layout; + pipeline_create_info.basePipelineHandle = VK_NULL_HANDLE; + pipeline_create_info.basePipelineIndex = 0; + auto& stage_create_info = pipeline_create_info.stage; + stage_create_info.sType = + VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; + stage_create_info.pNext = nullptr; + stage_create_info.flags = 0; + stage_create_info.stage = VK_SHADER_STAGE_COMPUTE_BIT; + stage_create_info.module = shader_module; + stage_create_info.pName = entry_points[entry_ordinal]->c_str(); + stage_create_info.pSpecializationInfo = &specialization_info; + } + absl::InlinedVector<VkPipeline, 1> pipelines; + pipelines.resize(entry_points.size()); + + // Some ICDs appear to leak in here, out of our control. + // Warning: leak checks remain disabled if an error is returned. + IREE_DISABLE_LEAK_CHECKS(); + VK_RETURN_IF_ERROR(syms->vkCreateComputePipelines( + *logical_device, pipeline_cache, pipeline_create_infos.size(), + pipeline_create_infos.data(), logical_device->allocator(), + pipelines.data())); + IREE_ENABLE_LEAK_CHECKS(); + + auto executable = + make_ref<PipelineExecutable>(CtorKey{}, logical_device, pipeline_layout, + descriptor_sets, std::move(pipelines)); + executable->tag_ = + spirv_executable_def.tag() ? spirv_executable_def.tag()->str() : ""; + return executable; +} + +PipelineExecutable::PipelineExecutable( + CtorKey ctor_key, const ref_ptr<VkDeviceHandle>& logical_device, + VkPipelineLayout pipeline_layout, PipelineDescriptorSets descriptor_sets, + absl::InlinedVector<VkPipeline, 1> pipelines) + : logical_device_(add_ref(logical_device)), + pipeline_layout_(pipeline_layout), + descriptor_sets_(descriptor_sets), + pipelines_(std::move(pipelines)) {} + +PipelineExecutable::~PipelineExecutable() { + IREE_TRACE_SCOPE0("PipelineExecutable::dtor"); + for (auto pipeline : pipelines_) { + syms()->vkDestroyPipeline(*logical_device_, pipeline, + logical_device_->allocator()); + } + pipelines_.clear(); +} + +StatusOr<VkPipeline> PipelineExecutable::GetPipelineForEntryPoint( + int entry_ordinal) const { + if (entry_ordinal < 0 || entry_ordinal >= pipelines_.size()) { + return OutOfRangeErrorBuilder(IREE_LOC) << "Invalid entry point ordinal"; + } + return pipelines_[entry_ordinal]; +} + +} // namespace vulkan +} // namespace hal +} // namespace iree
diff --git a/hal/vulkan/pipeline_executable.h b/hal/vulkan/pipeline_executable.h new file mode 100644 index 0000000..51fe650 --- /dev/null +++ b/hal/vulkan/pipeline_executable.h
@@ -0,0 +1,91 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_H_ +#define IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_H_ + +#include <vulkan/vulkan.h> + +#include <vector> + +#include "absl/container/inlined_vector.h" +#include "base/status.h" +#include "hal/executable.h" +#include "hal/executable_cache.h" +#include "hal/executable_spec.h" +#include "hal/vulkan/handle_util.h" +#include "schemas/spirv_executable_def_generated.h" + +namespace iree { +namespace hal { +namespace vulkan { + +struct PipelineDescriptorSets { + uint32_t buffer_binding_set; + VkDescriptorSetLayout buffer_binding_set_layout; + absl::InlinedVector<uint32_t, 8> buffer_binding_set_map; +}; + +class PipelineExecutable final : public Executable { + public: + static StatusOr<ref_ptr<PipelineExecutable>> Create( + const ref_ptr<VkDeviceHandle>& logical_device, + VkPipelineCache pipeline_cache, VkPipelineLayout pipeline_layout, + PipelineDescriptorSets descriptor_sets, + ExecutableCachingModeBitfield mode, + const SpirVExecutableDef& spirv_executable_def); + + // Private constructor. + struct CtorKey { + private: + friend class PipelineExecutable; + CtorKey() = default; + }; + PipelineExecutable(CtorKey ctor_key, + const ref_ptr<VkDeviceHandle>& logical_device, + VkPipelineLayout pipeline_layout, + PipelineDescriptorSets descriptor_sets, + absl::InlinedVector<VkPipeline, 1> pipelines); + ~PipelineExecutable() override; + + const ref_ptr<DynamicSymbols>& syms() const { + return logical_device_->syms(); + } + + bool supports_debugging() const override { return false; } + + VkPipelineLayout pipeline_layout() const { return pipeline_layout_; } + const PipelineDescriptorSets& descriptor_sets() const { + return descriptor_sets_; + } + + bool is_matmul() const { return tag_ == "__matmul__"; } + + StatusOr<VkPipeline> GetPipelineForEntryPoint(int entry_ordinal) const; + + private: + ref_ptr<VkDeviceHandle> logical_device_; + VkPipelineLayout pipeline_layout_; + PipelineDescriptorSets descriptor_sets_; + std::string tag_; + + // One pipeline per entry point. + absl::InlinedVector<VkPipeline, 1> pipelines_; +}; + +} // namespace vulkan +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_H_
diff --git a/hal/vulkan/status_util.cc b/hal/vulkan/status_util.cc new file mode 100644 index 0000000..b9ed027 --- /dev/null +++ b/hal/vulkan/status_util.cc
@@ -0,0 +1,231 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/vulkan/status_util.h" + +#include "base/status.h" + +namespace iree { +namespace hal { +namespace vulkan { + +Status VkResultToStatus(VkResult result) { + switch (result) { + // Success codes. + case VK_SUCCESS: + // Command successfully completed. + return OkStatus(); + case VK_NOT_READY: + // A fence or query has not yet completed. + return OkStatus(); + case VK_TIMEOUT: + // A wait operation has not completed in the specified time. + return OkStatus(); + case VK_EVENT_SET: + // An event is signaled. + return OkStatus(); + case VK_EVENT_RESET: + // An event is unsignaled. + return OkStatus(); + case VK_INCOMPLETE: + // A return array was too small for the result. + return OkStatus(); + case VK_SUBOPTIMAL_KHR: + // A swapchain no longer matches the surface properties exactly, but can + // still be used to present to the surface successfully. + return OkStatus(); + + // Error codes. + case VK_ERROR_OUT_OF_HOST_MEMORY: + // A host memory allocation has failed. + return ResourceExhaustedError("VK_ERROR_OUT_OF_HOST_MEMORY"); + case VK_ERROR_OUT_OF_DEVICE_MEMORY: + // A device memory allocation has failed. + return ResourceExhaustedError("VK_ERROR_OUT_OF_DEVICE_MEMORY"); + case VK_ERROR_INITIALIZATION_FAILED: + // Initialization of an object could not be completed for + // implementation-specific reasons. + return InternalError("VK_ERROR_INITIALIZATION_FAILED"); + case VK_ERROR_DEVICE_LOST: + // The logical or physical device has been lost. + // + // A logical device may become lost for a number of + // implementation-specific reasons, indicating that pending and future + // command execution may fail and cause resources and backing memory to + // become undefined. + // + // Typical reasons for device loss will include things like execution + // timing out (to prevent denial of service), power management events, + // platform resource management, or implementation errors. + // + // When this happens, certain commands will return + // VK_ERROR_DEVICE_LOST (see Error Codes for a list of such + // commands). After any such event, the logical device is considered lost. + // It is not possible to reset the logical device to a non-lost state, + // however the lost state is specific to a logical device (VkDevice), and + // the corresponding physical device (VkPhysicalDevice) may be otherwise + // unaffected. + // + // In some cases, the physical device may also be lost, and attempting to + // create a new logical device will fail, returning VK_ERROR_DEVICE_LOST. + // This is usually indicative of a problem with the underlying + // implementation, or its connection to the host. If the physical device + // has not been lost, and a new logical device is successfully created + // from that physical device, it must be in the non-lost state. + // + // Whilst logical device loss may be recoverable, in the case of physical + // device loss, it is unlikely that an application will be able to recover + // unless additional, unaffected physical devices exist on the system. The + // error is largely informational and intended only to inform the user + // that a platform issue has occurred, and should be investigated further. + // For example, underlying hardware may have developed a fault or become + // physically disconnected from the rest of the system. In many cases, + // physical device loss may cause other more serious issues such as the + // operating system crashing; in which case it may not be reported via the + // Vulkan API. + // + // Undefined behavior caused by an application error may cause a device to + // become lost. However, such undefined behavior may also cause + // unrecoverable damage to the process, and it is then not guaranteed that + // the API objects, including the VkPhysicalDevice or the VkInstance are + // still valid or that the error is recoverable. + // + // When a device is lost, its child objects are not implicitly destroyed + // and their handles are still valid. Those objects must still be + // destroyed before their parents or the device can be destroyed (see the + // Object Lifetime section). The host address space corresponding to + // device memory mapped using vkMapMemory is still valid, and host memory + // accesses to these mapped regions are still valid, but the contents are + // undefined. It is still legal to call any API command on the device and + // child objects. + // + // Once a device is lost, command execution may fail, and commands that + // return a VkResult may return VK_ERROR_DEVICE_LOST. + // Commands that do not allow run-time errors must still operate correctly + // for valid usage and, if applicable, return valid data. + // + // Commands that wait indefinitely for device execution (namely + // vkDeviceWaitIdle, vkQueueWaitIdle, vkWaitForFences with a maximum + // timeout, and vkGetQueryPoolResults with the VK_QUERY_RESULT_WAIT_BIT + // bit set in flags) must return in finite time even in the case + // of a lost device, and return either VK_SUCCESS or + // VK_ERROR_DEVICE_LOST. For any command that may return + // VK_ERROR_DEVICE_LOST, for the purpose of determining whether a + // command buffer is in the pending state, or whether resources are + // considered in-use by the device, a return value of + // VK_ERROR_DEVICE_LOST is equivalent to VK_SUCCESS. + return InternalError("VK_ERROR_DEVICE_LOST"); + case VK_ERROR_MEMORY_MAP_FAILED: + // Mapping of a memory object has failed. + return InternalError("VK_ERROR_MEMORY_MAP_FAILED"); + case VK_ERROR_LAYER_NOT_PRESENT: + // A requested layer is not present or could not be loaded. + return UnimplementedError("VK_ERROR_LAYER_NOT_PRESENT"); + case VK_ERROR_EXTENSION_NOT_PRESENT: + // A requested extension is not supported. + return UnimplementedError("VK_ERROR_EXTENSION_NOT_PRESENT"); + case VK_ERROR_FEATURE_NOT_PRESENT: + // A requested feature is not supported. + return UnimplementedError("VK_ERROR_FEATURE_NOT_PRESENT"); + case VK_ERROR_INCOMPATIBLE_DRIVER: + // The requested version of Vulkan is not supported by the driver or is + // otherwise incompatible for implementation-specific reasons. + return FailedPreconditionError("VK_ERROR_INCOMPATIBLE_DRIVER"); + case VK_ERROR_TOO_MANY_OBJECTS: + // Too many objects of the type have already been created. + return ResourceExhaustedError("VK_ERROR_TOO_MANY_OBJECTS"); + case VK_ERROR_FORMAT_NOT_SUPPORTED: + // A requested format is not supported on this device. + return UnimplementedError("VK_ERROR_FORMAT_NOT_SUPPORTED"); + case VK_ERROR_FRAGMENTED_POOL: + // A pool allocation has failed due to fragmentation of the pool’s memory. + // This must only be returned if no attempt to allocate host or device + // memory was made to accommodate the new allocation. + return ResourceExhaustedError("VK_ERROR_FRAGMENTED_POOL"); + case VK_ERROR_OUT_OF_POOL_MEMORY: + // A pool memory allocation has failed. This must only be returned if no + // attempt to allocate host or device memory was made to accommodate the + // new allocation. If the failure was definitely due to fragmentation of + // the pool, VK_ERROR_FRAGMENTED_POOL should be returned instead. + return ResourceExhaustedError("VK_ERROR_OUT_OF_POOL_MEMORY"); + case VK_ERROR_INVALID_EXTERNAL_HANDLE: + // An external handle is not a valid handle of the specified type. + return InvalidArgumentError("VK_ERROR_INVALID_EXTERNAL_HANDLE"); + case VK_ERROR_SURFACE_LOST_KHR: + // A surface is no longer available. + return UnavailableError("VK_ERROR_SURFACE_LOST_KHR"); + case VK_ERROR_NATIVE_WINDOW_IN_USE_KHR: + // The requested window is already in use by Vulkan or another API in a + // manner which prevents it from being used again. + return InvalidArgumentError("VK_ERROR_NATIVE_WINDOW_IN_USE_KHR"); + case VK_ERROR_OUT_OF_DATE_KHR: + // A surface has changed in such a way that it is no longer compatible + // with the swapchain, and further presentation requests using the + // swapchain will fail. Applications must query the new surface properties + // and recreate their swapchain if they wish to continue presenting to the + // surface. + return FailedPreconditionError("VK_ERROR_OUT_OF_DATE_KHR"); + case VK_ERROR_INCOMPATIBLE_DISPLAY_KHR: + // The display used by a swapchain does not use the same presentable image + // layout, or is incompatible in a way that prevents sharing an image. + return InvalidArgumentError("VK_ERROR_INCOMPATIBLE_DISPLAY_KHR"); + case VK_ERROR_VALIDATION_FAILED_EXT: + // Validation layer testing failed. It is not expected that an + // application would see this this error code during normal use of the + // validation layers. + return InvalidArgumentError("VK_ERROR_VALIDATION_FAILED_EXT"); + case VK_ERROR_INVALID_SHADER_NV: + // One or more shaders failed to compile or link. More details are + // reported back to the application when the validation layer is enabled + // using the extension VK_EXT_debug_report. + return InvalidArgumentError("VK_ERROR_INVALID_SHADER_NV"); + case VK_ERROR_INVALID_DRM_FORMAT_MODIFIER_PLANE_LAYOUT_EXT: + // When creating an image with + // VkImageDrmFormatModifierExplicitCreateInfoEXT, it is the application’s + // responsibility to satisfy all Valid Usage requirements. However, the + // implementation must validate that the provided pPlaneLayouts, when + // combined with the provided drmFormatModifier and other creation + // parameters in VkImageCreateInfo and its pNext chain, produce a valid + // image. (This validation is necessarily implementation-dependent and + // outside the scope of Vulkan, and therefore not described by Valid Usage + // requirements). If this validation fails, then vkCreateImage returns + // VK_ERROR_INVALID_DRM_FORMAT_MODIFIER_PLANE_LAYOUT_EXT. + return InvalidArgumentError( + "VK_ERROR_INVALID_DRM_FORMAT_MODIFIER_PLANE_LAYOUT_EXT"); + case VK_ERROR_FRAGMENTATION_EXT: + // A descriptor pool creation has failed due to fragmentation. + return ResourceExhaustedError("VK_ERROR_FRAGMENTATION_EXT"); + case VK_ERROR_NOT_PERMITTED_EXT: + // When creating a queue, the caller does not have sufficient privileges + // to request to acquire a priority above the default priority + // (VK_QUEUE_GLOBAL_PRIORITY_MEDIUM_EXT). + return PermissionDeniedError("VK_ERROR_NOT_PERMITTED_EXT"); + case VK_ERROR_INVALID_DEVICE_ADDRESS_EXT: + // A buffer creation failed because the requested address is not + // available. + return OutOfRangeError("VK_ERROR_INVALID_DEVICE_ADDRESS_EXT"); + case VK_ERROR_FULL_SCREEN_EXCLUSIVE_MODE_LOST_EXT: + // An operation on a swapchain created with + // VK_FULL_SCREEN_EXCLUSIVE_APPLICATION_CONTROLLED_EXT failed as it did + // not have exlusive full-screen access. This may occur due to + // implementation-dependent reasons, outside of the application’s control. + return UnavailableError("VK_ERROR_FULL_SCREEN_EXCLUSIVE_MODE_LOST_EXT"); + default: + return UnknownError(std::to_string(result)); + } +} + +} // namespace vulkan +} // namespace hal +} // namespace iree
diff --git a/hal/vulkan/status_util.h b/hal/vulkan/status_util.h new file mode 100644 index 0000000..8ff07f5 --- /dev/null +++ b/hal/vulkan/status_util.h
@@ -0,0 +1,87 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_VULKAN_STATUS_UTIL_H_ +#define IREE_HAL_VULKAN_STATUS_UTIL_H_ + +#include <vulkan/vulkan.h> + +#include "base/status.h" + +namespace iree { +namespace hal { +namespace vulkan { + +// RETURN_IF_ERROR but implicitly converts the VkResult return value to +// a Status. +// +// Usage: +// VK_RETURN_IF_ERROR(vkDoThing(...)); +#define VK_RETURN_IF_ERROR(expr) \ + RETURN_IF_ERROR(::iree::hal::vulkan::VkResultToStatus(expr)) + +// CHECK_OK but implicitly converts the VkResults return value to a +// Status and checks that it is OkStatus. +// +// Usage: +// VK_CHECK_OK(vkDoThing(...)); +#define VK_CHECK_OK(expr) CHECK_OK(::iree::hal::vulkan::VkResultToStatus(expr)) + +// Converts a VkResult to a Status object. +// +// Vulkan considers the following as "success codes" and users should ensure +// they first check the result prior to converting: +// +// - VK_SUCCESS -> OkStatus() +// - VK_NOT_READY -> OkStatus() +// - VK_TIMEOUT -> OkStatus() +// - VK_EVENT_SET -> OkStatus() +// - VK_EVENT_RESET -> OkStatus() +// - VK_INCOMPLETE -> OkStatus() +// - VK_SUBOPTIMAL_KHR -> OkStatus() +// +// The rest are considered as "error codes": +// +// - VK_ERROR_OUT_OF_HOST_MEMORY -> ResourceExhaustedError("VK...") +// - VK_ERROR_OUT_OF_DEVICE_MEMORY -> ResourceExhaustedError("VK...") +// - VK_ERROR_INITIALIZATION_FAILED -> InternalError("VK...") +// - VK_ERROR_DEVICE_LOST -> InternalError("VK...") +// - VK_ERROR_MEMORY_MAP_FAILED -> InternalError("VK...") +// - VK_ERROR_LAYER_NOT_PRESENT -> NotFoundError("VK...") +// - VK_ERROR_EXTENSION_NOT_PRESENT -> NotFoundError("VK...") +// - VK_ERROR_FEATURE_NOT_PRESENT -> NotFoundError("VK...") +// - VK_ERROR_INCOMPATIBLE_DRIVER -> FailedPreconditionError("VK...") +// - VK_ERROR_TOO_MANY_OBJECTS -> ResourceExhaustedError("VK...") +// - VK_ERROR_FORMAT_NOT_SUPPORTED -> UnimplementedError("VK...") +// - VK_ERROR_FRAGMENTED_POOL -> ResourceExhaustedError("VK...") +// - VK_ERROR_OUT_OF_POOL_MEMORY -> ResourceExhaustedError("VK...") +// - VK_ERROR_INVALID_EXTERNAL_HANDLE -> InvalidArgumentError("VK...") +// - VK_ERROR_SURFACE_LOST_KHR -> InternalError("VK...") +// - VK_ERROR_NATIVE_WINDOW_IN_USE_KHR -> InternalError("VK...") +// - VK_ERROR_OUT_OF_DATE_KHR -> InternalError("VK...") +// - VK_ERROR_INCOMPATIBLE_DISPLAY_KHR -> InternalError("VK...") +// - VK_ERROR_VALIDATION_FAILED_EXT -> InternalError("VK...") +// - VK_ERROR_INVALID_SHADER_NV -> InternalError("VK...") +// - VK_ERROR_INVALID_DRM_FORMAT_MODIFIER_PLANE_LAYOUT_EXT -> InternalError +// - VK_ERROR_FRAGMENTATION_EXT -> ResourceExhaustedError("VK...") +// - VK_ERROR_NOT_PERMITTED_EXT -> PermissionDeniedError("VK...") +// - VK_ERROR_INVALID_DEVICE_ADDRESS_EXT -> OutOfRangeError("VK...") +// - VK_ERROR_FULL_SCREEN_EXCLUSIVE_MODE_LOST_EXT -> InternalError("VK...") +Status VkResultToStatus(VkResult result); + +} // namespace vulkan +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_VULKAN_STATUS_UTIL_H_
diff --git a/hal/vulkan/vma_allocator.cc b/hal/vulkan/vma_allocator.cc new file mode 100644 index 0000000..a40dc00 --- /dev/null +++ b/hal/vulkan/vma_allocator.cc
@@ -0,0 +1,248 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/vulkan/vma_allocator.h" + +#include "absl/flags/flag.h" +#include "absl/memory/memory.h" +#include "base/source_location.h" +#include "base/status.h" +#include "base/tracing.h" +#include "hal/buffer.h" +#include "hal/vulkan/status_util.h" +#include "hal/vulkan/vma_buffer.h" + +#if VMA_RECORDING_ENABLED +ABSL_FLAG(std::string, vma_recording_file, "", + "File path to write a CSV containing the VMA recording."); +ABSL_FLAG(bool, vma_recording_flush_after_call, false, + "Flush the VMA recording file after every call (useful if " + "crashing/not exiting cleanly)."); +#endif // VMA_RECORDING_ENABLED + +namespace iree { +namespace hal { +namespace vulkan { + +// static +StatusOr<std::unique_ptr<VmaAllocator>> VmaAllocator::Create( + VkPhysicalDevice physical_device, + const ref_ptr<VkDeviceHandle>& logical_device) { + IREE_TRACE_SCOPE0("VmaAllocator::Create"); + + const auto& syms = logical_device->syms(); + VmaVulkanFunctions vulkan_fns; + vulkan_fns.vkGetPhysicalDeviceProperties = + syms->vkGetPhysicalDeviceProperties; + vulkan_fns.vkGetPhysicalDeviceMemoryProperties = + syms->vkGetPhysicalDeviceMemoryProperties; + vulkan_fns.vkAllocateMemory = syms->vkAllocateMemory; + vulkan_fns.vkFreeMemory = syms->vkFreeMemory; + vulkan_fns.vkMapMemory = syms->vkMapMemory; + vulkan_fns.vkUnmapMemory = syms->vkUnmapMemory; + vulkan_fns.vkFlushMappedMemoryRanges = syms->vkFlushMappedMemoryRanges; + vulkan_fns.vkInvalidateMappedMemoryRanges = + syms->vkInvalidateMappedMemoryRanges; + vulkan_fns.vkBindBufferMemory = syms->vkBindBufferMemory; + vulkan_fns.vkBindImageMemory = syms->vkBindImageMemory; + vulkan_fns.vkGetBufferMemoryRequirements = + syms->vkGetBufferMemoryRequirements; + vulkan_fns.vkGetImageMemoryRequirements = syms->vkGetImageMemoryRequirements; + vulkan_fns.vkCreateBuffer = syms->vkCreateBuffer; + vulkan_fns.vkDestroyBuffer = syms->vkDestroyBuffer; + vulkan_fns.vkCreateImage = syms->vkCreateImage; + vulkan_fns.vkDestroyImage = syms->vkDestroyImage; + vulkan_fns.vkCmdCopyBuffer = syms->vkCmdCopyBuffer; + + VmaRecordSettings record_settings; +#if VMA_RECORDING_ENABLED + record_settings.flags = absl::GetFlag(FLAGS_vma_recording_flush_after_call) + ? VMA_RECORD_FLUSH_AFTER_CALL_BIT + : 0; + record_settings.pFilePath = absl::GetFlag(FLAGS_vma_recording_file).c_str(); +#else + record_settings.flags = 0; + record_settings.pFilePath = nullptr; +#endif // VMA_RECORDING_ENABLED + + VmaAllocatorCreateInfo create_info; + create_info.flags = 0; + create_info.physicalDevice = physical_device; + create_info.device = *logical_device; + create_info.preferredLargeHeapBlockSize = 64 * 1024 * 1024; + create_info.pAllocationCallbacks = logical_device->allocator(); + create_info.pDeviceMemoryCallbacks = nullptr; + create_info.frameInUseCount = 0; + create_info.pHeapSizeLimit = nullptr; + create_info.pVulkanFunctions = &vulkan_fns; + create_info.pRecordSettings = &record_settings; + ::VmaAllocator vma = VK_NULL_HANDLE; + VK_RETURN_IF_ERROR(vmaCreateAllocator(&create_info, &vma)); + + auto allocator = + absl::WrapUnique(new VmaAllocator(physical_device, logical_device, vma)); + // TODO(benvanik): query memory properties/types. + return allocator; +} + +VmaAllocator::VmaAllocator(VkPhysicalDevice physical_device, + const ref_ptr<VkDeviceHandle>& logical_device, + ::VmaAllocator vma) + : physical_device_(physical_device), + logical_device_(add_ref(logical_device)), + vma_(vma) {} + +VmaAllocator::~VmaAllocator() { + IREE_TRACE_SCOPE0("VmaAllocator::dtor"); + vmaDestroyAllocator(vma_); +} + +bool VmaAllocator::CanUseBufferLike(Allocator* source_allocator, + MemoryTypeBitfield memory_type, + BufferUsageBitfield buffer_usage, + BufferUsageBitfield intended_usage) const { + // TODO(benvanik): ensure there is a memory type that can satisfy the request. + return source_allocator == this; +} + +bool VmaAllocator::CanAllocate(MemoryTypeBitfield memory_type, + BufferUsageBitfield buffer_usage, + size_t allocation_size) const { + // TODO(benvnik): ensure there is a memory type that can satisfy the request. + return true; +} + +Status VmaAllocator::MakeCompatible(MemoryTypeBitfield* memory_type, + BufferUsageBitfield* buffer_usage) const { + // TODO(benvanik): mutate to match supported memory types. + return OkStatus(); +} + +StatusOr<ref_ptr<VmaBuffer>> VmaAllocator::AllocateInternal( + MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, + MemoryAccessBitfield allowed_access, size_t allocation_size, + VmaAllocationCreateFlags flags) { + IREE_TRACE_SCOPE0("VmaAllocator::AllocateInternal"); + + VkBufferCreateInfo buffer_create_info; + buffer_create_info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; + buffer_create_info.pNext = nullptr; + buffer_create_info.flags = 0; + buffer_create_info.size = allocation_size; + buffer_create_info.usage = 0; + if (AllBitsSet(buffer_usage, BufferUsage::kTransfer)) { + buffer_create_info.usage |= VK_BUFFER_USAGE_TRANSFER_SRC_BIT; + buffer_create_info.usage |= VK_BUFFER_USAGE_TRANSFER_DST_BIT; + } + if (AllBitsSet(buffer_usage, BufferUsage::kDispatch)) { + buffer_create_info.usage |= VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT; + buffer_create_info.usage |= VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; + buffer_create_info.usage |= VK_BUFFER_USAGE_INDIRECT_BUFFER_BIT; + } + buffer_create_info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; + buffer_create_info.queueFamilyIndexCount = 0; + buffer_create_info.pQueueFamilyIndices = nullptr; + + VmaAllocationCreateInfo allocation_create_info; + allocation_create_info.flags = flags; + allocation_create_info.usage = VMA_MEMORY_USAGE_UNKNOWN; + allocation_create_info.requiredFlags = 0; + allocation_create_info.preferredFlags = 0; + allocation_create_info.memoryTypeBits = 0; // Automatic selection. + allocation_create_info.pool = VK_NULL_HANDLE; + allocation_create_info.pUserData = nullptr; + if (AllBitsSet(memory_type, MemoryType::kDeviceLocal)) { + if (AllBitsSet(memory_type, MemoryType::kHostVisible)) { + // Device-local, host-visible. + allocation_create_info.usage = VMA_MEMORY_USAGE_CPU_TO_GPU; + allocation_create_info.preferredFlags |= + VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; + } else { + // Device-local only. + allocation_create_info.usage = VMA_MEMORY_USAGE_GPU_ONLY; + allocation_create_info.requiredFlags |= + VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; + } + } else { + if (AllBitsSet(memory_type, MemoryType::kDeviceVisible)) { + // Host-local, device-visible. + allocation_create_info.usage = VMA_MEMORY_USAGE_GPU_TO_CPU; + } else { + // Host-local only. + allocation_create_info.usage = VMA_MEMORY_USAGE_CPU_ONLY; + } + } + if (AllBitsSet(memory_type, MemoryType::kHostCached)) { + allocation_create_info.requiredFlags |= VK_MEMORY_PROPERTY_HOST_CACHED_BIT; + } + if (AllBitsSet(memory_type, MemoryType::kHostCoherent)) { + allocation_create_info.requiredFlags |= + VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; + } + if (AllBitsSet(memory_type, MemoryType::kTransient)) { + allocation_create_info.preferredFlags |= + VK_MEMORY_PROPERTY_LAZILY_ALLOCATED_BIT; + } + if (AllBitsSet(buffer_usage, BufferUsage::kMapping)) { + allocation_create_info.requiredFlags |= VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; + } + + VkBuffer buffer = VK_NULL_HANDLE; + VmaAllocation allocation = VK_NULL_HANDLE; + VmaAllocationInfo allocation_info; + VK_RETURN_IF_ERROR(vmaCreateBuffer(vma_, &buffer_create_info, + &allocation_create_info, &buffer, + &allocation, &allocation_info)); + + return make_ref<VmaBuffer>(this, memory_type, allowed_access, buffer_usage, + allocation_size, 0, allocation_size, buffer, + allocation, allocation_info); +} + +StatusOr<ref_ptr<Buffer>> VmaAllocator::Allocate( + MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, + size_t allocation_size) { + IREE_TRACE_SCOPE0("VmaAllocator::Allocate"); + return AllocateInternal(memory_type, buffer_usage, MemoryAccess::kAll, + allocation_size, /*flags=*/0); +} + +StatusOr<ref_ptr<Buffer>> VmaAllocator::AllocateConstant( + BufferUsageBitfield buffer_usage, ref_ptr<Buffer> source_buffer) { + IREE_TRACE_SCOPE0("VmaAllocator::AllocateConstant"); + // TODO(benvanik): import memory to avoid the copy. + ASSIGN_OR_RETURN( + auto buffer, + AllocateInternal(MemoryType::kDeviceLocal | MemoryType::kHostVisible, + buffer_usage, + MemoryAccess::kRead | MemoryAccess::kDiscardWrite, + source_buffer->byte_length(), + /*flags=*/0)); + RETURN_IF_ERROR(buffer->CopyData(0, source_buffer.get(), 0, kWholeBuffer)); + buffer->set_allowed_access(MemoryAccess::kRead); + return buffer; +} + +StatusOr<ref_ptr<Buffer>> VmaAllocator::WrapMutable( + MemoryTypeBitfield memory_type, MemoryAccessBitfield allowed_access, + BufferUsageBitfield buffer_usage, void* data, size_t data_length) { + IREE_TRACE_SCOPE0("VmaAllocator::WrapMutable"); + // TODO(benvanik): import memory. + return UnimplementedErrorBuilder(IREE_LOC) + << "Wrapping host memory is not yet implemented"; +} + +} // namespace vulkan +} // namespace hal +} // namespace iree
diff --git a/hal/vulkan/vma_allocator.h b/hal/vulkan/vma_allocator.h new file mode 100644 index 0000000..ab908b3 --- /dev/null +++ b/hal/vulkan/vma_allocator.h
@@ -0,0 +1,110 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_VULKAN_VMA_ALLOCATOR_H_ +#define IREE_HAL_VULKAN_VMA_ALLOCATOR_H_ + +#include <vulkan/vulkan.h> + +#include <memory> + +#include "base/status.h" +#include "hal/allocator.h" +#include "hal/vulkan/dynamic_symbols.h" +#include "hal/vulkan/handle_util.h" +#include "hal/vulkan/internal_vk_mem_alloc.h" + +namespace iree { +namespace hal { +namespace vulkan { + +class VmaBuffer; + +// A HAL allocator using the Vulkan Memory Allocator (VMA) to manage memory. +// VMA (//third_party/vulkan_memory_allocator) provides dlmalloc-like behavior +// with suballocations made with various policies (best fit, first fit, etc). +// This reduces the number of allocations we need from the Vulkan implementation +// (which can sometimes be limited to as little as 4096 total allowed) and +// manages higher level allocation semantics like slab allocation and +// defragmentation. +// +// VMA is internally synchronized and the functionality exposed on the HAL +// interface is thread-safe. +// +// More information: +// https://github.com/GPUOpen-LibrariesAndSDKs/VulkanMemoryAllocator +// https://gpuopen-librariesandsdks.github.io/VulkanMemoryAllocator/html/ +class VmaAllocator final : public Allocator { + public: + static StatusOr<std::unique_ptr<VmaAllocator>> Create( + VkPhysicalDevice physical_device, + const ref_ptr<VkDeviceHandle>& logical_device); + + ~VmaAllocator() override; + + const ref_ptr<DynamicSymbols>& syms() const { + return logical_device_->syms(); + } + + ::VmaAllocator vma() const { return vma_; } + + bool CanUseBufferLike(Allocator* source_allocator, + MemoryTypeBitfield memory_type, + BufferUsageBitfield buffer_usage, + BufferUsageBitfield intended_usage) const override; + + bool CanAllocate(MemoryTypeBitfield memory_type, + BufferUsageBitfield buffer_usage, + size_t allocation_size) const override; + + Status MakeCompatible(MemoryTypeBitfield* memory_type, + BufferUsageBitfield* buffer_usage) const override; + + StatusOr<ref_ptr<Buffer>> Allocate(MemoryTypeBitfield memory_type, + BufferUsageBitfield buffer_usage, + size_t allocation_size) override; + + StatusOr<ref_ptr<Buffer>> AllocateConstant( + BufferUsageBitfield buffer_usage, ref_ptr<Buffer> source_buffer) override; + + StatusOr<ref_ptr<Buffer>> WrapMutable(MemoryTypeBitfield memory_type, + MemoryAccessBitfield allowed_access, + BufferUsageBitfield buffer_usage, + void* data, + size_t data_length) override; + + private: + VmaAllocator(VkPhysicalDevice physical_device, + const ref_ptr<VkDeviceHandle>& logical_device, + ::VmaAllocator vma); + + StatusOr<ref_ptr<VmaBuffer>> AllocateInternal( + MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, + MemoryAccessBitfield allowed_access, size_t allocation_size, + VmaAllocationCreateFlags flags); + + VkPhysicalDevice physical_device_; + ref_ptr<VkDeviceHandle> logical_device_; + + // Internally synchronized. We could externally synchronize if we thought it + // was worth it, however I'm not sure we'd be able to do much better with the + // current Allocator API. + ::VmaAllocator vma_; +}; + +} // namespace vulkan +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_VULKAN_VMA_ALLOCATOR_H_
diff --git a/hal/vulkan/vma_buffer.cc b/hal/vulkan/vma_buffer.cc new file mode 100644 index 0000000..b7d3b6b --- /dev/null +++ b/hal/vulkan/vma_buffer.cc
@@ -0,0 +1,163 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/vulkan/vma_buffer.h" + +#include "base/source_location.h" +#include "base/status.h" +#include "base/tracing.h" +#include "hal/vulkan/status_util.h" +#include "hal/vulkan/vma_allocator.h" + +namespace iree { +namespace hal { +namespace vulkan { + +VmaBuffer::VmaBuffer(VmaAllocator* allocator, MemoryTypeBitfield memory_type, + MemoryAccessBitfield allowed_access, + BufferUsageBitfield usage, device_size_t allocation_size, + device_size_t byte_offset, device_size_t byte_length, + VkBuffer buffer, VmaAllocation allocation, + VmaAllocationInfo allocation_info) + : Buffer(allocator, memory_type, allowed_access, usage, allocation_size, + byte_offset, byte_length), + vma_(allocator->vma()), + buffer_(buffer), + allocation_(allocation), + allocation_info_(allocation_info) { + // TODO(benvanik): set debug name instead and use the + // VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT flag. + vmaSetAllocationUserData(vma_, allocation_, this); +} + +VmaBuffer::~VmaBuffer() { + IREE_TRACE_SCOPE0("VmaBuffer::dtor"); + vmaDestroyBuffer(vma_, buffer_, allocation_); +} + +Status VmaBuffer::FillImpl(device_size_t byte_offset, device_size_t byte_length, + const void* pattern, device_size_t pattern_length) { + ASSIGN_OR_RETURN(auto mapping, MapMemory<uint8_t>(MemoryAccess::kDiscardWrite, + byte_offset, byte_length)); + void* data_ptr = static_cast<void*>(mapping.mutable_data()); + switch (pattern_length) { + case 1: { + uint8_t* data = static_cast<uint8_t*>(data_ptr); + uint8_t value_bits = *static_cast<const uint8_t*>(pattern); + std::fill_n(data + byte_offset, byte_length, value_bits); + break; + } + case 2: { + uint16_t* data = static_cast<uint16_t*>(data_ptr); + uint16_t value_bits = *static_cast<const uint16_t*>(pattern); + std::fill_n(data + byte_offset / sizeof(uint16_t), + byte_length / sizeof(uint16_t), value_bits); + break; + } + case 4: { + uint32_t* data = static_cast<uint32_t*>(data_ptr); + uint32_t value_bits = *static_cast<const uint32_t*>(pattern); + std::fill_n(data + byte_offset / sizeof(uint32_t), + byte_length / sizeof(uint32_t), value_bits); + break; + } + default: + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Unsupported scalar data size: " << pattern_length; + } + return OkStatus(); +} + +Status VmaBuffer::ReadDataImpl(device_size_t source_offset, void* data, + device_size_t data_length) { + ASSIGN_OR_RETURN( + auto mapping, + MapMemory<uint8_t>(MemoryAccess::kRead, source_offset, data_length)); + std::memcpy(data, mapping.data(), mapping.byte_length()); + return OkStatus(); +} + +Status VmaBuffer::WriteDataImpl(device_size_t target_offset, const void* data, + device_size_t data_length) { + ASSIGN_OR_RETURN(auto mapping, + MapMemory<uint8_t>(MemoryAccess::kDiscardWrite, + target_offset, data_length)); + std::memcpy(mapping.mutable_data(), data, mapping.byte_length()); + return OkStatus(); +} + +Status VmaBuffer::CopyDataImpl(device_size_t target_offset, + Buffer* source_buffer, + device_size_t source_offset, + device_size_t data_length) { + // This is pretty terrible. Let's not do this. + // TODO(benvanik): a way for allocators to indicate transfer compat. + ASSIGN_OR_RETURN(auto source_mapping, + source_buffer->MapMemory<uint8_t>( + MemoryAccess::kRead, source_offset, data_length)); + CHECK_EQ(data_length, source_mapping.size()); + ASSIGN_OR_RETURN(auto target_mapping, + MapMemory<uint8_t>(MemoryAccess::kDiscardWrite, + target_offset, data_length)); + CHECK_EQ(data_length, target_mapping.size()); + std::memcpy(target_mapping.mutable_data() + target_offset, + source_mapping.data(), data_length); + return OkStatus(); +} + +Status VmaBuffer::MapMemoryImpl(MappingMode mapping_mode, + MemoryAccessBitfield memory_access, + device_size_t local_byte_offset, + device_size_t local_byte_length, + void** out_data) { + uint8_t* data_ptr = nullptr; + VK_RETURN_IF_ERROR( + vmaMapMemory(vma_, allocation_, reinterpret_cast<void**>(&data_ptr))); + *out_data = data_ptr + local_byte_offset; + + // If we mapped for discard scribble over the bytes. This is not a mandated + // behavior but it will make debugging issues easier. Alternatively for + // heap buffers we could reallocate them such that ASAN yells, but that + // would only work if the entire buffer was discarded. +#ifndef NDEBUG + if (AnyBitSet(memory_access & MemoryAccess::kDiscard)) { + std::memset(data_ptr + local_byte_offset, 0xCD, local_byte_length); + } +#endif // !NDEBUG + + return OkStatus(); +} + +Status VmaBuffer::UnmapMemoryImpl(device_size_t local_byte_offset, + device_size_t local_byte_length, void* data) { + vmaUnmapMemory(vma_, allocation_); + return OkStatus(); +} + +Status VmaBuffer::InvalidateMappedMemoryImpl(device_size_t local_byte_offset, + device_size_t local_byte_length) { + vmaInvalidateAllocation(vma_, allocation_, local_byte_offset, + local_byte_length); + return OkStatus(); +} + +Status VmaBuffer::FlushMappedMemoryImpl(device_size_t local_byte_offset, + device_size_t local_byte_length) { + vmaFlushAllocation(vma_, allocation_, local_byte_offset, local_byte_length); + return OkStatus(); +} + +} // namespace vulkan +} // namespace hal +} // namespace iree
diff --git a/hal/vulkan/vma_buffer.h b/hal/vulkan/vma_buffer.h new file mode 100644 index 0000000..5642c87 --- /dev/null +++ b/hal/vulkan/vma_buffer.h
@@ -0,0 +1,79 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_VULKAN_VMA_BUFFER_H_ +#define IREE_HAL_VULKAN_VMA_BUFFER_H_ + +#include <vulkan/vulkan.h> + +#include "hal/buffer.h" +#include "vk_mem_alloc.h" + +namespace iree { +namespace hal { +namespace vulkan { + +class VmaAllocator; + +// A buffer implementation representing an allocation made from within a pool of +// a Vulkan Memory Allocator instance. See VmaAllocator for more information. +class VmaBuffer final : public Buffer { + public: + VmaBuffer(VmaAllocator* allocator, MemoryTypeBitfield memory_type, + MemoryAccessBitfield allowed_access, BufferUsageBitfield usage, + device_size_t allocation_size, device_size_t byte_offset, + device_size_t byte_length, VkBuffer buffer, + VmaAllocation allocation, VmaAllocationInfo allocation_info); + ~VmaBuffer() override; + + VkBuffer handle() const { return buffer_; } + VmaAllocation allocation() const { return allocation_; } + const VmaAllocationInfo& allocation_info() const { return allocation_info_; } + + // Exposed so that VmaAllocator can reset access after initial mapping. + using Buffer::set_allowed_access; + + private: + Status FillImpl(device_size_t byte_offset, device_size_t byte_length, + const void* pattern, device_size_t pattern_length) override; + Status ReadDataImpl(device_size_t source_offset, void* data, + device_size_t data_length) override; + Status WriteDataImpl(device_size_t target_offset, const void* data, + device_size_t data_length) override; + Status CopyDataImpl(device_size_t target_offset, Buffer* source_buffer, + device_size_t source_offset, + device_size_t data_length) override; + Status MapMemoryImpl(MappingMode mapping_mode, + MemoryAccessBitfield memory_access, + device_size_t local_byte_offset, + device_size_t local_byte_length, + void** out_data) override; + Status UnmapMemoryImpl(device_size_t local_byte_offset, + device_size_t local_byte_length, void* data) override; + Status InvalidateMappedMemoryImpl(device_size_t local_byte_offset, + device_size_t local_byte_length) override; + Status FlushMappedMemoryImpl(device_size_t local_byte_offset, + device_size_t local_byte_length) override; + + ::VmaAllocator vma_; + VkBuffer buffer_; + VmaAllocation allocation_; + VmaAllocationInfo allocation_info_; +}; + +} // namespace vulkan +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_VULKAN_VMA_BUFFER_H_
diff --git a/hal/vulkan/vulkan_device.cc b/hal/vulkan/vulkan_device.cc new file mode 100644 index 0000000..8fe5e60 --- /dev/null +++ b/hal/vulkan/vulkan_device.cc
@@ -0,0 +1,500 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/vulkan/vulkan_device.h" + +#include <functional> +#include <utility> + +#include "absl/container/inlined_vector.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#include "base/status.h" +#include "base/tracing.h" +#include "hal/command_buffer_validation.h" +#include "hal/command_queue.h" +#include "hal/fence.h" +#include "hal/vulkan/direct_command_buffer.h" +#include "hal/vulkan/direct_command_queue.h" +#include "hal/vulkan/dynamic_symbols.h" +#include "hal/vulkan/extensibility_util.h" +#include "hal/vulkan/legacy_fence.h" +#include "hal/vulkan/native_binary_semaphore.h" +#include "hal/vulkan/native_event.h" +#include "hal/vulkan/pipeline_cache.h" +#include "hal/vulkan/status_util.h" +#include "hal/vulkan/vma_allocator.h" + +namespace iree { +namespace hal { +namespace vulkan { + +namespace { + +constexpr uint32_t kInvalidQueueFamilyIndex = -1; + +struct QueueFamilyInfo { + uint32_t dispatch_index = kInvalidQueueFamilyIndex; + uint32_t dispatch_queue_count = 0; + uint32_t transfer_index = kInvalidQueueFamilyIndex; + uint32_t transfer_queue_count = 0; +}; + +// Finds the first queue in the listing (which is usually the driver-preferred) +// that has all of the |required_queue_flags| and none of the +// |excluded_queue_flags|. +// Returns kInvalidQueueFamilyIndex if no matching queue is found. +uint32_t FindFirstQueueFamilyWithFlags( + absl::Span<const VkQueueFamilyProperties> queue_family_properties, + uint32_t required_queue_flags, uint32_t excluded_queue_flags) { + for (int queue_family_index = 0; + queue_family_index < queue_family_properties.size(); + ++queue_family_index) { + const auto& properties = queue_family_properties[queue_family_index]; + if ((properties.queueFlags & required_queue_flags) == + required_queue_flags && + (properties.queueFlags & excluded_queue_flags) == 0) { + return queue_family_index; + } + } + return kInvalidQueueFamilyIndex; +} + +// Selects queue family indices for compute and transfer queues. +// Note that both queue families may be the same if there is only one family +// available. +StatusOr<QueueFamilyInfo> SelectQueueFamilies( + VkPhysicalDevice physical_device, const ref_ptr<DynamicSymbols>& syms) { + // Enumerate queue families available on the device. + uint32_t queue_family_count = 0; + syms->vkGetPhysicalDeviceQueueFamilyProperties(physical_device, + &queue_family_count, nullptr); + absl::InlinedVector<VkQueueFamilyProperties, 4> queue_family_properties( + queue_family_count); + syms->vkGetPhysicalDeviceQueueFamilyProperties( + physical_device, &queue_family_count, queue_family_properties.data()); + + QueueFamilyInfo queue_family_info; + + // Try to find a dedicated compute queue (no graphics caps). + // Some may support both transfer and compute. If that fails then fallback to + // any queue that supports compute. + queue_family_info.dispatch_index = FindFirstQueueFamilyWithFlags( + queue_family_properties, VK_QUEUE_COMPUTE_BIT, VK_QUEUE_GRAPHICS_BIT); + if (queue_family_info.dispatch_index == kInvalidQueueFamilyIndex) { + queue_family_info.dispatch_index = FindFirstQueueFamilyWithFlags( + queue_family_properties, VK_QUEUE_COMPUTE_BIT, 0); + } + if (queue_family_info.dispatch_index == kInvalidQueueFamilyIndex) { + return NotFoundErrorBuilder(IREE_LOC) + << "Unable to find any queue family support compute operations"; + } + queue_family_info.dispatch_queue_count = + queue_family_properties[queue_family_info.dispatch_index].queueCount; + + // Try to find a dedicated transfer queue (no compute or graphics caps). + // Not all devices have one, and some have only a queue family for everything + // and possibly a queue family just for compute/etc. If that fails then + // fallback to any queue that supports transfer. Finally, if /that/ fails then + // we just won't create a transfer queue and instead use the compute queue for + // all operations. + queue_family_info.transfer_index = FindFirstQueueFamilyWithFlags( + queue_family_properties, VK_QUEUE_TRANSFER_BIT, + VK_QUEUE_COMPUTE_BIT | VK_QUEUE_GRAPHICS_BIT); + if (queue_family_info.transfer_index == kInvalidQueueFamilyIndex) { + queue_family_info.transfer_index = FindFirstQueueFamilyWithFlags( + queue_family_properties, VK_QUEUE_TRANSFER_BIT, VK_QUEUE_GRAPHICS_BIT); + } + if (queue_family_info.transfer_index == kInvalidQueueFamilyIndex) { + queue_family_info.transfer_index = FindFirstQueueFamilyWithFlags( + queue_family_properties, VK_QUEUE_TRANSFER_BIT, 0); + } + if (queue_family_info.transfer_index != kInvalidQueueFamilyIndex) { + queue_family_info.transfer_queue_count = + queue_family_properties[queue_family_info.transfer_index].queueCount; + } + + // Ensure that we don't share the dispatch queues with transfer queues if that + // would put us over the queue count. + if (queue_family_info.dispatch_index == queue_family_info.transfer_index) { + queue_family_info.transfer_queue_count = std::min( + queue_family_properties[queue_family_info.dispatch_index].queueCount - + queue_family_info.dispatch_queue_count, + queue_family_info.transfer_queue_count); + } + + return queue_family_info; +} + +// Creates a transient command pool for the given queue family. +// Command buffers allocated from the pool must only be issued on queues +// belonging to the specified family. +StatusOr<ref_ptr<VkCommandPoolHandle>> CreateTransientCommandPool( + const ref_ptr<VkDeviceHandle>& logical_device, + uint32_t queue_family_index) { + VkCommandPoolCreateInfo create_info; + create_info.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; + create_info.pNext = nullptr; + create_info.flags = VK_COMMAND_POOL_CREATE_TRANSIENT_BIT; + create_info.queueFamilyIndex = queue_family_index; + + auto command_pool = make_ref<VkCommandPoolHandle>(logical_device); + VK_RETURN_IF_ERROR(logical_device->syms()->vkCreateCommandPool( + *logical_device, &create_info, logical_device->allocator(), + command_pool->mutable_value())); + return command_pool; +} + +} // namespace + +// static +StatusOr<std::shared_ptr<VulkanDevice>> VulkanDevice::Create( + const DeviceInfo& device_info, VkPhysicalDevice physical_device, + const ExtensibilitySpec& extensibility_spec, + const ref_ptr<DynamicSymbols>& syms) { + IREE_TRACE_SCOPE0("VulkanDevice::Create"); + + // Find the layers and extensions we need (or want) that are also available + // on the device. This will fail when required ones are not present. + ASSIGN_OR_RETURN( + auto enabled_layer_names, + MatchAvailableDeviceLayers(physical_device, extensibility_spec, *syms)); + ASSIGN_OR_RETURN(auto enabled_extension_names, + MatchAvailableDeviceExtensions(physical_device, + extensibility_spec, *syms)); + auto enabled_device_extensions = + PopulateEnabledDeviceExtensions(enabled_extension_names); + + // Find queue families we will expose as HAL queues. + ASSIGN_OR_RETURN(auto queue_family_info, + SelectQueueFamilies(physical_device, syms)); + + // Limit the number of queues we create (for now). + // We may want to allow this to grow, but each queue adds overhead and we need + // to measure to make sure we can effectively use them all. + queue_family_info.dispatch_queue_count = + std::min(2u, queue_family_info.dispatch_queue_count); + queue_family_info.transfer_queue_count = + std::min(1u, queue_family_info.transfer_queue_count); + bool has_dedicated_transfer_queues = + queue_family_info.transfer_queue_count > 0; + + // Setup the queue info we'll be using. + // Each queue here (created from within a family) will map to a HAL queue. + // + // Note that we need to handle the case where we have transfer queues that are + // of the same queue family as the dispatch queues: Vulkan requires that all + // queues created from the same family are done in the same + // VkDeviceQueueCreateInfo struct. + DVLOG(1) << "Creating " << queue_family_info.dispatch_queue_count + << " dispatch queue(s) in queue family " + << queue_family_info.dispatch_index; + absl::InlinedVector<VkDeviceQueueCreateInfo, 2> queue_create_info; + absl::InlinedVector<float, 4> dispatch_queue_priorities; + absl::InlinedVector<float, 4> transfer_queue_priorities; + queue_create_info.push_back({}); + auto& dispatch_queue_info = queue_create_info.back(); + dispatch_queue_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; + dispatch_queue_info.pNext = nullptr; + dispatch_queue_info.flags = 0; + dispatch_queue_info.queueFamilyIndex = queue_family_info.dispatch_index; + dispatch_queue_info.queueCount = queue_family_info.dispatch_queue_count; + if (has_dedicated_transfer_queues) { + if (queue_family_info.dispatch_index == queue_family_info.transfer_index) { + DVLOG(1) << "Creating " << queue_family_info.transfer_queue_count + << " dedicated transfer queue(s) in shared queue family " + << queue_family_info.transfer_index; + dispatch_queue_info.queueCount += queue_family_info.transfer_queue_count; + } else { + DVLOG(1) << "Creating " << queue_family_info.transfer_queue_count + << " dedicated transfer queue(s) in independent queue family " + << queue_family_info.transfer_index; + queue_create_info.push_back({}); + auto& transfer_queue_info = queue_create_info.back(); + transfer_queue_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; + transfer_queue_info.pNext = nullptr; + transfer_queue_info.queueFamilyIndex = queue_family_info.transfer_index; + transfer_queue_info.queueCount = queue_family_info.transfer_queue_count; + transfer_queue_info.flags = 0; + transfer_queue_priorities.resize(transfer_queue_info.queueCount); + transfer_queue_info.pQueuePriorities = transfer_queue_priorities.data(); + } + } + dispatch_queue_priorities.resize(dispatch_queue_info.queueCount); + dispatch_queue_info.pQueuePriorities = dispatch_queue_priorities.data(); + + // TODO(benvanik): specify features with VkPhysicalDeviceFeatures. + + // Create device and its queues. + VkDeviceCreateInfo device_create_info = {}; + device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO; + device_create_info.pNext = nullptr; + device_create_info.enabledLayerCount = enabled_layer_names.size(); + device_create_info.ppEnabledLayerNames = enabled_layer_names.data(); + device_create_info.enabledExtensionCount = enabled_extension_names.size(); + device_create_info.ppEnabledExtensionNames = enabled_extension_names.data(); + device_create_info.queueCreateInfoCount = queue_create_info.size(); + device_create_info.pQueueCreateInfos = queue_create_info.data(); + device_create_info.pEnabledFeatures = nullptr; + auto logical_device = make_ref<VkDeviceHandle>( + syms, enabled_device_extensions, /*allocator=*/nullptr); + VK_RETURN_IF_ERROR(syms->vkCreateDevice(physical_device, &device_create_info, + logical_device->allocator(), + logical_device->mutable_value())); + + // Create the device memory allocator. + // TODO(benvanik): allow other types to be plugged in. + ASSIGN_OR_RETURN(auto allocator, + VmaAllocator::Create(physical_device, logical_device)); + + // Create command pools for each queue family. If we don't have a transfer + // queue then we'll ignore that one and just use the dispatch pool. + // If we wanted to expose the pools through the HAL to allow the VM to more + // effectively manage them (pool per fiber, etc) we could, however I doubt the + // overhead of locking the pool will be even a blip. + ASSIGN_OR_RETURN(auto dispatch_command_pool, + CreateTransientCommandPool( + logical_device, queue_family_info.dispatch_index)); + ref_ptr<VkCommandPoolHandle> transfer_command_pool; + if (has_dedicated_transfer_queues) { + ASSIGN_OR_RETURN(transfer_command_pool, + CreateTransientCommandPool( + logical_device, queue_family_info.transfer_index)); + } + + // Get the queues and create the HAL wrappers. + absl::InlinedVector<std::unique_ptr<CommandQueue>, 4> command_queues; + for (uint32_t i = 0; i < queue_family_info.dispatch_queue_count; ++i) { + VkQueue queue = VK_NULL_HANDLE; + syms->vkGetDeviceQueue(*logical_device, queue_family_info.dispatch_index, i, + &queue); + std::string queue_name = absl::StrCat(device_info.name(), ":d", i); + command_queues.push_back(absl::make_unique<DirectCommandQueue>( + std::move(queue_name), + CommandCategory::kDispatch | CommandCategory::kTransfer, logical_device, + queue)); + } + if (has_dedicated_transfer_queues) { + uint32_t base_queue_index = 0; + if (queue_family_info.dispatch_index == queue_family_info.transfer_index) { + // Sharing a family, so transfer queues follow compute queues. + base_queue_index = queue_family_info.dispatch_index; + } + for (uint32_t i = 0; i < queue_family_info.transfer_queue_count; ++i) { + VkQueue queue = VK_NULL_HANDLE; + syms->vkGetDeviceQueue(*logical_device, queue_family_info.transfer_index, + base_queue_index + i, &queue); + std::string queue_name = absl::StrCat(device_info.name(), ":t", i); + command_queues.push_back(absl::make_unique<DirectCommandQueue>( + std::move(queue_name), CommandCategory::kTransfer, logical_device, + queue)); + } + } + + // TODO(b/140141417): implement timeline semaphore fences and switch here. + ASSIGN_OR_RETURN(auto legacy_fence_pool, + LegacyFencePool::Create(add_ref(logical_device))); + + return std::make_shared<VulkanDevice>( + CtorKey{}, device_info, physical_device, std::move(logical_device), + std::move(allocator), std::move(command_queues), + std::move(dispatch_command_pool), std::move(transfer_command_pool), + std::move(legacy_fence_pool)); +} + +VulkanDevice::VulkanDevice( + CtorKey ctor_key, const DeviceInfo& device_info, + VkPhysicalDevice physical_device, ref_ptr<VkDeviceHandle> logical_device, + std::unique_ptr<Allocator> allocator, + absl::InlinedVector<std::unique_ptr<CommandQueue>, 4> command_queues, + ref_ptr<VkCommandPoolHandle> dispatch_command_pool, + ref_ptr<VkCommandPoolHandle> transfer_command_pool, + ref_ptr<LegacyFencePool> legacy_fence_pool) + : Device(device_info), + physical_device_(physical_device), + logical_device_(std::move(logical_device)), + allocator_(std::move(allocator)), + command_queues_(std::move(command_queues)), + descriptor_pool_cache_( + make_ref<DescriptorPoolCache>(add_ref(logical_device_))), + dispatch_command_pool_(std::move(dispatch_command_pool)), + transfer_command_pool_(std::move(transfer_command_pool)), + legacy_fence_pool_(std::move(legacy_fence_pool)) { + // Populate the queue lists based on queue capabilities. + for (auto& command_queue : command_queues_) { + if (command_queue->can_dispatch()) { + dispatch_queues_.push_back(command_queue.get()); + if (transfer_command_pool_ == VK_NULL_HANDLE) { + transfer_queues_.push_back(command_queue.get()); + } + } else { + transfer_queues_.push_back(command_queue.get()); + } + } +} + +VulkanDevice::~VulkanDevice() { + IREE_TRACE_SCOPE0("VulkanDevice::dtor"); + + // Drop all command queues. These may wait until idle. + command_queues_.clear(); + dispatch_queues_.clear(); + transfer_queues_.clear(); + + // Drop command pools now that we know there are no more outstanding command + // buffers. + dispatch_command_pool_.reset(); + transfer_command_pool_.reset(); + + // Now that no commands are outstanding we can release all descriptor sets. + descriptor_pool_cache_.reset(); + + // Finally, destroy the device. + logical_device_.reset(); +} + +std::shared_ptr<ExecutableCache> VulkanDevice::CreateExecutableCache() { + IREE_TRACE_SCOPE0("VulkanDevice::CreateExecutableCache"); + return std::make_shared<PipelineCache>(logical_device_); +} + +StatusOr<ref_ptr<CommandBuffer>> VulkanDevice::CreateCommandBuffer( + CommandBufferModeBitfield mode, + CommandCategoryBitfield command_categories) { + IREE_TRACE_SCOPE0("VulkanDevice::CreateCommandBuffer"); + + // Select the command pool to used based on the types of commands used. + // Note that we may not have a dedicated transfer command pool if there are no + // dedicated transfer queues. + ref_ptr<VkCommandPoolHandle> command_pool; + if (transfer_command_pool_ && + !AllBitsSet(command_categories, CommandCategory::kDispatch)) { + command_pool = add_ref(transfer_command_pool_); + } else { + command_pool = add_ref(dispatch_command_pool_); + } + + VkCommandBufferAllocateInfo allocate_info; + allocate_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; + allocate_info.pNext = nullptr; + allocate_info.commandPool = *command_pool; + allocate_info.commandBufferCount = 1; + allocate_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; + + VkCommandBuffer command_buffer = VK_NULL_HANDLE; + { + absl::MutexLock lock(command_pool->mutex()); + VK_RETURN_IF_ERROR(syms()->vkAllocateCommandBuffers( + *logical_device_, &allocate_info, &command_buffer)); + } + + // TODO(b/140026716): conditionally enable validation. + auto impl = make_ref<DirectCommandBuffer>( + allocator(), mode, command_categories, add_ref(descriptor_pool_cache_), + add_ref(command_pool), command_buffer); + return WrapCommandBufferWithValidation(std::move(impl)); +} + +StatusOr<ref_ptr<Event>> VulkanDevice::CreateEvent() { + IREE_TRACE_SCOPE0("VulkanDevice::CreateEvent"); + + // TODO(b/138729892): pool events. + VkEventCreateInfo create_info; + create_info.sType = VK_STRUCTURE_TYPE_EVENT_CREATE_INFO; + create_info.pNext = nullptr; + create_info.flags = 0; + VkEvent event_handle = VK_NULL_HANDLE; + VK_RETURN_IF_ERROR(syms()->vkCreateEvent(*logical_device_, &create_info, + logical_device_->allocator(), + &event_handle)); + + return make_ref<NativeEvent>(add_ref(logical_device_), event_handle); +} + +StatusOr<ref_ptr<BinarySemaphore>> VulkanDevice::CreateBinarySemaphore( + bool initial_value) { + IREE_TRACE_SCOPE0("VulkanDevice::CreateBinarySemaphore"); + + VkSemaphoreCreateInfo create_info; + create_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_CREATE_INFO; + create_info.pNext = nullptr; + create_info.flags = initial_value ? VK_FENCE_CREATE_SIGNALED_BIT : 0; + VkSemaphore semaphore_handle = VK_NULL_HANDLE; + VK_RETURN_IF_ERROR(syms()->vkCreateSemaphore(*logical_device_, &create_info, + logical_device_->allocator(), + &semaphore_handle)); + + return make_ref<NativeBinarySemaphore>(add_ref(logical_device_), + semaphore_handle); +} + +StatusOr<ref_ptr<TimelineSemaphore>> VulkanDevice::CreateTimelineSemaphore( + uint64_t initial_value) { + IREE_TRACE_SCOPE0("VulkanDevice::CreateTimelineSemaphore"); + + // TODO(b/140141417): implement timeline semaphores. + return UnimplementedErrorBuilder(IREE_LOC) + << "Timeline semaphores not yet implemented"; +} + +StatusOr<ref_ptr<Fence>> VulkanDevice::CreateFence(uint64_t initial_value) { + IREE_TRACE_SCOPE0("VulkanDevice::CreateFence"); + + // TODO(b/140141417): implement timeline semaphore fences and switch here. + // NOTE: we'll want some magic factory so that we can cleanly compile out the + // legacy implementation and pool. + + return make_ref<LegacyFence>(add_ref(legacy_fence_pool_), initial_value); +} + +Status VulkanDevice::WaitAllFences(absl::Span<const FenceValue> fences, + absl::Time deadline) { + IREE_TRACE_SCOPE0("VulkanDevice::WaitAllFences"); + + // TODO(b/140141417): implement timeline semaphore fences and switch here. + + return LegacyFence::WaitForFences(logical_device_.get(), fences, + /*wait_all=*/true, deadline); +} + +StatusOr<int> VulkanDevice::WaitAnyFence(absl::Span<const FenceValue> fences, + absl::Time deadline) { + IREE_TRACE_SCOPE0("VulkanDevice::WaitAnyFence"); + + // TODO(b/140141417): implement timeline semaphore fences and switch here. + + return LegacyFence::WaitForFences(logical_device_.get(), fences, + /*wait_all=*/false, deadline); +} + +Status VulkanDevice::WaitIdle(absl::Time deadline) { + if (deadline == absl::InfiniteFuture()) { + // Fast path for using vkDeviceWaitIdle, which is usually cheaper (as it + // requires fewer calls into the driver). + IREE_TRACE_SCOPE0("VulkanDevice::WaitIdle#vkDeviceWaitIdle"); + VK_RETURN_IF_ERROR(syms()->vkDeviceWaitIdle(*logical_device_)); + return OkStatus(); + } + + IREE_TRACE_SCOPE0("VulkanDevice::WaitIdle#Fences"); + for (auto& command_queue : command_queues_) { + RETURN_IF_ERROR(command_queue->WaitIdle(deadline)); + } + return OkStatus(); +} + +} // namespace vulkan +} // namespace hal +} // namespace iree
diff --git a/hal/vulkan/vulkan_device.h b/hal/vulkan/vulkan_device.h new file mode 100644 index 0000000..b208bb2 --- /dev/null +++ b/hal/vulkan/vulkan_device.h
@@ -0,0 +1,120 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_VULKAN_VULKAN_DEVICE_H_ +#define IREE_HAL_VULKAN_VULKAN_DEVICE_H_ + +#include <vulkan/vulkan.h> + +#include <functional> +#include <memory> + +#include "absl/container/inlined_vector.h" +#include "absl/types/span.h" +#include "base/memory.h" +#include "hal/allocator.h" +#include "hal/device.h" +#include "hal/vulkan/descriptor_pool_cache.h" +#include "hal/vulkan/dynamic_symbols.h" +#include "hal/vulkan/extensibility_util.h" +#include "hal/vulkan/handle_util.h" +#include "hal/vulkan/legacy_fence.h" + +namespace iree { +namespace hal { +namespace vulkan { + +class VulkanDevice final : public Device { + public: + static StatusOr<std::shared_ptr<VulkanDevice>> Create( + const DeviceInfo& device_info, VkPhysicalDevice physical_device, + const ExtensibilitySpec& extensibility_spec, + const ref_ptr<DynamicSymbols>& syms); + + // Private constructor. + struct CtorKey { + private: + friend class VulkanDevice; + CtorKey() = default; + }; + VulkanDevice( + CtorKey ctor_key, const DeviceInfo& device_info, + VkPhysicalDevice physical_device, ref_ptr<VkDeviceHandle> logical_device, + std::unique_ptr<Allocator> allocator, + absl::InlinedVector<std::unique_ptr<CommandQueue>, 4> command_queues, + ref_ptr<VkCommandPoolHandle> dispatch_command_pool, + ref_ptr<VkCommandPoolHandle> transfer_command_pool, + ref_ptr<LegacyFencePool> legacy_fence_pool); + ~VulkanDevice() override; + + const ref_ptr<DynamicSymbols>& syms() const { + return logical_device_->syms(); + } + + Allocator* allocator() const override { return allocator_.get(); } + + absl::Span<CommandQueue*> dispatch_queues() const override { + return absl::MakeSpan(dispatch_queues_); + } + + absl::Span<CommandQueue*> transfer_queues() const override { + return absl::MakeSpan(transfer_queues_); + } + + std::shared_ptr<ExecutableCache> CreateExecutableCache() override; + + StatusOr<ref_ptr<CommandBuffer>> CreateCommandBuffer( + CommandBufferModeBitfield mode, + CommandCategoryBitfield command_categories) override; + + StatusOr<ref_ptr<Event>> CreateEvent() override; + + StatusOr<ref_ptr<BinarySemaphore>> CreateBinarySemaphore( + bool initial_value) override; + StatusOr<ref_ptr<TimelineSemaphore>> CreateTimelineSemaphore( + uint64_t initial_value) override; + + StatusOr<ref_ptr<Fence>> CreateFence(uint64_t initial_value) override; + Status WaitAllFences(absl::Span<const FenceValue> fences, + absl::Time deadline) override; + StatusOr<int> WaitAnyFence(absl::Span<const FenceValue> fences, + absl::Time deadline) override; + + Status WaitIdle(absl::Time deadline) override; + + private: + VkPhysicalDevice physical_device_; + ref_ptr<VkDeviceHandle> logical_device_; + + std::unique_ptr<Allocator> allocator_; + + mutable absl::InlinedVector<std::unique_ptr<CommandQueue>, 4> command_queues_; + mutable absl::InlinedVector<CommandQueue*, 4> dispatch_queues_; + mutable absl::InlinedVector<CommandQueue*, 4> transfer_queues_; + + ref_ptr<DescriptorPoolCache> descriptor_pool_cache_; + + ref_ptr<VkCommandPoolHandle> dispatch_command_pool_; + ref_ptr<VkCommandPoolHandle> transfer_command_pool_; + + // TODO(b/140141417): implement timeline semaphore fences and conditionally + // compile the legacy fence pool out. + ref_ptr<LegacyFencePool> legacy_fence_pool_; +}; + +} // namespace vulkan +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_VULKAN_VULKAN_DEVICE_H_
diff --git a/hal/vulkan/vulkan_driver.cc b/hal/vulkan/vulkan_driver.cc new file mode 100644 index 0000000..f4d5e34 --- /dev/null +++ b/hal/vulkan/vulkan_driver.cc
@@ -0,0 +1,229 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "hal/vulkan/vulkan_driver.h" + +#include <memory> + +#include "absl/container/inlined_vector.h" +#include "base/memory.h" +#include "base/status.h" +#include "base/tracing.h" +#include "hal/device_info.h" +#include "hal/vulkan/extensibility_util.h" +#include "hal/vulkan/status_util.h" +#include "hal/vulkan/vulkan_device.h" + +namespace iree { +namespace hal { +namespace vulkan { + +namespace { + +// Returns a VkApplicationInfo struct populated with the default app info. +// We may allow hosting applications to override this via weak-linkage if it's +// useful, otherwise this is enough to create the application. +VkApplicationInfo GetDefaultApplicationInfo() { + VkApplicationInfo info; + info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; + info.pNext = nullptr; + info.pApplicationName = "IREE-ML"; + info.applicationVersion = 0; + info.pEngineName = "IREE"; + info.engineVersion = 0; + info.apiVersion = VK_API_VERSION_1_0; + return info; +} + +// Populates device information from the given Vulkan physical device handle. +StatusOr<DeviceInfo> PopulateDeviceInfo(VkPhysicalDevice physical_device, + const ref_ptr<DynamicSymbols>& syms) { + VkPhysicalDeviceFeatures physical_device_features; + syms->vkGetPhysicalDeviceFeatures(physical_device, &physical_device_features); + // TODO(benvanik): check and optionally require these features: + // - physical_device_features.robustBufferAccess + // - physical_device_features.shaderInt16 + // - physical_device_features.shaderInt64 + // - physical_device_features.shaderFloat64 + + VkPhysicalDeviceProperties physical_device_properties; + syms->vkGetPhysicalDeviceProperties(physical_device, + &physical_device_properties); + // TODO(benvanik): check and optionally require reasonable limits. + + // TODO(benvanik): more clever/sanitized device naming. + std::string name = std::string(physical_device_properties.deviceName); + + DeviceFeatureBitfield supported_features = DeviceFeature::kNone; + // TODO(benvanik): implement debugging/profiling features. + // TODO(benvanik): use props to determine if we have timing info. + // supported_features |= DeviceFeature::kDebugging; + // supported_features |= DeviceFeature::kCoverage; + // supported_features |= DeviceFeature::kProfiling; + return DeviceInfo(std::move(name), supported_features, physical_device); +} + +} // namespace + +// static +StatusOr<std::shared_ptr<VulkanDriver>> VulkanDriver::Create( + Options options, ref_ptr<DynamicSymbols> syms) { + IREE_TRACE_SCOPE0("VulkanDriver::Create"); + + // Find the layers and extensions we need (or want) that are also available + // on the instance. This will fail when required ones are not present. + ASSIGN_OR_RETURN( + auto enabled_layer_names, + MatchAvailableInstanceLayers(options.instance_extensibility, *syms)); + ASSIGN_OR_RETURN( + auto enabled_extension_names, + MatchAvailableInstanceExtensions(options.instance_extensibility, *syms)); + auto instance_extensions = + PopulateEnabledInstanceExtensions(enabled_extension_names); + + // Create the instance this driver will use for all requests. + VkApplicationInfo app_info = GetDefaultApplicationInfo(); + app_info.apiVersion = options.api_version; + VkInstanceCreateInfo create_info; + create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; + create_info.pNext = nullptr; + create_info.flags = 0; + create_info.pApplicationInfo = &app_info; + create_info.enabledLayerCount = enabled_layer_names.size(); + create_info.ppEnabledLayerNames = enabled_layer_names.data(); + create_info.enabledExtensionCount = enabled_extension_names.size(); + create_info.ppEnabledExtensionNames = enabled_extension_names.data(); + + // If we have the debug_utils extension then we can chain a one-shot messenger + // callback that we can use to log out the instance creation errors. Once we + // have the real instance we can then register a real messenger. + union { + VkDebugUtilsMessengerCreateInfoEXT debug_utils_create_info; + VkDebugReportCallbackCreateInfoEXT debug_report_create_info; + }; + if (instance_extensions.debug_utils) { + create_info.pNext = &debug_utils_create_info; + DebugReporter::PopulateStaticCreateInfo(&debug_utils_create_info); + } else if (instance_extensions.debug_report) { + create_info.pNext = &debug_report_create_info; + DebugReporter::PopulateStaticCreateInfo(&debug_report_create_info); + } + + // Some ICDs appear to leak in here, out of our control. + // Warning: leak checks remain disabled if an error is returned. + IREE_DISABLE_LEAK_CHECKS(); + VkInstance instance = VK_NULL_HANDLE; + VK_RETURN_IF_ERROR( + syms->vkCreateInstance(&create_info, /*pAllocator=*/nullptr, &instance)) + << "Unable to create Vulkan instance"; + IREE_ENABLE_LEAK_CHECKS(); + + // TODO(benvanik): enable validation layers if needed. + + // Now that the instance has been created we can fetch all of the instance + // symbols. + RETURN_IF_ERROR(syms->LoadFromInstance(instance)); + + // The real debug messenger (not just the static one used above) can now be + // created as we've loaded all the required symbols. + // TODO(benvanik): strip in release builds. + std::unique_ptr<DebugReporter> debug_reporter; + if (instance_extensions.debug_utils) { + ASSIGN_OR_RETURN(debug_reporter, DebugReporter::CreateDebugUtilsMessenger( + instance, syms, + /*allocation_callbacks=*/nullptr)); + } else if (instance_extensions.debug_report) { + ASSIGN_OR_RETURN(debug_reporter, + DebugReporter::CreateDebugReportCallback( + instance, syms, /*allocation_callbacks=*/nullptr)); + } + + return std::make_shared<VulkanDriver>( + CtorKey{}, std::move(syms), instance, std::move(debug_reporter), + std::move(options.device_extensibility)); +} + +VulkanDriver::VulkanDriver(CtorKey ctor_key, ref_ptr<DynamicSymbols> syms, + VkInstance instance, + std::unique_ptr<DebugReporter> debug_reporter, + ExtensibilitySpec device_extensibility_spec) + : Driver("vulkan"), + syms_(std::move(syms)), + instance_(instance), + debug_reporter_(std::move(debug_reporter)), + device_extensibility_spec_(std::move(device_extensibility_spec)) {} + +VulkanDriver::~VulkanDriver() { + IREE_TRACE_SCOPE0("VulkanDriver::dtor"); + debug_reporter_.reset(); + syms()->vkDestroyInstance(instance_, /*pAllocator=*/nullptr); +} + +StatusOr<std::vector<DeviceInfo>> VulkanDriver::EnumerateAvailableDevices() { + IREE_TRACE_SCOPE0("VulkanDriver::EnumerateAvailableDevices"); + + // Query all available devices (at this moment, note that this may change!). + uint32_t physical_device_count = 0; + VK_RETURN_IF_ERROR(syms()->vkEnumeratePhysicalDevices( + instance_, &physical_device_count, nullptr)); + absl::InlinedVector<VkPhysicalDevice, 2> physical_devices( + physical_device_count); + VK_RETURN_IF_ERROR(syms()->vkEnumeratePhysicalDevices( + instance_, &physical_device_count, physical_devices.data())); + + // Convert to our HAL structure. + std::vector<DeviceInfo> device_infos; + device_infos.reserve(physical_device_count); + for (auto physical_device : physical_devices) { + // TODO(benvanik): if we fail should we just ignore the device in the list? + ASSIGN_OR_RETURN(auto device_info, + PopulateDeviceInfo(physical_device, syms())); + device_infos.push_back(std::move(device_info)); + } + return device_infos; +} + +StatusOr<std::shared_ptr<Device>> VulkanDriver::CreateDefaultDevice() { + IREE_TRACE_SCOPE0("VulkanDriver::CreateDefaultDevice"); + + // Query available devices. + ASSIGN_OR_RETURN(auto available_devices, EnumerateAvailableDevices()); + if (available_devices.empty()) { + return NotFoundErrorBuilder(IREE_LOC) << "No devices are available"; + } + + // Just create the first one we find. + return CreateDevice(available_devices.front()); +} + +StatusOr<std::shared_ptr<Device>> VulkanDriver::CreateDevice( + const DeviceInfo& device_info) { + IREE_TRACE_SCOPE0("VulkanDriver::CreateDevice"); + + auto physical_device = + static_cast<VkPhysicalDevice>(device_info.driver_handle()); + + // Attempt to create the device. + // This may fail if the device was enumerated but is in exclusive use, + // disabled by the system, or permission is denied. + ASSIGN_OR_RETURN(auto device, + VulkanDevice::Create(device_info, physical_device, + device_extensibility_spec_, syms())); + + return device; +} + +} // namespace vulkan +} // namespace hal +} // namespace iree
diff --git a/hal/vulkan/vulkan_driver.h b/hal/vulkan/vulkan_driver.h new file mode 100644 index 0000000..34ec759 --- /dev/null +++ b/hal/vulkan/vulkan_driver.h
@@ -0,0 +1,84 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_VULKAN_VULKAN_DRIVER_H_ +#define IREE_HAL_VULKAN_VULKAN_DRIVER_H_ + +#include <vulkan/vulkan.h> + +#include <memory> +#include <vector> + +#include "hal/driver.h" +#include "hal/vulkan/debug_reporter.h" +#include "hal/vulkan/dynamic_symbols.h" +#include "hal/vulkan/extensibility_util.h" + +namespace iree { +namespace hal { +namespace vulkan { + +class VulkanDriver final : public Driver { + public: + struct Options { + // Vulkan version that will be requested. + // Driver creation will fail if the required version is not available. + uint32_t api_version = VK_API_VERSION_1_0; + + // Extensibility descriptions for instances and devices. + // Device descriptions will be used for all devices created by the driver. + ExtensibilitySpec instance_extensibility; + ExtensibilitySpec device_extensibility; + }; + + static StatusOr<std::shared_ptr<VulkanDriver>> Create( + Options options, ref_ptr<DynamicSymbols> syms); + + // TODO(benvanik): method to wrap an existing instance/device (interop). + + // Private constructor. + struct CtorKey { + private: + friend class VulkanDriver; + CtorKey() = default; + }; + VulkanDriver(CtorKey ctor_key, ref_ptr<DynamicSymbols> syms, + VkInstance instance, + std::unique_ptr<DebugReporter> debug_reporter, + ExtensibilitySpec device_extensibility_spec); + ~VulkanDriver() override; + + const ref_ptr<DynamicSymbols>& syms() const { return syms_; } + + VkInstance instance() const { return instance_; } + + StatusOr<std::vector<DeviceInfo>> EnumerateAvailableDevices() override; + + StatusOr<std::shared_ptr<Device>> CreateDefaultDevice() override; + + StatusOr<std::shared_ptr<Device>> CreateDevice( + const DeviceInfo& device_info) override; + + private: + ref_ptr<DynamicSymbols> syms_; + VkInstance instance_; + std::unique_ptr<DebugReporter> debug_reporter_; + ExtensibilitySpec device_extensibility_spec_; +}; + +} // namespace vulkan +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_VULKAN_VULKAN_DRIVER_H_
diff --git a/hal/vulkan/vulkan_driver_module.cc b/hal/vulkan/vulkan_driver_module.cc new file mode 100644 index 0000000..51da626 --- /dev/null +++ b/hal/vulkan/vulkan_driver_module.cc
@@ -0,0 +1,102 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <memory> + +#include "absl/flags/flag.h" +#include "base/init.h" +#include "base/status.h" +#include "base/tracing.h" +#include "hal/driver_registry.h" +#include "hal/vulkan/dynamic_symbols.h" +#include "hal/vulkan/vulkan_driver.h" + +ABSL_FLAG(bool, vulkan_validation_layers, true, + "Enables standard Vulkan validation layers."); +ABSL_FLAG(bool, vulkan_debug_utils, true, + "Enables VK_EXT_debug_utils, records markers, and logs errors."); +ABSL_FLAG(bool, vulkan_debug_report, false, + "Enables VK_EXT_debug_report and logs errors."); +ABSL_FLAG(bool, vulkan_push_descriptors, true, + "Enables use of vkCmdPushDescriptorSetKHR, if available."); + +namespace iree { +namespace hal { +namespace vulkan { +namespace { + +StatusOr<std::shared_ptr<Driver>> CreateVulkanDriver() { + IREE_TRACE_SCOPE0("CreateVulkanDriver"); + + // Load the Vulkan library. This will fail if the library cannot be found or + // does not have the expected functions. + ASSIGN_OR_RETURN(auto syms, DynamicSymbols::CreateFromSystemLoader()); + + // Setup driver options from flags. We do this here as we want to enable other + // consumers that may not be using modules/command line flags to be able to + // set their options however they want. + VulkanDriver::Options options; + + // TODO: validation layers have bugs when using VK_EXT_debug_report, so if the + // user requested that we force them off with a warning. Prefer using + // VK_EXT_debug_utils when available. + if (absl::GetFlag(FLAGS_vulkan_debug_report) && + absl::GetFlag(FLAGS_vulkan_validation_layers)) { + LOG(WARNING) << "VK_EXT_debug_report has issues with modern validation " + "layers; disabling validation"; + absl::SetFlag(&FLAGS_vulkan_validation_layers, false); + } + + // REQUIRED: these are required extensions that must be present for IREE to + // work (such as those relied upon by SPIR-V kernels, etc). + options.device_extensibility.required_extensions.push_back( + VK_KHR_STORAGE_BUFFER_STORAGE_CLASS_EXTENSION_NAME); + + if (absl::GetFlag(FLAGS_vulkan_validation_layers)) { + options.instance_extensibility.optional_layers.push_back( + "VK_LAYER_LUNARG_standard_validation"); + } + + if (absl::GetFlag(FLAGS_vulkan_debug_report)) { + options.instance_extensibility.optional_extensions.push_back( + VK_EXT_DEBUG_REPORT_EXTENSION_NAME); + } + if (absl::GetFlag(FLAGS_vulkan_debug_utils)) { + options.instance_extensibility.optional_extensions.push_back( + VK_EXT_DEBUG_UTILS_EXTENSION_NAME); + } + + if (absl::GetFlag(FLAGS_vulkan_push_descriptors)) { + options.instance_extensibility.optional_extensions.push_back( + VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME); + options.device_extensibility.optional_extensions.push_back( + VK_KHR_PUSH_DESCRIPTOR_EXTENSION_NAME); + } + + // Create the driver and VkInstance. + ASSIGN_OR_RETURN(auto driver, VulkanDriver::Create(options, std::move(syms))); + + return driver; +} + +} // namespace +} // namespace vulkan +} // namespace hal +} // namespace iree + +IREE_REGISTER_MODULE_INITIALIZER(iree_hal_vulkan_driver, { + QCHECK_OK(::iree::hal::DriverRegistry::shared_registry()->Register( + "vulkan", ::iree::hal::vulkan::CreateVulkanDriver)); +}); +IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(iree_hal, iree_hal_vulkan_driver);
diff --git a/iree/BUILD b/iree/BUILD deleted file mode 100644 index d21f74c..0000000 --- a/iree/BUILD +++ /dev/null
@@ -1,31 +0,0 @@ -# Main IREE build file. -# Note that project-wide, bazel repo aliases are used: -# "@com_google_absl//absl/python" -# "@com_google_absl//absl" -# "@com_google_benchmark//:benchmark" -# "@local_config_mlir//" -# "@llvm//" -# "@com_github_google_flatbuffers//:flatbuffers" -# "@org_tensorflow//tensorflow" -# -# Various scripts and helpers operate on these prefixes textually, so -# avoid doing any systematic construction that would break the matching. - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -# Enables the debug service and other profiling features. -# $ bazel build --define=IREE_DEBUG=1 :some_target -config_setting( - name = "debug", - define_values = {"IREE_DEBUG": "1"}, -) - -# Marker library which can be extended to provide flags for things that -# need to know the platform target. -cc_library( - name = "target_config", - defines = ["IREE_UNSPECIFIED_TARGET=1"], -)
diff --git a/iree/base/BUILD b/iree/base/BUILD deleted file mode 100644 index 3b8249e..0000000 --- a/iree/base/BUILD +++ /dev/null
@@ -1,362 +0,0 @@ -# Common types and utilities used in the IREE codebase. - -load("//iree:build_defs.bzl", "platform_trampoline_deps") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "api", - srcs = ["api.cc"], - hdrs = ["api.h"], - visibility = ["//visibility:public"], - deps = [ - ":api_hdrs", - ":api_util", - ":file_mapping", - ":tracing", - ], -) - -cc_library( - name = "api_hdrs", - hdrs = ["api.h"], -) - -cc_library( - name = "api_util", - hdrs = ["api_util.h"], - deps = [ - ":api_hdrs", - ":logging", - ":shape", - ":status", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/time", - ], -) - -cc_library( - name = "arena", - srcs = ["arena.cc"], - hdrs = ["arena.h"], - deps = [ - ":logging", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "arena_test", - srcs = ["arena_test.cc"], - deps = [ - ":arena", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "bitfield", - hdrs = ["bitfield.h"], - deps = [ - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "bitfield_test", - srcs = ["bitfield_test.cc"], - deps = [ - ":bitfield", - "@com_google_absl//absl/base:core_headers", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "file_io", - hdrs = ["file_io.h"], - deps = [ - ":status", - ":target_platform", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ] + platform_trampoline_deps("file_io"), -) - -cc_library( - name = "file_io_hdrs", - hdrs = ["file_io.h"], - deps = [":status"], -) - -cc_library( - name = "file_mapping", - hdrs = ["file_mapping.h"], - deps = [ - ":ref_ptr", - ":status", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ] + platform_trampoline_deps("file_mapping"), -) - -cc_library( - name = "file_mapping_hdrs", - hdrs = ["file_mapping.h"], - deps = [ - ":ref_ptr", - ":status", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "file_path", - srcs = ["file_path.cc"], - hdrs = ["file_path.h"], - deps = [ - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "flatbuffer_util", - srcs = ["flatbuffer_util.cc"], - hdrs = ["flatbuffer_util.h"], - deps = [ - ":file_mapping", - ":memory", - ":source_location", - ":status", - ":tracing", - "@com_github_google_flatbuffers//:flatbuffers", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "init", - hdrs = ["init.h"], - deps = platform_trampoline_deps("init"), -) - -cc_library( - name = "intrusive_list", - hdrs = [ - "intrusive_list.h", - "intrusive_list_ref_ptr.inc", - "intrusive_list_unique_ptr.inc", - ], - deps = [ - ":logging", - ":ref_ptr", - ], -) - -cc_test( - name = "intrusive_list_test", - srcs = [ - "intrusive_list_ref_ptr_test.cc", - "intrusive_list_test.cc", - "intrusive_list_unique_ptr_test.cc", - ], - deps = [ - ":intrusive_list", - "@com_google_absl//absl/memory", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "logging", - hdrs = ["logging.h"], - deps = platform_trampoline_deps("logging"), -) - -cc_library( - name = "math", - hdrs = ["math.h"], - deps = [ - "@com_google_absl//absl/base:core_headers", - ], -) - -cc_library( - name = "memory", - hdrs = ["memory.h"], - deps = [ - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "ref_ptr", - hdrs = ["ref_ptr.h"], - deps = [ - ":logging", - "@com_google_absl//absl/base:core_headers", - ], -) - -cc_test( - name = "ref_ptr_test", - size = "small", - srcs = ["ref_ptr_test.cc"], - deps = [ - ":ref_ptr", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "shape", - srcs = ["shape.cc"], - hdrs = ["shape.h"], - deps = [ - ":logging", - ":source_location", - ":status", - "@com_google_absl//absl/meta:type_traits", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "shape_test", - srcs = ["shape_test.cc"], - deps = [ - ":shape", - ":status", - ":status_matchers", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "source_location", - hdrs = ["source_location.h"], - deps = platform_trampoline_deps("source_location"), -) - -cc_library( - name = "status", - hdrs = ["status.h"], - deps = [ - ":source_location", - ] + platform_trampoline_deps("status"), -) - -cc_library( - name = "status_matchers", - testonly = 1, - hdrs = ["status_matchers.h"], - deps = platform_trampoline_deps("status_matchers"), -) - -cc_library( - name = "target_platform", - hdrs = ["target_platform.h"], -) - -cc_library( - name = "time", - hdrs = ["time.h"], - deps = [ - "@com_google_absl//absl/time", - ], -) - -cc_library( - name = "tracing", - hdrs = ["tracing.h"], - deps = [ - "//iree:target_config", - "@com_google_tracing_framework_cpp//:tracing_framework_bindings_cpp", - ] + select({ - "@com_google_tracing_framework_cpp//:wtf_enable": [":tracing_enabled"], - "//conditions:default": [":tracing_disabled"], - }), -) - -cc_library( - name = "tracing_disabled", - srcs = [ - "tracing.h", - "tracing_disabled.cc", - ], - visibility = ["//visibility:private"], - deps = [ - ":init", - ":logging", - "@com_google_absl//absl/flags:flag", - "@com_google_tracing_framework_cpp//:tracing_framework_bindings_cpp", - ], - alwayslink = 1, -) - -cc_library( - name = "tracing_enabled", - srcs = [ - "tracing.cc", - "tracing.h", - ], - visibility = ["//visibility:private"], - deps = [ - ":file_io", - ":file_path", - ":init", - ":logging", - ":status", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_tracing_framework_cpp//:tracing_framework_bindings_cpp", - ], - alwayslink = 1, -) - -# Dependent code has been removed and wait_handle is currently incompatible -# with Windows, so excluding entirely. -# See google/iree/65 -# cc_library( -# name = "wait_handle", -# srcs = ["wait_handle.cc"], -# hdrs = ["wait_handle.h"], -# deps = [ -# ":logging", -# ":ref_ptr", -# ":source_location", -# ":status", -# ":time", -# "@com_google_absl//absl/base:core_headers", -# "@com_google_absl//absl/container:fixed_array", -# "@com_google_absl//absl/strings", -# "@com_google_absl//absl/time", -# "@com_google_absl//absl/types:span", -# ], -# ) - -# cc_test( -# name = "wait_handle_test", -# srcs = ["wait_handle_test.cc"], -# deps = [ -# ":status", -# ":status_matchers", -# ":wait_handle", -# "@com_google_absl//absl/time", -# "@com_google_googletest//:gtest_main", -# ], -# )
diff --git a/iree/base/api.cc b/iree/base/api.cc deleted file mode 100644 index cff917f..0000000 --- a/iree/base/api.cc +++ /dev/null
@@ -1,124 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/api.h" - -#include <cstdlib> -#include <string> - -#include "iree/base/api_util.h" -#include "iree/base/file_mapping.h" -#include "iree/base/tracing.h" - -namespace iree { - -//===----------------------------------------------------------------------===// -// iree Core API -//===----------------------------------------------------------------------===// - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_api_version_check(iree_api_version_t expected_version, - iree_api_version_t* out_actual_version) { - iree_api_version_t actual_version = IREE_API_VERSION_0; - *out_actual_version = actual_version; - return expected_version == actual_version ? IREE_STATUS_OK - : IREE_STATUS_OUT_OF_RANGE; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_allocator_alloc(void* self, iree_host_size_t byte_length, void** out_ptr) { - IREE_TRACE_SCOPE0("iree_allocator_alloc"); - - if (!out_ptr) { - return IREE_STATUS_INVALID_ARGUMENT; - } - *out_ptr = nullptr; - - if (byte_length <= 0) { - return IREE_STATUS_INVALID_ARGUMENT; - } - - *out_ptr = std::malloc(byte_length); - if (!*out_ptr) { - return IREE_STATUS_RESOURCE_EXHAUSTED; - } - - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_allocator_free(void* self, - void* ptr) { - IREE_TRACE_SCOPE0("iree_allocator_free"); - if (ptr) { - std::free(ptr); - } - return IREE_STATUS_OK; -} - -//===----------------------------------------------------------------------===// -// iree::FileMapping -//===----------------------------------------------------------------------===// - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_file_mapping_open_read(iree_string_view_t path, iree_allocator_t allocator, - iree_file_mapping_t** out_file_mapping) { - IREE_TRACE_SCOPE0("iree_file_mapping_open_read"); - - if (!out_file_mapping) { - return IREE_STATUS_INVALID_ARGUMENT; - } - *out_file_mapping = nullptr; - - IREE_API_ASSIGN_OR_RETURN( - auto file_mapping, - FileMapping::OpenRead(std::string(path.data, path.size))); - - *out_file_mapping = - reinterpret_cast<iree_file_mapping_t*>(file_mapping.release()); - - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_file_mapping_retain(iree_file_mapping_t* file_mapping) { - IREE_TRACE_SCOPE0("iree_file_mapping_retain"); - auto* handle = reinterpret_cast<FileMapping*>(file_mapping); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - handle->AddReference(); - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_file_mapping_release(iree_file_mapping_t* file_mapping) { - IREE_TRACE_SCOPE0("iree_file_mapping_release"); - auto* handle = reinterpret_cast<FileMapping*>(file_mapping); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - handle->ReleaseReference(); - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_byte_span_t IREE_API_CALL -iree_file_mapping_data(iree_file_mapping_t* file_mapping) { - IREE_TRACE_SCOPE0("iree_file_mapping_data"); - auto* handle = reinterpret_cast<FileMapping*>(file_mapping); - CHECK(handle) << "NULL file_mapping handle"; - auto data = handle->data(); - return {const_cast<uint8_t*>(data.data()), data.size()}; -} - -} // namespace iree
diff --git a/iree/base/api_util.h b/iree/base/api_util.h deleted file mode 100644 index 5eaa5e5..0000000 --- a/iree/base/api_util.h +++ /dev/null
@@ -1,126 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BASE_API_UTIL_H_ -#define IREE_BASE_API_UTIL_H_ - -#include "absl/base/macros.h" -#include "absl/time/time.h" -#include "iree/base/api.h" -#include "iree/base/logging.h" -#include "iree/base/shape.h" -#include "iree/base/status.h" - -namespace iree { - -inline iree_status_t ToApiStatus(Status status) { - DLOG(ERROR) << status; - return static_cast<iree_status_t>(status.code()); -} - -inline Status FromApiStatus(iree_status_t status_code, SourceLocation loc) { - return StatusBuilder(static_cast<StatusCode>(status_code), loc); -} - -// Internal helper for concatenating macro values. -#define IREE_API_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) x##y -#define IREE_API_STATUS_MACROS_IMPL_CONCAT_(x, y) \ - IREE_API_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) - -// clang-format off -#define IREE_API_STATUS_MACROS_IMPL_ELSE_BLOCKER_ switch (0) case 0: default: // NOLINT -// clang-format on - -namespace status_macro_internal { -class StatusAdaptorForApiMacros { - public: - StatusAdaptorForApiMacros(const Status& status) : status_(status) {} - StatusAdaptorForApiMacros(Status&& status) : status_(std::move(status)) {} - StatusAdaptorForApiMacros(const StatusAdaptorForApiMacros&) = delete; - StatusAdaptorForApiMacros& operator=(const StatusAdaptorForApiMacros&) = - delete; - explicit operator bool() const { return ABSL_PREDICT_TRUE(status_.ok()); } - Status&& Consume() { return std::move(status_); } - - private: - Status status_; -}; -} // namespace status_macro_internal - -#define IREE_API_RETURN_IF_ERROR(expr) \ - IREE_API_STATUS_MACROS_IMPL_ELSE_BLOCKER_ \ - if (::iree::status_macro_internal::StatusAdaptorForApiMacros \ - status_adaptor = {expr}) { \ - } else /* NOLINT */ \ - return ::iree::ToApiStatus(status_adaptor.Consume()) - -#define IREE_API_RETURN_IF_API_ERROR(expr) \ - IREE_API_STATUS_MACROS_IMPL_ELSE_BLOCKER_ \ - if (iree_status_t status = (expr)) { \ - return status; \ - } - -#define IREE_API_ASSIGN_OR_RETURN(...) \ - IREE_API_STATUS_MACROS_IMPL_GET_VARIADIC_( \ - (__VA_ARGS__, IREE_API_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_, \ - IREE_API_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_)) \ - (__VA_ARGS__) - -#define IREE_API_STATUS_MACROS_IMPL_GET_VARIADIC_HELPER_(_1, _2, _3, NAME, \ - ...) \ - NAME -#define IREE_API_STATUS_MACROS_IMPL_GET_VARIADIC_(args) \ - IREE_API_STATUS_MACROS_IMPL_GET_VARIADIC_HELPER_ args - -#define IREE_API_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_(lhs, rexpr) \ - IREE_API_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, std::move(_)) -#define IREE_API_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, \ - error_expression) \ - IREE_API_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_( \ - IREE_API_STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, \ - rexpr, error_expression) -#define IREE_API_STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_(statusor, lhs, rexpr, \ - error_expression) \ - auto statusor = (rexpr); \ - if (ABSL_PREDICT_FALSE(!statusor.ok())) { \ - return ::iree::ToApiStatus(std::move(statusor).status()); \ - } \ - lhs = std::move(statusor).ValueOrDie() - -// Converts an iree_time_t to its equivalent absl::Time. -inline absl::Time ToAbslTime(iree_time_t time) { - if (time == IREE_TIME_INFINITE_PAST) { - return absl::InfinitePast(); - } else if (time == IREE_TIME_INFINITE_FUTURE) { - return absl::InfiniteFuture(); - } else { - return absl::FromUnixNanos(time); - } -} - -// Converts a Shape to an iree_shape_t. -inline iree_status_t ToApiShape(const Shape& shape, iree_shape_t* out_shape) { - out_shape->rank = shape.size(); - if (shape.size() > ABSL_ARRAYSIZE(out_shape->dims)) { - return IREE_STATUS_OUT_OF_RANGE; - } - for (int i = 0; i < out_shape->rank; ++i) { - out_shape->dims[i] = shape[i]; - } - return IREE_STATUS_OK; -} - -} // namespace iree - -#endif // IREE_BASE_API_UTIL_H_
diff --git a/iree/base/arena.cc b/iree/base/arena.cc deleted file mode 100644 index 9a6d7c2..0000000 --- a/iree/base/arena.cc +++ /dev/null
@@ -1,125 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/arena.h" - -#include <memory> - -#include "absl/base/attributes.h" -#include "iree/base/logging.h" - -namespace iree { - -namespace { - -// Rounds up to the next alignment value, if it is not already aligned. -template <typename T> -ABSL_ATTRIBUTE_ALWAYS_INLINE constexpr T RoundToAlignment( - T value, T alignment) noexcept { - return ((value + alignment - 1) / alignment) * alignment; -} - -} // namespace - -Arena::Arena(size_t block_size) : block_size_(block_size) {} - -Arena::~Arena() { Clear(); } - -void Arena::Clear() { - // Deallocate all memory. - auto block_header = block_list_head_; - while (block_header) { - auto next_block = block_header->next_block; - std::free(block_header); - block_header = next_block; - } - block_list_head_ = nullptr; - block_header = unused_block_list_head_; - while (block_header) { - auto next_block = block_header->next_block; - std::free(block_header); - block_header = next_block; - } - unused_block_list_head_ = nullptr; - - bytes_allocated_ = 0; - block_bytes_allocated_ = 0; -} - -void Arena::Reset() { - // Move all blocks to the unused list and reset allocation count only. - auto block_header = block_list_head_; - while (block_header) { - auto next_block = block_header->next_block; - block_header->bytes_allocated = 0; - block_header->next_block = unused_block_list_head_; - unused_block_list_head_ = block_header; - block_header = next_block; - } - block_list_head_ = nullptr; - - bytes_allocated_ = 0; -} - -uint8_t* Arena::AllocateBytes(size_t length) { - if (!length) { - // Guarantee zero-length allocations return nullptr. - return nullptr; - } - - // Pad length allocated so we are machine word aligned. - // This ensures the next allocation starts at the right boundary. - size_t aligned_length = RoundToAlignment(length, sizeof(uintptr_t)); - - if (aligned_length > block_size_) { - // This allocation is larger than an entire block. That's bad. - // We could allocate this with malloc (and then keep track of those to free - // things), but for now let's just die. - CHECK(false); - return nullptr; - } - - if (!block_list_head_ || - block_list_head_->bytes_allocated + aligned_length > block_size_) { - // Check to see if we have an existing unused block we can use. - if (unused_block_list_head_) { - // Move block from unused list to main list. - auto block_header = unused_block_list_head_; - unused_block_list_head_ = block_header->next_block; - block_header->next_block = block_list_head_; - block_header->bytes_allocated = 0; - block_list_head_ = block_header; - } else { - // Allocate a new block. - auto block_ptr = reinterpret_cast<uint8_t*>( - std::malloc(sizeof(BlockHeader) + block_size_)); - auto block_header = reinterpret_cast<BlockHeader*>(block_ptr); - block_header->next_block = block_list_head_; - block_header->bytes_allocated = 0; - block_list_head_ = block_header; - block_bytes_allocated_ += sizeof(BlockHeader) + block_size_; - } - } - - BlockHeader* target_block = block_list_head_; - auto data_ptr = reinterpret_cast<uint8_t*>(target_block) + - sizeof(BlockHeader) + target_block->bytes_allocated; - target_block->bytes_allocated += aligned_length; - - bytes_allocated_ += length; - - return data_ptr; -} - -} // namespace iree
diff --git a/iree/base/arena_test.cc b/iree/base/arena_test.cc deleted file mode 100644 index fb0e801..0000000 --- a/iree/base/arena_test.cc +++ /dev/null
@@ -1,148 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/arena.h" - -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -namespace iree { -namespace { - -// Tests basic block allocations. -TEST(ArenaTest, BasicAllocation) { - Arena arena(64); - EXPECT_EQ(64, arena.block_size()); - EXPECT_EQ(0, arena.bytes_allocated()); - EXPECT_EQ(0, arena.block_bytes_allocated()); - - // Zero byte allocations should return nullptr and not allocate bytes. - auto zero_ptr = reinterpret_cast<uintptr_t>(arena.AllocateBytes(0)); - EXPECT_EQ(0, zero_ptr); - EXPECT_EQ(0, arena.bytes_allocated()); - EXPECT_EQ(0, arena.block_bytes_allocated()); - - arena.Clear(); - - // Allocations must be machine word aligned. - auto one_ptr = reinterpret_cast<uintptr_t>(arena.AllocateBytes(1)); - EXPECT_NE(0, one_ptr); - EXPECT_EQ(0, one_ptr % sizeof(uintptr_t)); - one_ptr = reinterpret_cast<uintptr_t>(arena.AllocateBytes(1)); - EXPECT_NE(0, one_ptr); - EXPECT_EQ(0, one_ptr % sizeof(uintptr_t)); - EXPECT_EQ(2, arena.bytes_allocated()); - EXPECT_LT(2, arena.block_bytes_allocated()); - - arena.Clear(); - EXPECT_EQ(0, arena.bytes_allocated()); - EXPECT_EQ(0, arena.block_bytes_allocated()); -} - -// Tests typed allocations. -TEST(ArenaTest, TypedAllocations) { - Arena arena(64); - - EXPECT_NE(nullptr, arena.Allocate<int>()); - EXPECT_EQ(4, arena.bytes_allocated()); - EXPECT_EQ(64 + Arena::kBlockOverhead, arena.block_bytes_allocated()); - arena.Clear(); - EXPECT_EQ(0, arena.bytes_allocated()); - EXPECT_EQ(0, arena.block_bytes_allocated()); - - struct MyType { - MyType() {} - explicit MyType(int initial_value) : value(initial_value) {} - - int value = 5; - }; - auto my_type_ptr = arena.Allocate<MyType>(); - EXPECT_NE(nullptr, my_type_ptr); - EXPECT_EQ(sizeof(MyType), arena.bytes_allocated()); - EXPECT_EQ(5, my_type_ptr->value); // Default ctor must be called. - arena.Clear(); - EXPECT_EQ(0, arena.bytes_allocated()); - EXPECT_EQ(0, arena.block_bytes_allocated()); - - my_type_ptr = arena.Allocate<MyType>(10); - EXPECT_NE(nullptr, my_type_ptr); - EXPECT_EQ(sizeof(MyType), arena.bytes_allocated()); - EXPECT_EQ(10, my_type_ptr->value); // Ctor should have been called. - arena.Clear(); - EXPECT_EQ(0, arena.bytes_allocated()); - EXPECT_EQ(0, arena.block_bytes_allocated()); -} - -// Tests multiple blocks. -TEST(ArenaTest, MultipleBlocks) { - Arena arena(16); - EXPECT_EQ(0, arena.bytes_allocated()); - EXPECT_EQ(0, arena.block_bytes_allocated()); - - // Allocate one entire block. - EXPECT_NE(nullptr, arena.AllocateBytes(16)); - EXPECT_EQ(16, arena.bytes_allocated()); - EXPECT_EQ(16 + Arena::kBlockOverhead, arena.block_bytes_allocated()); - - // Allocate into the next block. - EXPECT_NE(nullptr, arena.AllocateBytes(16)); - EXPECT_EQ(32, arena.bytes_allocated()); - EXPECT_EQ(32 + 2 * Arena::kBlockOverhead, arena.block_bytes_allocated()); - - // Clear. - arena.Clear(); - EXPECT_EQ(0, arena.bytes_allocated()); - EXPECT_EQ(0, arena.block_bytes_allocated()); - - // Allocate again. - EXPECT_NE(nullptr, arena.AllocateBytes(16)); - EXPECT_EQ(16, arena.bytes_allocated()); - EXPECT_EQ(16 + Arena::kBlockOverhead, arena.block_bytes_allocated()); - EXPECT_NE(nullptr, arena.AllocateBytes(16)); - EXPECT_EQ(32, arena.bytes_allocated()); - EXPECT_EQ(32 + 2 * Arena::kBlockOverhead, arena.block_bytes_allocated()); -} - -// Tests fast reset. -TEST(ArenaTest, FastReset) { - Arena arena(16); - EXPECT_EQ(0, arena.bytes_allocated()); - EXPECT_EQ(0, arena.block_bytes_allocated()); - - // Allocate one entire block. - EXPECT_NE(nullptr, arena.AllocateBytes(16)); - EXPECT_EQ(16, arena.bytes_allocated()); - EXPECT_EQ(16 + Arena::kBlockOverhead, arena.block_bytes_allocated()); - - // Allocate into the next block. - EXPECT_NE(nullptr, arena.AllocateBytes(16)); - EXPECT_EQ(32, arena.bytes_allocated()); - EXPECT_EQ(32 + 2 * Arena::kBlockOverhead, arena.block_bytes_allocated()); - - // Reset (without deallocating). - arena.Reset(); - EXPECT_EQ(0, arena.bytes_allocated()); - EXPECT_EQ(32 + 2 * Arena::kBlockOverhead, arena.block_bytes_allocated()); - - // Allocate again. - EXPECT_NE(nullptr, arena.AllocateBytes(16)); - EXPECT_EQ(16, arena.bytes_allocated()); - EXPECT_EQ(32 + 2 * Arena::kBlockOverhead, arena.block_bytes_allocated()); - EXPECT_NE(nullptr, arena.AllocateBytes(16)); - EXPECT_EQ(32, arena.bytes_allocated()); - EXPECT_EQ(32 + 2 * Arena::kBlockOverhead, arena.block_bytes_allocated()); -} - -} // namespace -} // namespace iree
diff --git a/iree/base/bitfield_test.cc b/iree/base/bitfield_test.cc deleted file mode 100644 index 0e545a6..0000000 --- a/iree/base/bitfield_test.cc +++ /dev/null
@@ -1,82 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/bitfield.h" - -#include <cstdint> -#include <vector> - -#include "gtest/gtest.h" - -namespace iree { - -// NOTE: define here so that we don't get internal linkage warnings. -enum class MyValue : uint32_t { - kNone = 0, - kA = 1 << 0, - kB = 1 << 1, - kAll = kA | kB, -}; -IREE_BITFIELD(MyValue); - -namespace { - -// Tests general usage. -TEST(BitfieldTest, FormatBitfieldValue) { - std::vector<std::pair<MyValue, const char *>> mappings = { - {MyValue::kA, "kA"}, - {MyValue::kB, "kB"}, - }; - EXPECT_EQ("", - FormatBitfieldValue(MyValue::kNone, absl::MakeConstSpan(mappings))); - EXPECT_EQ("kA", - FormatBitfieldValue(MyValue::kA, absl::MakeConstSpan(mappings))); - EXPECT_EQ("kA|kB", FormatBitfieldValue(MyValue::kA | MyValue::kB, - absl::MakeConstSpan(mappings))); -} - -// Tests that empty mapping tables are fine. -TEST(BitfieldTest, FormatBitfieldValueEmpty) { - EXPECT_EQ("", FormatBitfieldValue(MyValue::kNone, {})); -} - -// Tests that values not found in the mappings are still displayed. -TEST(BitfieldTest, FormatBitfieldValueUnhandledValues) { - EXPECT_EQ("kA|2h", FormatBitfieldValue(MyValue::kA | MyValue::kB, - { - {MyValue::kA, "kA"}, - })); -} - -// Tests priority order in the mapping table. -TEST(BitfieldTest, FormatBitfieldValuePriority) { - // No priority, will do separate. - EXPECT_EQ("kA|kB", FormatBitfieldValue(MyValue::kA | MyValue::kB, - { - {MyValue::kA, "kA"}, - {MyValue::kB, "kB"}, - {MyValue::kAll, "kAll"}, - })); - - // Priority on the combined flag, use that instead. - EXPECT_EQ("kAll", FormatBitfieldValue(MyValue::kA | MyValue::kB, - { - {MyValue::kAll, "kAll"}, - {MyValue::kA, "kA"}, - {MyValue::kB, "kB"}, - })); -} - -} // namespace -} // namespace iree
diff --git a/iree/base/file_io.h b/iree/base/file_io.h deleted file mode 100644 index 0a2003d..0000000 --- a/iree/base/file_io.h +++ /dev/null
@@ -1,48 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BASE_FILE_IO_H_ -#define IREE_BASE_FILE_IO_H_ - -#include <string> - -#include "iree/base/status.h" - -namespace iree { -namespace file_io { - -// Checks if a file exists at the provided path. -// -// Returns an OK status if the file definitely exists. -// Errors can include PermissionDeniedError, NotFoundError, etc. -Status FileExists(const std::string& path); - -// Synchronously reads a file's contents into a string. -StatusOr<std::string> GetFileContents(const std::string& path); - -// Deletes the file at the provided path. -Status DeleteFile(const std::string& path); - -// Moves a file from 'source_path' to 'destination_path'. -// -// This may simply rename the file, but may fall back to a full copy and delete -// of the original if renaming is not possible (for example when moving between -// physical storage locations). -Status MoveFile(const std::string& source_path, - const std::string& destination_path); - -} // namespace file_io -} // namespace iree - -#endif // IREE_BASE_FILE_IO_H_
diff --git a/iree/base/file_mapping.h b/iree/base/file_mapping.h deleted file mode 100644 index e871370..0000000 --- a/iree/base/file_mapping.h +++ /dev/null
@@ -1,51 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BASE_FILE_MAPPING_H_ -#define IREE_BASE_FILE_MAPPING_H_ - -#include <cstdint> -#include <string> - -#include "absl/types/span.h" -#include "iree/base/ref_ptr.h" -#include "iree/base/status.h" - -namespace iree { - -// A memory-mapped file handle. -class FileMapping : public RefObject<FileMapping> { - public: - // Opens a file and maps it into the calling process memory. - // The file will be opened for shared read access. - static StatusOr<ref_ptr<FileMapping>> OpenRead(std::string path); - - virtual ~FileMapping() = default; - - // Read-only contents of the file. - inline absl::Span<const uint8_t> data() const noexcept { return data_; } - - protected: - explicit FileMapping(absl::Span<const uint8_t> data) : data_(data) {} - - absl::Span<const uint8_t> data_; - - private: - FileMapping(const FileMapping&) = delete; - FileMapping& operator=(const FileMapping&) = delete; -}; - -} // namespace iree - -#endif // IREE_BASE_FILE_MAPPING_H_
diff --git a/iree/base/file_path.cc b/iree/base/file_path.cc deleted file mode 100644 index 13230ce..0000000 --- a/iree/base/file_path.cc +++ /dev/null
@@ -1,83 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/file_path.h" - -#include "absl/strings/str_cat.h" - -namespace iree { -namespace file_path { - -namespace { - -std::pair<absl::string_view, absl::string_view> SplitPath( - absl::string_view path) { - size_t pos = path.find_last_of('/'); - // Handle the case with no '/' in 'path'. - if (pos == absl::string_view::npos) { - return std::make_pair(path.substr(0, 0), path); - } - // Handle the case with a single leading '/' in 'path'. - if (pos == 0) { - return std::make_pair(path.substr(0, 1), absl::ClippedSubstr(path, 1)); - } - return std::make_pair(path.substr(0, pos), - absl::ClippedSubstr(path, pos + 1)); -} - -// Return the parts of the basename of path, split on the final ".". -// If there is no "." in the basename or "." is the final character in the -// basename, the second value will be empty. -std::pair<absl::string_view, absl::string_view> SplitBasename( - absl::string_view path) { - path = Basename(path); - size_t pos = path.find_last_of('.'); - if (pos == absl::string_view::npos) - return std::make_pair(path, absl::ClippedSubstr(path, path.size(), 0)); - return std::make_pair(path.substr(0, pos), - absl::ClippedSubstr(path, pos + 1)); -} - -} // namespace - -std::string JoinPaths(absl::string_view path1, absl::string_view path2) { - if (path1.empty()) return std::string(path2); - if (path2.empty()) return std::string(path1); - if (path1.back() == '/') { - if (path2.front() == '/') - return absl::StrCat(path1, absl::ClippedSubstr(path2, 1)); - } else { - if (path2.front() != '/') return absl::StrCat(path1, "/", path2); - } - return absl::StrCat(path1, path2); -} - -absl::string_view DirectoryName(absl::string_view path) { - return SplitPath(path).first; -} - -absl::string_view Basename(absl::string_view path) { - return SplitPath(path).second; -} - -absl::string_view Stem(absl::string_view path) { - return SplitBasename(path).first; -} - -absl::string_view Extension(absl::string_view path) { - return SplitBasename(path).second; -} - -} // namespace file_path -} // namespace iree
diff --git a/iree/base/flatbuffer_util.cc b/iree/base/flatbuffer_util.cc deleted file mode 100644 index 7ce44f4..0000000 --- a/iree/base/flatbuffer_util.cc +++ /dev/null
@@ -1,145 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/flatbuffer_util.h" - -#include <cerrno> -#include <cstring> - -#include "absl/memory/memory.h" -#include "iree/base/file_mapping.h" -#include "iree/base/memory.h" -#include "iree/base/source_location.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" - -namespace iree { - -FlatBufferFileBase::~FlatBufferFileBase() { - if (deleter_) { - deleter_(); - deleter_ = []() {}; - } -} - -Status FlatBufferFileBase::Create(const void* root_ptr, - std::function<void()> deleter) { - IREE_TRACE_SCOPE0("FlatBufferFileBase::Create"); - - root_ptr_ = root_ptr; - deleter_ = std::move(deleter); - - return OkStatus(); -} - -Status FlatBufferFileBase::CreateWithBackingBuffer( - const void* root_ptr, ::flatbuffers::DetachedBuffer backing_buffer) { - IREE_TRACE_SCOPE0("FlatBufferFileBase::Create"); - - root_ptr_ = root_ptr; - - // Pass along the buffer provided so we keep it alive until the - // FlatBufferFileBase is destructed. - auto backing_buffer_baton = IreeMoveToLambda(backing_buffer); - deleter_ = [backing_buffer_baton]() { (void)backing_buffer_baton.value; }; - - return OkStatus(); -} - -Status FlatBufferFileBase::Wrap(const void* root_ptr) { - IREE_TRACE_SCOPE0("FlatBufferFileBase::Wrap"); - return Create(root_ptr, []() {}); -} - -Status FlatBufferFileBase::FromBuffer(Identifier identifier, - absl::Span<const uint8_t> buffer_data, - std::function<void()> deleter, - size_t root_type_size, - VerifierFn verifier_fn) { - IREE_TRACE_SCOPE("FlatBufferFileBase::FromBuffer:size", int) - (static_cast<int>(buffer_data.size())); - - // Sanity check buffer for the minimum size as FlatBuffers doesn't. - if (buffer_data.size() < 16) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Provided serialized flatbuffer buffer is too small to be legit " - "at size=" - << buffer_data.size(); - } - - // Ensure the buffer has the BIPE magic bytes. - if (identifier.has_value() && !::flatbuffers::BufferHasIdentifier( - buffer_data.data(), identifier.value())) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Provided serialized buffer does not contain the expected type; " - "magic bytes mismatch (expected " - << identifier.value() << ")"; - } - - // Verify the FlatBuffer contains valid offsets and won't try to read out of - // bounds of the buffer. We inline a bit of VerifyBufferFromStart so this code - // can stay generic. - { - IREE_TRACE_SCOPE0("FlatBufferFileBase::FromBufferVerification"); - ::flatbuffers::Verifier verifier{buffer_data.data(), buffer_data.size()}; - if (!verifier_fn(identifier.value_or(nullptr), &verifier)) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "FlatBuffer failed to verify as expected type; possibly " - "corrupt input"; - } - } - - // Resolve the root pointer in the buffer. - // This is GetMutableRoot such that we don't need to know T. - root_ptr_ = buffer_data.data() + - ::flatbuffers::EndianScalar( - *reinterpret_cast<const ::flatbuffers::uoffset_t*>( - buffer_data.data())); - if (!root_ptr_) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Unable to resolve root table"; - } - deleter_ = std::move(deleter); - - return OkStatus(); -} - -Status FlatBufferFileBase::WrapBuffer(Identifier identifier, - absl::Span<const uint8_t> buffer_data, - size_t root_type_size, - VerifierFn verifier_fn) { - IREE_TRACE_SCOPE0("FlatBufferFileBase::WrapBuffer"); - return FromBuffer( - identifier, buffer_data, []() {}, root_type_size, verifier_fn); -} - -Status FlatBufferFileBase::LoadFile(Identifier identifier, std::string path, - size_t root_type_size, - VerifierFn verifier_fn) { - IREE_TRACE_SCOPE0("FlatBufferFileBase::LoadFile"); - - ASSIGN_OR_RETURN(auto file_mapping, FileMapping::OpenRead(path)); - auto buffer_data = file_mapping->data(); - - auto handle_baton = IreeMoveToLambda(file_mapping); - return FromBuffer( - identifier, buffer_data, - [handle_baton]() { - // Keeping the mmap handle alive. - (void)handle_baton.value; - }, - root_type_size, verifier_fn); -} - -} // namespace iree
diff --git a/iree/base/flatbuffer_util.h b/iree/base/flatbuffer_util.h deleted file mode 100644 index 0e9344f..0000000 --- a/iree/base/flatbuffer_util.h +++ /dev/null
@@ -1,321 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BASE_FLATBUFFER_UTIL_H_ -#define IREE_BASE_FLATBUFFER_UTIL_H_ - -#include <cstddef> -#include <cstdint> -#include <functional> -#include <memory> -#include <string> -#include <utility> -#include <vector> - -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" -#include "absl/types/span.h" -#include "flatbuffers/flatbuffers.h" -#include "iree/base/memory.h" -#include "iree/base/status.h" - -namespace iree { - -// Wraps a FlatBuffer String in an absl::string_view. -// Returns empty-string ("") for nullptr values. -inline absl::string_view WrapString(const ::flatbuffers::String* value) { - return value ? absl::string_view{value->data(), value->size()} : ""; -} - -// Base type for FlatBufferFile<T>. See below. -class FlatBufferFileBase { - public: - using Identifier = absl::optional<const char*>; - - virtual ~FlatBufferFileBase(); - - protected: - template <typename T> - friend class FlatBufferFile; - - using VerifierFn = bool (*)(const char* identifier, - ::flatbuffers::Verifier* verifier); - - FlatBufferFileBase() = default; - - const void* root_ptr() const { return root_ptr_; } - - // Redirections of template static methods on FlatBufferFile so we can put the - // implementations in a shared compilation unit. - // See FlatBufferFile<T> for doc comments. - Status Create(const void* root_ptr, std::function<void()> deleter); - Status CreateWithBackingBuffer(const void* root_ptr, - ::flatbuffers::DetachedBuffer backing_buffer); - Status Wrap(const void* root); - Status FromBuffer(Identifier identifier, - absl::Span<const uint8_t> buffer_data, - std::function<void()> deleter, size_t root_type_size, - VerifierFn verifier_fn); - // Initializes from an STL byte based container (string and vector of - // char/byte should be compatible). - template <typename Container> - Status FromContainer(Identifier identifier, Container container, - size_t root_type_size, VerifierFn verifier_fn); - Status WrapBuffer(Identifier identifier, - absl::Span<const uint8_t> buffer_data, - size_t root_type_size, VerifierFn verifier_fn); - Status LoadFile(Identifier identifier, std::string path, - size_t root_type_size, VerifierFn verifier_fn); - - private: - const void* root_ptr_ = nullptr; - std::function<void()> deleter_; -}; - -// Immutable root FlatBuffer type wrapper with support for loading and backing -// buffer management. -// -// Immutable and thread-safe. -template <typename T> -class FlatBufferFile final : public FlatBufferFileBase { - public: - // Creates a FlatBufferFile from an in-memory root pointer. - // The provided |deleter| will be called when the FlatBufferFile is destructed - // and can be used to deallocate/clean up resources. - // - // This assumes that the root pointer has already been verified as valid. - // If verification is required instead use FromBuffer on the original buffer. - static StatusOr<std::unique_ptr<FlatBufferFile<T>>> Create( - const T* root, std::function<void()> deleter); - - // Creates a FlatBufferFile from an in-memory root pointer and the detached - // backing buffer storing it. - // - // Example: - // FlatBufferBuilder fbb; - // MyTypeBuilder mtb(fbb); - // fbb.Finish(mtb.Finish()); - // auto my_type = FlatBufferFile<MyType>::CreateWithBackingBuffer( - // fbb.Release()); - // my_type->foo(); - static StatusOr<std::unique_ptr<FlatBufferFile<T>>> CreateWithBackingBuffer( - ::flatbuffers::DetachedBuffer backing_buffer); - - // Wraps a caller-owned in-memory root pointer. - // The provided |root| must remain valid for the lifetime of the returned - // FlatBufferFile. - // - // This assumes that the root pointer has already been verified as valid. - // If verification is required instead use FromBuffer on the original buffer. - static StatusOr<std::unique_ptr<FlatBufferFile<T>>> Wrap(const T* root); - - // Creates a FlatBufferFile wrapping an external data buffer with a deleter - // function that will be called when the FlatBufferFile is destructed. - static StatusOr<std::unique_ptr<FlatBufferFile<T>>> FromBuffer( - Identifier identifier, absl::Span<const uint8_t> buffer_data, - std::function<void()> deleter); - - // Creates a FlatBufferFile from a serialized data buffer. - // The FlatBufferFile takes ownership of the vector. - static StatusOr<std::unique_ptr<FlatBufferFile<T>>> FromBuffer( - Identifier identifier, std::vector<uint8_t> buffer_data); - - // Loads a FlatBufferFile from an external buffer owned by the caller. - // The buffer must remain valid until the Pipeline is destroyed. - static StatusOr<std::unique_ptr<FlatBufferFile<T>>> WrapBuffer( - Identifier identifier, absl::Span<const uint8_t> buffer_data); - - // Loads the FlatBufferFile from a serialized byte-based STL container. - template <typename Container> - static StatusOr<std::unique_ptr<FlatBufferFile<T>>> FromContainer( - Identifier identifier, Container buffer_data); - - // Loads a FlatBufferFile from a serialized string. - // The FlatBufferFile takes ownership of the string. - static StatusOr<std::unique_ptr<FlatBufferFile<T>>> FromString( - Identifier identifier, std::string buffer_data) { - return FromContainer(identifier, std::move(buffer_data)); - } - - // Loads a FlatBufferFile from a serialized byte vector. - // The FlatBufferFile takes ownership of the vector. - static StatusOr<std::unique_ptr<FlatBufferFile<T>>> FromVector( - Identifier identifier, std::vector<uint8_t> buffer_data) { - return FromContainer(identifier, std::move(buffer_data)); - } - - // Loads a FlatBufferFile from a serialized file on the file system. - // This will attempt to mmap the file and is the preferred way of loading as - // only those pages that contain requested tables will be read. - static StatusOr<std::unique_ptr<FlatBufferFile<T>>> LoadFile( - Identifier identifier, std::string path); - - // Returns a vector of file references that share the same underlying data - // buffer. The buffer will be kept alive until the last file is released. - static StatusOr<std::vector<std::unique_ptr<FlatBufferFile<T>>>> - CreateShareGroup(std::unique_ptr<FlatBufferFile<T>> file, int count); - - ~FlatBufferFile() override = default; - - // Typed root pointer of the file. - const T* root() const { return reinterpret_cast<const T*>(root_ptr()); } - - private: - FlatBufferFile() = default; - - // Conforms to VerifierFn. - static bool VerifierFnT(const char* identifier, - ::flatbuffers::Verifier* verifier) { - return verifier->VerifyBuffer<T>(identifier); - } -}; - -template <typename Container> -Status FlatBufferFileBase::FromContainer(Identifier identifier, - Container container, - size_t root_type_size, - VerifierFn verifier_fn) { - static_assert(sizeof(*container.data()) == 1, - "Expected container of byte sized elements"); - auto buffer_data = absl::MakeConstSpan( - // Double static_cast through void is safer than reinterpret_cast. - static_cast<const uint8_t*>(static_cast<const void*>(container.data())), - container.size()); - // Use a baton to keep the container alive until the FlatBufferFileBase is - // destroyed. - auto buffer_data_baton = IreeMoveToLambda(container); - return FromBuffer( - identifier, buffer_data, - [buffer_data_baton]() { - // Keeping the container alive. - (void)buffer_data_baton.value; - }, - root_type_size, verifier_fn); -} - -// static -template <typename T> -StatusOr<std::unique_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::Create( - const T* root, std::function<void()> deleter) { - std::unique_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>}; - auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get()); - RETURN_IF_ERROR(base_file->Create(root, std::move(deleter))); - return std::move(flat_buffer_file); -} - -// static -template <typename T> -StatusOr<std::unique_ptr<FlatBufferFile<T>>> -FlatBufferFile<T>::CreateWithBackingBuffer( - ::flatbuffers::DetachedBuffer backing_buffer) { - std::unique_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>}; - auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get()); - auto* root_ptr = ::flatbuffers::GetRoot<T>(backing_buffer.data()); - RETURN_IF_ERROR( - base_file->CreateWithBackingBuffer(root_ptr, std::move(backing_buffer))); - return std::move(flat_buffer_file); -} - -// static -template <typename T> -StatusOr<std::unique_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::Wrap( - const T* root) { - std::unique_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>}; - auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get()); - RETURN_IF_ERROR(base_file->Wrap(root)); - return std::move(flat_buffer_file); -} - -// static -template <typename T> -StatusOr<std::unique_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::FromBuffer( - Identifier identifier, absl::Span<const uint8_t> buffer_data, - std::function<void()> deleter) { - std::unique_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>}; - auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get()); - RETURN_IF_ERROR(base_file->FromBuffer( - identifier, buffer_data, std::move(deleter), sizeof(T), VerifierFnT)); - return std::move(flat_buffer_file); -} - -// static -template <typename T> -StatusOr<std::unique_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::FromBuffer( - Identifier identifier, std::vector<uint8_t> buffer_data) { - auto* buffer_data_ptr = new decltype(buffer_data); - (*buffer_data_ptr) = std::move(buffer_data); - return FromBuffer(identifier, absl::MakeConstSpan(*buffer_data_ptr), - [buffer_data_ptr]() { delete buffer_data_ptr; }); -} - -// static -template <typename T> -StatusOr<std::unique_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::WrapBuffer( - Identifier identifier, absl::Span<const uint8_t> buffer_data) { - std::unique_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>}; - auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get()); - RETURN_IF_ERROR( - base_file->WrapBuffer(identifier, buffer_data, sizeof(T), VerifierFnT)); - return std::move(flat_buffer_file); -} - -// static -template <typename T> -template <typename Container> -StatusOr<std::unique_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::FromContainer( - Identifier identifier, Container buffer_data) { - std::unique_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>}; - auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get()); - RETURN_IF_ERROR(base_file->FromContainer(identifier, std::move(buffer_data), - sizeof(T), VerifierFnT)); - return std::move(flat_buffer_file); -} - -// static -template <typename T> -StatusOr<std::unique_ptr<FlatBufferFile<T>>> FlatBufferFile<T>::LoadFile( - Identifier identifier, std::string path) { - std::unique_ptr<FlatBufferFile<T>> flat_buffer_file{new FlatBufferFile<T>}; - auto* base_file = static_cast<FlatBufferFileBase*>(flat_buffer_file.get()); - RETURN_IF_ERROR( - base_file->LoadFile(identifier, std::move(path), sizeof(T), VerifierFnT)); - return std::move(flat_buffer_file); -} - -// static -template <typename T> -StatusOr<std::vector<std::unique_ptr<FlatBufferFile<T>>>> -FlatBufferFile<T>::CreateShareGroup(std::unique_ptr<FlatBufferFile<T>> file, - int count) { - // Create a shared_ptr wrapper for the base file that will be. - std::shared_ptr<FlatBufferFile<T>> shared_file{file.release()}; - - // Create N files. We wrap and keep the shared_ptr alive in the deleter - // capture. By wrapping we avoid reverifying the entire buffer. - std::vector<std::unique_ptr<FlatBufferFile<T>>> list; - for (int i = 0; i < count; ++i) { - ASSIGN_OR_RETURN(auto new_file, FlatBufferFile<T>::Create( - shared_file->root(), [shared_file]() { - // Each new file keeps a reference to - // the shared file to keep it alive. - (void)shared_file; - })); - list.push_back(std::move(new_file)); - } - return std::move(list); -} - -} // namespace iree - -#endif // IREE_BASE_FLATBUFFER_UTIL_H_
diff --git a/iree/base/init.h b/iree/base/init.h deleted file mode 100644 index 4e8f345..0000000 --- a/iree/base/init.h +++ /dev/null
@@ -1,52 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BASE_INIT_H_ -#define IREE_BASE_INIT_H_ - -// Initializer macros are defined in separate files: -// IREE_DECLARE_MODULE_INITIALIZER(name) -// IREE_REGISTER_MODULE_INITIALIZER(name, body) -// IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(name1, name2) -// IREE_REQUIRE_MODULE_INITIALIZED(name) -// IREE_RUN_MODULE_INITIALIZERS() -// IREE_REQUIRE_MODULE_LINKED(name) -// -// These macros allow for arranging pieces of initialization code to be -// executed at a well-defined time and in a well-defined order. -// -// Initialization happens automatically during InitializeEnvironment(), which -// should be called early in main(), before other code runs. - -#ifdef IREE_CONFIG_GOOGLE_INTERNAL -#include "iree/base/google/init_google.h" -#else -#include "iree/base/internal/init_internal.h" -#endif // IREE_CONFIG_GOOGLE_INTERNAL - -namespace iree { - -// Initializes the system environment in a binary. -// -// This first parses command line flags, then resolves module initializers -// by calling IREE_RUN_MODULE_INITIALIZERS(). -// -// 'argc' and 'argv' are the command line flags to parse. -// -// This should typically be called early in main(), before other code runs. -void InitializeEnvironment(int* argc, char*** argv); - -} // namespace iree - -#endif // IREE_BASE_INIT_H_
diff --git a/iree/base/internal/BUILD b/iree/base/internal/BUILD deleted file mode 100644 index 6ffe4ac..0000000 --- a/iree/base/internal/BUILD +++ /dev/null
@@ -1,118 +0,0 @@ -# Implementations for iree/base/ - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "file_handle_win32", - srcs = ["file_handle_win32.cc"], - hdrs = ["file_handle_win32.h"], - deps = [ - "//iree/base:status", - "//iree/base:target_platform", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "file_io_internal", - srcs = [ - "file_io_posix.cc", - "file_io_win32.cc", - ], - deps = [ - ":file_handle_win32", - "//iree/base:file_io_hdrs", - "//iree/base:status", - "//iree/base:target_platform", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "file_mapping_internal", - srcs = [ - "file_mapping_posix.cc", - "file_mapping_win32.cc", - ], - deps = [ - ":file_handle_win32", - "//iree/base:file_mapping_hdrs", - "//iree/base:target_platform", - "//iree/base:tracing", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "init_internal", - srcs = ["init_internal.cc"], - hdrs = ["init_internal.h"], - deps = [ - "//iree/base:target_platform", - "@com_google_absl//absl/flags:parse", - ], -) - -cc_library( - name = "logging_internal", - srcs = ["logging.cc"], - hdrs = ["logging.h"], - deps = [ - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/flags:flag", - ], -) - -cc_library( - name = "source_location_internal", - hdrs = ["source_location.h"], -) - -cc_library( - name = "status_internal", - srcs = [ - "status.cc", - "status_builder.cc", - "status_errno.cc", - "status_errors.cc", - "status_win32_errors.cc", - "statusor.cc", - ], - hdrs = [ - "status.h", - "status_builder.h", - "status_errno.h", - "status_errors.h", - "status_macros.h", - "status_win32_errors.h", - "statusor.h", - ], - deps = [ - ":logging_internal", - "//iree/base:source_location", - "//iree/base:target_platform", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/debugging:stacktrace", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "status_matchers_internal", - testonly = 1, - hdrs = ["status_matchers.h"], - deps = [ - "//iree/base:status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - "@com_google_googletest//:gtest", - ], -)
diff --git a/iree/base/internal/file_handle_win32.cc b/iree/base/internal/file_handle_win32.cc deleted file mode 100644 index b4bf05a..0000000 --- a/iree/base/internal/file_handle_win32.cc +++ /dev/null
@@ -1,55 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/internal/file_handle_win32.h" - -#include "absl/memory/memory.h" -#include "iree/base/target_platform.h" - -#if defined(IREE_PLATFORM_WINDOWS) - -#include <windows.h> - -namespace iree { - -// static -StatusOr<std::unique_ptr<FileHandle>> FileHandle::OpenRead(std::string path, - DWORD file_flags) { - HANDLE handle = ::CreateFileA( - /*lpFileName=*/path.c_str(), /*dwDesiredAccess=*/GENERIC_READ, - /*dwShareMode=*/FILE_SHARE_READ, /*lpSecurityAttributes=*/nullptr, - /*dwCreationDisposition=*/OPEN_EXISTING, - /*dwFlagsAndAttributes=*/FILE_ATTRIBUTE_NORMAL | file_flags, - /*hTemplateFile=*/nullptr); - if (handle == INVALID_HANDLE_VALUE) { - return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC) - << "Unable to open file " << path; - } - - BY_HANDLE_FILE_INFORMATION file_info; - if (::GetFileInformationByHandle(handle, &file_info) == FALSE) { - return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC) - << "Unable to query file info for " << path; - } - - uint64_t file_size = (static_cast<uint64_t>(file_info.nFileSizeHigh) << 32) | - file_info.nFileSizeLow; - return absl::make_unique<FileHandle>(handle, file_size); -} - -FileHandle::~FileHandle() { ::CloseHandle(handle_); } - -} // namespace iree - -#endif // IREE_PLATFORM_WINDOWS
diff --git a/iree/base/internal/file_handle_win32.h b/iree/base/internal/file_handle_win32.h deleted file mode 100644 index 793e5c6..0000000 --- a/iree/base/internal/file_handle_win32.h +++ /dev/null
@@ -1,57 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BASE_INTERNAL_FILE_HANDLE_WIN32_H_ -#define IREE_BASE_INTERNAL_FILE_HANDLE_WIN32_H_ - -#include <memory> -#include <string> - -#include "absl/memory/memory.h" -#include "absl/strings/string_view.h" -#include "iree/base/status.h" -#include "iree/base/target_platform.h" - -#if defined(IREE_PLATFORM_WINDOWS) - -#include <windows.h> - -namespace iree { - -class FileHandle { - public: - static StatusOr<std::unique_ptr<FileHandle>> OpenRead(std::string path, - DWORD file_flags); - - FileHandle(HANDLE handle, size_t size) : handle_(handle), size_(size) {} - ~FileHandle(); - - absl::string_view path() const { return path_; } - HANDLE handle() const { return handle_; } - size_t size() const { return size_; } - - private: - FileHandle(const FileHandle&) = delete; - FileHandle& operator=(const FileHandle&) = delete; - - std::string path_; - HANDLE handle_; - size_t size_; -}; - -} // namespace iree - -#endif // IREE_PLATFORM_WINDOWS - -#endif // IREE_BASE_INTERNAL_FILE_HANDLE_WIN32_H_
diff --git a/iree/base/internal/file_io_posix.cc b/iree/base/internal/file_io_posix.cc deleted file mode 100644 index 3674509..0000000 --- a/iree/base/internal/file_io_posix.cc +++ /dev/null
@@ -1,89 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <cstdio> - -#include "iree/base/file_io.h" -#include "iree/base/status.h" -#include "iree/base/target_platform.h" - -#if defined(IREE_PLATFORM_ANDROID) || defined(IREE_PLATFORM_APPLE) || \ - defined(IREE_PLATFORM_LINUX) - -#include <sys/stat.h> -#include <sys/types.h> -#include <unistd.h> - -namespace iree { -namespace file_io { - -Status FileExists(const std::string& path) { - struct stat stat_buf; - return stat(path.c_str(), &stat_buf) == 0 ? OkStatus() - : NotFoundErrorBuilder(IREE_LOC); -} - -StatusOr<std::string> GetFileContents(const std::string& path) { - std::unique_ptr<FILE, void (*)(FILE*)> file = {std::fopen(path.c_str(), "r"), - +[](FILE* file) { - if (file) fclose(file); - }}; - if (file == nullptr) { - return ErrnoToCanonicalStatusBuilder(errno, "Failed to open file", - IREE_LOC); - } - if (std::fseek(file.get(), 0, SEEK_END) == -1) { - return ErrnoToCanonicalStatusBuilder(errno, "Failed to seek file", - IREE_LOC); - } - size_t file_size = std::ftell(file.get()); - if (file_size == -1L) { - return ErrnoToCanonicalStatusBuilder(errno, "Failed to read file length", - IREE_LOC); - } - if (std::fseek(file.get(), 0, SEEK_SET) == -1) { - return ErrnoToCanonicalStatusBuilder(errno, "Failed to seek file", - IREE_LOC); - } - std::string contents; - contents.resize(file_size); - if (std::fread(const_cast<char*>(contents.data()), file_size, 1, - file.get()) != file_size) { - return UnavailableErrorBuilder(IREE_LOC) - << "Unable to read entire file contents"; - } - return contents; -} - -Status DeleteFile(const std::string& path) { - if (::remove(path.c_str()) == -1) { - return ErrnoToCanonicalStatusBuilder(errno, "Failed to delete file", - IREE_LOC); - } - return OkStatus(); -} - -Status MoveFile(const std::string& source_path, - const std::string& destination_path) { - if (::rename(source_path.c_str(), destination_path.c_str()) == -1) { - return ErrnoToCanonicalStatusBuilder(errno, "Failed to rename file", - IREE_LOC); - } - return OkStatus(); -} - -} // namespace file_io -} // namespace iree - -#endif // IREE_PLATFORM_*
diff --git a/iree/base/internal/file_io_win32.cc b/iree/base/internal/file_io_win32.cc deleted file mode 100644 index bb06b5f..0000000 --- a/iree/base/internal/file_io_win32.cc +++ /dev/null
@@ -1,76 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "absl/memory/memory.h" -#include "absl/strings/str_cat.h" -#include "iree/base/file_io.h" -#include "iree/base/internal/file_handle_win32.h" -#include "iree/base/target_platform.h" - -#if defined(IREE_PLATFORM_WINDOWS) - -#include <windows.h> - -namespace iree { -namespace file_io { - -Status FileExists(const std::string& path) { - DWORD attrs = ::GetFileAttributesA(path.c_str()); - if (attrs == INVALID_FILE_ATTRIBUTES) { - return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC) - << "Unable to find/access file: " << path; - } - return OkStatus(); -} - -StatusOr<std::string> GetFileContents(const std::string& path) { - ASSIGN_OR_RETURN(auto file, FileHandle::OpenRead(std::move(path), - FILE_FLAG_SEQUENTIAL_SCAN)); - std::string result; - result.resize(file->size()); - DWORD bytes_read = 0; - if (::ReadFile(file->handle(), const_cast<char*>(result.data()), - result.size(), &bytes_read, nullptr) == FALSE) { - return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC) - << "Unable to read file span of " << result.size() << " bytes"; - } else if (bytes_read != file->size()) { - return ResourceExhaustedErrorBuilder(IREE_LOC) - << "Unable to read all " << file->size() - << " bytes from the file (got " << bytes_read << ")"; - } - return result; -} - -Status DeleteFile(const std::string& path) { - if (::DeleteFileA(path.c_str()) == FALSE) { - return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC) - << "Unable to delete/access file: " << path; - } - return OkStatus(); -} - -Status MoveFile(const std::string& source_path, - const std::string& destination_path) { - if (::MoveFileA(source_path.c_str(), destination_path.c_str()) == FALSE) { - return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC) - << "Unable to move file " << source_path << " to " - << destination_path; - } - return OkStatus(); -} - -} // namespace file_io -} // namespace iree - -#endif // IREE_PLATFORM_WINDOWS
diff --git a/iree/base/internal/file_mapping_posix.cc b/iree/base/internal/file_mapping_posix.cc deleted file mode 100644 index 38006ee..0000000 --- a/iree/base/internal/file_mapping_posix.cc +++ /dev/null
@@ -1,106 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/file_mapping.h" -#include "iree/base/target_platform.h" -#include "iree/base/tracing.h" - -#if defined(IREE_PLATFORM_ANDROID) || defined(IREE_PLATFORM_APPLE) || \ - defined(IREE_PLATFORM_LINUX) - -#include <fcntl.h> -#include <sys/mman.h> -#include <sys/stat.h> -#include <sys/types.h> -#include <unistd.h> - -#include <cerrno> - -namespace iree { - -namespace { - -class FileDescriptor { - public: - static StatusOr<std::unique_ptr<FileDescriptor>> OpenRead(std::string path) { - struct stat buf; - if (::lstat(path.c_str(), &buf) == -1) { - return NotFoundErrorBuilder(IREE_LOC) - << "Unable to stat file " << path << ": " << ::strerror(errno); - } - uint64_t file_size = static_cast<size_t>(buf.st_size); - - int fd = ::open(path.c_str(), O_RDONLY); - if (fd == -1) { - return UnavailableErrorBuilder(IREE_LOC) - << "Unable to open file " << path << ": " << ::strerror(errno); - } - - return absl::make_unique<FileDescriptor>(std::move(path), fd, file_size); - } - - FileDescriptor(std::string path, int fd, size_t size) - : path_(std::move(path)), fd_(fd), size_(size) {} - ~FileDescriptor() { ::close(fd_); } - - absl::string_view path() const { return path_; } - int fd() const { return fd_; } - size_t size() const { return size_; } - - private: - FileDescriptor(const FileDescriptor&) = delete; - FileDescriptor& operator=(const FileDescriptor&) = delete; - - std::string path_; - int fd_; - size_t size_; -}; - -class MMapMapping : public FileMapping { - public: - MMapMapping(void* data, size_t data_size) - : FileMapping( - absl::MakeSpan(reinterpret_cast<uint8_t*>(data), data_size)) {} - - ~MMapMapping() override { - if (::munmap(const_cast<uint8_t*>(data_.data()), data_.size()) != 0) { - LOG(WARNING) << "Unable to unmap file: " << strerror(errno); - } - } -}; - -} // namespace - -// static -StatusOr<ref_ptr<FileMapping>> FileMapping::OpenRead(std::string path) { - IREE_TRACE_SCOPE0("FileMapping::Open"); - - // Open the file for reading. Note that we only need to keep it open long - // enough to map it and we can close the descriptor after that. - ASSIGN_OR_RETURN(auto file, FileDescriptor::OpenRead(std::move(path))); - - // Map the file from the file descriptor. - void* data = - ::mmap(nullptr, file->size(), PROT_READ, MAP_SHARED, file->fd(), 0); - if (data == MAP_FAILED) { - return UnavailableErrorBuilder(IREE_LOC) - << "Mapping failed on file (ensure uncompressed): " << file->path(); - } - - return make_ref<MMapMapping>(data, file->size()); -} - -} // namespace iree - -#endif // IREE_PLATFORM_*
diff --git a/iree/base/internal/file_mapping_win32.cc b/iree/base/internal/file_mapping_win32.cc deleted file mode 100644 index ebe8fb7..0000000 --- a/iree/base/internal/file_mapping_win32.cc +++ /dev/null
@@ -1,98 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "absl/memory/memory.h" -#include "absl/strings/str_cat.h" -#include "iree/base/file_mapping.h" -#include "iree/base/internal/file_handle_win32.h" -#include "iree/base/target_platform.h" -#include "iree/base/tracing.h" - -#if defined(IREE_PLATFORM_WINDOWS) - -#include <windows.h> - -namespace iree { - -namespace { - -class Win32FileMapping : public FileMapping { - public: - Win32FileMapping(HANDLE mapping_handle, void* data, size_t data_size) - : FileMapping( - absl::MakeSpan(reinterpret_cast<uint8_t*>(data), data_size)), - mapping_handle_(mapping_handle) {} - - ~Win32FileMapping() override { - if (!data_.empty()) { - if (::UnmapViewOfFile(data_.data()) == FALSE) { - LOG(WARNING) << "Unable to unmap file: " << GetLastError(); - } - data_ = {}; - } - if (mapping_handle_) { - ::CloseHandle(mapping_handle_); - mapping_handle_ = nullptr; - } - } - - private: - HANDLE mapping_handle_; -}; - -} // namespace - -// static -StatusOr<ref_ptr<FileMapping>> FileMapping::OpenRead(std::string path) { - IREE_TRACE_SCOPE0("FileMapping::Open"); - - // Open the file for reading. Note that we only need to keep it open long - // enough to map it and we can close the descriptor after that. - ASSIGN_OR_RETURN(auto file, FileHandle::OpenRead(std::move(path), - FILE_FLAG_RANDOM_ACCESS)); - - HANDLE mapping_handle = ::CreateFileMappingA( - /*hFile=*/file->handle(), /*lpFileMappingAttributes=*/nullptr, - /*flProtect=*/PAGE_READONLY, /*dwMaximumSizeHigh=*/0, - /*dwMaximumSizeLow=*/0, /*lpName=*/nullptr); - if (!mapping_handle) { - return Win32ErrorToCanonicalStatusBuilder(GetLastError(), IREE_LOC) - << "Failed to create mapping on file (ensure uncompressed): " - << file->path(); - } - - void* data = - ::MapViewOfFileEx(/*hFileMappingObject=*/mapping_handle, - /*dwDesiredAccess=*/FILE_MAP_READ, - /*dwFileOffsetHigh=*/0, /*dwFileOffsetLow=*/0, - /*dwNumberOfBytesToMap=*/0, /*lpBaseAddress=*/nullptr); - if (!data) { - DWORD map_view_error = GetLastError(); - ::CloseHandle(mapping_handle); - return Win32ErrorToCanonicalStatusBuilder(map_view_error, IREE_LOC) - << "Failed to map view of file: " << file->path(); - } - - auto result = make_ref<Win32FileMapping>(mapping_handle, data, file->size()); - - // NOTE: file mappings hold references to the file, so we don't need to keep - // the file around any longer than this function. - file.reset(); - - return result; -} - -} // namespace iree - -#endif // IREE_PLATFORM_WINDOWS
diff --git a/iree/base/internal/init_internal.cc b/iree/base/internal/init_internal.cc deleted file mode 100644 index c379199..0000000 --- a/iree/base/internal/init_internal.cc +++ /dev/null
@@ -1,110 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/internal/init_internal.h" - -#include <string.h> - -#include <set> - -#include "absl/flags/parse.h" - -namespace iree { - -static Initializer::NameMap* static_name_map = nullptr; - -struct Initializer::InitializerData { - Initializer* initializer_obj; - std::set<std::string> dependency_names; - - InitializerData() : initializer_obj(nullptr) {} - explicit InitializerData(Initializer* i) : initializer_obj(i) {} -}; - -Initializer::DependencyRegisterer::DependencyRegisterer( - const char* name, Initializer* initializer, const Dependency& dependency) { - NameMap* name_map = InitializerNameMap(); - - // Insert 'dependency' into the 'dependency_names' set for 'initializer'. - InitializerData* initializer_data = &(*name_map)[name]; - initializer_data->dependency_names.insert(dependency.name); - - // Ensure that 'dependency' exists in the map. - InitializerData* dependency_data = &(*name_map)[dependency.name]; - dependency_data->initializer_obj = dependency.initializer; -} - -Initializer::Initializer(const char* name, InitializerFunc function) - : name_(name), function_(function), done_(false) { - // Register this Initializer instance (wrapped by an InitializerData) within - // the static name map. - NameMap* name_map = InitializerNameMap(); - InitializerData* initializer_data = &(*name_map)[name]; - initializer_data->initializer_obj = this; -} - -void Initializer::RunInitializers() { - // Run each registered Initializer, in lexicographic order of their names. - // Initializer dependencies will be run first as needed. - NameMap* name_map = InitializerNameMap(); - for (auto& p : *name_map) { - RunInitializer(&p.second); - } -} - -void Initializer::Require() { - NameMap* name_map = InitializerNameMap(); - InitializerData* initializer_data = &(name_map->find(name_)->second); - RunInitializer(initializer_data); -} - -Initializer::NameMap* Initializer::InitializerNameMap() { - if (static_name_map == nullptr) { - static_name_map = new Initializer::NameMap; - } - return static_name_map; -} - -void Initializer::RunInitializer(InitializerData* initializer_data) { - if (initializer_data->initializer_obj->done_) { - return; - } - - // Run Initializer dependencies first. - NameMap* name_map = InitializerNameMap(); - for (const auto& dependency_name : initializer_data->dependency_names) { - auto dep_init = name_map->find(dependency_name); - RunInitializer(&dep_init->second); - } - - // Finally run the Initializer itself. - initializer_data->initializer_obj->function_(); - initializer_data->initializer_obj->done_ = true; -} - -void InitializeEnvironment(int* argc, char*** argv) { - auto positional_args = absl::ParseCommandLine(*argc, *argv); - if (positional_args.size() < *argc) { - // Edit the passed argument refs to only include positional args. - *argc = positional_args.size(); - for (int i = 0; i < *argc; ++i) { - (*argv)[i] = positional_args[i]; - } - (*argv)[*argc + 1] = nullptr; - } - - IREE_RUN_MODULE_INITIALIZERS(); -} - -} // namespace iree
diff --git a/iree/base/internal/init_internal.h b/iree/base/internal/init_internal.h deleted file mode 100644 index 79431d8..0000000 --- a/iree/base/internal/init_internal.h +++ /dev/null
@@ -1,110 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BASE_INTERNAL_INIT_INTERNAL_H_ -#define IREE_BASE_INTERNAL_INIT_INTERNAL_H_ - -#include <map> -#include <string> - -#include "iree/base/target_platform.h" - -namespace iree { - -// A static instance of this class is declared for each piece of initialization -// code using the initializer macros. -class Initializer { - public: - typedef void (*InitializerFunc)(); - - Initializer(const char* name, InitializerFunc function); - - // Runs all registered initializers that have not yet run. - // The initializers are invoked in lexicographically increasing order by name, - // except as necessary to satisfy dependencies. - // - // This is normally called by InitializeEnvironment(), so application code - // typically should not call it directly. - static void RunInitializers(); - - // Runs this initializer if it has not yet run, including any dependencies. - void Require(); - - struct Dependency { - Dependency(const char* n, Initializer* i) : name(n), initializer(i) {} - const char* const name; - Initializer* const initializer; - }; - - // A static instance of this class is declared for each piece of - // initializer ordering definition. - struct DependencyRegisterer { - DependencyRegisterer(const char* name, Initializer* initializer, - const Dependency& dependency); - }; - - struct InitializerData; - typedef std::map<std::string, InitializerData> NameMap; - - private: - static NameMap* InitializerNameMap(); - static void RunInitializer(InitializerData* initializer_data); - - const std::string name_; - InitializerFunc function_; - bool done_; -}; - -// In iree/base/init.h: -void InitializeEnvironment(int* argc, char*** argv); - -} // namespace iree - -#define IREE_DECLARE_MODULE_INITIALIZER(name) \ - extern ::iree::Initializer iree_initializer_##name - -#define IREE_REGISTER_MODULE_INITIALIZER(name, body) \ - static void iree_init_##name() { body; } \ - ::iree::Initializer iree_initializer_##name(#name, iree_init_##name) - -#define IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(name1, name2) \ - namespace { \ - static ::iree::Initializer::DependencyRegisterer \ - iree_initializer_dependency_##name1##_##name2( \ - #name2, &iree_initializer_##name2, \ - ::iree::Initializer::Dependency(#name1, &iree_initializer_##name1)); \ - } - -#define IREE_REQUIRE_MODULE_INITIALIZED(name) \ - do { \ - IREE_DECLARE_MODULE_INITIALIZER(name); \ - iree_initializer_##name.Require(); \ - } while (0) - -#define IREE_RUN_MODULE_INITIALIZERS() \ - do { \ - ::iree::Initializer::RunInitializers(); \ - } while (0) - -#if !defined(IREE_COMPILER_MSVC) -#define IREE_ATTRIBUTE_USED __attribute__((used)) -#else -#define IREE_ATTRIBUTE_USED -#endif // IREE_COMPILER_MSVC - -#define IREE_REQUIRE_MODULE_LINKED(name) \ - IREE_ATTRIBUTE_USED static ::iree::Initializer* iree_module_ref_##name = \ - &iree_initializer_##name - -#endif // IREE_BASE_INTERNAL_INIT_INTERNAL_H_
diff --git a/iree/base/internal/logging.cc b/iree/base/internal/logging.cc deleted file mode 100644 index 29a2604..0000000 --- a/iree/base/internal/logging.cc +++ /dev/null
@@ -1,106 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/internal/logging.h" - -#include <string> - -#include "absl/flags/flag.h" - -ABSL_FLAG(int, iree_minloglevel, 0, - "Minimum logging level. 0 = INFO and above."); -ABSL_FLAG(int, iree_v, 0, - "Verbosity level maximum. 1 = VLOG(0-1), 2 = VLOG(0-2)."); -ABSL_FLAG(bool, iree_logtostderr, false, "Logs to stderr instead of stdout"); - -namespace iree { -namespace internal { - -namespace { - -// Parse log level (int64_t) from environment variable (char*). -// Returns true if the value was present and parsed successfully. -bool LogLevelStrToInt(const char* iree_env_var_val, int64_t* out_level) { - *out_level = 0; - if (iree_env_var_val == nullptr) { - return false; - } - - std::string min_log_level(iree_env_var_val); - std::istringstream ss(min_log_level); - int64_t level; - if (!(ss >> level)) { - // Invalid vlog level setting, set level to default (0). - return false; - } - - *out_level = level; - return true; -} - -int64_t MinLogLevelFromEnv() { - const char* iree_env_var_val = getenv("IREE_MIN_LOG_LEVEL"); - int64_t level = 0; - if (LogLevelStrToInt(iree_env_var_val, &level)) { - return level; - } - return absl::GetFlag(FLAGS_iree_minloglevel); -} - -int64_t MinVLogLevelFromEnv() { - const char* iree_env_var_val = getenv("IREE_MIN_VLOG_LEVEL"); - int64_t level = 0; - if (LogLevelStrToInt(iree_env_var_val, &level)) { - return level; - } - return absl::GetFlag(FLAGS_iree_v); -} - -} // namespace - -LogMessage::LogMessage(const char* file_name, int line, int severity) - : file_name_(file_name), line_(line), severity_(severity) {} - -LogMessage::~LogMessage() { - // Read the min log level once during the first call to logging. - static int64_t min_log_level = MinLogLevelFromEnv(); - if (ABSL_PREDICT_TRUE(severity_ >= min_log_level)) { - EmitLogMessage(); - } -} - -int64_t LogMessage::MinVLogLevel() { - static int64_t min_vlog_level = MinVLogLevelFromEnv(); - return min_vlog_level; -} - -void LogMessage::EmitLogMessage() { - // TODO(scotttodd): Include current system time - fprintf(absl::GetFlag(FLAGS_iree_logtostderr) ? stderr : stdout, - "%c %s:%d] %s\n", "IWEF"[severity_], file_name_, line_, - str().c_str()); -} - -LogMessageFatal::LogMessageFatal(const char* file, int line) - : LogMessage(file, line, FATAL) {} - -LogMessageFatal::~LogMessageFatal() { - EmitLogMessage(); - - // abort() ensures we don't return (as promised via ATTRIBUTE_NORETURN). - abort(); -} - -} // namespace internal -} // namespace iree
diff --git a/iree/base/internal/status.cc b/iree/base/internal/status.cc deleted file mode 100644 index f4f4d57..0000000 --- a/iree/base/internal/status.cc +++ /dev/null
@@ -1,178 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/internal/status.h" - -#include <atomic> -#include <memory> - -#include "absl/base/attributes.h" -#include "absl/debugging/stacktrace.h" -#include "absl/memory/memory.h" -#include "absl/strings/str_cat.h" - -ABSL_FLAG(bool, iree_status_save_stack_trace, false, - "Save and display the full stack trace of the point of error") - .OnUpdate([]() { - iree::StatusSavesStackTrace( - absl::GetFlag(FLAGS_iree_status_save_stack_trace)); - }); - -namespace iree { - -namespace status_internal { - -ABSL_CONST_INIT std::atomic<bool> iree_save_stack_trace{false}; - -} // namespace status_internal - -bool DoesStatusSaveStackTrace() { - return status_internal::iree_save_stack_trace.load(std::memory_order_relaxed); -} -void StatusSavesStackTrace(bool on_off) { - status_internal::iree_save_stack_trace.store(on_off, - std::memory_order_relaxed); -} - -std::string StatusCodeToString(StatusCode code) { - switch (code) { - case StatusCode::kOk: - return "OK"; - case StatusCode::kCancelled: - return "CANCELLED"; - case StatusCode::kUnknown: - return "UNKNOWN"; - case StatusCode::kInvalidArgument: - return "INVALID_ARGUMENT"; - case StatusCode::kDeadlineExceeded: - return "DEADLINE_EXCEEDED"; - case StatusCode::kNotFound: - return "NOT_FOUND"; - case StatusCode::kAlreadyExists: - return "ALREADY_EXISTS"; - case StatusCode::kPermissionDenied: - return "PERMISSION_DENIED"; - case StatusCode::kUnauthenticated: - return "UNAUTHENTICATED"; - case StatusCode::kResourceExhausted: - return "RESOURCE_EXHAUSTED"; - case StatusCode::kFailedPrecondition: - return "FAILED_PRECONDITION"; - case StatusCode::kAborted: - return "ABORTED"; - case StatusCode::kOutOfRange: - return "OUT_OF_RANGE"; - case StatusCode::kUnimplemented: - return "UNIMPLEMENTED"; - case StatusCode::kInternal: - return "INTERNAL"; - case StatusCode::kUnavailable: - return "UNAVAILABLE"; - case StatusCode::kDataLoss: - return "DATA_LOSS"; - default: - return ""; - } -} - -Status::Status() {} - -Status::Status(StatusCode code, absl::string_view message) { - state_ = absl::make_unique<State>(); - state_->code = code; - state_->message = std::string(message); -} - -Status::Status(const Status& x) { - if (x.ok()) return; - - state_ = absl::make_unique<State>(); - state_->code = x.state_->code; - state_->message = x.state_->message; -} - -Status& Status::operator=(const Status& x) { - if (x.ok()) { - state_ = nullptr; - } else { - state_ = absl::make_unique<State>(); - state_->code = x.state_->code; - state_->message = x.state_->message; - } - return *this; -} - -Status::~Status() {} - -bool Status::ok() const { return state_ == nullptr; } - -StatusCode Status::code() const { - return ok() ? StatusCode::kOk : state_->code; -} - -absl::string_view Status::message() const { - return ok() ? absl::string_view() : absl::string_view(state_->message); -} - -std::string Status::ToString() const { - if (ok()) { - return "OK"; - } - - std::string text; - absl::StrAppend(&text, StatusCodeToString(state_->code), ": ", - state_->message); - // TODO(scotttodd): Payloads (stack traces) - return text; -} - -void Status::IgnoreError() const { - // no-op -} - -bool Status::EqualsSlow(const Status& a, const Status& b) { - if (a.code() != b.code()) return false; - if (a.message() != b.message()) return false; - // TODO(scotttodd): Payloads - return true; -} - -bool operator==(const Status& lhs, const Status& rhs) { - return lhs.state_ == rhs.state_ || Status::EqualsSlow(lhs, rhs); -} - -bool operator!=(const Status& lhs, const Status& rhs) { return !(lhs == rhs); } - -std::ostream& operator<<(std::ostream& os, const Status& x) { - os << x.ToString(); - return os; -} - -Status OkStatus() { return Status(); } - -Status Annotate(const Status& s, absl::string_view msg) { - if (s.ok() || msg.empty()) return s; - - absl::string_view new_msg = msg; - std::string annotated; - if (!s.message().empty()) { - absl::StrAppend(&annotated, s.message(), "; ", msg); - new_msg = annotated; - } - Status result(s.code(), new_msg); - // TODO(scotttodd): Copy payload(s) into the new Status - return result; -} - -} // namespace iree
diff --git a/iree/base/internal/status.h b/iree/base/internal/status.h deleted file mode 100644 index b83e249..0000000 --- a/iree/base/internal/status.h +++ /dev/null
@@ -1,130 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BASE_INTERNAL_STATUS_H_ -#define IREE_BASE_INTERNAL_STATUS_H_ - -#include <atomic> -#include <string> - -#include "absl/base/attributes.h" -#include "absl/flags/flag.h" -#include "absl/strings/string_view.h" -#include "iree/base/internal/logging.h" - -ABSL_DECLARE_FLAG(bool, iree_status_save_stack_trace); - -namespace iree { - -// True if Status objects will capture stack traces on init for non-ok Statuses. -bool DoesStatusSaveStackTrace(); - -// Enables/disables status stack trace saving. This is global for the process. -// While useful for debugging, stack traces can impact performance severely. -void StatusSavesStackTrace(bool on_off); - -enum class StatusCode : int { - kOk = 0, - kCancelled = 1, - kUnknown = 2, - kInvalidArgument = 3, - kDeadlineExceeded = 4, - kNotFound = 5, - kAlreadyExists = 6, - kPermissionDenied = 7, - kResourceExhausted = 8, - kFailedPrecondition = 9, - kAborted = 10, - kOutOfRange = 11, - kUnimplemented = 12, - kInternal = 13, - kUnavailable = 14, - kDataLoss = 15, - kUnauthenticated = 16, - kDoNotUseReservedForFutureExpansionUseDefaultInSwitchInstead_ = 20 -}; - -std::string StatusCodeToString(StatusCode code); - -class ABSL_MUST_USE_RESULT Status; - -// A Status value can be either OK or not-OK -// * OK indicates that the operation succeeded. -// * A not-OK value indicates that the operation failed and contains details -// about the error. -class Status final { - public: - // Creates an OK status with no message. - Status(); - - // Creates a status with the specified code and error message. - Status(StatusCode code, absl::string_view message); - - Status(const Status&); - Status& operator=(const Status& x); - - ~Status(); - - // Returns true if the Status is OK. - ABSL_MUST_USE_RESULT bool ok() const; - - // Returns the error code. - StatusCode code() const; - - // Returns the error message. Note: prefer ToString() for debug logging. - // This message rarely describes the error code. It is not unusual for the - // error message to be the empty string. - absl::string_view message() const; - - // Return a combination of the error code name and message. - std::string ToString() const; - - // Compatibility with upstream API. Equiv to ToString(). - std::string error_message() const { return ToString(); } - - friend bool operator==(const Status&, const Status&); - friend bool operator!=(const Status&, const Status&); - - // Ignores any errors, potentially suppressing complaints from any tools. - void IgnoreError() const; - - private: - static bool EqualsSlow(const Status& a, const Status& b); - - struct State { - StatusCode code; - std::string message; - }; - // OK status has a nullptr state_. Otherwise, 'state_' points to - // a 'State' structure containing the error code and message(s). - std::unique_ptr<State> state_; -}; - -// Returns an OK status, equivalent to a default constructed instance. -Status OkStatus(); - -// Prints a human-readable representation of `x` to `os`. -std::ostream& operator<<(std::ostream& os, const Status& x); - -// Returns a Status that is identical to `s` except that the message() -// has been augmented by adding `msg` to the end of the original message. -Status Annotate(const Status& s, absl::string_view msg); - -#define CHECK_OK(val) CHECK_EQ(::iree::OkStatus(), (val)) -#define QCHECK_OK(val) QCHECK_EQ(::iree::OkStatus(), (val)) -#define DCHECK_OK(val) DCHECK_EQ(::iree::OkStatus(), (val)) - -} // namespace iree - -#endif // IREE_BASE_INTERNAL_STATUS_H_
diff --git a/iree/base/internal/status_builder.cc b/iree/base/internal/status_builder.cc deleted file mode 100644 index 7ac0e87..0000000 --- a/iree/base/internal/status_builder.cc +++ /dev/null
@@ -1,140 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/internal/status_builder.h" - -#include <cstdio> - -#include "iree/base/internal/status_errors.h" - -namespace iree { - -StatusBuilder::StatusBuilder(const Status& original_status, - SourceLocation location) - : status_(original_status), loc_(location) {} - -StatusBuilder::StatusBuilder(Status&& original_status, SourceLocation location) - : status_(original_status), loc_(location) {} - -StatusBuilder::StatusBuilder(const StatusBuilder& sb) - : status_(sb.status_), loc_(sb.loc_), message_(sb.message_) {} - -StatusBuilder::StatusBuilder(StatusCode code, SourceLocation location) - : status_(code, ""), loc_(location) {} - -StatusBuilder& StatusBuilder::operator=(const StatusBuilder& sb) { - status_ = sb.status_; - loc_ = sb.loc_; - message_ = sb.message_; - return *this; -} - -StatusBuilder::operator Status() const& { - return StatusBuilder(*this).CreateStatus(); -} -StatusBuilder::operator Status() && { return std::move(*this).CreateStatus(); } - -bool StatusBuilder::ok() const { return status_.ok(); } - -StatusCode StatusBuilder::code() const { return status_.code(); } - -SourceLocation StatusBuilder::source_location() const { return loc_; } - -Status StatusBuilder::CreateStatus() && { - Status result = JoinMessageToStatus(status_, message_); - - // Reset the status after consuming it. - status_ = UnknownError(""); - message_ = ""; - return result; -} - -Status StatusBuilder::JoinMessageToStatus(Status s, absl::string_view msg) { - if (msg.empty()) return s; - return Annotate(s, msg); -} - -std::ostream& operator<<(std::ostream& os, const StatusBuilder& builder) { - return os << static_cast<Status>(builder); -} - -std::ostream& operator<<(std::ostream& os, StatusBuilder&& builder) { - return os << static_cast<Status>(std::move(builder)); -} - -StatusBuilder AbortedErrorBuilder(SourceLocation location) { - return StatusBuilder(StatusCode::kAborted, location); -} - -StatusBuilder AlreadyExistsErrorBuilder(SourceLocation location) { - return StatusBuilder(StatusCode::kAlreadyExists, location); -} - -StatusBuilder CancelledErrorBuilder(SourceLocation location) { - return StatusBuilder(StatusCode::kCancelled, location); -} - -StatusBuilder DataLossErrorBuilder(SourceLocation location) { - return StatusBuilder(StatusCode::kDataLoss, location); -} - -StatusBuilder DeadlineExceededErrorBuilder(SourceLocation location) { - return StatusBuilder(StatusCode::kDeadlineExceeded, location); -} - -StatusBuilder FailedPreconditionErrorBuilder(SourceLocation location) { - return StatusBuilder(StatusCode::kFailedPrecondition, location); -} - -StatusBuilder InternalErrorBuilder(SourceLocation location) { - return StatusBuilder(StatusCode::kInternal, location); -} - -StatusBuilder InvalidArgumentErrorBuilder(SourceLocation location) { - return StatusBuilder(StatusCode::kInvalidArgument, location); -} - -StatusBuilder NotFoundErrorBuilder(SourceLocation location) { - return StatusBuilder(StatusCode::kNotFound, location); -} - -StatusBuilder OutOfRangeErrorBuilder(SourceLocation location) { - return StatusBuilder(StatusCode::kOutOfRange, location); -} - -StatusBuilder PermissionDeniedErrorBuilder(SourceLocation location) { - return StatusBuilder(StatusCode::kPermissionDenied, location); -} - -StatusBuilder UnauthenticatedErrorBuilder(SourceLocation location) { - return StatusBuilder(StatusCode::kUnauthenticated, location); -} - -StatusBuilder ResourceExhaustedErrorBuilder(SourceLocation location) { - return StatusBuilder(StatusCode::kResourceExhausted, location); -} - -StatusBuilder UnavailableErrorBuilder(SourceLocation location) { - return StatusBuilder(StatusCode::kUnavailable, location); -} - -StatusBuilder UnimplementedErrorBuilder(SourceLocation location) { - return StatusBuilder(StatusCode::kUnimplemented, location); -} - -StatusBuilder UnknownErrorBuilder(SourceLocation location) { - return StatusBuilder(StatusCode::kUnknown, location); -} - -} // namespace iree
diff --git a/iree/base/internal/status_builder.h b/iree/base/internal/status_builder.h deleted file mode 100644 index e05665b..0000000 --- a/iree/base/internal/status_builder.h +++ /dev/null
@@ -1,137 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BASE_INTERNAL_STATUS_BUILDER_H_ -#define IREE_BASE_INTERNAL_STATUS_BUILDER_H_ - -#include "iree/base/internal/status.h" -#include "iree/base/source_location.h" - -namespace iree { - -// Creates a status based on an original_status, but enriched with additional -// information. The builder implicitly converts to Status and StatusOr<T> -// allowing for it to be returned directly. -class ABSL_MUST_USE_RESULT StatusBuilder { - public: - // Creates a `StatusBuilder` based on an original status. - explicit StatusBuilder(const Status& original_status, - SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); - explicit StatusBuilder(Status&& original_status, - SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); - - // Creates a `StatusBuilder` from a status code. - // A typical user will not specify `location`, allowing it to default to the - // current location. - explicit StatusBuilder(StatusCode code, - SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); - - StatusBuilder(const StatusBuilder& sb); - StatusBuilder& operator=(const StatusBuilder& sb); - StatusBuilder(StatusBuilder&&) = default; - StatusBuilder& operator=(StatusBuilder&&) = default; - - // Appends to the extra message that will be added to the original status. - template <typename T> - StatusBuilder& operator<<(const T& value) &; - template <typename T> - StatusBuilder&& operator<<(const T& value) &&; - - // No-op functions that may be added later. - StatusBuilder& LogError() & { return *this; } - StatusBuilder&& LogError() && { return std::move(LogError()); } - StatusBuilder& LogWarning() & { return *this; } - StatusBuilder&& LogWarning() && { return std::move(LogWarning()); } - StatusBuilder& LogInfo() & { return *this; } - StatusBuilder&& LogInfo() && { return std::move(LogInfo()); } - - // Returns true if the Status created by this builder will be ok(). - bool ok() const; - - // Returns the error code for the Status created by this builder. - StatusCode code() const; - - // Returns the source location used to create this builder. - SourceLocation source_location() const; - - // Implicit conversion to Status. - operator Status() const&; - operator Status() &&; - - private: - Status CreateStatus() &&; - - static Status JoinMessageToStatus(Status s, absl::string_view msg); - - // The status that the result will be based on. - Status status_; - - // The location to record if this status is logged. - SourceLocation loc_; - - // The message that will be added to the original status. - std::string message_; -}; - -template <typename T> -StatusBuilder& StatusBuilder::operator<<(const T& value) & { - return *this; -} -template <typename T> -StatusBuilder&& StatusBuilder::operator<<(const T& value) && { - return std::move(operator<<(value)); -} - -// Implicitly converts `builder` to `Status` and write it to `os`. -std::ostream& operator<<(std::ostream& os, const StatusBuilder& builder); -std::ostream& operator<<(std::ostream& os, StatusBuilder&& builder); - -// Each of the functions below creates StatusBuilder with a canonical error. -// The error code of the StatusBuilder matches the name of the function. -StatusBuilder AbortedErrorBuilder( - SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); -StatusBuilder AlreadyExistsErrorBuilder( - SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); -StatusBuilder CancelledErrorBuilder( - SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); -StatusBuilder DataLossErrorBuilder( - SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); -StatusBuilder DeadlineExceededErrorBuilder( - SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); -StatusBuilder FailedPreconditionErrorBuilder( - SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); -StatusBuilder InternalErrorBuilder( - SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); -StatusBuilder InvalidArgumentErrorBuilder( - SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); -StatusBuilder NotFoundErrorBuilder( - SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); -StatusBuilder OutOfRangeErrorBuilder( - SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); -StatusBuilder PermissionDeniedErrorBuilder( - SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); -StatusBuilder UnauthenticatedErrorBuilder( - SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); -StatusBuilder ResourceExhaustedErrorBuilder( - SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); -StatusBuilder UnavailableErrorBuilder( - SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); -StatusBuilder UnimplementedErrorBuilder( - SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); -StatusBuilder UnknownErrorBuilder( - SourceLocation location IREE_LOC_CURRENT_DEFAULT_ARG); - -} // namespace iree - -#endif // IREE_BASE_INTERNAL_STATUS_BUILDER_H_
diff --git a/iree/base/internal/status_errno.cc b/iree/base/internal/status_errno.cc deleted file mode 100644 index deaa5c1..0000000 --- a/iree/base/internal/status_errno.cc +++ /dev/null
@@ -1,175 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/internal/status_errno.h" - -#include <cerrno> - -#include "absl/strings/str_cat.h" - -namespace iree { - -StatusCode ErrnoToCanonicalCode(int error_number) { - switch (error_number) { - case 0: - return StatusCode::kOk; - case EINVAL: // Invalid argument - case ENAMETOOLONG: // Filename too long - case E2BIG: // Argument list too long - case EDESTADDRREQ: // Destination address required - case EDOM: // Mathematics argument out of domain of function - case EFAULT: // Bad address - case EILSEQ: // Illegal byte sequence - case ENOPROTOOPT: // Protocol not available - case ENOSTR: // Not a STREAM - case ENOTSOCK: // Not a socket - case ENOTTY: // Inappropriate I/O control operation - case EPROTOTYPE: // Protocol wrong type for socket - case ESPIPE: // Invalid seek - return StatusCode::kInvalidArgument; - case ETIMEDOUT: // Connection timed out - case ETIME: // Timer expired - return StatusCode::kDeadlineExceeded; - case ENODEV: // No such device - case ENOENT: // No such file or directory -#ifdef ENOMEDIUM - case ENOMEDIUM: // No medium found -#endif - case ENXIO: // No such device or address - case ESRCH: // No such process - return StatusCode::kNotFound; - case EEXIST: // File exists - case EADDRNOTAVAIL: // Address not available - case EALREADY: // Connection already in progress -#ifdef ENOTUNIQ - case ENOTUNIQ: // Name not unique on network -#endif - return StatusCode::kAlreadyExists; - case EPERM: // Operation not permitted - case EACCES: // Permission denied -#ifdef ENOKEY - case ENOKEY: // Required key not available -#endif - case EROFS: // Read only file system - return StatusCode::kPermissionDenied; - case ENOTEMPTY: // Directory not empty - case EISDIR: // Is a directory - case ENOTDIR: // Not a directory - case EADDRINUSE: // Address already in use - case EBADF: // Invalid file descriptor -#ifdef EBADFD - case EBADFD: // File descriptor in bad state -#endif - case EBUSY: // Device or resource busy - case ECHILD: // No child processes - case EISCONN: // Socket is connected -#ifdef EISNAM - case EISNAM: // Is a named type file -#endif -#ifdef ENOTBLK - case ENOTBLK: // Block device required -#endif - case ENOTCONN: // The socket is not connected - case EPIPE: // Broken pipe -#ifdef ESHUTDOWN - case ESHUTDOWN: // Cannot send after transport endpoint shutdown -#endif - case ETXTBSY: // Text file busy -#ifdef EUNATCH - case EUNATCH: // Protocol driver not attached -#endif - return StatusCode::kFailedPrecondition; - case ENOSPC: // No space left on device -#ifdef EDQUOT - case EDQUOT: // Disk quota exceeded -#endif - case EMFILE: // Too many open files - case EMLINK: // Too many links - case ENFILE: // Too many open files in system - case ENOBUFS: // No buffer space available - case ENODATA: // No message is available on the STREAM read queue - case ENOMEM: // Not enough space - case ENOSR: // No STREAM resources -#ifdef EUSERS - case EUSERS: // Too many users -#endif - return StatusCode::kResourceExhausted; -#ifdef ECHRNG - case ECHRNG: // Channel number out of range -#endif - case EFBIG: // File too large - case EOVERFLOW: // Value too large to be stored in data type - case ERANGE: // Result too large - return StatusCode::kOutOfRange; -#ifdef ENOPKG - case ENOPKG: // Package not installed -#endif - case ENOSYS: // Function not implemented - case ENOTSUP: // Operation not supported - case EAFNOSUPPORT: // Address family not supported -#ifdef EPFNOSUPPORT - case EPFNOSUPPORT: // Protocol family not supported -#endif - case EPROTONOSUPPORT: // Protocol not supported -#ifdef ESOCKTNOSUPPORT - case ESOCKTNOSUPPORT: // Socket type not supported -#endif - case EXDEV: // Improper link - return StatusCode::kUnimplemented; - case EAGAIN: // Resource temporarily unavailable -#ifdef ECOMM - case ECOMM: // Communication error on send -#endif - case ECONNREFUSED: // Connection refused - case ECONNABORTED: // Connection aborted - case ECONNRESET: // Connection reset - case EINTR: // Interrupted function call -#ifdef EHOSTDOWN - case EHOSTDOWN: // Host is down -#endif - case EHOSTUNREACH: // Host is unreachable - case ENETDOWN: // Network is down - case ENETRESET: // Connection aborted by network - case ENETUNREACH: // Network unreachable - case ENOLCK: // No locks available - case ENOLINK: // Link has been severed -#ifdef ENONET - case ENONET: // Machine is not on the network -#endif - return StatusCode::kUnavailable; - case EDEADLK: // Resource deadlock avoided -#ifdef ESTALE - case ESTALE: // Stale file handle -#endif - return StatusCode::kAborted; - case ECANCELED: // Operation cancelled - return StatusCode::kCancelled; - default: - return StatusCode::kUnknown; - } -} - -Status ErrnoToCanonicalStatus(int error_number, absl::string_view message) { - // TODO(scotttodd): convert error number to a string - return Status(ErrnoToCanonicalCode(error_number), - absl::StrCat(message, ": ", error_number)); -} - -StatusBuilder ErrnoToCanonicalStatusBuilder(int error_number, - absl::string_view message, - SourceLocation location) { - return StatusBuilder(ErrnoToCanonicalStatus(error_number, message), location); -} - -} // namespace iree
diff --git a/iree/base/internal/status_errno.h b/iree/base/internal/status_errno.h deleted file mode 100644 index 8b3d74b..0000000 --- a/iree/base/internal/status_errno.h +++ /dev/null
@@ -1,41 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BASE_INTERNAL_STATUS_ERRNO_H_ -#define IREE_BASE_INTERNAL_STATUS_ERRNO_H_ - -#include "absl/strings/string_view.h" -#include "iree/base/internal/status.h" -#include "iree/base/internal/statusor.h" -#include "iree/base/source_location.h" - -namespace iree { - -// Returns the code for |error_number|, which should be an |errno| value. -// See https://en.cppreference.com/w/cpp/error/errno_macros and similar refs. -StatusCode ErrnoToCanonicalCode(int error_number); - -// Returns a Status, using a code of `ErrnoToCode(error_number)`, and a -// |message| with the result of `StrError(error_number)` appended. -Status ErrnoToCanonicalStatus(int error_number, absl::string_view message); - -// Returns a StatusBuilder using a status of -// `ErrnoToCanonicalStatus(error_number, message)` and |location|. -StatusBuilder ErrnoToCanonicalStatusBuilder(int error_number, - absl::string_view message, - SourceLocation location); - -} // namespace iree - -#endif // IREE_BASE_INTERNAL_STATUS_ERRNO_H_
diff --git a/iree/base/internal/status_errors.cc b/iree/base/internal/status_errors.cc deleted file mode 100644 index 28acbbc..0000000 --- a/iree/base/internal/status_errors.cc +++ /dev/null
@@ -1,147 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/internal/status_errors.h" - -namespace iree { - -Status AbortedError(absl::string_view message) { - return Status(StatusCode::kAborted, message); -} - -Status AlreadyExistsError(absl::string_view message) { - return Status(StatusCode::kAlreadyExists, message); -} - -Status CancelledError(absl::string_view message) { - return Status(StatusCode::kCancelled, message); -} - -Status DataLossError(absl::string_view message) { - return Status(StatusCode::kDataLoss, message); -} - -Status DeadlineExceededError(absl::string_view message) { - return Status(StatusCode::kDeadlineExceeded, message); -} - -Status FailedPreconditionError(absl::string_view message) { - return Status(StatusCode::kFailedPrecondition, message); -} - -Status InternalError(absl::string_view message) { - return Status(StatusCode::kInternal, message); -} - -Status InvalidArgumentError(absl::string_view message) { - return Status(StatusCode::kInvalidArgument, message); -} - -Status NotFoundError(absl::string_view message) { - return Status(StatusCode::kNotFound, message); -} - -Status OutOfRangeError(absl::string_view message) { - return Status(StatusCode::kOutOfRange, message); -} - -Status PermissionDeniedError(absl::string_view message) { - return Status(StatusCode::kPermissionDenied, message); -} - -Status ResourceExhaustedError(absl::string_view message) { - return Status(StatusCode::kResourceExhausted, message); -} - -Status UnauthenticatedError(absl::string_view message) { - return Status(StatusCode::kUnauthenticated, message); -} - -Status UnavailableError(absl::string_view message) { - return Status(StatusCode::kUnavailable, message); -} - -Status UnimplementedError(absl::string_view message) { - return Status(StatusCode::kUnimplemented, message); -} - -Status UnknownError(absl::string_view message) { - return Status(StatusCode::kUnknown, message); -} - -bool IsAborted(const Status& status) { - return status.code() == StatusCode::kAborted; -} - -bool IsAlreadyExists(const Status& status) { - return status.code() == StatusCode::kAlreadyExists; -} - -bool IsCancelled(const Status& status) { - return status.code() == StatusCode::kCancelled; -} - -bool IsDataLoss(const Status& status) { - return status.code() == StatusCode::kDataLoss; -} - -bool IsDeadlineExceeded(const Status& status) { - return status.code() == StatusCode::kDeadlineExceeded; -} - -bool IsFailedPrecondition(const Status& status) { - return status.code() == StatusCode::kFailedPrecondition; -} - -bool IsInternal(const Status& status) { - return status.code() == StatusCode::kInternal; -} - -bool IsInvalidArgument(const Status& status) { - return status.code() == StatusCode::kInvalidArgument; -} - -bool IsNotFound(const Status& status) { - return status.code() == StatusCode::kNotFound; -} - -bool IsOutOfRange(const Status& status) { - return status.code() == StatusCode::kOutOfRange; -} - -bool IsPermissionDenied(const Status& status) { - return status.code() == StatusCode::kPermissionDenied; -} - -bool IsResourceExhausted(const Status& status) { - return status.code() == StatusCode::kResourceExhausted; -} - -bool IsUnauthenticated(const Status& status) { - return status.code() == StatusCode::kUnauthenticated; -} - -bool IsUnavailable(const Status& status) { - return status.code() == StatusCode::kUnavailable; -} - -bool IsUnimplemented(const Status& status) { - return status.code() == StatusCode::kUnimplemented; -} - -bool IsUnknown(const Status& status) { - return status.code() == StatusCode::kUnknown; -} - -} // namespace iree
diff --git a/iree/base/internal/status_errors.h b/iree/base/internal/status_errors.h deleted file mode 100644 index ac23d74..0000000 --- a/iree/base/internal/status_errors.h +++ /dev/null
@@ -1,60 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BASE_INTERNAL_STATUS_ERRORS_H_ -#define IREE_BASE_INTERNAL_STATUS_ERRORS_H_ - -#include "absl/base/attributes.h" -#include "absl/strings/string_view.h" -#include "iree/base/internal/status.h" - -namespace iree { - -Status AbortedError(absl::string_view message); -Status AlreadyExistsError(absl::string_view message); -Status CancelledError(absl::string_view message); -Status DataLossError(absl::string_view message); -Status DeadlineExceededError(absl::string_view message); -Status FailedPreconditionError(absl::string_view message); -Status InternalError(absl::string_view message); -Status InvalidArgumentError(absl::string_view message); -Status NotFoundError(absl::string_view message); -Status OutOfRangeError(absl::string_view message); -Status PermissionDeniedError(absl::string_view message); -Status ResourceExhaustedError(absl::string_view message); -Status UnauthenticatedError(absl::string_view message); -Status UnavailableError(absl::string_view message); -Status UnimplementedError(absl::string_view message); -Status UnknownError(absl::string_view message); - -ABSL_MUST_USE_RESULT bool IsAborted(const Status& status); -ABSL_MUST_USE_RESULT bool IsAlreadyExists(const Status& status); -ABSL_MUST_USE_RESULT bool IsCancelled(const Status& status); -ABSL_MUST_USE_RESULT bool IsDataLoss(const Status& status); -ABSL_MUST_USE_RESULT bool IsDeadlineExceeded(const Status& status); -ABSL_MUST_USE_RESULT bool IsFailedPrecondition(const Status& status); -ABSL_MUST_USE_RESULT bool IsInternal(const Status& status); -ABSL_MUST_USE_RESULT bool IsInvalidArgument(const Status& status); -ABSL_MUST_USE_RESULT bool IsNotFound(const Status& status); -ABSL_MUST_USE_RESULT bool IsOutOfRange(const Status& status); -ABSL_MUST_USE_RESULT bool IsPermissionDenied(const Status& status); -ABSL_MUST_USE_RESULT bool IsResourceExhausted(const Status& status); -ABSL_MUST_USE_RESULT bool IsUnauthenticated(const Status& status); -ABSL_MUST_USE_RESULT bool IsUnavailable(const Status& status); -ABSL_MUST_USE_RESULT bool IsUnimplemented(const Status& status); -ABSL_MUST_USE_RESULT bool IsUnknown(const Status& status); - -} // namespace iree - -#endif // IREE_BASE_INTERNAL_STATUS_ERRORS_H_
diff --git a/iree/base/internal/status_macros.h b/iree/base/internal/status_macros.h deleted file mode 100644 index def1e6e..0000000 --- a/iree/base/internal/status_macros.h +++ /dev/null
@@ -1,108 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BASE_INTERNAL_STATUS_MACROS_H_ -#define IREE_BASE_INTERNAL_STATUS_MACROS_H_ - -#include "iree/base/internal/status.h" -#include "iree/base/internal/status_builder.h" -#include "iree/base/internal/statusor.h" -#include "iree/base/source_location.h" - -// Evaluates an expression that produces a `iree::Status`. If the status is not -// ok, returns it from the current function. -#define RETURN_IF_ERROR(expr) \ - STATUS_MACROS_IMPL_ELSE_BLOCKER_ \ - if (iree::status_macro_internal::StatusAdaptorForMacros \ - status_macro_internal_adaptor = {(expr), IREE_LOC}) { \ - } else /* NOLINT */ \ - return status_macro_internal_adaptor.Consume() - -// Executes an expression `rexpr` that returns a `iree::StatusOr<T>`. On OK, -// moves its value into the variable defined by `lhs`, otherwise returns -// from the current function. -#define ASSIGN_OR_RETURN(...) \ - STATUS_MACROS_IMPL_GET_VARIADIC_((__VA_ARGS__, \ - STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_, \ - STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_)) \ - (__VA_ARGS__) - -// ================================================================= -// == Implementation details, do not rely on anything below here. == -// ================================================================= - -// MSVC incorrectly expands variadic macros, splice together a macro call to -// work around the bug. -#define STATUS_MACROS_IMPL_GET_VARIADIC_HELPER_(_1, _2, _3, NAME, ...) NAME -#define STATUS_MACROS_IMPL_GET_VARIADIC_(args) \ - STATUS_MACROS_IMPL_GET_VARIADIC_HELPER_ args - -#define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_2_(lhs, rexpr) \ - STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, std::move(_)) -#define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_3_(lhs, rexpr, error_expression) \ - STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_( \ - STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __LINE__), lhs, rexpr, \ - error_expression) -#define STATUS_MACROS_IMPL_ASSIGN_OR_RETURN_(statusor, lhs, rexpr, \ - error_expression) \ - auto statusor = (rexpr); \ - if (ABSL_PREDICT_FALSE(!statusor.ok())) { \ - iree::StatusBuilder _(std::move(statusor).status(), IREE_LOC); \ - (void)_; /* error_expression is allowed to not use this variable */ \ - return (error_expression); \ - } \ - lhs = std::move(statusor).ValueOrDie() - -// Internal helper for concatenating macro values. -#define STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) x##y -#define STATUS_MACROS_IMPL_CONCAT_(x, y) STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) - -// clang-format off -#define STATUS_MACROS_IMPL_ELSE_BLOCKER_ switch (0) case 0: default: // NOLINT -// clang-format on - -namespace iree { -namespace status_macro_internal { - -// Provides a conversion to bool so that it can be used inside an if statement -// that declares a variable. -class StatusAdaptorForMacros { - public: - StatusAdaptorForMacros(const Status& status, SourceLocation loc) - : builder_(status, loc) {} - - StatusAdaptorForMacros(Status&& status, SourceLocation loc) - : builder_(std::move(status), loc) {} - - StatusAdaptorForMacros(const StatusBuilder& builder, SourceLocation loc) - : builder_(builder) {} - - StatusAdaptorForMacros(StatusBuilder&& builder, SourceLocation loc) - : builder_(std::move(builder)) {} - - StatusAdaptorForMacros(const StatusAdaptorForMacros&) = delete; - StatusAdaptorForMacros& operator=(const StatusAdaptorForMacros&) = delete; - - explicit operator bool() const { return ABSL_PREDICT_TRUE(builder_.ok()); } - - StatusBuilder&& Consume() { return std::move(builder_); } - - private: - StatusBuilder builder_; -}; - -} // namespace status_macro_internal -} // namespace iree - -#endif // IREE_BASE_INTERNAL_STATUS_MACROS_H_
diff --git a/iree/base/internal/status_matchers.h b/iree/base/internal/status_matchers.h deleted file mode 100644 index 83e6a8f..0000000 --- a/iree/base/internal/status_matchers.h +++ /dev/null
@@ -1,299 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BASE_INTERNAL_STATUS_MATCHERS_H_ -#define IREE_BASE_INTERNAL_STATUS_MATCHERS_H_ - -#include <memory> - -#include "absl/strings/str_cat.h" -#include "absl/types/optional.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "iree/base/status.h" - -#undef EXPECT_OK -#undef ASSERT_OK -#undef ASSERT_OK_AND_ASSIGN - -namespace iree { - -namespace internal { - -// Implements a gMock matcher that checks that an iree::StaturOr<T> has an OK -// status and that the contained T value matches another matcher. -template <typename T> -class IsOkAndHoldsMatcher - : public ::testing::MatcherInterface<const StatusOr<T> &> { - public: - template <typename MatcherT> - IsOkAndHoldsMatcher(MatcherT &&value_matcher) - : value_matcher_(::testing::SafeMatcherCast<const T &>(value_matcher)) {} - - // From testing::MatcherInterface. - void DescribeTo(std::ostream *os) const override { - *os << "is OK and contains a value that "; - value_matcher_.DescribeTo(os); - } - - // From testing::MatcherInterface. - void DescribeNegationTo(std::ostream *os) const override { - *os << "is not OK or contains a value that "; - value_matcher_.DescribeNegationTo(os); - } - - // From testing::MatcherInterface. - bool MatchAndExplain( - const StatusOr<T> &status_or, - ::testing::MatchResultListener *listener) const override { - if (!status_or.ok()) { - *listener << "which is not OK"; - return false; - } - - ::testing::StringMatchResultListener value_listener; - bool is_a_match = - value_matcher_.MatchAndExplain(status_or.ValueOrDie(), &value_listener); - std::string value_explanation = value_listener.str(); - if (!value_explanation.empty()) { - *listener << absl::StrCat("which contains a value ", value_explanation); - } - - return is_a_match; - } - - private: - const ::testing::Matcher<const T &> value_matcher_; -}; - -// A polymorphic IsOkAndHolds() matcher. -// -// IsOkAndHolds() returns a matcher that can be used to process an IsOkAndHolds -// expectation. However, the value type T is not provided when IsOkAndHolds() is -// invoked. The value type is only inferable when the gUnit framework invokes -// the matcher with a value. Consequently, the IsOkAndHolds() function must -// return an object that is implicitly convertible to a matcher for StatusOr<T>. -// gUnit refers to such an object as a polymorphic matcher, since it can be used -// to match with more than one type of value. -template <typename ValueMatcherT> -class IsOkAndHoldsGenerator { - public: - explicit IsOkAndHoldsGenerator(ValueMatcherT value_matcher) - : value_matcher_(std::move(value_matcher)) {} - - template <typename T> - operator ::testing::Matcher<const StatusOr<T> &>() const { - return ::testing::MakeMatcher(new IsOkAndHoldsMatcher<T>(value_matcher_)); - } - - private: - const ValueMatcherT value_matcher_; -}; - -// Implements a gMock matcher for checking error-code expectations on -// iree::Status objects. -template <typename Enum> -class StatusMatcher : public ::testing::MatcherInterface<const Status &> { - public: - StatusMatcher(Enum code, absl::optional<absl::string_view> message) - : code_(code), message_(message) {} - - // From testing::MatcherInterface. - // - // Describes the expected error code. - void DescribeTo(std::ostream *os) const override { - *os << "error code " << StatusCodeToString(code_); - if (message_.has_value()) { - *os << "::'" << message_.value() << "'"; - } - } - - // From testing::MatcherInterface. - // - // Tests whether |status| has an error code that meets this matcher's - // expectation. If an error message string is specified in this matcher, it - // also tests that |status| has an error message that matches that - // expectation. - bool MatchAndExplain( - const Status &status, - ::testing::MatchResultListener *listener) const override { - if (status.code() != code_) { - *listener << "whose error code is " << StatusCodeToString(status.code()); - return false; - } - if (message_.has_value() && status.message() != message_.value()) { - *listener << "whose error message is '" << status.message() << "'"; - return false; - } - return true; - } - - private: - // Expected error code. - const Enum code_; - - // Expected error message (empty if none expected and verified). - const absl::optional<std::string> message_; -}; - -// Implements a gMock matcher that checks whether a status container (e.g. -// iree::Status or iree::StatusOr<T>) has an OK status. -template <class T> -class IsOkMatcherImpl : public ::testing::MatcherInterface<T> { - public: - IsOkMatcherImpl() = default; - - // From testing::MatcherInterface. - // - // Describes the OK expectation. - void DescribeTo(std::ostream *os) const override { *os << "is OK"; } - - // From testing::MatcherInterface. - // - // Describes the negative OK expectation. - void DescribeNegationTo(std::ostream *os) const override { - *os << "is not OK"; - } - - // From testing::MatcherInterface. - // - // Tests whether |status_container|'s OK value meets this matcher's - // expectation. - bool MatchAndExplain( - const T &status_container, - ::testing::MatchResultListener *listener) const override { - if (!status_container.ok()) { - *listener << "which is not OK"; - return false; - } - return true; - } -}; - -// IsOkMatcherGenerator is an intermediate object returned by iree::IsOk(). -// It implements implicit type-cast operators to supported matcher types: -// Matcher<const Status &> and Matcher<const StatusOr<T> &>. These typecast -// operators create gMock matchers that test OK expectations on a status -// container. -class IsOkMatcherGenerator { - public: - // Type-cast operator for Matcher<const iree::Status &>. - operator ::testing::Matcher<const Status &>() const { - return ::testing::MakeMatcher( - new internal::IsOkMatcherImpl<const Status &>()); - } - - // Type-cast operator for Matcher<const iree::StatusOr<T> &>. - template <class T> - operator ::testing::Matcher<const StatusOr<T> &>() const { - return ::testing::MakeMatcher( - new internal::IsOkMatcherImpl<const StatusOr<T> &>()); - } -}; - -} // namespace internal - -// Returns a gMock matcher that expects an iree::StatusOr<T> object to have an -// OK status and for the contained T object to match |value_matcher|. -// -// Example: -// -// StatusOr<string> raven_speech_result = raven.Speak(); -// EXPECT_THAT(raven_speech_result, IsOkAndHolds(HasSubstr("nevermore"))); -// -// If foo is an object of type T and foo_result is an object of type -// StatusOr<T>, you can write: -// -// EXPECT_THAT(foo_result, IsOkAndHolds(foo)); -// -// instead of: -// -// EXPECT_THAT(foo_result, IsOkAndHolds(Eq(foo))); -template <typename ValueMatcherT> -internal::IsOkAndHoldsGenerator<ValueMatcherT> IsOkAndHolds( - ValueMatcherT value_matcher) { - return internal::IsOkAndHoldsGenerator<ValueMatcherT>(value_matcher); -} - -// Returns a gMock matcher that expects an iree::Status object to have the -// given |code|. -template <typename Enum> -::testing::Matcher<const Status &> StatusIs(Enum code) { - return ::testing::MakeMatcher( - new internal::StatusMatcher<Enum>(code, absl::nullopt)); -} - -// Returns a gMock matcher that expects an iree::Status object to have the -// given |code| and |message|. -template <typename Enum> -::testing::Matcher<const Status &> StatusIs(Enum code, - absl::string_view message) { - return ::testing::MakeMatcher( - new internal::StatusMatcher<Enum>(code, message)); -} - -// Returns an internal::IsOkMatcherGenerator, which may be typecast to a -// Matcher<iree::Status> or Matcher<iree::StatusOr<T>>. These gMock -// matchers test that a given status container has an OK status. -inline internal::IsOkMatcherGenerator IsOk() { - return internal::IsOkMatcherGenerator(); -} - -// Macros for testing the results of functions that return iree::Status or -// iree::StatusOr<T> (for any type T). -#define EXPECT_OK(rexpr) EXPECT_THAT(rexpr, ::iree::IsOk()) -#define ASSERT_OK(rexpr) ASSERT_THAT(rexpr, ::iree::IsOk()) - -// Executes an expression that returns an iree::StatusOr<T>, and assigns the -// contained variable to lhs if the error code is OK. -// If the Status is non-OK, generates a test failure and returns from the -// current function, which must have a void return type. -// -// Example: Assigning to an existing value -// ASSERT_OK_AND_ASSIGN(ValueType value, MaybeGetValue(arg)); -// -// The value assignment example might expand into: -// StatusOr<ValueType> status_or_value = MaybeGetValue(arg); -// ASSERT_OK(status_or_value.status()); -// ValueType value = status_or_value.ValueOrDie(); -#define ASSERT_OK_AND_ASSIGN(lhs, rexpr) \ - IREE_ASSERT_OK_AND_ASSIGN_IMPL( \ - IREE_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, \ - rexpr); - -#define IREE_ASSERT_OK_AND_ASSIGN_IMPL(statusor, lhs, rexpr) \ - auto statusor = (rexpr); \ - ASSERT_OK(statusor.status()) << statusor.status(); \ - lhs = std::move(statusor.ValueOrDie()) -#define IREE_STATUS_MACROS_CONCAT_NAME(x, y) \ - IREE_STATUS_MACROS_CONCAT_IMPL(x, y) -#define IREE_STATUS_MACROS_CONCAT_IMPL(x, y) x##y - -// Implements the PrintTo() method for iree::StatusOr<T>. This method is -// used by gUnit to print iree::StatusOr<T> objects for debugging. The -// implementation relies on gUnit for printing values of T when a -// iree::StatusOr<T> object is OK and contains a value. -template <typename T> -void PrintTo(const StatusOr<T> &statusor, std::ostream *os) { - if (!statusor.ok()) { - *os << statusor.status(); - } else { - *os << absl::StrCat("OK: ", - ::testing::PrintToString(statusor.ValueOrDie())); - } -} - -} // namespace iree - -#endif // IREE_BASE_INTERNAL_STATUS_MATCHERS_H_
diff --git a/iree/base/internal/status_win32_errors.cc b/iree/base/internal/status_win32_errors.cc deleted file mode 100644 index 16814d9..0000000 --- a/iree/base/internal/status_win32_errors.cc +++ /dev/null
@@ -1,63 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/internal/status_win32_errors.h" - -#include "absl/strings/str_cat.h" - -#if defined(IREE_PLATFORM_WINDOWS) - -#include <windows.h> - -namespace iree { - -StatusCode Win32ErrorToCanonicalCode(uint32_t error) { - switch (error) { - case ERROR_SUCCESS: - return StatusCode::kOk; - case ERROR_FILE_NOT_FOUND: - case ERROR_PATH_NOT_FOUND: - return StatusCode::kNotFound; - case ERROR_TOO_MANY_OPEN_FILES: - case ERROR_OUTOFMEMORY: - case ERROR_HANDLE_DISK_FULL: - case ERROR_HANDLE_EOF: - return StatusCode::kResourceExhausted; - case ERROR_ACCESS_DENIED: - return StatusCode::kPermissionDenied; - case ERROR_INVALID_HANDLE: - return StatusCode::kInvalidArgument; - case ERROR_NOT_READY: - case ERROR_READ_FAULT: - return StatusCode::kUnavailable; - case ERROR_WRITE_FAULT: - return StatusCode::kDataLoss; - case ERROR_NOT_SUPPORTED: - return StatusCode::kUnimplemented; - default: - return StatusCode::kUnknown; - } -} - -StatusBuilder Win32ErrorToCanonicalStatusBuilder(uint32_t error, - SourceLocation location) { - // TODO(benvanik): use FormatMessage; or defer until required? - return StatusBuilder( - Status(Win32ErrorToCanonicalCode(error), absl::StrCat("<TBD>", error)), - location); -} - -} // namespace iree - -#endif // IREE_PLATFORM_WINDOWS
diff --git a/iree/base/internal/status_win32_errors.h b/iree/base/internal/status_win32_errors.h deleted file mode 100644 index 599852d..0000000 --- a/iree/base/internal/status_win32_errors.h +++ /dev/null
@@ -1,38 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BASE_INTERNAL_STATUS_WIN32_ERRORS_H_ -#define IREE_BASE_INTERNAL_STATUS_WIN32_ERRORS_H_ - -#include "absl/strings/string_view.h" -#include "iree/base/internal/statusor.h" -#include "iree/base/source_location.h" -#include "iree/base/target_platform.h" - -#if defined(IREE_PLATFORM_WINDOWS) - -namespace iree { - -// Returns the code for |error| which should be a Win32 error dword. -StatusCode Win32ErrorToCanonicalCode(uint32_t error); - -// Returns a StatusBuilder with a status describing the |error| and |location|. -StatusBuilder Win32ErrorToCanonicalStatusBuilder(uint32_t error, - SourceLocation location); - -} // namespace iree - -#endif // IREE_PLATFORM_WINDOWS - -#endif // IREE_BASE_INTERNAL_STATUS_WIN32_ERRORS_H_
diff --git a/iree/base/internal/statusor.cc b/iree/base/internal/statusor.cc deleted file mode 100644 index f7eb153..0000000 --- a/iree/base/internal/statusor.cc +++ /dev/null
@@ -1,39 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/internal/statusor.h" - -#include "iree/base/internal/status_errors.h" - -namespace iree { - -namespace internal_statusor { - -void Helper::HandleInvalidStatusCtorArg(Status* status) { - const char* kMessage = - "An OK status is not a valid constructor argument to StatusOr<T>"; - LOG(ERROR) << kMessage; - *status = InternalError(kMessage); - abort(); -} - -void Helper::Crash(const Status& status) { - LOG(FATAL) << "Attempting to fetch value instead of handling error " - << status; - abort(); -} - -} // namespace internal_statusor - -} // namespace iree
diff --git a/iree/base/internal/statusor.h b/iree/base/internal/statusor.h deleted file mode 100644 index 4784d38..0000000 --- a/iree/base/internal/statusor.h +++ /dev/null
@@ -1,699 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BASE_INTERNAL_STATUSOR_H_ -#define IREE_BASE_INTERNAL_STATUSOR_H_ - -#include "absl/base/attributes.h" -#include "iree/base/internal/status.h" -#include "iree/base/internal/status_builder.h" - -namespace iree { - -template <typename T> -class ABSL_MUST_USE_RESULT StatusOr; - -namespace internal_statusor { - -template <typename T, typename U> -using IsStatusOrConversionAmbiguous = - absl::disjunction<std::is_constructible<T, StatusOr<U>&>, - std::is_constructible<T, const StatusOr<U>&>, - std::is_constructible<T, StatusOr<U>&&>, - std::is_constructible<T, const StatusOr<U>&&>, - std::is_convertible<StatusOr<U>&, T>, - std::is_convertible<const StatusOr<U>&, T>, - std::is_convertible<StatusOr<U>&&, T>, - std::is_convertible<const StatusOr<U>&&, T>>; - -template <typename T, typename U> -using IsStatusOrConversionAssigmentAmbiguous = - absl::disjunction<IsStatusOrConversionAmbiguous<T, U>, - std::is_assignable<T&, StatusOr<U>&>, - std::is_assignable<T&, const StatusOr<U>&>, - std::is_assignable<T&, StatusOr<U>&&>, - std::is_assignable<T&, const StatusOr<U>&&>>; - -template <typename T, typename U> -struct IsAmbiguousStatusOrForInitialization - : // Strip const-value refs from type and check again, else false_type. - public absl::conditional_t< - std::is_same<absl::remove_cv_t<absl::remove_reference_t<U>>, - U>::value, - std::false_type, - IsAmbiguousStatusOrForInitialization< - T, absl::remove_cv_t<absl::remove_reference_t<U>>>> {}; - -template <typename T, typename U> -struct IsAmbiguousStatusOrForInitialization<T, StatusOr<U>> - : public IsStatusOrConversionAmbiguous<T, U> {}; - -template <typename T, typename U> -using IsStatusOrDirectInitializationAmbiguous = absl::disjunction< - std::is_same<StatusOr<T>, absl::remove_cv_t<absl::remove_reference_t<U>>>, - std::is_same<Status, absl::remove_cv_t<absl::remove_reference_t<U>>>, - std::is_same<StatusBuilder, absl::remove_cv_t<absl::remove_reference_t<U>>>, - std::is_same<absl::in_place_t, - absl::remove_cv_t<absl::remove_reference_t<U>>>, - IsAmbiguousStatusOrForInitialization<T, U>>; - -template <typename T, typename U> -using IsStatusOrDirectInitializationValid = absl::disjunction< - // The is_same allows nested status ors to ignore this check iff same type. - std::is_same<T, absl::remove_cv_t<absl::remove_reference_t<U>>>, - absl::negation<IsStatusOrDirectInitializationAmbiguous<T, U>>>; - -class Helper { - public: - ABSL_ATTRIBUTE_NORETURN static void HandleInvalidStatusCtorArg(Status*); - ABSL_ATTRIBUTE_NORETURN static void Crash(const Status& status); -}; - -// Construct an instance of T in `p` through placement new, passing Args... to -// the constructor. -// This abstraction is here mostly for the gcc performance fix. -template <typename T, typename... Args> -void PlacementNew(void* p, Args&&... args) { -#if defined(__GNUC__) && !defined(__clang__) - // Teach gcc that 'p' cannot be null, fixing code size issues. - if (p == nullptr) __builtin_unreachable(); -#endif - new (p) T(std::forward<Args>(args)...); -} - -// Helper base class to hold the data and all operations. -// We move all this to a base class to allow mixing with the appropriate -// TraitsBase specialization. -template <typename T> -class StatusOrData { - template <typename U> - friend class StatusOrData; - - public: - StatusOrData() = delete; - - StatusOrData(const StatusOrData& other) { - if (other.ok()) { - MakeValue(other.data_); - MakeStatus(); - } else { - MakeStatus(other.status_); - } - } - - StatusOrData(StatusOrData&& other) noexcept { - if (other.ok()) { - MakeValue(std::move(other.data_)); - MakeStatus(); - } else { - MakeStatus(std::move(other.status_)); - } - } - - template <typename U> - explicit StatusOrData(const StatusOrData<U>& other) { - if (other.ok()) { - MakeValue(other.data_); - MakeStatus(); - } else { - MakeStatus(other.status_); - } - } - - template <typename U> - explicit StatusOrData(StatusOrData<U>&& other) { - if (other.ok()) { - MakeValue(std::move(other.data_)); - MakeStatus(); - } else { - MakeStatus(std::move(other.status_)); - } - } - - template <typename... Args> - explicit StatusOrData(absl::in_place_t, Args&&... args) - : data_(std::forward<Args>(args)...) { - MakeStatus(); - } - - explicit StatusOrData(const T& value) : data_(value) { MakeStatus(); } - explicit StatusOrData(T&& value) : data_(std::move(value)) { MakeStatus(); } - - explicit StatusOrData(const Status& status) : status_(status) { - EnsureNotOk(); - } - explicit StatusOrData(Status&& status) : status_(status) { EnsureNotOk(); } - - explicit StatusOrData(const StatusBuilder& builder) : status_(builder) { - EnsureNotOk(); - } - explicit StatusOrData(StatusBuilder&& builder) : status_(std::move(builder)) { - EnsureNotOk(); - } - - StatusOrData& operator=(const StatusOrData& other) { - if (this == &other) return *this; - if (other.ok()) - Assign(other.data_); - else - Assign(other.status_); - return *this; - } - - StatusOrData& operator=(StatusOrData&& other) { - if (this == &other) return *this; - if (other.ok()) - Assign(std::move(other.data_)); - else - Assign(std::move(other.status_)); - return *this; - } - - ~StatusOrData() { - if (ok()) { - status_.~Status(); - data_.~T(); - } else { - status_.~Status(); - } - } - - void Assign(const T& value) { - if (ok()) { - data_.~T(); - MakeValue(value); - } else { - MakeValue(value); - status_ = OkStatus(); - } - } - - void Assign(T&& value) { - if (ok()) { - data_.~T(); - MakeValue(std::move(value)); - } else { - MakeValue(std::move(value)); - status_ = OkStatus(); - } - } - - void Assign(const Status& status) { - Clear(); - status_ = status; - EnsureNotOk(); - } - - void Assign(Status&& status) { - Clear(); - status_ = std::move(status); - EnsureNotOk(); - } - - bool ok() const { return status_.ok(); } - - protected: - // status_ will always be active after the constructor. - // Union to be able to initialize exactly how we need without waste. - // Eg. in the copy constructor we use the default constructor of Status in - // the ok() path to avoid an extra Ref call. - union { - Status status_; - }; - - // data_ is active iff status_.ok()==true - struct Dummy {}; - union { - // When T is const, we need some non-const object we can cast to void* for - // the placement new. dummy_ is that object. - Dummy dummy_; - T data_; - }; - - void Clear() { - if (ok()) data_.~T(); - } - - void EnsureOk() const { - if (!ok()) Helper::Crash(status_); - } - - void EnsureNotOk() { - if (ok()) Helper::HandleInvalidStatusCtorArg(&status_); - } - - // Construct the value (data_) through placement new with the passed arg. - template <typename Arg> - void MakeValue(Arg&& arg) { - internal_statusor::PlacementNew<T>(&dummy_, std::forward<Arg>(arg)); - } - - // Construct the status (status_) through placement new with the passed arg. - template <typename... Args> - void MakeStatus(Args&&... args) { - internal_statusor::PlacementNew<Status>(&status_, - std::forward<Args>(args)...); - } -}; - -// Helper base class to allow implicitly deleted constructors and assignment -// operations in StatusOr. -// TraitsBase will explicitly delete what it can't support and StatusOr will -// inherit that behavior implicitly. -template <bool Copy, bool Move> -struct TraitsBase { - TraitsBase() = default; - TraitsBase(const TraitsBase&) = default; - TraitsBase(TraitsBase&&) = default; - TraitsBase& operator=(const TraitsBase&) = default; - TraitsBase& operator=(TraitsBase&&) = default; -}; - -template <> -struct TraitsBase<false, true> { - TraitsBase() = default; - TraitsBase(const TraitsBase&) = delete; - TraitsBase(TraitsBase&&) = default; - TraitsBase& operator=(const TraitsBase&) = delete; - TraitsBase& operator=(TraitsBase&&) = default; -}; - -template <> -struct TraitsBase<false, false> { - TraitsBase() = default; - TraitsBase(const TraitsBase&) = delete; - TraitsBase(TraitsBase&&) = delete; - TraitsBase& operator=(const TraitsBase&) = delete; - TraitsBase& operator=(TraitsBase&&) = delete; -}; - -} // namespace internal_statusor - -// StatusOr<T> is the union of a Status object and a T object. -// -// A StatusOr object either holds a usable value, or an error Status explaining -// why such a value is not present. -template <typename T> -class StatusOr : private internal_statusor::StatusOrData<T>, - private internal_statusor::TraitsBase< - std::is_copy_constructible<T>::value, - std::is_move_constructible<T>::value> { - template <typename U> - friend class StatusOr; - - typedef internal_statusor::StatusOrData<T> Base; - - public: - typedef T element_type; - - // Constructs a new StatusOr with StatusCode::kUnknown status. - explicit StatusOr(); - - // StatusOr<T> is copy constructible/assignable if T is copy constructible. - StatusOr(const StatusOr&) = default; - StatusOr& operator=(const StatusOr&) = default; - - // StatusOr<T> is move constructible/assignable if T is move constructible. - StatusOr(StatusOr&&) = default; - StatusOr& operator=(StatusOr&&) = default; - - // Converting constructors from StatusOr<U>, when T is constructible from U. - // To avoid ambiguity, they are disabled if T is also constructible from - // StatusOr<U>. Explicit iff the corresponding construction of T from U is - // explicit. - template < - typename U, - absl::enable_if_t< - absl::conjunction< - absl::negation<std::is_same<T, U>>, - std::is_constructible<T, const U&>, - std::is_convertible<const U&, T>, - absl::negation<internal_statusor::IsStatusOrConversionAmbiguous< - T, U>>>::value, - int> = 0> - StatusOr(const StatusOr<U>& other) // NOLINT - : Base(static_cast<const typename StatusOr<U>::Base&>(other)) {} - template < - typename U, - absl::enable_if_t< - absl::conjunction< - absl::negation<std::is_same<T, U>>, - std::is_constructible<T, const U&>, - absl::negation<std::is_convertible<const U&, T>>, - absl::negation<internal_statusor::IsStatusOrConversionAmbiguous< - T, U>>>::value, - int> = 0> - explicit StatusOr(const StatusOr<U>& other) - : Base(static_cast<const typename StatusOr<U>::Base&>(other)) {} - - template < - typename U, - absl::enable_if_t< - absl::conjunction< - absl::negation<std::is_same<T, U>>, std::is_constructible<T, U&&>, - std::is_convertible<U&&, T>, - absl::negation<internal_statusor::IsStatusOrConversionAmbiguous< - T, U>>>::value, - int> = 0> - StatusOr(StatusOr<U>&& other) // NOLINT - : Base(static_cast<typename StatusOr<U>::Base&&>(other)) {} - template < - typename U, - absl::enable_if_t< - absl::conjunction< - absl::negation<std::is_same<T, U>>, std::is_constructible<T, U&&>, - absl::negation<std::is_convertible<U&&, T>>, - absl::negation<internal_statusor::IsStatusOrConversionAmbiguous< - T, U>>>::value, - int> = 0> - explicit StatusOr(StatusOr<U>&& other) - : Base(static_cast<typename StatusOr<U>::Base&&>(other)) {} - - // Conversion copy/move assignment operator, T must be constructible and - // assignable from U. Only enable if T cannot be directly assigned from - // StatusOr<U>. - template < - typename U, - absl::enable_if_t< - absl::conjunction< - absl::negation<std::is_same<T, U>>, - std::is_constructible<T, const U&>, - std::is_assignable<T, const U&>, - absl::negation< - internal_statusor::IsStatusOrConversionAssigmentAmbiguous< - T, U>>>::value, - int> = 0> - StatusOr& operator=(const StatusOr<U>& other) { - this->Assign(other); - return *this; - } - template < - typename U, - absl::enable_if_t< - absl::conjunction< - absl::negation<std::is_same<T, U>>, std::is_constructible<T, U&&>, - std::is_assignable<T, U&&>, - absl::negation< - internal_statusor::IsStatusOrConversionAssigmentAmbiguous< - T, U>>>::value, - int> = 0> - StatusOr& operator=(StatusOr<U>&& other) { - this->Assign(std::move(other)); - return *this; - } - - // Constructs a new StatusOr with the given value. After calling this - // constructor, this->ok() will be true and the contained value may be - // retrieved with ValueOrDie(), operator*(), or operator->(). - StatusOr(const T& value); - - // Constructs a new StatusOr with the given non-ok status. After calling this - // constructor, this->ok() will be false and calls to ValueOrDie() will - // CHECK-fail. - StatusOr(const Status& status); - StatusOr& operator=(const Status& status); - StatusOr(const StatusBuilder& builder); - StatusOr& operator=(const StatusBuilder& builder); - - // Similar to the `const T&` overload. - // - // REQUIRES: T is move constructible. - StatusOr(T&& value); - - // RValue versions of the operations declared above. - StatusOr(Status&& status); - StatusOr& operator=(Status&& status); - StatusOr(StatusBuilder&& builder); - StatusOr& operator=(StatusBuilder&& builder); - - // Constructs the inner value T in-place using the provided args, using the - // T(args...) constructor. - template <typename... Args> - explicit StatusOr(absl::in_place_t, Args&&... args); - template <typename U, typename... Args> - explicit StatusOr(absl::in_place_t, std::initializer_list<U> ilist, - Args&&... args); - - // Constructs the inner value T in-place using the provided args, using the - // T(U) (direct-initialization) constructor. Only valid if T can be - // constructed from a U. Can accept move or copy constructors. Explicit it - // U is not convertible to T. To avoid ambiguity, this is disabled if U is - // a StatusOr<J>, where J is convertible to T. - template < - typename U = T, - absl::enable_if_t< - absl::conjunction< - internal_statusor::IsStatusOrDirectInitializationValid<T, U&&>, - std::is_constructible<T, U&&>, - std::is_convertible<U&&, T>>::value, - int> = 0> - StatusOr(U&& u) // NOLINT - : StatusOr(absl::in_place, std::forward<U>(u)) {} - - template < - typename U = T, - absl::enable_if_t< - absl::conjunction< - internal_statusor::IsStatusOrDirectInitializationValid<T, U&&>, - std::is_constructible<T, U&&>, - absl::negation<std::is_convertible<U&&, T>>>::value, - int> = 0> - explicit StatusOr(U&& u) // NOLINT - : StatusOr(absl::in_place, std::forward<U>(u)) {} - - // Returns this->ok() - explicit operator bool() const { return ok(); } - - // Returns this->status().ok() - ABSL_MUST_USE_RESULT bool ok() const { return this->status_.ok(); } - - // Returns a reference to our status. If this contains a T, then - // returns OkStatus(). - const Status& status() const&; - Status status() &&; - - // Returns a reference to our current value, or CHECK-fails if !this->ok(). If - // you have already checked the status using this->ok() or operator bool(), - // then you probably want to use operator*() or operator->() to access the - // current value instead of ValueOrDie(). - const T& ValueOrDie() const&; - T& ValueOrDie() &; - const T&& ValueOrDie() const&&; - T&& ValueOrDie() &&; - - // Returns a reference to the current value. - // - // REQUIRES: this->ok() == true, otherwise the behavior is undefined. - const T& operator*() const&; - T& operator*() &; - const T&& operator*() const&&; - T&& operator*() &&; - - // Returns a pointer to the current value. - // - // REQUIRES: this->ok() == true, otherwise the behavior is undefined. - const T* operator->() const; - T* operator->(); - - // Returns a copy of the current value if this->ok() == true. Otherwise - // returns a default value. - template <typename U> - T value_or(U&& default_value) const&; - template <typename U> - T value_or(U&& default_value) &&; - - // Ignores any errors. This method does nothing except potentially suppress - // complaints from any tools that are checking that errors are not dropped on - // the floor. - void IgnoreError() const; - - private: - using internal_statusor::StatusOrData<T>::Assign; - template <typename U> - void Assign(const StatusOr<U>& other); - template <typename U> - void Assign(StatusOr<U>&& other); -}; - -//////////////////////////////////////////////////////////////////////////////// -// Implementation details for StatusOr<T> - -template <typename T> -StatusOr<T>::StatusOr() : Base(Status(StatusCode::kUnknown, "")) {} - -template <typename T> -StatusOr<T>::StatusOr(const T& value) : Base(value) {} - -template <typename T> -StatusOr<T>::StatusOr(const Status& status) : Base(status) {} - -template <typename T> -StatusOr<T>::StatusOr(const StatusBuilder& builder) : Base(builder) {} - -template <typename T> -StatusOr<T>& StatusOr<T>::operator=(const Status& status) { - this->Assign(status); - return *this; -} - -template <typename T> -StatusOr<T>& StatusOr<T>::operator=(const StatusBuilder& builder) { - return *this = static_cast<Status>(builder); -} - -template <typename T> -StatusOr<T>::StatusOr(T&& value) : Base(std::move(value)) {} - -template <typename T> -StatusOr<T>::StatusOr(Status&& status) : Base(std::move(status)) {} - -template <typename T> -StatusOr<T>::StatusOr(StatusBuilder&& builder) : Base(std::move(builder)) {} - -template <typename T> -StatusOr<T>& StatusOr<T>::operator=(Status&& status) { - this->Assign(std::move(status)); - return *this; -} - -template <typename T> -StatusOr<T>& StatusOr<T>::operator=(StatusBuilder&& builder) { - return *this = static_cast<Status>(std::move(builder)); -} - -template <typename T> -template <typename U> -inline void StatusOr<T>::Assign(const StatusOr<U>& other) { - if (other.ok()) { - this->Assign(other.ValueOrDie()); - } else { - this->Assign(other.status()); - } -} - -template <typename T> -template <typename U> -inline void StatusOr<T>::Assign(StatusOr<U>&& other) { - if (other.ok()) { - this->Assign(std::move(other).ValueOrDie()); - } else { - this->Assign(std::move(other).status()); - } -} -template <typename T> -template <typename... Args> -StatusOr<T>::StatusOr(absl::in_place_t, Args&&... args) - : Base(absl::in_place, std::forward<Args>(args)...) {} - -template <typename T> -template <typename U, typename... Args> -StatusOr<T>::StatusOr(absl::in_place_t, std::initializer_list<U> ilist, - Args&&... args) - : Base(absl::in_place, ilist, std::forward<Args>(args)...) {} - -template <typename T> -const Status& StatusOr<T>::status() const& { - return this->status_; -} -template <typename T> -Status StatusOr<T>::status() && { - return ok() ? OkStatus() : std::move(this->status_); -} - -template <typename T> -const T& StatusOr<T>::ValueOrDie() const& { - this->EnsureOk(); - return this->data_; -} - -template <typename T> -T& StatusOr<T>::ValueOrDie() & { - this->EnsureOk(); - return this->data_; -} - -template <typename T> -const T&& StatusOr<T>::ValueOrDie() const&& { - this->EnsureOk(); - return std::move(this->data_); -} - -template <typename T> -T&& StatusOr<T>::ValueOrDie() && { - this->EnsureOk(); - return std::move(this->data_); -} - -template <typename T> -const T& StatusOr<T>::operator*() const& { - this->EnsureOk(); - return this->data_; -} - -template <typename T> -T& StatusOr<T>::operator*() & { - this->EnsureOk(); - return this->data_; -} - -template <typename T> -const T&& StatusOr<T>::operator*() const&& { - this->EnsureOk(); - return std::move(this->data_); -} - -template <typename T> -T&& StatusOr<T>::operator*() && { - this->EnsureOk(); - return std::move(this->data_); -} - -template <typename T> -const T* StatusOr<T>::operator->() const { - this->EnsureOk(); - return &this->data_; -} - -template <typename T> -T* StatusOr<T>::operator->() { - this->EnsureOk(); - return &this->data_; -} - -template <typename T> -template <typename U> -T StatusOr<T>::value_or(U&& default_value) const& { - if (ok()) { - return this->data_; - } - return std::forward<U>(default_value); -} - -template <typename T> -template <typename U> -T StatusOr<T>::value_or(U&& default_value) && { - if (ok()) { - return std::move(this->data_); - } - return std::forward<U>(default_value); -} - -template <typename T> -void StatusOr<T>::IgnoreError() const { - // no-op -} - -} // namespace iree - -#endif // IREE_BASE_INTERNAL_STATUSOR_H_
diff --git a/iree/base/intrusive_list.h b/iree/base/intrusive_list.h deleted file mode 100644 index ce31c54..0000000 --- a/iree/base/intrusive_list.h +++ /dev/null
@@ -1,758 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Doubly linked list using element interior storage. -// This has the performance of std::list (that means O(1) on insert and remove) -// but performs no allocations and has better caching behavior. -// -// Elements are maintained in lists by way of IntrusiveListLinks, with each link -// allowing the element to exist in one list simultaneously. In the most simple -// case subclassing IntrusiveLinkBase will let the type be added to a list with -// little boilerplate. If an element must be in more than one list -// simultaneously IntrusiveListLinks can be added as members. -// -// Usage (simple): -// class MySimpleElement : public IntrusiveLinkBase {}; -// IntrusiveList<MySimpleElement> list; -// list.push_back(new MySimpleElement()); -// for (auto element : list) { ... } -// -// Usage (multiple lists): -// class MultiElement { -// public: -// IntrusiveListLink list_link_a; -// IntrusiveListLink list_link_b; -// }; -// IntrusiveList<MultiElement, offsetof(MultiElement, list_link_a)> list_a; -// IntrusiveList<MultiElement, offsetof(MultiElement, list_link_b)> list_b; -// -// By default elements in the list are not retained and must be kept alive -// externally. For automatic memory management there are specializations for -// std::unique_ptr. -// -// Usage (unique_ptr): -// IntrusiveList<std::unique_ptr<MyElement>> list; -// list.push_back(absl::make_unique<MyElement>()); -// std::unique_ptr<MyElement> elm = list.take(list.front()); -// -// This type is thread-unsafe. - -#ifndef IREE_BASE_INTRUSIVE_LIST_H_ -#define IREE_BASE_INTRUSIVE_LIST_H_ - -#include <cstddef> -#include <cstdint> -#include <functional> -#include <iterator> -#include <limits> - -#include "iree/base/logging.h" - -namespace iree { - -// Define to enable extensive checks after each mutation of the intrusive list. -// #define IREE_PARANOID_INTRUSIVE_LIST - -// Storage for the doubly-linked list. -// This is embedded within all elements in an intrusive list. -struct IntrusiveListLink { - IntrusiveListLink* prev = nullptr; - IntrusiveListLink* next = nullptr; - - IntrusiveListLink() = default; - - // Prevent copies. - IntrusiveListLink(const IntrusiveListLink&) = delete; - IntrusiveListLink& operator=(const IntrusiveListLink&) = delete; -}; - -template <class T> -struct IntrusiveLinkBase : public T { - public: - IntrusiveListLink link; -}; - -template <> -struct IntrusiveLinkBase<void> { - public: - IntrusiveListLink link; -}; - -// Base type for intrusive lists. -// This is either used directly when the list is on naked pointers or -// specialized to std::unique_ptr. -template <typename T, typename IteratorT, typename ReverseIteratorT, - size_t kOffset> -class IntrusiveListBase { - public: - using self_type = IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>; - - IntrusiveListBase() = default; - virtual ~IntrusiveListBase() { clear(); } - - // Prevent copies. - IntrusiveListBase(const IntrusiveListBase&) = delete; - IntrusiveListBase& operator=(const IntrusiveListBase&) = delete; - - // Returns true if the list is empty. - // Performance: O(1) - constexpr bool empty() const { return head_ == nullptr; } - - // Returns the total number of items in the list. - // Performance: O(1) - constexpr size_t size() const { return count_; } - - // Returns true if the given item is contained within the list. - // Performance: O(n) - bool contains(T* value) const; - - // Appends the contents of the given list to this one. - // The |other_list| is cleared. - // Performance: O(1) - void merge_from(self_type* other_list); - - // Removes all items from the list. - // Performance: O(n) - void clear(); - - IteratorT begin() const { return IteratorT(head_); } - IteratorT end() const { return IteratorT(nullptr); } - ReverseIteratorT rbegin() const { return ReverseIteratorT(tail_); } - ReverseIteratorT rend() const { return ReverseIteratorT(nullptr); } - - // Returns the next item in the list relative to the given item. - // |value| must exist in the list. - // Performance: O(1) - T* next(T* value) const; - - // Returns the previous item in the list relative to the given item. - // |value| must exist in the list. - // Performance: O(1) - T* previous(T* value) const; - - // Returns the item at the front of the list, if any. - // Performance: O(1) - T* front() const; - - // Inserts an item at the front of the list. - // Performance: O(1) - void push_front(T* value); - - // Removes the item at the front of the list. - // Performance: O(1) - void pop_front(); - - // Returns the item at the back of the list, if any. - // Performance: O(1) - T* back() const; - - // Inserts an item at the back of the list. - // Performance: O(1) - void push_back(T* value); - - // Removes the item at the back of the list. - // Performance: O(1) - void pop_back(); - - // Inserts an item into the list before the given iterator. - // Performance: O(1) - void insert(const IteratorT& it, T* value) { return insert(*it, value); } - void insert(T* position, T* value); - - // Erases the given item from the list. - // Returns the item following the erased item, if any. - // Performance: O(1) - T* erase(T* value); - - // Erases the item from the list at the given iterator. - // Performance: O(1) - IteratorT erase(const IteratorT& it); - ReverseIteratorT erase(const ReverseIteratorT& it); - - // Replaces the item with a new item at the same position. - // |new_value| must not be contained in any list. - // Performance: O(1) - void replace(T* old_value, T* new_value); - - // Sorts the list with the given comparison function. - // The sort function is the same as used by std::sort. - // - // Uses merge sort O(N log N) using the algorithm described here: - // http://www.chiark.greenend.org.uk/~sgtatham/algorithms/listsort.html - void sort(bool (*compare_fn)(T* a, T* b)); - - protected: - // Called when an item is added to the list. - virtual void OnAdd(T* value) {} - // Called when an item is removed from the list. - virtual void OnRemove(T* value) {} - // Called when an item is removed and deallocated. - virtual void OnDeallocate(T* value) {} - - // Performs expensive correctness checks on the list structure. It's too slow - // to use in normal builds (even dbg), so it should only be used when there's - // a suspected issue with an intrusive list. Define - // IREE_PARANOID_INTRUSIVE_LIST to enable. - void CheckCorrectness() const; - - IntrusiveListLink* head_ = nullptr; - IntrusiveListLink* tail_ = nullptr; - size_t count_ = 0; -}; - -// Basic iterator for an IntrusiveList. -template <typename T, size_t kOffset, bool kForward> -class IntrusiveListIterator - : public std::iterator<std::input_iterator_tag, int> { - public: - using self_type = IntrusiveListIterator<T, kOffset, kForward>; - - explicit IntrusiveListIterator(IntrusiveListLink* current) - : current_(current) {} - IntrusiveListIterator& operator++(); - self_type operator++(int); - self_type& operator--(); - self_type operator--(int); - bool operator==(const self_type& rhs) const; - bool operator!=(const self_type& rhs) const; - T* operator*() const; - - protected: - IntrusiveListLink* current_; -}; - -// Specialized IntrusiveListBase used for unreferenced naked pointers. -// This very thinly wraps the base type and does no special memory management. -template <typename T, size_t kOffset> -class IntrusiveListUnrefBase - : public IntrusiveListBase<T, IntrusiveListIterator<T, kOffset, true>, - IntrusiveListIterator<T, kOffset, false>, - kOffset> { - public: - using IteratorT = IntrusiveListIterator<T, kOffset, true>; - using ReverseIteratorT = IntrusiveListIterator<T, kOffset, false>; - using base_list = IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>; - - using base_list::clear; - - // Removes all items from the list and calls the given deleter function for - // each of them. The built-in OnDeallocate will not be used. - // Performance: O(n) - void clear(const std::function<void(T*)>& deleter); - - private: - using base_list::count_; - using base_list::head_; - using base_list::tail_; -}; - -constexpr size_t kUseDefaultLinkOffset = std::numeric_limits<size_t>::max(); - -// IntrusiveList for raw pointers with a specified offset. -// Use this if there are multiple links within a type. -// -// Usage: -// struct MyType { -// IntrusiveListLink link_a; -// IntrusiveListLink link_b; -// }; -// IntrusiveList<MyType, offsetof(MyType, link_a)> list_a; -// IntrusiveList<MyType, offsetof(MyType, link_b)> list_b; -template <typename T, size_t kOffset = kUseDefaultLinkOffset> -class IntrusiveList : public IntrusiveListUnrefBase<T, kOffset> {}; - -// IntrusiveList for raw pointers. -// Items added to the list will not be owned by the list and must be freed by -// the caller. -// -// Usage: -// struct MyType : public IntrusiveListBase<void> {}; -// IntrusiveList<MyType> list; -// auto* p = new MyType(); -// list.push_back(p); // p is not retained and won't be freed! -// delete p; -template <typename T> -class IntrusiveList<T, kUseDefaultLinkOffset> - : public IntrusiveListUnrefBase<T, offsetof(T, link)> {}; - -// -- implementation -- - -namespace impl { - -// Maps an IntrusiveListLink to its containing type T. -template <typename T, size_t kOffset> -static inline T* LinkToT(IntrusiveListLink* link) { - if (link) { - return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(link) - kOffset); - } else { - return nullptr; - } -} - -// Maps a containing type T to its IntrusiveListLink. -template <typename T, size_t kOffset> -static inline IntrusiveListLink* TToLink(T* value) { - if (value) { - return reinterpret_cast<IntrusiveListLink*>( - reinterpret_cast<uintptr_t>(value) + kOffset); - } else { - return nullptr; - } -} - -} // namespace impl - -template <typename T, size_t kOffset, bool kForward> -IntrusiveListIterator<T, kOffset, kForward>& -IntrusiveListIterator<T, kOffset, kForward>::operator++() { - if (current_) { - current_ = kForward ? current_->next : current_->prev; - } - return *this; -} - -template <typename T, size_t kOffset, bool kForward> -IntrusiveListIterator<T, kOffset, kForward> -IntrusiveListIterator<T, kOffset, kForward>::operator++(int) { - self_type tmp(current_); - operator++(); - return tmp; -} - -template <typename T, size_t kOffset, bool kForward> -IntrusiveListIterator<T, kOffset, kForward>& -IntrusiveListIterator<T, kOffset, kForward>::operator--() { - if (current_) { - current_ = kForward ? current_->prev : current_->next; - } - return *this; -} - -template <typename T, size_t kOffset, bool kForward> -IntrusiveListIterator<T, kOffset, kForward> -IntrusiveListIterator<T, kOffset, kForward>::operator--(int) { - self_type tmp(current_); - operator--(); - return tmp; -} - -template <typename T, size_t kOffset, bool kForward> -bool IntrusiveListIterator<T, kOffset, kForward>::operator==( - const self_type& rhs) const { - return rhs.current_ == current_; -} - -template <typename T, size_t kOffset, bool kForward> -bool IntrusiveListIterator<T, kOffset, kForward>::operator!=( - const self_type& rhs) const { - return !operator==(rhs); -} - -template <typename T, size_t kOffset, bool kForward> -T* IntrusiveListIterator<T, kOffset, kForward>::operator*() const { - return impl::LinkToT<T, kOffset>(current_); -} - -template <typename T, size_t kOffset> -void IntrusiveListUnrefBase<T, kOffset>::clear( - const std::function<void(T*)>& deleter) { - auto* link = head_; - while (link) { - auto* next = link->next; - link->prev = link->next = nullptr; - deleter(impl::LinkToT<T, kOffset>(link)); - link = next; - } - head_ = tail_ = nullptr; - count_ = 0; -} - -template <typename T, typename IteratorT, typename ReverseIteratorT, - size_t kOffset> -void IntrusiveListBase<T, IteratorT, ReverseIteratorT, - kOffset>::CheckCorrectness() const { -#if defined(IREE_PARANOID_INTRUSIVE_LIST) - auto* link = head_; - IntrusiveListLink* previous = nullptr; - size_t actual_count = 0; - while (link) { - ++actual_count; - if (!link->prev) { - DCHECK_EQ(link, head_); - } - if (!link->next) { - DCHECK_EQ(link, tail_); - } - DCHECK_EQ(link->prev, previous); - previous = link; - link = link->next; - } - DCHECK_EQ(actual_count, count_); -#endif // IREE_PARANOID_INTRUSIVE_LIST -} - -template <typename T, typename IteratorT, typename ReverseIteratorT, - size_t kOffset> -bool IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::contains( - T* value) const { - if (!value) return false; - // TODO(benvanik): faster way of checking? requires list ptr in link? - auto* needle = impl::TToLink<T, kOffset>(value); - auto* link = head_; - while (link) { - if (link == needle) { - return true; - } - link = link->next; - } - return false; -} - -template <typename T, typename IteratorT, typename ReverseIteratorT, - size_t kOffset> -void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::merge_from( - self_type* other_list) { - if (tail_) { - tail_->next = other_list->head_; - } - if (other_list->head_) { - other_list->head_->prev = tail_; - } - if (!head_) { - head_ = other_list->head_; - } - tail_ = other_list->tail_; - - other_list->head_ = nullptr; - other_list->tail_ = nullptr; - - count_ += other_list->count_; - other_list->count_ = 0; -} - -template <typename T, typename IteratorT, typename ReverseIteratorT, - size_t kOffset> -void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::clear() { - auto* link = head_; - while (link) { - auto* next = link->next; - link->prev = link->next = nullptr; - OnDeallocate(impl::LinkToT<T, kOffset>(link)); - link = next; - } - head_ = tail_ = nullptr; - count_ = 0; -} - -template <typename T, typename IteratorT, typename ReverseIteratorT, - size_t kOffset> -inline T* IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::next( - T* value) const { - if (!value) { - return nullptr; - } - auto* link = impl::TToLink<T, kOffset>(value); - return impl::LinkToT<T, kOffset>(link->next); -} - -template <typename T, typename IteratorT, typename ReverseIteratorT, - size_t kOffset> -inline T* IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::previous( - T* value) const { - if (!value) { - return nullptr; - } - auto* link = impl::TToLink<T, kOffset>(value); - return impl::LinkToT<T, kOffset>(link->prev); -} - -template <typename T, typename IteratorT, typename ReverseIteratorT, - size_t kOffset> -inline T* IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::front() - const { - return impl::LinkToT<T, kOffset>(head_); -} - -template <typename T, typename IteratorT, typename ReverseIteratorT, - size_t kOffset> -void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::push_front( - T* value) { - DCHECK(value); - auto* link = impl::TToLink<T, kOffset>(value); - DCHECK(!link->next); - DCHECK(!link->prev); - link->next = head_; - link->prev = nullptr; - head_ = link; - if (link->next) { - link->next->prev = link; - } - if (!tail_) { - tail_ = link; - } - ++count_; - OnAdd(value); - CheckCorrectness(); -} - -template <typename T, typename IteratorT, typename ReverseIteratorT, - size_t kOffset> -void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::pop_front() { - DCHECK(head_); - auto* link = head_; - if (link) { - head_ = head_->next; - link->next = link->prev = nullptr; - if (head_) { - head_->prev = nullptr; - } - if (link == tail_) { - tail_ = nullptr; - } - --count_; - OnDeallocate(impl::LinkToT<T, kOffset>(link)); - } - CheckCorrectness(); -} - -template <typename T, typename IteratorT, typename ReverseIteratorT, - size_t kOffset> -inline T* IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::back() - const { - return impl::LinkToT<T, kOffset>(tail_); -} - -template <typename T, typename IteratorT, typename ReverseIteratorT, - size_t kOffset> -void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::push_back( - T* value) { - DCHECK(value); - auto* link = impl::TToLink<T, kOffset>(value); - DCHECK(!link->next); - DCHECK(!link->prev); - link->prev = tail_; - link->next = nullptr; - tail_ = link; - if (link->prev) { - link->prev->next = link; - } - if (!head_) { - head_ = link; - } - ++count_; - OnAdd(value); - CheckCorrectness(); -} - -template <typename T, typename IteratorT, typename ReverseIteratorT, - size_t kOffset> -void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::pop_back() { - DCHECK(tail_); - auto* link = tail_; - if (link) { - tail_ = tail_->prev; - link->next = link->prev = nullptr; - if (tail_) { - tail_->next = nullptr; - } - if (link == head_) { - head_ = nullptr; - } - --count_; - OnDeallocate(impl::LinkToT<T, kOffset>(link)); - } - CheckCorrectness(); -} - -template <typename T, typename IteratorT, typename ReverseIteratorT, - size_t kOffset> -void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::insert( - T* position, T* value) { - DCHECK(value); - auto* link = impl::TToLink<T, kOffset>(value); - auto* position_link = impl::TToLink<T, kOffset>(position); - DCHECK(!link->next); - DCHECK(!link->prev); - - if (position_link == head_) { - push_front(value); - } else if (position_link == nullptr) { - push_back(value); - } else { - link->next = position_link; - link->prev = position_link->prev; - position_link->prev->next = link; - position_link->prev = link; - ++count_; - OnAdd(value); - } - CheckCorrectness(); -} - -template <typename T, typename IteratorT, typename ReverseIteratorT, - size_t kOffset> -T* IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::erase(T* value) { - if (!value) { - return nullptr; - } - auto* link = impl::TToLink<T, kOffset>(value); - if (link->prev) { - DCHECK_NE(link, head_); - link->prev->next = link->next; - } else { - DCHECK_EQ(link, head_); - head_ = link->next; - } - if (link->next) { - DCHECK_NE(link, tail_); - link->next->prev = link->prev; - } else { - DCHECK_EQ(link, tail_); - tail_ = link->prev; - } - auto* next = link->next; - link->next = link->prev = nullptr; - --count_; - OnDeallocate(value); - CheckCorrectness(); - return impl::LinkToT<T, kOffset>(next); -} - -template <typename T, typename IteratorT, typename ReverseIteratorT, - size_t kOffset> -IteratorT IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::erase( - const IteratorT& it) { - return IteratorT(impl::TToLink<T, kOffset>(erase(*it))); -} - -template <typename T, typename IteratorT, typename ReverseIteratorT, - size_t kOffset> -ReverseIteratorT IntrusiveListBase<T, IteratorT, ReverseIteratorT, - kOffset>::erase(const ReverseIteratorT& it) { - return ReverseIteratorT(impl::TToLink<T, kOffset>(erase(*it))); -} - -template <typename T, typename IteratorT, typename ReverseIteratorT, - size_t kOffset> -void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::replace( - T* old_value, T* new_value) { - DCHECK(old_value); - DCHECK(new_value); - DCHECK_NE(old_value, new_value); - auto* old_link = impl::TToLink<T, kOffset>(old_value); - auto* new_link = impl::TToLink<T, kOffset>(new_value); - new_link->next = old_link->next; - new_link->prev = old_link->prev; - if (new_link->prev) { - new_link->prev->next = new_link; - } else { - head_ = new_link; - } - if (new_link->next) { - new_link->next->prev = new_link; - } else { - tail_ = new_link; - } - old_link->next = old_link->prev = nullptr; - OnAdd(new_value); - OnDeallocate(old_value); - CheckCorrectness(); -} - -template <typename T, typename IteratorT, typename ReverseIteratorT, - size_t kOffset> -void IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>::sort( - bool (*compare_fn)(T* a, T* b)) { - if (empty()) { - // Empty list no-op. - return; - } - // Repeatedly run until the list is sorted. - int in_size = 1; - while (true) { - IntrusiveListLink* p = head_; - IntrusiveListLink* q = nullptr; - IntrusiveListLink* e = nullptr; - IntrusiveListLink* tail = nullptr; - head_ = nullptr; - tail_ = nullptr; - // Repeatedly merge sublists. - int merge_count = 0; - do { - ++merge_count; - q = p; - // Determine the size of the first part and find the second. - int p_size = 0; - for (int i = 0; i < in_size; ++i) { - ++p_size; - q = q->next; - if (!q) { - break; - } - } - // Merge the two lists (if we have two). - int q_size = in_size; - while (p_size > 0 || (q_size > 0 && q)) { - if (p_size == 0) { - // p is empty; e must come from q. - e = q; - q = q->next; - --q_size; - } else if (q_size == 0 || !q) { - // q is empty; e must come from p. - e = p; - p = p->next; - --p_size; - } else if (compare_fn(impl::LinkToT<T, kOffset>(p), - impl::LinkToT<T, kOffset>(q))) { - // p <= q; e must come from p. - e = p; - p = p->next; - --p_size; - } else { - // q < p; e must come from q. - e = q; - q = q->next; - --q_size; - } - // Append e to the merged list. - if (tail) { - tail->next = e; - } else { - head_ = e; - } - e->prev = tail; - tail = e; - } - p = q; - } while (p); - tail->next = nullptr; - if (merge_count <= 1) { - // List is now sorted; stash and return. - tail_ = tail; - CheckCorrectness(); - return; - } - // Run merge again with larger lists. - in_size *= 2; - } -} - -} // namespace iree - -// Specializations: -#include "iree/base/intrusive_list_ref_ptr.inc" -#include "iree/base/intrusive_list_unique_ptr.inc" - -#endif // IREE_BASE_INTRUSIVE_LIST_H_
diff --git a/iree/base/intrusive_list_ref_ptr.inc b/iree/base/intrusive_list_ref_ptr.inc deleted file mode 100644 index cdc7ad5..0000000 --- a/iree/base/intrusive_list_ref_ptr.inc +++ /dev/null
@@ -1,174 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// IWYU pragma: private, include "iree/base/intrusive_list.h" - -#ifndef IREE_BASE_INTRUSIVE_LIST_REF_PTR_H_ -#define IREE_BASE_INTRUSIVE_LIST_REF_PTR_H_ - -#include <cstddef> -#include <iterator> - -#include "iree/base/intrusive_list.h" -#include "iree/base/ref_ptr.h" - -namespace iree { - -// Iterator for an IntrusiveList specialized to ref_ptr. -template <typename T, size_t kOffset, bool kForward> -class IntrusiveListRefPtrIterator - : public std::iterator<std::input_iterator_tag, int> { - public: - using self_type = IntrusiveListRefPtrIterator<T, kOffset, kForward>; - - explicit IntrusiveListRefPtrIterator(IntrusiveListLink* current) - : current_(current) {} - self_type& operator++() { - if (current_) { - current_ = kForward ? current_->next : current_->prev; - } - return *this; - } - self_type operator++(int) { - self_type tmp(current_); - operator++(); - return tmp; - } - self_type& operator--() { - if (current_) { - current_ = kForward ? current_->prev : current_->next; - } - return *this; - } - self_type operator--(int) { - self_type tmp(current_); - operator--(); - return tmp; - } - bool operator==(const self_type& rhs) const { - return rhs.current_ == current_; - } - bool operator!=(const self_type& rhs) const { return !operator==(rhs); } - ref_ptr<T> operator*() const { - return add_ref(impl::LinkToT<T, kOffset>(current_)); - } - - protected: - IntrusiveListLink* current_; -}; - -// Specialized IntrusiveListBase for ref_ptr types. -// This makes the list methods accept/return ref_ptrs and iterate with -// a ref_ptr iterator. -template <typename T, size_t kOffset> -class IntrusiveListRefPtrBase - : private IntrusiveListBase< - T, IntrusiveListRefPtrIterator<T, kOffset, true>, - IntrusiveListRefPtrIterator<T, kOffset, false>, kOffset> { - public: - using IteratorT = IntrusiveListRefPtrIterator<T, kOffset, true>; - using ReverseIteratorT = IntrusiveListRefPtrIterator<T, kOffset, false>; - using base_list = IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>; - - IntrusiveListRefPtrBase() = default; - - using base_list::empty; - using base_list::size; - - using base_list::contains; - bool contains(const ref_ptr<T>& value) const { - return base_list::contains(value.get()); - } - - using base_list::clear; - - using base_list::begin; - using base_list::end; - using base_list::rbegin; - using base_list::rend; - - inline ref_ptr<T> next(const ref_ptr<T>& value) const { - return add_ref(base_list::next(value.get())); - } - inline ref_ptr<T> next(T* value) const { - return add_ref(base_list::next(value)); - } - - inline ref_ptr<T> previous(const ref_ptr<T>& value) const { - return add_ref(base_list::previous(value.get())); - } - inline ref_ptr<T> previous(T* value) const { - return add_ref(base_list::previous(value)); - } - - // Performance: O(1) - inline ref_ptr<T> front() const { - return add_ref(impl::LinkToT<T, kOffset>(head_)); - } - - void push_front(const ref_ptr<T>& value) { - base_list::push_front(value.get()); - } - - using base_list::pop_front; - - // Performance: O(1) - inline ref_ptr<T> back() const { - return add_ref(impl::LinkToT<T, kOffset>(tail_)); - } - - void push_back(const ref_ptr<T>& value) { base_list::push_back(value.get()); } - - using base_list::pop_back; - - void insert(const IteratorT& it, const ref_ptr<T>& value) { - base_list::insert(it, value.get()); - } - - using base_list::erase; - - ref_ptr<T> erase(const ref_ptr<T>& value) { - return add_ref(base_list::erase(value.get())); - } - - void replace(const ref_ptr<T>& old_value, const ref_ptr<T>& new_value) { - base_list::replace(old_value.get(), new_value.get()); - } - void replace(T* old_value, const ref_ptr<T>& new_value) { - base_list::replace(old_value, new_value.get()); - } - - using base_list::sort; - - private: - void OnAdd(T* value) override { value->AddReference(); } - void OnRemove(T* value) override { value->ReleaseReference(); } - void OnDeallocate(T* value) override { value->ReleaseReference(); } - - using base_list::count_; - using base_list::head_; - using base_list::tail_; -}; - -template <typename U, size_t kOffset> -class IntrusiveList<ref_ptr<U>, kOffset> - : public IntrusiveListRefPtrBase<U, kOffset> {}; - -template <typename U> -class IntrusiveList<ref_ptr<U>, kUseDefaultLinkOffset> - : public IntrusiveListRefPtrBase<U, offsetof(U, link)> {}; - -} // namespace iree - -#endif // IREE_BASE_INTRUSIVE_LIST_REF_PTR_H_
diff --git a/iree/base/intrusive_list_ref_ptr_test.cc b/iree/base/intrusive_list_ref_ptr_test.cc deleted file mode 100644 index b94da7d..0000000 --- a/iree/base/intrusive_list_ref_ptr_test.cc +++ /dev/null
@@ -1,100 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "gtest/gtest.h" -#include "iree/base/intrusive_list.h" - -namespace iree { -namespace { - -static int alloc_count = 0; -struct RefCountedType : public RefObject<RefCountedType> { - IntrusiveListLink link; - RefCountedType() { ++alloc_count; } - ~RefCountedType() { --alloc_count; } - static void Deallocate(RefCountedType* value) { delete value; } - using RefObject<RefCountedType>::counter_; -}; - -TEST(IntrusiveListRefPtrTest, PushAndClear) { - alloc_count = 0; - IntrusiveList<ref_ptr<RefCountedType>> list; - EXPECT_EQ(0, alloc_count); - list.push_back(make_ref<RefCountedType>()); - EXPECT_EQ(1, alloc_count); - EXPECT_NE(nullptr, list.front()); - EXPECT_EQ(2, list.front()->counter_); - list.clear(); - EXPECT_EQ(0, alloc_count); -} - -TEST(IntrusiveListRefPtrTest, PushPop) { - alloc_count = 0; - IntrusiveList<ref_ptr<RefCountedType>> list; - list.push_back(make_ref<RefCountedType>()); - EXPECT_EQ(1, alloc_count); - list.push_back(make_ref<RefCountedType>()); - EXPECT_EQ(2, alloc_count); - EXPECT_NE(list.front(), list.back()); - list.pop_back(); - EXPECT_EQ(1, alloc_count); - list.pop_front(); - EXPECT_EQ(0, alloc_count); -} - -TEST(IntrusiveListRefPtrTest, PushErase) { - alloc_count = 0; - IntrusiveList<ref_ptr<RefCountedType>> list; - list.push_back(make_ref<RefCountedType>()); - EXPECT_EQ(1, alloc_count); - EXPECT_NE(nullptr, list.front()); - EXPECT_EQ(2, list.front()->counter_); - auto item = list.front(); - EXPECT_NE(nullptr, item.get()); - EXPECT_EQ(3, list.front()->counter_); - EXPECT_EQ(1, alloc_count); - list.erase(item); - EXPECT_EQ(1, alloc_count); - item.reset(); - EXPECT_EQ(0, alloc_count); -} - -TEST(IntrusiveListRefPtrTest, PushReplace) { - alloc_count = 0; - IntrusiveList<ref_ptr<RefCountedType>> list; - list.push_back(make_ref<RefCountedType>()); - EXPECT_EQ(1, alloc_count); - list.replace(list.front(), make_ref<RefCountedType>()); - EXPECT_EQ(1, alloc_count); - list.clear(); - EXPECT_EQ(0, alloc_count); -} - -TEST(IntrusiveListRefPtrTest, Iteration) { - alloc_count = 0; - IntrusiveList<ref_ptr<RefCountedType>> list; - list.push_back(make_ref<RefCountedType>()); - list.push_back(make_ref<RefCountedType>()); - list.push_back(make_ref<RefCountedType>()); - EXPECT_EQ(3, alloc_count); - for (auto item : list) { - const ref_ptr<RefCountedType>& item_ref = item; - EXPECT_NE(nullptr, item_ref.get()); - } - list.clear(); - EXPECT_EQ(0, alloc_count); -} - -} // namespace -} // namespace iree
diff --git a/iree/base/intrusive_list_test.cc b/iree/base/intrusive_list_test.cc deleted file mode 100644 index 37216ce..0000000 --- a/iree/base/intrusive_list_test.cc +++ /dev/null
@@ -1,523 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/intrusive_list.h" - -#include <algorithm> -#include <vector> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -namespace iree { -namespace { - -using ::testing::ElementsAre; - -struct Item { - size_t some_data_0; - IntrusiveListLink list_a; - size_t some_data_1; - IntrusiveListLink list_b; - size_t some_data_2; - int value; - - static const size_t kToken = 0xDEADBEEF; - explicit Item(int value) - : some_data_0(kToken), - some_data_1(kToken), - some_data_2(kToken), - value(value) {} - bool is_valid() { - return some_data_0 == kToken && some_data_1 == kToken && - some_data_2 == kToken; - } -}; - -template <typename T, size_t V> -std::vector<T*> ExtractItems(const IntrusiveList<T, V>& list) { - std::vector<T*> items; - for (auto* item : list) { - items.push_back(item); - } - return items; -} - -template <typename T, size_t V> -std::vector<int> ExtractValues(const IntrusiveList<T, V>& list) { - std::vector<int> values; - for (auto* item : list) { - values.push_back(item->value); - } - return values; -} - -template <typename T, size_t V> -std::vector<int> ExtractValuesMutable(const IntrusiveList<T, V>& list) { - std::vector<int> values; - for (auto* item : list) { - values.push_back(item->value); - } - return values; -} - -TEST(IntrusiveListTest, PushPopItems) { - Item item1(1); - Item item2(2); - Item item3(3); - Item item4(4); - - IntrusiveList<Item, offsetof(Item, list_a)> items; - EXPECT_TRUE(items.empty()); - EXPECT_EQ(items.size(), 0u); - EXPECT_EQ(items.front(), nullptr); - EXPECT_EQ(items.back(), nullptr); - EXPECT_TRUE(items.begin() == items.end()); - items.push_front(&item1); - EXPECT_FALSE(items.empty()); - EXPECT_EQ(items.size(), 1u); - EXPECT_EQ(items.front(), &item1); - EXPECT_EQ(items.back(), &item1); - EXPECT_FALSE(items.begin() == items.end()); - items.push_front(&item2); - EXPECT_EQ(items.size(), 2u); - EXPECT_EQ(items.front(), &item2); - EXPECT_EQ(items.back(), &item1); - items.push_front(&item3); - EXPECT_EQ(items.size(), 3u); - EXPECT_EQ(items.front(), &item3); - EXPECT_EQ(items.back(), &item1); - EXPECT_THAT(ExtractValues(items), ElementsAre(3, 2, 1)); - - items.push_back(&item4); - EXPECT_EQ(items.size(), 4u); - EXPECT_EQ(items.front(), &item3); - EXPECT_EQ(items.back(), &item4); - EXPECT_THAT(ExtractValues(items), ElementsAre(3, 2, 1, 4)); - - items.pop_front(); - EXPECT_EQ(items.size(), 3u); - EXPECT_EQ(items.front(), &item2); - EXPECT_EQ(items.back(), &item4); - EXPECT_THAT(ExtractValues(items), ElementsAre(2, 1, 4)); - - items.pop_back(); - EXPECT_EQ(items.size(), 2u); - EXPECT_EQ(items.front(), &item2); - EXPECT_EQ(items.back(), &item1); - EXPECT_THAT(ExtractValues(items), ElementsAre(2, 1)); - - items.pop_back(); - items.pop_front(); - EXPECT_TRUE(items.empty()); - EXPECT_EQ(items.size(), 0u); - EXPECT_EQ(items.front(), nullptr); - EXPECT_EQ(items.back(), nullptr); - EXPECT_TRUE(items.begin() == items.end()); - - EXPECT_TRUE(item1.is_valid()); - EXPECT_TRUE(item2.is_valid()); - EXPECT_TRUE(item3.is_valid()); - EXPECT_TRUE(item4.is_valid()); -} - -TEST(IntrusiveListTest, Contains) { - Item item1(1); - Item item2(2); - Item item3(3); - Item item4(4); - - IntrusiveList<Item, offsetof(Item, list_a)> items; - items.push_back(&item1); - items.push_back(&item2); - items.push_back(&item3); - // item4 omitted. - - EXPECT_TRUE(items.contains(&item1)); - EXPECT_TRUE(items.contains(&item2)); - EXPECT_TRUE(items.contains(&item3)); - EXPECT_FALSE(items.contains(&item4)); - - EXPECT_FALSE(items.contains(nullptr)); -} - -TEST(IntrusiveListTest, MergeFrom) { - Item item1(1); - Item item2(2); - Item item3(3); - Item item4(4); - - IntrusiveList<Item, offsetof(Item, list_a)> items0; - items0.push_back(&item1); - items0.push_back(&item2); - items0.push_back(&item3); - - IntrusiveList<Item, offsetof(Item, list_a)> items1; - items1.push_back(&item4); - - items0.merge_from(&items1); - EXPECT_THAT(ExtractValues(items0), ElementsAre(1, 2, 3, 4)); - EXPECT_TRUE(items1.empty()); -} - -TEST(IntrusiveListTest, MergeFromEmpty) { - IntrusiveList<Item, offsetof(Item, list_a)> items0; - IntrusiveList<Item, offsetof(Item, list_a)> items1; - items0.merge_from(&items1); -} - -TEST(IntrusiveListTest, MergeFromAll) { - Item item1(1); - Item item2(2); - Item item3(3); - Item item4(4); - IntrusiveList<Item, offsetof(Item, list_a)> items0; - items0.push_back(&item1); - items0.push_back(&item2); - items0.push_back(&item3); - items0.push_back(&item4); - IntrusiveList<Item, offsetof(Item, list_a)> items1; - - // Merge all items from items1 into items0. Shouldn't change anything. - items0.merge_from(&items1); - EXPECT_THAT(ExtractValues(items0), ElementsAre(1, 2, 3, 4)); - EXPECT_TRUE(items1.empty()); - - // Merge all items from items0 into items1. Should move everything. - items1.merge_from(&items0); - EXPECT_TRUE(items0.empty()); - EXPECT_THAT(ExtractValues(items1), ElementsAre(1, 2, 3, 4)); -} - -TEST(IntrusiveListTest, Erase) { - Item item1(1); - Item item2(2); - Item item3(3); - Item item4(4); - - IntrusiveList<Item, offsetof(Item, list_a)> items; - items.push_back(&item1); - items.push_back(&item2); - items.push_back(&item3); - items.push_back(&item4); - - EXPECT_THAT(ExtractValues(items), ElementsAre(1, 2, 3, 4)); - items.erase(&item3); - EXPECT_THAT(ExtractValues(items), ElementsAre(1, 2, 4)); - items.erase(&item1); - EXPECT_THAT(ExtractValues(items), ElementsAre(2, 4)); - items.erase(&item4); - EXPECT_THAT(ExtractValues(items), ElementsAre(2)); - items.erase(&item2); - EXPECT_TRUE(items.empty()); - - items.push_back(&item1); - items.push_back(&item2); - items.push_back(&item3); - items.push_back(&item4); - - EXPECT_THAT(ExtractValues(items), ElementsAre(1, 2, 3, 4)); - auto it = items.begin(); - items.erase(it); - EXPECT_THAT(ExtractValues(items), ElementsAre(2, 3, 4)); - it = items.end(); - items.erase(it); - EXPECT_THAT(ExtractValues(items), ElementsAre(2, 3, 4)); - it = items.begin(); - ++it; - items.erase(it); - EXPECT_THAT(ExtractValues(items), ElementsAre(2, 4)); - - it = items.begin(); - it = items.erase(it); - EXPECT_EQ(4, (*it)->value); - EXPECT_THAT(ExtractValues(items), ElementsAre(4)); - it = items.erase(it); - EXPECT_TRUE(items.empty()); - EXPECT_EQ(items.end(), it); -} - -TEST(IntrusiveListTest, MultipleLists) { - Item item1(1); - Item item2(2); - Item item3(3); - Item item4(4); - - IntrusiveList<Item, offsetof(Item, list_a)> items_a; - IntrusiveList<Item, offsetof(Item, list_b)> items_b; - items_a.push_back(&item1); - items_a.push_back(&item2); - items_a.push_back(&item3); - items_a.push_back(&item4); - items_b.push_front(&item1); - items_b.push_front(&item2); - items_b.push_front(&item3); - items_b.push_front(&item4); - EXPECT_THAT(ExtractValues(items_a), ElementsAre(1, 2, 3, 4)); - EXPECT_THAT(ExtractValues(items_b), ElementsAre(4, 3, 2, 1)); - items_b.erase(&item3); - EXPECT_THAT(ExtractValues(items_a), ElementsAre(1, 2, 3, 4)); - EXPECT_THAT(ExtractValues(items_b), ElementsAre(4, 2, 1)); - items_a.pop_back(); - EXPECT_THAT(ExtractValues(items_a), ElementsAre(1, 2, 3)); - EXPECT_THAT(ExtractValues(items_b), ElementsAre(4, 2, 1)); -} - -TEST(IntrusiveListTest, MutableIterator) { - Item item1(1); - Item item2(2); - Item item3(3); - Item item4(4); - - IntrusiveList<Item, offsetof(Item, list_a)> items; - items.push_back(&item4); - items.push_front(&item1); - items.push_front(&item2); - items.push_front(&item3); - - EXPECT_THAT(ExtractValuesMutable(items), ElementsAre(3, 2, 1, 4)); -} - -struct BaseType { - explicit BaseType(int value) : value(value) {} - int value; - IntrusiveListLink base_link; -}; -struct SubType : public BaseType { - explicit SubType(int value) : BaseType(value) {} - IntrusiveListLink sub_link; -}; -TEST(IntrusiveListTest, SimpleType) { - SubType item1(1); - SubType item2(2); - SubType item3(3); - SubType item4(4); - - IntrusiveList<BaseType, offsetof(BaseType, base_link)> items_a; - items_a.push_front(&item1); - items_a.push_front(&item2); - items_a.push_front(&item3); - items_a.push_front(&item4); - EXPECT_THAT(ExtractValues(items_a), ElementsAre(4, 3, 2, 1)); - - IntrusiveList<SubType, offsetof(SubType, sub_link)> items_b; - items_b.push_back(&item1); - items_b.push_back(&item2); - items_b.push_back(&item3); - items_b.push_back(&item4); - EXPECT_THAT(ExtractValues(items_b), ElementsAre(1, 2, 3, 4)); -} - -struct AbstractType { - explicit AbstractType(int value) : value(value) {} - virtual ~AbstractType() = default; - virtual int DoSomething() = 0; - int value; - IntrusiveListLink base_link; -}; -struct ImplType : public AbstractType { - explicit ImplType(int value) : AbstractType(value) {} - int DoSomething() override { return value; } - IntrusiveListLink sub_link; -}; - -TEST(IntrusiveListTest, ComplexType) { - ImplType item1(1); - ImplType item2(2); - ImplType item3(3); - ImplType item4(4); - - IntrusiveList<AbstractType, offsetof(AbstractType, base_link)> items_a; - items_a.push_front(&item1); - items_a.push_front(&item2); - items_a.push_front(&item3); - items_a.push_front(&item4); - EXPECT_THAT(ExtractValues(items_a), ElementsAre(4, 3, 2, 1)); - - IntrusiveList<ImplType, offsetof(ImplType, sub_link)> items_b; - items_b.push_back(&item1); - items_b.push_back(&item2); - items_b.push_back(&item3); - items_b.push_back(&item4); - EXPECT_THAT(ExtractValues(items_b), ElementsAre(1, 2, 3, 4)); -} - -bool Comparison(Item* a, Item* b) { return a->value < b->value; } - -TEST(IntrusiveListTest, Inserting) { - Item item1(1); - Item item2(2); - Item item3(3); - Item item4(4); - - IntrusiveList<Item, offsetof(Item, list_a)> items; - items.insert(items.end(), &item3); - items.insert(items.begin(), &item1); - items.insert(items.end(), &item4); - - auto pos = std::upper_bound(items.begin(), items.end(), &item2, Comparison); - items.insert(pos, &item2); - - EXPECT_THAT(ExtractValues(items), ElementsAre(1, 2, 3, 4)); -} - -// TODO(benvanik): test reverse iteration. - -TEST(IntrusiveListTest, NextPrevious) { - Item item1(1); - Item item2(2); - - IntrusiveList<Item, offsetof(Item, list_a)> items; - EXPECT_EQ(nullptr, items.previous(nullptr)); - EXPECT_EQ(nullptr, items.next(nullptr)); - - items.push_back(&item1); - EXPECT_EQ(nullptr, items.previous(&item1)); - EXPECT_EQ(nullptr, items.next(&item1)); - - items.push_back(&item2); - EXPECT_EQ(nullptr, items.previous(&item1)); - EXPECT_EQ(&item2, items.next(&item1)); - EXPECT_EQ(&item1, items.previous(&item2)); - EXPECT_EQ(nullptr, items.next(&item2)); -} - -TEST(IntrusiveListTest, Clear) { - Item item1(1); - Item item2(2); - Item item3(3); - Item item4(4); - - IntrusiveList<Item, offsetof(Item, list_a)> items; - - // Empty clear. - items.clear(); - EXPECT_TRUE(items.empty()); - - // 1 item clear. - items.push_back(&item1); - items.clear(); - EXPECT_TRUE(items.empty()); - - // Multi-item clear. - items.push_back(&item1); - items.push_back(&item2); - items.push_back(&item3); - items.push_back(&item4); - items.clear(); - EXPECT_TRUE(items.empty()); -} - -TEST(IntrusiveListTest, ClearDeleter) { - Item item1(1); - Item item2(2); - - IntrusiveList<Item, offsetof(Item, list_a)> items; - - // No-op first. - int delete_count = 0; - items.clear([&](Item* item) { ++delete_count; }); - EXPECT_EQ(0, delete_count); - - // Now with items. - items.push_back(&item1); - items.push_back(&item2); - items.clear([&](Item* item) { ++delete_count; }); - EXPECT_EQ(2, delete_count); - EXPECT_TRUE(items.empty()); -} - -TEST(IntrusiveListTest, Replace) { - Item item1(1); - Item item2(2); - Item item3(3); - - IntrusiveList<Item, offsetof(Item, list_a)> items; - items.push_back(&item1); - items.push_back(&item2); - - items.replace(&item1, &item3); - EXPECT_THAT(ExtractValues(items), ElementsAre(3, 2)); - EXPECT_FALSE(items.contains(&item1)); - items.replace(&item2, &item1); - EXPECT_THAT(ExtractValues(items), ElementsAre(3, 1)); - EXPECT_FALSE(items.contains(&item2)); -} - -TEST(IntrusiveListTest, Sort) { - Item item1(1); - Item item2(2); - Item item3(3); - Item item4(4); - - IntrusiveList<Item, offsetof(Item, list_a)> items; - - // Empty sort. - items.sort([](Item* a, Item* b) { return a->value < b->value; }); - - // Single item sort. - items.clear(); - items.push_back(&item1); - items.sort([](Item* a, Item* b) { return a->value < b->value; }); - EXPECT_THAT(ExtractValues(items), ElementsAre(1)); - - // Already sorted. - items.clear(); - items.push_back(&item1); - items.push_back(&item2); - items.push_back(&item3); - items.push_back(&item4); - items.sort([](Item* a, Item* b) { return a->value < b->value; }); - EXPECT_THAT(ExtractValues(items), ElementsAre(1, 2, 3, 4)); - - // Reverse. - items.clear(); - items.push_back(&item4); - items.push_back(&item3); - items.push_back(&item2); - items.push_back(&item1); - items.sort([](Item* a, Item* b) { return a->value < b->value; }); - EXPECT_THAT(ExtractValues(items), ElementsAre(1, 2, 3, 4)); - - // Random. - items.clear(); - items.push_back(&item2); - items.push_back(&item4); - items.push_back(&item1); - items.push_back(&item3); - items.sort([](Item* a, Item* b) { return a->value < b->value; }); - EXPECT_THAT(ExtractValues(items), ElementsAre(1, 2, 3, 4)); - - // Stability. - Item item1a(1); - Item item2a(2); - items.clear(); - items.push_back(&item2); - items.push_back(&item4); - items.push_back(&item1); - items.push_back(&item3); - items.push_back(&item1a); - items.push_back(&item2a); - items.sort([](Item* a, Item* b) { return a->value <= b->value; }); - EXPECT_THAT(ExtractValues(items), ElementsAre(1, 1, 2, 2, 3, 4)); - auto items_vector = ExtractItems(items); - EXPECT_EQ(&item1, items_vector[0]); - EXPECT_EQ(&item1a, items_vector[1]); - EXPECT_EQ(&item2, items_vector[2]); - EXPECT_EQ(&item2a, items_vector[3]); - items.clear(); -} - -} // namespace -} // namespace iree
diff --git a/iree/base/intrusive_list_unique_ptr.inc b/iree/base/intrusive_list_unique_ptr.inc deleted file mode 100644 index 94c541d..0000000 --- a/iree/base/intrusive_list_unique_ptr.inc +++ /dev/null
@@ -1,140 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// IWYU pragma: private, include "iree/base/intrusive_list.h" - -#ifndef IREE_BASE_INTRUSIVE_LIST_UNIQUE_PTR_H_ -#define IREE_BASE_INTRUSIVE_LIST_UNIQUE_PTR_H_ - -#include <cstddef> -#include <memory> - -#include "iree/base/intrusive_list.h" -#include "iree/base/logging.h" - -namespace iree { - -// Specialized IntrusiveListBase for std::unique_ptr types. -// This makes the list methods accept std::unique_ptrs and contains a special -// take() method that takes ownership of a list item. -template <typename T, size_t kOffset> -class IntrusiveListUniquePtrBase - : private IntrusiveListBase<T, IntrusiveListIterator<T, kOffset, true>, - IntrusiveListIterator<T, kOffset, false>, - kOffset> { - public: - using IteratorT = IntrusiveListIterator<T, kOffset, true>; - using ReverseIteratorT = IntrusiveListIterator<T, kOffset, false>; - using base_list = IntrusiveListBase<T, IteratorT, ReverseIteratorT, kOffset>; - - IntrusiveListUniquePtrBase() = default; - - using base_list::empty; - using base_list::size; - - using base_list::contains; - - using base_list::clear; - - using base_list::begin; - using base_list::end; - using base_list::rbegin; - using base_list::rend; - - using base_list::next; - - using base_list::previous; - - using base_list::front; - - void push_front(std::unique_ptr<T> value) { - base_list::push_front(value.release()); - } - - using base_list::pop_front; - - using base_list::back; - - void push_back(std::unique_ptr<T> value) { - base_list::push_back(value.release()); - } - - using base_list::pop_back; - - void insert(const IteratorT& it, std::unique_ptr<T> value) { - base_list::insert(it, value.release()); - } - - using base_list::erase; - - // Removes an item from the list at the given iterator and transfers ownership - // to the caller. - // Performance: O(1) - std::unique_ptr<T> take(IteratorT& it) { // NOLINT(runtime/references) - return take(*it); - } - - // Removes an item from the list and transfers ownership to the caller. - // Performance: O(1) - std::unique_ptr<T> take(T* value) { - if (!value) { - return {nullptr}; - } - auto* link = impl::TToLink<T, kOffset>(value); - if (link->prev) { - DCHECK_NE(link, head_); - link->prev->next = link->next; - } else { - DCHECK_EQ(link, head_); - head_ = link->next; - } - if (link->next) { - DCHECK_NE(link, tail_); - link->next->prev = link->prev; - } else { - DCHECK_EQ(link, tail_); - tail_ = link->prev; - } - link->next = link->prev = nullptr; - --count_; - base_list::OnRemove(value); - base_list::CheckCorrectness(); - return std::unique_ptr<T>(value); - } - - void replace(T* old_value, std::unique_ptr<T> new_value) { - base_list::replace(old_value, new_value.release()); - } - - using base_list::sort; - - private: - void OnDeallocate(T* value) override { delete value; } - - using base_list::count_; - using base_list::head_; - using base_list::tail_; -}; - -template <typename U, size_t kOffset> -class IntrusiveList<std::unique_ptr<U>, kOffset> - : public IntrusiveListUniquePtrBase<U, kOffset> {}; - -template <typename U> -class IntrusiveList<std::unique_ptr<U>, kUseDefaultLinkOffset> - : public IntrusiveListUniquePtrBase<U, offsetof(U, link)> {}; - -} // namespace iree - -#endif // IREE_BASE_INTRUSIVE_LIST_UNIQUE_PTR_H_
diff --git a/iree/base/intrusive_list_unique_ptr_test.cc b/iree/base/intrusive_list_unique_ptr_test.cc deleted file mode 100644 index be52c61..0000000 --- a/iree/base/intrusive_list_unique_ptr_test.cc +++ /dev/null
@@ -1,84 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "absl/memory/memory.h" -#include "gtest/gtest.h" -#include "iree/base/intrusive_list.h" - -namespace iree { -namespace { - -struct AllocatedType : public IntrusiveLinkBase<void> { - AllocatedType() { ++alloc_count; } - ~AllocatedType() { --alloc_count; } - static int alloc_count; -}; -int AllocatedType::alloc_count = 0; - -TEST(IntrusiveListUniquePtrTest, UniquePtr) { - AllocatedType::alloc_count = 0; - - // Push/clear. - IntrusiveList<std::unique_ptr<AllocatedType>> list; - EXPECT_EQ(0, AllocatedType::alloc_count); - list.push_back(absl::make_unique<AllocatedType>()); - EXPECT_EQ(1, AllocatedType::alloc_count); - EXPECT_NE(nullptr, list.front()); - list.clear(); - EXPECT_EQ(0, AllocatedType::alloc_count); - - // Push/pop. - list.push_back(absl::make_unique<AllocatedType>()); - EXPECT_EQ(1, AllocatedType::alloc_count); - EXPECT_NE(nullptr, list.front()); - for (auto item : list) { - EXPECT_EQ(item, list.front()); - } - list.pop_back(); - EXPECT_EQ(0, AllocatedType::alloc_count); - - // Push/take. - list.push_back(absl::make_unique<AllocatedType>()); - EXPECT_EQ(1, AllocatedType::alloc_count); - EXPECT_NE(nullptr, list.front()); - auto item = list.take(list.front()); - EXPECT_TRUE(list.empty()); - EXPECT_NE(nullptr, item.get()); - EXPECT_EQ(1, AllocatedType::alloc_count); - item.reset(); - EXPECT_EQ(0, AllocatedType::alloc_count); - - // Push/replace. - list.push_back(absl::make_unique<AllocatedType>()); - EXPECT_EQ(1, AllocatedType::alloc_count); - list.replace(list.front(), absl::make_unique<AllocatedType>()); - EXPECT_EQ(1, AllocatedType::alloc_count); - list.clear(); - EXPECT_EQ(0, AllocatedType::alloc_count); - - // Iteration. - list.push_back(absl::make_unique<AllocatedType>()); - list.push_back(absl::make_unique<AllocatedType>()); - list.push_back(absl::make_unique<AllocatedType>()); - EXPECT_EQ(3, AllocatedType::alloc_count); - for (auto item : list) { - AllocatedType* item_ptr = item; - EXPECT_NE(nullptr, item_ptr); - } - list.clear(); - EXPECT_EQ(0, AllocatedType::alloc_count); -} - -} // namespace -} // namespace iree
diff --git a/iree/base/logging.h b/iree/base/logging.h deleted file mode 100644 index 8bbf57a..0000000 --- a/iree/base/logging.h +++ /dev/null
@@ -1,63 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BASE_LOGGING_H_ -#define IREE_BASE_LOGGING_H_ - -// Logging macros live in their own file so that we can use external versions -// as required. -// -// LOG(severity) << ...; -// Logs a message at the given severity. -// Severity: -// INFO Logs information text. -// WARNING Logs a warning. -// ERROR Logs an error. -// FATAL Logs an error and exit(1). -// -// VLOG(level) << ...; -// Logs a verbose message at the given verbosity level. -// -// DVLOG(level) << ...; -// Behaves like `VLOG` in debug mode (i.e. `#ifndef NDEBUG`). -// Otherwise, it compiles away and does nothing. -// -// CHECK(condition) << ...; -// Runtime asserts that the given condition is true even in release builds. -// It's recommended that DCHECK is used instead as too many CHECKs -// can impact performance. -// -// CHECK_EQ|NE|LT|GT|LE|GE(val1, val2) << ...; -// Runtime assert the specified operation with the given values. -// -// DCHECK(condition) << ...; -// Runtime asserts that the given condition is true only in non-opt builds. -// -// DCHECK_EQ|NE|LT|GT|LE|GE(val1, val2) << ...; -// Runtime assert the specified operation with the given values in non-opt -// builds. -// -// QCHECK(condition) << ...; -// QCHECK_EQ|NE|LT|GT|LE|GE(val1, val2) << ...; -// These behave like `CHECK` but do not print a full stack trace. -// They are useful when problems are definitely unrelated to program flow, -// e.g. when validating user input. - -#ifdef IREE_CONFIG_GOOGLE_INTERNAL -#include "iree/base/google/logging_google.h" -#else -#include "iree/base/internal/logging.h" -#endif // IREE_CONFIG_GOOGLE_INTERNAL - -#endif // IREE_BASE_LOGGING_H_
diff --git a/iree/base/ref_ptr.h b/iree/base/ref_ptr.h deleted file mode 100644 index dd44034..0000000 --- a/iree/base/ref_ptr.h +++ /dev/null
@@ -1,364 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BASE_REF_PTR_H_ -#define IREE_BASE_REF_PTR_H_ - -#include <atomic> -#include <cstdint> -#include <type_traits> -#include <utility> - -#include "absl/base/attributes.h" -#include "iree/base/logging.h" - -namespace iree { - -// Use this to get really verbose refptr logging: -// #define IREE_VERBOSE_REF_PTR - -template <class T> -class ref_ptr; - -// Allocates a new ref_ptr type. -// Like make_unique, but for ref_ptr. -// -// Usage: -// ref_ptr<MyType> p = make_ref<MyType>(1, 2, 3); -template <typename T, typename... Args> -ref_ptr<T> make_ref(Args&&... args) { - return ref_ptr<T>(new T(std::forward<Args>(args)...)); -} - -// Assigns a raw pointer to a ref_ptr without adding a reference. -// -// Usage: -// ref_ptr<MyType> p = assign_ref(new MyType()); -template <typename T> -inline ref_ptr<T> assign_ref(T* value) { - return ref_ptr<T>(value); -} - -// Adds a reference to the given raw pointer. -// -// Usage: -// MyType* raw_ptr = AcquirePointerFromSomewhere(); -// ref_ptr<MyType> p = add_ref(raw_ptr); -template <typename T> -inline ref_ptr<T> add_ref(T* value) { - if (value) ref_ptr_add_ref(value); - return ref_ptr<T>(value); -} - -// Adds a reference to the given ref_ptr. -// -// Usage: -// ref_ptr<MyType> a = make_ref<MyType>(); -// ref_ptr<MyType> p = add_ref(a); -template <typename T> -inline ref_ptr<T> add_ref(const ref_ptr<T>& value) { - if (value.get()) ref_ptr_add_ref(value.get()); - return ref_ptr<T>(value.get()); -} - -// Reference counted pointer container. -// This is modeled on boost::instrusive_ptr in that it requires no -// extra storage over the pointer type and should compile to almost -// no additional code. It also allows us to round-trip object pointers -// through regular pointers, which is critical when having to round-trip -// them through JNI/etc where we can't use things like unique_ptr/shared_ptr. -// -// ref_ptr<Foo> p1(new Foo()); // ref count 1 -// ref_ptr<Foo> p2(p1); // ref count 2 -// p1.reset(); // ref count 1 -// p2.reset(); // ref count 0, deleted -// -// When round-tripping the pointer through external APIs, use release(): -// ref_ptr<Foo> p1(new Foo()); // ref count 1 -// Foo* raw_p = p1.release(); // ref count 1 -// // pass to API -// ref_ptr<Foo> p2(raw_p); // ref count 1 (don't add ref) -// p2.reset(); // ref count 0, deleted -// -// See the boost intrusive_ptr docs for details of behavior: -// http://www.boost.org/doc/libs/1_55_0/libs/smart_ptr/intrusive_ptr.html -// -// ref_ptr manages the target objects in a thread-safe way, though you'll want -// to take care with objects that may have pinned threads for deallocation. If -// you release the last reference to an object on a thread other than what it -// was expecting you're gonna have a bad time. -// -// Compatible only with types that subclass RefObject or implement the following -// methods: -// ref_ptr_add_ref -// ref_ptr_release_ref -template <class T> -class ref_ptr { - private: - typedef ref_ptr this_type; - typedef T* this_type::*unspecified_bool_type; - - public: - // Initializes with nullptr. - ABSL_ATTRIBUTE_ALWAYS_INLINE ref_ptr() noexcept = default; - - // Initializes with nullptr so that there is no way to create an - // uninitialized ref_ptr. - ABSL_ATTRIBUTE_ALWAYS_INLINE ref_ptr(std::nullptr_t) noexcept {} // NOLINT - - // Initializes the pointer to the given value. - // The value will not have its reference count incremented (as it is with - // unique_ptr). Use Retain to add to the reference count. - ABSL_ATTRIBUTE_ALWAYS_INLINE explicit ref_ptr(T* p) noexcept : px_(p) {} - - // Decrements the reference count of the owned pointer. - ABSL_ATTRIBUTE_ALWAYS_INLINE ~ref_ptr() noexcept { - if (px_) ref_ptr_release_ref(px_); - } - - // No implicit ref_ptr copying allowed; use add_ref instead. - ref_ptr(const ref_ptr&) noexcept = delete; - ref_ptr& operator=(const ref_ptr&) noexcept = delete; - - // Move support to transfer ownership from one ref_ptr to another. - ref_ptr(ref_ptr&& rhs) noexcept : px_(rhs.release()) {} - ref_ptr& operator=(ref_ptr&& rhs) noexcept { - if (px_ != rhs.px_) { - if (px_) ref_ptr_release_ref(px_); - px_ = rhs.release(); - } - return *this; - } - - // Move support from another compatible type. - template <typename U> - ref_ptr(ref_ptr<U>&& rhs) noexcept : px_(rhs.release()) {} // NOLINT - template <typename U> - ref_ptr& operator=(ref_ptr<U>&& rhs) noexcept { - if (px_ != rhs.get()) { - if (px_) ref_ptr_release_ref(px_); - px_ = rhs.release(); - } - return *this; - } - - // Resets the object to nullptr and decrements the reference count, possibly - // deleting it. - void reset() noexcept { - if (px_) { - ref_ptr_release_ref(px_); - px_ = nullptr; - } - } - - // Releases a pointer. - // Returns the current pointer held by this object without having - // its reference count decremented and resets the ref_ptr to empty. - // Returns nullptr if the ref_ptr holds no value. - // To re-wrap in a ref_ptr use either ref_ptr<T>(value) or assign(). - ABSL_ATTRIBUTE_ALWAYS_INLINE T* release() noexcept { - T* p = px_; - px_ = nullptr; - return p; - } - - // Assigns a pointer. - // The pointer will be accepted by the ref_ptr and its reference count will - // not be incremented. - ABSL_ATTRIBUTE_ALWAYS_INLINE void assign(T* value) noexcept { - reset(); - px_ = value; - } - - // Gets the pointer referenced by this instance. - // operator* and operator-> will assert() if there is no current object. - constexpr T* get() const noexcept { return px_; } - constexpr T& operator*() const noexcept { return *px_; } - constexpr T* operator->() const noexcept { return px_; } - - // Support boolean expression evaluation ala unique_ptr/shared_ptr: - // https://en.cppreference.com/w/cpp/memory/shared_ptr/operator_bool - constexpr operator unspecified_bool_type() const noexcept { - return px_ ? &this_type::px_ : nullptr; - } - // Supports unary expression evaluation. - constexpr bool operator!() const noexcept { return !px_; } - - // Swap support. - void swap(ref_ptr& rhs) { std::swap(px_, rhs.px_); } - - private: - T* px_ = nullptr; -}; - -// Base class for reference counted objects. -// Reference counted objects should be used with the ref_ptr pointer type. -// As reference counting can be tricky always prefer to use unique_ptr and -// avoid this type. Only use this when unique_ptr is not possible, such as -// when round-tripping objects through marshaling boundaries (v8/Java) or -// any objects that may have their lifetime tied to a garbage collected -// object. -// -// Subclasses should protect their dtor so that reference counting must -// be used. -// -// This is designed to avoid the need for extra vtable space or for adding -// methods to the vtable of subclasses. This differs from the boost Pointable -// version of this object. -// Inspiration for this comes from Peter Weinert's Dr. Dobb's article: -// http://www.drdobbs.com/cpp/a-base-class-for-intrusively-reference-c/229218807 -// -// RefObjects are thread safe and may be used with ref_ptrs from multiple -// threads. -// -// Subclasses may implement a custom Delete operator to handle their -// deallocation. It should be thread safe as it may be called from any thread. -// -// Usage: -// class MyRefObject : public RefObject<MyRefObject> { -// public: -// MyRefObject() = default; -// // Optional; can be used to return to pool/etc - must be public: -// static void Delete(MyRefObject* ptr) { -// ::operator delete(ptr); -// } -// }; -template <class T> -class RefObject { - static_assert(!std::is_array<T>::value, "T must not be an array"); - - // value is true if a static Delete(T*) function is present. - struct has_custom_deleter { - template <typename C> - static auto Test(C* p) -> decltype(C::Delete(nullptr), std::true_type()); - template <typename> - static std::false_type Test(...); - static constexpr bool value = - std::is_same<std::true_type, decltype(Test<T>(nullptr))>::value; - }; - - template <typename V, bool has_custom_deleter> - struct delete_thunk { - static void Delete(V* p) { - auto ref_obj = static_cast<RefObject<V>*>(p); - int previous_count = ref_obj->counter_.fetch_sub(1); -#ifdef IREE_VERBOSE_REF_PTR - LOG(INFO) << "ro-- " << typeid(V).name() << " " << p << " now " - << previous_count - 1 - << (previous_count == 1 ? " DEAD (CUSTOM)" : ""); -#endif // IREE_VERBOSE_REF_PTR - if (previous_count == 1) { - // We delete type T pointer here to avoid the need for a virtual dtor. - V::Delete(p); - } - } - }; - - template <typename V> - struct delete_thunk<V, false> { - static void Delete(V* p) { - auto ref_obj = static_cast<RefObject<V>*>(p); - int previous_count = ref_obj->counter_.fetch_sub(1); -#ifdef IREE_VERBOSE_REF_PTR - LOG(INFO) << "ro-- " << typeid(V).name() << " " << p << " now " - << previous_count - 1 << (previous_count == 1 ? " DEAD" : ""); -#endif // IREE_VERBOSE_REF_PTR - if (previous_count == 1) { - // We delete type T pointer here to avoid the need for a virtual dtor. - delete p; - } - } - }; - - public: - // Adds a reference; used by ref_ptr. - friend void ref_ptr_add_ref(T* p) { - auto ref_obj = static_cast<RefObject*>(p); - ++ref_obj->counter_; - -#ifdef IREE_VERBOSE_REF_PTR - LOG(INFO) << "ro++ " << typeid(T).name() << " " << p << " now " - << ref_obj->counter_; -#endif // IREE_VERBOSE_REF_PTR - } - - // Releases a reference, potentially deleting the object; used by ref_ptr. - friend void ref_ptr_release_ref(T* p) { - delete_thunk<T, has_custom_deleter::value>::Delete(p); - } - - // Adds a reference. - // ref_ptr should be used instead of this in most cases. This is required - // for when interoperating with marshaling APIs. - void AddReference() { ref_ptr_add_ref(static_cast<T*>(this)); } - - // Releases a reference, potentially deleting the object. - // ref_ptr should be used instead of this in most cases. This is required - // for when interoperating with marshaling APIs. - void ReleaseReference() { ref_ptr_release_ref(static_cast<T*>(this)); } - - protected: - RefObject() { ref_ptr_add_ref(static_cast<T*>(this)); } - RefObject(const RefObject&) = default; - RefObject& operator=(const RefObject&) { return *this; } - - std::atomic<intptr_t> counter_{0}; -}; - -// Various comparison operator overloads. - -template <class T, class U> -inline bool operator==(ref_ptr<T> const& a, ref_ptr<U> const& b) { - return a.get() == b.get(); -} - -template <class T, class U> -inline bool operator!=(ref_ptr<T> const& a, ref_ptr<U> const& b) { - return a.get() != b.get(); -} - -template <class T, class U> -inline bool operator==(ref_ptr<T> const& a, U* b) { - return a.get() == b; -} - -template <class T, class U> -inline bool operator!=(ref_ptr<T> const& a, U* b) { - return a.get() != b; -} - -template <class T, class U> -inline bool operator==(T* a, ref_ptr<U> const& b) { - return a == b.get(); -} - -template <class T, class U> -inline bool operator!=(T* a, ref_ptr<U> const& b) { - return a != b.get(); -} - -template <class T> -inline bool operator<(ref_ptr<T> const& a, ref_ptr<T> const& b) { - return a.get() < b.get(); -} - -// Swaps the pointers of two ref_ptrs. -template <class T> -void swap(ref_ptr<T>& lhs, ref_ptr<T>& rhs) { - lhs.swap(rhs); -} - -} // namespace iree - -#endif // IREE_BASE_REF_PTR_H_
diff --git a/iree/base/ref_ptr_test.cc b/iree/base/ref_ptr_test.cc deleted file mode 100644 index e518eb4..0000000 --- a/iree/base/ref_ptr_test.cc +++ /dev/null
@@ -1,330 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/ref_ptr.h" - -#include "gtest/gtest.h" - -namespace iree { -namespace { - -class MyType : public RefObject<MyType> { - public: - int x = 5; - - using RefObject<MyType>::counter_; // Expose for testing. -}; - -TEST(RefPtrTest, Construction) { - // Empty. - ref_ptr<MyType> n1; - EXPECT_EQ(nullptr, n1.get()); - ref_ptr<MyType> n2(nullptr); - EXPECT_EQ(nullptr, n2.get()); - - // Assign a new ptr and add ref. - MyType* a_ptr = new MyType(); - EXPECT_EQ(1, a_ptr->counter_); - ref_ptr<MyType> a(a_ptr); - EXPECT_EQ(1, a->counter_); - - // Assign existing ptr without adding a ref. - ref_ptr<MyType> b(a_ptr); - EXPECT_EQ(1, b->counter_); - - // Add a new ref. - ref_ptr<MyType> c = add_ref(b); - EXPECT_EQ(2, c->counter_); - - b.release(); -} - -TEST(RefPtrTest, Assign) { - // Ok to assign nothing. - ref_ptr<MyType> n1 = assign_ref<MyType>(nullptr); - EXPECT_EQ(nullptr, n1.get()); - - ref_ptr<MyType> mt = make_ref<MyType>(); - EXPECT_EQ(1, mt->counter_); - ref_ptr<MyType> n2 = assign_ref(mt.get()); - EXPECT_EQ(1, mt->counter_); - mt.release(); // must release, as we assigned to n2. - EXPECT_EQ(1, n2->counter_); - n2.reset(); -} - -TEST(RefPtrTest, Retain) { - // Ok to retain nothing. - ref_ptr<MyType> n1 = add_ref<MyType>(nullptr); - EXPECT_EQ(nullptr, n1.get()); - - ref_ptr<MyType> mt = make_ref<MyType>(); - EXPECT_EQ(1, mt->counter_); - ref_ptr<MyType> n2 = add_ref(mt.get()); - EXPECT_EQ(2, mt->counter_); - mt.reset(); - EXPECT_EQ(1, n2->counter_); - n2.reset(); -} - -TEST(RefPtrTest, Reset) { - ref_ptr<MyType> a(new MyType()); - ref_ptr<MyType> b(new MyType()); - - // Reset to drop reference. - ref_ptr<MyType> a_copy = add_ref(a); - EXPECT_EQ(2, a_copy->counter_); - a.reset(); - EXPECT_EQ(1, a_copy->counter_); - - // Reset via = operator. - a = nullptr; - EXPECT_EQ(1, a_copy->counter_); - a = add_ref(a_copy); - EXPECT_EQ(2, a_copy->counter_); - - // No-op on empty ptrs. - ref_ptr<MyType> n; - n.reset(); - n.assign(nullptr); -} - -TEST(RefPtrTest, ReleaseAssign) { - ref_ptr<MyType> a(new MyType()); - - // Release a's pointer. - MyType* a_raw_ptr = a.get(); - MyType* a_ptr = a.release(); - EXPECT_EQ(a_raw_ptr, a_ptr); - EXPECT_EQ(nullptr, a.get()); - EXPECT_EQ(1, a_ptr->counter_); - - // Re-wrap in a ref_ptr. - a.assign(a_ptr); - EXPECT_EQ(1, a->counter_); - - // No-op on empty ptrs. - ref_ptr<MyType> n; - EXPECT_EQ(nullptr, n.release()); -} - -TEST(RefPtrTest, Accessors) { - ref_ptr<MyType> a(new MyType()); - EXPECT_EQ(5, a->x); - a->x = 100; - EXPECT_EQ(100, a->x); - - MyType& ra = *a; - ra.x = 200; - EXPECT_EQ(200, ra.x); - - const MyType& cra = *a; - EXPECT_EQ(200, cra.x); -} - -TEST(RefPtrTest, BooleanExpressions) { - ref_ptr<MyType> a(new MyType()); - ref_ptr<MyType> n; - - EXPECT_NE(nullptr, a.get()); - EXPECT_TRUE(a); - EXPECT_FALSE(!a); - EXPECT_EQ(true, static_cast<bool>(a)); - - EXPECT_EQ(nullptr, n.get()); - EXPECT_FALSE(n); - EXPECT_TRUE(!n); - EXPECT_EQ(false, static_cast<bool>(n)); -} - -TEST(RefPtrTest, Comparisons) { - ref_ptr<MyType> a(new MyType()); - ref_ptr<MyType> b(new MyType()); - ref_ptr<MyType> n; - - EXPECT_TRUE(a == a); - EXPECT_TRUE(a == a.get()); - EXPECT_TRUE(a.get() == a); - EXPECT_FALSE(a != a); - EXPECT_FALSE(a != a.get()); - EXPECT_FALSE(a.get() != a); - - EXPECT_FALSE(a == b); - EXPECT_FALSE(a == b.get()); - EXPECT_FALSE(a.get() == b); - EXPECT_TRUE(a != b); - EXPECT_TRUE(a != b.get()); - EXPECT_TRUE(a.get() != b); - - EXPECT_TRUE(n == n); - EXPECT_TRUE(n == n.get()); - EXPECT_TRUE(n.get() == n); - EXPECT_FALSE(n != n); - EXPECT_FALSE(n != n.get()); - EXPECT_FALSE(n.get() != n); - - EXPECT_FALSE(a < a); - EXPECT_TRUE(n < a); -} - -TEST(RefPtrTest, Swap) { - ref_ptr<MyType> a(new MyType()); - ref_ptr<MyType> b(new MyType()); - MyType* a_ptr = a.get(); - MyType* b_ptr = b.get(); - - swap(a, a); - EXPECT_EQ(a_ptr, a); - - swap(a, b); - EXPECT_EQ(a_ptr, b.get()); - EXPECT_EQ(b_ptr, a.get()); - - swap(a, b); - EXPECT_EQ(a_ptr, a.get()); - EXPECT_EQ(b_ptr, b.get()); - - ref_ptr<MyType> c; - swap(a, c); - EXPECT_EQ(a_ptr, c.get()); - EXPECT_EQ(nullptr, a.get()); -} - -TEST(RefPtrTest, Move) { - auto a = make_ref<MyType>(); - auto b = make_ref<MyType>(); - ref_ptr<MyType> c; - EXPECT_EQ(nullptr, c.get()); - - c = std::move(a); - EXPECT_NE(nullptr, c.get()); - - b = std::move(c); - EXPECT_NE(nullptr, b.get()); -} - -TEST(RefPtrTest, MoveCompatible) { - struct MyBaseType : public RefObject<MyBaseType> { - int x = 5; - using RefObject<MyBaseType>::counter_; // Expose for testing. - }; - struct MyTypeA : public MyBaseType { - int a = 6; - }; - struct MyTypeB : public MyBaseType { - int b = 7; - }; - - ref_ptr<MyTypeA> a = make_ref<MyTypeA>(); - EXPECT_EQ(1, a->counter_); - ref_ptr<MyBaseType> base = add_ref(a); - EXPECT_EQ(a.get(), base.get()); - EXPECT_EQ(2, a->counter_); - - base = make_ref<MyTypeB>(); - EXPECT_EQ(1, a->counter_); - EXPECT_EQ(1, base->counter_); -} - -TEST(RefPtrTest, StackAllocation) { - static int alloc_count = 0; - class StackAllocationType : public RefObject<StackAllocationType> { - public: - StackAllocationType() { ++alloc_count; } - ~StackAllocationType() { --alloc_count; } - }; - { - StackAllocationType a; - EXPECT_EQ(1, alloc_count); - } - EXPECT_EQ(0, alloc_count); -} - -TEST(RefPtrTest, DefaultDeleter) { - static int alloc_count = 0; - class DefaultDeleterType : public RefObject<DefaultDeleterType> { - public: - DefaultDeleterType() { ++alloc_count; } - ~DefaultDeleterType() { --alloc_count; } - }; - - // Empty is ok. - ref_ptr<DefaultDeleterType> n; - n.reset(); - - // Lifecycle. - EXPECT_EQ(0, alloc_count); - ref_ptr<DefaultDeleterType> a = make_ref<DefaultDeleterType>(); - EXPECT_EQ(1, alloc_count); - a.reset(); - EXPECT_EQ(0, alloc_count); -} - -TEST(RefPtrTest, InlineDeallocator) { - static int alloc_count = 0; - class CustomDeleterType : public RefObject<CustomDeleterType> { - public: - CustomDeleterType() { ++alloc_count; } - static void Delete(CustomDeleterType* ptr) { - --alloc_count; - ::operator delete(ptr); - } - }; - - // Empty is ok. - ref_ptr<CustomDeleterType> n; - n.reset(); - - // Lifecycle. - EXPECT_EQ(0, alloc_count); - auto a = make_ref<CustomDeleterType>(); - EXPECT_EQ(1, alloc_count); - a.reset(); - EXPECT_EQ(0, alloc_count); -} - -class VirtualDtorTypeA : public RefObject<VirtualDtorTypeA> { - public: - VirtualDtorTypeA() { ++alloc_count_a; } - virtual ~VirtualDtorTypeA() { --alloc_count_a; } - static int alloc_count_a; -}; -int VirtualDtorTypeA::alloc_count_a = 0; - -class VirtualDtorTypeB : public VirtualDtorTypeA { - public: - VirtualDtorTypeB() { ++alloc_count_b; } - ~VirtualDtorTypeB() override { --alloc_count_b; } - static int alloc_count_b; -}; -int VirtualDtorTypeB::alloc_count_b = 0; - -TEST(RefPtrTest, VirtualDestructor) { - // Empty is ok. - ref_ptr<VirtualDtorTypeB> n; - n.reset(); - - // Lifecycle. - EXPECT_EQ(0, VirtualDtorTypeA::alloc_count_a); - EXPECT_EQ(0, VirtualDtorTypeB::alloc_count_b); - ref_ptr<VirtualDtorTypeB> a = make_ref<VirtualDtorTypeB>(); - EXPECT_EQ(1, VirtualDtorTypeA::alloc_count_a); - EXPECT_EQ(1, VirtualDtorTypeB::alloc_count_b); - a.reset(); - EXPECT_EQ(0, VirtualDtorTypeA::alloc_count_a); - EXPECT_EQ(0, VirtualDtorTypeB::alloc_count_b); -} - -} // namespace -} // namespace iree
diff --git a/iree/base/shape.cc b/iree/base/shape.cc deleted file mode 100644 index 875d119..0000000 --- a/iree/base/shape.cc +++ /dev/null
@@ -1,100 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/shape.h" - -#include <cstddef> - -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "iree/base/source_location.h" -#include "iree/base/status.h" - -namespace iree { - -Shape::Shape(const int* values, int size) : rank_(size) { - QCHECK_LE(size, kMaxRank) - << "Max rank of " << kMaxRank << ", shape has " << size; - std::memcpy(value_, values, size * sizeof(int)); -} - -std::string Shape::DebugString() const { - return absl::StrCat("[", absl::StrJoin(subspan(), ","), "]"); -} - -absl::Span<const int> Shape::subspan(size_type pos, size_type len) const { - if (len == npos) { - len = rank_ - pos; - } - return absl::MakeConstSpan(&value_[pos], len); -} - -void Shape::push_back(int dim) { - DCHECK_LE(rank_ + 1, kMaxRank); - value_[rank_++] = dim; -} - -void Shape::insert(iterator pos, int dim) { - int axis = static_cast<int>(pos - value_); - DCHECK_GE(axis, 0); - DCHECK_LE(axis, rank_); - DCHECK_LE(rank_ + 1, kMaxRank); - ++rank_; - for (int i = rank_ - 1; i > axis; --i) { - value_[i] = value_[i - 1]; - } - value_[axis] = dim; -} - -void Shape::erase(iterator pos) { - int axis = static_cast<int>(pos - value_); - DCHECK_GE(axis, 0); - DCHECK_LE(axis, rank_); - for (int i = axis; i < rank_ - 1; ++i) { - value_[i] = value_[i + 1]; - } - --rank_; -} - -int Shape::element_count() const { - size_t element_count = 1; - for (int i = 0; i < rank_; ++i) { - int dim = value_[i]; - if (dim == -1) { - return 0; - } - element_count *= dim; - } - return element_count; -} - -StatusOr<int> Shape::ResolveAxis(int axis) const { - if (rank_ == 0 && (axis == -1 || axis == 0)) { - // Scalar axes resolves to 0. - return 0; - } - - int new_axis = axis; - if (new_axis < 0) { - new_axis += rank_; - } - if (new_axis < 0 || new_axis >= rank_) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Axis " << new_axis << " (orig " << axis - << ") out of bounds of rank " << rank_; - } - return new_axis; -} - -} // namespace iree
diff --git a/iree/base/shape.h b/iree/base/shape.h deleted file mode 100644 index 758d66f..0000000 --- a/iree/base/shape.h +++ /dev/null
@@ -1,156 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BASE_SHAPE_H_ -#define IREE_BASE_SHAPE_H_ - -#include <array> -#include <cstring> -#include <initializer_list> -#include <iterator> -#include <string> -#include <type_traits> -#include <vector> - -#include "absl/meta/type_traits.h" -#include "absl/types/span.h" -#include "iree/base/logging.h" -#include "iree/base/status.h" - -namespace iree { - -// For simplicity we limit our shapes to a max of rank-N (shape.size() == N) as -// this prevents dynamic allocations and rarely are there greater ranks. -constexpr int kMaxRank = 5; - -// Represent indices and lengths of tensors. -using Index = std::array<int, kMaxRank>; -using Length = std::array<int, kMaxRank>; - -// Represents the number of elements in multiple dimensions. -// Can be rank-0 (scalar) to rank-kMaxRank. Tries to match the API of -// std::vector and can be converted to a Span via subspan(). -// -// https://www.tensorflow.org/guide/tensors#shape -class Shape { - public: - using size_type = int; - static constexpr size_type npos = ~(size_type(0)); // NOLINT - using iterator = int*; - using const_iterator = const int*; - - Shape() = default; - Shape(const int* values, int size); - Shape(std::initializer_list<int> values) - : Shape(values.begin(), values.size()) {} - explicit Shape(absl::Span<const int> values) - : Shape(values.data(), values.size()) {} - - template <typename Iterator> - using EnableIfForwardIterator = absl::enable_if_t<std::is_convertible< - typename std::iterator_traits<Iterator>::iterator_category, - std::forward_iterator_tag>::value>; - template <typename Iterator, EnableIfForwardIterator<Iterator>* = nullptr> - Shape(Iterator first, Iterator last) { - rank_ = std::distance(first, last); - QCHECK_LE(rank_, kMaxRank); - for (int i = 0; first != last; ++i, static_cast<void>(++first)) { - value_[i] = *first; - } - } - - // Returns a string representation of the given shape. - std::string DebugString() const; - - // Size (aka 'rank') of the shape, counting the number of dimensions. - constexpr size_type size() const noexcept { return rank_; } - - // Whether the shape is rank-0 (scalar). - constexpr bool empty() const noexcept { return rank_ == 0; } - - // Returns the total elements in the tensor shape. - // Returns 0 if the tensor shape is not complete and 1 if the shape is a - // scalar value. - int element_count() const; - - // Resolves an axis in [-R,R) to the real axis value and verifies the range. - StatusOr<int> ResolveAxis(int axis) const; - - // Compares two shapes for equality. - inline static bool Equal(const Shape& a, const Shape& b) { - return a.rank_ == b.rank_ && - std::memcmp(a.value_, b.value_, a.rank_ * sizeof(value_[0])) == 0; - } - - int& operator[](size_type i) noexcept { - DCHECK_GE(i, 0); - DCHECK_LT(i, rank_); - return value_[i]; - } - - const int& operator[](size_type i) const noexcept { - DCHECK_GE(i, 0); - DCHECK_LT(i, rank_); - return value_[i]; - } - - int front() const noexcept { - DCHECK_GE(rank_, 1); - return value_[0]; - } - - int back() const noexcept { - DCHECK_GE(rank_, 1); - return value_[rank_ - 1]; - } - - constexpr iterator begin() const noexcept { - return const_cast<iterator>(&value_[0]); - } - constexpr iterator end() const noexcept { - return const_cast<iterator>(&value_[rank_]); - } - constexpr const_iterator cbegin() const noexcept { return &value_[0]; } - constexpr const_iterator cend() const noexcept { return &value_[rank_]; } - - absl::Span<const int> subspan(size_type pos = 0, size_type len = npos) const; - absl::Span<const int> data() const { return subspan(); } - - void push_back(int dim); - - void insert(iterator pos, int dim); - - void erase(iterator pos); - - void clear() { rank_ = 0; } - - private: - size_type rank_ = 0; - int value_[kMaxRank]; -}; - -inline bool operator==(const Shape& a, const Shape& b) { - return Shape::Equal(a, b); -} - -inline bool operator!=(const Shape& a, const Shape& b) { return !(a == b); } - -inline std::ostream& operator<<(std::ostream& stream, const Shape& shape) { - stream << shape.DebugString(); - return stream; -} - -} // namespace iree - -#endif // IREE_BASE_SHAPE_H_
diff --git a/iree/base/shape_test.cc b/iree/base/shape_test.cc deleted file mode 100644 index ca163a3..0000000 --- a/iree/base/shape_test.cc +++ /dev/null
@@ -1,222 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/shape.h" - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "iree/base/status.h" -#include "iree/base/status_matchers.h" - -namespace iree { -namespace { - -using ::testing::ElementsAre; - -// Tests shapes that represent 0-D scalar values. -TEST(ShapeTest, Scalar) { - Shape shape; - EXPECT_EQ(0, shape.size()); - EXPECT_TRUE(shape.empty()); - EXPECT_EQ(1, shape.element_count()); - EXPECT_EQ(shape, shape); - EXPECT_EQ(0, shape.subspan().size()); - for (const int dim : shape) { - FAIL() << "Should have no dimensions, have: " << dim; - } - EXPECT_EQ(shape.begin(), shape.end()); - EXPECT_EQ(shape.cbegin(), shape.cend()); - shape.clear(); - EXPECT_EQ(0, shape.size()); -} - -// Tests the various ways of constructing a 1+D shape. -TEST(ShapeTest, NonScalarConstruction) { - EXPECT_EQ(0, Shape().size()); - EXPECT_EQ(0, Shape({}).size()); - EXPECT_EQ(1, Shape({10}).size()); - EXPECT_EQ(4, Shape({10, 20, 30, 40}).size()); - - std::vector<int> empty_data = {}; - EXPECT_EQ(0, Shape(empty_data.data(), empty_data.size()).size()); - EXPECT_EQ(0, Shape(empty_data.begin(), empty_data.end()).size()); - EXPECT_EQ(0, Shape(absl::MakeConstSpan(empty_data)).size()); - - EXPECT_THAT(Shape({}).subspan(), ElementsAre()); - EXPECT_THAT(Shape({10}).subspan(), ElementsAre(10)); - EXPECT_THAT(Shape({10, 20, 30, 40}).subspan(), ElementsAre(10, 20, 30, 40)); - - std::vector<int> valid_data = {10, 20, 30, 40}; - EXPECT_THAT(Shape(valid_data.begin(), valid_data.end()).subspan(), - ElementsAre(10, 20, 30, 40)); - EXPECT_THAT(Shape(absl::MakeConstSpan(valid_data)).subspan(), - ElementsAre(10, 20, 30, 40)); -} - -// Tests shapes that represent 1+D multidimensional values. -TEST(ShapeTest, NonScalarAccess) { - Shape shape = {1, 2, 3, 4}; - EXPECT_EQ(4, shape.size()); - EXPECT_FALSE(shape.empty()); - EXPECT_EQ(1 * 2 * 3 * 4, shape.element_count()); - EXPECT_EQ(shape, shape); - EXPECT_NE(shape, Shape({4, 3, 2, 1})); - EXPECT_THAT(shape.subspan(), ElementsAre(1, 2, 3, 4)); - std::vector<int> readout; - for (const int dim : shape) { - readout.push_back(dim); - } - EXPECT_THAT(readout, ElementsAre(1, 2, 3, 4)); - EXPECT_EQ(1, shape[0]); - EXPECT_EQ(2, shape[1]); - EXPECT_EQ(3, shape[2]); - EXPECT_EQ(4, shape[3]); - EXPECT_EQ(1, shape.front()); - EXPECT_EQ(4, shape.back()); -} - -TEST(ShapeTest, PushBack) { - Shape shape; - EXPECT_EQ(0, shape.size()); - - shape.push_back(10); - EXPECT_EQ(1, shape.size()); - EXPECT_EQ(10, shape.front()); - EXPECT_EQ(10, shape.back()); - EXPECT_EQ(10, shape[0]); - EXPECT_THAT(shape.subspan(), ElementsAre(10)); - - shape.push_back(20); - EXPECT_EQ(2, shape.size()); - EXPECT_EQ(10, shape.front()); - EXPECT_EQ(20, shape.back()); - EXPECT_EQ(10, shape[0]); - EXPECT_EQ(20, shape[1]); - EXPECT_THAT(shape.subspan(), ElementsAre(10, 20)); -} - -TEST(ShapeTest, Insert) { - Shape shape; - EXPECT_EQ(0, shape.size()); - - shape.insert(shape.begin(), 20); - EXPECT_THAT(shape.subspan(), ElementsAre(20)); - shape.insert(shape.begin(), 10); - EXPECT_THAT(shape.subspan(), ElementsAre(10, 20)); - shape.insert(shape.end(), 40); - EXPECT_THAT(shape.subspan(), ElementsAre(10, 20, 40)); - shape.insert(shape.begin() + 2, 30); - EXPECT_THAT(shape.subspan(), ElementsAre(10, 20, 30, 40)); - - Shape ex_shape{72, 4}; - ex_shape.insert(ex_shape.begin(), 144); - EXPECT_THAT(ex_shape.subspan(), ElementsAre(144, 72, 4)); -} - -TEST(ShapeTest, Erase) { - Shape shape = {1, 2, 3, 4}; - EXPECT_THAT(shape.subspan(), ElementsAre(1, 2, 3, 4)); - shape.erase(shape.begin()); - EXPECT_THAT(shape.subspan(), ElementsAre(2, 3, 4)); - shape.erase(shape.end()); - EXPECT_THAT(shape.subspan(), ElementsAre(2, 3)); - shape.erase(shape.begin() + 1); - EXPECT_THAT(shape.subspan(), ElementsAre(2)); - shape.erase(shape.end()); - EXPECT_THAT(shape.subspan(), ElementsAre()); -} - -TEST(ShapeTest, Clear) { - Shape shape; - EXPECT_EQ(0, shape.size()); - shape.clear(); - EXPECT_EQ(0, shape.size()); - - shape = Shape({1}); - shape.clear(); - EXPECT_EQ(0, shape.size()); - - shape = Shape({1, 2, 3, 4}); - shape.clear(); - EXPECT_EQ(0, shape.size()); -} - -TEST(ShapeTest, DebugString) { - EXPECT_EQ("[]", Shape({}).DebugString()); - EXPECT_EQ("[1]", Shape({1}).DebugString()); - EXPECT_EQ("[1,2]", Shape({1, 2}).DebugString()); -} - -TEST(ShapeTest, ElementCount) { - EXPECT_EQ(1, Shape({}).element_count()); - EXPECT_EQ(0, Shape({0}).element_count()); - EXPECT_EQ(1, Shape({1}).element_count()); - EXPECT_EQ(2, Shape({2, 1}).element_count()); - EXPECT_EQ(10, Shape({2, 5}).element_count()); - EXPECT_EQ(9216, Shape({72, 1, 128}).element_count()); - EXPECT_EQ(9216, Shape({1, 72, 128}).element_count()); - - // Partial shaping should yield no elements. - EXPECT_EQ(0, Shape({1, -1, 2, 3}).element_count()); -} - -TEST(ShapeTest, ResolveAxis) { - int axis; - ASSERT_OK_AND_ASSIGN(axis, Shape({0}).ResolveAxis(0)); - EXPECT_EQ(0, axis); - ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(1)); - EXPECT_EQ(1, axis); - ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(2)); - EXPECT_EQ(2, axis); - - EXPECT_TRUE(IsInvalidArgument(Shape({0, 1, 2}).ResolveAxis(3).status())); -} - -TEST(ShapeTest, ResolveAxisNegative) { - int axis; - ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(-3)); - EXPECT_EQ(0, axis); - ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(-2)); - EXPECT_EQ(1, axis); - ASSERT_OK_AND_ASSIGN(axis, Shape({0, 1, 2}).ResolveAxis(-1)); - EXPECT_EQ(2, axis); - - EXPECT_TRUE(IsInvalidArgument(Shape({0, 1, 2}).ResolveAxis(-4).status())); -} - -TEST(ShapeTest, ResolveAxisScalar) { - int axis; - ASSERT_OK_AND_ASSIGN(axis, Shape({}).ResolveAxis(0)); - EXPECT_EQ(0, axis); - ASSERT_OK_AND_ASSIGN(axis, Shape({}).ResolveAxis(-1)); - EXPECT_EQ(0, axis); - - EXPECT_TRUE(IsInvalidArgument(Shape({}).ResolveAxis(1).status())); -} - -TEST(ShapeTest, Equality) { - EXPECT_EQ(Shape({}), Shape({})); - EXPECT_EQ(Shape({0}), Shape({0})); - EXPECT_EQ(Shape({1}), Shape({1})); - EXPECT_EQ(Shape({1, 2}), Shape({1, 2})); - - EXPECT_NE(Shape({}), Shape({1})); - EXPECT_NE(Shape({-1}), Shape({1})); - EXPECT_NE(Shape({1}), Shape({})); - EXPECT_NE(Shape({1}), Shape({2})); - EXPECT_NE(Shape({1, 2}), Shape({3, 4})); -} - -} // namespace -} // namespace iree
diff --git a/iree/base/source_location.h b/iree/base/source_location.h deleted file mode 100644 index d4a10ca..0000000 --- a/iree/base/source_location.h +++ /dev/null
@@ -1,24 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BASE_SOURCE_LOCATION_H_ -#define IREE_BASE_SOURCE_LOCATION_H_ - -#ifdef IREE_CONFIG_GOOGLE_INTERNAL -#include "iree/base/google/source_location_google.h" -#else -#include "iree/base/internal/source_location.h" -#endif // IREE_CONFIG_GOOGLE_INTERNAL - -#endif // IREE_BASE_SOURCE_LOCATION_H_
diff --git a/iree/base/status.h b/iree/base/status.h deleted file mode 100644 index 4565101..0000000 --- a/iree/base/status.h +++ /dev/null
@@ -1,32 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BASE_STATUS_H_ -#define IREE_BASE_STATUS_H_ - -#ifdef IREE_CONFIG_GOOGLE_INTERNAL -#include "iree/base/google/status_google.h" -#else -#include "iree/base/internal/status.h" -#include "iree/base/internal/status_builder.h" -#include "iree/base/internal/status_errno.h" -#include "iree/base/internal/status_errors.h" -#include "iree/base/internal/status_macros.h" -#include "iree/base/internal/status_win32_errors.h" -#include "iree/base/internal/statusor.h" -#endif // IREE_CONFIG_GOOGLE_INTERNAL - -#include "iree/base/source_location.h" // IWYU pragma: export - -#endif // IREE_BASE_STATUS_H_
diff --git a/iree/base/status_matchers.h b/iree/base/status_matchers.h deleted file mode 100644 index 7920c57..0000000 --- a/iree/base/status_matchers.h +++ /dev/null
@@ -1,28 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BASE_STATUS_MATCHERS_H_ -#define IREE_BASE_STATUS_MATCHERS_H_ - -#ifdef IREE_CONFIG_GOOGLE_INTERNAL - -#include "iree/base/google/status_matchers_google.h" // IWYU pragma: export - -#else - -#include "iree/base/internal/status_matchers.h" // IWYU pragma: export - -#endif // IREE_CONFIG_GOOGLE_INTERNAL - -#endif // IREE_BASE_STATUS_MATCHERS_H_
diff --git a/iree/base/tracing.cc b/iree/base/tracing.cc deleted file mode 100644 index 51fa5d0..0000000 --- a/iree/base/tracing.cc +++ /dev/null
@@ -1,188 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Force the header to detect WTF_ENABLE so that this library builds -// (for when building recursively). -#if !defined(WTF_ENABLE) -#define WTF_ENABLE -#endif - -#include "iree/base/tracing.h" - -#include <thread> // NOLINT: Fiber doesn't work during startup on Android. - -#include "absl/base/attributes.h" -#include "absl/base/const_init.h" -#include "absl/base/thread_annotations.h" -#include "absl/flags/flag.h" -#include "absl/strings/str_cat.h" -#include "absl/synchronization/mutex.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "iree/base/file_io.h" -#include "iree/base/file_path.h" -#include "iree/base/init.h" -#include "iree/base/logging.h" -#include "iree/base/status.h" - -ABSL_FLAG(int32_t, iree_trace_file_period, 5, - "Seconds between automatic flushing of WTF trace files. 0 to " - "disable auto-flush."); -ABSL_FLAG(std::string, iree_trace_file, "/dev/null", - "wtf-trace file to save if --define=GLOBAL_WTF_ENABLE=1 was used " - "when building."); - -namespace iree { -namespace { - -// Guards global WTF state (like the flush fiber and IO). -ABSL_CONST_INIT absl::Mutex global_tracing_mutex(absl::kConstInit); - -// True when tracing has been enabled and initialized. -bool global_tracing_initialized ABSL_GUARDED_BY(global_tracing_mutex) = false; - -// If there is an existing file at the given path back it up by moving it aside. -// Only kMaxBackups will be kept to avoid unbounded growth. -void RollTraceFiles(const std::string& path) { - std::string path_stem = file_path::JoinPaths(file_path::DirectoryName(path), - file_path::Stem(path)); - const int kMaxBackups = 5; - for (int i = kMaxBackups; i >= 0; i--) { - std::string source_name; - if (i > 0) { - source_name = absl::StrCat(path_stem, ".", i, ".wtf-trace"); - } else { - source_name = path; - } - if (!file_io::FileExists(source_name).ok()) { - continue; - } - - Status status; - if (i == kMaxBackups) { - status = file_io::DeleteFile(source_name); - } else { - std::string backup_name = - absl::StrCat(path_stem, ".", (i + 1), ".wtf-trace"); - status = file_io::MoveFile(source_name, backup_name); - } - if (!status.ok()) { - LOG(WARNING) << "Could not remove backup trace file " << source_name - << ": " << status; - } - } -} - -// Flushes all recorded trace data since the last flush. -void FlushTraceFile() ABSL_EXCLUSIVE_LOCKS_REQUIRED(global_tracing_mutex) { - if (!global_tracing_initialized) return; - - const auto& trace_path = absl::GetFlag(FLAGS_iree_trace_file); - - static ::wtf::Runtime::SaveCheckpoint checkpoint; - static bool is_first_flush = true; - - if (is_first_flush && trace_path != "/dev/null") { - // Backup existing any existing trace files at the specified path. - RollTraceFiles(trace_path); - } - - auto save_options = - ::wtf::Runtime::SaveOptions::ForStreamingFile(&checkpoint); - if (is_first_flush) { - // On the first time, truncate the file. All subsequent flushes append. - save_options.open_mode = std::ios_base::trunc; - } - - is_first_flush = false; - - auto* runtime = ::wtf::Runtime::GetInstance(); - if (!runtime->SaveToFile(trace_path, save_options)) { - LOG(ERROR) << "Error saving WTF file: " << trace_path; - return; - } - - VLOG(1) << "Flushed WTF trace to: " << trace_path; -} - -} // namespace - -void InitializeTracing() { - if (!::wtf::kMasterEnable) { - if (!absl::GetFlag(FLAGS_iree_trace_file).empty()) { - LOG(WARNING) << "WTF trace save requested but WTF is not compiled in. " - << "Enable by building with --define=GLOBAL_WTF_ENABLE=1."; - } - return; - } - - absl::MutexLock lock(&global_tracing_mutex); - if (global_tracing_initialized) return; - global_tracing_initialized = true; - - LOG(INFO) << "Tracing enabled and streaming to: " - << absl::GetFlag(FLAGS_iree_trace_file); - - // Enable tracing on this thread, which we know is main. - IREE_TRACE_THREAD_ENABLE("main"); - - // Register atexit callback to stop tracking. - atexit(StopTracing); - - // Launch a thread to periodically flush the trace. - if (absl::GetFlag(FLAGS_iree_trace_file_period) > 0) { - auto flush_thread = std::thread(+[]() { - absl::Duration period = - absl::Seconds(absl::GetFlag(FLAGS_iree_trace_file_period)); - while (true) { - absl::SleepFor(period); - absl::MutexLock lock(&global_tracing_mutex); - if (!global_tracing_initialized) { - return; - } - FlushTraceFile(); - } - }); - flush_thread.detach(); - } -} - -// Stops tracing if currently initialized. -void StopTracing() { - if (!::wtf::kMasterEnable) return; - absl::MutexLock lock(&global_tracing_mutex); - if (!global_tracing_initialized) return; - - // Flush any pending trace data. - FlushTraceFile(); - - // Mark WTF as uninitialized to kill the flush thread. - global_tracing_initialized = false; - - LOG(INFO) << "Tracing stopped and flushed to file: " - << absl::GetFlag(FLAGS_iree_trace_file); -} - -void FlushTrace() { - if (!::wtf::kMasterEnable) return; - absl::MutexLock lock(&global_tracing_mutex); - if (!global_tracing_initialized) return; - FlushTraceFile(); -} - -} // namespace iree - -IREE_DECLARE_MODULE_INITIALIZER(iree_tracing); - -IREE_REGISTER_MODULE_INITIALIZER(iree_tracing, ::iree::InitializeTracing());
diff --git a/iree/base/tracing_disabled.cc b/iree/base/tracing_disabled.cc deleted file mode 100644 index 96ea115..0000000 --- a/iree/base/tracing_disabled.cc +++ /dev/null
@@ -1,29 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file is linked in only when WTF is not enabled. It allows us to keep the -// same flags and functions without needing to do a bunch of ifdef hackery or -// undefok mangling. - -#include <cstdint> -#include <string> - -#include "absl/flags/flag.h" -#include "iree/base/tracing.h" - -// TODO(benvanik): remove this when disabled so that we don't dep on flags. -ABSL_FLAG(int32_t, iree_trace_file_period, 0, - "Flag for tracing. Use --define=GLOBAL_WTF_ENABLE=1 to enable WTF."); -ABSL_FLAG(std::string, iree_trace_file, "", - "Flag for tracing. Use --define=GLOBAL_WTF_ENABLE=1 to enable WTF.");
diff --git a/iree/base/wait_handle.cc b/iree/base/wait_handle.cc deleted file mode 100644 index 39e5ca9..0000000 --- a/iree/base/wait_handle.cc +++ /dev/null
@@ -1,532 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/wait_handle.h" - -#include <errno.h> -#include <fcntl.h> -#include <poll.h> -#include <time.h> -#include <unistd.h> - -#include <type_traits> -#include <utility> - -#include "absl/container/fixed_array.h" -#include "absl/strings/str_cat.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "iree/base/source_location.h" -#include "iree/base/status.h" - -// TODO(benvanik): organize these macros - they are terrible. - -#if !defined(__ANDROID__) && !defined(OS_IOS) && !defined(__EMSCRIPTEN__) -#define IREE_HAS_PPOLL 1 -#endif // !__ANDROID__ && !__EMSCRIPTEN__ -#define IREE_HAS_POLL 1 - -#if !defined(OS_IOS) && !defined(OS_MACOSX) && !defined(__EMSCRIPTEN__) -#define IREE_HAS_EVENTFD 1 -#endif -#define IREE_HAS_PIPE 1 -// #define IREE_HAS_SYNC_FILE 1 - -#if defined(IREE_HAS_EVENTFD) -#include <sys/eventfd.h> -#endif // IREE_HAS_EVENTFD - -namespace iree { - -namespace { - -constexpr int kInvalidFd = WaitableObject::kInvalidFd; -constexpr int kSignaledFd = WaitableObject::kSignaledFd; - -// Retries a syscall until it succeeds or fails for a real reason. -template <typename SyscallT, typename... ParamsT> -StatusOr<typename std::result_of<SyscallT(ParamsT...)>::type> Syscall( - SyscallT syscall, ParamsT&&... params) { - while (true) { - const auto rv = syscall(std::forward<ParamsT>(params)...); - if (rv >= 0) return rv; - if (errno == EINTR) { - // Retry on EINTR. - continue; - } else { - return ErrnoToCanonicalStatus(errno, ""); - } - } -} - -#if defined(IREE_HAS_PPOLL) - -// ppoll(), present on Linux. -// ppoll is preferred as it has a much better timing mechanism; poll can have a -// large slop on the deadline. -// Documentation: https://linux.die.net/man/2/poll -StatusOr<int> SystemPoll(absl::Span<pollfd> poll_fds, absl::Time deadline) { - // Convert the deadline into a tmo_p struct for ppoll that controls whether - // the call is blocking or non-blocking. Note that we must do this every - // iteration of the loop as a previous ppoll may have taken some of the - // time. - // - // See the ppoll docs for more information as to what the expected value is: - // http://man7.org/linux/man-pages/man2/poll.2.html - timespec timeout_spec; - timespec* tmo_p; - if (deadline == absl::InfinitePast()) { - // 0 for non-blocking. - timeout_spec = {0}; - tmo_p = &timeout_spec; - } else if (deadline == absl::InfiniteFuture()) { - // nullptr to ppoll() to block forever. - tmo_p = nullptr; - } else { - // Wait only for as much time as we have before the deadline is exceeded. - absl::Duration remaining_time = deadline - absl::Now(); - if (remaining_time < absl::ZeroDuration()) { - // Note: we likely have already bailed before getting here with a negative - // duration. - return DeadlineExceededErrorBuilder(IREE_LOC); - } - timeout_spec = absl::ToTimespec(remaining_time); - tmo_p = &timeout_spec; - } - return Syscall(::ppoll, poll_fds.data(), poll_fds.size(), tmo_p, nullptr); -} - -#elif defined(IREE_HAS_POLL) - -// poll(), present pretty much everywhere. -// Documentation: https://linux.die.net/man/2/poll -StatusOr<int> SystemPoll(absl::Span<pollfd> poll_fds, absl::Time deadline) { - int timeout; - if (deadline == absl::InfinitePast()) { - // Don't block. - timeout = 0; - } else if (deadline == absl::InfiniteFuture()) { - // Block forever. - timeout = -1; - } else { - absl::Duration remaining_time = deadline - absl::Now(); - if (remaining_time < absl::ZeroDuration()) { - return DeadlineExceededErrorBuilder(IREE_LOC); - } - timeout = static_cast<int>(absl::ToInt64Milliseconds(remaining_time)); - } - return Syscall(::poll, poll_fds.data(), poll_fds.size(), timeout); -} - -#else -#error "No SystemPoll implementation" -#endif // IREE_HAS_PPOLL / IREE_HAS_POLL / etc - -// Builds the list of pollfds to for ppoll wait on and will perform any -// required wait handle callbacks. -// -// The provided deadline will be observed if any of the wait handles needs to -// block for acquiring an fd. -StatusOr<absl::FixedArray<pollfd>> AcquireWaitHandles( - WaitHandle::WaitHandleSpan wait_handles, absl::Time deadline) { - absl::FixedArray<pollfd> poll_fds{wait_handles.size()}; - for (int i = 0; i < wait_handles.size(); ++i) { - poll_fds[i].events = POLLIN | POLLPRI | POLLERR | POLLHUP | POLLNVAL; - poll_fds[i].revents = 0; - // NOTE: poll will ignore any negative fds and our kInvalidFd == -1 so we - // can still put them in the list and it'll just skip them. - if (!wait_handles[i] || !wait_handles[i]->object()) { - poll_fds[i].fd = kInvalidFd; - continue; - } - - // Acquire the file descriptor for waiting. - // This may block (if |deadline| allows it) if the fd is not yet available. - // This is like a pre-wait for the actual poll operation. It can be bad with - // WaitAny, though we could handle that better here. - ASSIGN_OR_RETURN(auto fd_info, - wait_handles[i]->object()->AcquireFdForWait(deadline)); - poll_fds[i].fd = fd_info.second; - - // Abort if deadline exceeded. - if (deadline != absl::InfinitePast() && deadline < absl::Now()) { - return DeadlineExceededErrorBuilder(IREE_LOC) - << "Deadline exceeded acquiring for fds"; - } - } - return poll_fds; -} - -Status ClearFd(WaitableObject::FdType fd_type, int fd) { - // Read in a loop until the read would block. - // Depending on how the users setup the fd the act of reading may reset the - // entire handle (such as with the default eventfd mode) or multiple reads - // may be required (such as with semaphores). - while (true) { -#if defined(IREE_HAS_EVENTFD) - eventfd_t val = 0; - int rv = ::eventfd_read(fd, &val); -#elif defined(IREE_HAS_PIPE) - char buf; - int rv = ::read(fd, &buf, 1); -#else - return UnimplementedErrorBuilder(IREE_LOC) << "fd_type cannot be cleared"; -#endif // IREE_HAS_EVENTFD - if (rv != -1) { - // Success! Keep going. - continue; - } else { - if (errno == EWOULDBLOCK) { - // The read would have blocked meaning that we've hit the end and - // successfully cleared the fd. - return OkStatus(); - } else if (errno == EINTR) { - // Retry. - continue; - } else { - return ErrnoToCanonicalStatus(errno, "ClearFd failed"); - } - } - } -} - -// Performs a single poll on multiple fds and returns information about the -// signaled fds, if any. -Status MultiPoll(WaitHandle::WaitHandleSpan wait_handles, - absl::Span<pollfd> poll_fds, absl::Time deadline, - int* out_any_signaled_index, int* out_unsignaled_count) { - *out_any_signaled_index = -1; - *out_unsignaled_count = 0; - - // poll has a nasty behavior where it allows -1 for fds... except for at [0]. - // To keep the rest of the code sane we correct for that here as epoll doesn't - // have that behavior and we may want to special case this later. - bool any_valid_fds = true; - int swapped_zero_index = -1; - if (poll_fds[0].fd < 0) { - // Find a valid handle. - for (int i = 1; i < poll_fds.size(); ++i) { - if (poll_fds[i].fd > 0) { - swapped_zero_index = i; - std::swap(poll_fds[0], poll_fds[i]); - break; - } - } - if (swapped_zero_index == -1) { - // No valid handles found, meaning that all handles are invalid. - // We'll skip the wait below so we can share the processing code for any - // fds that may be kSignaledFd. - any_valid_fds = false; - } - } - - // Pass handles to ppoll. - // http://man7.org/linux/man-pages/man2/poll.2.html - if (any_valid_fds) { - ASSIGN_OR_RETURN(int rv, SystemPoll(poll_fds, deadline)); - if (rv == 0) { - // Call timed out and no descriptors were ready. - // If this was just a poll then that's fine. - return DeadlineExceededErrorBuilder(IREE_LOC); - } - } - - // If we had swapped fds[0] above we need to correct for that now. - if (swapped_zero_index != -1) { - std::swap(poll_fds[0], poll_fds[swapped_zero_index]); - } - - // |rv| denotes the number of fds that were ready. Run through the list and - // find the ones that were ready and mark them as completed. - for (int i = 0; i < poll_fds.size(); ++i) { - if (poll_fds[i].fd == kSignaledFd || poll_fds[i].revents == POLLIN) { - // First attempt any resolve actions. If these fail we can't consider the - // fd as having been signaled. - ASSIGN_OR_RETURN( - bool resolved, - wait_handles[i]->object()->TryResolveWakeOnFd(poll_fds[i].fd)); - if (!resolved) { - ++(*out_unsignaled_count); - continue; - } - - // Successful wait. Kill the fd so it is ignored on the next poll. - poll_fds[i].fd = kInvalidFd; - *out_any_signaled_index = i; - } else if (poll_fds[i].revents) { - if (poll_fds[i].revents & POLLERR) { - return InternalErrorBuilder(IREE_LOC); - } else if (poll_fds[i].revents & POLLHUP) { - return CancelledErrorBuilder(IREE_LOC); - } else if (poll_fds[i].revents & POLLNVAL) { - return InvalidArgumentErrorBuilder(IREE_LOC); - } else { - return UnknownErrorBuilder(IREE_LOC); - } - } else if (poll_fds[i].fd != kInvalidFd) { - ++(*out_unsignaled_count); - } - } - - return OkStatus(); -} - -} // namespace - -// static -std::atomic<uint64_t> WaitHandle::next_unique_id_{1}; - -// static -WaitHandle WaitHandle::AlwaysSignaling() { - class AlwaysSignalingObject : public WaitableObject { - public: - std::string DebugString() const override { return "signal"; } - StatusOr<std::pair<FdType, int>> AcquireFdForWait( - absl::Time deadline) override { - return std::make_pair(FdType::kPermanent, kSignaledFd); - } - StatusOr<bool> TryResolveWakeOnFd(int fd) override { return true; } - }; - static auto* obj = new AlwaysSignalingObject(); - return WaitHandle(add_ref(obj)); -} - -// static -WaitHandle WaitHandle::AlwaysFailing() { - class AlwaysFailingObject : public WaitableObject { - public: - std::string DebugString() const override { return "fail"; } - StatusOr<std::pair<FdType, int>> AcquireFdForWait( - absl::Time deadline) override { - return InternalErrorBuilder(IREE_LOC) << "AlwaysFailingObject"; - } - StatusOr<bool> TryResolveWakeOnFd(int fd) override { - return InternalErrorBuilder(IREE_LOC) << "AlwaysFailingObject"; - } - }; - static auto* obj = new AlwaysFailingObject(); - return WaitHandle(add_ref(obj)); -} - -// static -Status WaitHandle::WaitAll(WaitHandleSpan wait_handles, absl::Time deadline) { - if (wait_handles.empty()) return OkStatus(); - - // Build the list of pollfds to wait on. - ASSIGN_OR_RETURN(auto poll_fds, AcquireWaitHandles(wait_handles, deadline)); - - // Loop until all handles have been signaled or the deadline is exceeded. - int unsignaled_count = 0; - do { - int any_signaled_index = 0; - RETURN_IF_ERROR(MultiPoll(wait_handles, absl::MakeSpan(poll_fds), deadline, - &any_signaled_index, &unsignaled_count)); - } while (unsignaled_count > 0 && absl::Now() < deadline); - - if (unsignaled_count == 0) { - // All waits resolved. - return OkStatus(); - } else { - // One or more were unsignaled. - return DeadlineExceededErrorBuilder(IREE_LOC); - } -} - -// static -StatusOr<bool> WaitHandle::TryWaitAll(WaitHandleSpan wait_handles) { - auto status = WaitAll(wait_handles, absl::InfinitePast()); - if (status.ok()) { - return true; - } else if (IsDeadlineExceeded(status)) { - return false; - } - return status; -} - -// static -StatusOr<int> WaitHandle::WaitAny(WaitHandleSpan wait_handles, - absl::Time deadline) { - if (wait_handles.empty()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "At least one wait handle is required for WaitAny"; - } - - // Build the list of pollfds to wait on. - ASSIGN_OR_RETURN(auto poll_fds, AcquireWaitHandles(wait_handles, deadline)); - - // Poll once; this makes a WaitAny just a WaitMulti that doesn't loop. - int any_signaled_index = -1; - int unsignaled_count = 0; - RETURN_IF_ERROR(MultiPoll(wait_handles, absl::MakeSpan(poll_fds), deadline, - &any_signaled_index, &unsignaled_count)); - if (any_signaled_index == -1) { - // No wait handles were valid. Pretend 0 was signaled. - return 0; - } - return any_signaled_index; -} - -// static -StatusOr<int> WaitHandle::TryWaitAny(WaitHandleSpan wait_handles) { - auto status_or = WaitAny(wait_handles, absl::InfinitePast()); - return IsDeadlineExceeded(status_or.status()) ? -1 : status_or; -} - -// Storage for static class variables; these won't be needed when we can use -// c++17 everywhere. -constexpr int WaitableObject::kInvalidFd; -constexpr int WaitableObject::kSignaledFd; - -WaitHandle::WaitHandle(ref_ptr<WaitableObject> object) - : unique_id_(++next_unique_id_), object_(std::move(object)) {} - -WaitHandle::~WaitHandle() { Dispose(); } - -void WaitHandle::Dispose() { object_.reset(); } - -WaitHandle::WaitHandle(WaitHandle&& other) - : unique_id_(other.unique_id_), object_(std::move(other.object_)) { - other.unique_id_ = 0; -} - -WaitHandle& WaitHandle::operator=(WaitHandle&& other) { - if (this != std::addressof(other)) { - // Close current handle. - Dispose(); - - // Take ownership of handle and resources. - object_ = std::move(other.object_); - - other.unique_id_ = ++next_unique_id_; - } - return *this; -} - -std::string WaitHandle::DebugString() const { - return object_ ? object_->DebugString() : absl::StrCat("wh_", unique_id_); -} - -StatusOr<bool> WaitHandle::TryWait() { - auto status = WaitAll({this}, absl::InfinitePast()); - if (status.ok()) { - return true; - } else if (IsDeadlineExceeded(status)) { - return false; - } - return status; -} - -ManualResetEvent::ManualResetEvent(const char* debug_name) - : debug_name_(debug_name) { - Initialize(); -} - -ManualResetEvent::~ManualResetEvent() { Dispose(); } - -void ManualResetEvent::Initialize() { -#if defined(IREE_HAS_EVENTFD) - // Create with an eventfd by default when we support it. - // eventfd has lower overhead than pipes (the syscalls are cheap). - // This usually will only fail if the system is completely out of handles. - // - // Docs: http://man7.org/linux/man-pages/man2/eventfd.2.html - fd_type_ = FdType::kEventFd; - fd_ = Syscall(::eventfd, 0, EFD_CLOEXEC | EFD_NONBLOCK).ValueOrDie(); -#elif defined(IREE_HAS_PIPE) - // Android/Linux/iOS-compatible POSIX pipe handle. - // Two handles are generated: one for transmitting and one for receiving. - // - // Docs: http://man7.org/linux/man-pages/man2/pipe.2.html - fd_type_ = FdType::kPipe; - int pipefd[2]; - Syscall(::pipe, pipefd).ValueOrDie(); - Syscall(::fcntl, pipefd[0], F_SETFL, O_NONBLOCK).ValueOrDie(); - fd_ = pipefd[0]; - write_fd_ = pipefd[1]; -#else -// NOTE: sync_file does not use Notifier as they come from the kernel. -#error "No fd-based sync primitive on this platform" -#endif // IREE_HAS_EVENTFD / IREE_HAS_PIPE / etc -} - -void ManualResetEvent::Dispose() { - if (fd_ != kInvalidFd) { - // Always signal, as we need to ensure waiters are woken. - CHECK_OK(Set()); - Syscall(::close, fd_).ValueOrDie(); - fd_ = kInvalidFd; - } - if (write_fd_ != kInvalidFd) { - Syscall(::close, write_fd_).ValueOrDie(); - write_fd_ = kInvalidFd; - } -} - -ManualResetEvent::ManualResetEvent(ManualResetEvent&& other) - : fd_type_(other.fd_type_), - fd_(other.fd_), - write_fd_(other.write_fd_), - debug_name_(other.debug_name_) { - other.fd_type_ = FdType::kPermanent; - other.fd_ = kInvalidFd; - other.write_fd_ = kInvalidFd; - other.debug_name_ = nullptr; -} - -ManualResetEvent& ManualResetEvent::operator=(ManualResetEvent&& other) { - if (this != std::addressof(other)) { - Dispose(); - fd_type_ = other.fd_type_; - fd_ = other.fd_; - write_fd_ = other.write_fd_; - debug_name_ = other.debug_name_; - other.fd_type_ = FdType::kPermanent; - other.fd_ = kInvalidFd; - other.write_fd_ = kInvalidFd; - other.debug_name_ = nullptr; - other.Initialize(); - } - return *this; -} - -std::string ManualResetEvent::DebugString() const { - if (debug_name_) { - return debug_name_; - } -#if defined(IREE_HAS_EVENTFD) - return absl::StrCat("eventfd_", fd_); -#elif defined(IREE_HAS_PIPE) - return absl::StrCat("pipe_", fd_, "_", write_fd_); -#else - return absl::StrCat("unknown_", fd_, "_", write_fd_); -#endif // IREE_HAS_EVENTFD / IREE_HAS_PIPE -} - -Status ManualResetEvent::Set() { -#if defined(IREE_HAS_EVENTFD) - return Syscall(::eventfd_write, fd_, 1ull).status(); -#elif defined(IREE_HAS_PIPE) - char buf = '\n'; - return Syscall(::write, write_fd_, &buf, 1).status(); -#else - return UnimplementedErrorBuilder(IREE_LOC) - << "No fd-based sync primitive on this platform"; -#endif // IREE_HAS_EVENTFD / IREE_HAS_PIPE -} - -Status ManualResetEvent::Reset() { return ClearFd(fd_type_, fd_); } - -WaitHandle ManualResetEvent::OnSet() { return WaitHandle(add_ref(this)); } - -} // namespace iree
diff --git a/iree/base/wait_handle.h b/iree/base/wait_handle.h deleted file mode 100644 index d051e51..0000000 --- a/iree/base/wait_handle.h +++ /dev/null
@@ -1,321 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BASE_WAIT_HANDLE_H_ -#define IREE_BASE_WAIT_HANDLE_H_ - -#include <atomic> -#include <cstdint> -#include <string> -#include <utility> - -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "absl/types/span.h" -#include "iree/base/ref_ptr.h" -#include "iree/base/status.h" -#include "iree/base/time.h" - -namespace iree { - -// Interfaces for waitable objects that can produce WaitHandles. -// WaitableObjects are much like ::thread::Selectable, only they support both -// the classic locking style as well as file descriptors for use with select(). -// -// Usage: -// class MyWaitableObject : public WaitableObject { -// public: -// std::string DebugString() const override { return "something useful"; } -// WaitHandle OnAsyncTask() { -// return WaitHandle(retain_ref(this)); -// } -// private: -// StatusOr<std::pair<FdType, int>> AcquireFdForWait( -// absl::Time deadline) override { -// // If blocking traditionally do so now and then return this: -// return std::make_pair(FdType::kPermanent, kSignaledFd); -// // Otherwise, see ManualResetEvent for an example using fds. -// } -// StatusOr<bool> TryResolveWakeOnFd(int fd) override { -// // Return true iff the object is really acquired, such as the semaphore -// // being decremented. -// return true; -// } -// }; -class WaitableObject : public RefObject<WaitableObject> { - public: - // Indicates that a file descriptor is invalid. It will not block when waited - // upon. - constexpr static int kInvalidFd = -1; - // Indicates that a file descriptor should be treated as signaled. - // Waiting on this fd should return as if it has already been signaled. - constexpr static int kSignaledFd = -2; - - // Defines the type of the native handle used for synchronization. - enum class FdType : uint16_t { - // Event has no handle and should be treated as permanently signaled. - kPermanent, - - // Android/Linux/iOS-compatible POSIX pipe handle. - // Two handles are generated: one for transmitting and one for receiving. - // - // More information: - // http://man7.org/linux/man-pages/man2/pipe.2.html - kPipe, - - // Android/Linux eventfd handle. - // These are akin to pipe() but require only a single handle and have - // significantly lower overhead (equivalent if not slightly better than - // pthreads condvars). - // - // eventfds support acting as both semaphores and auto reset events. - // - // More information: - // http://man7.org/linux/man-pages/man2/eventfd.2.html - kEventFd, - - // Android/Linux sync_file handle (aka 'sync fence'). - // The handle is allocated indirectly by the device driver via the - // <linux/sync_file.h> API. It may be waited upon with poll(), select(), or - // epoll() and must be closed with close() when no longer required. If - // waiting on multiple sync_files the caller should first merge them - // together. - // - // A sync_file must only be used as fences (one-shot manual reset events). - // - // More information: - // https://www.kernel.org/doc/Documentation/sync_file.txt - // https://lwn.net/Articles/702339/ - // https://source.android.com/devices/graphics/implement-vsync#explicit_synchronization - kSyncFile, - }; - - virtual ~WaitableObject() = default; - - // Returns a string representing the object, either specified as a debug_name - // or a unique ID. - virtual std::string DebugString() const = 0; - - // Attempts to acquire a file descriptor for the waitable objects by the given - // |deadline|. In many cases this will return immediately with a valid fd. - // - // In cases where the file descriptor may not be available the call may block - // until either it is available or the |deadline| has elapsed. Use - // absl::InfinitePast() to prevent blocking. - // - // Returns a valid file descriptor or kInvalidFd as an indication that the - // object should not be waited on (already signaled, etc). Can return - // kSignaledFd to indicate that it's already known that the handle has been - // signaled and the caller should resolve as if it caused a wake normally. - virtual StatusOr<std::pair<FdType, int>> AcquireFdForWait( - absl::Time deadline) = 0; - - // Tries to resolve the object with the given |fd|. - // In many cases this will no-op, however some types may require additional - // checks to ensure that the wait operation succeeded (such as semaphores - // that may need to query a count). If resolution fails the waitable object - // must not be considered signaled. This call will never block. - virtual StatusOr<bool> TryResolveWakeOnFd(int fd) = 0; -}; - -// Handle to waitable objects. -// WaitHandles are created by a particular synchronization primitive, such as -// Fence, as a way for one or more observers to poll or wait for notification. -// -// External synchronization primitives can be wrapped in WaitHandles to enable -// other libraries or languages to be waited on alongside WaitHandles created -// by the IREE primitives like Fence. See the notes on WaitHandleType for a list -// of handle types that are supported. -// -// Wait handles are thread-safe in that multiple threads may be waiting on them -// concurrently. -class WaitHandle { - public: - // Returns a WaitHandle that when waited on will never block. - static WaitHandle AlwaysSignaling(); - - // Returns a WaitHandle that when waited on will always fail. - static WaitHandle AlwaysFailing(); - - using WaitHandleSpan = absl::Span<WaitHandle* const>; - - // Blocks the caller until all passed |wait_handles| are signaled or the - // |deadline| elapses. - // - // Returns success if the wait is successful and all events have been - // signaled. - // - // Returns DEADLINE_EXCEEDED if the |deadline| elapses without all handles - // having been signaled. Note that a subset of the |wait_handles| may have - // been signaled and each can be queried to see which one. - static Status WaitAll(WaitHandleSpan wait_handles, absl::Time deadline); - static Status WaitAll(WaitHandleSpan wait_handles, absl::Duration timeout) { - return WaitAll(wait_handles, RelativeTimeoutToDeadline(timeout)); - } - static Status WaitAll(WaitHandleSpan wait_handles) { - return WaitAll(wait_handles, absl::InfiniteFuture()); - } - - // Tries waiting on the handles and returns immediately if it would have - // blocked. The caller will not be blocked even if a handle has not yet been - // signaled. - // - // Returns true if all handles have been signaled. - static StatusOr<bool> TryWaitAll(WaitHandleSpan wait_handles); - - // Blocks the caller until at least one of the |wait_handles| is signaled or - // the |deadline| elapses. - // - // Returns the index into |wait_handles| of a handle that was signaled. Note - // that more than one handle may have been signaled and all of the other - // |wait_handles| should be queried or waited on again until waits for them - // succeed. - // - // Returns DEADLINE_EXCEEDED if the |deadline| elapses without any handles - // having been signaled. - static StatusOr<int> WaitAny(WaitHandleSpan wait_handles, - absl::Time deadline); - static StatusOr<int> WaitAny(WaitHandleSpan wait_handles, - absl::Duration timeout) { - return WaitAny(wait_handles, RelativeTimeoutToDeadline(timeout)); - } - static StatusOr<int> WaitAny(WaitHandleSpan wait_handles) { - return WaitAny(wait_handles, absl::InfiniteFuture()); - } - - // Tries waiting for at least one handle to complete and returns immediately - // if none have been. The caller will not be blocked even if a handle has not - // yet been signaled. - // - // Returns the index into |wait_handles| of a handle that was signaled. Note - // that more than one handle may have been signaled and all of the other - // |wait_handles| should be queried or waited on again until waits for them - // succeed. - // - // Returns -1 if no handles were signaled. - static StatusOr<int> TryWaitAny(WaitHandleSpan wait_handles); - - // Default constructor creates a permanently signaled handle. - // Waiting on this handle will never block. - WaitHandle() = default; - - // Wraps an existing sync file descriptor. - // Ownership of the file descriptor is transferred to the WaitHandle and must - // be duplicated by the caller if they want to continue using it. - explicit WaitHandle(ref_ptr<WaitableObject> object); - - ~WaitHandle(); - - // Copying not supported. Create a new WaitHandle from the source. - WaitHandle(const WaitHandle&) = delete; - WaitHandle& operator=(const WaitHandle&) = delete; - - // Moving supported; sync primitive ownership is transferred. - WaitHandle(WaitHandle&& other); - WaitHandle& operator=(WaitHandle&& other); - - // Unique ID for the WaitHandle instance. - // Two wait handles, even if waiting on the same underlying primitive, will - // have differing unique_ids. This can be used for deduping the handles or - // storing handles in a map. - uint64_t unique_id() const { return unique_id_; } - - // Returns a unique string representing the handle. - std::string DebugString() const; - - // Blocks the caller until the handle is signaled or the |deadline| elapses. - // - // If waiting on multiple wait handles use WaitAll or WaitAny instead of - // multiple calls to Wait as they can significantly reduce overhead. - // - // Returns success if the wait is successful and the |wait_handle| was - // signaled. Returns DEADLINE_EXCEEDED if the timeout elapses without the - // handle having been signaled. - Status Wait(absl::Time deadline) { return WaitAll({this}, deadline); } - Status Wait(absl::Duration timeout) { - return WaitAll({this}, RelativeTimeoutToDeadline(timeout)); - } - Status Wait() { return WaitAll({this}, absl::InfiniteFuture()); } - - // Tries waiting on the handle and returns immediately if it would have - // waited. The caller will not be blocked even if the handle has not yet been - // signaled. - // - // Returns true if the handle has been signaled. - StatusOr<bool> TryWait(); - - // These accessors should generally be considered opaque but may be useful to - // code trying to interop with other runtimes. - const ref_ptr<WaitableObject>& object() const { return object_; } - - private: - // Disposes the handle by closing the fd and issuing callbacks. - void Dispose(); - - static std::atomic<uint64_t> next_unique_id_; - - uint64_t unique_id_ = 0; - ref_ptr<WaitableObject> object_; -}; - -// A manually-resettable event primitive. -// Effectively a binary semaphore with a maximum_count of 1 when running in -// auto-reset mode but also provides a sticky manual reset mode. -class ManualResetEvent : public WaitableObject { - public: - explicit ManualResetEvent(const char* debug_name = nullptr); - - ~ManualResetEvent() override; - - // Copying not supported. - ManualResetEvent(const ManualResetEvent&) = delete; - ManualResetEvent& operator=(const ManualResetEvent&) = delete; - - // Moving supported; sync primitive ownership is transferred. - ManualResetEvent(ManualResetEvent&& other); - ManualResetEvent& operator=(ManualResetEvent&& other); - - std::string DebugString() const override; - - // Sets the specified event object to the signaled state. - // The event stays signaled until Reset is called. Multiple waiters will be - // woken. - Status Set(); - - // Resets the specified event object to the nonsignaled state. - // Resetting an event that is already reset has no effect. - Status Reset(); - - // Returns a WaitHandle that will be signaled when the event is set. - WaitHandle OnSet(); - - protected: - void Initialize(); - void Dispose(); - - StatusOr<std::pair<FdType, int>> AcquireFdForWait( - absl::Time deadline) override { - return std::make_pair(fd_type_, fd_); - } - StatusOr<bool> TryResolveWakeOnFd(int fd) override { return true; } - - FdType fd_type_ = FdType::kPermanent; - int fd_ = kInvalidFd; - int write_fd_ = kInvalidFd; // Used only for fd_type_ == kPipe. - const char* debug_name_ = nullptr; -}; - -} // namespace iree - -#endif // IREE_BASE_WAIT_HANDLE_H_
diff --git a/iree/base/wait_handle_test.cc b/iree/base/wait_handle_test.cc deleted file mode 100644 index 600a63f..0000000 --- a/iree/base/wait_handle_test.cc +++ /dev/null
@@ -1,555 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/base/wait_handle.h" - -#include <unistd.h> - -#include <string> -#include <thread> // NOLINT -#include <type_traits> - -#include "absl/time/time.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "iree/base/status.h" -#include "iree/base/status_matchers.h" - -// StatusOr<bool> will be true if the status is ok, which is bad. -#define ASSERT_STATUSOR_TRUE(x) ASSERT_TRUE(x.ValueOrDie()) -#define ASSERT_STATUSOR_FALSE(x) ASSERT_FALSE(x.ValueOrDie()) - -namespace iree { -namespace { - -using ::testing::_; -using ::testing::Return; - -// Tests the AlwaysSignaling helper. -TEST(WaitHandleTest, AlwaysSignaling) { - ASSERT_OK(WaitHandle::AlwaysSignaling().Wait()); - EXPECT_FALSE(WaitHandle::AlwaysSignaling().DebugString().empty()); -} - -// Tests the AlwaysFailing helper. -TEST(WaitHandleTest, AlwaysFailing) { - ASSERT_FALSE(WaitHandle::AlwaysFailing().Wait().ok()); - EXPECT_FALSE(WaitHandle::AlwaysFailing().DebugString().empty()); -} - -// Tests the basic lifecycle of a permanently signaled wait handle. -TEST(WaitHandleTest, LifecyclePermanentSignaled) { - // Just to be sure it's ok to safely no-op a WaitHandle value. - WaitHandle wh_never_used; - (void)wh_never_used; - - // Try waiting; should return immediately. - WaitHandle wh0; - ASSERT_OK(wh0.Wait()); - - // Waits on multiple permanent handles should be ok. - WaitHandle wh1; - ASSERT_OK(WaitHandle::WaitAll({&wh0, &wh1})); -} - -// Tests moving permanent WaitHandles around. -TEST(WaitHandleTest, MovePermanent) { - WaitHandle wh0; - WaitHandle wh1{std::move(wh0)}; - WaitHandle wh2 = std::move(wh1); - wh1 = std::move(wh2); -} - -// Tests moving around real handles (that may require closing). -TEST(WaitHandleTest, MoveRealHandle) { - ManualResetEvent fence0; - WaitHandle wh0 = fence0.OnSet(); - WaitHandle wh1{std::move(wh0)}; - WaitHandle wh2 = std::move(wh1); - wh1 = std::move(wh2); - - // Now overwrite the handle value to force a close. - ManualResetEvent fence1; - WaitHandle wh3 = fence1.OnSet(); - wh1 = std::move(wh3); - wh1 = WaitHandle(); // Ensure handle dies first. -} - -// Tests the various forms of waiting on a single WaitHandle. -// Since these just call WaitAll we leave the involved testing to those. -TEST(WaitHandleTest, SingleWait) { - WaitHandle wh; - ASSERT_OK(wh.Wait()); - ASSERT_OK(wh.Wait(absl::Now() + absl::Seconds(1))); - ASSERT_OK(wh.Wait(absl::Seconds(1))); - ASSERT_STATUSOR_TRUE(wh.TryWait()); -} - -// Tests using WaitAll with no valid handles. This should no-op. -TEST(WaitHandleTest, WaitAllNop) { - ASSERT_OK(WaitHandle::WaitAll({})); - ASSERT_OK(WaitHandle::WaitAll({nullptr})); - ASSERT_OK(WaitHandle::WaitAll({nullptr, nullptr})); -} - -// Tests polling with WaitAll with multiple wait handles. -TEST(WaitHandleTest, WaitAllPoll) { - ManualResetEvent fence0; - WaitHandle wh0 = fence0.OnSet(); - ManualResetEvent fence1; - WaitHandle wh1 = fence1.OnSet(); - - // Poll; should return immediately with timeout. - ASSERT_TRUE(IsDeadlineExceeded( - WaitHandle::WaitAll({&wh0, &wh1}, absl::InfinitePast()))); - - // Notify fence1. - ASSERT_OK(fence1.Set()); - - // Poll; should return immediately with timeout as fence1 is not signaled. - ASSERT_TRUE(IsDeadlineExceeded( - WaitHandle::WaitAll({&wh0, &wh1}, absl::InfinitePast()))); - - // Notify fence0. - ASSERT_OK(fence0.Set()); - - // Poll again; should return immediately with success. - ASSERT_OK(WaitHandle::WaitAll({&wh0, &wh1}, absl::InfinitePast())); -} - -// Tests waiting when the first file handle is invalid. This is to verify a -// workaround for bad poll() behavior with fds[0] == -1. -TEST(WaitHandleTest, WaitAllWithInvalid0) { - ManualResetEvent fence; - WaitHandle wh = fence.OnSet(); - - // Poll; should return immediately with timeout as fence is not signaled. - ASSERT_TRUE(IsDeadlineExceeded( - WaitHandle::WaitAll({nullptr, &wh}, absl::InfinitePast()))); - - // Notify fence. - ASSERT_OK(fence.Set()); - - // Poll again; should return immediately with success. - ASSERT_OK(WaitHandle::WaitAll({nullptr, &wh}, absl::InfinitePast())); -} - -// Tests exceeding the timeout deadline with WaitAll. -TEST(WaitHandleTest, WaitAllTimeout) { - ManualResetEvent fence; - WaitHandle wh = fence.OnSet(); - - // Wait with timeout on the unsignaled fence: - // Via polling (should never block): - ASSERT_TRUE( - IsDeadlineExceeded(WaitHandle::WaitAll({&wh}, absl::InfinitePast()))); - ASSERT_STATUSOR_FALSE(WaitHandle::TryWaitAll({&wh})); - // Via time in the near future (should block): - ASSERT_TRUE( - IsDeadlineExceeded(WaitHandle::WaitAll({&wh}, absl::Milliseconds(250)))); - // Via time in the past, should exceed deadline. - ASSERT_TRUE( - IsDeadlineExceeded(WaitHandle::WaitAll({&wh}, absl::Milliseconds(-250)))); - - // Notify and ensure no more timeouts. - ASSERT_OK(fence.Set()); - ASSERT_OK(WaitHandle::WaitAll({&wh}, absl::InfinitePast())); - ASSERT_STATUSOR_TRUE(WaitHandle::TryWaitAll({&wh})); - ASSERT_OK(WaitHandle::WaitAll({&wh}, absl::Milliseconds(250))); - - // Via time in the past, should exceed deadline even if signaled. - ASSERT_TRUE( - IsDeadlineExceeded(WaitHandle::WaitAll({&wh}, absl::Milliseconds(-250)))); -} - -// Tests using WaitAll to wait on other threads. -TEST(WaitHandleTest, WaitAllThreaded) { - // Spin up two threads. - ManualResetEvent fence0; - std::thread t0{[&]() { - ::usleep(absl::ToInt64Microseconds(absl::Milliseconds(250))); - ASSERT_OK(fence0.Set()); - }}; - ManualResetEvent fence1; - std::thread t1{[&]() { - ::usleep(absl::ToInt64Microseconds(absl::Milliseconds(250))); - ASSERT_OK(fence1.Set()); - }}; - - // Wait on both threads to complete. - WaitHandle wh0 = fence0.OnSet(); - WaitHandle wh1 = fence1.OnSet(); - ASSERT_OK(WaitHandle::WaitAll({&wh0, &wh1})); - - t0.join(); - t1.join(); -} - -// Tests using WaitAll with multiple wait handles from the same fence. -TEST(WaitHandleTest, WaitAllSameSource) { - ManualResetEvent fence; - WaitHandle wh0 = fence.OnSet(); - WaitHandle wh1 = fence.OnSet(); - ASSERT_TRUE(IsDeadlineExceeded( - WaitHandle::WaitAll({&wh0, &wh1}, absl::InfinitePast()))); - ASSERT_OK(fence.Set()); - ASSERT_OK(WaitHandle::WaitAll({&wh0, &wh1})); -} - -// Tests using WaitAll with literally the same wait handles. -TEST(WaitHandleTest, WaitAllSameHandle) { - ManualResetEvent fence; - WaitHandle wh = fence.OnSet(); - ASSERT_TRUE(IsDeadlineExceeded( - WaitHandle::WaitAll({&wh, &wh}, absl::InfinitePast()))); - ASSERT_OK(fence.Set()); - ASSERT_OK(WaitHandle::WaitAll({&wh, &wh})); -} - -// Tests WaitAll when a wait handle fails. -TEST(WaitHandleTest, WaitAllFailure) { - WaitHandle good_wh; - // Create a purposefully bad handle to induce an error. - WaitHandle bad_wh = WaitHandle::AlwaysFailing(); - // Should fail with some posixy error. - ASSERT_FALSE(WaitHandle::WaitAll({&good_wh, &bad_wh}).ok()); -} - -// Tests using WaitAny with no valid handles. This should no-op. -TEST(WaitHandleTest, WaitAnyNop) { - ASSERT_TRUE(IsInvalidArgument(WaitHandle::WaitAny({}).status())); - ASSERT_OK_AND_ASSIGN(int index, WaitHandle::WaitAny({nullptr})); - ASSERT_EQ(0, index); - ASSERT_OK_AND_ASSIGN(index, WaitHandle::WaitAny({nullptr, nullptr})); - ASSERT_EQ(0, index); -} - -// Tests polling with WaitAny with multiple wait handles. -TEST(WaitHandleTest, WaitAnyPoll) { - ManualResetEvent fence0; - WaitHandle wh0 = fence0.OnSet(); - ManualResetEvent fence1; - WaitHandle wh1 = fence1.OnSet(); - - // Poll; should return immediately with timeout. - ASSERT_TRUE(IsDeadlineExceeded( - WaitHandle::WaitAny({&wh0, &wh1}, absl::InfinitePast()).status())); - - // Notify fence1. - ASSERT_OK(fence1.Set()); - - // Poll; should return immediately with fence1 signaled. - ASSERT_OK_AND_ASSIGN(int index, - WaitHandle::WaitAny({&wh0, &wh1}, absl::InfinitePast())); - EXPECT_EQ(1, index); - - // Notify fence0. - ASSERT_OK(fence0.Set()); - - // Poll again; should return immediately; which one is signaled is undefined. - ASSERT_OK_AND_ASSIGN(index, - WaitHandle::WaitAny({&wh0, &wh1}, absl::InfinitePast())); - ASSERT_TRUE(index == 0 || index == 1); -} - -// Tests exceeding the timeout deadline with WaitAny. -TEST(WaitHandleTest, WaitAnyTimeout) { - ManualResetEvent fence0; - WaitHandle wh0 = fence0.OnSet(); - ManualResetEvent fence1; - WaitHandle wh1 = fence1.OnSet(); - - // Wait with timeout on the unsignaled fences: - // Via polling (should never block): - ASSERT_TRUE(IsDeadlineExceeded( - WaitHandle::WaitAny({&wh0, &wh1}, absl::InfinitePast()).status())); - ASSERT_OK_AND_ASSIGN(int index, WaitHandle::TryWaitAny({&wh0, &wh1})); - ASSERT_EQ(-1, index); - // Via time in the near future (should block): - ASSERT_TRUE(IsDeadlineExceeded( - WaitHandle::WaitAny({&wh0, &wh1}, absl::Milliseconds(250)).status())); - - // Notify one of the fences. Should return immediately. - ASSERT_OK(fence1.Set()); - ASSERT_OK_AND_ASSIGN(index, - WaitHandle::WaitAny({&wh0, &wh1}, absl::InfinitePast())); - ASSERT_EQ(1, index); - ASSERT_OK_AND_ASSIGN(index, WaitHandle::TryWaitAny({&wh0, &wh1})); - ASSERT_EQ(1, index); - ASSERT_OK_AND_ASSIGN( - index, WaitHandle::WaitAny({&wh0, &wh1}, absl::Milliseconds(250))); - ASSERT_EQ(1, index); - - // The unnotified fence should still timeout. - ASSERT_TRUE(IsDeadlineExceeded( - WaitHandle::WaitAny({&wh0}, absl::InfinitePast()).status())); - ASSERT_OK_AND_ASSIGN(index, WaitHandle::TryWaitAny({&wh0})); - ASSERT_EQ(-1, index); - ASSERT_TRUE(IsDeadlineExceeded( - WaitHandle::WaitAny({&wh0}, absl::Milliseconds(250)).status())); - - // Notify last fence and ensure complete. - ASSERT_OK(fence0.Set()); - ASSERT_OK_AND_ASSIGN(index, - WaitHandle::WaitAny({&wh0}, absl::InfinitePast())); - ASSERT_EQ(0, index); - ASSERT_OK_AND_ASSIGN(index, WaitHandle::TryWaitAny({&wh0})); - ASSERT_EQ(0, index); - ASSERT_OK_AND_ASSIGN(index, - WaitHandle::WaitAny({&wh0}, absl::Milliseconds(250))); - ASSERT_EQ(0, index); -} - -// Tests using WaitAny to wait on other threads. -TEST(WaitHandleTest, WaitAnyThreaded) { - // Spin up two threads. - // t1 will wait on t0 such that they will act in sequence. - ManualResetEvent fence0; - std::thread t0{[&]() { - ::usleep(absl::ToInt64Microseconds(absl::Milliseconds(250))); - ASSERT_OK(fence0.Set()); - }}; - ManualResetEvent fence1; - std::thread t1{[&]() { - ASSERT_OK(fence0.OnSet().Wait()); - ::usleep(absl::ToInt64Microseconds(absl::Milliseconds(250))); - ASSERT_OK(fence1.Set()); - }}; - - // Wait on both threads. We expect 0 to complete first. - WaitHandle wh0 = fence0.OnSet(); - WaitHandle wh1 = fence1.OnSet(); - ASSERT_OK_AND_ASSIGN(int index, WaitHandle::WaitAny({&wh0, &wh1})); - ASSERT_EQ(0, index); - - // Now wait for thread 1. - ASSERT_OK_AND_ASSIGN(index, WaitHandle::WaitAny({&wh1})); - ASSERT_EQ(0, index); - - t0.join(); - t1.join(); -} - -// Tests using WaitAny with multiple wait handles from the same fence. -TEST(WaitHandleTest, WaitAnySameSource) { - ManualResetEvent fence; - WaitHandle wh0 = fence.OnSet(); - WaitHandle wh1 = fence.OnSet(); - ASSERT_TRUE(IsDeadlineExceeded( - WaitHandle::WaitAny({&wh0, &wh1}, absl::InfinitePast()).status())); - ASSERT_OK(fence.Set()); - ASSERT_OK_AND_ASSIGN(int index, WaitHandle::WaitAny({&wh0, &wh1})); - ASSERT_TRUE(index == 0 || index == 1); -} - -// Tests using WaitAny with literally the same wait handles. -TEST(WaitHandleTest, WaitAnySameHandle) { - ManualResetEvent fence; - WaitHandle wh = fence.OnSet(); - ASSERT_TRUE(IsDeadlineExceeded( - WaitHandle::WaitAny({&wh, &wh}, absl::InfinitePast()).status())); - ASSERT_OK(fence.Set()); - ASSERT_OK_AND_ASSIGN(int index, WaitHandle::WaitAny({&wh, &wh})); - ASSERT_TRUE(index == 0 || index == 1); -} - -// Tests WaitAny when a wait handle fails. -TEST(WaitHandleTest, WaitAnyFailure) { - WaitHandle good_wh; - // Create a purposefully bad handle to induce an error. - WaitHandle bad_wh = WaitHandle::AlwaysFailing(); - // Should fail with some posixy error. - ASSERT_FALSE(WaitHandle::WaitAny({&good_wh, &bad_wh}).ok()); -} - -// ManualResetEvent with innards exposed. Meh. -class ExposedManualResetEvent : public ManualResetEvent { - public: - using ManualResetEvent::AcquireFdForWait; - using ManualResetEvent::TryResolveWakeOnFd; -}; - -// Mock type for the WaitableObject methods. -class MockWaitableObject : public ::testing::StrictMock<WaitableObject> { - public: - MockWaitableObject() : ::testing::StrictMock<WaitableObject>() {} - - MOCK_CONST_METHOD0(DebugString, std::string()); - MOCK_METHOD1(AcquireFdForWait, - StatusOr<std::pair<FdType, int>>(absl::Time deadline)); - MOCK_METHOD1(TryResolveWakeOnFd, StatusOr<bool>(int fd)); - - WaitHandle OnSomething() { return WaitHandle(add_ref(this)); } -}; - -// Tests normal AcquireFdForWait + TryResolveWakeOnFd use. -TEST(WaitableObjectTest, AcquireAndResolve) { - MockWaitableObject mwo; - WaitHandle wh = mwo.OnSomething(); - - // Use a MRE for testing, as we can just use its fd. - ExposedManualResetEvent mre; - - // Try waiting; we should see the AcquireFdForWait and then return because - // the fd has not been resolved. - EXPECT_CALL(mwo, AcquireFdForWait(_)).WillOnce([&](absl::Time deadline) { - // Return the valid FD from the MRE. - return mre.AcquireFdForWait(deadline); - }); - ASSERT_STATUSOR_FALSE(wh.TryWait()); - - // Signal the MRE. - ASSERT_OK(mre.Set()); - - // Try waiting again; we should get the AcquireFdForWait and then also get - // the TryResolveWakeOnFd. - EXPECT_CALL(mwo, AcquireFdForWait(_)).WillOnce([&](absl::Time deadline) { - // Return the valid (and now signaled) FD from the MRE. - return mre.AcquireFdForWait(deadline); - }); - EXPECT_CALL(mwo, TryResolveWakeOnFd(_)).WillOnce(Return(true)); - ASSERT_STATUSOR_TRUE(wh.TryWait()); -} - -// Tests timing out in AcquireFdForWait. -TEST(WaitableObjectTest, AcquireFdForWaitTimeout) { - ManualResetEvent mre; - WaitHandle always_wait = mre.OnSet(); - WaitHandle always_signal = WaitHandle::AlwaysSignaling(); - MockWaitableObject mwo; - WaitHandle wh = mwo.OnSomething(); - - // Make the AcquireFdForWait take longer than the timeout. We should hit - // deadline exceeded even though always_wait hasn't be signaled. - EXPECT_CALL(mwo, AcquireFdForWait(_)).WillOnce([](absl::Time deadline) { - ::usleep(absl::ToInt64Microseconds(absl::Milliseconds(10))); - return std::make_pair(WaitableObject::FdType::kPermanent, - WaitableObject::kInvalidFd); - }); - ASSERT_TRUE(IsDeadlineExceeded(WaitHandle::WaitAll( - {&wh, &always_signal}, absl::Now() - absl::Milliseconds(250)))); -} - -// Tests TryResolveWakeOnFd when a handle is a permanent kSignaledFd. -TEST(WaitableObjectTest, SignaledFd) { - MockWaitableObject mwo; - WaitHandle wh = mwo.OnSomething(); - - // Return the kSignaledFd handle and expect that we still get our notify call. - // We can do this multiple times. - for (int i = 0; i < 4; ++i) { - EXPECT_CALL(mwo, AcquireFdForWait(_)) - .WillOnce(Return(std::make_pair(WaitableObject::FdType::kPermanent, - WaitableObject::kSignaledFd))); - EXPECT_CALL(mwo, TryResolveWakeOnFd(WaitableObject::kSignaledFd)) - .WillOnce(Return(true)); - ASSERT_STATUSOR_TRUE(wh.TryWait()); - } -} - -// Tests that waiting will not resolve if TryResolveWakeOnFd returns false. -TEST(WaitableObjectTest, UnresolvedWake) { - MockWaitableObject mwo; - WaitHandle wh = mwo.OnSomething(); - - // Fail to resolve the first time. - // Since we are only trying to wait it should bail. - EXPECT_CALL(mwo, AcquireFdForWait(_)) - .WillOnce(Return(std::make_pair(WaitableObject::FdType::kPermanent, - WaitableObject::kSignaledFd))); - EXPECT_CALL(mwo, TryResolveWakeOnFd(WaitableObject::kSignaledFd)) - .WillOnce(Return(false)); - ASSERT_STATUSOR_FALSE(wh.TryWait()); - - // Resolve on the next try. - EXPECT_CALL(mwo, AcquireFdForWait(_)) - .WillOnce(Return(std::make_pair(WaitableObject::FdType::kPermanent, - WaitableObject::kSignaledFd))); - EXPECT_CALL(mwo, TryResolveWakeOnFd(WaitableObject::kSignaledFd)) - .WillOnce(Return(true)); - ASSERT_STATUSOR_TRUE(wh.TryWait()); -} - -// Tests the normal lifecycle of a ManualResetEvent. -TEST(ManualResetEventTest, Lifecycle) { - ManualResetEvent ev; - EXPECT_FALSE(ev.DebugString().empty()); - WaitHandle wh0 = ev.OnSet(); - EXPECT_EQ(ev.DebugString(), wh0.DebugString()); - WaitHandle wh1 = ev.OnSet(); - EXPECT_EQ(ev.DebugString(), wh1.DebugString()); - // Should not be set. - ASSERT_STATUSOR_FALSE(wh0.TryWait()); - ASSERT_STATUSOR_FALSE(wh1.TryWait()); - // Set should be sticky. - ASSERT_OK(ev.Set()); - ASSERT_STATUSOR_TRUE(wh0.TryWait()); - ASSERT_STATUSOR_TRUE(wh1.TryWait()); - // Reset should clear. - ASSERT_OK(ev.Reset()); - ASSERT_STATUSOR_FALSE(wh0.TryWait()); - ASSERT_STATUSOR_FALSE(wh1.TryWait()); - // Setting again should enable the previous WaitHandles to be signaled. - ASSERT_OK(ev.Set()); - ASSERT_STATUSOR_TRUE(wh0.TryWait()); - ASSERT_STATUSOR_TRUE(wh1.TryWait()); -} - -// Tests moving ManualResetEvents around. -TEST(ManualResetEventTest, Move) { - ManualResetEvent ev0; - WaitHandle wh = ev0.OnSet(); - ManualResetEvent ev1{std::move(ev0)}; - ManualResetEvent ev2 = std::move(ev1); - ev1 = std::move(ev2); - ASSERT_OK(ev1.Set()); - ASSERT_STATUSOR_TRUE(wh.TryWait()); -} - -// Tests redundantly setting and resetting ManualResetEvents. -TEST(ManualResetEventTest, RedundantUse) { - ManualResetEvent ev; - ASSERT_OK(ev.Reset()); - ASSERT_OK(ev.Reset()); - ASSERT_FALSE(ev.OnSet().TryWait().ValueOrDie()); - ASSERT_OK(ev.Set()); - ASSERT_OK(ev.Set()); - ASSERT_TRUE(ev.OnSet().TryWait().ValueOrDie()); - ASSERT_OK(ev.Reset()); - ASSERT_FALSE(ev.OnSet().TryWait().ValueOrDie()); -} - -// Tests waiting on an initially-set ManualResetEvent; -TEST(ManualResetEventTest, SetThenWait) { - ManualResetEvent ev; - ASSERT_OK(ev.Set()); - ASSERT_TRUE(ev.OnSet().TryWait().ValueOrDie()); -} - -// Tests that dangling an event will not wake waiters. -// This is intentional (for now); we could with a bit of wrangling make it so -// that WaitableObjects tracked their waiters and ensured they were all cleaned -// up, but that seems hard. Don't drop your objects. -TEST(ManualResetEventTest, NeverSet) { - ManualResetEvent ev; - WaitHandle wh = ev.OnSet(); - ASSERT_STATUSOR_FALSE(wh.TryWait()); - // Kill event to unblock waiters. - ev = ManualResetEvent(); - // Waiter should not have woken. - ASSERT_STATUSOR_FALSE(wh.TryWait()); -} - -} // namespace -} // namespace iree
diff --git a/iree/bindings/python/BUILD b/iree/bindings/python/BUILD deleted file mode 100644 index 0ac8ef1..0000000 --- a/iree/bindings/python/BUILD +++ /dev/null
@@ -1,25 +0,0 @@ -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//iree:build_defs.bzl", "iree_py_library") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -iree_py_library( - name = "pathsetup", - imports = ["."], -)
diff --git a/iree/bindings/python/pyiree/BUILD b/iree/bindings/python/pyiree/BUILD deleted file mode 100644 index 558df05..0000000 --- a/iree/bindings/python/pyiree/BUILD +++ /dev/null
@@ -1,103 +0,0 @@ -# Copyright 2019 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//iree:build_defs.bzl", "NUMPY_DEPS", "iree_py_extension") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -COMPILER_DEPS = [ - "//iree/compiler/Translation/Sequencer", - "//iree/compiler/Translation/Interpreter", - "//iree/compiler/Translation/SPIRV", -] - -DRIVER_DEPS = [ - "//iree/hal/interpreter:interpreter_driver_module", - "//iree/hal/vulkan:vulkan_driver_module", -] - -iree_py_extension( - name = "binding", - srcs = [ - "binding.cc", - "binding.h", - "compiler.cc", - "compiler.h", - "hal.cc", - "hal.h", - "initialize.cc", - "initialize.h", - "rt.cc", - "rt.h", - "status_utils.cc", - "status_utils.h", - "vm.cc", - "vm.h", - ], - copts = [ - "-fexceptions", - ], - features = ["-use_header_modules"], - deps = [ - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - "//iree/base:api", - "//iree/base:init", - "//iree/base:status", - "//iree/hal:api", - "//iree/rt:api", - "//iree/schemas", - "//iree/vm:api", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Parser", - "@iree_pybind11//:pybind11", - ] + COMPILER_DEPS + DRIVER_DEPS, -) - -py_test( - name = "compiler_test", - srcs = ["compiler_test.py"], - python_version = "PY3", - deps = [ - ":binding", - "//iree/bindings/python:pathsetup", - "@absl_py//absl/testing:absltest", - ], -) - -py_test( - name = "hal_test", - srcs = ["hal_test.py"], - python_version = "PY3", - deps = [ - ":binding", - "//iree/bindings/python:pathsetup", - "@absl_py//absl/testing:absltest", - ], -) - -py_test( - name = "runtime_test", - srcs = ["runtime_test.py"], - python_version = "PY3", - deps = NUMPY_DEPS + [ - ":binding", - "@absl_py//absl/testing:absltest", - ], -)
diff --git a/iree/bindings/python/pyiree/binding.cc b/iree/bindings/python/pyiree/binding.cc deleted file mode 100644 index 3fc4b59..0000000 --- a/iree/bindings/python/pyiree/binding.cc +++ /dev/null
@@ -1,46 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/bindings/python/pyiree/binding.h" - -#include "iree/bindings/python/pyiree/compiler.h" -#include "iree/bindings/python/pyiree/hal.h" -#include "iree/bindings/python/pyiree/initialize.h" -#include "iree/bindings/python/pyiree/rt.h" -#include "iree/bindings/python/pyiree/status_utils.h" -#include "iree/bindings/python/pyiree/vm.h" - -namespace iree { -namespace python { - -PYBIND11_MODULE(binding, m) { - m.doc() = "IREE Binding Backend Helpers"; - py::class_<OpaqueBlob, std::shared_ptr<OpaqueBlob>>(m, "OpaqueBlob"); - m.def("initialize_extension", &InitializeExtension); - - auto compiler_m = m.def_submodule("compiler", "IREE compiler support"); - SetupCompilerBindings(compiler_m); - - auto hal_m = m.def_submodule("hal", "IREE HAL support"); - SetupHalBindings(hal_m); - - auto rt_m = m.def_submodule("rt", "IREE RT api"); - SetupRtBindings(rt_m); - - auto vm_m = m.def_submodule("vm", "IREE VM api"); - SetupVmBindings(vm_m); -} - -} // namespace python -} // namespace iree
diff --git a/iree/bindings/python/pyiree/binding.h b/iree/bindings/python/pyiree/binding.h deleted file mode 100644 index 47ee323..0000000 --- a/iree/bindings/python/pyiree/binding.h +++ /dev/null
@@ -1,147 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BINDINGS_PYTHON_PYIREE_BINDING_H_ -#define IREE_BINDINGS_PYTHON_PYIREE_BINDING_H_ - -#include <vector> - -#include "absl/types/optional.h" -#include "iree/base/api.h" -#include "pybind11/pybind11.h" -#include "pybind11/stl.h" - -namespace pybind11 { -namespace detail { -#if !defined(ABSL_HAVE_STD_OPTIONAL) -// Make absl::optional act like the future C++17 optional for pybind11. -// If ABSL_HAVE_STD_OPTIONAL is defined then absl::optional == std::optional -// and the default type caster is sufficient. -template <typename T> -struct type_caster<absl::optional<T>> : optional_caster<absl::optional<T>> {}; -#endif -} // namespace detail -} // namespace pybind11 - -namespace iree { -namespace python { - -namespace py = pybind11; - -// Wrapper around a blob of memory. -// Used to transport blobs back and forth between C++ and Python. -class OpaqueBlob { - public: - OpaqueBlob() : data_(nullptr), size_(0) {} - OpaqueBlob(void* data, size_t size) : data_(data), size_(size) {} - virtual ~OpaqueBlob() = default; - - void* data() { return data_; } - const void* data() const { return data_; } - size_t size() const { return size_; } - - // Create a free function from the OpaqueBlob shared pointer. - using BufferFreeFn = void (*)(void* self, iree_byte_span_t); - static std::pair<BufferFreeFn, void*> CreateFreeFn( - std::shared_ptr<OpaqueBlob> blob) { - // Note that there are more efficient ways to write this which - // don't bounce through an extra heap alloc, but this is not - // intended to be a high impact code path. - struct Holder { - std::shared_ptr<OpaqueBlob> blob; - }; - Holder* holder = new Holder{std::move(blob)}; - auto free_fn = +([](void* self, iree_byte_span_t) { - Holder* self_holder = static_cast<Holder*>(self); - delete self_holder; - }); - return {free_fn, holder}; - } - - protected: - void* data_; - size_t size_; -}; - -// Opaque blob that owns a vector. -class OpaqueByteVectorBlob : public OpaqueBlob { - public: - OpaqueByteVectorBlob(std::vector<uint8_t> v) - : OpaqueBlob(), v_(std::move(v)) { - data_ = v_.data(); - size_ = v_.size(); - } - - private: - std::vector<uint8_t> v_; -}; - -template <typename T> -struct ApiPtrAdapter {}; - -template <typename Self, typename T> -class ApiRefCounted { - public: - ApiRefCounted() : instance_(nullptr) {} - ApiRefCounted(ApiRefCounted&& other) : instance_(other.instance_) { - other.instance_ = nullptr; - } - void operator=(const ApiRefCounted&) = delete; - - ~ApiRefCounted() { Release(); } - - // Creates an instance of the ref counted wrapper based on an instance - // that has already been retained. Ownership is transferred to the - // wrapper. - static Self CreateRetained(T* retained_inst) { - auto self = Self(); - self.instance_ = retained_inst; - return self; - } - - // Creates a new instance, retaining the underlying object. - static Self RetainAndCreate(T* non_retained_inst) { - auto self = Self(); - self.instance_ = non_retained_inst; - if (non_retained_inst) { - ApiPtrAdapter<T>::Retain(non_retained_inst); - } - return self; - } - - T* raw_ptr() { - if (!instance_) { - throw std::invalid_argument("API object is null"); - } - return instance_; - } - void Retain() { - if (instance_) { - ApiPtrAdapter<T>::Retain(instance_); - } - } - void Release() { - if (instance_) { - ApiPtrAdapter<T>::Release(instance_); - } - } - - private: - T* instance_; -}; - -} // namespace python -} // namespace iree - -#endif // IREE_BINDINGS_PYTHON_PYIREE_BINDING_H_
diff --git a/iree/bindings/python/pyiree/compiler.cc b/iree/bindings/python/pyiree/compiler.cc deleted file mode 100644 index c093d56..0000000 --- a/iree/bindings/python/pyiree/compiler.cc +++ /dev/null
@@ -1,93 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/bindings/python/pyiree/compiler.h" - -#include <stdexcept> - -#include "iree/bindings/python/pyiree/binding.h" -#include "iree/bindings/python/pyiree/initialize.h" -#include "iree/bindings/python/pyiree/status_utils.h" -#include "iree/compiler/Translation/Sequencer/SequencerModuleTranslation.h" -#include "iree/schemas/module_def_generated.h" -#include "llvm/Support/SourceMgr.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Module.h" -#include "mlir/Parser.h" - -namespace py = pybind11; - -using namespace mlir; -using namespace mlir::iree_compiler; - -using llvm::MemoryBuffer; -using llvm::MemoryBufferRef; -using llvm::StringRef; - -namespace iree { -namespace python { - -namespace { - -OwningModuleRef parseMLIRModuleFromString(StringRef contents, - MLIRContext* context) { - std::unique_ptr<MemoryBuffer> contents_buffer; - if (contents.back() == 0) { - // If it has a nul terminator, just use as-is. - contents_buffer = MemoryBuffer::getMemBuffer(contents.drop_back()); - } else { - // Otherwise, make a copy. - contents_buffer = MemoryBuffer::getMemBufferCopy(contents, "EMBED"); - } - - llvm::SourceMgr source_mgr; - source_mgr.AddNewSourceBuffer(std::move(contents_buffer), llvm::SMLoc()); - OwningModuleRef mlir_module = parseSourceFile(source_mgr, context); - return mlir_module; -} - -} // namespace - -std::shared_ptr<OpaqueBlob> CompileModuleFromAsm(const std::string& moduleAsm) { - InitializeExtension({}); - - MLIRContext context; - - // Arrange to get a view that includes a terminating null to avoid additional - // copy. - const char* moduleAsmChars = moduleAsm.c_str(); - StringRef moduleAsmSr(moduleAsmChars, moduleAsm.size() + 1); - - // TODO(laurenzo): This error handling is super hoaky. Hook into the MLIR - // error reporter and plumb through properly. - OwningModuleRef mlirModule = parseMLIRModuleFromString(moduleAsmSr, &context); - if (!mlirModule) { - throw std::runtime_error("Failed to parse MLIR asm"); - } - - auto moduleBlob = - mlir::iree_compiler::translateMlirToIreeSequencerModule(mlirModule.get()); - if (moduleBlob.empty()) { - throw std::runtime_error("Failed to translate MLIR module"); - } - return std::make_shared<OpaqueByteVectorBlob>(std::move(moduleBlob)); -} - -void SetupCompilerBindings(pybind11::module m) { - m.def("compile_module_from_asm", CompileModuleFromAsm); -} - -} // namespace python -} // namespace iree
diff --git a/iree/bindings/python/pyiree/compiler.h b/iree/bindings/python/pyiree/compiler.h deleted file mode 100644 index faa65b7..0000000 --- a/iree/bindings/python/pyiree/compiler.h +++ /dev/null
@@ -1,30 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BINDINGS_PYTHON_PYIREE_COMPILER_H_ -#define IREE_BINDINGS_PYTHON_PYIREE_COMPILER_H_ - -#include <string> - -#include "iree/bindings/python/pyiree/binding.h" - -namespace iree { -namespace python { - -void SetupCompilerBindings(pybind11::module m); - -} // namespace python -} // namespace iree - -#endif // IREE_BINDINGS_PYTHON_PYIREE_COMPILER_H_
diff --git a/iree/bindings/python/pyiree/hal.cc b/iree/bindings/python/pyiree/hal.cc deleted file mode 100644 index 648822b..0000000 --- a/iree/bindings/python/pyiree/hal.cc +++ /dev/null
@@ -1,135 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/bindings/python/pyiree/hal.h" - -#include "iree/hal/api.h" - -namespace iree { -namespace python { - -namespace { - -class HalMappedMemory { - public: - HalMappedMemory(iree_hal_mapped_memory_t mapped_memory, - iree_hal_buffer_view_t* bv) - : mapped_memory_(mapped_memory), bv_(bv) { - iree_hal_buffer_view_retain(bv_); - } - ~HalMappedMemory() { - if (bv_) { - iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(bv_); - CHECK_EQ(iree_hal_buffer_unmap(buffer, &mapped_memory_), IREE_STATUS_OK); - iree_hal_buffer_view_release(bv_); - } - } - HalMappedMemory(HalMappedMemory&& other) - : mapped_memory_(other.mapped_memory_), bv_(other.bv_) { - other.bv_ = nullptr; - } - - static HalMappedMemory Create(HalBufferView& bv) { - iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(bv.raw_ptr()); - iree_device_size_t byte_length = iree_hal_buffer_byte_length(buffer); - iree_hal_mapped_memory_t mapped_memory; - CheckApiStatus(iree_hal_buffer_map(buffer, IREE_HAL_MEMORY_ACCESS_READ, - 0 /* element_offset */, byte_length, - &mapped_memory), - "Could not map memory"); - return HalMappedMemory(mapped_memory, bv.raw_ptr()); - } - - py::buffer_info ToBufferInfo() { - iree_hal_buffer_t* buffer = iree_hal_buffer_view_buffer(bv_); - iree_shape_t shape; - CheckApiStatus(iree_hal_buffer_view_shape(bv_, &shape), - "Error getting buffer view shape"); - int8_t element_size = iree_hal_buffer_view_element_size(bv_); - iree_device_size_t byte_length = iree_hal_buffer_byte_length(buffer); - absl::InlinedVector<ssize_t, IREE_SHAPE_MAX_RANK> dims; - dims.resize(shape.rank); - for (int i = 0; i < shape.rank; ++i) { - dims[i] = shape.dims[i]; - } - absl::InlinedVector<ssize_t, IREE_SHAPE_MAX_RANK> strides; - strides.resize(shape.rank); - for (int i = 1; i < shape.rank; ++i) { - strides[i - 1] = shape.dims[i] * element_size; - } - if (!strides.empty()) { - strides.back() = 1 * element_size; - } - - // TODO(laurenzo): We need to figure out how to propagate dtype in the - // buffer view. - return py::buffer_info( - mapped_memory_.contents.data, element_size, - py::format_descriptor<float>::format(), // TODO(laurenzo): DTYPE! - shape.rank, dims, strides); - } - - private: - iree_hal_mapped_memory_t mapped_memory_; - iree_hal_buffer_view_t* bv_; -}; - -} // namespace - -void SetupHalBindings(pybind11::module m) { - // Enums. - py::enum_<iree_hal_memory_type_t>(m, "MemoryType") - .value("NONE", IREE_HAL_MEMORY_TYPE_NONE) - .value("TRANSIENT", IREE_HAL_MEMORY_TYPE_TRANSIENT) - .value("HOST_VISIBLE", IREE_HAL_MEMORY_TYPE_HOST_VISIBLE) - .value("HOST_COHERENT", IREE_HAL_MEMORY_TYPE_HOST_COHERENT) - .value("HOST_CACHED", IREE_HAL_MEMORY_TYPE_HOST_CACHED) - .value("HOST_LOCAL", IREE_HAL_MEMORY_TYPE_HOST_LOCAL) - .value("DEVICE_VISIBLE", IREE_HAL_MEMORY_TYPE_HOST_VISIBLE) - .value("DEVICE_LOCAL", IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL) - .export_values(); - py::enum_<iree_hal_buffer_usage_t>(m, "BufferUsage") - .value("NONE", IREE_HAL_BUFFER_USAGE_NONE) - .value("CONSTANT", IREE_HAL_BUFFER_USAGE_CONSTANT) - .value("TRANSFER", IREE_HAL_BUFFER_USAGE_TRANSFER) - .value("MAPPING", IREE_HAL_BUFFER_USAGE_MAPPING) - .value("DISPATCH", IREE_HAL_BUFFER_USAGE_DISPATCH) - .value("ALL", IREE_HAL_BUFFER_USAGE_ALL) - .export_values(); - py::enum_<iree_hal_memory_access_t>(m, "MemoryAccess") - .value("NONE", IREE_HAL_MEMORY_ACCESS_NONE) - .value("READ", IREE_HAL_MEMORY_ACCESS_READ) - .value("WRITE", IREE_HAL_MEMORY_ACCESS_WRITE) - .value("DISCARD", IREE_HAL_MEMORY_ACCESS_DISCARD) - .value("DISCARD_WRITE", IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE) - .value("ALL", IREE_HAL_MEMORY_ACCESS_ALL) - .export_values(); - - py::class_<HalShape>(m, "Shape").def(py::init(&HalShape::FromIntVector)); - py::class_<HalBufferView>(m, "BufferView") - .def("map", HalMappedMemory::Create); - py::class_<HalMappedMemory>(m, "MappedMemory", py::buffer_protocol()) - .def_buffer(&HalMappedMemory::ToBufferInfo); - py::class_<HalBuffer>(m, "Buffer") - .def_static("allocate_heap", &HalBuffer::AllocateHeapBuffer, - py::arg("memory_type"), py::arg("usage"), - py::arg("allocation_size")) - .def("fill_zero", &HalBuffer::FillZero, py::arg("byte_offset"), - py::arg("byte_length")) - .def("create_view", &HalBuffer::CreateView, py::arg("shape"), - py::arg("element_size")); -} - -} // namespace python -} // namespace iree
diff --git a/iree/bindings/python/pyiree/hal.h b/iree/bindings/python/pyiree/hal.h deleted file mode 100644 index eab644e..0000000 --- a/iree/bindings/python/pyiree/hal.h +++ /dev/null
@@ -1,97 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BINDINGS_PYTHON_PYIREE_HAL_H_ -#define IREE_BINDINGS_PYTHON_PYIREE_HAL_H_ - -#include "iree/bindings/python/pyiree/binding.h" -#include "iree/bindings/python/pyiree/status_utils.h" -#include "iree/hal/api.h" - -namespace iree { -namespace python { - -template <> -struct ApiPtrAdapter<iree_hal_buffer_t> { - static void Retain(iree_hal_buffer_t* b) { iree_hal_buffer_retain(b); } - static void Release(iree_hal_buffer_t* b) { iree_hal_buffer_release(b); } -}; - -template <> -struct ApiPtrAdapter<iree_hal_buffer_view_t> { - static void Retain(iree_hal_buffer_view_t* bv) { - iree_hal_buffer_view_retain(bv); - } - static void Release(iree_hal_buffer_view_t* bv) { - iree_hal_buffer_view_release(bv); - } -}; - -struct HalShape { - public: - static HalShape FromIntVector(std::vector<int32_t> indices) { - if (indices.size() > IREE_SHAPE_MAX_RANK) { - throw RaiseValueError("Shape exceeded maximum rank"); - } - HalShape s; - s.s.rank = indices.size(); - for (size_t i = 0, e = indices.size(); i < e; ++i) { - s.s.dims[i] = indices[i]; - } - return s; - } - - iree_shape_t s; -}; - -class HalBufferView - : public ApiRefCounted<HalBufferView, iree_hal_buffer_view_t> { - public: -}; - -class HalBuffer : public ApiRefCounted<HalBuffer, iree_hal_buffer_t> { - public: - static HalBuffer AllocateHeapBuffer(int32_t memory_type, int32_t usage, - iree_host_size_t allocation_size) { - iree_hal_buffer_t* buffer = nullptr; - CheckApiStatus( - iree_hal_heap_buffer_allocate( - static_cast<iree_hal_memory_type_t>(memory_type), - static_cast<iree_hal_buffer_usage_t>(usage), allocation_size, - IREE_ALLOCATOR_DEFAULT, IREE_ALLOCATOR_DEFAULT, &buffer), - "Error allocating heap buffer"); - return HalBuffer::CreateRetained(buffer); - } - - void FillZero(iree_device_size_t byte_offset, - iree_device_size_t byte_length) { - CheckApiStatus(iree_hal_buffer_zero(raw_ptr(), byte_offset, byte_length), - "Error zero filling buffer"); - } - - HalBufferView CreateView(HalShape& shape, size_t element_size) { - iree_hal_buffer_view_t* bv; - CheckApiStatus(iree_hal_buffer_view_create(raw_ptr(), shape.s, element_size, - IREE_ALLOCATOR_DEFAULT, &bv), - "Error creating buffer view"); - return HalBufferView::CreateRetained(bv); - } -}; - -void SetupHalBindings(pybind11::module m); - -} // namespace python -} // namespace iree - -#endif // IREE_BINDINGS_PYTHON_PYIREE_HAL_H_
diff --git a/iree/bindings/python/pyiree/initialize.cc b/iree/bindings/python/pyiree/initialize.cc deleted file mode 100644 index 2d4f48e..0000000 --- a/iree/bindings/python/pyiree/initialize.cc +++ /dev/null
@@ -1,53 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/bindings/python/pyiree/initialize.h" - -#include <string.h> - -#include <mutex> // NOLINT - -#include "iree/base/init.h" - -namespace iree { -namespace python { - -namespace { - -void InternalInitialize(const std::vector<std::string>& arguments) { - int argc = arguments.size() + 1; // plus one for program name. - char** argv = static_cast<char**>( - malloc(sizeof(char*) * (argc + 1))); // plus one for null terminator. - char** orig_argv = argv; - argv[0] = strdup("<python_extension>"); - for (int i = 1; i < argc; ++i) { - argv[i] = strdup(arguments[i - 1].c_str()); - } - argv[argc] = nullptr; - InitializeEnvironment(&argc, &argv); - for (int i = 0; i < argc; ++i) { - free(argv[i]); - } - free(orig_argv); -} - -} // namespace - -void InitializeExtension(const std::vector<std::string>& arguments) { - static std::once_flag init_once; - std::call_once(init_once, InternalInitialize, arguments); -} - -} // namespace python -} // namespace iree
diff --git a/iree/bindings/python/pyiree/rt.cc b/iree/bindings/python/pyiree/rt.cc deleted file mode 100644 index f09f20d..0000000 --- a/iree/bindings/python/pyiree/rt.cc +++ /dev/null
@@ -1,150 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/bindings/python/pyiree/rt.h" - -#include "iree/base/api.h" -#include "iree/bindings/python/pyiree/status_utils.h" -#include "iree/hal/api.h" - -namespace iree { -namespace python { - -HalBufferView RtContext::WrapPyBufferForInput(py::buffer py_buffer) { - auto py_buffer_info = py_buffer.request(false /* writable */); - if (py_buffer_info.ndim > IREE_SHAPE_MAX_RANK || py_buffer_info.ndim < 0) { - RaiseValueError("Unsupported buffer rank"); - } - if (py_buffer_info.size < 0) { - RaiseValueError("Illegal buffer size"); - } - - // For the moment, allocate a device visible buffer of equivalent size and - // copy into it. - // TODO(laurenzo): Once sequencer is in place, switch to HeapBuffer, wrap - // and retain the original buffer. - iree_host_size_t byte_size = py_buffer_info.size * py_buffer_info.itemsize; - HalBuffer buffer = - AllocateDeviceVisible(byte_size, IREE_HAL_BUFFER_USAGE_CONSTANT | - IREE_HAL_BUFFER_USAGE_TRANSFER | - IREE_HAL_BUFFER_USAGE_DISPATCH); - CheckApiStatus(iree_hal_buffer_write_data(buffer.raw_ptr(), 0, - py_buffer_info.ptr, byte_size), - "Error writing to input buffer"); - - // Create the buffer view. - // TODO(laurenzo): This does no validation on dtype and only cares if the - // elementsize matches. Figure out where to enforce actual dtype. - iree_shape_t shape; - shape.rank = py_buffer_info.ndim; - - // Verify strides are row-major. - // TODO(laurenzo): Test this with rank>1. - for (int i = 1; i < shape.rank; ++i) { - if ((py_buffer_info.strides[i - 1] * py_buffer_info.itemsize) != - py_buffer_info.shape[i]) { - RaiseValueError("Expected row-major layout"); - } - } - if (!py_buffer_info.strides.empty()) { - if (py_buffer_info.strides.back() != 1) { - RaiseValueError("Expected row-major layout"); - } - } - - // Populate shape. - for (int i = 0; i < shape.rank; ++i) { - ssize_t dim = py_buffer_info.shape[i]; - if (dim < 0) { - RaiseValueError("Unsupported negative dim"); - } - shape.dims[i] = dim; - } - - iree_hal_buffer_view_t* bv; - CheckApiStatus(iree_hal_buffer_view_create(buffer.raw_ptr(), shape, - py_buffer_info.itemsize, - IREE_ALLOCATOR_DEFAULT, &bv), - "Error allocating buffer view"); - - return HalBufferView::CreateRetained(bv); -} - -void SetupRtBindings(pybind11::module m) { - // BufferPlacement. - py::enum_<BufferPlacement>(m, "BufferPlacement") - .value("HEAP", BufferPlacement::kHeap) - .value("DEVICE_VISIBLE", BufferPlacement::kDeviceVisible) - .value("DEVICE_LOCAL", BufferPlacement::kDeviceLocal) - .export_values(); - - // RtModule. - py::class_<RtModule>(m, "Module") - .def_property_readonly("name", &RtModule::name) - .def("lookup_function_by_ordinal", &RtModule::lookup_function_by_ordinal) - .def("lookup_function_by_name", &RtModule::lookup_function_by_name); - // RtFunction. - py::class_<RtFunction>(m, "Function") - .def_property_readonly("name", &RtFunction::name) - .def_property_readonly("signature", &RtFunction::signature); - py::class_<iree_rt_function_signature_t>(m, "FunctionSignature") - .def_readonly("argument_count", - &iree_rt_function_signature_t::argument_count) - .def_readonly("result_count", - &iree_rt_function_signature_t::result_count); - - // RtPolicy. - py::class_<RtPolicy>(m, "Policy").def(py::init(&RtPolicy::Create)); - - // RtInstance. - py::class_<RtInstance>(m, "Instance") - .def(py::init(&RtInstance::Create), - py::arg_v("driver_name", absl::optional<std::string>())); - - // RtContext. - py::class_<RtContext>(m, "Context") - .def(py::init(&RtContext::Create), py::arg("instance"), py::arg("policy")) - .def_property_readonly("context_id", &RtContext::context_id) - .def("register_modules", &RtContext::RegisterModules, py::arg("modules")) - .def("register_module", &RtContext::RegisterModule, py::arg("module")) - .def("lookup_module_by_name", &RtContext::LookupModuleByName, - py::arg("name")) - .def("resolve_function", &RtContext::ResolveFunction, - py::arg("full_name")) - .def("allocate", &RtContext::Allocate, py::arg("allocation_size"), - py::arg("placement") = BufferPlacement::kHeap, - py::arg("usage") = IREE_HAL_BUFFER_USAGE_ALL) - .def("allocate_device_local", &RtContext::AllocateDeviceLocal, - py::arg("allocation_size"), - py::arg("usage") = IREE_HAL_BUFFER_USAGE_ALL) - .def("allocate_device_visible", &RtContext::AllocateDeviceVisible, - py::arg("allocation_size"), - py::arg("usage") = IREE_HAL_BUFFER_USAGE_ALL) - .def("wrap_for_input", &RtContext::WrapPyBufferForInput, py::arg("v")) - .def("invoke", &RtContext::Invoke, py::arg("f"), py::arg("policy"), - py::arg("arguments"), - py::arg("results") = absl::optional<std::vector<HalBufferView*>>()); - - // RtInvocation. - py::class_<RtInvocation>(m, "Invocation") - .def("query_status", &RtInvocation::QueryStatus) - .def("await", &RtInvocation::Await, - py::arg("deadline") = IREE_TIME_INFINITE_FUTURE) - .def("await_optional", &RtInvocation::AwaitOptional, - py::arg("deadline") = IREE_TIME_INFINITE_FUTURE) - .def_property_readonly("results", &RtInvocation::ConsumeResults); -} - -} // namespace python -} // namespace iree
diff --git a/iree/bindings/python/pyiree/rt.h b/iree/bindings/python/pyiree/rt.h deleted file mode 100644 index 2f1dbd0..0000000 --- a/iree/bindings/python/pyiree/rt.h +++ /dev/null
@@ -1,390 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BINDINGS_PYTHON_PYIREE_RT_H_ -#define IREE_BINDINGS_PYTHON_PYIREE_RT_H_ - -#include "absl/container/inlined_vector.h" -#include "iree/base/api.h" -#include "iree/bindings/python/pyiree/binding.h" -#include "iree/bindings/python/pyiree/hal.h" -#include "iree/bindings/python/pyiree/initialize.h" -#include "iree/bindings/python/pyiree/status_utils.h" -#include "iree/hal/api.h" -#include "iree/rt/api.h" - -namespace iree { -namespace python { - -// When creating a buffer via the context, switch between the different -// allocation entry-points via an enum (these are separate functions in the -// C API). -enum class BufferPlacement { - kHeap, - kDeviceVisible, - kDeviceLocal, -}; - -// Adapts API pointer access to retain/release API calls. -template <> -struct ApiPtrAdapter<iree_rt_module_t> { - static void Retain(iree_rt_module_t* m) { iree_rt_module_retain(m); } - static void Release(iree_rt_module_t* m) { iree_rt_module_release(m); } -}; - -template <> -struct ApiPtrAdapter<iree_rt_instance_t> { - static void Retain(iree_rt_instance_t* inst) { - iree_rt_instance_retain(inst); - } - static void Release(iree_rt_instance_t* inst) { - iree_rt_instance_release(inst); - } -}; - -template <> -struct ApiPtrAdapter<iree_rt_policy_t> { - static void Retain(iree_rt_policy_t* p) { iree_rt_policy_retain(p); } - static void Release(iree_rt_policy_t* p) { iree_rt_policy_release(p); } -}; - -template <> -struct ApiPtrAdapter<iree_rt_context_t> { - static void Retain(iree_rt_context_t* c) { iree_rt_context_retain(c); } - static void Release(iree_rt_context_t* c) { iree_rt_context_release(c); } -}; - -template <> -struct ApiPtrAdapter<iree_rt_invocation_t> { - static void Retain(iree_rt_invocation_t* inv) { - iree_rt_invocation_retain(inv); - } - static void Release(iree_rt_invocation_t* inv) { - iree_rt_invocation_release(inv); - } -}; - -// Wrapper classes. These mirror the Python declarations. -class RtFunction { - public: - // Note that this will retain the module. - RtFunction(iree_rt_function_t function) : function_(function) { - iree_rt_module_retain(function_.module); - } - ~RtFunction() { - if (function_.module) iree_rt_module_release(function_.module); - } - RtFunction(RtFunction&& other) : function_(other.function_) { - other.function_.module = nullptr; - } - void operator=(const RtFunction&) = delete; - - std::string name() { - auto sv = iree_rt_function_name(&function_); - return std::string(sv.data, sv.size); - } - - iree_rt_function_signature_t signature() { - iree_rt_function_signature_t sig; - CheckApiStatus(iree_rt_function_signature(&function_, &sig), - "Error getting function signature"); - return sig; - } - - iree_rt_function_t& raw_function() { return function_; } - - private: - iree_rt_function_t function_; -}; - -class RtModule : public ApiRefCounted<RtModule, iree_rt_module_t> { - public: - std::string name() { - auto sv = iree_rt_module_name(raw_ptr()); - return std::string(sv.data, sv.size); - } - - absl::optional<RtFunction> lookup_function_by_ordinal(int32_t ordinal) { - iree_rt_function_t f; - // TODO(laurenzo): Support an optional linkage argument. - auto module = raw_ptr(); - auto status = iree_rt_module_lookup_function_by_ordinal( - module, IREE_RT_FUNCTION_LINKAGE_EXPORT, ordinal, &f); - if (status == IREE_STATUS_NOT_FOUND) { - return absl::optional<RtFunction>(); - } - CheckApiStatus(status, "Error looking up function"); - return RtFunction(f); - } - - absl::optional<RtFunction> lookup_function_by_name(const std::string& name) { - iree_rt_function_t f; - // TODO(laurenzo): Support an optional linkage argument. - auto module = raw_ptr(); - iree_string_view_t name_sv{name.data(), name.size()}; - auto status = iree_rt_module_lookup_function_by_name( - module, IREE_RT_FUNCTION_LINKAGE_EXPORT, name_sv, &f); - if (status == IREE_STATUS_NOT_FOUND) { - return absl::optional<RtFunction>(); - } - CheckApiStatus(status, "Error looking up function"); - return RtFunction(f); - } -}; - -class RtInstance : public ApiRefCounted<RtInstance, iree_rt_instance_t> { - public: - // TODO(laurenzo): Support optional allocator argument. - static RtInstance Create(absl::optional<std::string> driver_name) { - InitializeExtension({}); - iree_rt_instance_t* raw_inst; - CheckApiStatus(iree_rt_instance_create(IREE_ALLOCATOR_DEFAULT, &raw_inst), - "Error creating instance"); - RtInstance inst = RtInstance::CreateRetained(raw_inst); - - if (!driver_name) { - driver_name = "interpreter"; - } - CheckApiStatus(iree_rt_instance_register_driver_ex( - raw_inst, iree_string_view_t{driver_name->c_str(), - driver_name->size()}), - "Error registering drivers"); - - return inst; - } -}; - -class RtPolicy : public ApiRefCounted<RtPolicy, iree_rt_policy_t> { - public: - // TODO(laurenzo): Support optional allocator argument. - static RtPolicy Create() { - iree_rt_policy_t* policy; - CheckApiStatus(iree_rt_policy_create(IREE_ALLOCATOR_DEFAULT, &policy), - "Error creating policy"); - return RtPolicy::CreateRetained(policy); - } -}; - -class RtInvocation : public ApiRefCounted<RtInvocation, iree_rt_invocation_t> { - public: - // Returns whether ready. - // Raises exception on error. - bool QueryStatus() { - auto status = iree_rt_invocation_query_status(raw_ptr()); - if (status == IREE_STATUS_OK) { - return true; - } else if (status == IREE_STATUS_UNAVAILABLE) { - return false; - } else { - CheckApiStatus(status, "Error in function invocation"); - return false; - } - } - - // TODO(laurenzo): Convert to the pybind chrono support. - // Returns whether the invocation is ready. - bool AwaitOptional(iree_time_t epoch_nanos_deadline) { - auto status = iree_rt_invocation_await(raw_ptr(), epoch_nanos_deadline); - if (status == IREE_STATUS_OK) { - return true; - } else if (status == IREE_STATUS_DEADLINE_EXCEEDED) { - return false; - } else { - CheckApiStatus(status, "Error in invocation"); - return false; - } - } - - // Similar to AwaitOptional but will raise an error unless if the status - // is ready. - void Await(iree_time_t epoch_nanos_deadline) { - if (!AwaitOptional(epoch_nanos_deadline)) { - RaiseValueError("Deadline expired"); - } - } - - std::vector<HalBufferView> ConsumeResults() { - static constexpr size_t kInlineSize = 8; - iree_host_size_t result_count; - absl::InlinedVector<iree_hal_buffer_view_t*, kInlineSize> result_bvs; - result_bvs.resize(kInlineSize); - auto status = iree_rt_invocation_consume_results( - raw_ptr(), kInlineSize, IREE_ALLOCATOR_DEFAULT, &result_bvs[0], - &result_count); - if (status == IREE_STATUS_OUT_OF_RANGE) { - // Resize/retry. - result_bvs.resize(result_count); - status = iree_rt_invocation_consume_results( - raw_ptr(), result_count, IREE_ALLOCATOR_DEFAULT, &result_bvs[0], - &result_count); - } - CheckApiStatus(status, "Error consuming invocation results"); - result_bvs.resize(result_count); - std::vector<HalBufferView> results; - for (auto* raw_bv : result_bvs) { - results.push_back(HalBufferView::CreateRetained(raw_bv)); - } - return results; - } -}; - -class RtContext : public ApiRefCounted<RtContext, iree_rt_context_t> { - public: - static RtContext Create(RtInstance* instance, RtPolicy* policy) { - iree_rt_context_t* context; - // TODO(laurenzo): Support optional allocator argument. - CheckApiStatus( - iree_rt_context_create(instance->raw_ptr(), policy->raw_ptr(), - IREE_ALLOCATOR_DEFAULT, &context), - "Error creating instance"); - return RtContext::CreateRetained(context); - } - - int context_id() { return iree_rt_context_id(raw_ptr()); } - - void RegisterModules(std::vector<RtModule*> modules) { - std::vector<iree_rt_module_t*> module_raw_ptrs; - module_raw_ptrs.resize(modules.size()); - for (size_t i = 0, e = modules.size(); i < e; ++i) { - auto module_raw_ptr = modules[i]->raw_ptr(); - module_raw_ptrs[i] = module_raw_ptr; - } - CheckApiStatus( - iree_rt_context_register_modules(raw_ptr(), module_raw_ptrs.data(), - module_raw_ptrs.size()), - "Error registering modules"); - } - - void RegisterModule(RtModule* module) { - iree_rt_module_t* module_raw_ptr = module->raw_ptr(); - CheckApiStatus( - iree_rt_context_register_modules(raw_ptr(), &module_raw_ptr, 1), - "Error registering module"); - } - - absl::optional<RtModule> LookupModuleByName(const std::string& name) { - iree_rt_module_t* module = iree_rt_context_lookup_module_by_name( - raw_ptr(), {name.data(), name.size()}); - if (!module) { - return absl::optional<RtModule>(); - } - return RtModule::RetainAndCreate(module); - } - - absl::optional<RtFunction> ResolveFunction(const std::string& full_name) { - iree_rt_function_t f; - auto status = iree_rt_context_resolve_function( - raw_ptr(), {full_name.data(), full_name.size()}, &f); - if (status == IREE_STATUS_NOT_FOUND) { - return absl::optional<RtFunction>(); - } - CheckApiStatus(status, "Error resolving function"); - return RtFunction(f); - } - - // Convenience method to allocate host, device-visible or device-local - // buffers. - HalBuffer Allocate(iree_host_size_t allocation_size, - BufferPlacement placement, int32_t usage) { - iree_hal_buffer_t* raw_buffer = nullptr; - switch (placement) { - case BufferPlacement::kHeap: - // Even though allocating a heap buffer does not require the context, - // provide it here to make the API easier to navigate. - CheckApiStatus( - iree_hal_heap_buffer_allocate( - IREE_HAL_MEMORY_TYPE_HOST_LOCAL, - static_cast<iree_hal_buffer_usage_t>(usage), allocation_size, - IREE_ALLOCATOR_DEFAULT, IREE_ALLOCATOR_DEFAULT, &raw_buffer), - "Error allocating heap buffer"); - break; - case BufferPlacement::kDeviceLocal: - CheckApiStatus( - iree_rt_context_allocate_device_local_buffer( - raw_ptr(), static_cast<iree_hal_buffer_usage_t>(usage), - allocation_size, IREE_ALLOCATOR_DEFAULT, &raw_buffer), - "Error allocating device local buffer"); - break; - case BufferPlacement::kDeviceVisible: - CheckApiStatus( - iree_rt_context_allocate_device_visible_buffer( - raw_ptr(), static_cast<iree_hal_buffer_usage_t>(usage), - allocation_size, IREE_ALLOCATOR_DEFAULT, &raw_buffer), - "Error allocating device visible buffer"); - break; - default: - throw RaiseValueError("Unknown BufferPlacement"); - } - - return HalBuffer::CreateRetained(raw_buffer); - } - - HalBuffer AllocateHeap(iree_host_size_t allocation_size, int32_t usage) { - return Allocate(allocation_size, BufferPlacement::kHeap, usage); - } - - HalBuffer AllocateDeviceLocal(iree_host_size_t allocation_size, - int32_t usage) { - return Allocate(allocation_size, BufferPlacement::kDeviceLocal, usage); - } - - HalBuffer AllocateDeviceVisible(iree_host_size_t allocation_size, - int32_t usage) { - return Allocate(allocation_size, BufferPlacement::kDeviceVisible, usage); - } - - // One stop convenience method for wrapping a python buffer protocol buffer - // for input to a function. At the runtime's discretion, this may make a copy - // or do something smarter, meaning the data in the backing python buffer - // will either be accessed immediately or at some future point. - HalBufferView WrapPyBufferForInput(py::buffer py_buffer); - - RtInvocation Invoke(RtFunction& f, RtPolicy& policy, - std::vector<HalBufferView*> arguments, - absl::optional<std::vector<HalBufferView*>> results) { - absl::InlinedVector<iree_hal_buffer_view_t*, 8> raw_arguments; - raw_arguments.resize(arguments.size()); - for (size_t i = 0, e = arguments.size(); i < e; ++i) { - auto inst = arguments[i]; - CheckApiNotNull(inst, "Argument buffer view cannot be None"); - raw_arguments[i] = inst->raw_ptr(); - } - absl::InlinedVector<iree_hal_buffer_view_t*, 8> raw_results; - if (results) { - raw_results.resize(results->size()); - for (size_t i = 0, e = results->size(); i < e; ++i) { - auto inst = (*results)[i]; - CheckApiNotNull(inst, "Result buffer view cannot be None"); - raw_results[i] = inst->raw_ptr(); - } - } - - iree_rt_invocation_t* invocation; - CheckApiStatus(iree_rt_invocation_create( - raw_ptr(), &f.raw_function(), policy.raw_ptr(), - nullptr /* dependencies */, raw_arguments.data(), - raw_arguments.size(), raw_results.data(), - raw_results.size(), IREE_ALLOCATOR_DEFAULT, &invocation), - "Error invoking function"); - - return RtInvocation::CreateRetained(invocation); - } -}; - -void SetupRtBindings(pybind11::module m); - -} // namespace python -} // namespace iree - -#endif // IREE_BINDINGS_PYTHON_PYIREE_RT_H_
diff --git a/iree/bindings/python/pyiree/status_utils.cc b/iree/bindings/python/pyiree/status_utils.cc deleted file mode 100644 index a8c5008..0000000 --- a/iree/bindings/python/pyiree/status_utils.cc +++ /dev/null
@@ -1,72 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/bindings/python/pyiree/status_utils.h" - -#include "absl/strings/str_cat.h" - -namespace iree { -namespace python { - -namespace { - -PyObject* StatusToPyExcClass(const Status& status) { - switch (status.code()) { - case StatusCode::kInvalidArgument: - return PyExc_ValueError; - case StatusCode::kOutOfRange: - return PyExc_IndexError; - case StatusCode::kUnimplemented: - return PyExc_NotImplementedError; - default: - return PyExc_RuntimeError; - } -} - -PyObject* ApiStatusToPyExcClass(iree_status_t status) { - switch (status) { - case IREE_STATUS_INVALID_ARGUMENT: - return PyExc_ValueError; - case IREE_STATUS_OUT_OF_RANGE: - return PyExc_IndexError; - case IREE_STATUS_UNIMPLEMENTED: - return PyExc_NotImplementedError; - default: - return PyExc_RuntimeError; - } -} - -} // namespace - -pybind11::error_already_set StatusToPyExc(const Status& status) { - assert(!status.ok()); - PyErr_SetString(StatusToPyExcClass(status), status.error_message().c_str()); - return pybind11::error_already_set(); -} - -pybind11::error_already_set ApiStatusToPyExc(iree_status_t status, - const char* message) { - assert(status != IREE_STATUS_OK); - auto full_message = absl::StrCat(message, ": ", static_cast<int>(status)); - PyErr_SetString(ApiStatusToPyExcClass(status), full_message.c_str()); - return pybind11::error_already_set(); -} - -pybind11::error_already_set RaiseValueError(const char* message) { - PyErr_SetString(PyExc_ValueError, message); - return pybind11::error_already_set(); -} - -} // namespace python -} // namespace iree
diff --git a/iree/bindings/python/pyiree/status_utils.h b/iree/bindings/python/pyiree/status_utils.h deleted file mode 100644 index ec12b01..0000000 --- a/iree/bindings/python/pyiree/status_utils.h +++ /dev/null
@@ -1,67 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BINDINGS_PYTHON_PYIREE_STATUS_UTILS_H_ -#define IREE_BINDINGS_PYTHON_PYIREE_STATUS_UTILS_H_ - -#include "iree/base/api.h" -#include "iree/base/status.h" -#include "pybind11/pytypes.h" - -namespace iree { -namespace python { - -// Converts a failing status to a throwable exception, setting Python -// error information. -// Correct usage is something like: -// if (!status.ok()) { -// throw StatusToPyExc(status); -// } -pybind11::error_already_set StatusToPyExc(const Status& status); - -// Raises a value error with the given message. -// Correct usage: -// throw RaiseValueError("Foobar'd"); -pybind11::error_already_set RaiseValueError(const char* message); - -// Consumes a StatusOr<T>, returning an rvalue reference to the T if the -// status is ok(). Otherwise, throws an exception. -template <typename T> -T&& PyConsumeStatusOr(iree::StatusOr<T>&& sor) { - if (sor.ok()) { - return std::move(*sor); - } - throw StatusToPyExc(sor.status()); -} - -pybind11::error_already_set ApiStatusToPyExc(iree_status_t status, - const char* message); - -static void CheckApiStatus(iree_status_t status, const char* message) { - if (status == IREE_STATUS_OK) { - return; - } - throw ApiStatusToPyExc(status, message); -} - -static void CheckApiNotNull(const void* p, const char* message) { - if (!p) { - throw RaiseValueError(message); - } -} - -} // namespace python -} // namespace iree - -#endif // IREE_BINDINGS_PYTHON_PYIREE_STATUS_UTILS_H_
diff --git a/iree/bindings/python/pyiree/vm.cc b/iree/bindings/python/pyiree/vm.cc deleted file mode 100644 index c97ebce..0000000 --- a/iree/bindings/python/pyiree/vm.cc +++ /dev/null
@@ -1,37 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/bindings/python/pyiree/vm.h" - -#include "iree/bindings/python/pyiree/status_utils.h" - -namespace iree { -namespace python { - -RtModule CreateModuleFromBlob(std::shared_ptr<OpaqueBlob> blob) { - iree_rt_module_t* module; - auto free_fn = OpaqueBlob::CreateFreeFn(blob); - auto status = iree_vm_bytecode_module_create_from_buffer( - {static_cast<const uint8_t*>(blob->data()), blob->size()}, free_fn.first, - free_fn.second, IREE_ALLOCATOR_DEFAULT, &module); - CheckApiStatus(status, "Error creating vm module from blob"); - return RtModule::CreateRetained(module); -} - -void SetupVmBindings(pybind11::module m) { - m.def("create_module_from_blob", CreateModuleFromBlob); -} - -} // namespace python -} // namespace iree
diff --git a/iree/bindings/python/pyiree/vm.h b/iree/bindings/python/pyiree/vm.h deleted file mode 100644 index fcb7982..0000000 --- a/iree/bindings/python/pyiree/vm.h +++ /dev/null
@@ -1,30 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_BINDINGS_PYTHON_PYIREE_VM_H_ -#define IREE_BINDINGS_PYTHON_PYIREE_VM_H_ - -#include "iree/bindings/python/pyiree/binding.h" -#include "iree/bindings/python/pyiree/rt.h" -#include "iree/vm/api.h" - -namespace iree { -namespace python { - -void SetupVmBindings(pybind11::module m); - -} // namespace python -} // namespace iree - -#endif // IREE_BINDINGS_PYTHON_PYIREE_VM_H_
diff --git a/iree/build_defs.bzl b/iree/build_defs.bzl deleted file mode 100644 index d42ebcf..0000000 --- a/iree/build_defs.bzl +++ /dev/null
@@ -1,130 +0,0 @@ -"""Common Bazel definitions for IREE.""" - -load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") -load("@iree_native_python//:build_defs.bzl", "py_extension") -load("@iree_core//build_tools/third_party/glslang:build_defs.bzl", "glsl_vulkan") -load("@rules_python//python:defs.bzl", "py_library") - -NUMPY_DEPS = [] - -def platform_trampoline_deps(basename): - """Produce a list of deps for the given `basename` platform target. - - Example: - "file_mapping" -> ["//iree/base/internal/file_mapping_internal"] - - This is used for compatibility with various methods of including the - library in foreign source control systems. - - Args: - basename: Library name prefix for a library in base/internal. - Returns: - A list of dependencies for depending on the library in a platform - sensitive way. - """ - return [ - "//iree/base/internal:%s_internal" % basename, - ] - -# A platform-sensitive list of copts for the Vulkan loader. -PLATFORM_VULKAN_LOADER_COPTS = select({ - "//iree/hal/vulkan:native_vk": [], - "//iree/hal/vulkan:swiftshader_vk": [], - "//conditions:default": [], -}) - -# A platform-sensitive list of dependencies for non-test targets using Vulkan. -PLATFORM_VULKAN_DEPS = select({ - "//iree/hal/vulkan:native_vk": [], - "//iree/hal/vulkan:swiftshader_vk": [], - "//conditions:default": [], -}) - -# A platform-sensitive list of dependencies for tests using Vulkan. -PLATFORM_VULKAN_TEST_DEPS = [ - "@com_google_googletest//:gtest_main", -] - -def iree_py_library(**kwargs): - """Compatibility py_library which has bazel compatible args.""" - - # This is used when args are needed that are incompatible with upstream. - # Presently, this includes: - # imports - py_library(**kwargs) - -def iree_py_extension(deps = [], **kwargs): - """Delegates to the real py_extension.""" - py_extension( - deps = ["@iree_native_python//:python_headers"] + deps, - **kwargs - ) - -def iree_build_test(name, targets): - """Dummy rule to ensure that targets build. - - This is currently undefined in bazel and is preserved for compatibility. - """ - pass - -def iree_setup_lit_package(data): - """Should be called once per test package that contains globbed lit tests. - - Args: - data: Additional, project specific data deps to add. - """ - - # Bundle together all of the test utilities that are used by tests. - native.filegroup( - name = "lit_test_utilities", - testonly = True, - data = data + [ - "@llvm//:FileCheck", - ], - ) - -def iree_glob_lit_tests( - data = [":lit_test_utilities"], - driver = "//iree/tools:run_lit.sh", - test_file_exts = ["mlir"]): - """Globs lit test files into tests for a package. - - For most packages, the defaults suffice. Packages that include this must - also include a call to iree_setup_lit_package(). - - Args: - data: Data files to include/build. - driver: Test driver. - test_file_exts: File extensions to glob. - """ - for test_file_ext in test_file_exts: - test_files = native.glob([ - "*.%s" % (test_file_ext,), - "**/*.%s" % (test_file_ext,), - ]) - for test_file in test_files: - test_file_location = "$(location %s)" % (test_file,) - native.sh_test( - name = "%s.test" % (test_file,), - size = "small", - srcs = [driver], - data = data + [test_file], - args = [test_file_location], - ) - -# The OSS build currently has issues with generating flatbuffer reflections. -# It is hard-coded to disabled here (and in iree_flatbuffer_cc_library) until triaged/fixed. -FLATBUFFER_SUPPORTS_REFLECTIONS = False - -def iree_flatbuffer_cc_library(**kwargs): - """Wrapper for the flatbuffer_cc_library.""" - - # TODO(laurenzo): The bazel rule for reflections seems broken in OSS - # builds. Fix it and enable by default. - flatbuffer_cc_library( - gen_reflections = False, - **kwargs - ) - -def iree_glsl_vulkan(**kwargs): - glsl_vulkan(**kwargs)
diff --git a/iree/compiler/IR/ConfigOps.cpp b/iree/compiler/IR/ConfigOps.cpp deleted file mode 100644 index 3ec129a..0000000 --- a/iree/compiler/IR/ConfigOps.cpp +++ /dev/null
@@ -1,110 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/ConfigOps.h" - -#include "iree/compiler/IR/StructureOps.h" -#include "iree/compiler/IR/Types.h" -#include "llvm/ADT/SmallString.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/SymbolTable.h" -#include "mlir/IR/Value.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/STLExtras.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { - -//===----------------------------------------------------------------------===// -// Generic printers and parsers. -//===----------------------------------------------------------------------===// - -// Parses an op that has no inputs and no outputs. -static ParseResult parseNoIOOp(OpAsmParser &parser, OperationState &state) { - if (failed(parser.parseOptionalAttributeDict(state.attributes))) { - return failure(); - } - return success(); -} - -// Prints an op that has no inputs and no outputs. -static void printNoIOOp(Operation *op, OpAsmPrinter &printer) { - printer << op->getName(); - printer.printOptionalAttrDict(op->getAttrs()); -} - -//===----------------------------------------------------------------------===// -// iree.target_config -//===----------------------------------------------------------------------===// - -void ExecutableTargetConfigOp::build(Builder *builder, OperationState &state, - std::string backend) { - state.addAttribute("backend", builder->getStringAttr(backend)); - ensureTerminator(*state.addRegion(), *builder, state.location); -} - -static ParseResult parseExecutableTargetConfigOp(OpAsmParser &parser, - OperationState &state) { - llvm::SMLoc backendLoc; - StringAttr backendAttr; - if (failed(parser.parseLParen()) || - failed(parser.getCurrentLocation(&backendLoc)) || - failed(parser.parseAttribute(backendAttr, "backend", state.attributes))) { - return failure(); - } - - Region *body = state.addRegion(); - if (failed(parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))) { - return failure(); - } - if (succeeded(parser.parseOptionalKeyword("attributes"))) { - if (failed(parser.parseOptionalAttributeDict(state.attributes))) { - return failure(); - } - } - - ExecutableTargetConfigOp::ensureTerminator(*body, parser.getBuilder(), - state.location); - - return success(); -} - -static void printExecutableTargetConfigOp(OpAsmPrinter &printer, - ExecutableTargetConfigOp op) { - printer << op.getOperationName() << "(" << op.backend() << ")"; - - printer.printRegion(op.body(), /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/false); - - // Print out executable attributes, if present. - SmallVector<StringRef, 1> ignoredAttrs = { - "backend", - }; - if (op.getAttrs().size() > ignoredAttrs.size()) { - printer << "\n attributes "; - printer.printOptionalAttrDict(op.getAttrs(), ignoredAttrs); - } -} - -#define GET_OP_CLASSES -#include "iree/compiler/IR/ConfigOps.cpp.inc" - -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/IR/ConfigOps.h b/iree/compiler/IR/ConfigOps.h deleted file mode 100644 index 1cef1f6..0000000 --- a/iree/compiler/IR/ConfigOps.h +++ /dev/null
@@ -1,38 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_COMPILER_IR_CONFIGOPS_H_ -#define IREE_COMPILER_IR_CONFIGOPS_H_ - -#include <cstdint> - -#include "iree/compiler/IR/Types.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/StandardTypes.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { - -#define GET_OP_CLASSES -#include "iree/compiler/IR/ConfigOps.h.inc" - -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_IR_CONFIGOPS_H_
diff --git a/iree/compiler/IR/ConfigOps.td b/iree/compiler/IR/ConfigOps.td deleted file mode 100644 index 21681c5..0000000 --- a/iree/compiler/IR/ConfigOps.td +++ /dev/null
@@ -1,44 +0,0 @@ -// Ops used to declare configuration used by the IREE compiler. -// These allow inline config that follows along the IR they are associated with. -// Multiple config ops are allowed within a single scope to indicate that the -// parent IR node should be processed for multiple targets. - -#ifdef IREE_CONFIG_OPS -#else -#define IREE_CONFIG_OPS - -include "iree/compiler/IR/OpBase.td" - -class IREE_ConfigOp<string mnemonic, list<OpTrait> traits = []> : - Op<IREE_Dialect, mnemonic, traits> { - let parser = [{ return parse$cppClass(parser, result); }]; - let printer = [{ print$cppClass(p, *this); }]; -} - -//===----------------------------------------------------------------------===// -// iree.executable configuration -//===----------------------------------------------------------------------===// - -def IREE_ExecutableTargetConfigOp : IREE_ConfigOp<"target_config", [ - IREE_ExecutableOnly, - SingleBlockImplicitTerminator<"ExecutableTargetConfigEndOp"> -]> { - let arguments = (ins - StrAttr:$backend - ); - - let regions = (region SizedRegion<1>:$body); - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<"Builder *builder, OperationState &state, std::string backend">, - ]; -} - -def IREE_ExecutableTargetConfigEndOp : - IREE_ConfigOp<"_target_config_end", [Terminator, IREE_ExecutableTargetConfigOnly]> { - let parser = [{ return parseNoIOOp(parser, result); }]; - let printer = [{ printNoIOOp(getOperation(), p); }]; -} - -#endif // IREE_CONFIG_OPS
diff --git a/iree/compiler/IR/Dialect.cpp b/iree/compiler/IR/Dialect.cpp deleted file mode 100644 index 0946319..0000000 --- a/iree/compiler/IR/Dialect.cpp +++ /dev/null
@@ -1,90 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Dialect.h" - -#include "iree/compiler/IR/ConfigOps.h" -#include "iree/compiler/IR/Ops.h" -#include "iree/compiler/IR/StructureOps.h" -#include "iree/compiler/IR/Types.h" -#include "llvm/Support/SourceMgr.h" - -namespace mlir { -namespace iree_compiler { - -static DialectRegistration<IREEDialect> iree_dialect; - -IREEDialect::IREEDialect(MLIRContext *context) - : Dialect(getDialectNamespace(), context) { -#define IREE_ADD_TYPE(NAME, KIND, TYPE) addTypes<TYPE>(); - IREE_TYPE_TABLE(IREE_ADD_TYPE); - -#define GET_OP_LIST - addOperations< -#include "iree/compiler/IR/Ops.cpp.inc" - >(); -#define GET_OP_LIST - addOperations< -#include "iree/compiler/IR/ConfigOps.cpp.inc" - >(); -#define GET_OP_LIST - addOperations< -#include "iree/compiler/IR/StructureOps.cpp.inc" - >(); -} - -//===----------------------------------------------------------------------===// -// Type Parsing -//===----------------------------------------------------------------------===// - -#define IREE_TYPE_PARSER(NAME, KIND, TYPE) \ - static Type parse##TYPE(IREEDialect const &dialect, StringRef spec, \ - Location loc) { \ - spec.consume_front(NAME); \ - return TYPE::get(dialect.getContext()); \ - } -IREE_TYPE_TABLE(IREE_TYPE_PARSER); - -#define IREE_PARSE_TYPE(NAME, KIND, TYPE) \ - if (spec.startswith(NAME)) { \ - return parse##TYPE(*this, spec, loc); \ - } -Type IREEDialect::parseType(StringRef spec, Location loc) const { - IREE_TYPE_TABLE(IREE_PARSE_TYPE); - emitError(loc, "unknown IREE type: ") << spec; - return Type(); -} - -//===----------------------------------------------------------------------===// -// Type Printing -//===----------------------------------------------------------------------===// - -#define IREE_TYPE_PRINTER(NAME, KIND, TYPE) \ - static void print##TYPE(TYPE type, llvm::raw_ostream &os) { os << NAME; } -IREE_TYPE_TABLE(IREE_TYPE_PRINTER); - -#define IREE_PRINT_TYPE(NAME, KIND, TYPE) \ - case KIND: \ - print##TYPE(type.cast<TYPE>(), os); \ - return; -void IREEDialect::printType(Type type, llvm::raw_ostream &os) const { - switch (type.getKind()) { - IREE_TYPE_TABLE(IREE_PRINT_TYPE); - default: - llvm_unreachable("unhandled IREE type"); - } -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/IR/Interpreter/BUILD b/iree/compiler/IR/Interpreter/BUILD deleted file mode 100644 index 63a063d..0000000 --- a/iree/compiler/IR/Interpreter/BUILD +++ /dev/null
@@ -1,75 +0,0 @@ -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -load("@local_config_mlir//:tblgen.bzl", "gentbl") - -filegroup( - name = "td_files", - srcs = glob(["*.td"]), -) - -cc_library( - name = "Interpreter", - srcs = [ - "HLDialect.cpp", - "HLOps.cpp", - "HLOps.cpp.inc", - "LLDialect.cpp", - "LLOps.cpp", - "LLOps.cpp.inc", - "OpWriters.cpp", - ], - hdrs = [ - "HLDialect.h", - "HLOps.h", - "HLOps.h.inc", - "LLDialect.h", - "LLOps.h", - "LLOps.h.inc", - "OpWriters.h", - ], - deps = [ - ":HLOpsGen", - ":LLOpsGen", - "//iree/compiler/IR", - "//iree/compiler/Serialization", - "//iree/compiler/Utils", - "//iree/schemas/bytecode:interpreter_bytecode_v0", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:StandardOps", - ], - alwayslink = 1, -) - -gentbl( - name = "HLOpsGen", - tbl_outs = [ - ("-gen-op-decls", "HLOps.h.inc"), - ("-gen-op-defs", "HLOps.cpp.inc"), - ], - tblgen = "@local_config_mlir//:mlir-tblgen", - td_file = "HLOps.td", - td_srcs = [ - ":td_files", - "@local_config_mlir//:include/mlir/IR/OpBase.td", - "//iree/compiler/IR:OpBase.td", - ], -) - -gentbl( - name = "LLOpsGen", - tbl_outs = [ - ("-gen-op-decls", "LLOps.h.inc"), - ("-gen-op-defs", "LLOps.cpp.inc"), - ], - tblgen = "@local_config_mlir//:mlir-tblgen", - td_file = "LLOps.td", - td_srcs = [ - ":td_files", - "@local_config_mlir//:include/mlir/IR/OpBase.td", - "//iree/compiler/IR:OpBase.td", - ], -)
diff --git a/iree/compiler/IR/Interpreter/HLDialect.cpp b/iree/compiler/IR/Interpreter/HLDialect.cpp deleted file mode 100644 index 55b3fbb..0000000 --- a/iree/compiler/IR/Interpreter/HLDialect.cpp +++ /dev/null
@@ -1,34 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Interpreter/HLDialect.h" - -#include "iree/compiler/IR/Interpreter/HLOps.h" -#include "llvm/Support/SourceMgr.h" - -namespace mlir { -namespace iree_compiler { - -IREEHLInterpreterDialect::IREEHLInterpreterDialect(MLIRContext* context) - : Dialect(getDialectNamespace(), context) { -#define GET_OP_LIST - addOperations< -#include "iree/compiler/IR/Interpreter/HLOps.cpp.inc" - >(); -} - -static DialectRegistration<IREEHLInterpreterDialect> iree_hl_interp_dialect; - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/IR/Interpreter/HLOps.cpp b/iree/compiler/IR/Interpreter/HLOps.cpp deleted file mode 100644 index c708787..0000000 --- a/iree/compiler/IR/Interpreter/HLOps.cpp +++ /dev/null
@@ -1,246 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Interpreter/HLOps.h" - -#include "iree/compiler/IR/Ops.h" -#include "iree/compiler/Utils/OpCreationUtils.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Function.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/PatternMatch.h" - -namespace mlir { -namespace iree_compiler { -namespace IREEInterp { -namespace HL { - -//===----------------------------------------------------------------------===// -// iree_hl_interp.call -//===----------------------------------------------------------------------===// - -static ParseResult parseCallOp(OpAsmParser &parser, OperationState &state) { - SymbolRefAttr calleeAttr; - FunctionType calleeType; - SmallVector<OpAsmParser::OperandType, 4> operands; - auto calleeLoc = parser.getNameLoc(); - if (parser.parseAttribute(calleeAttr, "callee", state.attributes) || - parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || - parser.parseOptionalAttributeDict(state.attributes) || - parser.parseColonType(calleeType) || - parser.addTypesToList(calleeType.getResults(), state.types) || - parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc, - state.operands)) { - return failure(); - } - return success(); -} - -static void printCallOp(OpAsmPrinter &p, CallOp op) { - p << "iree_hl_interp.call " << op.getAttr("callee") << '('; - p.printOperands(op.getOperands()); - p << ')'; - p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); - p << " : "; - p.printType(op.getCalleeType()); -} - -FunctionType CallOp::getCalleeType() { - SmallVector<Type, 4> resultTypes(getResultTypes()); - SmallVector<Type, 8> argTypes(getOperandTypes()); - return FunctionType::get(argTypes, resultTypes, getContext()); -} - -//===----------------------------------------------------------------------===// -// iree_hl_interp.call_indirect -//===----------------------------------------------------------------------===// - -static ParseResult parseCallIndirectOp(OpAsmParser &parser, - OperationState &result) { - FunctionType calleeType; - OpAsmParser::OperandType callee; - llvm::SMLoc operandsLoc; - SmallVector<OpAsmParser::OperandType, 4> operands; - return failure( - parser.parseOperand(callee) || parser.getCurrentLocation(&operandsLoc) || - parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || - parser.parseOptionalAttributeDict(result.attributes) || - parser.parseColonType(calleeType) || - parser.resolveOperand(callee, calleeType, result.operands) || - parser.resolveOperands(operands, calleeType.getInputs(), operandsLoc, - result.operands) || - parser.addTypesToList(calleeType.getResults(), result.types)); -} - -static void printCallIndirectOp(OpAsmPrinter &p, CallIndirectOp op) { - p << "iree_hl_interp.call_indirect "; - p.printOperand(op.getCallee()); - p << '('; - auto operandRange = op.getOperands(); - p.printOperands(++operandRange.begin(), operandRange.end()); - p << ')'; - p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); - p << " : " << op.getCallee()->getType(); -} - -//===----------------------------------------------------------------------===// -// iree_hl_interp.return -//===----------------------------------------------------------------------===// - -static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &state) { - SmallVector<OpAsmParser::OperandType, 2> opInfo; - SmallVector<Type, 2> types; - llvm::SMLoc loc = parser.getCurrentLocation(); - return failure(parser.parseOperandList(opInfo) || - (!opInfo.empty() && parser.parseColonTypeList(types)) || - parser.resolveOperands(opInfo, types, loc, state.operands)); -} - -static void printReturnOp(OpAsmPrinter &p, ReturnOp op) { - p << "iree_hl_interp.return"; - if (op.getNumOperands() > 0) { - p << ' '; - p.printOperands(op.operand_begin(), op.operand_end()); - p << " : "; - interleaveComma(op.getOperandTypes(), p); - } -} - -//===----------------------------------------------------------------------===// -// iree_hl_interp.br -//===----------------------------------------------------------------------===// - -static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &result) { - Block *dest; - SmallVector<Value *, 4> destOperands; - if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure(); - result.addSuccessor(dest, destOperands); - return success(); -} - -static void printBranchOp(OpAsmPrinter &p, BranchOp op) { - p << "iree_hl_interp.br "; - p.printSuccessorAndUseList(op.getOperation(), 0); -} - -Block *BranchOp::getDest() { return getOperation()->getSuccessor(0); } - -void BranchOp::setDest(Block *block) { - return getOperation()->setSuccessor(block, 0); -} - -void BranchOp::eraseOperand(unsigned index) { - getOperation()->eraseSuccessorOperand(0, index); -} - -//===----------------------------------------------------------------------===// -// iree_hl_interp.cond_br -//===----------------------------------------------------------------------===// - -static ParseResult parseCondBranchOp(OpAsmParser &parser, - OperationState &result) { - SmallVector<Value *, 4> destOperands; - Block *dest; - OpAsmParser::OperandType condInfo; - - // Parse the condition. - Type int1Ty = parser.getBuilder().getI1Type(); - if (parser.parseOperand(condInfo) || parser.parseComma() || - parser.resolveOperand(condInfo, int1Ty, result.operands)) { - return parser.emitError(parser.getNameLoc(), - "expected condition type was boolean (i1)"); - } - - // Parse the true successor. - if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure(); - result.addSuccessor(dest, destOperands); - - // Parse the false successor. - destOperands.clear(); - if (parser.parseComma() || - parser.parseSuccessorAndUseList(dest, destOperands)) - return failure(); - result.addSuccessor(dest, destOperands); - - return success(); -} - -static void printCondBranchOp(OpAsmPrinter &p, CondBranchOp op) { - p << "iree_hl_interp.cond_br "; - p.printOperand(op.getCondition()); - p << ", "; - p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex); - p << ", "; - p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex); -} - -//===----------------------------------------------------------------------===// -// iree_hl_interp.clone -//===----------------------------------------------------------------------===// - -OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) { - // If this is the only usage, we know the clone is unnecessary. - // TODO(b/135053584) More sophisticated analysis. - if (src()->hasOneUse()) return src(); - return {}; -} - -//===----------------------------------------------------------------------===// -// iree_hl_interp.concat -//===----------------------------------------------------------------------===// - -namespace { -struct ConcatToCopies : public OpRewritePattern<ConcatOp> { - using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(ConcatOp concatOp, - PatternRewriter &rewriter) const override { - auto finalType = concatOp.getResult()->getType().cast<ShapedType>(); - auto loc = concatOp.getLoc(); - std::vector<Value *> dimPieces; - auto dst = - rewriter.create<IREEInterp::HL::AllocHeapOp>(loc, finalType, dimPieces); - - llvm::SmallVector<int64_t, 4> zeroOffset(finalType.getRank(), 0); - auto srcIndices = createArrayConstant(rewriter, loc, zeroOffset); - - auto concatDimension = concatOp.dimension().getZExtValue(); - llvm::SmallVector<int64_t, 4> dstIndices(finalType.getRank(), 0); - for (auto *src : concatOp.srcs()) { - auto srcShape = src->getType().cast<ShapedType>().getShape(); - auto lengths = createArrayConstant(rewriter, loc, srcShape); - auto dstIndicesOp = createArrayConstant(rewriter, loc, dstIndices); - rewriter.create<IREEInterp::HL::CopyOp>(loc, src, srcIndices, dst, - dstIndicesOp, lengths); - dstIndices[concatDimension] += srcShape[concatDimension]; - } - - concatOp.replaceAllUsesWith(dst.getResult()); - - return matchSuccess(); - } -}; -} // namespace - -void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert<ConcatToCopies>(context); -} - -#define GET_OP_CLASSES -#include "iree/compiler/IR/Interpreter/HLOps.cpp.inc" - -} // namespace HL -} // namespace IREEInterp -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/IR/Interpreter/HLOps.h b/iree/compiler/IR/Interpreter/HLOps.h deleted file mode 100644 index e3a4d23..0000000 --- a/iree/compiler/IR/Interpreter/HLOps.h +++ /dev/null
@@ -1,38 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_COMPILER_IR_INTERPRETER_HLOPS_H_ -#define IREE_COMPILER_IR_INTERPRETER_HLOPS_H_ - -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/TypeUtilities.h" - -namespace mlir { -namespace iree_compiler { -namespace IREEInterp { -namespace HL { - -#define GET_OP_CLASSES -#include "iree/compiler/IR/Interpreter/HLOps.h.inc" - -} // namespace HL -} // namespace IREEInterp -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_IR_INTERPRETER_HLOPS_H_
diff --git a/iree/compiler/IR/Interpreter/HLOps.td b/iree/compiler/IR/Interpreter/HLOps.td deleted file mode 100644 index 3458b45..0000000 --- a/iree/compiler/IR/Interpreter/HLOps.td +++ /dev/null
@@ -1,658 +0,0 @@ -// IREE high-level interpreter op definitions. -// This op set contains pseudo ops, ops that accept non-MemRef types, and ops in -// normal SSA form. -// -// Through lowering these high-level ops are converted to low-level ops in the -// LLOps.td (iree_ll_interp.*). These map 1:1 with the bytecode, -// accept only MemRef types, and generally use output parameters instead of -// return types. -// -// The source of truth for bytecode opcodes is: -// https://github.com/google/iree/tree/master/iree/schemas/bytecode/interpreter_bytecode_v0.h - -#ifdef IREE_INTERPRETER_HL_OPS -#else -#define IREE_INTERPRETER_HL_OPS - -#ifdef IREE_OP_BASE -#else -include "iree/compiler/IR/OpBase.td" -#endif // IREE_OP_BASE - -def IREEInterpHL_Dialect : Dialect { - let name = "iree_hl_interp"; - let cppNamespace = "IREEInterp::HL"; -} - -//===----------------------------------------------------------------------===// -// Base op classes -//===----------------------------------------------------------------------===// - -class IREEInterpHL_Op<string mnemonic, list<OpTrait> traits = []> : - Op<IREEInterpHL_Dialect, mnemonic, traits>; - -class IREEInterpHL_PureOp<string mnemonic, list<OpTrait> traits = []> : - IREEInterpHL_Op<mnemonic, !listconcat(traits, [NoSideEffect])>; - -//===----------------------------------------------------------------------===// -// High-level interpreter ops -//===----------------------------------------------------------------------===// - -def IREEInterpHL_CallOp : IREEInterpHL_Op<"call"> { - let arguments = (ins SymbolRefAttr:$callee, Variadic<IREEHL_MemRef>); - let results = (outs Variadic<IREEHL_MemRef>); - - let builders = [OpBuilder< - "Builder *builder, OperationState &result, FuncOp callee," - "ArrayRef<Value *> operands = {}", [{ - result.addOperands(operands); - result.addAttribute("callee", builder->getSymbolRefAttr(callee)); - result.addTypes(callee.getType().getResults()); - }]>, OpBuilder< - "Builder *builder, OperationState &result, StringRef callee," - "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{ - result.addOperands(operands); - result.addAttribute("callee", builder->getSymbolRefAttr(callee)); - result.addTypes(results); - }]>]; - - let extraClassDeclaration = [{ - // TODO(b/132296600): make tablegen follow the style guide. - StringRef getCallee() { return callee(); } - FunctionType getCalleeType(); - - // TODO(b/133879130): make tablegen support variadic operand accessors. - /// Get the argument operands to the called function. - operand_range getArgOperands() { - return {arg_operand_begin(), arg_operand_end()}; - } - operand_iterator arg_operand_begin() { return operand_begin(); } - operand_iterator arg_operand_end() { return operand_end(); } - }]; - - let parser = [{ return parse$cppClass(parser, result); }]; - let printer = [{ return print$cppClass(p, *this); }]; -} - -def IREEInterpHL_CallIndirectOp : IREEInterpHL_Op<"call_indirect"> { - let arguments = (ins FunctionType:$callee, Variadic<IREEHL_MemRef>:$operands); - let results = (outs Variadic<IREEHL_MemRef>); - - let builders = [OpBuilder< - "Builder *, OperationState &result, Value *callee," - "ArrayRef<Value *> operands = {}", [{ - result.operands.push_back(callee); - result.addOperands(operands); - result.addTypes(callee->getType().cast<FunctionType>().getResults()); - }]>]; - - let extraClassDeclaration = [{ - // TODO(b/132296600): make tablegen follow the style guide. - Value *getCallee() { return getOperand(0); } - - // TODO(b/133879130): make tablegen support variadic operand accessors. - /// Get the argument operands to the called function. - operand_range getArgOperands() { - return {arg_operand_begin(), arg_operand_end()}; - } - operand_iterator arg_operand_begin() { return ++operand_begin(); } - operand_iterator arg_operand_end() { return operand_end(); } - }]; - - let parser = [{ return parse$cppClass(parser, result); }]; - let printer = [{ return print$cppClass(p, *this); }]; -} - -def IREEInterpHL_ReturnOp : IREEInterpHL_Op<"return", [Terminator]> { - let arguments = (ins Variadic<IREEHL_MemRef>:$operands); - - let builders = [OpBuilder< - "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }] - >]; - - let parser = [{ return parse$cppClass(parser, result); }]; - let printer = [{ return print$cppClass(p, *this); }]; -} - -def IREEInterpHL_BranchOp : IREEInterpHL_Op<"br", [Terminator]> { - let arguments = (ins Variadic<IREEHL_MemRef>:$operands); - - let builders = [OpBuilder< - "Builder *, OperationState &result, Block *dest," - "ArrayRef<Value *> operands = {}", [{ - result.addSuccessor(dest, operands); - }]>]; - - let extraClassDeclaration = [{ - Block *getDest(); - void setDest(Block *block); - - /// Erase the operand at 'index' from the operand list. - void eraseOperand(unsigned index); - }]; - - let parser = [{ return parse$cppClass(parser, result); }]; - let printer = [{ return print$cppClass(p, *this); }]; -} - -def IREEInterpHL_CondBranchOp : IREEInterpHL_Op<"cond_br", [Terminator]> { - let arguments = (ins - IREEHL_BoolScalar:$condition, - Variadic<IREEHL_MemRef>:$branchOperands - ); - - let builders = [OpBuilder< - "Builder *, OperationState &result, Value *condition," - "Block *trueDest, ArrayRef<Value *> trueOperands," - "Block *falseDest, ArrayRef<Value *> falseOperands", [{ - result.addOperands(condition); - result.addSuccessor(trueDest, trueOperands); - result.addSuccessor(falseDest, falseOperands); - }]>]; - - let extraClassDeclaration = [{ - // These are the indices into the dests list. - enum { trueIndex = 0, falseIndex = 1 }; - - // The condition operand is the first operand in the list. - Value *getCondition() { return getOperand(0); } - - /// Return the destination if the condition is true. - Block *getTrueDest() { - return getOperation()->getSuccessor(trueIndex); - } - - /// Return the destination if the condition is false. - Block *getFalseDest() { - return getOperation()->getSuccessor(falseIndex); - } - - // Accessors for operands to the 'true' destination. - Value *getTrueOperand(unsigned idx) { - assert(idx < getNumTrueOperands()); - return getOperand(getTrueDestOperandIndex() + idx); - } - - void setTrueOperand(unsigned idx, Value *value) { - assert(idx < getNumTrueOperands()); - setOperand(getTrueDestOperandIndex() + idx, value); - } - - operand_iterator true_operand_begin() { - return operand_begin() + getTrueDestOperandIndex(); - } - operand_iterator true_operand_end() { - return true_operand_begin() + getNumTrueOperands(); - } - operand_range getTrueOperands() { - return {true_operand_begin(), true_operand_end()}; - } - - unsigned getNumTrueOperands() { - return getOperation()->getNumSuccessorOperands(trueIndex); - } - - /// Erase the operand at 'index' from the true operand list. - void eraseTrueOperand(unsigned index) { - getOperation()->eraseSuccessorOperand(trueIndex, index); - } - - // Accessors for operands to the 'false' destination. - Value *getFalseOperand(unsigned idx) { - assert(idx < getNumFalseOperands()); - return getOperand(getFalseDestOperandIndex() + idx); - } - void setFalseOperand(unsigned idx, Value *value) { - assert(idx < getNumFalseOperands()); - setOperand(getFalseDestOperandIndex() + idx, value); - } - - operand_iterator false_operand_begin() { return true_operand_end(); } - operand_iterator false_operand_end() { - return false_operand_begin() + getNumFalseOperands(); - } - operand_range getFalseOperands() { - return {false_operand_begin(), false_operand_end()}; - } - - unsigned getNumFalseOperands() { - return getOperation()->getNumSuccessorOperands(falseIndex); - } - - /// Erase the operand at 'index' from the false operand list. - void eraseFalseOperand(unsigned index) { - getOperation()->eraseSuccessorOperand(falseIndex, index); - } - - private: - /// Get the index of the first true destination operand. - unsigned getTrueDestOperandIndex() { return 1; } - - /// Get the index of the first false destination operand. - unsigned getFalseDestOperandIndex() { - return getTrueDestOperandIndex() + getNumTrueOperands(); - } - }]; - - let parser = [{ return parse$cppClass(parser, result); }]; - let printer = [{ return print$cppClass(p, *this); }]; -} - -def IREEInterpHL_CmpIOp : - IREEInterpHL_PureOp<"cmp_i", [SameOperandsAndResultShape, - AllTypesMatch<["lhs", "rhs"]>]> { - let arguments = (ins - I32Attr:$predicate, - IREEHL_IntMemRef:$lhs, - IREEHL_IntMemRef:$rhs - ); - let results = (outs IREEHL_BoolMemRef); -} - -def IREEInterpHL_CmpFOp : - IREEInterpHL_PureOp<"cmp_f", [SameOperandsAndResultShape, - AllTypesMatch<["lhs", "rhs"]>]> { - let arguments = (ins - I32Attr:$predicate, - IREEHL_FloatMemRef:$lhs, - IREEHL_FloatMemRef:$rhs - ); - let results = (outs IREEHL_BoolMemRef); -} - -// TODO(b/142012496): Add trait that enables DCE but not CSE. -def IREEInterpHL_AllocHeapOp : IREEInterpHL_Op<"alloc_heap"> { - // TODO(benvanik): attributes and args. - let arguments = (ins - Variadic<IREEHL_MemRef>:$dim_pieces - ); - let results = (outs - IREEHL_MemRef - ); -} - -def IREEInterpHL_DiscardOp : IREEInterpHL_Op<"discard"> { - let arguments = (ins IREEHL_MemRef); -} - -def IREEInterpHL_RankOp : IREEInterpHL_PureOp<"rank"> { - let arguments = (ins IREEHL_MemRef); - let results = (outs IREEHL_IntScalar); -} - -def IREEInterpHL_DimOp : IREEInterpHL_PureOp<"dim"> { - // TODO(benvanik) add dim attr (I32Attr:$dim) - let arguments = (ins IREEHL_MemRef); - let results = (outs IREEHL_IntScalar); -} - -def IREEInterpHL_ShapeOp : IREEInterpHL_PureOp<"shape"> { - let arguments = (ins IREEHL_MemRef); - let results = (outs IREEHL_1DIntMemRef); -} - -def IREEInterpHL_LengthOp : IREEInterpHL_PureOp<"length"> { - let arguments = (ins IREEHL_MemRef); - let results = (outs IREEHL_IndexScalar); -} - -def IREEInterpHL_SliceOp : - IREEInterpHL_PureOp<"slice", [AllElementTypesMatch<["src", "result"]>, - AllTypesMatch<["srcIndices", "lengths"]>]> { - let arguments = (ins - IREEHL_MemRef:$src, - IREEHL_1DIndexMemRef:$srcIndices, - IREEHL_1DIndexMemRef:$lengths - ); - let results = (outs IREEHL_MemRef:$result); -} - -def IREEInterpHL_CopyOp : IREEInterpHL_Op<"copy", [ - AllElementCountsMatch<["srcIndices", "dstIndices", "lengths"]>, - AllRanksMatch<["src", "dst"]>, - // The checks above are redundant with this one, but they give more specific - // error messages. - AllMatch<[ - Rank<"src">.result, - Rank<"dst">.result, - ElementCount<"srcIndices">.result, - ElementCount<"dstIndices">.result, - ElementCount<"lengths">.result - ], "src/dst rank is the same as srcIndices/dstIndices/lengths size">, - AllElementTypesMatch<["src", "dst"]> -]> { - let arguments = (ins - IREEHL_MemRef:$src, - IREEHL_1DIndexMemRef:$srcIndices, - IREEHL_MemRef:$dst, - IREEHL_1DIndexMemRef:$dstIndices, - IREEHL_1DIndexMemRef:$lengths - ); -} - -def IREEInterpHL_CloneOp : - IREEInterpHL_PureOp<"clone", [SameOperandsAndResultType]> { - let arguments = (ins IREEHL_MemRef:$src); - let results = (outs IREEHL_MemRef); - - let hasFolder = 1; -} - -// A pseudo op provided for convenience. This gets canonicalized to a series of -// copies. -def IREEInterpHL_ConcatOp : IREEInterpHL_PureOp<"concat"> { - let arguments = (ins - Variadic<IREEHL_MemRef>:$srcs, - I32Attr:$dimension - ); - let results = (outs IREEHL_MemRef); - - let hasCanonicalizer = 1; -} - -// TODO(benvanik): add split dim/size/etc. Maybe make multiple ops? -def IREEInterpHL_SplitOp : - IREEInterpHL_PureOp<"split", [SameOperandsAndResultElementType]> { - let arguments = (ins IREEHL_MemRef:$src); - let results = (outs Variadic<IREEHL_MemRef>); -} - -def IREEInterpHL_AssignOp : - IREEInterpHL_PureOp<"assign", [SameOperandsAndResultType]> { - let arguments = (ins IREEHL_MemRef:$src); - let results = (outs IREEHL_MemRef:$result); -} - -def IREEInterpHL_CondAssignOp : - IREEInterpHL_PureOp<"cond_assign", - [AllTypesMatch<["lhs", "rhs", "result"]>]> { - let arguments = (ins - IREEHL_BoolScalar:$cond, - IREEHL_MemRef:$lhs, - IREEHL_MemRef:$rhs - ); - let results = (outs IREEHL_MemRef:$result); -} - -def IREEInterpHL_ReshapeOp : IREEInterpHL_PureOp<"reshape"> { - let arguments = (ins IREEHL_MemRef:$src, IREEHL_MemRef:$shape); - let results = (outs IREEHL_MemRef); -} - -def IREEInterpHL_SelectOp : - IREEInterpHL_PureOp<"select", [AllTypesMatch<["lhs", "rhs", "result"]>]> { - let arguments = (ins - IREEHL_BoolMemRef:$cond, - IREEHL_MemRef:$lhs, - IREEHL_MemRef:$rhs - ); - let results = (outs IREEHL_MemRef:$result); -} - -def IREEInterpHL_BroadcastOp : - IREEInterpHL_PureOp<"broadcast", - [AllElementTypesMatch<["operand", "result"]>]> { - let arguments = (ins - IREE_ScalarMemRefOf<[AnyType]>:$operand, - IREEHL_1DIntMemRef:$shape - ); - let results = (outs IREEHL_MemRef:$result); -} - -def IREEInterpHL_PadOp : - IREEInterpHL_PureOp< - "pad", [AllElementTypesMatch<["src", "result", "padding_value"]>]> { - let arguments = (ins - IREEHL_MemRef:$src, - IREEHL_AnyScalar:$padding_value, - IREEHL_1DIndexMemRef:$edge_padding_low, - IREEHL_1DIndexMemRef:$edge_padding_high, - IREEHL_1DIndexMemRef:$interior_padding - ); - - let results = (outs IREEHL_MemRef:$result); -} - -def IREEInterpHL_TileOp : - IREEInterpHL_PureOp<"tile", [AllElementTypesMatch<["operand", "result"]>]> { - let arguments = (ins - IREEHL_MemRef:$operand, - IREEHL_1DIntMemRef:$shape - ); - let results = (outs IREEHL_MemRef:$result); -} - -def IREEInterpHL_TransposeOp : - IREEInterpHL_PureOp<"transpose", [ - AllElementTypesMatch<["operand", "result"]>, - AllRanksMatch<["operand", "result"]>, - AllElementCountsMatch<["operand", "result"]> - ]> { - let arguments = (ins - IREEHL_MemRef:$operand, - IREEHL_1DIntMemRef:$permutation - ); - let results = (outs IREEHL_MemRef:$result); -} - -def IREEInterpHL_ReverseOp : - IREEInterpHL_PureOp<"reverse", [AllTypesMatch<["operand", "result"]>]> { - let arguments = (ins - IREEHL_MemRef:$operand, - IREEHL_1DIntMemRef:$dims - ); - let results = (outs IREEHL_MemRef:$result); -} - -class IREEInterpHL_UnaryElementwiseOp<string mnemonic, Type type, - list<OpTrait> traits = []> : - IREEInterpHL_PureOp<mnemonic, - !listconcat(traits, [SameOperandsAndResultType])> { - let arguments = (ins type); - let results = (outs type); -} - -class IREEInterpHL_UnaryElementwiseFloatOp<string mnemonic, - list<OpTrait> traits = []> : - IREEInterpHL_UnaryElementwiseOp<mnemonic, IREEHL_FloatMemRef, traits>; - -class IREEInterpHL_UnaryElementwiseIntOp<string mnemonic, - list<OpTrait> traits = []> : - IREEInterpHL_UnaryElementwiseOp<mnemonic, IREEHL_IntMemRef, traits>; - -class IREEInterpHL_BinaryElementwiseOp<string mnemonic, Type type, - list<OpTrait> traits> : - IREEInterpHL_PureOp<mnemonic, - !listconcat(traits, [SameOperandsAndResultType])> { - let arguments = (ins type:$lhs, type:$rhs); - let results = (outs type); -} - -class IREEInterpHL_BinaryElementwiseFloatOp<string mnemonic, - list<OpTrait> traits = []> : - IREEInterpHL_BinaryElementwiseOp<mnemonic, IREEHL_FloatMemRef, - traits>; - -class IREEInterpHL_BinaryElementwiseIntOp<string mnemonic, - list<OpTrait> traits = []> : - IREEInterpHL_BinaryElementwiseOp<mnemonic, IREEHL_IntMemRef, - traits>; - -class IREEInterpHL_TernaryOp<string mnemonic, - Type type = IREEHL_MemRef, - list<OpTrait> traits = []> : - IREEInterpHL_PureOp<mnemonic, traits> { - let arguments = (ins type:$a, type:$b, type:$c); - let results = (outs type); -} - -// TODO(benvanik): add traits for broadcasting support. - -def IREEInterpHL_NotOp : IREEInterpHL_UnaryElementwiseIntOp<"not">; -def IREEInterpHL_AndOp : IREEInterpHL_BinaryElementwiseIntOp<"and">; -def IREEInterpHL_OrOp : IREEInterpHL_BinaryElementwiseIntOp<"or">; -def IREEInterpHL_XorOp : IREEInterpHL_BinaryElementwiseIntOp<"xor">; -def IREEInterpHL_ShiftLeftOp : IREEInterpHL_BinaryElementwiseIntOp<"sll">; -def IREEInterpHL_ShiftRightLogicalOp : IREEInterpHL_BinaryElementwiseIntOp<"srl">; -def IREEInterpHL_ShiftRightArithmeticOp : IREEInterpHL_BinaryElementwiseIntOp<"sra">; - -def IREEInterpHL_AddIOp : IREEInterpHL_BinaryElementwiseIntOp<"add_i">; -def IREEInterpHL_AddFOp : IREEInterpHL_BinaryElementwiseFloatOp<"add_f">; -def IREEInterpHL_SubIOp : IREEInterpHL_BinaryElementwiseIntOp<"sub_i">; -def IREEInterpHL_SubFOp : IREEInterpHL_BinaryElementwiseFloatOp<"sub_f">; -def IREEInterpHL_AbsIOp : IREEInterpHL_UnaryElementwiseIntOp<"abs_i">; -def IREEInterpHL_AbsFOp : IREEInterpHL_UnaryElementwiseFloatOp<"abs_f">; -def IREEInterpHL_MulIOp : IREEInterpHL_BinaryElementwiseIntOp<"mul_i">; -def IREEInterpHL_MulFOp : IREEInterpHL_BinaryElementwiseFloatOp<"mul_f">; -def IREEInterpHL_DivISOp : IREEInterpHL_BinaryElementwiseIntOp<"div_i_s">; -def IREEInterpHL_DivIUOp : IREEInterpHL_BinaryElementwiseIntOp<"div_i_u">; -def IREEInterpHL_DivFOp : IREEInterpHL_BinaryElementwiseFloatOp<"div_f">; -def IREEInterpHL_MulAddIOp : IREEInterpHL_TernaryOp<"madd_i", IREEHL_IntMemRef>; -def IREEInterpHL_MulAddFOp : IREEInterpHL_TernaryOp<"madd_f", IREEHL_FloatMemRef>; -def IREEInterpHL_ExpFOp : IREEInterpHL_UnaryElementwiseFloatOp<"exp_f">; -def IREEInterpHL_LogFOp : IREEInterpHL_UnaryElementwiseFloatOp<"log_f">; -def IREEInterpHL_RsqrtFOp : IREEInterpHL_UnaryElementwiseFloatOp<"rsqrt_f">; -def IREEInterpHL_CosFOp : IREEInterpHL_UnaryElementwiseFloatOp<"cos_f">; -def IREEInterpHL_SinFOp : IREEInterpHL_UnaryElementwiseFloatOp<"sin_f">; -def IREEInterpHL_TanhFOp : IREEInterpHL_UnaryElementwiseFloatOp<"tanh_f">; -def IREEInterpHL_Atan2FOp : IREEInterpHL_UnaryElementwiseFloatOp<"atan2_f">; - -def IREEInterpHL_MinISOp : IREEInterpHL_BinaryElementwiseIntOp<"min_i_s">; -def IREEInterpHL_MinIUOp : IREEInterpHL_BinaryElementwiseIntOp<"min_i_u">; -def IREEInterpHL_MinFOp : IREEInterpHL_BinaryElementwiseFloatOp<"min_f">; -def IREEInterpHL_MaxISOp : IREEInterpHL_BinaryElementwiseIntOp<"max_i_s">; -def IREEInterpHL_MaxIUOp : IREEInterpHL_BinaryElementwiseIntOp<"max_i_u">; -def IREEInterpHL_MaxFOp : IREEInterpHL_BinaryElementwiseFloatOp<"max_f">; -def IREEInterpHL_ClampFOp : IREEInterpHL_TernaryOp<"clamp_f", IREEHL_FloatMemRef>; -def IREEInterpHL_FloorFOp : IREEInterpHL_UnaryElementwiseFloatOp<"floor_f">; -def IREEInterpHL_CeilFOp : IREEInterpHL_UnaryElementwiseFloatOp<"ceil_f">; - -class IREEInterpHL_ConversionOp<string mnemonic, Type inputType, - Type outputType> : - IREEInterpHL_PureOp<mnemonic, [SameOperandsAndResultShape]> { - let arguments = (ins inputType); - let results = (outs outputType); -} - -def IREEInterpHL_ConvertSSOp : - IREEInterpHL_ConversionOp<"convert_s_s", IREEHL_IntMemRef, - IREEHL_IntMemRef>; -def IREEInterpHL_ConvertSUOp : - IREEInterpHL_ConversionOp<"convert_s_u", IREEHL_IntMemRef, - IREEHL_IntMemRef>; -def IREEInterpHL_ConvertSFOp : - IREEInterpHL_ConversionOp<"convert_s_f", IREEHL_IntMemRef, - IREEHL_FloatMemRef>; - -def IREEInterpHL_ConvertUSOp : - IREEInterpHL_ConversionOp<"convert_u_s", IREEHL_IntMemRef, - IREEHL_IntMemRef>; -def IREEInterpHL_ConvertUUOp : - IREEInterpHL_ConversionOp<"convert_u_u", IREEHL_IntMemRef, - IREEHL_IntMemRef>; -def IREEInterpHL_ConvertUFOp : - IREEInterpHL_ConversionOp<"convert_u_f", IREEHL_IntMemRef, - IREEHL_FloatMemRef>; - -def IREEInterpHL_ConvertFSOp : - IREEInterpHL_ConversionOp<"convert_f_s", IREEHL_FloatMemRef, - IREEHL_IntMemRef>; -def IREEInterpHL_ConvertFUOp : - IREEInterpHL_ConversionOp<"convert_f_u", IREEHL_FloatMemRef, - IREEHL_IntMemRef>; -def IREEInterpHL_ConvertFFOp : - IREEInterpHL_ConversionOp<"convert_f_f", IREEHL_FloatMemRef, - IREEHL_FloatMemRef>; - -def IREEInterpHL_MatMulIOp : - IREEInterpHL_PureOp<"matmul_i", - [AllElementTypesMatch<["lhs", "rhs", "result"]>]> { - let arguments = (ins - IREEHL_IntMemRef:$lhs, - IREEHL_IntMemRef:$rhs, - IREEHL_IntMemRef:$multiplier_mantissa, - IREEHL_IntMemRef:$multiplier_exponent - ); - let results = (outs IREEHL_IntMemRef:$result); -} -def IREEInterpHL_MatMulFOp : - IREEInterpHL_PureOp<"matmul_f", [SameOperandsAndResultElementType]> { - let arguments = (ins - IREEHL_FloatMemRef:$lhs, - IREEHL_FloatMemRef:$rhs - ); - let results = (outs IREEHL_FloatMemRef); -} - -def IREEInterpHL_ReduceSumIOp : - IREEInterpHL_PureOp<"reduce_sum_i", - [AllElementTypesMatch<["src", "result", "init"]>]> { - let arguments = (ins - IREEHL_IntMemRef:$src, - IREEHL_IntMemRef:$init, - I32Attr:$dimension - ); - let results = (outs IREEHL_IntMemRef:$result); -} -def IREEInterpHL_ReduceSumFOp : - IREEInterpHL_PureOp<"reduce_sum_f", - [AllElementTypesMatch<["src", "result", "init"]>]> { - let arguments = (ins - IREEHL_FloatMemRef:$src, - IREEHL_FloatMemRef:$init, - I32Attr:$dimension - ); - let results = (outs IREEHL_FloatMemRef:$result); -} -def IREEInterpHL_ReduceMinIOp : - IREEInterpHL_PureOp<"reduce_min_i", - [AllElementTypesMatch<["src", "result", "init"]>]> { - let arguments = (ins - IREEHL_IntMemRef:$src, - IREEHL_IntMemRef:$init, - I32Attr:$dimension - ); - let results = (outs IREEHL_IntMemRef:$result); -} -def IREEInterpHL_ReduceMinFOp : - IREEInterpHL_PureOp<"reduce_min_f", - [AllElementTypesMatch<["src", "result", "init"]>]> { - let arguments = (ins - IREEHL_FloatMemRef:$src, - IREEHL_FloatMemRef:$init, - I32Attr:$dimension - ); - let results = (outs IREEHL_FloatMemRef:$result); -} -def IREEInterpHL_ReduceMaxIOp : - IREEInterpHL_PureOp<"reduce_max_i", - [AllElementTypesMatch<["src", "result", "init"]>]> { - let arguments = (ins - IREEHL_IntMemRef:$src, - IREEHL_IntMemRef:$init, - I32Attr:$dimension - ); - let results = (outs IREEHL_IntMemRef:$result); -} -def IREEInterpHL_ReduceMaxFOp : - IREEInterpHL_PureOp<"reduce_max_f", - [AllElementTypesMatch<["src", "result", "init"]>]> { - let arguments = (ins - IREEHL_FloatMemRef:$src, - IREEHL_FloatMemRef:$init, - I32Attr:$dimension - ); - let results = (outs IREEHL_FloatMemRef:$result); -} - -def IREEInterpHL_TraceOp : IREEInterpHL_Op<"trace"> { - let arguments = (ins Variadic<IREEHL_MemRef>:$srcs); -} - -def IREEInterpHL_CondBreakOp : IREEInterpHL_Op<"cond_break"> { - let arguments = (ins IREEHL_BoolScalar:$cond); -} - -def IREEInterpHL_BreakOp : IREEInterpHL_Op<"break">; - -#endif // IREE_INTERPRETER_HL_OPS
diff --git a/iree/compiler/IR/Interpreter/LLDialect.cpp b/iree/compiler/IR/Interpreter/LLDialect.cpp deleted file mode 100644 index ffe4289..0000000 --- a/iree/compiler/IR/Interpreter/LLDialect.cpp +++ /dev/null
@@ -1,34 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Interpreter/LLDialect.h" - -#include "iree/compiler/IR/Interpreter/LLOps.h" -#include "llvm/Support/SourceMgr.h" - -namespace mlir { -namespace iree_compiler { - -IREELLInterpreterDialect::IREELLInterpreterDialect(MLIRContext* context) - : Dialect(getDialectNamespace(), context) { -#define GET_OP_LIST - addOperations< -#include "iree/compiler/IR/Interpreter/LLOps.cpp.inc" - >(); -} - -static DialectRegistration<IREELLInterpreterDialect> iree_ll_interp_dialect; - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/IR/Interpreter/LLOps.cpp b/iree/compiler/IR/Interpreter/LLOps.cpp deleted file mode 100644 index 2fbbb64..0000000 --- a/iree/compiler/IR/Interpreter/LLOps.cpp +++ /dev/null
@@ -1,228 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Interpreter/LLOps.h" - -#include "mlir/IR/Builders.h" -#include "mlir/IR/Function.h" -#include "mlir/IR/OpImplementation.h" - -namespace mlir { -namespace iree_compiler { -namespace IREEInterp { -namespace LL { - -//===----------------------------------------------------------------------===// -// iree_ll_interp.call -//===----------------------------------------------------------------------===// - -static ParseResult parseCallOp(OpAsmParser &parser, OperationState &state) { - SymbolRefAttr calleeAttr; - FunctionType calleeType; - SmallVector<OpAsmParser::OperandType, 4> operands; - auto calleeLoc = parser.getNameLoc(); - if (parser.parseAttribute(calleeAttr, "callee", state.attributes) || - parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || - parser.parseOptionalAttributeDict(state.attributes) || - parser.parseColonType(calleeType) || - parser.addTypesToList(calleeType.getResults(), state.types) || - parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc, - state.operands)) { - return failure(); - } - return success(); -} - -static void printCallOp(OpAsmPrinter &p, CallOp op) { - p << "iree_ll_interp.call " << op.getAttr("callee") << '('; - p.printOperands(op.getOperands()); - p << ')'; - p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); - p << " : "; - p.printType(op.getCalleeType()); -} - -FunctionType CallOp::getCalleeType() { - SmallVector<Type, 4> resultTypes(getResultTypes()); - SmallVector<Type, 8> argTypes(getOperandTypes()); - return FunctionType::get(argTypes, resultTypes, getContext()); -} - -//===----------------------------------------------------------------------===// -// iree_ll_interp.call_import -//===----------------------------------------------------------------------===// - -static ParseResult parseCallImportOp(OpAsmParser &parser, - OperationState &state) { - SymbolRefAttr calleeAttr; - FunctionType calleeType; - SmallVector<OpAsmParser::OperandType, 4> operands; - auto calleeLoc = parser.getNameLoc(); - if (parser.parseAttribute(calleeAttr, "callee", state.attributes) || - parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || - parser.parseOptionalAttributeDict(state.attributes) || - parser.parseColonType(calleeType) || - parser.addTypesToList(calleeType.getResults(), state.types) || - parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc, - state.operands)) { - return failure(); - } - return success(); -} - -static void printCallImportOp(OpAsmPrinter &p, CallImportOp op) { - p << "iree_ll_interp.call_import " << op.getAttr("callee") << '('; - p.printOperands(op.getOperands()); - p << ')'; - p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); - p << " : "; - p.printType(op.getCalleeType()); -} - -FunctionType CallImportOp::getCalleeType() { - SmallVector<Type, 4> resultTypes(getResultTypes()); - SmallVector<Type, 8> argTypes(getOperandTypes()); - return FunctionType::get(argTypes, resultTypes, getContext()); -} - -//===----------------------------------------------------------------------===// -// iree_ll_interp.call_indirect -//===----------------------------------------------------------------------===// - -static ParseResult parseCallIndirectOp(OpAsmParser &parser, - OperationState &result) { - FunctionType calleeType; - OpAsmParser::OperandType callee; - llvm::SMLoc operandsLoc; - SmallVector<OpAsmParser::OperandType, 4> operands; - return failure( - parser.parseOperand(callee) || parser.getCurrentLocation(&operandsLoc) || - parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || - parser.parseOptionalAttributeDict(result.attributes) || - parser.parseColonType(calleeType) || - parser.resolveOperand(callee, calleeType, result.operands) || - parser.resolveOperands(operands, calleeType.getInputs(), operandsLoc, - result.operands) || - parser.addTypesToList(calleeType.getResults(), result.types)); -} - -static void printCallIndirectOp(OpAsmPrinter &p, CallIndirectOp op) { - p << "iree_ll_interp.call_indirect "; - p.printOperand(op.getCallee()); - p << '('; - auto operandRange = op.getOperands(); - p.printOperands(++operandRange.begin(), operandRange.end()); - p << ')'; - p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); - p << " : " << op.getCallee()->getType(); -} - -//===----------------------------------------------------------------------===// -// iree_ll_interp.return -//===----------------------------------------------------------------------===// - -static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &state) { - SmallVector<OpAsmParser::OperandType, 2> opInfo; - SmallVector<Type, 2> types; - llvm::SMLoc loc = parser.getCurrentLocation(); - return failure(parser.parseOperandList(opInfo) || - (!opInfo.empty() && parser.parseColonTypeList(types)) || - parser.resolveOperands(opInfo, types, loc, state.operands)); -} - -static void printReturnOp(OpAsmPrinter &p, ReturnOp op) { - p << "iree_ll_interp.return"; - if (op.getNumOperands() > 0) { - p << ' '; - p.printOperands(op.operand_begin(), op.operand_end()); - p << " : "; - interleaveComma(op.getOperandTypes(), p); - } -} - -//===----------------------------------------------------------------------===// -// iree_ll_interp.br -//===----------------------------------------------------------------------===// - -static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &result) { - Block *dest; - SmallVector<Value *, 4> destOperands; - if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure(); - result.addSuccessor(dest, destOperands); - return success(); -} - -static void printBranchOp(OpAsmPrinter &p, BranchOp op) { - p << "iree_ll_interp.br "; - p.printSuccessorAndUseList(op.getOperation(), 0); -} - -Block *BranchOp::getDest() { return getOperation()->getSuccessor(0); } - -void BranchOp::setDest(Block *block) { - return getOperation()->setSuccessor(block, 0); -} - -void BranchOp::eraseOperand(unsigned index) { - getOperation()->eraseSuccessorOperand(0, index); -} - -//===----------------------------------------------------------------------===// -// iree_ll_interp.cond_br -//===----------------------------------------------------------------------===// - -static ParseResult parseCondBranchOp(OpAsmParser &parser, - OperationState &result) { - SmallVector<Value *, 4> destOperands; - Block *dest; - OpAsmParser::OperandType condInfo; - - // Parse the condition. - Type int1Ty = parser.getBuilder().getI1Type(); - if (parser.parseOperand(condInfo) || parser.parseComma() || - parser.resolveOperand(condInfo, int1Ty, result.operands)) { - return parser.emitError(parser.getNameLoc(), - "expected condition type was boolean (i1)"); - } - - // Parse the true successor. - if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure(); - result.addSuccessor(dest, destOperands); - - // Parse the false successor. - destOperands.clear(); - if (parser.parseComma() || - parser.parseSuccessorAndUseList(dest, destOperands)) - return failure(); - result.addSuccessor(dest, destOperands); - - return success(); -} - -static void printCondBranchOp(OpAsmPrinter &p, CondBranchOp op) { - p << "iree_ll_interp.cond_br "; - p.printOperand(op.getCondition()); - p << ", "; - p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex); - p << ", "; - p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex); -} - -#define GET_OP_CLASSES -#include "iree/compiler/IR/Interpreter/LLOps.cpp.inc" - -} // namespace LL -} // namespace IREEInterp -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/IR/Interpreter/LLOps.h b/iree/compiler/IR/Interpreter/LLOps.h deleted file mode 100644 index 578b33f..0000000 --- a/iree/compiler/IR/Interpreter/LLOps.h +++ /dev/null
@@ -1,38 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_COMPILER_IR_INTERPRETER_LLOPS_H_ -#define IREE_COMPILER_IR_INTERPRETER_LLOPS_H_ - -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/TypeUtilities.h" - -namespace mlir { -namespace iree_compiler { -namespace IREEInterp { -namespace LL { - -#define GET_OP_CLASSES -#include "iree/compiler/IR/Interpreter/LLOps.h.inc" - -} // namespace LL -} // namespace IREEInterp -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_IR_INTERPRETER_LLOPS_H_
diff --git a/iree/compiler/IR/Interpreter/LLOps.td b/iree/compiler/IR/Interpreter/LLOps.td deleted file mode 100644 index b4ec592..0000000 --- a/iree/compiler/IR/Interpreter/LLOps.td +++ /dev/null
@@ -1,633 +0,0 @@ -// IREE low-level interpreter op definitions. -// These map 1:1 with the bytecode, accept only MemRef types and generally use -// output parameters instead of return types. -// -// The source of truth for bytecode opcodes is: -// https://github.com/google/iree/tree/master/iree/schemas/bytecode/interpreter_bytecode_v0.h - -#ifdef IREE_INTERPRETER_LL_OPS -#else -#define IREE_INTERPRETER_LL_OPS - -#ifdef IREE_OP_BASE -#else -include "iree/compiler/IR/OpBase.td" -#endif // IREE_OP_BASE - -def IREEInterpLL_Dialect : Dialect { - let name = "iree_ll_interp"; - let cppNamespace = "IREEInterp::LL"; -} - -//===----------------------------------------------------------------------===// -// Base op classes -//===----------------------------------------------------------------------===// - -class IREEInterpLL_Op<string mnemonic, list<OpTrait> traits = []> : - Op<IREEInterpLL_Dialect, mnemonic, traits>; - -class IREEInterpLL_PureOp<string mnemonic, list<OpTrait> traits = []> : - IREEInterpLL_Op<mnemonic, !listconcat(traits, [NoSideEffect])>; - -class IREEInterpLL_UnaryOp<string mnemonic, Type type = IREELL_MemRef, - list<OpTrait> traits = []> : IREEInterpLL_Op<mnemonic, traits> { - let arguments = (ins type:$input, type:$dst); -} - -class IREEInterpLL_BinaryOp<string mnemonic, Type type = IREELL_MemRef, - list<OpTrait> traits = []> : IREEInterpLL_Op<mnemonic, traits> { - let arguments = (ins type:$lhs, type:$rhs, type:$dst); -} - -class IREEInterpLL_TernaryOp<string mnemonic, Type type = IREELL_MemRef, - list<OpTrait> traits = []> - : IREEInterpLL_Op<mnemonic, traits> { - let arguments = (ins type : $a, type : $b, type : $c, type : $dst); -} - -//===----------------------------------------------------------------------===// -// Low-level interpreter ops -//===----------------------------------------------------------------------===// - -// TODO(benvanik): value attribute. -def IREEInterpLL_ConstantOp : IREEInterpLL_PureOp<"constant"> { - let results = (outs IREELL_MemRef); -} - -def IREEInterpLL_CallOp : IREEInterpLL_Op<"call"> { - let arguments = (ins SymbolRefAttr:$callee, Variadic<IREELL_MemRef>); - let results = (outs Variadic<IREELL_MemRef>); - - let builders = [OpBuilder< - "Builder *builder, OperationState &result, FuncOp callee," - "ArrayRef<Value *> operands = {}", [{ - result.addOperands(operands); - result.addAttribute("callee", builder->getSymbolRefAttr(callee)); - result.addTypes(callee.getType().getResults()); - }]>, OpBuilder< - "Builder *builder, OperationState &result, StringRef callee," - "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{ - result.addOperands(operands); - result.addAttribute("callee", builder->getSymbolRefAttr(callee)); - result.addTypes(results); - }]>]; - - let extraClassDeclaration = [{ - // TODO(b/132296600): make tablegen follow the style guide. - StringRef getCallee() { return callee(); } - FunctionType getCalleeType(); - - // TODO(b/133879130): make tablegen support variadic operand accessors. - /// Get the argument operands to the called function. - operand_range getArgOperands() { - return {arg_operand_begin(), arg_operand_end()}; - } - - operand_iterator arg_operand_begin() { return operand_begin(); } - operand_iterator arg_operand_end() { return operand_end(); } - }]; - - let parser = [{ return parse$cppClass(parser, result); }]; - let printer = [{ return print$cppClass(p, *this); }]; -} - -// TODO(benvanik): add verifier that target isExternal. -def IREEInterpLL_CallImportOp : IREEInterpLL_Op<"call_import"> { - let arguments = (ins SymbolRefAttr:$callee, Variadic<IREELL_MemRef>); - let results = (outs Variadic<IREELL_MemRef>); - - let builders = [OpBuilder< - "Builder *builder, OperationState &result, FuncOp callee," - "ArrayRef<Value *> operands = {}", [{ - result.addOperands(operands); - result.addAttribute("callee", builder->getSymbolRefAttr(callee)); - result.addTypes(callee.getType().getResults()); - }]>, OpBuilder< - "Builder *builder, OperationState &result, StringRef callee," - "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{ - result.addOperands(operands); - result.addAttribute("callee", builder->getSymbolRefAttr(callee)); - result.addTypes(results); - }]>]; - - let extraClassDeclaration = [{ - // TODO(b/132296600): make tablegen follow the style guide. - StringRef getCallee() { return callee(); } - FunctionType getCalleeType(); - - // TODO(b/133879130): make tablegen support variadic operand accessors. - /// Get the argument operands to the called function. - operand_range getArgOperands() { - return {arg_operand_begin(), arg_operand_end()}; - } - - operand_iterator arg_operand_begin() { return operand_begin(); } - operand_iterator arg_operand_end() { return operand_end(); } - }]; - - let parser = [{ return parse$cppClass(parser, result); }]; - let printer = [{ return print$cppClass(p, *this); }]; -} - -def IREEInterpLL_CallIndirectOp : IREEInterpLL_Op<"call_indirect"> { - let arguments = (ins FunctionType:$callee, Variadic<IREELL_MemRef>:$operands); - let results = (outs Variadic<IREELL_MemRef>); - - let builders = [OpBuilder< - "Builder *, OperationState &result, Value *callee," - "ArrayRef<Value *> operands = {}", [{ - result.operands.push_back(callee); - result.addOperands(operands); - result.addTypes(callee->getType().cast<FunctionType>().getResults()); - }]>]; - - let extraClassDeclaration = [{ - Value *getCallee() { return getOperand(0); } - - /// Get the argument operands to the called function. - operand_range getArgOperands() { - return {arg_operand_begin(), arg_operand_end()}; - } - - operand_iterator arg_operand_begin() { return ++operand_begin(); } - operand_iterator arg_operand_end() { return operand_end(); } - }]; - - let parser = [{ return parse$cppClass(parser, result); }]; - let printer = [{ return print$cppClass(p, *this); }]; -} - -def IREEInterpLL_ReturnOp : IREEInterpLL_Op<"return", [Terminator]> { - let arguments = (ins Variadic<IREELL_MemRef>:$operands); - - let builders = [OpBuilder< - "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }] - >]; - - let parser = [{ return parse$cppClass(parser, result); }]; - let printer = [{ return print$cppClass(p, *this); }]; -} - -def IREEInterpLL_BranchOp : IREEInterpLL_Op<"br", [Terminator]> { - let arguments = (ins Variadic<IREELL_MemRef>:$operands); - - let builders = [OpBuilder< - "Builder *, OperationState &result, Block *dest," - "ArrayRef<Value *> operands = {}", [{ - result.addSuccessor(dest, operands); - }]>]; - - let extraClassDeclaration = [{ - Block *getDest(); - void setDest(Block *block); - - /// Erase the operand at 'index' from the operand list. - void eraseOperand(unsigned index); - }]; - - let parser = [{ return parse$cppClass(parser, result); }]; - let printer = [{ return print$cppClass(p, *this); }]; -} - -def IREEInterpLL_CondBranchOp : IREEInterpLL_Op<"cond_br", [Terminator]> { - let arguments = (ins - IREELL_BoolScalar:$condition, - Variadic<IREELL_MemRef>:$branchOperands - ); - - let builders = [OpBuilder< - "Builder *, OperationState &result, Value *condition," - "Block *trueDest, ArrayRef<Value *> trueOperands," - "Block *falseDest, ArrayRef<Value *> falseOperands", [{ - result.addOperands(condition); - result.addSuccessor(trueDest, trueOperands); - result.addSuccessor(falseDest, falseOperands); - }]>]; - - let extraClassDeclaration = [{ - // These are the indices into the dests list. - enum { trueIndex = 0, falseIndex = 1 }; - - // The condition operand is the first operand in the list. - Value *getCondition() { return getOperand(0); } - - /// Return the destination if the condition is true. - Block *getTrueDest() { - return getOperation()->getSuccessor(trueIndex); - } - - /// Return the destination if the condition is false. - Block *getFalseDest() { - return getOperation()->getSuccessor(falseIndex); - } - - // Accessors for operands to the 'true' destination. - Value *getTrueOperand(unsigned idx) { - assert(idx < getNumTrueOperands()); - return getOperand(getTrueDestOperandIndex() + idx); - } - - void setTrueOperand(unsigned idx, Value *value) { - assert(idx < getNumTrueOperands()); - setOperand(getTrueDestOperandIndex() + idx, value); - } - - operand_iterator true_operand_begin() { - return operand_begin() + getTrueDestOperandIndex(); - } - operand_iterator true_operand_end() { - return true_operand_begin() + getNumTrueOperands(); - } - operand_range getTrueOperands() { - return {true_operand_begin(), true_operand_end()}; - } - - unsigned getNumTrueOperands() { - return getOperation()->getNumSuccessorOperands(trueIndex); - } - - /// Erase the operand at 'index' from the true operand list. - void eraseTrueOperand(unsigned index) { - getOperation()->eraseSuccessorOperand(trueIndex, index); - } - - // Accessors for operands to the 'false' destination. - Value *getFalseOperand(unsigned idx) { - assert(idx < getNumFalseOperands()); - return getOperand(getFalseDestOperandIndex() + idx); - } - void setFalseOperand(unsigned idx, Value *value) { - assert(idx < getNumFalseOperands()); - setOperand(getFalseDestOperandIndex() + idx, value); - } - - operand_iterator false_operand_begin() { return true_operand_end(); } - operand_iterator false_operand_end() { - return false_operand_begin() + getNumFalseOperands(); - } - operand_range getFalseOperands() { - return {false_operand_begin(), false_operand_end()}; - } - - unsigned getNumFalseOperands() { - return getOperation()->getNumSuccessorOperands(falseIndex); - } - - /// Erase the operand at 'index' from the false operand list. - void eraseFalseOperand(unsigned index) { - getOperation()->eraseSuccessorOperand(falseIndex, index); - } - - private: - /// Get the index of the first true destination operand. - unsigned getTrueDestOperandIndex() { return 1; } - - /// Get the index of the first false destination operand. - unsigned getFalseDestOperandIndex() { - return getTrueDestOperandIndex() + getNumTrueOperands(); - } - }]; - - let parser = [{ return parse$cppClass(parser, result); }]; - let printer = [{ return print$cppClass(p, *this); }]; -} - -def IREEInterpLL_CmpIOp : IREEInterpLL_Op<"cmp_i"> { - let arguments = (ins - I32Attr:$predicate, - IREELL_IntMemRef:$lhs, - IREELL_IntMemRef:$rhs, - IREELL_BoolMemRef:$dst - ); -} - -def IREEInterpLL_CmpFOp : IREEInterpLL_Op<"cmp_f"> { - let arguments = (ins - I32Attr:$predicate, - IREELL_FloatMemRef:$lhs, - IREELL_FloatMemRef:$rhs, - IREELL_BoolMemRef:$dst - ); -} - -def IREEInterpLL_AllocStaticOp : IREEInterpLL_PureOp<"alloc_static"> { - // TODO(benvanik): attributes and args. - let results = (outs IREELL_MemRef); -} - -def IREEInterpLL_AllocStackOp : IREEInterpLL_PureOp<"alloc_stack"> { - // TODO(benvanik): atributes and args. - let arguments = (ins - Variadic<IREELL_MemRef>:$dim_pieces - ); - let results = (outs - IREELL_MemRef - ); -} - -def IREEInterpLL_AllocStackInitOp : IREEInterpLL_PureOp<"alloc_stack_init"> { - // TODO(benvanik): attributes and args. - let arguments = (ins - Variadic<IREELL_MemRef>:$dim_pieces - ); - let results = (outs - IREELL_MemRef - ); -} - -// TODO(b/142012496): Add trait that enables DCE but not CSE. -def IREEInterpLL_AllocHeapOp : IREEInterpLL_Op<"alloc_heap"> { - // TODO(benvanik): attributes and args. - let arguments = (ins - Variadic<IREELL_MemRef>:$dim_pieces - ); - let results = (outs - IREELL_MemRef - ); -} - -def IREEInterpLL_DiscardOp : IREEInterpLL_Op<"discard"> { - let arguments = (ins IREELL_MemRef); -} - -def IREEInterpLL_RankOp : IREEInterpLL_Op<"rank"> { - let arguments = (ins - IREELL_MemRef:$input, - IREELL_I32Scalar:$dst - ); -} - -def IREEInterpLL_DimOp : IREEInterpLL_Op<"dim"> { - // TODO(benvanik) add dim attr (I32Attr:$dim) - let arguments = (ins - IREELL_MemRef:$input, - IREELL_I32Scalar:$dst - ); -} - -def IREEInterpLL_ShapeOp : IREEInterpLL_Op<"shape"> { - let arguments = (ins - IREELL_MemRef:$input, - IREELL_I32MemRef:$dst - ); -} - -def IREEInterpLL_LengthOp : IREEInterpLL_Op<"length"> { - let arguments = (ins - IREELL_MemRef:$input, - IREELL_I32Scalar:$dst - ); -} - - -def IREEInterpLL_DynamicSliceOp : IREEInterpLL_PureOp<"dynamic_slice"> { - let arguments = (ins - IREELL_MemRef:$src, - IREELL_1DIndexMemRef:$srcIndices, - IREELL_1DIndexMemRef:$lengths - ); - let results = (outs - IREELL_MemRef - ); -} - -// TODO(benvanik): add attribute requirements/types. -def IREEInterpLL_StaticSliceOp : - IREEInterpLL_PureOp<"static_slice", [SameOperandsAndResultElementType]> { - let arguments = (ins IREELL_MemRef:$src); - let results = (outs IREELL_MemRef); -} - -def IREEInterpLL_DynamicCopyOp : IREEInterpLL_Op<"dynamic_copy", [ - AllElementCountsMatch<["srcIndices", "dstIndices", "lengths"]>, -]> { - let arguments = (ins - IREELL_MemRef:$src, - IREELL_1DIndexMemRef:$srcIndices, - IREELL_MemRef:$dst, - IREELL_1DIndexMemRef:$dstIndices, - IREELL_1DIndexMemRef:$lengths - ); -} - -def IREEInterpLL_StaticCopyOp : IREEInterpLL_Op<"static_copy", [ - AllElementCountsMatch<["srcIndices", "dstIndices", "lengths"]>, -]> { - let arguments = (ins - IREELL_MemRef:$src, - I32ElementsAttr:$srcIndices, - IREELL_MemRef:$dst, - I32ElementsAttr:$dstIndices, - I32ElementsAttr:$lengths - ); -} - -def IREEInterpLL_CloneOp : - IREEInterpLL_PureOp<"clone", [SameOperandsAndResultType]> { - let arguments = (ins IREELL_MemRef:$src); - let results = (outs IREELL_MemRef); -} - -// TODO(benvanik): add split dim/size/etc. Maybe make multiple ops? -def IREEInterpLL_SplitOp : IREEInterpLL_PureOp<"split"> { - let arguments = (ins - IREELL_MemRef:$src - ); - let results = (outs - Variadic<IREELL_MemRef> - ); -} - -def IREEInterpLL_AssignOp : - IREEInterpLL_Op<"assign", [SameOperandsAndResultType]> { - let arguments = (ins IREELL_MemRef:$src); - let results = (outs IREELL_MemRef); -} - -def IREEInterpLL_CondAssignOp : IREEInterpLL_Op<"cond_assign"> { - let arguments = (ins - IREELL_BoolScalar:$cond, - IREELL_MemRef:$lhs, - IREELL_MemRef:$rhs - ); - let results = (outs - IREELL_MemRef - ); -} - -def IREEInterpLL_ReshapeOp : IREEInterpLL_Op<"reshape"> { - let arguments = (ins - IREELL_MemRef:$input, - IREELL_1DIntMemRef:$shape - ); - let results = (outs - IREELL_MemRef - ); -} - -def IREEInterpLL_SelectOp : IREEInterpLL_Op<"select"> { - let arguments = (ins - IREELL_MemRef:$cond, - IREELL_MemRef:$lhs, - IREELL_MemRef:$rhs, - IREELL_MemRef:$dst - ); -} - -def IREEInterpLL_PadOp : - IREEInterpLL_Op< - "pad", [AllElementTypesMatch<["src", "dst", "padding_value"]>]> { - let arguments = (ins - IREELL_MemRef:$src, - IREELL_ElementScalar:$padding_value, - IREELL_1DIndexMemRef:$edge_padding_low, - IREELL_1DIndexMemRef:$edge_padding_high, - IREELL_1DIndexMemRef:$interior_padding, - IREELL_MemRef:$dst - ); -} - -def IREEInterpLL_TransposeOp : IREEInterpLL_BinaryOp<"transpose">; - -def IREEInterPLL_ReverseOp : IREEInterpLL_BinaryOp<"reverse">; - -def IREEInterpLL_BroadcastOp : IREEInterpLL_BinaryOp<"broadcast">; - -def IREEInterpLL_TileOp : IREEInterpLL_BinaryOp<"tile">; - -// TODO(benvanik): add traits for broadcasting support. - -def IREEInterpLL_NotOp : IREEInterpLL_UnaryOp<"not">; -def IREEInterpLL_AndOp : IREEInterpLL_BinaryOp<"and">; -def IREEInterpLL_OrOp : IREEInterpLL_BinaryOp<"or">; -def IREEInterpLL_XorOp : IREEInterpLL_BinaryOp<"xor">; -def IREEInterpLL_ShiftLeftOp : IREEInterpLL_BinaryOp<"sll">; -def IREEInterpLL_ShiftRightLogicalOp : IREEInterpLL_BinaryOp<"srl">; -def IREEInterpLL_ShiftRightArithmeticOp : IREEInterpLL_BinaryOp<"sra">; - -def IREEInterpLL_AddIOp : IREEInterpLL_BinaryOp<"add_i", IREELL_IntMemRef>; -def IREEInterpLL_AddFOp : IREEInterpLL_BinaryOp<"add_f", IREELL_FloatMemRef>; -def IREEInterpLL_SubIOp : IREEInterpLL_BinaryOp<"sub_i", IREELL_IntMemRef>; -def IREEInterpLL_SubFOp : IREEInterpLL_BinaryOp<"sub_f", IREELL_FloatMemRef>; -def IREEInterpLL_AbsIOp : IREEInterpLL_UnaryOp<"abs_i", IREELL_IntMemRef>; -def IREEInterpLL_AbsFOp : IREEInterpLL_UnaryOp<"abs_f", IREELL_FloatMemRef>; -def IREEInterpLL_MulIOp : IREEInterpLL_BinaryOp<"mul_i", IREELL_IntMemRef>; -def IREEInterpLL_MulFOp : IREEInterpLL_BinaryOp<"mul_f", IREELL_FloatMemRef>; -def IREEInterpLL_DivISOp : IREEInterpLL_BinaryOp<"div_i_s", IREELL_IntMemRef>; -def IREEInterpLL_DivIUOp : IREEInterpLL_BinaryOp<"div_i_u", IREELL_IntMemRef>; -def IREEInterpLL_DivFOp : IREEInterpLL_BinaryOp<"div_f", IREELL_FloatMemRef>; -def IREEInterpLL_MulAddIOp : IREEInterpLL_BinaryOp<"madd_i", IREELL_IntMemRef>; -def IREEInterpLL_MulAddFOp : IREEInterpLL_BinaryOp<"madd_f", IREELL_FloatMemRef>; -def IREEInterpLL_ExpFOp : IREEInterpLL_UnaryOp<"exp_f", IREELL_FloatMemRef>; -def IREEInterpLL_LogFOp : IREEInterpLL_UnaryOp<"log_f", IREELL_FloatMemRef>; -def IREEInterpLL_RsqrtFOp : IREEInterpLL_UnaryOp<"rsqrt_f", IREELL_FloatMemRef>; -def IREEInterpLL_CosFOp : IREEInterpLL_UnaryOp<"cos_f", IREELL_FloatMemRef>; -def IREEInterpLL_SinFOp : IREEInterpLL_UnaryOp<"sin_f", IREELL_FloatMemRef>; -def IREEInterpLL_TanhFOp : IREEInterpLL_UnaryOp<"tanh_f", IREELL_FloatMemRef>; -def IREEInterpLL_Atan2FOp : IREEInterpLL_UnaryOp<"atan2_f", IREELL_FloatMemRef>; - -def IREEInterpLL_MinISOp : IREEInterpLL_BinaryOp<"min_i_s", IREELL_IntMemRef>; -def IREEInterpLL_MinIUOp : IREEInterpLL_BinaryOp<"min_i_u", IREELL_IntMemRef>; -def IREEInterpLL_MinFOp : IREEInterpLL_BinaryOp<"min_f", IREELL_FloatMemRef>; -def IREEInterpLL_MaxISOp : IREEInterpLL_BinaryOp<"max_i_s", IREELL_IntMemRef>; -def IREEInterpLL_MaxIUOp : IREEInterpLL_BinaryOp<"max_i_u", IREELL_IntMemRef>; -def IREEInterpLL_MaxFOp : IREEInterpLL_BinaryOp<"max_f", IREELL_FloatMemRef>; -def IREEInterpLL_ClampFOp : IREEInterpLL_TernaryOp<"clamp_f", IREELL_FloatMemRef>; -def IREEInterpLL_FloorFOp : IREEInterpLL_UnaryOp<"floor_f", IREELL_FloatMemRef>; -def IREEInterpLL_CeilFOp : IREEInterpLL_UnaryOp<"ceil_f", IREELL_FloatMemRef>; - -def IREEInterpLL_ConvertSSOp : IREEInterpLL_UnaryOp<"convert_s_s", IREELL_MemRef>; -def IREEInterpLL_ConvertSUOp : IREEInterpLL_UnaryOp<"convert_s_u", IREELL_MemRef>; -def IREEInterpLL_ConvertSFOp : IREEInterpLL_UnaryOp<"convert_s_f", IREELL_MemRef>; - -def IREEInterpLL_ConvertUSOp : IREEInterpLL_UnaryOp<"convert_u_s", IREELL_MemRef>; -def IREEInterpLL_ConvertUUOp : IREEInterpLL_UnaryOp<"convert_u_u", IREELL_MemRef>; -def IREEInterpLL_ConvertUFOp : IREEInterpLL_UnaryOp<"convert_u_f", IREELL_MemRef>; - -def IREEInterpLL_ConvertFSOp : IREEInterpLL_UnaryOp<"convert_f_s", IREELL_MemRef>; -def IREEInterpLL_ConvertFUOp : IREEInterpLL_UnaryOp<"convert_f_u", IREELL_MemRef>; -def IREEInterpLL_ConvertFFOp : IREEInterpLL_UnaryOp<"convert_f_f", IREELL_MemRef>; - -def IREEInterpLL_MatMulIOp : IREEInterpLL_Op<"matmul_i"> { - let arguments = (ins - IREELL_IntMemRef:$lhs, - IREELL_IntMemRef:$rhs, - IREELL_IntMemRef:$multiplier_mantissa, - IREELL_IntMemRef:$multiplier_exponent, - IREELL_IntMemRef:$dst - ); -} -def IREEInterpLL_MatMulFOp : IREEInterpLL_Op<"matmul_f"> { - let arguments = (ins - IREELL_FloatMemRef:$lhs, - IREELL_FloatMemRef:$rhs, - IREELL_FloatMemRef:$dst - ); -} - -def IREEInterpLL_ReduceSumIOp : IREEInterpLL_Op<"reduce_sum_i"> { - let arguments = (ins - IREELL_IntMemRef:$src, - IREELL_IntMemRef:$init, - I32Attr:$dimension, - IREELL_IntMemRef:$dst - ); -} -def IREEInterpLL_ReduceSumFOp : IREEInterpLL_Op<"reduce_sum_f"> { - let arguments = (ins - IREELL_FloatMemRef:$src, - IREELL_FloatMemRef:$init, - I32Attr:$dimension, - IREELL_FloatMemRef:$dst - ); -} - -def IREEInterpLL_ReduceMinIOp : IREEInterpLL_Op<"reduce_min_i"> { - let arguments = (ins - IREELL_IntMemRef:$src, - IREELL_IntMemRef:$init, - I32Attr:$dimension, - IREELL_IntMemRef:$dst - ); -} -def IREEInterpLL_ReduceMinFOp : IREEInterpLL_Op<"reduce_min_f"> { - let arguments = (ins - IREELL_FloatMemRef:$src, - IREELL_FloatMemRef:$init, - I32Attr:$dimension, - IREELL_FloatMemRef:$dst - ); -} - -def IREEInterpLL_ReduceMaxIOp : IREEInterpLL_Op<"reduce_max_i"> { - let arguments = (ins - IREELL_IntMemRef:$src, - IREELL_IntMemRef:$init, - I32Attr:$dimension, - IREELL_IntMemRef:$dst - ); -} -def IREEInterpLL_ReduceMaxFOp : IREEInterpLL_Op<"reduce_max_f"> { - let arguments = (ins - IREELL_FloatMemRef:$src, - IREELL_FloatMemRef:$init, - I32Attr:$dimension, - IREELL_FloatMemRef:$dst - ); -} - -def IREEInterpLL_TraceOp : IREEInterpLL_Op<"trace"> { - let arguments = (ins - Variadic<IREELL_MemRef>:$srcs - ); -} - -def IREEInterpLL_CondBreakOp : IREEInterpLL_Op<"cond_break"> { - let arguments = (ins - IREELL_BoolScalar:$cond - ); -} - -def IREEInterpLL_BreakOp : IREEInterpLL_Op<"break">; - -#endif // IREE_INTERPRETER_LL_OPS
diff --git a/iree/compiler/IR/Interpreter/OpWriters.cpp b/iree/compiler/IR/Interpreter/OpWriters.cpp deleted file mode 100644 index 2e07793..0000000 --- a/iree/compiler/IR/Interpreter/OpWriters.cpp +++ /dev/null
@@ -1,261 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Interpreter/OpWriters.h" - -#include "iree/compiler/IR/Interpreter/LLOps.h" -#include "iree/compiler/Serialization/BytecodeWriter.h" -#include "iree/compiler/Utils/Macros.h" -#include "iree/schemas/bytecode/interpreter_bytecode_v0.h" -#include "mlir/IR/Module.h" -#include "mlir/IR/TypeUtilities.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -//===----------------------------------------------------------------------===// -// Sequencer ops -//===----------------------------------------------------------------------===// - -LogicalResult writeOp(IREEInterp::LL::ConstantOp op, BytecodeWriter *writer) { - RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kConstant)); - auto memrefType = op.getType().dyn_cast<MemRefType>(); - if (!memrefType) { - return op.emitOpError() - << "Constant has an unsupported type; must be a memref: " - << op.getType(); - } - RETURN_IF_FAILURE(writer->WriteConstant(memrefType, op.getAttr("value"))); - RETURN_IF_FAILURE(writer->WriteLocal(op.getResult())); - return success(); -} - -LogicalResult writeOp(IREEInterp::LL::CallOp op, BytecodeWriter *writer) { - auto module = op.getOperation()->getParentOfType<ModuleOp>(); - auto callee = module.lookupSymbol<FuncOp>(op.getCallee()); - // TODO(benvanik): transforms to convert Call->CallImport. - // TODO(benvanik): switch with kCallTail if attr exists. - if (callee.isExternal()) { - RETURN_IF_FAILURE( - writer->WriteOpcode(iree::InterpreterOpcode::kCallImport)); - } else { - RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kCall)); - } - RETURN_IF_FAILURE(writer->WriteFunctionOrdinal(callee)); - RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands())); - RETURN_IF_FAILURE(writer->WriteLocals(op.getResults())); - return success(); -} - -LogicalResult writeOp(IREEInterp::LL::CallImportOp op, BytecodeWriter *writer) { - auto module = op.getOperation()->getParentOfType<ModuleOp>(); - auto callee = module.lookupSymbol<FuncOp>(op.getCallee()); - // TODO(benvanik): switch with kCallTail if attr exists. - RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kCallImport)); - RETURN_IF_FAILURE(writer->WriteImportOrdinal(callee)); - RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands())); - RETURN_IF_FAILURE(writer->WriteLocals(op.getResults())); - return success(); -} - -LogicalResult writeOp(IREEInterp::LL::CallIndirectOp op, - BytecodeWriter *writer) { - RETURN_IF_FAILURE( - writer->WriteOpcode(iree::InterpreterOpcode::kCallIndirect)); - RETURN_IF_FAILURE(writer->WriteTypeIndex(op.getCallee()->getType())); - RETURN_IF_FAILURE(writer->WriteLocal(op.getCallee())); - RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands())); - RETURN_IF_FAILURE(writer->WriteLocals(op.getResults())); - return success(); -} - -LogicalResult WriteConvertOperands(Operation *op, BytecodeWriter *writer) { - auto *src = op->getOperand(0); - RETURN_IF_FAILURE( - writer->WriteTypeIndex(getElementTypeOrSelf(src->getType()))); - RETURN_IF_FAILURE(writer->WriteLocal(src)); - auto *dst = op->getOperand(1); - RETURN_IF_FAILURE( - writer->WriteTypeIndex(getElementTypeOrSelf(dst->getType()))); - RETURN_IF_FAILURE(writer->WriteLocal(dst)); - return success(); -} - -LogicalResult writeOp(IREEInterp::LL::ConvertSSOp op, BytecodeWriter *writer) { - RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kConvertSS)); - return WriteConvertOperands(op, writer); -} - -LogicalResult writeOp(IREEInterp::LL::ConvertUUOp op, BytecodeWriter *writer) { - RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kConvertUU)); - return WriteConvertOperands(op, writer); -} - -LogicalResult writeOp(IREEInterp::LL::ConvertSUOp op, BytecodeWriter *writer) { - RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kConvertSU)); - return WriteConvertOperands(op, writer); -} - -LogicalResult writeOp(IREEInterp::LL::ConvertUSOp op, BytecodeWriter *writer) { - RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kConvertUS)); - return WriteConvertOperands(op, writer); -} - -LogicalResult writeOp(IREEInterp::LL::BranchOp op, BytecodeWriter *writer) { - RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kBranch)); - RETURN_IF_FAILURE(writer->WriteBlockOffset(op.getDest())); - RETURN_IF_FAILURE(writer->WriteCount(op.getNumOperands())); - for (int i = 0; i < op.getNumOperands(); ++i) { - // Copy src->dst. - RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(i))); - RETURN_IF_FAILURE(writer->WriteLocal(op.getDest()->getArgument(i))); - } - return success(); -} - -LogicalResult writeOp(IREEInterp::LL::CondBranchOp op, BytecodeWriter *writer) { - RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kCondBranch)); - RETURN_IF_FAILURE(writer->WriteLocal(op.getCondition())); - RETURN_IF_FAILURE(writer->WriteBlockOffset(op.getTrueDest())); - RETURN_IF_FAILURE(writer->WriteCount(op.getNumTrueOperands())); - for (int i = 0; i < op.getNumTrueOperands(); ++i) { - // Copy src->dst. - RETURN_IF_FAILURE(writer->WriteLocal(op.getTrueOperand(i))); - RETURN_IF_FAILURE(writer->WriteLocal(op.getTrueDest()->getArgument(i))); - } - RETURN_IF_FAILURE(writer->WriteBlockOffset(op.getFalseDest())); - RETURN_IF_FAILURE(writer->WriteCount(op.getNumFalseOperands())); - for (int i = 0; i < op.getNumFalseOperands(); ++i) { - // Copy src->dst. - RETURN_IF_FAILURE(writer->WriteLocal(op.getFalseOperand(i))); - RETURN_IF_FAILURE(writer->WriteLocal(op.getFalseDest()->getArgument(i))); - } - return success(); -} - -LogicalResult writeOp(IREEInterp::LL::CmpIOp op, BytecodeWriter *writer) { - RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kCmpI)); - RETURN_IF_FAILURE( - writer->WriteUint8(static_cast<uint8_t>(op.predicate().getZExtValue()))); - RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(0))); - RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(1))); - RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(2))); - return success(); -} - -LogicalResult writeOp(IREEInterp::LL::CmpFOp op, BytecodeWriter *writer) { - RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kCmpF)); - RETURN_IF_FAILURE( - writer->WriteUint8(static_cast<uint8_t>(op.predicate().getZExtValue()))); - RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(0))); - RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(1))); - RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(2))); - return success(); -} - -LogicalResult writeOp(IREEInterp::LL::AllocHeapOp op, BytecodeWriter *writer) { - auto memrefType = op.getType().cast<MemRefType>(); - RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kAllocHeap)); - RETURN_IF_FAILURE(writer->WriteInt32(0)); - RETURN_IF_FAILURE(writer->WriteTypeIndex(memrefType.getElementType())); - RETURN_IF_FAILURE(writer->WriteShapePieces(memrefType)); - RETURN_IF_FAILURE(writer->WriteLocals(op.getOperands())); - RETURN_IF_FAILURE(writer->WriteLocal(op.getResult())); - return success(); -} - -LogicalResult writeOp(IREEInterp::LL::StaticCopyOp op, BytecodeWriter *writer) { - RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kStaticCopy)); - RETURN_IF_FAILURE(writer->WriteLocal(op.src())); - RETURN_IF_FAILURE(writer->WriteShapePieces(op.srcIndices())); - RETURN_IF_FAILURE(writer->WriteLocal(op.dst())); - RETURN_IF_FAILURE(writer->WriteShapePieces(op.dstIndices())); - RETURN_IF_FAILURE(writer->WriteShapePieces(op.lengths())); - return success(); -} - -LogicalResult writeReduceOperands(Operation *op, BytecodeWriter *writer, - APInt dimension) { - RETURN_IF_FAILURE(writer->WriteLocal(op->getOperand(0))); - RETURN_IF_FAILURE(writer->WriteLocal(op->getOperand(1))); - RETURN_IF_FAILURE(writer->WriteInt32(dimension.getZExtValue())); - RETURN_IF_FAILURE(writer->WriteLocal(op->getOperand(2))); - return success(); -} - -LogicalResult writeOp(IREEInterp::LL::ReduceSumIOp op, BytecodeWriter *writer) { - RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kReduceSumI)); - return writeReduceOperands(op, writer, op.dimension()); -} - -LogicalResult writeOp(IREEInterp::LL::ReduceSumFOp op, BytecodeWriter *writer) { - RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kReduceSumF)); - return writeReduceOperands(op, writer, op.dimension()); -} - -LogicalResult writeOp(IREEInterp::LL::ReduceMinIOp op, BytecodeWriter *writer) { - RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kReduceMinI)); - return writeReduceOperands(op, writer, op.dimension()); -} - -LogicalResult writeOp(IREEInterp::LL::ReduceMinFOp op, BytecodeWriter *writer) { - RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kReduceMinF)); - return writeReduceOperands(op, writer, op.dimension()); -} - -LogicalResult writeOp(IREEInterp::LL::ReduceMaxIOp op, BytecodeWriter *writer) { - RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kReduceMaxI)); - return writeReduceOperands(op, writer, op.dimension()); -} - -LogicalResult writeOp(IREEInterp::LL::ReduceMaxFOp op, BytecodeWriter *writer) { - RETURN_IF_FAILURE(writer->WriteOpcode(iree::InterpreterOpcode::kReduceMaxF)); - return writeReduceOperands(op, writer, op.dimension()); -} - -} // namespace - -void registerInterpreterCustomWriters(VMFunctionBuilder *builder) { -#define REGISTER_CUSTOM_WRITER_IMPL(op_type) \ - builder->RegisterCustomWriter( \ - op_type::getOperationName(), \ - +[](Operation *op, BytecodeWriter *writer) { \ - return writeOp(cast<op_type>(op), writer); \ - }); - REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ConstantOp); - REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::CallOp); - REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::CallImportOp); - REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::CallIndirectOp); - REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::BranchOp); - REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::CondBranchOp); - REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ConvertSSOp); - REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ConvertUUOp); - REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ConvertSUOp); - REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ConvertUSOp); - REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::CmpIOp); - REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::CmpFOp); - REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::AllocHeapOp); - REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::StaticCopyOp); - REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ReduceSumIOp); - REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ReduceSumFOp); - REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ReduceMinIOp); - REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ReduceMinFOp); - REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ReduceMaxIOp); - REGISTER_CUSTOM_WRITER_IMPL(IREEInterp::LL::ReduceMaxFOp); -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/IR/Interpreter/OpWriters.h b/iree/compiler/IR/Interpreter/OpWriters.h deleted file mode 100644 index 302fc88..0000000 --- a/iree/compiler/IR/Interpreter/OpWriters.h +++ /dev/null
@@ -1,30 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_COMPILER_IR_INTERPRETER_OPWRITERS_H_ -#define IREE_COMPILER_IR_INTERPRETER_OPWRITERS_H_ - -#include "iree/compiler/Serialization/VMFunctionBuilder.h" - -namespace mlir { -namespace iree_compiler { - -// Registers custom op writers with the builder. -// Ops not registered will use the generic writer. -void registerInterpreterCustomWriters(VMFunctionBuilder *builder); - -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_IR_INTERPRETER_OPWRITERS_H_
diff --git a/iree/compiler/IR/Interpreter/test/BUILD b/iree/compiler/IR/Interpreter/test/BUILD deleted file mode 100644 index 4df56ac..0000000 --- a/iree/compiler/IR/Interpreter/test/BUILD +++ /dev/null
@@ -1,15 +0,0 @@ -load("//iree:build_defs.bzl", "iree_glob_lit_tests", "iree_setup_lit_package") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -iree_setup_lit_package( - data = [ - "//iree/tools:iree-opt", - "//iree/tools:iree-run-mlir", - ], -) - -iree_glob_lit_tests()
diff --git a/iree/compiler/IR/Ops.cpp b/iree/compiler/IR/Ops.cpp deleted file mode 100644 index ada40ab..0000000 --- a/iree/compiler/IR/Ops.cpp +++ /dev/null
@@ -1,639 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Ops.h" - -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/SMLoc.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/IR/Value.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Support/STLExtras.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { - -//===----------------------------------------------------------------------===// -// iree.constant -//===----------------------------------------------------------------------===// - -static ParseResult parseConstantOp(OpAsmParser &parser, - OperationState &result) { - Attribute valueAttr; - Type type; - if (parser.parseLSquare() || - parser.parseAttribute(valueAttr, "value", result.attributes) || - parser.parseRSquare() || - parser.parseOptionalAttributeDict(result.attributes) || - parser.parseColonType(type)) - return failure(); - - return parser.addTypeToList(type, result.types); -} - -static void printConstantOp(OpAsmPrinter &p, ConstantOp &op) { - p << "iree.constant["; - p.printAttribute(op.getValue()); - p << "] "; - p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"}); - - p << " : "; - p.printType(op.getType()); -} - -namespace { - -// TODO(gcmn) this is duplicated from MemRefUtils to avoid a circular -// dependency. Extract op-dependent parts of memref utils to allow reuse. -MemRefType convertTypeToMemRef(Type type) { - if (type.isIntOrIndexOrFloat()) { - return MemRefType::get({}, type, {}, 0); - } else if (auto tensorType = type.dyn_cast<RankedTensorType>()) { - return MemRefType::get(tensorType.getShape(), tensorType.getElementType()); - } else if (auto memRefType = type.dyn_cast<MemRefType>()) { - return MemRefType::get(memRefType.getShape(), memRefType.getElementType()); - } else { - llvm_unreachable("Unconvertable type"); - } -} - -} // namespace - -void ConstantOp::build(Builder *builder, OperationState &state, - ElementsAttr value) { - auto type = convertTypeToMemRef(value.getType()); - return build(builder, state, type, value); -} - -// TODO(b/134575149): enable folder when we store the correct type. -// OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { -// assert(operands.empty() && "constant has no operands"); -// return getValue(); -// } - -//===----------------------------------------------------------------------===// -// iree.tensor_to_memref -//===----------------------------------------------------------------------===// - -static ParseResult parseTensorToMemRefOp(OpAsmParser &parser, - OperationState &state) { - OpAsmParser::OperandType operand; - Type operandType; - Type resultType; - if (failed(parser.parseLParen()) || failed(parser.parseOperand(operand)) || - failed(parser.parseColonType(operandType)) || - failed(parser.resolveOperand(operand, operandType, state.operands)) || - failed(parser.parseRParen()) || - failed(parser.parseColonType(resultType)) || - failed(parser.addTypeToList(resultType, state.types))) { - return failure(); - } - return success(); -} - -static void printTensorToMemRefOp(OpAsmPrinter &p, TensorToMemRefOp &op) { - p << "iree.tensor_to_memref("; - p.printOperand(op.getOperand()); - p << " : "; - p.printType(op.getOperand()->getType()); - p << ") : "; - p.printType(op.getType()); -} - -OpFoldResult TensorToMemRefOp::fold(ArrayRef<Attribute> operands) { - if (auto memrefToTensorOp = dyn_cast_or_null<IREE::MemRefToTensorOp>( - getOperand()->getDefiningOp())) { - return memrefToTensorOp.getOperand(); - } - - return {}; -} - -void TensorToMemRefOp::build(Builder *builder, OperationState &state, - Value *arg) { - build(builder, state, convertTypeToMemRef(arg->getType()), arg); -} - -//===----------------------------------------------------------------------===// -// iree.memref_to_tensor -//===----------------------------------------------------------------------===// - -static ParseResult parseMemRefToTensorOp(OpAsmParser &parser, - OperationState &state) { - OpAsmParser::OperandType operand; - Type operandType; - Type resultType; - if (failed(parser.parseLParen()) || failed(parser.parseOperand(operand)) || - failed(parser.parseColonType(operandType)) || - failed(parser.resolveOperand(operand, operandType, state.operands)) || - failed(parser.parseRParen()) || - failed(parser.parseColonType(resultType)) || - failed(parser.addTypeToList(resultType, state.types))) { - return failure(); - } - return success(); -} - -static void printMemRefToTensorOp(OpAsmPrinter &p, MemRefToTensorOp &op) { - p << "iree.memref_to_tensor("; - p.printOperand(op.getOperand()); - p << " : "; - p.printType(op.getOperand()->getType()); - p << ") : "; - p.printType(op.getType()); -} - -OpFoldResult MemRefToTensorOp::fold(ArrayRef<Attribute> operands) { - if (auto tensorToMemRefOp = dyn_cast_or_null<IREE::TensorToMemRefOp>( - getOperand()->getDefiningOp())) { - return tensorToMemRefOp.getOperand(); - } - - return {}; -} - -void MemRefToTensorOp::build(Builder *builder, OperationState &state, - Value *arg) { - // TODO(gcmn) Use getTensorType from MemRefUtils when circular dependency can - // be avoided. - auto memRefType = arg->getType().cast<MemRefType>(); - auto tensorType = - RankedTensorType::get(memRefType.getShape(), memRefType.getElementType()); - build(builder, state, tensorType, arg); -} - -//===----------------------------------------------------------------------===// -// iree.scalar_to_memref -//===----------------------------------------------------------------------===// - -static ParseResult parseScalarToMemRefOp(OpAsmParser &parser, - OperationState &state) { - OpAsmParser::OperandType operand; - Type operandType; - Type resultType; - if (failed(parser.parseLParen()) || failed(parser.parseOperand(operand)) || - failed(parser.parseColonType(operandType)) || - failed(parser.resolveOperand(operand, operandType, state.operands)) || - failed(parser.parseRParen()) || - failed(parser.parseColonType(resultType)) || - failed(parser.addTypeToList(resultType, state.types))) { - return failure(); - } - return success(); -} - -static void printScalarToMemRefOp(OpAsmPrinter &p, ScalarToMemRefOp &op) { - p << "iree.scalar_to_memref("; - p.printOperand(op.getOperand()); - p << " : "; - p.printType(op.getOperand()->getType()); - p << ") : "; - p.printType(op.getType()); -} - -OpFoldResult ScalarToMemRefOp::fold(ArrayRef<Attribute> operands) { - if (auto memrefToScalarOp = dyn_cast_or_null<IREE::MemRefToScalarOp>( - getOperand()->getDefiningOp())) { - return memrefToScalarOp.getOperand(); - } - - return {}; -} - -void ScalarToMemRefOp::build(Builder *builder, OperationState &state, - Value *arg) { - build(builder, state, convertTypeToMemRef(arg->getType()), arg); -} - -//===----------------------------------------------------------------------===// -// iree.memref_to_scalar -//===----------------------------------------------------------------------===// - -static ParseResult parseMemRefToScalarOp(OpAsmParser &parser, - OperationState &state) { - OpAsmParser::OperandType operand; - Type operandType; - Type resultType; - if (failed(parser.parseLParen()) || failed(parser.parseOperand(operand)) || - failed(parser.parseColonType(operandType)) || - failed(parser.resolveOperand(operand, operandType, state.operands)) || - failed(parser.parseRParen()) || - failed(parser.parseColonType(resultType)) || - failed(parser.addTypeToList(resultType, state.types))) { - return failure(); - } - return success(); -} - -static void printMemRefToScalarOp(OpAsmPrinter &p, MemRefToScalarOp &op) { - p << "iree.memref_to_scalar("; - p.printOperand(op.getOperand()); - p << " : "; - p.printType(op.getOperand()->getType()); - p << ") : "; - p.printType(op.getType()); -} - -OpFoldResult MemRefToScalarOp::fold(ArrayRef<Attribute> operands) { - if (auto scalarToMemRefOp = dyn_cast_or_null<IREE::ScalarToMemRefOp>( - getOperand()->getDefiningOp())) { - return scalarToMemRefOp.getOperand(); - } - - return {}; -} - -void MemRefToScalarOp::build(Builder *builder, OperationState &state, - Value *arg) { - build(builder, state, getElementTypeOrSelf(arg), arg); -} - -//===----------------------------------------------------------------------===// -// iree.dispatch_region -//===----------------------------------------------------------------------===// - -void DispatchRegionOp::build(Builder *builder, OperationState &state, - ArrayRef<Type> resultTypes, Value *workload, - ArrayRef<Value *> operands, - ArrayRef<NamedAttribute> attributes) { - state.addTypes(resultTypes); - state.addOperands({workload}); - state.addOperands(operands); - state.addAttributes(attributes); - state.addRegion(); - state.setOperandListToResizable(); -} - -ParseResult parseDispatchRegionOp(OpAsmParser &parser, OperationState &state) { - // Parse required workload. - OpAsmParser::OperandType workloadArg; - Type workloadArgType; - if (failed(parser.parseLSquare()) || - failed(parser.parseOperand(workloadArg)) || - failed(parser.parseColonType(workloadArgType)) || - failed(parser.parseRSquare()) || - failed(parser.resolveOperand(workloadArg, workloadArgType, - state.operands))) { - return failure(); - } - - // Parse (optional) args. - SmallVector<OpAsmParser::OperandType, 16> regionArgs; - SmallVector<Type, 16> regionArgTypes; - if (failed(parser.parseLParen())) { - return failure(); - } - if (failed(parser.parseOptionalRParen())) { - SmallVector<OpAsmParser::OperandType, 16> regionOperands; - auto argsLoc = parser.getCurrentLocation(); - do { - // Reserve entries in the lists. - regionArgs.emplace_back(); - regionOperands.emplace_back(); - regionArgTypes.emplace_back(); - if (failed(parser.parseRegionArgument(regionArgs.back())) || - failed(parser.parseEqual()) || - failed(parser.parseOperand(regionOperands.back())) || - failed(parser.parseColonType(regionArgTypes.back()))) { - return failure(); - } - } while (succeeded(parser.parseOptionalComma())); - if (failed(parser.parseRParen()) || - failed(parser.resolveOperands(regionOperands, regionArgTypes, argsLoc, - state.operands))) { - return failure(); - } - } - state.setOperandListToResizable(); - - // Parse (optional) results. - if (failed(parser.parseOptionalColonTypeList(state.types))) { - return failure(); - } - - // Parse region body. - Region *body = state.addRegion(); - if (failed(parser.parseRegion(*body, regionArgs, regionArgTypes)) || - failed(parser.parseOptionalAttributeDict(state.attributes))) { - return failure(); - } - return success(); -} - -void printDispatchRegionOp(OpAsmPrinter &p, DispatchRegionOp op) { - p << "iree.dispatch_region"; - - // Print the workload argument. - p << "["; - p.printOperand(op.getWorkload()); - p << " : "; - p.printType(op.getWorkload()->getType()); - p << "]"; - - // Print the data argument remapping. - p << "("; - interleaveComma( - llvm::zip(op.getBody().front().getArguments(), op.getArgOperands()), p, - [&](std::tuple<BlockArgument *, Value *> it) { - p << *std::get<0>(it) << " = " << *std::get<1>(it); - p << " : "; - p << std::get<1>(it)->getType(); - }); - p << ")"; - - // Print the result types, if any. - if (op.getNumResults() > 0) { - p << " : "; - interleaveComma(op.getResultTypes(), p); - } - - p.printRegion(op.getBody(), /*printEntryBlockArgs=*/false); - p.printOptionalAttrDict(op.getAttrs(), - /*elidedAttrs=*/{}); -} - -//===----------------------------------------------------------------------===// -// iree.reduction_region -//===----------------------------------------------------------------------===// - -void ReductionRegionOp::build(Builder *builder, OperationState &state, - ArrayRef<Type> resultTypes, Value *workload, - ArrayRef<Value *> operands, - ArrayRef<Value *> initialValues, - ArrayRef<int64_t> dimensions, - ArrayRef<NamedAttribute> attributes) { - state.addTypes(resultTypes); - state.addOperands({workload}); - state.addOperands(operands); - state.addOperands(initialValues); - state.addAttribute( - "dimensions", - DenseIntElementsAttr::get( - RankedTensorType::get({static_cast<int64_t>(dimensions.size())}, - builder->getIntegerType(64)), - dimensions)); - state.addAttributes(attributes); - state.addRegion(); - state.setOperandListToResizable(); -} - -void ReductionRegionOp::build( - Builder *builder, OperationState &state, ArrayRef<Type> resultTypes, - Value *workload, ArrayRef<Value *> operands, - ArrayRef<Value *> initialValues, ArrayRef<int64_t> windowDimensions, - ArrayRef<int64_t> windowStrides, ArrayRef<int64_t> baseDilations, - ArrayRef<int64_t> windowDilations, PaddingMode paddingMode, - ArrayRef<NamedAttribute> attributes) { - state.addTypes(resultTypes); - state.addOperands({workload}); - state.addOperands(operands); - state.addOperands(initialValues); - state.addAttribute( - "window_dimensions", - DenseIntElementsAttr::get( - RankedTensorType::get({static_cast<int64_t>(windowDimensions.size())}, - builder->getIntegerType(64)), - windowDimensions)); - state.addAttribute( - "window_strides", - DenseIntElementsAttr::get( - RankedTensorType::get({static_cast<int64_t>(windowStrides.size())}, - builder->getIntegerType(64)), - windowStrides)); - state.addAttribute( - "base_dilations", - DenseIntElementsAttr::get( - RankedTensorType::get({static_cast<int64_t>(baseDilations.size())}, - builder->getIntegerType(64)), - baseDilations)); - state.addAttribute( - "window_dilations", - DenseIntElementsAttr::get( - RankedTensorType::get({static_cast<int64_t>(windowDilations.size())}, - builder->getIntegerType(64)), - windowDilations)); - state.addAttribute("padding_mode", builder->getI32IntegerAttr( - static_cast<int32_t>(paddingMode))); - state.addAttributes(attributes); - state.addRegion(); - state.setOperandListToResizable(); -} - -ParseResult parseReductionRegionOp(OpAsmParser &parser, OperationState &state) { - OpAsmParser::OperandType workloadArg; - Type workloadArgType; - if (failed(parser.parseLSquare()) || - failed(parser.parseOperand(workloadArg)) || - failed(parser.parseColonType(workloadArgType)) || - failed(parser.parseRSquare()) || - failed(parser.resolveOperand(workloadArg, workloadArgType, - state.operands))) { - return failure(); - } - - SmallVector<OpAsmParser::OperandType, 8> reductionOperands; - Type reductionType; - auto operandsLoc = parser.getCurrentLocation(); - if (failed(parser.parseLParen()) || - failed(parser.parseOperandList(reductionOperands)) || - failed(parser.parseRParen()) || - failed(parser.parseColonType(reductionType)) || - failed(parser.resolveOperands( - reductionOperands, reductionType.cast<FunctionType>().getInputs(), - operandsLoc, state.operands))) { - return failure(); - } - for (auto type : reductionType.cast<FunctionType>().getResults()) { - state.types.push_back(type); - } - state.setOperandListToResizable(); - - SmallVector<OpAsmParser::OperandType, 8> regionArgs; - SmallVector<Type, 8> regionArgTypes; - if (failed(parser.parseKeyword("invocation")) || - failed(parser.parseLParen())) { - return failure(); - } - do { - Type argType; - SmallVector<OpAsmParser::OperandType, 2> reductionRegionArgs; - OpAsmParser::OperandType initialValue; - if (failed(parser.parseLParen()) || - failed(parser.parseOperandList(reductionRegionArgs, 2)) || - failed(parser.parseRParen()) || failed(parser.parseEqual()) || - failed(parser.parseOperand(initialValue)) || - failed(parser.parseColonType(argType)) || - failed(parser.resolveOperand(initialValue, argType, state.operands))) { - return failure(); - } - regionArgs.push_back(reductionRegionArgs[0]); - regionArgTypes.push_back(argType); - regionArgs.push_back(reductionRegionArgs[1]); - regionArgTypes.push_back(argType); - } while (succeeded(parser.parseOptionalComma())); - if (failed(parser.parseRParen())) { - return failure(); - } - - // Parse region body. - Region *body = state.addRegion(); - if (failed(parser.parseRegion(*body, regionArgs, regionArgTypes)) || - failed(parser.parseOptionalAttributeDict(state.attributes))) { - return failure(); - } - - return success(); -} - -void printReductionRegionOp(OpAsmPrinter &p, ReductionRegionOp op) { - p << "iree.reduction_region"; - - // Print the workload argument. - p << "["; - p.printOperand(op.getWorkload()); - p << " : "; - p.printType(op.getWorkload()->getType()); - p << "]"; - - p << "("; - p.printOperands(op.getODSOperands(1)); - p << ")"; - if (op.getNumResults() > 0) { - p << " : ("; - interleaveComma(op.getODSOperands(1), p, - [&](Value *operand) { p.printType(operand->getType()); }); - p << ")"; - p << " -> ("; - interleaveComma(op.getResultTypes(), p); - p << ")"; - } - p << "\n"; - - p << " invocation("; - auto &entryBlock = op.getBody().getBlocks().front(); - int regionArgIndex = 0; - interleaveComma(op.getODSOperands(2), p, [&](Value *operand) { - p << "("; - p.printOperand(entryBlock.getArgument(regionArgIndex++)); - p << ", "; - p.printOperand(entryBlock.getArgument(regionArgIndex++)); - p << ") = "; - p.printOperand(operand); - p << " : "; - p.printType(operand->getType()); - }); - p << ") "; - - p.printRegion(op.getBody(), /*printEntryBlockArgs=*/false); - p.printOptionalAttrDict(op.getAttrs(), - /*elidedAttrs=*/{}); -} - -//===----------------------------------------------------------------------===// -// iree.return -//===----------------------------------------------------------------------===// - -static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &state) { - SmallVector<OpAsmParser::OperandType, 2> opInfo; - SmallVector<Type, 2> types; - llvm::SMLoc loc = parser.getCurrentLocation(); - return failure(parser.parseOperandList(opInfo) || - (!opInfo.empty() && parser.parseColonTypeList(types)) || - parser.resolveOperands(opInfo, types, loc, state.operands)); -} - -static void printReturnOp(OpAsmPrinter &p, ReturnOp op) { - p << "iree.return"; - if (op.getNumOperands() > 0) { - p << ' '; - p.printOperands(op.operand_begin(), op.operand_end()); - p << " : "; - interleaveComma(op.getOperandTypes(), p); - } -} - -//===----------------------------------------------------------------------===// -// iree.load_input -//===----------------------------------------------------------------------===// - -ParseResult parseLoadInputOp(OpAsmParser &parser, OperationState &state) { - OpAsmParser::OperandType operand; - Type argType; - if (parser.parseLParen() || parser.parseOperand(operand) || - parser.parseColonType(argType) || parser.parseRParen() || - parser.resolveOperand(operand, argType, state.operands)) { - return failure(); - } - Type outputType; - if (parser.parseColonType(outputType) || - parser.addTypeToList(outputType, state.types)) { - return failure(); - } - return success(); -} - -void printLoadInputOp(OpAsmPrinter &printer, Operation *op) { - auto *inputValue = op->getOperand(0); - auto *outputValue = op->getResult(0); - printer << op->getName() << '('; - printer.printOperand(inputValue); - printer << " : "; - printer.printType(inputValue->getType()); - printer << ") : "; - printer.printType(outputValue->getType()); -} - -//===----------------------------------------------------------------------===// -// iree.store_output -//===----------------------------------------------------------------------===// - -ParseResult parseStoreOutputOp(OpAsmParser &parser, OperationState &state) { - OpAsmParser::OperandType op0, op1; - Type argType0, argType1; - if (parser.parseLParen() || parser.parseOperand(op0) || - parser.parseColonType(argType0) || parser.parseComma() || - parser.resolveOperand(op0, argType0, state.operands) || - parser.parseOperand(op1) || parser.parseColonType(argType1) || - parser.parseRParen() || - parser.resolveOperand(op1, argType1, state.operands)) { - return failure(); - } - return success(); -} - -void printStoreOutputOp(OpAsmPrinter &printer, Operation *op) { - auto *inputValue = op->getOperand(0); - auto *outputValue = op->getOperand(1); - printer << op->getName() << '('; - printer.printOperand(inputValue); - printer << " : "; - printer.printType(inputValue->getType()); - printer << ", "; - printer.printOperand(outputValue); - printer << " : "; - printer.printType(outputValue->getType()); - printer << ")"; -} - -#define GET_OP_CLASSES -#include "iree/compiler/IR/Ops.cpp.inc" - -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/IR/Ops.h b/iree/compiler/IR/Ops.h deleted file mode 100644 index 34f0bc3..0000000 --- a/iree/compiler/IR/Ops.h +++ /dev/null
@@ -1,36 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_COMPILER_IR_OPS_H_ -#define IREE_COMPILER_IR_OPS_H_ - -#include "iree/compiler/IR/Types.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/StandardTypes.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { - -#define GET_OP_CLASSES -#include "iree/compiler/IR/Ops.h.inc" - -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_IR_OPS_H_
diff --git a/iree/compiler/IR/Ops.td b/iree/compiler/IR/Ops.td deleted file mode 100644 index cb54101..0000000 --- a/iree/compiler/IR/Ops.td +++ /dev/null
@@ -1,200 +0,0 @@ -// IREE ops for working with buffers and buffer views. -// These are used by common transforms between the sequencer and interpreter and -// allow us to share some of the common lowering passes from other dialects. - -#ifdef IREE_OPS -#else -#define IREE_OPS - -#ifdef IREE_OP_BASE -#else -include "iree/compiler/IR/OpBase.td" -#endif // IREE_OP_BASE - -class IREE_Op<string mnemonic, list<OpTrait> traits = []> : - Op<IREE_Dialect, mnemonic, traits> { - let parser = [{ return parse$cppClass(parser, result); }]; - let printer = [{ print$cppClass(p, *this); }]; -} - -class IREE_PureOp<string mnemonic, list<OpTrait> traits = []> : - IREE_Op<mnemonic, !listconcat(traits, [NoSideEffect])>; - -// TODO(b/134575149): determine if we want multiple constant op types. -def IREE_ConstantOp : IREE_PureOp<"constant", [ - AllShapesMatch<["value", "result"]>, - AllElementTypesMatch<["value", "result"]> -]> { - let arguments = (ins ElementsAttr:$value); - let results = (outs IREEHL_MemRef:$result); - - // TODO(b/132296600): make tablegen follow the style guide. - let extraClassDeclaration = [{ - Attribute getValue() { return value(); } - }]; - - let builders = [OpBuilder<"Builder*, OperationState&, ElementsAttr">]; - - // TODO(b/134575149): enable folder when we store the correct type. - // let hasFolder = 1; -} - -// TODO(b/134671482): remove/move tensor_to_memref/memref_to_tensor. -def IREE_TensorToMemRefOp : IREE_PureOp<"tensor_to_memref", [ - SameOperandsAndResultShape, SameOperandsAndResultElementType -]> { - let arguments = (ins AnyTensor); - let results = (outs IREEHL_MemRef); - - let builders = [OpBuilder<"Builder*, OperationState&, Value*">]; - - let hasFolder = 1; -} - -// TODO(b/134671482): remove/move tensor_to_memref/memref_to_tensor. -def IREE_MemRefToTensorOp : IREE_PureOp<"memref_to_tensor", [ - SameOperandsAndResultShape, SameOperandsAndResultElementType -]> { - let arguments = (ins IREEHL_MemRef); - let results = (outs AnyTensor); - let builders = [OpBuilder<"Builder*, OperationState&, Value*">]; - - let hasFolder = 1; -} - -def IREE_ScalarToMemRefOp : IREE_PureOp<"scalar_to_memref", [ - SameOperandsAndResultElementType -]> { - let arguments = (ins IREEHL_Element); - let results = (outs IREEHL_AnyScalar); - - let builders = [OpBuilder<"Builder*, OperationState&, Value*">]; - - let hasFolder = 1; -} - -def IREE_MemRefToScalarOp : IREE_PureOp<"memref_to_scalar", [ - SameOperandsAndResultElementType -]> { - let arguments = (ins IREEHL_AnyScalar); - let results = (outs IREEHL_Element); - - let builders = [OpBuilder<"Builder*, OperationState&, Value*">]; - - let hasFolder = 1; -} - -def IREE_Workload : TensorOf<[AnyInteger]>; - -def IREE_DispatchRegionOp : IREE_PureOp<"dispatch_region"> { - let arguments = (ins - IREE_Workload:$workload, - Variadic<AnyType>:$args - ); - let results = (outs Variadic<AnyType>); - let regions = (region AnyRegion:$body); - - let extraClassDeclaration = [{ - // TODO(b/132296600): make tablegen follow the style guide. - Value *getWorkload() { return workload(); } - Region& getBody() { return body(); } - - // TODO(b/133879130): make tablegen support variadic operand accessors. - /// Get the argument operands to the called function. - operand_range getArgOperands() { - return {arg_operand_begin(), arg_operand_end()}; - } - unsigned mapArgOperandToOpOperand(unsigned i) { return i + 1; } - unsigned getNumArgOperands() { return getNumOperands() - 1; } - Value *getArgOperand(unsigned i) { - return getOperand(mapArgOperandToOpOperand(i)); - } - void setArgOperand(unsigned i, Value *arg) { - setOperand(mapArgOperandToOpOperand(i), arg); - } - - operand_iterator arg_operand_begin() { - return operand_begin() + mapArgOperandToOpOperand(0); - } - operand_iterator arg_operand_end() { return operand_end(); } - }]; - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<"Builder *builder, OperationState &state," - "ArrayRef<Type> resultTypes, Value *workload," - "ArrayRef<Value *> args," - "ArrayRef<NamedAttribute> attributes = {}">, - ]; -} - -def IREE_ReductionRegionOp : IREE_PureOp<"reduction_region", [ - SameVariadicOperandSize, -]> { - let arguments = (ins - IREE_Workload:$workload, - Variadic<AnyType>:$operands, - Variadic<AnyType>:$initial_values, - OptionalAttr<I64ElementsAttr>:$dimensions, - OptionalAttr<I64ElementsAttr>:$window_dimensions, - OptionalAttr<I64ElementsAttr>:$window_strides, - OptionalAttr<I64ElementsAttr>:$base_dilations, - OptionalAttr<I64ElementsAttr>:$window_dilations, - OptionalAttr<IREE_PaddingModeAttr>:$padding_mode - ); - let results = (outs Variadic<AnyType>); - let regions = (region AnyRegion:$body); - - let extraClassDeclaration = [{ - // TODO(b/132296600): make tablegen follow the style guide. - Value *getWorkload() { return workload(); } - Region& getBody() { return body(); } - - bool isWindowed() { - return window_dimensions().hasValue(); - } - - PaddingMode getPaddingMode() { - return static_cast<PaddingMode>(padding_mode().getValue()); - } - - unsigned getNumReductionOperands() { return (getNumOperands() - 1) / 2; } - operand_range getReductionOperands() { return getODSOperands(1); } - operand_range getInitialValueOperands() { return getODSOperands(2); } - }]; - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<"Builder *builder, OperationState &state," - "ArrayRef<Type> resultTypes, Value *workload, ArrayRef<Value *> operands," - "ArrayRef<Value *> initialValues," - "ArrayRef<int64_t> dimensions," - "ArrayRef<NamedAttribute> attributes = {}">, - OpBuilder<"Builder *builder, OperationState &state," - "ArrayRef<Type> resultTypes, Value *workload, ArrayRef<Value *> operands," - "ArrayRef<Value *> initialValues," - "ArrayRef<int64_t> windowDimensions, ArrayRef<int64_t> windowStrides," - "ArrayRef<int64_t> baseDilations, ArrayRef<int64_t> windowDilations," - "PaddingMode paddingMode," - "ArrayRef<NamedAttribute> attributes = {}">, - ]; -} - -def IREE_ReturnOp : IREE_Op<"return", [Terminator]> { - let arguments = (ins Variadic<AnyType>:$operands); - - let builders = [OpBuilder< - "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }] - >]; -} - -def IREE_LoadInputOp : IREE_PureOp<"load_input"> { - let arguments = (ins IREEHL_MemRef:$src); - let results = (outs AnyType); -} - -def IREE_StoreOutputOp : IREE_Op<"store_output"> { - let arguments = (ins AnyType:$src, IREEHL_MemRef:$dst); -} - -#endif // IREE_OPS
diff --git a/iree/compiler/IR/Sequencer/BUILD b/iree/compiler/IR/Sequencer/BUILD deleted file mode 100644 index ba722cb..0000000 --- a/iree/compiler/IR/Sequencer/BUILD +++ /dev/null
@@ -1,76 +0,0 @@ -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -load("@local_config_mlir//:tblgen.bzl", "gentbl") - -filegroup( - name = "td_files", - srcs = glob(["*.td"]), -) - -cc_library( - name = "Sequencer", - srcs = [ - "HLDialect.cpp", - "HLOps.cpp", - "HLOps.cpp.inc", - "LLDialect.cpp", - "LLOps.cpp", - "LLOps.cpp.inc", - "OpWriters.cpp", - ], - hdrs = [ - "HLDialect.h", - "HLOps.h", - "HLOps.h.inc", - "LLDialect.h", - "LLOps.h", - "LLOps.h.inc", - "OpWriters.h", - ], - deps = [ - ":HLOpsGen", - ":LLOpsGen", - "//iree/compiler/IR", - "//iree/compiler/Serialization", - "//iree/compiler/Utils", - "//iree/schemas/bytecode:sequencer_bytecode_v0", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Support", - ], - alwayslink = 1, -) - -gentbl( - name = "HLOpsGen", - tbl_outs = [ - ("-gen-op-decls", "HLOps.h.inc"), - ("-gen-op-defs", "HLOps.cpp.inc"), - ], - tblgen = "@local_config_mlir//:mlir-tblgen", - td_file = "HLOps.td", - td_srcs = [ - ":td_files", - "@local_config_mlir//:include/mlir/IR/OpBase.td", - "//iree/compiler/IR:OpBase.td", - ], -) - -gentbl( - name = "LLOpsGen", - tbl_outs = [ - ("-gen-op-decls", "LLOps.h.inc"), - ("-gen-op-defs", "LLOps.cpp.inc"), - ], - tblgen = "@local_config_mlir//:mlir-tblgen", - td_file = "LLOps.td", - td_srcs = [ - ":td_files", - "@local_config_mlir//:include/mlir/IR/OpBase.td", - "//iree/compiler/IR:OpBase.td", - ], -)
diff --git a/iree/compiler/IR/Sequencer/HLDialect.cpp b/iree/compiler/IR/Sequencer/HLDialect.cpp deleted file mode 100644 index 7b96908..0000000 --- a/iree/compiler/IR/Sequencer/HLDialect.cpp +++ /dev/null
@@ -1,34 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Sequencer/HLDialect.h" - -#include "iree/compiler/IR/Sequencer/HLOps.h" -#include "llvm/Support/SourceMgr.h" - -namespace mlir { -namespace iree_compiler { - -IREEHLSequencerDialect::IREEHLSequencerDialect(MLIRContext* context) - : Dialect(getDialectNamespace(), context) { -#define GET_OP_LIST - addOperations< -#include "iree/compiler/IR/Sequencer/HLOps.cpp.inc" - >(); -} - -static DialectRegistration<IREEHLSequencerDialect> iree_hl_seq_dialect; - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/IR/Sequencer/HLOps.cpp b/iree/compiler/IR/Sequencer/HLOps.cpp deleted file mode 100644 index 63a5d6f..0000000 --- a/iree/compiler/IR/Sequencer/HLOps.cpp +++ /dev/null
@@ -1,379 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Sequencer/HLOps.h" - -#include "iree/compiler/IR/Ops.h" -#include "iree/compiler/IR/Types.h" -#include "iree/compiler/Utils/OpCreationUtils.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Function.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/StandardTypes.h" - -namespace mlir { -namespace iree_compiler { -namespace IREESeq { -namespace HL { - -namespace { - -static LogicalResult verifyWorkload(Operation *op, Value *workload) { - if (auto workloadType = workload->getType().dyn_cast<MemRefType>()) { - if (workloadType.getNumElements() != 3) { - return op->emitOpError("workload must be specified as (x,y,z) but has ") - << workloadType.getNumElements() - << " elements (type=" << workload->getType() << ")"; - } - return success(); - } - return op->emitOpError( - "workload must be specified as an (x,y,z) memref but has type ") - << workload->getType(); -} - -} // namespace - -//===----------------------------------------------------------------------===// -// iree_hl_seq.call -//===----------------------------------------------------------------------===// - -static ParseResult parseCallOp(OpAsmParser &parser, OperationState &state) { - SymbolRefAttr calleeAttr; - FunctionType calleeType; - SmallVector<OpAsmParser::OperandType, 4> operands; - auto calleeLoc = parser.getNameLoc(); - if (parser.parseAttribute(calleeAttr, "callee", state.attributes) || - parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || - parser.parseOptionalAttributeDict(state.attributes) || - parser.parseColonType(calleeType) || - parser.addTypesToList(calleeType.getResults(), state.types) || - parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc, - state.operands)) { - return failure(); - } - return success(); -} - -static void printCallOp(OpAsmPrinter &p, CallOp op) { - p << "iree_hl_seq.call " << op.getAttr("callee") << '('; - p.printOperands(op.getOperands()); - p << ')'; - p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); - p << " : "; - p.printType(op.getCalleeType()); -} - -FunctionType CallOp::getCalleeType() { - SmallVector<Type, 4> resultTypes(getResultTypes()); - SmallVector<Type, 8> argTypes(getOperandTypes()); - return FunctionType::get(argTypes, resultTypes, getContext()); -} - -//===----------------------------------------------------------------------===// -// iree_hl_seq.call_indirect -//===----------------------------------------------------------------------===// - -static ParseResult parseCallIndirectOp(OpAsmParser &parser, - OperationState &result) { - FunctionType calleeType; - OpAsmParser::OperandType callee; - llvm::SMLoc operandsLoc; - SmallVector<OpAsmParser::OperandType, 4> operands; - return failure( - parser.parseOperand(callee) || parser.getCurrentLocation(&operandsLoc) || - parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || - parser.parseOptionalAttributeDict(result.attributes) || - parser.parseColonType(calleeType) || - parser.resolveOperand(callee, calleeType, result.operands) || - parser.resolveOperands(operands, calleeType.getInputs(), operandsLoc, - result.operands) || - parser.addTypesToList(calleeType.getResults(), result.types)); -} - -static void printCallIndirectOp(OpAsmPrinter &p, CallIndirectOp op) { - p << "iree_hl_seq.call_indirect "; - p.printOperand(op.getCallee()); - p << '('; - auto operandRange = op.getOperands(); - p.printOperands(++operandRange.begin(), operandRange.end()); - p << ')'; - p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); - p << " : " << op.getCallee()->getType(); -} - -//===----------------------------------------------------------------------===// -// iree_hl_seq.return -//===----------------------------------------------------------------------===// - -static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &state) { - SmallVector<OpAsmParser::OperandType, 2> opInfo; - SmallVector<Type, 2> types; - llvm::SMLoc loc = parser.getCurrentLocation(); - return failure(parser.parseOperandList(opInfo) || - (!opInfo.empty() && parser.parseColonTypeList(types)) || - parser.resolveOperands(opInfo, types, loc, state.operands)); -} - -static void printReturnOp(OpAsmPrinter &p, ReturnOp op) { - p << "iree_hl_seq.return"; - if (op.getNumOperands() > 0) { - p << ' '; - p.printOperands(op.operand_begin(), op.operand_end()); - p << " : "; - interleaveComma(op.getOperandTypes(), p); - } -} - -//===----------------------------------------------------------------------===// -// iree_hl_seq.br -//===----------------------------------------------------------------------===// - -static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &result) { - Block *dest; - SmallVector<Value *, 4> destOperands; - if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure(); - result.addSuccessor(dest, destOperands); - return success(); -} - -static void printBranchOp(OpAsmPrinter &p, BranchOp op) { - p << "iree_hl_seq.br "; - p.printSuccessorAndUseList(op.getOperation(), 0); -} - -Block *BranchOp::getDest() { return getOperation()->getSuccessor(0); } - -void BranchOp::setDest(Block *block) { - return getOperation()->setSuccessor(block, 0); -} - -void BranchOp::eraseOperand(unsigned index) { - getOperation()->eraseSuccessorOperand(0, index); -} - -//===----------------------------------------------------------------------===// -// iree_hl_seq.cond_br -//===----------------------------------------------------------------------===// - -static ParseResult parseCondBranchOp(OpAsmParser &parser, - OperationState &result) { - SmallVector<Value *, 4> destOperands; - Block *dest; - OpAsmParser::OperandType condInfo; - - // Parse the condition. - Type int1Ty = parser.getBuilder().getI1Type(); - if (parser.parseOperand(condInfo) || parser.parseComma() || - parser.resolveOperand(condInfo, int1Ty, result.operands)) { - return parser.emitError(parser.getNameLoc(), - "expected condition type was boolean (i1)"); - } - - // Parse the true successor. - if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure(); - result.addSuccessor(dest, destOperands); - - // Parse the false successor. - destOperands.clear(); - if (parser.parseComma() || - parser.parseSuccessorAndUseList(dest, destOperands)) - return failure(); - result.addSuccessor(dest, destOperands); - - return success(); -} - -static void printCondBranchOp(OpAsmPrinter &p, CondBranchOp op) { - p << "iree_hl_seq.cond_br "; - p.printOperand(op.getCondition()); - p << ", "; - p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex); - p << ", "; - p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex); -} - -//===----------------------------------------------------------------------===// -// iree_hl_seq.dispatch -//===----------------------------------------------------------------------===// - -static ParseResult parseDispatchOp(OpAsmParser &parser, OperationState &state) { - auto executableLoc = parser.getNameLoc(); - - SymbolRefAttr executableAttr; - SymbolRefAttr entryPointAttr; - FunctionType entryPointType; - if (failed(parser.parseAttribute(executableAttr, "executable", - state.attributes)) || - failed(parser.parseColon()) || failed(parser.parseColon()) || - failed(parser.parseAttribute(entryPointAttr, "entry_point", - state.attributes))) { - return failure(); - } - - OpAsmParser::OperandType workloadArg; - Type workloadArgType; - if (failed(parser.parseLSquare()) || - failed(parser.parseOperand(workloadArg)) || - failed(parser.parseColonType(workloadArgType)) || - failed(parser.parseRSquare()) || - failed(parser.resolveOperand(workloadArg, workloadArgType, - state.operands))) { - return failure(); - } - - SmallVector<OpAsmParser::OperandType, 4> operands; - if (failed( - parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren)) || - failed(parser.parseOptionalAttributeDict(state.attributes)) || - failed(parser.parseColonType(entryPointType)) || - failed(parser.addTypesToList(entryPointType.getResults(), state.types)) || - failed(parser.resolveOperands(operands, entryPointType.getInputs(), - executableLoc, state.operands))) { - return failure(); - } - return success(); -} - -static void printDispatchOp(OpAsmPrinter &p, DispatchOp op) { - p << "iree_hl_seq.dispatch " << op.getExecutable() - << "::" << op.getEntryPoint(); - p << "["; - p.printOperand(op.getWorkload()); - p << " : "; - p.printType(op.getWorkload()->getType()); - p << "]("; - p.printOperands(op.getArgOperands()); - p << ')'; - p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{ - "executable", - "entry_point", - }); - p << " : "; - p.printType(op.getEntryPointType()); -} - -static LogicalResult verifyDispatchOp(DispatchOp op) { - if (failed(verifyWorkload(op, op.getWorkload()))) { - return failure(); - } - return success(); -} - -FunctionType DispatchOp::getEntryPointType() { - SmallVector<Type, 4> resultTypes(getResultTypes()); - SmallVector<Type, 8> argTypes(getArgOperandTypes()); - return FunctionType::get(argTypes, resultTypes, getContext()); -} - -//===----------------------------------------------------------------------===// -// iree_hl_seq.rank -//===----------------------------------------------------------------------===// - -OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) { - Builder builder(getContext()); - if (auto op0 = operands[0].dyn_cast_or_null<ElementsAttr>()) { - return builder.getIntegerAttr(builder.getIntegerType(32), - op0.getType().getRank()); - } - return {}; -} - -//===----------------------------------------------------------------------===// -// iree_hl_seq.shape -//===----------------------------------------------------------------------===// - -void ShapeOp::build(Builder *builder, OperationState &state, Value *operand) { - state.addOperands(operand); - int64_t rank = 0; - if (auto shapedType = operand->getType().dyn_cast<ShapedType>()) { - rank = shapedType.getRank(); - } - state.addTypes(MemRefType::get({rank}, builder->getIntegerType(32))); -} - -OpFoldResult ShapeOp::fold(ArrayRef<Attribute> operands) { - Builder builder(getContext()); - if (auto op0 = operands[0].dyn_cast_or_null<ElementsAttr>()) { - return DenseIntElementsAttr::get( - RankedTensorType::get({op0.getType().getRank()}, - builder.getIntegerType(32)), - op0.getType().getShape()); - } - return {}; -} - -//===----------------------------------------------------------------------===// -// iree_hl_seq.length -//===----------------------------------------------------------------------===// - -OpFoldResult LengthOp::fold(ArrayRef<Attribute> operands) { - Builder builder(getContext()); - if (auto op0 = operands[0].dyn_cast_or_null<ElementsAttr>()) { - return builder.getIntegerAttr(builder.getIntegerType(32), - op0.getNumElements()); - } - return {}; -} - -//===----------------------------------------------------------------------===// -// iree_hl_seq.concat -//===----------------------------------------------------------------------===// - -namespace { -struct ConcatToCopies : public OpRewritePattern<ConcatOp> { - using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(ConcatOp concatOp, - PatternRewriter &rewriter) const override { - auto finalType = concatOp.getResult()->getType().cast<ShapedType>(); - auto loc = concatOp.getLoc(); - std::vector<Value *> dimPieces; - auto dst = - rewriter.create<IREESeq::HL::AllocHeapOp>(loc, finalType, dimPieces); - - llvm::SmallVector<int64_t, 4> zeroOffset(finalType.getRank(), 0); - auto srcIndices = createArrayConstant(rewriter, loc, zeroOffset); - - auto concatDimension = concatOp.dimension().getZExtValue(); - llvm::SmallVector<int64_t, 4> dstIndices(finalType.getRank(), 0); - for (auto *src : concatOp.srcs()) { - auto srcShape = src->getType().cast<ShapedType>().getShape(); - auto lengths = createArrayConstant(rewriter, loc, srcShape); - auto dstIndicesOp = createArrayConstant(rewriter, loc, dstIndices); - rewriter.create<IREESeq::HL::CopyOp>(loc, src, srcIndices, dst, - dstIndicesOp, lengths); - dstIndices[concatDimension] += srcShape[concatDimension]; - } - - concatOp.replaceAllUsesWith(dst.getResult()); - - return matchSuccess(); - } -}; -} // namespace - -void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert<ConcatToCopies>(context); -} - -#define GET_OP_CLASSES -#include "iree/compiler/IR/Sequencer/HLOps.cpp.inc" - -} // namespace HL -} // namespace IREESeq -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/IR/Sequencer/HLOps.h b/iree/compiler/IR/Sequencer/HLOps.h deleted file mode 100644 index 9fb6a4d..0000000 --- a/iree/compiler/IR/Sequencer/HLOps.h +++ /dev/null
@@ -1,39 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_COMPILER_IR_SEQUENCER_HLOPS_H_ -#define IREE_COMPILER_IR_SEQUENCER_HLOPS_H_ - -#include "iree/compiler/IR/Types.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/TypeUtilities.h" - -namespace mlir { -namespace iree_compiler { -namespace IREESeq { -namespace HL { - -#define GET_OP_CLASSES -#include "iree/compiler/IR/Sequencer/HLOps.h.inc" - -} // namespace HL -} // namespace IREESeq -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_IR_SEQUENCER_HLOPS_H_
diff --git a/iree/compiler/IR/Sequencer/HLOps.td b/iree/compiler/IR/Sequencer/HLOps.td deleted file mode 100644 index 1a32f76..0000000 --- a/iree/compiler/IR/Sequencer/HLOps.td +++ /dev/null
@@ -1,429 +0,0 @@ -// IREE high-level sequencer op definitions. -// This op set contains pseudo ops, ops that accept non-MemRef types, and ops in -// normal SSA form. -// -// Through lowering these high-level ops are converted to low-level ops in the -// LLOps.td (iree_ll_seq.*). These map 1:1 with the bytecode, accept -// only MemRef types, and generally use output parameters instead of return -// types. -// -// The source of truth for bytecode opcodes is: -// https://github.com/google/iree/tree/master/iree/schemas/bytecode/sequencer_bytecode_v0.h - -#ifdef IREE_SEQUENCER_HL_OPS -#else -#define IREE_SEQUENCER_HL_OPS - -#ifdef IREE_OP_BASE -#else -include "iree/compiler/IR/OpBase.td" -#endif // IREE_OP_BASE - -def IREESeqHL_Dialect : Dialect { - let name = "iree_hl_seq"; - let cppNamespace = "IREESeq::HL"; -} - -//===----------------------------------------------------------------------===// -// Base op classes -//===----------------------------------------------------------------------===// - -class IREESeqHL_Op<string mnemonic, list<OpTrait> traits = []> : - Op<IREESeqHL_Dialect, mnemonic, traits>; - -class IREESeqHL_PureOp<string mnemonic, list<OpTrait> traits = []> : - IREESeqHL_Op<mnemonic, !listconcat(traits, [NoSideEffect])>; - -//===----------------------------------------------------------------------===// -// High-level sequencer ops -//===----------------------------------------------------------------------===// - -def IREESeqHL_CallOp : IREESeqHL_PureOp<"call"> { - let arguments = (ins SymbolRefAttr:$callee, Variadic<IREEHL_MemRef>); - let results = (outs Variadic<IREEHL_MemRef>); - - let builders = [OpBuilder< - "Builder *builder, OperationState &result, FuncOp callee," - "ArrayRef<Value *> operands = {}", [{ - result.addOperands(operands); - result.addAttribute("callee", builder->getSymbolRefAttr(callee)); - result.addTypes(callee.getType().getResults()); - }]>, OpBuilder< - "Builder *builder, OperationState &result, StringRef callee," - "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{ - result.addOperands(operands); - result.addAttribute("callee", builder->getSymbolRefAttr(callee)); - result.addTypes(results); - }]>]; - - let extraClassDeclaration = [{ - // TODO(b/132296600): make tablegen follow the style guide. - StringRef getCallee() { return callee(); } - FunctionType getCalleeType(); - - // TODO(b/133879130): make tablegen support variadic operand accessors. - /// Get the argument operands to the called function. - operand_range getArgOperands() { - return {arg_operand_begin(), arg_operand_end()}; - } - - operand_iterator arg_operand_begin() { return operand_begin(); } - operand_iterator arg_operand_end() { return operand_end(); } - }]; - - let parser = [{ return parse$cppClass(parser, result); }]; - let printer = [{ return print$cppClass(p, *this); }]; -} - -def IREESeqHL_CallIndirectOp : IREESeqHL_Op<"call_indirect"> { - let arguments = (ins FunctionType:$callee, Variadic<IREEHL_MemRef>:$operands); - let results = (outs Variadic<IREEHL_MemRef>); - - let builders = [OpBuilder< - "Builder *, OperationState &result, Value *callee," - "ArrayRef<Value *> operands = {}", [{ - result.operands.push_back(callee); - result.addOperands(operands); - result.addTypes(callee->getType().cast<FunctionType>().getResults()); - }]>]; - - let extraClassDeclaration = [{ - // TODO(b/132296600): make tablegen follow the style guide. - Value *getCallee() { return getOperand(0); } - - // TODO(b/133879130): make tablegen support variadic operand accessors. - /// Get the argument operands to the called function. - operand_range getArgOperands() { - return {arg_operand_begin(), arg_operand_end()}; - } - operand_iterator arg_operand_begin() { return ++operand_begin(); } - operand_iterator arg_operand_end() { return operand_end(); } - }]; - - let parser = [{ return parse$cppClass(parser, result); }]; - let printer = [{ return print$cppClass(p, *this); }]; -} - -def IREESeqHL_ReturnOp : IREESeqHL_Op<"return", [Terminator]> { - let arguments = (ins Variadic<IREEHL_MemRef>:$operands); - - let builders = [OpBuilder< - "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }] - >]; - - let parser = [{ return parse$cppClass(parser, result); }]; - let printer = [{ return print$cppClass(p, *this); }]; -} - -def IREESeqHL_BranchOp : IREESeqHL_Op<"br", [Terminator]> { - let arguments = (ins Variadic<IREEHL_MemRef>:$operands); - - let skipDefaultBuilders = 1; - let builders = [OpBuilder< - "Builder *, OperationState &result, Block *dest," - "ArrayRef<Value *> operands = {}", [{ - result.addSuccessor(dest, operands); - }]>]; - - let extraClassDeclaration = [{ - Block *getDest(); - void setDest(Block *block); - - /// Erase the operand at 'index' from the operand list. - void eraseOperand(unsigned index); - }]; - - let parser = [{ return parse$cppClass(parser, result); }]; - let printer = [{ return print$cppClass(p, *this); }]; -} - -def IREESeqHL_CondBranchOp : IREESeqHL_Op<"cond_br", [Terminator]> { - let arguments = (ins - IREEHL_BoolScalar:$condition, - Variadic<IREEHL_MemRef>:$branchOperands - ); - - let skipDefaultBuilders = 1; - let builders = [OpBuilder< - "Builder *, OperationState &result, Value *condition, " - "Block *trueDest, ArrayRef<Value *> trueOperands, " - "Block *falseDest, ArrayRef<Value *> falseOperands", [{ - result.addOperands(condition); - result.addSuccessor(trueDest, trueOperands); - result.addSuccessor(falseDest, falseOperands); - }]>]; - - let extraClassDeclaration = [{ - // These are the indices into the dests list. - enum { trueIndex = 0, falseIndex = 1 }; - - // The condition operand is the first operand in the list. - Value *getCondition() { return getOperand(0); } - - /// Return the destination if the condition is true. - Block *getTrueDest() { - return getOperation()->getSuccessor(trueIndex); - } - - /// Return the destination if the condition is false. - Block *getFalseDest() { - return getOperation()->getSuccessor(falseIndex); - } - - // Accessors for operands to the 'true' destination. - Value *getTrueOperand(unsigned idx) { - assert(idx < getNumTrueOperands()); - return getOperand(getTrueDestOperandIndex() + idx); - } - - void setTrueOperand(unsigned idx, Value *value) { - assert(idx < getNumTrueOperands()); - setOperand(getTrueDestOperandIndex() + idx, value); - } - - operand_iterator true_operand_begin() { - return operand_begin() + getTrueDestOperandIndex(); - } - operand_iterator true_operand_end() { - return true_operand_begin() + getNumTrueOperands(); - } - operand_range getTrueOperands() { - return {true_operand_begin(), true_operand_end()}; - } - - unsigned getNumTrueOperands() { - return getOperation()->getNumSuccessorOperands(trueIndex); - } - - /// Erase the operand at 'index' from the true operand list. - void eraseTrueOperand(unsigned index) { - getOperation()->eraseSuccessorOperand(trueIndex, index); - } - - // Accessors for operands to the 'false' destination. - Value *getFalseOperand(unsigned idx) { - assert(idx < getNumFalseOperands()); - return getOperand(getFalseDestOperandIndex() + idx); - } - void setFalseOperand(unsigned idx, Value *value) { - assert(idx < getNumFalseOperands()); - setOperand(getFalseDestOperandIndex() + idx, value); - } - - operand_iterator false_operand_begin() { return true_operand_end(); } - operand_iterator false_operand_end() { - return false_operand_begin() + getNumFalseOperands(); - } - operand_range getFalseOperands() { - return {false_operand_begin(), false_operand_end()}; - } - - unsigned getNumFalseOperands() { - return getOperation()->getNumSuccessorOperands(falseIndex); - } - - /// Erase the operand at 'index' from the false operand list. - void eraseFalseOperand(unsigned index) { - getOperation()->eraseSuccessorOperand(falseIndex, index); - } - - private: - /// Get the index of the first true destination operand. - unsigned getTrueDestOperandIndex() { return 1; } - - /// Get the index of the first false destination operand. - unsigned getFalseDestOperandIndex() { - return getTrueDestOperandIndex() + getNumTrueOperands(); - } - }]; - - let parser = [{ return parse$cppClass(parser, result); }]; - let printer = [{ return print$cppClass(p, *this); }]; -} - -def IREESeqHL_DispatchOp : IREESeqHL_Op<"dispatch"> { - let arguments = (ins - SymbolRefAttr:$executable, - SymbolRefAttr:$entry_point, - IREEHL_IntMemRef:$workload, - Variadic<IREEHL_MemRef>:$operands - ); - let results = (outs Variadic<IREEHL_MemRef>); - - let skipDefaultBuilders = 1; - let builders = [OpBuilder< - "Builder *builder, OperationState &result, StringRef executable," - "StringRef entry_point, Value *workload," - "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{ - result.addOperands({workload}); - result.addOperands(operands); - result.addAttribute("executable", builder->getSymbolRefAttr(executable)); - result.addAttribute("entry_point", builder->getSymbolRefAttr(entry_point)); - result.addTypes(results); - }]>]; - - let extraClassDeclaration = [{ - // TODO(b/132296600): make tablegen follow the style guide. - StringRef getExecutable() { return executable(); } - StringRef getEntryPoint() { return entry_point(); } - FunctionType getEntryPointType(); - - Value *getWorkload() { return getOperand(0); } - - // TODO(b/133879130): make tablegen support variadic operand accessors. - operand_range getArgOperands() { - return {arg_operand_begin(), arg_operand_end()}; - } - operand_iterator arg_operand_begin() { return operand_begin() + 1; } - operand_iterator arg_operand_end() { return operand_end(); } - - operand_type_range getArgOperandTypes() { - return {arg_operand_type_begin(), arg_operand_type_end()}; - } - operand_type_iterator arg_operand_type_begin() { - return operand_type_iterator(arg_operand_begin()); - } - operand_type_iterator arg_operand_type_end() { - return operand_type_iterator(arg_operand_end()); - } - }]; - - let parser = [{ return parse$cppClass(parser, result); }]; - let printer = [{ return print$cppClass(p, *this); }]; - let verifier = [{ return verify$cppClass(*this); }]; -} - -// TODO(b/142012496): Add trait that enables DCE but not CSE. -def IREESeqHL_AllocHeapOp : IREESeqHL_Op<"alloc_heap"> { - // TODO(benvanik): attributes and args. - let arguments = (ins Variadic<IREEHL_IntMemRef>:$dim_pieces); - let results = (outs IREEHL_MemRef); -} - -def IREESeqHL_DiscardOp : IREESeqHL_Op<"discard"> { - let arguments = (ins IREEHL_MemRef); -} - -def IREESeqHL_RankOp : IREESeqHL_PureOp<"rank"> { - let arguments = (ins IREEHL_MemRef); - let results = (outs IREEHL_IntScalar); - - let hasFolder = 1; -} - -def IREESeqHL_DimOp : IREESeqHL_PureOp<"dim"> { - // TODO(benvanik) add dim attr (I32Attr:$dim) - let arguments = (ins IREEHL_MemRef); - let results = (outs IREEHL_IntScalar); -} - -def IREESeqHL_ShapeOp : IREESeqHL_PureOp<"shape"> { - let arguments = (ins IREEHL_MemRef); - let results = (outs IREEHL_1DIntMemRef); - - let skipDefaultBuilders = 1; - let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *operand">]; - - let hasFolder = 1; -} - -def IREESeqHL_LengthOp : IREESeqHL_PureOp<"length"> { - let arguments = (ins IREEHL_MemRef); - let results = (outs IREEHL_IndexScalar); - - let hasFolder = 1; -} - -def IREESeqHL_SliceOp : - IREESeqHL_PureOp<"slice", [AllElementTypesMatch<["src", "result"]>, - AllTypesMatch<["indices", "lengths"]>]> { - let arguments = (ins - IREEHL_MemRef:$src, - IREEHL_1DIndexMemRef:$indices, - IREEHL_1DIndexMemRef:$lengths - ); - let results = (outs IREEHL_MemRef:$result); -} - -def IREESeqHL_CopyOp : IREESeqHL_Op<"copy", [ - AllElementCountsMatch<["srcIndices", "dstIndices", "lengths"]>, - AllRanksMatch<["src", "dst"]>, - // The checks above are redundant with this one, but they give more specific - // error messages. - AllMatch<[ - Rank<"src">.result, - Rank<"dst">.result, - ElementCount<"srcIndices">.result, - ElementCount<"dstIndices">.result, - ElementCount<"lengths">.result - ], "src/dst rank is the same as srcIndices/dstIndices/lengths size">, - AllElementTypesMatch<["src", "dst"]> -]> { - let arguments = (ins - IREEHL_MemRef:$src, - IREEHL_1DIndexMemRef:$srcIndices, - IREEHL_MemRef:$dst, - IREEHL_1DIndexMemRef:$dstIndices, - IREEHL_1DIndexMemRef:$lengths - ); -} - -def IREESeqHL_FillOp : IREESeqHL_Op<"fill"> { - let arguments = (ins - IREEHL_I32Scalar:$value, - IREEHL_MemRef:$dst, - IREEHL_1DIndexMemRef:$dstIndices, - IREEHL_1DIndexMemRef:$lengths - ); -} - -def IREESeqHL_CloneOp : IREESeqHL_PureOp<"clone", [SameOperandsAndResultType]> { - let arguments = (ins IREEHL_MemRef:$src); - let results = (outs IREEHL_MemRef); -} - -// A pseudo op provided for convenience. This gets canonicalized to a series of -// copies. -def IREESeqHL_ConcatOp : IREESeqHL_PureOp<"concat"> { - // TODO(b/135032064) Add type constraints when they support variadic - let arguments = (ins - Variadic<IREEHL_MemRef>:$srcs, - I32Attr:$dimension - ); - let results = (outs IREEHL_MemRef); - - let hasCanonicalizer = 1; -} - -def IREESeqHL_AssignOp : - IREESeqHL_PureOp<"assign", [SameOperandsAndResultType]> { - let arguments = (ins IREEHL_MemRef:$src); - let results = (outs IREEHL_MemRef); -} - -def IREESeqHL_CondAssignOp : IREESeqHL_PureOp<"cond_assign"> { - let arguments = (ins - IREEHL_BoolScalar:$cond, - IREEHL_MemRef:$lhs, - IREEHL_MemRef:$rhs - ); - let results = (outs IREEHL_MemRef); -} - -def IREESeqHL_ReshapeOp : IREESeqHL_PureOp<"reshape"> { - let arguments = (ins IREEHL_MemRef:$src, IREEHL_MemRef:$shape); - let results = (outs IREEHL_MemRef); -} - -def IREESeqHL_TraceOp : IREESeqHL_Op<"trace"> { - let arguments = (ins Variadic<IREEHL_MemRef>:$srcs); -} - -def IREESeqHL_CondBreakOp : IREESeqHL_Op<"cond_break"> { - let arguments = (ins IREEHL_BoolScalar:$cond); -} - -def IREESeqHL_BreakOp : IREESeqHL_Op<"break">; - -#endif // IREE_SEQUENCER_HL_OPS
diff --git a/iree/compiler/IR/Sequencer/LLDialect.cpp b/iree/compiler/IR/Sequencer/LLDialect.cpp deleted file mode 100644 index 872e27f..0000000 --- a/iree/compiler/IR/Sequencer/LLDialect.cpp +++ /dev/null
@@ -1,34 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Sequencer/LLDialect.h" - -#include "iree/compiler/IR/Sequencer/LLOps.h" -#include "llvm/Support/SourceMgr.h" - -namespace mlir { -namespace iree_compiler { - -IREELLSequencerDialect::IREELLSequencerDialect(MLIRContext* context) - : Dialect(getDialectNamespace(), context) { -#define GET_OP_LIST - addOperations< -#include "iree/compiler/IR/Sequencer/LLOps.cpp.inc" - >(); -} - -static DialectRegistration<IREELLSequencerDialect> iree_ll_seq_dialect; - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/IR/Sequencer/LLOps.cpp b/iree/compiler/IR/Sequencer/LLOps.cpp deleted file mode 100644 index c5b132b..0000000 --- a/iree/compiler/IR/Sequencer/LLOps.cpp +++ /dev/null
@@ -1,671 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Sequencer/LLOps.h" - -#include "iree/compiler/IR/Ops.h" -#include "iree/compiler/Utils/OpUtils.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Function.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/Module.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Support/STLExtras.h" - -namespace mlir { -namespace iree_compiler { -namespace IREESeq { -namespace LL { - -namespace { - -static LogicalResult verifyWorkload(Operation *op, Value *workload) { - if (auto workloadType = workload->getType().dyn_cast<MemRefType>()) { - if (workloadType.getNumElements() != 3) { - return op->emitOpError("workload must be specified as (x,y,z) but has ") - << workloadType.getNumElements() - << " elements (type=" << workload->getType() << ")"; - } - return success(); - } - return op->emitOpError( - "workload must be specified as an (x,y,z) memref but has type ") - << workload->getType(); -} - -static LogicalResult verifyWorkload(Operation *op, ElementsAttr workload) { - if (workload.getNumElements() != 3) { - return op->emitOpError("workload must be specified as (x,y,z) but has ") - << workload.getNumElements() << " elements (value=" << workload - << ")"; - } - return success(); -} - -} // namespace - -//===----------------------------------------------------------------------===// -// iree_ll_seq.constant -//===----------------------------------------------------------------------===// - -OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { - return getValue(); -} - -//===----------------------------------------------------------------------===// -// iree_ll_seq.call -//===----------------------------------------------------------------------===// - -static ParseResult parseCallOp(OpAsmParser &parser, OperationState &state) { - SymbolRefAttr calleeAttr; - FunctionType calleeType; - SmallVector<OpAsmParser::OperandType, 4> operands; - auto calleeLoc = parser.getNameLoc(); - if (parser.parseAttribute(calleeAttr, "callee", state.attributes) || - parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || - parser.parseOptionalAttributeDict(state.attributes) || - parser.parseColonType(calleeType) || - parser.addTypesToList(calleeType.getResults(), state.types) || - parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc, - state.operands)) { - return failure(); - } - return success(); -} - -static void printCallOp(OpAsmPrinter &p, CallOp op) { - p << "iree_ll_seq.call " << op.getAttr("callee") << '('; - p.printOperands(op.getOperands()); - p << ')'; - p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); - p << " : "; - p.printType(op.getCalleeType()); -} - -FunctionType CallOp::getCalleeType() { - SmallVector<Type, 4> resultTypes(getResultTypes()); - SmallVector<Type, 8> argTypes(getOperandTypes()); - return FunctionType::get(argTypes, resultTypes, getContext()); -} - -//===----------------------------------------------------------------------===// -// iree_ll_seq.call_import -//===----------------------------------------------------------------------===// - -static ParseResult parseCallImportOp(OpAsmParser &parser, - OperationState &state) { - SymbolRefAttr calleeAttr; - FunctionType calleeType; - SmallVector<OpAsmParser::OperandType, 4> operands; - auto calleeLoc = parser.getNameLoc(); - if (parser.parseAttribute(calleeAttr, "callee", state.attributes) || - parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || - parser.parseOptionalAttributeDict(state.attributes) || - parser.parseColonType(calleeType) || - parser.addTypesToList(calleeType.getResults(), state.types) || - parser.resolveOperands(operands, calleeType.getInputs(), calleeLoc, - state.operands)) { - return failure(); - } - return success(); -} - -static void printCallImportOp(OpAsmPrinter &p, CallImportOp op) { - p << "iree_ll_seq.call_import " << op.getAttr("callee") << '('; - p.printOperands(op.getOperands()); - p << ')'; - p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); - p << " : "; - p.printType(op.getCalleeType()); -} - -FunctionType CallImportOp::getCalleeType() { - SmallVector<Type, 4> resultTypes(getResultTypes()); - SmallVector<Type, 8> argTypes(getOperandTypes()); - return FunctionType::get(argTypes, resultTypes, getContext()); -} - -//===----------------------------------------------------------------------===// -// iree_ll_seq.call_indirect -//===----------------------------------------------------------------------===// - -static ParseResult parseCallIndirectOp(OpAsmParser &parser, - OperationState &result) { - FunctionType calleeType; - OpAsmParser::OperandType callee; - llvm::SMLoc operandsLoc; - SmallVector<OpAsmParser::OperandType, 4> operands; - return failure( - parser.parseOperand(callee) || parser.getCurrentLocation(&operandsLoc) || - parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || - parser.parseOptionalAttributeDict(result.attributes) || - parser.parseColonType(calleeType) || - parser.resolveOperand(callee, calleeType, result.operands) || - parser.resolveOperands(operands, calleeType.getInputs(), operandsLoc, - result.operands) || - parser.addTypesToList(calleeType.getResults(), result.types)); -} - -static void printCallIndirectOp(OpAsmPrinter &p, CallIndirectOp op) { - p << "iree_ll_seq.call_indirect "; - p.printOperand(op.getCallee()); - p << '('; - auto operandRange = op.getOperands(); - p.printOperands(++operandRange.begin(), operandRange.end()); - p << ')'; - p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"callee"}); - p << " : " << op.getCallee()->getType(); -} - -//===----------------------------------------------------------------------===// -// iree_ll_seq.return -//===----------------------------------------------------------------------===// - -static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &state) { - SmallVector<OpAsmParser::OperandType, 2> opInfo; - SmallVector<Type, 2> types; - llvm::SMLoc loc = parser.getCurrentLocation(); - return failure(parser.parseOperandList(opInfo) || - (!opInfo.empty() && parser.parseColonTypeList(types)) || - parser.resolveOperands(opInfo, types, loc, state.operands)); -} - -static void printReturnOp(OpAsmPrinter &p, ReturnOp op) { - p << "iree_ll_seq.return"; - if (op.getNumOperands() > 0) { - p << ' '; - p.printOperands(op.operand_begin(), op.operand_end()); - p << " : "; - interleaveComma(op.getOperandTypes(), p); - } -} - -//===----------------------------------------------------------------------===// -// iree_ll_seq.br -//===----------------------------------------------------------------------===// - -static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &result) { - Block *dest; - SmallVector<Value *, 4> destOperands; - if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure(); - result.addSuccessor(dest, destOperands); - return success(); -} - -static void printBranchOp(OpAsmPrinter &p, BranchOp op) { - p << "iree_ll_seq.br "; - p.printSuccessorAndUseList(op.getOperation(), 0); -} - -Block *BranchOp::getDest() { return getOperation()->getSuccessor(0); } - -void BranchOp::setDest(Block *block) { - return getOperation()->setSuccessor(block, 0); -} - -void BranchOp::eraseOperand(unsigned index) { - getOperation()->eraseSuccessorOperand(0, index); -} - -//===----------------------------------------------------------------------===// -// iree_ll_seq.cond_br -//===----------------------------------------------------------------------===// - -static ParseResult parseCondBranchOp(OpAsmParser &parser, - OperationState &result) { - SmallVector<Value *, 4> destOperands; - Block *dest; - OpAsmParser::OperandType condInfo; - - // Parse the condition. - Type int1Ty = parser.getBuilder().getI1Type(); - if (parser.parseOperand(condInfo) || parser.parseComma() || - parser.resolveOperand(condInfo, int1Ty, result.operands)) { - return parser.emitError(parser.getNameLoc(), - "expected condition type was boolean (i1)"); - } - - // Parse the true successor. - if (parser.parseSuccessorAndUseList(dest, destOperands)) return failure(); - result.addSuccessor(dest, destOperands); - - // Parse the false successor. - destOperands.clear(); - if (parser.parseComma() || - parser.parseSuccessorAndUseList(dest, destOperands)) - return failure(); - result.addSuccessor(dest, destOperands); - - return success(); -} - -static void printCondBranchOp(OpAsmPrinter &p, CondBranchOp op) { - p << "iree_ll_interp.cond_br "; - p.printOperand(op.getCondition()); - p << ", "; - p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::trueIndex); - p << ", "; - p.printSuccessorAndUseList(op.getOperation(), CondBranchOp::falseIndex); -} - -//===----------------------------------------------------------------------===// -// iree_ll_seq.dynamic_dispatch -//===----------------------------------------------------------------------===// - -static ParseResult parseDynamicDispatchOp(OpAsmParser &parser, - OperationState &state) { - auto executableLoc = parser.getNameLoc(); - - SymbolRefAttr executableAttr; - SymbolRefAttr entryPointAttr; - FunctionType entryPointType; - if (failed(parser.parseAttribute(executableAttr, "executable", - state.attributes)) || - failed(parser.parseColon()) || failed(parser.parseColon()) || - failed(parser.parseAttribute(entryPointAttr, "entry_point", - state.attributes))) { - return failure(); - } - - OpAsmParser::OperandType workloadArg; - Type workloadArgType; - if (failed(parser.parseLSquare()) || - failed(parser.parseOperand(workloadArg)) || - failed(parser.parseColonType(workloadArgType)) || - failed(parser.parseRSquare()) || - failed(parser.resolveOperand(workloadArg, workloadArgType, - state.operands))) { - return failure(); - } - - SmallVector<OpAsmParser::OperandType, 4> operands; - if (failed( - parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren)) || - failed(parser.parseOptionalAttributeDict(state.attributes)) || - failed(parser.parseColonType(entryPointType)) || - failed(parser.addTypesToList(entryPointType.getResults(), state.types)) || - failed(parser.resolveOperands(operands, entryPointType.getInputs(), - executableLoc, state.operands))) { - return failure(); - } - return success(); -} - -static void printDynamicDispatchOp(OpAsmPrinter &p, DynamicDispatchOp op) { - p << "iree_ll_seq.dynamic_dispatch " << op.getExecutable() - << "::" << op.getEntryPoint(); - p << "["; - p.printOperand(op.getWorkload()); - p << " : "; - p.printType(op.getWorkload()->getType()); - p << "]("; - p.printOperands(op.getArgOperands()); - p << ')'; - p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{ - "executable", - "entry_point", - }); - p << " : "; - p.printType(op.getEntryPointType()); -} - -static LogicalResult verifyDynamicDispatchOp(DynamicDispatchOp op) { - if (failed(verifyWorkload(op, op.getWorkload()))) { - return failure(); - } - return success(); -} - -FunctionType DynamicDispatchOp::getEntryPointType() { - SmallVector<Type, 4> resultTypes(getResultTypes()); - SmallVector<Type, 8> argTypes(getArgOperandTypes()); - return FunctionType::get(argTypes, resultTypes, getContext()); -} - -namespace { -struct MakeDynamicDispatchOpStatic - : public OpRewritePattern<DynamicDispatchOp> { - using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(DynamicDispatchOp dynamicDispatchOp, - PatternRewriter &rewriter) const override { - ElementsAttr workloadAttr; - if (!matchPattern(dynamicDispatchOp.getWorkload(), - m_Constant(&workloadAttr))) { - return matchFailure(); - } - - SmallVector<Type, 8> resultTypes{dynamicDispatchOp.getResultTypes()}; - SmallVector<Value *, 8> operands{dynamicDispatchOp.getArgOperands()}; - rewriter.replaceOpWithNewOp<IREESeq::LL::StaticDispatchOp>( - dynamicDispatchOp, dynamicDispatchOp.getExecutable(), - dynamicDispatchOp.getEntryPoint(), workloadAttr, resultTypes, operands); - return matchSuccess(); - } -}; -} // namespace - -void DynamicDispatchOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert<MakeDynamicDispatchOpStatic>(context); -} - -//===----------------------------------------------------------------------===// -// iree_ll_seq.static_dispatch -//===----------------------------------------------------------------------===// - -static ParseResult parseStaticDispatchOp(OpAsmParser &parser, - OperationState &state) { - auto executableLoc = parser.getNameLoc(); - - SymbolRefAttr executableAttr; - SymbolRefAttr entryPointAttr; - FunctionType entryPointType; - if (failed(parser.parseAttribute(executableAttr, "executable", - state.attributes)) || - failed(parser.parseColon()) || failed(parser.parseColon()) || - failed(parser.parseAttribute(entryPointAttr, "entry_point", - state.attributes))) { - return failure(); - } - - ElementsAttr workloadAttr; - if (failed(parser.parseLSquare()) || - failed( - parser.parseAttribute(workloadAttr, "workload", state.attributes)) || - failed(parser.parseRSquare())) { - return failure(); - } - - SmallVector<OpAsmParser::OperandType, 4> operands; - if (failed( - parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren)) || - failed(parser.parseOptionalAttributeDict(state.attributes)) || - failed(parser.parseColonType(entryPointType)) || - failed(parser.addTypesToList(entryPointType.getResults(), state.types)) || - failed(parser.resolveOperands(operands, entryPointType.getInputs(), - executableLoc, state.operands))) { - return failure(); - } - return success(); -} - -static void printStaticDispatchOp(OpAsmPrinter &p, StaticDispatchOp op) { - p << "iree_ll_seq.static_dispatch " << op.getExecutable() - << "::" << op.getEntryPoint(); - p << "["; - p.printAttribute(op.getWorkload()); - p << "]("; - p.printOperands(op.getArgOperands()); - p << ')'; - p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{ - "executable", - "entry_point", - "workload", - }); - p << " : "; - p.printType(op.getEntryPointType()); -} - -static LogicalResult verifyStaticDispatchOp(StaticDispatchOp op) { - if (failed(verifyWorkload(op, op.getWorkload()))) { - return failure(); - } - return success(); -} - -FunctionType StaticDispatchOp::getEntryPointType() { - SmallVector<Type, 4> resultTypes(getResultTypes()); - SmallVector<Type, 8> argTypes(getArgOperandTypes()); - return FunctionType::get(argTypes, resultTypes, getContext()); -} - -//===----------------------------------------------------------------------===// -// iree_ll_seq.shape -//===----------------------------------------------------------------------===// - -namespace { -struct FoldShapeOp : public OpRewritePattern<ShapeOp> { - using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(ShapeOp shapeOp, - PatternRewriter &rewriter) const override { - auto memRefType = shapeOp.input()->getType().cast<MemRefType>(); - if (memRefType.hasStaticShape()) { - auto constantOp = rewriter.create<IREESeq::LL::ConstantOp>( - shapeOp.getLoc(), - MemRefType::get({memRefType.getRank()}, rewriter.getIntegerType(64)), - DenseIntElementsAttr::get( - RankedTensorType::get({memRefType.getRank()}, - rewriter.getIntegerType(64)), - memRefType.getShape())); - replaceSubsequentUses(shapeOp, shapeOp.dst(), constantOp.getResult()); - rewriter.eraseOp(shapeOp); - return matchSuccess(); - } - return matchFailure(); - } -}; -} // namespace - -void ShapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert<FoldShapeOp>(context); -} - -//===----------------------------------------------------------------------===// -// iree_ll_seq.length -//===----------------------------------------------------------------------===// - -namespace { -struct FoldLengthOp : public OpRewritePattern<LengthOp> { - using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(LengthOp lengthOp, - PatternRewriter &rewriter) const override { - auto memRefType = lengthOp.input()->getType().cast<MemRefType>(); - if (memRefType.hasStaticShape()) { - auto constantOp = rewriter.create<IREESeq::LL::ConstantOp>( - lengthOp.getLoc(), MemRefType::get({}, rewriter.getIntegerType(64)), - DenseIntElementsAttr::get( - RankedTensorType::get({}, rewriter.getIntegerType(64)), - {memRefType.getNumElements()})); - replaceSubsequentUses(lengthOp, lengthOp.dst(), constantOp.getResult()); - rewriter.eraseOp(lengthOp); - return matchSuccess(); - } - return matchFailure(); - } -}; -} // namespace - -void LengthOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert<FoldLengthOp>(context); -} - -//===----------------------------------------------------------------------===// -// iree_ll_seq.compute_offset -//===----------------------------------------------------------------------===// - -namespace { -struct FoldComputeOffsetOp : public OpRewritePattern<ComputeOffsetOp> { - using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(ComputeOffsetOp computeOffsetOp, - PatternRewriter &rewriter) const override { - ElementsAttr shapeAttr; - ElementsAttr indicesAttr; - if (!matchPattern(computeOffsetOp.shape(), m_Constant(&shapeAttr)) || - !matchPattern(computeOffsetOp.indices(), m_Constant(&indicesAttr))) { - return matchFailure(); - } - - int64_t offset = 0; - for (unsigned i = 0; i < indicesAttr.getNumElements(); ++i) { - int64_t axisOffset = - indicesAttr.getValue({i}).cast<IntegerAttr>().getInt(); - for (unsigned j = i + 1; j < shapeAttr.getNumElements(); ++j) { - axisOffset *= shapeAttr.getValue({j}).cast<IntegerAttr>().getInt(); - } - offset += axisOffset; - } - offset *= computeOffsetOp.elementSize().getZExtValue(); - - auto constantOp = rewriter.create<IREESeq::LL::ConstantOp>( - computeOffsetOp.getLoc(), - MemRefType::get({}, rewriter.getIntegerType(64)), - DenseIntElementsAttr::get( - RankedTensorType::get({}, rewriter.getIntegerType(64)), {offset})); - replaceSubsequentUses(computeOffsetOp, computeOffsetOp.dst(), - constantOp.getResult()); - rewriter.eraseOp(computeOffsetOp); - return matchSuccess(); - } -}; -} // namespace - -void ComputeOffsetOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert<FoldComputeOffsetOp>(context); -} - -//===----------------------------------------------------------------------===// -// iree_ll_seq.compute_range -//===----------------------------------------------------------------------===// - -namespace { -struct FoldComputeRangeOp : public OpRewritePattern<ComputeRangeOp> { - using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(ComputeRangeOp computeRangeOp, - PatternRewriter &rewriter) const override { - ElementsAttr shapeAttr; - ElementsAttr indicesAttr; - ElementsAttr lengthsAttr; - if (!matchPattern(computeRangeOp.shape(), m_Constant(&shapeAttr)) || - !matchPattern(computeRangeOp.indices(), m_Constant(&indicesAttr)) || - !matchPattern(computeRangeOp.lengths(), m_Constant(&lengthsAttr))) { - return matchFailure(); - } - - int64_t offset = 0; - int64_t length = computeRangeOp.elementSize().getZExtValue(); - for (unsigned i = 0; i < indicesAttr.getNumElements(); ++i) { - int64_t axisOffset = - indicesAttr.getValue({i}).cast<IntegerAttr>().getInt(); - for (unsigned j = i + 1; j < shapeAttr.getNumElements(); ++j) { - axisOffset *= shapeAttr.getValue({j}).cast<IntegerAttr>().getInt(); - } - offset += axisOffset; - length *= lengthsAttr.getValue({i}).cast<IntegerAttr>().getInt(); - } - offset *= computeRangeOp.elementSize().getZExtValue(); - - auto offsetConstantOp = rewriter.create<IREESeq::LL::ConstantOp>( - computeRangeOp.getLoc(), - MemRefType::get({}, rewriter.getIntegerType(64)), - DenseIntElementsAttr::get( - RankedTensorType::get({}, rewriter.getIntegerType(64)), {offset})); - replaceSubsequentUses(computeRangeOp, computeRangeOp.dstOffset(), - offsetConstantOp.getResult()); - auto lengthConstantOp = rewriter.create<IREESeq::LL::ConstantOp>( - computeRangeOp.getLoc(), - MemRefType::get({}, rewriter.getIntegerType(64)), - DenseIntElementsAttr::get( - RankedTensorType::get({}, rewriter.getIntegerType(64)), {length})); - replaceSubsequentUses(computeRangeOp, computeRangeOp.dstLength(), - lengthConstantOp.getResult()); - rewriter.eraseOp(computeRangeOp); - return matchSuccess(); - } -}; -} // namespace - -void ComputeRangeOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert<FoldComputeRangeOp>(context); -} - -//===----------------------------------------------------------------------===// -// iree_ll_seq.dynamic_copy -//===----------------------------------------------------------------------===// - -namespace { -struct MakeDynamicCopyOpStatic : public OpRewritePattern<DynamicCopyOp> { - using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(DynamicCopyOp dynamicCopyOp, - PatternRewriter &rewriter) const override { - ElementsAttr srcOffsetAttr; - ElementsAttr dstOffsetAttr; - ElementsAttr lengthAttr; - if (!matchPattern(dynamicCopyOp.srcOffset(), m_Constant(&srcOffsetAttr)) || - !matchPattern(dynamicCopyOp.dstOffset(), m_Constant(&dstOffsetAttr)) || - !matchPattern(dynamicCopyOp.length(), m_Constant(&lengthAttr))) { - return matchFailure(); - } - - rewriter.replaceOpWithNewOp<IREESeq::LL::StaticCopyOp>( - dynamicCopyOp, dynamicCopyOp.src(), - srcOffsetAttr.getValue({}).cast<IntegerAttr>(), dynamicCopyOp.dst(), - dstOffsetAttr.getValue({}).cast<IntegerAttr>(), - lengthAttr.getValue({}).cast<IntegerAttr>()); - return matchSuccess(); - } -}; -} // namespace - -void DynamicCopyOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert<MakeDynamicCopyOpStatic>(context); -} - -//===----------------------------------------------------------------------===// -// iree_ll_seq.dynamic_fill -//===----------------------------------------------------------------------===// - -namespace { -struct MakeDynamicFillOpStatic : public OpRewritePattern<DynamicFillOp> { - using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(DynamicFillOp dynamicFillOp, - PatternRewriter &rewriter) const override { - ElementsAttr valueAttr; - ElementsAttr dstOffsetAttr; - ElementsAttr lengthAttr; - if (!matchPattern(dynamicFillOp.value(), m_Constant(&valueAttr)) || - !matchPattern(dynamicFillOp.dstOffset(), m_Constant(&dstOffsetAttr)) || - !matchPattern(dynamicFillOp.length(), m_Constant(&lengthAttr))) { - return matchFailure(); - } - - rewriter.replaceOpWithNewOp<IREESeq::LL::StaticFillOp>( - dynamicFillOp, valueAttr.getValue({}).cast<IntegerAttr>(), - dynamicFillOp.dst(), dstOffsetAttr.getValue({}).cast<IntegerAttr>(), - lengthAttr.getValue({}).cast<IntegerAttr>()); - return matchSuccess(); - } -}; -} // namespace - -void DynamicFillOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert<MakeDynamicFillOpStatic>(context); -} - -#define GET_OP_CLASSES -#include "iree/compiler/IR/Sequencer/LLOps.cpp.inc" - -} // namespace LL -} // namespace IREESeq -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/IR/Sequencer/LLOps.h b/iree/compiler/IR/Sequencer/LLOps.h deleted file mode 100644 index 9a29ba0..0000000 --- a/iree/compiler/IR/Sequencer/LLOps.h +++ /dev/null
@@ -1,38 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_COMPILER_IR_SEQUENCER_LLOPS_H_ -#define IREE_COMPILER_IR_SEQUENCER_LLOPS_H_ - -#include "iree/compiler/IR/Types.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/StandardTypes.h" - -namespace mlir { -namespace iree_compiler { -namespace IREESeq { -namespace LL { - -#define GET_OP_CLASSES -#include "iree/compiler/IR/Sequencer/LLOps.h.inc" - -} // namespace LL -} // namespace IREESeq -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_IR_SEQUENCER_LLOPS_H_
diff --git a/iree/compiler/IR/Sequencer/LLOps.td b/iree/compiler/IR/Sequencer/LLOps.td deleted file mode 100644 index 6f352fb..0000000 --- a/iree/compiler/IR/Sequencer/LLOps.td +++ /dev/null
@@ -1,579 +0,0 @@ -// IREE low-level sequencer op definitions. -// These map 1:1 with the bytecode, accept only MemRef types and generally use -// output parameters instead of return types. -// -// The source of truth for bytecode opcodes is: -// https://github.com/google/iree/tree/master/iree/schemas/bytecode/sequencer_bytecode_v0.h -// -// Note that in this dialect we cannot use folders: they require that all -// operands are possible to make constants where we use output arguments that -// will never be constant. Instead we can use canonicalization patterns to -// match constant input operands and do the folding by replacing output operands -// with the new values. - -#ifdef IREE_SEQUENCER_LL_OPS -#else -#define IREE_SEQUENCER_LL_OPS - -#ifdef IREE_OP_BASE -#else -include "iree/compiler/IR/OpBase.td" -#endif // IREE_OP_BASE - -def IREESeqLL_Dialect : Dialect { - let name = "iree_ll_seq"; - let cppNamespace = "IREESeq::LL"; -} - -//===----------------------------------------------------------------------===// -// Base op classes -//===----------------------------------------------------------------------===// - -class IREESeqLL_Op<string mnemonic, list<OpTrait> traits = []> : - Op<IREESeqLL_Dialect, mnemonic, traits> { - bit hasCustomSerializer = 0; -} - -class IREESeqLL_PureOp<string mnemonic, list<OpTrait> traits = []> : - IREESeqLL_Op<mnemonic, !listconcat(traits, [NoSideEffect])>; - -class IREESeqLL_UnaryOp<string mnemonic, Type type = IREELL_MemRef, - list<OpTrait> traits = []> : IREESeqLL_Op<mnemonic, traits> { - let arguments = (ins type:$input, type:$dst); -} - -class IREESeqLL_BinaryOp<string mnemonic, Type type = IREELL_MemRef, - list<OpTrait> traits = []> : IREESeqLL_Op<mnemonic, traits> { - let arguments = (ins type:$lhs, type:$rhs, type:$dst); -} - -class IREESeqLL_TernaryOp<string mnemonic, Type type = IREELL_MemRef, - list<OpTrait> traits = []> - : IREESeqLL_Op<mnemonic, traits> { - let arguments = (ins type : $a, type : $b, type : $c, type : $dst); -} - -//===----------------------------------------------------------------------===// -// Low-level sequencer ops -//===----------------------------------------------------------------------===// - -def IREESeqLL_ConstantOp : IREESeqLL_PureOp<"constant"> { - let arguments = (ins ElementsAttr:$value); - let results = (outs IREELL_MemRef); - - // TODO(b/132296600): make tablegen follow the style guide. - let extraClassDeclaration = [{ - Attribute getValue() { return value(); } - }]; - - let hasFolder = 1; -} - -def IREESeqLL_CallOp : IREESeqLL_Op<"call"> { - let arguments = (ins SymbolRefAttr:$callee, Variadic<IREELL_MemRef>); - let results = (outs Variadic<IREELL_MemRef>); - - let builders = [OpBuilder< - "Builder *builder, OperationState &result, FuncOp callee," - "ArrayRef<Value *> operands = {}", [{ - result.addOperands(operands); - result.addAttribute("callee", builder->getSymbolRefAttr(callee)); - result.addTypes(callee.getType().getResults()); - }]>, OpBuilder< - "Builder *builder, OperationState &result, StringRef callee," - "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{ - result.addOperands(operands); - result.addAttribute("callee", builder->getSymbolRefAttr(callee)); - result.addTypes(results); - }]>]; - - let extraClassDeclaration = [{ - // TODO(b/132296600): make tablegen follow the style guide. - StringRef getCallee() { return callee(); } - FunctionType getCalleeType(); - - // TODO(b/133879130): make tablegen support variadic operand accessors. - /// Get the argument operands to the called function. - operand_range getArgOperands() { - return {arg_operand_begin(), arg_operand_end()}; - } - - operand_iterator arg_operand_begin() { return operand_begin(); } - operand_iterator arg_operand_end() { return operand_end(); } - }]; - - let parser = [{ return parse$cppClass(parser, result); }]; - let printer = [{ return print$cppClass(p, *this); }]; -} - -// TODO(benvanik): add verifier that target isExternal. -def IREESeqLL_CallImportOp : IREESeqLL_Op<"call_import"> { - let arguments = (ins SymbolRefAttr:$callee, Variadic<IREELL_MemRef>); - let results = (outs Variadic<IREELL_MemRef>); - - let builders = [OpBuilder< - "Builder *builder, OperationState &result, FuncOp callee," - "ArrayRef<Value *> operands = {}", [{ - result.addOperands(operands); - result.addAttribute("callee", builder->getSymbolRefAttr(callee)); - result.addTypes(callee.getType().getResults()); - }]>, OpBuilder< - "Builder *builder, OperationState &result, StringRef callee," - "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{ - result.addOperands(operands); - result.addAttribute("callee", builder->getSymbolRefAttr(callee)); - result.addTypes(results); - }]>]; - - let extraClassDeclaration = [{ - // TODO(b/132296600): make tablegen follow the style guide. - StringRef getCallee() { return callee(); } - FunctionType getCalleeType(); - - // TODO(b/133879130): make tablegen support variadic operand accessors. - /// Get the argument operands to the called function. - operand_range getArgOperands() { - return {arg_operand_begin(), arg_operand_end()}; - } - - operand_iterator arg_operand_begin() { return operand_begin(); } - operand_iterator arg_operand_end() { return operand_end(); } - }]; - - let parser = [{ return parse$cppClass(parser, result); }]; - let printer = [{ return print$cppClass(p, *this); }]; -} - -def IREESeqLL_CallIndirectOp : IREESeqLL_Op<"call_indirect"> { - let arguments = (ins FunctionType:$callee, Variadic<IREELL_MemRef>:$operands); - let results = (outs Variadic<IREELL_MemRef>); - - let builders = [OpBuilder< - "Builder *, OperationState &result, Value *callee," - "ArrayRef<Value *> operands = {}", [{ - result.operands.push_back(callee); - result.addOperands(operands); - result.addTypes(callee->getType().cast<FunctionType>().getResults()); - }]>]; - - let extraClassDeclaration = [{ - Value *getCallee() { return getOperand(0); } - - /// Get the argument operands to the called function. - operand_range getArgOperands() { - return {arg_operand_begin(), arg_operand_end()}; - } - - operand_iterator arg_operand_begin() { return ++operand_begin(); } - operand_iterator arg_operand_end() { return operand_end(); } - }]; - - let parser = [{ return parse$cppClass(parser, result); }]; - let printer = [{ return print$cppClass(p, *this); }]; -} - -def IREESeqLL_ReturnOp : IREESeqLL_Op<"return", [Terminator]> { - let arguments = (ins Variadic<IREELL_MemRef>:$operands); - - let builders = [OpBuilder< - "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }] - >]; - - let parser = [{ return parse$cppClass(parser, result); }]; - let printer = [{ return print$cppClass(p, *this); }]; -} - -def IREESeqLL_BranchOp : IREESeqLL_Op<"br", [Terminator]> { - let arguments = (ins Variadic<IREELL_MemRef>:$operands); - - let skipDefaultBuilders = 1; - let builders = [OpBuilder< - "Builder *, OperationState &result, Block *dest, " - "ArrayRef<Value *> operands = {}", [{ - result.addSuccessor(dest, operands); - }]>]; - - let extraClassDeclaration = [{ - Block *getDest(); - void setDest(Block *block); - - /// Erase the operand at 'index' from the operand list. - void eraseOperand(unsigned index); - }]; - - let parser = [{ return parse$cppClass(parser, result); }]; - let printer = [{ return print$cppClass(p, *this); }]; -} - -def IREESeqLL_CondBranchOp : IREESeqLL_Op<"cond_br", [Terminator]> { - let arguments = (ins - IREELL_BoolScalar:$condition, - Variadic<IREELL_MemRef>:$branchOperands - ); - - let skipDefaultBuilders = 1; - let builders = [OpBuilder< - "Builder *, OperationState &result, Value *condition," - "Block *trueDest, ArrayRef<Value *> trueOperands," - "Block *falseDest, ArrayRef<Value *> falseOperands", [{ - result.addOperands(condition); - result.addSuccessor(trueDest, trueOperands); - result.addSuccessor(falseDest, falseOperands); - }]>]; - - let extraClassDeclaration = [{ - // These are the indices into the dests list. - enum { trueIndex = 0, falseIndex = 1 }; - - // The condition operand is the first operand in the list. - Value *getCondition() { return getOperand(0); } - - /// Return the destination if the condition is true. - Block *getTrueDest() { - return getOperation()->getSuccessor(trueIndex); - } - - /// Return the destination if the condition is false. - Block *getFalseDest() { - return getOperation()->getSuccessor(falseIndex); - } - - // Accessors for operands to the 'true' destination. - Value *getTrueOperand(unsigned idx) { - assert(idx < getNumTrueOperands()); - return getOperand(getTrueDestOperandIndex() + idx); - } - - void setTrueOperand(unsigned idx, Value *value) { - assert(idx < getNumTrueOperands()); - setOperand(getTrueDestOperandIndex() + idx, value); - } - - operand_iterator true_operand_begin() { - return operand_begin() + getTrueDestOperandIndex(); - } - operand_iterator true_operand_end() { - return true_operand_begin() + getNumTrueOperands(); - } - operand_range getTrueOperands() { - return {true_operand_begin(), true_operand_end()}; - } - - unsigned getNumTrueOperands() { - return getOperation()->getNumSuccessorOperands(trueIndex); - } - - /// Erase the operand at 'index' from the true operand list. - void eraseTrueOperand(unsigned index) { - getOperation()->eraseSuccessorOperand(trueIndex, index); - } - - // Accessors for operands to the 'false' destination. - Value *getFalseOperand(unsigned idx) { - assert(idx < getNumFalseOperands()); - return getOperand(getFalseDestOperandIndex() + idx); - } - void setFalseOperand(unsigned idx, Value *value) { - assert(idx < getNumFalseOperands()); - setOperand(getFalseDestOperandIndex() + idx, value); - } - - operand_iterator false_operand_begin() { return true_operand_end(); } - operand_iterator false_operand_end() { - return false_operand_begin() + getNumFalseOperands(); - } - operand_range getFalseOperands() { - return {false_operand_begin(), false_operand_end()}; - } - - unsigned getNumFalseOperands() { - return getOperation()->getNumSuccessorOperands(falseIndex); - } - - /// Erase the operand at 'index' from the false operand list. - void eraseFalseOperand(unsigned index) { - getOperation()->eraseSuccessorOperand(falseIndex, index); - } - - private: - /// Get the index of the first true destination operand. - unsigned getTrueDestOperandIndex() { return 1; } - - /// Get the index of the first false destination operand. - unsigned getFalseDestOperandIndex() { - return getTrueDestOperandIndex() + getNumTrueOperands(); - } - }]; - - let parser = [{ return parse$cppClass(parser, result); }]; - let printer = [{ return print$cppClass(p, *this); }]; -} - -def IREESeqLL_DynamicDispatchOp : IREESeqLL_Op<"dynamic_dispatch"> { - let arguments = (ins - SymbolRefAttr:$executable, - SymbolRefAttr:$entry_point, - IREELL_IntMemRef:$workload, - Variadic<IREELL_MemRef>:$operands - ); - let results = (outs Variadic<IREELL_MemRef>); - - let builders = [OpBuilder< - "Builder *builder, OperationState &result, StringRef executable," - "StringRef entry_point, Value *workload," - "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{ - result.addOperands({workload}); - result.addOperands(operands); - result.addAttribute("executable", builder->getSymbolRefAttr(executable)); - result.addAttribute("entry_point", builder->getSymbolRefAttr(entry_point)); - result.addTypes(results); - }]>]; - - let extraClassDeclaration = [{ - // TODO(b/132296600): make tablegen follow the style guide. - StringRef getExecutable() { return executable(); } - StringRef getEntryPoint() { return entry_point(); } - FunctionType getEntryPointType(); - - Value *getWorkload() { return getOperand(0); } - - // TODO(b/133879130): make tablegen support variadic operand accessors. - operand_range getArgOperands() { - return {arg_operand_begin(), arg_operand_end()}; - } - operand_iterator arg_operand_begin() { return operand_begin() + 1; } - operand_iterator arg_operand_end() { return operand_end(); } - - operand_type_range getArgOperandTypes() { - return {arg_operand_type_begin(), arg_operand_type_end()}; - } - operand_type_iterator arg_operand_type_begin() { - return operand_type_iterator(arg_operand_begin()); - } - operand_type_iterator arg_operand_type_end() { - return operand_type_iterator(arg_operand_end()); - } - }]; - - let parser = [{ return parse$cppClass(parser, result); }]; - let printer = [{ return print$cppClass(p, *this); }]; - let verifier = [{ return verify$cppClass(*this); }]; - let hasCanonicalizer = 1; -} - -def IREESeqLL_StaticDispatchOp : IREESeqLL_Op<"static_dispatch"> { - let arguments = (ins - SymbolRefAttr:$executable, - SymbolRefAttr:$entry_point, - I32ElementsAttr:$workload, - Variadic<IREELL_MemRef>:$operands - ); - let results = (outs Variadic<IREELL_MemRef>); - - let builders = [OpBuilder< - "Builder *builder, OperationState &result, StringRef executable," - "StringRef entry_point, ElementsAttr workload," - "ArrayRef<Type> results, ArrayRef<Value *> operands = {}", [{ - result.addAttribute("workload", workload); - result.addOperands(operands); - result.addAttribute("executable", builder->getSymbolRefAttr(executable)); - result.addAttribute("entry_point", builder->getSymbolRefAttr(entry_point)); - result.addTypes(results); - }]>]; - - let extraClassDeclaration = [{ - // TODO(b/132296600): make tablegen follow the style guide. - StringRef getExecutable() { return executable(); } - StringRef getEntryPoint() { return entry_point(); } - FunctionType getEntryPointType(); - - ElementsAttr getWorkload() { return workload(); } - - // TODO(b/133879130): make tablegen support variadic operand accessors. - operand_range getArgOperands() { - return {arg_operand_begin(), arg_operand_end()}; - } - operand_iterator arg_operand_begin() { return operand_begin(); } - operand_iterator arg_operand_end() { return operand_end(); } - - operand_type_range getArgOperandTypes() { - return {arg_operand_type_begin(), arg_operand_type_end()}; - } - operand_type_iterator arg_operand_type_begin() { - return operand_type_iterator(arg_operand_begin()); - } - operand_type_iterator arg_operand_type_end() { - return operand_type_iterator(arg_operand_end()); - } - }]; - - let parser = [{ return parse$cppClass(parser, result); }]; - let printer = [{ return print$cppClass(p, *this); }]; - let verifier = [{ return verify$cppClass(*this); }]; -} - -def IREESeqLL_AllocStaticOp : IREESeqLL_PureOp<"alloc_static"> { - // TODO(benvanik): attributes and args. - let results = (outs IREELL_MemRef); -} - -def IREESeqLL_AllocStackOp : IREESeqLL_PureOp<"alloc_stack"> { - // TODO(benvanik): attributes and args. - let arguments = (ins Variadic<IREELL_IntMemRef>:$dim_pieces); - let results = (outs IREELL_MemRef); -} - -def IREESeqLL_AllocStackInitOp : IREESeqLL_PureOp<"alloc_stack_init"> { - // TODO(benvanik): attributes and args. - let arguments = (ins Variadic<IREELL_IntMemRef>:$dim_pieces); - let results = (outs IREELL_MemRef); -} - -// TODO(b/142012496): Add trait that enables DCE but not CSE. -def IREESeqLL_AllocHeapOp : IREESeqLL_Op<"alloc_heap"> { - // TODO(benvanik): attributes and args. - let arguments = (ins Variadic<IREELL_IntMemRef>:$dim_pieces); - let results = (outs IREELL_MemRef); -} - -def IREESeqLL_DiscardOp : IREESeqLL_Op<"discard"> { - let arguments = (ins IREELL_MemRef); -} - -def IREESeqLL_ShapeOp : IREESeqLL_Op<"shape"> { - let arguments = (ins IREELL_MemRef:$input, IREELL_I32MemRef:$dst); - - let hasCanonicalizer = 1; -} - -def IREESeqLL_LengthOp : IREESeqLL_Op<"length"> { - let arguments = (ins IREELL_MemRef:$input, IREELL_I32Scalar:$dst); - - let hasCanonicalizer = 1; -} - -def IREESeqLL_ComputeOffsetOp : IREESeqLL_Op<"compute_offset"> { - let arguments = (ins - IREELL_1DIntMemRef:$shape, - I8Attr:$elementSize, - IREELL_1DIntMemRef:$indices, - IREELL_I32Scalar:$dst - ); - - let hasCanonicalizer = 1; -} - -def IREESeqLL_ComputeRangeOp : IREESeqLL_Op<"compute_range"> { - let arguments = (ins - IREELL_1DIntMemRef:$shape, - I8Attr:$elementSize, - IREELL_1DIntMemRef:$indices, - IREELL_1DIntMemRef:$lengths, - IREELL_I32Scalar:$dstOffset, - IREELL_I32Scalar:$dstLength - ); - - let hasCanonicalizer = 1; -} - -def IREESeqLL_DynamicSliceOp : IREESeqLL_PureOp<"dynamic_slice", [ - AllElementTypesMatch<["src", "result"]> -]> { - let arguments = (ins - IREELL_MemRef:$src, - IREELL_IntScalar:$offset, - IREELL_IntScalar:$length - ); - let results = (outs IREELL_MemRef:$result); -} - -def IREESeqLL_StaticSliceOp : IREESeqLL_PureOp<"static_slice", [ - AllElementTypesMatch<["src", "result"]> -]> { - let arguments = (ins - IREELL_MemRef:$src, - I64Attr:$offset, - I64Attr:$length - ); - let results = (outs IREELL_MemRef:$result); -} - -def IREESeqLL_DynamicCopyOp : IREESeqLL_Op<"dynamic_copy"> { - let arguments = (ins - IREELL_MemRef:$src, - IREELL_IndexScalar:$srcOffset, - IREELL_MemRef:$dst, - IREELL_IndexScalar:$dstOffset, - IREELL_IndexScalar:$length - ); - - let hasCanonicalizer = 1; -} - -def IREESeqLL_StaticCopyOp : IREESeqLL_Op<"static_copy"> { - let arguments = (ins - IREELL_MemRef:$src, - I64Attr:$srcOffset, - IREELL_MemRef:$dst, - I64Attr:$dstOffset, - I64Attr:$length - ); -} - -def IREESeqLL_DynamicFillOp : IREESeqLL_Op<"dynamic_fill"> { - let arguments = (ins - IREELL_I32Scalar:$value, - IREELL_MemRef:$dst, - IREELL_IndexScalar:$dstOffset, - IREELL_IndexScalar:$length - ); - - let hasCanonicalizer = 1; -} - -def IREESeqLL_StaticFillOp : IREESeqLL_Op<"static_fill"> { - let arguments = (ins - I32Attr:$value, - IREELL_MemRef:$dst, - I64Attr:$dstOffset, - I64Attr:$length - ); -} - -def IREESeqLL_CloneOp : - IREESeqLL_PureOp<"clone", [SameOperandsAndResultType]> { - let arguments = (ins IREELL_MemRef:$src); - let results = (outs IREELL_MemRef); -} - -def IREESeqLL_AssignOp : - IREESeqLL_Op<"assign", [SameOperandsAndResultType]> { - let arguments = (ins IREELL_MemRef:$src); - let results = (outs IREELL_MemRef); -} - -def IREESeqLL_CondAssignOp : IREESeqLL_Op<"cond_assign"> { - let arguments = (ins - IREELL_BoolScalar:$cond, - IREELL_MemRef:$lhs, - IREELL_MemRef:$rhs - ); - let results = (outs IREELL_MemRef); -} - -def IREESeqLL_ReshapeOp : IREESeqLL_Op<"reshape"> { - let arguments = (ins IREELL_MemRef:$input, IREELL_1DIntMemRef:$shape); - let results = (outs IREELL_MemRef); -} - -def IREESeqLL_TraceOp : IREESeqLL_Op<"trace"> { - let arguments = (ins Variadic<IREELL_MemRef>:$srcs); -} - -def IREESeqLL_CondBreakOp : IREESeqLL_Op<"cond_break"> { - let arguments = (ins IREELL_BoolScalar:$cond); -} - -def IREESeqLL_BreakOp : IREESeqLL_Op<"break">; - -#endif // IREE_SEQUENCER_LL_OPS
diff --git a/iree/compiler/IR/Sequencer/OpWriters.cpp b/iree/compiler/IR/Sequencer/OpWriters.cpp deleted file mode 100644 index 592c2d7..0000000 --- a/iree/compiler/IR/Sequencer/OpWriters.cpp +++ /dev/null
@@ -1,266 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Sequencer/OpWriters.h" - -#include "iree/compiler/IR/Sequencer/LLOps.h" -#include "iree/compiler/IR/StructureOps.h" -#include "iree/compiler/Serialization/BytecodeWriter.h" -#include "iree/compiler/Utils/Macros.h" -#include "iree/schemas/bytecode/sequencer_bytecode_v0.h" -#include "mlir/IR/Module.h" -#include "mlir/IR/OpImplementation.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -//===----------------------------------------------------------------------===// -// Sequencer ops -//===----------------------------------------------------------------------===// - -LogicalResult writeOp(IREESeq::LL::ConstantOp op, BytecodeWriter *writer) { - RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kConstant)); - auto memRefType = op.getType().dyn_cast<MemRefType>(); - if (!memRefType) { - return op.emitError() - << "Constant has an unsupported type; must be a memref: " - << op.getType(); - } - RETURN_IF_FAILURE(writer->WriteConstant(memRefType, op.getAttr("value"))); - RETURN_IF_FAILURE(writer->WriteLocal(op.getResult())); - return success(); -} - -LogicalResult writeOp(IREESeq::LL::CallOp op, BytecodeWriter *writer) { - auto module = op.getOperation()->getParentOfType<ModuleOp>(); - auto callee = module.lookupSymbol<FuncOp>(op.getCallee()); - // TODO(benvanik): switch with kCallTail if attr exists. - RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kCall)); - RETURN_IF_FAILURE(writer->WriteFunctionOrdinal(callee)); - RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands())); - RETURN_IF_FAILURE(writer->WriteLocals(op.getResults())); - return success(); -} - -LogicalResult writeOp(IREESeq::LL::CallImportOp op, BytecodeWriter *writer) { - auto module = op.getOperation()->getParentOfType<ModuleOp>(); - auto callee = module.lookupSymbol<FuncOp>(op.getCallee()); - // TODO(benvanik): transforms to convert Call->CallImport. - // TODO(benvanik): switch with kCallTail if attr exists. - if (callee.isExternal()) { - RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kCallImport)); - } else { - RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kCall)); - } - RETURN_IF_FAILURE(writer->WriteImportOrdinal(callee)); - RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands())); - RETURN_IF_FAILURE(writer->WriteLocals(op.getResults())); - return success(); -} - -LogicalResult writeOp(IREESeq::LL::CallIndirectOp op, BytecodeWriter *writer) { - RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kCallIndirect)); - RETURN_IF_FAILURE(writer->WriteTypeIndex(op.getCallee()->getType())); - RETURN_IF_FAILURE(writer->WriteLocal(op.getCallee())); - RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands())); - RETURN_IF_FAILURE(writer->WriteLocals(op.getResults())); - return success(); -} - -LogicalResult writeOp(IREESeq::LL::BranchOp op, BytecodeWriter *writer) { - RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kBranch)); - RETURN_IF_FAILURE(writer->WriteBlockOffset(op.getDest())); - RETURN_IF_FAILURE(writer->WriteCount(op.getNumOperands())); - for (int i = 0; i < op.getNumOperands(); ++i) { - // Copy src->dst. - RETURN_IF_FAILURE(writer->WriteLocal(op.getOperand(i))); - RETURN_IF_FAILURE(writer->WriteLocal(op.getDest()->getArgument(i))); - } - return success(); -} - -LogicalResult writeOp(IREESeq::LL::CondBranchOp op, BytecodeWriter *writer) { - RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kCondBranch)); - RETURN_IF_FAILURE(writer->WriteLocal(op.getCondition())); - RETURN_IF_FAILURE(writer->WriteBlockOffset(op.getTrueDest())); - RETURN_IF_FAILURE(writer->WriteCount(op.getNumTrueOperands())); - for (int i = 0; i < op.getNumTrueOperands(); ++i) { - // Copy src->dst. - RETURN_IF_FAILURE(writer->WriteLocal(op.getTrueOperand(i))); - RETURN_IF_FAILURE(writer->WriteLocal(op.getTrueDest()->getArgument(i))); - } - RETURN_IF_FAILURE(writer->WriteBlockOffset(op.getFalseDest())); - RETURN_IF_FAILURE(writer->WriteCount(op.getNumFalseOperands())); - for (int i = 0; i < op.getNumFalseOperands(); ++i) { - // Copy src->dst. - RETURN_IF_FAILURE(writer->WriteLocal(op.getFalseOperand(i))); - RETURN_IF_FAILURE(writer->WriteLocal(op.getFalseDest()->getArgument(i))); - } - return success(); -} - -LogicalResult writeDispatchOpExecutableRef(Operation *op, StringRef executable, - StringRef entryPoint, - BytecodeWriter *writer) { - auto module = op->getParentOfType<ModuleOp>(); - auto multiArchExecutableOp = - module.lookupSymbol<IREE::MultiArchExecutableOp>(executable); - if (!multiArchExecutableOp) { - return op->emitError() << "Executable @" << executable.str() - << " not found in module"; - } - - auto executableOrdinalAttr = multiArchExecutableOp.getAttr("iree.ordinal") - .dyn_cast_or_null<IntegerAttr>(); - if (!executableOrdinalAttr) { - return op->emitError() << "No ordinal assigned to executable"; - } - int executableOrdinal = executableOrdinalAttr.getInt(); - - // TODO(benvanik): move an export table to the MAE to make this cleaner. - auto executableOp = - cast<IREE::ExecutableOp>(multiArchExecutableOp.getBlock().front()); - auto entryPointOp = - executableOp.getInnerModule().lookupSymbol<FuncOp>(entryPoint); - if (!entryPointOp) { - return op->emitError() << "Entry point @" << entryPoint.str() - << " not found in executable @" << executable.str(); - } - if (!entryPointOp.getAttr("iree.ordinal")) { - return op->emitError() << "No ordinal assigned to entry point"; - } - int entryPointOrdinal = - entryPointOp.getAttr("iree.ordinal").cast<IntegerAttr>().getInt(); - - RETURN_IF_FAILURE(writer->WriteUint32(executableOrdinal)); - RETURN_IF_FAILURE(writer->WriteUint16(entryPointOrdinal)); - - return success(); -} - -LogicalResult writeOp(IREESeq::LL::DynamicDispatchOp op, - BytecodeWriter *writer) { - RETURN_IF_FAILURE( - writer->WriteOpcode(iree::SequencerOpcode::kDynamicDispatch)); - RETURN_IF_FAILURE(writeDispatchOpExecutableRef(op, op.getExecutable(), - op.getEntryPoint(), writer)); - RETURN_IF_FAILURE(writer->WriteLocal(op.getWorkload())); - RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands())); - // TODO(benvanik): support output arg group (or change to tags). - RETURN_IF_FAILURE(writer->WriteCount(/*output_arg_count*/ 0)); - RETURN_IF_FAILURE(writer->WriteLocals(op.getResults())); - return success(); -} - -LogicalResult writeOp(IREESeq::LL::StaticDispatchOp op, - BytecodeWriter *writer) { - RETURN_IF_FAILURE( - writer->WriteOpcode(iree::SequencerOpcode::kStaticDispatch)); - RETURN_IF_FAILURE(writeDispatchOpExecutableRef(op, op.getExecutable(), - op.getEntryPoint(), writer)); - auto workloadAttr = op.getWorkload(); - RETURN_IF_FAILURE( - writer->WriteInt32(workloadAttr.getValue<IntegerAttr>({0}).getInt())); - RETURN_IF_FAILURE( - writer->WriteInt32(workloadAttr.getValue<IntegerAttr>({1}).getInt())); - RETURN_IF_FAILURE( - writer->WriteInt32(workloadAttr.getValue<IntegerAttr>({2}).getInt())); - RETURN_IF_FAILURE(writer->WriteLocals(op.getArgOperands())); - // TODO(benvanik): support output arg group (or change to tags). - RETURN_IF_FAILURE(writer->WriteCount(/*output_arg_count*/ 0)); - RETURN_IF_FAILURE(writer->WriteLocals(op.getResults())); - return success(); -} - -LogicalResult writeOp(IREESeq::LL::AllocHeapOp op, BytecodeWriter *writer) { - auto memRefType = op.getType().cast<MemRefType>(); - RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kAllocHeap)); - RETURN_IF_FAILURE(writer->WriteInt32(0)); - RETURN_IF_FAILURE(writer->WriteTypeIndex(memRefType.getElementType())); - RETURN_IF_FAILURE(writer->WriteShapePieces(memRefType)); - RETURN_IF_FAILURE(writer->WriteLocals(op.getOperands())); - RETURN_IF_FAILURE(writer->WriteLocal(op.getResult())); - return success(); -} - -LogicalResult writeOp(IREESeq::LL::ComputeRangeOp op, BytecodeWriter *writer) { - RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kComputeRange)); - RETURN_IF_FAILURE(writer->WriteLocal(op.shape())); - RETURN_IF_FAILURE(writer->WriteUint8(op.elementSize().getZExtValue())); - RETURN_IF_FAILURE(writer->WriteLocal(op.indices())); - RETURN_IF_FAILURE(writer->WriteLocal(op.lengths())); - RETURN_IF_FAILURE(writer->WriteLocal(op.dstOffset())); - RETURN_IF_FAILURE(writer->WriteLocal(op.dstLength())); - return success(); -} - -LogicalResult writeOp(IREESeq::LL::StaticSliceOp op, BytecodeWriter *writer) { - RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kStaticSlice)); - RETURN_IF_FAILURE(writer->WriteLocal(op.src())); - RETURN_IF_FAILURE(writer->WriteInt32(op.offset().getZExtValue())); - RETURN_IF_FAILURE(writer->WriteInt32(op.length().getZExtValue())); - RETURN_IF_FAILURE(writer->WriteTypeIndex(op.getResult()->getType())); - RETURN_IF_FAILURE( - writer->WriteShapePieces(op.getResult()->getType().cast<ShapedType>())); - RETURN_IF_FAILURE(writer->WriteLocal(op.getResult())); - return success(); -} - -LogicalResult writeOp(IREESeq::LL::StaticCopyOp op, BytecodeWriter *writer) { - RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kStaticCopy)); - RETURN_IF_FAILURE(writer->WriteLocal(op.src())); - RETURN_IF_FAILURE(writer->WriteInt32(op.srcOffset().getZExtValue())); - RETURN_IF_FAILURE(writer->WriteLocal(op.dst())); - RETURN_IF_FAILURE(writer->WriteInt32(op.dstOffset().getZExtValue())); - RETURN_IF_FAILURE(writer->WriteInt32(op.length().getZExtValue())); - return success(); -} - -LogicalResult writeOp(IREESeq::LL::StaticFillOp op, BytecodeWriter *writer) { - RETURN_IF_FAILURE(writer->WriteOpcode(iree::SequencerOpcode::kStaticFill)); - RETURN_IF_FAILURE(writer->WriteInt32(op.value().getZExtValue())); - RETURN_IF_FAILURE(writer->WriteLocal(op.dst())); - RETURN_IF_FAILURE(writer->WriteInt32(op.dstOffset().getZExtValue())); - RETURN_IF_FAILURE(writer->WriteInt32(op.length().getZExtValue())); - return success(); -} - -} // namespace - -void registerSequencerCustomWriters(VMFunctionBuilder *builder) { -#define REGISTER_CUSTOM_WRITER_IMPL(op_type) \ - builder->RegisterCustomWriter( \ - op_type::getOperationName(), \ - +[](Operation *op, BytecodeWriter *writer) { \ - return writeOp(cast<op_type>(op), writer); \ - }); - REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::ConstantOp); - REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::CallOp); - REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::CallImportOp); - REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::CallIndirectOp); - REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::BranchOp); - REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::CondBranchOp); - REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::DynamicDispatchOp); - REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::StaticDispatchOp); - REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::AllocHeapOp); - REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::ComputeRangeOp); - REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::StaticSliceOp); - REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::StaticCopyOp); - REGISTER_CUSTOM_WRITER_IMPL(IREESeq::LL::StaticFillOp); -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/IR/Sequencer/OpWriters.h b/iree/compiler/IR/Sequencer/OpWriters.h deleted file mode 100644 index a0039af..0000000 --- a/iree/compiler/IR/Sequencer/OpWriters.h +++ /dev/null
@@ -1,30 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_COMPILER_IR_SEQUENCER_OPWRITERS_H_ -#define IREE_COMPILER_IR_SEQUENCER_OPWRITERS_H_ - -#include "iree/compiler/Serialization/VMFunctionBuilder.h" - -namespace mlir { -namespace iree_compiler { - -// Registers custom op writers with the builder. -// Ops not registered will use the generic writer. -void registerSequencerCustomWriters(VMFunctionBuilder *builder); - -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_IR_SEQUENCER_OPWRITERS_H_
diff --git a/iree/compiler/IR/Sequencer/test/BUILD b/iree/compiler/IR/Sequencer/test/BUILD deleted file mode 100644 index 4df56ac..0000000 --- a/iree/compiler/IR/Sequencer/test/BUILD +++ /dev/null
@@ -1,15 +0,0 @@ -load("//iree:build_defs.bzl", "iree_glob_lit_tests", "iree_setup_lit_package") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -iree_setup_lit_package( - data = [ - "//iree/tools:iree-opt", - "//iree/tools:iree-run-mlir", - ], -) - -iree_glob_lit_tests()
diff --git a/iree/compiler/IR/StructureOps.cpp b/iree/compiler/IR/StructureOps.cpp deleted file mode 100644 index 1ee1c72..0000000 --- a/iree/compiler/IR/StructureOps.cpp +++ /dev/null
@@ -1,250 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/StructureOps.h" - -#include "iree/compiler/IR/Types.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallString.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/SymbolTable.h" -#include "mlir/IR/Value.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/STLExtras.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { - -//===----------------------------------------------------------------------===// -// Generic printers and parsers. -//===----------------------------------------------------------------------===// - -// Parses an op that has no inputs and no outputs. -static ParseResult parseNoIOOp(OpAsmParser &parser, OperationState &state) { - if (failed(parser.parseOptionalAttributeDict(state.attributes))) { - return failure(); - } - return success(); -} - -// Prints an op that has no inputs and no outputs. -static void printNoIOOp(Operation *op, OpAsmPrinter &printer) { - printer << op->getName(); - printer.printOptionalAttrDict(op->getAttrs()); -} - -//===----------------------------------------------------------------------===// -// iree.module -//===----------------------------------------------------------------------===// - -void ModuleOp::build(Builder *builder, OperationState &state) { - ensureTerminator(*state.addRegion(), *builder, state.location); -} - -static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) { - Region *body = state.addRegion(); - if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) { - return failure(); - } - if (parser.parseOptionalAttributeDict(state.attributes)) { - return failure(); - } - ModuleOp::ensureTerminator(*body, parser.getBuilder(), state.location); - return success(); -} - -static void printModuleOp(OpAsmPrinter &printer, Operation *op) { - printer << op->getName(); - printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/false); - printer.printOptionalAttrDict(op->getAttrs()); -} - -//===----------------------------------------------------------------------===// -// iree.multi_arch_executable -//===----------------------------------------------------------------------===// - -void MultiArchExecutableOp::build(Builder *builder, OperationState &state, - StringRef name) { - state.addAttribute(SymbolTable::getSymbolAttrName(), - builder->getStringAttr(name)); - ensureTerminator(*state.addRegion(), *builder, state.location); -} - -static ParseResult parseMultiArchExecutableOp(OpAsmParser &parser, - OperationState &state) { - auto &builder = parser.getBuilder(); - - // Parse the name as a symbol reference attr and then convert to a string. - SymbolRefAttr nameAttr; - if (failed(parser.parseAttribute(nameAttr, SymbolTable::getSymbolAttrName(), - state.attributes))) { - return failure(); - } - state.attributes.back().second = builder.getStringAttr(nameAttr.getValue()); - - if (succeeded(parser.parseOptionalLSquare())) { - IntegerAttr ordinalAttr; - if (failed(parser.parseAttribute(ordinalAttr, builder.getIntegerType(32), - "iree.ordinal", state.attributes)) || - failed(parser.parseRSquare())) { - return failure(); - } - } - - if (failed(parser.parseLParen()) || failed(parser.parseRParen())) { - return failure(); - } - - Region *body = state.addRegion(); - if (failed(parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))) { - return failure(); - } - if (succeeded(parser.parseOptionalKeyword("attributes"))) { - if (failed(parser.parseOptionalAttributeDict(state.attributes))) { - return failure(); - } - } - - MultiArchExecutableOp::ensureTerminator(*body, builder, state.location); - - return success(); -} - -static void printMultiArchExecutableOp(OpAsmPrinter &printer, - MultiArchExecutableOp op) { - printer << op.getOperationName() << " @" << op.sym_name(); - if (auto ordinalAttr = - op.getAttr("iree.ordinal").dyn_cast_or_null<IntegerAttr>()) { - printer << "[" << ordinalAttr.getInt() << "]"; - } - printer << "()"; - - printer.printRegion(op.body(), /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/false); - - // Print out executable attributes, if present. - SmallVector<StringRef, 2> ignoredAttrs = { - SymbolTable::getSymbolAttrName(), - "iree.ordinal", - }; - SmallVector<NamedAttribute, 4> attrs( - llvm::make_filter_range(op.getAttrs(), [&](const NamedAttribute &attr) { - return llvm::count(ignoredAttrs, attr.first) == 0; - })); - if (!attrs.empty()) { - printer << "\n attributes "; - printer.printOptionalAttrDict(attrs); - } -} - -//===----------------------------------------------------------------------===// -// iree.executable -//===----------------------------------------------------------------------===// - -void ExecutableOp::build(Builder *builder, OperationState &state, - IREE::ExecutableFormat format) { - state.addAttribute("format", - builder->getI32IntegerAttr(static_cast<uint32_t>(format))); - ensureTerminator(*state.addRegion(), *builder, state.location); -} - -static ParseResult parseExecutableOp(OpAsmParser &parser, - OperationState &state) { - auto &builder = parser.getBuilder(); - - if (succeeded(parser.parseOptionalLSquare())) { - IntegerAttr ordinalAttr; - if (failed(parser.parseAttribute(ordinalAttr, builder.getIntegerType(32), - "iree.ordinal", state.attributes)) || - failed(parser.parseRSquare())) { - return failure(); - } - } - - IntegerAttr executableOrdinalAttr; - StringAttr formatAttr; - llvm::SMLoc formatLoc; - if (failed(parser.parseLParen()) || - failed(parser.getCurrentLocation(&formatLoc)) || - failed(parser.parseAttribute(formatAttr, "format", state.attributes))) { - return failure(); - } - auto format = symbolizeExecutableFormat(formatAttr.getValue()); - if (!format.hasValue()) { - return parser.emitError(formatLoc) - << "Unknown executable format " << formatAttr.getValue(); - } - state.attributes.back().second = - builder.getI32IntegerAttr(static_cast<int32_t>(format.getValue())); - - Region *body = state.addRegion(); - if (failed(parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))) { - return failure(); - } - if (succeeded(parser.parseOptionalKeyword("attributes"))) { - if (failed(parser.parseOptionalAttributeDict(state.attributes))) { - return failure(); - } - } - - ExecutableOp::ensureTerminator(*body, parser.getBuilder(), state.location); - - return success(); -} - -static void printExecutableOp(OpAsmPrinter &printer, ExecutableOp op) { - printer << op.getOperationName(); - if (auto ordinalAttr = - op.getAttr("iree.ordinal").dyn_cast_or_null<IntegerAttr>()) { - printer << "[" << ordinalAttr.getInt() << "]"; - } - printer << "("; - auto format = symbolizeExecutableFormat(op.format()); - if (format.hasValue()) { - printer << stringifyExecutableFormat(format.getValue()); - } else { - printer << "INVALID FORMAT"; - } - printer << ")"; - - printer.printRegion(op.body(), /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/false); - - // Print out executable attributes, if present. - SmallVector<StringRef, 2> ignoredAttrs = { - "iree.ordinal", - "format", - }; - SmallVector<NamedAttribute, 4> attrs( - llvm::make_filter_range(op.getAttrs(), [&](const NamedAttribute &attr) { - return llvm::count(ignoredAttrs, attr.first) == 0; - })); - if (!attrs.empty()) { - printer << "\n attributes "; - printer.printOptionalAttrDict(attrs); - } -} - -#define GET_OP_CLASSES -#include "iree/compiler/IR/StructureOps.cpp.inc" - -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/IR/StructureOps.h b/iree/compiler/IR/StructureOps.h deleted file mode 100644 index 003310b..0000000 --- a/iree/compiler/IR/StructureOps.h +++ /dev/null
@@ -1,40 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_COMPILER_IR_STRUCTUREOPS_H_ -#define IREE_COMPILER_IR_STRUCTUREOPS_H_ - -#include <cstdint> - -#include "iree/compiler/IR/Types.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/FunctionSupport.h" -#include "mlir/IR/Module.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/StandardTypes.h" - -namespace mlir { -namespace iree_compiler { -namespace IREE { - -#define GET_OP_CLASSES -#include "iree/compiler/IR/StructureOps.h.inc" - -} // namespace IREE -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_IR_STRUCTUREOPS_H_
diff --git a/iree/compiler/IR/StructureOps.td b/iree/compiler/IR/StructureOps.td deleted file mode 100644 index 189d3a4..0000000 --- a/iree/compiler/IR/StructureOps.td +++ /dev/null
@@ -1,122 +0,0 @@ -// Structural ops such as 'module' and 'executable'. -// These are used to organize IREE IR into regions representing ops that act at -// the sequencer level (coarse control flow/scheduling) and ops that perform -// actual work (math/etc) on runtime execution backends. - -#ifdef IREE_STRUCTURE_OPS -#else -#define IREE_STRUCTURE_OPS - -#ifdef IREE_OP_BASE -#else -include "iree/compiler/IR/OpBase.td" -#endif // IREE_OP_BASE - -class IREE_StructureOp<string mnemonic, list<OpTrait> traits = []> : - Op<IREE_Dialect, mnemonic, traits> { - let parser = [{ return parse$cppClass(parser, result); }]; - let printer = [{ print$cppClass(p, *this); }]; -} - -def IREE_ModuleOp : - IREE_StructureOp<"module", [ - SingleBlockImplicitTerminator<"ModuleEndOp">, - NativeOpTrait<"SymbolTable"> - ]> { - let regions = (region SizedRegion<1>:$body); - let extraClassDeclaration = [{ - Block& getBlock() { - return this->getOperation()->getRegion(0).front(); - } - }]; - - let skipDefaultBuilders = 1; - let builders = [OpBuilder<"Builder *, OperationState &state">]; -} - -def IREE_ModuleEndOp : - IREE_StructureOp<"_module_end", [ - IREE_ModuleOnly, - Terminator - ]> { - let parser = [{ return parseNoIOOp(parser, result); }]; - let printer = [{ printNoIOOp(getOperation(), p); }]; -} - -def IREE_MultiArchExecutableOp : - IREE_StructureOp<"multi_arch_executable", [ - // TODO(benvanik): make iree.module work and make this IREE_ModuleOnly. - SingleBlockImplicitTerminator<"MultiArchExecutableEndOp"> - ]> { - let arguments = (ins - StrAttr:$sym_name, - OptionalAttr<I32Attr>:$ordinal - ); - - let regions = (region SizedRegion<1>:$body); - let extraClassDeclaration = [{ - StringRef getName() { - return this->getOperation()->template getAttrOfType<StringAttr>( - ::mlir::SymbolTable::getSymbolAttrName()).getValue(); - } - - Region& getBody() { - return this->getOperation()->getRegion(0); - } - Block& getBlock() { - return this->getOperation()->getRegion(0).front(); - } - }]; - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<"Builder *builder, OperationState &state, StringRef name">, - ]; -} - -def IREE_MultiArchExecutableEndOp : - IREE_StructureOp<"_multi_arch_executable_end", [ - IREE_MultiArchExecutableOnly, - Terminator - ]> { - let parser = [{ return parseNoIOOp(parser, result); }]; - let printer = [{ printNoIOOp(getOperation(), p); }]; -} - -def IREE_ExecutableOp : - IREE_StructureOp<"executable", [ - SingleBlockImplicitTerminator<"ExecutableEndOp">, - NativeOpTrait<"SymbolTable"> - ]> { - let arguments = (ins - IREE_ExecutableFormatAttr:$format, - OptionalAttr<I32Attr>:$ordinal - ); - - let regions = (region SizedRegion<1>:$body); - let extraClassDeclaration = [{ - Region& getBody() { - return this->getOperation()->getRegion(0); - } - Block& getBlock() { - return this->getOperation()->getRegion(0).front(); - } - ::mlir::ModuleOp getInnerModule() { - return *getBlock().getOps<::mlir::ModuleOp>().begin(); - } - }]; - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<[{Builder *builder, OperationState &state, - ExecutableFormat executable_format}]>, - ]; -} - -def IREE_ExecutableEndOp : - IREE_StructureOp<"_executable_end", [Terminator, IREE_ExecutableOnly]> { - let parser = [{ return parseNoIOOp(parser, result); }]; - let printer = [{ printNoIOOp(getOperation(), p); }]; -} - -#endif // IREE_STRUCTURE_OPS
diff --git a/iree/compiler/IR/Traits.cpp b/iree/compiler/IR/Traits.cpp deleted file mode 100644 index 8a7081a..0000000 --- a/iree/compiler/IR/Traits.cpp +++ /dev/null
@@ -1,23 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Traits.h" - -namespace mlir { -namespace iree_compiler { - -// TODO(benvanik): traits. - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/IR/Types.cpp b/iree/compiler/IR/Types.cpp deleted file mode 100644 index bc49235..0000000 --- a/iree/compiler/IR/Types.cpp +++ /dev/null
@@ -1,53 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Types.h" - -#include "iree/compiler/IR/Enums.cpp.inc" - -namespace mlir { -namespace iree_compiler { - -// static -DeviceType DeviceType::get(MLIRContext *context) { - return Base::get(context, TypeKind::Device); -} - -// static -DeviceGroupType DeviceGroupType::get(MLIRContext *context) { - return Base::get(context, TypeKind::DeviceGroup); -} - -// static -CommandBufferType CommandBufferType::get(MLIRContext *context) { - return Base::get(context, TypeKind::CommandBuffer); -} - -// static -EventType EventType::get(MLIRContext *context) { - return Base::get(context, TypeKind::Event); -} - -// static -SemaphoreType SemaphoreType::get(MLIRContext *context) { - return Base::get(context, TypeKind::Semaphore); -} - -// static -FenceType FenceType::get(MLIRContext *context) { - return Base::get(context, TypeKind::Fence); -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/IR/Types.h b/iree/compiler/IR/Types.h deleted file mode 100644 index da53d53..0000000 --- a/iree/compiler/IR/Types.h +++ /dev/null
@@ -1,115 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_COMPILER_IR_TYPES_H_ -#define IREE_COMPILER_IR_TYPES_H_ - -#include <cstdint> - -#include "llvm/ADT/DenseMapInfo.h" -#include "llvm/ADT/StringSwitch.h" -#include "mlir/IR/Types.h" -#include "mlir/Support/LLVM.h" - -// Order matters. -#include "iree/compiler/IR/Enums.h.inc" - -namespace mlir { -namespace iree_compiler { - -namespace TypeKind { -enum Kind { - Device = Type::FIRST_IREE_TYPE, - DeviceGroup, - CommandBuffer, - Event, - Semaphore, - Fence, -}; -} // namespace TypeKind - -// clang-format off -#define IREE_TYPE_TABLE(map) \ - map("device", TypeKind::Device, DeviceType) \ - map("device_group", TypeKind::DeviceGroup, DeviceGroupType) \ - map("command_buffer", TypeKind::CommandBuffer, CommandBufferType) \ - map("event", TypeKind::Event, EventType) \ - map("semaphore", TypeKind::Semaphore, SemaphoreType) \ - map("fence", TypeKind::Fence, FenceType) -// clang-format on - -// iree.device mapping to a runtime-resolved device type. -class DeviceType : public Type::TypeBase<DeviceType, Type> { - public: - using Base::Base; - - static bool kindof(unsigned kind) { return kind == TypeKind::Device; } - - static DeviceType get(MLIRContext *context); -}; - -// iree.device_group relating multiple iree.device requirements with each other. -class DeviceGroupType : public Type::TypeBase<DeviceGroupType, Type> { - public: - using Base::Base; - - static bool kindof(unsigned kind) { return kind == TypeKind::DeviceGroup; } - - static DeviceGroupType get(MLIRContext *context); -}; - -// iree.command_buffer mapping to an iree::hal::CommandBuffer. -class CommandBufferType : public Type::TypeBase<CommandBufferType, Type> { - public: - using Base::Base; - - static bool kindof(unsigned kind) { return kind == TypeKind::CommandBuffer; } - - static CommandBufferType get(MLIRContext *context); -}; - -// iree.event mapping to an iree::hal::Event. -class EventType : public Type::TypeBase<EventType, Type> { - public: - using Base::Base; - - static bool kindof(unsigned kind) { return kind == TypeKind::Event; } - - static EventType get(MLIRContext *context); -}; - -// iree.semaphore mapping to an iree::hal::Semaphore. -class SemaphoreType : public Type::TypeBase<SemaphoreType, Type> { - public: - using Base::Base; - - static bool kindof(unsigned kind) { return kind == TypeKind::Semaphore; } - - static SemaphoreType get(MLIRContext *context); -}; - -// iree.fence mapping to an iree::hal::Fence. -class FenceType : public Type::TypeBase<FenceType, Type> { - public: - using Base::Base; - - static bool kindof(unsigned kind) { return kind == TypeKind::Fence; } - - static FenceType get(MLIRContext *context); -}; - -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_IR_TYPES_H_
diff --git a/iree/compiler/IR/test/BUILD b/iree/compiler/IR/test/BUILD deleted file mode 100644 index 4df56ac..0000000 --- a/iree/compiler/IR/test/BUILD +++ /dev/null
@@ -1,15 +0,0 @@ -load("//iree:build_defs.bzl", "iree_glob_lit_tests", "iree_setup_lit_package") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -iree_setup_lit_package( - data = [ - "//iree/tools:iree-opt", - "//iree/tools:iree-run-mlir", - ], -) - -iree_glob_lit_tests()
diff --git a/iree/compiler/Serialization/BUILD b/iree/compiler/Serialization/BUILD deleted file mode 100644 index 6bee31b..0000000 --- a/iree/compiler/Serialization/BUILD +++ /dev/null
@@ -1,43 +0,0 @@ -# Serialization for the VM bytecode. - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "Serialization", - srcs = [ - "BytecodeTables.cpp", - "BytecodeWriter.cpp", - "VMDeviceTableBuilder.cpp", - "VMExecutableTableBuilder.cpp", - "VMFunctionBuilder.cpp", - "VMFunctionTableBuilder.cpp", - "VMModuleBuilder.cpp", - "VMSourceMapBuilder.cpp", - ], - hdrs = [ - "BytecodeTables.h", - "BytecodeWriter.h", - "VMDeviceTableBuilder.h", - "VMExecutableTableBuilder.h", - "VMFunctionBuilder.h", - "VMFunctionTableBuilder.h", - "VMModuleBuilder.h", - "VMSourceMapBuilder.h", - ], - deps = [ - "//iree/compiler/IR", - "//iree/compiler/Utils", - "//iree/schemas", - "//iree/schemas/bytecode:bytecode_v0", - "//iree/schemas/bytecode:interpreter_bytecode_v0", - "//iree/schemas/bytecode:sequencer_bytecode_v0", - "@com_github_google_flatbuffers//:flatbuffers", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Support", - ], -)
diff --git a/iree/compiler/Serialization/BytecodeTables.cpp b/iree/compiler/Serialization/BytecodeTables.cpp deleted file mode 100644 index 7d0cc9d..0000000 --- a/iree/compiler/Serialization/BytecodeTables.cpp +++ /dev/null
@@ -1,73 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/Serialization/BytecodeTables.h" - -#include "llvm/ADT/STLExtras.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -// Info tables mapping 1:1 with bytecode ops. -// -// Note that we ensure the table is 256 elements long exactly to make sure -// that unused opcodes are handled gracefully. -#define DECLARE_INFO(ordinal, enum_value, name, flags, operand_encodings, ...) \ - { \ - name, \ - flags, \ - {operand_encodings}, \ - }, - -static const OpcodeInfo kInterpreterInfoTable[256] = { - IREE_INTERPRETER_OPCODE_LIST(DECLARE_INFO, DECLARE_INFO)}; - -static const OpcodeInfo kSequencerInfoTable[256] = { - IREE_SEQUENCER_OPCODE_LIST(DECLARE_INFO, DECLARE_INFO)}; - -#undef DECLARE_INFO - -} // namespace - -llvm::Optional<iree::InterpreterOpcode> GetInterpreterOpcodeByName( - StringRef name) { - for (int i = 0; i < llvm::array_lengthof(kInterpreterInfoTable); ++i) { - if (name == kInterpreterInfoTable[i].mnemonic) { - return static_cast<iree::InterpreterOpcode>(i); - } - } - return llvm::None; -} - -const OpcodeInfo& GetInterpreterOpcodeInfo(iree::InterpreterOpcode opcode) { - return kInterpreterInfoTable[static_cast<uint8_t>(opcode)]; -} - -llvm::Optional<iree::SequencerOpcode> GetSequencerOpcodeByName(StringRef name) { - for (int i = 0; i < llvm::array_lengthof(kSequencerInfoTable); ++i) { - if (name == kSequencerInfoTable[i].mnemonic) { - return static_cast<iree::SequencerOpcode>(i); - } - } - return llvm::None; -} - -const OpcodeInfo& GetSequencerOpcodeInfo(iree::SequencerOpcode opcode) { - return kSequencerInfoTable[static_cast<uint8_t>(opcode)]; -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Serialization/BytecodeTables.h b/iree/compiler/Serialization/BytecodeTables.h deleted file mode 100644 index bee8763..0000000 --- a/iree/compiler/Serialization/BytecodeTables.h +++ /dev/null
@@ -1,52 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_COMPILER_SERIALIZATION_BYTECODE_TABLES_H_ -#define IREE_COMPILER_SERIALIZATION_BYTECODE_TABLES_H_ - -#include "iree/schemas/bytecode/interpreter_bytecode_v0.h" -#include "iree/schemas/bytecode/sequencer_bytecode_v0.h" -#include "llvm/ADT/Optional.h" -#include "llvm/ADT/StringRef.h" -#include "mlir/Support/LLVM.h" - -namespace mlir { -namespace iree_compiler { - -struct OpcodeInfo { - const char* mnemonic = nullptr; - iree::OpcodeFlagBitfield flags = iree::OpcodeFlagBitfield::kDefault; - union { - const char operands_value[8] = {0}; - const iree::OperandEncoding operands[8]; - }; -}; - -// Returns an opcode - if found - for the given interpreter op. -llvm::Optional<iree::InterpreterOpcode> GetInterpreterOpcodeByName( - StringRef name); - -// Returns the info for the given interpreter opcode. -const OpcodeInfo& GetInterpreterOpcodeInfo(iree::InterpreterOpcode opcode); - -// Returns an opcode - if found - for the given sequencer op. -llvm::Optional<iree::SequencerOpcode> GetSequencerOpcodeByName(StringRef name); - -// Returns the info for the given sequencer opcode. -const OpcodeInfo& GetSequencerOpcodeInfo(iree::SequencerOpcode opcode); - -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_SERIALIZATION_BYTECODE_TABLES_H_
diff --git a/iree/compiler/Serialization/BytecodeWriter.cpp b/iree/compiler/Serialization/BytecodeWriter.cpp deleted file mode 100644 index a357c3a..0000000 --- a/iree/compiler/Serialization/BytecodeWriter.cpp +++ /dev/null
@@ -1,334 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/Serialization/BytecodeWriter.h" - -#include <algorithm> - -#include "iree/compiler/Utils/Macros.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/Support/LLVM.h" - -namespace mlir { -namespace iree_compiler { - -LogicalResult BytecodeWriter::WriteCount(int count) { - if (count > UINT8_MAX) { - // TODO(benvanik): varints? - llvm::errs() << "Too many items: " << count - << "; only 0-UINT8_MAX are supported"; - return failure(); - } - return WriteUint8(static_cast<uint8_t>(count)); -} - -LogicalResult BytecodeWriter::WriteTypeIndex(Type type) { - iree::BuiltinType type_index; - if (type.isInteger(8)) { - type_index = iree::BuiltinType::kI8; - } else if (type.isInteger(16)) { - type_index = iree::BuiltinType::kI16; - } else if (type.isInteger(32)) { - type_index = iree::BuiltinType::kI32; - } else if (type.isInteger(64)) { - type_index = iree::BuiltinType::kI64; - } else if (type.isF16()) { - type_index = iree::BuiltinType::kF16; - } else if (type.isF32()) { - type_index = iree::BuiltinType::kF32; - } else if (type.isF64()) { - type_index = iree::BuiltinType::kF64; - } else { - // TODO(benvanik): support unknown types as BuiltinType::kOpaque? - return emitError(UnknownLoc::get(type.getContext())) - << "Type " << type << " cannot be represented by a builtin type"; - } - return WriteUint8(static_cast<uint8_t>(type_index)); -} - -LogicalResult BytecodeWriter::WriteFunctionOrdinal(FuncOp function) { - auto functionOrdinal = function.getAttrOfType<IntegerAttr>("iree.ordinal"); - if (!functionOrdinal) { - return function.emitError() << "Ordinal not assigned to function"; - } - RETURN_IF_FAILURE(WriteUint32(functionOrdinal.getInt())); - return success(); -} - -LogicalResult BytecodeWriter::WriteImportOrdinal(FuncOp function) { - // For now this is the same as internal function ordinals, though we could - // probably shrink it. - return WriteFunctionOrdinal(function); -} - -LogicalResult BytecodeWriter::WriteConstant(MemRefType memRefType, - Attribute baseAttr) { - // All types are memrefs, so we only need the element type. - RETURN_IF_FAILURE(WriteTypeIndex(memRefType.getElementType())); - - // Write shape (we could optimize this for cases of scalars and such). - RETURN_IF_FAILURE(WriteCount(memRefType.getRank())); - for (int i = 0; i < memRefType.getRank(); ++i) { - RETURN_IF_FAILURE(WriteInt32(memRefType.getDimSize(i))); - } - - if (auto attr = baseAttr.dyn_cast<SplatElementsAttr>()) { - RETURN_IF_FAILURE( - WriteUint8(static_cast<uint8_t>(iree::ConstantEncoding::kSplat))); - return WriteAttributeData(attr.getSplatValue()); - } - RETURN_IF_FAILURE( - WriteUint8(static_cast<uint8_t>(iree::ConstantEncoding::kDense))); - return WriteAttributeData(baseAttr); -} - -LogicalResult BytecodeWriter::WriteAttributeData(Attribute baseAttr) { - if (auto attr = baseAttr.dyn_cast<BoolAttr>()) { - return WriteUint8(attr.getValue() ? 1 : 0); - } else if (auto attr = baseAttr.dyn_cast<IntegerAttr>()) { - if (attr.getType().isIndex()) { - int32_t value = static_cast<int32_t>(attr.getInt()); - return WriteBytes(&value, 4); - } else { - int bitWidth = attr.getValue().getBitWidth(); - switch (bitWidth) { - case 8: - case 16: - case 32: - case 64: - return WriteBytes(attr.getValue().getRawData(), bitWidth / 8); - default: - return emitError(UnknownLoc::get(baseAttr.getContext())) - << "Bit width for integers must be one of 8,16,32,64; others " - "not implemented: " - << bitWidth; - } - } - } else if (auto attr = baseAttr.dyn_cast<FloatAttr>()) { - int bitWidth = attr.getType().getIntOrFloatBitWidth(); - auto bitcastValue = attr.getValue().bitcastToAPInt(); - switch (bitWidth) { - case 16: - case 32: - case 64: - return WriteBytes(bitcastValue.getRawData(), bitWidth / 8); - default: - return emitError(UnknownLoc::get(baseAttr.getContext())) - << "Bit width for floats must be one of 16,32,64; others " - "not implemented: " - << bitWidth; - } - } else if (auto attr = baseAttr.dyn_cast<StringAttr>()) { - // TODO(benvanik): other attribute encodings. - } else if (auto attr = baseAttr.dyn_cast<ArrayAttr>()) { - // TODO(benvanik): other attribute encodings. - } else if (auto attr = baseAttr.dyn_cast<AffineMapAttr>()) { - // TODO(benvanik): other attribute encodings. - } else if (auto attr = baseAttr.dyn_cast<IntegerSetAttr>()) { - // TODO(benvanik): other attribute encodings. - } else if (auto attr = baseAttr.dyn_cast<TypeAttr>()) { - // TODO(benvanik): other attribute encodings. - } else if (auto attr = baseAttr.dyn_cast<SymbolRefAttr>()) { - // TODO(benvanik): other attribute encodings. - } else if (auto attr = baseAttr.dyn_cast<SplatElementsAttr>()) { - return WriteAttributeData(attr.getSplatValue()); - } else if (auto attr = baseAttr.dyn_cast<DenseIntElementsAttr>()) { - int elementCount = attr.getType().getNumElements(); - if (elementCount == 0) { - return success(); - } - int bitWidth = attr.getType().getElementTypeBitWidth(); - int byteWidth = bitWidth / 8; - auto dst = ReserveBytes(elementCount * byteWidth); - if (dst.empty()) return failure(); - uint8_t *dstPtr = dst.data(); - for (auto element : attr) { - assert(element.getBitWidth() == bitWidth); - std::memcpy(dstPtr, element.getRawData(), byteWidth); - dstPtr += byteWidth; - } - return success(); - } else if (auto attr = baseAttr.dyn_cast<DenseFPElementsAttr>()) { - int elementCount = attr.getType().getNumElements(); - if (elementCount == 0) { - return success(); - } - int bitWidth = attr.getType().getElementTypeBitWidth(); - auto dst = ReserveBytes(elementCount * bitWidth / 8); - if (dst.empty()) return failure(); - uint8_t *dstPtr = dst.data(); - for (auto element : attr) { - auto bitcastValue = element.bitcastToAPInt(); - std::memcpy(dstPtr, bitcastValue.getRawData(), - bitcastValue.getBitWidth() / 8); - dstPtr += bitWidth / 8; - } - return success(); - } else if (auto attr = baseAttr.dyn_cast<DenseElementsAttr>()) { - // TODO(benvanik): other attribute encodings. - } else if (auto attr = baseAttr.dyn_cast<OpaqueElementsAttr>()) { - // TODO(benvanik): other attribute encodings. - } else if (auto attr = baseAttr.dyn_cast<SparseElementsAttr>()) { - // TODO(benvanik): other attribute encodings. - } - return emitError(UnknownLoc::get(baseAttr.getContext())) - << "Serializer for attribute kind " - << static_cast<int>(baseAttr.getKind()) << " not implemented"; -} - -Optional<int> BytecodeWriter::LookupLocalOrdinal(Value *value) { - int ordinal; - auto it = localMap_.find(value); - if (it != localMap_.end()) { - ordinal = it->second; - } else { - ordinal = localMap_.size(); - localMap_.insert({value, ordinal}); - } - if (ordinal > UINT16_MAX) { - // TODO(benvanik): varints? - emitError(UnknownLoc::get(value->getContext())) - << "Too many ordinals: " << ordinal - << "; only 0-UINT16_MAX are supported"; - return llvm::None; - } - return ordinal; -} - -LogicalResult BytecodeWriter::PrepareLocal(Value *value) { - if (!LookupLocalOrdinal(value).hasValue()) return failure(); - return success(); -} - -LogicalResult BytecodeWriter::WriteLocal(Value *value) { - auto ordinal = LookupLocalOrdinal(value); - if (!ordinal.hasValue()) { - return failure(); - } - if (ordinal.getValue() > UINT16_MAX) { - // TODO(benvanik): varints? - return emitError(UnknownLoc::get(value->getContext())) - << "Too many locals: " << ordinal.getValue() - << "; only 0-UINT16_MAX are supported"; - } - return WriteUint16(static_cast<uint16_t>(ordinal.getValue())); -} - -LogicalResult BytecodeWriter::WriteLocals( - llvm::iterator_range<Operation::operand_iterator> values) { - int count = std::distance(values.begin(), values.end()); - RETURN_IF_FAILURE(WriteCount(count)); - for (auto *value : values) { - RETURN_IF_FAILURE(WriteLocal(value)); - } - return success(); -} - -LogicalResult BytecodeWriter::WriteLocals( - llvm::iterator_range<Operation::result_iterator> values) { - int count = std::distance(values.begin(), values.end()); - RETURN_IF_FAILURE(WriteCount(count)); - for (auto *value : values) { - RETURN_IF_FAILURE(WriteLocal(value)); - } - return success(); -} - -MutableArrayRef<uint8_t> BytecodeWriter::ReserveBytes(size_t dataLength) { - int offset = bytecode_.size(); - bytecode_.resize(offset + dataLength); - return MutableArrayRef<uint8_t>( - reinterpret_cast<uint8_t *>(bytecode_.data()) + offset, dataLength); -} - -LogicalResult BytecodeWriter::WriteBytes(const void *data, size_t dataLength) { - auto dst = ReserveBytes(dataLength); - if (dataLength != dst.size()) { - return failure(); - } - std::memcpy(dst.data(), data, dst.size()); - return success(); -} - -LogicalResult BytecodeWriter::WriteUint8(uint8_t value) { - return WriteBytes(&value, sizeof(value)); -} - -LogicalResult BytecodeWriter::WriteUint16(uint16_t value) { - return WriteBytes(&value, sizeof(value)); -} - -LogicalResult BytecodeWriter::WriteInt32(int32_t value) { - return WriteBytes(&value, sizeof(value)); -} - -LogicalResult BytecodeWriter::WriteUint32(uint32_t value) { - return WriteBytes(&value, sizeof(value)); -} - -LogicalResult BytecodeWriter::WriteElementsAttrInt32(ElementsAttr attr) { - int elementCount = attr.getType().getNumElements(); - RETURN_IF_FAILURE(WriteCount(elementCount)); - for (auto value : attr.getValues<int32_t>()) { - RETURN_IF_FAILURE(WriteInt32(value)); - } - return success(); -} - -LogicalResult BytecodeWriter::WriteShapePieces(const ShapedType &type) { - RETURN_IF_FAILURE(WriteCount(type.getRank())); - for (int64_t dim : type.getShape()) { - RETURN_IF_FAILURE(WriteInt32(dim)); - } - return success(); -} - -LogicalResult BytecodeWriter::WriteShapePieces(ElementsAttr pieces) { - return WriteElementsAttrInt32(pieces); -} - -LogicalResult BytecodeWriter::MarkBlockOffset(Block *block) { - blockOffsets_[block] = bytecode_.size(); - return success(); -} - -LogicalResult BytecodeWriter::WriteBlockOffset(Block *targetBlock) { - // Reserve space for the offset and stash for later fixup. - blockOffsetFixups_.push_back({targetBlock, bytecode_.size()}); - bytecode_.resize(bytecode_.size() + sizeof(int32_t)); - return success(); -} - -LogicalResult BytecodeWriter::FixupOffsets() { - for (const auto &fixup : blockOffsetFixups_) { - auto it = blockOffsets_.find(fixup.first); - if (it == blockOffsets_.end()) { - llvm::errs() << "Block offset not found: " << fixup.first; - return failure(); - } - std::memcpy(bytecode_.data() + fixup.second, &it->second, sizeof(int32_t)); - } - blockOffsetFixups_.clear(); - return success(); -} - -std::vector<uint8_t> BytecodeWriter::Finish() { - localMap_.clear(); - return std::move(bytecode_); -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Serialization/BytecodeWriter.h b/iree/compiler/Serialization/BytecodeWriter.h deleted file mode 100644 index dbadd1c..0000000 --- a/iree/compiler/Serialization/BytecodeWriter.h +++ /dev/null
@@ -1,96 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_COMPILER_SERIALIZATION_BYTECODE_WRITER_H_ -#define IREE_COMPILER_SERIALIZATION_BYTECODE_WRITER_H_ - -#include <cstddef> -#include <utility> -#include <vector> - -#include "iree/compiler/IR/StructureOps.h" -#include "iree/schemas/bytecode/bytecode_v0.h" -#include "llvm/ADT/Optional.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/Function.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/Types.h" -#include "mlir/IR/Value.h" - -namespace mlir { -namespace iree_compiler { - -class BytecodeWriter { - public: - int offset() const { return bytecode_.size(); } - - int local_count() const { return localMap_.size(); } - - template <typename T> - LogicalResult WriteOpcode(T value) { - static_assert(sizeof(T) == sizeof(uint8_t), "Opcode enum size mismatch"); - return WriteUint8(static_cast<uint8_t>(value)); - } - - LogicalResult WriteCount(int count); - - LogicalResult WriteTypeIndex(Type type); - - LogicalResult WriteFunctionOrdinal(FuncOp function); - LogicalResult WriteImportOrdinal(FuncOp function); - - LogicalResult WriteConstant(MemRefType memRefType, Attribute baseAttr); - LogicalResult WriteAttributeData(Attribute baseAttr); - - llvm::Optional<int> LookupLocalOrdinal(Value *value); - LogicalResult PrepareLocal(Value *value); - LogicalResult WriteLocal(Value *value); - LogicalResult WriteLocals( - llvm::iterator_range<Operation::operand_iterator> values); - LogicalResult WriteLocals( - llvm::iterator_range<Operation::result_iterator> values); - - LogicalResult WriteBytes(const void *data, size_t dataLength); - MutableArrayRef<uint8_t> ReserveBytes(size_t dataLength); - LogicalResult WriteUint8(uint8_t value); - LogicalResult WriteUint16(uint16_t value); - LogicalResult WriteInt32(int32_t value); - LogicalResult WriteUint32(uint32_t value); - - LogicalResult WriteElementsAttrInt32(ElementsAttr attr); - - LogicalResult WriteShapePieces(const ShapedType &type); - LogicalResult WriteShapePieces(ElementsAttr pieces); - - LogicalResult MarkBlockOffset(Block *block); - LogicalResult WriteBlockOffset(Block *targetBlock); - LogicalResult FixupOffsets(); - - std::vector<uint8_t> Finish(); - - private: - std::vector<uint8_t> bytecode_; - - llvm::DenseMap<Value *, int> localMap_; - - llvm::DenseMap<Block *, size_t> blockOffsets_; - std::vector<std::pair<Block *, size_t>> blockOffsetFixups_; -}; - -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_SERIALIZATION_BYTECODE_WRITER_H_
diff --git a/iree/compiler/Serialization/VMDeviceTableBuilder.cpp b/iree/compiler/Serialization/VMDeviceTableBuilder.cpp deleted file mode 100644 index 7703014..0000000 --- a/iree/compiler/Serialization/VMDeviceTableBuilder.cpp +++ /dev/null
@@ -1,46 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/Serialization/VMDeviceTableBuilder.h" - -namespace mlir { -namespace iree_compiler { - -VMDeviceTableBuilder::VMDeviceTableBuilder( - ::flatbuffers::FlatBufferBuilder *fbb) - : fbb_(fbb) {} - -LogicalResult VMDeviceTableBuilder::AddDevice( - ::flatbuffers::Offset<iree::DeviceDef> deviceDef) { - deviceDefs_.push_back(deviceDef); - return success(); -} - -LogicalResult VMDeviceTableBuilder::AddDeviceGroup( - ::flatbuffers::Offset<iree::DeviceGroupDef> deviceGroupDef) { - deviceGroupDefs_.push_back(deviceGroupDef); - return success(); -} - -::flatbuffers::Offset<iree::DeviceTableDef> VMDeviceTableBuilder::Finish() { - auto devicesOffset = fbb_->CreateVector(deviceDefs_); - auto deviceGroupsOffset = fbb_->CreateVector(deviceGroupDefs_); - iree::DeviceTableDefBuilder dtdb(*fbb_); - dtdb.add_devices(devicesOffset); - dtdb.add_device_groups(deviceGroupsOffset); - return dtdb.Finish(); -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Serialization/VMDeviceTableBuilder.h b/iree/compiler/Serialization/VMDeviceTableBuilder.h deleted file mode 100644 index 9170baf..0000000 --- a/iree/compiler/Serialization/VMDeviceTableBuilder.h +++ /dev/null
@@ -1,45 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_COMPILER_SERIALIZATION_VMDEVICETABLEBUILDER_H_ -#define IREE_COMPILER_SERIALIZATION_VMDEVICETABLEBUILDER_H_ - -#include "flatbuffers/flatbuffers.h" -#include "iree/schemas/device_table_def_generated.h" -#include "mlir/Support/LogicalResult.h" - -namespace mlir { -namespace iree_compiler { - -class VMDeviceTableBuilder { - public: - explicit VMDeviceTableBuilder(::flatbuffers::FlatBufferBuilder *fbb); - - LogicalResult AddDevice(::flatbuffers::Offset<iree::DeviceDef> deviceDef); - - LogicalResult AddDeviceGroup( - ::flatbuffers::Offset<iree::DeviceGroupDef> deviceGroupDef); - - ::flatbuffers::Offset<iree::DeviceTableDef> Finish(); - - private: - ::flatbuffers::FlatBufferBuilder *fbb_; - std::vector<::flatbuffers::Offset<iree::DeviceDef>> deviceDefs_; - std::vector<::flatbuffers::Offset<iree::DeviceGroupDef>> deviceGroupDefs_; -}; - -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_SERIALIZATION_VMDEVICETABLEBUILDER_H_
diff --git a/iree/compiler/Serialization/VMExecutableTableBuilder.cpp b/iree/compiler/Serialization/VMExecutableTableBuilder.cpp deleted file mode 100644 index 7a5004a..0000000 --- a/iree/compiler/Serialization/VMExecutableTableBuilder.cpp +++ /dev/null
@@ -1,41 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/Serialization/VMExecutableTableBuilder.h" - -namespace mlir { -namespace iree_compiler { - -VMExecutableTableBuilder::VMExecutableTableBuilder( - ::flatbuffers::FlatBufferBuilder *fbb) - : fbb_(fbb) {} - -LogicalResult VMExecutableTableBuilder::AddMultiArchExecutable( - ::flatbuffers::Offset<iree::MultiArchExecutableDef> - multiArchExecutableDef) { - multiArchExecutableDefs_.push_back(multiArchExecutableDef); - return success(); -} - -::flatbuffers::Offset<iree::ExecutableTableDef> -VMExecutableTableBuilder::Finish() { - auto multiArchExecutablesOffset = - fbb_->CreateVector(multiArchExecutableDefs_); - iree::ExecutableTableDefBuilder etdb(*fbb_); - etdb.add_multi_arch_executables(multiArchExecutablesOffset); - return etdb.Finish(); -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Serialization/VMExecutableTableBuilder.h b/iree/compiler/Serialization/VMExecutableTableBuilder.h deleted file mode 100644 index 1125a4d..0000000 --- a/iree/compiler/Serialization/VMExecutableTableBuilder.h +++ /dev/null
@@ -1,44 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_COMPILER_SERIALIZATION_VM_EXECUTABLE_TABLE_BUILDER_H_ -#define IREE_COMPILER_SERIALIZATION_VM_EXECUTABLE_TABLE_BUILDER_H_ - -#include "flatbuffers/flatbuffers.h" -#include "iree/schemas/executable_table_def_generated.h" -#include "mlir/Support/LogicalResult.h" - -namespace mlir { -namespace iree_compiler { - -class VMExecutableTableBuilder { - public: - explicit VMExecutableTableBuilder(::flatbuffers::FlatBufferBuilder *fbb); - - LogicalResult AddMultiArchExecutable( - ::flatbuffers::Offset<iree::MultiArchExecutableDef> - multiArchExecutableDef); - - ::flatbuffers::Offset<iree::ExecutableTableDef> Finish(); - - private: - ::flatbuffers::FlatBufferBuilder *fbb_; - std::vector<::flatbuffers::Offset<iree::MultiArchExecutableDef>> - multiArchExecutableDefs_; -}; - -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_SERIALIZATION_VM_EXECUTABLE_TABLE_BUILDER_H_
diff --git a/iree/compiler/Serialization/VMFunctionBuilder.cpp b/iree/compiler/Serialization/VMFunctionBuilder.cpp deleted file mode 100644 index cc2fee2..0000000 --- a/iree/compiler/Serialization/VMFunctionBuilder.cpp +++ /dev/null
@@ -1,359 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/Serialization/VMFunctionBuilder.h" - -#include "flatbuffers/flatbuffers.h" -#include "iree/compiler/IR/Dialect.h" -#include "iree/compiler/IR/Types.h" -#include "iree/compiler/Serialization/BytecodeTables.h" -#include "iree/compiler/Utils/Macros.h" -#include "iree/schemas/bytecode/bytecode_v0.h" -#include "iree/schemas/type_def_generated.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/Module.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -LogicalResult WriteGenericIreeOp(Block *block, Operation *op, - BytecodeWriter *writer) { - // Strip the dialect name from the op name and lookup the opcode. - // TODO(benvanik): adjust for supporting sequencer opcodes. - - auto opName = op->getName().getStringRef(); - auto dialect = op->getDialect(); - if (!dialect) { - return op->emitOpError() << "Op does not belong to a registered dialect"; - } - - auto dialectNamespace = dialect->getNamespace(); - std::unique_ptr<OpcodeInfo> operandInfo; - auto strippedOpName = opName.substr(opName.find('.') + 1).str(); - if (dialectNamespace == "iree_ll_seq") { - auto opcode = GetSequencerOpcodeByName(strippedOpName); - if (!opcode.hasValue()) { - return op->emitOpError() - << "No sequencer opcode found for op; is it a pseudo op?"; - } - RETURN_IF_FAILURE(writer->WriteOpcode(opcode.getValue())); - operandInfo = - std::make_unique<OpcodeInfo>(GetSequencerOpcodeInfo(opcode.getValue())); - } else if (dialectNamespace == "iree_ll_interp" || - // TODO(gcmn) remove special case for IREE dialect? - dialectNamespace == IREEDialect::getDialectNamespace()) { - auto opcode = GetInterpreterOpcodeByName(strippedOpName); - if (!opcode.hasValue()) { - return op->emitOpError() - << "No interpreter opcode found for op; is it a pseudo op?"; - } - RETURN_IF_FAILURE(writer->WriteOpcode(opcode.getValue())); - operandInfo = std::make_unique<OpcodeInfo>( - GetInterpreterOpcodeInfo(opcode.getValue())); - } else { - return op->emitOpError() - << "Op belongs to unknown dialect " << dialectNamespace.str(); - } - // Write inputs and outputs based on the bytecode encoding. - int operandIndex = 0; - int resultIndex = 0; - for (int i = 0; i < llvm::array_lengthof(operandInfo->operands); ++i) { - auto op_encoding = operandInfo->operands[i]; - if (op_encoding == iree::OperandEncoding::kNone) break; - switch (op_encoding) { - case iree::OperandEncoding::kInputSlot: - case iree::OperandEncoding::kOutputSlot: { - auto *value = op->getOperand(operandIndex++); - RETURN_IF_FAILURE(writer->WriteLocal(value)); - break; - } - case iree::OperandEncoding::kVariadicInputSlots: - case iree::OperandEncoding::kVariadicOutputSlots: { - int count = op->getNumOperands() - operandIndex; - RETURN_IF_FAILURE(writer->WriteCount(count)); - for (; count; --count) { - auto *value = op->getOperand(operandIndex++); - RETURN_IF_FAILURE(writer->WriteLocal(value)); - } - break; - } - case iree::OperandEncoding::kResultSlot: { - auto *value = op->getResult(resultIndex++); - RETURN_IF_FAILURE(writer->WriteLocal(value)); - break; - } - case iree::OperandEncoding::kVariadicResultSlots: { - int count = op->getNumResults() - resultIndex; - RETURN_IF_FAILURE(writer->WriteCount(count)); - for (; count; --count) { - auto *value = op->getResult(resultIndex++); - RETURN_IF_FAILURE(writer->WriteLocal(value)); - } - break; - } - case iree::OperandEncoding::kConstant: - case iree::OperandEncoding::kFunctionOrdinal: - case iree::OperandEncoding::kBlockOffset: - case iree::OperandEncoding::kTypeIndex: - case iree::OperandEncoding::kIndex: - case iree::OperandEncoding::kIndexList: - case iree::OperandEncoding::kCmpIPredicate: - case iree::OperandEncoding::kCmpFPredicate: - return op->emitOpError() - << "Operand encoding " << static_cast<char>(op_encoding) - << " not supported by generic writer for " << opName.str(); - return failure(); - default: - return op->emitOpError() - << "Operand encoding " << static_cast<char>(op_encoding) << " (" - << static_cast<int>(op_encoding) << ") not recognized (typo?)"; - } - } - - return success(); -} - -} // namespace - -VMFunctionBuilder::VMFunctionBuilder(FuncOp function, - VMFunctionTableBuilder *functionTable, - ::flatbuffers::FlatBufferBuilder *fbb) - : context_(function.getContext()), - function_(function), - functionTable_(functionTable), - fbb_(fbb) {} - -void VMFunctionBuilder::RegisterCustomWriter(StringRef operationName, - CustomWriterFn writerFn) { - customWriters_.insert({operationName, writerFn}); -} - -LogicalResult VMFunctionBuilder::ConvertBytecode() { - BytecodeWriter writer; - sourceMap_ = {}; - - RETURN_IF_FAILURE(BeginFunction(function_, &writer)); - for (auto &block : function_.getBlocks()) { - RETURN_IF_FAILURE(BeginBlock(&block, &writer)); - for (auto &op : block.getOperations()) { - if (failed(WriteOperation(&block, &op, &writer))) { - op.emitError() << "Unable to serialize operation"; - return failure(); - } - } - RETURN_IF_FAILURE(EndBlock(&block, block.getTerminator(), &writer)); - } - RETURN_IF_FAILURE(EndFunction(function_, &writer)); - - int localCount = writer.local_count(); - auto bodyBytes = writer.Finish(); - auto bodyOffset = fbb_->CreateVector( - reinterpret_cast<const int8_t *>(bodyBytes.data()), bodyBytes.size()); - iree::BytecodeDefBuilder bdb(*fbb_); - bdb.add_local_count(localCount); - bdb.add_contents(bodyOffset); - bytecodeDef_ = bdb.Finish(); - - return success(); -} - -::flatbuffers::Offset<iree::FunctionDef> VMFunctionBuilder::Finish() { - using TypeDefVector = - ::flatbuffers::Vector<::flatbuffers::Offset<iree::TypeDef>>; - - const auto &functionType = function_.getType(); - std::vector<::flatbuffers::Offset<iree::TypeDef>> inputs; - for (const auto &type : functionType.getInputs()) { - auto typeOffset = SerializeType(type, fbb_); - if (typeOffset.IsNull()) return {}; - inputs.push_back(typeOffset); - } - ::flatbuffers::Offset<TypeDefVector> inputsOffset; - if (!inputs.empty()) { - inputsOffset = fbb_->CreateVector(inputs); - } - - std::vector<::flatbuffers::Offset<iree::TypeDef>> results; - for (const auto &type : functionType.getResults()) { - auto typeOffset = SerializeType(type, fbb_); - if (typeOffset.IsNull()) return {}; - results.push_back(typeOffset); - } - ::flatbuffers::Offset<TypeDefVector> resultsOffset; - if (!results.empty()) { - resultsOffset = fbb_->CreateVector(results); - } - iree::FunctionTypeDefBuilder ftb(*fbb_); - ftb.add_inputs(inputsOffset); - ftb.add_results(resultsOffset); - auto functionTypeOffset = ftb.Finish(); - - // TODO(benvanik): strip names of internal functions. - auto nameOffset = fbb_->CreateString(function_.getName().str()); - iree::FunctionDefBuilder fdb(*fbb_); - fdb.add_name(nameOffset); - fdb.add_type(functionTypeOffset); - fdb.add_bytecode(bytecodeDef_); - return fdb.Finish(); -} - -LogicalResult VMFunctionBuilder::BeginFunction(FuncOp function, - BytecodeWriter *writer) { - // Assign value slots for all arguments and results. - // Keeping them at the front will make it easier to find during debugging - // and makes spans easier to compute at runtime. - for (auto argument : function.getArguments()) { - RETURN_IF_FAILURE(writer->PrepareLocal(argument)); - } - return success(); -} - -LogicalResult VMFunctionBuilder::EndFunction(FuncOp function, - BytecodeWriter *writer) { - RETURN_IF_FAILURE(writer->FixupOffsets()); - return success(); -} - -LogicalResult VMFunctionBuilder::BeginBlock(Block *block, - BytecodeWriter *writer) { - RETURN_IF_FAILURE(writer->MarkBlockOffset(block)); - return success(); -} - -LogicalResult VMFunctionBuilder::EndBlock(Block *block, Operation *op, - BytecodeWriter *writer) { - return success(); -} - -LogicalResult VMFunctionBuilder::WriteOperation(Block *block, Operation *baseOp, - BytecodeWriter *writer) { - if (!baseOp->getLoc().isa<UnknownLoc>()) { - sourceMap_.locations.push_back({writer->offset(), baseOp->getLoc()}); - } - - // Check registered writers first to allow overrides. - auto writerIt = customWriters_.find(baseOp->getName().getStringRef()); - if (writerIt != customWriters_.end()) { - return writerIt->second(baseOp, writer); - } - - // Fallback to using the generic writer. - if (baseOp->getAbstractOperation()->dialect.getNamespace().startswith( - "iree")) { - RETURN_IF_FAILURE(WriteGenericIreeOp(block, baseOp, writer)); - } else { - return baseOp->emitError() - << "Unsupported op " << baseOp->getName().getStringRef().str() - << "; incorrectly outlined or not yet implemented"; - } - return success(); -} - -::flatbuffers::Offset<iree::TypeDef> VMFunctionBuilder::SerializeType( - Type type, ::flatbuffers::FlatBufferBuilder *fbb) { - ::flatbuffers::Offset<void> typeDefUnion; - iree::TypeDefUnion typeUnionType; - if (auto memRefType = type.dyn_cast<MemRefType>()) { - auto memRefTypeOffset = SerializeMemRefType(memRefType, fbb_); - if (memRefTypeOffset.IsNull()) return {}; - typeDefUnion = memRefTypeOffset.Union(); - typeUnionType = iree::TypeDefUnion::MemRefTypeDef; - } else if (auto deviceType = type.dyn_cast<DeviceType>()) { - typeDefUnion = iree::CreateDeviceTypeDef(*fbb).Union(); - typeUnionType = iree::TypeDefUnion::DeviceTypeDef; - } else if (auto commandBufferType = type.dyn_cast<CommandBufferType>()) { - typeDefUnion = iree::CreateCommandBufferTypeDef(*fbb).Union(); - typeUnionType = iree::TypeDefUnion::CommandBufferTypeDef; - } else if (auto eventType = type.dyn_cast<EventType>()) { - typeDefUnion = iree::CreateEventTypeDef(*fbb).Union(); - typeUnionType = iree::TypeDefUnion::EventTypeDef; - } else if (auto semaphoreType = type.dyn_cast<SemaphoreType>()) { - typeDefUnion = iree::CreateSemaphoreTypeDef(*fbb).Union(); - typeUnionType = iree::TypeDefUnion::SemaphoreTypeDef; - } else if (auto fenceType = type.dyn_cast<FenceType>()) { - typeDefUnion = iree::CreateFenceTypeDef(*fbb).Union(); - typeUnionType = iree::TypeDefUnion::FenceTypeDef; - } else { - function_.emitError() << "Function " << function_.getName().str() - << " has unsupported I/O with type " << type; - return {}; - } - - iree::TypeDefBuilder tdb(*fbb); - tdb.add_type_union_type(typeUnionType); - tdb.add_type_union(typeDefUnion); - return tdb.Finish(); -} - -::flatbuffers::Offset<iree::MemRefTypeDef> -VMFunctionBuilder::SerializeMemRefType(const MemRefType &type, - ::flatbuffers::FlatBufferBuilder *fbb) { - auto elementTypeOffset = SerializeElementType(type.getElementType(), fbb); - if (elementTypeOffset.IsNull()) return {}; - std::vector<int> shape; - for (int dim : type.getShape()) { - shape.push_back(dim); - } - auto shapeOffset = fbb->CreateVector(shape); - iree::MemRefTypeDefBuilder tb(*fbb); - tb.add_element_type(elementTypeOffset); - tb.add_shape(shapeOffset); - tb.add_memory_space(type.getMemorySpace()); - return tb.Finish(); -} - -::flatbuffers::Offset<iree::ElementTypeDef> -VMFunctionBuilder::SerializeElementType(const Type &genericType, - ::flatbuffers::FlatBufferBuilder *fbb) { - ::flatbuffers::Offset<void> typeDefUnion; - iree::ElementTypeDefUnion typeUnionType; - if (auto type = genericType.dyn_cast<FloatType>()) { - iree::FloatTypeDefBuilder tb(*fbb); - tb.add_width(type.getWidth()); - typeDefUnion = tb.Finish().Union(); - typeUnionType = iree::ElementTypeDefUnion::FloatTypeDef; - } else if (auto type = genericType.dyn_cast<IntegerType>()) { - iree::IntegerTypeDefBuilder tb(*fbb); - tb.add_width(type.getWidth()); - typeDefUnion = tb.Finish().Union(); - typeUnionType = iree::ElementTypeDefUnion::IntegerTypeDef; - } else if (auto type = genericType.dyn_cast<OpaqueType>()) { - auto dialectOffset = fbb->CreateString(type.getDialectNamespace().c_str()); - auto typeDataOffset = fbb->CreateString(type.getTypeData().data()); - iree::UnknownTypeDefBuilder tb(*fbb); - tb.add_dialect(dialectOffset); - tb.add_type_data(typeDataOffset); - typeDefUnion = tb.Finish().Union(); - typeUnionType = iree::ElementTypeDefUnion::UnknownTypeDef; - } else { - function_.emitError() - << "Unimplemented type encoding: " << genericType - << "; ensure IREE lowering passes are converting types to the IREE " - "set"; - return {}; - } - - iree::ElementTypeDefBuilder tdb(*fbb); - tdb.add_type_union_type(typeUnionType); - tdb.add_type_union(typeDefUnion); - return tdb.Finish(); -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Serialization/VMFunctionBuilder.h b/iree/compiler/Serialization/VMFunctionBuilder.h deleted file mode 100644 index 129afeb..0000000 --- a/iree/compiler/Serialization/VMFunctionBuilder.h +++ /dev/null
@@ -1,77 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_COMPILER_SERIALIZATION_VM_FUNCTION_BUILDER_H_ -#define IREE_COMPILER_SERIALIZATION_VM_FUNCTION_BUILDER_H_ - -#include "iree/compiler/Serialization/BytecodeWriter.h" -#include "iree/compiler/Serialization/VMFunctionTableBuilder.h" -#include "iree/compiler/Serialization/VMSourceMapBuilder.h" -#include "iree/schemas/bytecode_def_generated.h" -#include "iree/schemas/function_def_generated.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/Function.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/StandardTypes.h" - -namespace mlir { -namespace iree_compiler { - -class VMFunctionBuilder { - public: - using CustomWriterFn = - std::function<LogicalResult(Operation *, BytecodeWriter *writer)>; - - VMFunctionBuilder(FuncOp function, VMFunctionTableBuilder *functionTable, - ::flatbuffers::FlatBufferBuilder *fbb); - ~VMFunctionBuilder() = default; - - void RegisterCustomWriter(StringRef operationName, CustomWriterFn writerFn); - - const VMFunctionSourceMap &source_map() const { return sourceMap_; } - - LogicalResult ConvertBytecode(); - - ::flatbuffers::Offset<iree::FunctionDef> Finish(); - - ::flatbuffers::Offset<iree::TypeDef> SerializeType( - Type type, ::flatbuffers::FlatBufferBuilder *fbb); - ::flatbuffers::Offset<iree::MemRefTypeDef> SerializeMemRefType( - const MemRefType &genericType, ::flatbuffers::FlatBufferBuilder *fbb); - ::flatbuffers::Offset<iree::ElementTypeDef> SerializeElementType( - const Type &genericType, ::flatbuffers::FlatBufferBuilder *fbb); - - private: - LogicalResult BeginFunction(FuncOp function, BytecodeWriter *writer); - LogicalResult EndFunction(FuncOp function, BytecodeWriter *writer); - LogicalResult BeginBlock(Block *block, BytecodeWriter *writer); - LogicalResult EndBlock(Block *block, Operation *op, BytecodeWriter *writer); - - LogicalResult WriteOperation(Block *block, Operation *baseOp, - BytecodeWriter *writer); - - llvm::StringMap<CustomWriterFn> customWriters_; - - MLIRContext *context_; - FuncOp function_; - VMFunctionTableBuilder *functionTable_; - ::flatbuffers::FlatBufferBuilder *fbb_; - ::flatbuffers::Offset<iree::BytecodeDef> bytecodeDef_; - VMFunctionSourceMap sourceMap_; -}; - -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_SERIALIZATION_VM_FUNCTION_BUILDER_H_
diff --git a/iree/compiler/Serialization/VMFunctionTableBuilder.cpp b/iree/compiler/Serialization/VMFunctionTableBuilder.cpp deleted file mode 100644 index 090b520..0000000 --- a/iree/compiler/Serialization/VMFunctionTableBuilder.cpp +++ /dev/null
@@ -1,87 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/Serialization/VMFunctionTableBuilder.h" - -#include "iree/compiler/Serialization/VMSourceMapBuilder.h" -#include "llvm/Support/raw_ostream.h" - -namespace mlir { -namespace iree_compiler { - -VMFunctionTableBuilder::VMFunctionTableBuilder( - ::flatbuffers::FlatBufferBuilder *fbb) - : fbb_(fbb) {} - -bool VMFunctionTableBuilder::IsFunctionDeclared(FuncOp funcOp) { - return functionSet_.count(funcOp.getName()) != 0; -} - -LogicalResult VMFunctionTableBuilder::DeclareFunction(FuncOp funcOp, - LinkageType linkageType) { - if (functionSet_.count(funcOp.getName())) { - return funcOp.emitError() << "Function has already been declared/defined"; - } - auto functionOrdinal = funcOp.getAttrOfType<IntegerAttr>("iree.ordinal"); - if (!functionOrdinal) { - return funcOp.emitError() << "Ordinal not assigned to function"; - } - int ordinal = functionOrdinal.getInt(); - functionDefs_.resize( - std::max(functionDefs_.size(), static_cast<size_t>(ordinal) + 1u)); - functionSourceMaps_.resize( - std::max(functionDefs_.size(), static_cast<size_t>(ordinal) + 1u)); - functionSet_.insert({funcOp.getName()}); - switch (linkageType) { - case LinkageType::kInternal: - break; - case LinkageType::kImport: - importIndices_.push_back(ordinal); - break; - case LinkageType::kExport: - exportIndices_.push_back(ordinal); - break; - } - return success(); -} - -LogicalResult VMFunctionTableBuilder::DefineFunction( - FuncOp funcOp, ::flatbuffers::Offset<iree::FunctionDef> functionDef, - VMFunctionSourceMap functionSourceMap) { - auto functionOrdinal = funcOp.getAttrOfType<IntegerAttr>("iree.ordinal"); - if (!functionOrdinal) { - return funcOp.emitError() << "Ordinal not assigned to function"; - } - int ordinal = functionOrdinal.getInt(); - if (!functionDefs_[ordinal].IsNull()) { - return funcOp.emitOpError() << "Function has already been defined"; - } - functionDefs_[ordinal] = functionDef; - functionSourceMaps_[ordinal] = std::move(functionSourceMap); - return success(); -} - -::flatbuffers::Offset<iree::FunctionTableDef> VMFunctionTableBuilder::Finish() { - auto functionsOffset = fbb_->CreateVector(functionDefs_); - auto importsOffset = fbb_->CreateVector(importIndices_); - auto exportsOffset = fbb_->CreateVector(exportIndices_); - iree::FunctionTableDefBuilder ftdb(*fbb_); - ftdb.add_functions(functionsOffset); - ftdb.add_imports(importsOffset); - ftdb.add_exports(exportsOffset); - return ftdb.Finish(); -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Serialization/VMFunctionTableBuilder.h b/iree/compiler/Serialization/VMFunctionTableBuilder.h deleted file mode 100644 index fc87abf..0000000 --- a/iree/compiler/Serialization/VMFunctionTableBuilder.h +++ /dev/null
@@ -1,75 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_COMPILER_SERIALIZATION_VM_FUNCTION_TABLE_BUILDER_H_ -#define IREE_COMPILER_SERIALIZATION_VM_FUNCTION_TABLE_BUILDER_H_ - -#include <string> -#include <vector> - -#include "flatbuffers/flatbuffers.h" -#include "iree/compiler/Serialization/VMSourceMapBuilder.h" -#include "iree/schemas/function_def_generated.h" -#include "iree/schemas/function_table_def_generated.h" -#include "llvm/ADT/StringSet.h" -#include "mlir/IR/Function.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/OperationSupport.h" - -namespace mlir { -namespace iree_compiler { - -enum class LinkageType { - kInternal, - kImport, - kExport, -}; - -class VMFunctionTableBuilder { - public: - explicit VMFunctionTableBuilder(::flatbuffers::FlatBufferBuilder *fbb); - - int max_function_ordinal() const { return functionDefs_.size(); } - - ArrayRef<VMFunctionSourceMap> function_source_maps() { - return llvm::makeArrayRef(functionSourceMaps_); - } - - // Returns true if |funcOp| has already been declared in the table. - bool IsFunctionDeclared(FuncOp funcOp); - - // Declares |funcOp| with the given |linkageType|. - // Fails if the function has already been declared or defined. - LogicalResult DeclareFunction(FuncOp funcOp, LinkageType linkageType); - - // Defines |funcOp| using the given |functionDef|. - LogicalResult DefineFunction( - FuncOp funcOp, ::flatbuffers::Offset<iree::FunctionDef> functionDef, - VMFunctionSourceMap functionSourceMap); - - ::flatbuffers::Offset<iree::FunctionTableDef> Finish(); - - private: - ::flatbuffers::FlatBufferBuilder *fbb_; - llvm::StringSet<> functionSet_; - std::vector<::flatbuffers::Offset<iree::FunctionDef>> functionDefs_; - std::vector<VMFunctionSourceMap> functionSourceMaps_; - std::vector<int> importIndices_; - std::vector<int> exportIndices_; -}; - -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_SERIALIZATION_VM_FUNCTION_TABLE_BUILDER_H_
diff --git a/iree/compiler/Serialization/VMModuleBuilder.cpp b/iree/compiler/Serialization/VMModuleBuilder.cpp deleted file mode 100644 index 50cea2e..0000000 --- a/iree/compiler/Serialization/VMModuleBuilder.cpp +++ /dev/null
@@ -1,70 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/Serialization/VMModuleBuilder.h" - -#include "iree/schemas/executable_table_def_generated.h" - -namespace mlir { -namespace iree_compiler { - -VMModuleBuilder::VMModuleBuilder(::flatbuffers::FlatBufferBuilder *fbb) - : fbb_(fbb), - deviceTable_(fbb), - functionTable_(fbb), - executableTable_(fbb), - sourceMap_(fbb) {} - -::flatbuffers::Offset<iree::ModuleDef> VMModuleBuilder::Finish() { - auto nameOffset = fbb_->CreateString("module"); - auto deviceTableOffset = deviceTable_.Finish(); - if (deviceTableOffset.IsNull()) return {}; - auto functionTableOffset = functionTable_.Finish(); - if (functionTableOffset.IsNull()) return {}; - auto executableTableOffset = executableTable_.Finish(); - if (executableTableOffset.IsNull()) return {}; - - for (int function_ordinal = 0; - function_ordinal < functionTable_.function_source_maps().size(); - ++function_ordinal) { - if (failed(sourceMap_.AddFunction( - function_ordinal, - functionTable_.function_source_maps()[function_ordinal]))) { - return {}; - } - } - auto sourceMapOffset = - sourceMap_.Finish(functionTable_.max_function_ordinal()); - if (sourceMapOffset.IsNull()) return {}; - - iree::ModuleDefBuilder mdb(*fbb_); - mdb.add_name(nameOffset); - mdb.add_device_table(deviceTableOffset); - mdb.add_function_table(functionTableOffset); - mdb.add_executable_table(executableTableOffset); - mdb.add_source_map(sourceMapOffset); - return mdb.Finish(); -} - -std::vector<uint8_t> VMModuleBuilder::Serialize( - ::flatbuffers::Offset<iree::ModuleDef> module_def) { - FinishModuleDefBuffer(*fbb_, module_def); - std::vector<uint8_t> bytes; - bytes.resize(fbb_->GetSize()); - std::memcpy(bytes.data(), fbb_->GetBufferPointer(), bytes.size()); - return bytes; -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Serialization/VMModuleBuilder.h b/iree/compiler/Serialization/VMModuleBuilder.h deleted file mode 100644 index 1f991fa..0000000 --- a/iree/compiler/Serialization/VMModuleBuilder.h +++ /dev/null
@@ -1,57 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_COMPILER_SERIALIZATION_VM_MODULE_BUILDER_H_ -#define IREE_COMPILER_SERIALIZATION_VM_MODULE_BUILDER_H_ - -#include <vector> - -#include "flatbuffers/flatbuffers.h" -#include "iree/compiler/Serialization/VMDeviceTableBuilder.h" -#include "iree/compiler/Serialization/VMExecutableTableBuilder.h" -#include "iree/compiler/Serialization/VMFunctionTableBuilder.h" -#include "iree/compiler/Serialization/VMSourceMapBuilder.h" -#include "iree/schemas/module_def_generated.h" - -namespace mlir { -namespace iree_compiler { - -class VMModuleBuilder { - public: - explicit VMModuleBuilder(::flatbuffers::FlatBufferBuilder *fbb); - - ::flatbuffers::FlatBufferBuilder *fbb() const { return fbb_; } - VMDeviceTableBuilder *device_table() { return &deviceTable_; } - VMFunctionTableBuilder *function_table() { return &functionTable_; } - VMExecutableTableBuilder *executable_table() { return &executableTable_; } - VMSourceMapBuilder *source_map() { return &sourceMap_; } - - ::flatbuffers::Offset<iree::ModuleDef> Finish(); - - std::vector<uint8_t> Serialize( - ::flatbuffers::Offset<iree::ModuleDef> module_def); - - private: - ::flatbuffers::FlatBufferBuilder *fbb_; - - VMDeviceTableBuilder deviceTable_; - VMFunctionTableBuilder functionTable_; - VMExecutableTableBuilder executableTable_; - VMSourceMapBuilder sourceMap_; -}; - -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_SERIALIZATION_VM_MODULE_BUILDER_H_
diff --git a/iree/compiler/Serialization/VMSourceMapBuilder.cpp b/iree/compiler/Serialization/VMSourceMapBuilder.cpp deleted file mode 100644 index c0b0637..0000000 --- a/iree/compiler/Serialization/VMSourceMapBuilder.cpp +++ /dev/null
@@ -1,164 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/Serialization/VMSourceMapBuilder.h" - -#include "flatbuffers/flatbuffers.h" -#include "iree/schemas/source_map_def_generated.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Identifier.h" -#include "mlir/IR/Location.h" - -namespace mlir { -namespace iree_compiler { - -VMSourceMapBuilder::VMSourceMapBuilder(::flatbuffers::FlatBufferBuilder *fbb) - : fbb_(fbb) {} - -int VMSourceMapBuilder::GetUniqueString(std::string value) { - auto it = stringTableMap_.find(value); - if (it != stringTableMap_.end()) { - return it->second; - } - int stringIndex = stringTable_.size(); - stringTableMap_.insert({value, stringIndex}); - stringTable_.push_back(std::move(value)); - return stringIndex; -} - -LogicalResult VMSourceMapBuilder::AddFunction( - int functionOrdinal, VMFunctionSourceMap functionSourceMap) { - if (functionMaps_.size() <= functionOrdinal) { - functionMaps_.resize(functionOrdinal + 1); - } - functionMaps_[functionOrdinal] = std::move(functionSourceMap); - return success(); -} - -::flatbuffers::Offset<iree::SourceMapDef> VMSourceMapBuilder::Finish( - int maxFunctionOrdinal) { - // NOTE: we always ensure the source map table is the same size as the - // function table so that lookups at runtime can be validated once at load - // time (ensuring the tables match up) instead of on each lookup. - if (maxFunctionOrdinal < functionMaps_.size()) { - llvm::errs() << "Max function ordinal defined as " << maxFunctionOrdinal - << " but there are " << functionMaps_.size() - << " function source maps present"; - return {}; - } - functionMaps_.resize(maxFunctionOrdinal); - - std::vector<::flatbuffers::Offset<iree::FunctionSourceMapDef>> functionDefs; - functionDefs.resize(maxFunctionOrdinal); - for (int i = 0; i < functionMaps_.size(); ++i) { - const auto &functionMap = functionMaps_[i]; - functionDefs[i] = SerializeVMFunctionSourceMap(functionMap); - if (functionDefs[i].IsNull()) return {}; - } - - auto functionTableOffset = fbb_->CreateVector(functionDefs); - auto stringTableOffset = fbb_->CreateVectorOfStrings(stringTable_); - iree::SourceMapDefBuilder smdb(*fbb_); - smdb.add_function_table(functionTableOffset); - smdb.add_string_table(stringTableOffset); - return smdb.Finish(); -} - -::flatbuffers::Offset<iree::FunctionSourceMapDef> -VMSourceMapBuilder::SerializeVMFunctionSourceMap( - const VMFunctionSourceMap &functionMap) { - if (functionMap.locations.empty()) { - // Empty table. This ensures that we still have a non-null value in the - // function table but doesn't waste much space. - iree::FunctionSourceMapDefBuilder fsmdb(*fbb_); - return fsmdb.Finish(); - } - - LocationOffsetTable locationOffsetTable; - std::vector<iree::BytecodeSourceLocation> bytecodeMap; - for (const auto &offset_location : functionMap.locations) { - int locationIndex = - SerializeLocation(offset_location.second, &locationOffsetTable); - bytecodeMap.push_back({offset_location.first, locationIndex}); - } - auto locationTableOffset = - fbb_->CreateVector(locationOffsetTable.locationDefs); - auto bytecodeMapOffset = fbb_->CreateVectorOfStructs(bytecodeMap); - - iree::FunctionSourceMapDefBuilder fsmdb(*fbb_); - fsmdb.add_location_table(locationTableOffset); - fsmdb.add_bytecode_map(bytecodeMapOffset); - return fsmdb.Finish(); -} - -int VMSourceMapBuilder::SerializeLocation( - const Location &location, LocationOffsetTable *locationOffsetTable) { - auto existingIt = locationOffsetTable->locationMap.find(location); - if (existingIt != locationOffsetTable->locationMap.end()) { - return existingIt->getSecond(); - } - - iree::LocationDefUnion locationUnionType; - ::flatbuffers::Offset<void> locationUnionOffset; - if (auto fileLoc = location.dyn_cast<FileLineColLoc>()) { - locationUnionType = iree::LocationDefUnion::FileLocationDef; - int filenameIndex = GetUniqueString(fileLoc.getFilename().str()); - iree::FileLocationDefBuilder lb(*fbb_); - lb.add_filename(filenameIndex); - lb.add_line(fileLoc.getLine()); - lb.add_column(fileLoc.getColumn()); - locationUnionOffset = lb.Finish().Union(); - } else if (auto nameLoc = location.dyn_cast<NameLoc>()) { - locationUnionType = iree::LocationDefUnion::NameLocationDef; - int nameIndex = GetUniqueString(nameLoc.getName().str()); - iree::NameLocationDefBuilder lb(*fbb_); - lb.add_name(nameIndex); - locationUnionOffset = lb.Finish().Union(); - } else if (auto callSiteLoc = location.dyn_cast<CallSiteLoc>()) { - locationUnionType = iree::LocationDefUnion::CallSiteLocationDef; - int calleeIndex = - SerializeLocation(callSiteLoc.getCallee(), locationOffsetTable); - int callerIndex = - SerializeLocation(callSiteLoc.getCaller(), locationOffsetTable); - iree::CallSiteLocationDefBuilder lb(*fbb_); - lb.add_callee_location(calleeIndex); - lb.add_caller_location(callerIndex); - locationUnionOffset = lb.Finish().Union(); - } else if (auto fusedLoc = location.dyn_cast<FusedLoc>()) { - locationUnionType = iree::LocationDefUnion::FusedLocationDef; - std::vector<int> locationIndices; - locationIndices.reserve(fusedLoc.getLocations().size()); - for (const auto &child_loc : fusedLoc.getLocations()) { - int child_index = SerializeLocation(child_loc, locationOffsetTable); - locationIndices.push_back(child_index); - } - auto locationIndicesOffset = fbb_->CreateVector(locationIndices); - iree::FusedLocationDefBuilder lb(*fbb_); - lb.add_locations(locationIndicesOffset); - locationUnionOffset = lb.Finish().Union(); - } else { - llvm_unreachable("Unimplemented location kind"); - } - - iree::LocationDefBuilder ldb(*fbb_); - ldb.add_location_union_type(locationUnionType); - ldb.add_location_union(locationUnionOffset); - int locationIndex = locationOffsetTable->locationDefs.size(); - locationOffsetTable->locationDefs.push_back(ldb.Finish()); - locationOffsetTable->locationMap.insert({location, locationIndex}); - return locationIndex; -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Serialization/VMSourceMapBuilder.h b/iree/compiler/Serialization/VMSourceMapBuilder.h deleted file mode 100644 index 6c202ac..0000000 --- a/iree/compiler/Serialization/VMSourceMapBuilder.h +++ /dev/null
@@ -1,64 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_COMPILER_SERIALIZATION_VM_SOURCE_MAP_BUILDER_H_ -#define IREE_COMPILER_SERIALIZATION_VM_SOURCE_MAP_BUILDER_H_ - -#include <vector> - -#include "flatbuffers/flatbuffers.h" -#include "iree/schemas/source_map_def_generated.h" -#include "llvm/ADT/StringMap.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" - -namespace mlir { -namespace iree_compiler { - -struct VMFunctionSourceMap { - std::vector<std::pair<int, Location>> locations; -}; - -class VMSourceMapBuilder { - public: - explicit VMSourceMapBuilder(::flatbuffers::FlatBufferBuilder *fbb); - - LogicalResult AddFunction(int functionOrdinal, - VMFunctionSourceMap functionSourceMap); - - ::flatbuffers::Offset<iree::SourceMapDef> Finish(int maxFunctionOrdinal); - - private: - struct LocationOffsetTable { - std::vector<::flatbuffers::Offset<iree::LocationDef>> locationDefs; - llvm::DenseMap<Location, int> locationMap; - }; - - int GetUniqueString(std::string value); - - ::flatbuffers::Offset<iree::FunctionSourceMapDef> - SerializeVMFunctionSourceMap(const VMFunctionSourceMap &functionMap); - int SerializeLocation(const Location &location, - LocationOffsetTable *locationOffsetTable); - - ::flatbuffers::FlatBufferBuilder *fbb_; - std::vector<std::string> stringTable_; - llvm::StringMap<int> stringTableMap_; - std::vector<VMFunctionSourceMap> functionMaps_; -}; - -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_SERIALIZATION_VM_SOURCE_MAP_BUILDER_H_
diff --git a/iree/compiler/Transforms/AggressiveOpElimination.cpp b/iree/compiler/Transforms/AggressiveOpElimination.cpp deleted file mode 100644 index 31c706b..0000000 --- a/iree/compiler/Transforms/AggressiveOpElimination.cpp +++ /dev/null
@@ -1,77 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <deque> -#include <memory> - -#include "iree/compiler/IR/Interpreter/HLOps.h" -#include "iree/compiler/IR/Interpreter/LLOps.h" -#include "iree/compiler/IR/Sequencer/HLOps.h" -#include "iree/compiler/IR/Sequencer/LLOps.h" -#include "iree/compiler/IR/StructureOps.h" -#include "mlir/Analysis/Dominance.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Transforms/Utils.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -template <typename T> -struct EraseUnused : public OpRewritePattern<T> { - using OpRewritePattern<T>::OpRewritePattern; - PatternMatchResult matchAndRewrite(T op, - PatternRewriter &rewriter) const override { - if (op.use_empty()) { - rewriter.eraseOp(op); - return this->matchSuccess(); - } - return this->matchFailure(); - } -}; - -void populateAggressiveOpEliminationPatterns(OwningRewritePatternList &patterns, - MLIRContext *ctx) { - patterns.insert<EraseUnused<LoadOp>, EraseUnused<AllocOp>, - EraseUnused<IREESeq::HL::AllocHeapOp>, - EraseUnused<IREESeq::LL::AllocHeapOp>, - EraseUnused<IREEInterp::HL::AllocHeapOp>, - EraseUnused<IREEInterp::LL::AllocHeapOp>>(ctx); -} - -} // namespace - -// TODO(b/142012496) Make these be handled by normal DCE. -class AggressiveOpEliminationPass - : public FunctionPass<AggressiveOpEliminationPass> { - public: - void runOnFunction() override { - OwningRewritePatternList patterns; - populateAggressiveOpEliminationPatterns(patterns, &getContext()); - - applyPatternsGreedily(getFunction(), patterns); - } -}; - -std::unique_ptr<OpPassBase<FuncOp>> createAggressiveOpEliminationPass() { - return std::make_unique<AggressiveOpEliminationPass>(); -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Transforms/BUILD b/iree/compiler/Transforms/BUILD deleted file mode 100644 index 4060276..0000000 --- a/iree/compiler/Transforms/BUILD +++ /dev/null
@@ -1,41 +0,0 @@ -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "Transforms", - srcs = [ - "AggressiveOpElimination.cpp", - "AssignFunctionOrdinals.cpp", - "ConvertFromTupleCallingConvention.cpp", - "ConvertToMemRefCallingConvention.cpp", - "DropUnreachableFunctions.cpp", - "DropUnusedExecutables.cpp", - "LegalizeTypeStorage.cpp", - "LowerStdToIreeDialect.cpp", - "LowerXLAToIreeDialect.cpp", - ], - hdrs = [ - "ConversionUtils.h", - "Passes.h", - "Rewrites.h", - ], - deps = [ - "//iree/compiler/IR", - "//iree/compiler/IR/Interpreter", - "//iree/compiler/IR/Sequencer", - "//iree/compiler/Utils", - "@llvm//:support", - "@local_config_mlir//:Analysis", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", - "@local_config_mlir//:StandardDialectRegistration", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Support", - "@local_config_mlir//:TransformUtils", - "@local_config_mlir//:Transforms", - "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo", - ], - alwayslink = 1, -)
diff --git a/iree/compiler/Transforms/ConversionUtils.h b/iree/compiler/Transforms/ConversionUtils.h deleted file mode 100644 index 18ccd6e..0000000 --- a/iree/compiler/Transforms/ConversionUtils.h +++ /dev/null
@@ -1,101 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_COMPILER_TRANSFORMS_CONVERSIONUTILS_H_ -#define IREE_COMPILER_TRANSFORMS_CONVERSIONUTILS_H_ - -#include "iree/compiler/Utils/MemRefUtils.h" -#include "iree/compiler/Utils/TypeConversionUtils.h" -#include "mlir/Transforms/DialectConversion.h" - -namespace mlir { -namespace iree_compiler { - -template <typename SrcOp, typename DstOp> -struct UnaryOpLowering : public OpConversionPattern<SrcOp> { - using OpConversionPattern<SrcOp>::OpConversionPattern; - - PatternMatchResult matchAndRewrite( - SrcOp op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - auto *value = loadAccessValue(op.getLoc(), operands[0], rewriter); - value = wrapAsMemRef(value, op, rewriter); - - auto dstType = convertTypeToMemRef(op.getResult()); - auto dstOp = rewriter.create<DstOp>(op.getLoc(), dstType, value); - auto result = dstOp.getResult(); - result = wrapAsTensor(result, op, rewriter); - - rewriter.replaceOp( - op, {loadResultValue(op.getLoc(), op.getType(), result, rewriter)}); - return this->matchSuccess(); - } -}; - -template <typename SrcOp, typename DstOp> -struct BinaryOpLowering : public OpConversionPattern<SrcOp> { - using OpConversionPattern<SrcOp>::OpConversionPattern; - - PatternMatchResult matchAndRewrite( - SrcOp op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - auto *lhsValue = loadAccessValue(op.getLoc(), operands[0], rewriter); - auto *rhsValue = loadAccessValue(op.getLoc(), operands[1], rewriter); - auto dstType = convertTypeToMemRef(op.getResult()); - - lhsValue = wrapAsMemRef(lhsValue, op, rewriter); - rhsValue = wrapAsMemRef(rhsValue, op, rewriter); - - auto midOp = - rewriter.create<DstOp>(op.getLoc(), dstType, lhsValue, rhsValue); - auto result = midOp.getResult(); - result = wrapAsTensor(result, op, rewriter); - - rewriter.replaceOp( - op, {loadResultValue(op.getLoc(), op.getType(), result, rewriter)}); - return this->matchSuccess(); - } -}; - -template <typename SrcOp, typename DstOp> -struct TernaryOpLowering : public OpConversionPattern<SrcOp> { - using OpConversionPattern<SrcOp>::OpConversionPattern; - - PatternMatchResult matchAndRewrite( - SrcOp op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - auto *aValue = loadAccessValue(op.getLoc(), operands[0], rewriter); - auto *bValue = loadAccessValue(op.getLoc(), operands[1], rewriter); - auto *cValue = loadAccessValue(op.getLoc(), operands[2], rewriter); - - aValue = wrapAsMemRef(aValue, op, rewriter); - bValue = wrapAsMemRef(bValue, op, rewriter); - cValue = wrapAsMemRef(cValue, op, rewriter); - - auto dstType = convertTypeToMemRef(op.getResult()); - auto dstOp = - rewriter.create<DstOp>(op.getLoc(), dstType, aValue, bValue, cValue); - auto result = dstOp.getResult(); - result = wrapAsTensor(result, op, rewriter); - - rewriter.replaceOp( - op, {loadResultValue(op.getLoc(), op.getType(), result, rewriter)}); - return this->matchSuccess(); - } -}; - -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_TRANSFORMS_CONVERSIONUTILS_H_
diff --git a/iree/compiler/Transforms/ConvertToMemRefCallingConvention.cpp b/iree/compiler/Transforms/ConvertToMemRefCallingConvention.cpp deleted file mode 100644 index 6bebd56..0000000 --- a/iree/compiler/Transforms/ConvertToMemRefCallingConvention.cpp +++ /dev/null
@@ -1,398 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Ops.h" -#include "iree/compiler/IR/StructureOps.h" -#include "iree/compiler/Utils/MemRefUtils.h" -#include "iree/compiler/Utils/TypeConversionUtils.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/SymbolTable.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Transforms/Utils.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -void copyOperationAttrs(Operation *oldOp, Operation *newOp) { - for (const auto &oldAttr : oldOp->getAttrs()) { - newOp->setAttr(oldAttr.first, oldAttr.second); - } -} - -FunctionType getMemRefFunctionType(FunctionType type) { - Builder builder(type.getContext()); - llvm::SmallVector<Type, 8> replacementInputs; - for (auto type : type.getInputs()) { - auto memRefType = convertTypeToMemRef(type); - if (!memRefType) { - return nullptr; - } - replacementInputs.push_back(memRefType); - } - llvm::SmallVector<Type, 8> replacementResults; - for (auto type : type.getResults()) { - auto memRefType = convertTypeToMemRef(type); - if (!memRefType) { - return nullptr; - } - replacementResults.push_back(memRefType); - } - return builder.getFunctionType(replacementInputs, replacementResults); -} - -bool insertLoad(BlockArgument *oldArg, BlockArgument *newArg, - OpBuilder &builder, BlockAndValueMapping *mapping) { - auto loc = oldArg->getOwner()->getParent()->getLoc(); - - // If old arg was a memref we don't need to change anything. We still need - // to remap so that the use lists match through conversion, though. - if (oldArg->getType().isa<MemRefType>()) { - mapping->map(oldArg, newArg); - return false; - } else if (oldArg->getType().isa<TensorType>()) { - auto castOp = builder.create<IREE::MemRefToTensorOp>(loc, newArg); - mapping->map(oldArg, castOp.getResult()); - return false; - } - - // Insert the load we'll use to unbox the value. - auto loadedValue = builder.create<LoadOp>(loc, newArg, ArrayRef<Value *>{}); - mapping->map(oldArg, loadedValue); - - return false; -} - -bool insertLoad(Operation *oldOp, Value *oldValue, Value *newValue, - OpBuilder &builder, BlockAndValueMapping *mapping) { - // If old value was a memref we don't need to change anything. - if (oldValue->getType().isa<MemRefType>()) { - mapping->map(oldValue, newValue); - return false; - } else if (oldValue->getType().isa<TensorType>()) { - auto castOp = - builder.create<IREE::MemRefToTensorOp>(oldOp->getLoc(), newValue); - mapping->map(oldValue, castOp.getResult()); - return false; - } - - assert(newValue->getType().isa<MemRefType>()); - - // Insert the load we'll use to unbox the value. - auto loadedValue = - builder.create<LoadOp>(oldOp->getLoc(), newValue, ArrayRef<Value *>{}); - mapping->map(oldValue, loadedValue); - - return false; -} - -Value *insertStore(Operation *oldOp, Value *oldValue, OpBuilder &builder, - BlockAndValueMapping *mapping) { - auto *newValue = mapping->lookupOrNull(oldValue); - if (!newValue) { - return nullptr; - } - - // If the previous value was already a memref we don't need to change - // anything. - // TODO(benvanik): ensure indices make sense. - if (oldValue->getType().isa<MemRefType>()) { - return newValue; - } else if (oldValue->getType().isa<TensorType>()) { - auto castOp = - builder.create<IREE::TensorToMemRefOp>(oldOp->getLoc(), newValue); - return castOp.getResult(); - } - - // Look back up and see if we can find the memref the value was loaded from. - if (auto *sourceMemRef = resolveValueToSourceMemRef(oldValue, oldOp)) { - return mapping->lookupOrNull(sourceMemRef); - } - - // Allocate the memref to store the value. - auto newStorage = builder.create<AllocOp>( - oldOp->getLoc(), convertTypeToMemRef(oldValue->getType())); - - // Insert the store we'll use to box the value. - builder.create<StoreOp>(oldOp->getLoc(), newValue, newStorage, - ArrayRef<Value *>{}); - - return newStorage; -} - -bool convertCallOp(CallOp *oldOp, OpBuilder &builder, - BlockAndValueMapping *mapping) { - llvm::SmallVector<Value *, 4> newArgs; - for (auto *oldArg : oldOp->getOperands()) { - auto *newArg = insertStore(oldOp->getOperation(), oldArg, builder, mapping); - if (!newArg) { - return true; - } - newArgs.push_back(newArg); - } - - SmallVector<Type, 4> resultTypes; - for (auto oldType : oldOp->getOperation()->getResultTypes()) { - resultTypes.push_back(convertTypeToMemRef(oldType)); - } - auto newOp = builder.create<CallOp>(oldOp->getLoc(), oldOp->getCallee(), - resultTypes, newArgs); - copyOperationAttrs(oldOp->getOperation(), newOp.getOperation()); - - for (int i = 0; i < newOp.getNumResults(); ++i) { - auto *oldResult = oldOp->getResult(i); - auto *newResult = newOp.getResult(i); - if (insertLoad(oldOp->getOperation(), oldResult, newResult, builder, - mapping)) { - return true; - } - } - - return false; -} - -bool convertCallIndirectOp(CallIndirectOp *oldOp, OpBuilder &builder, - BlockAndValueMapping *mapping) { - // TODO(benvanik): support wrapping callee values. - oldOp->emitError("CallIndirectOp not yet supported"); - return true; -#if 0 - llvm::SmallVector<Value *, 4> newArgs; - for (auto *oldArg : oldOp->getArgOperands()) { - auto *newArg = insertStore(oldOp->getOperation(), oldArg, builder, mapping); - if (!newArg) { - return true; - } - newArgs.push_back(newArg); - } - - auto newOp = builder.create<CallIndirectOp>(oldOp->getLoc(), - oldOp->getCallee(), newArgs); - copyOperationAttrs(oldOp->getOperation(), newOp.getOperation()); - - for (int i = 0; i < newOp.getNumResults(); ++i) { - auto *oldResult = oldOp->getResult(i); - auto *newResult = newOp.getResult(i); - if (insertLoad(oldOp->getOperation(), oldResult, newResult, builder, - mapping)) { - return true; - } - } - - return false; -#endif // 0 -} - -bool convertReturnOp(Operation *oldOp, OpBuilder &builder, - BlockAndValueMapping *mapping) { - BlockAndValueMapping returnMapping; - for (auto *oldArg : oldOp->getOperands()) { - auto *newArg = insertStore(oldOp, oldArg, builder, mapping); - if (!newArg) { - return true; - } - returnMapping.map(oldArg, newArg); - } - - builder.clone(*oldOp, returnMapping); - return false; -} - -bool convertBranchOp(BranchOp *oldOp, OpBuilder &builder, - BlockAndValueMapping *mapping) { - llvm::SmallVector<Value *, 4> newArgs; - for (auto *oldArg : oldOp->getOperands()) { - auto *newArg = insertStore(oldOp->getOperation(), oldArg, builder, mapping); - if (!newArg) { - return true; - } - newArgs.push_back(newArg); - } - - auto *dest = mapping->lookupOrNull(oldOp->getDest()); - if (!dest) { - oldOp->emitError("Destination block mapping not found"); - return true; - } - - auto newOp = builder.create<BranchOp>(oldOp->getLoc(), dest, newArgs); - copyOperationAttrs(oldOp->getOperation(), newOp.getOperation()); - - return false; -} - -bool convertCondBranchOp(CondBranchOp *oldOp, OpBuilder &builder, - BlockAndValueMapping *mapping) { - llvm::SmallVector<Value *, 4> trueArgs; - for (auto *oldArg : oldOp->getTrueOperands()) { - auto *newArg = insertStore(oldOp->getOperation(), oldArg, builder, mapping); - if (!newArg) { - return true; - } - trueArgs.push_back(newArg); - } - llvm::SmallVector<Value *, 4> falseArgs; - for (auto *oldArg : oldOp->getFalseOperands()) { - auto *newArg = insertStore(oldOp->getOperation(), oldArg, builder, mapping); - if (!newArg) { - return true; - } - falseArgs.push_back(newArg); - } - - auto *trueDest = mapping->lookupOrNull(oldOp->getTrueDest()); - if (!trueDest) { - oldOp->emitError("True destination block mapping not found"); - return true; - } - auto *falseDest = mapping->lookupOrNull(oldOp->getFalseDest()); - if (!falseDest) { - oldOp->emitError("False destination block mapping not found"); - return true; - } - - // Lowering will take care of the condition store. - auto *newCondition = mapping->lookupOrNull(oldOp->getCondition()); - if (!newCondition) { - oldOp->emitError("Condition value mapping not found"); - return false; - } - - auto newOp = builder.create<CondBranchOp>( - oldOp->getLoc(), newCondition, trueDest, trueArgs, falseDest, falseArgs); - copyOperationAttrs(oldOp->getOperation(), newOp.getOperation()); - - return false; -} - -bool convertOperation(Operation *oldOp, OpBuilder &builder, - BlockAndValueMapping *mapping) { - if (isa<ConstantOp>(oldOp)) { - builder.clone(*oldOp, *mapping); - return false; - } else if (auto callOp = dyn_cast<CallOp>(oldOp)) { - return convertCallOp(&callOp, builder, mapping); - } else if (auto callIndirectOp = dyn_cast<CallIndirectOp>(oldOp)) { - return convertCallIndirectOp(&callIndirectOp, builder, mapping); - } else if (isa<ReturnOp>(oldOp) || isa<IREE::ReturnOp>(oldOp)) { - return convertReturnOp(oldOp, builder, mapping); - } else if (auto branchOp = dyn_cast<BranchOp>(oldOp)) { - return convertBranchOp(&branchOp, builder, mapping); - } else if (auto condBranchOp = dyn_cast<CondBranchOp>(oldOp)) { - return convertCondBranchOp(&condBranchOp, builder, mapping); - } else { - builder.clone(*oldOp, *mapping); - return false; - } -} - -bool convertFunction(FuncOp oldFunc, FuncOp newFunc) { - OpBuilder builder(newFunc.getBody()); - BlockAndValueMapping mapping; - - // Create new blocks matching the expected arguments of the old ones. - // This sets up the block mappings to enable us to reference blocks forward - // during conversion. - newFunc.getBlocks().clear(); - for (auto &oldBlock : oldFunc.getBlocks()) { - auto *newBlock = builder.createBlock(&newFunc.getBody()); - for (auto *oldArg : oldBlock.getArguments()) { - // Replace the block args with memrefs. - auto memRefType = convertTypeToMemRef(oldArg->getType()); - if (!memRefType) return true; - auto *newArg = newBlock->addArgument(memRefType); - - // Insert loads to preserve type, if needed. - // This will replace all uses of the oldArg with the loaded value from - // newArg so that the block contents are still using unwrapped values. - if (insertLoad(oldArg, newArg, builder, &mapping)) { - return true; - } - } - mapping.map(&oldBlock, newBlock); - } - - // Convert all ops in the blocks. - for (auto &oldBlock : oldFunc.getBlocks()) { - builder.setInsertionPointToEnd(mapping.lookupOrNull(&oldBlock)); - for (auto &oldOp : oldBlock.getOperations()) { - if (convertOperation(&oldOp, builder, &mapping)) { - return true; - } - } - } - - return false; -} - -} // namespace - -class ConvertToMemRefCallingConventionPass - : public ModulePass<ConvertToMemRefCallingConventionPass> { - public: - void runOnModule() override { - auto module = getModule(); - - // Build a list of (oldFunc, newFunc) for all functions we need to - // replace. This will ensure that when we go to convert function bodies we - // have only new functions defined. - std::vector<std::pair<FuncOp, FuncOp>> convertedFunctions; - - for (auto oldFunc : module.getOps<FuncOp>()) { - // Create the replacement function, ensuring that we copy attributes. - auto functionType = getMemRefFunctionType(oldFunc.getType()); - if (!functionType) { - return signalPassFailure(); - } - - auto newFunc = FuncOp::create(oldFunc.getLoc(), oldFunc.getName(), - functionType, oldFunc.getDialectAttrs()); - convertedFunctions.push_back({oldFunc, newFunc}); - - // Perform the actual body conversion now. - if (convertFunction(oldFunc, newFunc)) { - return signalPassFailure(); - } - } - - // Replace functions in the module. - for (auto &pair : convertedFunctions) { - pair.first.erase(); - module.push_back(pair.second); - } - } -}; - -std::unique_ptr<OpPassBase<ModuleOp>> -createConvertToMemRefCallingConventionPass() { - return std::make_unique<ConvertToMemRefCallingConventionPass>(); -} - -static PassRegistration<ConvertToMemRefCallingConventionPass> pass( - "convert-to-memref-calling-convention", - "Convert functions to use a memref-based calling convention."); - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Transforms/DropUnreachableFunctions.cpp b/iree/compiler/Transforms/DropUnreachableFunctions.cpp deleted file mode 100644 index 6a5748c..0000000 --- a/iree/compiler/Transforms/DropUnreachableFunctions.cpp +++ /dev/null
@@ -1,67 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/Utils/ModuleUtils.h" -#include "llvm/ADT/SetVector.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/Utils.h" - -namespace mlir { -namespace iree_compiler { - -// Drops all functions in a module that are not reachable by functions with the -// "iree.module.export" attribute. -class DropUnreachableModuleFunctionsPass - : public ModulePass<DropUnreachableModuleFunctionsPass> { - public: - void runOnModule() override { - dropUnusedFunctions(getModule(), {"iree.module.export"}); - } -}; - -// Drops all functions in a module that are not reachable by functions with the -// "iree.executable.export" attribute. -class DropUnreachableExecutableFunctionsPass - : public ModulePass<DropUnreachableExecutableFunctionsPass> { - public: - void runOnModule() override { - dropUnusedFunctions(getModule(), {"iree.executable.export"}); - } -}; - -std::unique_ptr<OpPassBase<ModuleOp>> -createDropUnreachableModuleFunctionsPass() { - return std::make_unique<DropUnreachableModuleFunctionsPass>(); -} - -std::unique_ptr<OpPassBase<ModuleOp>> -createDropUnreachableExecutableFunctionsPass() { - return std::make_unique<DropUnreachableExecutableFunctionsPass>(); -} - -static PassRegistration<DropUnreachableModuleFunctionsPass> moduleFunctionsPass( - "iree-drop-unreachable-module-functions", - "Drop all functions not reachable from an exported function"); - -static PassRegistration<DropUnreachableExecutableFunctionsPass> - executableFunctionsPass( - "iree-drop-unreachable-executable-functions", - "Drop all functions not reachable from an exported function"); - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Transforms/DropUnusedExecutables.cpp b/iree/compiler/Transforms/DropUnusedExecutables.cpp deleted file mode 100644 index e10f81c..0000000 --- a/iree/compiler/Transforms/DropUnusedExecutables.cpp +++ /dev/null
@@ -1,62 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Sequencer/HLOps.h" -#include "iree/compiler/IR/StructureOps.h" -#include "llvm/ADT/SetVector.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/Utils.h" - -namespace mlir { -namespace iree_compiler { - -// Drops all executables in a module that are not used by any dispatch -// sequencer op. -class DropUnusedExecutablesPass : public ModulePass<DropUnusedExecutablesPass> { - public: - void runOnModule() override { - DenseSet<StringRef> usedExecutableNames; - for (auto funcOp : getModule().getOps<FuncOp>()) { - funcOp.walk([&](IREESeq::HL::DispatchOp op) { - usedExecutableNames.insert(op.getExecutable()); - }); - } - DenseSet<Operation *> deadExecutables; - for (auto executableOp : - getModule().getOps<IREE::MultiArchExecutableOp>()) { - if (usedExecutableNames.count(executableOp.getName()) == 0) { - deadExecutables.insert(executableOp); - } - } - for (auto executableOp : deadExecutables) { - executableOp->erase(); - } - } -}; - -std::unique_ptr<OpPassBase<ModuleOp>> createDropUnusedExecutablesPass() { - return std::make_unique<DropUnusedExecutablesPass>(); // NOLINT -} - -static PassRegistration<DropUnusedExecutablesPass> executableFunctionsPass( - "iree-drop-unused-executables", - "Drop all executables not reachable from a dispatch/reduce op."); - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Transforms/Interpreter/BUILD b/iree/compiler/Transforms/Interpreter/BUILD deleted file mode 100644 index e5816fa..0000000 --- a/iree/compiler/Transforms/Interpreter/BUILD +++ /dev/null
@@ -1,41 +0,0 @@ -# Transforms specific to the IREE interpreter. - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "Interpreter", - srcs = [ - "ExpandReductionsToOps.cpp", - "LowerInterpreterDialect.cpp", - "LowerStdToInterpreterDialect.cpp", - "LowerToInterpreterDialect.cpp", - "LowerXLAToInterpreterDialect.cpp", - "MakeExecutableABI.cpp", - ], - hdrs = [ - "Passes.h", - "Rewrites.h", - ], - deps = [ - "//iree/compiler/IR", - "//iree/compiler/IR/Interpreter", - "//iree/compiler/Serialization", - "//iree/compiler/Transforms", - "//iree/compiler/Utils", - "//iree/schemas/bytecode:interpreter_bytecode_v0", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Support", - "@local_config_mlir//:TransformUtils", - "@local_config_mlir//:Transforms", - "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo", - "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_legalize_to_standard", - "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_lower_general_dot", - ], - alwayslink = 1, -)
diff --git a/iree/compiler/Transforms/Interpreter/ExpandReductionsToOps.cpp b/iree/compiler/Transforms/Interpreter/ExpandReductionsToOps.cpp deleted file mode 100644 index 41706fe..0000000 --- a/iree/compiler/Transforms/Interpreter/ExpandReductionsToOps.cpp +++ /dev/null
@@ -1,216 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Dialect.h" -#include "iree/compiler/IR/Interpreter/HLDialect.h" -#include "iree/compiler/IR/Interpreter/HLOps.h" -#include "iree/compiler/IR/Ops.h" -#include "iree/compiler/Transforms/ConversionUtils.h" -#include "iree/compiler/Utils/MemRefUtils.h" -#include "iree/compiler/Utils/OpCreationUtils.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Function.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/IR/Value.h" -#include "mlir/Pass/Pass.h" -#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -LogicalResult convertReductionOp(FuncOp entryPoint, FuncOp applyFunc, - Operation *elementOp, OpBuilder &builder) { - // Ensure that this op is pass-through and does not interact with any other - // ops within the function. - // TODO(b/139313439): support fused reductions. - for (auto *operand : elementOp->getOperands()) { - if (operand->getDefiningOp() != nullptr) { - return elementOp->emitOpError() - << "Fused reductions are not supported (operand not sourced from " - "block args)"; - } - } - for (auto *result : elementOp->getResults()) { - for (auto *user : result->getUsers()) { - if (!user->isKnownTerminator()) { - return elementOp->emitOpError() << "Fused reductions are not supported " - "(result used by non-terminator)"; - } - } - } - - // Determine the index of the args we care about. We'll use these to match up - // the operands of the entry point with our application. - // Our arguments are expanded tuples like <lhs0, rhs0>, <lhs1, rhs1>, so this - // index gets the offset * 2. - auto &applyEntryBlock = applyFunc.getBlocks().front(); - int setIndex = std::distance(applyEntryBlock.args_begin(), - llvm::find(applyEntryBlock.getArguments(), - elementOp->getOperand(0))) / - 2; - - // Map to the args from the entry point. - auto &entryPointEntryBlock = entryPoint.getBlocks().front(); - Value *srcArg = entryPointEntryBlock.getArgument(setIndex); - Value *initArg = entryPointEntryBlock.getArgument( - applyFunc.getNumArguments() / 2 + setIndex); - Value *dstArg = - entryPointEntryBlock.getArgument(applyFunc.getNumArguments() + setIndex); - auto dstType = dstArg->getType().cast<ShapedType>(); - Type elementType = dstType.getElementType(); - auto loc = elementOp->getLoc(); - auto dimensionAttr = entryPoint.getAttrOfType<IntegerAttr>( - "iree.executable.reduction.dimension"); - - Operation *expandedOp = nullptr; - if (isa<IREEInterp::HL::AddFOp>(elementOp) || - isa<IREEInterp::HL::AddIOp>(elementOp)) { - if (elementType.isa<FloatType>()) { - expandedOp = builder.create<IREEInterp::HL::ReduceSumFOp>( - loc, dstType, srcArg, initArg, dimensionAttr); - } else { - expandedOp = builder.create<IREEInterp::HL::ReduceSumIOp>( - loc, dstType, srcArg, initArg, dimensionAttr); - } - } else if (isa<IREEInterp::HL::MinFOp>(elementOp) || - isa<IREEInterp::HL::MinISOp>(elementOp) || - isa<IREEInterp::HL::MinIUOp>(elementOp)) { - if (elementType.isa<FloatType>()) { - expandedOp = builder.create<IREEInterp::HL::ReduceMinFOp>( - loc, dstType, srcArg, initArg, dimensionAttr); - } else { - expandedOp = builder.create<IREEInterp::HL::ReduceMinIOp>( - loc, dstType, srcArg, initArg, dimensionAttr); - } - } else if (isa<IREEInterp::HL::MaxFOp>(elementOp) || - isa<IREEInterp::HL::MaxISOp>(elementOp) || - isa<IREEInterp::HL::MaxIUOp>(elementOp)) { - if (elementType.isa<FloatType>()) { - expandedOp = builder.create<IREEInterp::HL::ReduceMaxFOp>( - loc, dstType, srcArg, initArg, dimensionAttr); - } else { - expandedOp = builder.create<IREEInterp::HL::ReduceMaxIOp>( - loc, dstType, srcArg, initArg, dimensionAttr); - } - } - if (!expandedOp) { - return elementOp->emitOpError() - << "No matching expanded reduction op for elemental op"; - } - llvm::SmallVector<int64_t, 4> zeroOffset(dstType.getRank(), 0); - auto zeroIndices = createArrayConstant(builder, loc, zeroOffset); - auto lengths = createArrayConstant(builder, loc, dstType.getShape()); - builder.create<IREEInterp::HL::CopyOp>( - loc, expandedOp->getResult(0), zeroIndices, dstArg, zeroIndices, lengths); - - return success(); -} - -// Replaces the given elemental |funcOp| with a widened reduction. -LogicalResult expandReductionFunction(FuncOp entryFunc) { - if (!entryFunc.empty()) { - return entryFunc.emitError() - << "Function has already been expanded or has existing contents"; - } else if (!entryFunc.getAttr("iree.executable.reduction.dimension")) { - return entryFunc.emitError() << "Windowed reductions are not yet supported"; - } - auto applySym = - entryFunc.getAttrOfType<SymbolRefAttr>("iree.executable.reduction.apply"); - if (!applySym) { - return entryFunc.emitError() << "No reduction application function defined"; - } - auto applyFunc = entryFunc.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>( - applySym.getValue()); - if (!applyFunc) { - return entryFunc.emitError() - << "Unable to find apply function " << applySym; - } - - auto *entryBlock = entryFunc.addEntryBlock(); - OpBuilder builder(entryBlock); - - if (applyFunc.getBlocks() - .front() - .walk([&](Operation *op) { - if (!op->isKnownTerminator()) { - if (failed( - convertReductionOp(entryFunc, applyFunc, op, builder))) { - return WalkResult::interrupt(); - } - } - return WalkResult::advance(); - }) - .wasInterrupted()) { - return applyFunc.emitError() << "Unable to convert apply func"; - } - - builder.create<IREE::ReturnOp>(builder.getUnknownLoc()); - - // Remove the apply function as we have inlined it. - applyFunc.erase(); - entryFunc.removeAttr("iree.executable.reduction.apply"); - entryFunc.removeAttr("iree.executable.reduction.dimension"); - - return success(); -} - -// Limited lowering of reductions to fat reduce_* ops. -// -// The specific subset this supports is: -// * 'min', 'max', and 'add' computations, with function names matching the -// computation -// * one op per reduction (no fusions yet). -// Note: computations and shapes are not validated. -// -// TODO(b/139410773): Implement more generally, supporting custom computations. -class ExpandReductionsToOpsPass : public ModulePass<ExpandReductionsToOpsPass> { - public: - void runOnModule() override { - auto module = getModule(); - SmallVector<FuncOp, 4> reductionFuncs; - for (auto funcOp : module.getOps<FuncOp>()) { - if (funcOp.getAttr("iree.executable.reduction.apply")) { - reductionFuncs.push_back(funcOp); - } - } - for (auto funcOp : reductionFuncs) { - if (failed(expandReductionFunction(funcOp))) { - return signalPassFailure(); - } - } - } -}; - -} // namespace - -std::unique_ptr<OpPassBase<ModuleOp>> createExpandReductionsToOpsPass() { - return std::make_unique<ExpandReductionsToOpsPass>(); -} - -static PassRegistration<ExpandReductionsToOpsPass> pass( - "iree-expand-reductions-to-ops", - "Expands IREE reduction functions to their interpreter ops"); - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Transforms/Interpreter/LowerInterpreterDialect.cpp b/iree/compiler/Transforms/Interpreter/LowerInterpreterDialect.cpp deleted file mode 100644 index 22b6a8a..0000000 --- a/iree/compiler/Transforms/Interpreter/LowerInterpreterDialect.cpp +++ /dev/null
@@ -1,253 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Dialect.h" -#include "iree/compiler/IR/Interpreter/HLDialect.h" -#include "iree/compiler/IR/Interpreter/HLOps.h" -#include "iree/compiler/IR/Interpreter/LLDialect.h" -#include "iree/compiler/IR/Interpreter/LLOps.h" -#include "iree/compiler/IR/Ops.h" -#include "iree/compiler/Serialization/BytecodeTables.h" -#include "iree/schemas/bytecode/interpreter_bytecode_v0.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Allocator.h" -#include "llvm/Support/Casting.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/Utils.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -struct LowerBranchOpPattern - : public OpRewritePattern<IREEInterp::HL::BranchOp> { - using OpRewritePattern<IREEInterp::HL::BranchOp>::OpRewritePattern; - - PatternMatchResult matchAndRewrite(IREEInterp::HL::BranchOp op, - PatternRewriter &rewriter) const { - SmallVector<Value *, 8> operands{op.getOperation()->getOperands()}; - - rewriter.replaceOpWithNewOp<IREEInterp::LL::BranchOp>(op, op.getDest(), - operands); - return matchSuccess(); - } -}; - -struct LowerCondCondBranchOpPattern - : public OpRewritePattern<IREEInterp::HL::CondBranchOp> { - using OpRewritePattern<IREEInterp::HL::CondBranchOp>::OpRewritePattern; - - PatternMatchResult matchAndRewrite(IREEInterp::HL::CondBranchOp op, - PatternRewriter &rewriter) const { - SmallVector<Value *, 8> trueOperands{op.getTrueOperands()}; - SmallVector<Value *, 8> falseOperands{op.getFalseOperands()}; - - rewriter.replaceOpWithNewOp<IREEInterp::LL::CondBranchOp>( - op, op.getCondition(), op.getTrueDest(), trueOperands, - op.getFalseDest(), falseOperands); - return matchSuccess(); - } -}; - -// Returns true if the op defined by |opName| (like 'iree_ll_interp.reshape') -// uses output operands for results (like iree_ll_interp.add_i) or returns real -// results. -bool opTakesOutputOperands(llvm::StringRef opName) { - if (!opName.consume_front("iree_ll_interp.")) { - assert(false && "op not part of IREE LL Interpreter dialect"); - return false; - } - auto opcode = GetInterpreterOpcodeByName(opName.str()); - assert(opcode.hasValue() && "op has no corresponding opcode"); - const auto &info = GetInterpreterOpcodeInfo(opcode.getValue()); - for (auto &operand : info.operands) { - if (operand == iree::OperandEncoding::kOutputSlot || - operand == iree::OperandEncoding::kVariadicOutputSlots) { - return true; - } - } - return false; -} - -template <typename SrcOp, typename DstOp> -class SimpleOpLowering : public OpRewritePattern<SrcOp> { - using OpRewritePattern<SrcOp>::OpRewritePattern; - - PatternMatchResult matchAndRewrite(SrcOp op, - PatternRewriter &rewriter) const { - SmallVector<Value *, 8> operands{op.getOperation()->getOperands()}; - - // Most ops take results as output operands to populate during execution. - // Certain ops, like reshape, return references to existing memrefs and - // should still retain their results. - if (!opTakesOutputOperands(DstOp::getOperationName())) { - SmallVector<Type, 8> resultTypes{op.getOperation()->getResultTypes()}; - - rewriter.replaceOpWithNewOp<DstOp>(op, resultTypes, operands, - op.getAttrs()); - return this->matchSuccess(); - } - - SmallVector<Value *, 4> replacementValues; - for (Value *result : op.getOperation()->getResults()) { - auto memRefType = result->getType().cast<MemRefType>(); - if (!memRefType.hasStaticShape()) { - // TODO(benvanik): real thing here - dynamic shaping required. - // This should emit a shape calculation based on the operation. Most - // are likely simple and by running DCE after this we can clean up - // parts that are static or unused. - op.emitOpError() << "uses unsupported dynamic shapes"; - return this->matchFailure(); - } - ArrayRef<Value *> dim_pieces; - auto allocOp = rewriter.create<IREEInterp::LL::AllocHeapOp>( - op.getLoc(), memRefType, dim_pieces); - operands.push_back(allocOp); - replacementValues.push_back(allocOp); - } - ArrayRef<Type> resultTypes; - rewriter.create<DstOp>(op.getLoc(), resultTypes, operands, op.getAttrs()); - rewriter.replaceOp(op, replacementValues); - return this->matchSuccess(); - } -}; - -} // namespace - -void populateInterpreterLoweringPatterns(OwningRewritePatternList &patterns, - MLIRContext *ctx) { - patterns.insert<LowerBranchOpPattern, LowerCondCondBranchOpPattern>(ctx); - patterns.insert< - SimpleOpLowering<IREE::ConstantOp, IREEInterp::LL::ConstantOp>, - SimpleOpLowering<IREEInterp::HL::CopyOp, IREEInterp::LL::DynamicCopyOp>, - SimpleOpLowering<IREEInterp::HL::SliceOp, - IREEInterp::LL::DynamicSliceOp>>(ctx); -#define SAME_NAME_SIMPLE_PATTERN(op_name) \ - SimpleOpLowering<IREEInterp::HL::op_name, IREEInterp::LL::op_name> - // clang-format off - patterns.insert< - SAME_NAME_SIMPLE_PATTERN(AssignOp), - SAME_NAME_SIMPLE_PATTERN(AbsFOp), - SAME_NAME_SIMPLE_PATTERN(AbsIOp), - SAME_NAME_SIMPLE_PATTERN(AddFOp), - SAME_NAME_SIMPLE_PATTERN(AddIOp), - SAME_NAME_SIMPLE_PATTERN(AllocHeapOp), - SAME_NAME_SIMPLE_PATTERN(AndOp), - SAME_NAME_SIMPLE_PATTERN(Atan2FOp), - SAME_NAME_SIMPLE_PATTERN(BreakOp), - SAME_NAME_SIMPLE_PATTERN(BroadcastOp), - SAME_NAME_SIMPLE_PATTERN(CallOp), - SAME_NAME_SIMPLE_PATTERN(CallIndirectOp), - SAME_NAME_SIMPLE_PATTERN(CeilFOp), - SAME_NAME_SIMPLE_PATTERN(ClampFOp), - SAME_NAME_SIMPLE_PATTERN(CloneOp), - SAME_NAME_SIMPLE_PATTERN(CmpFOp), - SAME_NAME_SIMPLE_PATTERN(CmpIOp), - SAME_NAME_SIMPLE_PATTERN(CondAssignOp), - SAME_NAME_SIMPLE_PATTERN(ConvertSSOp), - SAME_NAME_SIMPLE_PATTERN(ConvertUUOp), - SAME_NAME_SIMPLE_PATTERN(ConvertSUOp), - SAME_NAME_SIMPLE_PATTERN(ConvertUSOp), - SAME_NAME_SIMPLE_PATTERN(CondBreakOp), - SAME_NAME_SIMPLE_PATTERN(CosFOp), - SAME_NAME_SIMPLE_PATTERN(DimOp), - SAME_NAME_SIMPLE_PATTERN(DivFOp), - SAME_NAME_SIMPLE_PATTERN(DivISOp), - SAME_NAME_SIMPLE_PATTERN(DivIUOp), - SAME_NAME_SIMPLE_PATTERN(ExpFOp), - SAME_NAME_SIMPLE_PATTERN(LogFOp), - SAME_NAME_SIMPLE_PATTERN(RsqrtFOp), - SAME_NAME_SIMPLE_PATTERN(FloorFOp), - SAME_NAME_SIMPLE_PATTERN(LengthOp), - SAME_NAME_SIMPLE_PATTERN(MatMulFOp), - SAME_NAME_SIMPLE_PATTERN(MatMulIOp), - SAME_NAME_SIMPLE_PATTERN(MaxFOp), - SAME_NAME_SIMPLE_PATTERN(MaxISOp), - SAME_NAME_SIMPLE_PATTERN(MaxIUOp), - SAME_NAME_SIMPLE_PATTERN(MinFOp), - SAME_NAME_SIMPLE_PATTERN(MinISOp), - SAME_NAME_SIMPLE_PATTERN(MinIUOp), - SAME_NAME_SIMPLE_PATTERN(MulAddFOp), - SAME_NAME_SIMPLE_PATTERN(MulAddIOp), - SAME_NAME_SIMPLE_PATTERN(MulFOp), - SAME_NAME_SIMPLE_PATTERN(MulIOp), - SAME_NAME_SIMPLE_PATTERN(NotOp), - SAME_NAME_SIMPLE_PATTERN(OrOp), - SAME_NAME_SIMPLE_PATTERN(PadOp), - SAME_NAME_SIMPLE_PATTERN(RankOp), - SAME_NAME_SIMPLE_PATTERN(ReduceSumIOp), - SAME_NAME_SIMPLE_PATTERN(ReduceSumFOp), - SAME_NAME_SIMPLE_PATTERN(ReduceMinIOp), - SAME_NAME_SIMPLE_PATTERN(ReduceMinFOp), - SAME_NAME_SIMPLE_PATTERN(ReduceMaxIOp), - SAME_NAME_SIMPLE_PATTERN(ReduceMaxFOp), - SAME_NAME_SIMPLE_PATTERN(ReshapeOp), - SAME_NAME_SIMPLE_PATTERN(ReturnOp), - SAME_NAME_SIMPLE_PATTERN(SelectOp), - SAME_NAME_SIMPLE_PATTERN(ShapeOp), - SAME_NAME_SIMPLE_PATTERN(ShiftLeftOp), - SAME_NAME_SIMPLE_PATTERN(ShiftRightArithmeticOp), - SAME_NAME_SIMPLE_PATTERN(ShiftRightLogicalOp), - SAME_NAME_SIMPLE_PATTERN(SinFOp), - SAME_NAME_SIMPLE_PATTERN(SplitOp), - SAME_NAME_SIMPLE_PATTERN(SubFOp), - SAME_NAME_SIMPLE_PATTERN(SubIOp), - SAME_NAME_SIMPLE_PATTERN(TanhFOp), - SAME_NAME_SIMPLE_PATTERN(TileOp), - SAME_NAME_SIMPLE_PATTERN(TraceOp), - SAME_NAME_SIMPLE_PATTERN(TransposeOp), - SAME_NAME_SIMPLE_PATTERN(ReverseOp), - SAME_NAME_SIMPLE_PATTERN(XorOp)>(ctx); - // clang-format on -#undef SAME_NAME_SIMPLE_PATTERN -} - -namespace { -class LowerInterpreterDialectPass - : public FunctionPass<LowerInterpreterDialectPass> { - public: - void runOnFunction() override { - OwningRewritePatternList patterns; - populateInterpreterLoweringPatterns(patterns, &getContext()); - - ConversionTarget target(getContext()); - target.addLegalDialect<IREELLInterpreterDialect>(); - target.addLegalOp<FuncOp, IREE::ReturnOp>(); - if (failed(applyFullConversion(getFunction(), target, patterns))) { - return signalPassFailure(); - } - } -}; -} // namespace - -std::unique_ptr<OpPassBase<FuncOp>> createLowerInterpreterDialectPass() { - return std::make_unique<LowerInterpreterDialectPass>(); -} - -static PassRegistration<LowerInterpreterDialectPass> pass( - "lower-iree-interpreter-hl-to-ll", "Lowers IREE HL ops to IREE LL ops"); - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Transforms/Interpreter/LowerStdToInterpreterDialect.cpp b/iree/compiler/Transforms/Interpreter/LowerStdToInterpreterDialect.cpp deleted file mode 100644 index 9263b2a..0000000 --- a/iree/compiler/Transforms/Interpreter/LowerStdToInterpreterDialect.cpp +++ /dev/null
@@ -1,303 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Dialect.h" -#include "iree/compiler/IR/Interpreter/HLDialect.h" -#include "iree/compiler/IR/Interpreter/HLOps.h" -#include "iree/compiler/IR/Interpreter/LLDialect.h" -#include "iree/compiler/IR/Ops.h" -#include "iree/compiler/Transforms/ConversionUtils.h" -#include "iree/compiler/Utils/MemRefUtils.h" -#include "iree/compiler/Utils/OpCreationUtils.h" -#include "iree/compiler/Utils/TypeConversionUtils.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/Support/Allocator.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Function.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Module.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Transforms/DialectConversion.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -struct CallOpLowering : public OpConversionPattern<CallOp> { - using OpConversionPattern::OpConversionPattern; - - PatternMatchResult matchAndRewrite( - CallOp op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - auto callOp = cast<CallOp>(op); - auto calleeType = callOp.getCalleeType(); - rewriter.replaceOpWithNewOp<IREEInterp::HL::CallOp>( - op, callOp.getCallee(), calleeType.getResults(), operands); - return matchSuccess(); - } -}; - -struct CallIndirectOpLowering : public OpConversionPattern<CallIndirectOp> { - using OpConversionPattern::OpConversionPattern; - - PatternMatchResult matchAndRewrite( - CallIndirectOp op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - auto callOp = cast<CallIndirectOp>(op); - rewriter.replaceOpWithNewOp<IREEInterp::HL::CallIndirectOp>( - op, callOp.getCallee(), operands); - return matchSuccess(); - } -}; - -struct ReturnOpLowering : public OpConversionPattern<ReturnOp> { - using OpConversionPattern::OpConversionPattern; - - PatternMatchResult matchAndRewrite( - ReturnOp op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp<IREEInterp::HL::ReturnOp>(op, operands); - return matchSuccess(); - } -}; - -struct BranchOpLowering : public OpConversionPattern<BranchOp> { - using OpConversionPattern::OpConversionPattern; - - PatternMatchResult matchAndRewrite( - BranchOp op, ArrayRef<Value *> properOperands, - ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp<IREEInterp::HL::BranchOp>(op, destinations[0], - operands[0]); - return this->matchSuccess(); - } -}; - -struct CondBranchOpLowering : public OpConversionPattern<CondBranchOp> { - using OpConversionPattern::OpConversionPattern; - - PatternMatchResult matchAndRewrite( - CondBranchOp op, ArrayRef<Value *> properOperands, - ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands, - ConversionPatternRewriter &rewriter) const override { - auto *condValue = loadAccessValue(op.getLoc(), properOperands[0], rewriter); - rewriter.replaceOpWithNewOp<IREEInterp::HL::CondBranchOp>( - op, condValue, destinations[IREEInterp::HL::CondBranchOp::trueIndex], - operands[IREEInterp::HL::CondBranchOp::trueIndex], - destinations[IREEInterp::HL::CondBranchOp::falseIndex], - operands[IREEInterp::HL::CondBranchOp::falseIndex]); - return this->matchSuccess(); - } -}; - -template <typename SrcOp, typename DstOp> -struct CompareOpLowering : public OpConversionPattern<SrcOp> { - using OpConversionPattern<SrcOp>::OpConversionPattern; - - PatternMatchResult matchAndRewrite( - SrcOp op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - auto lhValue = loadAccessValue(op.getLoc(), operands[0], rewriter); - auto rhValue = loadAccessValue(op.getLoc(), operands[1], rewriter); - - lhValue = wrapAsMemRef(lhValue, op, rewriter); - rhValue = wrapAsMemRef(rhValue, op, rewriter); - - // TODO(benvanik): map predicate to stable value. - auto predicate = - rewriter.getI32IntegerAttr(static_cast<int32_t>(op.getPredicate())); - - auto dstType = convertTypeToMemRef(op.getResult()); - auto midOp = rewriter.create<DstOp>(op.getLoc(), dstType, predicate, - lhValue, rhValue); - - auto result = wrapAsTensor(midOp.getResult(), op, rewriter); - rewriter.replaceOp( - op, {loadResultValue(op.getLoc(), op.getType(), result, rewriter)}); - return this->matchSuccess(); - } -}; - -struct CmpIOpLowering - : public CompareOpLowering<CmpIOp, IREEInterp::HL::CmpIOp> { - using CompareOpLowering::CompareOpLowering; -}; - -struct CmpFOpLowering - : public CompareOpLowering<CmpFOp, IREEInterp::HL::CmpFOp> { - using CompareOpLowering::CompareOpLowering; -}; - -struct AllocOpLowering : public OpConversionPattern<AllocOp> { - using OpConversionPattern::OpConversionPattern; - - PatternMatchResult matchAndRewrite( - AllocOp op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - // TODO(benvanik): replace with length computation. - rewriter.replaceOpWithNewOp<IREEInterp::HL::AllocHeapOp>(op, op.getType(), - operands); - return matchSuccess(); - } -}; - -struct DeallocOpLowering : public OpConversionPattern<DeallocOp> { - using OpConversionPattern::OpConversionPattern; - - PatternMatchResult matchAndRewrite( - DeallocOp op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp<IREEInterp::HL::DiscardOp>(op, operands[0]); - return matchSuccess(); - } -}; - -struct LoadOpLowering : public OpRewritePattern<LoadOp> { - using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(LoadOp loadOp, - PatternRewriter &rewriter) const override { - if (loadOp.getMemRefType().getRank() != 0) { - loadOp.emitError() << "Cannot lower load of non-scalar"; - return matchFailure(); - } - ArrayRef<Value *> dimPieces; - auto dst = - rewriter - .create<AllocOp>(loadOp.getLoc(), loadOp.getMemRefType(), dimPieces) - .getResult(); - auto emptyArrayMemref = createArrayConstant(rewriter, loadOp.getLoc(), {}); - rewriter.create<IREEInterp::HL::CopyOp>( - loadOp.getLoc(), loadOp.getMemRef(), - /*srcIndices=*/emptyArrayMemref, dst, - /*dstIndices=*/emptyArrayMemref, /*lengths=*/emptyArrayMemref); - - rewriter.replaceOpWithNewOp<IREE::MemRefToScalarOp>(loadOp, dst); - - return matchSuccess(); - } -}; - -struct StoreOpLowering : public OpRewritePattern<StoreOp> { - using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(StoreOp storeOp, - PatternRewriter &rewriter) const override { - if (storeOp.getMemRefType().getRank() != 0) { - storeOp.emitError() << "Cannot lower store of non-scalar"; - return matchFailure(); - } - - auto src = rewriter.create<IREE::ScalarToMemRefOp>( - storeOp.getLoc(), storeOp.getValueToStore()); - - auto emptyArrayMemref = createArrayConstant(rewriter, storeOp.getLoc(), {}); - rewriter.replaceOpWithNewOp<IREEInterp::HL::CopyOp>( - storeOp, src, /*srcIndices=*/emptyArrayMemref, storeOp.getMemRef(), - /*dstIndices=*/emptyArrayMemref, /*lengths=*/emptyArrayMemref); - - return matchSuccess(); - } -}; - -#define UNARY_OP_LOWERING(StdOpType, IREEOpType) \ - struct StdOpType##Lowering : public UnaryOpLowering<StdOpType, IREEOpType> { \ - using UnaryOpLowering::UnaryOpLowering; \ - }; - -#define BINARY_OP_LOWERING(StdOpType, IREEOpType) \ - struct StdOpType##Lowering \ - : public BinaryOpLowering<StdOpType, IREEOpType> { \ - using BinaryOpLowering::BinaryOpLowering; \ - }; - -#define TERNARY_OP_LOWERING(StdOpType, IREEOpType) \ - struct StdOpType##Lowering \ - : public TernaryOpLowering<StdOpType, IREEOpType> { \ - using TernaryOpLowering::TernaryOpLowering; \ - }; - -// UNARY_OP_LOWERING(RankOp, IREEInterp::HL::RankOp); -UNARY_OP_LOWERING(DimOp, IREEInterp::HL::DimOp); -// UNARY_OP_LOWERING(ShapeOp, IREEInterp::HL::ShapeOp); -// UNARY_OP_LOWERING(LengthOp, IREEInterp::HL::LengthOp); - -// UNARY_OP_LOWERING(NotOp, IREEInterp::HL::NotOp); -BINARY_OP_LOWERING(AndOp, IREEInterp::HL::AndOp); -BINARY_OP_LOWERING(OrOp, IREEInterp::HL::OrOp); -// BINARY_OP_LOWERING(XorOp, IREEInterp::HL::XorOp); -// BINARY_OP_LOWERING(ShiftLeftOp, IREEInterp::HL::ShiftLeftOp); -// BINARY_OP_LOWERING(ShiftRightLogicalOp, IREEInterp::HL::ShiftRightLogicalOp); -// BINARY_OP_LOWERING(ShiftRightArithmeticOp, -// IREEInterp::HL::ShiftRightArithmeticOp); - -BINARY_OP_LOWERING(AddIOp, IREEInterp::HL::AddIOp); -BINARY_OP_LOWERING(AddFOp, IREEInterp::HL::AddFOp); -BINARY_OP_LOWERING(SubIOp, IREEInterp::HL::SubIOp); -BINARY_OP_LOWERING(SubFOp, IREEInterp::HL::SubFOp); -// UNARY_OP_LOWERING(AbsIOp, IREEInterp::HL::AbsIOp); -// UNARY_OP_LOWERING(AbsFOp, IREEInterp::HL::AbsFOp); -BINARY_OP_LOWERING(MulIOp, IREEInterp::HL::MulIOp); -BINARY_OP_LOWERING(MulFOp, IREEInterp::HL::MulFOp); -BINARY_OP_LOWERING(DivISOp, IREEInterp::HL::DivISOp); -BINARY_OP_LOWERING(DivIUOp, IREEInterp::HL::DivIUOp); -BINARY_OP_LOWERING(DivFOp, IREEInterp::HL::DivFOp); -// BINARY_OP_LOWERING(MulAddIOp, IREEInterp::HL::MulAddIOp); -// BINARY_OP_LOWERING(MulAddFOp, IREEInterp::HL::MulAddFOp); -// UNARY_OP_LOWERING(ExpFOp, IREEInterp::HL::ExpFOp); -// UNARY_OP_LOWERING(LogFOp, IREEInterp::HL::LogFOp); -// UNARY_OP_LOWERING(RsqrtFOp, IREEInterp::HL::RsqrtFOp); -// UNARY_OP_LOWERING(CosFOp, IREEInterp::HL::CosFOp); -// UNARY_OP_LOWERING(SinFOp, IREEInterp::HL::SinFOp); -// UNARY_OP_LOWERING(TanhFOp, IREEInterp::HL::TanhFOp); -// UNARY_OP_LOWERING(Atan2FOp, IREEInterp::HL::Atan2FOp); - -// BINARY_OP_LOWERING(MinISOp, IREEInterp::HL::MinISOp); -// BINARY_OP_LOWERING(MinIUOp, IREEInterp::HL::MinIUOp); -// BINARY_OP_LOWERING(MinFOp, IREEInterp::HL::MinFOp); -// BINARY_OP_LOWERING(MaxISOp, IREEInterp::HL::MaxISOp); -// BINARY_OP_LOWERING(MaxIUOp, IREEInterp::HL::MaxIUOp); -// BINARY_OP_LOWERING(MaxFOp, IREEInterp::HL::MaxFOp); -// TERNARY_OP_LOWERING(ClampFOp, IREEInterp::HL::ClampFOp); -// UNARY_OP_LOWERING(FloorFOp, IREEInterp::HL::FloorFOp); -// UNARY_OP_LOWERING(CeilFOp, IREEInterp::HL::CeilFOp); - -} // namespace - -void populateLowerStdToInterpreterPatterns(OwningRewritePatternList &patterns, - MLIRContext *ctx) { - patterns.insert< - // Control flow. - CallOpLowering, CallIndirectOpLowering, ReturnOpLowering, - BranchOpLowering, CondBranchOpLowering, CmpIOpLowering, CmpFOpLowering, - // Memory management. - AllocOpLowering, DeallocOpLowering, LoadOpLowering, StoreOpLowering, - // Shape operations. - DimOpLowering, - // Logical ops. - AndOpLowering, OrOpLowering, - // Arithmetic ops. - AddIOpLowering, AddFOpLowering, SubIOpLowering, SubFOpLowering, - MulIOpLowering, MulFOpLowering, DivISOpLowering, DivIUOpLowering, - DivFOpLowering>(ctx); -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Transforms/Interpreter/LowerToInterpreterDialect.cpp b/iree/compiler/Transforms/Interpreter/LowerToInterpreterDialect.cpp deleted file mode 100644 index 19524b0..0000000 --- a/iree/compiler/Transforms/Interpreter/LowerToInterpreterDialect.cpp +++ /dev/null
@@ -1,63 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Dialect.h" -#include "iree/compiler/IR/Interpreter/HLDialect.h" -#include "iree/compiler/IR/Interpreter/LLDialect.h" -#include "iree/compiler/Transforms/Interpreter/Rewrites.h" -#include "iree/compiler/Transforms/Rewrites.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" -#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" - -namespace mlir { -namespace iree_compiler { -namespace { - -class LowerToInterpreterDialectPass - : public FunctionPass<LowerToInterpreterDialectPass> { - public: - void runOnFunction() override { - OwningRewritePatternList patterns; - auto* ctx = &getContext(); - xla_hlo::PopulateGeneralDotOpLoweringPatterns(&patterns, ctx); - xla_hlo::PopulateXlaToStdPatterns(&patterns, ctx); - populateLowerStdToIreePatterns(patterns, ctx); - populateLowerStdToInterpreterPatterns(patterns, ctx); - populateLowerXlaToIreePatterns(patterns, ctx); - populateLowerXlaToInterpreterPatterns(patterns, ctx); - - ConversionTarget target(getContext()); - target.addLegalDialect<IREEHLInterpreterDialect, IREEDialect>(); - target.addLegalOp<FuncOp, ReturnOp>(); - if (failed(applyFullConversion(getFunction(), target, patterns))) { - return signalPassFailure(); - } - } -}; - -} // namespace - -std::unique_ptr<OpPassBase<FuncOp>> createLowerToInterpreterDialectPass() { - return std::make_unique<LowerToInterpreterDialectPass>(); -} - -static PassRegistration<LowerToInterpreterDialectPass> pass( - "lower-to-iree-interpreter", - "Convert all ops to the IREE interpreter dialect"); - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Transforms/Interpreter/LowerXLAToInterpreterDialect.cpp b/iree/compiler/Transforms/Interpreter/LowerXLAToInterpreterDialect.cpp deleted file mode 100644 index 7bbb0d6..0000000 --- a/iree/compiler/Transforms/Interpreter/LowerXLAToInterpreterDialect.cpp +++ /dev/null
@@ -1,565 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Dialect.h" -#include "iree/compiler/IR/Interpreter/HLDialect.h" -#include "iree/compiler/IR/Interpreter/HLOps.h" -#include "iree/compiler/IR/Ops.h" -#include "iree/compiler/Transforms/ConversionUtils.h" -#include "iree/compiler/Utils/MemRefUtils.h" -#include "iree/compiler/Utils/OpCreationUtils.h" -#include "iree/compiler/Utils/TypeConversionUtils.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Function.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/IR/Value.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" -#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -// TODO(suderman): tablegen this? or something a bit more flexible. - -#define UNARY_OP_LOWERING(XlaOpType, IREEOpType) \ - struct XlaOpType##Lowering \ - : public UnaryOpLowering<xla_hlo::XlaOpType, IREEOpType> { \ - using UnaryOpLowering::UnaryOpLowering; \ - }; - -#define TERNARY_OP_LOWERING(XlaOpType, IREEOpType) \ - struct XlaOpType##Lowering \ - : public TernaryOpLowering<xla_hlo::XlaOpType, IREEOpType> { \ - using TernaryOpLowering::TernaryOpLowering; \ - }; - -UNARY_OP_LOWERING(CopyOp, IREEInterp::HL::CloneOp); -UNARY_OP_LOWERING(ExpOp, IREEInterp::HL::ExpFOp); -UNARY_OP_LOWERING(LogOp, IREEInterp::HL::LogFOp); -UNARY_OP_LOWERING(FloorOp, IREEInterp::HL::FloorFOp); -UNARY_OP_LOWERING(RsqrtOp, IREEInterp::HL::RsqrtFOp); -UNARY_OP_LOWERING(TanhOp, IREEInterp::HL::TanhFOp); -TERNARY_OP_LOWERING(SelectOp, IREEInterp::HL::SelectOp); - -#undef UNARY_OP_LOWERING -#undef TERNARY_OP_LOWERING - -template <typename T> -static Operation *createShapeTargetingOp(ConversionPatternRewriter &rewriter, - Location loc, Value *input, - MemRefType targetType) { - auto shapeOp = createArrayConstant(rewriter, loc, targetType.getShape()); - return rewriter.create<T>(loc, targetType, input, shapeOp); -} - -static Value *inputAsMemref(ConversionPatternRewriter &rewriter, Operation *op, - Value *tensor) { - return wrapAsMemRef(loadAccessValue(op->getLoc(), tensor, rewriter), op, - rewriter); -} - -template <typename SrcOp> -class XlaOpLowering : public OpConversionPattern<SrcOp> { - using OpConversionPattern<SrcOp>::OpConversionPattern; - - PatternMatchResult matchAndRewrite( - SrcOp srcOp, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - SmallVector<Value *, 4> memrefOperands; - for (auto operand : operands) { - memrefOperands.push_back(inputAsMemref(rewriter, srcOp, operand)); - } - - if (auto dstOp = rewriteInternal(&srcOp, memrefOperands, rewriter)) { - rewriter.replaceOp(srcOp, - wrapAsTensor(dstOp->getResult(0), srcOp, rewriter)); - return this->matchSuccess(); - } - return this->matchFailure(); - } - - protected: - virtual Operation *rewriteInternal( - SrcOp *op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const { - llvm_unreachable("unimplemented rewrite, did you mean rewriteTerminator?"); - } -}; - -struct BroadcastInDimOpLowering - : public XlaOpLowering<xla_hlo::BroadcastInDimOp> { - using XlaOpLowering::XlaOpLowering; - - Operation *rewriteInternal( - xla_hlo::BroadcastInDimOp *op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - auto *inputValue = operands[0]; - auto inputType = inputValue->getType().cast<MemRefType>(); - auto finalType = convertTypeToMemRef(*op); - - // Reshape to scalar and broadcast. - auto createFinal = createShapeTargetingOp<IREEInterp::HL::BroadcastOp>; - llvm::SmallVector<int64_t, 6> intermediateShape{}; - - // Or reshape to final rank and tile. - if (inputType.getNumElements() != 1) { - createFinal = createShapeTargetingOp<IREEInterp::HL::TileOp>; - - intermediateShape = llvm::SmallVector<int64_t, 6>(finalType.getRank(), 1); - auto inputShape = inputType.getShape(); - auto dimensions = op->broadcast_dimensions(); - for (size_t i = 0; i < inputType.getRank(); ++i) { - auto index = dimensions->getValue(i).cast<IntegerAttr>().getInt(); - intermediateShape[index] = inputShape[i]; - } - } - - auto intermediateType = - MemRefType::get(intermediateShape, inputType.getElementType()); - auto reshapeOp = createShapeTargetingOp<IREEInterp::HL::ReshapeOp>( - rewriter, op->getLoc(), inputValue, intermediateType); - return createFinal(rewriter, op->getLoc(), reshapeOp->getResult(0), - finalType); - } -}; - -struct ConcatOpLowering : public XlaOpLowering<xla_hlo::ConcatenateOp> { - using XlaOpLowering::XlaOpLowering; - - Operation *rewriteInternal( - xla_hlo::ConcatenateOp *op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - auto finalType = convertTypeToMemRef(*op); - - return rewriter.create<IREEInterp::HL::ConcatOp>( - op->getLoc(), finalType, operands, - rewriter.getI32IntegerAttr(op->dimension().getZExtValue())); - } -}; - -struct DotOpLowering : public XlaOpLowering<xla_hlo::DotOp> { - using XlaOpLowering::XlaOpLowering; - - Operation *rewriteInternal( - xla_hlo::DotOp *op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - auto *lhsValue = operands[0]; - auto *rhsValue = operands[1]; - - auto finalType = convertTypeToMemRef(*op); - auto elementType = finalType.getElementType(); - if (!elementType.isa<FloatType>()) { - op->emitOpError("xla_hlo.dot only supports floating point values"); - } - - Operation *matMulOp = rewriter - .create<IREEInterp::HL::MatMulFOp>( - op->getLoc(), finalType, lhsValue, rhsValue) - .getOperation(); - return matMulOp; - } -}; - -struct DynamicUpdateSliceOpLowering - : public XlaOpLowering<xla_hlo::DynamicUpdateSliceOp> { - using XlaOpLowering::XlaOpLowering; - - Operation *rewriteInternal( - xla_hlo::DynamicUpdateSliceOp *op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - auto operand = operands[0]; - auto update = operands[1]; - - auto updateType = update->getType().cast<ShapedType>(); - Value *lengthConstant = - createArrayConstant(rewriter, op->getLoc(), updateType.getShape()); - - auto startIndices = makeArrayRef(operands).drop_front(2); - const int rank = startIndices.size(); - llvm::SmallVector<Value *, 4> valuesToConcat; - valuesToConcat.reserve(startIndices.size()); - auto type = getElementTypeOrSelf(startIndices.front()); - - // To generate the offset matrix we need to convert the variadic tensors - // into a reshaped and concated value. - for (auto index : startIndices) { - auto reshapedIndex = rewriter.create<IREEInterp::HL::ReshapeOp>( - op->getLoc(), MemRefType::get({1}, type), index, - createArrayConstant(rewriter, op->getLoc(), {1})); - valuesToConcat.push_back(reshapedIndex); - } - - auto dstOffset = rewriter - .create<IREEInterp::HL::ConcatOp>( - op->getLoc(), MemRefType::get({rank}, type), - valuesToConcat, rewriter.getI32IntegerAttr(0)) - .getResult(); - - llvm::SmallVector<int64_t, 4> zero_offset; - zero_offset.resize(updateType.getRank(), 0); - auto srcOffset = createArrayConstant(rewriter, op->getLoc(), zero_offset); - - auto copiedOperand = rewriter.create<IREEInterp::HL::CloneOp>( - op->getLoc(), operand->getType(), operand); - - rewriter - .create<IREEInterp::HL::CopyOp>(op->getLoc(), update, srcOffset, - copiedOperand, dstOffset, - lengthConstant) - .getOperation(); - - return copiedOperand; - } -}; - -template <typename XlaOpType, typename IreeFloatOpType, typename IreeIntOpType> -struct BinaryFloatIntOpLowering : public XlaOpLowering<XlaOpType> { - using XlaOpLowering<XlaOpType>::XlaOpLowering; - - Operation *rewriteInternal( - XlaOpType *op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - auto *lhs = operands[0]; - auto *rhs = operands[1]; - auto inputType = lhs->getType().cast<MemRefType>(); - auto elementType = inputType.getElementType(); - - if (elementType.isa<FloatType>()) { - return rewriter.create<IreeFloatOpType>(op->getLoc(), inputType, lhs, - rhs); - } - - return rewriter.create<IreeIntOpType>(op->getLoc(), inputType, lhs, rhs); - } -}; - -struct MaxOpLowering - : public BinaryFloatIntOpLowering<xla_hlo::MaxOp, IREEInterp::HL::MaxFOp, - IREEInterp::HL::MaxISOp> { - using BinaryFloatIntOpLowering::BinaryFloatIntOpLowering; -}; - -struct MinOpLowering - : public BinaryFloatIntOpLowering<xla_hlo::MinOp, IREEInterp::HL::MinFOp, - IREEInterp::HL::MinISOp> { - using BinaryFloatIntOpLowering::BinaryFloatIntOpLowering; -}; - -struct ConvertLowering : public XlaOpLowering<xla_hlo::ConvertOp> { - using XlaOpLowering<xla_hlo::ConvertOp>::XlaOpLowering; - - Operation *rewriteInternal( - xla_hlo::ConvertOp *op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - auto *operand = operands[0]; - auto *result = op->getResult(); - - auto operandType = operand->getType().cast<MemRefType>().getElementType(); - auto resultType = result->getType().cast<ShapedType>().getElementType(); - - auto newResultType = convertTypeToMemRef(result); - -#define ConvertCase(InType, OutType, NewOp) \ - { \ - if (operandType.isa<InType>() && resultType.isa<OutType>()) { \ - return rewriter.create<NewOp>(op->getLoc(), newResultType, operand); \ - } \ - } - ConvertCase(IntegerType, IntegerType, IREEInterp::HL::ConvertSSOp); - ConvertCase(IntegerType, FloatType, IREEInterp::HL::ConvertSFOp); - ConvertCase(FloatType, IntegerType, IREEInterp::HL::ConvertFSOp); - ConvertCase(FloatType, FloatType, IREEInterp::HL::ConvertFFOp); -#undef ConvertCase - - return nullptr; - } -}; - -// Lowers a subset of gathers along axis 0 that are really just a slice and -// reshape. -struct GatherOpLowering : public OpConversionPattern<xla_hlo::GatherOp> { - using OpConversionPattern::OpConversionPattern; - - // TODO(gcmn): This only handles a minimal number of cases. When XLA - // redefines gather to be simpler, lower it properly. - PatternMatchResult matchAndRewrite( - xla_hlo::GatherOp gatherOp, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - if (gatherOp.index_vector_dim() != 0) { - gatherOp.emitRemark() - << "Couldn't lower gather with index_vector_dim != 0"; - return matchFailure(); - } - if (gatherOp.start_index_map().getType().getRank() != 1 || - gatherOp.start_index_map().getValue(0).cast<IntegerAttr>().getValue() != - 0) { - gatherOp.emitRemark() - << "Couldn't lower gather with start_index_map != [0]"; - return matchFailure(); - } - if (gatherOp.collapsed_slice_dims().getType().getRank() != 1 || - gatherOp.collapsed_slice_dims() - .getValue(0) - .cast<IntegerAttr>() - .getValue() != 0) { - gatherOp.emitRemark() - << "Couldn't lower gather with collapsed_dims != [0]"; - return matchFailure(); - } - - auto resultType = gatherOp.getResult()->getType().cast<RankedTensorType>(); - if (gatherOp.offset_dims().getType().getNumElements() != - resultType.getRank()) { - gatherOp.emitRemark() << "Couldn't lower gather with offset_dims != " - "[0,...,rank of output]"; - return matchFailure(); - } - for (auto it : llvm::enumerate(gatherOp.offset_dims())) { - if (it.index() != it.value()) { - gatherOp.emitRemark() << "Couldn't lower gather with offset_dims != " - "[0,...,rank of output]"; - return matchFailure(); - } - } - - for (auto it : llvm::enumerate(resultType.getShape())) { - if (gatherOp.slice_sizes() - .getValue(it.index() + 1) - .cast<IntegerAttr>() - .getValue() != it.value()) { - gatherOp.emitRemark() - << "Couldn't lower gather with slice_sizes not [1] + final shape"; - return matchFailure(); - } - } - - auto inputType = gatherOp.operand()->getType().cast<RankedTensorType>(); - - auto startIndices = - inputAsMemref(rewriter, gatherOp, gatherOp.start_indices()); - auto startIndicesType = startIndices->getType().cast<MemRefType>(); - if (startIndicesType.getNumElements() != inputType.getRank()) { - auto extraDims = inputType.getRank() - startIndicesType.getNumElements(); - auto elementType = startIndicesType.getElementType(); - - if (startIndicesType.getRank() != 1) { - startIndices = createShapeTargetingOp<IREEInterp::HL::ReshapeOp>( - rewriter, gatherOp.getLoc(), startIndices, - MemRefType::get({1}, elementType)) - ->getResult(0); - } - - llvm::SmallVector<int64_t, 4> zeroes; - zeroes.resize(extraDims, 0); - - auto elementsAttr = DenseIntElementsAttr::get( - RankedTensorType::get(zeroes.size(), elementType), - llvm::makeArrayRef(zeroes)); - - auto extraStartIndices = - rewriter.create<IREE::ConstantOp>(gatherOp.getLoc(), elementsAttr); - - auto memrefOutputType = - MemRefType::get({inputType.getRank()}, elementType); - - SmallVector<Value *, 2> valuesToConcat = {startIndices, - extraStartIndices}; - startIndices = rewriter.create<IREEInterp::HL::ConcatOp>( - gatherOp.getLoc(), memrefOutputType, valuesToConcat, - rewriter.getI32IntegerAttr(0)); - } - - auto sliceSizeValues = gatherOp.slice_sizes().getValues<int64_t>(); - std::vector<int64_t> sliceSizes = {sliceSizeValues.begin(), - sliceSizeValues.end()}; - auto dstType = MemRefType::get(sliceSizes, inputType.getElementType()); - - auto src = inputAsMemref(rewriter, gatherOp, gatherOp.operand()); - std::vector<Value *> dim_pieces; - auto dst = rewriter.create<IREEInterp::HL::AllocHeapOp>( - gatherOp.getLoc(), dstType, dim_pieces); - auto lengths = rewriter.create<IREE::ConstantOp>(gatherOp.getLoc(), - gatherOp.slice_sizes()); - llvm::SmallVector<int64_t, 4> zero_offset; - zero_offset.resize(dstType.getRank(), 0); - auto dstIndices = - createArrayConstant(rewriter, gatherOp.getLoc(), zero_offset); - - rewriter.create<IREEInterp::HL::CopyOp>( - gatherOp.getLoc(), src, startIndices, dst, dstIndices, lengths); - - auto reshaped = createShapeTargetingOp<IREEInterp::HL::ReshapeOp>( - rewriter, gatherOp.getLoc(), dst, convertTypeToMemRef(gatherOp)); - rewriter.replaceOp( - gatherOp, wrapAsTensor(reshaped->getResult(0), gatherOp, rewriter)); - - return matchSuccess(); - } -}; - -struct SliceOpLowering : public XlaOpLowering<xla_hlo::SliceOp> { - using XlaOpLowering<xla_hlo::SliceOp>::XlaOpLowering; - - Operation *rewriteInternal( - xla_hlo::SliceOp *op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - // XLA slice has value semantics, whereas the IREE slice creates a view. We - // lower it to a copy if all strides are one which may be transformed to a - // slice by later optimizations. - auto isNotOne = [](APInt stride) { return stride != 1; }; - if (llvm::any_of(op->strides(), isNotOne)) { - op->emitRemark() << "Could not lower slice op with non-singular strides"; - return nullptr; - } - - auto finalType = convertTypeToMemRef(*op); - auto src = operands[0]; - std::vector<Value *> dim_pieces; - auto dst = rewriter.create<IREEInterp::HL::AllocHeapOp>( - op->getLoc(), finalType, dim_pieces); - auto srcIndices = - rewriter.create<IREE::ConstantOp>(op->getLoc(), op->start_indices()); - auto lengths = - createArrayConstant(rewriter, op->getLoc(), finalType.getShape()); - - llvm::SmallVector<int64_t, 4> zero_offset; - zero_offset.resize(finalType.getRank(), 0); - auto dstIndices = createArrayConstant(rewriter, op->getLoc(), zero_offset); - - rewriter.create<IREEInterp::HL::CopyOp>(op->getLoc(), src, srcIndices, dst, - dstIndices, lengths); - return dst; - } -}; - -struct PadOpLowering : public XlaOpLowering<xla_hlo::PadOp> { - using XlaOpLowering::XlaOpLowering; - - Operation *rewriteInternal( - xla_hlo::PadOp *op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - auto *src = operands[0]; - auto *paddingValue = operands[1]; - - // TODO(b/140836672) Support negative padding - for (int i = 0; i < op->edge_padding_high().getNumElements(); ++i) { - if (op->edge_padding_high().getValue<IntegerAttr>(i).getInt() < 0 || - op->edge_padding_low().getValue<IntegerAttr>(i).getInt() < 0) { - op->emitRemark() << "Could not lower pad op with negative padding"; - return nullptr; - } - } - - auto edgePaddingLowOp = - rewriter.create<IREE::ConstantOp>(op->getLoc(), op->edge_padding_low()); - auto edgePaddingHighOp = rewriter.create<IREE::ConstantOp>( - op->getLoc(), op->edge_padding_high()); - auto interiorPaddingOp = - rewriter.create<IREE::ConstantOp>(op->getLoc(), op->interior_padding()); - - return rewriter.create<IREEInterp::HL::PadOp>( - op->getLoc(), convertTypeToMemRef(*op), src, paddingValue, - edgePaddingLowOp, edgePaddingHighOp, interiorPaddingOp); - } -}; - -struct ReshapeOpLowering : public XlaOpLowering<xla_hlo::ReshapeOp> { - using XlaOpLowering::XlaOpLowering; - - Operation *rewriteInternal( - xla_hlo::ReshapeOp *op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - return createShapeTargetingOp<IREEInterp::HL::ReshapeOp>( - rewriter, op->getLoc(), operands[0], convertTypeToMemRef(*op)); - } -}; - -struct TransposeOpLowering : public XlaOpLowering<xla_hlo::TransposeOp> { - using XlaOpLowering::XlaOpLowering; - - Operation *rewriteInternal( - xla_hlo::TransposeOp *op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - auto permutationOp = - rewriter.create<IREE::ConstantOp>(op->getLoc(), op->permutation()); - - return rewriter.create<IREEInterp::HL::TransposeOp>( - op->getLoc(), convertTypeToMemRef(*op), operands[0], permutationOp); - } -}; - -struct ReverseOpLowering : public XlaOpLowering<xla_hlo::ReverseOp> { - using XlaOpLowering::XlaOpLowering; - - Operation *rewriteInternal( - xla_hlo::ReverseOp *op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - auto reverseOp = - rewriter.create<IREE::ConstantOp>(op->getLoc(), op->dimensions()); - - return rewriter.create<IREEInterp::HL::ReverseOp>( - op->getLoc(), convertTypeToMemRef(*op), operands[0], reverseOp); - } -}; - -} // namespace - -void populateLowerXlaToInterpreterPatterns(OwningRewritePatternList &patterns, - MLIRContext *ctx) { - patterns - .insert<BroadcastInDimOpLowering, ConcatOpLowering, ConvertLowering, - CopyOpLowering, DotOpLowering, DynamicUpdateSliceOpLowering, - ExpOpLowering, FloorOpLowering, GatherOpLowering, LogOpLowering, - MaxOpLowering, MinOpLowering, PadOpLowering, ReshapeOpLowering, - ReverseOpLowering, RsqrtOpLowering, SelectOpLowering, - SliceOpLowering, TransposeOpLowering, TanhOpLowering>(ctx); -} - -namespace { -// Just for testing these passes. -// TODO(b/141337493) can we get rid of this pass entirely? -class LowerXLAToInterpreterDialectPass - : public FunctionPass<LowerXLAToInterpreterDialectPass> { - public: - void runOnFunction() override { - OwningRewritePatternList patterns; - populateLowerXlaToInterpreterPatterns(patterns, &getContext()); - - ConversionTarget target(getContext()); - target.addLegalDialect<IREEHLInterpreterDialect, IREEDialect>(); - if (failed(applyPartialConversion(getFunction(), target, patterns))) { - return signalPassFailure(); - } - } -}; - -} // namespace - -static PassRegistration<LowerXLAToInterpreterDialectPass> pass( - "lower-xla-to-iree-interpreter", - "Convert all XLA functions to the IREE dialect"); - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Transforms/Interpreter/MakeExecutableABI.cpp b/iree/compiler/Transforms/Interpreter/MakeExecutableABI.cpp deleted file mode 100644 index e304ba5..0000000 --- a/iree/compiler/Transforms/Interpreter/MakeExecutableABI.cpp +++ /dev/null
@@ -1,147 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Interpreter/HLOps.h" -#include "iree/compiler/IR/Ops.h" -#include "iree/compiler/Utils/OpCreationUtils.h" -#include "iree/compiler/Utils/OpUtils.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Support/LogicalResult.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -// Replaces a load_input op with valid IR that loads the input value. -LogicalResult replaceLoadInputOp(IREE::LoadInputOp bindOp) { - OpBuilder builder(bindOp); - - Value *newValue = nullptr; - auto dstType = bindOp.getResult()->getType(); - if (dstType.isa<TensorType>()) { - auto castOp = - builder.create<IREE::MemRefToTensorOp>(bindOp.getLoc(), bindOp.src()); - newValue = castOp.getResult(); - } else if (dstType.isIntOrIndexOrFloat()) { - auto loadOp = builder.create<LoadOp>(bindOp.getLoc(), dstType, bindOp.src(), - ArrayRef<Value *>{}); - newValue = loadOp.getResult(); - } else { - return bindOp.emitError() - << "Unsupported input destination type " << dstType; - } - - bindOp.replaceAllUsesWith(newValue); - bindOp.erase(); - - return success(); -} - -// Replaces a store_output op with valid IR that stores the output value. -LogicalResult replaceStoreOutputOp(IREE::StoreOutputOp bindOp) { - OpBuilder builder(bindOp); - - auto srcType = bindOp.src()->getType(); - if (srcType.isa<MemRefType>()) { - // Already stored into the output. - } else if (srcType.isa<TensorType>()) { - auto castOp = - builder.create<IREE::TensorToMemRefOp>(bindOp.getLoc(), bindOp.src()); - - // Insert a copy to our output parameter. - auto dst = bindOp.dst()->getType().cast<ShapedType>(); - if (!dst.hasStaticShape()) { - return bindOp.emitError() - << "Dynamic output args are not yet implemented"; - } - - auto zeroValues = llvm::SmallVector<int64_t, 4>(dst.getRank()); - auto zeros = createArrayConstant(builder, bindOp.getLoc(), zeroValues); - auto lengths = - createArrayConstant(builder, bindOp.getLoc(), dst.getShape()); - builder.create<IREEInterp::HL::CopyOp>(bindOp.getLoc(), castOp.getResult(), - zeros, bindOp.dst(), zeros, lengths); - } else if (srcType.isIntOrIndexOrFloat()) { - builder.create<StoreOp>(bindOp.getLoc(), bindOp.src(), bindOp.dst(), - ArrayRef<Value *>{}); - } else { - return bindOp.emitError() << "Unsupported output src type " << srcType; - } - - bindOp.erase(); - - return success(); -} - -// Strips iree.bind_* ops from |func|. -LogicalResult stripBindingOps(FuncOp func) { - // Find iree.load_input ops to replace with memref_to_tensor if needed. - SmallVector<IREE::LoadInputOp, 8> bindInputOps; - func.walk([&](IREE::LoadInputOp bindOp) { bindInputOps.push_back(bindOp); }); - for (auto &bindOp : bindInputOps) { - if (failed(replaceLoadInputOp(bindOp))) { - return failure(); - } - } - - // Find iree.store_output ops and replace with tensor_to_memref if needed. - SmallVector<IREE::StoreOutputOp, 8> bindOutputOps; - func.walk( - [&](IREE::StoreOutputOp bindOp) { bindOutputOps.push_back(bindOp); }); - for (auto &bindOp : bindOutputOps) { - if (failed(replaceStoreOutputOp(bindOp))) { - return failure(); - } - } - - return success(); -} - -} // namespace - -// Finds iree.executable.export functions and fixes up bindings. -// For the interpreter this really just means stripping the bind ops entirely. -class MakeExecutableABIPass : public ModulePass<MakeExecutableABIPass> { - public: - void runOnModule() override { - auto module = getModule(); - for (auto func : module.getOps<FuncOp>()) { - if (func.getAttr("iree.executable.export")) { - if (failed(stripBindingOps(func))) { - return signalPassFailure(); - } - } - } - } -}; - -std::unique_ptr<OpPassBase<ModuleOp>> createMakeExecutableABIPass() { - return std::make_unique<MakeExecutableABIPass>(); -} - -static PassRegistration<MakeExecutableABIPass> pass( - "iree-make-executable-abi", - "Makes functions match the IREE dispatch executable ABI."); - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Transforms/Interpreter/test/BUILD b/iree/compiler/Transforms/Interpreter/test/BUILD deleted file mode 100644 index 59cfc89..0000000 --- a/iree/compiler/Transforms/Interpreter/test/BUILD +++ /dev/null
@@ -1,16 +0,0 @@ -# Tests for lowering MLIR in various dialects to IREE interpreter bytecode. - -load("//iree:build_defs.bzl", "iree_glob_lit_tests", "iree_setup_lit_package") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -iree_setup_lit_package( - data = [ - "//iree/tools:iree-opt", - ], -) - -iree_glob_lit_tests()
diff --git a/iree/compiler/Transforms/Interpreter/test/xla/BUILD b/iree/compiler/Transforms/Interpreter/test/xla/BUILD deleted file mode 100644 index 8ad177e..0000000 --- a/iree/compiler/Transforms/Interpreter/test/xla/BUILD +++ /dev/null
@@ -1,17 +0,0 @@ -# Tests specific to lowering XLA to IREE. - -load("//iree:build_defs.bzl", "iree_glob_lit_tests", "iree_setup_lit_package") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -iree_setup_lit_package( - data = [ - "//iree/tools:iree-opt", - "//iree/tools:iree-run-mlir", - ], -) - -iree_glob_lit_tests()
diff --git a/iree/compiler/Transforms/LegalizeTypeStorage.cpp b/iree/compiler/Transforms/LegalizeTypeStorage.cpp deleted file mode 100644 index ab80e05..0000000 --- a/iree/compiler/Transforms/LegalizeTypeStorage.cpp +++ /dev/null
@@ -1,145 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/Utils/TypeConversionUtils.h" -#include "llvm/ADT/DenseSet.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Transforms/Utils.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -bool convertOperation(Operation *oldOp, OpBuilder &builder, - BlockAndValueMapping *mapping) { - OperationState state(oldOp->getLoc(), oldOp->getName()); - if (oldOp->getNumSuccessors() == 0) { - // Non-branching operations can just add all the operands. - for (auto *oldOperand : oldOp->getOperands()) { - state.operands.push_back(mapping->lookupOrDefault(oldOperand)); - } - } else { - // We add the operands separated by nullptr's for each successor. - unsigned firstSuccOperand = oldOp->getNumSuccessors() - ? oldOp->getSuccessorOperandIndex(0) - : oldOp->getNumOperands(); - auto opOperands = oldOp->getOpOperands(); - unsigned i = 0; - for (; i != firstSuccOperand; ++i) { - state.operands.push_back(mapping->lookupOrDefault(opOperands[i].get())); - } - for (unsigned succ = 0, e = oldOp->getNumSuccessors(); succ != e; ++succ) { - state.successors.push_back( - mapping->lookupOrDefault(oldOp->getSuccessor(succ))); - // Add sentinel to delineate successor operands. - state.operands.push_back(nullptr); - // Remap the successors operands. - for (auto *operand : oldOp->getSuccessorOperands(succ)) { - state.operands.push_back(mapping->lookupOrDefault(operand)); - } - } - } - for (const auto &oldType : oldOp->getResultTypes()) { - state.types.push_back(legalizeType(oldType)); - } - state.attributes = {oldOp->getAttrs().begin(), oldOp->getAttrs().end()}; - auto newOp = builder.createOperation(state); - for (int i = 0; i < newOp->getNumResults(); ++i) { - mapping->map(oldOp->getResult(i), newOp->getResult(i)); - } - return false; -} - -bool convertFunction(FuncOp oldFunction, FuncOp newFunction) { - OpBuilder builder(newFunction.getBody()); - BlockAndValueMapping mapping; - - // Create new blocks matching the expected arguments of the old ones. - // This sets up the block mappings to enable us to reference blocks forward - // during conversion. - newFunction.getBlocks().clear(); - for (auto &oldBlock : oldFunction.getBlocks()) { - auto *newBlock = builder.createBlock(&newFunction.getBody()); - mapping.map(&oldBlock, newBlock); - for (auto *oldArg : oldBlock.getArguments()) { - auto *newArg = newBlock->addArgument(legalizeType(oldArg->getType())); - mapping.map(oldArg, newArg); - } - } - - // Convert all ops in the blocks. - for (auto &oldBlock : oldFunction.getBlocks()) { - builder.setInsertionPointToEnd(mapping.lookupOrNull(&oldBlock)); - for (auto &oldOp : oldBlock.getOperations()) { - if (convertOperation(&oldOp, builder, &mapping)) { - return true; - } - } - } - - return false; -} - -} // namespace - -class LegalizeTypeStoragePass : public ModulePass<LegalizeTypeStoragePass> { - public: - void runOnModule() override { - auto module = getModule(); - - // Build a list of (oldFunction, newFunction) for all functions we need to - // replace. This will ensure that when we go to convert function bodies we - // have only new functions defined. - std::vector<std::pair<FuncOp, FuncOp>> convertedFunctions; - - for (auto oldFunction : module.getOps<FuncOp>()) { - // Create the replacement function, ensuring that we copy attributes. - auto newFunction = FuncOp::create( - oldFunction.getLoc(), oldFunction.getName(), - legalizeType(oldFunction.getType()).cast<FunctionType>(), - oldFunction.getDialectAttrs()); - convertedFunctions.push_back({oldFunction, newFunction}); - - // Perform the actual body conversion now that we have proper signatures. - if (convertFunction(oldFunction, newFunction)) { - return signalPassFailure(); - } - } - - // Replace functions in the module. - for (auto &pair : convertedFunctions) { - pair.first.erase(); - module.push_back(pair.second); - } - } -}; - -std::unique_ptr<OpPassBase<ModuleOp>> createLegalizeTypeStoragePass() { - return std::make_unique<LegalizeTypeStoragePass>(); -} - -static PassRegistration<LegalizeTypeStoragePass> pass( - "iree-legalize-type-storage", - "Legalizes types to ones supported by the IREE VM."); - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Transforms/LowerStdToIreeDialect.cpp b/iree/compiler/Transforms/LowerStdToIreeDialect.cpp deleted file mode 100644 index b2aef0c..0000000 --- a/iree/compiler/Transforms/LowerStdToIreeDialect.cpp +++ /dev/null
@@ -1,77 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Dialect.h" -#include "iree/compiler/IR/Ops.h" -#include "iree/compiler/Utils/MemRefUtils.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/PatternMatch.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -struct ConstantOpLowering : public OpRewritePattern<ConstantOp> { - using OpRewritePattern::OpRewritePattern; - - PatternMatchResult matchAndRewrite(ConstantOp op, - PatternRewriter &rewriter) const override { - if (auto elementsValue = op.getValue().dyn_cast<ElementsAttr>()) { - auto ireeConst = - rewriter.create<IREE::ConstantOp>(op.getLoc(), elementsValue); - - auto result = wrapAsTensor(ireeConst.getResult(), op, rewriter); - rewriter.replaceOp(op, result); - return matchSuccess(); - } - - auto type = op.getValue().getType(); - if (!type.isIntOrFloat()) { - return matchFailure(); - } - auto elementsValue = - DenseElementsAttr::get(RankedTensorType::get({}, type), op.getValue()); - auto ireeConst = - rewriter.create<IREE::ConstantOp>(op.getLoc(), elementsValue); - rewriter.replaceOpWithNewOp<IREE::MemRefToScalarOp>(op, ireeConst); - return matchSuccess(); - } -}; - -struct ExtractElementOpLowering : public OpRewritePattern<ExtractElementOp> { - using OpRewritePattern::OpRewritePattern; - - PatternMatchResult matchAndRewrite(ExtractElementOp op, - PatternRewriter &rewriter) const override { - Value *memRefInput = - wrapAsMemRef(loadAccessValue(op.getLoc(), op.getAggregate(), rewriter), - op, rewriter); - - SmallVector<Value *, 4> indices = {op.indices().begin(), - op.indices().end()}; - rewriter.replaceOpWithNewOp<LoadOp>(op, memRefInput, indices); - return matchSuccess(); - } -}; - -} // namespace - -void populateLowerStdToIreePatterns(OwningRewritePatternList &patterns, - MLIRContext *ctx) { - patterns.insert<ConstantOpLowering, ExtractElementOpLowering>(ctx); -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Transforms/LowerXLAToIreeDialect.cpp b/iree/compiler/Transforms/LowerXLAToIreeDialect.cpp deleted file mode 100644 index 7a7f87f..0000000 --- a/iree/compiler/Transforms/LowerXLAToIreeDialect.cpp +++ /dev/null
@@ -1,44 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Ops.h" -#include "iree/compiler/Utils/MemRefUtils.h" -#include "mlir/IR/PatternMatch.h" -#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -struct ConstOpLowering : public OpRewritePattern<xla_hlo::ConstOp> { - using OpRewritePattern::OpRewritePattern; - - PatternMatchResult matchAndRewrite(xla_hlo::ConstOp op, - PatternRewriter &rewriter) const override { - auto ireeConst = rewriter.create<IREE::ConstantOp>(op.getLoc(), op.value()); - rewriter.replaceOp(op, wrapAsTensor(ireeConst, op, rewriter)); - return matchSuccess(); - } -}; - -} // namespace - -void populateLowerXlaToIreePatterns(OwningRewritePatternList &patterns, - MLIRContext *ctx) { - patterns.insert<ConstOpLowering>(ctx); -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/AssignExecutableOrdinals.cpp b/iree/compiler/Transforms/Sequencer/AssignExecutableOrdinals.cpp deleted file mode 100644 index 467903b..0000000 --- a/iree/compiler/Transforms/Sequencer/AssignExecutableOrdinals.cpp +++ /dev/null
@@ -1,75 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/StructureOps.h" -#include "iree/compiler/Utils/OpUtils.h" -#include "llvm/ADT/DenseMap.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/Utils.h" - -namespace mlir { -namespace iree_compiler { - -class AssignExecutableOrdinalsPass - : public ModulePass<AssignExecutableOrdinalsPass> { - public: - void runOnModule() override { - Builder builder(getModule()); - int nextExecutableOrdinal = 0; - for (auto multiArchExecutableOp : - getModule().getOps<IREE::MultiArchExecutableOp>()) { - multiArchExecutableOp.setAttr( - "iree.ordinal", builder.getI32IntegerAttr(nextExecutableOrdinal++)); - - // We'll scan for all entry points in the first executable. Then on all - // other executables we can reuse the ordinals (ensuring that iteration - // order does not matter). - llvm::DenseMap<StringRef, FuncOp> entryPointMap; - for (auto executableOp : - multiArchExecutableOp.getBlock().getOps<IREE::ExecutableOp>()) { - executableOp.setAttr("iree.ordinal", - multiArchExecutableOp.getAttr("iree.ordinal")); - int nextEntryPointOrdinal = 0; - for (auto funcOp : executableOp.getInnerModule().getOps<FuncOp>()) { - if (!funcOp.getAttr("iree.executable.export")) continue; - auto it = entryPointMap.find(funcOp.getName()); - if (it == entryPointMap.end()) { - funcOp.setAttr("iree.ordinal", - builder.getI32IntegerAttr(nextEntryPointOrdinal++)); - entryPointMap.insert({funcOp.getName(), funcOp}); - } else { - funcOp.setAttr("iree.ordinal", it->second.getAttr("iree.ordinal")); - } - } - } - } - } -}; - -std::unique_ptr<OpPassBase<ModuleOp>> createAssignExecutableOrdinalsPass() { - return std::make_unique<AssignExecutableOrdinalsPass>(); -} - -static PassRegistration<AssignExecutableOrdinalsPass> pass( - "iree-assign-executable-ordinals", - "Assigns executable and entry point ordinals"); - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/AssignExecutableWorkloadAttrs.cpp b/iree/compiler/Transforms/Sequencer/AssignExecutableWorkloadAttrs.cpp deleted file mode 100644 index 2fcd6e3..0000000 --- a/iree/compiler/Transforms/Sequencer/AssignExecutableWorkloadAttrs.cpp +++ /dev/null
@@ -1,125 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Sequencer/LLOps.h" -#include "iree/compiler/IR/StructureOps.h" -#include "iree/compiler/Utils/OpUtils.h" -#include "llvm/ADT/StringMap.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/Utils.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -struct WorkloadInfo { - SmallVector<ElementsAttr, 4> staticWorkloads; - SmallVector<Value *, 4> dynamicWorkloads; -}; - -// Finds all dispatches and records their workload attributes mapped by -// (executable ordinal, entry point ordinal). -llvm::StringMap<llvm::StringMap<WorkloadInfo>> gatherExecutableWorkloadInfos( - ModuleOp moduleOp) { - llvm::StringMap<llvm::StringMap<WorkloadInfo>> workloadInfos; - for (auto funcOp : moduleOp.getOps<FuncOp>()) { - funcOp.walk([&](IREESeq::LL::DynamicDispatchOp op) { - auto &workloadInfo = - workloadInfos[op.getExecutable()][op.getEntryPoint()]; - workloadInfo.dynamicWorkloads.push_back(op.getWorkload()); - }); - funcOp.walk([&](IREESeq::LL::StaticDispatchOp op) { - auto &workloadInfo = - workloadInfos[op.getExecutable()][op.getEntryPoint()]; - for (auto existingWorkloadAttr : workloadInfo.staticWorkloads) { - if (existingWorkloadAttr == op.getWorkload()) { - return; // Already present, ignore. - } - } - workloadInfo.staticWorkloads.push_back(op.getWorkload()); - }); - } - return workloadInfos; -} - -// Adds attributes to the given executable entry point describing the workload -// info to the backends that will be processing them. -LogicalResult attributeExecutableEntryPointWorkload( - FuncOp entryPointOp, const WorkloadInfo &workloadInfo) { - if (!workloadInfo.dynamicWorkloads.empty()) { - return entryPointOp.emitError() << "Dynamic workloads not yet supported"; - } - if (workloadInfo.staticWorkloads.size() != 1) { - return entryPointOp.emitError() << "Static workload sizes differ in shape"; - } - - // Easy because we just support static workloads now. - // When this code is adapted to support dynamic workloads we'll want to put - // a pair of attrs describing which dimensions may be static and which args - // have the dynamic values to reference. - entryPointOp.setAttr("iree.executable.workload", - workloadInfo.staticWorkloads.front()); - - return success(); -} - -} // namespace - -class AssignExecutableWorkloadAttrsPass - : public ModulePass<AssignExecutableWorkloadAttrsPass> { - public: - void runOnModule() override { - Builder builder(getModule()); - - // Find all dispatches and capture their workload information. - // We store this information by executable and then entry point ordinal. - auto executableWorkloadInfos = gatherExecutableWorkloadInfos(getModule()); - - // Process each executable with the workload information. - for (auto &executableIt : executableWorkloadInfos) { - auto multiArchExecutableOp = cast<IREE::MultiArchExecutableOp>( - getModule().lookupSymbol(executableIt.first())); - for (auto executableOp : - multiArchExecutableOp.getBlock().getOps<IREE::ExecutableOp>()) { - for (auto &entryPointIt : executableIt.second) { - auto funcOp = cast<FuncOp>( - executableOp.getInnerModule().lookupSymbol(entryPointIt.first())); - if (failed(attributeExecutableEntryPointWorkload( - funcOp, entryPointIt.second))) { - return signalPassFailure(); - } - } - } - } - } -}; - -std::unique_ptr<OpPassBase<ModuleOp>> -createAssignExecutableWorkloadAttrsPass() { - return std::make_unique<AssignExecutableWorkloadAttrsPass>(); -} - -static PassRegistration<AssignExecutableWorkloadAttrsPass> pass( - "iree-assign-executable-workload-attrs", - "Assigns executable entrypoint workload attributes"); - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/BUILD b/iree/compiler/Transforms/Sequencer/BUILD deleted file mode 100644 index 777e28b..0000000 --- a/iree/compiler/Transforms/Sequencer/BUILD +++ /dev/null
@@ -1,46 +0,0 @@ -# Transforms specific to the IREE sequencer. - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "Sequencer", - srcs = [ - "AssignExecutableOrdinals.cpp", - "AssignExecutableWorkloadAttrs.cpp", - "FoldCompatibleDispatchRegions.cpp", - "IdentifyDispatchRegions.cpp", - "IdentifyReductionRegions.cpp", - "LegalizeInputs.cpp", - "LowerSequencerDialect.cpp", - "LowerStdToSequencerDialect.cpp", - "LowerToSequencerDialect.cpp", - "LowerXLAToSequencerDialect.cpp", - "OutlineDispatchRegions.cpp", - "OutlineReductionRegions.cpp", - "RematerializeDispatchConstants.cpp", - ], - hdrs = [ - "Passes.h", - "Rewrites.h", - ], - deps = [ - "//iree/compiler/IR", - "//iree/compiler/IR/Sequencer", - "//iree/compiler/Transforms", - "//iree/compiler/Utils", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", - "@local_config_mlir//:StandardDialectRegistration", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Support", - "@local_config_mlir//:TransformUtils", - "@local_config_mlir//:Transforms", - "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo", - "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_legalize_to_standard", - "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_lower_general_dot", - ], -)
diff --git a/iree/compiler/Transforms/Sequencer/FoldCompatibleDispatchRegions.cpp b/iree/compiler/Transforms/Sequencer/FoldCompatibleDispatchRegions.cpp deleted file mode 100644 index d4ea4a0..0000000 --- a/iree/compiler/Transforms/Sequencer/FoldCompatibleDispatchRegions.cpp +++ /dev/null
@@ -1,63 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Ops.h" -#include "iree/compiler/Utils/DispatchUtils.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/Utils.h" - -namespace mlir { -namespace iree_compiler { - -// Identifies dispatch regions that have compatible workloads and folds them. -// This relies on CSE having deduped workloads to simplify the logic to simply -// looking for dispatch regions using the same values. -class FoldCompatibleDispatchRegionsPass - : public FunctionPass<FoldCompatibleDispatchRegionsPass> { - public: - void runOnFunction() override { - auto func = getFunction(); - for (auto &block : func) { - if (failed(mergeBlockDispatchRegions(func, &block))) { - return signalPassFailure(); - } - } - } -}; - -std::unique_ptr<OpPassBase<FuncOp>> createFoldCompatibleDispatchRegionsPass() { - return std::make_unique<FoldCompatibleDispatchRegionsPass>(); -} - -static PassRegistration<FoldCompatibleDispatchRegionsPass> pass( - "iree-fold-compatible-dispatch-regions", - "Folds dispatch regions that have compatible workloads."); - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/IdentifyDispatchRegions.cpp b/iree/compiler/Transforms/Sequencer/IdentifyDispatchRegions.cpp deleted file mode 100644 index 93be8a9..0000000 --- a/iree/compiler/Transforms/Sequencer/IdentifyDispatchRegions.cpp +++ /dev/null
@@ -1,259 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <algorithm> - -#include "iree/compiler/IR/Ops.h" -#include "iree/compiler/Utils/DispatchUtils.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/Utils.h" -#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -// Returns true if the given |op| can be dispatched in all cases. -// Other passes may handle special cases of these ops but this initial -// identification is conservative. -bool isDispatchableOp(Operation *op) { - if (op->getDialect() && op->getDialect()->getNamespace().startswith("iree")) { - // Ignore things we've already produced as they should only relate to - // sequencer operations. - return false; - } else if (op->isKnownTerminator()) { - // Currently we skip all terminators as we want to leave them in the block - // to keep it valid. Future folding passes may take care of them if they are - // worth bringing into the dispatch region. - return false; - } else if (isa<CallOp>(op)) { - // This may be handled by a control-flow folding pass later once we have - // done our initial analysis and know what functions are compatible. - return false; - } else if (isa<CallIndirectOp>(op)) { - // Indirect calls are not supported in dispatch code. - return false; - } else if (isa<AllocOp>(op)) { - // Allocations are sequencer ops. - // Note that we could support static allocations (convert to stack/etc). - return false; - } else if (isa<ConstantOp>(op)) { - // Constants are handled in the RematerializeDispatchConstants pass. - // We do that independently so that we can more easily see the use of - // constants across all dispatches instead of just on an individual basis - // as we do here. - return false; - } else if (isa<xla_hlo::DynamicUpdateSliceOp>(op)) { - // TODO(benvanik): lower these to the sequencer dialect prior to ID'ing. - return false; - } - return true; -} - -// Returns true if the given |op| can have other ops fused into it. -// This is sketchy and it'd be nice to define this as an op property instead. -// -// What we are looking for in foldable ops is whether the execution of the op -// when fused has some possible benefit (or at least, a non-negative cost). -// Eventually we want to allow backends to vote on this and allow multiple -// folding strategies within the same executable. For now we just hardcode what -// we know for the ops we have. -// -// Preconditions: isDispatchableOp(op) == true. -bool isFusionRootOp(Operation *op) { - if (isa<xla_hlo::DotOp>(op) || isa<xla_hlo::ConvOp>(op)) { - // We have hand-written kernels for these right now we want to stand alone. - // When we do a bit more magic we should allow these ops to fold. - return false; - } - return true; -} - -// Returns true if the given |op| can be fused into other ops. -// -// Ops that perform narrowing on shapes (such as reduction ops) should not -// generally be fused with other downstream ops (probably...). This avoids -// potential oversampling and indexing issues and allows backends to perform -// more efficient rooted cascading reduction dispatches. -// -// Preconditions: isDispatchableOp(op) == true. -bool isFusableOp(Operation *op) { - if (isa<xla_hlo::DotOp>(op) || isa<xla_hlo::ConvOp>(op)) { - return false; - } else if (isa<xla_hlo::ReduceOp>(op)) { - // Reduction is usually a dedicated root operation - we can shove things in - // the front of it but not behind. - return false; - } - return true; -} - -// Puts all of the |unsortedOps| into |sortedOps| in an arbitrary topological -// order. -// https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search -// -// Preconditions: |unsortedOps| has no cycles within the set of ops. -std::vector<Operation *> sortOpsTopologically( - const llvm::SetVector<Operation *> &unsortedOps) { - llvm::SetVector<Operation *> unmarkedOps; - unmarkedOps.insert(unsortedOps.begin(), unsortedOps.end()); - llvm::SetVector<Operation *> markedOps; - - using VisitFn = std::function<void(Operation * op)>; - VisitFn visit = [&](Operation *op) { - if (markedOps.count(op) > 0) return; - for (auto *result : op->getResults()) { - for (auto *user : result->getUsers()) { - // Don't visit ops not in our set. - if (unsortedOps.count(user) == 0) continue; - visit(user); - } - } - markedOps.insert(op); - }; - - while (!unmarkedOps.empty()) { - auto *op = unmarkedOps.pop_back_val(); - visit(op); - } - - auto sortedOps = markedOps.takeVector(); - std::reverse(sortedOps.begin(), sortedOps.end()); - return sortedOps; -} - -// Recursively traverses the IR DAG along the operand edges to find ops we are -// able to fuse and appends them to |subgraph|. -void gatherFusionOps(Operation *op, llvm::SetVector<Operation *> *subgraph) { - // Skip ops that are used outside of the subgraph we are building. - for (auto *result : op->getResults()) { - if (result->use_empty() || result->hasOneUse()) continue; - for (auto *user : result->getUsers()) { - if (subgraph->count(user) == 0) { - // Op that consumes the result is not (yet) in the subgraph. - // For now we'll ignore these as it may represent a fork that we don't - // want to join too early. - return; - } - } - } - - // Walk backward up to ops providing our input operands. - for (auto *operand : op->getOperands()) { - auto *sourceOp = operand->getDefiningOp(); - if (!sourceOp) continue; - if (subgraph->count(sourceOp) == 0) { - if (isDispatchableOp(sourceOp) && isFusableOp(sourceOp)) { - gatherFusionOps(sourceOp, subgraph); - } - } - } - - subgraph->insert(op); -} - -// Finds all ops that can be fused together with the given |rootOp| by searching -// backwards in the op order through input edges. -// Returns a topologically sorted list of all fused ops with |rootOp| at the -// end. -std::vector<Operation *> findFusionSubgraphFromRoot(Operation *rootOp) { - if (!isFusionRootOp(rootOp)) { - return {rootOp}; - } - llvm::SetVector<Operation *> subgraph; - subgraph.insert(rootOp); - gatherFusionOps(rootOp, &subgraph); - return sortOpsTopologically(subgraph); -} - -// Identifies ranges of dispatchable ops and moves them into dispatch regions. -LogicalResult identifyBlockDispatchRegions(FuncOp func, Block *block) { - // Fixed point iteration until we can no longer fuse anything. - bool didFindAnyNewRegions; - do { - // Iterate in reverse so we root further along in the op list. - didFindAnyNewRegions = false; - for (auto &rootOp : llvm::reverse(*block)) { - if (!isDispatchableOp(&rootOp)) { - // Op should remain at the sequencer level. - continue; - } - - // Attempt to find all operations, including rootOp, that can be fused. - // The ops will be sorted in topological order with rootOp as the last op. - // Worst case we may end up with a subgraph of only the rootOp. - auto fusedSubgraph = findFusionSubgraphFromRoot(&rootOp); - - // Compute the workload based on the output shape. - // When variadic all output shapes match so we can just take the first. - auto *workload = calculateWorkload(&rootOp, rootOp.getResult(0)); - - // Try to build a dispatch region from this root. - if (failed(buildDispatchRegion(func, block, workload, fusedSubgraph))) { - return failure(); - } - - // Successfully created a dispatch region from the ops and we must now - // start over again as we've likely trashed the whole block structure. - didFindAnyNewRegions = true; - break; - } - } while (didFindAnyNewRegions); - return success(); -} - -} // namespace - -// Identifies dispatchable ops and moves them into iree.dispatch_regions. -// Some ops, such as call, will be deferred until following passes. -class IdentifyDispatchRegionsPass - : public FunctionPass<IdentifyDispatchRegionsPass> { - public: - void runOnFunction() override { - auto func = getFunction(); - for (auto &block : func) { - if (failed(identifyBlockDispatchRegions(func, &block))) { - return signalPassFailure(); - } - } - } -}; - -std::unique_ptr<OpPassBase<FuncOp>> createIdentifyDispatchRegionsPass() { - return std::make_unique<IdentifyDispatchRegionsPass>(); -} - -static PassRegistration<IdentifyDispatchRegionsPass> pass( - "iree-identify-dispatch-regions", - "Conservatively identifies dispatch regions in functions."); - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/IdentifyReductionRegions.cpp b/iree/compiler/Transforms/Sequencer/IdentifyReductionRegions.cpp deleted file mode 100644 index 40d4ac2..0000000 --- a/iree/compiler/Transforms/Sequencer/IdentifyReductionRegions.cpp +++ /dev/null
@@ -1,163 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <algorithm> - -#include "iree/compiler/IR/Ops.h" -#include "iree/compiler/IR/Types.h" -#include "iree/compiler/Utils/DispatchUtils.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/Utils.h" -#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -// Builds a new iree.reduction_region with the given |invocationRegion|. -// The new region will be inserted after |originalOp|. -// -// All |invocationRegion| ops must be compatible with the |workload| specified -// as they will all be dispatched with the same workgroup structure. The -// |invocationRegion| will not be modified. -LogicalResult buildReductionRegion(Operation *originalOp, - ArrayRef<Value *> operands, - ArrayRef<Value *> initialValues, - ArrayRef<int64_t> dimensions, - Region &invocationRegion) { - OpBuilder parentBuilder(originalOp); - - // Compute the workload based on the output shape. - // When variadic all output shapes match so we can just take the first. - auto *workload = calculateWorkload(originalOp, originalOp->getResult(0)); - - // Build the region op and add it to the parent block. - SmallVector<Type, 4> resultTypes{originalOp->getResultTypes()}; - auto reductionRegionOp = parentBuilder.create<IREE::ReductionRegionOp>( - originalOp->getLoc(), resultTypes, workload, operands, initialValues, - dimensions); - - // Create the block and setup the arg mapping for captured values. - BlockAndValueMapping mapping; - invocationRegion.cloneInto(&reductionRegionOp.getBody(), mapping); - - // Replace xla_hlo.return -> iree.return. - OpBuilder regionBuilder(reductionRegionOp.getBody()); - reductionRegionOp.walk([&](xla_hlo::ReturnOp returnOp) { - regionBuilder.setInsertionPoint(returnOp); - SmallVector<Value *, 4> returnValues(returnOp.getOperands()); - regionBuilder.create<IREE::ReturnOp>(returnOp.getLoc(), returnValues); - returnOp.erase(); - }); - - // Replace usage of values with the results of the region. - for (int i = 0; i < originalOp->getNumResults(); ++i) { - originalOp->getResult(i)->replaceAllUsesWith( - reductionRegionOp.getResult(i)); - } - - return success(); -} - -// Converts an xla_hlo::ReduceOp to a reduction region and inlines the target -// computation into the region body. -LogicalResult buildReductionRegionFromXLAReduceOp(xla_hlo::ReduceOp reduceOp) { - SmallVector<Value *, 4> operands(reduceOp.getOperands()); - OperandAdaptor<xla_hlo::ReduceOp> adaptor(operands); - - SmallVector<int64_t, 4> dimensions; - for (auto dim : reduceOp.dimensions().getIntValues()) { - dimensions.push_back(dim.getSExtValue()); - } - - // Create the iree.reduction_region. - if (failed(buildReductionRegion(reduceOp, adaptor.operands(), - adaptor.init_values(), dimensions, - reduceOp.body()))) { - return failure(); - } - - // Remove original XLA reduction op. - reduceOp.erase(); - - return success(); -} - -// Identifies reduction ops and moves them into reduction regions. -LogicalResult identifyBlockReductionRegions(FuncOp funcOp, Block *block) { - // Fixed point iteration until we can no longer fuse anything. - bool didFindAnyNewRegions; - do { - // Iterate in reverse so we root further along in the op list. - didFindAnyNewRegions = false; - for (auto &rootOp : llvm::reverse(*block)) { - if (auto reduceOp = dyn_cast<xla_hlo::ReduceOp>(rootOp)) { - if (failed(buildReductionRegionFromXLAReduceOp(reduceOp))) { - return failure(); - } - - // Successfully created a dispatch region from the ops and we must now - // start over again as we've likely trashed the whole block structure. - didFindAnyNewRegions = true; - break; - } - } - } while (didFindAnyNewRegions); - return success(); -} - -} // namespace - -// Identifies reduction ops and moves their targets into iree.reduction_regions. -class IdentifyReductionRegionsPass - : public ModulePass<IdentifyReductionRegionsPass> { - public: - void runOnModule() override { - for (auto funcOp : getModule().getOps<FuncOp>()) { - for (auto &block : funcOp) { - if (failed(identifyBlockReductionRegions(funcOp, &block))) { - return signalPassFailure(); - } - } - } - } -}; - -std::unique_ptr<OpPassBase<ModuleOp>> createIdentifyReductionRegionsPass() { - return std::make_unique<IdentifyReductionRegionsPass>(); // NOLINT -} - -static PassRegistration<IdentifyReductionRegionsPass> pass( - "iree-identify-reduction-regions", - "Identifies reduction regions based on input reduction ops."); - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/LegalizeInputs.cpp b/iree/compiler/Transforms/Sequencer/LegalizeInputs.cpp deleted file mode 100644 index 5d79791..0000000 --- a/iree/compiler/Transforms/Sequencer/LegalizeInputs.cpp +++ /dev/null
@@ -1,53 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Dialect.h" -#include "iree/compiler/Transforms/Rewrites.h" -#include "iree/compiler/Transforms/Sequencer/Rewrites.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" -#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" -#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" - -namespace mlir { -namespace iree_compiler { -namespace { - -class LegalizeInputOpsPass - : public FunctionPass<LegalizeInputOpsPass> { - public: - void runOnFunction() override { - OwningRewritePatternList patterns; - xla_hlo::PopulateGeneralDotOpLoweringPatterns(&patterns, &getContext()); - - ConversionTarget target(getContext()); - target.addLegalDialect<xla_hlo::XlaHloDialect, StandardOpsDialect>(); - target.addLegalOp<FuncOp, ReturnOp>(); - target.addIllegalOp<xla_hlo::DotGeneralOp, xla_hlo::WhileOp>(); - if (failed(applyFullConversion(getFunction(), target, patterns))) { - return signalPassFailure(); - } - } -}; - -} // namespace - -std::unique_ptr<OpPassBase<FuncOp>> createLegalizeInputOpsPass() { - return std::make_unique<LegalizeInputOpsPass>(); -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/LowerSequencerDialect.cpp b/iree/compiler/Transforms/Sequencer/LowerSequencerDialect.cpp deleted file mode 100644 index ea3efa6..0000000 --- a/iree/compiler/Transforms/Sequencer/LowerSequencerDialect.cpp +++ /dev/null
@@ -1,305 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Dialect.h" -#include "iree/compiler/IR/Ops.h" -#include "iree/compiler/IR/Sequencer/HLDialect.h" -#include "iree/compiler/IR/Sequencer/HLOps.h" -#include "iree/compiler/IR/Sequencer/LLDialect.h" -#include "iree/compiler/IR/Sequencer/LLOps.h" -#include "iree/compiler/IR/StructureOps.h" -#include "iree/compiler/Utils/TypeConversionUtils.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Module.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/Utils.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -template <typename SrcOp> -class SequencerLoweringPattern : public OpConversionPattern<SrcOp> { - public: - SequencerLoweringPattern(MLIRContext *context, TypeConverter &typeConverter) - : OpConversionPattern<SrcOp>(context), typeConverter_(typeConverter) {} - - protected: - TypeConverter &typeConverter_; -}; - -// Returns an integer scalar memref containing the offset specified by |indices| -// within |type|. -Value *computeOffset(Location loc, Value *reference, Value *indices, - OpBuilder &builder) { - auto referenceType = reference->getType().cast<ShapedType>(); - auto *shapeMemRef = builder - .create<IREESeq::LL::AllocHeapOp>( - loc, - MemRefType::get({referenceType.getRank()}, - builder.getIntegerType(32)), - ArrayRef<Value *>{}) - .getResult(); - builder.create<IREESeq::LL::ShapeOp>(loc, reference, shapeMemRef); - auto *resultMemRef = - builder - .create<IREESeq::LL::AllocHeapOp>( - loc, MemRefType::get({}, builder.getIntegerType(32)), - ArrayRef<Value *>{}) - .getResult(); - auto elementSizeAttr = builder.getIntegerAttr( - builder.getIntegerType(8), referenceType.getElementTypeBitWidth() / 8); - builder.create<IREESeq::LL::ComputeOffsetOp>( - loc, shapeMemRef, elementSizeAttr, indices, resultMemRef); - return resultMemRef; -} - -// Returns a tuple of (offset, length) integer scalar memrefs with the range -// specified by |indices| and |lengths| within |type|. -std::pair<Value *, Value *> computeRange(Location loc, Value *reference, - Value *indices, Value *lengths, - OpBuilder &builder) { - auto referenceType = reference->getType().cast<ShapedType>(); - auto *shapeMemRef = builder - .create<IREESeq::LL::AllocHeapOp>( - loc, - MemRefType::get({referenceType.getRank()}, - builder.getIntegerType(32)), - ArrayRef<Value *>{}) - .getResult(); - builder.create<IREESeq::LL::ShapeOp>(loc, reference, shapeMemRef); - auto *offsetMemRef = - builder - .create<IREESeq::LL::AllocHeapOp>( - loc, MemRefType::get({}, builder.getIntegerType(32)), - ArrayRef<Value *>{}) - .getResult(); - auto *lengthMemRef = - builder - .create<IREESeq::LL::AllocHeapOp>( - loc, MemRefType::get({}, builder.getIntegerType(32)), - ArrayRef<Value *>{}) - .getResult(); - auto elementSizeAttr = builder.getIntegerAttr( - builder.getIntegerType(8), referenceType.getElementTypeBitWidth() / 8); - builder.create<IREESeq::LL::ComputeRangeOp>(loc, shapeMemRef, elementSizeAttr, - indices, lengths, offsetMemRef, - lengthMemRef); - return {offsetMemRef, lengthMemRef}; -} - -struct LowerSliceOpPattern - : public SequencerLoweringPattern<IREESeq::HL::SliceOp> { - using SequencerLoweringPattern::SequencerLoweringPattern; - - PatternMatchResult matchAndRewrite( - IREESeq::HL::SliceOp op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - OperandAdaptor<IREESeq::HL::SliceOp> operandAdaptor(operands); - auto range = computeRange(op.getLoc(), operandAdaptor.src(), - operandAdaptor.indices(), - operandAdaptor.lengths(), rewriter); - rewriter.replaceOpWithNewOp<IREESeq::LL::DynamicSliceOp>( - op, typeConverter_.convertType(op.getType()), - ArrayRef<Value *>{operandAdaptor.src(), range.first, range.second}, - op.getAttrs()); - return matchSuccess(); - } -}; - -struct LowerShapeOpPattern - : public SequencerLoweringPattern<IREESeq::HL::ShapeOp> { - using SequencerLoweringPattern::SequencerLoweringPattern; - - PatternMatchResult matchAndRewrite( - IREESeq::HL::ShapeOp op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - auto *shapeMemRef = - rewriter - .create<IREESeq::LL::AllocHeapOp>( - op.getLoc(), - MemRefType::get({op.getType().cast<ShapedType>().getRank()}, - rewriter.getIntegerType(64)), - ArrayRef<Value *>{}) - .getResult(); - op.replaceAllUsesWith(shapeMemRef); - rewriter.replaceOpWithNewOp<IREESeq::LL::ShapeOp>(op, operands[0], - shapeMemRef); - return matchSuccess(); - } -}; - -struct LowerCopyOpPattern - : public SequencerLoweringPattern<IREESeq::HL::CopyOp> { - using SequencerLoweringPattern::SequencerLoweringPattern; - - PatternMatchResult matchAndRewrite( - IREESeq::HL::CopyOp op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - OperandAdaptor<IREESeq::HL::CopyOp> operandAdaptor(operands); - auto *srcOffsetMemRef = - computeOffset(op.getLoc(), operandAdaptor.src(), - operandAdaptor.srcIndices(), rewriter); - auto dstRange = computeRange(op.getLoc(), operandAdaptor.dst(), - operandAdaptor.dstIndices(), - operandAdaptor.lengths(), rewriter); - rewriter.replaceOpWithNewOp<IREESeq::LL::DynamicCopyOp>( - op, operandAdaptor.src(), srcOffsetMemRef, operandAdaptor.dst(), - dstRange.first, dstRange.second); - return matchSuccess(); - } -}; - -struct LowerFillOpPattern - : public SequencerLoweringPattern<IREESeq::HL::FillOp> { - using SequencerLoweringPattern::SequencerLoweringPattern; - - PatternMatchResult matchAndRewrite( - IREESeq::HL::FillOp op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - OperandAdaptor<IREESeq::HL::FillOp> operandAdaptor(operands); - auto dstRange = computeRange(op.getLoc(), operandAdaptor.dst(), - operandAdaptor.dstIndices(), - operandAdaptor.lengths(), rewriter); - rewriter.replaceOpWithNewOp<IREESeq::LL::DynamicFillOp>( - op, operandAdaptor.value(), operandAdaptor.dst(), dstRange.first, - dstRange.second); - return matchSuccess(); - } -}; - -struct LowerBranchOpPattern - : public SequencerLoweringPattern<IREESeq::HL::BranchOp> { - using SequencerLoweringPattern< - IREESeq::HL::BranchOp>::SequencerLoweringPattern; - - PatternMatchResult matchAndRewrite( - IREESeq::HL::BranchOp op, ArrayRef<Value *> properOperands, - ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp<IREESeq::LL::BranchOp>(op, destinations[0], - operands[0]); - return matchSuccess(); - } -}; - -struct LowerCondCondBranchOpPattern - : public SequencerLoweringPattern<IREESeq::HL::CondBranchOp> { - using SequencerLoweringPattern< - IREESeq::HL::CondBranchOp>::SequencerLoweringPattern; - - PatternMatchResult matchAndRewrite( - IREESeq::HL::CondBranchOp op, ArrayRef<Value *> properOperands, - ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp<IREESeq::LL::CondBranchOp>( - op, properOperands[0], - destinations[IREESeq::HL::CondBranchOp::trueIndex], - operands[IREESeq::HL::CondBranchOp::trueIndex], - destinations[IREESeq::HL::CondBranchOp::falseIndex], - operands[IREESeq::HL::CondBranchOp::falseIndex]); - return matchSuccess(); - } -}; - -// Rewrites an op into one with all the same operands, results, and attributes. -// Operands and results in the ops must have the same order and attributes must -// have the same name. They must also be constructed properly by the default -// builders. -template <typename SRC, typename DST> -struct LowerIdenticalOpPattern : public SequencerLoweringPattern<SRC> { - using SequencerLoweringPattern<SRC>::SequencerLoweringPattern; - - PatternMatchResult matchAndRewrite( - SRC op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - SmallVector<Type, 8> originalResultTypes{ - op.getOperation()->getResultTypes()}; - SmallVector<Type, 8> resultTypes; - if (failed(this->typeConverter_.convertTypes(originalResultTypes, - resultTypes))) { - op.emitOpError() << "Failed to convert result types"; - return this->matchFailure(); - } - rewriter.replaceOpWithNewOp<DST>(op, resultTypes, operands, op.getAttrs()); - return this->matchSuccess(); - } -}; - -} // namespace - -class LowerSequencerDialectPass : public ModulePass<LowerSequencerDialectPass> { - public: - void runOnModule() override { - auto *ctx = &getContext(); - LLTypeConverter typeConverter(ctx); - OwningRewritePatternList patterns; - patterns.insert< - LowerIdenticalOpPattern<IREE::ConstantOp, IREESeq::LL::ConstantOp>, - LowerIdenticalOpPattern<IREESeq::HL::DispatchOp, - IREESeq::LL::DynamicDispatchOp>, - LowerShapeOpPattern, LowerCopyOpPattern, LowerSliceOpPattern, - LowerBranchOpPattern, LowerCondCondBranchOpPattern>(ctx, typeConverter); -#define IDENTICAL_OP_LOWERING(op_name) \ - LowerIdenticalOpPattern<IREESeq::HL::op_name, IREESeq::LL::op_name> - patterns.insert< - IDENTICAL_OP_LOWERING(AllocHeapOp), IDENTICAL_OP_LOWERING(CloneOp), - IDENTICAL_OP_LOWERING(ReshapeOp), IDENTICAL_OP_LOWERING(CallOp), - IDENTICAL_OP_LOWERING(ReturnOp)>(ctx, typeConverter); -#undef IDENTICAL_OP_LOWERING - - mlir::populateFuncOpTypeConversionPattern(patterns, ctx, typeConverter); - ConversionTarget target(*ctx); - target.addLegalDialect<IREELLSequencerDialect>(); - target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { - return typeConverter.isSignatureLegal(op.getType()); - }); - - // TODO(b/142791494): The conversion framework will recurse into the - // executable if we just call it on the top-level module. This can't be a - // function pass because type conversion replaces the original functions. - auto funcsIt = getModule().getOps<FuncOp>(); - SmallVector<Operation *, 4> funcs(funcsIt.begin(), funcsIt.end()); - - if (failed(applyFullConversion(funcs, target, patterns, &typeConverter))) { - return signalPassFailure(); - } - } -}; - -std::unique_ptr<OpPassBase<ModuleOp>> createLowerSequencerDialectPass() { - return std::make_unique<LowerSequencerDialectPass>(); -} - -static PassRegistration<LowerSequencerDialectPass> pass( - "iree-lower-sequencer-dialect", - "Lowers the IREE HL sequencer dialect to the LL sequencer dialect."); - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/LowerStdToSequencerDialect.cpp b/iree/compiler/Transforms/Sequencer/LowerStdToSequencerDialect.cpp deleted file mode 100644 index 3f839ab..0000000 --- a/iree/compiler/Transforms/Sequencer/LowerStdToSequencerDialect.cpp +++ /dev/null
@@ -1,204 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Dialect.h" -#include "iree/compiler/IR/Ops.h" -#include "iree/compiler/IR/Sequencer/HLDialect.h" -#include "iree/compiler/IR/Sequencer/HLOps.h" -#include "iree/compiler/IR/StructureOps.h" -#include "iree/compiler/Utils/MemRefUtils.h" -#include "iree/compiler/Utils/OpCreationUtils.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/Utils.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -template <typename T> -class SequencerConversionPattern : public OpConversionPattern<T> { - using OpConversionPattern<T>::OpConversionPattern; -}; - -struct CallOpLowering : public SequencerConversionPattern<CallOp> { - using SequencerConversionPattern::SequencerConversionPattern; - - PatternMatchResult matchAndRewrite( - CallOp callOp, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - SmallVector<Type, 4> resultTypes(callOp.getResultTypes()); - rewriter.replaceOpWithNewOp<IREESeq::HL::CallOp>(callOp, callOp.getCallee(), - resultTypes, operands); - - return matchSuccess(); - } -}; - -struct CallIndirectOpLowering - : public SequencerConversionPattern<CallIndirectOp> { - using SequencerConversionPattern::SequencerConversionPattern; - - PatternMatchResult matchAndRewrite( - CallIndirectOp callOp, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp<IREESeq::HL::CallIndirectOp>( - callOp, callOp.getCallee(), operands); - return matchSuccess(); - } -}; - -struct ReturnOpLowering : public SequencerConversionPattern<ReturnOp> { - using SequencerConversionPattern::SequencerConversionPattern; - - PatternMatchResult matchAndRewrite( - ReturnOp returnOp, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - SmallVector<Value *, 4> newOperands; - newOperands.reserve(operands.size()); - for (auto *operand : operands) { - newOperands.push_back(wrapAsMemRef(operand, returnOp, rewriter)); - } - rewriter.replaceOpWithNewOp<IREESeq::HL::ReturnOp>(returnOp, newOperands); - return matchSuccess(); - } -}; - -struct BranchOpLowering : public SequencerConversionPattern<BranchOp> { - using SequencerConversionPattern::SequencerConversionPattern; - PatternMatchResult matchAndRewrite( - BranchOp branchOp, ArrayRef<Value *> properOperands, - ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp<IREESeq::HL::BranchOp>( - branchOp, destinations[0], operands[0]); - return this->matchSuccess(); - } -}; - -struct CondBranchOpLowering : public SequencerConversionPattern<CondBranchOp> { - using SequencerConversionPattern::SequencerConversionPattern; - PatternMatchResult matchAndRewrite( - CondBranchOp condBranchOp, ArrayRef<Value *> properOperands, - ArrayRef<Block *> destinations, ArrayRef<ArrayRef<Value *>> operands, - ConversionPatternRewriter &rewriter) const override { - auto *condValue = - loadAccessValue(condBranchOp.getLoc(), properOperands[0], rewriter); - rewriter.replaceOpWithNewOp<IREESeq::HL::CondBranchOp>( - condBranchOp, condValue, - destinations[IREESeq::HL::CondBranchOp::trueIndex], - operands[IREESeq::HL::CondBranchOp::trueIndex], - destinations[IREESeq::HL::CondBranchOp::falseIndex], - operands[IREESeq::HL::CondBranchOp::falseIndex]); - return this->matchSuccess(); - } -}; - -struct AllocOpLowering : public SequencerConversionPattern<AllocOp> { - using SequencerConversionPattern::SequencerConversionPattern; - PatternMatchResult matchAndRewrite( - AllocOp allocOp, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - // TODO(benvanik): replace with length computation. - rewriter.replaceOpWithNewOp<IREESeq::HL::AllocHeapOp>( - allocOp, allocOp.getType(), operands); - return matchSuccess(); - } -}; - -struct DeallocOpLowering : public SequencerConversionPattern<DeallocOp> { - using SequencerConversionPattern::SequencerConversionPattern; - PatternMatchResult matchAndRewrite( - DeallocOp deallocOp, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp<IREESeq::HL::DiscardOp>(deallocOp, operands[0]); - return matchSuccess(); - } -}; - -struct LoadOpLowering : public SequencerConversionPattern<LoadOp> { - using SequencerConversionPattern::SequencerConversionPattern; - PatternMatchResult matchAndRewrite( - LoadOp loadOp, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - if (loadOp.getMemRefType().getRank() != 0) { - loadOp.emitError() << "Cannot lower load of non-scalar"; - return matchFailure(); - } - ArrayRef<Value *> dimPieces; - auto dst = rewriter.create<AllocOp>(loadOp.getLoc(), loadOp.getMemRefType(), - dimPieces); - auto emptyArrayMemref = createArrayConstant(rewriter, loadOp.getLoc(), {}); - rewriter.create<IREESeq::HL::CopyOp>(loadOp.getLoc(), loadOp.getMemRef(), - /*srcIndices=*/emptyArrayMemref, dst, - /*dstIndices=*/emptyArrayMemref, - /*lengths=*/emptyArrayMemref); - - rewriter.replaceOpWithNewOp<IREE::MemRefToScalarOp>(loadOp, dst); - - return matchSuccess(); - } -}; - -struct StoreOpLowering : public SequencerConversionPattern<StoreOp> { - using SequencerConversionPattern::SequencerConversionPattern; - PatternMatchResult matchAndRewrite( - StoreOp storeOp, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - if (storeOp.getMemRefType().getRank() != 0) { - storeOp.emitError() << "Cannot lower store of non-scalar"; - return matchFailure(); - } - - auto src = rewriter.create<IREE::ScalarToMemRefOp>( - storeOp.getLoc(), storeOp.getValueToStore()); - - auto emptyArrayMemref = createArrayConstant(rewriter, storeOp.getLoc(), {}); - rewriter.replaceOpWithNewOp<IREESeq::HL::CopyOp>( - storeOp, src, /*srcIndices=*/emptyArrayMemref, storeOp.getMemRef(), - /*dstIndices=*/emptyArrayMemref, /*lengths=*/emptyArrayMemref); - - return matchSuccess(); - } -}; - -} // namespace - -void populateLowerStdToSequencerPatterns(OwningRewritePatternList &patterns, - MLIRContext *context) { - patterns.insert< - // Control flow. - CallOpLowering, CallIndirectOpLowering, ReturnOpLowering, - BranchOpLowering, CondBranchOpLowering, - // Memory management. - AllocOpLowering, DeallocOpLowering, LoadOpLowering, StoreOpLowering>( - context); -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/LowerToSequencerDialect.cpp b/iree/compiler/Transforms/Sequencer/LowerToSequencerDialect.cpp deleted file mode 100644 index 658120b..0000000 --- a/iree/compiler/Transforms/Sequencer/LowerToSequencerDialect.cpp +++ /dev/null
@@ -1,62 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Dialect.h" -#include "iree/compiler/IR/Sequencer/HLDialect.h" -#include "iree/compiler/IR/Sequencer/LLDialect.h" -#include "iree/compiler/Transforms/Rewrites.h" -#include "iree/compiler/Transforms/Sequencer/Rewrites.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" -#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" - -namespace mlir { -namespace iree_compiler { -namespace { - -class LowerToSequencerDialectPass - : public FunctionPass<LowerToSequencerDialectPass> { - public: - void runOnFunction() override { - OwningRewritePatternList patterns; - auto* ctx = &getContext(); - xla_hlo::PopulateGeneralDotOpLoweringPatterns(&patterns, ctx); - xla_hlo::PopulateXlaToStdPatterns(&patterns, ctx); - populateLowerStdToIreePatterns(patterns, ctx); - populateLowerStdToSequencerPatterns(patterns, ctx); - populateLowerXlaToIreePatterns(patterns, ctx); - populateLowerXlaToSequencerPatterns(patterns, ctx); - - ConversionTarget target(getContext()); - target.addLegalDialect<IREEHLSequencerDialect, IREEDialect>(); - target.addLegalOp<FuncOp>(); - if (failed(applyFullConversion(getFunction(), target, patterns))) { - return signalPassFailure(); - } - } -}; - -} // namespace - -std::unique_ptr<OpPassBase<FuncOp>> createLowerToSequencerDialectPass() { - return std::make_unique<LowerToSequencerDialectPass>(); -} - -static PassRegistration<LowerToSequencerDialectPass> pass( - "lower-to-iree-sequencer", "Convert all ops to the IREE sequencer dialect"); - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/LowerXLAToSequencerDialect.cpp b/iree/compiler/Transforms/Sequencer/LowerXLAToSequencerDialect.cpp deleted file mode 100644 index 785015c..0000000 --- a/iree/compiler/Transforms/Sequencer/LowerXLAToSequencerDialect.cpp +++ /dev/null
@@ -1,224 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/IR/Dialect.h" -#include "iree/compiler/IR/Ops.h" -#include "iree/compiler/IR/Sequencer/HLDialect.h" -#include "iree/compiler/IR/Sequencer/HLOps.h" -#include "iree/compiler/IR/StructureOps.h" -#include "iree/compiler/Transforms/ConversionUtils.h" -#include "iree/compiler/Utils/MemRefUtils.h" -#include "iree/compiler/Utils/OpCreationUtils.h" -#include "iree/compiler/Utils/TypeConversionUtils.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Function.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/IR/Value.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" -#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -// TODO(suderman): tablegen this? or something a bit more flexible. - -#define UNARY_OP_LOWERING(XlaOpType, IREEOpType) \ - struct XlaOpType##Lowering \ - : public UnaryOpLowering<xla_hlo::XlaOpType, IREEOpType> { \ - using UnaryOpLowering::UnaryOpLowering; \ - }; - -#define TERNARY_OP_LOWERING(XlaOpType, IREEOpType) \ - struct XlaOpType##Lowering \ - : public TernaryOpLowering<xla_hlo::XlaOpType, IREEOpType> { \ - using TernaryOpLowering::TernaryOpLowering; \ - }; - -UNARY_OP_LOWERING(CopyOp, IREESeq::HL::CloneOp); - -#undef UNARY_OP_LOWERING -#undef TERNARY_OP_LOWERING - -template <typename T> -static Operation *createShapeTargetingOp(ConversionPatternRewriter &rewriter, - Location loc, Value *input, - MemRefType targetType) { - auto shapeOp = createArrayConstant(rewriter, loc, targetType.getShape()); - return rewriter.create<T>(loc, targetType, input, shapeOp); -} - -static Value *inputAsMemref(ConversionPatternRewriter &rewriter, Operation *op, - Value *tensor) { - return wrapAsMemRef(loadAccessValue(op->getLoc(), tensor, rewriter), op, - rewriter); -} - -template <typename SrcOp> -class XlaOpLowering : public OpConversionPattern<SrcOp> { - public: - using OpConversionPattern<SrcOp>::OpConversionPattern; - - PatternMatchResult matchAndRewrite( - SrcOp op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - auto srcOp = cast<SrcOp>(op); - - SmallVector<Value *, 4> memrefOperands; - for (auto operand : operands) { - memrefOperands.push_back(inputAsMemref(rewriter, op, operand)); - } - - auto dstOp = rewriteInternal(&srcOp, memrefOperands, rewriter); - rewriter.replaceOp(op, wrapAsTensor(dstOp->getResult(0), srcOp, rewriter)); - return this->matchSuccess(); - } - - protected: - virtual Operation *rewriteInternal( - SrcOp *op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const { - llvm_unreachable("unimplemented rewrite, did you mean rewriteTerminator?"); - } -}; - -struct ConcatOpLowering : public XlaOpLowering<xla_hlo::ConcatenateOp> { - using XlaOpLowering::XlaOpLowering; - - Operation *rewriteInternal( - xla_hlo::ConcatenateOp *op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - auto finalType = convertTypeToMemRef(*op); - - return rewriter.create<IREESeq::HL::ConcatOp>( - op->getLoc(), finalType, operands, - rewriter.getI32IntegerAttr(op->dimension().getZExtValue())); - } -}; - -struct DynamicUpdateSliceLowering - : public XlaOpLowering<xla_hlo::DynamicUpdateSliceOp> { - using XlaOpLowering::XlaOpLowering; - - Operation *rewriteInternal( - xla_hlo::DynamicUpdateSliceOp *op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - auto operand = operands[0]; - auto update = operands[1]; - - auto updateType = update->getType().cast<ShapedType>(); - Value *lengthConstant = - createArrayConstant(rewriter, op->getLoc(), updateType.getShape()); - - auto startIndices = makeArrayRef(operands).drop_front(2); - const int rank = startIndices.size(); - llvm::SmallVector<Value *, 4> valuesToConcat; - valuesToConcat.reserve(startIndices.size()); - auto type = getElementTypeOrSelf(startIndices.front()); - - // To generate the offset matrix we need to convert the variadic tensors - // into a reshaped and concated value. - for (auto index : startIndices) { - auto reshapedIndex = rewriter.create<IREESeq::HL::ReshapeOp>( - op->getLoc(), MemRefType::get({1}, type), index, - createArrayConstant(rewriter, op->getLoc(), {1})); - valuesToConcat.push_back(reshapedIndex); - } - - auto dstOffset = rewriter - .create<IREESeq::HL::ConcatOp>( - op->getLoc(), MemRefType::get({rank}, type), - valuesToConcat, rewriter.getI32IntegerAttr(0)) - .getResult(); - - llvm::SmallVector<int64_t, 4> zero_offset; - zero_offset.resize(updateType.getRank(), 0); - auto srcOffset = createArrayConstant(rewriter, op->getLoc(), zero_offset); - - auto copiedOperand = rewriter.create<IREESeq::HL::CloneOp>( - op->getLoc(), operand->getType(), operand); - - rewriter - .create<IREESeq::HL::CopyOp>(op->getLoc(), update, srcOffset, - copiedOperand, dstOffset, lengthConstant) - .getOperation(); - - return copiedOperand; - } -}; - -struct SliceLowering : public XlaOpLowering<xla_hlo::SliceOp> { - using XlaOpLowering::XlaOpLowering; - Operation *rewriteInternal( - xla_hlo::SliceOp *op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - // XLA slice has value semantics, whereas the IREE slice creates a view. We - // lower it to a copy if all strides are one which may be transformed to a - // slice by later optimizations. - auto isNotOne = [](APInt stride) { return stride != 1; }; - if (llvm::any_of(op->strides(), isNotOne)) { - op->emitRemark() << "Could not lower slice op with non-singular strides"; - return nullptr; - } - - auto finalType = convertTypeToMemRef(*op); - - auto src = operands[0]; - std::vector<Value *> dim_pieces; - auto dst = rewriter.create<IREESeq::HL::AllocHeapOp>(op->getLoc(), - finalType, dim_pieces); - auto srcIndices = - rewriter.create<IREE::ConstantOp>(op->getLoc(), op->start_indices()); - auto lengths = - createArrayConstant(rewriter, op->getLoc(), finalType.getShape()); - - llvm::SmallVector<int64_t, 4> zero_offset; - zero_offset.resize(finalType.getRank(), 0); - auto dstIndices = createArrayConstant(rewriter, op->getLoc(), zero_offset); - - rewriter.create<IREESeq::HL::CopyOp>(op->getLoc(), src, srcIndices, dst, - dstIndices, lengths); - return dst; - } -}; - -struct ReshapeOpLowering : public XlaOpLowering<xla_hlo::ReshapeOp> { - using XlaOpLowering::XlaOpLowering; - - Operation *rewriteInternal( - xla_hlo::ReshapeOp *op, ArrayRef<Value *> operands, - ConversionPatternRewriter &rewriter) const override { - return createShapeTargetingOp<IREESeq::HL::ReshapeOp>( - rewriter, op->getLoc(), operands[0], convertTypeToMemRef(*op)); - } -}; - -} // namespace - -void populateLowerXlaToSequencerPatterns(OwningRewritePatternList &patterns, - MLIRContext *ctx) { - patterns.insert<ConcatOpLowering, CopyOpLowering, DynamicUpdateSliceLowering, - ReshapeOpLowering, SliceLowering>(ctx); -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/OutlineDispatchRegions.cpp b/iree/compiler/Transforms/Sequencer/OutlineDispatchRegions.cpp deleted file mode 100644 index fff3e8c..0000000 --- a/iree/compiler/Transforms/Sequencer/OutlineDispatchRegions.cpp +++ /dev/null
@@ -1,242 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <utility> - -#include "iree/compiler/IR/Ops.h" -#include "iree/compiler/IR/Sequencer/HLOps.h" -#include "iree/compiler/IR/StructureOps.h" -#include "iree/compiler/IR/Types.h" -#include "iree/compiler/Utils/DispatchUtils.h" -#include "iree/compiler/Utils/MemRefUtils.h" -#include "iree/compiler/Utils/TypeConversionUtils.h" -#include "llvm/ADT/SetVector.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/Utils.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -// Inserts a load from a wrapped memref (as inserted via insertDispatcherStore). -// Returns the value in the original type. -Value *insertDispatcheeLoad(Operation *op, Type originalType, Value *value, - OpBuilder &builder) { - // If old value was a memref we don't need to change anything. - if (originalType.isa<MemRefType>()) { - return value; - } - - auto loadInputOp = - builder.create<IREE::LoadInputOp>(op->getLoc(), originalType, value); - value->replaceAllUsesWith(loadInputOp.getResult()); - loadInputOp.setOperand(value); - return loadInputOp.getResult(); -} - -// Marshals args and results as buffers for the given region. -// Beyond inserting the appropriate tensor-to-memref ops we avoid mutating the -// interior of the dispatch region as much as possible. -LogicalResult marshalDispatchSite(IREE::DispatchRegionOp regionOp) { - auto &entryBlock = regionOp.getBody().getBlocks().front(); - OpBuilder dispatcherBuilder(regionOp); - OpBuilder dispatcheeBuilder(&entryBlock, entryBlock.begin()); - - // Wrap input operands and unwrap in the entry block. - SmallVector<Value *, 8> newArgs; - for (int i = 0; i < regionOp.getNumArgOperands(); ++i) { - // Wrap the input outside of the region. - auto *blockArg = entryBlock.getArgument(i); - Type originalType = blockArg->getType(); - auto *originalArg = regionOp.getArgOperand(i); - auto *wrappedArg = - insertDispatcherStore(regionOp, originalArg, dispatcherBuilder); - newArgs.push_back(wrappedArg); - blockArg->setType(wrappedArg->getType()); - - // Unwrap the block arg value and replace all of the uses with the newly - // unwrapped value. - insertDispatcheeLoad(regionOp, originalType, blockArg, dispatcheeBuilder); - } - - // Allocate output arguments and replace the return values with those. - SmallVector<Type, 8> newResults; - SmallVector<std::pair<int, Value *>, 8> resultIndicesToOutputArgs; - SmallVector<int, 8> deadResultIndices; - SmallVector<std::pair<Value *, Value *>, 8> replacedResults; - for (int i = 0; i < regionOp.getNumResults(); ++i) { - auto *result = regionOp.getResult(i); - auto convertedType = convertTypeToMemRef(result->getType()); - - // Allocate output buffer in the dispatcher to pass in to the region. - Value *allocatedValue = allocateDispatchOutputBuffer( - regionOp.getLoc(), convertedType, dispatcherBuilder); - if (!allocatedValue) { - regionOp.emitError("unable to allocate result value"); - return failure(); - } - newArgs.push_back(allocatedValue); - - auto *newBlockArg = entryBlock.addArgument(allocatedValue->getType()); - resultIndicesToOutputArgs.push_back({i, newBlockArg}); - - // NOTE: right now we always replace results. If we want to allow return - // values we can avoid killing them here. - deadResultIndices.push_back(i); - replacedResults.push_back({result, allocatedValue}); - } - - // Remove dead results from return statements. - regionOp.walk([&](IREE::ReturnOp returnOp) { - // Replace the results we were returning with stores to output arguments. - OpBuilder builder(returnOp); - for (auto resultToArg : resultIndicesToOutputArgs) { - auto *value = returnOp.getOperand(resultToArg.first); - auto *outputArg = resultToArg.second; - builder.create<IREE::StoreOutputOp>(returnOp.getLoc(), value, outputArg); - } - - // Filter out the results that are now dead. - SmallVector<Value *, 8> newOperands(returnOp.getOperands()); - for (int i = deadResultIndices.size() - 1; i >= 0; --i) { - newOperands.erase(newOperands.begin() + deadResultIndices[i]); - } - returnOp.getOperation()->setOperands(newOperands); - }); - - // Clone the region op with the new args/results. - auto newRegionOp = dispatcherBuilder.create<IREE::DispatchRegionOp>( - regionOp.getLoc(), newResults, regionOp.getWorkload(), newArgs); - newRegionOp.getBody().takeBody(regionOp.getBody()); - - // Marshal back the results by replacing uses of the original with loads from - // the new output arg. - for (auto &it : replacedResults) { - insertDispatcherLoad(regionOp, it.first, it.second, dispatcherBuilder); - } - - // Remove original region. - regionOp.erase(); - - return success(); -} - -// Converts a dispatch_region into a dispatch to the outlined region function. -LogicalResult convertToDispatchOp(IREE::DispatchRegionOp regionOp, - IREE::MultiArchExecutableOp executable, - FuncOp entryPoint) { - // Insert at the same place as the original region. - OpBuilder dispatcherBuilder(regionOp); - - // Ensure workload is a memref. - auto *workload = - wrapAsMemRef(regionOp.getWorkload(), regionOp, dispatcherBuilder); - - // Create the dispatch op to the executable function. - SmallVector<Value *, 8> operandValues(regionOp.getArgOperands()); - auto dispatchOp = dispatcherBuilder.create<IREESeq::HL::DispatchOp>( - regionOp.getLoc(), executable.getName(), entryPoint.getName(), workload, - entryPoint.getType().getResults(), operandValues); - - // Replace uses of the existing results with the new results. - for (int i = 0; i < regionOp.getNumResults(); ++i) { - regionOp.getResult(i)->replaceAllUsesWith(dispatchOp.getResult(i)); - } - - // Erase original region. - regionOp.erase(); - - return success(); -} - -// Outlines a dispatch region into an iree.multi_arch_executable. -LogicalResult outlineDispatchRegion(IREE::DispatchRegionOp regionOp, - int outlinedRegionOrdinal) { - // Build function type matching 1:1 with the region signature. - SmallVector<Type, 8> operandTypes; - for (auto *arg : regionOp.getArgOperands()) { - operandTypes.push_back(arg->getType()); - } - SmallVector<Type, 8> resultTypes(regionOp.getResultTypes()); - auto functionType = - FunctionType::get(operandTypes, resultTypes, regionOp.getContext()); - - // Create the executable with the region cloned into it. - IREE::MultiArchExecutableOp multiArchExecutable; - FuncOp outlinedFunc; - std::tie(multiArchExecutable, outlinedFunc) = createRegionExecutable( - regionOp, functionType, - "_dispatch_" + std::to_string(outlinedRegionOrdinal)); - outlinedFunc.setAttr("iree.executable.export", - UnitAttr::get(regionOp.getContext())); - - // Finally convert the dispatch region into a dispatch to the outlined func. - return convertToDispatchOp(regionOp, multiArchExecutable, outlinedFunc); -} - -} // namespace - -class OutlineDispatchRegionsPass - : public ModulePass<OutlineDispatchRegionsPass> { - public: - void runOnModule() override { - auto module = getModule(); - - ModuleManager moduleManager(module); - auto funcs = module.getOps<FuncOp>(); - SmallVector<FuncOp, 4> funcOps(funcs.begin(), funcs.end()); - for (auto func : funcOps) { - // Perform marshaling of the dispatcher and dispatchee I/O. - // This inserts the required stores and loads to make everything memrefs - // and adds the iree.load_input/iree.store_output ops to the dispatchee. - if (func.walk([&](IREE::DispatchRegionOp op) { - if (failed(marshalDispatchSite(op))) { - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }) - .wasInterrupted()) { - return signalPassFailure(); - } - - // Outline all of the iree.dispatch_region ops in this function. - SmallVector<IREE::DispatchRegionOp, 8> dispatchRegionOps; - func.walk( - [&](IREE::DispatchRegionOp op) { dispatchRegionOps.push_back(op); }); - for (int i = 0; i < dispatchRegionOps.size(); ++i) { - if (failed(outlineDispatchRegion(dispatchRegionOps[i], i))) { - return signalPassFailure(); - } - } - } - } -}; - -std::unique_ptr<OpPassBase<ModuleOp>> createOutlineDispatchRegionsPass() { - return std::make_unique<OutlineDispatchRegionsPass>(); -} - -static PassRegistration<OutlineDispatchRegionsPass> pass( - "iree-outline-dispatch-regions", - "Outlines dispatch regions into standalone functions"); - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/OutlineReductionRegions.cpp b/iree/compiler/Transforms/Sequencer/OutlineReductionRegions.cpp deleted file mode 100644 index 9636d1b..0000000 --- a/iree/compiler/Transforms/Sequencer/OutlineReductionRegions.cpp +++ /dev/null
@@ -1,307 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <utility> - -#include "iree/compiler/IR/Ops.h" -#include "iree/compiler/IR/Sequencer/HLOps.h" -#include "iree/compiler/IR/StructureOps.h" -#include "iree/compiler/IR/Types.h" -#include "iree/compiler/Utils/DispatchUtils.h" -#include "iree/compiler/Utils/MemRefUtils.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SetVector.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/Utils.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -// Determines the shapes involved with reducing this dimension. -SmallVector<int64_t, 4> calculateResultShape(Value *input, - int windowDimension) { - SmallVector<int64_t, 4> resultShape; - for (auto it : - llvm::enumerate(input->getType().cast<ShapedType>().getShape())) { - if (it.index() != windowDimension) { - resultShape.push_back(it.value()); - } - } - return resultShape; -} - -// Creates an executable that holds the given elemental reduction region. -// The executable will have an entry point taking the specified reduction values -// and writing the results to output arguments. -std::pair<IREE::MultiArchExecutableOp, FuncOp> createReductionExecutable( - IREE::ReductionRegionOp regionOp, int outlinedRegionOrdinal, - int separatedReductionIndex, int reductionDimension, - SmallVector<Value *, 4> initialValues, SmallVector<Value *, 4> inputs) { - Builder builder(regionOp.getContext()); - - // Build function type matching 1:1 with the region signature. - SmallVector<Type, 8> elementalOperandTypes; - SmallVector<Type, 8> elementalResultTypes; - for (auto *arg : regionOp.getInitialValueOperands()) { - // (in0, in1) -> out0 - elementalOperandTypes.push_back(arg->getType()); - elementalOperandTypes.push_back(arg->getType()); - elementalResultTypes.push_back(arg->getType()); - } - auto elementalFunctionType = FunctionType::get( - elementalOperandTypes, elementalResultTypes, regionOp.getContext()); - - // Create the executable with the region cloned into it. - IREE::MultiArchExecutableOp multiArchExecutable; - FuncOp elementalFunc; - std::tie(multiArchExecutable, elementalFunc) = createRegionExecutable( - regionOp, elementalFunctionType, - "_reduce_" + std::to_string(outlinedRegionOrdinal) + "_dim_" + - std::to_string(separatedReductionIndex)); - - // Create a new entry point that we can use with the signature for this - // dimension. - SmallVector<Type, 8> allOperandTypes; - auto inputTypes = - llvm::map_range(inputs, [](Value *value) { return value->getType(); }); - allOperandTypes.append(inputTypes.begin(), inputTypes.end()); - auto initialValueTypes = llvm::map_range( - initialValues, [](Value *value) { return value->getType(); }); - allOperandTypes.append(initialValueTypes.begin(), initialValueTypes.end()); - for (auto resultType : llvm::enumerate(regionOp.getResultTypes())) { - auto shapedType = resultType.value().cast<ShapedType>(); - allOperandTypes.push_back(MemRefType::get( - calculateResultShape(inputs[resultType.index()], reductionDimension), - shapedType.getElementType())); - } - auto entryFuncType = FunctionType::get(allOperandTypes, ArrayRef<Type>{}, - regionOp.getContext()); - auto entryFunc = - FuncOp::create(regionOp.getLoc(), - (elementalFunc.getName() + "_entry").str(), entryFuncType); - entryFunc.setAttr("iree.executable.export", - UnitAttr::get(regionOp.getContext())); - elementalFunc.getOperation()->getBlock()->push_back(entryFunc); - entryFunc.getOperation()->moveBefore(elementalFunc); - entryFunc.setAttr("iree.executable.reduction", - UnitAttr::get(regionOp.getContext())); - entryFunc.setAttr("iree.executable.reduction.apply", - builder.getSymbolRefAttr(elementalFunc)); - - return {multiArchExecutable, entryFunc}; -} - -// Converts a reduction_region into a dispatch to the outlined region function -// for a single reduction dimension. -// Returns the results of the reduction or empty if the construction fails. -SmallVector<Value *, 4> convertToDispatchOp( - IREE::ReductionRegionOp regionOp, IREE::MultiArchExecutableOp executable, - FuncOp entryFunc, int reductionDimension, - SmallVector<Value *, 4> initialValues, SmallVector<Value *, 4> inputs, - OpBuilder &dispatcherBuilder) { - // Allocate output args and replace the return values with those. - SmallVector<Value *, 4> resultValues; - for (auto resultType : llvm::enumerate(regionOp.getResultTypes())) { - // Allocate output buffer in the dispatcher to pass in to the region. - auto shapedType = resultType.value().cast<ShapedType>(); - Value *allocatedValue = allocateDispatchOutputBuffer( - regionOp.getLoc(), - MemRefType::get(calculateResultShape(inputs[resultType.index()], - reductionDimension), - shapedType.getElementType()), - dispatcherBuilder); - if (!allocatedValue) { - regionOp.emitError("unable to allocate result value"); - return {}; - } - resultValues.push_back(allocatedValue); - } - - // Calculate workload from the result shape. - auto *workload = - wrapAsMemRef(calculateWorkload(regionOp, resultValues.front()), regionOp, - dispatcherBuilder); - - // Create the reduce op to the executable function. - std::vector<Value *> allOperands; - allOperands.insert(allOperands.end(), inputs.begin(), inputs.end()); - allOperands.insert(allOperands.end(), initialValues.begin(), - initialValues.end()); - allOperands.insert(allOperands.end(), resultValues.begin(), - resultValues.end()); - dispatcherBuilder.create<IREESeq::HL::DispatchOp>( - regionOp.getLoc(), executable.getName(), entryFunc.getName(), workload, - ArrayRef<Type>{}, allOperands); - - return resultValues; -} - -// Outlines a reduction region into one or more iree.multi_arch_executables. -// This separates the reduction into multiple dispatches, one for each reduction -// dimension (thankfully XLA's operation semantics state this is ok). We then -// special case the first dispatch such that it takes the constant initial -// values so that we don't have to materialize a buffer for them. -LogicalResult outlineReductionRegion(IREE::ReductionRegionOp regionOp, - int outlinedRegionOrdinal) { - // Insert at the same place as the original region. - OpBuilder dispatcherBuilder(regionOp); - - // Wrap input operands in memrefs. - SmallVector<Value *, 4> initialValues{llvm::map_range( - regionOp.getInitialValueOperands(), [&](Value *originalArg) { - return insertDispatcherStore(regionOp, originalArg, dispatcherBuilder); - })}; - SmallVector<Value *, 4> temps{ - llvm::map_range(regionOp.getReductionOperands(), [&](Value *originalArg) { - return insertDispatcherStore(regionOp, originalArg, dispatcherBuilder); - })}; - - // Create one dispatch per dimension being reduced. - // We'll do this by chaining the original input through with the temporary - // reduction results. The results we end up with will be the originally - // requested shape and we can just substitute them. - if (regionOp.isWindowed()) { - auto windowDimensions = regionOp.window_dimensions().getValue(); - auto windowStrides = regionOp.window_strides().getValue(); - auto baseDilations = regionOp.base_dilations().getValue(); - auto windowDilations = regionOp.window_dilations().getValue(); - SmallVector<std::tuple<int64_t, int64_t, int64_t, int64_t>, 4> - sortedWindowAttrs; - for (uint64_t i = 0; i < windowDimensions.getNumElements(); ++i) { - int64_t windowDimension = - windowDimensions.getValue<IntegerAttr>({i}).getInt(); - int64_t windowStride = windowStrides.getValue<IntegerAttr>({i}).getInt(); - int64_t baseDilation = baseDilations.getValue<IntegerAttr>({i}).getInt(); - int64_t windowDilation = - windowDilations.getValue<IntegerAttr>({i}).getInt(); - sortedWindowAttrs.push_back( - {windowDimension, windowStride, baseDilation, windowDilation}); - } - llvm::sort(sortedWindowAttrs, - [](std::tuple<int64_t, int64_t, int64_t, int64_t> a, - std::tuple<int64_t, int64_t, int64_t, int64_t> b) { - return std::get<0>(a) - std::get<0>(b); - }); - for (auto windowAttrs : llvm::enumerate(sortedWindowAttrs)) { - int64_t windowDimension = std::get<0>(windowAttrs.value()); - int64_t windowStride = std::get<1>(windowAttrs.value()); - int64_t baseDilation = std::get<2>(windowAttrs.value()); - int64_t windowDilation = std::get<3>(windowAttrs.value()); - IREE::MultiArchExecutableOp multiArchExecutable; - FuncOp entryFunc; - std::tie(multiArchExecutable, entryFunc) = createReductionExecutable( - regionOp, outlinedRegionOrdinal, windowAttrs.index(), windowDimension, - initialValues, temps); - entryFunc.setAttr("iree.executable.reduction.padding_mode", - dispatcherBuilder.getI32IntegerAttr( - regionOp.padding_mode().getValue())); - entryFunc.setAttr("iree.executable.reduction.window_dimension", - dispatcherBuilder.getI32IntegerAttr(windowDimension)); - entryFunc.setAttr("iree.executable.reduction.window_stride", - dispatcherBuilder.getI32IntegerAttr(windowStride)); - entryFunc.setAttr("iree.executable.reduction.base_dilation", - dispatcherBuilder.getI32IntegerAttr(baseDilation)); - entryFunc.setAttr("iree.executable.reduction.window_dilation", - dispatcherBuilder.getI32IntegerAttr(windowDilation)); - temps = convertToDispatchOp(regionOp, multiArchExecutable, entryFunc, - windowDimension, initialValues, - std::move(temps), dispatcherBuilder); - if (temps.empty()) { - return regionOp.emitOpError() - << "Failed to construct reduction for windowed dimension " - << windowDimension; - } - } - } else { - auto dimensions = regionOp.dimensions().getValue(); - SmallVector<int64_t, 4> sortedDimensions; - for (uint64_t i = 0; i < dimensions.getNumElements(); ++i) { - sortedDimensions.push_back( - dimensions.getValue<IntegerAttr>({i}).getInt()); - } - llvm::sort(sortedDimensions, [](int64_t a, int64_t b) { return a - b; }); - for (auto dimension : llvm::enumerate(sortedDimensions)) { - IREE::MultiArchExecutableOp multiArchExecutable; - FuncOp entryFunc; - std::tie(multiArchExecutable, entryFunc) = createReductionExecutable( - regionOp, outlinedRegionOrdinal, dimension.index(), dimension.value(), - initialValues, temps); - entryFunc.setAttr("iree.executable.reduction.dimension", - dispatcherBuilder.getI32IntegerAttr(dimension.value())); - temps = convertToDispatchOp(regionOp, multiArchExecutable, entryFunc, - dimension.value(), initialValues, - std::move(temps), dispatcherBuilder); - if (temps.empty()) { - return regionOp.emitOpError() - << "Failed to construct reduction for dimension " - << dimension.value(); - } - } - } - for (auto it : llvm::enumerate(regionOp.getResults())) { - insertDispatcherLoad(regionOp, it.value(), temps[it.index()], - dispatcherBuilder); - } - - // Erase original region. - regionOp.erase(); - - return success(); -} - -} // namespace - -class OutlineReductionRegionsPass - : public ModulePass<OutlineReductionRegionsPass> { - public: - void runOnModule() override { - auto module = getModule(); - - ModuleManager moduleManager(module); - auto funcs = module.getOps<FuncOp>(); - SmallVector<FuncOp, 4> funcOps(funcs.begin(), funcs.end()); - for (auto func : funcOps) { - // Outline all of the iree.reduction_region ops in this function. - std::vector<IREE::ReductionRegionOp> reductionRegionOps; - func.walk([&](IREE::ReductionRegionOp op) { - reductionRegionOps.push_back(op); - }); - for (int i = 0; i < reductionRegionOps.size(); ++i) { - if (failed(outlineReductionRegion(reductionRegionOps[i], i))) { - return signalPassFailure(); - } - } - } - } -}; - -std::unique_ptr<OpPassBase<ModuleOp>> createOutlineReductionRegionsPass() { - return std::make_unique<OutlineReductionRegionsPass>(); // NOLINT -} - -static PassRegistration<OutlineReductionRegionsPass> pass( - "iree-outline-reduction-regions", - "Outlines reduction regions into standalone functions"); - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/RematerializeDispatchConstants.cpp b/iree/compiler/Transforms/Sequencer/RematerializeDispatchConstants.cpp deleted file mode 100644 index c541e4e..0000000 --- a/iree/compiler/Transforms/Sequencer/RematerializeDispatchConstants.cpp +++ /dev/null
@@ -1,149 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <algorithm> - -#include "iree/compiler/IR/Ops.h" -#include "iree/compiler/Utils/DispatchUtils.h" -#include "llvm/Support/Debug.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/Utils.h" -#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -// Chosen randomly for now. We can measure and see what makes sense. -constexpr int64_t kMaxRematerializedConstantSizeInBytes = 1 * 1024; - -// Returns true if the constant value is under a certain threshold. -// This threshold is fixed for all backends as a value that is assumed small -// enough to be worth inlining possibly several times (at the cost of binary -// bloat). -bool isConstantSmall(ConstantOp constantOp) { - if (auto shapedType = constantOp.getType().dyn_cast<ShapedType>()) { - return shapedType.getSizeInBits() / 8 <= - kMaxRematerializedConstantSizeInBytes; - } - - // Assume anything unshaped is small. This may not always be true in custom - // dialects but is in std for now. - return true; -} - -// Returns true if the dispatch region is allowed to have constants inside. -// Certain regions that may get replaced or turned into kernel imports shouldn't -// have the constants moved into them as they'll just get lost. -bool canDispatchRegionContainConstants( - IREE::DispatchRegionOp dispatchRegionOp) { - for (auto &block : dispatchRegionOp.getBody()) { - for (auto &op : block) { - if (isa<xla_hlo::DotOp>(&op)) { - return false; - } - } - } - return true; -} - -// Rematerializes a constant inside of all dispatch regions that use it. -// Afterward the constant is only removed if there are no other uses within the -// non-dispatch block (such as by sequencer ops). -LogicalResult rematerializeConstantInDispatchRegions(ConstantOp constantOp) { - Value *constantValue = constantOp.getResult(); - SmallVector<IREE::DispatchRegionOp, 4> usingRegionOps; - for (auto *user : constantValue->getUsers()) { - if (auto dispatchRegionOp = dyn_cast<IREE::DispatchRegionOp>(user)) { - // Ensure this isn't just the workload and is used as an arg. - if (std::find(dispatchRegionOp.arg_operand_begin(), - dispatchRegionOp.arg_operand_end(), - constantValue) != dispatchRegionOp.arg_operand_end()) { - if (canDispatchRegionContainConstants(dispatchRegionOp)) { - usingRegionOps.push_back(dispatchRegionOp); - } - } - } - } - for (auto &dispatchRegionOp : usingRegionOps) { - if (failed(inlineDispatchRegionOperandsUsingValue(dispatchRegionOp, - constantValue))) { - return failure(); - } - } - - // Remove if there are no other uses within the block. - if (constantOp.use_empty()) { - constantOp.erase(); - } - - return success(); -} - -} // namespace - -// Finds constant arguments to dispatch regions that are too small to be worth -// putting into constant pools. This prevents things like a CSE'd scalar -// constant of 0.0 being passed by reference to a bunch of regions. Later -// backend-specific passes running on the dispatch regions may also be able to -// improve their constant propagation chances by having the full constant value -// available. -// -// Note that this currently only operates at the block level. Constants that are -// pushed across branches are assumed to have been rematerialized within blocks -// already, but if that isn't the case then this pass can be extended to do -// that. -class RematerializeDispatchConstantsPass - : public FunctionPass<RematerializeDispatchConstantsPass> { - public: - void runOnFunction() override { - for (auto &block : getFunction()) { - SmallVector<ConstantOp, 8> smallConstantOps; - for (auto constantOp : block.getOps<ConstantOp>()) { - if (isConstantSmall(constantOp)) { - smallConstantOps.push_back(constantOp); - } - } - // Note: we iterate in reverse so that the rematerialized constants appear - // in the same order they did originally (as insertion is at the top). - for (auto constantOp : llvm::reverse(smallConstantOps)) { - if (failed(rematerializeConstantInDispatchRegions(constantOp))) { - return signalPassFailure(); - } - } - } - } -}; - -std::unique_ptr<OpPassBase<FuncOp>> createRematerializeDispatchConstantsPass() { - return std::make_unique<RematerializeDispatchConstantsPass>(); -} - -static PassRegistration<RematerializeDispatchConstantsPass> pass( - "iree-rematerialize-dispatch-constants", - "Rematerializes small previously-CSE'd constants into dispatch regions."); - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Transforms/Sequencer/Rewrites.h b/iree/compiler/Transforms/Sequencer/Rewrites.h deleted file mode 100644 index 87fa105..0000000 --- a/iree/compiler/Transforms/Sequencer/Rewrites.h +++ /dev/null
@@ -1,40 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_COMPILER_TRANSFORMS_SEQUENCER_REWRITES_H_ -#define IREE_COMPILER_TRANSFORMS_SEQUENCER_REWRITES_H_ - -#include "iree/compiler/Utils/TypeConversionUtils.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/PatternMatch.h" - -namespace mlir { -namespace iree_compiler { - -// Adds rewrite patterns for lowering IREE Sequencer HL ops (iree_hl_seq.*) -// to LL ops (iree_ll_seq.*). -void populateSequencerLoweringPatterns(OwningRewritePatternList &patterns, - MLIRContext *ctx); - -// Adds rewrite patterns for lowering xla_hlo ops to Sequencer HL ops. -void populateLowerXlaToSequencerPatterns(OwningRewritePatternList &patterns, - MLIRContext *ctx); - -// Adds rewrite patterns for lowering standard ops to Sequencer HL ops. -void populateLowerStdToSequencerPatterns(OwningRewritePatternList &patterns, - MLIRContext *ctx); -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_TRANSFORMS_SEQUENCER_REWRITES_H_
diff --git a/iree/compiler/Transforms/Sequencer/test/BUILD b/iree/compiler/Transforms/Sequencer/test/BUILD deleted file mode 100644 index 74a48a3..0000000 --- a/iree/compiler/Transforms/Sequencer/test/BUILD +++ /dev/null
@@ -1,16 +0,0 @@ -# Tests specific to the sequencer. - -load("//iree:build_defs.bzl", "iree_glob_lit_tests", "iree_setup_lit_package") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -iree_setup_lit_package( - data = [ - "//iree/tools:iree-opt", - ], -) - -iree_glob_lit_tests()
diff --git a/iree/compiler/Transforms/test/BUILD b/iree/compiler/Transforms/test/BUILD deleted file mode 100644 index 79700a8..0000000 --- a/iree/compiler/Transforms/test/BUILD +++ /dev/null
@@ -1,16 +0,0 @@ -# Tests for common transforms. - -load("//iree:build_defs.bzl", "iree_glob_lit_tests", "iree_setup_lit_package") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -iree_setup_lit_package( - data = [ - "//iree/tools:iree-opt", - ], -) - -iree_glob_lit_tests()
diff --git a/iree/compiler/Translation/Interpreter/BUILD b/iree/compiler/Translation/Interpreter/BUILD deleted file mode 100644 index 8dff827..0000000 --- a/iree/compiler/Translation/Interpreter/BUILD +++ /dev/null
@@ -1,31 +0,0 @@ -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "Interpreter", - srcs = ["InterpreterExecutableTranslation.cpp"], - hdrs = ["InterpreterExecutableTranslation.h"], - deps = [ - "//iree/compiler/IR", - "//iree/compiler/IR/Interpreter", - "//iree/compiler/Serialization", - "//iree/compiler/Transforms", - "//iree/compiler/Transforms/Interpreter", - "//iree/compiler/Utils", - "//iree/schemas", - "@com_github_google_flatbuffers//:flatbuffers", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", - "@local_config_mlir//:StandardDialectRegistration", - "@local_config_mlir//:Support", - "@local_config_mlir//:Transforms", - "@local_config_mlir//:Translation", - "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo", - "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_dialect_registration", - "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_legalize_to_standard", - ], - alwayslink = 1, -)
diff --git a/iree/compiler/Translation/Interpreter/InterpreterExecutableTranslation.cpp b/iree/compiler/Translation/Interpreter/InterpreterExecutableTranslation.cpp deleted file mode 100644 index d21d6e0..0000000 --- a/iree/compiler/Translation/Interpreter/InterpreterExecutableTranslation.cpp +++ /dev/null
@@ -1,289 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/Translation/Interpreter/InterpreterExecutableTranslation.h" - -#include <cstdint> -#include <iostream> -#include <vector> - -#include "flatbuffers/flatbuffers.h" -#include "flatbuffers/minireflect.h" -#include "iree/compiler/IR/ConfigOps.h" -#include "iree/compiler/IR/Interpreter/OpWriters.h" -#include "iree/compiler/IR/Types.h" -#include "iree/compiler/Serialization/VMFunctionBuilder.h" -#include "iree/compiler/Serialization/VMFunctionTableBuilder.h" -#include "iree/compiler/Serialization/VMModuleBuilder.h" -#include "iree/compiler/Transforms/Interpreter/Passes.h" -#include "iree/compiler/Transforms/Passes.h" -#include "iree/compiler/Utils/Macros.h" -#include "iree/compiler/Utils/OpUtils.h" -#include "iree/compiler/Utils/TranslationUtils.h" -#include "iree/schemas/executable_def_generated.h" -#include "iree/schemas/module_def_generated.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/Debug.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Module.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/Passes.h" -#include "mlir/Translation.h" -#include "tensorflow/compiler/mlir/xla/transforms/passes.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -// Builds a pass pipeline that optimizes and legalizes the module to the form -// expected by translation. -void buildLegalizeInputPassPipeline(PassManager *passManager) { - // Standard passes that shake out a lot of garbage. - // Some may have been run prior to translation but this ensures we are always - // in a known state. - passManager->addPass(createCanonicalizerPass()); - passManager->addPass(createLoopFusionPass()); - passManager->addPass(createLoopInvariantCodeMotionPass()); - passManager->addPass(createMemRefDataFlowOptPass()); - passManager->addPass(createCanonicalizerPass()); - passManager->addPass(createSimplifyAffineStructuresPass()); - passManager->addPass(createCSEPass()); - passManager->addPass(createCanonicalizerPass()); - - // Eliminate ops we don't care about based on a lack of side-effects. - // IREE does not guarantee exception/error behavior of dead ops. - passManager->addPass(createAggressiveOpEliminationPass()); - - // Expand uses of tuples into independent args/results. - passManager->addPass(createConvertFromTupleCallingConventionPass()); - passManager->addPass(createCanonicalizerPass()); -} - -// Builds a pass pipeline that converts functions to the iree_hl_interp dialect. -void buildInterpreterConversionPassPipeline(PassManager *passManager) { - // We don't need the IREE binding ops anymore, as we match the calling - // convention exactly (we're the same VM). - passManager->addPass(createMakeExecutableABIPass()); - - // Convert to the memref calling convention and optimize away as many - // loads and stores as we can prior to progressing. - passManager->addPass(createConvertToMemRefCallingConventionPass()); - passManager->addPass(createCanonicalizerPass()); - passManager->addPass(createMemRefDataFlowOptPass()); - - // Convert various dialects to IREE opcodes and cleanup leftover conversions. - passManager->addPass(createLowerToInterpreterDialectPass()); - passManager->addPass(createCanonicalizerPass()); - passManager->addPass(createAggressiveOpEliminationPass()); - - // Widen reduction functions (that have iree.executable.reduction attrs) to - // use their primitive IREE ops. - passManager->addPass(createExpandReductionsToOpsPass()); - - // Convert any uses of index to int32_t (as we explicitly don't want to - // support dynamic index width). - // This also looks for other weird types (i1, etc). - passManager->addPass(createLegalizeTypeStoragePass()); - - // Perform any last-minute optimizations to trim down the IR. - passManager->addPass(createAggressiveOpEliminationPass()); - passManager->addPass(createCanonicalizerPass()); - passManager->addPass(createLoopFusionPass()); - passManager->addPass(createLoopInvariantCodeMotionPass()); - passManager->addPass(createMemRefDataFlowOptPass()); - passManager->addPass(createCanonicalizerPass()); - passManager->addPass(createCSEPass()); - passManager->addPass(createCanonicalizerPass()); - - // Drop all functions that are not reachable. - passManager->addPass(createDropUnreachableExecutableFunctionsPass()); -} - -// Builds a pass pipeline that lowers the iree_hl_interp dialect to the -// iree_ll_interp dialect and prepares for serialization. -void buildInterpreterLoweringPassPipeline(PassManager *passManager) { - // Lower iree_hl_interp -> iree_ll_interp. - passManager->addPass(createLowerInterpreterDialectPass()); - - // Assign ordinals used by the bytecode to reference executables and - // functions. - passManager->addPass(createAssignFunctionOrdinalsPass()); -} - -class InterpreterTranslator { - public: - explicit InterpreterTranslator(ExecutableTranslationOptions options) - : options_(options) {} - - const ExecutableTranslationOptions &options() const { return options_; } - - std::unique_ptr<iree::ExecutableDefT> translateExecutable( - IREE::ExecutableOp executableOp); - - private: - LogicalResult translateExecutableModule(IREE::ExecutableOp executableOp, - ModuleOp moduleOp, - VMModuleBuilder *moduleBuilder); - LogicalResult declareFunction(FuncOp function, - VMModuleBuilder *moduleBuilder); - LogicalResult defineFunction(FuncOp function, VMModuleBuilder *moduleBuilder); - - ExecutableTranslationOptions options_; -}; - -std::unique_ptr<iree::ExecutableDefT> -InterpreterTranslator::translateExecutable(IREE::ExecutableOp executableOp) { - auto moduleOp = executableOp.getInnerModule(); - - // Run all passes to go from input to the iree_ll_interp dialect. - auto executableConversionPasses = - createPassManager(moduleOp.getContext(), options()); - buildLegalizeInputPassPipeline(executableConversionPasses.get()); - buildInterpreterConversionPassPipeline(executableConversionPasses.get()); - buildInterpreterLoweringPassPipeline(executableConversionPasses.get()); - if (failed(runPassPipeline(options(), executableConversionPasses.get(), - moduleOp))) { - executableOp.emitError() << "Failed to run conversion passes"; - return {}; - } - - // Build the module bytecode. - ::flatbuffers::FlatBufferBuilder fbb; - VMModuleBuilder moduleBuilder(&fbb); - if (failed( - translateExecutableModule(executableOp, moduleOp, &moduleBuilder))) { - executableOp.emitError() << "Failed to translate executable module"; - return {}; - } - auto moduleDef = moduleBuilder.Finish(); - if (moduleDef.IsNull()) { - moduleOp.emitError() << "Failed to verify completed module def"; - return {}; - } - auto bytes = moduleBuilder.Serialize(moduleDef); - if (bytes.empty()) { - moduleOp.emitError() << "Failed to serialize final module def"; - return {}; - } - - OpBuilder builder(executableOp); - executableOp.setAttr("format", builder.getI32IntegerAttr(static_cast<int32_t>( - IREE::ExecutableFormat::IreeBytecode))); - - auto executableDef = std::make_unique<iree::ExecutableDefT>(); - executableDef->format = - static_cast<uint32_t>(IREE::ExecutableFormat::IreeBytecode); - executableDef->supported_features = iree::ExecutableFeature::kDebugging; - executableDef->contents = std::move(bytes); - return executableDef; -} - -LogicalResult InterpreterTranslator::translateExecutableModule( - IREE::ExecutableOp executableOp, ModuleOp moduleOp, - VMModuleBuilder *moduleBuilder) { - // Declare functions first so that we get stable indices during declaration - // (as call ops need to use the function table). - for (auto function : moduleOp.getOps<FuncOp>()) { - RETURN_IF_FAILURE(declareFunction(function, moduleBuilder)); - } - - // Define functions now that all functions have been declared. - for (auto function : moduleOp.getOps<FuncOp>()) { - RETURN_IF_FAILURE(defineFunction(function, moduleBuilder)); - } - - return success(); -} - -LogicalResult InterpreterTranslator::declareFunction( - FuncOp function, VMModuleBuilder *moduleBuilder) { - auto *functionTable = moduleBuilder->function_table(); - if (functionTable->IsFunctionDeclared(function)) { - // Already declared. - return success(); - } - - LinkageType linkageType; - if (function.isExternal()) { - linkageType = LinkageType::kImport; - } else if (function.getAttr("iree.executable.export")) { - linkageType = LinkageType::kExport; - } else { - linkageType = LinkageType::kInternal; - } - if (failed(functionTable->DeclareFunction(function, linkageType))) { - return function.emitError() << "Unable to declare function"; - } - - // Import functions must have their definition defined here so we get their - // type. Internal and export functions will be defined during conversion. - if (linkageType == LinkageType::kImport) { - VMFunctionBuilder functionBuilder(function, moduleBuilder->function_table(), - moduleBuilder->fbb()); - auto functionOffset = functionBuilder.Finish(); - if (functionOffset.IsNull()) { - return function.emitError() - << "Failed to create import function bytecode"; - } - RETURN_IF_FAILURE( - functionTable->DefineFunction(function, functionOffset, {})); - } - - return success(); -} - -LogicalResult InterpreterTranslator::defineFunction( - FuncOp function, VMModuleBuilder *moduleBuilder) { - VMFunctionBuilder functionBuilder(function, moduleBuilder->function_table(), - moduleBuilder->fbb()); - registerInterpreterCustomWriters(&functionBuilder); - RETURN_IF_FAILURE(functionBuilder.ConvertBytecode()); - auto functionOffset = functionBuilder.Finish(); - if (functionOffset.IsNull()) { - return function.emitError() << "Failed to serialize function"; - } - RETURN_IF_FAILURE(moduleBuilder->function_table()->DefineFunction( - function, functionOffset, functionBuilder.source_map())); - return success(); -} - -} // namespace - -llvm::Optional<ExecutableTranslationResult> -translateExecutableToInterpreterExecutable( - ArrayRef<IREE::ExecutableOp> executableOps, - ExecutableTranslationOptions options) { - InterpreterTranslator translator(options); - ExecutableTranslationResult translationResult; - for (auto executableOp : llvm::make_early_inc_range(executableOps)) { - auto executableDef = translator.translateExecutable(executableOp); - if (!executableDef) { - executableOp.emitError() << "Failed to translate one or more executables"; - return llvm::None; - } - translationResult.executable_defs.push_back(std::move(executableDef)); - } - return translationResult; -} - -static ExecutableTranslationRegistration - InterpreterExecutableTranslationRegistration( - "interpreter-bytecode", translateExecutableToInterpreterExecutable); - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Translation/Interpreter/InterpreterExecutableTranslation.h b/iree/compiler/Translation/Interpreter/InterpreterExecutableTranslation.h deleted file mode 100644 index 41949c7..0000000 --- a/iree/compiler/Translation/Interpreter/InterpreterExecutableTranslation.h +++ /dev/null
@@ -1,39 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_COMPILER_TRANSLATION_INTERPRETER_INTERPRETEREXECUTABLETRANSLATION_H_ -#define IREE_COMPILER_TRANSLATION_INTERPRETER_INTERPRETEREXECUTABLETRANSLATION_H_ - -#include <vector> - -#include "iree/compiler/IR/StructureOps.h" -#include "iree/compiler/Utils/TranslationUtils.h" -#include "mlir/IR/Module.h" - -namespace mlir { -namespace iree_compiler { - -// Translates an MLIR module into a bytecode interpreter executable. -// These executables are stored as IREE modules as defined in the -// https://github.com/google/iree/tree/master/iree/schemas/module_def.fbs -// schema. -llvm::Optional<ExecutableTranslationResult> -translateExecutableToInterpreterExecutable( - ArrayRef<IREE::ExecutableOp> executableOps, - ExecutableTranslationOptions options = {}); - -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_TRANSLATION_INTERPRETER_INTERPRETEREXECUTABLETRANSLATION_H_
diff --git a/iree/compiler/Translation/SPIRV/AffineExprCodegen.h b/iree/compiler/Translation/SPIRV/AffineExprCodegen.h deleted file mode 100644 index b21182b..0000000 --- a/iree/compiler/Translation/SPIRV/AffineExprCodegen.h +++ /dev/null
@@ -1,143 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//===- AffineExprCodegen.h -------------------------------------*- C++//-*-===// -// -// Code-generation for Affine Expression. -// -//===----------------------------------------------------------------------===// -#ifndef IREE_COMPILER_TRANSLATION_SPIRV_AFFINEEXPRCODGEN_H -#define IREE_COMPILER_TRANSLATION_SPIRV_AFFINEEXPRCODGEN_H - -#include "iree/compiler/Translation/SPIRV/XLAIndexPropagation.h" -#include "mlir/Dialect/SPIRV/SPIRVOps.h" -#include "mlir/IR/AffineExprVisitor.h" - -namespace mlir { -namespace iree_compiler { - -/// Codegenerator for affine expressions. -class AffineExprCodegen : public AffineExprVisitor<AffineExprCodegen, Value *> { - public: - explicit AffineExprCodegen(spirv::ModuleOp module, - IndexComputationCache &tensorIndices) - : builder(module.getContext()), - location(module.getLoc()), - tensorIndices(tensorIndices) {} - - Value *visitAddExpr(AffineBinaryOpExpr expr) { - auto operand1 = getValueInternal(expr.getLHS()); - auto operand2 = getValueInternal(expr.getRHS()); - return builder.create<spirv::IAddOp>(location, operand1, operand2); - } - Value *visitMulExpr(AffineBinaryOpExpr expr) { - auto operand1 = getValueInternal(expr.getLHS()); - auto operand2 = getValueInternal(expr.getRHS()); - return builder.create<spirv::IMulOp>(location, operand1, operand2); - } - Value *visitModExpr(AffineBinaryOpExpr expr) { - auto operand1 = getValueInternal(expr.getLHS()); - auto operand2 = getValueInternal(expr.getRHS()); - return builder.create<spirv::SModOp>(location, operand1, operand2); - } - Value *visitFloorDivExpr(AffineBinaryOpExpr expr) { - auto operand1 = getValueInternal(expr.getLHS()); - auto operand2 = getValueInternal(expr.getRHS()); - return builder.create<spirv::SDivOp>(location, operand1, operand2); - } - Value *visitCeilDivExpr(AffineBinaryOpExpr expr) { - // TODO(ravishankarm): Implement ceil div expr codegen. - llvm_unreachable("Unimplemented affine AffineCeilDivExpr codegen"); - return nullptr; - } - Value *visitConstantExpr(AffineConstantExpr expr) { - return builder.create<spirv::ConstantOp>( - location, builder.getIntegerType(32), - builder.getI32IntegerAttr(expr.getValue())); - } - Value *visitDimExpr(AffineDimExpr expr) { - return threadDimToDstValue.lookup(expr.getPosition()); - } - Value *visitSymbolExpr(AffineSymbolExpr expr) { - // TODO(ravishankarm): Implement symbol expr codegen. - llvm_unreachable("Unimplemented affine AffineSymbolExpr codegen"); - return nullptr; - } - - /// Set the value that contains the workitem ID along a particular - /// dimension. 0 -> x-dimension, 1 -> y-dimension, etc. - void setDimDstValue(unsigned dimID, Value *value) { - threadDimToDstValue[dimID] = value; - } - - /// Generates the scalar value for a affine expression. - Value *getValue(AffineExpr expr, OpBuilder::InsertPoint ip, Location loc) { - auto &val = exprToDstValue[expr]; - if (!val) { - location = loc; - builder.restoreInsertionPoint(ip); - val = visit(expr); - } - return val; - } - - /// Returns a list of indices of a particular tensor in the source dialect - /// needed within the dispatch function (obtained from the - /// IndexComputationCache) - SmallVector<AffineMap, 4> getIndices(Value *value) { - SmallVector<AffineMap, 4> indices; - for (auto &index : tensorIndices[value]) { - indices.push_back(index.first); - } - return indices; - } - - /// For a given tensor in the source dialect and index, return the index of - /// all operands needed to compute the result. - ArrayRef<AffineMap> getOperandIndices(Value *value, AffineMap index) { - return tensorIndices[value][index]; - } - - private: - /// Returns the Value corresponding to the AffineExpr `expr` by either - /// previously generated value for the same index, or by generating the value. - /// This version assumes the insertion point/Location has already been set. - Value *getValueInternal(AffineExpr expr) { - auto &val = exprToDstValue[expr]; - if (!val) { - val = visit(expr); - } - return val; - } - - OpBuilder builder; - - Location location; - - /// Map from launch dimension to scalar value. - DenseMap<unsigned, Value *> threadDimToDstValue; - - /// Cache of affine expression to scalar value. TODO(ravishankarm) : Might - /// need to be changed if we are handling control flow within the dispatch - /// function. - DenseMap<AffineExpr, Value *> exprToDstValue; - - /// Map from tensor value in source dialect to list of indices of the tensor - /// needed within a workitem to compute the results of the dispatch function. - IndexComputationCache &tensorIndices; -}; -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_TRANSLATION_SPIRV_AFFINEEXPRCODGEN_H
diff --git a/iree/compiler/Translation/SPIRV/BUILD b/iree/compiler/Translation/SPIRV/BUILD deleted file mode 100644 index b0ad00b..0000000 --- a/iree/compiler/Translation/SPIRV/BUILD +++ /dev/null
@@ -1,50 +0,0 @@ -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "SPIRV", - srcs = [ - "AffineExprCodegen.h", - "EmbeddedKernels.cpp", - "IREEIndexComputation.cpp", - "IREEToSPIRV.cpp", - "IREEToSPIRVPass.cpp", - "IndexComputation.cpp", - "SPIRVExecutableTranslation.cpp", - "SPIRVLowering.cpp", - "SPIRVLowering.h", - "XLAIndexPropagation.cpp", - ], - hdrs = [ - "EmbeddedKernels.h", - "IREEIndexComputation.h", - "IREEToSPIRV.h", - "IREEToSPIRVPass.h", - "IndexComputation.h", - "SPIRVExecutableTranslation.h", - "XLAIndexPropagation.h", - ], - deps = [ - "//iree/compiler/IR", - "//iree/compiler/Translation/SPIRV/Kernels", - "//iree/compiler/Utils", - "//iree/schemas", - "//iree/schemas:spirv_executable_def_cc_fbs", - "@com_github_google_flatbuffers//:flatbuffers", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", - "@local_config_mlir//:SPIRVDialect", - "@local_config_mlir//:SPIRVDialectRegistration", - "@local_config_mlir//:SPIRVSerialization", - "@local_config_mlir//:StandardDialectRegistration", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Support", - "@local_config_mlir//:Transforms", - "@local_config_mlir//:Translation", - "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo", - ], - alwayslink = 1, -)
diff --git a/iree/compiler/Translation/SPIRV/EmbeddedKernels.cpp b/iree/compiler/Translation/SPIRV/EmbeddedKernels.cpp deleted file mode 100644 index d50f3cf..0000000 --- a/iree/compiler/Translation/SPIRV/EmbeddedKernels.cpp +++ /dev/null
@@ -1,219 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/Translation/SPIRV/EmbeddedKernels.h" - -#include "iree/compiler/Translation/SPIRV/Kernels/Kernels.h" -#include "iree/schemas/spirv_executable_def_generated.h" -#include "mlir/IR/Function.h" -#include "mlir/IR/Module.h" -#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -// Reads the SPIR-V code for the embedded kernel with the given file name. -// If the kernel under Kernels/ is 'matmul.comp' then |kernelName| would be -// 'matmul.spv' (because it's been compiled). -std::vector<uint32_t> readEmbeddedKernelCode(std::string kernelName) { - auto *fileToc = spirv_kernels::Kernels_create(); - for (int i = 0; i < spirv_kernels::Kernels_size(); ++i) { - if (std::strcmp(fileToc[i].name, kernelName.c_str()) == 0) { - std::vector<uint32_t> code; - code.resize(fileToc[i].size / 4); - std::memcpy(code.data(), fileToc[i].data, fileToc[i].size); - return code; - } - } - return {}; -} - -// Adds a storage buffer binding to the descriptor set layout. -void addDescriptorSetLayoutBinding(uint32_t binding, - iree::VkDescriptorSetLayoutDefT *dsl) { - auto bindingDef = std::make_unique<iree::VkDescriptorSetLayoutBindingDefT>(); - bindingDef->binding = binding; - bindingDef->descriptor_count = 1; - bindingDef->descriptor_type = 7; // VK_DESCRIPTOR_TYPE_STORAGE_BUFFER - bindingDef->stage_flags = 0x00000020; // VK_SHADER_STAGE_COMPUTE_BIT - dsl->bindings.push_back(std::move(bindingDef)); -} - -// Adds a specialization map entry for |constant_id| set to a 4-byte int value. -void addSpecializationMapEntry( - uint32_t constant_id, uint32_t value, - iree::VkSpecializationInfoDefT *specializationInfoDef) { - auto specValue = std::make_unique<iree::VkSpecializationMapEntryDefT>(); - specValue->constant_id = constant_id; - specValue->uint32_value = value; - specializationInfoDef->map_entries.push_back(std::move(specValue)); -} - -LogicalResult buildReductionExecutable(IREE::ExecutableOp executableOp, - FuncOp entryFuncOp, - iree::SpirVExecutableDefT *out_def) { - auto funcType = entryFuncOp.getType(); - auto arg0 = funcType.getInput(0).cast<ShapedType>(); - if (!arg0.getElementType().isF32()) { - // When we do other types we'll need other shaders. - return entryFuncOp.emitOpError() - << "Only floating point reduction is implemented"; - } - - auto module = executableOp.getInnerModule(); - auto applyFuncAttr = entryFuncOp.getAttrOfType<SymbolRefAttr>( - "iree.executable.reduction.apply"); - auto applyFuncOp = module.lookupSymbol(applyFuncAttr.getValue()); - - // TODO(benvanik): specialize (template on shapes/types/etc). - std::string kernelName = "reduce_untiled.spv"; - llvm::Optional<uint32_t> operationId; - applyFuncOp->walk([&](Operation *op) { - if (isa<xla_hlo::AddOp>(op)) { - operationId = 0; - } else if (isa<xla_hlo::MaxOp>(op)) { - operationId = 1; - } else if (isa<xla_hlo::MinOp>(op)) { - operationId = 2; - } - }); - if (!operationId.hasValue()) { - applyFuncOp->dump(); - return applyFuncOp->emitOpError() << "Unsupported reduction operator"; - } - - out_def->tag = "__reduce__"; - out_def->entry_points = {"main"}; - - out_def->code = readEmbeddedKernelCode(kernelName); - - // arg0, arg1, ret0 - auto pipelineLayoutDef = std::make_unique<iree::VkPipelineLayoutDefT>(); - pipelineLayoutDef->buffer_binding_set = 0; - auto dsl = std::make_unique<iree::VkDescriptorSetLayoutDefT>(); - addDescriptorSetLayoutBinding(0, dsl.get()); - addDescriptorSetLayoutBinding(1, dsl.get()); - addDescriptorSetLayoutBinding(2, dsl.get()); - pipelineLayoutDef->descriptor_set_layouts.push_back(std::move(dsl)); - out_def->pipeline_layout = std::move(pipelineLayoutDef); - - // See the shader source for documentation on the values of A/B/C/R. - int64_t reductionDimension = - entryFuncOp - .getAttrOfType<IntegerAttr>("iree.executable.reduction.dimension") - .getInt(); - uint32_t r = arg0.getDimSize(reductionDimension); - uint32_t a = 1; - for (int i = 0; i < reductionDimension; ++i) { - a *= arg0.getDimSize(i); - } - uint32_t b = 1; - for (int i = reductionDimension + 1; i < arg0.getRank(); ++i) { - b *= arg0.getDimSize(i); - } - uint32_t c = b; - - auto specializationInfoDef = - std::make_unique<iree::VkSpecializationInfoDefT>(); - addSpecializationMapEntry(/*kOperationId*/ 100, operationId.getValue(), - specializationInfoDef.get()); - addSpecializationMapEntry(/*kA*/ 101, a, specializationInfoDef.get()); - addSpecializationMapEntry(/*kB*/ 102, b, specializationInfoDef.get()); - addSpecializationMapEntry(/*kC*/ 103, c, specializationInfoDef.get()); - addSpecializationMapEntry(/*kR*/ 104, r, specializationInfoDef.get()); - out_def->specialization_info = std::move(specializationInfoDef); - - return success(); -} - -// Builds a SPIR-V executable from a well-known matmul executable. -// |out_def| will be populated with all required information for serialization. -LogicalResult buildMatMulExecutable(IREE::ExecutableOp executableOp, - FuncOp entryFuncOp, xla_hlo::DotOp dotOp, - iree::SpirVExecutableDefT *out_def) { - auto arg0 = dotOp.getOperand(0)->getType().cast<ShapedType>(); - auto arg1 = dotOp.getOperand(1)->getType().cast<ShapedType>(); - - out_def->tag = "__matmul__"; - out_def->entry_points = {"main"}; - - // TODO(benvanik): specialize (template on shapes/types/etc). - out_def->code = readEmbeddedKernelCode("matmul.spv"); - - // arg0, arg1, ret0 - auto pipelineLayoutDef = std::make_unique<iree::VkPipelineLayoutDefT>(); - pipelineLayoutDef->buffer_binding_set = 0; - auto dsl = std::make_unique<iree::VkDescriptorSetLayoutDefT>(); - addDescriptorSetLayoutBinding(0, dsl.get()); - addDescriptorSetLayoutBinding(1, dsl.get()); - addDescriptorSetLayoutBinding(2, dsl.get()); - pipelineLayoutDef->descriptor_set_layouts.push_back(std::move(dsl)); - out_def->pipeline_layout = std::move(pipelineLayoutDef); - - // Shapes of [arg0, arg1, ret0]. - // arg0 = [b0, m, k] - // arg1 = [b0, k, n] - // ret0 = [b0, m, n] - // Note that we handle both batched (rank 3) and unbatched (rank 2). - uint32_t m = arg0.getRank() == 3 ? arg0.getDimSize(1) : arg0.getDimSize(0); - uint32_t k = arg0.getRank() == 3 ? arg0.getDimSize(2) : arg0.getDimSize(1); - uint32_t n = arg1.getRank() == 3 ? arg1.getDimSize(2) : arg1.getDimSize(1); - auto specializationInfoDef = - std::make_unique<iree::VkSpecializationInfoDefT>(); - addSpecializationMapEntry(/*kMatrixM*/ 100, m, specializationInfoDef.get()); - addSpecializationMapEntry(/*kMatrixK*/ 101, k, specializationInfoDef.get()); - addSpecializationMapEntry(/*kMatrixN*/ 102, n, specializationInfoDef.get()); - out_def->specialization_info = std::move(specializationInfoDef); - - return success(); -} - -} // namespace - -bool tryEmbeddedKernelRewrite(IREE::ExecutableOp executableOp, - iree::SpirVExecutableDefT *out_def) { - auto module = executableOp.getInnerModule(); - for (auto funcOp : module.getOps<FuncOp>()) { - if (funcOp.getAttr("iree.executable.reduction")) { - if (failed(buildReductionExecutable(executableOp, funcOp, out_def))) { - executableOp.emitOpError() << "Failed to splat in the reduction kernel"; - return false; - } - return true; - } - - for (auto &block : funcOp) { - for (auto &op : block) { - if (isa<xla_hlo::ConvOp>(&op)) { - executableOp.emitOpError() << "Conv not yet implemented"; - return false; - } else if (auto dotOp = dyn_cast_or_null<xla_hlo::DotOp>(&op)) { - if (failed(buildMatMulExecutable(executableOp, funcOp, dotOp, - out_def))) { - executableOp.emitOpError() - << "Failed to splat in the matmul kernel"; - return false; - } - return true; - } - } - } - } - return false; -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/EmbeddedKernels.h b/iree/compiler/Translation/SPIRV/EmbeddedKernels.h deleted file mode 100644 index 951a5ff..0000000 --- a/iree/compiler/Translation/SPIRV/EmbeddedKernels.h +++ /dev/null
@@ -1,35 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_COMPILER_TRANSLATION_SPIRV_EMBEDDEDKERNELS_H_ -#define IREE_COMPILER_TRANSLATION_SPIRV_EMBEDDEDKERNELS_H_ - -#include "flatbuffers/flatbuffers.h" -#include "iree/compiler/IR/StructureOps.h" -#include "iree/schemas/spirv_executable_def_generated.h" -#include "mlir/Support/LogicalResult.h" - -namespace mlir { -namespace iree_compiler { - -// Tries to match the |executableOp| against an embedded kernel and if matched -// will populate |out_def| with the kernel. -// Returns true if the kernel matched and was populated. -bool tryEmbeddedKernelRewrite(IREE::ExecutableOp executableOp, - iree::SpirVExecutableDefT* out_def); - -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_TRANSLATION_SPIRV_EMBEDDEDKERNELS_H_
diff --git a/iree/compiler/Translation/SPIRV/IREEIndexComputation.cpp b/iree/compiler/Translation/SPIRV/IREEIndexComputation.cpp deleted file mode 100644 index 54ad2f5..0000000 --- a/iree/compiler/Translation/SPIRV/IREEIndexComputation.cpp +++ /dev/null
@@ -1,107 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//===- IREEIndexComputation.cpp --------------------------------*- C++//-*-===// -// -// Implementaiton of Index Propagation for IREE statements that are used in -// dispatch functions. -// -//===----------------------------------------------------------------------===// -#include "iree/compiler/Translation/SPIRV/IREEIndexComputation.h" - -namespace mlir { -namespace iree_compiler { - -//===----------------------------------------------------------------------===// -// IREELoadInputOp -//===----------------------------------------------------------------------===// - -LogicalResult IREELoadIndexPropagation::propagateIndexMap( - Operation *operation, IndexComputationCache &indexMap) const { - auto loadOp = cast<IREE::LoadInputOp>(operation); - auto result = operation->getResult(0); - auto src = loadOp.src(); - auto resultType = result->getType().dyn_cast<RankedTensorType>(); - auto srcType = src->getType().dyn_cast<MemRefType>(); - if (!resultType || !srcType || resultType.getShape() != srcType.getShape()) { - return loadOp.emitError( - "mismatch in shape of the result tensor and source memref"); - } - // Initialize the storage for the src. - indexMap[src]; - for (auto &resultIndexMap : indexMap[operation->getResult(0)]) { - indexMap[src][resultIndexMap.first]; - resultIndexMap.second.push_back(resultIndexMap.first); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// IREEStoreOutputOp -//===----------------------------------------------------------------------===// - -LogicalResult IREEStoreIndexPropagation::propagateIndexMap( - Operation *operation, IndexComputationCache &indexMap) const { - auto storeOp = cast<IREE::StoreOutputOp>(operation); - auto src = storeOp.src(); - auto srcType = src->getType().dyn_cast<ShapedType>(); - if (!srcType || !srcType.hasStaticShape()) { - return storeOp.emitError( - "can only handle store with src being tensor of static shape"); - } - - SmallVector<int64_t, 3> launchSize; - if (failed(getLaunchSize(operation, launchSize))) { - return failure(); - } - - // The launch dimensions are [x, y, z] co-ordinates. The reverse of this is - // used to determine the location of the tensor element computed by a - // workitem. The choice is failry arbitrary but is done to enable the common - // case where consecutive workitems compute "logically" adjacent tensor - // elements. - Builder builder(storeOp.getContext()); - SmallVector<AffineExpr, 4> affineExprs; - int64_t numElements = 1; - for (size_t i = launchSize.size(); i > 0; --i) { - // If launchSize along any dimension is 1, just use 0 for the index. This is - // not just an optimization. If you have an output of type memref<f32> which - // is lowered to !spv.ptr<!spv.struct<f32>, StorageBuffer> with launchSize - // <1>, then spv.AccessChain requires the indices to be a constant. - if (launchSize[i - 1] == 1) { - affineExprs.push_back(builder.getAffineConstantExpr(0)); - } else { - affineExprs.push_back(builder.getAffineDimExpr(i - 1)); - } - numElements *= launchSize[i - 1]; - } - auto launchMap = AffineMap::get(launchSize.size(), 0, affineExprs); - - // The stored tensor can be a reshape of the launch dimension. It still - // retains the requirement that each workitem is computing a single element - // of the stored tensor. - AffineMap srcMap; - SmallVector<int64_t, 3> revLaunchSize(reverse(launchSize)); - if (numElements != srcType.getNumElements() || - failed(getReshapeOperandMap(builder, launchMap, revLaunchSize, - srcType.getShape(), srcMap))) { - return storeOp.emitError( - "unable to map from launch id to element to compute within a " - "workitem"); - } - indexMap[src][srcMap]; - return success(); -} -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/IREEIndexComputation.h b/iree/compiler/Translation/SPIRV/IREEIndexComputation.h deleted file mode 100644 index 13d9216..0000000 --- a/iree/compiler/Translation/SPIRV/IREEIndexComputation.h +++ /dev/null
@@ -1,92 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//===- IREEIndexComputation.h ----------------------------------*- C++//-*-===// -// -// Index Propagation for IREE statements that are used in dispatch functions. -// -//===----------------------------------------------------------------------===// -#ifndef IREE_COMPILER_TRANSLATION_SPIRV_H -#define IREE_COMPILER_TRANSLATION_SPIRV_H - -#include "iree/compiler/IR/Ops.h" -#include "iree/compiler/IR/StructureOps.h" -#include "iree/compiler/Translation/SPIRV/XLAIndexPropagation.h" -#include "mlir/IR/Function.h" - -namespace mlir { -namespace iree_compiler { - -/// Gets the launch size associated with the dispatch function that this op is -/// part of. -inline LogicalResult getLaunchSize(Operation *op, - SmallVectorImpl<int64_t> &launchSize) { - auto funcOp = op->getParentOfType<FuncOp>(); - if (!funcOp || !funcOp.getAttr("iree.executable.export")) { - return op->emitError( - "expected operation to be in dispatch function to get launch size"); - } - auto workloadAttr = - funcOp.getAttrOfType<DenseElementsAttr>("iree.executable.workload"); - if (!workloadAttr) { - op->emitError( - "unable to find workload size, missing attribute " - "iree.executable.workload in dispatch function"); - } - launchSize.clear(); - for (auto value : workloadAttr.getValues<APInt>()) { - launchSize.push_back(value.getSExtValue()); - } - // Drop trailing ones. - auto dropFrom = launchSize.size() - 1; - while (dropFrom > 0 && launchSize[dropFrom] == 1) { - --dropFrom; - } - if (dropFrom > 0) { - launchSize.erase(std::next(launchSize.begin(), dropFrom + 1), - launchSize.end()); - } - return success(); -} - -/// Index propagation for iree.load_input operation. This operation is -/// essentially a copy from a memref to a tensor. So just copy the index map to -/// the memref operand from the result tensor. -class IREELoadIndexPropagation final - : public IndexPropagationOp<IREE::LoadInputOp> { - public: - using IndexPropagationOp<IREE::LoadInputOp>::IndexPropagationOp; - - LogicalResult propagateIndexMap( - Operation *operation, IndexComputationCache &indexMap) const override; -}; - -/// Index propagation for iree.store_output operation. The launch size is -/// assumed to match the shape of the tensor that is being stored. This -/// operation acts as a seed for the index propogation. Each workitem is assumed -/// to compute a single element of this tensor. The range of the index map is -/// the reverse of the launch dimension. -class IREEStoreIndexPropagation final - : public IndexPropagationOp<IREE::StoreOutputOp> { - public: - using IndexPropagationOp<IREE::StoreOutputOp>::IndexPropagationOp; - - LogicalResult propagateIndexMap( - Operation *operation, IndexComputationCache &indexMap) const override; -}; - -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_TRANSLATION_SPIRV_H
diff --git a/iree/compiler/Translation/SPIRV/IREEToSPIRV.cpp b/iree/compiler/Translation/SPIRV/IREEToSPIRV.cpp deleted file mode 100644 index 909ce62..0000000 --- a/iree/compiler/Translation/SPIRV/IREEToSPIRV.cpp +++ /dev/null
@@ -1,77 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//===- IREEToSPIRV.cpp -----------------------------------------*- C++//-*-===// -// -// Translation of IREE statements in dispatch functions to SPIR-V. -// -//===----------------------------------------------------------------------===// -#include "iree/compiler/Translation/SPIRV/IREEToSPIRV.h" - -namespace mlir { -namespace iree_compiler { - -/// IREE::LoadInputOp is essentially a memcpy. Just update the `valueCache` with -/// the value of the operand. -LogicalResult IREELoadOpSPIRVLowering::lowerOperation( - Operation *op, OpBuilder &builder, AffineMap index, - ArrayRef<Value *> operands, ValueCache &valueCache) const { - auto loadOp = cast<IREE::LoadInputOp>(op); - auto result = loadOp.getResult(); - valueCache.setOperandDstValue(result, index, operands[0]); - return success(); -} - -/// IREE::StoreOp needs to write to the spv.globalVariable created for the -/// memref that holds the result of the dispatch function. -LogicalResult IREEStoreOpSPIRVLowering::lowerOperation( - Operation *op, OpBuilder &builder, AffineExprCodegen &affineExprCodegen, - ValueCache &valueCache, - DenseMap<Value *, spirv::GlobalVariableOp> &inputBuffers, - ArrayRef<spirv::GlobalVariableOp> outputBuffers) const { - auto storeOp = cast<IREE::StoreOutputOp>(op); - auto src = storeOp.src(); - auto indices = affineExprCodegen.getIndices(src); - if (indices.size() != 1) { - return storeOp.emitError( - "expected to compute a single element of the tensor that is stored " - "into the output memref"); - } - auto var = inputBuffers.lookup(storeOp.dst()); - if (!var) { - return storeOp.emitError( - "unable to find spv.globalVariable that corresponds to the dst memref"); - } - auto ptr = genPointerOffset(builder, storeOp.getLoc(), affineExprCodegen, - indices[0], var); - auto scalarValue = valueCache.getOperandDstValue(src, indices[0]); - builder.create<spirv::StoreOp>(storeOp.getLoc(), ptr, scalarValue, - /*memory_access = */ nullptr, - /*alignment = */ nullptr); - return success(); -} - -/// IREE::ReturnOp in dispatch functions lowered to SPIR-V should have no -/// operands. -LogicalResult IREEReturnOpSPIRVLowering::lowerOperation( - Operation *op, OpBuilder &builder, AffineExprCodegen &affineExprCodegen, - ValueCache &valueCache, - DenseMap<Value *, spirv::GlobalVariableOp> &inputBuffers, - ArrayRef<spirv::GlobalVariableOp> outputBuffers) const { - builder.create<spirv::ReturnOp>(op->getLoc()); - return success(); -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/IREEToSPIRV.h b/iree/compiler/Translation/SPIRV/IREEToSPIRV.h deleted file mode 100644 index 461e3f3..0000000 --- a/iree/compiler/Translation/SPIRV/IREEToSPIRV.h +++ /dev/null
@@ -1,69 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//===- IREEToSPIRV.h -------------------------------------------*- C++//-*-===// -// -// Translation of IREE statements in dispatch functions to SPIR-V. -// -//===----------------------------------------------------------------------===// -#ifndef IREE_COMPILER_TRANSLATION_SPIRV_IREETOSPIRV_H -#define IREE_COMPILER_TRANSLATION_SPIRV_IREETOSPIRV_H - -#include "iree/compiler/IR/Ops.h" -#include "iree/compiler/IR/StructureOps.h" -#include "iree/compiler/Translation/SPIRV/SPIRVLowering.h" - -namespace mlir { -namespace iree_compiler { - -/// Translation of iree.load_input operation. -class IREELoadOpSPIRVLowering final - : public SPIRVOpLowering<IREE::LoadInputOp> { - public: - using SPIRVOpLowering<IREE::LoadInputOp>::SPIRVOpLowering; - - LogicalResult lowerOperation(Operation *op, OpBuilder &builder, - AffineMap index, ArrayRef<Value *> operands, - ValueCache &valueCache) const override; -}; - -/// Translation of iree.return operation. -class IREEReturnOpSPIRVLowering final : public SPIRVOpLowering<IREE::ReturnOp> { - public: - using SPIRVOpLowering<IREE::ReturnOp>::SPIRVOpLowering; - - LogicalResult lowerOperation( - Operation *op, OpBuilder &builder, AffineExprCodegen &codegen, - ValueCache &valueCache, - DenseMap<Value *, spirv::GlobalVariableOp> &inputBuffers, - ArrayRef<spirv::GlobalVariableOp> outputBuffers) const override; -}; - -/// Translation of iree.store_output operation. -class IREEStoreOpSPIRVLowering final - : public SPIRVOpLowering<IREE::StoreOutputOp> { - public: - using SPIRVOpLowering<IREE::StoreOutputOp>::SPIRVOpLowering; - - LogicalResult lowerOperation( - Operation *op, OpBuilder &builder, AffineExprCodegen &codegen, - ValueCache &valueCache, - DenseMap<Value *, spirv::GlobalVariableOp> &inputBuffers, - ArrayRef<spirv::GlobalVariableOp> outputBuffers) const override; -}; - -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_TRANSLATION_SPIRV_IREETOSPIRV_H
diff --git a/iree/compiler/Translation/SPIRV/IREEToSPIRVPass.cpp b/iree/compiler/Translation/SPIRV/IREEToSPIRVPass.cpp deleted file mode 100644 index 7f1d75c..0000000 --- a/iree/compiler/Translation/SPIRV/IREEToSPIRVPass.cpp +++ /dev/null
@@ -1,196 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//===- IREEToSPIRVPass.cpp -------------------------------------*- C++//-*-===// -// -// Pass to translate iree executables for vulkan-spirv. -// -//===----------------------------------------------------------------------===// -#include "iree/compiler/Translation/SPIRV/IREEToSPIRVPass.h" - -#include "iree/compiler/Translation/SPIRV/IREEIndexComputation.h" -#include "iree/compiler/Translation/SPIRV/IREEToSPIRV.h" -#include "mlir/Dialect/SPIRV/SPIRVOps.h" -#include "mlir/Dialect/SPIRV/SPIRVTypes.h" -#include "mlir/Dialect/StandardOps/Ops.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -class IREEToSPIRVPass : public ModulePass<IREEToSPIRVPass> { - void runOnModule() override; -}; - -} // namespace - -void IREEToSPIRVPass::runOnModule() { - auto module = getModule(); - OpBuilder builder(module.getBodyRegion()); - - // Initialize the index computation. - IndexPropagationList<IndexPropagationOp<ConstantOp>, - // IREE-specific ops: - IndexPropagationOp<IREE::ReturnOp>, - IREELoadIndexPropagation, IREEStoreIndexPropagation, - // Standard dialect unary elementwise ops: - NoBroadcastPwOpIndexPropagation<SIToFPOp>, - NoBroadcastPwOpIndexPropagation<SignExtendIOp>, - // Standard dialect binary elementwise ops: - NoBroadcastPwOpIndexPropagation<AddFOp>, - NoBroadcastPwOpIndexPropagation<AddIOp>, - NoBroadcastPwOpIndexPropagation<AndOp>, - NoBroadcastPwOpIndexPropagation<CmpFOp>, - NoBroadcastPwOpIndexPropagation<CmpIOp>, - NoBroadcastPwOpIndexPropagation<DivFOp>, - NoBroadcastPwOpIndexPropagation<DivISOp>, - NoBroadcastPwOpIndexPropagation<DivIUOp>, - NoBroadcastPwOpIndexPropagation<MulFOp>, - NoBroadcastPwOpIndexPropagation<MulIOp>, - NoBroadcastPwOpIndexPropagation<OrOp>, - NoBroadcastPwOpIndexPropagation<RemFOp>, - NoBroadcastPwOpIndexPropagation<RemISOp>, - NoBroadcastPwOpIndexPropagation<RemIUOp>, - NoBroadcastPwOpIndexPropagation<SubFOp>, - NoBroadcastPwOpIndexPropagation<SubFOp>, - NoBroadcastPwOpIndexPropagation<SubIOp>, - NoBroadcastPwOpIndexPropagation<TruncateIOp>, - NoBroadcastPwOpIndexPropagation<XOrOp>, - NoBroadcastPwOpIndexPropagation<ZeroExtendIOp>, - // XLA unary elementwise ops: - NoBroadcastPwOpIndexPropagation<xla_hlo::AbsOp>, - NoBroadcastPwOpIndexPropagation<xla_hlo::CeilOp>, - NoBroadcastPwOpIndexPropagation<xla_hlo::ConvertOp>, - NoBroadcastPwOpIndexPropagation<xla_hlo::CosOp>, - NoBroadcastPwOpIndexPropagation<xla_hlo::ExpOp>, - NoBroadcastPwOpIndexPropagation<xla_hlo::FloorOp>, - NoBroadcastPwOpIndexPropagation<xla_hlo::LogOp>, - NoBroadcastPwOpIndexPropagation<xla_hlo::NegOp>, - NoBroadcastPwOpIndexPropagation<xla_hlo::RsqrtOp>, - NoBroadcastPwOpIndexPropagation<xla_hlo::SignOp>, - NoBroadcastPwOpIndexPropagation<xla_hlo::TanhOp>, - // XLA binary elementwise ops: - NoBroadcastPwOpIndexPropagation<xla_hlo::AddOp>, - NoBroadcastPwOpIndexPropagation<xla_hlo::AndOp>, - NoBroadcastPwOpIndexPropagation<xla_hlo::DivOp>, - NoBroadcastPwOpIndexPropagation<xla_hlo::MaxOp>, - NoBroadcastPwOpIndexPropagation<xla_hlo::MinOp>, - NoBroadcastPwOpIndexPropagation<xla_hlo::MulOp>, - NoBroadcastPwOpIndexPropagation<xla_hlo::SubOp>, - // XLA other ops: - // TODO(ravishankarm): conv, dot. - // TODO(ravishankarm): gather. - // TODO(ravishankarm): pad. - // TODO(ravishankarm): slice. - NoBroadcastPwOpIndexPropagation<xla_hlo::CopyOp>, - ReshapeOpIndexPropagation<xla_hlo::ReshapeOp>, - NoBroadcastPwOpIndexPropagation<xla_hlo::SelectOp>, - XLABroadcastOpIndexPropagation, - XLABroadcastInDimOpIndexPropagation, - XLAReverseOpIndexPropagation, - XLATransposeOpIndexPropagation> - indexPropagation; - - // Initialize the spir-v codegenerator. - SPIRVCodegen< - ConstantOpSPIRVLowering, - // IREE-specific ops: - IREELoadOpSPIRVLowering, IREEReturnOpSPIRVLowering, - IREEStoreOpSPIRVLowering, - // Standard dialect unary elementwise ops: - // Standard dialect binary elementwise ops: - SPIRVPwOpLowering<AddFOp, spirv::FAddOp>, - SPIRVPwOpLowering<DivFOp, spirv::FDivOp>, - SPIRVPwOpLowering<MulFOp, spirv::FMulOp>, - SPIRVPwOpLowering<SubFOp, spirv::FSubOp>, - SPIRVPwOpLowering<AddIOp, spirv::IAddOp>, - SPIRVPwOpLowering<DivISOp, spirv::SDivOp>, - SPIRVPwOpLowering<MulIOp, spirv::IMulOp>, - SPIRVPwOpLowering<SubIOp, spirv::ISubOp>, - // XLA unary elementwise ops: - SPIRVPwOpLowering<xla_hlo::AbsOp, spirv::GLSLSAbsOp, spirv::GLSLFAbsOp>, - SPIRVPwOpLowering<xla_hlo::CeilOp, spirv::GLSLCeilOp>, - // TODO(ravishankarm): xla_hlo::ConvertOp - SPIRVPwOpLowering<xla_hlo::CosOp, spirv::GLSLCosOp>, - SPIRVPwOpLowering<xla_hlo::ExpOp, spirv::GLSLExpOp>, - SPIRVPwOpLowering<xla_hlo::FloorOp, spirv::GLSLFloorOp>, - SPIRVPwOpLowering<xla_hlo::LogOp, spirv::GLSLLogOp>, - SPIRVPwOpLowering<xla_hlo::NegOp, spirv::FNegateOp>, - SPIRVPwOpLowering<xla_hlo::RsqrtOp, spirv::GLSLInverseSqrtOp>, - SPIRVPwOpLowering<xla_hlo::SignOp, spirv::GLSLSSignOp, - spirv::GLSLFSignOp>, - SPIRVPwOpLowering<xla_hlo::TanhOp, spirv::GLSLTanhOp>, - // XLA binary elementwise ops: - SPIRVPwOpLowering<xla_hlo::AddOp, spirv::IAddOp, spirv::FAddOp>, - SPIRVPwOpLowering<xla_hlo::AndOp, spirv::LogicalAndOp>, - SPIRVPwOpLowering<xla_hlo::DivOp, spirv::FDivOp>, - SPIRVPwOpLowering<xla_hlo::MaxOp, spirv::GLSLSMaxOp, spirv::GLSLFMaxOp>, - SPIRVPwOpLowering<xla_hlo::MinOp, spirv::GLSLSMinOp, spirv::GLSLFMinOp>, - SPIRVPwOpLowering<xla_hlo::MulOp, spirv::IMulOp, spirv::FMulOp>, - SPIRVPwOpLowering<xla_hlo::SubOp, spirv::ISubOp, spirv::FSubOp>, - // XLA other ops: - CmpFOpSPIRVLowering, - SPIRVPwOpLowering<xla_hlo::SelectOp, spirv::SelectOp>, - SPIRVIndexOpLowering<xla_hlo::BroadcastOp>, - SPIRVIndexOpLowering<xla_hlo::BroadcastInDimOp>, - SPIRVIndexOpLowering<xla_hlo::CopyOp>, - SPIRVIndexOpLowering<xla_hlo::ReshapeOp>, - SPIRVIndexOpLowering<xla_hlo::ReverseOp>, - SPIRVIndexOpLowering<xla_hlo::TransposeOp>> - spirvCodegen; - - // Create a spirv.module Op. - auto spvModule = builder.create<spirv::ModuleOp>( - module.getLoc(), - builder.getI32IntegerAttr( - static_cast<int32_t>(spirv::AddressingModel::Logical)), - builder.getI32IntegerAttr( - static_cast<int32_t>(spirv::MemoryModel::GLSL450))); - SmallVector<StringRef, 2> caps; - caps.push_back(spirv::stringifyCapability(spirv::Capability::Shader)); - spvModule.setAttr("capabilities", builder.getStrArrayAttr(caps)); - SmallVector<StringRef, 2> exts; - exts.push_back("SPV_KHR_storage_buffer_storage_class"); - spvModule.setAttr("extensions", builder.getStrArrayAttr(exts)); - - for (auto funcOp : module.getOps<FuncOp>()) { - // TODO(ravishankarm): FuncOps in executable that are not dispatch functions - // are not lowered to SPIR-V. Fix this limitation. - if (!funcOp.getAttr("iree.executable.export")) continue; - - IndexComputationCache indexMap; - if (failed(indexPropagation.propagate(funcOp.getBody(), indexMap))) { - return signalPassFailure(); - } - // dumpIndexCache(indexMap); - - ValueCache valueCache; - AffineExprCodegen affineExprCodegen(spvModule, indexMap); - if (failed(spirvCodegen.codegen(spvModule, funcOp, affineExprCodegen, - valueCache))) { - return signalPassFailure(); - } - } -} - -std::unique_ptr<OpPassBase<ModuleOp>> createIREEToSPIRVPass() { - return std::make_unique<IREEToSPIRVPass>(); -} -static PassRegistration<IREEToSPIRVPass> pass( - "convert-iree-to-spirv", - "Convert IREE dispatch functions to SPIR-V dialect"); - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/IndexComputation.cpp b/iree/compiler/Translation/SPIRV/IndexComputation.cpp deleted file mode 100644 index fdb8464..0000000 --- a/iree/compiler/Translation/SPIRV/IndexComputation.cpp +++ /dev/null
@@ -1,269 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//===- IndexComputation.cpp ------------------------------------*- C++//-*-===// -// -// For an IREE dispatch function, compute the map from workitem ID to index of -// tensor computed within that workitem. -// -//===----------------------------------------------------------------------===// -#include "iree/compiler/Translation/SPIRV/IndexComputation.h" - -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/raw_ostream.h" - -static llvm::cl::opt<bool> doAffineExprSimplify( - "simplify-spirv-affine-exprs", - llvm::cl::desc("Simplify affine expressions during code-generation."), - llvm::cl::init(true)); - -namespace mlir { -namespace iree_compiler { - -//===----------------------------------------------------------------------===// -// Reshape Utility Functions -//===----------------------------------------------------------------------===// - -namespace { -/// Handles shapes for scalars. Shape of scalars are represented as empty vetor, -/// i.e. {}. Its easier to do index propogation to handle the scalar as vector -/// of size 1. -inline SmallVector<int64_t, 4> handleIfScalar(ArrayRef<int64_t> shape) { - SmallVector<int64_t, 4> resultShape; - if (shape.empty()) { - return {1}; - } - return SmallVector<int64_t, 4>(shape.begin(), shape.end()); -} - -/// Reshapes are often used to either add a dimension of size 1 or remove a -/// dimension of size 1. Recognizing such cases can make the code-generation -/// easier. The AffineMap needs to either add a constant 0 in the range for such -/// added dimensions or drop those dimensions. -inline LogicalResult getAffineExprForAddOrRemoveDimension( - Builder &builder, ArrayRef<AffineExpr> resultExprs, - ArrayRef<int64_t> resultShape, ArrayRef<int64_t> operandShape, - SmallVectorImpl<AffineExpr> &operandExprs) { - auto resultIndex = resultShape.size(); - auto operandIndex = operandShape.size(); - operandExprs.resize(operandShape.size()); - // Try to match up the dimensions of the operand and result by ignoring any - // dimensions of size of 1 that are introduced. - while (resultIndex > 0 && operandIndex > 0) { - if (resultShape[resultIndex - 1] == -1 || - operandShape[operandIndex - 1] == -1) { - return failure(); - } - if (resultShape[resultIndex - 1] == operandShape[operandIndex - 1]) { - operandExprs[operandIndex - 1] = resultExprs[resultIndex - 1]; - resultIndex--; - operandIndex--; - continue; - } - if (resultShape[resultIndex - 1] == 1) { - // This is a dimension that is added on the operand. This affine - // expression corresponding to this dimension is dropped. - resultIndex--; - continue; - } - if (operandShape[operandIndex - 1] == 1) { - // This is a dimension of size 1 of the operand that is dropped. Add a - // constant expr 0. - operandExprs[operandIndex - 1] = builder.getAffineConstantExpr(0); - operandIndex--; - continue; - } - return failure(); - } - // Any remaining dimensions should be 1. - while (resultIndex > 0) { - if (resultShape[resultIndex - 1] != 1) { - return failure(); - } - resultIndex--; - } - while (operandIndex > 0) { - if (operandShape[operandIndex - 1] != 1) { - return failure(); - } - // This is a dimension of size 1 that is dropped. Add a constant expression - // 0. - operandExprs[operandIndex - 1] = builder.getAffineConstantExpr(0); - operandIndex--; - } - return success(); -} - -/// Constructs the strides of an array assuming a row-major packed layout. -// TODO(ravishankarm): This assumes the shape are static. When using dynamic -// shapes, parameters of each dimension can be used to construct AffineExpr for -// strides along each dimension. Note that multiplying two symbolic constants is -// technically not affine, but you could use another symbol to represent the -// product, so it should be still representable as affine exprs. -inline LogicalResult getRowMajorPackedStrides( - Builder &builder, ArrayRef<int64_t> shape, - SmallVectorImpl<AffineExpr> &strides) { - strides.resize(shape.size()); - int64_t stride = 1; - for (auto dim : enumerate(reverse(shape))) { - if (dim.value() < 0) { - // TODO(ravishankarm) : Better error message. - return failure(); - } - strides[shape.size() - 1 - dim.index()] = - builder.getAffineConstantExpr(stride); - stride *= dim.value(); - } - return success(); -} - -/// Linearizes the index of the result position accessed using the shape of the -/// result tensor and delinearizes it to get the position of the operand. -inline LogicalResult getAffineExprForReshape( - Builder &builder, unsigned numDims, unsigned numSymbols, - ArrayRef<AffineExpr> resultExprs, ArrayRef<int64_t> resultShape, - ArrayRef<int64_t> operandShape, SmallVectorImpl<AffineExpr> &operandExprs) { - // To linearize the index, assume that the memory is laid out in - // packed-row-major layout based on the shape. - // TODO(ravishankarm) : When there is stride information, use that to map from - // index to memory location. - SmallVector<AffineExpr, 4> resultStrides; - if (failed(getRowMajorPackedStrides(builder, resultShape, resultStrides))) { - return failure(); - } - AffineExpr linearizedExpr; - for (auto index : enumerate(resultExprs)) { - auto val = getAffineBinaryOpExpr(AffineExprKind::Mul, index.value(), - resultStrides[index.index()]); - if (doAffineExprSimplify) { - val = simplifyAffineExpr(val, numDims, numSymbols); - } - linearizedExpr = (index.index() ? getAffineBinaryOpExpr(AffineExprKind::Add, - linearizedExpr, val) - : val); - if (doAffineExprSimplify) { - linearizedExpr = simplifyAffineExpr(val, numDims, numSymbols); - } - } - - // Unlinearize the index, assuming row-major-packed layout. - // TODO(ravishankarm) : When there is stride information, use that to map from - // memory location to index. - SmallVector<AffineExpr, 4> operandStrides; - if (failed(getRowMajorPackedStrides(builder, operandShape, operandStrides))) { - return failure(); - } - operandExprs.resize(operandStrides.size()); - for (auto stride : enumerate(operandStrides)) { - if (stride.index() == operandStrides.size() - 1) { - operandExprs[stride.index()] = linearizedExpr; - break; - } - auto expr = getAffineBinaryOpExpr(AffineExprKind::FloorDiv, linearizedExpr, - stride.value()); - operandExprs[stride.index()] = - (doAffineExprSimplify ? simplifyAffineExpr(expr, numDims, numSymbols) - : expr); - - linearizedExpr = getAffineBinaryOpExpr(AffineExprKind::Mod, linearizedExpr, - stride.value()); - if (doAffineExprSimplify) { - linearizedExpr = simplifyAffineExpr(linearizedExpr, numDims, numSymbols); - } - } - return success(); -} -} // namespace - -LogicalResult getReshapeOperandMap(Builder &builder, AffineMap resultIndexMap, - ArrayRef<int64_t> resultShapeRef, - ArrayRef<int64_t> operandShapeRef, - AffineMap &operandIndexMap) { - auto resultShape = handleIfScalar(resultShapeRef); - auto operandShape = handleIfScalar(operandShapeRef); - auto resultExprs = resultIndexMap.getResults(); - assert(resultShape.size() == resultExprs.size() && - "Ranks of the Domain of index map and result must be the same"); - SmallVector<AffineExpr, 4> operandExprs; - if (failed(getAffineExprForAddOrRemoveDimension( - builder, resultExprs, resultShape, operandShape, operandExprs)) && - failed(getAffineExprForReshape( - builder, resultIndexMap.getNumDims(), resultIndexMap.getNumSymbols(), - resultExprs, resultShape, operandShape, operandExprs))) { - return failure(); - } - assert(operandExprs.size() == operandShape.size() && - "expected as many exprs for the operand as the rank of the operand"); - operandIndexMap = - AffineMap::get(resultIndexMap.getNumDims(), - resultIndexMap.getNumSymbols(), operandExprs); - - return success(); -} - -LogicalResult IndexPropagation::propagateIndexMap( - Operation *op, IndexComputationCache &indexMap) const { - if (op->getNumResults() == 0) { - // Nothing to do for this op. - return success(); - } - if (op->getNumResults() != 1) { - return op->emitError( - "default index propagation handles case with a single-return value"); - } - // Initialize the storage for all the operands. - for (auto arg : op->getOperands()) { - indexMap[arg]; - } - for (auto &resultIndexMap : indexMap[op->getResult(0)]) { - SmallVector<AffineMap, 4> operandIndices; - if (failed(this->propagateIndexMap(op, resultIndexMap.first, - operandIndices))) { - return failure(); - } - assert(operandIndices.size() == op->getNumOperands() && - "Expected as many indices as operands"); - for (auto arg : enumerate(op->getOperands())) { - indexMap[arg.value()][operandIndices[arg.index()]]; - resultIndexMap.second.push_back(operandIndices[arg.index()]); - } - } - return success(); -} - -void dumpIndexCache(IndexComputationCache &indexMap) { - for (auto &el : indexMap) { - // llvm::errs() << "Value : " << *(el.first); - // llvm::errs().flush(); - if (isa<OpResult>(el.first)) { - llvm::errs() << "Operation : " << el.first->getDefiningOp()->getName(); - } else if (isa<BlockArgument>(el.first)) { - llvm::errs() << "BlockArgument"; - } - for (auto &used : el.second) { - llvm::errs() << "\n\t" << used.first << " : ["; - std::string sep = ""; - for (auto &operand : used.second) { - llvm::errs() << sep << operand; - sep = ", "; - } - llvm::errs() << "]"; - } - llvm::errs() << "\n"; - } - llvm::errs() << "\n"; -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/Kernels/spirv_utils.bzl b/iree/compiler/Translation/SPIRV/Kernels/spirv_utils.bzl deleted file mode 100644 index 49f03aa..0000000 --- a/iree/compiler/Translation/SPIRV/Kernels/spirv_utils.bzl +++ /dev/null
@@ -1,32 +0,0 @@ -"""Utilities for handling hand-written SPIR-V files.""" - -load("//iree:build_defs.bzl", "iree_glsl_vulkan") -load("//build_tools/embed_data:build_defs.bzl", "cc_embed_data") - -def spirv_kernel_cc_library(name, srcs): - """Compiles GLSL files into SPIR-V binaries and embeds them in a cc_library. - - Args: - name: cc_library name to depend on. - srcs: a list of GLSL source files. - """ - spv_files = [] - for src in srcs: - spv_name = src.split(".")[-2] - iree_glsl_vulkan( - name = spv_name, - srcs = [src], - ) - spv_files.append(spv_name + ".spv") - native.filegroup( - name = name + "_files", - srcs = spv_files, - ) - cc_embed_data( - name = name, - srcs = spv_files, - cc_file_output = name + ".cc", - h_file_output = name + ".h", - cpp_namespace = "mlir::iree_compiler::spirv_kernels", - flatten = True, - )
diff --git a/iree/compiler/Translation/SPIRV/SPIRVExecutableTranslation.cpp b/iree/compiler/Translation/SPIRV/SPIRVExecutableTranslation.cpp deleted file mode 100644 index a014d26..0000000 --- a/iree/compiler/Translation/SPIRV/SPIRVExecutableTranslation.cpp +++ /dev/null
@@ -1,314 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/Translation/SPIRV/SPIRVExecutableTranslation.h" - -#include <cstdint> -#include <iostream> -#include <map> -#include <vector> - -#include "flatbuffers/flatbuffers.h" -#include "iree/compiler/Translation/SPIRV/EmbeddedKernels.h" -#include "iree/compiler/Translation/SPIRV/IREEToSPIRVPass.h" -#include "iree/compiler/Utils/OpUtils.h" -#include "iree/compiler/Utils/TranslationUtils.h" -#include "iree/schemas/executable_def_generated.h" -#include "iree/schemas/spirv_executable_def_generated.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" -#include "mlir/Dialect/SPIRV/SPIRVOps.h" -#include "mlir/Dialect/SPIRV/Serialization.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Module.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/Passes.h" -#include "mlir/Translation.h" -#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" -#include "tensorflow/compiler/mlir/xla/transforms/passes.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -class SPIRVTranslator { - public: - explicit SPIRVTranslator(ExecutableTranslationOptions options) - : options_(options) {} - - const ExecutableTranslationOptions &options() const { return options_; } - - // Returns a populated ExecutableDef or nullptr if translation is - // unsuccessful. - std::unique_ptr<iree::ExecutableDefT> translateExecutable( - IREE::ExecutableOp executableOp); - - private: - // Returns a list of entry point names matching the expected export ordinals. - std::vector<std::string> populateEntryPointNames( - IREE::ExecutableOp executableOp); - - // Translates the input module into the SPIR-V dialect and returns the - // serialized code words or empty if translation failed. - std::vector<uint32_t> translateAndSerializeShaderModule( - IREE::ExecutableOp executableOp); - - // Returns a pipeline layout definition based on the bindings required. - std::unique_ptr<iree::VkPipelineLayoutDefT> populatePipelineLayout( - spirv::ModuleOp spirvModuleOp); - - ExecutableTranslationOptions options_; -}; - -std::unique_ptr<iree::ExecutableDefT> SPIRVTranslator::translateExecutable( - IREE::ExecutableOp executableOp) { - // Try first to match against an embedded kernel (such as matmul) and - // otherwise fall back to generating the kernel. - iree::SpirVExecutableDefT spirvExecutableDef; - if (!tryEmbeddedKernelRewrite(executableOp, &spirvExecutableDef)) { - // The sequencer and runtime use ordinals instead of names. We provide the - // list of entry point names here that are then passed in - // VkShaderModuleCreateInfo. - spirvExecutableDef.entry_points = populateEntryPointNames(executableOp); - - // Translate the module and generate the SPIR-V code. - // The module is expected to be modified and must contain the metadata - // required to enable the following information needed for the - // SpirVExecutableDef to be extracted. - spirvExecutableDef.code = translateAndSerializeShaderModule(executableOp); - if (spirvExecutableDef.code.empty()) { - executableOp.emitError() - << "Failed to translate and serialize SPIR-V executable"; - return {}; - } - - // Reflect against the entry thunk to identify the required pipeline - // layout based on binding information. This is used by the runtime to - // create the VkPipelineLayout. - for (auto spirvModuleOp : - executableOp.getBlock().getOps<spirv::ModuleOp>()) { - spirvExecutableDef.pipeline_layout = - populatePipelineLayout(spirvModuleOp); - if (!spirvExecutableDef.pipeline_layout) { - spirvModuleOp.emitError() - << "Failed to generate pipeline for SPIR-V module"; - return {}; - } - break; - } - } - - // Pack the executable definition and get the bytes with the proper header. - // The header is used to verify the contents at runtime. - ::flatbuffers::FlatBufferBuilder fbb; - auto executableOffset = - iree::SpirVExecutableDef::Pack(fbb, &spirvExecutableDef); - iree::FinishSpirVExecutableDefBuffer(fbb, executableOffset); - std::vector<uint8_t> bytes; - bytes.resize(fbb.GetSize()); - std::memcpy(bytes.data(), fbb.GetBufferPointer(), bytes.size()); - - OpBuilder builder(executableOp); - executableOp.setAttr("format", builder.getI32IntegerAttr(static_cast<int32_t>( - IREE::ExecutableFormat::SpirV))); - - auto executableDef = std::make_unique<iree::ExecutableDefT>(); - executableDef->format = static_cast<uint32_t>(IREE::ExecutableFormat::SpirV); - executableDef->contents = std::move(bytes); - return executableDef; -} - -std::vector<std::string> SPIRVTranslator::populateEntryPointNames( - IREE::ExecutableOp executableOp) { - auto module = executableOp.getInnerModule(); - DenseMap<unsigned, StringRef> entryPoints; - for (auto funcOp : module.getOps<FuncOp>()) { - if (!funcOp.getAttr("iree.executable.export")) continue; - auto ordinalAttr = funcOp.getAttrOfType<IntegerAttr>("iree.ordinal"); - entryPoints[ordinalAttr.getInt()] = funcOp.getName(); - } - std::vector<std::string> entryPointNames(entryPoints.size()); - for (auto &entry : entryPoints) { - entryPointNames[entry.first] = entry.second.str(); - } - return entryPointNames; -} - -std::vector<uint32_t> SPIRVTranslator::translateAndSerializeShaderModule( - IREE::ExecutableOp executableOp) { - auto module = executableOp.getInnerModule(); - - // We can use the workload hint to know what the expected dispatch workload - // is. If we want to remap this to make more sense for the operations we are - // performing we can do that here. - // - // Note that workloads are computed per entry point. There may be some - // dimensions of the workload that are static (in which case workloadAttr will - // have non-dynamic dims) and others that need to be taken from an argument - // shape (in which case workloadRef is the argument ordinal to take dynamic - // dimensions from). - // TODO(benvanik): make it just an arg instead? iree.workload special op? - // TODO(benvanik): instead of FuncOp have an iree.entry_point op with these. - for (auto funcOp : module.getOps<FuncOp>()) { - // TODO(ravishankarm): FuncOps in executable that are not dispatch functions - // are not lowered to SPIR-V. Fix this limitation. - if (!funcOp.getAttr("iree.executable.export")) continue; - auto workloadAttr = - funcOp.getAttrOfType<ElementsAttr>("iree.executable.workload"); - auto workloadRefAttr = - funcOp.getAttrOfType<IntegerAttr>("iree.executable.workload_ref"); - std::array<int32_t, 3> staticWorkloadDims = {-1, -1, -1}; - if (workloadAttr) { - for (unsigned i = 0; i < 3; ++i) { - if (auto dimAttr = - workloadAttr.getValue({i}).dyn_cast_or_null<IntegerAttr>()) { - staticWorkloadDims[i] = dimAttr.getInt(); - } - } - } - std::array<BlockArgument *, 3> dynamicWorkloadDimRefs; - if (workloadRefAttr) { - for (unsigned i = 0; i < 3; ++i) { - if (staticWorkloadDims[i] == -1) { - dynamicWorkloadDimRefs[i] = - funcOp.getArgument(workloadRefAttr.getInt()); - } - } - } - - // Now staticWorkloadDims will have non-negative values for known dimensions - // and any dim with -1 will need to be pulled from the corresponding shape - // dimension of dynamicWorkloadDimRefs. - - // TODO(b/137868263): use this information to map from workgroup to - // invocation and perform indexing. - } - - // Lower module to spirv::ModuleOp. - auto spirvGenPasses = createPassManager(module.getContext(), options()); - spirvGenPasses->addPass(xla_hlo::createLegalizeToStdPass()); - spirvGenPasses->addPass(createIREEToSPIRVPass()); - if (failed(runPassPipeline(options(), spirvGenPasses.get(), module))) { - executableOp.emitError() << "Failed to generate spv.module"; - return {}; - } - - auto spvModules = module.getOps<spirv::ModuleOp>(); - if (std::distance(spvModules.begin(), spvModules.end()) != 1) { - executableOp.emitError() - << "Expected a single spv.module for an IREE executable op"; - return {}; - } - - // Serialize the spirv::ModuleOp into the binary that we will embed in the - // final flatbuffer. - std::vector<uint32_t> spvBinaries; - for (auto spvModule : spvModules) { - SmallVector<uint32_t, 256> spvBinary; - if (failed(spirv::serialize(spvModule, spvBinary))) { - executableOp.emitError() << "Failed to serialize spv.module"; - return {}; - } - spvBinaries.insert(spvBinaries.end(), spvBinary.begin(), spvBinary.end()); - - // Clone the module into executableOp directly. - auto clonedModule = spvModule.clone(); - executableOp.getBlock().getOperations().insert( - std::prev(executableOp.getBlock().getOperations().end()), clonedModule); - } - // Remove the original code. - module.erase(); - - return spvBinaries; -} - -std::unique_ptr<iree::VkPipelineLayoutDefT> -SPIRVTranslator::populatePipelineLayout(spirv::ModuleOp spirvModuleOp) { - // NOTE: we currently make some assumptions about this based on the expected - // ABI of the runtime. If we wanted to support more general shaders with more - // complex I/O we'd need to find a better way to communicate this through the - // VkPipelineLayoutDef. - auto pipelineLayoutDef = std::make_unique<iree::VkPipelineLayoutDefT>(); - pipelineLayoutDef->buffer_binding_set = 0; - - // Build a set of descriptor_set -> binding -> variable. - // This makes it easier to write out the descriptor in a logical order, even - // though this is not strictly required. - int64_t maxDescriptorSetOrdinal = -1; - std::map<int32_t, std::map<int32_t, spirv::GlobalVariableOp>> descriptorSets; - for (auto globalVar : - spirvModuleOp.getBlock().getOps<spirv::GlobalVariableOp>()) { - auto descriptorSetAttr = - globalVar.getAttrOfType<IntegerAttr>("descriptor_set"); - auto bindingAttr = globalVar.getAttrOfType<IntegerAttr>("binding"); - if (!descriptorSetAttr || !bindingAttr) { - // Not something the runtime cares about. - continue; - } - maxDescriptorSetOrdinal = - std::max(descriptorSetAttr.getInt(), maxDescriptorSetOrdinal); - auto &descriptorSet = descriptorSets[descriptorSetAttr.getInt()]; - descriptorSet[bindingAttr.getInt()] = globalVar; - } - - // Create the individual layout and binding defs. - pipelineLayoutDef->descriptor_set_layouts.resize(maxDescriptorSetOrdinal + 1); - for (auto &descriptorSetBindings : descriptorSets) { - int32_t descriptorSet = descriptorSetBindings.first; - auto dsl = std::make_unique<iree::VkDescriptorSetLayoutDefT>(); - - for (auto &globalVarBinding : descriptorSetBindings.second) { - auto binding = std::make_unique<iree::VkDescriptorSetLayoutBindingDefT>(); - binding->binding = globalVarBinding.first; - binding->descriptor_count = 1; - // TODO(benvanik): pull from type info. - binding->descriptor_type = 7; // VK_DESCRIPTOR_TYPE_STORAGE_BUFFER - binding->stage_flags = 0x00000020; // VK_SHADER_STAGE_COMPUTE_BIT - dsl->bindings.push_back(std::move(binding)); - } - - pipelineLayoutDef->descriptor_set_layouts[descriptorSet] = std::move(dsl); - } - - return pipelineLayoutDef; -} - -} // namespace - -llvm::Optional<ExecutableTranslationResult> -translateExecutableToSPIRVExecutable(ArrayRef<IREE::ExecutableOp> executableOps, - ExecutableTranslationOptions options) { - SPIRVTranslator translator(options); - ExecutableTranslationResult translationResult; - for (auto executableOp : llvm::make_early_inc_range(executableOps)) { - auto executableDef = translator.translateExecutable(executableOp); - if (!executableDef) { - executableOp.emitError() << "Failed to translate one or more executables"; - return llvm::None; - } - translationResult.executable_defs.push_back(std::move(executableDef)); - } - return translationResult; -} - -static ExecutableTranslationRegistration SPIRVExecutableTranslationRegistration( - "vulkan-spirv", translateExecutableToSPIRVExecutable); - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/SPIRVExecutableTranslation.h b/iree/compiler/Translation/SPIRV/SPIRVExecutableTranslation.h deleted file mode 100644 index 68e3d5d..0000000 --- a/iree/compiler/Translation/SPIRV/SPIRVExecutableTranslation.h +++ /dev/null
@@ -1,38 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_COMPILER_TRANSLATION_SPIRV_SPIRVEXECUTABLETRANSLATION_H_ -#define IREE_COMPILER_TRANSLATION_SPIRV_SPIRVEXECUTABLETRANSLATION_H_ - -#include <vector> - -#include "iree/compiler/IR/StructureOps.h" -#include "iree/compiler/Utils/TranslationUtils.h" -#include "mlir/IR/Module.h" - -namespace mlir { -namespace iree_compiler { - -// Translates an MLIR module into a SPIR-V executable. -// These executables are stored as FlatBuffers in the -// https://github.com/google/iree/tree/master/iree/schemas/spirv_executable_def.fbs -// schema. -llvm::Optional<ExecutableTranslationResult> -translateExecutableToSPIRVExecutable(ArrayRef<IREE::ExecutableOp> executableOps, - ExecutableTranslationOptions options = {}); - -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_TRANSLATION_SPIRV_SPIRVEXECUTABLETRANSLATION_H_
diff --git a/iree/compiler/Translation/SPIRV/SPIRVLowering.cpp b/iree/compiler/Translation/SPIRV/SPIRVLowering.cpp deleted file mode 100644 index a1f8e35..0000000 --- a/iree/compiler/Translation/SPIRV/SPIRVLowering.cpp +++ /dev/null
@@ -1,131 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//===- SPIRVLowering.cpp ---------------------------------------*- C++//-*-===// -// -// SPIR-V Code-generation for XLA-HLO Ops within IREE Dispatch functions -// -//===----------------------------------------------------------------------===// -#include "iree/compiler/Translation/SPIRV/SPIRVLowering.h" - -namespace mlir { -namespace iree_compiler { -//===----------------------------------------------------------------------===// -// ConstantOp -//===----------------------------------------------------------------------===// -LogicalResult ConstantOpSPIRVLowering::lowerOperation( - Operation *op, OpBuilder &builder, AffineMap index, ArrayRef<Value *>, - ValueCache &valueCache) const { - auto constOp = cast<ConstantOp>(op); - auto attr = constOp.value().dyn_cast<DenseElementsAttr>(); - if (!attr || !attr.isSplat()) { - return op->emitError( - "unhandled constant lowering unless value is a splat dense element " - "attribute"); - } - auto resultType = constOp.getResult()->getType(); - Type resultElemType; - if (resultType.isIntOrFloat()) { - resultElemType = resultType; - } else if (auto shapedType = resultType.dyn_cast<ShapedType>()) { - resultElemType = shapedType.getElementType(); - } else { - return op->emitError("unhandled result type of constant : ") << resultType; - } - Attribute constVal = attr.getSplatValue(); - auto spirvConstOp = - builder.create<spirv::ConstantOp>(op->getLoc(), resultElemType, constVal); - valueCache.setOperandDstValue(constOp.getResult(), index, - spirvConstOp.getResult()); - return success(); -} - -//===----------------------------------------------------------------------===// -// CmpFOp -//===----------------------------------------------------------------------===// -LogicalResult CmpFOpSPIRVLowering::lowerOperation( - Operation *op, OpBuilder &builder, AffineMap index, - ArrayRef<Value *> operands, ValueCache &valueCache) const { - if (operands.size() != 2) { - return op->emitError("expected two operands in spir-v lowering of CmpFOp"); - } - Operation *spirvOp = nullptr; - auto opInfo = op->getAttrOfType<IntegerAttr>(CmpFOp::getPredicateAttrName()); - if (!opInfo) { - return op->emitError("expected CmpFOp to contain ") - << CmpFOp::getPredicateAttrName() << " attribute"; - } - auto boolType = builder.getI1Type(); - auto predicateVal = static_cast<CmpFPredicate>(opInfo.getInt()); - switch (predicateVal) { -#define DISPATCH(caseLabel, opName) \ - case caseLabel: \ - spirvOp = builder.create<opName>(op->getLoc(), boolType, operands[0], \ - operands[1]); \ - break; - - DISPATCH(CmpFPredicate::OEQ, spirv::FOrdEqualOp); - DISPATCH(CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp); - DISPATCH(CmpFPredicate::OGT, spirv::FOrdGreaterThanOp); - DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp); - DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp); - DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp); - DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp); - DISPATCH(CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp); - DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp); - DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp); - DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp); - DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp); - -#undef DISPATCH - - default: - return op->emitError("unhandled predicate attribute for SPIR-V lowering"); - } - valueCache.setOperandDstValue(op->getResult(0), index, spirvOp->getResult(0)); - return success(); -} - -//===----------------------------------------------------------------------===// -// ReturnOp -//===----------------------------------------------------------------------===// -LogicalResult ReturnOpSPIRVLowering::lowerOperation( - Operation *op, OpBuilder &builder, AffineExprCodegen &affineExprCodegen, - ValueCache &valueCache, - DenseMap<Value *, spirv::GlobalVariableOp> &inputBuffers, - ArrayRef<spirv::GlobalVariableOp> outputBuffers) const { - auto returnOp = cast<ReturnOp>(op); - if (returnOp.getNumOperands() != 1) { - return returnOp.emitError( - "unhandled lowering of return statement with multiple returns"); - } - auto returnTensor = returnOp.getOperand(0); - auto indices = affineExprCodegen.getIndices(returnTensor); - if (indices.size() != 1) { - return returnOp.emitError( - "expected to compute a single element of the return tensor"); - } - assert(outputBuffers.size() == 1 && "Expected a single output buffer"); - auto var = outputBuffers[0]; - auto ptr = genPointerOffset(builder, returnOp.getLoc(), affineExprCodegen, - indices[0], var); - auto scalarVal = valueCache.getOperandDstValue(returnTensor, indices[0]); - builder.create<spirv::StoreOp>(returnOp.getLoc(), ptr, scalarVal, - /*memory_access = */ nullptr, - /*alignment = */ nullptr); - builder.create<spirv::ReturnOp>(returnOp.getLoc()); - return success(); -} -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/SPIRVLowering.h b/iree/compiler/Translation/SPIRV/SPIRVLowering.h deleted file mode 100644 index f6de5d0..0000000 --- a/iree/compiler/Translation/SPIRV/SPIRVLowering.h +++ /dev/null
@@ -1,591 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//===- SPIRVLowering.h -----------------------------------------*- C++//-*-===// -// -// SPIR-V Code-generation for tensor operations within IREE Dispatch functions -// -//===----------------------------------------------------------------------===// -#ifndef IREE_COMPILER_TRANSLATION_SPIRV_SPIRVLOWERING_H -#define IREE_COMPILER_TRANSLATION_SPIRV_SPIRVLOWERING_H - -#include "iree/compiler/Translation/SPIRV/AffineExprCodegen.h" -#include "mlir/Dialect/SPIRV/SPIRVDialect.h" -#include "mlir/Dialect/SPIRV/SPIRVOps.h" -#include "mlir/Support/StringExtras.h" - -namespace mlir { -namespace iree_compiler { - -class ValueCache { - public: - Value *getOperandDstValue(Value *value, AffineMap index) { - return convertedValueMap.lookup(value).lookup(index); - } - - void setOperandDstValue(Value *value, AffineMap index, Value *scalar) { - convertedValueMap[value][index] = scalar; - } - - private: - DenseMap<Value *, DenseMap<AffineMap, Value *>> convertedValueMap; -}; - -/// Base class for lowering tensor operations in the dispatch function to SPIR-V -/// op. -class SPIRVLowering { - public: - virtual ~SPIRVLowering() = default; - virtual StringRef getOpName() = 0; - /// This method (in the derived class) should generate the scalar operation - /// corresponding the the tensor operation `op` to generate the value of the - /// result tensor at a particular `index`. The scalar value of the operands - /// needed to compute this value is passed in within `operands`. The methods - /// have to insert the scalar result value of the generated operation into the - /// `valueCache`. - virtual LogicalResult lowerOperation(Operation *op, OpBuilder &builder, - AffineMap index, - ArrayRef<Value *> operands, - ValueCache &valueCache) const { - return failure(); - } - - /// This method (in the derived class) should generate the scalar operations - /// corresponding to the tensor operation `op`. This should be implemented - /// when the `op` has no result value, typically store operations and return - /// operations. - virtual LogicalResult lowerOperation( - Operation *op, OpBuilder &builder, AffineExprCodegen &affineExprCodegen, - ValueCache &valueCache, - DenseMap<Value *, spirv::GlobalVariableOp> &inputBuffers, - ArrayRef<spirv::GlobalVariableOp> outputBuffers) const { - return failure(); - } -}; - -/// Base class that gets the opName for the operation. -template <typename OpTy> -class SPIRVOpLowering : public SPIRVLowering { - public: - using SPIRVLowering::SPIRVLowering; - virtual ~SPIRVOpLowering<OpTy>() {} - StringRef getOpName() override { return OpTy::getOperationName(); } -}; - -/// SPIR-V lowering for ConstantOp. -class ConstantOpSPIRVLowering final : public SPIRVOpLowering<ConstantOp> { - public: - using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering; - LogicalResult lowerOperation(Operation *op, OpBuilder &builder, - AffineMap index, ArrayRef<Value *> operands, - ValueCache &valueCache) const override; -}; - -/// SPIR-V lowering for CmpFOp. -class CmpFOpSPIRVLowering final : public SPIRVOpLowering<CmpFOp> { - public: - using SPIRVOpLowering<CmpFOp>::SPIRVOpLowering; - - LogicalResult lowerOperation(Operation *op, OpBuilder &builder, - AffineMap index, ArrayRef<Value *> operands, - ValueCache &valueCache) const override; -}; - -/// SPIR-V lowering for Min/Max operations. -template <typename OpTy, typename CmpOpTy, typename CmpFOpTy> -class CmpSelectOpSPIRVLowering final : public SPIRVOpLowering<OpTy> { - public: - using SPIRVOpLowering<OpTy>::SPIRVOpLowering; - LogicalResult lowerOperation(Operation *op, OpBuilder &builder, - AffineMap index, ArrayRef<Value *> operands, - ValueCache &valueCache) const override { - if (op->getNumOperands() != 2) { - return op->emitError( - "unhandled SPIR-V lowering for more than 2 operands"); - } - assert(operands.size() == op->getNumOperands() && - "expected as many operands for the replacement as the original " - "instruction"); - auto cmpSelectOp = cast<OpTy>(op); - auto result = cmpSelectOp.getResult(); - auto resultTy = result->getType().template dyn_cast<ShapedType>(); - if (!resultTy) { - return op->emitError( - "unhandled lowering of operations that don't return a " - "ShapedType"); - } - auto elementTy = resultTy.getElementType(); - auto boolTy = builder.getI1Type(); - Operation *cmpOp = nullptr; - if (elementTy.template isa<FloatType>()) { - cmpOp = builder.create<CmpFOpTy>(op->getLoc(), boolTy, operands, - ArrayRef<NamedAttribute>()); - } else { - cmpOp = builder.create<CmpOpTy>(op->getLoc(), boolTy, operands, - ArrayRef<NamedAttribute>()); - } - auto selectOp = builder.create<spirv::SelectOp>( - op->getLoc(), operands[0]->getType(), cmpOp->getResult(0), operands[0], - operands[1]); - valueCache.setOperandDstValue(op->getResult(0), index, - selectOp.getResult()); - return success(); - } -}; - -/// This class is the general template used to emit scalar instruction -/// corresponding for point-wise operations. Assumes that the original -/// instruction has a single result value of type ShapedType. -/// TODO(ravishankarm) : In XLA-HLO, the same operations is used for -/// integer/float tensor operations. So allow this op to take an additional op -/// type as a template parameter to handle such cases. Find a better way to do -/// this. -template <typename OpTy, typename ReplacementOpTy, - typename FloatOpTy = ReplacementOpTy> -class SPIRVPwOpLowering final : public SPIRVOpLowering<OpTy> { - public: - using SPIRVOpLowering<OpTy>::SPIRVOpLowering; - - LogicalResult lowerOperation(Operation *op, OpBuilder &builder, - AffineMap index, - ArrayRef<Value *> scalarOperands, - ValueCache &valueCache) const override { - // TODO(ravishankarm) : This check should really be a static_assert. See if - // that can be changed. - if (op->getNumOperands() == 0) { - return op->emitError("expected op to have at least one operand"); - } - auto pwOp = cast<OpTy>(op); - auto result = pwOp.getResult(); - auto resultType = result->getType().template dyn_cast<ShapedType>(); - if (!resultType) { - return op->emitError( - "unhandled lowering of operations that don't return a " - "ShapedType"); - } - auto elementType = resultType.getElementType(); - Operation *scalarOp = nullptr; - if (elementType.template isa<IntegerType>()) { - scalarOp = builder - .create<ReplacementOpTy>(op->getLoc(), elementType, - scalarOperands, - ArrayRef<NamedAttribute>()) - .getOperation(); - } else { - scalarOp = - builder - .create<FloatOpTy>(op->getLoc(), elementType, scalarOperands, - ArrayRef<NamedAttribute>()) - .getOperation(); - } - if (!scalarOp) { - return op->emitError("unable to lower operation"); - } - valueCache.setOperandDstValue(pwOp.getResult(), index, - scalarOp->getResult(0)); - return success(); - } -}; - -/// This class is the general template used to emit scalar instruction for index -/// transformation instructions like transpose. Assumes a single result value -/// and a single operand -template <typename OpTy> -class SPIRVIndexOpLowering final : public SPIRVOpLowering<OpTy> { - public: - using SPIRVOpLowering<OpTy>::SPIRVOpLowering; - - LogicalResult lowerOperation(Operation *op, OpBuilder &builder, - AffineMap index, - ArrayRef<Value *> scalarOperands, - ValueCache &valueCache) const override { - if (op->getNumOperands() != 1) { - return op->emitError( - "unhandled lowering of index transformation operation with multiple " - "operands"); - } - auto indexOp = cast<OpTy>(op); - valueCache.setOperandDstValue(indexOp.getResult(), index, - scalarOperands[0]); - return success(); - } -}; - -/// Ggenerates spv.AccessChain instruction to get the pointer value at a given -/// location of a spv.globalVariable. -inline Value *genPointerOffset(OpBuilder &builder, Location loc, - AffineExprCodegen &affineExprCodegen, - AffineMap indexMap, - spirv::GlobalVariableOp &var) { - auto basePtr = builder.create<spirv::AddressOfOp>( - loc, var.type(), builder.getSymbolRefAttr(var.sym_name())); - auto varPtrType = var.type().cast<spirv::PointerType>().getPointeeType(); - // The variable has to be a struct type with a single element. - assert(varPtrType.isa<spirv::StructType>() && - "expected variable type to be a spv.ptr<spv.struct<...>>"); - auto varStructType = varPtrType.cast<spirv::StructType>(); - assert(varStructType.getNumElements() == 1 && - "expected variable type to be a spv.ptr of spv.struct with a single " - "element"); - auto varType = varStructType.getElementType(0); - - SmallVector<Value *, 2> accessIndex; - /// For scalar values, the index-map computed with already map to the 0-th - /// element. For arrays, they map to the position accessed. So just for arrays - /// we need to add an extra 0 to index into the struct. - if (varType.isa<spirv::ArrayType>() || - varType.isa<spirv::RuntimeArrayType>()) { - auto i32Type = builder.getIntegerType(32); - auto zero = builder.create<spirv::ConstantOp>(loc, i32Type, - builder.getI32IntegerAttr(0)); - accessIndex.push_back(zero); - } - for (auto indexExpr : indexMap.getResults()) { - accessIndex.push_back(affineExprCodegen.getValue( - indexExpr, builder.saveInsertionPoint(), loc)); - } - return builder.create<spirv::AccessChainOp>(loc, basePtr, accessIndex); -} - -/// Lower return statements during SPIR-V codegeneration. -class ReturnOpSPIRVLowering : public SPIRVOpLowering<ReturnOp> { - public: - using SPIRVOpLowering<ReturnOp>::SPIRVOpLowering; - - LogicalResult lowerOperation( - Operation *op, OpBuilder &builder, AffineExprCodegen &affineExprCodegen, - ValueCache &valueCache, - DenseMap<Value *, spirv::GlobalVariableOp> &inputBuffers, - ArrayRef<spirv::GlobalVariableOp> outputBuffers) const override; -}; - -/// Class to drive the SPIRV code-generation. -template <typename... Ts> -class SPIRVCodegen { - using OpCodegenListT = llvm::StringMap<std::unique_ptr<SPIRVLowering>>; - - public: - explicit SPIRVCodegen() { insert(); } - - LogicalResult codegen(spirv::ModuleOp &spirvModule, FuncOp &fn, - AffineExprCodegen &affineExprCodegen, - ValueCache &valueCache) { - if (fn.getBlocks().size() != 1) { - return emitError( - fn.getLoc(), - "unimplemeneted handling multiple blocks within a function"); - } - - OpBuilder builder(spirvModule.body()); - // Create the entry function and generate global invocation ID. Creates a - // global variable for all inputs and output tensors. - return createEntryFn(builder, fn, affineExprCodegen, valueCache); - } - - private: - /// Helper method to create the entry function. Creates global variables for - /// all inputs and outputs. Inserts the spv.EntryPoint operations as well. - LogicalResult createEntryFn(OpBuilder &builder, FuncOp &fn, - AffineExprCodegen &affineExprCodegen, - ValueCache &valueCache) { - auto loc = fn.getLoc(); - // TODO(ravishankarm) : This should actually be part of the SPIR-V - // conversion framework in MLIR core. Move it there. - auto convertType = [&loc](Type t, - spirv::PointerType &varType) -> LogicalResult { - auto shapedType = t.dyn_cast<ShapedType>(); - if (!shapedType) { - return emitError(loc, "expected ShapedType argument"); - } - auto elementType = shapedType.getElementType(); - if (!elementType.isIntOrFloat()) { - return emitError(loc, "unhandled element type ") - << elementType << " while lowering to SPIR-V"; - } - int64_t stride = elementType.getIntOrFloatBitWidth() / 8; - for (auto dim : reverse(shapedType.getShape())) { - if (dim <= 0) { - return emitError(loc, "expected tensor dimensions to be non-zero"); - } - elementType = spirv::ArrayType::get( - elementType, dim, - static_cast<spirv::ArrayType::LayoutInfo>(stride)); - stride *= dim; - } - // TODO(ravishankarm): Verify that the type of the variable passes - // spirv-val. - varType = spirv::PointerType::get( - spirv::StructType::get(elementType, - static_cast<spirv::StructType::LayoutInfo>(0)), - spirv::StorageClass::StorageBuffer); - return success(); - }; - - // Convert functions arguments and return values to - // spirv::GlobalVariables. All global variables are given a descriptor set - // of 0 and binding is the argument number. - auto fnType = fn.getType(); - auto descriptorSetAttrName = convertToSnakeCase( - stringifyDecoration(spirv::Decoration::DescriptorSet)); - auto bindingAttrName = - convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding)); - for (auto argType : enumerate(fnType.getInputs())) { - spirv::PointerType varType; - if (failed(convertType(argType.value(), varType))) { - return failure(); - } - auto varName = - fn.getName().str() + "_arg_" + std::to_string(argType.index()); - auto var = builder.create<spirv::GlobalVariableOp>( - loc, TypeAttr::get(varType), builder.getStringAttr(varName), nullptr); - // Set descriptor_set to 0. - var.setAttr(descriptorSetAttrName, builder.getI32IntegerAttr(0)); - // Set binding to argument number. - var.setAttr(bindingAttrName, builder.getI32IntegerAttr(argType.index())); - - inputArgToVariable[fn.getArgument(argType.index())] = var; - } - for (auto resType : enumerate(fnType.getResults())) { - spirv::PointerType varType; - if (failed(convertType(resType.value(), varType))) { - return failure(); - } - auto varName = - fn.getName().str() + "_res_" + std::to_string(resType.index()); - auto var = builder.create<spirv::GlobalVariableOp>( - loc, TypeAttr::get(varType), builder.getStringAttr(varName), nullptr); - // Set descriptor_set to 0. - var.setAttr(descriptorSetAttrName, builder.getI32IntegerAttr(0)); - // Set binding to (result number + num arguments) - var.setAttr( - bindingAttrName, - builder.getI32IntegerAttr(fnType.getNumInputs() + resType.index())); - - resultIndexToVariable.push_back(var); - } - - auto entryFnType = - builder.getFunctionType(ArrayRef<Type>(), ArrayRef<Type>()); - auto entryFn = builder.create<FuncOp>(loc, fn.getName(), entryFnType, - ArrayRef<NamedAttribute>()); - - // Start a scope to create an insertion guard to reset the builder once the - // function is lowered. - { - OpBuilder::InsertionGuard funcInsertGuard(builder); - builder.setInsertionPointToStart(entryFn.addEntryBlock()); - - // Create the Global invocation ID. - if (failed(createGlobalInvocationID(builder, fn.getLoc(), - affineExprCodegen))) { - return failure(); - } - - if (failed(lowerFunction(builder, fn, entryFn, affineExprCodegen, - valueCache))) { - return failure(); - } - } - - // Create the entry point instructions for the entry function. - if (failed(createEntryPoint(builder, loc, entryFn))) { - return failure(); - } - return success(); - } - - /// Creates the global variable for GlobalInvocationID, and gets the ID at x, - /// y and z dimensions. - LogicalResult createGlobalInvocationID(OpBuilder &builder, Location loc, - AffineExprCodegen &affineExprCodegen) { - auto moduleOp = builder.getInsertionBlock() - ->getParentOp() - ->getParentOfType<spirv::ModuleOp>(); - OpBuilder moduleBuilder(moduleOp.body()); - auto i32Type = builder.getIntegerType(32); - auto idType = VectorType::get(3, i32Type); - auto ptrIdType = - spirv::PointerType::get(idType, spirv::StorageClass::Input); - auto globalInvocationID = moduleBuilder.create<spirv::GlobalVariableOp>( - loc, TypeAttr::get(ptrIdType), - builder.getStringAttr("globalInvocationID"), nullptr); - globalInvocationID.setAttr( - convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn)), - builder.getStringAttr( - spirv::stringifyBuiltIn(spirv::BuiltIn::GlobalInvocationId))); - interface.push_back( - builder.getSymbolRefAttr(globalInvocationID.sym_name())); - - auto globalInvocationIDPtr = builder.create<spirv::AddressOfOp>( - loc, ptrIdType, - builder.getSymbolRefAttr(globalInvocationID.getOperation())); - auto id = builder.create<spirv::LoadOp>(loc, idType, globalInvocationIDPtr, - nullptr, nullptr); - auto id_x = builder.create<spirv::CompositeExtractOp>( - loc, i32Type, id, builder.getArrayAttr(builder.getI32IntegerAttr(0))); - auto id_y = builder.create<spirv::CompositeExtractOp>( - loc, i32Type, id, builder.getArrayAttr(builder.getI32IntegerAttr(1))); - auto id_z = builder.create<spirv::CompositeExtractOp>( - loc, i32Type, id, builder.getArrayAttr(builder.getI32IntegerAttr(2))); - affineExprCodegen.setDimDstValue(0, id_x); - affineExprCodegen.setDimDstValue(1, id_y); - affineExprCodegen.setDimDstValue(2, id_z); - return success(); - } - - /// Method to load the values of globalVariables corresponding to the - /// arguments of the dispatch function at all indices needed within the - /// dispatch function. - LogicalResult initArgValues(OpBuilder &builder, Location loc, - AffineExprCodegen &affineExprCodegen, - ValueCache &valueCache, Value *origArg) { - for (auto indexMap : affineExprCodegen.getIndices(origArg)) { - auto var = inputArgToVariable.lookup(origArg); - if (!var) { - return emitError( - loc, "undefined SPIR-V global variable for tensor argument"); - } - auto ptr = - genPointerOffset(builder, loc, affineExprCodegen, indexMap, var); - auto elementType = - ptr->getType().template cast<spirv::PointerType>().getPointeeType(); - auto val = builder.create<spirv::LoadOp>(loc, elementType, ptr, - /*memory_access =*/nullptr, - /*alignment = */ nullptr); - valueCache.setOperandDstValue(origArg, indexMap, val); - } - return success(); - } - - /// Adds the spv.EntryPointOp and records all the interface variables used in - /// the entryFn. - LogicalResult createEntryPoint(OpBuilder &builder, Location loc, - FuncOp entryFn) { - builder.create<spirv::EntryPointOp>( - loc, - builder.getI32IntegerAttr( - static_cast<int32_t>(spirv::ExecutionModel::GLCompute)), - builder.getSymbolRefAttr(entryFn), builder.getArrayAttr(interface)); - builder.create<spirv::ExecutionModeOp>( - loc, builder.getSymbolRefAttr(entryFn), - builder.getI32IntegerAttr( - static_cast<int32_t>(spirv::ExecutionMode::LocalSize)), - builder.getI32ArrayAttr({1, 1, 1})); - interface.clear(); - return success(); - } - - /// Lowers the body of the function in the original dialect to SPIR-V dialect. - LogicalResult lowerFunction(OpBuilder &builder, FuncOp fn, FuncOp entryFn, - AffineExprCodegen &affineExprCodegen, - ValueCache &valueCache) { - for (auto arg : fn.getArguments()) { - // Load values of the argument at all indices needed for computation - // within the dispatch function. - if (failed(initArgValues(builder, fn.getLoc(), affineExprCodegen, - valueCache, arg))) { - return failure(); - } - } - - for (auto &block : fn) { - for (auto &op : block) { - // Lower individual operations. - if (failed( - lowerOperation(builder, affineExprCodegen, valueCache, &op))) { - return failure(); - } - } - } - return success(); - } - - /// Dispatches the lowering of tensor operation to SPIR-V scalar - /// operation. - LogicalResult lowerOperation(OpBuilder &builder, - AffineExprCodegen &affineExprCodegen, - ValueCache &valueCache, Operation *op) { - auto opName = op->getName().getStringRef(); - if (!opCodegenList.count(opName)) { - return op->emitError("unhandled codegen"); - } - if (op->getNumResults() > 1) { - return op->emitError("unhandled codegen for multiple result values"); - } - - // Zero return case. - if (!op->getNumResults()) { - return opCodegenList[opName]->lowerOperation( - op, builder, affineExprCodegen, valueCache, inputArgToVariable, - resultIndexToVariable); - } - - // Single return case. - auto resultTensor = op->getResult(0); - auto indices = affineExprCodegen.getIndices(resultTensor); - for (auto &index : indices) { - auto operandIndices = - affineExprCodegen.getOperandIndices(resultTensor, index); - SmallVector<Value *, 2> scalarOperands; - for (auto arg : llvm::enumerate(op->getOperands())) { - auto scalarArg = valueCache.getOperandDstValue( - arg.value(), operandIndices[arg.index()]); - if (!scalarArg) { - return op->emitError("argument ") - << arg.index() << " has no scalar value"; - } - scalarOperands.push_back(scalarArg); - } - if (failed(opCodegenList[opName]->lowerOperation( - op, builder, index, scalarOperands, valueCache))) { - return failure(); - } - } - return success(); - } - - void insert() { - std::vector<std::unique_ptr<SPIRVLowering>> objs; - using dummy = int[]; - (void)dummy{0, (objs.emplace_back(std::make_unique<Ts>()), 0)...}; - for (auto &elem : objs) { - StringRef opName = elem->getOpName(); - opCodegenList.try_emplace(opName, std::move(elem)); - } - } - - /// List of classes that implement the operation lowering from tensor - /// operations to SPIR-V. - OpCodegenListT opCodegenList; - - /// I/O interface for the entry function containing global variables that are - /// used by the entire function call tree. - SmallVector<Attribute, 4> interface; - - /// Mapping from argument of the dispatch function in tensor dialect to the - /// corresponding spv.globalVariable. - DenseMap<Value *, spirv::GlobalVariableOp> inputArgToVariable; - - /// List of spv.globalVariables created for tensors returned by the dispatch - /// function in tensor dialects. - SmallVector<spirv::GlobalVariableOp, 1> resultIndexToVariable; - - /// GlobalInvocationID variable. - spirv::GlobalVariableOp globalInvocationID; -}; - -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_TRANSLATION_SPIRV_SPIRVLOWERING_H
diff --git a/iree/compiler/Translation/SPIRV/XLAIndexPropagation.cpp b/iree/compiler/Translation/SPIRV/XLAIndexPropagation.cpp deleted file mode 100644 index 50cb8ec..0000000 --- a/iree/compiler/Translation/SPIRV/XLAIndexPropagation.cpp +++ /dev/null
@@ -1,126 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//===- XLAIndexPropagation.cpp ---------------------------------*- C++//-*-===// -// -// For an IREE dispatch function in XLA-HLO dialect, compute the indices of all -// tensors needed to produce the value of the result tensors at a particlar -// index. -// -//===----------------------------------------------------------------------===// - -#include "iree/compiler/Translation/SPIRV/XLAIndexPropagation.h" - -namespace mlir { -namespace iree_compiler { - -//===----------------------------------------------------------------------===// -// BroadcastInDimOp -//===----------------------------------------------------------------------===// - -LogicalResult XLABroadcastInDimOpIndexPropagation::propagateIndexMap( - Operation *operation, AffineMap resultIndex, - SmallVectorImpl<AffineMap> &indexMap) const { - auto broadcastOp = cast<xla_hlo::BroadcastInDimOp>(operation); - auto broadcastDim = broadcastOp.broadcast_dimensions(); - - Builder builder(operation->getContext()); - if (!broadcastDim) { - // This is a scalar. So all indices map to the same element. - AffineMap scalarMap = - AffineMap::get(resultIndex.getNumDims(), resultIndex.getNumSymbols(), - builder.getAffineConstantExpr(0)); - indexMap.push_back(scalarMap); - return success(); - } - - // Handle non-scalar cases. - auto dimensions = broadcastDim->getValues<int64_t>(); - SmallVector<AffineExpr, 4> exprs; - for (auto resultExpr : enumerate(resultIndex.getResults())) { - if (llvm::any_of(dimensions, [&resultExpr](int64_t dim) { - return dim == resultExpr.index(); - })) { - exprs.push_back(resultExpr.value()); - } - } - auto operandMap = AffineMap::get(resultIndex.getNumDims(), - resultIndex.getNumSymbols(), exprs); - indexMap.push_back(operandMap); - return success(); -} - -//===----------------------------------------------------------------------===// -// BroadcastOp -//===----------------------------------------------------------------------===// - -// For broadcast op, just drop the first N expressions of the resultIndex, where -// N is the number of elements in broadcast_sizes attribute. -LogicalResult XLABroadcastOpIndexPropagation::propagateIndexMap( - Operation *operation, AffineMap resultIndex, - SmallVectorImpl<AffineMap> &indexMap) const { - auto broadcastOp = cast<xla_hlo::BroadcastOp>(operation); - auto broadcastDim = broadcastOp.broadcast_sizes(); - - SmallVector<AffineExpr, 4> exprs; - for (auto i : llvm::seq<size_t>( - broadcastDim.getType().getShape()[0], - operation->getResult(0)->getType().cast<ShapedType>().getRank())) { - exprs.push_back(resultIndex.getResult(i)); - } - - Builder builder(operation->getContext()); - if (exprs.empty()) { - // The result is a scalar. Just add a constant expr 0. - exprs.push_back(builder.getAffineConstantExpr(0)); - } - auto operandMap = AffineMap::get(resultIndex.getNumDims(), - resultIndex.getNumSymbols(), exprs); - indexMap.push_back(operandMap); - return success(); -} - -//===----------------------------------------------------------------------===// -// ReverseOp -//===----------------------------------------------------------------------===// - -LogicalResult XLAReverseOpIndexPropagation::propagateIndexMap( - Operation *op, AffineMap resultIndex, - SmallVectorImpl<AffineMap> &indexMap) const { - auto reverseOp = cast<xla_hlo::ReverseOp>(op); - DenseSet<unsigned> dimensions; - for (auto index : reverseOp.dimensions()) { - dimensions.insert(index.getZExtValue()); - } - return propagateIndexMapImpl(op, dimensions, resultIndex, indexMap); -} - -//===----------------------------------------------------------------------===// -// TransposeOp -//===----------------------------------------------------------------------===// - -LogicalResult XLATransposeOpIndexPropagation::propagateIndexMap( - Operation *op, AffineMap resultIndex, - SmallVectorImpl<AffineMap> &indexMap) const { - auto transposeOp = cast<xla_hlo::TransposeOp>(op); - // Compute the affine map that represents the permutation. - SmallVector<unsigned, 4> permutation; - for (auto index : transposeOp.permutation()) { - permutation.push_back(index.getZExtValue()); - } - return propagateIndexMapImpl(op, permutation, resultIndex, indexMap); -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Translation/SPIRV/XLAIndexPropagation.h b/iree/compiler/Translation/SPIRV/XLAIndexPropagation.h deleted file mode 100644 index a20b0f1..0000000 --- a/iree/compiler/Translation/SPIRV/XLAIndexPropagation.h +++ /dev/null
@@ -1,112 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//===- XLAIndexPropagation.h -----------------------------------*- C++//-*-===// -// -// For an IREE dispatch function in XLA-HLO dialect, compute the indices of all -// tensors needed to produce the value of the result tensors at a particlar -// index. -// -//===----------------------------------------------------------------------===// -#ifndef IREE_COMPILER_TRANSLATION_SPIRV_XLAINDEXPROPOGATION_H -#define IREE_COMPILER_TRANSLATION_SPIRV_XLAINDEXPROPOGATION_H - -#include "iree/compiler/Translation/SPIRV/IndexComputation.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/Function.h" -#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" - -namespace mlir { -namespace iree_compiler { - -class XLABroadcastInDimOpIndexPropagation final - : public IndexPropagationOp<xla_hlo::BroadcastInDimOp> { - public: - using IndexPropagationOp<xla_hlo::BroadcastInDimOp>::IndexPropagationOp; - - LogicalResult propagateIndexMap( - Operation *operation, AffineMap resultIndex, - SmallVectorImpl<AffineMap> &operandIndices) const override; -}; - -// For broadcast op, just drop the first N expressions of the resultIndex, where -// N is the number of elements in broadcast_sizes attribute. -class XLABroadcastOpIndexPropagation final - : public IndexPropagationOp<xla_hlo::BroadcastOp> { - public: - using IndexPropagationOp<xla_hlo::BroadcastOp>::IndexPropagationOp; - - LogicalResult propagateIndexMap( - Operation *operation, AffineMap resultIndex, - SmallVectorImpl<AffineMap> &operandIndices) const override; -}; - -/// For return ops, it is assumed that each thread is computing the value of one -/// element of the returned tensor. -template <typename OpTy> -class ReturnOpIndexPropagation : public IndexPropagationOp<OpTy> { - public: - using IndexPropagationOp<OpTy>::IndexPropagationOp; - - LogicalResult propagateIndexMap( - Operation *operation, IndexComputationCache &indexMap) const override { - if (operation->getNumOperands() != 1) { - return operation->emitError("unhandled multiple return values"); - } - auto returnValue = operation->getOperand(0); - auto returnType = returnValue->getType().cast<RankedTensorType>(); - auto returnRank = returnType.getRank(); - if (returnRank > 3) { - return operation->emitError("unhandled return tensor of dimension ") - << returnType.getShape().size(); - } - // Have as many symbols as the rank of the input tensor. These symbols map - // to GlobalInvocationID along the three dimensions. - Builder builder(operation->getContext()); - SmallVector<AffineExpr, 4> affineExprs; - for (size_t i = returnRank; i > 0; --i) { - affineExprs.push_back(builder.getAffineDimExpr(i - 1)); - } - indexMap[operation->getOperand(0)] - [AffineMap::get(returnRank, 0, affineExprs)]; - return success(); - } -}; - -/// Index propogation for XLA Reverse. -class XLAReverseOpIndexPropagation final - : public ReverseOpIndexPropagation<xla_hlo::ReverseOp> { - public: - using ReverseOpIndexPropagation< - xla_hlo::ReverseOp>::ReverseOpIndexPropagation; - LogicalResult propagateIndexMap( - Operation *op, AffineMap resultIndex, - SmallVectorImpl<AffineMap> &indexMap) const override; -}; - -/// Index propogation for XLA Transpose. -class XLATransposeOpIndexPropagation final - : public TransposeOpIndexPropagation<xla_hlo::TransposeOp> { - public: - using TransposeOpIndexPropagation< - xla_hlo::TransposeOp>::TransposeOpIndexPropagation; - LogicalResult propagateIndexMap( - Operation *op, AffineMap resultIndex, - SmallVectorImpl<AffineMap> &indexMap) const override; -}; - -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_TRANSLATION_SPIRV_XLAINDEXPROPOGATION_H
diff --git a/iree/compiler/Translation/SPIRV/test/BUILD b/iree/compiler/Translation/SPIRV/test/BUILD deleted file mode 100644 index 79700a8..0000000 --- a/iree/compiler/Translation/SPIRV/test/BUILD +++ /dev/null
@@ -1,16 +0,0 @@ -# Tests for common transforms. - -load("//iree:build_defs.bzl", "iree_glob_lit_tests", "iree_setup_lit_package") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -iree_setup_lit_package( - data = [ - "//iree/tools:iree-opt", - ], -) - -iree_glob_lit_tests()
diff --git a/iree/compiler/Translation/Sequencer/BUILD b/iree/compiler/Translation/Sequencer/BUILD deleted file mode 100644 index 4245b19..0000000 --- a/iree/compiler/Translation/Sequencer/BUILD +++ /dev/null
@@ -1,33 +0,0 @@ -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "Sequencer", - srcs = ["SequencerModuleTranslation.cpp"], - hdrs = ["SequencerModuleTranslation.h"], - deps = [ - "//iree/base:status", - "//iree/compiler/IR", - "//iree/compiler/IR/Sequencer", - "//iree/compiler/Serialization", - "//iree/compiler/Transforms", - "//iree/compiler/Transforms/Sequencer", - "//iree/compiler/Utils", - "//iree/hal:executable_format", - "//iree/schemas", - "@com_github_google_flatbuffers//:flatbuffers", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", - "@local_config_mlir//:StandardDialectRegistration", - "@local_config_mlir//:Support", - "@local_config_mlir//:Transforms", - "@local_config_mlir//:Translation", - "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo", - "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_dialect_registration", - "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_legalize_control_flow", - ], - alwayslink = 1, -)
diff --git a/iree/compiler/Translation/Sequencer/SequencerModuleTranslation.cpp b/iree/compiler/Translation/Sequencer/SequencerModuleTranslation.cpp deleted file mode 100644 index 7184f6b..0000000 --- a/iree/compiler/Translation/Sequencer/SequencerModuleTranslation.cpp +++ /dev/null
@@ -1,505 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/Translation/Sequencer/SequencerModuleTranslation.h" - -#include <cstdint> -#include <iostream> -#include <memory> -#include <unordered_map> -#include <vector> - -#include "flatbuffers/flatbuffers.h" -#include "flatbuffers/minireflect.h" -#include "iree/base/status.h" -#include "iree/compiler/IR/ConfigOps.h" -#include "iree/compiler/IR/Sequencer/OpWriters.h" -#include "iree/compiler/IR/StructureOps.h" -#include "iree/compiler/IR/Types.h" -#include "iree/compiler/Serialization/VMFunctionBuilder.h" -#include "iree/compiler/Serialization/VMFunctionTableBuilder.h" -#include "iree/compiler/Serialization/VMModuleBuilder.h" -#include "iree/compiler/Transforms/Passes.h" -#include "iree/compiler/Transforms/Sequencer/Passes.h" -#include "iree/compiler/Utils/Macros.h" -#include "iree/compiler/Utils/OpUtils.h" -#include "iree/compiler/Utils/TranslationUtils.h" -#include "iree/hal/executable_format.h" -#include "iree/schemas/executable_def_generated.h" -#include "iree/schemas/executable_table_def_generated.h" -#include "iree/schemas/module_def_generated.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/Sequence.h" -#include "llvm/ADT/StringExtras.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/ADT/StringSet.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/ToolOutputFile.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Module.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/Passes.h" -#include "mlir/Translation.h" -#include "tensorflow/compiler/mlir/xla/transforms/passes.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -// Builds a pass pipeline that optimizes and legalizes the module to the form -// expected by partitioning. -void buildLegalizeInputPassPipeline(PassManager *passManager) { - // Convert to the subset of XLA HLO and Standard dialects supported as IREE - // input. In particular, move from XLA HLO to standard control flow. - passManager->addPass(xla_hlo::createLegalizeControlFlowPass()); - passManager->addPass(createLegalizeInputOpsPass()); - - // Standard passes that shake out a lot of garbage. - // Some may have been run prior to translation but this ensures we are always - // in a known state. - passManager->addPass(createCanonicalizerPass()); - passManager->addPass(createLoopFusionPass()); - passManager->addPass(createLoopInvariantCodeMotionPass()); - passManager->addPass(createMemRefDataFlowOptPass()); - passManager->addPass(createCanonicalizerPass()); - passManager->addPass(createSimplifyAffineStructuresPass()); - passManager->addPass(createCSEPass()); - passManager->addPass(createCanonicalizerPass()); - - // Expand uses of tuples into independent args/results. - passManager->addPass(createConvertFromTupleCallingConventionPass()); - passManager->addPass(createCanonicalizerPass()); -} - -// Builds a pass pipeline that partitions the module into sequencer functions -// and executables ready to be translated. -void buildPartitioningPassPipeline(PassManager *passManager) { - // Find reduction ops and create iree.reduction_regions. We do this prior to - // performing dispatch region identification so that we can build as big of - // fused reduction regions as possible. The remaining ops will be put into - // dispatch regions. - passManager->addPass(createIdentifyReductionRegionsPass()); - passManager->addPass(createCSEPass()); - - // Create all of the dispatch regions, CSE their workloads, and fold. - passManager->addPass(createIdentifyDispatchRegionsPass()); - passManager->addPass(createCSEPass()); - passManager->addPass(createFoldCompatibleDispatchRegionsPass()); - - // Note that as we are rematerializing things here it's critical we do not run - // the canonicalizer/CSE between now and when we outline - otherwise it'll - // undo all of our work! - passManager->addPass(createRematerializeDispatchConstantsPass()); - - // Outline the dispatch regions into their own functions. This separates the - // sequencer functions performing dispatches from the dispatchees. - passManager->addPass(createOutlineDispatchRegionsPass()); - passManager->addPass(createOutlineReductionRegionsPass()); - - // Cleanup identity sequencer tensor-to-memref ops that clutter up the IR. - passManager->addPass(createCanonicalizerPass()); - - // Drop all functions that are no longer reachable. - // This is important as many of the functions remaining are probably - // dispatchable and unused now that we've outlined them executables. - passManager->addPass(createDropUnreachableModuleFunctionsPass()); - - // Drop all unused executables. - // Note that we need to have dropped unreachable functions first otherwise - // references could keep executables that are unreachable from exported - // functions alive. - passManager->addPass(createDropUnusedExecutablesPass()); -} - -// Builds a pass pipeline that converts sequencer functions to the iree_seq.hl -// dialect. -void buildSequencerConversionPassPipeline(PassManager *passManager) { - passManager->addPass(createConvertToMemRefCallingConventionPass()); - - // Convert ops that are supported by the sequencer directly to the sequencer - // dialect. The ops that remain should be only those that can be moved into - // dispatch regions. - passManager->addPass(createLowerToSequencerDialectPass()); - - // Cleanup identity sequencer tensor-to-memref ops and other memory accesses - // that clutter up the IR. - passManager->addPass(createCanonicalizerPass()); - passManager->addPass(createMemRefDataFlowOptPass()); - - // Eliminate ops we don't care about based on a lack of side-effects. - // IREE does not guarantee exception/error behavior of dead ops. - passManager->addPass(createAggressiveOpEliminationPass()); - - // Perform any last-minute optimizations to trim down the IR. - passManager->addPass(createCanonicalizerPass()); - passManager->addPass(createMemRefDataFlowOptPass()); - passManager->addPass(createCSEPass()); -} - -// Builds a pass pipeline that lowers the iree_seq.hl dialect to the iree_seq.ll -// dialect and prepares for serialization. -void buildSequencerLoweringPassPipeline(PassManager *passManager) { - // Lower iree_hl_seq -> iree_ll_seq. - passManager->addPass(createLowerSequencerDialectPass()); - passManager->addPass(createCanonicalizerPass()); - passManager->addPass(createMemRefDataFlowOptPass()); - passManager->addPass(createAggressiveOpEliminationPass()); - - // Assign ordinals used by the bytecode to reference executables and - // functions. - passManager->addPass(createAssignFunctionOrdinalsPass()); - passManager->addPass(createAssignExecutableOrdinalsPass()); - - // Plumb workload information down into executable entry points. This allows - // the backends to calculate their workgroup sizes, indexing, etc. - passManager->addPass(createAssignExecutableWorkloadAttrsPass()); -} - -// Inserts one or more iree.executable_target_config ops based on the -// translation options. -void insertTargetConfigOps(const ModuleTranslationOptions &options, - OpBuilder &builder) { - llvm::StringSet<> targetBackends; - if (options.target_backends.empty()) { - // Add all backends when none are explicitly provided. - targetBackends.insert(getExecutableTranslationRegistry().keys().begin(), - getExecutableTranslationRegistry().keys().end()); - } else { - for (auto &targetBackend : options.target_backends) { - for (auto &matchedBackend : - matchExecutableTranslationBackendNames(targetBackend)) { - targetBackends.insert(matchedBackend); - } - } - } - for (auto &targetBackend : targetBackends) { - builder.create<IREE::ExecutableTargetConfigOp>(builder.getUnknownLoc(), - targetBackend.getKey()); - } -} - -class SequencerTranslator { - public: - explicit SequencerTranslator(ModuleTranslationOptions options) - : options_(options) {} - - const ModuleTranslationOptions &options() const { return options_; } - - std::vector<uint8_t> translateModule(ModuleOp module); - - private: - LogicalResult translateMultiArchExecutable( - IREE::MultiArchExecutableOp executableOp, VMModuleBuilder *moduleBuilder); - - LogicalResult translateSequencerModule(ModuleOp module, - VMModuleBuilder *moduleBuilder); - LogicalResult declareFunction(FuncOp function, - VMModuleBuilder *moduleBuilder); - LogicalResult defineFunction(FuncOp function, VMModuleBuilder *moduleBuilder); - - ModuleTranslationOptions options_; -}; - -std::vector<uint8_t> SequencerTranslator::translateModule(ModuleOp module) { - // Run one large set of passes to get to a partitioned module. - auto partitioningPasses = createPassManager(module.getContext(), options()); - buildLegalizeInputPassPipeline(partitioningPasses.get()); - buildPartitioningPassPipeline(partitioningPasses.get()); - if (failed(runPassPipeline(options(), partitioningPasses.get(), module))) { - module.emitError() << "Failed to run partitioning passes"; - return {}; - } - - // Run the sequencer-specific conversion passes on the module. - auto sequencerConversionPasses = - createPassManager(module.getContext(), options()); - buildSequencerConversionPassPipeline(sequencerConversionPasses.get()); - if (failed(runPassPipeline(options(), sequencerConversionPasses.get(), - module))) { - module.emitError() << "Failed to run sequencer conversion passes"; - return {}; - } - - // Lower sequencer functions to their final form. - auto sequencerLoweringPasses = - createPassManager(module.getContext(), options()); - buildSequencerLoweringPassPipeline(sequencerLoweringPasses.get()); - if (failed( - runPassPipeline(options(), sequencerLoweringPasses.get(), module))) { - module.emitError() << "Failed to run sequencer lowering passes"; - return {}; - } - - // Perform translation on all executables. - // We then know exactly what executable formats we have and can query them to - // see if we need to do any additional processing (such as to support better - // types/etc). - ::flatbuffers::FlatBufferBuilder fbb; - VMModuleBuilder moduleBuilder(&fbb); - for (auto multiArchExecutableOp : - module.getOps<IREE::MultiArchExecutableOp>()) { - if (failed(translateMultiArchExecutable(multiArchExecutableOp, - &moduleBuilder))) { - module.emitError() << "Failed to translate multi-arch-executable"; - return {}; - } - } - - // Build the module bytecode. - if (failed(translateSequencerModule(module, &moduleBuilder))) { - module.emitError() << "Unable to translate sequencer module"; - return {}; - } - auto moduleDef = moduleBuilder.Finish(); - if (moduleDef.IsNull()) { - module.emitError() << "Failed to verify completed module def"; - return {}; - } - auto bytes = moduleBuilder.Serialize(moduleDef); - if (bytes.empty()) { - module.emitError() << "Failed to serialize final module def"; - return {}; - } - return bytes; -} - -LogicalResult SequencerTranslator::translateMultiArchExecutable( - IREE::MultiArchExecutableOp multiArchExecutableOp, - VMModuleBuilder *moduleBuilder) { - auto &fbb = *moduleBuilder->fbb(); - - // Find the unspecified executable. This is the template from which we will - // translate to other targets. - IREE::ExecutableOp templateExecutableOp; - for (auto executableOp : - multiArchExecutableOp.getBlock().getOps<IREE::ExecutableOp>()) { - if (executableOp.format() == - static_cast<uint32_t>(IREE::ExecutableFormat::Unspecified)) { - templateExecutableOp = executableOp; - break; - } - } - if (!templateExecutableOp) { - // Fine for there to be no unspecified executable - just ignore. - return success(); - } - int entryPointCount = 0; - for (auto func : templateExecutableOp.getInnerModule().getOps<FuncOp>()) { - if (func.getAttr("iree.executable.export")) { - ++entryPointCount; - } - } - - // For now we just add target config ops based on options. In the future we - // could do this earlier via an analysis pass determining which targets should - // be used for each executable. - OpBuilder configBuilder(templateExecutableOp); - configBuilder.setInsertionPointToStart(&templateExecutableOp.getBlock()); - insertTargetConfigOps(options(), configBuilder); - - // Find all target configs and bucket them into the backends that will - // translate them. This way we can batch the translations and possibly enable - // backends to dedupe some things. - DenseMap<StringRef, std::vector<IREE::ExecutableTargetConfigOp>> - backendTargetConfigOps; - for (auto targetConfigOp : templateExecutableOp.getBlock() - .getOps<IREE::ExecutableTargetConfigOp>()) { - auto &targetConfigOps = backendTargetConfigOps[targetConfigOp.backend()]; - targetConfigOps.push_back(targetConfigOp); - } - if (backendTargetConfigOps.empty()) { - // There are no target configs - which likely means we've already translated - // this in a previous pass. - return success(); - } - - ExecutableTranslationOptions translationOptions; - translationOptions.CopyFrom(options()); - - // Invoke each backend translator on the template executables to produce new - // executables. The backends may produce any number of executables that we - // then merge back in to the iree.multi_arch_executable and the module - // flatbuffer. - std::vector<std::unique_ptr<iree::ExecutableDefT>> translatedExecutableDefs; - for (auto it : backendTargetConfigOps) { - const auto &backendKey = it.first; - const auto &targetConfigOps = it.second; - - // Find the translator to use in the registry. It must have been linked in - // and the name must match what is used in the registration macro. - auto translateExecutableFn = - getExecutableTranslationRegistry().lookup(backendKey); - if (!translateExecutableFn) { - return multiArchExecutableOp.emitError() - << "No registered backend found for target '" << backendKey.str() - << "'; ensure it is linked in to your binary (have: " - << llvm::join(getExecutableTranslationRegistry().keys(), ", ") - << ")"; - } - - // Clone the executable for each config so that the translator is allowed to - // modify it in-place. - // We also need to strip all of the other configs so that the translator - // backend only sees the one for each of its configs. - OpBuilder builder(&multiArchExecutableOp.getBlock()); - builder.setInsertionPoint(multiArchExecutableOp.getBlock().getTerminator()); - SmallVector<IREE::ExecutableOp, 4> clonedExecutableOps; - for (auto targetConfigOp : targetConfigOps) { - auto executableCloneOp = cast<IREE::ExecutableOp>( - builder.clone(*templateExecutableOp.getOperation())); - for (auto existingTargetConfigOp : llvm::make_early_inc_range( - executableCloneOp.getBlock() - .getOps<IREE::ExecutableTargetConfigOp>())) { - existingTargetConfigOp.erase(); - } - OpBuilder configBuilder(executableCloneOp); - configBuilder.setInsertionPointToStart(&executableCloneOp.getBlock()); - configBuilder.clone(*targetConfigOp.getOperation()); - clonedExecutableOps.push_back(executableCloneOp); - } - - // Perform translation on all of the backend-specific targets. - // Note that the results here may not have the same number of executables we - // started with if the backend either couldn't satisfy some of the requests - // or decided to dedupe or expand certain ones. - auto translationResults = - translateExecutableFn(clonedExecutableOps, translationOptions); - if (!translationResults.hasValue()) { - return multiArchExecutableOp.emitError() - << "Failed to translate executable with backend " << backendKey; - } - for (auto &executableDef : translationResults.getValue().executable_defs) { - translatedExecutableDefs.push_back(std::move(executableDef)); - } - } - - // Remove configs from the template executable so that if we are called again - // we don't re-translate. - for (auto targetConfigOp : llvm::make_early_inc_range( - templateExecutableOp.getBlock() - .getOps<IREE::ExecutableTargetConfigOp>())) { - targetConfigOp.erase(); - } - - // Create multi-arch executable with all of the target-specific executables. - iree::MultiArchExecutableDefT maedf; - maedf.name = multiArchExecutableOp.getName(); - maedf.entry_point_count = entryPointCount; - maedf.executables = std::move(translatedExecutableDefs); - auto maedfOffset = iree::MultiArchExecutableDef::Pack(fbb, &maedf); - RETURN_IF_FAILURE( - moduleBuilder->executable_table()->AddMultiArchExecutable(maedfOffset)); - - return success(); -} - -LogicalResult SequencerTranslator::translateSequencerModule( - ModuleOp module, VMModuleBuilder *moduleBuilder) { - // Declare functions. This must happen first so that we get stable indices - // during declaration (as call ops need to use the function table). - for (auto function : module.getOps<FuncOp>()) { - RETURN_IF_FAILURE(declareFunction(function, moduleBuilder)); - } - - // Define functions and convert their bodies to bytecode. - for (auto function : module.getOps<FuncOp>()) { - RETURN_IF_FAILURE(defineFunction(function, moduleBuilder)); - } - - return success(); -} - -LogicalResult SequencerTranslator::declareFunction( - FuncOp function, VMModuleBuilder *moduleBuilder) { - auto *functionTable = moduleBuilder->function_table(); - if (functionTable->IsFunctionDeclared(function)) { - // Already declared. - return success(); - } - - LinkageType linkageType; - if (function.isExternal()) { - linkageType = LinkageType::kImport; - } else if (function.getAttr("iree.module.export")) { - linkageType = LinkageType::kExport; - } else { - linkageType = LinkageType::kInternal; - } - if (failed(functionTable->DeclareFunction(function, linkageType))) { - return function.emitError() - << "Unable to declare function " << function.getName(); - } - - // Import functions must have their definition defined here so we get their - // type. Internal and export functions will be defined during conversion. - if (linkageType == LinkageType::kImport) { - VMFunctionBuilder functionBuilder(function, moduleBuilder->function_table(), - moduleBuilder->fbb()); - auto functionOffset = functionBuilder.Finish(); - if (functionOffset.IsNull()) { - return function.emitError() - << "Failed to create import function bytecode"; - } - RETURN_IF_FAILURE( - functionTable->DefineFunction(function, functionOffset, {})); - } - - return success(); -} - -LogicalResult SequencerTranslator::defineFunction( - FuncOp function, VMModuleBuilder *moduleBuilder) { - VMFunctionBuilder functionBuilder(function, moduleBuilder->function_table(), - moduleBuilder->fbb()); - registerSequencerCustomWriters(&functionBuilder); - RETURN_IF_FAILURE(functionBuilder.ConvertBytecode()); - auto functionOffset = functionBuilder.Finish(); - if (functionOffset.IsNull()) { - return function.emitError() << "Failed to convert function to bytecode"; - } - RETURN_IF_FAILURE(moduleBuilder->function_table()->DefineFunction( - function, functionOffset, functionBuilder.source_map())); - return success(); -} - -} // namespace - -std::vector<uint8_t> translateMlirToIreeSequencerModule( - ModuleOp module, ModuleTranslationOptions options) { - SequencerTranslator translator(options); - return translator.translateModule(module); -} - -LogicalResult translateMlirToIreeSequencerModuleFile( - ModuleOp module, llvm::raw_ostream &output) { - ModuleTranslationOptions options; - SequencerTranslator translator(options); - auto bytecodeModule = translator.translateModule(module); - if (bytecodeModule.empty()) { - return emitError(UnknownLoc::get(module.getContext()), - "failed to translate module"); - } - - output.write(reinterpret_cast<const char *>(bytecodeModule.data()), - bytecodeModule.size()); - return success(); -} - -static TranslateFromMLIRRegistration MlirToIreeSequencerModuleTranslate( - "mlir-to-iree-module", translateMlirToIreeSequencerModuleFile); - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Translation/Sequencer/SequencerModuleTranslation.h b/iree/compiler/Translation/Sequencer/SequencerModuleTranslation.h deleted file mode 100644 index 0b861a3..0000000 --- a/iree/compiler/Translation/Sequencer/SequencerModuleTranslation.h +++ /dev/null
@@ -1,36 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_COMPILER_TRANSLATION_SEQUENCER_SEQUENCERMODULETRANSLATION_H_ -#define IREE_COMPILER_TRANSLATION_SEQUENCER_SEQUENCERMODULETRANSLATION_H_ - -#include <vector> - -#include "iree/compiler/Utils/TranslationUtils.h" -#include "mlir/IR/Module.h" - -namespace mlir { -namespace iree_compiler { - -// Translates an MLIR module in a compatible IREE input dialect (such as XLA HLO -// and/or Std) into an IREE Module. Executables will be lowered based on the -// provided configuration. -// Returns an empty vector on translation failure. -std::vector<uint8_t> translateMlirToIreeSequencerModule( - ModuleOp module, ModuleTranslationOptions options = {}); - -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_TRANSLATION_SEQUENCER_SEQUENCERMODULETRANSLATION_H_
diff --git a/iree/compiler/Utils/BUILD b/iree/compiler/Utils/BUILD deleted file mode 100644 index 8ff56d5..0000000 --- a/iree/compiler/Utils/BUILD +++ /dev/null
@@ -1,41 +0,0 @@ -# Utilities for working with IREE MLIR types. - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "Utils", - srcs = [ - "DispatchUtils.cpp", - "MemRefUtils.cpp", - "ModuleUtils.cpp", - "OpCreationUtils.cpp", - "OpUtils.cpp", - "TranslationUtils.cpp", - "TypeConversionUtils.cpp", - ], - hdrs = [ - "DispatchUtils.h", - "Macros.h", - "MemRefUtils.h", - "ModuleUtils.h", - "OpCreationUtils.h", - "OpUtils.h", - "TranslationUtils.h", - "TypeConversionUtils.h", - ], - deps = [ - "//iree/compiler/IR", - "//iree/schemas", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Support", - "@local_config_mlir//:TransformUtils", - "@local_config_mlir//:Transforms", - "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo", - ], -)
diff --git a/iree/compiler/Utils/DispatchUtils.cpp b/iree/compiler/Utils/DispatchUtils.cpp deleted file mode 100644 index 9e3fb12..0000000 --- a/iree/compiler/Utils/DispatchUtils.cpp +++ /dev/null
@@ -1,726 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/Utils/DispatchUtils.h" - -#include "iree/compiler/IR/Ops.h" -#include "iree/compiler/Utils/TypeConversionUtils.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/Utils.h" -#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" - -namespace mlir { -namespace iree_compiler { - -Value *calculateWorkload(Operation *op, Value *baseOperand) { - OpBuilder builder(op); - - std::array<int32_t, 3> workload = {1, 1, 1}; - - // TODO(b/139353314): lookup/calculate based on type/etc. - auto resultType = baseOperand->getType(); - if (auto shapedType = resultType.dyn_cast<ShapedType>()) { - if (!shapedType.hasStaticShape()) { - op->emitOpError() << "Dynamic shapes not yet supported"; - return nullptr; - } - auto shape = shapedType.getShape(); - // Drop the trailing ones from the shape. - while (shape.size() > 1 && shape.back() == 1) { - shape = shape.drop_back(); - } - if (shape.size() <= 3) { - // Maps to XYZ (possibly with 1's for unused dimensions). - for (auto dim : enumerate(shape)) { - workload[shape.size() - 1 - dim.index()] = dim.value(); - } - } else { - // Need to flatten the shape to fit XYZ. For now we just squash from LHS. - workload[2] = 1; - for (int i = 0; i < shape.size(); ++i) { - workload[2] *= shape[i]; - } - workload[1] = shape[shape.size() - 2]; - workload[0] = shape.back(); - } - } - - // TODO(b/139353314): optimize workload layout. - - auto constantType = RankedTensorType::get({3}, builder.getIntegerType(32)); - return builder.create<ConstantOp>( - op->getLoc(), constantType, - DenseIntElementsAttr::get<int32_t>(constantType, workload)); -} - -bool isTriviallyDispatchable(FuncOp func) { - if (func.empty()) return false; - auto &block = func.front(); - if (block.getOperations().size() != 2) return false; - auto &op0 = block.front(); - auto &op1 = block.back(); - auto regionOp = dyn_cast<IREE::DispatchRegionOp>(op0); - auto returnOp = dyn_cast<ReturnOp>(op1); - if (!regionOp || !returnOp || - regionOp.getNumResults() != returnOp.getNumOperands()) { - return false; - } - for (int i = 0; i < regionOp.getNumResults(); ++i) { - if (regionOp.getResult(i) != returnOp.getOperand(i)) return false; - } - return true; -} - -namespace { - -// Returns the set of values that must be captured for use by |ops| and the -// set of values defined by |ops| that are used outside of the set. -LogicalResult analyzeOpRangeValues( - const llvm::SmallDenseSet<Operation *> &opSet, - llvm::SetVector<Value *> *capturedValues, - llvm::SetVector<Value *> *escapingValues) { - for (auto *op : opSet) { - for (auto *value : op->getOperands()) { - if (!llvm::is_contained(opSet, value->getDefiningOp())) { - // Op is using a value not in the ops set, ensure we capture it. - capturedValues->insert(value); - } - } - for (auto *value : op->getResults()) { - for (auto &use : value->getUses()) { - if (!llvm::is_contained(opSet, use.getOwner())) { - // An op outside of the ops set is using the value, needs to escape. - escapingValues->insert(value); - } - } - } - } - return success(); -} - -} // namespace - -LogicalResult buildDispatchRegion(FuncOp func, Block *parentBlock, - Value *workload, ArrayRef<Operation *> ops) { - // Fused location with all ops. - SmallVector<Location, 16> opLocs; - for (auto *op : ops) { - opLocs.push_back(op->getLoc()); - } - auto regionLoc = FusedLoc::get(opLocs, func.getContext()); - - // Get a list of values that we need to capture and values that escape the - // region and need to be returned. - llvm::SmallDenseSet<Operation *> opSet; - opSet.reserve(ops.size()); - opSet.insert(ops.begin(), ops.end()); - llvm::SetVector<Value *> capturedValues; - llvm::SetVector<Value *> escapingValues; - if (failed(analyzeOpRangeValues(opSet, &capturedValues, &escapingValues))) { - return failure(); - } - SmallVector<Type, 8> escapingTypes; - for (auto *value : escapingValues) escapingTypes.push_back(value->getType()); - - // Build the region op and add it to the parent block. - OpBuilder parentBuilder(parentBlock); - parentBuilder.setInsertionPoint(ops.back()); - auto dispatchRegionOp = parentBuilder.create<IREE::DispatchRegionOp>( - regionLoc, escapingTypes, workload, capturedValues.getArrayRef()); - - // Create the block and setup the arg mapping for captured values. - auto *regionBlock = new Block(); - dispatchRegionOp.getBody().push_back(regionBlock); - OpBuilder regionBuilder(regionBlock); - BlockAndValueMapping mapping; - for (auto *capturedValue : capturedValues) { - auto *blockArg = regionBlock->addArgument(capturedValue->getType()); - mapping.map(capturedValue, blockArg); - } - - // Clone ops into the new region block. - for (auto *op : ops) { - // Note that this updates the mapping with the new values (so at the end - // we have those new values). - regionBuilder.clone(*op, mapping); - } - - // Return results (as we need a terminator in our block). - // These are all of the values that escape our region. - SmallVector<Value *, 8> resultValues; - for (auto *oldValue : escapingValues) { - resultValues.push_back(mapping.lookupOrDefault(oldValue)); - } - regionBuilder.create<IREE::ReturnOp>(opLocs.back(), resultValues); - - // Replace usage of values with the results of the region. - for (int i = 0; i < escapingValues.size(); ++i) { - escapingValues[i]->replaceAllUsesWith(dispatchRegionOp.getResult(i)); - } - - // Remove original ops from the parent region. - for (auto it = ops.rbegin(); it != ops.rend(); ++it) { - (*it)->erase(); - } - - return success(); -} - -namespace { - -// Replaces |returnOp| with a clone including |newOperands| appended. -LogicalResult appendReturnOperands(IREE::ReturnOp returnOp, - ArrayRef<Value *> newOperands) { - // Insert prior to the original return. - OpBuilder builder(returnOp); - - // Clone with new args. - SmallVector<Value *, 8> operands; - operands.reserve(returnOp.getNumOperands() + newOperands.size()); - operands.append(returnOp.operand_begin(), returnOp.operand_end()); - operands.append(newOperands.begin(), newOperands.end()); - builder.create<IREE::ReturnOp>(returnOp.getLoc(), operands); - - // Remove original. - returnOp.erase(); - - return success(); -} - -// Replaces |regionOp| with a clone including |newArgs| and |newResults|. -IREE::DispatchRegionOp appendRegionArgsAndResults( - IREE::DispatchRegionOp ®ionOp, ArrayRef<Value *> newArgs, - ArrayRef<Value *> newResults, Location otherLoc) { - // Insert prior to the original region. - OpBuilder builder(regionOp); - - // Location is original region + new region location (both probably fused). - SmallVector<Location, 2> fusedLocs = {regionOp.getLoc(), otherLoc}; - auto fusedLoc = FusedLoc::get(fusedLocs, regionOp.getContext()); - - // Clone with new results. - SmallVector<Value *, 8> operands; - operands.append(regionOp.getArgOperands().begin(), - regionOp.getArgOperands().end()); - operands.append(newArgs.begin(), newArgs.end()); - SmallVector<Type, 8> resultTypes; - resultTypes.append(regionOp.result_type_begin(), regionOp.result_type_end()); - for (auto *newResult : newResults) { - resultTypes.push_back(newResult->getType()); - } - auto newRegionOp = builder.create<IREE::DispatchRegionOp>( - fusedLoc, resultTypes, regionOp.getWorkload(), operands, - regionOp.getAttrs()); - newRegionOp.getBody().takeBody(regionOp.getBody()); - - // Replace uses of original values with the new values. - for (int i = 0; i < regionOp.getNumResults(); ++i) { - regionOp.getResult(i)->replaceAllUsesWith(newRegionOp.getResult(i)); - } - - // Erase the original region. - regionOp.erase(); - - return newRegionOp; -} - -// Removes results that are not used from the dispatch region. -// Returns the new operation. There may be unused ops in the region but DCE -// should take care of that later. -IREE::DispatchRegionOp removeUnusedResults(IREE::DispatchRegionOp regionOp) { - // Find return value within the region. - auto ®ionBlock = regionOp.getBody().getBlocks().front(); - auto returnOp = dyn_cast<IREE::ReturnOp>(regionBlock.getTerminator()); - if (!returnOp) { - regionBlock.getParent()->getParentOfType<FuncOp>().emitError() - << "Block does not contain an iree.return op"; - } - - // Calculate new return values. - SmallVector<Type, 8> newReturnTypes; - SmallVector<Value *, 8> newReturnValues; - SmallVector<Value *, 8> newRegionResults; - for (int i = 0; i < returnOp.getNumOperands(); ++i) { - auto *resultValue = regionOp.getResult(i); - if (!resultValue->use_empty()) { - // Still has uses so we will preserve it. - newReturnTypes.push_back(resultValue->getType()); - newReturnValues.push_back(returnOp.getOperand(i)); - newRegionResults.push_back(resultValue); - } - } - - // Update return op operands. We can do this in-place as we are only shrinking - // the list. - returnOp.getOperation()->setOperands(newReturnValues); - - // Insert prior to the original region. - OpBuilder builder(regionOp); - - // Clone with new results. - SmallVector<Value *, 8> operands(regionOp.getArgOperands()); - auto newRegionOp = builder.create<IREE::DispatchRegionOp>( - regionOp.getLoc(), newReturnTypes, regionOp.getWorkload(), operands, - regionOp.getAttrs()); - newRegionOp.getBody().takeBody(regionOp.getBody()); - - // Replace uses of original values with the new values. - for (int i = 0; i < newRegionResults.size(); ++i) { - newRegionResults[i]->replaceAllUsesWith(newRegionOp.getResult(i)); - } - - // Erase the original region. - regionOp.erase(); - - return newRegionOp; -} - -// Returns true if |lhs| and |rhs| have either an identical workload or one that -// is compatible. -bool areDispatchRegionWorkloadsCompatible(IREE::DispatchRegionOp &lhs, - IREE::DispatchRegionOp &rhs) { - // TODO(benvanik): more sophisticated checking; right now it's just identical. - return lhs.getWorkload() == rhs.getWorkload(); -} - -// Returns true if |value| depends in any way on |op| through any path. -// Only works if the operations are within the same block. -bool doesValueDependOnOperation(Value *value, Operation *op) { - if (!value->getDefiningOp()) { - return false; - } else if (value->getDefiningOp() == op) { - return true; - } else if (value->getDefiningOp()->isBeforeInBlock(op)) { - // Can't depend on |op| as it is defined prior to it. - return false; - } - for (auto *operand : value->getDefiningOp()->getOperands()) { - if (doesValueDependOnOperation(operand, op)) { - return true; - } - } - return true; -} - -// Returns true if |rhs| transitively depends on any out of |lhs|. -// |rhs| may depend directly on the results of |lhs| but no other ops in the -// parent block will use the results prior to |rhs|. -bool areDispatchRegionsTransitivelyDependent(IREE::DispatchRegionOp &lhs, - IREE::DispatchRegionOp &rhs) { - for (auto *arg : rhs.getArgOperands()) { - if (arg->getDefiningOp() != lhs && doesValueDependOnOperation(arg, lhs)) { - // Transitively dependent - boo - can't merge yet. - return true; - } - } - return false; -} - -// Returns true if the dispatch region contains only a single block. -// This is because our merge isn't very smart and will not preserve the CFG -// right now. We can fix this when needed. -bool isDispatchRegionMergable(IREE::DispatchRegionOp ®ionOp) { - // Disallow merging of dispatch regions containing matmuls and other big ops. - // We do this to allow backends to lower the big op as entirely isolated such - // that substituting library calls is easier. - for (auto &block : regionOp.getBody().getBlocks()) { - for (auto &op : block) { - if (isa<xla_hlo::DotOp>(op) || isa<xla_hlo::ConvOp>(op)) { - return false; - } - } - } - return regionOp.getBody().getBlocks().size() == 1; -} - -// Merges |rhs| into |lhs| and returns the new |lhs| op. -// Precondition: !areDispatchRegionsTransitivelyDependent -IREE::DispatchRegionOp mergeDispatchRegions(IREE::DispatchRegionOp &lhs, - IREE::DispatchRegionOp &rhs) { - auto &lhsBlock = lhs.getBody().front(); - auto &rhsBlock = rhs.getBody().front(); - - // Find the values used as return values in the lhs. - // We'll need to replace the uses in rhs with these. - auto lhsReturnOp = cast<IREE::ReturnOp>(lhsBlock.getTerminator()); - SmallVector<Value *, 8> lhsReturnValues; - lhsReturnValues.reserve(lhsReturnOp.getNumOperands()); - lhsReturnValues.append(lhsReturnOp.operand_begin(), - lhsReturnOp.operand_end()); - - // Find the values used as return values in the rhs. - // We'll add these to the results of the lhs region. - auto rhsReturnOp = cast<IREE::ReturnOp>(rhsBlock.getTerminator()); - SmallVector<Value *, 8> rhsReturnValues; - rhsReturnValues.reserve(rhsReturnOp.getNumOperands()); - rhsReturnValues.append(rhsReturnOp.operand_begin(), - rhsReturnOp.operand_end()); - - // Compute new args. - BlockAndValueMapping mapping; - SmallVector<Value *, 8> newArgs; - for (int rhsOpIdx = 0; rhsOpIdx < rhs.getNumArgOperands(); ++rhsOpIdx) { - bool didElide = false; - // Find if the rhs arg already exists on the lhs and dedupe. - for (int lhsOpIdx = 0; lhsOpIdx < lhs.getNumArgOperands(); ++lhsOpIdx) { - if (rhs.getArgOperand(rhsOpIdx) == lhs.getArgOperand(lhsOpIdx)) { - mapping.map(rhsBlock.getArgument(rhsOpIdx), - lhsBlock.getArgument(lhsOpIdx)); - didElide = true; - break; - } - } - // Find if the arg has a direct dependency on the results of the lhs. - for (int lhsResultIdx = 0; lhsResultIdx < lhs.getNumResults(); - ++lhsResultIdx) { - if (rhs.getArgOperand(rhsOpIdx) == lhs.getResult(lhsResultIdx)) { - // Direct dependency; can elide. We'll skip adding it to the new region - // args and instead just remap it later. - mapping.map(rhsBlock.getArgument(rhsOpIdx), - lhsReturnValues[lhsResultIdx]); - didElide = true; - break; - } - } - if (!didElide) { - // Add to the lhs block. - auto *oldArg = rhs.getOperand(rhsOpIdx + 1); - auto *newArg = lhsBlock.addArgument(oldArg->getType()); - mapping.map(rhsBlock.getArgument(rhsOpIdx), newArg); - newArgs.push_back(oldArg); - } - } - - OpBuilder regionBuilder(&lhsBlock); - - // Copy ops (replacing any args as needed). - // Note that we need to insert prior to the terminator. - regionBuilder.setInsertionPoint(lhsReturnOp); - for (auto &op : rhsBlock) { - // Note that this updates the mapping with the new values (so at the end - // we have those new values). - // - // We avoid the return op here as we have already merged it above. - if (!op.isKnownTerminator()) { - regionBuilder.clone(op, mapping); - } - } - - // Compute new results and add to both region and return op. - SmallVector<Value *, 8> newResults; - for (auto *rhsResult : rhsReturnValues) { - newResults.push_back(mapping.lookupOrDefault(rhsResult)); - } - if (failed(appendReturnOperands(lhsReturnOp, newResults))) { - return nullptr; - } - auto newRegionOp = - appendRegionArgsAndResults(lhs, newArgs, newResults, rhs.getLoc()); - - // Replace uses of original values with the new values. - for (int i = 0; i < rhs.getNumResults(); ++i) { - rhs.getResult(i)->replaceAllUsesWith( - newRegionOp.getResult(lhsReturnValues.size() + i)); - } - - // Remove rhs region. - rhs.erase(); - - // Remove results from the lhs that aren't used anymore as they may have been - // elided when we merged as only the rhs was using them. - newRegionOp = removeUnusedResults(newRegionOp); - - return newRegionOp; -} - -} // namespace - -LogicalResult mergeBlockDispatchRegions(FuncOp func, Block *parentBlock) { - SmallVector<IREE::DispatchRegionOp, 8> mergableRegions; - for (auto &op : *parentBlock) { - if (auto regionOp = dyn_cast<IREE::DispatchRegionOp>(op)) { - if (isDispatchRegionMergable(regionOp)) { - mergableRegions.push_back(regionOp); - } else { - regionOp.emitRemark( - "Unable to merge into following iree.dispatch_regions; " - "contains non-trivial control flow"); - } - } - } - for (int i = 0; i < mergableRegions.size(); ++i) { - if (!mergableRegions[i]) continue; - auto &lhs = mergableRegions[i]; - for (int j = i + 1; j < mergableRegions.size(); ++j) { - if (!mergableRegions[j]) continue; - auto &rhs = mergableRegions[j]; - if (!areDispatchRegionWorkloadsCompatible(lhs, rhs) || - areDispatchRegionsTransitivelyDependent(lhs, rhs)) { - continue; - } - if (!isDispatchRegionMergable(rhs)) { - // TODO(b/134675461): support non-trivial control flow. - rhs.emitRemark( - "Unable to merge into previous iree.dispatch_region; " - "contains non-trivial control flow"); - } - mergableRegions[i] = mergeDispatchRegions(lhs, rhs); - if (!mergableRegions[i]) { - return failure(); - } - mergableRegions[j] = nullptr; - --i; // Try again to see if there are subsequent regions to merge. - break; - } - } - - return success(); -} - -namespace { - -// Recursively clones the given |sourceOp| and returns the newly cloned op. -Operation *recursivelyCloneOp(Operation *sourceOp, OpBuilder &builder, - BlockAndValueMapping *mapping) { - // Note that we dedupe required operands in the case of multiple arguments - // coming from the same source operation. - SmallPtrSet<Operation *, 4> operandOps; - for (auto *operand : sourceOp->getOperands()) { - operandOps.insert(operand->getDefiningOp()); - } - for (auto *operandOp : operandOps) { - recursivelyCloneOp(operandOp, builder, mapping); - } - return builder.clone(*sourceOp, *mapping); -} - -// Clones the |sourceValue| op tree into |targetBlock|. -// |mapping| is used to lookup existing values that may be present in the block -// such as block arguments or already cloned ancestor ops. |mapping| will be -// updated as the tree is cloned. -Value *cloneOpTreeIntoBlock(Value *sourceValue, Block *targetBlock, - BlockAndValueMapping *mapping) { - // If the op has already been cloned we can just reuse that. - // This happens if multiple arguments reference the same trees. - if (auto *existingValue = mapping->lookupOrNull(sourceValue)) { - return existingValue; - } - - OpBuilder builder(targetBlock); - builder.setInsertionPointToStart(targetBlock); - auto *sourceOp = sourceValue->getDefiningOp(); - auto *clonedOp = recursivelyCloneOp(sourceOp, builder, mapping); - - // Return only the result matching our source value (in the case of multiple - // results). - int resultIndex = std::distance( - sourceOp->result_begin(), - std::find(sourceOp->result_begin(), sourceOp->result_end(), sourceValue)); - return clonedOp->getResult(resultIndex); -} - -} // namespace - -LogicalResult inlineDispatchRegionOperandsUsingValue( - IREE::DispatchRegionOp dispatchRegionOp, Value *value) { - // Find all args that are using this value. - SmallVector<unsigned, 4> argIndices; - for (auto arg : llvm::enumerate(dispatchRegionOp.getArgOperands())) { - if (arg.value() == value) { - argIndices.push_back(arg.index()); - } - } - if (argIndices.empty()) { - // Not used? Wasteful call! - return success(); - } - - // Clone the value (and the ops required to create it) into the entry block. - auto &entryBlock = dispatchRegionOp.getBody().getBlocks().front(); - BlockAndValueMapping mapping; - auto *clonedValue = cloneOpTreeIntoBlock(value, &entryBlock, &mapping); - - // Replace all uses of the inner operand with the new value. - for (unsigned argIndex : argIndices) { - entryBlock.getArgument(argIndex)->replaceAllUsesWith(clonedValue); - } - - // Remove the dispatch region args and the block args that have been - // replaced. - for (unsigned argIndex : llvm::reverse(argIndices)) { - dispatchRegionOp.getOperation()->eraseOperand( - dispatchRegionOp.mapArgOperandToOpOperand(argIndex)); - entryBlock.eraseArgument(argIndex); - } - - return success(); -} - -namespace { - -// Recursively finds all reachable functions from the given |rootFunc| and adds -// them to the |reachableFuncs| set. -// -// Note that indirect calls are not supported, however we don't allow those in -// dispatch regions anyway so they should not be present here. -LogicalResult findReachableFunctions(Operation *rootFunc, - llvm::SetVector<FuncOp> &reachableFuncs) { - bool allCallsValid = true; - rootFunc->walk([&](CallOp op) { - auto callee = rootFunc->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>( - op.getCallee()); - if (!callee.getAttr("iree.dispatchable")) { - allCallsValid = false; - rootFunc->emitError() << callee.getName().str() << " is not dispatchable"; - return; - } - if (reachableFuncs.insert(callee)) { - findReachableFunctions(callee, reachableFuncs); - } - }); - return success(allCallsValid); -} - -} // namespace - -std::pair<IREE::MultiArchExecutableOp, FuncOp> createRegionExecutable( - Operation *op, FunctionType functionType, StringRef symbolSuffix) { - // Create the function and take the region body directly. - // NOTE: this will get uniquified if we have multiple in the same block. - auto parentFunc = op->getParentOfType<FuncOp>(); - std::string functionName = - (parentFunc.getName().str() + "_rgn" + symbolSuffix).str(); - auto outlinedFunc = FuncOp::create(op->getLoc(), functionName, functionType); - BlockAndValueMapping mapping; - op->getRegion(0).cloneInto(&outlinedFunc.getBody(), mapping); - - // Gather all reachable functions. - llvm::SetVector<FuncOp> reachableFuncs; - findReachableFunctions(outlinedFunc, reachableFuncs); - - // Create the multi-arch executable that will contain the outlined region. - // NOTE: this will get uniquified if we have multiple in the same block. - auto parentModule = parentFunc.getParentOfType<ModuleOp>(); - OpBuilder parentModuleBuilder(parentModule); - parentModuleBuilder.setInsertionPoint(parentFunc); - std::string executableName = - (parentFunc.getName().str() + "_ex" + symbolSuffix).str(); - auto multiArchExecutable = - parentModuleBuilder.create<IREE::MultiArchExecutableOp>( - outlinedFunc.getLoc(), executableName); - - // Create the executable op initially unspecified so that later - // transformations can compile it to various formats. - OpBuilder multiArchExecutableBuilder(multiArchExecutable); - multiArchExecutableBuilder.setInsertionPointToStart( - &multiArchExecutable.getBlock()); - auto executable = multiArchExecutableBuilder.create<IREE::ExecutableOp>( - outlinedFunc.getLoc(), IREE::ExecutableFormat::Unspecified); - - // Create the inner ModuleOp that contains the original functions. We need - // to provide this shim as some ops (like std.call) look for the - // containing module to provide symbol resolution. - OpBuilder executableBuilder(executable); - executableBuilder.setInsertionPointToStart(&executable.getBlock()); - auto innerModule = executableBuilder.create<ModuleOp>(outlinedFunc.getLoc()); - - // TODO(b/137674142): make an ExecutableEntryPointOp and convert the - // entry thunk into that format. - innerModule.push_back(outlinedFunc); - - // Copy all reachable functions into the executable. - // Linker passes may dedupe these later on. - for (auto reachableFunc : reachableFuncs) { - auto clonedFunc = reachableFunc.clone(); - clonedFunc.removeAttr("iree.dispatchable"); - innerModule.push_back(clonedFunc); - } - - return std::make_pair(multiArchExecutable, outlinedFunc); -} - -Value *insertDispatcherStore(Operation *op, Value *value, OpBuilder &builder) { - if (!value) { - return nullptr; - } - - // If the previous value was already a memref we don't need to change - // anything. - // TODO(benvanik): ensure indices make sense. - if (value->getType().isa<MemRefType>()) { - return value; - } else if (value->getType().isa<TensorType>()) { - auto castOp = builder.create<IREE::TensorToMemRefOp>(op->getLoc(), value); - return castOp.getResult(); - } - - // Allocate the memref to store the value. - auto newStorage = builder.create<AllocOp>( - op->getLoc(), convertTypeToMemRef(value->getType())); - - // Insert the store we'll use to box the value. - builder.create<StoreOp>(op->getLoc(), value, newStorage, ArrayRef<Value *>{}); - - return newStorage; -} - -Value *insertDispatcherLoad(Operation *op, Value *originalValue, - Value *allocatedValue, OpBuilder &builder) { - // If old value was a memref we don't need to change anything. - if (originalValue->getType().isa<MemRefType>()) { - return allocatedValue; - } else if (originalValue->getType().isa<TensorType>()) { - auto castOp = - builder.create<IREE::MemRefToTensorOp>(op->getLoc(), allocatedValue); - originalValue->replaceAllUsesWith(castOp.getResult()); - return castOp.getResult(); - } - - // Insert the load we'll use to unbox the value. - auto loadOp = - builder.create<LoadOp>(op->getLoc(), allocatedValue, ArrayRef<Value *>{}); - originalValue->replaceAllUsesWith(loadOp); - return loadOp; -} - -// TODO(benvanik): enough information to walk into dispatch region and compute -// shape when not static. -Value *allocateDispatchOutputBuffer(Location loc, MemRefType type, - OpBuilder &builder) { - // TODO(benvanik): allocation algorithm: - // - synthesize shape logic (magic) [[ for now assume fixed shapes ]] - // - insert shape logic above region - // - rely on folding to merge multiple calculations together - // - unranked = death, need to be able to alloc shape outputs - // - insert alloc - SmallVector<Value *, 4> dimPieces; - return builder.create<AllocOp>(loc, type, dimPieces); -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Utils/DispatchUtils.h b/iree/compiler/Utils/DispatchUtils.h deleted file mode 100644 index 80f9be1..0000000 --- a/iree/compiler/Utils/DispatchUtils.h +++ /dev/null
@@ -1,92 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Utilities for dispatch region and function manipulation. -// These are shared between all dispatchable types such as the standard -// iree.dispatch_region as well as dispatch-related types like -// iree.reduction_region. - -#ifndef IREE_COMPILER_UTILS_DISPATCHUTILS_H_ -#define IREE_COMPILER_UTILS_DISPATCHUTILS_H_ - -#include <utility> - -#include "iree/compiler/IR/Ops.h" -#include "iree/compiler/IR/StructureOps.h" -#include "llvm/ADT/ArrayRef.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Function.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/Value.h" - -namespace mlir { -namespace iree_compiler { - -// Calculates the workload for |op| based on the op type. -Value *calculateWorkload(Operation *op, Value *baseOperand); - -// Returns true if the func is trivially dispatchable, meaning that: -// - it contains a single block -// - it contains a single dispatch region -// - it contains a return op directly returning the dispatch region results -bool isTriviallyDispatchable(FuncOp func); - -// Builds a new iree.dispatch_region with the given |ops|. -// The region will capture all required values and return all values used -// outside of the |ops| provided. The region will be inserted at the location of -// the last operation in the set. -// -// All |ops| must be compatible with the |workload| specified as they will all -// be dispatched with the same workgroup structure. -// TODO(benvanik): ensure we want to insert at end. Maybe front? -LogicalResult buildDispatchRegion(FuncOp func, Block *parentBlock, - Value *workload, ArrayRef<Operation *> ops); - -// Merges multiple dispatch regions within a block into the same region, -// if possible. Operations may be reordered if it's possible to merge more while -// still obeying data dependencies. -LogicalResult mergeBlockDispatchRegions(FuncOp func, Block *parentBlock); - -// Inlines use of the given |value| from outside of a dispatch region to inside -// of it and removes the argument. Supports multiple arguments that reference -// |value| and will clone the entire value tree. -LogicalResult inlineDispatchRegionOperandsUsingValue( - IREE::DispatchRegionOp dispatchRegionOp, Value *value); - -// Creates an iree.multi_arch_executable containing an iree.executable with an -// exported function containing the body region of |op|. Created executables -// will be named for their original function concatenated with |symbolSuffix|. -std::pair<IREE::MultiArchExecutableOp, FuncOp> createRegionExecutable( - Operation *op, FunctionType functionType, StringRef symbolSuffix); - -// Inserts a conversion of an arbitrary |value| to a memref, possibly by way of -// wrapping in an allocation. -// Returns a new memref containing the value or an alias to |value|. -Value *insertDispatcherStore(Operation *op, Value *value, OpBuilder &builder); - -// Inserts a load from a wrapped memref. -// Returns the value in the original type or an alias to the |value| memref. -Value *insertDispatcherLoad(Operation *op, Value *originalValue, - Value *allocatedValue, OpBuilder &builder); - -// TODO(benvanik): enough information to walk into dispatch region and compute -// shape when not static. -Value *allocateDispatchOutputBuffer(Location loc, MemRefType type, - OpBuilder &builder); - -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_UTILS_DISPATCHUTILS_H_
diff --git a/iree/compiler/Utils/MemRefUtils.cpp b/iree/compiler/Utils/MemRefUtils.cpp deleted file mode 100644 index acb8822..0000000 --- a/iree/compiler/Utils/MemRefUtils.cpp +++ /dev/null
@@ -1,94 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/Utils/MemRefUtils.h" - -#include <cassert> - -#include "iree/compiler/IR/Ops.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/ErrorHandling.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/StandardTypes.h" - -namespace mlir { -namespace iree_compiler { -Value *resolveValueToSourceMemRef(Value *value, Operation *useOp) { - // TODO(benvanik): implement this for real; this is naive but enough for our - // simple load patterns. - auto *defInstr = value->getDefiningOp(); - if (auto loadOp = dyn_cast_or_null<LoadOp>(defInstr)) { - // TODO(benvanik): support views. - return loadOp.getMemRef(); - } - return nullptr; -} - -Value *wrapAsTensor(Value *value, Operation *srcOp, OpBuilder &builder) { - if (srcOp->getResult(0)->getType().isa<TensorType>()) { - if (isa_and_nonnull<IREE::TensorToMemRefOp>(value->getDefiningOp())) { - return value->getDefiningOp()->getOperand(0); - } - auto newOp = builder.create<IREE::MemRefToTensorOp>(srcOp->getLoc(), value); - value = newOp.getResult(); - } - return value; -} - -Value *wrapAsMemRef(Value *value, Operation *srcOp, OpBuilder &builder) { - if (value->getType().isa<TensorType>()) { - if (isa_and_nonnull<IREE::MemRefToTensorOp>(value->getDefiningOp())) { - return value->getDefiningOp()->getOperand(0); - } - auto newOp = builder.create<IREE::TensorToMemRefOp>(srcOp->getLoc(), value); - value = newOp.getResult(); - } - return value; -} - -Value *loadAccessValue(Location location, Value *operand, OpBuilder &builder) { - if (operand->getType().isa<MemRefType>() || - operand->getType().isa<TensorType>()) { - return operand; - } - - auto memRefType = MemRefType::get({}, operand->getType()); - if (auto loadOp = dyn_cast_or_null<LoadOp>(operand->getDefiningOp())) { - // TODO(benvanik): handle creating views. - if (loadOp.getMemRefType() == memRefType) { - return loadOp.getMemRef(); - } - } - - auto allocOp = builder.create<AllocOp>(location, memRefType); - builder.create<StoreOp>(location, operand, allocOp.getResult(), - ArrayRef<Value *>{}); - return allocOp.getResult(); -} - -Value *loadResultValue(Location location, const Type &originalType, - Value *result, OpBuilder &builder) { - if (originalType.isa<MemRefType>()) { - return result; - } else if (auto tensorType = originalType.dyn_cast<TensorType>()) { - return result; - } - - auto loadOp = builder.create<LoadOp>(location, result, ArrayRef<Value *>{}); - return loadOp.getResult(); -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Utils/ModuleUtils.cpp b/iree/compiler/Utils/ModuleUtils.cpp deleted file mode 100644 index 230fdb1..0000000 --- a/iree/compiler/Utils/ModuleUtils.cpp +++ /dev/null
@@ -1,98 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/Utils/ModuleUtils.h" - -#include "llvm/ADT/SetVector.h" -#include "mlir/IR/Function.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -// Finds a list of functions with the given |attrName| and adds them to |funcs|. -void findFunctionsWithAttr(ModuleOp module, const char *attrName, - llvm::SetVector<FuncOp> &funcs) { - for (auto func : module.getOps<FuncOp>()) { - if (func.getAttr(attrName)) { - funcs.insert(func); - } - } -} - -// Inserts functions reachable directly from |func| to |usedFuncs|. -void insertUsedFunctions(ModuleOp module, FuncOp func, - DenseSet<FuncOp> *usedFuncs, - std::vector<FuncOp> *toSearch) { - auto onCalledFunction = [&](StringRef calleeName) { - auto calleeFunc = module.lookupSymbol<FuncOp>(calleeName); - if (usedFuncs->insert(calleeFunc).second) { - // New function found! Add to queue for searching. - toSearch->push_back(calleeFunc); - } - }; - for (auto &block : func) { - for (auto &op : block) { - // TODO(benvanik): replace with iree_hl.call check. - if (auto calleeAttr = op.getAttr("callee")) { - onCalledFunction(calleeAttr.cast<SymbolRefAttr>().getValue()); - } - } - } -} - -// Returns a set containing the names of all functions used by the given -// |rootFuncs| list. -DenseSet<FuncOp> findUsedFunctions(ModuleOp module, - ArrayRef<FuncOp> rootFuncs) { - // Breadth-first search. - DenseSet<FuncOp> usedFuncs; - usedFuncs.insert(rootFuncs.begin(), rootFuncs.end()); - std::vector<FuncOp> toSearch = {rootFuncs.begin(), rootFuncs.end()}; - while (!toSearch.empty()) { - auto func = toSearch.back(); - toSearch.pop_back(); - insertUsedFunctions(module, func, &usedFuncs, &toSearch); - } - return usedFuncs; -} - -} // namespace - -void dropUnusedFunctions(ModuleOp module, ArrayRef<const char *> keepAttrs) { - // Find all of the exported functions we'll treat as roots. - llvm::SetVector<FuncOp> rootFuncs; - for (auto keepAttr : keepAttrs) { - findFunctionsWithAttr(module, keepAttr, rootFuncs); - } - - // Find the full set of all used functions reachable from the given rootFuncs. - // This set will contain the rootFuncs. - auto usedFuncs = findUsedFunctions(module, rootFuncs.getArrayRef()); - - // Drop all unused functions. - std::vector<FuncOp> deadFuncs; - for (auto func : module.getOps<FuncOp>()) { - if (!llvm::is_contained(usedFuncs, func)) { - deadFuncs.push_back(func); - } - } - for (auto func : deadFuncs) { - func.erase(); - } -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Utils/OpCreationUtils.cpp b/iree/compiler/Utils/OpCreationUtils.cpp deleted file mode 100644 index d624310..0000000 --- a/iree/compiler/Utils/OpCreationUtils.cpp +++ /dev/null
@@ -1,45 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/Utils/OpCreationUtils.h" - -#include <cstdint> - -#include "iree/compiler/IR/Ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Location.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -ElementsAttr elementsAttrFromArray(OpBuilder &builder, - ArrayRef<int64_t> elements) { - return DenseIntElementsAttr::get( - RankedTensorType::get(elements.size(), builder.getIntegerType(64)), - elements); -} - -} // namespace - -IREE::ConstantOp createArrayConstant(OpBuilder &builder, Location loc, - llvm::ArrayRef<int64_t> elements) { - auto elementsAttr = elementsAttrFromArray(builder, elements); - return builder.create<IREE::ConstantOp>(loc, elementsAttr); -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Utils/OpCreationUtils.h b/iree/compiler/Utils/OpCreationUtils.h deleted file mode 100644 index 7cdd701..0000000 --- a/iree/compiler/Utils/OpCreationUtils.h +++ /dev/null
@@ -1,39 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Utility functions related to the creation of new operations. Where possible, -// use custom builders. These helpers are for situations where a custom builder -// is not appropriate. - -#ifndef IREE_COMPILER_UTILS_OPCREATIONUTILS_H_ -#define IREE_COMPILER_UTILS_OPCREATIONUTILS_H_ - -#include <cstdint> - -#include "iree/compiler/IR/Ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Location.h" -#include "mlir/Support/LLVM.h" - -namespace mlir { -namespace iree_compiler { - -IREE::ConstantOp createArrayConstant(OpBuilder &builder, Location loc, - llvm::ArrayRef<int64_t> elements); - -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_UTILS_OPCREATIONUTILS_H_
diff --git a/iree/compiler/Utils/OpUtils.cpp b/iree/compiler/Utils/OpUtils.cpp deleted file mode 100644 index 5ff88b5..0000000 --- a/iree/compiler/Utils/OpUtils.cpp +++ /dev/null
@@ -1,44 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/Utils/OpUtils.h" - -namespace mlir { -namespace iree_compiler { - -void removeDeadOperations(llvm::SetVector<Operation *> &deadOperations) { - while (!deadOperations.empty()) { - auto *op = deadOperations.front(); - deadOperations.erase(deadOperations.begin()); - for (auto *operand : op->getOperands()) { - // TODO(benvanik): add check for op side effects. - if (operand->hasOneUse()) { - deadOperations.insert(operand->getDefiningOp()); - } - } - op->erase(); - } -} - -void replaceSubsequentUses(Operation *userOp, Value *oldValue, - Value *newValue) { - for (auto &use : oldValue->getUses()) { - if (userOp->isBeforeInBlock(use.getOwner())) { - use.set(newValue); - } - } -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Utils/TranslationUtils.cpp b/iree/compiler/Utils/TranslationUtils.cpp deleted file mode 100644 index 033485e..0000000 --- a/iree/compiler/Utils/TranslationUtils.cpp +++ /dev/null
@@ -1,139 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/Utils/TranslationUtils.h" - -#include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LogicalResult.h" - -namespace mlir { -namespace iree_compiler { - -namespace { - -// Returns the static registry of translator names to translation functions. -llvm::StringMap<TranslateExecutableFn> - &getMutableExecutableTranslationRegistry() { - static llvm::StringMap<TranslateExecutableFn> registry; - return registry; -} - -// Returns true if the given |value| matches |pattern| (normal * and ? rules). -bool matchPattern(StringRef value, StringRef pattern) { - size_t nextCharIndex = pattern.find_first_of("*?"); - if (nextCharIndex == std::string::npos) { - return value == pattern; - } else if (nextCharIndex > 0) { - if (value.substr(0, nextCharIndex) != pattern.substr(0, nextCharIndex)) { - return false; - } - value = value.substr(nextCharIndex); - pattern = pattern.substr(nextCharIndex); - } - char patternChar = pattern[0]; - if (value.empty() && pattern.empty()) { - return true; - } else if (patternChar == '*' && pattern.size() > 1 && value.empty()) { - return false; - } else if (patternChar == '*' && pattern.size() == 1) { - return true; - } else if (patternChar == '?' || value[0] == patternChar) { - return matchPattern(value.substr(1), pattern.substr(1)); - } else if (patternChar == '*') { - return matchPattern(value, pattern.substr(1)) || - matchPattern(value.substr(1), pattern); - } - return false; -} - -// Force enables IR printing on the |passManager|. -void enableIRPrinting(PassManager *passManager) { - auto notVerifier = [](Pass *pass) { - return pass->getName() != "FunctionVerifier" && - pass->getName() != "ModuleVerifier"; - }; - bool printModuleScope = false; - passManager->enableIRPrinting(/*shouldPrintBeforePass=*/{}, - /*shouldPrintAfterPass=*/notVerifier, - printModuleScope, llvm::dbgs()); - passManager->disableMultithreading(); -} - -} // namespace - -ExecutableTranslationRegistration::ExecutableTranslationRegistration( - llvm::StringRef name, const TranslateExecutableFn &fn) { - auto ®istry = getMutableExecutableTranslationRegistry(); - if (registry.find(name) != registry.end()) { - llvm::report_fatal_error( - "Attempting to overwrite an existing translation function"); - } - assert(fn && "Attempting to register an empty translation function"); - registry[name] = fn; -} - -const llvm::StringMap<TranslateExecutableFn> - &getExecutableTranslationRegistry() { - return getMutableExecutableTranslationRegistry(); -} - -std::vector<std::string> matchExecutableTranslationBackendNames( - llvm::StringRef pattern) { - std::vector<std::string> matches; - for (auto &entry : getExecutableTranslationRegistry()) { - if (matchPattern(entry.getKey(), pattern)) { - matches.push_back(entry.getKey().str()); - } - } - return matches; -} - -std::unique_ptr<PassManager> createPassManager( - MLIRContext *ctx, const TranslationOptions &translationOptions) { - std::unique_ptr<PassManager> passManager(new PassManager(ctx)); - - // Enable IR printing/timing/etc from command line options. - registerPassManagerCLOptions(); - applyPassManagerCLOptions(*passManager); - - // Override with programmatic options. - if (translationOptions.print_mlir) { - enableIRPrinting(passManager.get()); - } - - return passManager; -} - -LogicalResult runPassPipeline(const TranslationOptions &translationOptions, - PassManager *passManager, ModuleOp module) { - if (translationOptions.print_mlir) { - module.dump(); - } - - // Run on the module. - if (failed(passManager->run(module))) { - return failure(); - } - - if (translationOptions.print_mlir) { - module.dump(); - } - - return success(); -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/compiler/Utils/TranslationUtils.h b/iree/compiler/Utils/TranslationUtils.h deleted file mode 100644 index a6aaefa..0000000 --- a/iree/compiler/Utils/TranslationUtils.h +++ /dev/null
@@ -1,109 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_COMPILER_UTILS_TRANSLATIONUTILS_H_ -#define IREE_COMPILER_UTILS_TRANSLATIONUTILS_H_ - -#include <functional> -#include <memory> - -#include "iree/compiler/IR/StructureOps.h" -#include "iree/schemas/executable_def_generated.h" -#include "llvm/ADT/StringMap.h" -#include "llvm/ADT/StringRef.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Module.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/Value.h" -#include "mlir/Pass/PassManager.h" - -namespace mlir { -namespace iree_compiler { - -// Common translation options for diagnostics and debugging. -struct TranslationOptions { - // Enables MLIR IR printing during translation. - // This can be specified via the -print-ir-before-all and -print-ir-after-all - // command line flags or overridden programmatically via this flag. - bool print_mlir = false; - - void CopyFrom(const TranslationOptions &other) { - print_mlir = other.print_mlir; - } -}; - -// Options for iree.module translation for diagnostics and debugging. -struct ModuleTranslationOptions : public TranslationOptions { - // Defines which backend translators will be used to translate executables. - // If empty then all linked in translators will be used. - // TODO(benvanik): extend to allow specifying entire config blobs via mlir. - std::vector<std::string> target_backends; -}; - -// Options for iree.executable translation for diagnostics and debugging. -// Target configuration is sourced from the iree.target_config op within the -// iree.executable. -struct ExecutableTranslationOptions : public TranslationOptions {}; - -// Results of a translation operation. -// May contain zero or more executable defs depending on translation options, -// defined target configs, and support. -struct ExecutableTranslationResult { - std::vector<std::unique_ptr<iree::ExecutableDefT>> executable_defs; -}; - -// Registered function that given a set of |executableOps| containing one -// or more iree.executables will produce zero or more serialized executables. -// -// Each iree.executable provided contains one iree.executable_target_config with -// backend-specific translation information. The translator can decide whether -// to translate each independently, group them together, etc. -// -// The provided |executableOps| can be mutated by the callee and will be -// preserved for debugging after translation. If any executable in -// |executableOps| is not used by the translator then it should be erased. -using TranslateExecutableFn = - std::function<llvm::Optional<ExecutableTranslationResult>( - ArrayRef<IREE::ExecutableOp> executableOps, - ExecutableTranslationOptions options)>; - -// Registers an executable translation function. -struct ExecutableTranslationRegistration { - ExecutableTranslationRegistration(llvm::StringRef name, - const TranslateExecutableFn &fn); -}; - -// Returns a read-only reference to the translator registry. -const llvm::StringMap<TranslateExecutableFn> - &getExecutableTranslationRegistry(); - -// Returns executable translation backend names matching the given pattern. -// This accepts wildcards for any delimited value. For example, 'foo-*-bar' will -// match 'foo-123-bar' and 'foo-456-bar' and 'foo-10?' will match 'foo-101' and -// 'foo-102'. -std::vector<std::string> matchExecutableTranslationBackendNames( - llvm::StringRef pattern); - -// Creates a new pass manager initialized with the given options. -std::unique_ptr<PassManager> createPassManager( - MLIRContext *ctx, const TranslationOptions &translationOptions); - -// Runs an initialized set of passes on the given module. -LogicalResult runPassPipeline(const TranslationOptions &translationOptions, - PassManager *passManager, ModuleOp module); - -} // namespace iree_compiler -} // namespace mlir - -#endif // IREE_COMPILER_UTILS_TRANSLATIONUTILS_H_
diff --git a/iree/compiler/Utils/TypeConversionUtils.cpp b/iree/compiler/Utils/TypeConversionUtils.cpp deleted file mode 100644 index ae4e11a..0000000 --- a/iree/compiler/Utils/TypeConversionUtils.cpp +++ /dev/null
@@ -1,74 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/compiler/Utils/TypeConversionUtils.h" - -#include <cassert> - -#include "iree/compiler/IR/Ops.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/ErrorHandling.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/StandardTypes.h" - -namespace mlir { -namespace iree_compiler { - -Type legalizeType(Type type) { - if (type.isIndex()) { - return IntegerType::get(kIndexBitWidth, type.getContext()); - } else if (type.isInteger(1)) { - return IntegerType::get(kBoolBitWidth, type.getContext()); - } else if (auto memRefType = type.dyn_cast<MemRefType>()) { - return MemRefType::get(memRefType.getShape(), - legalizeType(memRefType.getElementType())); - } else if (auto functionType = type.dyn_cast<FunctionType>()) { - llvm::SmallVector<Type, 4> inputs; - for (const auto &oldType : functionType.getInputs()) { - inputs.push_back(legalizeType(oldType)); - } - llvm::SmallVector<Type, 4> results; - for (const auto &oldType : functionType.getResults()) { - results.push_back(legalizeType(oldType)); - } - return FunctionType::get(inputs, results, type.getContext()); - } - return type; -} - -Type LLTypeConverter::convertType(Type type) { return legalizeType(type); } - -MemRefType convertTypeToMemRef(Type type) { - if (type.isIntOrIndexOrFloat()) { - return MemRefType::get({}, type, {}, 0); - } else if (auto tensorType = type.dyn_cast<RankedTensorType>()) { - return MemRefType::get(tensorType.getShape(), tensorType.getElementType()); - } else if (auto memRefType = type.dyn_cast<MemRefType>()) { - return memRefType; - } else { - llvm_unreachable("Unconvertable type"); - } -} - -MemRefType convertTypeToMemRef(Value *value) { - return convertTypeToMemRef(value->getType()); -} - -Type MemRefTypeConverter::convertType(Type type) { - return convertTypeToMemRef(type); -} - -} // namespace iree_compiler -} // namespace mlir
diff --git a/iree/hal/BUILD b/iree/hal/BUILD deleted file mode 100644 index 880f528..0000000 --- a/iree/hal/BUILD +++ /dev/null
@@ -1,377 +0,0 @@ -# HAL (Hardware Abstraction Layer). -# Subdirectories contain implementations for different hardware and -# software backends. - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "allocator", - srcs = ["allocator.cc"], - hdrs = ["allocator.h"], - deps = [ - ":buffer", - "//iree/base:source_location", - "//iree/base:status", - "//iree/base:tracing", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "api", - srcs = ["api.cc"], - hdrs = [ - "api.h", - "api_detail.h", - ], - visibility = ["//visibility:public"], - deps = [ - ":api_hdrs", - ":buffer", - ":buffer_view", - ":fence", - ":heap_buffer", - ":semaphore", - "//iree/base:api", - "//iree/base:api_util", - "//iree/base:shape", - "//iree/base:tracing", - "@com_google_absl//absl/base:core_headers", - ], -) - -cc_library( - name = "api_hdrs", - hdrs = ["api.h"], - deps = [ - "//iree/base:api_hdrs", - ], -) - -cc_library( - name = "buffer", - srcs = ["buffer.cc"], - hdrs = ["buffer.h"], - deps = [ - ":resource", - "//iree/base:bitfield", - "//iree/base:logging", - "//iree/base:source_location", - "//iree/base:status", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@com_google_absl//absl/types:variant", - ], -) - -cc_test( - name = "buffer_test", - srcs = [ - "buffer_mapping_test.cc", - "buffer_test.cc", - ], - deps = [ - ":buffer", - ":heap_buffer", - "//iree/base:status", - "//iree/base:status_matchers", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "buffer_view", - srcs = ["buffer_view.cc"], - hdrs = ["buffer_view.h"], - deps = [ - ":buffer", - "//iree/base:shape", - "//iree/base:source_location", - "//iree/base:status", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/strings", - ], -) - -cc_test( - name = "buffer_view_test", - srcs = [ - "buffer_view_test.cc", - ], - deps = [ - ":buffer", - ":buffer_view", - ":heap_buffer", - "//iree/base:status", - "//iree/base:status_matchers", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "buffer_view_string_util", - srcs = ["buffer_view_string_util.cc"], - hdrs = ["buffer_view_string_util.h"], - deps = [ - ":allocator", - ":buffer_view", - ":heap_buffer", - "//iree/base:source_location", - "//iree/base:status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - ], -) - -cc_test( - name = "buffer_view_string_util_test", - srcs = ["buffer_view_string_util_test.cc"], - deps = [ - ":buffer_view_string_util", - "//iree/base:status", - "//iree/base:status_matchers", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "command_buffer", - srcs = ["command_buffer.cc"], - hdrs = ["command_buffer.h"], - deps = [ - ":allocator", - ":buffer", - ":buffer_view", - ":event", - ":executable", - ":resource", - "//iree/base:bitfield", - "//iree/base:shape", - "//iree/base:status", - "@com_google_absl//absl/base:core_headers", - ], -) - -cc_library( - name = "command_buffer_validation", - srcs = ["command_buffer_validation.cc"], - hdrs = ["command_buffer_validation.h"], - deps = [ - ":command_buffer", - "//iree/base:logging", - "//iree/base:status", - ], -) - -cc_library( - name = "command_queue", - hdrs = ["command_queue.h"], - deps = [ - ":command_buffer", - ":fence", - ":semaphore", - "//iree/base:bitfield", - "//iree/base:status", - "//iree/base:time", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "deferred_buffer", - srcs = ["deferred_buffer.cc"], - hdrs = ["deferred_buffer.h"], - deps = [ - ":allocator", - ":buffer", - "//iree/base:status", - ], -) - -cc_test( - name = "deferred_buffer_test", - srcs = ["deferred_buffer_test.cc"], - deps = [ - ":deferred_buffer", - ":heap_buffer", - "//iree/base:status_matchers", - "//iree/hal/testing:mock_allocator", - "@com_google_absl//absl/memory", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "device", - hdrs = ["device.h"], - deps = [ - ":allocator", - ":buffer", - ":command_queue", - ":device_info", - ":event", - ":executable_cache", - ":semaphore", - "//iree/base:status", - "//iree/base:time", - "@com_google_absl//absl/time", - ], -) - -cc_library( - name = "device_info", - hdrs = ["device_info.h"], - deps = [ - "//iree/base:bitfield", - "@com_google_absl//absl/base:core_headers", - ], -) - -cc_library( - name = "device_manager", - srcs = ["device_manager.cc"], - hdrs = ["device_manager.h"], - deps = [ - ":allocator", - ":buffer", - ":command_queue", - ":device", - ":device_placement", - ":executable_format", - ":fence", - ":heap_buffer", - "//iree/base:source_location", - "//iree/base:status", - "//iree/base:time", - "//iree/base:tracing", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "device_placement", - hdrs = ["device_placement.h"], -) - -cc_library( - name = "driver", - hdrs = ["driver.h"], - deps = [ - ":device", - ":device_info", - "//iree/base:status", - ], -) - -cc_library( - name = "driver_registry", - srcs = ["driver_registry.cc"], - hdrs = ["driver_registry.h"], - deps = [ - ":driver", - "//iree/base:init", - "//iree/base:status", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/synchronization", - ], -) - -cc_library( - name = "event", - hdrs = ["event.h"], - deps = [ - ":resource", - ], -) - -cc_library( - name = "executable", - hdrs = ["executable.h"], - deps = [":resource"], -) - -cc_library( - name = "executable_cache", - srcs = ["executable_cache.cc"], - hdrs = ["executable_cache.h"], - deps = [ - ":executable", - ":executable_format", - ":executable_spec", - "//iree/base:bitfield", - "//iree/base:ref_ptr", - "//iree/base:status", - ], -) - -cc_library( - name = "executable_format", - hdrs = ["executable_format.h"], - deps = [ - "@com_google_absl//absl/base:core_headers", - ], -) - -cc_library( - name = "executable_spec", - hdrs = ["executable_spec.h"], - deps = [ - ":executable_format", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "fence", - hdrs = ["fence.h"], - deps = [ - ":resource", - "//iree/base:status", - ], -) - -cc_library( - name = "heap_buffer", - srcs = ["heap_buffer.cc"], - hdrs = ["heap_buffer.h"], - deps = [ - ":allocator", - ":buffer", - "//iree/base:source_location", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal/host:host_buffer", - "@com_google_absl//absl/base:core_headers", - ], -) - -cc_library( - name = "resource", - hdrs = ["resource.h"], - deps = [ - "//iree/base:ref_ptr", - ], -) - -cc_library( - name = "semaphore", - hdrs = ["semaphore.h"], - deps = [ - ":resource", - "@com_google_absl//absl/types:variant", - ], -) - -cc_library( - name = "stack_trace", - hdrs = ["stack_trace.h"], -)
diff --git a/iree/hal/allocator.cc b/iree/hal/allocator.cc deleted file mode 100644 index 7232804..0000000 --- a/iree/hal/allocator.cc +++ /dev/null
@@ -1,77 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/allocator.h" - -#include <cstdint> -#include <cstdlib> -#include <string> -#include <utility> - -#include "iree/base/source_location.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" - -namespace iree { -namespace hal { - -bool Allocator::CanUseBuffer(Buffer* buffer, - BufferUsageBitfield intended_usage) const { - return CanUseBufferLike(buffer->allocator(), buffer->memory_type(), - buffer->usage(), intended_usage); -} - -StatusOr<ref_ptr<Buffer>> Allocator::AllocateConstant( - BufferUsageBitfield buffer_usage, ref_ptr<Buffer> source_buffer) { - if (AnyBitSet(source_buffer->usage() & BufferUsage::kConstant) && - CanUseBuffer(source_buffer.get(), buffer_usage)) { - // Buffer can be used directly by the device. - return source_buffer; - } - - IREE_TRACE_SCOPE0("Allocator::AllocateConstant"); - - // We need to map so we can copy into it. - buffer_usage |= BufferUsage::kMapping; - // It will be constant after we write it. - buffer_usage |= BufferUsage::kConstant; - - MemoryTypeBitfield memory_type = - MemoryType::kDeviceLocal | MemoryType::kHostVisible; - ASSIGN_OR_RETURN(auto device_buffer, Allocate(memory_type, buffer_usage, - source_buffer->byte_length())); - ASSIGN_OR_RETURN(auto source_mapping, - source_buffer->MapMemory<uint8_t>(MemoryAccess::kRead)); - RETURN_IF_ERROR(device_buffer->WriteData(0, source_mapping.data(), - source_mapping.byte_length())); - return device_buffer; -} - -StatusOr<ref_ptr<Buffer>> Allocator::Wrap(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - const void* data, - size_t data_length) { - return WrapMutable(memory_type, MemoryAccess::kRead, buffer_usage, - const_cast<void*>(data), data_length); -} - -StatusOr<ref_ptr<Buffer>> Allocator::WrapMutable( - MemoryTypeBitfield memory_type, MemoryAccessBitfield allowed_access, - BufferUsageBitfield buffer_usage, void* data, size_t data_length) { - return UnimplementedErrorBuilder(IREE_LOC) - << "Allocator does not support wrapping host memory"; -} - -} // namespace hal -} // namespace iree
diff --git a/iree/hal/allocator.h b/iree/hal/allocator.h deleted file mode 100644 index ef56240..0000000 --- a/iree/hal/allocator.h +++ /dev/null
@@ -1,138 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_ALLOCATOR_H_ -#define IREE_HAL_ALLOCATOR_H_ - -#include <cstddef> -#include <memory> - -#include "absl/types/span.h" -#include "iree/base/status.h" -#include "iree/hal/buffer.h" - -namespace iree { -namespace hal { - -// Allocates buffers for a particular device memory space. -// -// Buffers allocated are only guaranteed to work with the driver that the -// allocator services. Any attempt to use buffers on drivers they were not -// allocated from must first be checked with CanUseBuffer. -// -// Thread-safe. -class Allocator { - public: - virtual ~Allocator() = default; - - // Returns true if the device can use the given buffer for the provided usage. - // For buffers allocated from this allocator it's expected that the result - // will always be true. For buffers that originate from another allocator - // there may be limited support for cross-device usage. - // - // Returning false indicates that the buffer must be transferred externally - // into a buffer compatible with the device this allocator services. - bool CanUseBuffer(Buffer* buffer, BufferUsageBitfield intended_usage) const; - virtual bool CanUseBufferLike(Allocator* source_allocator, - MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - BufferUsageBitfield intended_usage) const = 0; - - // Returns true if the allocator can allocate a buffer with the given - // attributes. - virtual bool CanAllocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - size_t allocation_size) const = 0; - - // Adjusts allocation parameters to be compatible with the allocator. - // Certain allocators may require particular memory types to function. By - // adjusting the parameters prior to allocation callers can be sure they are - // able to successfully Allocate a buffer later on with the same parameters. - virtual Status MakeCompatible(MemoryTypeBitfield* memory_type, - BufferUsageBitfield* buffer_usage) const { - return OkStatus(); - } - - // Allocates a buffer from the allocator. - // Fails if the memory type requested for the given usage cannot be serviced. - // Callers can use CanAllocate to decide their memory use strategy. - // - // The memory type of the buffer returned may differ from the requested value - // if the device can provide more functionality; for example, if requesting - // MemoryType::kHostVisible but the memory is really host cached you may get - // a buffer back with MemoryType::kHostVisible | MemoryType::kHostCached. The - // only requirement is that the buffer satisfy the required bits. - virtual StatusOr<ref_ptr<Buffer>> Allocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - size_t allocation_size) = 0; - - // Allocates a buffer from the allocator for use as a constant value. - // The provided |source_buffer| may be returned if the device can use it - // directly and otherwise will be copied. - virtual StatusOr<ref_ptr<Buffer>> AllocateConstant( - BufferUsageBitfield buffer_usage, ref_ptr<Buffer> source_buffer); - - // Wraps an existing host heap allocation in a buffer. - // Ownership of the host allocation remains with the caller and the memory - // must remain valid for so long as the Buffer may be in use. - // Will have MemoryType::kHostLocal in most cases and may not be usable - // by the device. - // - // The inference optimizer makes assumptions about buffer aliasing based on - // Buffer instances and because of this wrapping the same host buffer in - // multiple Buffers will create potential memory aliasing issues that can be - // difficult to track down. There's no checking as to whether a host buffer - // has already been wrapped so it's best for callers to ensure this is never - // possible (the simplest way being to never use Wrap and always just allocate - // new Buffers). - // - // Fails if the allocator cannot access host memory in this way. - StatusOr<ref_ptr<Buffer>> Wrap(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - const void* data, size_t data_length); - virtual StatusOr<ref_ptr<Buffer>> WrapMutable( - MemoryTypeBitfield memory_type, MemoryAccessBitfield allowed_access, - BufferUsageBitfield buffer_usage, void* data, size_t data_length); - template <typename T> - StatusOr<ref_ptr<Buffer>> Wrap(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - absl::Span<const T> data); - template <typename T> - StatusOr<ref_ptr<Buffer>> WrapMutable(MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, - BufferUsageBitfield buffer_usage, - absl::Span<T> data); -}; - -// Inline functions and template definitions follow: - -template <typename T> -StatusOr<ref_ptr<Buffer>> Allocator::Wrap(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - absl::Span<const T> data) { - return Wrap(memory_type, buffer_usage, data.data(), data.size() * sizeof(T)); -} - -template <typename T> -StatusOr<ref_ptr<Buffer>> Allocator::WrapMutable( - MemoryTypeBitfield memory_type, MemoryAccessBitfield allowed_access, - BufferUsageBitfield buffer_usage, absl::Span<T> data) { - return WrapMutable(memory_type, allowed_access, buffer_usage, data.data(), - data.size() * sizeof(T)); -} - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_ALLOCATOR_H_
diff --git a/iree/hal/api.cc b/iree/hal/api.cc deleted file mode 100644 index 79c1a27..0000000 --- a/iree/hal/api.cc +++ /dev/null
@@ -1,439 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/api.h" - -#include "iree/base/api.h" -#include "iree/base/api_util.h" -#include "iree/base/shape.h" -#include "iree/base/tracing.h" -#include "iree/hal/api_detail.h" -#include "iree/hal/buffer.h" -#include "iree/hal/buffer_view.h" -#include "iree/hal/fence.h" -#include "iree/hal/heap_buffer.h" -#include "iree/hal/semaphore.h" - -namespace iree { -namespace hal { - -//===----------------------------------------------------------------------===// -// iree::hal::Buffer -//===----------------------------------------------------------------------===// - -IREE_API_EXPORT iree_status_t iree_hal_buffer_subspan( - iree_hal_buffer_t* buffer, iree_device_size_t byte_offset, - iree_device_size_t byte_length, iree_allocator_t allocator, - iree_hal_buffer_t** out_buffer) { - IREE_TRACE_SCOPE0("iree_hal_buffer_subspan"); - - if (!out_buffer) { - return IREE_STATUS_INVALID_ARGUMENT; - } - *out_buffer = nullptr; - - if (!buffer) { - return IREE_STATUS_INVALID_ARGUMENT; - } - auto handle = add_ref(reinterpret_cast<Buffer*>(buffer)); - - IREE_API_ASSIGN_OR_RETURN(auto new_handle, - Buffer::Subspan(handle, byte_offset, byte_length)); - - *out_buffer = reinterpret_cast<iree_hal_buffer_t*>(new_handle.release()); - - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t -iree_hal_buffer_retain(iree_hal_buffer_t* buffer) { - IREE_TRACE_SCOPE0("iree_hal_buffer_retain"); - auto* handle = reinterpret_cast<Buffer*>(buffer); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - handle->AddReference(); - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t -iree_hal_buffer_release(iree_hal_buffer_t* buffer) { - IREE_TRACE_SCOPE0("iree_hal_buffer_release"); - auto* handle = reinterpret_cast<Buffer*>(buffer); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - handle->ReleaseReference(); - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_device_size_t -iree_hal_buffer_byte_length(const iree_hal_buffer_t* buffer) { - IREE_TRACE_SCOPE0("iree_hal_buffer_byte_length"); - const auto* handle = reinterpret_cast<const Buffer*>(buffer); - CHECK(handle) << "NULL buffer handle"; - return handle->byte_length(); -} - -IREE_API_EXPORT iree_status_t -iree_hal_buffer_zero(iree_hal_buffer_t* buffer, iree_device_size_t byte_offset, - iree_device_size_t byte_length) { - IREE_TRACE_SCOPE0("iree_hal_buffer_zero"); - auto* handle = reinterpret_cast<Buffer*>(buffer); - if (!buffer) { - return IREE_STATUS_INVALID_ARGUMENT; - } - IREE_API_RETURN_IF_ERROR(handle->Fill8(byte_offset, byte_length, 0)); - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t -iree_hal_buffer_fill(iree_hal_buffer_t* buffer, iree_device_size_t byte_offset, - iree_device_size_t byte_length, const void* pattern, - iree_host_size_t pattern_length) { - IREE_TRACE_SCOPE0("iree_hal_buffer_fill"); - auto* handle = reinterpret_cast<Buffer*>(buffer); - if (!buffer) { - return IREE_STATUS_INVALID_ARGUMENT; - } - IREE_API_RETURN_IF_ERROR( - handle->Fill(byte_offset, byte_length, pattern, pattern_length)); - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t iree_hal_buffer_read_data( - iree_hal_buffer_t* buffer, iree_device_size_t source_offset, - void* target_buffer, iree_device_size_t data_length) { - IREE_TRACE_SCOPE0("iree_hal_buffer_read_data"); - auto* handle = reinterpret_cast<Buffer*>(buffer); - if (!buffer) { - return IREE_STATUS_INVALID_ARGUMENT; - } - IREE_API_RETURN_IF_ERROR( - handle->ReadData(source_offset, target_buffer, data_length)); - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t iree_hal_buffer_write_data( - iree_hal_buffer_t* buffer, iree_device_size_t target_offset, - const void* source_buffer, iree_device_size_t data_length) { - IREE_TRACE_SCOPE0("iree_hal_buffer_write_data"); - auto* handle = reinterpret_cast<Buffer*>(buffer); - if (!buffer) { - return IREE_STATUS_INVALID_ARGUMENT; - } - IREE_API_RETURN_IF_ERROR( - handle->WriteData(target_offset, source_buffer, data_length)); - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t iree_hal_buffer_map( - iree_hal_buffer_t* buffer, iree_hal_memory_access_t memory_access, - iree_device_size_t element_offset, iree_device_size_t element_length, - iree_hal_mapped_memory_t* out_mapped_memory) { - IREE_TRACE_SCOPE0("iree_hal_buffer_map"); - - if (!out_mapped_memory) { - return IREE_STATUS_INVALID_ARGUMENT; - } - std::memset(out_mapped_memory, 0, sizeof(*out_mapped_memory)); - - auto* buffer_handle = reinterpret_cast<Buffer*>(buffer); - if (!buffer_handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - - IREE_API_ASSIGN_OR_RETURN( - auto mapping, buffer_handle->MapMemory<uint8_t>( - static_cast<MemoryAccessBitfield>(memory_access), - element_offset, element_length)); - - static_assert(sizeof(iree_hal_mapped_memory_t::reserved) >= - sizeof(MappedMemory<uint8_t>), - "C mapped memory struct must have large enough storage for the " - "matching C++ struct"); - auto* mapping_storage = - reinterpret_cast<MappedMemory<uint8_t>*>(out_mapped_memory->reserved); - *mapping_storage = std::move(mapping); - - out_mapped_memory->contents = {const_cast<uint8_t*>(mapping_storage->data()), - mapping_storage->size()}; - - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t iree_hal_buffer_unmap( - iree_hal_buffer_t* buffer, iree_hal_mapped_memory_t* mapped_memory) { - IREE_TRACE_SCOPE0("iree_hal_buffer_map"); - auto* buffer_handle = reinterpret_cast<Buffer*>(buffer); - if (!buffer_handle || !mapped_memory) { - return IREE_STATUS_INVALID_ARGUMENT; - } - - auto* mapping = - reinterpret_cast<MappedMemory<uint8_t>*>(mapped_memory->reserved); - mapping->reset(); - - std::memset(mapped_memory, 0, sizeof(*mapped_memory)); - return IREE_STATUS_OK; -} - -//===----------------------------------------------------------------------===// -// iree::hal::HeapBuffer -//===----------------------------------------------------------------------===// - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_allocate( - iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t usage, - iree_host_size_t allocation_size, iree_allocator_t contents_allocator, - iree_allocator_t allocator, iree_hal_buffer_t** out_buffer) { - IREE_TRACE_SCOPE0("iree_hal_heap_buffer_allocate"); - - if (!out_buffer) { - return IREE_STATUS_INVALID_ARGUMENT; - } - *out_buffer = nullptr; - - if (!allocation_size) { - return IREE_STATUS_INVALID_ARGUMENT; - } - - auto handle = HeapBuffer::Allocate( - static_cast<MemoryTypeBitfield>(memory_type), - static_cast<BufferUsageBitfield>(usage), allocation_size); - - *out_buffer = reinterpret_cast<iree_hal_buffer_t*>( - static_cast<Buffer*>(handle.release())); - - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_allocate_copy( - iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t usage, - iree_hal_memory_access_t allowed_access, iree_byte_span_t contents, - iree_allocator_t contents_allocator, iree_allocator_t allocator, - iree_hal_buffer_t** out_buffer) { - IREE_TRACE_SCOPE0("iree_hal_heap_buffer_allocate_copy"); - - if (!out_buffer) { - return IREE_STATUS_INVALID_ARGUMENT; - } - *out_buffer = nullptr; - - if (!contents.data || !contents.data_length) { - return IREE_STATUS_INVALID_ARGUMENT; - } - - auto handle = HeapBuffer::AllocateCopy( - static_cast<BufferUsageBitfield>(usage), - static_cast<MemoryAccessBitfield>(allowed_access), contents.data, - contents.data_length); - - *out_buffer = reinterpret_cast<iree_hal_buffer_t*>(handle.release()); - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_wrap( - iree_hal_memory_type_t memory_type, iree_hal_memory_access_t allowed_access, - iree_hal_buffer_usage_t usage, iree_byte_span_t contents, - iree_allocator_t allocator, iree_hal_buffer_t** out_buffer) { - IREE_TRACE_SCOPE0("iree_hal_heap_buffer_wrap"); - - if (!out_buffer) { - return IREE_STATUS_INVALID_ARGUMENT; - } - *out_buffer = nullptr; - - if (!contents.data || !contents.data_length) { - return IREE_STATUS_INVALID_ARGUMENT; - } - - auto handle = - HeapBuffer::WrapMutable(static_cast<MemoryTypeBitfield>(memory_type), - static_cast<MemoryAccessBitfield>(allowed_access), - static_cast<BufferUsageBitfield>(usage), - contents.data, contents.data_length); - - *out_buffer = reinterpret_cast<iree_hal_buffer_t*>(handle.release()); - return IREE_STATUS_OK; -} - -//===----------------------------------------------------------------------===// -// iree::hal::BufferView -//===----------------------------------------------------------------------===// - -IREE_API_EXPORT iree_status_t iree_hal_buffer_view_create( - iree_hal_buffer_t* buffer, iree_shape_t shape, int8_t element_size, - iree_allocator_t allocator, iree_hal_buffer_view_t** out_buffer_view) { - IREE_TRACE_SCOPE0("iree_hal_buffer_view_create"); - - if (!out_buffer_view) { - return IREE_STATUS_INVALID_ARGUMENT; - } - *out_buffer_view = nullptr; - - if (!buffer) { - return IREE_STATUS_INVALID_ARGUMENT; - } else if (shape.rank > kMaxRank || element_size <= 0) { - return IREE_STATUS_OUT_OF_RANGE; - } - - // Allocate and initialize the iree_hal_buffer_view struct. - iree_hal_buffer_view* handle = nullptr; - IREE_API_RETURN_IF_API_ERROR(allocator.alloc( - allocator.self, sizeof(*handle), reinterpret_cast<void**>(&handle))); - new (handle) iree_hal_buffer_view(); - handle->allocator = allocator; - - handle->impl.buffer = add_ref(reinterpret_cast<Buffer*>(buffer)); - handle->impl.shape = {shape.dims, shape.rank}; - handle->impl.element_size = element_size; - - *out_buffer_view = reinterpret_cast<iree_hal_buffer_view_t*>(handle); - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t -iree_hal_buffer_view_retain(iree_hal_buffer_view_t* buffer_view) { - IREE_TRACE_SCOPE0("iree_hal_buffer_view_retain"); - auto* handle = reinterpret_cast<iree_hal_buffer_view*>(buffer_view); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - handle->AddReference(); - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t -iree_hal_buffer_view_release(iree_hal_buffer_view_t* buffer_view) { - IREE_TRACE_SCOPE0("iree_hal_buffer_view_release"); - auto* handle = reinterpret_cast<iree_hal_buffer_view*>(buffer_view); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - handle->ReleaseReference(); - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t iree_hal_buffer_view_assign( - iree_hal_buffer_view_t* buffer_view, iree_hal_buffer_t* buffer, - iree_shape_t shape, int8_t element_size) { - IREE_TRACE_SCOPE0("iree_hal_buffer_view_assign"); - auto* handle = reinterpret_cast<iree_hal_buffer_view*>(buffer_view); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - handle->impl.buffer.reset(); - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t -iree_hal_buffer_view_reset(iree_hal_buffer_view_t* buffer_view) { - IREE_TRACE_SCOPE0("iree_hal_buffer_view_reset"); - auto* handle = reinterpret_cast<iree_hal_buffer_view*>(buffer_view); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - handle->impl.buffer.reset(); - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_hal_buffer_t* iree_hal_buffer_view_buffer( - const iree_hal_buffer_view_t* buffer_view) { - IREE_TRACE_SCOPE0("iree_hal_buffer_view_buffer"); - const auto* handle = - reinterpret_cast<const iree_hal_buffer_view*>(buffer_view); - CHECK(handle) << "NULL buffer_view handle"; - return reinterpret_cast<iree_hal_buffer_t*>(handle->impl.buffer.get()); -} - -IREE_API_EXPORT iree_status_t iree_hal_buffer_view_shape( - const iree_hal_buffer_view_t* buffer_view, iree_shape_t* out_shape) { - IREE_TRACE_SCOPE0("iree_hal_buffer_view_shape"); - - if (!out_shape) { - return IREE_STATUS_INVALID_ARGUMENT; - } - out_shape->rank = 0; - - const auto* handle = - reinterpret_cast<const iree_hal_buffer_view*>(buffer_view); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - - const auto& shape = handle->impl.shape; - return ToApiShape(shape, out_shape); -} - -IREE_API_EXPORT int8_t -iree_hal_buffer_view_element_size(const iree_hal_buffer_view_t* buffer_view) { - IREE_TRACE_SCOPE0("iree_hal_buffer_view_element_size"); - const auto* handle = - reinterpret_cast<const iree_hal_buffer_view*>(buffer_view); - CHECK(handle) << "NULL buffer_view handle"; - return handle->impl.element_size; -} - -//===----------------------------------------------------------------------===// -// iree::hal::Semaphore -//===----------------------------------------------------------------------===// - -IREE_API_EXPORT iree_status_t -iree_hal_semaphore_retain(iree_hal_semaphore_t* semaphore) { - IREE_TRACE_SCOPE0("iree_hal_semaphore_retain"); - auto* handle = reinterpret_cast<Semaphore*>(semaphore); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - handle->AddReference(); - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t -iree_hal_semaphore_release(iree_hal_semaphore_t* semaphore) { - IREE_TRACE_SCOPE0("iree_hal_semaphore_release"); - auto* handle = reinterpret_cast<Semaphore*>(semaphore); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - handle->ReleaseReference(); - return IREE_STATUS_OK; -} - -//===----------------------------------------------------------------------===// -// iree::hal::Fence -//===----------------------------------------------------------------------===// - -IREE_API_EXPORT iree_status_t iree_hal_fence_retain(iree_hal_fence_t* fence) { - IREE_TRACE_SCOPE0("iree_hal_fence_retain"); - auto* handle = reinterpret_cast<Fence*>(fence); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - handle->AddReference(); - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t iree_hal_fence_release(iree_hal_fence_t* fence) { - IREE_TRACE_SCOPE0("iree_hal_fence_release"); - auto* handle = reinterpret_cast<Fence*>(fence); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - handle->ReleaseReference(); - return IREE_STATUS_OK; -} - -} // namespace hal -} // namespace iree
diff --git a/iree/hal/api.h b/iree/hal/api.h deleted file mode 100644 index 6347519..0000000 --- a/iree/hal/api.h +++ /dev/null
@@ -1,366 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// See iree/base/api.h for documentation on the API conventions used. - -#ifndef IREE_HAL_API_H_ -#define IREE_HAL_API_H_ - -#include <stdint.h> - -#include "iree/base/api.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -//===----------------------------------------------------------------------===// -// Types and Enums -//===----------------------------------------------------------------------===// - -typedef struct iree_hal_buffer iree_hal_buffer_t; -typedef struct iree_hal_buffer_view iree_hal_buffer_view_t; -typedef struct iree_hal_semaphore iree_hal_semaphore_t; -typedef struct iree_hal_fence iree_hal_fence_t; - -// Reference to a buffer's mapped memory. -typedef struct { - // Contents of the buffer. Behavior is undefined if an access is performed - // whose type was not specified during mapping. - iree_byte_span_t contents; - - // Used internally - do not modify. - uint64_t reserved[8]; -} iree_hal_mapped_memory_t; - -// A bitfield specifying properties for a memory type. -typedef enum { - IREE_HAL_MEMORY_TYPE_NONE = 0, - - // Memory is lazily allocated by the device and only exists transiently. - // This is the optimal mode for memory used only within a single command - // buffer. Transient buffers, even if they have - // IREE_HAL_MEMORY_TYPE_HOST_VISIBLE set, should be treated as device-local - // and opaque as they may have no memory attached to them outside of the time - // they are being evaluated on devices. - // - // This flag can be treated as a hint in most cases; allocating a buffer with - // it set _may_ return the same as if it had not be set. Certain allocation - // routines may use the hint to more tightly control reuse or defer wiring the - // memory. - IREE_HAL_MEMORY_TYPE_TRANSIENT = 1 << 0, - - // Memory allocated with this type can be mapped for host access using - // iree_hal_buffer_map. - IREE_HAL_MEMORY_TYPE_HOST_VISIBLE = 1 << 1, - - // The host cache management commands MappedMemory::Flush and - // MappedMemory::Invalidate are not needed to flush host writes - // to the device or make device writes visible to the host, respectively. - IREE_HAL_MEMORY_TYPE_HOST_COHERENT = 1 << 2, - - // Memory allocated with this type is cached on the host. Host memory - // accesses to uncached memory are slower than to cached memory, however - // uncached memory is always host coherent. MappedMemory::Flush must be used - // to ensure the device has visibility into any changes made on the host and - // Invalidate must be used to ensure the host has visibility into any changes - // made on the device. - IREE_HAL_MEMORY_TYPE_HOST_CACHED = 1 << 3, - - // Memory is accessible as normal host allocated memory. - IREE_HAL_MEMORY_TYPE_HOST_LOCAL = - IREE_HAL_MEMORY_TYPE_HOST_VISIBLE | IREE_HAL_MEMORY_TYPE_HOST_COHERENT, - - // Memory allocated with this type is visible to the device for execution. - // Being device visible does not mean the same thing as - // IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL. Though an allocation may be visible to - // the device and therefore useable for execution it may require expensive - // mapping or implicit transfers. - IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE = 1 << 4, - - // Memory allocated with this type is the most efficient for device access. - // Devices may support using memory that is not device local via - // IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE but doing so can incur non-trivial - // performance penalties. Device local memory, on the other hand, is - // guaranteed to be fast for all operations. - IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL = - IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE | (1 << 5), -} iree_hal_memory_type_t; - -// A bitfield specifying how memory will be accessed in a mapped memory region. -typedef enum { - // Memory is not mapped. - IREE_HAL_MEMORY_ACCESS_NONE = 0, - // Memory will be read. - // If a buffer is only mapped for reading it may still be possible to write to - // it but the results will be undefined (as it may present coherency issues). - IREE_HAL_MEMORY_ACCESS_READ = 1 << 0, - // Memory will be written. - // If a buffer is only mapped for writing it may still be possible to read - // from it but the results will be undefined or incredibly slow (as it may - // be mapped by the driver as uncached). - IREE_HAL_MEMORY_ACCESS_WRITE = 1 << 1, - // Memory will be discarded prior to mapping. - // The existing contents will be undefined after mapping and must be written - // to ensure validity. - IREE_HAL_MEMORY_ACCESS_DISCARD = 1 << 2, - // Memory will be discarded and completely overwritten in a single operation. - IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE = - IREE_HAL_MEMORY_ACCESS_WRITE | IREE_HAL_MEMORY_ACCESS_DISCARD, - // Memory may have any operation performed on it. - IREE_HAL_MEMORY_ACCESS_ALL = IREE_HAL_MEMORY_ACCESS_READ | - IREE_HAL_MEMORY_ACCESS_WRITE | - IREE_HAL_MEMORY_ACCESS_DISCARD, -} iree_hal_memory_access_t; - -// Bitfield that defines how a buffer is intended to be used. -// Usage allows the driver to appropriately place the buffer for more -// efficient operations of the specified types. -typedef enum { - IREE_HAL_BUFFER_USAGE_NONE = 0, - - // The buffer, once defined, will not be mapped or updated again. - // This should be used for uniform parameter values such as runtime - // constants for executables. Doing so may allow drivers to inline values or - // represent them in command buffers more efficiently (avoiding memory reads - // or swapping, etc). - IREE_HAL_BUFFER_USAGE_CONSTANT = 1 << 0, - - // The buffer can be used as the source or target of a transfer command - // (CopyBuffer, UpdateBuffer, etc). - // - // If |IREE_HAL_BUFFER_USAGE_MAPPING| is not specified drivers may safely - // assume that the host may never need visibility of this buffer as all - // accesses will happen via command buffers. - IREE_HAL_BUFFER_USAGE_TRANSFER = 1 << 1, - - // The buffer can be mapped by the host application for reading and writing. - // - // As mapping may require placement in special address ranges or system - // calls to enable visibility the driver can use the presence (or lack of) - // this flag to perform allocation-type setup and avoid initial mapping - // overhead. - IREE_HAL_BUFFER_USAGE_MAPPING = 1 << 2, - - // The buffer can be provided as an input or output to an executable. - // Buffers of this type may be directly used by drivers during dispatch. - IREE_HAL_BUFFER_USAGE_DISPATCH = 1 << 3, - - // Buffer may be used for any operation. - IREE_HAL_BUFFER_USAGE_ALL = IREE_HAL_BUFFER_USAGE_TRANSFER | - IREE_HAL_BUFFER_USAGE_MAPPING | - IREE_HAL_BUFFER_USAGE_DISPATCH, -} iree_hal_buffer_usage_t; - -//===----------------------------------------------------------------------===// -// iree::hal::Buffer -//===----------------------------------------------------------------------===// - -#ifndef IREE_API_NO_PROTOTYPES - -// Returns a reference to a subspan of the |buffer|. -// If |byte_length| is IREE_WHOLE_BUFFER the remaining bytes in the buffer after -// |byte_offset| (possibly 0) will be selected. -// -// The parent buffer will remain alive for the lifetime of the subspan -// returned. If the subspan is a small portion this may cause additional -// memory to remain allocated longer than required. -// -// Returns the given |buffer| if the requested span covers the entire range. -// |out_buffer| must be released by the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_subspan( - iree_hal_buffer_t* buffer, iree_device_size_t byte_offset, - iree_device_size_t byte_length, iree_allocator_t allocator, - iree_hal_buffer_t** out_buffer); - -// Retains the given |buffer| for the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_buffer_retain(iree_hal_buffer_t* buffer); - -// Releases the given |buffer| from the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_buffer_release(iree_hal_buffer_t* buffer); - -// Returns the size in bytes of the buffer. -IREE_API_EXPORT iree_device_size_t IREE_API_CALL -iree_hal_buffer_byte_length(const iree_hal_buffer_t* buffer); - -// Sets a range of the buffer to binary zero. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_buffer_zero(iree_hal_buffer_t* buffer, iree_device_size_t byte_offset, - iree_device_size_t byte_length); - -// Sets a range of the buffer to the given value. -// Only |pattern_length| values with 1, 2, or 4 bytes are supported. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_buffer_fill(iree_hal_buffer_t* buffer, iree_device_size_t byte_offset, - iree_device_size_t byte_length, const void* pattern, - iree_host_size_t pattern_length); - -// Reads a block of data from the buffer at the given offset. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_read_data( - iree_hal_buffer_t* buffer, iree_device_size_t source_offset, - void* target_buffer, iree_device_size_t data_length); - -// Writes a block of byte data into the buffer at the given offset. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_write_data( - iree_hal_buffer_t* buffer, iree_device_size_t target_offset, - const void* source_buffer, iree_device_size_t data_length); - -// Maps the buffer to be accessed as a host pointer into |out_mapped_memory|. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_map( - iree_hal_buffer_t* buffer, iree_hal_memory_access_t memory_access, - iree_device_size_t element_offset, iree_device_size_t element_length, - iree_hal_mapped_memory_t* out_mapped_memory); - -// Unmaps the buffer as was previously mapped to |mapped_memory|. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_unmap( - iree_hal_buffer_t* buffer, iree_hal_mapped_memory_t* mapped_memory); - -#endif // IREE_API_NO_PROTOTYPES - -//===----------------------------------------------------------------------===// -// iree::hal::HeapBuffer -//===----------------------------------------------------------------------===// - -#ifndef IREE_API_NO_PROTOTYPES - -// Allocates a zeroed host heap buffer of the given size. -// The buffer contents will be allocated with |contents_allocator| while -// |allocator| is used for the iree_hal_buffer_t. -// -// Returns a buffer allocated with malloc that may not be usable by devices -// without copies. |memory_type| should be set to -// IREE_HAL_MEMORY_TYPE_HOST_LOCAL in most cases. -// |out_buffer| must be released by the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_allocate( - iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t usage, - iree_host_size_t allocation_size, iree_allocator_t contents_allocator, - iree_allocator_t allocator, iree_hal_buffer_t** out_buffer); - -// Allocates a host heap buffer with a copy of the given data. -// The buffer contents will be allocated with |contents_allocator| while -// |allocator| is used for the iree_hal_buffer_t. -// -// Returns a buffer allocated with malloc that may not be usable by devices -// without copies. |memory_type| should be set to -// IREE_HAL_MEMORY_TYPE_HOST_LOCAL in most cases. -// |out_buffer| must be released by the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_allocate_copy( - iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t usage, - iree_hal_memory_access_t allowed_access, iree_byte_span_t contents, - iree_allocator_t contents_allocator, iree_allocator_t allocator, - iree_hal_buffer_t** out_buffer); - -// Wraps an existing host heap allocation in a buffer. -// Ownership of the host allocation remains with the caller and the memory -// must remain valid for so long as the iree_hal_buffer_t may be in use. -// -// Returns a buffer allocated with malloc that may not be usable by devices -// without copies. |memory_type| should be set to -// IREE_HAL_MEMORY_TYPE_HOST_LOCAL in most cases. -// |out_buffer| must be released by the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_heap_buffer_wrap( - iree_hal_memory_type_t memory_type, iree_hal_memory_access_t allowed_access, - iree_hal_buffer_usage_t usage, iree_byte_span_t contents, - iree_allocator_t allocator, iree_hal_buffer_t** out_buffer); - -// TODO(benvanik): add a wrap that takes an allocator just for the buffer. - -#endif // IREE_API_NO_PROTOTYPES - -//===----------------------------------------------------------------------===// -// iree::hal::BufferView -//===----------------------------------------------------------------------===// - -#ifndef IREE_API_NO_PROTOTYPES - -// Creates a buffer view with the given |buffer|, which may be nullptr. -// |out_buffer_view| must be released by the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_create( - iree_hal_buffer_t* buffer, iree_shape_t shape, int8_t element_size, - iree_allocator_t allocator, iree_hal_buffer_view_t** out_buffer_view); - -// Retains the given |buffer_view| for the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_buffer_view_retain(iree_hal_buffer_view_t* buffer_view); - -// Releases the given |buffer_view| from the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_buffer_view_release(iree_hal_buffer_view_t* buffer_view); - -// Sets the buffer view to point at the new |buffer| with the given metadata. -// To clear a buffer_view to empty use iree_hal_buffer_view_reset. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_assign( - iree_hal_buffer_view_t* buffer_view, iree_hal_buffer_t* buffer, - iree_shape_t shape, int8_t element_size); - -// Resets the buffer view to have an empty buffer and shape. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_buffer_view_reset(iree_hal_buffer_view_t* buffer_view); - -// Returns the buffer underlying the buffer view. -// The caller must retain the returned buffer if they want to continue using it. -IREE_API_EXPORT iree_hal_buffer_t* IREE_API_CALL -iree_hal_buffer_view_buffer(const iree_hal_buffer_view_t* buffer_view); - -// Returns the shape of the buffer view in |out_shape|. -// If there is not enough space in |out_shape| to store all dimensions then -// IREE_STATUS_OUT_OF_RANGE is returned and |out_shape|.rank is set to the rank. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_buffer_view_shape( - const iree_hal_buffer_view_t* buffer_view, iree_shape_t* out_shape); - -// Returns the size of each element in the buffer view in bytes. -IREE_API_EXPORT int8_t IREE_API_CALL -iree_hal_buffer_view_element_size(const iree_hal_buffer_view_t* buffer_view); - -#endif // IREE_API_NO_PROTOTYPES - -//===----------------------------------------------------------------------===// -// iree::hal::Semaphore -//===----------------------------------------------------------------------===// - -#ifndef IREE_API_NO_PROTOTYPES - -// Retains the given |semaphore| for the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_semaphore_retain(iree_hal_semaphore_t* semaphore); - -// Releases the given |semaphore| from the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_semaphore_release(iree_hal_semaphore_t* semaphore); - -#endif // IREE_API_NO_PROTOTYPES - -//===----------------------------------------------------------------------===// -// iree::hal::Fence -//===----------------------------------------------------------------------===// - -#ifndef IREE_API_NO_PROTOTYPES - -// Retains the given |fence| for the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_fence_retain(iree_hal_fence_t* fence); - -// Releases the given |fence| from the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_fence_release(iree_hal_fence_t* fence); - -#endif // IREE_API_NO_PROTOTYPES - -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus - -#endif // IREE_HAL_API_H_
diff --git a/iree/hal/api_detail.h b/iree/hal/api_detail.h deleted file mode 100644 index 9bc3047..0000000 --- a/iree/hal/api_detail.h +++ /dev/null
@@ -1,42 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// Additional definitions for internal users of the api. This should only -// be included from internal implementation files. - -#ifndef IREE_HAL_API_DETAIL_H_ -#define IREE_HAL_API_DETAIL_H_ - -#include "iree/hal/api.h" -#include "iree/hal/buffer_view.h" - -namespace iree { -namespace hal { - -// In the API, buffer views are ref objects, and this allows parts of the -// API outside of the HAL to work with them. -struct iree_hal_buffer_view : public RefObject<iree_hal_buffer_view> { - BufferView impl; - iree_allocator_t allocator; - - static void Delete(iree_hal_buffer_view* ptr) { - ptr->impl.buffer.reset(); - ptr->allocator.free(ptr->allocator.self, ptr); - } -}; - -} // namespace hal -} // namespace iree - -#endif
diff --git a/iree/hal/buffer.cc b/iree/hal/buffer.cc deleted file mode 100644 index 49a0602..0000000 --- a/iree/hal/buffer.cc +++ /dev/null
@@ -1,549 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/buffer.h" - -#include <algorithm> -#include <atomic> -#include <cstdint> -#include <cstring> -#include <sstream> - -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "absl/types/variant.h" -#include "iree/base/status.h" - -namespace iree { -namespace hal { - -#if HAS_IREE_BUFFER_DEBUG_NAME -namespace { -// Used for diagnostic purposes only as a default buffer name. -std::atomic<int> next_buffer_id_{0}; -} // namespace -#endif // HAS_IREE_BUFFER_DEBUG_NAME - -std::string MemoryTypeString(MemoryTypeBitfield memory_type) { - return FormatBitfieldValue(memory_type, - { - // Combined: - {MemoryType::kHostLocal, "kHostLocal"}, - {MemoryType::kDeviceLocal, "kDeviceLocal"}, - // Separate: - {MemoryType::kTransient, "kTransient"}, - {MemoryType::kHostVisible, "kHostVisible"}, - {MemoryType::kHostCoherent, "kHostCoherent"}, - {MemoryType::kHostCached, "kHostCached"}, - {MemoryType::kDeviceVisible, "kDeviceVisible"}, - }); -} - -std::string MemoryAccessString(MemoryAccessBitfield memory_access) { - return FormatBitfieldValue(memory_access, - { - // Combined: - {MemoryAccess::kAll, "kAll"}, - {MemoryAccess::kDiscardWrite, "kDiscardWrite"}, - // Separate: - {MemoryAccess::kRead, "kRead"}, - {MemoryAccess::kWrite, "kWrite"}, - {MemoryAccess::kDiscard, "kDiscard"}, - }); -} - -std::string BufferUsageString(BufferUsageBitfield buffer_usage) { - return FormatBitfieldValue(buffer_usage, - { - // Combined: - {BufferUsage::kAll, "kAll"}, - // Separate: - {BufferUsage::kConstant, "kConstant"}, - {BufferUsage::kTransfer, "kTransfer"}, - {BufferUsage::kMapping, "kMapping"}, - {BufferUsage::kDispatch, "kDispatch"}, - }); -} - -// Special router for buffers that just reference other buffers. -// We keep this out of the base Buffer so that it's a bit easier to track -// delegation. -class SubspanBuffer : public Buffer { - public: - SubspanBuffer(ref_ptr<Buffer> parent_buffer, device_size_t byte_offset, - device_size_t byte_length) - : Buffer(parent_buffer->allocator(), parent_buffer->memory_type(), - parent_buffer->allowed_access(), parent_buffer->usage(), - parent_buffer->allocation_size(), byte_offset, byte_length) { - allocated_buffer_ = parent_buffer.get(); - parent_buffer_ = std::move(parent_buffer); - } - - protected: - Status FillImpl(device_size_t byte_offset, device_size_t byte_length, - const void* pattern, device_size_t pattern_length) override { - return parent_buffer_->FillImpl(byte_offset, byte_length, pattern, - pattern_length); - } - - Status ReadDataImpl(device_size_t source_offset, void* data, - device_size_t data_length) override { - return parent_buffer_->ReadDataImpl(source_offset, data, data_length); - } - - Status WriteDataImpl(device_size_t target_offset, const void* data, - device_size_t data_length) override { - return parent_buffer_->WriteDataImpl(target_offset, data, data_length); - } - - Status CopyDataImpl(device_size_t target_offset, Buffer* source_buffer, - device_size_t source_offset, - device_size_t data_length) override { - return parent_buffer_->CopyDataImpl(target_offset, source_buffer, - source_offset, data_length); - } - - Status MapMemoryImpl(MappingMode mapping_mode, - MemoryAccessBitfield memory_access, - device_size_t local_byte_offset, - device_size_t local_byte_length, - void** out_data) override { - return parent_buffer_->MapMemoryImpl(mapping_mode, memory_access, - local_byte_offset, local_byte_length, - out_data); - } - - Status UnmapMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length, void* data) override { - return parent_buffer_->UnmapMemoryImpl(local_byte_offset, local_byte_length, - data); - } - - Status InvalidateMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) override { - return parent_buffer_->InvalidateMappedMemoryImpl(local_byte_offset, - local_byte_length); - } - - Status FlushMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) override { - return parent_buffer_->FlushMappedMemoryImpl(local_byte_offset, - local_byte_length); - } -}; - -// static -StatusOr<ref_ptr<Buffer>> Buffer::Subspan(const ref_ptr<Buffer>& buffer, - device_size_t byte_offset, - device_size_t byte_length) { - RETURN_IF_ERROR(buffer->CalculateRange(byte_offset, byte_length, &byte_offset, - &byte_length)); - if (byte_offset == 0 && byte_length == buffer->byte_length()) { - // Asking for the same buffer. - return add_ref(buffer); - } - - // To avoid heavy nesting of subspans that just add indirection we go to the - // parent buffer directly. If we wanted better accounting (to track where - // buffers came from) we'd want to avoid this but I'm not sure that's worth - // the super deep indirection that could arise. - if (buffer->allocated_buffer() != buffer.get()) { - CHECK(buffer->parent_buffer_); - return Buffer::Subspan(buffer->parent_buffer_, byte_offset, byte_length); - } else { - return {make_ref<SubspanBuffer>(add_ref(buffer), byte_offset, byte_length)}; - } -} - -// static -Buffer::Overlap Buffer::TestOverlap( - Buffer* lhs_buffer, device_size_t lhs_offset, device_size_t lhs_length, - Buffer* rhs_buffer, device_size_t rhs_offset, device_size_t rhs_length) { - if (lhs_buffer->allocated_buffer() != rhs_buffer->allocated_buffer()) { - // Not even the same buffers. - return Overlap::kDisjoint; - } - // Resolve offsets into the underlying allocation. - device_size_t lhs_alloc_offset = lhs_buffer->byte_offset() + lhs_offset; - device_size_t rhs_alloc_offset = rhs_buffer->byte_offset() + rhs_offset; - device_size_t lhs_alloc_length = lhs_length == kWholeBuffer - ? lhs_buffer->byte_length() - lhs_offset - : lhs_length; - device_size_t rhs_alloc_length = rhs_length == kWholeBuffer - ? rhs_buffer->byte_length() - rhs_offset - : rhs_length; - if (!lhs_alloc_length || !rhs_alloc_length) { - return Overlap::kDisjoint; - } - if (lhs_alloc_offset == rhs_alloc_offset && - lhs_alloc_length == rhs_alloc_length) { - return Overlap::kComplete; - } - return lhs_alloc_offset + lhs_alloc_length > rhs_alloc_offset && - rhs_alloc_offset + rhs_alloc_length > lhs_alloc_offset - ? Overlap::kPartial - : Overlap::kDisjoint; -} - -// static -bool Buffer::DoesOverlap(Buffer* lhs_buffer, device_size_t lhs_offset, - device_size_t lhs_length, Buffer* rhs_buffer, - device_size_t rhs_offset, device_size_t rhs_length) { - return TestOverlap(lhs_buffer, lhs_offset, lhs_length, rhs_buffer, rhs_offset, - rhs_length) != Overlap::kDisjoint; -} - -Buffer::Buffer(Allocator* allocator, MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, BufferUsageBitfield usage, - device_size_t allocation_size, device_size_t byte_offset, - device_size_t byte_length) - : allocated_buffer_(const_cast<Buffer*>(this)), - allocator_(allocator), - memory_type_(memory_type), - allowed_access_(allowed_access), - usage_(usage), - allocation_size_(allocation_size), - byte_offset_(byte_offset), - byte_length_(byte_length) { -#if HAS_IREE_BUFFER_DEBUG_NAME - // Default name for logging. - // It'd be nice to defer this until it's required but that would require - // synchronization or something. - const char* debug_name_prefix = ""; - if ((memory_type_ & MemoryType::kHostLocal) == MemoryType::kHostLocal) { - debug_name_prefix = "host_buffer_"; - } else if ((memory_type_ & MemoryType::kDeviceLocal) == - MemoryType::kDeviceLocal) { - // TODO(benvanik): include allocator ID to differentiate devices. - debug_name_prefix = "device_buffer_"; - } - debug_name_ = absl::StrCat(debug_name_prefix, next_buffer_id_++); -#endif // HAS_IREE_BUFFER_DEBUG_NAME -} - -Buffer* Buffer::allocated_buffer() const noexcept { - Buffer* allocated_buffer = allocated_buffer_; - while (allocated_buffer != this && - allocated_buffer != allocated_buffer->allocated_buffer()) { - allocated_buffer = allocated_buffer->allocated_buffer(); - } - return allocated_buffer; -} - -std::string Buffer::DebugString() const { - std::ostringstream stream; - stream << allocated_buffer()->debug_name() << "[" - << (allocation_size() == kWholeBuffer - ? "?" - : std::to_string(allocation_size())) - << "]."; - if (AnyBitSet(memory_type() & MemoryType::kTransient)) stream << "Z"; - if ((memory_type() & MemoryType::kHostLocal) == MemoryType::kHostLocal) { - stream << "h"; - } else { - if (AnyBitSet(memory_type() & MemoryType::kHostVisible)) stream << "v"; - if (AnyBitSet(memory_type() & MemoryType::kHostCoherent)) stream << "x"; - if (AnyBitSet(memory_type() & MemoryType::kHostCached)) stream << "c"; - } - if ((memory_type() & MemoryType::kDeviceLocal) == MemoryType::kDeviceLocal) { - stream << "D"; - } else { - if (AnyBitSet(memory_type() & MemoryType::kDeviceVisible)) stream << "V"; - } - stream << "."; - if (AnyBitSet(usage() & BufferUsage::kConstant)) stream << "c"; - if (AnyBitSet(usage() & BufferUsage::kTransfer)) stream << "t"; - if (AnyBitSet(usage() & BufferUsage::kMapping)) stream << "m"; - if (AnyBitSet(usage() & BufferUsage::kDispatch)) stream << "d"; - if (byte_offset_ || byte_length_ != allocation_size_) { - stream << "(" << byte_offset_ << "-" << (byte_offset_ + byte_length_ - 1) - << ")"; - } - return stream.str(); -} - -std::string Buffer::DebugStringShort() const { - // TODO(benvanik): figure out what's most useful here. Maybe a long variant? - std::ostringstream stream; - stream << allocated_buffer()->debug_name() << "[" - << (allocation_size() == kWholeBuffer - ? "?" - : std::to_string(allocation_size())) - << "]"; - if (byte_offset_ || byte_length_ != allocation_size_) { - stream << "(" << byte_offset_ << "-" << (byte_offset_ + byte_length_ - 1) - << ")"; - } - return stream.str(); -} - -Status Buffer::ValidateCompatibleMemoryType( - MemoryTypeBitfield memory_type) const { - if ((memory_type_ & memory_type) != memory_type) { - // Missing one or more bits. - return PermissionDeniedErrorBuilder(IREE_LOC) - << "Buffer memory type is not compatible with the requested " - "operation; buffer has " - << MemoryTypeString(memory_type_) << ", operation requires " - << MemoryTypeString(memory_type); - } - return OkStatus(); -} - -Status Buffer::ValidateAccess(MemoryAccessBitfield memory_access) const { - if (!AnyBitSet(memory_access & - (MemoryAccess::kRead | MemoryAccess::kWrite))) { - // No actual access bits defined. - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Memory access must specify one or more of kRead or kWrite"; - } else if ((allowed_access_ & memory_access) != memory_access) { - // Bits must match exactly. - return PermissionDeniedErrorBuilder(IREE_LOC) - << "The buffer does not support the requested access type; buffer " - "allows " - << MemoryAccessString(allowed_access_) << ", operation requires " - << MemoryAccessString(memory_access); - } - return OkStatus(); -} - -Status Buffer::ValidateUsage(BufferUsageBitfield usage) const { - if ((usage_ & usage) != usage) { - // Missing one or more bits. - return PermissionDeniedErrorBuilder(IREE_LOC) - << "Requested usage was not specified when the buffer was " - "allocated; buffer allows " - << BufferUsageString(usage_) << ", operation requires " - << BufferUsageString(usage); - } - return OkStatus(); -} - -Status Buffer::CalculateRange(device_size_t base_offset, - device_size_t max_length, device_size_t offset, - device_size_t length, - device_size_t* out_adjusted_offset, - device_size_t* out_adjusted_length) { - // Check if the start of the range runs off the end of the buffer. - if (offset > max_length) { - *out_adjusted_offset = 0; - if (out_adjusted_length) *out_adjusted_length = 0; - return OutOfRangeErrorBuilder(IREE_LOC) - << "Attempted to access an address off the end of the valid buffer " - "range (offset=" - << offset << ", length=" << length - << ", buffer byte_length=" << max_length << ")"; - } - - // Handle length as kWholeBuffer by adjusting it (if allowed). - if (length == kWholeBuffer && !out_adjusted_length) { - *out_adjusted_offset = 0; - return InvalidArgumentErrorBuilder(IREE_LOC) - << "kWholeBuffer may only be used with buffer ranges, not external " - "pointer ranges"; - } - - // Calculate the real ranges adjusted for our region within the allocation. - device_size_t adjusted_offset = base_offset + offset; - device_size_t adjusted_length = - length == kWholeBuffer ? max_length - offset : length; - if (adjusted_length == 0) { - // Fine to have a zero length. - *out_adjusted_offset = adjusted_offset; - if (out_adjusted_length) *out_adjusted_length = adjusted_length; - return OkStatus(); - } - - // Check if the end runs over the allocation. - device_size_t end = offset + adjusted_length - 1; - if (end >= max_length) { - *out_adjusted_offset = 0; - if (out_adjusted_length) *out_adjusted_length = 0; - return OutOfRangeErrorBuilder(IREE_LOC) - << "Attempted to access an address outside of the valid buffer " - "range (offset=" - << offset << ", adjusted_length=" << adjusted_length - << ", end=" << end << ", buffer byte_length=" << max_length << ")"; - } - - *out_adjusted_offset = adjusted_offset; - if (out_adjusted_length) *out_adjusted_length = adjusted_length; - return OkStatus(); -} - -Status Buffer::CalculateRange(device_size_t offset, device_size_t length, - device_size_t* out_adjusted_offset, - device_size_t* out_adjusted_length) const { - return CalculateRange(byte_offset_, byte_length_, offset, length, - out_adjusted_offset, out_adjusted_length); -} - -Status Buffer::CalculateLocalRange(device_size_t max_length, - device_size_t offset, device_size_t length, - device_size_t* out_adjusted_offset, - device_size_t* out_adjusted_length) { - return CalculateRange(0, max_length, offset, length, out_adjusted_offset, - out_adjusted_length); -} - -Status Buffer::Fill(device_size_t byte_offset, device_size_t byte_length, - const void* pattern, device_size_t pattern_length) { - // If not host visible we'll need to issue command buffers. - RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible)); - RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kWrite)); - RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping)); - RETURN_IF_ERROR( - CalculateRange(byte_offset, byte_length, &byte_offset, &byte_length)); - if (pattern_length != 1 && pattern_length != 2 && pattern_length != 4) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Fill patterns must be 1, 2, or 4 bytes"; - } - if ((byte_offset % pattern_length) != 0 || - (byte_length % pattern_length) != 0) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Attempting to fill a range with " << pattern_length - << " byte values that is not " - "aligned (offset=" - << byte_offset << ", length=" << byte_length << ")"; - } - if (byte_length == 0) { - return OkStatus(); // No-op. - } - const uint32_t kZero = 0; - if (std::memcmp(pattern, &kZero, pattern_length) == 0) { - // We can turn all-zero values into single-byte fills as that can be much - // faster on devices (doing a fill8 vs fill32). - pattern_length = 1; - } - return FillImpl(byte_offset, byte_length, pattern, pattern_length); -} - -Status Buffer::ReadData(device_size_t source_offset, void* data, - device_size_t data_length) { - // If not host visible we'll need to issue command buffers. - RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible)); - RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kRead)); - RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping)); - RETURN_IF_ERROR(CalculateRange(source_offset, data_length, &source_offset)); - if (data_length == 0) { - return OkStatus(); // No-op. - } - return ReadDataImpl(source_offset, data, data_length); -} - -Status Buffer::WriteData(device_size_t target_offset, const void* data, - device_size_t data_length) { - // If not host visible we'll need to issue command buffers. - RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible)); - RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kWrite)); - RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping)); - RETURN_IF_ERROR(CalculateRange(target_offset, data_length, &target_offset)); - if (data_length == 0) { - return OkStatus(); // No-op. - } - return WriteDataImpl(target_offset, data, data_length); -} - -Status Buffer::CopyData(device_size_t target_offset, Buffer* source_buffer, - device_size_t source_offset, - device_size_t data_length) { - RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible)); - RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kWrite)); - RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping)); - RETURN_IF_ERROR( - source_buffer->ValidateCompatibleMemoryType(MemoryType::kHostVisible)); - RETURN_IF_ERROR(source_buffer->ValidateAccess(MemoryAccess::kRead)); - RETURN_IF_ERROR(source_buffer->ValidateUsage(BufferUsage::kMapping)); - - // We need to validate both buffers. - device_size_t source_data_length = data_length; - device_size_t target_data_length = data_length; - device_size_t adjusted_source_offset; - RETURN_IF_ERROR(source_buffer->CalculateRange( - source_offset, source_data_length, &adjusted_source_offset, - &source_data_length)); - RETURN_IF_ERROR(CalculateRange(target_offset, target_data_length, - &target_offset, &target_data_length)); - device_size_t adjusted_data_length; - if (data_length == kWholeBuffer) { - // Whole buffer copy requested - that could mean either, so take the min. - adjusted_data_length = std::min(source_data_length, target_data_length); - } else { - // Specific length requested - validate that we have matching lengths. - CHECK_EQ(source_data_length, target_data_length); - adjusted_data_length = source_data_length; - } - - // Elide zero length copies. - if (adjusted_data_length == 0) { - return OkStatus(); - } - - // Check for overlap. - if (this == source_buffer && - adjusted_source_offset <= target_offset + adjusted_data_length && - target_offset <= adjusted_source_offset + adjusted_data_length) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Source and target ranges overlap within the same buffer"; - } - - return CopyDataImpl(target_offset, source_buffer, source_offset, - adjusted_data_length); -} - -Status Buffer::MapMemory(MappingMode mapping_mode, - MemoryAccessBitfield memory_access, - device_size_t* byte_offset, device_size_t* byte_length, - void** out_data) { - RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible)); - RETURN_IF_ERROR(ValidateAccess(memory_access)); - RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping)); - RETURN_IF_ERROR( - CalculateRange(*byte_offset, *byte_length, byte_offset, byte_length)); - *out_data = nullptr; - return MapMemoryImpl(mapping_mode, memory_access, *byte_offset, *byte_length, - out_data); -} - -Status Buffer::UnmapMemory(device_size_t local_byte_offset, - device_size_t local_byte_length, void* data) { - RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible)); - RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping)); - // NOTE: local_byte_offset/local_byte_length are already adjusted. - return UnmapMemoryImpl(local_byte_offset, local_byte_length, data); -} - -Status Buffer::InvalidateMappedMemory(device_size_t local_byte_offset, - device_size_t local_byte_length) { - RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible)); - if (AnyBitSet(memory_type_ & MemoryType::kHostCoherent)) { - return PermissionDeniedErrorBuilder(IREE_LOC) - << "Buffer memory type is coherent and invalidation is not required"; - } - RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping)); - // NOTE: local_byte_offset/local_byte_length are already adjusted. - return InvalidateMappedMemoryImpl(local_byte_offset, local_byte_length); -} - -Status Buffer::FlushMappedMemory(device_size_t local_byte_offset, - device_size_t local_byte_length) { - RETURN_IF_ERROR(ValidateCompatibleMemoryType(MemoryType::kHostVisible | - MemoryType::kHostCached)); - RETURN_IF_ERROR(ValidateUsage(BufferUsage::kMapping)); - // NOTE: local_byte_offset/local_byte_length are already adjusted. - return FlushMappedMemoryImpl(local_byte_offset, local_byte_length); -} - -} // namespace hal -} // namespace iree
diff --git a/iree/hal/buffer.h b/iree/hal/buffer.h deleted file mode 100644 index 20065e6..0000000 --- a/iree/hal/buffer.h +++ /dev/null
@@ -1,903 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Allocated memory buffer wrapper type and utilities. -// -// Buffers are the basic unit of memory used by the inference system. They may -// be allocated such that they are accessible from the host (normal C++ code -// running on the main CPU), a particular device (such as an accelerator) or -// family of devices, or from some mix of all of those. -// -// The type of memory a buffer is allocated within has implications on it's -// performance and lifetime. For example if an application attempts to use a -// host-allocated buffer (MemoryType::kHostLocal) on an accelerator with -// discrete memory the accelerator may either be unable to access the memory or -// take a non-trivial performance hit when attempting to do so (involving -// setting up kernel mappings, doing DMA transfers, etc). Likewise, trying to -// access a device-allocated buffer (MemoryType::kDeviceLocal) may incur similar -// overhead or not be possible at all. This may be due to restrictions in the -// memory visibility, address spaces, mixed endianness or pointer widths, -// and other weirdness. -// -// The memory types (defined by a bitfield of MemoryType values) that a -// particular context (host or device) may use vary from device to device and -// must be queried by the application when allocating buffers. It's strongly -// recommended that the most specific memory type be set as possible. For -// example allocating a buffer with MemoryType::kHostCoherent even when it will -// never be used in a way that requires coherency may occupy address space -// reservations or memory mapping that would otherwise not be needed. -// -// As buffers may sometimes not be accessible from the host the base Buffer type -// does not allow for direct void* access and instead buffers must be either -// manipulated using utility functions (such as ReadData or WriteData) or by -// mapping them into a host-accessible address space via MapMemory. Buffer must -// be unmapped before any command may use it. -// -// Buffers may map (roughly) 1:1 with an allocation either from the host heap or -// a device. Buffer::Subspan can be used to reference subspans of buffers like -// absl::Span - though unlike absl::Span the returned Buffer holds a reference -// to the parent buffer. - -#ifndef IREE_HAL_BUFFER_H_ -#define IREE_HAL_BUFFER_H_ - -#include <cstddef> -#include <cstdint> -#include <memory> -#include <string> -#include <utility> - -#include "absl/types/span.h" -#include "absl/types/variant.h" -#include "iree/base/bitfield.h" -#include "iree/base/logging.h" -#include "iree/base/source_location.h" -#include "iree/base/status.h" -#include "iree/hal/resource.h" - -// Only enable debug names in non-opt modes (unless the user forces it on). -#if !defined(NDEBUG) && !defined(HAS_IREE_BUFFER_DEBUG_NAME) -#define HAS_IREE_BUFFER_DEBUG_NAME 1 -#endif // !NDEBUG - -namespace iree { - -// std::size_t equivalent that is the size as used on device. -// As the device may have a larger memory address space than the host we treat -// all byte offsets as this type instead of the host-specified size_t. -using device_size_t = uint64_t; - -// When used as a length value in functions causes the length to be the entire -// remaining buffer from the specified offset. -constexpr device_size_t kWholeBuffer = ~0ull; - -} // namespace iree - -namespace iree { -namespace hal { - -class Allocator; -template <typename T> -class MappedMemory; - -// A bitfield specifying properties for a memory type. -enum class MemoryType : uint32_t { - kNone = 0, - - // Memory is lazily allocated by the device and only exists transiently. - // This is the optimal mode for memory used only within a single command - // buffer. Transient buffers, even if they have kHostVisible set, should be - // treated as device-local and opaque as they may have no memory attached to - // them outside of the time they are being evaluated on devices. - // - // This flag can be treated as a hint in most cases; allocating a buffer with - // it set _may_ return the same as if it had not be set. Certain allocation - // routines may use the hint to more tightly control reuse or defer wiring the - // memory. - kTransient = 1 << 0, - - // Memory allocated with this type can be mapped for host access using - // Buffer::MapMemory. - kHostVisible = 1 << 1, - - // The host cache management commands MappedMemory::Flush and - // MappedMemory::Invalidate are not needed to flush host writes - // to the device or make device writes visible to the host, respectively. - kHostCoherent = 1 << 2, - - // Memory allocated with this type is cached on the host. Host memory - // accesses to uncached memory are slower than to cached memory, however - // uncached memory is always host coherent. MappedMemory::Flush must be used - // to ensure the device has visibility into any changes made on the host and - // Invalidate must be used to ensure the host has visibility into any changes - // made on the device. - kHostCached = 1 << 3, - - // Memory is accessible as normal host allocated memory. - kHostLocal = kHostVisible | kHostCoherent, - - // Memory allocated with this type is visible to the device for execution. - // Being device visible does not mean the same thing as kDeviceLocal. Though - // an allocation may be visible to the device and therefore useable for - // execution it may require expensive mapping or implicit transfers. - kDeviceVisible = 1 << 4, - - // Memory allocated with this type is the most efficient for device access. - // Devices may support using memory that is not device local via - // kDeviceVisible but doing so can incur non-trivial performance penalties. - // Device local memory, on the other hand, is guaranteed to be fast for all - // operations. - kDeviceLocal = kDeviceVisible | (1 << 5), -}; -IREE_BITFIELD(MemoryType); -using MemoryTypeBitfield = MemoryType; -std::string MemoryTypeString(MemoryTypeBitfield memory_type); - -// A bitfield specifying how memory will be accessed in a mapped memory region. -enum class MemoryAccess : uint32_t { - // Memory is not mapped. - kNone = 0, - - // Memory will be read. - // If a buffer is only mapped for reading it may still be possible to write to - // it but the results will be undefined (as it may present coherency issues). - kRead = 1 << 0, - - // Memory will be written. - // If a buffer is only mapped for writing it may still be possible to read - // from it but the results will be undefined or incredibly slow (as it may - // be mapped by the driver as uncached). - kWrite = 1 << 1, - - // Memory will be discarded prior to mapping. - // The existing contents will be undefined after mapping and must be written - // to ensure validity. - kDiscard = 1 << 2, - - // Memory will be discarded and completely overwritten in a single operation. - kDiscardWrite = kWrite | kDiscard, - - // Memory may have any operation performed on it. - kAll = kRead | kWrite | kDiscard, -}; -IREE_BITFIELD(MemoryAccess); -using MemoryAccessBitfield = MemoryAccess; -std::string MemoryAccessString(MemoryAccessBitfield memory_access); - -// Bitfield that defines how a buffer is intended to be used. -// Usage allows the driver to appropriately place the buffer for more -// efficient operations of the specified types. -enum class BufferUsage { - kNone = 0, - - // The buffer, once defined, will not be mapped or updated again. - // This should be used for uniform parameter values such as runtime - // constants for executables. Doing so may allow drivers to inline values or - // represent them in command buffers more efficiently (avoiding memory reads - // or swapping, etc). - kConstant = 1 << 0, - - // The buffer can be used as the source or target of a transfer command - // (CopyBuffer, UpdateBuffer, etc). - // - // If |kMapping| is not specified drivers may safely assume that the host - // may never need visibility of this buffer as all accesses will happen via - // command buffers. - kTransfer = 1 << 1, - - // The buffer can be mapped by the host application for reading and writing. - // - // As mapping may require placement in special address ranges or system - // calls to enable visibility the driver can use the presence (or lack of) - // this flag to perform allocation-type setup and avoid initial mapping - // overhead. - kMapping = 1 << 2, - - // The buffer can be provided as an input or output to an executable. - // Buffers of this type may be directly used by drivers during dispatch. - kDispatch = 1 << 3, - - // Buffer may be used for any operation. - kAll = kTransfer | kMapping | kDispatch, -}; -IREE_BITFIELD(BufferUsage); -using BufferUsageBitfield = BufferUsage; -std::string BufferUsageString(BufferUsageBitfield buffer_usage); - -// A memory buffer. -// Buffers have a specific memory_type that is used to describe the capabilities -// and behavior of the backing memory of the buffer. Buffers may be any mix of -// host-accessible, host-coherent, or device-accessible for various usages. -// Depending on these memory types the buffers may be mapped for access on the -// host as memory though certain restrictions may be imposed. -// -// See MemoryType for more information about the types and what operations they -// support. -class Buffer : public Resource { - public: - // Returns a reference to a subspan of the buffer. - // If |byte_length| is kWholeBuffer the remaining bytes in the buffer after - // |byte_offset| (possibly 0) will be selected. - // - // The parent buffer will remain alive for the lifetime of the subspan - // returned. If the subspan is a small portion this may cause additional - // memory to remain allocated longer than required. - // - // Returns the given |buffer| if the requested span covers the entire range. - static StatusOr<ref_ptr<Buffer>> Subspan(const ref_ptr<Buffer>& buffer, - device_size_t byte_offset, - device_size_t byte_length); - - // Overlap test results. - enum class Overlap { - // No overlap between the two buffers. - kDisjoint, - // Partial overlap between the two buffers. - kPartial, - // Complete overlap between the two buffers (they are the same). - kComplete, - }; - - // Tests whether the given buffers overlap, including support for subspans. - // kWholeBuffer may be used for |lhs_length| and/or |rhs_length| to use the - // lengths of those buffers, respectively. - static Overlap TestOverlap(Buffer* lhs_buffer, device_size_t lhs_offset, - device_size_t lhs_length, Buffer* rhs_buffer, - device_size_t rhs_offset, - device_size_t rhs_length); - - // Returns true if the two buffer ranges overlap at all. - static bool DoesOverlap(Buffer* lhs_buffer, device_size_t lhs_offset, - device_size_t lhs_length, Buffer* rhs_buffer, - device_size_t rhs_offset, device_size_t rhs_length); - - // Disallow copies (as copying requires real work). - Buffer(const Buffer&) = delete; - Buffer& operator=(const Buffer&) = delete; - - ~Buffer() override = default; - -#if HAS_IREE_BUFFER_DEBUG_NAME - // Optionally populated name useful for logging a persistent name for the - // buffer. - absl::string_view debug_name() const { return debug_name_; } - void set_debug_name(std::string debug_name) { - debug_name_ = std::move(debug_name); - } -#else - absl::string_view debug_name() const { return ""; } - void set_debug_name(std::string debug_name) {} -#endif // HAS_IREE_BUFFER_DEBUG_NAME - - // Memory allocator this buffer was allocated from. - // May be nullptr if the buffer has no particular allocator and should be - // assumed to be allocated from the host heap. - constexpr Allocator* allocator() const { - return allocated_buffer_ == this ? allocator_ - : allocated_buffer_->allocator(); - } - - // Memory type this buffer is allocated from. - MemoryTypeBitfield memory_type() const { return memory_type_; } - - // Memory access operations allowed on the buffer. - MemoryAccessBitfield allowed_access() const { return allowed_access_; } - - // Bitfield describing how the buffer is to be used. - BufferUsageBitfield usage() const { return usage_; } - - // Returns the underlying buffer that represents the allocated memory for the - // Buffer. In most cases this is the buffer itself but for buffer subspan - // references it will point to the parent buffer. - Buffer* allocated_buffer() const noexcept; - - // Size of the resource memory allocation in bytes. - // This may be rounded up from the originally requested size or the ideal - // size for the resource based on device restrictions. - constexpr device_size_t allocation_size() const { - return allocated_buffer_ == this ? allocation_size_ - : allocated_buffer_->allocation_size(); - } - - // Range within the underlying allocation this buffer occupies. - // For buffers that map 1:1 with an allocation this should be - // [0, allocation_size()), however may still differ if the allocation needed - // to be aligned. - // - // The offset is most often manipulated by Subspan, however it's important to - // note that the offset may not be what was passed to Subspan as it refers to - // the offset in the original ancestor buffer, not the buffer from which the - // subspan was taken. - constexpr device_size_t byte_offset() const noexcept { return byte_offset_; } - constexpr device_size_t byte_length() const noexcept { return byte_length_; } - - // TODO(benvanik): add debug_name. - - // Returns a longer debug string describing the buffer and its attributes. - std::string DebugString() const; - // Returns a short debug string describing the buffer. - std::string DebugStringShort() const; - - // Sets a range of the buffer to the given value. - // This requires that the resource was allocated with - // MemoryType::kHostVisible and BufferUsage::kMapping. - // If |byte_length| is kWholeBuffer the remaining bytes in the buffer after - // |byte_offset| (possibly 0) will be filled. - // - // The |byte_offset| and |byte_length| must be aligned to the size of the fill - // value. Multi-byte values will be written in host order for host buffers and - // device order for device buffers. - // - // Only |pattern_length| values with 1, 2, or 4 bytes are supported. - // - // Fails if the write could not be performed; either the bounds are out of - // range or the memory type does not support writing in this way. - Status Fill(device_size_t byte_offset, device_size_t byte_length, - const void* pattern, device_size_t pattern_length); - template <typename T> - Status Fill8(device_size_t byte_offset, device_size_t byte_length, T value); - template <typename T> - Status Fill16(device_size_t byte_offset, device_size_t byte_length, T value); - template <typename T> - Status Fill32(device_size_t byte_offset, device_size_t byte_length, T value); - template <typename T> - Status Fill8(T value); - template <typename T> - Status Fill16(T value); - template <typename T> - Status Fill32(T value); - - // Reads a block of byte data from the resource at the given offset. - // This requires that the resource was allocated with - // MemoryType::kHostVisible and BufferUsage::kMapping. - // - // Fails if the read could not be performed; either the bounds are out of - // range or the memory type does not support reading in this way. - Status ReadData(device_size_t source_offset, void* data, - device_size_t data_length); - - // Writes a block of byte data into the resource at the given offset. - // This requires that the resource was allocated with - // MemoryType::kHostVisible and BufferUsage::kMapping. - // - // Fails if the write could not be performed; either the bounds are out of - // range or the memory type does not support writing in this way. - Status WriteData(device_size_t target_offset, const void* data, - device_size_t data_length); - - // Copies data from the provided source_buffer into the buffer. - // This requires that the resource was allocated with - // MemoryType::kHostVisible and BufferUsage::kMapping. - // The source and destination may be the same buffer but the ranges must not - // overlap (a la memcpy). - // - // Fails if the write could not be performed; either the bounds are out of - // range or the memory type does not support writing in this way. - Status CopyData(device_size_t target_offset, Buffer* source_buffer, - device_size_t source_offset, device_size_t data_length); - Status CopyData(device_size_t target_offset, Buffer* source_buffer) { - return CopyData(target_offset, source_buffer, 0, kWholeBuffer); - } - - // Maps the resource memory for direct access from the host. - // This requires that the resource was allocated with - // MemoryType::kHostVisible and BufferUsage::kMapping. - // - // If MemoryType::kHostCoherent was not specified then explicit - // Invalidate and Flush calls must be used to control visibility of the data - // on the device. If MemoryType::kHostCached is not set callers must not - // attempt to read from the mapped memory as doing so may produce undefined - // results and/or ultra slow reads. - // - // If the MemoryAccess::kDiscard bit is set when mapping for writes the caller - // guarantees that they will be overwriting all data in the mapped range. This - // is used as a hint to the device that the prior contents are no longer - // required and can enable optimizations that save on synchronization and - // readback. Note however that it is strictly a hint and the contents are not - // guaranteed to be zeroed during mapping. - // - // This allows mapping the memory as a C++ type. Care must be taken to ensure - // the data layout in C++ matches the expected data layout in the executables - // that consume this data. For simple primitives like uint8_t or float this is - // usually not a problem however struct packing may have many restrictions. - // - // The returned mapping should be unmapped when it is no longer required. - // Unmapping does not implicitly flush. - // - // Fails if the memory could not be mapped due to mapping exhaustion, invalid - // arguments, or unsupported memory types. - // - // Example: - // ASSIGN_OR_RETURN(auto mapping, buffer->MapForRead<MyStruct>()); - // mapping[5].foo = 3; - // std::memcpy(mapping.data(), source_data, mapping.size()); - // mapping.reset(); - template <typename T> - StatusOr<MappedMemory<T>> MapMemory( - MemoryAccessBitfield memory_access, device_size_t element_offset = 0, - device_size_t element_length = kWholeBuffer); - - protected: - template <typename T> - friend class MappedMemory; - - // Defines the mode of a MapMemory operation. - enum class MappingMode { - // The call to MapMemory will always be matched with UnmapMemory. - kScoped, - }; - - Buffer(Allocator* allocator, MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, BufferUsageBitfield usage, - device_size_t allocation_size, device_size_t byte_offset, - device_size_t byte_length); - - // Allows subclasses to override the allowed access bits. - // This should only be done when known safe by the allocation scheme. - void set_allowed_access(MemoryAccessBitfield allowed_access) { - allowed_access_ = allowed_access; - } - - // Sets a range of the buffer to the given value. - // State and parameters have already been validated. For the >8bit variants - // the offset and length have already been validated to be aligned to the - // natural alignment of the type. - virtual Status FillImpl(device_size_t byte_offset, device_size_t byte_length, - const void* pattern, - device_size_t pattern_length) = 0; - - // Reads a block of byte data from the resource at the given offset. - // State and parameters have already been validated. - virtual Status ReadDataImpl(device_size_t source_offset, void* data, - device_size_t data_length) = 0; - - // Writes a block of byte data into the resource at the given offset. - // State and parameters have already been validated. - virtual Status WriteDataImpl(device_size_t target_offset, const void* data, - device_size_t data_length) = 0; - - // Copies a block of byte data into the resource at the given offset. - // State and parameters have already been validated. - virtual Status CopyDataImpl(device_size_t target_offset, - Buffer* source_buffer, - device_size_t source_offset, - device_size_t data_length) = 0; - - // Maps memory directly. - // The output data pointer will be properly aligned to the start of the data. - // |local_byte_offset| and |local_byte_length| are the adjusted values that - // should map into the local space of the buffer. - // - // Fails if the memory could not be mapped (invalid access type, invalid - // range, or unsupported memory type). - // State and parameters have already been validated. - virtual Status MapMemoryImpl(MappingMode mapping_mode, - MemoryAccessBitfield memory_access, - device_size_t local_byte_offset, - device_size_t local_byte_length, - void** out_data) = 0; - - // Unmaps previously mapped memory. - // No-op if the memory is not mapped. As this is often used in destructors - // we can't rely on failures here propagating with anything but CHECK/DCHECK. - // State and parameters have already been validated. - virtual Status UnmapMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length, - void* data) = 0; - - // Invalidates ranges of non-coherent memory from the host caches. - // Use this before reading from non-coherent memory. - // This guarantees that device writes to the memory ranges provided are - // visible on the host. - // This is only required for memory types without kHostCoherent set. - // State and parameters have already been validated. - virtual Status InvalidateMappedMemoryImpl( - device_size_t local_byte_offset, device_size_t local_byte_length) = 0; - - // Flushes ranges of non-coherent memory from the host caches. - // Use this after writing to non-coherent memory. - // This guarantees that host writes to the memory ranges provided are made - // available for device access. - // This is only required for memory types without kHostCoherent set. - // State and parameters have already been validated. - virtual Status FlushMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) = 0; - - // Validates the given buffer range and adjusts the offset and length if the - // provided length is kWholeBuffer or the buffer is offset within its - // allocation. This calculates the range in the given domain without adjusting - // to any particular buffer base offsets. - static Status CalculateLocalRange(device_size_t max_length, - device_size_t offset, device_size_t length, - device_size_t* out_adjusted_offset, - device_size_t* out_adjusted_length); - - private: - friend class Allocator; - - // This is not great and deserves cleanup. - friend class DeferredBuffer; - friend class SubspanBuffer; - friend class HeapBuffer; - - // Maps memory directly. - // The byte offset and byte length may be adjusted for device alignment. - // The output data pointer will be properly aligned to the start of the data. - // Fails if the memory could not be mapped (invalid access type, invalid - // range, or unsupported memory type). - Status MapMemory(MappingMode mapping_mode, MemoryAccessBitfield memory_access, - device_size_t* byte_offset, device_size_t* byte_length, - void** out_data); - - // Unmaps previously mapped memory. - // No-op if the memory is not mapped. As this is often used in destructors - // we can't rely on failures here propagating with anything but CHECK/DCHECK. - Status UnmapMemory(device_size_t local_byte_offset, - device_size_t local_byte_length, void* data); - - // Invalidates ranges of non-coherent memory from the host caches. - // Use this before reading from non-coherent memory. - // This guarantees that device writes to the memory ranges provided are - // visible on the host. - // This is only required for memory types without kHostCoherent set. - Status InvalidateMappedMemory(device_size_t local_byte_offset, - device_size_t local_byte_length); - - // Flushes ranges of non-coherent memory from the host caches. - // Use this after writing to non-coherent memory. - // This guarantees that host writes to the memory ranges provided are made - // available for device access. - // This is only required for memory types without kHostCoherent set. - Status FlushMappedMemory(device_size_t local_byte_offset, - device_size_t local_byte_length); - - // Returns a failure if the memory type the buffer was allocated from is not - // compatible with the given type. - Status ValidateCompatibleMemoryType(MemoryTypeBitfield memory_type) const; - // Returns a failure if the buffer memory type or usage disallows the given - // access type. - Status ValidateAccess(MemoryAccessBitfield memory_access) const; - // Returns a failure if the buffer was not allocated for the given usage. - Status ValidateUsage(BufferUsageBitfield usage) const; - // Validates the given buffer range and optionally adjusts the offset and - // length if the provided length is kWholeBuffer or the buffer is offset - // within its allocation. - static Status CalculateRange(device_size_t base_offset, - device_size_t max_length, device_size_t offset, - device_size_t length, - device_size_t* out_adjusted_offset, - device_size_t* out_adjusted_length = nullptr); - Status CalculateRange(device_size_t offset, device_size_t length, - device_size_t* out_adjusted_offset, - device_size_t* out_adjusted_length = nullptr) const; - - // Points to either this or parent_buffer_.get(). - Buffer* allocated_buffer_ = nullptr; - - Allocator* allocator_ = nullptr; - MemoryTypeBitfield memory_type_ = MemoryType::kNone; - MemoryAccessBitfield allowed_access_ = MemoryAccess::kNone; - BufferUsageBitfield usage_ = BufferUsage::kNone; - - device_size_t allocation_size_ = 0; - device_size_t byte_offset_ = 0; - device_size_t byte_length_ = 0; - -#if HAS_IREE_BUFFER_DEBUG_NAME - // Friendly name for the buffer used in DebugString. May be set by the app or - // auto generated. - std::string debug_name_; -#endif // HAS_IREE_BUFFER_DEBUG_NAME - - // Defined when this buffer is a subspan of another buffer. - ref_ptr<Buffer> parent_buffer_; -}; - -// A memory mapping RAII object. -// The mapping will stay active until it is reset and will retain the buffer. -template <typename T> -class MappedMemory { - public: - using unspecified_bool_type = const T* MappedMemory<T>::*; - - MappedMemory() = default; - MappedMemory(MemoryAccessBitfield access, ref_ptr<Buffer> buffer, - device_size_t byte_offset, device_size_t byte_length, - device_size_t element_size, T* data); - - // Allow moving but disallow copying as the mapping is stateful. - MappedMemory(MappedMemory&& rhs) noexcept; - MappedMemory& operator=(MappedMemory&& rhs) noexcept; - MappedMemory(const MappedMemory&) = delete; - MappedMemory& operator=(const MappedMemory&) = delete; - - ~MappedMemory(); - - // The buffer resource that this mapping references. - const ref_ptr<Buffer>& buffer() const noexcept { return buffer_; } - // Offset, in bytes, into the resource allocation. - // This value is *informative only*, as it may vary from device to device. - device_size_t byte_offset() const noexcept { return byte_offset_; } - // Length, in bytes, of the resource mapping. - // This may be larger than the originally requested length due to alignment. - // This value is *informative only*, as it may vary from device to device. - device_size_t byte_length() const noexcept { return byte_length_; } - - // True if the mapping is empty. - bool empty() const noexcept { return element_size_ == 0; } - // The size of the mapping as requested in elements. - size_t size() const noexcept { return static_cast<size_t>(element_size_); } - - // Returns a read-only pointer to the mapped memory. - // This will be nullptr if the mapping failed or the mapping is not readable. - const T* data() const noexcept; - absl::Span<const T> contents() const noexcept { return {data(), size()}; } - - // Returns a mutable pointer to the mapped memory. - // This will be nullptr if the mapping failed or the mapping is not writable. - // If the mapping was not made with read access it may still be possible to - // read from this memory but behavior is undefined. - T* mutable_data() noexcept; - absl::Span<T> mutable_contents() noexcept { return {mutable_data(), size()}; } - - // Equivalent to absl::Span::subspan(). - // May return a 0-length span. - // Fails if the buffer is not mapped or not mapped for the requested access. - StatusOr<absl::Span<const T>> Subspan( - device_size_t element_offset = 0, - device_size_t element_length = kWholeBuffer) const noexcept; - StatusOr<absl::Span<T>> MutableSubspan( - device_size_t element_offset = 0, - device_size_t element_length = kWholeBuffer) noexcept; - - // Accesses an element in the mapped memory. - // Must be called with a valid index in [0, size()). - const T& operator[](device_size_t i) const noexcept { return data_[i]; } - - // Invalidates a range of non-coherent elements from the host caches. - Status Invalidate(device_size_t element_offset = 0, - device_size_t element_length = kWholeBuffer) const; - - // Flushes a range of non-coherent elements from the host caches. - Status Flush(device_size_t element_offset = 0, - device_size_t element_length = kWholeBuffer); - - // Unmaps the mapped memory. - // The memory will not be implicitly flushed when unmapping. - void reset(); - - private: - Status ValidateAccess(MemoryAccessBitfield memory_access) const; - Status CalculateDataRange(device_size_t element_offset, - device_size_t element_length, - device_size_t* out_adjusted_element_offset, - device_size_t* out_adjusted_element_length) const; - - MemoryAccessBitfield access_ = MemoryAccess::kNone; - ref_ptr<Buffer> buffer_; - device_size_t byte_offset_ = 0; - device_size_t byte_length_ = 0; - device_size_t element_size_ = 0; - T* data_ = nullptr; -}; - -// Inline functions and template definitions follow: - -template <typename T> -Status Buffer::Fill8(device_size_t byte_offset, device_size_t byte_length, - T value) { - auto sized_value = reinterpret_cast<uint8_t*>(&value); - return Fill(byte_offset, byte_length, sized_value, sizeof(*sized_value)); -} - -template <typename T> -Status Buffer::Fill16(device_size_t byte_offset, device_size_t byte_length, - T value) { - auto sized_value = reinterpret_cast<uint16_t*>(&value); - return Fill(byte_offset, byte_length, sized_value, sizeof(*sized_value)); -} - -template <typename T> -Status Buffer::Fill32(device_size_t byte_offset, device_size_t byte_length, - T value) { - auto sized_value = reinterpret_cast<uint32_t*>(&value); - return Fill(byte_offset, byte_length, sized_value, sizeof(*sized_value)); -} - -template <typename T> -Status Buffer::Fill8(T value) { - return Fill8(0, kWholeBuffer, value); -} - -template <typename T> -Status Buffer::Fill16(T value) { - return Fill16(0, kWholeBuffer, value); -} - -template <typename T> -Status Buffer::Fill32(T value) { - return Fill32(0, kWholeBuffer, value); -} - -template <typename T> -StatusOr<MappedMemory<T>> Buffer::MapMemory(MemoryAccessBitfield memory_access, - device_size_t element_offset, - device_size_t element_length) { - device_size_t byte_offset = element_offset * sizeof(T); - device_size_t byte_length = element_length == kWholeBuffer - ? kWholeBuffer - : element_length * sizeof(T); - void* data = nullptr; - RETURN_IF_ERROR(MapMemory(MappingMode::kScoped, memory_access, &byte_offset, - &byte_length, &data)); - return MappedMemory<T>{ - memory_access, add_ref(this), byte_offset, - byte_length, byte_length / sizeof(T), static_cast<T*>(data)}; -} - -template <typename T> -MappedMemory<T>::MappedMemory(MemoryAccessBitfield access, - ref_ptr<Buffer> buffer, device_size_t byte_offset, - device_size_t byte_length, - device_size_t element_size, T* data) - : access_(access), - buffer_(std::move(buffer)), - byte_offset_(byte_offset), - byte_length_(byte_length), - element_size_(element_size), - data_(data) {} - -template <typename T> -MappedMemory<T>::MappedMemory(MappedMemory<T>&& rhs) noexcept - : access_(rhs.access_), - buffer_(std::move(rhs.buffer_)), - byte_offset_(rhs.byte_offset_), - byte_length_(rhs.byte_length_), - element_size_(rhs.element_size_), - data_(rhs.data_) { - rhs.access_ = MemoryAccess::kNone; - rhs.buffer_.reset(); - rhs.byte_offset_ = 0; - rhs.byte_length_ = 0; - rhs.element_size_ = 0; - rhs.data_ = nullptr; -} - -template <typename T> -MappedMemory<T>& MappedMemory<T>::operator=(MappedMemory<T>&& rhs) noexcept { - if (this != &rhs) { - reset(); - access_ = rhs.access_; - buffer_ = std::move(rhs.buffer_); - byte_offset_ = rhs.byte_offset_; - byte_length_ = rhs.byte_length_; - element_size_ = rhs.element_size_; - data_ = rhs.data_; - - rhs.access_ = MemoryAccess::kNone; - rhs.buffer_.reset(); - rhs.byte_offset_ = 0; - rhs.byte_length_ = 0; - rhs.element_size_ = 0; - rhs.data_ = nullptr; - } - return *this; -} - -template <typename T> -MappedMemory<T>::~MappedMemory() { - // Unmap (if needed) - note that we can't fail gracefully here :( - reset(); -} - -template <typename T> -const T* MappedMemory<T>::data() const noexcept { - if (!data_ || !AnyBitSet(access_ & MemoryAccess::kRead)) { - return nullptr; - } - return data_; -} - -template <typename T> -T* MappedMemory<T>::mutable_data() noexcept { - if (!data_ || !AnyBitSet(access_ & MemoryAccess::kWrite)) { - return nullptr; - } - return data_; -} - -template <typename T> -Status MappedMemory<T>::ValidateAccess( - MemoryAccessBitfield memory_access) const { - if (!data_) { - return FailedPreconditionErrorBuilder(IREE_LOC) << "Buffer is not mapped"; - } else if (!AnyBitSet(access_ & memory_access)) { - return PermissionDeniedErrorBuilder(IREE_LOC) - << "Buffer is not mapped for the desired access"; - } - return OkStatus(); -} - -template <typename T> -Status MappedMemory<T>::CalculateDataRange( - device_size_t element_offset, device_size_t element_length, - device_size_t* out_adjusted_element_offset, - device_size_t* out_adjusted_element_length) const { - RETURN_IF_ERROR(Buffer::CalculateLocalRange( - element_size_ * sizeof(T), element_offset * sizeof(T), - element_length == kWholeBuffer ? kWholeBuffer - : element_length * sizeof(T), - out_adjusted_element_offset, out_adjusted_element_length)); - *out_adjusted_element_offset /= sizeof(T); - *out_adjusted_element_length /= sizeof(T); - return OkStatus(); -} - -template <typename T> -inline StatusOr<absl::Span<const T>> MappedMemory<T>::Subspan( - device_size_t element_offset, device_size_t element_length) const noexcept { - RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kRead)); - RETURN_IF_ERROR(CalculateDataRange(element_offset, element_length, - &element_offset, &element_length)); - return absl::Span<const T>(data_ + element_offset, element_length); -} - -template <typename T> -inline StatusOr<absl::Span<T>> MappedMemory<T>::MutableSubspan( - device_size_t element_offset, device_size_t element_length) noexcept { - RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kWrite)); - RETURN_IF_ERROR(CalculateDataRange(element_offset, element_length, - &element_offset, &element_length)); - return absl::Span<T>(data_ + element_offset, element_length); -} - -template <typename T> -Status MappedMemory<T>::Invalidate(device_size_t element_offset, - device_size_t element_length) const { - RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kRead)); - RETURN_IF_ERROR(CalculateDataRange(element_offset, element_length, - &element_offset, &element_length)); - if (!element_length) return OkStatus(); - return buffer_->InvalidateMappedMemory( - byte_offset_ + element_offset * sizeof(T), element_length * sizeof(T)); -} - -template <typename T> -Status MappedMemory<T>::Flush(device_size_t element_offset, - device_size_t element_length) { - RETURN_IF_ERROR(ValidateAccess(MemoryAccess::kWrite)); - RETURN_IF_ERROR(CalculateDataRange(element_offset, element_length, - &element_offset, &element_length)); - if (!element_length) return OkStatus(); - return buffer_->FlushMappedMemory(byte_offset_ + element_offset * sizeof(T), - element_length * sizeof(T)); -} - -template <typename T> -void MappedMemory<T>::reset() { - if (!buffer_) return; - // TODO(benvanik): better handling of errors? may be fine to always warn. - buffer_->UnmapMemory(byte_offset_, byte_length_, data_).IgnoreError(); - buffer_.reset(); - access_ = MemoryAccess::kNone; - byte_offset_ = 0; - byte_length_ = 0; - element_size_ = 0; - data_ = nullptr; -} - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_BUFFER_H_
diff --git a/iree/hal/buffer_mapping_test.cc b/iree/hal/buffer_mapping_test.cc deleted file mode 100644 index 9b7e9ec..0000000 --- a/iree/hal/buffer_mapping_test.cc +++ /dev/null
@@ -1,539 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Tests for the MemoryMapping RAII wrapper. -// This uses a mock buffer implementation such that it is only testing -// MemoryMapping and not any real underlying memory mapping behavior. - -#include <cstdint> -#include <memory> -#include <utility> - -#include "absl/types/span.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "iree/base/status.h" -#include "iree/base/status_matchers.h" -#include "iree/hal/buffer.h" - -namespace iree { -namespace hal { -class Allocator; - -namespace { - -using ::testing::_; -using ::testing::DoAll; -using ::testing::Return; -using ::testing::SetArgPointee; - -static void* const kValidPtr = reinterpret_cast<void*>(0xBEEFCAFEF00D1234ull); - -class MockBuffer : public Buffer { - public: - using Buffer::MappingMode; - - MockBuffer(Allocator* allocator, MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, BufferUsageBitfield usage, - device_size_t allocation_size) - : Buffer(allocator, memory_type, allowed_access, usage, allocation_size, - 0, allocation_size) {} - - MOCK_METHOD4(FillImpl, - Status(device_size_t byte_offset, device_size_t byte_length, - const void* pattern, device_size_t pattern_length)); - - MOCK_METHOD3(ReadDataImpl, Status(device_size_t source_offset, void* data, - device_size_t data_length)); - MOCK_METHOD3(WriteDataImpl, - Status(device_size_t target_offset, const void* data, - device_size_t data_length)); - MOCK_METHOD4(CopyDataImpl, - Status(device_size_t target_offset, Buffer* source_buffer, - device_size_t source_offset, device_size_t data_length)); - - MOCK_METHOD5(MapMemoryImpl, - Status(MappingMode mapping_mode, - MemoryAccessBitfield memory_access, - device_size_t local_byte_offset, - device_size_t local_byte_length, void** out_data)); - MOCK_METHOD3(UnmapMemoryImpl, - Status(device_size_t local_byte_offset, - device_size_t local_byte_length, void* data)); - MOCK_METHOD2(InvalidateMappedMemoryImpl, - Status(device_size_t local_byte_offset, - device_size_t local_byte_length)); - MOCK_METHOD2(FlushMappedMemoryImpl, Status(device_size_t local_byte_offset, - device_size_t local_byte_length)); -}; - -TEST(MemoryMappingTest, MapWholeBuffer) { - auto buffer = - std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal, - MemoryAccess::kAll, BufferUsage::kAll, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kRead, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - ASSERT_OK_AND_ASSIGN(auto mapping, - buffer->MapMemory<uint8_t>(MemoryAccess::kRead)); - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mapping.reset(); -} - -TEST(MemoryMappingTest, MapPartialBuffer) { - auto buffer = - std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal, - MemoryAccess::kAll, BufferUsage::kAll, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kRead, 4, 12, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - ASSERT_OK_AND_ASSIGN(auto mapping, - buffer->MapMemory<uint8_t>(MemoryAccess::kRead, 4, 12)); - EXPECT_CALL(*buffer, UnmapMemoryImpl(4, 12, kValidPtr)) - .WillOnce(Return(OkStatus())); - mapping.reset(); -} - -TEST(MemoryMappingTest, EmptyHandle) { - MappedMemory<uint8_t> mm_a; - MappedMemory<uint8_t> mm_b; - mm_a = std::move(mm_b); - EXPECT_EQ(nullptr, mm_a.buffer()); - EXPECT_EQ(0, mm_a.byte_offset()); - EXPECT_EQ(0, mm_a.byte_length()); - EXPECT_TRUE(mm_a.empty()); - EXPECT_EQ(0, mm_a.size()); - EXPECT_EQ(nullptr, mm_a.data()); - EXPECT_EQ(nullptr, mm_a.mutable_data()); - EXPECT_TRUE(IsFailedPrecondition(mm_a.Subspan().status())); - EXPECT_TRUE(IsFailedPrecondition(mm_a.MutableSubspan().status())); - EXPECT_TRUE(IsFailedPrecondition(mm_a.Invalidate())); - EXPECT_TRUE(IsFailedPrecondition(mm_a.Flush())); - mm_a.reset(); -} - -TEST(MemoryMappingTest, MoveHandle) { - auto buffer = - std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal, - MemoryAccess::kAll, BufferUsage::kAll, 128); - - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kRead, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - ASSERT_OK_AND_ASSIGN(auto mm_a, - buffer->MapMemory<uint8_t>(MemoryAccess::kRead)); - - // Should be able to move the handle around without having any calls. - auto mm_b = std::move(mm_a); - mm_a = std::move(mm_b); - mm_b = std::move(mm_a); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_b.reset(); -} - -TEST(MemoryMappingTest, ReadOnlyAccess) { - auto buffer = - std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal, - MemoryAccess::kRead, BufferUsage::kAll, 128); - - // Should succeed to map for reading. - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kRead, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - ASSERT_OK_AND_ASSIGN(auto mm_r, - buffer->MapMemory<uint8_t>(MemoryAccess::kRead)); - - // Non-mutable access is fine. - EXPECT_EQ(kValidPtr, mm_r.data()); - ASSERT_OK_AND_ASSIGN(auto span, mm_r.Subspan()); - (void)span; - - // Read-only mappings should not be able to get mutable access. - EXPECT_EQ(nullptr, mm_r.mutable_data()); - EXPECT_TRUE(IsPermissionDenied(mm_r.MutableSubspan().status())); - - // Read-only mappings should not be able to call Flush. - EXPECT_TRUE(IsPermissionDenied(mm_r.Flush())); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_r.reset(); - - // Should fail to map for writing. - EXPECT_TRUE(IsPermissionDenied( - buffer->MapMemory<uint8_t>(MemoryAccess::kWrite).status())); -} - -TEST(MemoryMappingTest, ReadWriteAccess) { - auto buffer = std::make_shared<MockBuffer>( - nullptr, MemoryType::kHostLocal, - MemoryAccess::kRead | MemoryAccess::kWrite, BufferUsage::kAll, 128); - - // Should succeed to map for reading and/or writing. - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kRead | MemoryAccess::kWrite, - 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - ASSERT_OK_AND_ASSIGN( - auto mm_rw, - buffer->MapMemory<uint8_t>(MemoryAccess::kRead | MemoryAccess::kWrite)); - - // Everything valid. - EXPECT_EQ(kValidPtr, mm_rw.data()); - ASSERT_OK_AND_ASSIGN(auto span, mm_rw.Subspan()); - EXPECT_EQ(kValidPtr, mm_rw.mutable_data()); - ASSERT_OK_AND_ASSIGN(span, mm_rw.MutableSubspan()); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_rw.reset(); - - // Should fail to map for discard. - EXPECT_TRUE(IsPermissionDenied( - buffer->MapMemory<uint8_t>(MemoryAccess::kDiscardWrite).status())); -} - -TEST(MemoryMappingTest, WriteOnlyAccess) { - auto buffer = std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal, - MemoryAccess::kWrite, - BufferUsage::kAll, 128); - - // Should succeed to map for writing. - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kWrite, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - ASSERT_OK_AND_ASSIGN(auto mm_w, - buffer->MapMemory<uint8_t>(MemoryAccess::kWrite)); - - // Mutable access is valid. - EXPECT_EQ(kValidPtr, mm_w.mutable_data()); - ASSERT_OK_AND_ASSIGN(auto span, mm_w.MutableSubspan()); - (void)span; - - // Write-only mappings should not be able to get non-mutable access. - EXPECT_EQ(nullptr, mm_w.data()); - EXPECT_TRUE(IsPermissionDenied(mm_w.Subspan().status())); - - // Write-only mappings should not be able to call Invalidate. - EXPECT_TRUE(IsPermissionDenied(mm_w.Invalidate())); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_w.reset(); - - // Should fail to map for reading. - EXPECT_TRUE(IsPermissionDenied( - buffer->MapMemory<uint8_t>(MemoryAccess::kRead).status())); - - // Should fail to map for discard. - EXPECT_TRUE(IsPermissionDenied( - buffer->MapMemory<uint8_t>(MemoryAccess::kDiscardWrite).status())); -} - -TEST(MemoryMappingTest, WriteDiscardAccess) { - auto buffer = std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal, - MemoryAccess::kDiscardWrite, - BufferUsage::kAll, 128); - - // Should succeed to map for writing with discard. - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kDiscardWrite, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - ASSERT_OK_AND_ASSIGN(auto mm_dw, - buffer->MapMemory<uint8_t>(MemoryAccess::kDiscardWrite)); - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_dw.reset(); - - // Should also be ok to map for just writing. - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kWrite, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - ASSERT_OK_AND_ASSIGN(auto mm_w, - buffer->MapMemory<uint8_t>(MemoryAccess::kWrite)); - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_w.reset(); - - // Should fail to map for reading. - EXPECT_TRUE(IsPermissionDenied( - buffer->MapMemory<uint8_t>(MemoryAccess::kRead).status())); -} - -TEST(MemoryMappingTest, Subspan) { - auto buffer = - std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal, - MemoryAccess::kAll, BufferUsage::kAll, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kRead, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - ASSERT_OK_AND_ASSIGN(auto mm_r, - buffer->MapMemory<uint8_t>(MemoryAccess::kRead)); - - // Request some valid ranges and ensure the byte offsets are correct. - ASSERT_OK_AND_ASSIGN(auto ss, mm_r.Subspan()); - EXPECT_EQ(kValidPtr, ss.data()); - EXPECT_EQ(128, ss.size()); - ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(100, 2)); - EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 100, ss.data()); - EXPECT_EQ(2, ss.size()); - ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(100, kWholeBuffer)); - EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 100, ss.data()); - EXPECT_EQ(28, ss.size()); - - // Zero length ranges are fine. - ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(0, 0)); - EXPECT_EQ(kValidPtr, ss.data()); - EXPECT_TRUE(ss.empty()); - ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(128, 0)); - EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 128, ss.data()); - EXPECT_TRUE(ss.empty()); - ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(128, kWholeBuffer)); - EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 128, ss.data()); - EXPECT_TRUE(ss.empty()); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_r.reset(); -} - -TEST(MemoryMappingTest, SubspanOutOfRange) { - auto buffer = - std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal, - MemoryAccess::kAll, BufferUsage::kAll, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kRead, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - ASSERT_OK_AND_ASSIGN(auto mm_r, - buffer->MapMemory<uint8_t>(MemoryAccess::kRead)); - - // Try some invalid ranges that would overrun the span. - EXPECT_TRUE(IsOutOfRange(mm_r.Subspan(1234, 0).status())); - EXPECT_TRUE(IsOutOfRange(mm_r.Subspan(1234, 2).status())); - EXPECT_TRUE(IsOutOfRange(mm_r.Subspan(1234, kWholeBuffer).status())); - EXPECT_TRUE(IsOutOfRange(mm_r.Subspan(100, 1234).status())); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_r.reset(); -} - -TEST(MemoryMappingTest, MutableSubspan) { - auto buffer = - std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal, - MemoryAccess::kAll, BufferUsage::kAll, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kWrite, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - ASSERT_OK_AND_ASSIGN(auto mm_w, - buffer->MapMemory<uint8_t>(MemoryAccess::kWrite)); - - // Request some valid ranges and ensure the byte offsets are correct. - ASSERT_OK_AND_ASSIGN(auto ss, mm_w.MutableSubspan()); - EXPECT_EQ(kValidPtr, ss.data()); - EXPECT_EQ(128, ss.size()); - ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(100, 2)); - EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 100, ss.data()); - EXPECT_EQ(2, ss.size()); - ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(100, kWholeBuffer)); - EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 100, ss.data()); - EXPECT_EQ(28, ss.size()); - - // Zero length ranges are fine. - ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(0, 0)); - EXPECT_EQ(kValidPtr, ss.data()); - EXPECT_TRUE(ss.empty()); - ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(128, 0)); - EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 128, ss.data()); - EXPECT_TRUE(ss.empty()); - ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(128, kWholeBuffer)); - EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 128, ss.data()); - EXPECT_TRUE(ss.empty()); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_w.reset(); -} - -TEST(MemoryMappingTest, MutableSubspanOutOfRange) { - auto buffer = - std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal, - MemoryAccess::kAll, BufferUsage::kAll, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kWrite, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - ASSERT_OK_AND_ASSIGN(auto mm_w, - buffer->MapMemory<uint8_t>(MemoryAccess::kWrite)); - - // Try some invalid ranges that would overrun the span. - EXPECT_TRUE(IsOutOfRange(mm_w.MutableSubspan(1234, 0).status())); - EXPECT_TRUE(IsOutOfRange(mm_w.MutableSubspan(1234, 2).status())); - EXPECT_TRUE(IsOutOfRange(mm_w.MutableSubspan(1234, kWholeBuffer).status())); - EXPECT_TRUE(IsOutOfRange(mm_w.MutableSubspan(100, 1234).status())); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_w.reset(); -} - -TEST(MemoryMappingTest, ElementOperator) { - auto buffer = - std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal, - MemoryAccess::kAll, BufferUsage::kAll, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kRead, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - ASSERT_OK_AND_ASSIGN(auto mm_r, - buffer->MapMemory<uint8_t>(MemoryAccess::kRead)); - - // Just verify we are getting the expected pointer back. - EXPECT_EQ(kValidPtr, &mm_r[0]); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_r.reset(); -} - -TEST(MemoryMappingTest, Invalidate) { - auto buffer = - std::make_shared<MockBuffer>(nullptr, MemoryType::kHostVisible, - MemoryAccess::kAll, BufferUsage::kAll, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kRead, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - ASSERT_OK_AND_ASSIGN(auto mm_r, - buffer->MapMemory<uint8_t>(MemoryAccess::kRead)); - - // Invalidate a few ways. - EXPECT_CALL(*buffer, InvalidateMappedMemoryImpl(0, 128)) - .WillOnce(Return(OkStatus())); - EXPECT_OK(mm_r.Invalidate()); - EXPECT_CALL(*buffer, InvalidateMappedMemoryImpl(100, 2)) - .WillOnce(Return(OkStatus())); - EXPECT_OK(mm_r.Invalidate(100, 2)); - EXPECT_CALL(*buffer, InvalidateMappedMemoryImpl(100, 28)) - .WillOnce(Return(OkStatus())); - EXPECT_OK(mm_r.Invalidate(100, kWholeBuffer)); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_r.reset(); -} - -TEST(MemoryMappingTest, InvalidateOutOfRange) { - auto buffer = - std::make_shared<MockBuffer>(nullptr, MemoryType::kHostVisible, - MemoryAccess::kAll, BufferUsage::kAll, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kRead, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - ASSERT_OK_AND_ASSIGN(auto mm_r, - buffer->MapMemory<uint8_t>(MemoryAccess::kRead)); - - // Try to invalidate invalid ranges. - EXPECT_TRUE(IsOutOfRange(mm_r.Invalidate(1234, 0))); - EXPECT_TRUE(IsOutOfRange(mm_r.Invalidate(1234, 12345))); - EXPECT_TRUE(IsOutOfRange(mm_r.Invalidate(1234, kWholeBuffer))); - EXPECT_TRUE(IsOutOfRange(mm_r.Invalidate(1, 1234))); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_r.reset(); -} - -TEST(MemoryMappingTest, InvalidateBadMode) { - // Invalidate is not required on coherent memory. - auto coherent_buffer = - std::make_shared<MockBuffer>(nullptr, MemoryType::kHostLocal, - MemoryAccess::kAll, BufferUsage::kAll, 128); - EXPECT_CALL(*coherent_buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kRead, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - ASSERT_OK_AND_ASSIGN( - auto mm_r, coherent_buffer->MapMemory<uint8_t>(MemoryAccess::kRead)); - EXPECT_TRUE(IsPermissionDenied(mm_r.Invalidate())); - EXPECT_CALL(*coherent_buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_r.reset(); -} - -TEST(MemoryMappingTest, Flush) { - auto buffer = std::make_shared<MockBuffer>( - nullptr, MemoryType::kHostVisible | MemoryType::kHostCached, - MemoryAccess::kAll, BufferUsage::kAll, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kWrite, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - ASSERT_OK_AND_ASSIGN(auto mm_w, - buffer->MapMemory<uint8_t>(MemoryAccess::kWrite)); - - // Flush a few ways. - EXPECT_CALL(*buffer, FlushMappedMemoryImpl(0, 128)) - .WillOnce(Return(OkStatus())); - EXPECT_OK(mm_w.Flush()); - EXPECT_CALL(*buffer, FlushMappedMemoryImpl(100, 2)) - .WillOnce(Return(OkStatus())); - EXPECT_OK(mm_w.Flush(100, 2)); - EXPECT_CALL(*buffer, FlushMappedMemoryImpl(100, 28)) - .WillOnce(Return(OkStatus())); - EXPECT_OK(mm_w.Flush(100, kWholeBuffer)); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_w.reset(); -} - -TEST(MemoryMappingTest, FlushOutOfRange) { - auto buffer = std::make_shared<MockBuffer>( - nullptr, MemoryType::kHostVisible | MemoryType::kHostCached, - MemoryAccess::kAll, BufferUsage::kAll, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kWrite, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - ASSERT_OK_AND_ASSIGN(auto mm_w, - buffer->MapMemory<uint8_t>(MemoryAccess::kWrite)); - - // Try to flush invalid ranges. - EXPECT_TRUE(IsOutOfRange(mm_w.Flush(1234, 0))); - EXPECT_TRUE(IsOutOfRange(mm_w.Flush(1234, 12345))); - EXPECT_TRUE(IsOutOfRange(mm_w.Flush(1234, kWholeBuffer))); - EXPECT_TRUE(IsOutOfRange(mm_w.Flush(1, 1234))); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_w.reset(); -} - -TEST(MemoryMappingTest, FlushBadMode) { - // Flush is not required on uncached memory. - auto uncached_buffer = - std::make_shared<MockBuffer>(nullptr, MemoryType::kHostVisible, - MemoryAccess::kAll, BufferUsage::kAll, 128); - EXPECT_CALL(*uncached_buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - MemoryAccess::kWrite, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - ASSERT_OK_AND_ASSIGN( - auto mm_w, uncached_buffer->MapMemory<uint8_t>(MemoryAccess::kWrite)); - EXPECT_TRUE(IsPermissionDenied(mm_w.Flush())); - EXPECT_CALL(*uncached_buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_w.reset(); -} - -} // namespace -} // namespace hal -} // namespace iree
diff --git a/iree/hal/buffer_test.cc b/iree/hal/buffer_test.cc deleted file mode 100644 index 1d9fced..0000000 --- a/iree/hal/buffer_test.cc +++ /dev/null
@@ -1,1000 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Tests for the shared buffer functionality and host heap buffers. -// This does not test device-specific buffer implementations; see the device -// code for associated tests. - -#include "iree/hal/buffer.h" - -#include <vector> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "iree/base/status_matchers.h" -#include "iree/hal/heap_buffer.h" - -namespace iree { -namespace hal { -namespace { - -using ::testing::_; -using ::testing::ElementsAre; -using ::testing::Eq; -using ::testing::Not; - -TEST(BufferTest, Allocate) { - auto buffer = - HeapBuffer::Allocate(BufferUsage::kTransfer | BufferUsage::kMapping, 14); - EXPECT_NE(nullptr, buffer->allocator()); - EXPECT_EQ(MemoryAccess::kAll, buffer->allowed_access()); - EXPECT_EQ(MemoryType::kHostLocal, buffer->memory_type()); - EXPECT_EQ(BufferUsage::kTransfer | BufferUsage::kMapping, buffer->usage()); - - // We don't currently do any padding on the host. - // Other implementations may differ. - EXPECT_GE(14, buffer->allocation_size()); - EXPECT_EQ(0, buffer->byte_offset()); - EXPECT_EQ(14, buffer->byte_length()); - - // Data should be zeroed by default. - std::vector<uint8_t> zero_data(buffer->allocation_size()); - std::vector<uint8_t> actual_data(buffer->allocation_size()); - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(zero_data)); -} - -TEST(BufferTest, AllocateZeroLength) { - auto buffer = - HeapBuffer::Allocate(BufferUsage::kTransfer | BufferUsage::kMapping, 0); - EXPECT_NE(nullptr, buffer->allocator()); - EXPECT_EQ(MemoryType::kHostLocal, buffer->memory_type()); - EXPECT_EQ(BufferUsage::kTransfer | BufferUsage::kMapping, buffer->usage()); - EXPECT_EQ(0, buffer->allocation_size()); -} - -TEST(BufferTest, AllocateCopy) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto buffer = - HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, - src_data.data(), src_data.size()); - EXPECT_NE(nullptr, buffer->allocator()); - EXPECT_GE(src_data.size(), buffer->allocation_size()); - - // Data should have been copied. - std::vector<uint8_t> actual_data(src_data.size()); - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(src_data)); - - // Modify the source data and ensure it is not reflected in the buffer. - src_data[0] = 0x88; - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Not(Eq(src_data))); -} - -TEST(BufferTest, AllocateCopyZeroLength) { - std::vector<uint8_t> src_data; - auto buffer = - HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, - src_data.data(), src_data.size()); - EXPECT_NE(nullptr, buffer->allocator()); - EXPECT_EQ(0, buffer->allocation_size()); -} - -TEST(BufferTest, AllocateCopyTyped) { - std::vector<int32_t> src_data = {0, 1, 2, 3}; - auto buffer = - HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, - absl::MakeConstSpan(src_data)); - EXPECT_NE(nullptr, buffer->allocator()); - EXPECT_EQ(MemoryType::kHostLocal, buffer->memory_type()); - EXPECT_EQ(BufferUsage::kTransfer | BufferUsage::kMapping, buffer->usage()); - EXPECT_GE(src_data.size() * sizeof(int32_t), buffer->allocation_size()); - - // Data should have been copied. - std::vector<int32_t> actual_data(src_data.size()); - EXPECT_OK(buffer->ReadData(0, actual_data.data(), - actual_data.size() * sizeof(int32_t))); - EXPECT_THAT(actual_data, Eq(src_data)); -} - -TEST(BufferTest, WrapConstant) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto buffer = HeapBuffer::Wrap(MemoryType::kHostLocal, - BufferUsage::kTransfer | BufferUsage::kMapping, - absl::MakeConstSpan(src_data)); - EXPECT_EQ(MemoryType::kHostLocal, buffer->memory_type()); - EXPECT_EQ(BufferUsage::kTransfer | BufferUsage::kMapping, buffer->usage()); - EXPECT_EQ(src_data.size(), buffer->allocation_size()); - - // src_data and buffer should match after the wrapping. - std::vector<uint8_t> actual_data(src_data.size()); - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(src_data)); - - // Modify the source data directly. - src_data[0] = 123; - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(src_data)); - - // Attempts to modify the buffer should fail. - std::vector<uint8_t> new_data = {3, 2, 1, 0}; - EXPECT_TRUE(IsPermissionDenied( - buffer->WriteData(0, new_data.data(), new_data.size()))); -} - -TEST(BufferTest, WrapMutable) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto buffer = HeapBuffer::WrapMutable( - MemoryType::kHostLocal, MemoryAccess::kAll, - BufferUsage::kTransfer | BufferUsage::kMapping, absl::MakeSpan(src_data)); - EXPECT_EQ(MemoryType::kHostLocal, buffer->memory_type()); - EXPECT_EQ(BufferUsage::kTransfer | BufferUsage::kMapping, buffer->usage()); - EXPECT_EQ(src_data.size(), buffer->allocation_size()); - - // src_data and buffer should match after the wrapping. - std::vector<uint8_t> actual_data(src_data.size()); - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(src_data)); - - // Modify the source data directly. - src_data[0] = 123; - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(src_data)); - - // Modify the source data via the Buffer and ensure reflected in src_data. - std::vector<uint8_t> new_data = {3, 2, 1, 0}; - EXPECT_OK(buffer->WriteData(0, new_data.data(), new_data.size())); - EXPECT_THAT(src_data, Eq(new_data)); -} - -TEST(BufferTest, WrapExternal) { - // This is not fully supported yet, but does let us verify that the validation - // of memory types is working. - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto buffer = HeapBuffer::Wrap(MemoryType::kDeviceLocal, BufferUsage::kAll, - absl::MakeConstSpan(src_data)); - EXPECT_EQ(MemoryType::kDeviceLocal, buffer->memory_type()); - - // Should fail (for now) as the buffer is not host visible. - EXPECT_TRUE(IsPermissionDenied(buffer->Fill8(0, kWholeBuffer, 0x99u))); -} - -TEST(BufferTest, DoesOverlap) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto parent_buffer = - HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, - src_data.data(), src_data.size()); - - // A buffer should overlap with itself. - EXPECT_FALSE(Buffer::DoesOverlap(parent_buffer.get(), 0, 1, - parent_buffer.get(), 1, 1)); - EXPECT_TRUE(Buffer::DoesOverlap(parent_buffer.get(), 0, 1, - parent_buffer.get(), 0, 1)); - - // Zero length buffers never overlap. - EXPECT_FALSE(Buffer::DoesOverlap(parent_buffer.get(), 1, 1, - parent_buffer.get(), 1, 0)); - - // Subspans should offset within their allocation. - ASSERT_OK_AND_ASSIGN(auto subspan_buffer_0, - Buffer::Subspan(parent_buffer, 1, 2)); - ASSERT_OK_AND_ASSIGN(auto subspan_buffer_1, - Buffer::Subspan(parent_buffer, 2, 2)); - EXPECT_FALSE(Buffer::DoesOverlap(subspan_buffer_0.get(), 0, 1, - subspan_buffer_1.get(), 0, 1)); - EXPECT_TRUE(Buffer::DoesOverlap(subspan_buffer_0.get(), 1, 1, - subspan_buffer_1.get(), 0, 1)); - - // Mixing subspans and normal buffers. - EXPECT_FALSE(Buffer::DoesOverlap(parent_buffer.get(), 0, 1, - subspan_buffer_0.get(), 0, 1)); - EXPECT_TRUE(Buffer::DoesOverlap(parent_buffer.get(), 1, 2, - subspan_buffer_0.get(), 1, 1)); - - // Independent buffers should not be able to overlap. - auto other_buffer = HeapBuffer::Allocate(BufferUsage::kAll, 128); - EXPECT_FALSE(Buffer::DoesOverlap(parent_buffer.get(), 0, kWholeBuffer, - other_buffer.get(), 0, kWholeBuffer)); -} - -TEST(BufferTest, Subspan) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto parent_buffer = - HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, - src_data.data(), src_data.size()); - ASSERT_TRUE(parent_buffer); - - // Create a subspan of the buffer. - ASSERT_OK_AND_ASSIGN(auto subspan_buffer, - Buffer::Subspan(parent_buffer, 1, 2)); - ASSERT_TRUE(subspan_buffer); - EXPECT_EQ(1, subspan_buffer->byte_offset()); - EXPECT_EQ(2, subspan_buffer->byte_length()); - - // Modifications to either buffer should appear in the other. - EXPECT_OK(subspan_buffer->Fill8(1, kWholeBuffer, 0xFFu)); - std::vector<uint8_t> actual_data(src_data.size()); - EXPECT_OK(parent_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 1, 0xFF, 3)); - - // Subspans should be able to create subspans. - // NOTE: offset is from the original buffer. - ASSERT_OK_AND_ASSIGN(auto subsubspan_buffer, - Buffer::Subspan(subspan_buffer, 1, 1)); - ASSERT_TRUE(subsubspan_buffer); - EXPECT_EQ(2, subsubspan_buffer->byte_offset()); - EXPECT_EQ(1, subsubspan_buffer->byte_length()); - - // Zero length subspans are fine. - ASSERT_OK_AND_ASSIGN(auto zero_subspan_buffer, - Buffer::Subspan(parent_buffer, 0, 0)); - ASSERT_TRUE(zero_subspan_buffer); - EXPECT_EQ(0, zero_subspan_buffer->byte_offset()); - EXPECT_EQ(0, zero_subspan_buffer->byte_length()); - - // Subspan with kWholeBuffer should get the remaining size (or zero). - ASSERT_OK_AND_ASSIGN(auto whole_subspan_buffer, - Buffer::Subspan(parent_buffer, 1, kWholeBuffer)); - ASSERT_TRUE(whole_subspan_buffer); - EXPECT_EQ(1, whole_subspan_buffer->byte_offset()); - EXPECT_EQ(3, whole_subspan_buffer->byte_length()); - - // Zero length subspans are fine. - ASSERT_OK(Buffer::Subspan(subspan_buffer, 2, 0)); - ASSERT_OK(Buffer::Subspan(subspan_buffer, 2, kWholeBuffer)); -} - -TEST(BufferTest, SubspanIdentity) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto parent_buffer = - HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, - src_data.data(), src_data.size()); - - // Asking for a subspan of the entire buffer should return the same buffer. - // Mostly an optimization. - EXPECT_EQ(parent_buffer.get(), - Buffer::Subspan(parent_buffer, 0, kWholeBuffer).ValueOrDie().get()); - EXPECT_EQ(parent_buffer.get(), - Buffer::Subspan(parent_buffer, 0, 4).ValueOrDie().get()); -} - -TEST(BufferTest, SubspanOutOfRange) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto parent_buffer = - HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, - src_data.data(), src_data.size()); - ASSERT_TRUE(parent_buffer); - - // Create a subspan of the buffer. - ASSERT_OK_AND_ASSIGN(auto subspan_buffer, - Buffer::Subspan(parent_buffer, 1, 2)); - ASSERT_TRUE(subspan_buffer); - EXPECT_EQ(1, subspan_buffer->byte_offset()); - EXPECT_EQ(2, subspan_buffer->byte_length()); - - // Try to make subspans from invalid ranges. - EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(parent_buffer, 5, 0).status())); - EXPECT_TRUE( - IsOutOfRange(Buffer::Subspan(parent_buffer, 5, kWholeBuffer).status())); - EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(parent_buffer, 4, 1).status())); - EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(parent_buffer, 0, 123).status())); - EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(subspan_buffer, 1, 2).status())); - EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(subspan_buffer, 0, 44).status())); -} - -TEST(BufferTest, Fill8) { - auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 5); - ASSERT_TRUE(buffer); - - // Data should be zeroed by default. - std::vector<uint8_t> actual_data(buffer->allocation_size()); - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 0, 0, 0, 0)); - - // Fill with a sentinel. - EXPECT_OK(buffer->Fill8(0, buffer->allocation_size(), 0x33u)); - - // Verify data. - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x33, 0x33, 0x33, 0x33, 0x33)); - - // Zero fills are fine. - EXPECT_OK(buffer->Fill8(0, 0, 0x44u)); - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x33, 0x33, 0x33, 0x33, 0x33)); - - // Fill the remaining parts of the buffer by using kWholeBuffer. - EXPECT_OK(buffer->Fill8(2, kWholeBuffer, 0x55u)); - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x33, 0x33, 0x55, 0x55, 0x55)); - - // Fill a small region of the buffer. - EXPECT_OK(buffer->Fill8(1, 1, 0x66u)); - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x33, 0x66, 0x55, 0x55, 0x55)); - - // Whole buffer helper. - EXPECT_OK(buffer->Fill8(0x99u)); - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x99, 0x99, 0x99, 0x99, 0x99)); -} - -TEST(BufferTest, Fill8OutOfRange) { - auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 5); - ASSERT_TRUE(buffer); - - // Fill with a sentinel. - EXPECT_OK(buffer->Fill8(0, buffer->allocation_size(), 0x33u)); - - // Try to fill with invalid ranges. - EXPECT_TRUE(IsOutOfRange(buffer->Fill8(1, 444, 0x44u))); - EXPECT_TRUE(IsOutOfRange(buffer->Fill8(123, 444, 0x44u))); - EXPECT_TRUE(IsOutOfRange(buffer->Fill8(123, 1, 0x44u))); - EXPECT_TRUE(IsOutOfRange(buffer->Fill8(1, 444, 0x44u))); - - // Ensure nothing happened with the bad ranges. - std::vector<uint8_t> actual_data(buffer->allocation_size()); - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x33, 0x33, 0x33, 0x33, 0x33)); -} - -TEST(BufferTest, Fill8BadMode) { - // Fail to fill buffers not supporting mapping. - auto nonmapping_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4); - EXPECT_TRUE( - IsPermissionDenied(nonmapping_buffer->Fill8(0, kWholeBuffer, 0x99u))); - - // Fail to fill constant buffers. - std::vector<uint8_t> const_data = {1, 2, 3}; - auto constant_buffer = - HeapBuffer::Wrap(MemoryType::kHostLocal, BufferUsage::kMapping, - absl::MakeConstSpan(const_data)); - EXPECT_TRUE( - IsPermissionDenied(constant_buffer->Fill8(0, kWholeBuffer, 0x99u))); -} - -TEST(BufferTest, Fill8Subspan) { - auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 5); - ASSERT_TRUE(buffer); - - // Test on subspan. - std::vector<uint8_t> actual_data(buffer->allocation_size()); - ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 1, 3)); - EXPECT_OK(subspan_buffer->Fill8(2, kWholeBuffer, 0xDDu)); - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 0, 0, 0xDD, 0)); -} - -TEST(BufferTest, Fill16) { - auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9); - ASSERT_TRUE(buffer); - - // Data should be zeroed by default. - std::vector<uint8_t> actual_data(buffer->allocation_size()); - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 0, 0, 0, 0, 0, 0, 0, 0)); - - // Fill with a sentinel. - EXPECT_OK(buffer->Fill16(0, 4, 0x1122u)); - - // Verify data. - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x22, 0x11, 0x22, 0x11, 0, 0, 0, 0, 0)); - - // Zero fills are fine. - EXPECT_OK(buffer->Fill16(0, 0, 0x5566u)); - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x22, 0x11, 0x22, 0x11, 0, 0, 0, 0, 0)); - - // Fill the remaining parts of the buffer by using kWholeBuffer. - auto aligned_buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 8); - EXPECT_OK(aligned_buffer->Fill16(4, kWholeBuffer, 0x5566u)); - std::vector<uint8_t> aligned_actual_data(aligned_buffer->allocation_size()); - EXPECT_OK(aligned_buffer->ReadData(0, aligned_actual_data.data(), - aligned_actual_data.size())); - EXPECT_THAT(aligned_actual_data, - ElementsAre(0, 0, 0, 0, 0x66, 0x55, 0x66, 0x55)); - - // Whole buffer helper. - EXPECT_OK(aligned_buffer->Fill16(0x5566u)); - EXPECT_OK(aligned_buffer->ReadData(0, aligned_actual_data.data(), - aligned_actual_data.size())); - EXPECT_THAT(aligned_actual_data, - ElementsAre(0x66, 0x55, 0x66, 0x55, 0x66, 0x55, 0x66, 0x55)); -} - -TEST(BufferTest, Fill16OutOfRange) { - auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9); - ASSERT_TRUE(buffer); - - // Try to fill with invalid ranges. - EXPECT_TRUE(IsOutOfRange(buffer->Fill16(4, 444, 0x5566u))); - EXPECT_TRUE(IsOutOfRange(buffer->Fill16(128, 444, 0x5566u))); - EXPECT_TRUE(IsOutOfRange(buffer->Fill16(128, 4, 0x5566u))); - EXPECT_TRUE(IsOutOfRange(buffer->Fill16(4, 444, 0x5566u))); -} - -TEST(BufferTest, Fill16Unaligned) { - auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9); - ASSERT_TRUE(buffer); - - // Try to fill with unaligned ranges. - EXPECT_TRUE(IsInvalidArgument(buffer->Fill16(1, 4, 0x5566u))); - EXPECT_TRUE(IsInvalidArgument(buffer->Fill16(0, 5, 0x5566u))); -} - -TEST(BufferTest, Fill16BadMode) { - // Fail to fill buffers not supporting mapping. - auto nonmapping_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4); - EXPECT_TRUE( - IsPermissionDenied(nonmapping_buffer->Fill16(0, kWholeBuffer, 0x99AAu))); - - // Fail to fill constant buffers. - std::vector<uint8_t> const_data = {1, 2, 3}; - auto constant_buffer = - HeapBuffer::Wrap(MemoryType::kHostLocal, BufferUsage::kMapping, - absl::MakeConstSpan(const_data)); - EXPECT_TRUE( - IsPermissionDenied(constant_buffer->Fill16(0, kWholeBuffer, 0x99AAu))); -} - -TEST(BufferTest, Fill16Subspan) { - auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9); - ASSERT_TRUE(buffer); - - // Fill with a sentinel. - EXPECT_OK(buffer->Fill16(0, 4, 0x1122u)); - - // Test on subspan. - std::vector<uint8_t> actual_data(buffer->allocation_size()); - ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 2, 4)); - EXPECT_OK(subspan_buffer->Fill16(2, kWholeBuffer, 0xAABBu)); - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, - ElementsAre(0x22, 0x11, 0x22, 0x11, 0xBB, 0xAA, 0, 0, 0)); -} - -TEST(BufferTest, Fill32) { - auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9); - ASSERT_TRUE(buffer); - - // Data should be zeroed by default. - std::vector<uint8_t> actual_data(buffer->allocation_size()); - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 0, 0, 0, 0, 0, 0, 0, 0)); - - // Fill with a sentinel. - EXPECT_OK(buffer->Fill32(0, 8, 0x11223344u)); - - // Verify data. - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, - ElementsAre(0x44, 0x33, 0x22, 0x11, 0x44, 0x33, 0x22, 0x11, 0)); - - // Zero fills are fine. - EXPECT_OK(buffer->Fill32(0, 0, 0x55667788u)); - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, - ElementsAre(0x44, 0x33, 0x22, 0x11, 0x44, 0x33, 0x22, 0x11, 0)); - - // Fill the remaining parts of the buffer by using kWholeBuffer. - auto aligned_buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 8); - EXPECT_OK(aligned_buffer->Fill32(4, kWholeBuffer, 0x55667788u)); - std::vector<uint8_t> aligned_actual_data(aligned_buffer->allocation_size()); - EXPECT_OK(aligned_buffer->ReadData(0, aligned_actual_data.data(), - aligned_actual_data.size())); - EXPECT_THAT(aligned_actual_data, - ElementsAre(0, 0, 0, 0, 0x88, 0x77, 0x66, 0x55)); - - // Whole buffer helper. - EXPECT_OK(aligned_buffer->Fill32(0x55667788u)); - EXPECT_OK(aligned_buffer->ReadData(0, aligned_actual_data.data(), - aligned_actual_data.size())); - EXPECT_THAT(aligned_actual_data, - ElementsAre(0x88, 0x77, 0x66, 0x55, 0x88, 0x77, 0x66, 0x55)); -} - -TEST(BufferTest, Fill32OutOfRange) { - auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9); - ASSERT_TRUE(buffer); - - // Try to fill with invalid ranges. - EXPECT_TRUE(IsOutOfRange(buffer->Fill32(4, 444, 0x55667788u))); - EXPECT_TRUE(IsOutOfRange(buffer->Fill32(128, 444, 0x55667788u))); - EXPECT_TRUE(IsOutOfRange(buffer->Fill32(128, 4, 0x55667788u))); - EXPECT_TRUE(IsOutOfRange(buffer->Fill32(4, 444, 0x55667788u))); -} - -TEST(BufferTest, Fill32Unaligned) { - auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9); - ASSERT_TRUE(buffer); - - // Try to fill with unaligned ranges. - EXPECT_TRUE(IsInvalidArgument(buffer->Fill32(1, 4, 0x55667788u))); - EXPECT_TRUE(IsInvalidArgument(buffer->Fill32(0, 5, 0x55667788u))); -} - -TEST(BufferTest, Fill32BadMode) { - // Fail to fill buffers not supporting mapping. - auto nonmapping_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4); - EXPECT_TRUE(IsPermissionDenied( - nonmapping_buffer->Fill32(0, kWholeBuffer, 0x99AABBCCu))); - - // Fail to fill constant buffers. - std::vector<uint8_t> const_data = {1, 2, 3}; - auto constant_buffer = - HeapBuffer::Wrap(MemoryType::kHostLocal, BufferUsage::kMapping, - absl::MakeConstSpan(const_data)); - EXPECT_TRUE(IsPermissionDenied( - constant_buffer->Fill32(0, kWholeBuffer, 0x99AABBCCu))); -} - -TEST(BufferTest, Fill32Subspan) { - auto buffer = HeapBuffer::Allocate(BufferUsage::kMapping, 9); - ASSERT_TRUE(buffer); - - // Fill with a sentinel. - EXPECT_OK(buffer->Fill32(0, 8, 0x11223344u)); - - // Test on subspan. - std::vector<uint8_t> actual_data(buffer->allocation_size()); - ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 4, 4)); - EXPECT_OK(subspan_buffer->Fill32(0, kWholeBuffer, 0xAABBCCDDu)); - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, - ElementsAre(0x44, 0x33, 0x22, 0x11, 0xDD, 0xCC, 0xBB, 0xAA, 0)); -} - -TEST(BufferTest, ReadData) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto buffer = - HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, - src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Read the data back. - std::vector<uint8_t> actual_data(src_data.size()); - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(src_data)); - - // Reading zero bytes is valid. - std::vector<uint8_t> zero_data(0); - EXPECT_OK(buffer->ReadData(1, zero_data.data(), 0)); - - // Read a portion of the data. - std::vector<uint8_t> partial_data(2); - EXPECT_OK(buffer->ReadData(1, partial_data.data(), 2)); - EXPECT_THAT(partial_data, ElementsAre(1, 2)); -} - -TEST(BufferTest, ReadDataOutOfRange) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto buffer = - HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, - src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Try to read out of range. - std::vector<uint8_t> partial_data(2); - EXPECT_TRUE(IsOutOfRange(buffer->ReadData(0, partial_data.data(), 444))); - EXPECT_TRUE(IsOutOfRange(buffer->ReadData(1230, partial_data.data(), 444))); - EXPECT_TRUE(IsOutOfRange(buffer->ReadData(1230, partial_data.data(), 1))); - EXPECT_TRUE(IsInvalidArgument( - buffer->ReadData(0, partial_data.data(), kWholeBuffer))); -} - -TEST(BufferTest, ReadDataBadMode) { - // Fail to read buffers not supporting mapping. - std::vector<uint8_t> actual_data(1); - auto nonmapping_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4); - EXPECT_TRUE(IsPermissionDenied( - nonmapping_buffer->ReadData(0, actual_data.data(), 1))); -} - -TEST(BufferTest, ReadDataSubspan) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto buffer = - HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, - src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Test on subspan. - std::vector<uint8_t> subspan_data(1); - ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 1, 2)); - EXPECT_OK(subspan_buffer->ReadData(1, subspan_data.data(), 1)); - EXPECT_THAT(subspan_data, ElementsAre(2)); -} - -TEST(BufferTest, WriteData) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto buffer = - HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, - src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Read the data back - should still match. - std::vector<uint8_t> actual_data(src_data.size()); - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(src_data)); - - // Write over the entire buffer. - std::vector<uint8_t> new_data = {10, 20, 30, 40}; - EXPECT_OK(buffer->WriteData(0, new_data.data(), new_data.size())); - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(new_data)); - - // Writing zero bytes is valid. - std::vector<uint8_t> zero_data; - EXPECT_OK(buffer->WriteData(0, zero_data.data(), 0)); - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(new_data)); - - // Write over a portion of the buffer. - std::vector<uint8_t> partial_data = {99}; - EXPECT_OK(buffer->WriteData(1, partial_data.data(), partial_data.size())); - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(10, 99, 30, 40)); -} - -TEST(BufferTest, WriteDataOutOfRange) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto buffer = - HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, - src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Try to write out of range. - std::vector<uint8_t> partial_data = {99}; - EXPECT_TRUE(IsOutOfRange(buffer->WriteData(0, partial_data.data(), 444))); - EXPECT_TRUE(IsOutOfRange(buffer->WriteData(1230, partial_data.data(), 444))); - EXPECT_TRUE(IsOutOfRange(buffer->WriteData(1230, partial_data.data(), 1))); - EXPECT_TRUE(IsInvalidArgument( - buffer->WriteData(0, partial_data.data(), kWholeBuffer))); -} - -TEST(BufferTest, WriteDataBadMode) { - std::vector<uint8_t> actual_data(4); - - // Fail to write buffers not supporting mapping. - auto nonmapping_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4); - EXPECT_TRUE(IsPermissionDenied( - nonmapping_buffer->WriteData(0, actual_data.data(), 1))); - - // Fail to write to constant buffers. - std::vector<uint8_t> const_data = {1, 2, 3}; - auto constant_buffer = - HeapBuffer::Wrap(MemoryType::kHostLocal, BufferUsage::kTransfer, - absl::MakeConstSpan(const_data)); - EXPECT_TRUE( - IsPermissionDenied(constant_buffer->WriteData(0, actual_data.data(), 2))); -} - -TEST(BufferTest, WriteDataSubspan) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto buffer = - HeapBuffer::AllocateCopy(BufferUsage::kTransfer | BufferUsage::kMapping, - src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Test on subspan. - std::vector<uint8_t> subspan_data = {0xAA}; - ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 1, 2)); - EXPECT_OK(subspan_buffer->WriteData(1, subspan_data.data(), 1)); - std::vector<uint8_t> actual_data(src_data.size()); - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 1, 0xAA, 3)); -} - -TEST(BufferTest, CopyData) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto src_buffer = - HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer, - src_data.data(), src_data.size()); - ASSERT_TRUE(src_buffer); - std::vector<uint8_t> dst_data = {0, 1, 2, 3, 4}; - auto dst_buffer = - HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer, - dst_data.data(), dst_data.size()); - ASSERT_TRUE(dst_buffer); - - // Copy of length 0 should not change the dest buffer. - EXPECT_OK(dst_buffer->CopyData(0, src_buffer.get(), 0, 0)); - std::vector<uint8_t> actual_data(dst_data.size()); - EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(dst_data)); - - // Copy a subrange of the buffer. - EXPECT_OK(dst_buffer->CopyData(1, src_buffer.get(), 2, 2)); - EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 2, 3, 3, 4)); - - // Copy the entire buffer using kWholeBuffer. This will adjust sizes - // to ensure that the min buffer is taken. We test both src and dst buffer - // offset/length calculations (note that some may end up as 0 copies). - EXPECT_OK(dst_buffer->CopyData(3, src_buffer.get(), 0, kWholeBuffer)); - EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 2, 3, 0, 1)); - EXPECT_OK(dst_buffer->CopyData(0, src_buffer.get(), 2, kWholeBuffer)); - EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(2, 3, 3, 0, 1)); - EXPECT_OK(dst_buffer->CopyData(0, src_buffer.get(), 3, kWholeBuffer)); - EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(3, 3, 3, 0, 1)); - EXPECT_OK(dst_buffer->CopyData(4, src_buffer.get(), 0, kWholeBuffer)); - EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(3, 3, 3, 0, 0)); -} - -TEST(BufferTest, CopyDataOutOfRange) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto src_buffer = - HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer, - src_data.data(), src_data.size()); - ASSERT_TRUE(src_buffer); - std::vector<uint8_t> dst_data = {0, 1, 2, 3, 4}; - auto dst_buffer = - HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer, - dst_data.data(), dst_data.size()); - ASSERT_TRUE(dst_buffer); - - // Try to copy out of range of source and dest. - EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(123, src_buffer.get(), 0, 1))); - EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(4, src_buffer.get(), 0, 4))); - EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(0, src_buffer.get(), 123, 1))); - EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(0, src_buffer.get(), 0, 123))); - EXPECT_TRUE( - IsOutOfRange(dst_buffer->CopyData(123, src_buffer.get(), 123, 123))); - EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(0, src_buffer.get(), 123, 0))); -} - -TEST(BufferTest, CopyDataOverlapping) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto src_buffer = - HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer, - src_data.data(), src_data.size()); - ASSERT_TRUE(src_buffer); - std::vector<uint8_t> dst_data = {0, 1, 2, 3, 4}; - auto dst_buffer = - HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer, - dst_data.data(), dst_data.size()); - ASSERT_TRUE(dst_buffer); - - // Test overlap. Non-overlapping regions should be fine, otherwise fail. - std::vector<uint8_t> actual_data(dst_data.size()); - EXPECT_OK(dst_buffer->CopyData(0, dst_buffer.get(), 4, 1)); - EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(4, 1, 2, 3, 4)); - EXPECT_TRUE( - IsInvalidArgument(dst_buffer->CopyData(2, dst_buffer.get(), 0, 3))); - EXPECT_TRUE( - IsInvalidArgument(dst_buffer->CopyData(0, dst_buffer.get(), 0, 3))); -} - -TEST(BufferTest, CopyDataBadMode) { - // Both source and target buffers must support mapping. - auto nonmapping_src_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4); - auto nonmapping_dst_buffer = HeapBuffer::Allocate(BufferUsage::kTransfer, 4); - EXPECT_TRUE(IsPermissionDenied(nonmapping_dst_buffer->CopyData( - 0, nonmapping_src_buffer.get(), 0, kWholeBuffer))); - EXPECT_TRUE(IsPermissionDenied(nonmapping_src_buffer->CopyData( - 0, nonmapping_dst_buffer.get(), 0, kWholeBuffer))); - - // Fail to copy into to constant buffers. - std::vector<uint8_t> const_data = {1, 2, 3}; - auto constant_buffer = - HeapBuffer::Wrap(MemoryType::kHostLocal, BufferUsage::kTransfer, - absl::MakeConstSpan(const_data)); - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto src_buffer = - HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer, - src_data.data(), src_data.size()); - EXPECT_TRUE(IsPermissionDenied( - constant_buffer->CopyData(0, src_buffer.get(), 0, kWholeBuffer))); -} - -TEST(BufferTest, CopyDataSubspan) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto src_buffer = - HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer, - src_data.data(), src_data.size()); - ASSERT_TRUE(src_buffer); - std::vector<uint8_t> dst_data = {0, 1, 2, 3, 4}; - auto dst_buffer = - HeapBuffer::AllocateCopy(BufferUsage::kMapping | BufferUsage::kTransfer, - dst_data.data(), dst_data.size()); - ASSERT_TRUE(dst_buffer); - - // Test on subspan. - std::vector<uint8_t> actual_data(dst_data.size()); - ASSERT_OK_AND_ASSIGN(auto subspan_src_buffer, - Buffer::Subspan(src_buffer, 1, 3)); - ASSERT_OK_AND_ASSIGN(auto subspan_dst_buffer, - Buffer::Subspan(dst_buffer, 2, 3)); - EXPECT_OK(subspan_dst_buffer->CopyData(1, subspan_src_buffer.get(), 1, 2)); - EXPECT_OK(dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 1, 2, 2, 3)); -} - -// NOTE: more tests related specifically to MappedMemory are in -// buffer_mapping_test.cc. This tests the MapMemory operation and enough to -// ensure the memory was mapped to the correct range and the HostBuffer and -// SubspanBuffer work as intended for basic usage. -TEST(BufferTest, MapMemory) { - std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6}; - auto buffer = HeapBuffer::AllocateCopy( - BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kRead, - src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // 0-length mappings are valid. - ASSERT_OK_AND_ASSIGN(auto mapping, - buffer->MapMemory<uint8_t>(MemoryAccess::kRead, 0, 0)); - EXPECT_TRUE(mapping.empty()); - EXPECT_EQ(0, mapping.size()); - EXPECT_EQ(0, mapping.byte_length()); - EXPECT_NE(nullptr, mapping.data()); - ASSERT_OK_AND_ASSIGN(auto span, mapping.Subspan()); - EXPECT_TRUE(span.empty()); - mapping.reset(); - - // Map the whole buffer for reading. - ASSERT_OK_AND_ASSIGN(mapping, buffer->MapMemory<uint8_t>(MemoryAccess::kRead, - 0, kWholeBuffer)); - EXPECT_EQ(src_data.size(), mapping.size()); - ASSERT_OK_AND_ASSIGN(span, mapping.Subspan()); - EXPECT_THAT(span, ElementsAre(0, 1, 2, 3, 4, 5, 6)); - mapping.reset(); - - // Map a portion of the buffer for reading. - ASSERT_OK_AND_ASSIGN(mapping, - buffer->MapMemory<uint8_t>(MemoryAccess::kRead, 1, 2)); - EXPECT_EQ(2, mapping.size()); - ASSERT_OK_AND_ASSIGN(span, mapping.Subspan()); - EXPECT_THAT(span, ElementsAre(1, 2)); - mapping.reset(); -} - -TEST(BufferTest, MapMemoryNonByte) { - std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6}; - auto buffer = HeapBuffer::AllocateCopy( - BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kRead, - src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Map the buffer as non-byte values. - // Note that we'll round down to the number of valid elements at the - // alignment. - ASSERT_OK_AND_ASSIGN(auto mapping16, - buffer->MapMemory<uint16_t>(MemoryAccess::kRead)); - EXPECT_EQ(3, mapping16.size()); - EXPECT_LE(6, mapping16.byte_length()); - ASSERT_OK_AND_ASSIGN(auto span16, mapping16.Subspan()); - EXPECT_THAT(span16, ElementsAre(0x0100, 0x0302, 0x0504)); - mapping16.reset(); -} - -TEST(BufferTest, MapMemoryOutOfRange) { - std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6}; - auto buffer = HeapBuffer::AllocateCopy( - BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kRead, - src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Test invalid mapping ranges. - EXPECT_TRUE(IsOutOfRange( - buffer->MapMemory<uint16_t>(MemoryAccess::kRead, 0, 123).status())); - EXPECT_TRUE(IsOutOfRange( - buffer->MapMemory<uint16_t>(MemoryAccess::kRead, 5, 1231).status())); - EXPECT_TRUE(IsOutOfRange( - buffer->MapMemory<uint16_t>(MemoryAccess::kRead, 6, kWholeBuffer) - .status())); - EXPECT_TRUE(IsOutOfRange( - buffer->MapMemory<uint16_t>(MemoryAccess::kRead, 1236, 1).status())); -} - -TEST(BufferTest, MapMemoryBadMode) { - std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6}; - auto read_buffer = HeapBuffer::AllocateCopy( - BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kRead, - src_data.data(), src_data.size()); - ASSERT_TRUE(read_buffer); - - // Test mapping the read-only buffer for writing. - EXPECT_TRUE(IsPermissionDenied( - read_buffer->MapMemory<uint8_t>(MemoryAccess::kWrite).status())); - EXPECT_TRUE(IsPermissionDenied( - read_buffer->MapMemory<uint8_t>(MemoryAccess::kDiscardWrite).status())); - EXPECT_TRUE(IsPermissionDenied( - read_buffer - ->MapMemory<uint8_t>(MemoryAccess::kRead | MemoryAccess::kDiscard) - .status())); - EXPECT_TRUE(IsInvalidArgument( - read_buffer->MapMemory<uint8_t>(MemoryAccess::kNone).status())); -} - -TEST(BufferTest, MapMemoryWrite) { - std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6}; - auto buffer = HeapBuffer::AllocateCopy( - BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kAll, - src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Map and modify the data. We should see it when we read back. - ASSERT_OK_AND_ASSIGN(auto mapping, - buffer->MapMemory<uint8_t>(MemoryAccess::kWrite, 1, 2)); - auto mutable_data = mapping.mutable_data(); - mutable_data[0] = 0xAA; - mutable_data[1] = 0xBB; - mapping.reset(); - std::vector<uint8_t> actual_data(src_data.size()); - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 0xAA, 0xBB, 3, 4, 5, 6)); -} - -TEST(BufferTest, MapMemoryDiscard) { - std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6}; - auto buffer = HeapBuffer::AllocateCopy( - BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kAll, - src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Map for discard. Note that we can't really rely on the value of the data - // so we just trust that it's been discarded. It's a hint, anyway. We can be - // sure that the data we didn't want to discard is the same though. - std::vector<uint8_t> actual_data(src_data.size()); - ASSERT_OK_AND_ASSIGN(auto mapping, buffer->MapMemory<uint8_t>( - MemoryAccess::kDiscardWrite, 1, 2)); - EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, _, _, 3, 4, 5, 6)); - mapping.reset(); -} - -TEST(BufferTest, MapMemorySubspan) { - std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6}; - auto parent_buffer = HeapBuffer::AllocateCopy( - BufferUsage::kTransfer | BufferUsage::kMapping, MemoryAccess::kAll, - src_data.data(), src_data.size()); - ASSERT_TRUE(parent_buffer); - ASSERT_OK_AND_ASSIGN(auto subspan_buffer, - Buffer::Subspan(parent_buffer, 1, 3)); - ASSERT_OK_AND_ASSIGN(auto mapping, subspan_buffer->MapMemory<uint8_t>( - MemoryAccess::kDiscardWrite, 1, 2)); - auto* mutable_data = mapping.mutable_data(); - mutable_data[0] = 0xCC; - mutable_data[1] = 0xDD; - mapping.reset(); - - std::vector<uint8_t> actual_data(src_data.size()); - EXPECT_OK(parent_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 1, 0xCC, 0xDD, 4, 5, 6)); - - // Just here to make coverage happy; they are currently no-ops on the host. - // buffer_mapping_test.cc contains tests that ensure they are called - // correctly. - std::vector<uint8_t> external_data = {0, 1, 2, 3, 4}; - auto external_buffer = HeapBuffer::WrapMutable( - MemoryType::kHostVisible | MemoryType::kHostCached, MemoryAccess::kAll, - BufferUsage::kAll, absl::MakeSpan(external_data)); - ASSERT_OK_AND_ASSIGN(auto external_subspan_buffer, - Buffer::Subspan(external_buffer, 0, 1)); - ASSERT_OK_AND_ASSIGN( - mapping, external_subspan_buffer->MapMemory<uint8_t>(MemoryAccess::kAll)); - EXPECT_OK(mapping.Invalidate()); - EXPECT_OK(mapping.Flush()); -} - -} // namespace -} // namespace hal -} // namespace iree
diff --git a/iree/hal/buffer_view.cc b/iree/hal/buffer_view.cc deleted file mode 100644 index 75687cb..0000000 --- a/iree/hal/buffer_view.cc +++ /dev/null
@@ -1,180 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/buffer_view.h" - -#include "absl/container/inlined_vector.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "iree/base/source_location.h" -#include "iree/base/status.h" -#include "iree/hal/buffer.h" - -namespace iree { -namespace hal { - -namespace { -// Pretty prints an array, e.g. [1, 2, 3, 4] -inline std::string PrettyPrint(absl::Span<const int32_t> arr) { - return "[" + absl::StrJoin(arr, ",") + "]"; -} -} // namespace - -// static -bool BufferView::Equal(const BufferView& lhs, const BufferView& rhs) { - return lhs.buffer.get() == rhs.buffer.get() && - lhs.element_size == rhs.element_size && lhs.shape == rhs.shape; -} - -std::string BufferView::DebugStringShort() const { - if (element_size == 0) { - return "Ø"; - } - return shape.empty() ? std::to_string(element_size) - : absl::StrCat(absl::StrJoin(shape.subspan(), "x"), "x", - element_size); -} - -StatusOr<device_size_t> BufferView::CalculateOffset( - absl::Span<const int32_t> indices) const { - if (indices.empty()) { - return 0; - } else if (shape.empty() || indices.size() > shape.size()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Indices " << PrettyPrint(indices) - << " out of bounds of the rank of buffer_view " - << DebugStringShort(); - } - device_size_t offset = 0; - for (int i = 0; i < indices.size(); ++i) { - if (indices[i] >= shape[i]) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Indices[" << i << "]=" << indices[i] - << " out of bounds of buffer_view " << DebugStringShort(); - } - device_size_t axis_offset = indices[i]; - for (int j = i + 1; j < shape.size(); ++j) { - axis_offset *= shape[j]; - } - offset += axis_offset; - } - offset *= element_size; - return offset; -} - -StatusOr<BufferView> BufferView::Slice( - absl::Span<const int32_t> start_indices, - absl::Span<const int32_t> lengths) const { - if (start_indices.size() != shape.size()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Slice start_indices " << PrettyPrint(start_indices) - << " do not match rank of buffer_view " << DebugStringShort(); - } - if (start_indices.size() != lengths.size()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Slice start_indices " << PrettyPrint(start_indices) - << " and lengths " << PrettyPrint(lengths) - << " are not the same size"; - } - - // Buffer::Subspan only support contiguous memory. To ensure that this slice - // only requests such, we validate that the offset in the buffer between the - // start and end indices is the same as the requested size of the slice. - absl::InlinedVector<int32_t, 6> end_indices(lengths.size()); - device_size_t subspan_length = element_size; - for (int i = 0; i < lengths.size(); ++i) { - subspan_length *= lengths[i]; - end_indices[i] = start_indices[i] + lengths[i] - 1; - } - - ASSIGN_OR_RETURN(auto start_byte_offset, CalculateOffset(start_indices)); - // Also validates the ends are in bounds. - ASSIGN_OR_RETURN(auto end_byte_offset, CalculateOffset(end_indices)); - - auto offset_length = end_byte_offset - start_byte_offset + element_size; - if (subspan_length != offset_length) { - return UnimplementedErrorBuilder(IREE_LOC) - << "Slice for non-contiguous region of memory unimplemented. " - "start_indices: " - << PrettyPrint(start_indices) << " lengths: " << PrettyPrint(lengths) - << " " << subspan_length << " " << offset_length << " " - << PrettyPrint(end_indices); - } - - ASSIGN_OR_RETURN(auto new_buffer, - Buffer::Subspan(buffer, start_byte_offset, subspan_length)); - return BufferView(std::move(new_buffer), Shape(lengths), element_size); -} - -// static -Status BufferView::Copy(BufferView* src, - absl::Span<const int32_t> src_start_indices, - BufferView* dst, - absl::Span<const int32_t> dst_start_indices, - absl::Span<const int32_t> lengths) { - if (src_start_indices.size() != src->shape.size() || - dst_start_indices.size() != dst->shape.size() || - src_start_indices.size() != lengths.size() || - dst_start_indices.size() != lengths.size()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Src/dst shape/size mismatch: src=" << src->DebugStringShort() - << ", dst=" << dst->DebugStringShort() - << ", src_indices=" << PrettyPrint(src_start_indices) - << ", dst_indices=" << PrettyPrint(dst_start_indices) - << ", lengths=" << PrettyPrint(lengths); - } - - // Copies only support contiguous memory. To ensure that this copy - // only requests such, we validate that the offset in the buffer between the - // start and end indices is the same as the requested size of the copy. - absl::InlinedVector<int32_t, 4> src_end_indices(lengths.size()); - absl::InlinedVector<int32_t, 4> dst_end_indices(lengths.size()); - device_size_t total_length = src->element_size; - for (int i = 0; i < lengths.size(); ++i) { - total_length *= lengths[i]; - src_end_indices[i] = src_start_indices[i] + lengths[i] - 1; - dst_end_indices[i] = dst_start_indices[i] + lengths[i] - 1; - } - - ASSIGN_OR_RETURN(auto src_start_byte_offset, - src->CalculateOffset(src_start_indices)); - ASSIGN_OR_RETURN(auto src_end_byte_offset, - src->CalculateOffset(src_end_indices)); - ASSIGN_OR_RETURN(auto dst_start_byte_offset, - dst->CalculateOffset(dst_start_indices)); - ASSIGN_OR_RETURN(auto dst_end_byte_offset, - dst->CalculateOffset(dst_end_indices)); - - auto src_length = - src_end_byte_offset - src_start_byte_offset + src->element_size; - auto dst_length = - dst_end_byte_offset - dst_start_byte_offset + dst->element_size; - if (src_length != dst_length || src_length != total_length) { - return UnimplementedErrorBuilder(IREE_LOC) - << "Copy for non-contiguous region of memory unimplemented: " - << src->DebugStringShort() << ", dst=" << dst->DebugStringShort() - << ", src_indices=" << PrettyPrint(src_start_indices) - << ", dst_indices=" << PrettyPrint(dst_start_indices) - << ", lengths=" << PrettyPrint(lengths); - } - - RETURN_IF_ERROR(dst->buffer->CopyData(dst_start_byte_offset, - src->buffer.get(), - src_start_byte_offset, total_length)); - - return OkStatus(); -} - -} // namespace hal -} // namespace iree
diff --git a/iree/hal/buffer_view.h b/iree/hal/buffer_view.h deleted file mode 100644 index 308a0bb..0000000 --- a/iree/hal/buffer_view.h +++ /dev/null
@@ -1,108 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_BUFFER_VIEW_H_ -#define IREE_HAL_BUFFER_VIEW_H_ - -#include <memory> -#include <ostream> - -#include "iree/base/shape.h" -#include "iree/hal/buffer.h" - -namespace iree { -namespace hal { - -struct BufferView { - // Returns true if the given buffer_views are exactly equal. - static bool Equal(const BufferView& lhs, const BufferView& rhs); - - BufferView() = default; - BufferView(ref_ptr<Buffer> buffer, Shape shape, int8_t element_size) noexcept - : buffer(std::move(buffer)), shape(shape), element_size(element_size) {} - - BufferView(const BufferView& other) noexcept - : buffer(add_ref(other.buffer)), - shape(other.shape), - element_size(other.element_size) {} - BufferView& operator=(const BufferView& other) noexcept { - buffer = add_ref(other.buffer); - shape = other.shape; - element_size = other.element_size; - return *this; - } - BufferView(BufferView&& other) noexcept - : buffer(std::move(other.buffer)), - shape(other.shape), - element_size(other.element_size) {} - BufferView& operator=(BufferView&& other) noexcept { - buffer = std::move(other.buffer); - shape = other.shape; - element_size = other.element_size; - return *this; - } - - // Returns a string useful for printing debug messages. - std::string DebugStringShort() const; - - // Total length of the valid view range in bytes. - device_size_t byte_length() const { - return shape.element_count() * element_size; - } - - // TODO(b/134586626): remove this when byte ranges are encoded in IR. - // Calculates a byte offset into the buffer_view at the given dimension - // indices. - StatusOr<device_size_t> CalculateOffset( - absl::Span<const int32_t> indices) const; - - // TODO(b/134586626): remove this when byte ranges are encoded in IR. - // Returns a view onto the given range of the buffer underlying this view. The - // returned view starts at the offset indicated by |start_indices| and has a - // shape of |lengths|. - // Only contiguous regions of memory are supported at the moment. - StatusOr<BufferView> Slice(absl::Span<const int32_t> start_indices, - absl::Span<const int32_t> lengths) const; - - // TODO(b/134586626): remove this when byte ranges are encoded in IR. - static Status Copy(BufferView* src, - absl::Span<const int32_t> src_start_indices, - BufferView* dst, - absl::Span<const int32_t> dst_start_indices, - absl::Span<const int32_t> lengths); - - ref_ptr<Buffer> buffer; - Shape shape; - int8_t element_size; - // TODO(benvanik): strides. -}; - -inline bool operator==(const BufferView& a, const BufferView& b) { - return BufferView::Equal(a, b); -} - -inline bool operator!=(const BufferView& a, const BufferView& b) { - return !(a == b); -} - -inline std::ostream& operator<<(std::ostream& stream, - const BufferView& buffer_view) { - stream << buffer_view.DebugStringShort(); - return stream; -} - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_BUFFER_VIEW_H_
diff --git a/iree/hal/buffer_view_string_util.cc b/iree/hal/buffer_view_string_util.cc deleted file mode 100644 index ad95743..0000000 --- a/iree/hal/buffer_view_string_util.cc +++ /dev/null
@@ -1,542 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/buffer_view_string_util.h" - -#include <functional> -#include <sstream> -#include <type_traits> - -#include "absl/strings/ascii.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_join.h" -#include "absl/strings/str_split.h" -#include "absl/strings/strip.h" -#include "absl/types/optional.h" -#include "iree/base/source_location.h" -#include "iree/base/status.h" -#include "iree/hal/heap_buffer.h" - -namespace iree { -namespace hal { - -namespace { - -/* clang-format off */ -constexpr char kHexValue[256] = { - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, 0, 0, 0, 0, // '0'..'9' - 0, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 'A'..'F' - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 'a'..'f' - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 -}; -/* clang-format on */ - -template <typename T> -void HexStringToBytes(const char* from, T to, ptrdiff_t num) { - for (int i = 0; i < num; i++) { - to[i] = (kHexValue[from[i * 2] & 0xFF] << 4) + - (kHexValue[from[i * 2 + 1] & 0xFF]); - } -} - -constexpr char kHexTable[513] = - "000102030405060708090a0b0c0d0e0f" - "101112131415161718191a1b1c1d1e1f" - "202122232425262728292a2b2c2d2e2f" - "303132333435363738393a3b3c3d3e3f" - "404142434445464748494a4b4c4d4e4f" - "505152535455565758595a5b5c5d5e5f" - "606162636465666768696a6b6c6d6e6f" - "707172737475767778797a7b7c7d7e7f" - "808182838485868788898a8b8c8d8e8f" - "909192939495969798999a9b9c9d9e9f" - "a0a1a2a3a4a5a6a7a8a9aaabacadaeaf" - "b0b1b2b3b4b5b6b7b8b9babbbcbdbebf" - "c0c1c2c3c4c5c6c7c8c9cacbcccdcecf" - "d0d1d2d3d4d5d6d7d8d9dadbdcdddedf" - "e0e1e2e3e4e5e6e7e8e9eaebecedeeef" - "f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff"; - -template <typename T> -void BytesToHexString(const unsigned char* src, T dest, ptrdiff_t num) { - auto dest_ptr = &dest[0]; - for (auto src_ptr = src; src_ptr != (src + num); ++src_ptr, dest_ptr += 2) { - const char* hex_p = &kHexTable[*src_ptr * 2]; - std::copy(hex_p, hex_p + 2, dest_ptr); - } -} - -// Returns true if the given type is represented as binary hex data. -bool IsBinaryType(absl::string_view type_str) { - return !type_str.empty() && absl::ascii_isdigit(type_str[0]); -} - -// Parses binary hex data. -Status ParseBinaryData(absl::string_view data_str, Buffer* buffer) { - data_str = absl::StripAsciiWhitespace(data_str); - ASSIGN_OR_RETURN(auto mapping, - buffer->MapMemory<uint8_t>(MemoryAccess::kDiscardWrite)); - auto contents = mapping.mutable_contents(); - size_t dst_i = 0; - size_t src_i = 0; - while (src_i < data_str.size() && dst_i < contents.size()) { - char c = data_str[src_i]; - if (absl::ascii_isspace(c) || c == ',') { - ++src_i; - continue; - } - if (src_i + 1 >= data_str.size()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Invalid input hex data (offset=" << src_i << ")"; - } - HexStringToBytes(data_str.data() + src_i, contents.data() + dst_i, 1); - src_i += 2; - ++dst_i; - } - if (dst_i < contents.size()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Too few elements to fill type; expected " << contents.size() - << " but only read " << dst_i; - } else if (data_str.size() - src_i > 0) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Input data string contains more elements than the underlying " - "buffer (" - << contents.size() << ")"; - } - return OkStatus(); -} - -// Prints binary hex data. -Status PrintBinaryData(int element_size, Buffer* buffer, size_t max_entries, - std::ostream* stream) { - max_entries *= element_size; // Counting bytes, but treat them as elements. - ASSIGN_OR_RETURN(auto mapping, - buffer->MapMemory<uint8_t>(MemoryAccess::kRead)); - auto contents = mapping.contents(); - char hex_buffer[8 * 2]; - for (size_t i = 0; i < std::min(max_entries, mapping.size()); - i += element_size) { - if (i > 0) *stream << " "; - BytesToHexString(contents.data() + i, hex_buffer, element_size); - *stream << hex_buffer; - } - if (mapping.size() > max_entries) *stream << "..."; - return OkStatus(); -} - -template <typename ElementType, typename Enabled = void> -struct SimpleStrToValue { - absl::optional<ElementType> operator()(absl::string_view text) const = delete; -}; - -template <typename IntegerType> -struct SimpleStrToValue< - IntegerType, - typename std::enable_if<(sizeof(IntegerType) < 4), void>::type> { - absl::optional<IntegerType> operator()(absl::string_view text) const { - int32_t value; - return absl::SimpleAtoi(text, &value) ? absl::optional<IntegerType>{value} - : absl::nullopt; - } -}; - -template <typename IntegerType> -struct SimpleStrToValue< - IntegerType, - typename std::enable_if<(sizeof(IntegerType) >= 4), void>::type> { - absl::optional<IntegerType> operator()(absl::string_view text) const { - IntegerType value; - return absl::SimpleAtoi(text, &value) ? absl::optional<IntegerType>{value} - : absl::nullopt; - } -}; - -template <> -struct SimpleStrToValue<float, void> { - absl::optional<float> operator()(absl::string_view text) const { - float value; - return absl::SimpleAtof(text, &value) ? absl::optional<float>{value} - : absl::nullopt; - } -}; - -template <> -struct SimpleStrToValue<double, void> { - absl::optional<double> operator()(absl::string_view text) const { - double value; - return absl::SimpleAtod(text, &value) ? absl::optional<double>{value} - : absl::nullopt; - } -}; - -template <typename T> -Status ParseNumericalDataElement(absl::string_view data_str, size_t token_start, - size_t token_end, absl::Span<T> contents, - int dst_i) { - if (dst_i >= contents.size()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Input data string contains more elements than the underlying " - "buffer (" - << contents.size() << ")"; - } - auto element_str = data_str.substr(token_start, token_end - token_start + 1); - auto element = SimpleStrToValue<T>()(element_str); - if (!element.has_value()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Unable to parse element " << dst_i << " = '" << element_str - << "'"; - } - contents[dst_i] = element.value(); - return OkStatus(); -} - -template <typename T> -Status ParseNumericalDataAsType(absl::string_view data_str, Buffer* buffer) { - ASSIGN_OR_RETURN(auto mapping, - buffer->MapMemory<T>(MemoryAccess::kDiscardWrite)); - auto contents = mapping.mutable_contents(); - size_t src_i = 0; - size_t dst_i = 0; - size_t token_start = std::string::npos; - while (src_i < data_str.size()) { - char c = data_str[src_i++]; - bool is_separator = - absl::ascii_isspace(c) || c == ',' || c == '[' || c == ']'; - if (token_start == std::string::npos) { - if (!is_separator) { - token_start = src_i - 1; - } - continue; - } else if (token_start != std::string::npos && !is_separator) { - continue; - } - RETURN_IF_ERROR(ParseNumericalDataElement<T>(data_str, token_start, - src_i - 2, contents, dst_i++)); - token_start = std::string::npos; - } - if (token_start != std::string::npos) { - RETURN_IF_ERROR(ParseNumericalDataElement<T>( - data_str, token_start, data_str.size() - 1, contents, dst_i++)); - } - if (dst_i < contents.size()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Input data string contains fewer elements than the underlying " - "buffer (expected " - << contents.size() << ")"; - } - return OkStatus(); -} - -// Parses numerical data (ints, floats, etc) in some typed form. -Status ParseNumericalData(absl::string_view type_str, - absl::string_view data_str, Buffer* buffer) { - if (type_str == "i8") { - return ParseNumericalDataAsType<int8_t>(data_str, buffer); - } else if (type_str == "u8") { - return ParseNumericalDataAsType<uint8_t>(data_str, buffer); - } else if (type_str == "i16") { - return ParseNumericalDataAsType<int16_t>(data_str, buffer); - } else if (type_str == "u16") { - return ParseNumericalDataAsType<uint16_t>(data_str, buffer); - } else if (type_str == "i32") { - return ParseNumericalDataAsType<int32_t>(data_str, buffer); - } else if (type_str == "u32") { - return ParseNumericalDataAsType<uint32_t>(data_str, buffer); - } else if (type_str == "i64") { - return ParseNumericalDataAsType<int64_t>(data_str, buffer); - } else if (type_str == "u64") { - return ParseNumericalDataAsType<uint64_t>(data_str, buffer); - } else if (type_str == "f32") { - return ParseNumericalDataAsType<float>(data_str, buffer); - } else if (type_str == "f64") { - return ParseNumericalDataAsType<double>(data_str, buffer); - } else { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Unsupported type: " << type_str; - } -} - -template <typename T> -void PrintElementList(const Shape& shape, absl::Span<const T> data, - size_t* max_entries, std::ostream* stream) { - if (shape.empty()) { - // Scalar value. - PrintElementList({1}, data, max_entries, stream); - return; - } else if (shape.size() == 1) { - // Leaf dimension; output data. - size_t max_count = std::min(*max_entries, static_cast<size_t>(shape[0])); - *stream << absl::StrJoin(data.subspan(0, max_count), " "); - if (max_count < shape[0]) { - *stream << "..."; - } - *max_entries -= max_count; - } else { - // Nested; recurse into next dimension. - Shape nested_shape = Shape(shape.subspan(1)); - size_t length = nested_shape.element_count(); - size_t offset = 0; - for (int i = 0; i < shape[0]; ++i) { - *stream << "["; - PrintElementList<T>(nested_shape, data.subspan(offset, length), - max_entries, stream); - offset += length; - *stream << "]"; - } - } -} - -template <typename T> -Status PrintNumericalDataAsType(const Shape& shape, Buffer* buffer, - size_t max_entries, std::ostream* stream) { - ASSIGN_OR_RETURN(auto mapping, buffer->MapMemory<T>(MemoryAccess::kRead)); - PrintElementList(shape, mapping.contents(), &max_entries, stream); - return OkStatus(); -} - -// Prints numerical data (ints, floats, etc) from some typed form. -Status PrintNumericalData(const Shape& shape, absl::string_view type_str, - Buffer* buffer, size_t max_entries, - std::ostream* stream) { - if (type_str == "i8") { - return PrintNumericalDataAsType<int8_t>(shape, buffer, max_entries, stream); - } else if (type_str == "u8") { - return PrintNumericalDataAsType<uint8_t>(shape, buffer, max_entries, - stream); - } else if (type_str == "i16") { - return PrintNumericalDataAsType<int16_t>(shape, buffer, max_entries, - stream); - } else if (type_str == "u16") { - return PrintNumericalDataAsType<uint16_t>(shape, buffer, max_entries, - stream); - } else if (type_str == "i32") { - return PrintNumericalDataAsType<int32_t>(shape, buffer, max_entries, - stream); - } else if (type_str == "u32") { - return PrintNumericalDataAsType<uint32_t>(shape, buffer, max_entries, - stream); - } else if (type_str == "i64") { - return PrintNumericalDataAsType<int64_t>(shape, buffer, max_entries, - stream); - } else if (type_str == "u64") { - return PrintNumericalDataAsType<uint64_t>(shape, buffer, max_entries, - stream); - } else if (type_str == "f32") { - return PrintNumericalDataAsType<float>(shape, buffer, max_entries, stream); - } else if (type_str == "f64") { - return PrintNumericalDataAsType<double>(shape, buffer, max_entries, stream); - } else { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Unsupported type: " << type_str; - } -} - -} // namespace - -StatusOr<int> GetTypeElementSize(absl::string_view type_str) { - if (type_str.empty()) { - return InvalidArgumentErrorBuilder(IREE_LOC) << "Type is empty"; - } else if (IsBinaryType(type_str)) { - // If the first character is a digit then we are dealign with binary data. - // The type is just the number of bytes per element. - int element_size = 0; - if (!absl::SimpleAtoi(type_str, &element_size)) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Unable to parse element size type '" << type_str << "'"; - } - return element_size; - } - // We know that our types are single characters followed by bit counts. - // If we start to support other types we may need to do something more clever. - int bit_count = 0; - if (!absl::SimpleAtoi(type_str.substr(1), &bit_count)) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Unable to parse type bit count from '" << type_str - << "'; expecting something like 'i32'"; - } - return bit_count / 8; -} - -StatusOr<Shape> ParseShape(absl::string_view shape_str) { - std::vector<int> dims; - for (auto dim_str : absl::StrSplit(shape_str, 'x', absl::SkipWhitespace())) { - int dim_value = 0; - if (!absl::SimpleAtoi(dim_str, &dim_value)) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Invalid shape dimension '" << dim_str - << "' while parsing shape '" << shape_str << "'"; - } - dims.push_back(dim_value); - } - return Shape{dims}; -} - -StatusOr<BufferView> ParseBufferViewFromString( - absl::string_view buffer_view_str, hal::Allocator* allocator) { - // Strip whitespace that may come along (linefeeds/etc). - buffer_view_str = absl::StripAsciiWhitespace(buffer_view_str); - if (buffer_view_str.empty()) { - // Empty lines denote empty buffer_views. - return BufferView{}; - } - - // Split into the components we can work with: shape, type, and data. - absl::string_view shape_and_type_str; - absl::string_view data_str; - auto equal_index = buffer_view_str.find('='); - if (equal_index == std::string::npos) { - // Treat a lack of = as defaulting the data to zeros. - shape_and_type_str = buffer_view_str; - } else { - shape_and_type_str = buffer_view_str.substr(0, equal_index); - data_str = buffer_view_str.substr(equal_index + 1); - } - absl::string_view shape_str; - absl::string_view type_str; - auto last_x_index = shape_and_type_str.rfind('x'); - if (last_x_index == std::string::npos) { - // Scalar. - type_str = shape_and_type_str; - } else { - // Has a shape. - shape_str = shape_and_type_str.substr(0, last_x_index); - type_str = shape_and_type_str.substr(last_x_index + 1); - } - - // Populate BufferView metadata required for allocation. - BufferView result; - ASSIGN_OR_RETURN(result.element_size, GetTypeElementSize(type_str)); - ASSIGN_OR_RETURN(result.shape, ParseShape(shape_str)); - - // Allocate the host buffer. - size_t allocation_size = result.shape.element_count() * result.element_size; - if (allocator) { - ASSIGN_OR_RETURN( - result.buffer, - allocator->Allocate(MemoryType::kHostLocal | MemoryType::kDeviceVisible, - BufferUsage::kAll | BufferUsage::kConstant, - allocation_size)); - } else { - result.buffer = HeapBuffer::Allocate( - MemoryType::kHostLocal, BufferUsage::kAll | BufferUsage::kConstant, - allocation_size); - } - - if (!data_str.empty()) { - // Parse the data from the string right into the buffer. - if (IsBinaryType(type_str)) { - // Parse as binary hex. - RETURN_IF_ERROR(ParseBinaryData(data_str, result.buffer.get())); - } else { - // Parse as some nicely formatted type. - RETURN_IF_ERROR( - ParseNumericalData(type_str, data_str, result.buffer.get())); - } - } - - return result; -} - -StatusOr<BufferViewPrintMode> ParseBufferViewPrintMode(absl::string_view str) { - char str_char = str.empty() ? '?' : str[0]; - switch (str_char) { - case 'b': - return BufferViewPrintMode::kBinary; - case 'i': - return BufferViewPrintMode::kSignedInteger; - case 'u': - return BufferViewPrintMode::kUnsignedInteger; - case 'f': - return BufferViewPrintMode::kFloatingPoint; - default: - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Unsupported output type '" << str << "'"; - } -} - -StatusOr<std::string> PrintBufferViewToString(const BufferView& buffer_view, - BufferViewPrintMode print_mode, - size_t max_entries) { - std::string result; - RETURN_IF_ERROR( - PrintBufferViewToString(buffer_view, print_mode, max_entries, &result)); - return result; -} - -Status PrintBufferViewToString(const BufferView& buffer_view, - BufferViewPrintMode print_mode, - size_t max_entries, std::string* out_result) { - std::ostringstream stream; - RETURN_IF_ERROR( - PrintBufferViewToStream(buffer_view, print_mode, max_entries, &stream)); - *out_result = stream.str(); - return OkStatus(); -} - -Status PrintBufferViewToStream(const BufferView& buffer_view, - BufferViewPrintMode print_mode, - size_t max_entries, std::ostream* stream) { - if (!buffer_view.buffer) { - // No buffer means the buffer_view is empty. We use the empty string to - // denote this (as we have no useful information). - return OkStatus(); - } - - // Pick a type based on the element size and the printing mode. - std::string type_str; - switch (print_mode) { - case BufferViewPrintMode::kBinary: - type_str = std::to_string(buffer_view.element_size); - break; - case BufferViewPrintMode::kSignedInteger: - absl::StrAppend(&type_str, "i", buffer_view.element_size * 8); - break; - case BufferViewPrintMode::kUnsignedInteger: - absl::StrAppend(&type_str, "u", buffer_view.element_size * 8); - break; - case BufferViewPrintMode::kFloatingPoint: - absl::StrAppend(&type_str, "f", buffer_view.element_size * 8); - break; - } - - // [shape]x[type]= prefix (taking into account scalar values). - *stream << absl::StrJoin(buffer_view.shape.begin(), buffer_view.shape.end(), - "x"); - if (!buffer_view.shape.empty()) *stream << "x"; - *stream << type_str; - *stream << "="; - - if (IsBinaryType(type_str)) { - return PrintBinaryData(buffer_view.element_size, buffer_view.buffer.get(), - max_entries, stream); - } else { - return PrintNumericalData(buffer_view.shape, type_str, - buffer_view.buffer.get(), max_entries, stream); - } -} - -} // namespace hal -} // namespace iree
diff --git a/iree/hal/buffer_view_string_util.h b/iree/hal/buffer_view_string_util.h deleted file mode 100644 index 6e9ca2c..0000000 --- a/iree/hal/buffer_view_string_util.h +++ /dev/null
@@ -1,95 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Utilities for working with BufferView data, mostly useful for testing. -// These functions allow for conversion between types, parsing and printing, and -// basic comparisons. -// -// The canonical BufferView string format is: -// [shape]x[type]=value,value,... -// For example: -// 2x2xi32=0,1,2,3 -// Characters like [] are optional and will be ignored during parsing: -// 2x2xi32=[[0 1][2 3]] -// -// The type may be one of the following: -// * 1/2/4/8 = 1/2/4/8 byte elements in binary hex format. -// * i8/u8 = signed/unsigned 8-bit integers. -// * i16/u16 = signed/unsigned 16-bit integers. -// * i32/u32 = signed/unsigned 32-bit integers. -// * i64/u64 = signed/unsigned 64-bit integers. -// * f32 = 32-bit floating-point number. -// * f64 = 64-bit floating-point number. - -#ifndef IREE_HAL_BUFFER_VIEW_STRING_UTIL_H_ -#define IREE_HAL_BUFFER_VIEW_STRING_UTIL_H_ - -#include <ostream> -#include <string> - -#include "absl/strings/string_view.h" -#include "iree/base/status.h" -#include "iree/hal/allocator.h" -#include "iree/hal/buffer_view.h" - -namespace iree { -namespace hal { - -// Returns the size, in bytes, of the given type. -StatusOr<int> GetTypeElementSize(absl::string_view type_str); - -// Returns a Shape parsed from the given NxMx... string. -StatusOr<Shape> ParseShape(absl::string_view shape_str); - -// Parses a BufferView encoded in a string. -// If an |allocator| is provided the buffer will be allocated as host-local and -// device-visible. Otherwise, buffers will be host-local. -// The format accepted matches that produced by PrintBufferViewToString. -StatusOr<BufferView> ParseBufferViewFromString( - absl::string_view buffer_view_str, hal::Allocator* allocator = nullptr); - -// Defines how the elements within a BufferView are interpreted during printing. -enum class BufferViewPrintMode { - // Interpret the data as if it were serialized bytes. - // In this mode no conversion is performed and the bytes in memory are printed - // as hex in groupings based on the element size. Shortened to 'b'. - kBinary, - // Interpret elements as signed integers; shortened to 'i'. - kSignedInteger, - // Interpret elements as unsigned integers; shortened to 'u'. - kUnsignedInteger, - // Interpret elements as floating-point values; shortened to 'f'. - kFloatingPoint, -}; - -// Returns the BufferViewPrintMode based on the shortened char in |str|. -StatusOr<BufferViewPrintMode> ParseBufferViewPrintMode(absl::string_view str); - -// Prints a BufferView to a string encoded in the canonical format. -StatusOr<std::string> PrintBufferViewToString(const BufferView& buffer_view, - BufferViewPrintMode print_mode, - size_t max_entries); -Status PrintBufferViewToString(const BufferView& buffer_view, - BufferViewPrintMode print_mode, - size_t max_entries, std::string* out_result); - -// Prints a BufferView to a string stream encoded in the canonical format. -Status PrintBufferViewToStream(const BufferView& buffer_view, - BufferViewPrintMode print_mode, - size_t max_entries, std::ostream* stream); - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_BUFFER_VIEW_STRING_UTIL_H_
diff --git a/iree/hal/buffer_view_string_util_test.cc b/iree/hal/buffer_view_string_util_test.cc deleted file mode 100644 index 11902ba..0000000 --- a/iree/hal/buffer_view_string_util_test.cc +++ /dev/null
@@ -1,186 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/buffer_view_string_util.h" - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "iree/base/status.h" -#include "iree/base/status_matchers.h" - -namespace iree { -namespace hal { -namespace { - -using ::testing::ElementsAre; - -template <typename T> -StatusOr<std::vector<T>> ReadBuffer(const ref_ptr<Buffer>& buffer) { - std::vector<T> result; - result.resize(buffer->byte_length() / sizeof(T)); - RETURN_IF_ERROR( - buffer->ReadData(0, result.data(), result.size() * sizeof(T))); - return result; -} - -TEST(BufferViewUtilTest, GetTypeElementSize) { - EXPECT_EQ(1, GetTypeElementSize("1").ValueOrDie()); - EXPECT_EQ(7, GetTypeElementSize("7").ValueOrDie()); - EXPECT_EQ(4, GetTypeElementSize("i32").ValueOrDie()); - EXPECT_EQ(8, GetTypeElementSize("f64").ValueOrDie()); - - EXPECT_FALSE(GetTypeElementSize("").ok()); - EXPECT_FALSE(GetTypeElementSize(" ").ok()); - EXPECT_FALSE(GetTypeElementSize("a").ok()); - EXPECT_FALSE(GetTypeElementSize("ib").ok()); - EXPECT_FALSE(GetTypeElementSize("i").ok()); - EXPECT_FALSE(GetTypeElementSize("i543ff").ok()); -} - -TEST(BufferViewUtilTest, ParseShape) { - EXPECT_EQ((Shape{}), ParseShape("").ValueOrDie()); - EXPECT_EQ((Shape{1}), ParseShape("1").ValueOrDie()); - EXPECT_EQ((Shape{1, 2}), ParseShape("1x2").ValueOrDie()); - EXPECT_EQ((Shape{1, 2}), ParseShape(" 1 x 2 ").ValueOrDie()); - - EXPECT_FALSE(ParseShape("abc").ok()); - EXPECT_FALSE(ParseShape("1xf").ok()); - EXPECT_FALSE(ParseShape("1xff23").ok()); -} - -TEST(BufferViewUtilTest, ParseBufferViewFromStringEmpty) { - // Empty string = empty buffer_view. - ASSERT_OK_AND_ASSIGN(auto m0, ParseBufferViewFromString("")); - EXPECT_EQ(nullptr, m0.buffer.get()); - EXPECT_EQ(Shape{}, m0.shape); - EXPECT_EQ(0, m0.element_size); - - // No = means no data. - ASSERT_OK_AND_ASSIGN(auto m1, ParseBufferViewFromString("4x2xf32")); - EXPECT_EQ(4 * 2 * 4, m1.buffer->allocation_size()); - EXPECT_EQ(Shape({4, 2}), m1.shape); - EXPECT_EQ(4, m1.element_size); - EXPECT_THAT(ReadBuffer<float>(m1.buffer).ValueOrDie(), - ElementsAre(0, 0, 0, 0, 0, 0, 0, 0)); - - // No data after = means no data. - ASSERT_OK_AND_ASSIGN(auto m2, ParseBufferViewFromString("4x2xf32=")); - EXPECT_EQ(4 * 2 * 4, m2.buffer->allocation_size()); - EXPECT_EQ(Shape({4, 2}), m2.shape); - EXPECT_EQ(4, m2.element_size); - EXPECT_THAT(ReadBuffer<float>(m2.buffer).ValueOrDie(), - ElementsAre(0, 0, 0, 0, 0, 0, 0, 0)); -} - -TEST(BufferViewUtilTest, ParseBufferViewFromStringBinary) { - ASSERT_OK_AND_ASSIGN(auto m0, ParseBufferViewFromString("4x1=00 01 02 03")); - EXPECT_EQ(Shape({4}), m0.shape); - EXPECT_EQ(1, m0.element_size); - EXPECT_THAT(ReadBuffer<uint8_t>(m0.buffer).ValueOrDie(), - ElementsAre(0, 1, 2, 3)); - - // Whitespace shouldn't matter. - ASSERT_OK_AND_ASSIGN(auto m1, ParseBufferViewFromString("4x1=00,010203")); - EXPECT_EQ(Shape({4}), m1.shape); - EXPECT_EQ(1, m1.element_size); - EXPECT_THAT(ReadBuffer<uint8_t>(m1.buffer).ValueOrDie(), - ElementsAre(0, 1, 2, 3)); - - // Should fail on malformed hex bytes. - EXPECT_FALSE(ParseBufferViewFromString("4x1=1").ok()); - EXPECT_FALSE(ParseBufferViewFromString("4x1=00003").ok()); - EXPECT_FALSE(ParseBufferViewFromString("4x1=%0123%\1").ok()); - EXPECT_FALSE(ParseBufferViewFromString("4x1=00010203040506").ok()); -} - -TEST(BufferViewUtilTest, ParseBufferViewFromStringAllowBrackets) { - ASSERT_OK_AND_ASSIGN(auto m0, - ParseBufferViewFromString("4xi16=[[0][ 1 ][2]][3]")); - EXPECT_EQ(Shape({4}), m0.shape); - EXPECT_EQ(2, m0.element_size); - EXPECT_THAT(ReadBuffer<int16_t>(m0.buffer).ValueOrDie(), - ElementsAre(0, 1, 2, 3)); -} - -TEST(BufferViewUtilTest, ParseBufferViewFromStringInteger) { - // Signed int16. - ASSERT_OK_AND_ASSIGN(auto m0, - ParseBufferViewFromString("4xi16=0 12345 65535 -2")); - EXPECT_EQ(Shape({4}), m0.shape); - EXPECT_EQ(2, m0.element_size); - EXPECT_THAT(ReadBuffer<int16_t>(m0.buffer).ValueOrDie(), - ElementsAre(0, 12345, -1, -2)); - - // Unsigned int16. - ASSERT_OK_AND_ASSIGN(auto m1, - ParseBufferViewFromString("4xu16=0 12345 65535 -2")); - EXPECT_EQ(Shape({4}), m1.shape); - EXPECT_EQ(2, m1.element_size); - EXPECT_THAT(ReadBuffer<uint16_t>(m1.buffer).ValueOrDie(), - ElementsAre(0, 12345, 65535, 65534)); - - // Mixing separator types is ok. - ASSERT_OK_AND_ASSIGN(auto m2, - ParseBufferViewFromString("4xu16=0, 12345, 65535, -2")); - EXPECT_EQ(Shape({4}), m2.shape); - EXPECT_EQ(2, m2.element_size); - EXPECT_THAT(ReadBuffer<uint16_t>(m2.buffer).ValueOrDie(), - ElementsAre(0, 12345, 65535, 65534)); - - // Should fail on malformed integers bytes and out of bounds values. - EXPECT_FALSE(ParseBufferViewFromString("4xi32=asodfj").ok()); - EXPECT_FALSE(ParseBufferViewFromString("4xi32=0 1 2 3 4").ok()); -} - -TEST(BufferViewUtilTest, ParseBufferViewFromStringFloat) { - // Float. - ASSERT_OK_AND_ASSIGN(auto m0, - ParseBufferViewFromString("4xf32=0 1.0 1234 -2.0e-5")); - EXPECT_EQ(Shape({4}), m0.shape); - EXPECT_EQ(4, m0.element_size); - EXPECT_THAT(ReadBuffer<float>(m0.buffer).ValueOrDie(), - ElementsAre(0.0f, 1.0f, 1234.0f, -2.0e-5f)); - - // Double. - ASSERT_OK_AND_ASSIGN(auto m1, ParseBufferViewFromString( - "4xf64=0 1.0 123456789012345 -2.0e-5")); - EXPECT_EQ(Shape({4}), m1.shape); - EXPECT_EQ(8, m1.element_size); - EXPECT_THAT(ReadBuffer<double>(m1.buffer).ValueOrDie(), - ElementsAre(0.0, 1.0, 123456789012345.0, -2.0e-5)); - - // Should fail on malformed floats and out of bounds values. - EXPECT_FALSE(ParseBufferViewFromString("4xf32=asodfj").ok()); - EXPECT_FALSE(ParseBufferViewFromString("4xf32=0").ok()); - EXPECT_FALSE(ParseBufferViewFromString("4xf32=0 1 2 3 4").ok()); -} - -TEST(BufferViewUtilTest, ParseBufferViewPrintMode) { - EXPECT_EQ(BufferViewPrintMode::kBinary, - ParseBufferViewPrintMode("b").ValueOrDie()); - EXPECT_EQ(BufferViewPrintMode::kSignedInteger, - ParseBufferViewPrintMode("i").ValueOrDie()); - EXPECT_EQ(BufferViewPrintMode::kUnsignedInteger, - ParseBufferViewPrintMode("u").ValueOrDie()); - EXPECT_EQ(BufferViewPrintMode::kFloatingPoint, - ParseBufferViewPrintMode("f").ValueOrDie()); - - EXPECT_FALSE(ParseBufferViewPrintMode("").ok()); - EXPECT_FALSE(ParseBufferViewPrintMode("s").ok()); - EXPECT_FALSE(ParseBufferViewPrintMode("asdfasdf").ok()); -} - -} // namespace -} // namespace hal -} // namespace iree
diff --git a/iree/hal/buffer_view_test.cc b/iree/hal/buffer_view_test.cc deleted file mode 100644 index 18dcbcf..0000000 --- a/iree/hal/buffer_view_test.cc +++ /dev/null
@@ -1,285 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/buffer_view.h" - -#include <numeric> -#include <vector> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "iree/base/status.h" -#include "iree/base/status_matchers.h" -#include "iree/hal/buffer.h" -#include "iree/hal/heap_buffer.h" - -namespace iree { -namespace hal { -namespace { - -template <typename T> -BufferView MakeView(const std::vector<T> src_data, Shape shape) { - auto parent_buffer = HeapBuffer::AllocateCopy( - BufferUsage::kTransfer | BufferUsage::kMapping, absl::MakeSpan(src_data)); - - return BufferView(std::move(parent_buffer), shape, sizeof(T)); -} - -template <typename T> -std::vector<T> ReadData(BufferView view) { - std::vector<T> data(view.shape.element_count()); - EXPECT_OK(view.buffer->ReadData(0, data.data(), data.size() * sizeof(T))); - return data; -} - -TEST(BufferViewTest, SliceWholeBuffer) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - Shape shape = {2, 2}; - auto parent_view = MakeView(src_data, shape); - - std::vector<int32_t> start_indices = {0, 0}; - std::vector<int32_t> lengths = {2, 2}; - ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths)); - - EXPECT_TRUE(BufferView::Equal(parent_view, slice)) - << "original parent_view " << parent_view.DebugStringShort() - << " and whole slice " << slice.DebugStringShort() << " are not equal"; -} - -TEST(BufferViewTest, SliceSingleRow) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - Shape shape = {2, 2}; - auto parent_view = MakeView(src_data, shape); - - std::vector<int32_t> start_indices = {1, 0}; - std::vector<int32_t> lengths = {1, 2}; - ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths)); - - EXPECT_EQ(ReadData<uint8_t>(slice), std::vector<uint8_t>({2, 3})); -} - -TEST(BufferViewTest, SliceRowStart) { - std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6, 7}; - Shape shape = {2, 4}; - auto parent_view = MakeView(src_data, shape); - - std::vector<int32_t> start_indices = {1, 0}; - std::vector<int32_t> lengths = {1, 3}; - ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths)); - - EXPECT_EQ(ReadData<uint8_t>(slice), std::vector<uint8_t>({4, 5, 6})); -} - -TEST(BufferViewTest, SliceRowEnd) { - std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6, 7}; - Shape shape = {2, 4}; - auto parent_view = MakeView(src_data, shape); - - std::vector<int32_t> start_indices = {1, 1}; - std::vector<int32_t> lengths = {1, 3}; - ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths)); - - EXPECT_EQ(ReadData<uint8_t>(slice), std::vector<uint8_t>({5, 6, 7})); -} - -TEST(BufferViewTest, SliceRowMiddle) { - std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6, 7}; - Shape shape = {2, 4}; - auto parent_view = MakeView(src_data, shape); - - std::vector<int32_t> start_indices = {1, 1}; - std::vector<int32_t> lengths = {1, 2}; - ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths)); - - EXPECT_EQ(ReadData<uint8_t>(slice), std::vector<uint8_t>({5, 6})); -} - -TEST(BufferViewTest, SliceMultiRow) { - std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6, 7, 8}; - Shape shape = {3, 3}; - auto parent_view = MakeView(src_data, shape); - - std::vector<int32_t> start_indices = {1, 0}; - std::vector<int32_t> lengths = {2, 3}; - ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths)); - - EXPECT_EQ(ReadData<uint8_t>(slice), std::vector<uint8_t>({3, 4, 5, 6, 7, 8})); -} - -TEST(BufferViewTest, SliceHighRank) { - std::vector<uint8_t> src_data(81); - std::iota(src_data.begin(), src_data.end(), 0); - Shape shape = {3, 3, 3, 3}; - - auto parent_view = MakeView(src_data, shape); - - std::vector<int32_t> start_indices = {1, 2, 2, 1}; - std::vector<int32_t> lengths = {1, 1, 1, 2}; - ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths)); - - EXPECT_EQ(ReadData<uint8_t>(slice), std::vector<uint8_t>({52, 53})); -} - -TEST(BufferViewTest, SliceModifySlice) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - Shape shape = {2, 2}; - auto parent_view = MakeView(src_data, shape); - - std::vector<int32_t> start_indices = {1, 0}; - std::vector<int32_t> lengths = {1, 2}; - ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths)); - - EXPECT_OK(slice.buffer->Fill8(0, kWholeBuffer, 0xFFu)); - - auto parent_data = ReadData<uint8_t>(parent_view); - EXPECT_EQ(parent_data, std::vector<uint8_t>({0, 1, 0xFFu, 0xFFu})); -} - -TEST(BufferViewTest, SliceModifyParent) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - Shape shape = {2, 2}; - auto parent_view = MakeView(src_data, shape); - - std::vector<int32_t> start_indices = {1, 0}; - std::vector<int32_t> lengths = {1, 2}; - ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths)); - - EXPECT_OK(parent_view.buffer->Fill8(0, kWholeBuffer, 0xFFu)); - - EXPECT_EQ(ReadData<uint8_t>(slice), std::vector<uint8_t>({0xFFu, 0xFFu})); -} - -TEST(BufferViewTest, SliceMultiByteElementWholeBuffer) { - const std::vector<int32_t> src_data = {INT32_MAX, 1, 2, 3}; - - Shape shape = {2, 2}; - auto parent_view = MakeView(src_data, shape); - - std::vector<int32_t> start_indices = {0, 0}; - std::vector<int32_t> lengths = {2, 2}; - ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths)); - - EXPECT_TRUE(BufferView::Equal(parent_view, slice)) - << "original parent_view " << parent_view.DebugStringShort() - << " and whole slice " << slice.DebugStringShort() << " are not equal"; -} - -TEST(BufferViewTest, SliceShapeAndElementSize) { - std::vector<int32_t> src_data = {INT32_MAX, 1, 2, 3}; - Shape shape = {2, 2}; - auto parent_view = MakeView(src_data, shape); - - std::vector<int32_t> start_indices = {1, 0}; - std::vector<int32_t> lengths = {1, 2}; - ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths)); - EXPECT_EQ(slice.shape, Shape(lengths)); - EXPECT_EQ(slice.element_size, 4); -} - -TEST(BufferViewTest, SliceMultiByteElement) { - std::vector<int32_t> src_data = {INT32_MAX, 1, 2, 3}; - Shape shape = {2, 2}; - auto parent_view = MakeView(src_data, shape); - - std::vector<int32_t> start_indices = {1, 0}; - std::vector<int32_t> lengths = {1, 2}; - ASSERT_OK_AND_ASSIGN(auto slice, parent_view.Slice(start_indices, lengths)); - - EXPECT_EQ(ReadData<int32_t>(slice), std::vector<int32_t>({2, 3})); -} - -TEST(BufferViewTest, SliceIndexBadRank) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - Shape shape = {2, 2}; - auto parent_view = MakeView(src_data, shape); - - std::vector<int32_t> start_indices = {0}; - std::vector<int32_t> lengths = {2}; - EXPECT_TRUE( - IsInvalidArgument(parent_view.Slice(start_indices, lengths).status())); -} - -TEST(BufferViewTest, SliceIndexLengthMismatch) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - Shape shape = {2, 2}; - auto parent_view = MakeView(src_data, shape); - - std::vector<int32_t> start_indices = {0, 0}; - std::vector<int32_t> lengths = {2}; - EXPECT_TRUE( - IsInvalidArgument(parent_view.Slice(start_indices, lengths).status())); -} - -TEST(BufferViewTest, SliceIndicesOutOfBounds) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - Shape shape = {2, 2}; - - auto parent_view = MakeView(src_data, shape); - - std::vector<int32_t> start_indices = {0, 3}; - std::vector<int32_t> lengths = {1, 1}; - EXPECT_TRUE( - IsInvalidArgument(parent_view.Slice(start_indices, lengths).status())); -} - -TEST(BufferViewTest, SliceLengthsOutOfBounds) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - Shape shape = {2, 2}; - - auto parent_view = MakeView(src_data, shape); - - std::vector<int32_t> start_indices = {0, 0}; - std::vector<int32_t> lengths = {1, 3}; - EXPECT_TRUE( - IsInvalidArgument(parent_view.Slice(start_indices, lengths).status())); -} - -TEST(BufferViewTest, SliceNonContiguous) { - std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6, 7, 8}; - Shape shape = {3, 3}; - auto parent_view = MakeView(src_data, shape); - - std::vector<int32_t> start_indices = {1, 1}; - std::vector<int32_t> lengths = {2, 2}; - EXPECT_TRUE( - IsUnimplemented(parent_view.Slice(start_indices, lengths).status())); -} - -TEST(BufferViewTest, SliceNonContiguousMultiRowLeft) { - std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6, 7, 8}; - Shape shape = {3, 3}; - auto parent_view = MakeView(src_data, shape); - - std::vector<int32_t> start_indices = {1, 0}; - std::vector<int32_t> lengths = {2, 1}; - EXPECT_TRUE( - IsUnimplemented(parent_view.Slice(start_indices, lengths).status())); -} - -TEST(BufferViewTest, SliceHighRankNonContiguous) { - std::vector<uint8_t> src_data(81); - std::iota(src_data.begin(), src_data.end(), 0); - Shape shape = {3, 3, 3, 3}; - - auto parent_view = MakeView(src_data, shape); - - std::vector<int32_t> start_indices = {1, 0, 2, 1}; - std::vector<int32_t> lengths = {1, 2, 1, 2}; - EXPECT_TRUE( - IsUnimplemented(parent_view.Slice(start_indices, lengths).status())); -} - -} // namespace -} // namespace hal -} // namespace iree
diff --git a/iree/hal/command_buffer.cc b/iree/hal/command_buffer.cc deleted file mode 100644 index d83810a..0000000 --- a/iree/hal/command_buffer.cc +++ /dev/null
@@ -1,29 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/command_buffer.h" - -namespace iree { -namespace hal { - -std::string CommandCategoryString(CommandCategoryBitfield categories) { - return FormatBitfieldValue(categories, - { - {CommandCategory::kTransfer, "kTransfer"}, - {CommandCategory::kDispatch, "kDispatch"}, - }); -} - -} // namespace hal -} // namespace iree
diff --git a/iree/hal/command_buffer.h b/iree/hal/command_buffer.h deleted file mode 100644 index 9b32232..0000000 --- a/iree/hal/command_buffer.h +++ /dev/null
@@ -1,383 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_COMMAND_BUFFER_H_ -#define IREE_HAL_COMMAND_BUFFER_H_ - -#include <cstdint> - -#include "iree/base/bitfield.h" -#include "iree/base/shape.h" -#include "iree/base/status.h" -#include "iree/hal/allocator.h" -#include "iree/hal/buffer.h" -#include "iree/hal/buffer_view.h" -#include "iree/hal/event.h" -#include "iree/hal/executable.h" -#include "iree/hal/resource.h" - -namespace iree { -namespace hal { - -// A bitfield specifying the mode of operation for a command buffer. -enum class CommandBufferMode : uint32_t { - // Command buffer will be submitted once and never used again. - // This may enable in-place patching of command buffers that reduce overhead - // when it's known that command buffers will not be reused. - kOneShot = 1 << 0, -}; -IREE_BITFIELD(CommandBufferMode); -using CommandBufferModeBitfield = CommandBufferMode; -std::string CommandBufferModeString(CommandBufferModeBitfield mode); - -// A bitfield specifying the category of commands in a command queue. -enum class CommandCategory : uint32_t { - // Command is considered a transfer operation (memcpy, etc). - kTransfer = 1 << 0, - // Command is considered a dispatch operation (dispatch/execute). - kDispatch = 1 << 1, -}; -IREE_BITFIELD(CommandCategory); -using CommandCategoryBitfield = CommandCategory; -std::string CommandCategoryString(CommandCategoryBitfield categories); - -// Bitfield specifying which execution stage a brarrier should start/end at. -// -// Maps to VkPipelineStageFlagBits. -enum class ExecutionStage : uint32_t { - // Top of the pipeline when commands are initially issued by the device. - kCommandIssue = 1 << 0, - // Stage of the pipeline when dispatch parameter data is consumed. - kCommandProcess = 1 << 1, - // Stage where dispatch commands execute. - kDispatch = 1 << 2, - // Stage where transfer (copy/clear/fill/etc) commands execute. - kTransfer = 1 << 3, - // Final stage in the pipeline when commands are retired on the device. - kCommandRetire = 1 << 4, - // Pseudo-stage for read/writes by the host. Not executed on device. - kHost = 1 << 5, -}; -IREE_BITFIELD(ExecutionStage); -using ExecutionStageBitfield = ExecutionStage; - -// Bitfield specifying which scopes will access memory and how. -// -// Maps to VkAccessFlagBits. -enum class AccessScope : uint32_t { - // Read access to indirect command data as part of an indirect dispatch. - kIndirectCommandRead = 1 << 0, - // Constant uniform buffer reads by the device. - kConstantRead = 1 << 1, - // Storage buffer reads by dispatch commands. - kDispatchRead = 1 << 2, - // Storage buffer writes by dispatch commands. - kDispatchWrite = 1 << 3, - // Source of a transfer operation. - kTransferRead = 1 << 4, - // Target of a transfer operation. - kTransferWrite = 1 << 5, - // Read operation by the host through mapped memory. - kHostRead = 1 << 6, - // Write operation by the host through mapped memory. - kHostWrite = 1 << 7, - // External/non-specific read. - kMemoryRead = 1 << 8, - // External/non-specific write. - kMemoryWrite = 1 << 9, -}; -IREE_BITFIELD(AccessScope); -using AccessScopeBitfield = AccessScope; - -// Defines a global memory barrier. -// These are cheaper to encode than buffer-specific barriers but may cause -// stalls and bubbles in device pipelines if applied too broadly. Prefer them -// over equivalently large sets of buffer-specific barriers (such as when -// completely changing execution contexts). -// -// Maps to VkMemoryBarrier. -struct MemoryBarrier { - // All access scopes prior-to the barrier (inclusive). - AccessScopeBitfield source_scope; - // All access scopes following the barrier (inclusive). - AccessScopeBitfield target_scope; -}; - -// Defines a memory barrier that applies to a range of a specific buffer. -// Use of these (vs. global memory barriers) provides fine-grained execution -// ordering to device command processors and allows for more aggressive -// reordering. -// -// Maps to VkBufferMemoryBarrier. -struct BufferBarrier { - // All access scopes prior-to the barrier (inclusive). - AccessScopeBitfield source_scope; - // All access scopes following the barrier (inclusive). - AccessScopeBitfield target_scope; - // Buffer the barrier is restricted to. - // The barrier will apply to the entire physical device allocation. - Buffer* buffer = nullptr; - // Relative offset/length within |buffer| (which may itself be mapped into the - // device allocation at an offset). - device_size_t offset = 0; - device_size_t length = kWholeBuffer; -}; - -// Represents a binding to a buffer with a set of attributes. -// This may be used by drivers to validate alignment. -struct BufferBinding { - // Access rights of the buffer contents by the executable. - MemoryAccessBitfield access = MemoryAccess::kAll; - - // The buffer this binding references. - // The buffer is not retained by the binding and must be kept alive externally - // for the duration it is in use by the queue. - Buffer* buffer = nullptr; - - // Shape of the buffer contents. - Shape shape; - - // Size of each element within the buffer, in bytes. - int8_t element_size = 0; - - BufferBinding() = default; - BufferBinding(MemoryAccessBitfield access, Buffer* buffer) - : access(access), buffer(buffer) {} - BufferBinding(MemoryAccessBitfield access, Buffer* buffer, Shape shape, - int8_t element_size) - : access(access), - buffer(buffer), - shape(shape), - element_size(element_size) {} - BufferBinding(MemoryAccessBitfield access, const BufferView& buffer_view) - : access(access), - buffer(buffer_view.buffer.get()), - shape(buffer_view.shape), - element_size(buffer_view.element_size) {} -}; - -// Wraps parameters for a Dispatch request. -struct DispatchRequest { - // Executable prepared for use on the device. - // The executable must remain alive until all in-flight dispatch requests - // that use it have completed. - Executable* executable = nullptr; - - // Executable entry point ordinal. - int entry_point = 0; - - // TODO(benvanik): predication. - - // Static workload parameters defining the X, Y, and Z workgroup counts. - std::array<int32_t, 3> workload; - - // An optional buffer containing the dynamic workload to dispatch. - // The contents need not be available at the time of recording but must be - // made visible prior to execution of the dispatch command. - // - // Buffer contents are expected to be 3 int32 values defining the X, Y, and Z - // workgroup counts. - // - // The buffer must have been allocated with BufferUsage::kDispatch and be - // of MemoryType::kDeviceVisible. - Buffer* workload_buffer = nullptr; - - // A list of buffers that contain the execution inputs/outputs. - // Order is dependent on executable arg layout. - // - // Buffers must have been allocated with BufferUsage::kDispatch and be - // of MemoryType::kDeviceVisible. - absl::Span<const BufferBinding> bindings; - - // TODO(benvanik): push-constant equivalent (uniforms, etc). -}; - -// Asynchronous command buffer recording interface. -// Commands are recorded by the implementation for later submission to command -// queues. -// -// Buffers and synchronization objects referenced must remain valid and not be -// modified or read while there are commands in-flight. The usual flow is to -// populate input buffers, Dispatch using those buffers, wait on a Fence until -// the buffers are guaranteed to no longer be in use, and then reuse or release -// the buffers. -// -// Errors that can be recognized when operations are enqueued will be returned -// immediately, such as invalid argument errors. Errors that can only be -// determined at execution time will be returned on fences. Once a failure -// occurs the device queue will enter an error state that invalidates all -// operations on the device queue (as ordering is not strict and any may still -// be in-flight). In this case the user of the device queue should treat all -// in-flight operations as cancelled and fully reset themselves. Other device -// queues that may be waiting on events from the device queue will also enter -// error states. Only once a user has acknowledged and cleared the error state -// with a Reset the queue will become usable, and otherwise all operations will -// return errors. -// -// Command buffers are thread-compatible. Use multiple command buffers if trying -// to record commands from multiple threads. Command buffers must not be mutated -// between when they have are submitted for execution on a queue and when the -// fence fires indicating the completion of their execution. -class CommandBuffer : public Resource { - public: - virtual CommandBuffer* impl() { return this; } - - // Device allocator that commands encoded into the buffer share compatibility - // with. - Allocator* allocator() const { return allocator_; } - - // Command buffer operation mode. - CommandBufferModeBitfield mode() const { return mode_; } - - // Command categories that may be recorded into the buffer. - CommandCategoryBitfield command_categories() const { - return command_categories_; - } - - // True if the command buffer is between a Begin/End recording block. - virtual bool is_recording() const = 0; - - // Resets and begins recording into the command buffer, clearing all - // previously recorded contents. - // The command buffer must not be in-flight. - virtual Status Begin() = 0; - - // Ends recording into the command buffer. - // This must be called prior to submitting the command buffer for execution. - virtual Status End() = 0; - - // TODO(benvanik): annotations for debugging and tracing: - // enter/exit - // stack frame manipulation - // explicit timers? or profiling buffer? - - // TODO(b/138719910): cross-queue and external acquire/release. - // virtual Status AcquireBuffer() = 0; - // virtual Status ReleaseBuffer() = 0; - - // Defines a memory dependency between commands recorded before and after the - // barrier. One or more memory or buffer barriers can be specified to indicate - // between which stages or buffers the dependencies exist. - virtual Status ExecutionBarrier( - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span<const MemoryBarrier> memory_barriers, - absl::Span<const BufferBarrier> buffer_barriers) = 0; - - // Sets an event to the signaled state. - // |source_stage_mask| specifies when the event is signaled. - // - // Events are only valid within a single command buffer. Events can only be - // used on non-transfer queues. - virtual Status SignalEvent(Event* event, - ExecutionStageBitfield source_stage_mask) = 0; - - // Resets an event to the non-signaled state. - // |source_stage_mask| specifies when the event is unsignaled. - // - // Events are only valid within a single command buffer. Events can only be - // used on non-transfer queues. - virtual Status ResetEvent(Event* event, - ExecutionStageBitfield source_stage_mask) = 0; - - // Waits for one or more events to be signaled and defines a memory dependency - // between the synchronization scope of the signal operations and the commands - // following the wait. - // - // |source_stage_mask| must include ExecutionStage::kHost for Event::Signal to - // be visibile. - // - // Events are only valid within a single command buffer. Events remain - // signaled even after waiting and must be reset to be reused. Events can only - // be used on non-transfer queues. - virtual Status WaitEvents( - absl::Span<Event*> events, ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span<const MemoryBarrier> memory_barriers, - absl::Span<const BufferBarrier> buffer_barriers) = 0; - - // Fills the target buffer with the given repeating value. - // Expects that value_length is one of 1, 2, or 4 and that the offset and - // length are aligned to the natural alignment of the value. - // The target buffer must be compatible with the devices owned by this - // device queue and be allocated with BufferUsage::kTransfer. - virtual Status FillBuffer(Buffer* target_buffer, device_size_t target_offset, - device_size_t length, const void* pattern, - size_t pattern_length) = 0; - - // Hints to the device queue that the given buffer will not be used again. - // After encoding a discard the buffer contents will be considered undefined. - // This is because the discard may be used to elide write backs to host memory - // or aggressively reuse the allocation for other purposes. - // - // For buffers allocated with MemoryType::kTransient this may allow - // the device queue to reclaim the memory used by the buffer earlier than - // otherwise possible. - virtual Status DiscardBuffer(Buffer* buffer) = 0; - - // Updates a range of the given target buffer from the source host memory. - // The source host memory is copied immediately into the command buffer and - // occupies command buffer space. It is strongly recommended that large buffer - // updates are performed via CopyBuffer where there is the possibility of a - // zero-copy path. - // The |source_buffer| may be releaed by the caller immediately after this - // call returns. - // The |target_buffer| must be compatible with the devices owned by this - // device queue and be allocated with BufferUsage::kTransfer. - virtual Status UpdateBuffer(const void* source_buffer, - device_size_t source_offset, - Buffer* target_buffer, - device_size_t target_offset, - device_size_t length) = 0; - - // Copies a range of one buffer to another. - // Both buffers must be compatible with the devices owned by this device - // queue and be allocated with BufferUsage::kTransfer. Though the source and - // target buffer may be the same the ranges must not overlap (as with memcpy). - // - // This can be used to perform device->host, host->device, and device->device - // copies. - virtual Status CopyBuffer(Buffer* source_buffer, device_size_t source_offset, - Buffer* target_buffer, device_size_t target_offset, - device_size_t length) = 0; - - // Dispatches an execution request. - // The request may execute overlapped with any other transfer operation or - // dispatch made within the same barrier-defined sequence. - // - // The executable specified must be registered for use with the device driver - // owning this queue. It must not be unregistered until all requests that use - // it have completed. - // - // Fails if the queue does not support dispatch operations (as indicated by - // can_dispatch). - virtual Status Dispatch(const DispatchRequest& dispatch_request) = 0; - - protected: - CommandBuffer(Allocator* allocator, CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories) - : allocator_(allocator), - mode_(mode), - command_categories_(command_categories) {} - - private: - Allocator* const allocator_; - const CommandBufferModeBitfield mode_; - const CommandCategoryBitfield command_categories_; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_COMMAND_BUFFER_H_
diff --git a/iree/hal/command_buffer_validation.cc b/iree/hal/command_buffer_validation.cc deleted file mode 100644 index c78580f..0000000 --- a/iree/hal/command_buffer_validation.cc +++ /dev/null
@@ -1,403 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/command_buffer_validation.h" - -#include "iree/base/logging.h" -#include "iree/base/status.h" - -namespace iree { -namespace hal { - -namespace { - -// Command buffer validation shim. -// Wraps an existing command buffer to provide in-depth validation during -// recording. This should be enabled whenever the command buffer is being driven -// by unsafe code or when early and readable diagnostics are needed. -class ValidatingCommandBuffer : public CommandBuffer { - public: - explicit ValidatingCommandBuffer(ref_ptr<CommandBuffer> impl); - ~ValidatingCommandBuffer() override; - - CommandBuffer* impl() override { return impl_.get(); } - - bool is_recording() const override; - - Status Begin() override; - Status End() override; - - Status ExecutionBarrier( - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span<const MemoryBarrier> memory_barriers, - absl::Span<const BufferBarrier> buffer_barriers) override; - Status SignalEvent(Event* event, - ExecutionStageBitfield source_stage_mask) override; - Status ResetEvent(Event* event, - ExecutionStageBitfield source_stage_mask) override; - Status WaitEvents(absl::Span<Event*> events, - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span<const MemoryBarrier> memory_barriers, - absl::Span<const BufferBarrier> buffer_barriers) override; - Status FillBuffer(Buffer* target_buffer, device_size_t target_offset, - device_size_t length, const void* pattern, - size_t pattern_length) override; - Status DiscardBuffer(Buffer* buffer) override; - Status UpdateBuffer(const void* source_buffer, device_size_t source_offset, - Buffer* target_buffer, device_size_t target_offset, - device_size_t length) override; - Status CopyBuffer(Buffer* source_buffer, device_size_t source_offset, - Buffer* target_buffer, device_size_t target_offset, - device_size_t length) override; - Status Dispatch(const DispatchRequest& dispatch_request) override; - - private: - // Returns a failure if the queue does not support the given caps. - Status ValidateCategories(CommandCategoryBitfield required_categories) const; - // Returns a failure if the memory type the buffer was allocated from is not - // compatible with the given type. - Status ValidateCompatibleMemoryType(Buffer* buffer, - MemoryTypeBitfield memory_type) const; - // Returns a failure if the buffer memory type or usage disallows the given - // access type. - Status ValidateAccess(Buffer* buffer, - MemoryAccessBitfield memory_access) const; - // Returns a failure if the buffer was not allocated for the given usage. - Status ValidateUsage(Buffer* buffer, BufferUsageBitfield usage) const; - // Validates that the range provided is within the given buffer. - Status ValidateRange(Buffer* buffer, device_size_t byte_offset, - device_size_t byte_length) const; - - ref_ptr<CommandBuffer> impl_; -}; - -ValidatingCommandBuffer::ValidatingCommandBuffer(ref_ptr<CommandBuffer> impl) - : CommandBuffer(impl->allocator(), impl->mode(), - impl->command_categories()), - impl_(std::move(impl)) {} - -ValidatingCommandBuffer::~ValidatingCommandBuffer() = default; - -bool ValidatingCommandBuffer::is_recording() const { - return impl_->is_recording(); -} - -Status ValidatingCommandBuffer::Begin() { - DVLOG(3) << "CommandBuffer::Begin()"; - if (impl_->is_recording()) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Command buffer is already recording"; - } - return impl_->Begin(); -} - -Status ValidatingCommandBuffer::End() { - DVLOG(3) << "CommandBuffer::End()"; - if (!impl_->is_recording()) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Command buffer is not recording"; - } - return impl_->End(); -} - -Status ValidatingCommandBuffer::ValidateCategories( - CommandCategoryBitfield required_categories) const { - if (!AllBitsSet(command_categories(), required_categories)) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Operation requires categories " - << CommandCategoryString(required_categories) - << " but buffer only supports " - << CommandCategoryString(command_categories()); - } - return OkStatus(); -} - -Status ValidatingCommandBuffer::ValidateCompatibleMemoryType( - Buffer* buffer, MemoryTypeBitfield memory_type) const { - if ((buffer->memory_type() & memory_type) != memory_type) { - // Missing one or more bits. - return PermissionDeniedErrorBuilder(IREE_LOC) - << "Buffer memory type is not compatible with the requested " - "operation; buffer has " - << MemoryTypeString(buffer->memory_type()) << ", operation requires " - << MemoryTypeString(memory_type); - } - return OkStatus(); -} - -Status ValidatingCommandBuffer::ValidateAccess( - Buffer* buffer, MemoryAccessBitfield memory_access) const { - if ((buffer->allowed_access() & memory_access) != memory_access) { - // Bits must match exactly. - return PermissionDeniedErrorBuilder(IREE_LOC) - << "The buffer does not support the requested access type; buffer " - "allows " - << MemoryAccessString(buffer->allowed_access()) - << ", operation requires " << MemoryAccessString(memory_access); - } - return OkStatus(); -} - -// Returns a failure if the buffer was not allocated for the given usage. -Status ValidatingCommandBuffer::ValidateUsage(Buffer* buffer, - BufferUsageBitfield usage) const { - if (!allocator()->CanUseBuffer(buffer, usage)) { - // Buffer cannot be used on the queue for the given usage. - return PermissionDeniedErrorBuilder(IREE_LOC) - << "Requested usage of " << buffer->DebugString() - << " is not supported for the buffer on this queue; " - "buffer allows " - << BufferUsageString(buffer->usage()) << ", queue requires " - << BufferUsageString(usage); - } - - if ((buffer->usage() & usage) != usage) { - // Missing one or more bits. - return PermissionDeniedErrorBuilder(IREE_LOC) - << "Requested usage was not specified when the buffer was " - "allocated; buffer allows " - << BufferUsageString(buffer->usage()) << ", operation requires " - << BufferUsageString(usage); - } - - return OkStatus(); -} - -// Validates that the range provided is within the given buffer. -Status ValidatingCommandBuffer::ValidateRange(Buffer* buffer, - device_size_t byte_offset, - device_size_t byte_length) const { - // Check if the start of the range runs off the end of the buffer. - if (byte_offset > buffer->byte_length()) { - return OutOfRangeErrorBuilder(IREE_LOC) - << "Attempted to access an address off the end of the valid buffer " - "range (offset=" - << byte_offset << ", length=" << byte_length - << ", buffer byte_length=" << buffer->byte_length() << ")"; - } - - if (byte_length == 0) { - // Fine to have a zero length. - return OkStatus(); - } - - // Check if the end runs over the allocation. - device_size_t end = byte_offset + byte_length; - if (end > buffer->byte_length()) { - return OutOfRangeErrorBuilder(IREE_LOC) - << "Attempted to access an address outside of the valid buffer " - "range (offset=" - << byte_offset << ", length=" << byte_length - << ", end(inc)=" << (end - 1) - << ", buffer byte_length=" << buffer->byte_length() << ")"; - } - - return OkStatus(); -} - -Status ValidatingCommandBuffer::ExecutionBarrier( - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span<const MemoryBarrier> memory_barriers, - absl::Span<const BufferBarrier> buffer_barriers) { - DVLOG(3) << "CommandBuffer::ExecutionBarrier(...)"; - - // TODO(benvanik): additional synchronization validation. - RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer | - CommandCategory::kDispatch)); - - return impl_->ExecutionBarrier(source_stage_mask, target_stage_mask, - memory_barriers, buffer_barriers); -} - -Status ValidatingCommandBuffer::SignalEvent( - Event* event, ExecutionStageBitfield source_stage_mask) { - DVLOG(3) << "CommandBuffer::SignalEvent(...)"; - - // TODO(benvanik): additional synchronization validation. - RETURN_IF_ERROR(ValidateCategories(CommandCategory::kDispatch)); - - return impl_->SignalEvent(event, source_stage_mask); -} - -Status ValidatingCommandBuffer::ResetEvent( - Event* event, ExecutionStageBitfield source_stage_mask) { - DVLOG(3) << "CommandBuffer::ResetEvent(...)"; - - // TODO(benvanik): additional synchronization validation. - RETURN_IF_ERROR(ValidateCategories(CommandCategory::kDispatch)); - - return impl_->ResetEvent(event, source_stage_mask); -} - -Status ValidatingCommandBuffer::WaitEvents( - absl::Span<Event*> events, ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span<const MemoryBarrier> memory_barriers, - absl::Span<const BufferBarrier> buffer_barriers) { - DVLOG(3) << "CommandBuffer::WaitEvents(...)"; - - // TODO(benvanik): additional synchronization validation. - RETURN_IF_ERROR(ValidateCategories(CommandCategory::kDispatch)); - - return impl_->WaitEvents(events, source_stage_mask, target_stage_mask, - memory_barriers, buffer_barriers); -} - -Status ValidatingCommandBuffer::FillBuffer(Buffer* target_buffer, - device_size_t target_offset, - device_size_t length, - const void* pattern, - size_t pattern_length) { - DVLOG(3) << "CommandBuffer::FillBuffer(" << target_buffer->DebugString() - << ", " << target_offset << ", " << length << ", ??, " - << pattern_length << ")"; - - RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer)); - RETURN_IF_ERROR( - ValidateCompatibleMemoryType(target_buffer, MemoryType::kDeviceVisible)); - RETURN_IF_ERROR(ValidateAccess(target_buffer, MemoryAccess::kWrite)); - RETURN_IF_ERROR(ValidateUsage(target_buffer, BufferUsage::kTransfer)); - RETURN_IF_ERROR(ValidateRange(target_buffer, target_offset, length)); - - // Ensure the value length is supported. - if (pattern_length != 1 && pattern_length != 2 && pattern_length != 4) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Fill value length is not one of the supported values " - "(pattern_length=" - << pattern_length << ")"; - } - - // Ensure the offset and length have an alignment matching the value length. - if ((target_offset % pattern_length) != 0 || (length % pattern_length) != 0) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Fill offset and/or length do not match the natural alignment of " - "the fill value (target_offset=" - << target_offset << ", length=" << length - << ", pattern_length=" << pattern_length << ")"; - } - - return impl_->FillBuffer(target_buffer, target_offset, length, pattern, - pattern_length); -} - -Status ValidatingCommandBuffer::DiscardBuffer(Buffer* buffer) { - DVLOG(3) << "CommandBuffer::DiscardBuffer(" << buffer->DebugString() << ")"; - - RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer)); - RETURN_IF_ERROR( - ValidateCompatibleMemoryType(buffer, MemoryType::kDeviceVisible)); - RETURN_IF_ERROR(ValidateUsage(buffer, BufferUsage::kNone)); - - return impl_->DiscardBuffer(buffer); -} - -Status ValidatingCommandBuffer::UpdateBuffer(const void* source_buffer, - device_size_t source_offset, - Buffer* target_buffer, - device_size_t target_offset, - device_size_t length) { - DVLOG(3) << "CommandBuffer::UpdateBuffer(" << source_buffer << ", " - << source_offset << ", " << target_buffer->DebugString() << ", " - << target_offset << ", " << length << ")"; - - RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer)); - RETURN_IF_ERROR( - ValidateCompatibleMemoryType(target_buffer, MemoryType::kDeviceVisible)); - RETURN_IF_ERROR(ValidateAccess(target_buffer, MemoryAccess::kWrite)); - RETURN_IF_ERROR(ValidateUsage(target_buffer, BufferUsage::kTransfer)); - RETURN_IF_ERROR(ValidateRange(target_buffer, target_offset, length)); - - return impl_->UpdateBuffer(source_buffer, source_offset, target_buffer, - target_offset, length); -} - -Status ValidatingCommandBuffer::CopyBuffer(Buffer* source_buffer, - device_size_t source_offset, - Buffer* target_buffer, - device_size_t target_offset, - device_size_t length) { - DVLOG(3) << "CommandBuffer::CopyBuffer(" << source_buffer->DebugString() - << ", " << source_offset << ", " << target_buffer->DebugString() - << ", " << target_offset << ", " << length << ")"; - - RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer)); - - // At least source or destination must be device-visible to enable - // host->device, device->host, and device->device. - // TODO(b/117338171): host->host copies. - if (!AnyBitSet(source_buffer->memory_type() & MemoryType::kDeviceVisible) && - !AnyBitSet(target_buffer->memory_type() & MemoryType::kDeviceVisible)) { - return PermissionDeniedErrorBuilder(IREE_LOC) - << "At least one buffer must be device-visible for a copy; " - "source_buffer=" - << MemoryTypeString(source_buffer->memory_type()) - << ", target_buffer=" - << MemoryTypeString(target_buffer->memory_type()); - } - - RETURN_IF_ERROR(ValidateAccess(source_buffer, MemoryAccess::kRead)); - RETURN_IF_ERROR(ValidateAccess(target_buffer, MemoryAccess::kWrite)); - RETURN_IF_ERROR(ValidateUsage(source_buffer, BufferUsage::kTransfer)); - RETURN_IF_ERROR(ValidateUsage(target_buffer, BufferUsage::kTransfer)); - RETURN_IF_ERROR(ValidateRange(source_buffer, source_offset, length)); - RETURN_IF_ERROR(ValidateRange(target_buffer, target_offset, length)); - - // Check for overlap - just like memcpy we don't handle that. - if (Buffer::TestOverlap(source_buffer, source_offset, length, target_buffer, - target_offset, - length) != Buffer::Overlap::kDisjoint) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Source and target ranges overlap within the same buffer"; - } - - return impl_->CopyBuffer(source_buffer, source_offset, target_buffer, - target_offset, length); -} - -Status ValidatingCommandBuffer::Dispatch( - const DispatchRequest& dispatch_request) { - DVLOG(3) << "CommandBuffer::Dispatch(?)"; - - RETURN_IF_ERROR(ValidateCategories(CommandCategory::kDispatch)); - - // Validate all buffers referenced have compatible memory types, access - // rights, and usage. - for (const auto& binding : dispatch_request.bindings) { - RETURN_IF_ERROR(ValidateCompatibleMemoryType(binding.buffer, - MemoryType::kDeviceVisible)) - << "input buffer: " << MemoryAccessString(binding.access) << " " - << binding.buffer->DebugStringShort(); - RETURN_IF_ERROR(ValidateAccess(binding.buffer, binding.access)); - RETURN_IF_ERROR(ValidateUsage(binding.buffer, BufferUsage::kDispatch)); - // TODO(benvanik): validate it matches the executable expectations. - // TODO(benvanik): validate buffer contains enough data for shape+size. - } - - // TODO(benvanik): validate no aliasing? - - return impl_->Dispatch(dispatch_request); -} - -} // namespace - -ref_ptr<CommandBuffer> WrapCommandBufferWithValidation( - ref_ptr<CommandBuffer> impl) { - return make_ref<ValidatingCommandBuffer>(std::move(impl)); -} - -} // namespace hal -} // namespace iree
diff --git a/iree/hal/command_buffer_validation.h b/iree/hal/command_buffer_validation.h deleted file mode 100644 index 036f132..0000000 --- a/iree/hal/command_buffer_validation.h +++ /dev/null
@@ -1,32 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_COMMAND_BUFFER_VALIDATION_H_ -#define IREE_HAL_COMMAND_BUFFER_VALIDATION_H_ - -#include "iree/hal/command_buffer.h" - -namespace iree { -namespace hal { - -// Wraps an existing command buffer to provide in-depth validation during -// recording. This should be enabled whenever the command buffer is being driven -// by unsafe code or when early and readable diagnostics are needed. -ref_ptr<CommandBuffer> WrapCommandBufferWithValidation( - ref_ptr<CommandBuffer> impl); - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_COMMAND_BUFFER_VALIDATION_H_
diff --git a/iree/hal/command_queue.h b/iree/hal/command_queue.h deleted file mode 100644 index 0e8eb5d..0000000 --- a/iree/hal/command_queue.h +++ /dev/null
@@ -1,119 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_COMMAND_QUEUE_H_ -#define IREE_HAL_COMMAND_QUEUE_H_ - -#include <cstdint> -#include <string> - -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "absl/types/span.h" -#include "iree/base/bitfield.h" -#include "iree/base/status.h" -#include "iree/base/time.h" -#include "iree/hal/command_buffer.h" -#include "iree/hal/fence.h" -#include "iree/hal/semaphore.h" - -namespace iree { -namespace hal { - -// A batch of command buffers with synchronization information for submission. -struct SubmissionBatch { - // Semaphores that must be signaled prior to the execution of any command - // buffer in this submission. For TimelineSemaphores the specified payload - // must be reached or exceeded. - absl::Span<const SemaphoreValue> wait_semaphores; - - // Command buffers that will execute in this batch. - // The command buffers will begin execution in order but may complete out of - // order. - absl::Span<CommandBuffer* const> command_buffers; - - // Semaphores to signal after execution of all command buffers complete. - // TimelineSemaphores will be set to the maximum of the specified payload or - // their current payload. - absl::Span<const SemaphoreValue> signal_semaphores; -}; - -// Asynchronous command execution queue. -// -// CommandQueues may capture device status at Fence barriers, including -// information about device state such as thermal throttling. This information -// is a snapshot of the state at the time the fence was signaled and not -// necessarily live at the time of the application query. -// -// Command queues are thread-safe and submissions may occur from multiple -// threads. -class CommandQueue { - public: - virtual ~CommandQueue() = default; - - // Name of the queue used for logging purposes. - // Try to keep at 4 characters total for prettier logging. - const std::string& name() const { return name_; } - - // Capabilities of the command queue. - CommandCategoryBitfield supported_categories() const { - return supported_categories_; - } - - // Whether this queue may be used for transfer commands. - bool can_transfer() const { - return AllBitsSet(supported_categories_, CommandCategory::kTransfer); - } - - // Whether this queue may be used for dispatch commands. - bool can_dispatch() const { - return AllBitsSet(supported_categories_, CommandCategory::kDispatch); - } - - // Submits one or more command batches for execution on the queue. - // Dependencies between |batches| on BinarySemaphores must be sorted in order - // such that all semaphores are signaled prior to any waits on them. - // Dependencies between TimelineSemaphores may occur in any order. - // - // The provided |fence| will be signaled when all |batches| have retired. - virtual Status Submit(absl::Span<const SubmissionBatch> batches, - FenceValue fence) = 0; - inline Status Submit(const SubmissionBatch& batch, FenceValue fence) { - return Submit(absl::MakeConstSpan(&batch, 1), std::move(fence)); - } - - // Blocks until all outstanding requests have been completed. - // This is equivalent to having waited on all outstanding fences. - // Implicitly calls Flush to ensure delayed requests are scheduled. - // - // If the command queue has encountered an error during submission at any - // point it will be returned here (repeatedly). - virtual Status WaitIdle(absl::Time deadline) = 0; - inline Status WaitIdle(absl::Duration timeout) { - return WaitIdle(RelativeTimeoutToDeadline(timeout)); - } - inline Status WaitIdle() { return WaitIdle(absl::InfiniteFuture()); } - - protected: - CommandQueue(std::string name, CommandCategoryBitfield supported_categories) - : name_(std::move(name)), supported_categories_(supported_categories) {} - - const std::string name_; - const CommandCategoryBitfield supported_categories_; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_COMMAND_QUEUE_H_
diff --git a/iree/hal/dawn/BUILD b/iree/hal/dawn/BUILD deleted file mode 100644 index 2c6293a..0000000 --- a/iree/hal/dawn/BUILD +++ /dev/null
@@ -1,72 +0,0 @@ -# HAL implementation using Dawn and SPIR-V executables. -# https://dawn.googlesource.com/dawn - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "dawn_device", - srcs = ["dawn_device.cc"], - hdrs = ["dawn_device.h"], - deps = [ - "//iree/base:memory", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:command_queue", - "//iree/hal:device", - "//iree/hal:executable_cache", - "//iree/hal:fence", - "//iree/hal/host:host_local_allocator", - "//third_party/dawn:dawn_headers", - "//third_party/dawn:dawn_native", - "//third_party/dawn:dawn_static_proc", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "dawn_driver", - srcs = ["dawn_driver.cc"], - hdrs = ["dawn_driver.h"], - deps = [ - ":dawn_device", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:device_info", - "//iree/hal:driver", - "//third_party/dawn:dawn_headers", - "//third_party/dawn:dawn_native", - "//third_party/dawn:dawn_static_proc", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - ], -) - -# TODO(scotttodd): Use SwiftShader to test Vulkan backend -cc_test( - name = "dawn_driver_test", - srcs = ["dawn_driver_test.cc"], - deps = [ - ":dawn_driver", - "//iree/base:status_matchers", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "dawn_driver_module", - srcs = ["dawn_driver_module.cc"], - deps = [ - ":dawn_driver", - "//iree/base:init", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:driver_registry", - "@com_google_absl//absl/flags:flag", - ], - alwayslink = 1, -)
diff --git a/iree/hal/dawn/dawn_device.cc b/iree/hal/dawn/dawn_device.cc deleted file mode 100644 index 7888a72..0000000 --- a/iree/hal/dawn/dawn_device.cc +++ /dev/null
@@ -1,139 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/dawn/dawn_device.h" - -#include "absl/memory/memory.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/command_queue.h" -#include "iree/hal/executable_cache.h" -#include "iree/hal/fence.h" - -namespace iree { -namespace hal { -namespace dawn { - -namespace { - -// ExecutableCache implementation that compiles but does nothing. -// This will be replaced with something functional soon. -class NoopExecutableCache final : public ExecutableCache { - public: - explicit NoopExecutableCache() {} - ~NoopExecutableCache() override = default; - - bool CanPrepareFormat(ExecutableFormat format) const override { - return false; - } - - StatusOr<ref_ptr<Executable>> PrepareExecutable( - ExecutableCachingModeBitfield mode, const ExecutableSpec& spec) override { - return UnimplementedErrorBuilder(IREE_LOC) << "PrepareExecutable NYI"; - } -}; - -} // namespace - -DawnDevice::DawnDevice(const DeviceInfo& device_info, - ::dawn::Device backend_device) - : Device(device_info), backend_device_(backend_device) { - IREE_TRACE_SCOPE0("DawnDevice::ctor"); - - // TODO(scotttodd): construct command queues, perform other initialization - - // Log some basic device info. - std::string backend_type_str; - auto* adapter = - static_cast<dawn_native::Adapter*>(device_info.driver_handle()); - switch (adapter->GetBackendType()) { - case dawn_native::BackendType::D3D12: - backend_type_str = "D3D12"; - break; - case dawn_native::BackendType::Metal: - backend_type_str = "Metal"; - break; - case dawn_native::BackendType::Null: - backend_type_str = "Null"; - break; - case dawn_native::BackendType::OpenGL: - backend_type_str = "OpenGL"; - break; - case dawn_native::BackendType::Vulkan: - backend_type_str = "Vulkan"; - break; - } - LOG(INFO) << "Created DawnDevice '" << device_info.name() << "' (" - << backend_type_str << ")"; -} - -DawnDevice::~DawnDevice() = default; - -std::shared_ptr<ExecutableCache> DawnDevice::CreateExecutableCache() { - return std::make_shared<NoopExecutableCache>(); -} - -StatusOr<ref_ptr<CommandBuffer>> DawnDevice::CreateCommandBuffer( - CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories) { - return UnimplementedErrorBuilder(IREE_LOC) << "CreateCommandBuffer NYI"; -} - -StatusOr<ref_ptr<Event>> DawnDevice::CreateEvent() { - return UnimplementedErrorBuilder(IREE_LOC) << "CreateEvent NYI"; -} - -StatusOr<ref_ptr<BinarySemaphore>> DawnDevice::CreateBinarySemaphore( - bool initial_value) { - IREE_TRACE_SCOPE0("DawnDevice::CreateBinarySemaphore"); - - return UnimplementedErrorBuilder(IREE_LOC) - << "Binary semaphores not yet implemented"; -} - -StatusOr<ref_ptr<TimelineSemaphore>> DawnDevice::CreateTimelineSemaphore( - uint64_t initial_value) { - IREE_TRACE_SCOPE0("DawnDevice::CreateTimelineSemaphore"); - - return UnimplementedErrorBuilder(IREE_LOC) - << "Timeline semaphores not yet implemented"; -} - -StatusOr<ref_ptr<Fence>> DawnDevice::CreateFence(uint64_t initial_value) { - IREE_TRACE_SCOPE0("DawnDevice::CreateFence"); - - return UnimplementedErrorBuilder(IREE_LOC) << "CreateFence NYI"; -} - -Status DawnDevice::WaitAllFences(absl::Span<const FenceValue> fences, - absl::Time deadline) { - IREE_TRACE_SCOPE0("DawnDevice::WaitAllFences"); - - return UnimplementedErrorBuilder(IREE_LOC) << "WaitAllFences NYI"; -} - -StatusOr<int> DawnDevice::WaitAnyFence(absl::Span<const FenceValue> fences, - absl::Time deadline) { - IREE_TRACE_SCOPE0("DawnDevice::WaitAnyFence"); - - return UnimplementedErrorBuilder(IREE_LOC) << "WaitAnyFence NYI"; -} - -Status DawnDevice::WaitIdle(absl::Time deadline) { - return UnimplementedErrorBuilder(IREE_LOC) << "WaitIdle"; -} - -} // namespace dawn -} // namespace hal -} // namespace iree
diff --git a/iree/hal/dawn/dawn_device.h b/iree/hal/dawn/dawn_device.h deleted file mode 100644 index 117c352..0000000 --- a/iree/hal/dawn/dawn_device.h +++ /dev/null
@@ -1,78 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_DAWN_DAWN_DEVICE_H_ -#define IREE_HAL_DAWN_DAWN_DEVICE_H_ - -#include "absl/container/inlined_vector.h" -#include "absl/types/span.h" -#include "iree/base/memory.h" -#include "iree/hal/device.h" -#include "iree/hal/host/host_local_allocator.h" -#include "third_party/dawn/src/include/dawn/dawncpp.h" -#include "third_party/dawn/src/include/dawn_native/DawnNative.h" - -namespace iree { -namespace hal { -namespace dawn { - -class DawnDevice final : public Device { - public: - explicit DawnDevice(const DeviceInfo& device_info, - ::dawn::Device backend_device); - ~DawnDevice() override; - - Allocator* allocator() const override { return &allocator_; } - - absl::Span<CommandQueue*> dispatch_queues() const override { - return RawPtrSpan(absl::MakeSpan(command_queues_)); - } - - absl::Span<CommandQueue*> transfer_queues() const override { - return RawPtrSpan(absl::MakeSpan(command_queues_)); - } - - std::shared_ptr<ExecutableCache> CreateExecutableCache() override; - - StatusOr<ref_ptr<CommandBuffer>> CreateCommandBuffer( - CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories) override; - - StatusOr<ref_ptr<Event>> CreateEvent() override; - - StatusOr<ref_ptr<BinarySemaphore>> CreateBinarySemaphore( - bool initial_value) override; - StatusOr<ref_ptr<TimelineSemaphore>> CreateTimelineSemaphore( - uint64_t initial_value) override; - - StatusOr<ref_ptr<Fence>> CreateFence(uint64_t initial_value) override; - Status WaitAllFences(absl::Span<const FenceValue> fences, - absl::Time deadline) override; - StatusOr<int> WaitAnyFence(absl::Span<const FenceValue> fences, - absl::Time deadline) override; - - Status WaitIdle(absl::Time deadline) override; - - private: - mutable HostLocalAllocator allocator_; - mutable absl::InlinedVector<std::unique_ptr<CommandQueue>, 1> command_queues_; - - ::dawn::Device backend_device_; -}; - -} // namespace dawn -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_DAWN_DAWN_DEVICE_H_
diff --git a/iree/hal/dawn/dawn_driver.cc b/iree/hal/dawn/dawn_driver.cc deleted file mode 100644 index 9e9067a..0000000 --- a/iree/hal/dawn/dawn_driver.cc +++ /dev/null
@@ -1,120 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/dawn/dawn_driver.h" - -#include "absl/memory/memory.h" -#include "absl/strings/str_cat.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/dawn/dawn_device.h" -#include "iree/hal/device_info.h" - -namespace iree { -namespace hal { -namespace dawn { - -namespace { - -// Populates device information from the given dawn_native::Adapter. -StatusOr<DeviceInfo> PopulateDeviceInfo(dawn_native::Adapter* adapter) { - // TODO(scotttodd): Query these for each backend or implement? - DeviceFeatureBitfield supported_features = DeviceFeature::kNone; - // supported_features |= DeviceFeature::kDebugging; - // supported_features |= DeviceFeature::kCoverage; - // supported_features |= DeviceFeature::kProfiling; - - // TODO(scotttodd): more clever/sanitized device naming. - std::string device_name = absl::StrCat("dawn-", adapter->GetPCIInfo().name); - - return DeviceInfo(device_name, supported_features, - reinterpret_cast<void*>(adapter)); -} - -} // namespace - -DawnDriver::DawnDriver() : Driver("dawn") { - dawn_instance_ = absl::make_unique<dawn_native::Instance>(); -} - -DawnDriver::~DawnDriver() = default; - -StatusOr<std::vector<DeviceInfo>> DawnDriver::EnumerateAvailableDevices() { - IREE_TRACE_SCOPE0("DawnDriver::EnumerateAvailableDevices"); - - if (dawn_backend_adapters_.empty()) { - // Discover adapters (i.e. devices and their associated backend APIs). - // Retain the list of adapters so pointers are valid for the lifetime of - // this object. - dawn_instance_->DiscoverDefaultAdapters(); - dawn_backend_adapters_ = dawn_instance_->GetAdapters(); - } else { - // Assume that the list of adapters does not change. This is not guaranteed - // to be true, but we also don't want to invalidate pointers by requesting - // a new list each time. If the list of available devices would change, - // tearing down and creating a new DawnDriver may be your best option. - } - - // Convert to our HAL structure. - std::vector<DeviceInfo> device_infos; - device_infos.reserve(dawn_backend_adapters_.size()); - for (auto& adapter : dawn_backend_adapters_) { - // TODO(scotttodd): if we fail should we just ignore the device in the list? - ASSIGN_OR_RETURN(auto device_info, PopulateDeviceInfo(&adapter)); - device_infos.push_back(std::move(device_info)); - } - return device_infos; -} - -StatusOr<std::shared_ptr<Device>> DawnDriver::CreateDefaultDevice() { - IREE_TRACE_SCOPE0("DawnDriver::CreateDefaultDevice"); - - // Query available devices. - ASSIGN_OR_RETURN(auto available_devices, EnumerateAvailableDevices()); - if (available_devices.empty()) { - return NotFoundErrorBuilder(IREE_LOC) << "No devices are available"; - } - - // Create the first non-null device, if any. - for (const auto& device : available_devices) { - auto* adapter = static_cast<dawn_native::Adapter*>(device.driver_handle()); - if (adapter->GetBackendType() != dawn_native::BackendType::Null) { - return CreateDevice(device); - } - } - - // Otherwise create the first null device. - return CreateDevice(available_devices.front()); -} - -StatusOr<std::shared_ptr<Device>> DawnDriver::CreateDevice( - const DeviceInfo& device_info) { - IREE_TRACE_SCOPE0("DawnDriver::CreateDevice"); - - auto* adapter = - static_cast<dawn_native::Adapter*>(device_info.driver_handle()); - ::DawnDevice c_backend_device = adapter->CreateDevice(); - if (!c_backend_device) { - return InternalErrorBuilder(IREE_LOC) << "Failed to create a Dawn device"; - } - DawnProcTable backend_procs = dawn_native::GetProcs(); - dawnSetProcs(&backend_procs); - ::dawn::Device backend_device = ::dawn::Device::Acquire(c_backend_device); - - return std::make_shared<DawnDevice>(device_info, backend_device); -} - -} // namespace dawn -} // namespace hal -} // namespace iree
diff --git a/iree/hal/dawn/dawn_driver.h b/iree/hal/dawn/dawn_driver.h deleted file mode 100644 index 4c79c25..0000000 --- a/iree/hal/dawn/dawn_driver.h +++ /dev/null
@@ -1,50 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_DAWN_DAWN_DRIVER_H_ -#define IREE_HAL_DAWN_DAWN_DRIVER_H_ - -#include <memory> -#include <vector> - -#include "iree/hal/driver.h" -#include "third_party/dawn/src/include/dawn/dawncpp.h" -#include "third_party/dawn/src/include/dawn_native/DawnNative.h" - -namespace iree { -namespace hal { -namespace dawn { - -class DawnDriver final : public Driver { - public: - DawnDriver(); - ~DawnDriver() override; - - StatusOr<std::vector<DeviceInfo>> EnumerateAvailableDevices() override; - - StatusOr<std::shared_ptr<Device>> CreateDefaultDevice() override; - - StatusOr<std::shared_ptr<Device>> CreateDevice( - const DeviceInfo& device_info) override; - - private: - std::unique_ptr<dawn_native::Instance> dawn_instance_; - std::vector<dawn_native::Adapter> dawn_backend_adapters_; -}; - -} // namespace dawn -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_DAWN_DAWN_DRIVER_H_
diff --git a/iree/hal/dawn/dawn_driver_module.cc b/iree/hal/dawn/dawn_driver_module.cc deleted file mode 100644 index 2fce4bf..0000000 --- a/iree/hal/dawn/dawn_driver_module.cc +++ /dev/null
@@ -1,41 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <memory> - -#include "iree/base/init.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/dawn/dawn_driver.h" -#include "iree/hal/driver_registry.h" - -namespace iree { -namespace hal { -namespace dawn { -namespace { - -StatusOr<std::shared_ptr<Driver>> CreateDawnDriver() { - return std::make_shared<DawnDriver>(); -} - -} // namespace -} // namespace dawn -} // namespace hal -} // namespace iree - -IREE_REGISTER_MODULE_INITIALIZER(iree_hal_dawn_driver, { - QCHECK_OK(::iree::hal::DriverRegistry::shared_registry()->Register( - "dawn", ::iree::hal::dawn::CreateDawnDriver)); -}); -IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(iree_hal, iree_hal_dawn_driver);
diff --git a/iree/hal/dawn/dawn_driver_test.cc b/iree/hal/dawn/dawn_driver_test.cc deleted file mode 100644 index 07a1561..0000000 --- a/iree/hal/dawn/dawn_driver_test.cc +++ /dev/null
@@ -1,45 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/dawn/dawn_driver.h" - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "iree/base/status_matchers.h" - -namespace iree { -namespace hal { -namespace dawn { -namespace { - -TEST(DawnDriverTest, CreateDefaultDevice) { - DawnDriver dawn_driver; - ASSERT_OK_AND_ASSIGN(auto default_device, dawn_driver.CreateDefaultDevice()); -} - -TEST(DawnDriverTest, EnumerateDevicesAndCreate) { - DawnDriver dawn_driver; - - ASSERT_OK_AND_ASSIGN(auto available_devices, - dawn_driver.EnumerateAvailableDevices()); - ASSERT_GT(available_devices.size(), 0); - - ASSERT_OK_AND_ASSIGN(auto first_device, - dawn_driver.CreateDevice(available_devices[0])); -} - -} // namespace -} // namespace dawn -} // namespace hal -} // namespace iree
diff --git a/iree/hal/deferred_buffer.cc b/iree/hal/deferred_buffer.cc deleted file mode 100644 index 3414436..0000000 --- a/iree/hal/deferred_buffer.cc +++ /dev/null
@@ -1,162 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/deferred_buffer.h" - -#include "iree/base/status.h" - -namespace iree { -namespace hal { - -DeferredBuffer::DeferredBuffer(Allocator* allocator, - MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, - BufferUsageBitfield usage, - device_size_t byte_length) - : Buffer(allocator, memory_type, allowed_access, usage, 0, 0, byte_length) { -} - -DeferredBuffer::~DeferredBuffer() = default; - -Status DeferredBuffer::GrowByteLength(device_size_t new_byte_length) { - if (parent_buffer_) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Attempting to set min allocation size while bound to an " - "allocation"; - } - if (byte_length_ != kWholeBuffer && new_byte_length < byte_length_) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Attempting to shrink a buffer to " << new_byte_length - << " when it has a minimum size of " << byte_length_; - } - byte_length_ = new_byte_length; - return OkStatus(); -} - -Status DeferredBuffer::BindAllocation(ref_ptr<Buffer> allocated_buffer, - device_size_t byte_offset, - device_size_t byte_length) { - // We can only be bound to allocations that are compatible with our specified - // allocator and usage. - if (!allocator_->CanUseBuffer(allocated_buffer.get(), usage())) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Allocation is not compatible with the allocator specified for " - "the deferred buffer"; - } - - // Calculate the range in the allocated_buffer that we are interested in. - RETURN_IF_ERROR(Buffer::CalculateRange(0, allocated_buffer->byte_length(), - byte_offset, byte_length, &byte_offset, - &byte_length)); - - // Verify that we have enough bytes for what we've promised. - if (byte_length < byte_length_) { - return OutOfRangeErrorBuilder(IREE_LOC) - << "Allocation range is too small; min_allocation_size=" - << byte_length_ << " but the range of " << byte_offset << "-" - << (byte_offset + byte_length - 1) << " (" << byte_length - << "b) is too small"; - } - - allocated_buffer_ = allocated_buffer.get(); - parent_buffer_ = std::move(allocated_buffer); - byte_offset_ = byte_offset; - return OkStatus(); -} - -void DeferredBuffer::ResetAllocation() { - allocated_buffer_ = this; - parent_buffer_.reset(); - byte_offset_ = 0; -} - -StatusOr<Buffer*> DeferredBuffer::ResolveAllocation() const { - // If you get errors here then someone allocated the buffer with - // MemoryType::kTransient and you are trying to use it outside of the time - // it is actually allocated (such as during CommandBuffer evaluation). If - // you need to use the buffer in non-transient ways then allocate the buffer - // without the MemoryType::kTransient flag. - if (!parent_buffer_) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Attempting to use a transient buffer prior to allocation: " - << DebugString(); - } - return parent_buffer_.get(); -} - -Status DeferredBuffer::FillImpl(device_size_t byte_offset, - device_size_t byte_length, const void* pattern, - device_size_t pattern_length) { - ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation()); - return allocated_buffer->FillImpl(byte_offset, byte_length, pattern, - pattern_length); -} - -Status DeferredBuffer::ReadDataImpl(device_size_t source_offset, void* data, - device_size_t data_length) { - ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation()); - return allocated_buffer->ReadDataImpl(source_offset, data, data_length); -} - -Status DeferredBuffer::WriteDataImpl(device_size_t target_offset, - const void* data, - device_size_t data_length) { - ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation()); - return allocated_buffer->WriteDataImpl(target_offset, data, data_length); -} - -Status DeferredBuffer::CopyDataImpl(device_size_t target_offset, - Buffer* source_buffer, - device_size_t source_offset, - device_size_t data_length) { - ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation()); - return allocated_buffer->CopyDataImpl(target_offset, source_buffer, - source_offset, data_length); -} - -Status DeferredBuffer::MapMemoryImpl(MappingMode mapping_mode, - MemoryAccessBitfield memory_access, - device_size_t local_byte_offset, - device_size_t local_byte_length, - void** out_data) { - ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation()); - return allocated_buffer->MapMemoryImpl(mapping_mode, memory_access, - local_byte_offset, local_byte_length, - out_data); -} - -Status DeferredBuffer::UnmapMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length, - void* data) { - ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation()); - return allocated_buffer->UnmapMemoryImpl(local_byte_offset, local_byte_length, - data); -} - -Status DeferredBuffer::InvalidateMappedMemoryImpl( - device_size_t local_byte_offset, device_size_t local_byte_length) { - ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation()); - return allocated_buffer->InvalidateMappedMemoryImpl(local_byte_offset, - local_byte_length); -} - -Status DeferredBuffer::FlushMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) { - ASSIGN_OR_RETURN(auto* allocated_buffer, ResolveAllocation()); - return allocated_buffer->FlushMappedMemoryImpl(local_byte_offset, - local_byte_length); -} - -} // namespace hal -} // namespace iree
diff --git a/iree/hal/deferred_buffer.h b/iree/hal/deferred_buffer.h deleted file mode 100644 index aeb19ad..0000000 --- a/iree/hal/deferred_buffer.h +++ /dev/null
@@ -1,106 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_DEFERRED_BUFFER_H_ -#define IREE_HAL_DEFERRED_BUFFER_H_ - -#include <cstddef> -#include <memory> -#include <utility> - -#include "iree/base/status.h" -#include "iree/hal/allocator.h" -#include "iree/hal/buffer.h" - -namespace iree { -namespace hal { - -// A Buffer that can have its underlying allocation changed at runtime. -// Unbound buffers act as a way to logically group dependent ranges of memory -// without needing to have allocated that memory yet. -// -// Usage: -// // Setup two spans referencing ranges of a deferred buffer. -// auto deferred_buffer = std::make_shared<DeferredBuffer>(..., 200); -// ASSIGN_OR_RETURN(auto span0, Buffer::Subspan(deferred_buffer, 0, 100)); -// ASSIGN_OR_RETURN(auto span1, Buffer::Subspan(deferred_buffer, 100, 100)); -// -// // Attempting to access |deferred_buffer| or |span0| or |span1| will fail. -// // ERROR: span0->Fill(false); -// -// // Now allocate a real buffer to serve as storage for the data. -// ASSIGN_OR_RETURN(auto allocated_buffer, Buffer::Allocate(..., 200)); -// RETURN_IF_ERROR(deferred_buffer->BindAllocation( -// allocated_buffer, 0, kWholeBuffer)); -// -// // And now we can use the spans. -// RETURN_IF_ERROR(span0->Fill(false)); -// -// // If at some point we want to detach the buffer from the allocation (so we -// // can use a different allocation, reuse the memory, etc). -// deferred_buffer->ResetAllocation(); -// -// Thread-compatible. Attempting to rebind the allocation while other threads -// are using the buffer will lead to undefined behavior. -class DeferredBuffer : public Buffer { - public: - DeferredBuffer(Allocator* allocator, MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, BufferUsageBitfield usage, - device_size_t byte_length); - ~DeferredBuffer() override; - - // Grows the minimum allocation size of the buffer to |new_byte_length|. - // Attempting to bind an allocation less than this size will fail. This must - // only be called when the buffer is not bound to an allocation. - Status GrowByteLength(device_size_t new_byte_length); - - // Binds or rebinds the deferred buffer to an allocated buffer. - Status BindAllocation(ref_ptr<Buffer> allocated_buffer, - device_size_t byte_offset, device_size_t byte_length); - - // Resets the deferred buffer to have no binding. - void ResetAllocation(); - - private: - // Resolves the allocated buffer that this subspan references into. - // This will fail if the buffer has not yet been bound to an allocation or - // the allocated buffer has not been committed. - StatusOr<Buffer*> ResolveAllocation() const; - - Status FillImpl(device_size_t byte_offset, device_size_t byte_length, - const void* pattern, device_size_t pattern_length) override; - Status ReadDataImpl(device_size_t source_offset, void* data, - device_size_t data_length) override; - Status WriteDataImpl(device_size_t target_offset, const void* data, - device_size_t data_length) override; - Status CopyDataImpl(device_size_t target_offset, Buffer* source_buffer, - device_size_t source_offset, - device_size_t data_length) override; - Status MapMemoryImpl(MappingMode mapping_mode, - MemoryAccessBitfield memory_access, - device_size_t local_byte_offset, - device_size_t local_byte_length, - void** out_data) override; - Status UnmapMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length, void* data) override; - Status InvalidateMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) override; - Status FlushMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) override; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_DEFERRED_BUFFER_H_
diff --git a/iree/hal/deferred_buffer_test.cc b/iree/hal/deferred_buffer_test.cc deleted file mode 100644 index dd8b968..0000000 --- a/iree/hal/deferred_buffer_test.cc +++ /dev/null
@@ -1,174 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/deferred_buffer.h" - -#include "absl/memory/memory.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "iree/base/status_matchers.h" -#include "iree/hal/heap_buffer.h" -#include "iree/hal/testing/mock_allocator.h" - -namespace iree { -namespace hal { -namespace { - -using ::iree::hal::testing::MockAllocator; -using ::testing::_; -using ::testing::Return; - -// Tests properties of unbound buffers. -TEST(DeferredBufferTest, Unbound) { - MockAllocator allocator; - auto deferred_buffer = absl::make_unique<DeferredBuffer>( - &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll, - 100); - EXPECT_EQ(&allocator, deferred_buffer->allocator()); - EXPECT_EQ(deferred_buffer.get(), deferred_buffer->allocated_buffer()); - EXPECT_EQ(0, deferred_buffer->allocation_size()); - EXPECT_EQ(0, deferred_buffer->byte_offset()); - EXPECT_EQ(100, deferred_buffer->byte_length()); -} - -// Tests that binding verifies allocators are compatible. -TEST(DeferredBufferTest, AllocatorCheck) { - MockAllocator allocator; - auto deferred_buffer = absl::make_unique<DeferredBuffer>( - &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll, - 100); - auto real_buffer = - HeapBuffer::Allocate(MemoryType::kHostLocal, BufferUsage::kAll, 256); - EXPECT_CALL( - allocator, - CanUseBufferLike(real_buffer->allocator(), real_buffer->memory_type(), - real_buffer->usage(), BufferUsage::kAll)) - .WillOnce(Return(false)); - EXPECT_TRUE(IsInvalidArgument( - deferred_buffer->BindAllocation(std::move(real_buffer), 0, 100))); -} - -// Tests that binding verifies allocation sizes. -TEST(DeferredBufferTest, SizeCheck) { - MockAllocator allocator; - auto deferred_buffer = absl::make_unique<DeferredBuffer>( - &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll, - 100); - auto real_buffer = - HeapBuffer::Allocate(MemoryType::kHostLocal, BufferUsage::kAll, 256); - EXPECT_CALL(allocator, CanUseBufferLike(_, _, _, _)) - .WillRepeatedly(Return(true)); - - EXPECT_OK(deferred_buffer->BindAllocation(add_ref(real_buffer), 10, 100)); - EXPECT_EQ(256, deferred_buffer->allocation_size()); - EXPECT_EQ(10, deferred_buffer->byte_offset()); - EXPECT_EQ(100, deferred_buffer->byte_length()); - EXPECT_OK( - deferred_buffer->BindAllocation(add_ref(real_buffer), 10, kWholeBuffer)); - EXPECT_EQ(256, deferred_buffer->allocation_size()); - EXPECT_EQ(10, deferred_buffer->byte_offset()); - EXPECT_EQ(100, deferred_buffer->byte_length()); - - EXPECT_TRUE(IsOutOfRange( - deferred_buffer->BindAllocation(add_ref(real_buffer), 200, 100))); - EXPECT_TRUE(IsOutOfRange(deferred_buffer->BindAllocation(add_ref(real_buffer), - 200, kWholeBuffer))); - EXPECT_TRUE(IsOutOfRange( - deferred_buffer->BindAllocation(add_ref(real_buffer), 10, 10))); -} - -// Tests resizing buffers after they have been allocated. -TEST(DeferredBufferTest, Resizing) { - MockAllocator allocator; - auto deferred_buffer = absl::make_unique<DeferredBuffer>( - &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll, - 100); - auto real_buffer = - HeapBuffer::Allocate(MemoryType::kHostLocal, BufferUsage::kAll, 256); - EXPECT_CALL(allocator, CanUseBufferLike(_, _, _, _)) - .WillRepeatedly(Return(true)); - - // Grow. - EXPECT_EQ(100, deferred_buffer->byte_length()); - EXPECT_OK(deferred_buffer->GrowByteLength(150)); - EXPECT_EQ(150, deferred_buffer->byte_length()); - - // Shrinking should fail. - EXPECT_TRUE(IsInvalidArgument(deferred_buffer->GrowByteLength(5))); - - // Growing should fail if bound. - EXPECT_OK(deferred_buffer->BindAllocation(std::move(real_buffer), 0, 150)); - EXPECT_TRUE(IsFailedPrecondition(deferred_buffer->GrowByteLength(100))); -} - -// Tests binding and rebinding behavior. -TEST(DeferredBufferTest, Rebinding) { - MockAllocator allocator; - auto deferred_buffer = absl::make_unique<DeferredBuffer>( - &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll, - 100); - auto real_buffer = - HeapBuffer::Allocate(MemoryType::kHostLocal, BufferUsage::kAll, 256); - EXPECT_CALL(allocator, CanUseBufferLike(_, _, _, _)) - .WillRepeatedly(Return(true)); - - // Safe to reset when not bound. - deferred_buffer->ResetAllocation(); - EXPECT_EQ(deferred_buffer.get(), deferred_buffer->allocated_buffer()); - EXPECT_EQ(0, deferred_buffer->allocation_size()); - - EXPECT_OK(deferred_buffer->BindAllocation(add_ref(real_buffer), 0, 100)); - EXPECT_EQ(real_buffer.get(), deferred_buffer->allocated_buffer()); - EXPECT_EQ(256, deferred_buffer->allocation_size()); - deferred_buffer->ResetAllocation(); - EXPECT_EQ(deferred_buffer.get(), deferred_buffer->allocated_buffer()); - EXPECT_EQ(0, deferred_buffer->allocation_size()); - EXPECT_OK(deferred_buffer->BindAllocation(add_ref(real_buffer), 0, 100)); - EXPECT_EQ(real_buffer.get(), deferred_buffer->allocated_buffer()); - EXPECT_EQ(256, deferred_buffer->allocation_size()); -} - -// Tests normal usage of bound buffers. -TEST(DeferredBufferTest, BoundUsage) { - MockAllocator allocator; - auto deferred_buffer = absl::make_unique<DeferredBuffer>( - &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll, - 100); - auto real_buffer = - HeapBuffer::Allocate(MemoryType::kHostLocal, BufferUsage::kAll, 256); - EXPECT_CALL(allocator, CanUseBufferLike(_, _, _, _)) - .WillRepeatedly(Return(true)); - EXPECT_OK(deferred_buffer->BindAllocation(std::move(real_buffer), 0, 100)); - - EXPECT_FALSE(deferred_buffer->DebugString().empty()); - EXPECT_FALSE(deferred_buffer->DebugStringShort().empty()); - - EXPECT_OK(deferred_buffer->Fill8(0, 10, 0xFF)); -} - -// Tests that unbound buffers fail to perform any buffer actions. -TEST(DeferredBufferTest, UnboundUsage) { - MockAllocator allocator; - auto deferred_buffer = absl::make_unique<DeferredBuffer>( - &allocator, MemoryType::kHostLocal, MemoryAccess::kAll, BufferUsage::kAll, - 100); - EXPECT_FALSE(deferred_buffer->DebugString().empty()); - EXPECT_FALSE(deferred_buffer->DebugStringShort().empty()); - - EXPECT_TRUE(IsFailedPrecondition(deferred_buffer->Fill8(0, 10, 0xFF))); -} - -} // namespace -} // namespace hal -} // namespace iree
diff --git a/iree/hal/device.h b/iree/hal/device.h deleted file mode 100644 index ff97a8a..0000000 --- a/iree/hal/device.h +++ /dev/null
@@ -1,165 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_DEVICE_H_ -#define IREE_HAL_DEVICE_H_ - -#include <memory> - -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "iree/base/status.h" -#include "iree/base/time.h" -#include "iree/hal/allocator.h" -#include "iree/hal/buffer.h" -#include "iree/hal/command_queue.h" -#include "iree/hal/device_info.h" -#include "iree/hal/event.h" -#include "iree/hal/executable_cache.h" -#include "iree/hal/semaphore.h" - -namespace iree { -namespace hal { - -class Device { - public: - virtual ~Device() = default; - - // Information about device capabilities. - const DeviceInfo& info() const { return device_info_; } - - // TODO(benvanik): status (thermal, power mode, etc). - - // TODO(benvanik): throttling adjustment/power profile. - - // TODO(benvanik): control (suspend/resume, delay, etc). - - // An allocator providing buffers usable by the device. - // This allocator may be shared with other devices in the same family. - virtual Allocator* allocator() const = 0; - - // Returns a list of all general-purpose dispatch queues provided by the - // device. In general these map 1:1 with independent execution contexts, - // though some devices may hide that and expose only a single queue that is - // scheduled internally. - virtual absl::Span<CommandQueue*> dispatch_queues() const = 0; - - // Returns a list of transfer queues provided by the device. These queues may - // perform transfer operations asynchronously with respect to execution on the - // dispatch queues. For large sequences of transfer operations always prefer - // using one of these queues. - // Note that if the device does not support a dedicated transfer queue this - // list may be the same as (or a subset of) dispatch_queues. - virtual absl::Span<CommandQueue*> transfer_queues() const = 0; - - // TODO(b/137153339): accept initial cache data. - // Creates a device-specific cache for executables prepared for dispatch. - // The cache manages executable compilation, caching (on disk or in memory), - // and lifetime. Users can decide to use one or more caches to allow differing - // lifetimes (such as unloading modules), persistent on disk caching of only - // specific hot executables, etc. - // - // Returns a thread-safe cache that must remain alive until all executables - // using the cache are no longer in-flight. - virtual std::shared_ptr<ExecutableCache> CreateExecutableCache() = 0; - - // Creates a command buffer for recording commands to submit to queues owned - // by this device. The command buffer may come from a pool but will be reset - // prior to being returned to the caller. - virtual StatusOr<ref_ptr<CommandBuffer>> CreateCommandBuffer( - CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories) = 0; - - // Creates an event for recording into command buffers. - // The returned event object is only usable with this device and events must - // only be used to synchronize within the same queue. - virtual StatusOr<ref_ptr<Event>> CreateEvent() = 0; - - // Creates a binary semaphore that can be used with command queues owned by - // this device. To use the semaphores with other devices or instances they - // must first be exported. - virtual StatusOr<ref_ptr<BinarySemaphore>> CreateBinarySemaphore( - bool initial_value) = 0; - - // Creates a timeline semaphore that can be used with command queues owned by - // this device. To use the semaphores with other devices or instances they - // must first be exported. - virtual StatusOr<ref_ptr<TimelineSemaphore>> CreateTimelineSemaphore( - uint64_t initial_value) = 0; - - // Creates a fence that can be used with command queues owned by this device. - // To use the fences with other devices or instances they must first be - // exported. - virtual StatusOr<ref_ptr<Fence>> CreateFence(uint64_t initial_value) = 0; - - // TODO(benvanik): import/export semaphore utilities. - // TODO(benvanik): import/export fence utilities. - // TODO(benvanik): fences to wait handles. - - // Blocks the caller until all passed |fences| reach or exceed the specified - // payload values or the |deadline| elapses. All |fences| must be created from - // this device (or be imported into it). - // - // Returns success if the wait is successful and all fences have been - // signaled. - // - // Returns DEADLINE_EXCEEDED if the |deadline| elapses without all fences - // having been signaled. Note that a subset of the |fences| may have been - // signaled and each can be queried to see which ones. - virtual Status WaitAllFences(absl::Span<const FenceValue> fences, - absl::Time deadline) = 0; - inline Status WaitAllFences(absl::Span<const FenceValue> fences, - absl::Duration timeout) { - return WaitAllFences(fences, RelativeTimeoutToDeadline(timeout)); - } - - // Blocks the caller until at least one of the |fences| reaches or exceeds the - // specified payload value or the |deadline| elapses. All |fences| must be - // created from this device (or be imported into it). - // - // Returns an arbitrary index into |fences| of a fence that was signaled. Note - // that more than one fence may have been signaled and all of the other - // |fences| should be queried or waited on again until waits for them - // succeed. - // - // Returns DEADLINE_EXCEEDED if the |deadline| elapses without any fences - // having been signaled. - virtual StatusOr<int> WaitAnyFence(absl::Span<const FenceValue> fences, - absl::Time deadline) = 0; - inline StatusOr<int> WaitAnyFence(absl::Span<const FenceValue> fences, - absl::Duration timeout) { - return WaitAnyFence(fences, RelativeTimeoutToDeadline(timeout)); - } - - // Blocks until all outstanding requests on all queues have been - // completed. This is equivalent to having waited on all outstanding - // fences. - virtual Status WaitIdle(absl::Time deadline) = 0; - inline Status WaitIdle(absl::Duration timeout) { - return WaitIdle(RelativeTimeoutToDeadline(timeout)); - } - inline Status WaitIdle() { return WaitIdle(absl::InfiniteFuture()); } - - protected: - explicit Device(DeviceInfo device_info) - : device_info_(std::move(device_info)) {} - - private: - const DeviceInfo device_info_; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_DEVICE_H_
diff --git a/iree/hal/device_info.h b/iree/hal/device_info.h deleted file mode 100644 index 7a6c6d3..0000000 --- a/iree/hal/device_info.h +++ /dev/null
@@ -1,90 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_DEVICE_INFO_H_ -#define IREE_HAL_DEVICE_INFO_H_ - -#include <cstdint> -#include <string> -#include <utility> - -#include "iree/base/bitfield.h" - -namespace iree { -namespace hal { - -// Describes features supported by the device. -// These flags indicate the availability of features that may be enabled at the -// request of the calling application. Note that certain features may disable -// runtime optimizations or require compilation flags to ensure the required -// metadata is present in executables. -enum class DeviceFeature : uint32_t { - kNone = 0, - - // Device supports executable debugging. - // When present executables *may* be compiled with - // ExecutableCachingMode::kEnableDebugging and will have usable debugging - // related methods. Note that if the input executables do not have embedded - // debugging information they still may not be able to perform disassembly or - // fine-grained breakpoint insertion. - kDebugging = 1 << 0, - - // Device supports executable coverage information. - // When present executables *may* be compiled with - // ExecutableCachingMode::kEnableCoverage and will produce coverage buffers - // during dispatch. Note that input executables must have partial embedded - // debug information to allow mapping back to source offsets. - kCoverage = 1 << 1, - - // Device supports executable and command queue profiling. - // When present executables *may* be compiled with - // ExecutableCachingMode::kEnableProfiling and will produce profiling buffers - // during dispatch. Note that input executables must have partial embedded - // debug information to allow mapping back to source offsets. - kProfiling = 1 << 2, -}; -IREE_BITFIELD(DeviceFeature); -using DeviceFeatureBitfield = DeviceFeature; - -// TODO(benvanik): device info (caps, physical mappings, etc). -class DeviceInfo { - public: - DeviceInfo(std::string name, DeviceFeatureBitfield supported_features, - void* driver_handle = nullptr) - : name_(std::move(name)), - supported_features_(supported_features), - driver_handle_(driver_handle) {} - - const std::string& name() const { return name_; } - - // Features supported by the device. - DeviceFeatureBitfield supported_features() const { - return supported_features_; - } - - // Opaque handle used by drivers to correlate this device with their internal - // listing. This handle will not be valid across driver instances or outside - // of the current process. - void* driver_handle() const { return driver_handle_; } - - private: - const std::string name_; - const DeviceFeatureBitfield supported_features_; - void* driver_handle_; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_DEVICE_INFO_H_
diff --git a/iree/hal/device_manager.cc b/iree/hal/device_manager.cc deleted file mode 100644 index 6f8d9f4..0000000 --- a/iree/hal/device_manager.cc +++ /dev/null
@@ -1,201 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/device_manager.h" - -#include <algorithm> - -#include "iree/base/source_location.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/heap_buffer.h" - -namespace iree { -namespace hal { - -DeviceManager::DeviceManager() = default; - -DeviceManager::~DeviceManager() { - IREE_TRACE_SCOPE0("DeviceManager::dtor"); - WaitIdle().IgnoreError(); -} - -Status DeviceManager::RegisterDevice(std::shared_ptr<Device> device) { - IREE_TRACE_SCOPE0("DeviceManager::RegisterDevice"); - absl::MutexLock lock(&device_mutex_); - if (std::find(devices_.begin(), devices_.end(), device) != devices_.end()) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Device already registered"; - } - devices_.push_back(std::move(device)); - return OkStatus(); -} - -Status DeviceManager::UnregisterDevice(Device* device) { - IREE_TRACE_SCOPE0("DeviceManager::UnregisterDevice"); - absl::MutexLock lock(&device_mutex_); - auto it = std::find_if(devices_.begin(), devices_.end(), - [device](const std::shared_ptr<Device>& other_device) { - return device == other_device.get(); - }); - if (it == devices_.end()) { - return NotFoundErrorBuilder(IREE_LOC) << "Device not registered"; - } - devices_.erase(it); - return OkStatus(); -} - -StatusOr<DevicePlacement> DeviceManager::ResolvePlacement( - const PlacementSpec& placement_spec) const { - IREE_TRACE_SCOPE0("DeviceManager::ResolvePlacement"); - absl::MutexLock lock(&device_mutex_); - if (devices_.empty()) { - return NotFoundErrorBuilder(IREE_LOC) << "No devices registered"; - } - - // TODO(benvanik): multiple devices and placement. - QCHECK_EQ(devices_.size(), 1) - << "Multiple devices not yet supported (need placement)"; - DevicePlacement device_placement; - device_placement.device = devices_.front(); - - return device_placement; -} - -StatusOr<Allocator*> DeviceManager::FindCompatibleAllocator( - MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - absl::Span<const DevicePlacement> device_placements) const { - IREE_TRACE_SCOPE0("DeviceManager::FindCompatibleAllocator"); - if (device_placements.empty()) { - return InvalidArgumentErrorBuilder(IREE_LOC) << "No placements provided"; - } - - // Find the first allocator. As we only return an allocator if all placements - // are compatible we'll compare allocator[0] against allocator[1,N]. - Allocator* some_allocator = nullptr; - for (const auto& device_placement : device_placements) { - auto* allocator = device_placement.device->allocator(); - if (!some_allocator) { - some_allocator = allocator; - continue; - } - // NOTE: as there can be asymmetry between usage restrictions (A can use B - // but B cannot use A) we have to compare both directions. - if (!some_allocator->CanUseBufferLike(allocator, memory_type, buffer_usage, - buffer_usage) || - !allocator->CanUseBufferLike(some_allocator, memory_type, buffer_usage, - buffer_usage)) { - // Allocators are not compatible. - return NotFoundErrorBuilder(IREE_LOC) - << "No single allocator found that is compatible with all " - "placements"; - } - } - return some_allocator; -} - -StatusOr<ref_ptr<Buffer>> DeviceManager::TryAllocateDeviceVisibleBuffer( - MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - device_size_t allocation_size, - absl::Span<const DevicePlacement> device_placements) { - IREE_TRACE_SCOPE("DeviceManager::TryAllocateDeviceVisibleBuffer:size", int) - (static_cast<int>(allocation_size)); - if (!AnyBitSet(memory_type & MemoryType::kHostLocal)) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Host-local buffers require the kHostLocal bit: " - << MemoryTypeString(memory_type); - } - - // Strip kDeviceVisible as we conditionally add it based on support. - memory_type &= ~MemoryType::kDeviceVisible; - - // Find an allocator that works for device-visible buffers. - // If this fails we'll fall back to allocation a non-device-visible buffer. - auto allocator_or = - FindCompatibleAllocator(memory_type | MemoryType::kDeviceVisible, - buffer_usage, device_placements); - if (allocator_or.ok()) { - return allocator_or.ValueOrDie()->Allocate( - memory_type | MemoryType::kDeviceVisible, buffer_usage, - allocation_size); - } - - // Fallback to allocating a host-local buffer. - return HeapBuffer::Allocate(memory_type, buffer_usage, allocation_size); -} - -StatusOr<ref_ptr<Buffer>> DeviceManager::AllocateDeviceVisibleBuffer( - MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - device_size_t allocation_size, - absl::Span<const DevicePlacement> device_placements) { - IREE_TRACE_SCOPE("DeviceManager::AllocateDeviceVisibleBuffer:size", int) - (static_cast<int>(allocation_size)); - if (!AnyBitSet(memory_type & MemoryType::kHostLocal)) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Host-local buffers require the kHostLocal bit: " - << MemoryTypeString(memory_type); - } - - // Always use device-visible. - memory_type |= MemoryType::kDeviceVisible; - - // Find an allocator that works for device-visible buffers. - ASSIGN_OR_RETURN( - auto* allocator, - FindCompatibleAllocator(memory_type, buffer_usage, device_placements)); - return allocator->Allocate(memory_type, buffer_usage, allocation_size); -} - -StatusOr<ref_ptr<Buffer>> DeviceManager::AllocateDeviceLocalBuffer( - MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - device_size_t allocation_size, - absl::Span<const DevicePlacement> device_placements) { - IREE_TRACE_SCOPE("DeviceManager::AllocateDeviceLocalBuffer:size", int) - (static_cast<int>(allocation_size)); - if (!AnyBitSet(memory_type & MemoryType::kDeviceLocal)) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Device-local buffers require the kDeviceLocal bit: " - << MemoryTypeString(memory_type); - } - - // Find an allocator that works for device-local buffers. - ASSIGN_OR_RETURN( - auto* allocator, - FindCompatibleAllocator(memory_type, buffer_usage, device_placements)); - return allocator->Allocate(memory_type, buffer_usage, allocation_size); -} - -Status DeviceManager::Submit(Device* device, CommandQueue* command_queue, - absl::Span<const SubmissionBatch> batches, - absl::Time deadline, FenceValue fence) { - IREE_TRACE_SCOPE0("DeviceManager::Submit"); - return command_queue->Submit(batches, fence); -} - -Status DeviceManager::Flush() { - IREE_TRACE_SCOPE0("DeviceManager::Flush"); - return OkStatus(); -} - -Status DeviceManager::WaitIdle(absl::Time deadline) { - IREE_TRACE_SCOPE0("DeviceManager::WaitIdle"); - absl::MutexLock lock(&device_mutex_); - for (const auto& device : devices_) { - RETURN_IF_ERROR(device->WaitIdle(deadline)); - } - return OkStatus(); -} - -} // namespace hal -} // namespace iree
diff --git a/iree/hal/device_manager.h b/iree/hal/device_manager.h deleted file mode 100644 index b6a8783..0000000 --- a/iree/hal/device_manager.h +++ /dev/null
@@ -1,209 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_DEVICE_MANAGER_H_ -#define IREE_HAL_DEVICE_MANAGER_H_ - -#include <vector> - -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" -#include "iree/base/status.h" -#include "iree/base/time.h" -#include "iree/hal/allocator.h" -#include "iree/hal/buffer.h" -#include "iree/hal/command_queue.h" -#include "iree/hal/device.h" -#include "iree/hal/device_placement.h" -#include "iree/hal/executable_format.h" -#include "iree/hal/fence.h" - -namespace iree { -namespace hal { - -// Specifies how devices should be resolved to DevicePlacements. -// Most fields are optional and when not included will be ignored. -struct PlacementSpec { - // TODO(benvanik): other requirements (features/caps, power, etc). - - // A list of executable formats that the placement should support. - // If more than one format is provided any device satisfying at least one - // will be considered for placement. The formats can be sorted in descending - // priority order to prefer the first available format in the case of ties. - absl::Span<const ExecutableFormat> available_formats; -}; - -// Manages device lifetime and placement resolution. -// Optionally the DeviceManager may be used for automatic device selection for -// allocations or batched submissions, however this is not required if specific -// devices and scheduling behavior are known to the caller. -// -// Thread-safe. Note that callers must ensure that unregistered devices are kept -// alive for as long as any commands are in-flight that may be using them. -class DeviceManager final { - public: - DeviceManager(); - ~DeviceManager(); - - // Registers a device with the manager. - // The device will be used to resolve placements. Any placements resolved - // prior to the addition of the device will need to be refreshed by the caller - // if they want to make use of the new device. - Status RegisterDevice(std::shared_ptr<Device> device); - - // Unregisters a device with the manager. - // Placements that resolved to the device prior to unregistering will remain - // valid for that device. Callers will need to refresh the placements to - // ensure the device stops being used. - Status UnregisterDevice(Device* device); - - // TODO(benvanik): dispatch info + requirements + etc -> DevicePlacement. - - // Resolves a placement spec to a device placement based on the registered - // devices. - // If the placement is not fully specified the device and queue may be chosen - // at random. See PlacementSpec for more information about resolution and - // ranking. - StatusOr<DevicePlacement> ResolvePlacement( - const PlacementSpec& placement_spec) const; - - // Finds an allocator that can allocate buffers of the given |memory_type| and - // |buffer_usage| such that the buffers can be used interchangebly. - // Fails if there is no Allocator that can satisfy that requirement. - StatusOr<Allocator*> FindCompatibleAllocator( - MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - absl::Span<const DevicePlacement> device_placements) const; - - // Tries to allocate a host-local buffer that _may_ be optimal for use with - // the given |device_placements| and _may_ be device-visible. The buffer can - // be used for staging uploads to device-local buffers and is useful for times - // when the buffer will be used more on the host than the device. If a buffer - // never needs to be used with a device prefer instead - // Allocator::host_local()::Allocate. - // - // Returns a buffer even if it's not possible to satisfy the requested - // |buffer_usage| for the |device_placements| at the cost of a run-time - // performance hit. - StatusOr<ref_ptr<Buffer>> TryAllocateDeviceVisibleBuffer( - MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - device_size_t allocation_size, - absl::Span<const DevicePlacement> device_placements); - StatusOr<ref_ptr<Buffer>> TryAllocateDeviceVisibleBuffer( - BufferUsageBitfield buffer_usage, device_size_t allocation_size, - absl::Span<const DevicePlacement> device_placements) { - return TryAllocateDeviceVisibleBuffer( - MemoryType::kHostLocal | MemoryType::kDeviceVisible, buffer_usage, - allocation_size, device_placements); - } - - // Allocates a host-local buffer that is optimal for use on the host but is - // usable by the given |device_placements| (at a possible performance - // penalty). The buffer can be used for staging uploads to device-local - // buffers and is useful for times when the buffer will be used more on the - // host than the device. If a buffer never needs to be used with a device - // prefer instead HeapBuffer::Allocate. - // - // Fails if it is not possible to allocate and satisfy all |device_placements| - // for the requested |buffer_usage|. - StatusOr<ref_ptr<Buffer>> AllocateDeviceVisibleBuffer( - MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - device_size_t allocation_size, - absl::Span<const DevicePlacement> device_placements); - StatusOr<ref_ptr<Buffer>> AllocateDeviceVisibleBuffer( - BufferUsageBitfield buffer_usage, device_size_t allocation_size, - absl::Span<const DevicePlacement> device_placements) { - return AllocateDeviceVisibleBuffer( - MemoryType::kHostLocal | MemoryType::kDeviceVisible, buffer_usage, - allocation_size, device_placements); - } - - // Allocates a device-local buffer that is optimal for use with the given - // |device_placements|. The buffer will not be host-visible and can only be - // used from compatible device queues. - // - // Fails if it is not possible to allocate and satisfy all |device_placements| - // for the requested |buffer_usage|. - StatusOr<ref_ptr<Buffer>> AllocateDeviceLocalBuffer( - MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - device_size_t allocation_size, - absl::Span<const DevicePlacement> device_placements); - StatusOr<ref_ptr<Buffer>> AllocateDeviceLocalBuffer( - BufferUsageBitfield buffer_usage, device_size_t allocation_size, - absl::Span<const DevicePlacement> device_placements) { - return AllocateDeviceLocalBuffer(MemoryType::kDeviceLocal, buffer_usage, - allocation_size, device_placements); - } - - // Enqueues a submission against the given target |device| |command_queue|. - // The provided |deadline| is used to determine how long the submission can - // stay waiting in the queue prior to flushing, with absl::InfinitePast - // indicating immediate submission and absl::InfiniteFuture indicating that - // Flush must be called. - // - // If a |fence| is provided it will be signaled when the submission has - // completed and otherwise the caller must use WaitIdle to ensure completion. - // If a sequence of submissions are performed then the semaphore relationships - // can be used to elide waits. Submit(A)+Submit(B, fence) where there is a - // dependency from A->B is safe. - // - // All provided resources must remain alive until the provided |fence| - // resolves or Scheduler::WaitIdle succeeds. - // - // Submissions may be made from any thread. Behavior is undefined - // if a thread is performing a WaitIdle while another thread submits work. - Status Submit(Device* device, CommandQueue* command_queue, - absl::Span<const SubmissionBatch> batches, absl::Time deadline, - FenceValue fence = {}); - Status Submit(Device* device, CommandQueue* command_queue, - absl::Span<const SubmissionBatch> batches, - absl::Duration timeout, FenceValue fence = {}) { - return Submit(device, command_queue, batches, - RelativeTimeoutToDeadline(timeout), fence); - } - Status Submit(Device* device, CommandQueue* command_queue, - absl::Span<const SubmissionBatch> batches, - FenceValue fence = {}) { - return Submit(device, command_queue, batches, absl::InfinitePast(), fence); - } - - // Flushes any requests that are pending in the scheduler and ensures they - // begin executing ASAP regardless of policy. - // - // If any used device has encountered an error during submission at any - // point it will be returned here (repeatedly). - Status Flush(); - - // Blocks until all outstanding requests have been completed. - // This is equivalent to having waited on all outstanding fences. - // Implicitly calls Flush to ensure delayed requests are scheduled. - // Work submitted from other threads during a wait may not be included in the - // wait set. - // - // If any used device has encountered an error during submission at any - // point it will be returned here (repeatedly). - Status WaitIdle(absl::Time deadline); - inline Status WaitIdle(absl::Duration timeout) { - return WaitIdle(RelativeTimeoutToDeadline(timeout)); - } - inline Status WaitIdle() { return WaitIdle(absl::InfiniteFuture()); } - - private: - mutable absl::Mutex device_mutex_; - std::vector<std::shared_ptr<Device>> devices_ ABSL_GUARDED_BY(device_mutex_); -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_DEVICE_MANAGER_H_
diff --git a/iree/hal/driver.h b/iree/hal/driver.h deleted file mode 100644 index 023660c..0000000 --- a/iree/hal/driver.h +++ /dev/null
@@ -1,61 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_DRIVER_H_ -#define IREE_HAL_DRIVER_H_ - -#include <memory> -#include <string> -#include <vector> - -#include "iree/base/status.h" -#include "iree/hal/device.h" -#include "iree/hal/device_info.h" - -namespace iree { -namespace hal { - -class Driver { - public: - virtual ~Driver() = default; - - // Driver name used during registration. - const std::string& name() const { return name_; } - - // TODO(benvanik): info/query (version number, etc). - - // Enumerates devices available for creation from the driver. - // This may fail if the driver is in an invalid state but otherwise will - // return an empty list if no devices are available. - virtual StatusOr<std::vector<DeviceInfo>> EnumerateAvailableDevices() = 0; - - // Creates the driver-defined 'default' device. - // This may simply be the first device enumerated. - virtual StatusOr<std::shared_ptr<Device>> CreateDefaultDevice() = 0; - - // Creates a device as queried with the given |device_info|. - virtual StatusOr<std::shared_ptr<Device>> CreateDevice( - const DeviceInfo& device_info) = 0; - - protected: - explicit Driver(std::string name) : name_(std::move(name)) {} - - private: - const std::string name_; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_DRIVER_H_
diff --git a/iree/hal/driver_registry.cc b/iree/hal/driver_registry.cc deleted file mode 100644 index 21ab5ea..0000000 --- a/iree/hal/driver_registry.cc +++ /dev/null
@@ -1,87 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/driver_registry.h" - -#include "iree/base/status.h" - -namespace iree { -namespace hal { - -// static -DriverRegistry* DriverRegistry::shared_registry() { - static auto* singleton = new DriverRegistry(); - return singleton; -} - -DriverRegistry::DriverRegistry() = default; - -DriverRegistry::~DriverRegistry() = default; - -Status DriverRegistry::Register(std::string driver_name, FactoryFn factory_fn) { - absl::MutexLock lock(&mutex_); - for (const auto& pair : driver_factory_fns_) { - if (pair.first == driver_name) { - return AlreadyExistsErrorBuilder(IREE_LOC) - << "Driver already registered: " << driver_name; - } - } - driver_factory_fns_.emplace_back(driver_name, std::move(factory_fn)); - return OkStatus(); -} - -bool DriverRegistry::HasDriver(absl::string_view driver_name) const { - absl::MutexLock lock(&mutex_); - for (const auto& pair : driver_factory_fns_) { - if (pair.first == driver_name) { - return true; - } - } - return false; -} - -std::vector<std::string> DriverRegistry::EnumerateAvailableDrivers() const { - absl::MutexLock lock(&mutex_); - std::vector<std::string> driver_names; - driver_names.reserve(driver_factory_fns_.size()); - for (const auto& pair : driver_factory_fns_) { - driver_names.push_back(pair.first); - } - return driver_names; -} - -StatusOr<std::shared_ptr<Driver>> DriverRegistry::Create( - absl::string_view driver_name) const { - FactoryFn factory_fn; - { - absl::MutexLock lock(&mutex_); - for (const auto& pair : driver_factory_fns_) { - if (pair.first == driver_name) { - factory_fn = pair.second; - break; - } - } - if (!factory_fn) { - return NotFoundErrorBuilder(IREE_LOC) - << "Driver " << driver_name << " not found"; - } - } - return factory_fn(); -} - -} // namespace hal -} // namespace iree - -IREE_REGISTER_MODULE_INITIALIZER( - iree_hal, ::iree::hal::DriverRegistry::shared_registry());
diff --git a/iree/hal/driver_registry.h b/iree/hal/driver_registry.h deleted file mode 100644 index 26b05fc..0000000 --- a/iree/hal/driver_registry.h +++ /dev/null
@@ -1,83 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_DRIVER_REGISTRY_H_ -#define IREE_HAL_DRIVER_REGISTRY_H_ - -#include <memory> -#include <vector> - -#include "absl/base/thread_annotations.h" -#include "absl/synchronization/mutex.h" -#include "iree/base/init.h" -#include "iree/base/status.h" -#include "iree/hal/driver.h" - -namespace iree { -namespace hal { - -// Driver registry and factory. -// Factory functions for available drivers are registered with a given name and -// can be invoked with a call to Create. The configuration of the drivers is -// generally contained within the factory function and consumers of the drivers -// don't need to fiddle with things. -// -// This is used for dynamic *safe* link-time driver module registration. -// Roughly: driver_registry provides the shared registry and a way to create -// drivers and *_driver_module.cc files register drivers when linked in. -// Remember to alwayslink=1 on cc_libraries providing modules. -// -// If link-time driver registration is not desired (or possible) it's also -// possible to explicitly register drivers via this registry. This is useful -// when programmatically enabling drivers. -// -// Thread-safe. -class DriverRegistry final { - public: - using FactoryFn = std::function<StatusOr<std::shared_ptr<Driver>>()>; - - // The shared driver registry singleton that modules use when linked in. - static DriverRegistry* shared_registry(); - - DriverRegistry(); - ~DriverRegistry(); - - // Registers a driver and its factory function. - // The function will be called to create a new driver whenever it is requested - // via Create. - Status Register(std::string driver_name, FactoryFn factory_fn); - - // Returns true if there is a driver registered with the given name. - bool HasDriver(absl::string_view driver_name) const; - - // Returns a list of registered drivers. - std::vector<std::string> EnumerateAvailableDrivers() const; - - // TODO(benvanik): flags for enabling debug validation/control/etc. - // Creates a driver by name. - StatusOr<std::shared_ptr<Driver>> Create(absl::string_view driver_name) const; - - private: - mutable absl::Mutex mutex_; - std::vector<std::pair<std::string, FactoryFn>> driver_factory_fns_ - ABSL_GUARDED_BY(mutex_); -}; - -} // namespace hal -} // namespace iree - -IREE_DECLARE_MODULE_INITIALIZER(iree_hal); -IREE_REQUIRE_MODULE_LINKED(iree_hal); - -#endif // IREE_HAL_DRIVER_REGISTRY_H_
diff --git a/iree/hal/event.h b/iree/hal/event.h deleted file mode 100644 index c7786f4..0000000 --- a/iree/hal/event.h +++ /dev/null
@@ -1,35 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_EVENT_H_ -#define IREE_HAL_EVENT_H_ - -#include "iree/hal/resource.h" - -namespace iree { -namespace hal { - -// Events are used for defining synchronization scopes within CommandBuffers. -// An event only exists within a single CommandBuffer and must not be used -// across CommandBuffers from the same device or others. -// -// See CommandBuffer::SignalEvent and CommandBuffer::WaitEvents for more info. -class Event : public Resource { - public: -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_EVENT_H_
diff --git a/iree/hal/executable.h b/iree/hal/executable.h deleted file mode 100644 index d724d01..0000000 --- a/iree/hal/executable.h +++ /dev/null
@@ -1,57 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_EXECUTABLE_H_ -#define IREE_HAL_EXECUTABLE_H_ - -#include "iree/hal/resource.h" - -namespace iree { -namespace hal { - -class Executable : public Resource { - public: - ~Executable() override = default; - - // True if the executable was prepared with debugging enabled and the device - // and input data support debugging (symbols present, etc). - virtual bool supports_debugging() const = 0; - - // TODO(benvanik): disassembly methods. - - // TODO(benvanik): relative offset calculation: - // - step once - // - step over - // - step out - - // TODO(benvanik): create executable split on breakpoint. - // Executable should return when the breakpoint is hit without any future - // modifications to output buffers. If the breakpoint is not hit the - // executable should run to completion as normal. - - // TODO(benvanik): retrieve coverage info. - // Returns a buffer containing offset -> coverage metrics. Note that depending - // on the device this may only contain a single coverage metric for the entire - // executable or some subset of the available offsets. - - // TODO(benvanik): retrieve profiling info. - - protected: - Executable() = default; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_EXECUTABLE_H_
diff --git a/iree/hal/executable_cache.cc b/iree/hal/executable_cache.cc deleted file mode 100644 index 26ce40c..0000000 --- a/iree/hal/executable_cache.cc +++ /dev/null
@@ -1,25 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/executable_cache.h" - -namespace iree { -namespace hal { - -ExecutableCache::ExecutableCache() = default; - -ExecutableCache::~ExecutableCache() = default; - -} // namespace hal -} // namespace iree
diff --git a/iree/hal/executable_cache.h b/iree/hal/executable_cache.h deleted file mode 100644 index 3e31831..0000000 --- a/iree/hal/executable_cache.h +++ /dev/null
@@ -1,126 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_EXECUTABLE_CACHE_H_ -#define IREE_HAL_EXECUTABLE_CACHE_H_ - -#include "iree/base/bitfield.h" -#include "iree/base/ref_ptr.h" -#include "iree/base/status.h" -#include "iree/hal/executable.h" -#include "iree/hal/executable_format.h" -#include "iree/hal/executable_spec.h" - -namespace iree { -namespace hal { - -// Defines how the executable cache performs preparation. -enum class ExecutableCachingMode : uint32_t { - // Allows the cache to reference the provided executable_data after it has - // prepared the executable. Callers must ensure the data remains valid for the - // lifetime of the cache. If memory mapping constant executable data from - // disk this can be used to avoid copies. - kAliasProvidedData = 1 << 0, - - // Allows the prepared executable to be cached persistently (on disk/etc). - // Enable for any executable that is likely to be used in future runs. - // Note that not all caches support persistent serialization and this is just - // a hint. - kAllowPersistentCaching = 1 << 1, - - // Allows the cache to optimize the executable as much as it can. - // This may cause preparation to take significantly longer while (hopefully) - // improving runtime performance. Avoid for one-shot executables. - kAllowOptimization = 1 << 2, - - // Enables Executable debugging methods if supported by the device and - // executable. This may disable certain optimizations or retain additional - // data to allow disassembly, stepping, etc. - // - // Device must support the DeviceFeature::kDebugging feature and executables - // must support the ExecutableFeature::kDebugging feature. - kEnableDebugging = 1 << 3, - - // Enables Executable coverage if supported by the device and executable. - // Depending on the optimization mode this may produce partial coverage - // results (for example, when certain source operations were optimized away). - // - // Device must support the DeviceFeature::kCoverage feature and executables - // must support the ExecutableFeature::kCoverage feature. - kEnableCoverage = 1 << 4, - - // Enables Executable profiling if supported by the device and executable. - // Depending on the optimization mode this may produce partial profiling - // results. Profiling attribution (whether to the entire executable or - // specific operations) depends on the implementation. - // - // Device must support the DeviceFeature::kProfiling feature and executables - // must support the ExecutableFeature::kProfiling feature. - kEnableProfiling = 1 << 5, - - // Default caching mode. - kDefault = kAllowPersistentCaching | kAllowOptimization, -}; -IREE_BITFIELD(ExecutableCachingMode); -using ExecutableCachingModeBitfield = ExecutableCachingMode; - -// A cache of prepared executables for a particular device. -// Caches may be shared across multiple devices from the same driver or specific -// to individual devices. Caches may persist prepared executables across process -// launches or reprepare them each run. Callers should assume that the cache is -// a no-op and the returned Executables only live for as long as the cache does. -// -// The term 'cache' here is rather optimistic - it's perfectly acceptable for -// implementations to not cache at all and return new Executables for each -// PrepareExecutable called (even for the same executable). Callers should -// expect such behavior and try to retain the results of the PrepareExecutable -// calls to reduce overhead in re-preparing executables. -// -// Thread-safe - multiple threads may prepare executables (including the *same* -// executable) simultaneously. -class ExecutableCache { - public: - virtual ~ExecutableCache(); - - // TODO(benvanik): status/queries (size, etc). - - // TODO(b/137153339): serialization/deserialization. - - // Returns true if the executable cache can prepare the given executable input - // format. Perparation may still fail if the particular version or features - // required by the executable are not supported. - virtual bool CanPrepareFormat(ExecutableFormat format) const = 0; - - // Prepares an executable for use. - // The provided |spec| and |executable_data| will be used to either lookup a - // previously prepared executable in the cache or prepare a new one. - // - // Depending on the driver preparation may take a non-trivial amount of time - // (such as when JITing/etc). As the cache is internally synchronized callers - // can issue preparation requests from multiple threads - even for the same - // executables - and calls will block until preparation completes. - // - // When preparing a large number of executables it's recommended to use the - // PrepareExecutables method to batch and wait on the results. - virtual StatusOr<ref_ptr<Executable>> PrepareExecutable( - ExecutableCachingModeBitfield mode, const ExecutableSpec& spec) = 0; - - protected: - ExecutableCache(); -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_EXECUTABLE_CACHE_H_
diff --git a/iree/hal/executable_spec.h b/iree/hal/executable_spec.h deleted file mode 100644 index a88553f..0000000 --- a/iree/hal/executable_spec.h +++ /dev/null
@@ -1,44 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_EXECUTABLE_SPEC_H_ -#define IREE_HAL_EXECUTABLE_SPEC_H_ - -#include "absl/types/span.h" -#include "iree/hal/executable_format.h" - -namespace iree { -namespace hal { - -// Defines an executable specification used by a cache to prepare an executable. -struct ExecutableSpec { - // TODO(benvanik): pre-populated hash_code/key to avoid calculation. - - // Format of the executable input data. - ExecutableFormat format = kExecutableFormatUnspecified; - - // A reference to the executable data as input to the cache. - // If ExecutableCachingMode::kAliasProvidedData is set then this reference - // may be retained by the cache and the backing buffer must be kept valid for - // the lifetime of the cache. - absl::Span<const uint8_t> executable_data; - - // TODO(benvanik): add specialization info (constants/defines). - // TODO(benvanik): add compiler flags? could treat as opaque. -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_EXECUTABLE_SPEC_H_
diff --git a/iree/hal/fence.h b/iree/hal/fence.h deleted file mode 100644 index b395e91..0000000 --- a/iree/hal/fence.h +++ /dev/null
@@ -1,72 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_FENCE_H_ -#define IREE_HAL_FENCE_H_ - -#include <cstdint> - -#include "iree/base/status.h" -#include "iree/hal/resource.h" - -namespace iree { -namespace hal { - -// Synchronization mechanism for device->host notification. -// Fences behave like timeline semaphores and contain a monotonically increasing -// uint64_t payload. They may be waited on any number of times - even if they -// have already been signaled. -// -// A fence is updated to its new value after all prior commands have completed -// but the delay between completion and the host being woken varies. Some -// implementations may coalesce fences to avoid spurious waking while others -// will immediately synchronize with the host. -// -// The primary use of fences is for resource lifetime management: all resources -// used by a set of submission batches must be considered live until the fence -// attached to the submission has signaled. -// -// Fences may be set to a permanently failed state by implementations when -// errors occur during asynchronous execution. Users are expected to propagate -// the failures and possibly reset the entire device that produced the error. -// -// For more information on fences see the following docs describing how -// timelines are generally used (specifically in the device->host case): -// https://www.youtube.com/watch?v=SpE--Rf516Y -// https://www.khronos.org/assets/uploads/developers/library/2018-xdc/Vulkan-Timeline-Semaphores-Part-1_Sep18.pdf -// https://docs.microsoft.com/en-us/windows/win32/direct3d12/user-mode-heap-synchronization -class Fence : public Resource { - public: - // Returns a permanent failure status if the fence is indicating an - // asynchronous failure. - // - // Returns the status at the time the method is called without blocking and as - // such is only valid after a fence has been signaled. The same failure status - // will be returned regardless of when in the timeline the error occurred. - virtual Status status() const = 0; - - // Queries the current payload of the fence. As the payload is monotonically - // increasing it is guaranteed that the value is at least equal to the - // previous result of a QueryValue call and coherent with any waits for a - // specified value via Device::WaitAllFences. - virtual StatusOr<uint64_t> QueryValue() = 0; -}; - -// A reference to a fence and associated payload value. -using FenceValue = std::pair<Fence*, uint64_t>; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_FENCE_H_
diff --git a/iree/hal/heap_buffer.cc b/iree/hal/heap_buffer.cc deleted file mode 100644 index 8c78935..0000000 --- a/iree/hal/heap_buffer.cc +++ /dev/null
@@ -1,190 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/heap_buffer.h" - -#include <cstdint> -#include <cstdlib> -#include <string> -#include <utility> - -#include "iree/base/source_location.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/allocator.h" -#include "iree/hal/host/host_buffer.h" - -namespace iree { -namespace hal { - -namespace { - -// An allocator that allocates or wraps host-only buffers. -// The resulting buffers are not usable by most devices without a copy and -// using a device allocator is strongly preferred. -class HeapAllocator : public Allocator { - public: - // Returns a singleton heap allocator that can provide buffers that have - // MemoryType::kHostLocal and are allocated with malloc/free. - // These buffers will not be usable by devices directly and may incur - // additional copies. - static Allocator* std_heap(); - - // TODO(benvanik): specify custom allocator (not malloc/free). - HeapAllocator(); - ~HeapAllocator() override; - - bool CanUseBufferLike(Allocator* source_allocator, - MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - BufferUsageBitfield intended_usage) const override; - - bool CanAllocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - size_t allocation_size) const override; - - StatusOr<ref_ptr<Buffer>> Allocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - size_t allocation_size) override; - - StatusOr<ref_ptr<Buffer>> WrapMutable(MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, - BufferUsageBitfield buffer_usage, - void* data, - size_t data_length) override; -}; - -// static -Allocator* HeapAllocator::std_heap() { - static Allocator* std_heap_allocator = new HeapAllocator(); - return std_heap_allocator; -} - -HeapAllocator::HeapAllocator() = default; - -HeapAllocator::~HeapAllocator() = default; - -bool HeapAllocator::CanUseBufferLike(Allocator* source_allocator, - MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - BufferUsageBitfield intended_usage) const { - // The host can use anything with kHostVisible. - if (!AnyBitSet(memory_type & MemoryType::kHostVisible)) { - return false; - } - - // Host currently uses mapping to copy buffers, which is done a lot. - if (!AnyBitSet(buffer_usage & BufferUsage::kMapping)) { - return false; - } - - return true; -} - -bool HeapAllocator::CanAllocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - size_t allocation_size) const { - // This host only allocator cannot serve device visible allocation as we - // can't know which devices these buffers will be used with. - return (memory_type & MemoryType::kHostLocal) == MemoryType::kHostLocal && - !AnyBitSet(memory_type & MemoryType::kDeviceLocal) && - !AnyBitSet(memory_type & MemoryType::kDeviceVisible); -} - -StatusOr<ref_ptr<Buffer>> HeapAllocator::Allocate( - MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - size_t allocation_size) { - IREE_TRACE_SCOPE0("HeapAllocator::Allocate"); - - if (!CanAllocate(memory_type, buffer_usage, allocation_size)) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Allocation not supported; memory_type=" - << MemoryTypeString(memory_type) - << ", buffer_usage=" << BufferUsageString(buffer_usage) - << ", allocation_size=" << allocation_size; - } - - void* malloced_data = std::calloc(1, allocation_size); - if (!malloced_data) { - return ResourceExhaustedErrorBuilder(IREE_LOC) - << "Failed to malloc " << allocation_size << " bytes"; - } - - auto buffer = - make_ref<HostBuffer>(this, memory_type, MemoryAccess::kAll, buffer_usage, - allocation_size, malloced_data, true); - return buffer; -} - -StatusOr<ref_ptr<Buffer>> HeapAllocator::WrapMutable( - MemoryTypeBitfield memory_type, MemoryAccessBitfield allowed_access, - BufferUsageBitfield buffer_usage, void* data, size_t data_length) { - auto buffer = make_ref<HostBuffer>(this, memory_type, allowed_access, - buffer_usage, data_length, data, false); - return buffer; -} - -} // namespace - -// static -ref_ptr<Buffer> HeapBuffer::Allocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield usage, - size_t allocation_size) { - auto buffer_or = - HeapAllocator::std_heap()->Allocate(memory_type, usage, allocation_size); - return std::move(buffer_or.ValueOrDie()); -} - -// static -ref_ptr<Buffer> HeapBuffer::AllocateCopy(BufferUsageBitfield usage, - const void* data, size_t data_length) { - return AllocateCopy(usage, MemoryAccess::kAll, data, data_length); -} - -// static -ref_ptr<Buffer> HeapBuffer::AllocateCopy(BufferUsageBitfield usage, - MemoryAccessBitfield allowed_access, - const void* data, size_t data_length) { - IREE_TRACE_SCOPE0("HeapBuffer::AllocateCopy"); - // Ensure we can map so that we can copy into it. - usage |= BufferUsage::kMapping; - auto buffer_or = HeapAllocator::std_heap()->Allocate(MemoryType::kHostLocal, - usage, data_length); - auto buffer = std::move(buffer_or.ValueOrDie()); - buffer->WriteData(0, data, data_length).IgnoreError(); - buffer->set_allowed_access(allowed_access); - return buffer; -} - -// static -ref_ptr<Buffer> HeapBuffer::Wrap(MemoryTypeBitfield memory_type, - BufferUsageBitfield usage, const void* data, - size_t data_length) { - auto buffer_or = - HeapAllocator::std_heap()->Wrap(memory_type, usage, data, data_length); - return std::move(buffer_or.ValueOrDie()); -} - -// static -ref_ptr<Buffer> HeapBuffer::WrapMutable(MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, - BufferUsageBitfield usage, void* data, - size_t data_length) { - auto buffer_or = HeapAllocator::std_heap()->WrapMutable( - memory_type, allowed_access, usage, data, data_length); - return std::move(buffer_or.ValueOrDie()); -} - -} // namespace hal -} // namespace iree
diff --git a/iree/hal/heap_buffer.h b/iree/hal/heap_buffer.h deleted file mode 100644 index eba9b72..0000000 --- a/iree/hal/heap_buffer.h +++ /dev/null
@@ -1,117 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_HEAP_BUFFER_H_ -#define IREE_HAL_HEAP_BUFFER_H_ - -#include <memory> - -#include "iree/base/status.h" -#include "iree/hal/buffer.h" - -namespace iree { -namespace hal { - -// Factory for buffers that are allocated from the host heap (malloc/free). -// These buffers cannot be used by devices and will incur copies/transfers when -// used. Prefer device-specific allocators instead. -class HeapBuffer { - public: - // Allocates a zeroed host heap buffer of the given size. - // Returns a buffer allocated with malloc and have MemoryType::kHostLocal - // and will not be usable by devices without copies. - static ref_ptr<Buffer> Allocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield usage, - size_t allocation_size); - static ref_ptr<Buffer> Allocate(BufferUsageBitfield usage, - size_t allocation_size) { - return Allocate(MemoryType::kHostLocal, usage, allocation_size); - } - - // Allocates a host heap buffer with a copy of the given data. - // Returns a buffer allocated with malloc and have MemoryType::kHostLocal - // and will not be usable by devices without copies. - static ref_ptr<Buffer> AllocateCopy(BufferUsageBitfield usage, - const void* data, size_t data_length); - static ref_ptr<Buffer> AllocateCopy(BufferUsageBitfield usage, - MemoryAccessBitfield allowed_access, - const void* data, size_t data_length); - template <typename T> - static ref_ptr<Buffer> AllocateCopy(BufferUsageBitfield usage, - absl::Span<const T> data); - template <typename T> - static ref_ptr<Buffer> AllocateCopy(BufferUsageBitfield usage, - MemoryAccessBitfield allowed_access, - absl::Span<const T> data); - - // Wraps an existing host heap allocation in a buffer. - // Ownership of the host allocation remains with the caller and the memory - // must remain valid for so long as the Buffer may be in use. - // Will have MemoryType::kHostLocal in most cases and may not be usable - // by the device. - static ref_ptr<Buffer> Wrap(MemoryTypeBitfield memory_type, - BufferUsageBitfield usage, const void* data, - size_t data_length); - static ref_ptr<Buffer> WrapMutable(MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, - BufferUsageBitfield usage, void* data, - size_t data_length); - template <typename T> - static ref_ptr<Buffer> Wrap(MemoryTypeBitfield memory_type, - BufferUsageBitfield usage, - absl::Span<const T> data); - template <typename T> - static ref_ptr<Buffer> WrapMutable(MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, - BufferUsageBitfield usage, - absl::Span<T> data); -}; - -// Inline functions and template definitions follow: - -template <typename T> -ref_ptr<Buffer> HeapBuffer::AllocateCopy(BufferUsageBitfield usage, - absl::Span<const T> data) { - return HeapBuffer::AllocateCopy(usage, MemoryAccess::kAll, data); -} - -template <typename T> -ref_ptr<Buffer> HeapBuffer::AllocateCopy(BufferUsageBitfield usage, - MemoryAccessBitfield allowed_access, - absl::Span<const T> data) { - return HeapBuffer::AllocateCopy(usage, allowed_access, data.data(), - data.size() * sizeof(T)); -} - -template <typename T> -ref_ptr<Buffer> HeapBuffer::Wrap(MemoryTypeBitfield memory_type, - BufferUsageBitfield usage, - absl::Span<const T> data) { - return HeapBuffer::Wrap(memory_type, usage, data.data(), - data.size() * sizeof(T)); -} - -template <typename T> -ref_ptr<Buffer> HeapBuffer::WrapMutable(MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, - BufferUsageBitfield usage, - absl::Span<T> data) { - return HeapBuffer::WrapMutable(memory_type, allowed_access, usage, - data.data(), data.size() * sizeof(T)); -} - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_HEAP_BUFFER_H_
diff --git a/iree/hal/host/BUILD b/iree/hal/host/BUILD deleted file mode 100644 index 6e6c51f..0000000 --- a/iree/hal/host/BUILD +++ /dev/null
@@ -1,155 +0,0 @@ -# Default implementations for HAL types that use the host resources. -# These are generally just wrappers around host heap memory and host threads. - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "async_command_queue", - srcs = ["async_command_queue.cc"], - hdrs = ["async_command_queue.h"], - deps = [ - ":host_submission_queue", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:command_queue", - "//iree/hal:fence", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/synchronization", - ], -) - -cc_test( - name = "async_command_queue_test", - srcs = ["async_command_queue_test.cc"], - deps = [ - ":async_command_queue", - ":host_submission_queue", - "//iree/base:status", - "//iree/base:status_matchers", - "//iree/base:time", - "//iree/hal:command_queue", - "//iree/hal/testing:mock_command_buffer", - "//iree/hal/testing:mock_command_queue", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "host_buffer", - srcs = ["host_buffer.cc"], - hdrs = ["host_buffer.h"], - deps = [ - "//iree/base:logging", - "//iree/base:source_location", - "//iree/base:status", - "//iree/hal:buffer", - "@com_google_absl//absl/base:core_headers", - ], -) - -cc_library( - name = "host_event", - srcs = ["host_event.cc"], - hdrs = ["host_event.h"], - deps = [ - "//iree/hal:event", - ], -) - -cc_library( - name = "host_fence", - srcs = ["host_fence.cc"], - hdrs = ["host_fence.h"], - deps = [ - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:fence", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "host_fence_test", - srcs = ["host_fence_test.cc"], - deps = [ - ":host_fence", - "//iree/base:status", - "//iree/base:status_matchers", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "host_local_allocator", - srcs = ["host_local_allocator.cc"], - hdrs = ["host_local_allocator.h"], - deps = [ - ":host_buffer", - "//iree/base:source_location", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:allocator", - "//iree/hal:buffer", - ], -) - -cc_library( - name = "host_local_command_processor", - srcs = ["host_local_command_processor.cc"], - hdrs = ["host_local_command_processor.h"], - deps = [ - "//iree/base:source_location", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:command_buffer", - ], -) - -cc_library( - name = "host_submission_queue", - srcs = ["host_submission_queue.cc"], - hdrs = ["host_submission_queue.h"], - deps = [ - ":host_fence", - "//iree/base:intrusive_list", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:command_queue", - "//iree/hal:fence", - "//iree/hal:semaphore", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/synchronization", - ], -) - -cc_test( - name = "host_submission_queue_test", - srcs = ["host_submission_queue_test.cc"], - deps = [ - ":host_submission_queue", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "inproc_command_buffer", - srcs = ["inproc_command_buffer.cc"], - hdrs = ["inproc_command_buffer.h"], - deps = [ - "//iree/base:arena", - "//iree/base:intrusive_list", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:command_buffer", - ], -)
diff --git a/iree/hal/host/async_command_queue.cc b/iree/hal/host/async_command_queue.cc deleted file mode 100644 index f7e549f..0000000 --- a/iree/hal/host/async_command_queue.cc +++ /dev/null
@@ -1,127 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/host/async_command_queue.h" - -#include "absl/base/thread_annotations.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" - -namespace iree { -namespace hal { - -AsyncCommandQueue::AsyncCommandQueue(std::unique_ptr<CommandQueue> target_queue) - : CommandQueue(target_queue->name(), target_queue->supported_categories()), - target_queue_(std::move(target_queue)) { - IREE_TRACE_SCOPE0("AsyncCommandQueue::ctor"); - thread_ = std::thread([this]() { ThreadMain(); }); -} - -AsyncCommandQueue::~AsyncCommandQueue() { - IREE_TRACE_SCOPE0("AsyncCommandQueue::dtor"); - { - // Signal to thread that we want to stop. Note that the thread may have - // already been stopped and that's ok (as we'll Join right away). - // The thread will finish processing any queued submissions. - absl::MutexLock lock(&submission_mutex_); - submission_queue_.SignalShutdown(); - } - thread_.join(); - - // Ensure we shut down OK. - { - absl::MutexLock lock(&submission_mutex_); - CHECK(submission_queue_.empty()) - << "Dirty shutdown of async queue (unexpected thread exit?)"; - } -} - -void AsyncCommandQueue::ThreadMain() { - // TODO(benvanik): make this safer (may die if trace is flushed late). - IREE_TRACE_THREAD_ENABLE(target_queue_->name().c_str()); - - bool is_exiting = false; - while (!is_exiting) { - // Block until we are either requested to exit or there are pending - // submissions. - submission_mutex_.Lock(); - submission_mutex_.Await(absl::Condition( - +[](HostSubmissionQueue* queue) { - return queue->has_shutdown() || !queue->empty(); - }, - &submission_queue_)); - if (!submission_queue_.empty()) { - // Run all ready submissions (this may be called many times). - submission_mutex_.AssertHeld(); - submission_queue_ - .ProcessBatches( - [this](absl::Span<CommandBuffer* const> command_buffers) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(submission_mutex_) { - // Release the lock while we perform the processing so that - // other threads can submit more work. - submission_mutex_.AssertHeld(); - submission_mutex_.Unlock(); - - // Relay the command buffers to the target queue. - // Since we are taking care of all synchronization they - // don't need any waiters or fences. - auto status = target_queue_->Submit( - {{}, command_buffers, {}}, {nullptr, 0u}); - - // Take back the lock so we can manipulate the queue safely. - submission_mutex_.Lock(); - submission_mutex_.AssertHeld(); - - return status; - }) - .IgnoreError(); - submission_mutex_.AssertHeld(); - } - if (submission_queue_.has_shutdown()) { - // Exit when there are no more submissions to process and an exit was - // requested (or we errored out). - is_exiting = true; - } - submission_mutex_.Unlock(); - } -} - -Status AsyncCommandQueue::Submit(absl::Span<const SubmissionBatch> batches, - FenceValue fence) { - IREE_TRACE_SCOPE0("AsyncCommandQueue::Submit"); - absl::MutexLock lock(&submission_mutex_); - return submission_queue_.Enqueue(batches, fence); -} - -Status AsyncCommandQueue::WaitIdle(absl::Time deadline) { - IREE_TRACE_SCOPE0("AsyncCommandQueue::WaitIdle"); - - // Wait until the deadline, the thread exits, or there are no more pending - // submissions. - absl::MutexLock lock(&submission_mutex_); - if (!submission_mutex_.AwaitWithDeadline( - absl::Condition( - +[](HostSubmissionQueue* queue) { - return queue->empty() || !queue->permanent_error().ok(); - }, - &submission_queue_), - deadline)) { - return DeadlineExceededErrorBuilder(IREE_LOC) - << "Deadline exceeded waiting for submission thread to go idle"; - } - return submission_queue_.permanent_error(); -} - -} // namespace hal -} // namespace iree
diff --git a/iree/hal/host/async_command_queue.h b/iree/hal/host/async_command_queue.h deleted file mode 100644 index 5716d55..0000000 --- a/iree/hal/host/async_command_queue.h +++ /dev/null
@@ -1,71 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_HOST_ASYNC_COMMAND_QUEUE_H_ -#define IREE_HAL_HOST_ASYNC_COMMAND_QUEUE_H_ - -#include <memory> -#include <thread> // NOLINT - -#include "absl/base/thread_annotations.h" -#include "absl/synchronization/mutex.h" -#include "iree/hal/command_queue.h" -#include "iree/hal/fence.h" -#include "iree/hal/host/host_submission_queue.h" - -namespace iree { -namespace hal { - -// Asynchronous command queue wrapper. -// This creates a single thread to perform all CommandQueue operations. Any -// submitted CommandBuffer is dispatched in FIFO order on the queue thread -// against the provided |target_queue|. -// -// Target queues will receive submissions containing only command buffers as -// all semaphore synchronization is handled by the wrapper. Fences will also be -// omitted and code should safely handle nullptr. -// -// AsyncCommandQueue (as with CommandQueue) is thread-safe. Multiple threads -// may submit command buffers concurrently, though the order of execution in -// such a case depends entirely on the synchronization primitives provided. -class AsyncCommandQueue final : public CommandQueue { - public: - explicit AsyncCommandQueue(std::unique_ptr<CommandQueue> target_queue); - ~AsyncCommandQueue() override; - - Status Submit(absl::Span<const SubmissionBatch> batches, - FenceValue fence) override; - - Status WaitIdle(absl::Time deadline) override; - - private: - // Thread entry point for the async worker thread. - // Waits for submissions to be queued up and processes them eagerly. - void ThreadMain(); - - // CommandQueue that the async queue relays submissions into. - std::unique_ptr<CommandQueue> target_queue_; - - // Thread that runs the ThreadMain() function and processes submissions. - std::thread thread_; - - // Queue that manages submission ordering. - mutable absl::Mutex submission_mutex_; - HostSubmissionQueue submission_queue_ ABSL_GUARDED_BY(submission_mutex_); -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_HOST_ASYNC_COMMAND_QUEUE_H_
diff --git a/iree/hal/host/async_command_queue_test.cc b/iree/hal/host/async_command_queue_test.cc deleted file mode 100644 index 307ed3b..0000000 --- a/iree/hal/host/async_command_queue_test.cc +++ /dev/null
@@ -1,232 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/host/async_command_queue.h" - -#include <cstdint> -#include <memory> -#include <utility> - -#include "absl/memory/memory.h" -#include "absl/time/time.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "iree/base/status.h" -#include "iree/base/status_matchers.h" -#include "iree/base/time.h" -#include "iree/hal/command_queue.h" -#include "iree/hal/host/host_submission_queue.h" -#include "iree/hal/testing/mock_command_buffer.h" -#include "iree/hal/testing/mock_command_queue.h" - -namespace iree { -namespace hal { -namespace { - -using ::testing::_; - -using testing::MockCommandBuffer; -using testing::MockCommandQueue; - -struct AsyncCommandQueueTest : public ::testing::Test { - MockCommandQueue* mock_target_queue; - std::unique_ptr<CommandQueue> command_queue; - - void SetUp() override { - auto mock_queue = absl::make_unique<MockCommandQueue>( - "mock", CommandCategory::kTransfer | CommandCategory::kDispatch); - mock_target_queue = mock_queue.get(); - command_queue = absl::make_unique<AsyncCommandQueue>(std::move(mock_queue)); - } - - void TearDown() override { - command_queue.reset(); - mock_target_queue = nullptr; - } -}; - -// Tests that submitting a command buffer and immediately waiting will not -// deadlock. -TEST_F(AsyncCommandQueueTest, BlockingSubmit) { - ::testing::InSequence sequence; - - auto cmd_buffer = make_ref<MockCommandBuffer>( - nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer); - - EXPECT_CALL(*mock_target_queue, Submit(_, _)) - .WillOnce( - [&](absl::Span<const SubmissionBatch> batches, FenceValue fence) { - CHECK_EQ(1, batches.size()); - CHECK_EQ(1, batches[0].command_buffers.size()); - CHECK_EQ(cmd_buffer.get(), batches[0].command_buffers[0]); - CHECK_EQ(nullptr, fence.first); - return OkStatus(); - }); - HostFence fence(0u); - ASSERT_OK(command_queue->Submit({{}, {cmd_buffer.get()}, {}}, {&fence, 1u})); - ASSERT_OK(HostFence::WaitForFences({{&fence, 1u}}, /*wait_all=*/true, - absl::InfiniteFuture())); -} - -// Tests that failure is propagated along the fence from the target queue. -TEST_F(AsyncCommandQueueTest, PropagateSubmitFailure) { - ::testing::InSequence sequence; - - auto cmd_buffer = make_ref<MockCommandBuffer>( - nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer); - - EXPECT_CALL(*mock_target_queue, Submit(_, _)) - .WillOnce( - [](absl::Span<const SubmissionBatch> batches, FenceValue fence) { - return DataLossErrorBuilder(IREE_LOC); - }); - HostFence fence(0u); - ASSERT_OK(command_queue->Submit({{}, {cmd_buffer.get()}, {}}, {&fence, 1u})); - EXPECT_TRUE(IsDataLoss(HostFence::WaitForFences( - {{&fence, 1u}}, /*wait_all=*/true, absl::InfiniteFuture()))); -} - -// Tests that waiting for idle is a no-op when nothing is queued. -TEST_F(AsyncCommandQueueTest, WaitIdleWhileIdle) { - ASSERT_OK(command_queue->WaitIdle()); -} - -// Tests that waiting for idle will block when work is pending/in-flight. -TEST_F(AsyncCommandQueueTest, WaitIdleWithPending) { - ::testing::InSequence sequence; - - auto cmd_buffer = make_ref<MockCommandBuffer>( - nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer); - - EXPECT_CALL(*mock_target_queue, Submit(_, _)) - .WillOnce( - [](absl::Span<const SubmissionBatch> batches, FenceValue fence) { - Sleep(absl::Milliseconds(100)); - return OkStatus(); - }); - HostFence fence(0u); - ASSERT_OK(command_queue->Submit({{}, {cmd_buffer.get()}, {}}, {&fence, 1u})); - - // This should block for a sec or two. - ASSERT_OK(command_queue->WaitIdle()); - - // Should have already expired. - ASSERT_OK_AND_ASSIGN(uint64_t value, fence.QueryValue()); - ASSERT_EQ(1u, value); -} - -// Tests that waiting for idle with multiple pending submissions will wait until -// all of them complete while still allowing incremental progress. -TEST_F(AsyncCommandQueueTest, WaitIdleAndProgress) { - ::testing::InSequence sequence; - - EXPECT_CALL(*mock_target_queue, Submit(_, _)) - .WillRepeatedly( - [](absl::Span<const SubmissionBatch> batches, FenceValue fence) { - Sleep(absl::Milliseconds(100)); - return OkStatus(); - }); - - auto cmd_buffer_0 = make_ref<MockCommandBuffer>( - nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer); - auto cmd_buffer_1 = make_ref<MockCommandBuffer>( - nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer); - - HostFence fence_0(0u); - ASSERT_OK( - command_queue->Submit({{}, {cmd_buffer_0.get()}, {}}, {&fence_0, 1u})); - HostFence fence_1(0u); - ASSERT_OK( - command_queue->Submit({{}, {cmd_buffer_1.get()}, {}}, {&fence_1, 1u})); - - // This should block for a sec or two. - ASSERT_OK(command_queue->WaitIdle()); - - // Both should have already expired. - ASSERT_OK_AND_ASSIGN(uint64_t value_0, fence_0.QueryValue()); - ASSERT_EQ(1u, value_0); - ASSERT_OK_AND_ASSIGN(uint64_t value_1, fence_1.QueryValue()); - ASSERT_EQ(1u, value_1); -} - -// Tests that failures are sticky. -TEST_F(AsyncCommandQueueTest, StickyFailures) { - ::testing::InSequence sequence; - - // Fail. - EXPECT_CALL(*mock_target_queue, Submit(_, _)) - .WillOnce( - [](absl::Span<const SubmissionBatch> batches, FenceValue fence) { - Sleep(absl::Milliseconds(100)); - return DataLossErrorBuilder(IREE_LOC); - }); - auto cmd_buffer_0 = make_ref<MockCommandBuffer>( - nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer); - HostFence fence_0(0u); - ASSERT_OK( - command_queue->Submit({{}, {cmd_buffer_0.get()}, {}}, {&fence_0, 1u})); - EXPECT_TRUE(IsDataLoss(HostFence::WaitForFences( - {{&fence_0, 1u}}, /*wait_all=*/true, absl::InfiniteFuture()))); - - // Future flushes/waits/etc should also fail. - EXPECT_TRUE(IsDataLoss(command_queue->WaitIdle())); - - // Future submits should fail asynchronously. - auto cmd_buffer_1 = make_ref<MockCommandBuffer>( - nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer); - HostFence fence_1(0u); - EXPECT_TRUE(IsDataLoss( - command_queue->Submit({{}, {cmd_buffer_1.get()}, {}}, {&fence_1, 1u}))); -} - -// Tests that a failure with two submissions pending causes the second to -// bail as well. -TEST_F(AsyncCommandQueueTest, FailuresCascadeAcrossSubmits) { - ::testing::InSequence sequence; - - // Fail. - EXPECT_CALL(*mock_target_queue, Submit(_, _)) - .WillOnce( - [](absl::Span<const SubmissionBatch> batches, FenceValue fence) { - Sleep(absl::Milliseconds(100)); - return DataLossErrorBuilder(IREE_LOC); - }); - - auto cmd_buffer_0 = make_ref<MockCommandBuffer>( - nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer); - auto cmd_buffer_1 = make_ref<MockCommandBuffer>( - nullptr, CommandBufferMode::kOneShot, CommandCategory::kTransfer); - - HostBinarySemaphore semaphore_0_1(false); - HostFence fence_0(0u); - ASSERT_OK(command_queue->Submit({{}, {cmd_buffer_0.get()}, {&semaphore_0_1}}, - {&fence_0, 1u})); - HostFence fence_1(0u); - ASSERT_OK(command_queue->Submit({{&semaphore_0_1}, {cmd_buffer_1.get()}, {}}, - {&fence_1, 1u})); - - EXPECT_TRUE(IsDataLoss(command_queue->WaitIdle())); - - EXPECT_TRUE(IsDataLoss(HostFence::WaitForFences( - {{&fence_0, 1u}}, /*wait_all=*/true, absl::InfiniteFuture()))); - EXPECT_TRUE(IsDataLoss(HostFence::WaitForFences( - {{&fence_1, 1u}}, /*wait_all=*/true, absl::InfiniteFuture()))); - - // Future flushes/waits/etc should also fail. - EXPECT_TRUE(IsDataLoss(command_queue->WaitIdle())); -} - -} // namespace -} // namespace hal -} // namespace iree
diff --git a/iree/hal/host/host_buffer.cc b/iree/hal/host/host_buffer.cc deleted file mode 100644 index 555f6ac..0000000 --- a/iree/hal/host/host_buffer.cc +++ /dev/null
@@ -1,148 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/host/host_buffer.h" - -#include <cstdint> -#include <cstdlib> -#include <cstring> - -#include "iree/base/logging.h" -#include "iree/base/source_location.h" -#include "iree/base/status.h" - -namespace iree { -namespace hal { - -class Allocator; - -HostBuffer::HostBuffer(Allocator* allocator, MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, - BufferUsageBitfield usage, device_size_t allocation_size, - void* data, bool owns_data) - : Buffer(allocator, memory_type, allowed_access, usage, allocation_size, 0, - allocation_size), - data_(data), - owns_data_(owns_data) {} - -HostBuffer::~HostBuffer() { - if (owns_data_ && data_) { - std::free(data_); - data_ = nullptr; - } -} - -Status HostBuffer::FillImpl(device_size_t byte_offset, - device_size_t byte_length, const void* pattern, - device_size_t pattern_length) { - auto data_ptr = data_; - switch (pattern_length) { - case 1: { - uint8_t* data = static_cast<uint8_t*>(data_ptr); - uint8_t value_bits = *static_cast<const uint8_t*>(pattern); - std::fill_n(data + byte_offset, byte_length, value_bits); - break; - } - case 2: { - uint16_t* data = static_cast<uint16_t*>(data_ptr); - uint16_t value_bits = *static_cast<const uint16_t*>(pattern); - std::fill_n(data + byte_offset / sizeof(uint16_t), - byte_length / sizeof(uint16_t), value_bits); - break; - } - case 4: { - uint32_t* data = static_cast<uint32_t*>(data_ptr); - uint32_t value_bits = *static_cast<const uint32_t*>(pattern); - std::fill_n(data + byte_offset / sizeof(uint32_t), - byte_length / sizeof(uint32_t), value_bits); - break; - } - default: - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Unsupported scalar data size: " << pattern_length; - } - return OkStatus(); -} - -Status HostBuffer::ReadDataImpl(device_size_t source_offset, void* data, - device_size_t data_length) { - auto data_ptr = static_cast<uint8_t*>(data_); - std::memcpy(data, data_ptr + source_offset, data_length); - return OkStatus(); -} - -Status HostBuffer::WriteDataImpl(device_size_t target_offset, const void* data, - device_size_t data_length) { - auto data_ptr = static_cast<uint8_t*>(data_); - std::memcpy(data_ptr + target_offset, data, data_length); - return OkStatus(); -} - -Status HostBuffer::CopyDataImpl(device_size_t target_offset, - Buffer* source_buffer, - device_size_t source_offset, - device_size_t data_length) { - // This is pretty terrible. Let's not do this. - // TODO(benvanik): a way for allocators to indicate transfer compat. - ASSIGN_OR_RETURN(auto source_data, - source_buffer->MapMemory<uint8_t>( - MemoryAccess::kRead, source_offset, data_length)); - CHECK_EQ(data_length, source_data.size()); - auto data_ptr = static_cast<uint8_t*>(data_); - std::memcpy(data_ptr + target_offset, source_data.data(), data_length); - return OkStatus(); -} - -Status HostBuffer::MapMemoryImpl(MappingMode mapping_mode, - MemoryAccessBitfield memory_access, - device_size_t local_byte_offset, - device_size_t local_byte_length, - void** out_data) { - auto data_ptr = static_cast<uint8_t*>(data_); - *out_data = data_ptr + local_byte_offset; - - // If we mapped for discard scribble over the bytes. This is not a mandated - // behavior but it will make debugging issues easier. Alternatively for - // heap buffers we could reallocate them such that ASAN yells, but that - // would only work if the entire buffer was discarded. -#ifndef NDEBUG - if (AnyBitSet(memory_access & MemoryAccess::kDiscard)) { - std::memset(data_ptr + local_byte_offset, 0xCD, local_byte_length); - } -#endif // !NDEBUG - - return OkStatus(); -} - -Status HostBuffer::UnmapMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length, - void* data) { - // No-op? We still want error checking to make finding misuse easier. - return OkStatus(); -} - -Status HostBuffer::InvalidateMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) { - // No-op? We still want error checking to make finding misuse easier. - return OkStatus(); -} - -Status HostBuffer::FlushMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) { - // No-op? We still want error checking to make finding misuse easier. - return OkStatus(); -} - -} // namespace hal -} // namespace iree
diff --git a/iree/hal/host/host_buffer.h b/iree/hal/host/host_buffer.h deleted file mode 100644 index 7a52758..0000000 --- a/iree/hal/host/host_buffer.h +++ /dev/null
@@ -1,67 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_HOST_BUFFER_H_ -#define IREE_HAL_HOST_BUFFER_H_ - -#include <cstdint> - -#include "iree/base/status.h" -#include "iree/hal/buffer.h" - -namespace iree { -namespace hal { - -// A buffer type that operates on host pointers. -// This can be used by Allocator implementations when they support operating -// on host memory (or mapping their memory to host memory). -class HostBuffer : public Buffer { - public: - HostBuffer(Allocator* allocator, MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, BufferUsageBitfield usage, - device_size_t allocation_size, void* data, bool owns_data); - - ~HostBuffer() override; - - protected: - Status FillImpl(device_size_t byte_offset, device_size_t byte_length, - const void* pattern, device_size_t pattern_length) override; - Status ReadDataImpl(device_size_t source_offset, void* data, - device_size_t data_length) override; - Status WriteDataImpl(device_size_t target_offset, const void* data, - device_size_t data_length) override; - Status CopyDataImpl(device_size_t target_offset, Buffer* source_buffer, - device_size_t source_offset, - device_size_t data_length) override; - Status MapMemoryImpl(MappingMode mapping_mode, - MemoryAccessBitfield memory_access, - device_size_t local_byte_offset, - device_size_t local_byte_length, - void** out_data) override; - Status UnmapMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length, void* data) override; - Status InvalidateMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) override; - Status FlushMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) override; - - private: - void* data_ = nullptr; - bool owns_data_ = false; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_HOST_BUFFER_H_
diff --git a/iree/hal/host/host_event.cc b/iree/hal/host/host_event.cc deleted file mode 100644 index 9ffac59..0000000 --- a/iree/hal/host/host_event.cc +++ /dev/null
@@ -1,25 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/host/host_event.h" - -namespace iree { -namespace hal { - -HostEvent::HostEvent() = default; - -HostEvent::~HostEvent() = default; - -} // namespace hal -} // namespace iree
diff --git a/iree/hal/host/host_event.h b/iree/hal/host/host_event.h deleted file mode 100644 index c5fb33a..0000000 --- a/iree/hal/host/host_event.h +++ /dev/null
@@ -1,32 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_HOST_HOST_EVENT_H_ -#define IREE_HAL_HOST_HOST_EVENT_H_ - -#include "iree/hal/event.h" - -namespace iree { -namespace hal { - -class HostEvent final : public Event { - public: - HostEvent(); - ~HostEvent() override; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_HOST_HOST_EVENT_H_
diff --git a/iree/hal/host/host_fence.cc b/iree/hal/host/host_fence.cc deleted file mode 100644 index 6932b20..0000000 --- a/iree/hal/host/host_fence.cc +++ /dev/null
@@ -1,110 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/host/host_fence.h" - -#include <atomic> -#include <cstdint> - -#include "absl/container/inlined_vector.h" -#include "absl/synchronization/mutex.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" - -namespace iree { -namespace hal { - -HostFence::HostFence(uint64_t initial_value) : value_(initial_value) {} - -HostFence::~HostFence() = default; - -Status HostFence::status() const { - absl::MutexLock lock(&mutex_); - return status_; -} - -StatusOr<uint64_t> HostFence::QueryValue() { - return value_.load(std::memory_order_acquire); -} - -Status HostFence::Signal(uint64_t value) { - absl::MutexLock lock(&mutex_); - if (!status_.ok()) { - return status_; - } - if (value_.exchange(value) >= value) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Fence values must be monotonically increasing"; - } - return OkStatus(); -} - -Status HostFence::Fail(Status status) { - absl::MutexLock lock(&mutex_); - status_ = status; - value_.store(UINT64_MAX, std::memory_order_release); - return OkStatus(); -} - -// static -Status HostFence::WaitForFences(absl::Span<const FenceValue> fences, - bool wait_all, absl::Time deadline) { - IREE_TRACE_SCOPE0("HostFence::WaitForFences"); - - // Some of the fences may already be signaled; we only need to wait for those - // that are not yet at the expected value. - using HostFenceValue = std::pair<HostFence*, uint64_t>; - absl::InlinedVector<HostFenceValue, 4> waitable_fences; - waitable_fences.reserve(fences.size()); - for (auto& fence_value : fences) { - auto* fence = reinterpret_cast<HostFence*>(fence_value.first); - ASSIGN_OR_RETURN(uint64_t current_value, fence->QueryValue()); - if (current_value == UINT64_MAX) { - // Fence has failed. Return the error. - return fence->status(); - } else if (current_value < fence_value.second) { - // Fence has not yet hit the required value; wait for it. - waitable_fences.push_back({fence, fence_value.second}); - } - } - - // TODO(benvanik): maybe sort fences by value in case we are waiting on - // multiple values from the same fence. - - // Loop over the fences and wait for them to complete. - // TODO(b/140026716): add WaitHandle support for !wait_all (wait any). - for (auto& fence_value : waitable_fences) { - auto* fence = fence_value.first; - absl::MutexLock lock(&fence->mutex_); - if (!fence->mutex_.AwaitWithDeadline( - absl::Condition( - +[](HostFenceValue* fence_value) { - return fence_value->first->value_.load( - std::memory_order_acquire) >= fence_value->second; - }, - &fence_value), - deadline)) { - return DeadlineExceededErrorBuilder(IREE_LOC) - << "Deadline exceeded waiting for fences"; - } - if (!fence->status_.ok()) { - return fence->status_; - } - } - - return OkStatus(); -} - -} // namespace hal -} // namespace iree
diff --git a/iree/hal/host/host_fence.h b/iree/hal/host/host_fence.h deleted file mode 100644 index 80792e7..0000000 --- a/iree/hal/host/host_fence.h +++ /dev/null
@@ -1,64 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_HOST_HOST_FENCE_H_ -#define IREE_HAL_HOST_HOST_FENCE_H_ - -#include <atomic> -#include <cstdint> - -#include "absl/base/thread_annotations.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" -#include "iree/base/status.h" -#include "iree/hal/fence.h" - -namespace iree { -namespace hal { - -// TODO(b/140026716): add WaitHandle support for better multi-wait. -// Simple host-only fence semaphore implemented with a mutex. -// -// Thread-safe (as instances may be imported and used by others). -class HostFence final : public Fence { - public: - // Waits for one or more (or all) fences to reach or exceed the given values. - static Status WaitForFences(absl::Span<const FenceValue> fences, - bool wait_all, absl::Time deadline); - - explicit HostFence(uint64_t initial_value); - ~HostFence() override; - - Status status() const override; - StatusOr<uint64_t> QueryValue() override; - - Status Signal(uint64_t value); - Status Fail(Status status); - - private: - // The mutex is not required to query the value; this lets us quickly check if - // a required value has been exceeded. The mutex is only used to update and - // notify waiters. - std::atomic<uint64_t> value_{0}; - - // We have a full mutex here so that we can perform condvar waits on value - // changes. - mutable absl::Mutex mutex_; - Status status_ ABSL_GUARDED_BY(mutex_); -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_HOST_HOST_FENCE_H_
diff --git a/iree/hal/host/host_fence_test.cc b/iree/hal/host/host_fence_test.cc deleted file mode 100644 index a843923..0000000 --- a/iree/hal/host/host_fence_test.cc +++ /dev/null
@@ -1,148 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/host/host_fence.h" - -#include <cstdint> -#include <thread> // NOLINT - -#include "absl/time/time.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "iree/base/status.h" -#include "iree/base/status_matchers.h" - -namespace iree { -namespace hal { -namespace { - -// Tests that a fence that is unused properly cleans itself up. -TEST(HostFenceTest, NoOp) { - HostFence fence(123u); - EXPECT_TRUE(fence.status().ok()); - ASSERT_OK_AND_ASSIGN(uint64_t value, fence.QueryValue()); - EXPECT_EQ(123u, value); -} - -// Tests that a fence will accept new values as it is signaled. -TEST(HostFenceTest, NormalSignaling) { - HostFence fence(2u); - EXPECT_EQ(2u, fence.QueryValue().ValueOrDie()); - EXPECT_OK(fence.Signal(3u)); - EXPECT_EQ(3u, fence.QueryValue().ValueOrDie()); - EXPECT_OK(fence.Signal(40u)); - EXPECT_EQ(40u, fence.QueryValue().ValueOrDie()); -} - -// Tests that a fence will fail to set non-increasing values. -TEST(HostFenceTest, RequireIncreasingValues) { - HostFence fence(2u); - EXPECT_EQ(2u, fence.QueryValue().ValueOrDie()); - // Same value. - EXPECT_TRUE(IsInvalidArgument(fence.Signal(2u))); - // Decreasing. - EXPECT_TRUE(IsInvalidArgument(fence.Signal(1u))); -} - -// Tests that a fence that has failed will remain in a failed state. -TEST(HostFenceTest, StickyFailure) { - HostFence fence(2u); - // Signal to 3. - EXPECT_OK(fence.Signal(3u)); - EXPECT_TRUE(fence.status().ok()); - EXPECT_EQ(3u, fence.QueryValue().ValueOrDie()); - - // Fail now. - EXPECT_OK(fence.Fail(UnknownErrorBuilder(IREE_LOC))); - EXPECT_TRUE(IsUnknown(fence.status())); - EXPECT_EQ(UINT64_MAX, fence.QueryValue().ValueOrDie()); - - // Unable to signal again (it'll return the sticky failure). - EXPECT_TRUE(IsUnknown(fence.Signal(4u))); - EXPECT_TRUE(IsUnknown(fence.status())); - EXPECT_EQ(UINT64_MAX, fence.QueryValue().ValueOrDie()); -} - -// Tests waiting on no fences. -TEST(HostFenceTest, EmptyWait) { - EXPECT_OK( - HostFence::WaitForFences({}, /*wait_all=*/true, absl::InfiniteFuture())); -} - -// Tests waiting on a fence that has already been signaled. -TEST(HostFenceTest, WaitAlreadySignaled) { - HostFence fence(2u); - // Test both previous and current values. - EXPECT_OK(HostFence::WaitForFences({{&fence, 1u}}, /*wait_all=*/true, - absl::InfiniteFuture())); - EXPECT_OK(HostFence::WaitForFences({{&fence, 2u}}, /*wait_all=*/true, - absl::InfiniteFuture())); -} - -// Tests waiting on a fence that has not been signaled. -TEST(HostFenceTest, WaitUnsignaled) { - HostFence fence(2u); - // NOTE: we don't actually block here because otherwise we'd lock up. - EXPECT_TRUE(IsDeadlineExceeded(HostFence::WaitForFences( - {{&fence, 3u}}, /*wait_all=*/true, absl::InfinitePast()))); -} - -// Tests waiting on a failed fence (it should return the error on the fence). -TEST(HostFenceTest, WaitAlreadyFailed) { - HostFence fence(2u); - EXPECT_OK(fence.Fail(UnknownErrorBuilder(IREE_LOC))); - EXPECT_TRUE(IsUnknown(HostFence::WaitForFences( - {{&fence, 2u}}, /*wait_all=*/true, absl::InfinitePast()))); -} - -// Tests threading behavior by ping-ponging between the test main thread and -// a little thread. -TEST(HostFenceTest, PingPong) { - HostFence a2b(0u); - HostFence b2a(0u); - std::thread thread([&]() { - // Should advance right past this because the value is already set. - ASSERT_OK(HostFence::WaitForFences({{&a2b, 0u}}, /*wait_all=*/true, - absl::InfiniteFuture())); - ASSERT_OK(b2a.Signal(1u)); - // Jump ahead. - ASSERT_OK(HostFence::WaitForFences({{&a2b, 4u}}, /*wait_all=*/true, - absl::InfiniteFuture())); - }); - ASSERT_OK(HostFence::WaitForFences({{&b2a, 1u}}, /*wait_all=*/true, - absl::InfiniteFuture())); - ASSERT_OK(a2b.Signal(4u)); - thread.join(); -} - -// Tests that failure still wakes waiters and propagates the error. -TEST(HostFenceTest, FailNotifies) { - HostFence a2b(0u); - HostFence b2a(0u); - bool got_failure = false; - std::thread thread([&]() { - ASSERT_OK(b2a.Signal(1u)); - got_failure = IsUnknown(HostFence::WaitForFences( - {{&a2b, 1u}}, /*wait_all=*/true, absl::InfiniteFuture())); - }); - ASSERT_OK(HostFence::WaitForFences({{&b2a, 1u}}, /*wait_all=*/true, - absl::InfiniteFuture())); - ASSERT_OK(a2b.Fail(UnknownErrorBuilder(IREE_LOC))); - thread.join(); - ASSERT_TRUE(got_failure); -} - -} // namespace -} // namespace hal -} // namespace iree
diff --git a/iree/hal/host/host_local_allocator.cc b/iree/hal/host/host_local_allocator.cc deleted file mode 100644 index 10b898d..0000000 --- a/iree/hal/host/host_local_allocator.cc +++ /dev/null
@@ -1,111 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/host/host_local_allocator.h" - -#include <cstdlib> -#include <string> -#include <utility> - -#include "iree/base/source_location.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/host/host_buffer.h" - -namespace iree { -namespace hal { - -HostLocalAllocator::HostLocalAllocator() = default; - -HostLocalAllocator::~HostLocalAllocator() = default; - -bool HostLocalAllocator::CanUseBufferLike( - Allocator* source_allocator, MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - BufferUsageBitfield intended_usage) const { - // Must always have visibility to the device, which ensures we can test - // against the host but have things work on devices with separate address - // spaces. - if (!AnyBitSet(memory_type & MemoryType::kDeviceVisible)) { - return false; - } - - // kHostVisible is required for mapping. - if (AnyBitSet(intended_usage & BufferUsage::kMapping) && - !AnyBitSet(memory_type & MemoryType::kHostVisible)) { - return false; - } - - // Dispatch needs to be specified if we intend to dispatch. - if (AnyBitSet(intended_usage & BufferUsage::kDispatch) && - !AnyBitSet(buffer_usage & BufferUsage::kDispatch)) { - return false; - } - - return true; -} - -bool HostLocalAllocator::CanAllocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - size_t allocation_size) const { - // Host allows everything, pretty much, so long as it is device-visible (as - // the host is the device here). - return AnyBitSet(memory_type & MemoryType::kDeviceVisible); -} - -Status HostLocalAllocator::MakeCompatible( - MemoryTypeBitfield* memory_type, BufferUsageBitfield* buffer_usage) const { - // Always ensure we are host-visible. - *memory_type |= MemoryType::kHostVisible; - - // Host currently uses mapping to copy buffers, which is done a lot. - // We could probably remove this restriction somehow. - *buffer_usage |= BufferUsage::kMapping; - - // TODO(b/111372612): tensorflow needs transfer too, but shouldn't. - *buffer_usage |= BufferUsage::kTransfer; - - return OkStatus(); -} - -StatusOr<ref_ptr<Buffer>> HostLocalAllocator::Allocate( - MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - size_t allocation_size) { - IREE_TRACE_SCOPE0("HostLocalAllocator::Allocate"); - - if (!CanAllocate(memory_type, buffer_usage, allocation_size)) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Allocation not supported; memory_type=" - << MemoryTypeString(memory_type) - << ", buffer_usage=" << BufferUsageString(buffer_usage) - << ", allocation_size=" << allocation_size; - } - - // Make compatible with our requirements. - RETURN_IF_ERROR(MakeCompatible(&memory_type, &buffer_usage)); - - void* malloced_data = std::calloc(1, allocation_size); - if (!malloced_data) { - return ResourceExhaustedErrorBuilder(IREE_LOC) - << "Failed to malloc " << allocation_size << " bytes"; - } - - auto buffer = - make_ref<HostBuffer>(this, memory_type, MemoryAccess::kAll, buffer_usage, - allocation_size, malloced_data, true); - return buffer; -} - -} // namespace hal -} // namespace iree
diff --git a/iree/hal/host/host_local_allocator.h b/iree/hal/host/host_local_allocator.h deleted file mode 100644 index 020fbea..0000000 --- a/iree/hal/host/host_local_allocator.h +++ /dev/null
@@ -1,60 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_HOST_LOCAL_ALLOCATOR_H_ -#define IREE_HAL_HOST_LOCAL_ALLOCATOR_H_ - -#include <cstddef> -#include <memory> - -#include "iree/base/status.h" -#include "iree/hal/allocator.h" -#include "iree/hal/buffer.h" - -namespace iree { -namespace hal { - -// An allocator implementation that allocates buffers from host memory. -// This can be used for drivers that do not have a memory space of their own. -// -// Buffers allocated will have be MemoryType::kHostLocal | kDeviceVisible as -// the 'device' in the case of a host-local queue *is* the host. To keep code -// written initially for a host-local queue working when other queues are used -// the allocator only works with buffers that are kDeviceVisible. -class HostLocalAllocator : public Allocator { - public: - HostLocalAllocator(); - ~HostLocalAllocator() override; - - bool CanUseBufferLike(Allocator* source_allocator, - MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - BufferUsageBitfield intended_usage) const override; - - bool CanAllocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - size_t allocation_size) const override; - - Status MakeCompatible(MemoryTypeBitfield* memory_type, - BufferUsageBitfield* buffer_usage) const override; - - StatusOr<ref_ptr<Buffer>> Allocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - size_t allocation_size) override; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_HOST_LOCAL_ALLOCATOR_H_
diff --git a/iree/hal/host/host_local_command_processor.cc b/iree/hal/host/host_local_command_processor.cc deleted file mode 100644 index a2c94ec..0000000 --- a/iree/hal/host/host_local_command_processor.cc +++ /dev/null
@@ -1,120 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/host/host_local_command_processor.h" - -#include "iree/base/source_location.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" - -namespace iree { -namespace hal { - -HostLocalCommandProcessor::HostLocalCommandProcessor( - Allocator* allocator, CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories) - : CommandBuffer(allocator, mode, command_categories) {} - -HostLocalCommandProcessor::~HostLocalCommandProcessor() = default; - -Status HostLocalCommandProcessor::Begin() { - IREE_TRACE_SCOPE0("HostLocalCommandProcessor::Begin"); - is_recording_ = true; - return OkStatus(); -} - -Status HostLocalCommandProcessor::End() { - IREE_TRACE_SCOPE0("HostLocalCommandProcessor::End"); - is_recording_ = false; - return OkStatus(); -} - -Status HostLocalCommandProcessor::ExecutionBarrier( - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span<const MemoryBarrier> memory_barriers, - absl::Span<const BufferBarrier> buffer_barriers) { - IREE_TRACE_SCOPE0("HostLocalCommandProcessor::ExecutionBarrier"); - // No-op. - return OkStatus(); -} - -Status HostLocalCommandProcessor::SignalEvent( - Event* event, ExecutionStageBitfield source_stage_mask) { - IREE_TRACE_SCOPE0("HostLocalCommandProcessor::SignalEvent"); - // No-op. - return OkStatus(); -} - -Status HostLocalCommandProcessor::ResetEvent( - Event* event, ExecutionStageBitfield source_stage_mask) { - IREE_TRACE_SCOPE0("HostLocalCommandProcessor::ResetEvent"); - // No-op. - return OkStatus(); -} - -Status HostLocalCommandProcessor::WaitEvents( - absl::Span<Event*> events, ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span<const MemoryBarrier> memory_barriers, - absl::Span<const BufferBarrier> buffer_barriers) { - IREE_TRACE_SCOPE0("HostLocalCommandProcessor::WaitEvents"); - // No-op. - return OkStatus(); -} - -Status HostLocalCommandProcessor::FillBuffer(Buffer* target_buffer, - device_size_t target_offset, - device_size_t length, - const void* pattern, - size_t pattern_length) { - IREE_TRACE_SCOPE0("HostLocalCommandProcessor::FillBuffer"); - return target_buffer->Fill(target_offset, length, pattern, pattern_length); -} - -Status HostLocalCommandProcessor::DiscardBuffer(Buffer* buffer) { - IREE_TRACE_SCOPE0("HostLocalCommandProcessor::DiscardBuffer"); - // No-op as we don't support lazily allocated buffers. - return OkStatus(); -} - -Status HostLocalCommandProcessor::UpdateBuffer(const void* source_buffer, - device_size_t source_offset, - Buffer* target_buffer, - device_size_t target_offset, - device_size_t length) { - IREE_TRACE_SCOPE0("HostLocalCommandProcessor::UpdateBuffer"); - return target_buffer->WriteData( - target_offset, static_cast<const uint8_t*>(source_buffer) + source_offset, - length); -} - -Status HostLocalCommandProcessor::CopyBuffer(Buffer* source_buffer, - device_size_t source_offset, - Buffer* target_buffer, - device_size_t target_offset, - device_size_t length) { - IREE_TRACE_SCOPE0("HostLocalCommandProcessor::CopyBuffer"); - return target_buffer->CopyData(target_offset, source_buffer, source_offset, - length); -} - -Status HostLocalCommandProcessor::Dispatch( - const DispatchRequest& dispatch_request) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Command processor does not support dispatch operations"; -} - -} // namespace hal -} // namespace iree
diff --git a/iree/hal/host/host_local_command_processor.h b/iree/hal/host/host_local_command_processor.h deleted file mode 100644 index f60d4c2..0000000 --- a/iree/hal/host/host_local_command_processor.h +++ /dev/null
@@ -1,85 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_HOST_HOST_LOCAL_COMMAND_PROCESSOR_H_ -#define IREE_HAL_HOST_HOST_LOCAL_COMMAND_PROCESSOR_H_ - -#include "iree/hal/command_buffer.h" - -namespace iree { -namespace hal { - -// Host-local command processor for dispatching transfer operations against -// buffers allocated from the HostLocalAllocator. -// This assumes that all buffers are host-visible (if not local) and that all -// buffers can be mapped for access. -// -// Subclasses may implement Dispatch, otherwise the default implementation just -// returns failure. -// -// Thread-compatible (as with CommandBuffer itself). -class HostLocalCommandProcessor : public CommandBuffer { - public: - HostLocalCommandProcessor(Allocator* allocator, - CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories); - ~HostLocalCommandProcessor() override; - - bool is_recording() const override { return is_recording_; } - - Status Begin() override; - Status End() override; - - Status ExecutionBarrier( - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span<const MemoryBarrier> memory_barriers, - absl::Span<const BufferBarrier> buffer_barriers) override; - - Status SignalEvent(Event* event, - ExecutionStageBitfield source_stage_mask) override; - - Status ResetEvent(Event* event, - ExecutionStageBitfield source_stage_mask) override; - - Status WaitEvents(absl::Span<Event*> events, - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span<const MemoryBarrier> memory_barriers, - absl::Span<const BufferBarrier> buffer_barriers) override; - - Status FillBuffer(Buffer* target_buffer, device_size_t target_offset, - device_size_t length, const void* pattern, - size_t pattern_length) override; - - Status DiscardBuffer(Buffer* buffer) override; - - Status UpdateBuffer(const void* source_buffer, device_size_t source_offset, - Buffer* target_buffer, device_size_t target_offset, - device_size_t length) override; - - Status CopyBuffer(Buffer* source_buffer, device_size_t source_offset, - Buffer* target_buffer, device_size_t target_offset, - device_size_t length) override; - - Status Dispatch(const DispatchRequest& dispatch_request) override; - - private: - bool is_recording_ = false; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_HOST_HOST_LOCAL_COMMAND_PROCESSOR_H_
diff --git a/iree/hal/host/host_submission_queue.cc b/iree/hal/host/host_submission_queue.cc deleted file mode 100644 index a86d442..0000000 --- a/iree/hal/host/host_submission_queue.cc +++ /dev/null
@@ -1,295 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/host/host_submission_queue.h" - -#include <atomic> -#include <cstdint> - -#include "absl/synchronization/mutex.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" - -namespace iree { -namespace hal { - -HostBinarySemaphore::HostBinarySemaphore(bool initial_value) { - State state = {0}; - state.signaled = initial_value ? 1 : 0; - state_ = state; -} - -bool HostBinarySemaphore::is_signaled() const { - return state_.load(std::memory_order_acquire).signaled == 1; -} - -Status HostBinarySemaphore::BeginSignaling() { - State old_state = state_.load(std::memory_order_acquire); - if (old_state.signal_pending != 0) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "A signal operation on a binary semaphore is already pending"; - } - State new_state = old_state; - new_state.signal_pending = 1; - state_.compare_exchange_strong(old_state, new_state); - return OkStatus(); -} - -Status HostBinarySemaphore::EndSignaling() { - State old_state = state_.load(std::memory_order_acquire); - DCHECK_EQ(old_state.signal_pending, 1) - << "A signal operation on a binary semaphore was not pending"; - if (old_state.signaled != 0) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "A binary semaphore cannot be signaled multiple times"; - } - State new_state = old_state; - new_state.signal_pending = 0; - new_state.signaled = 1; - state_.compare_exchange_strong(old_state, new_state); - return OkStatus(); -} - -Status HostBinarySemaphore::BeginWaiting() { - State old_state = state_.load(std::memory_order_acquire); - if (old_state.wait_pending != 0) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "A wait operation on a binary semaphore is already pending"; - } - State new_state = old_state; - new_state.wait_pending = 1; - state_.compare_exchange_strong(old_state, new_state); - return OkStatus(); -} - -Status HostBinarySemaphore::EndWaiting() { - State old_state = state_.load(std::memory_order_acquire); - DCHECK_EQ(old_state.wait_pending, 1) - << "A wait operation on a binary semaphore was not pending"; - if (old_state.signaled != 1) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "A binary semaphore cannot be reset multiple times"; - } - State new_state = old_state; - new_state.wait_pending = 0; - new_state.signaled = 0; - state_.compare_exchange_strong(old_state, new_state); - return OkStatus(); -} - -HostSubmissionQueue::HostSubmissionQueue() = default; - -HostSubmissionQueue::~HostSubmissionQueue() = default; - -bool HostSubmissionQueue::IsBatchReady(const PendingBatch& batch) const { - for (auto& wait_point : batch.wait_semaphores) { - if (wait_point.index() == 0) { - auto* binary_semaphore = - reinterpret_cast<HostBinarySemaphore*>(absl::get<0>(wait_point)); - if (!binary_semaphore->is_signaled()) { - return false; - } - } else { - // TODO(b/140141417): implement timeline semaphores. - return false; - } - } - return true; -} - -Status HostSubmissionQueue::Enqueue(absl::Span<const SubmissionBatch> batches, - FenceValue fence) { - IREE_TRACE_SCOPE0("HostSubmissionQueue::Enqueue"); - - if (has_shutdown_) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Cannot enqueue new submissions; queue is exiting"; - } else if (!permanent_error_.ok()) { - return permanent_error_; - } - - // Verify waiting/signaling behavior on semaphores and prepare them all. - // We need to track this to ensure that we are modeling the Vulkan behavior - // and are consistent across HAL implementations. - for (auto& batch : batches) { - for (auto& semaphore_value : batch.wait_semaphores) { - if (semaphore_value.index() == 0) { - auto* binary_semaphore = reinterpret_cast<HostBinarySemaphore*>( - absl::get<0>(semaphore_value)); - RETURN_IF_ERROR(binary_semaphore->BeginWaiting()); - } else { - // TODO(b/140141417): implement timeline semaphores. - return UnimplementedErrorBuilder(IREE_LOC) << "Timeline semaphores NYI"; - } - } - for (auto& semaphore_value : batch.signal_semaphores) { - if (semaphore_value.index() == 0) { - auto* binary_semaphore = reinterpret_cast<HostBinarySemaphore*>( - absl::get<0>(semaphore_value)); - RETURN_IF_ERROR(binary_semaphore->BeginSignaling()); - } else { - // TODO(b/140141417): implement timeline semaphores. - return UnimplementedErrorBuilder(IREE_LOC) << "Timeline semaphores NYI"; - } - } - } - - // Add to list - order does not matter as Process evaluates semaphores. - auto submission = absl::make_unique<Submission>(); - submission->fence = std::move(fence); - submission->pending_batches.resize(batches.size()); - for (int i = 0; i < batches.size(); ++i) { - submission->pending_batches[i] = PendingBatch{ - {batches[i].wait_semaphores.begin(), batches[i].wait_semaphores.end()}, - {batches[i].command_buffers.begin(), batches[i].command_buffers.end()}, - {batches[i].signal_semaphores.begin(), - batches[i].signal_semaphores.end()}, - }; - } - list_.push_back(std::move(submission)); - - return OkStatus(); -} - -Status HostSubmissionQueue::ProcessBatches(ExecuteFn execute_fn) { - IREE_TRACE_SCOPE0("HostSubmissionQueue::ProcessBatches"); - - if (!permanent_error_.ok()) { - // Sticky failure state. - return permanent_error_; - } - - // Repeated try to run things until we quiesce or are blocked. - while (permanent_error_.ok() && !list_.empty()) { - // NOTE: to support re-entrancy where |execute_fn| may modify the submission - // list we need to always start from the beginning. If we wanted we could - // track a list of ready submissions however that's a lot of bookkeeping and - // the list is usually short. - bool restart_iteration = false; - for (auto* submission : list_) { - for (int i = 0; i < submission->pending_batches.size(); ++i) { - auto& batch = submission->pending_batches[i]; - if (!IsBatchReady(batch)) { - // Try the next batch in the submission until we find one that is - // ready. If none are ready we'll return to the caller. - continue; - } - - // Batch can run! Process now and remove it from the list so we don't - // try to run it again. - auto batch_status = ProcessBatch(batch, execute_fn); - submission->pending_batches.erase(submission->pending_batches.begin() + - i); - if (batch_status.ok()) { - // Batch succeeded. Since we want to preserve submission order we'll - // break out of the loop and try from the first submission again. - if (submission->pending_batches.empty()) { - // All work for this submission completed successfully. Signal the - // fence and remove the submission from the list. - RETURN_IF_ERROR(CompleteSubmission(submission, OkStatus())); - list_.take(submission).reset(); - } - } else { - // Batch failed; set the permanent error flag and abort so we don't - // try to process anything else. - permanent_error_ = batch_status; - RETURN_IF_ERROR(CompleteSubmission(submission, batch_status)); - list_.take(submission).reset(); - } - restart_iteration = true; - break; - } - if (restart_iteration) break; - } - } - - if (!permanent_error_.ok()) { - // If the sticky error got set while processing we need to abort all - // remaining submissions (simulating a device loss). - FailAllPending(permanent_error_); - return permanent_error_; - } - - return OkStatus(); -} - -Status HostSubmissionQueue::ProcessBatch(const PendingBatch& batch, - const ExecuteFn& execute_fn) { - IREE_TRACE_SCOPE0("HostSubmissionQueue::ProcessBatch"); - - // Complete the waits on all semaphores and reset them. - for (auto& semaphore_value : batch.wait_semaphores) { - if (semaphore_value.index() == 0) { - auto* binary_semaphore = - reinterpret_cast<HostBinarySemaphore*>(absl::get<0>(semaphore_value)); - RETURN_IF_ERROR(binary_semaphore->EndWaiting()); - } else { - // TODO(b/140141417): implement timeline semaphores. - return UnimplementedErrorBuilder(IREE_LOC) << "Timeline semaphores NYI"; - } - } - - // Let the caller handle execution of the command buffers. - RETURN_IF_ERROR(execute_fn(batch.command_buffers)); - - // Signal all semaphores to allow them to unblock waiters. - for (auto& semaphore_value : batch.signal_semaphores) { - if (semaphore_value.index() == 0) { - auto* binary_semaphore = - reinterpret_cast<HostBinarySemaphore*>(absl::get<0>(semaphore_value)); - RETURN_IF_ERROR(binary_semaphore->EndSignaling()); - } else { - // TODO(b/140141417): implement timeline semaphores. - return UnimplementedErrorBuilder(IREE_LOC) << "Timeline semaphores NYI"; - } - } - - return OkStatus(); -} - -Status HostSubmissionQueue::CompleteSubmission(Submission* submission, - Status status) { - IREE_TRACE_SCOPE0("HostSubmissionQueue::CompleteSubmission"); - - // It's safe to drop any remaining batches - their semaphores will never be - // signaled but that's fine as we should be the only thing relying on them. - submission->pending_batches.clear(); - - // Signal the fence. - auto* fence = static_cast<HostFence*>(submission->fence.first); - if (status.ok()) { - RETURN_IF_ERROR(fence->Signal(submission->fence.second)); - } else { - RETURN_IF_ERROR(fence->Fail(std::move(status))); - } - - return OkStatus(); -} - -void HostSubmissionQueue::FailAllPending(Status status) { - IREE_TRACE_SCOPE0("HostSubmissionQueue::FailAllPending"); - while (!list_.empty()) { - auto submission = list_.take(list_.front()); - CompleteSubmission(submission.get(), status).IgnoreError(); - submission.reset(); - } -} - -void HostSubmissionQueue::SignalShutdown() { - IREE_TRACE_SCOPE0("HostSubmissionQueue::SignalShutdown"); - has_shutdown_ = true; -} - -} // namespace hal -} // namespace iree
diff --git a/iree/hal/host/host_submission_queue.h b/iree/hal/host/host_submission_queue.h deleted file mode 100644 index 14eae7a..0000000 --- a/iree/hal/host/host_submission_queue.h +++ /dev/null
@@ -1,163 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_HOST_HOST_SUBMISSION_QUEUE_H_ -#define IREE_HAL_HOST_HOST_SUBMISSION_QUEUE_H_ - -#include "absl/base/thread_annotations.h" -#include "absl/container/inlined_vector.h" -#include "absl/synchronization/mutex.h" -#include "iree/base/intrusive_list.h" -#include "iree/base/status.h" -#include "iree/hal/command_queue.h" -#include "iree/hal/host/host_fence.h" -#include "iree/hal/semaphore.h" - -namespace iree { -namespace hal { - -class HostSubmissionQueue; - -// Simple host-only binary semaphore implemented with a mutex. -// To match the expected HAL behavior (mostly dictated by Vulkan) we can only -// have a single waiter and waits can only occur once a signal has been -// enqueued. -// -// Thread-safe (as instances may be imported and used by others). -class HostBinarySemaphore final : public BinarySemaphore { - public: - explicit HostBinarySemaphore(bool initial_value); - - // Returns true if the semaphore has been signaled. - bool is_signaled() const; - - private: - friend class HostSubmissionQueue; - - // Begins a signal operation and ensures no other signal operation is pending. - Status BeginSignaling(); - // Ends a signal operation by setting the semaphore to the signaled state. - Status EndSignaling(); - - // Begins a wait operation and ensures no other wait operation is pending. - Status BeginWaiting(); - // Ends a wait operation by resetting the semaphore to the unsignaled state. - Status EndWaiting(); - - // A single 32-bit int for lock-free semaphore behavior. We need to do this - // extra tracking so that we get consistent behavior across HAL - // implementations that have strict semaphore semantics. - struct State { - uint32_t signal_pending : 1; - uint32_t wait_pending : 1; - uint32_t signaled : 1; - }; - std::atomic<State> state_{{0, 0, 0}}; -}; - -// Simple host-only timeline semaphore implemented with a mutex. -// -// Thread-safe (as instances may be imported and used by others). -class HostTimelineSemaphore final : public TimelineSemaphore { - public: - // TODO(b/140141417): implement timeline semaphores. -}; - -// A queue managing CommandQueue submissions that uses host-local -// synchronization primitives. Evaluates submission order by respecting the -// wait and signal semaphores defined per batch and notifies fences upon -// submission completion. -// -// Note that it's possible for HAL users to deadlock themselves; we don't try to -// avoid that as in device backends it may not be possible and we want to have -// some kind of warning in the host implementation that TSAN can catch. -// -// Thread-compatible. Const methods may be called from any thread. -class HostSubmissionQueue { - public: - using ExecuteFn = - std::function<Status(absl::Span<CommandBuffer* const> command_buffers)>; - - HostSubmissionQueue(); - ~HostSubmissionQueue(); - - // Returns true if the queue is currently empty. - bool empty() const { return list_.empty(); } - // Returns true if SignalShutdown has been called. - bool has_shutdown() const { return has_shutdown_; } - // The sticky error status, if an error has occurred. - Status permanent_error() const { return permanent_error_; } - - // Enqueues a new submission. - // No work will be performed until Process is called. - Status Enqueue(absl::Span<const SubmissionBatch> batches, FenceValue fence); - - // Processes all ready batches using the provided |execute_fn|. - // The function may be called several times if new batches become ready due to - // prior batches in the sequence completing during processing. - // - // Returns any errors returned by |execute_fn| (which will be the same as - // permanent_error()). When an error occurs all in-flight submissions are - // aborted, the permanent_error() is set, and the queue is shutdown. - Status ProcessBatches(ExecuteFn execute_fn); - - // Marks the queue as having shutdown. All pending submissions will be allowed - // to complete but future enqueues will fail. - void SignalShutdown(); - - private: - // A submitted command buffer batch and its synchronization information. - struct PendingBatch { - absl::InlinedVector<SemaphoreValue, 4> wait_semaphores; - absl::InlinedVector<CommandBuffer*, 4> command_buffers; - absl::InlinedVector<SemaphoreValue, 4> signal_semaphores; - }; - struct Submission : public IntrusiveLinkBase<void> { - absl::InlinedVector<PendingBatch, 4> pending_batches; - FenceValue fence; - }; - - // Returns true if all wait semaphores in the |batch| are signaled. - bool IsBatchReady(const PendingBatch& batch) const; - - // Processes a batch by resetting semaphores, dispatching the command buffers - // to the specified |execute_fn|, and signaling semaphores. - // - // Preconditions: IsBatchReady(batch) == true - Status ProcessBatch(const PendingBatch& batch, const ExecuteFn& execute_fn); - - // Completes a submission by signaling the fence with the given |status|. - Status CompleteSubmission(Submission* submission, Status status); - - // Fails all pending submissions with the given status. - // Errors that occur during this process are silently ignored. - void FailAllPending(Status status); - - // True to exit the thread after all submissions complete. - bool has_shutdown_ = false; - - // A sticky error that is set on the first failed submit. All future - // submissions will be skipped except for fences, which will receive this - // error. - Status permanent_error_; - - // Pending submissions in submission order. - // Note that we may evaluate batches within the list out of order. - IntrusiveList<std::unique_ptr<Submission>> list_; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_HOST_HOST_SUBMISSION_QUEUE_H_
diff --git a/iree/hal/host/host_submission_queue_test.cc b/iree/hal/host/host_submission_queue_test.cc deleted file mode 100644 index 227131a..0000000 --- a/iree/hal/host/host_submission_queue_test.cc +++ /dev/null
@@ -1,30 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/host/host_submission_queue.h" - -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -namespace iree { -namespace hal { -namespace { - -TEST(HostSubmissionQueueTest, TBD) { - // TODO(benvanik): test! -} - -} // namespace -} // namespace hal -} // namespace iree
diff --git a/iree/hal/host/inproc_command_buffer.cc b/iree/hal/host/inproc_command_buffer.cc deleted file mode 100644 index 0833843..0000000 --- a/iree/hal/host/inproc_command_buffer.cc +++ /dev/null
@@ -1,264 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/host/inproc_command_buffer.h" - -#include "iree/base/tracing.h" - -namespace iree { -namespace hal { - -InProcCommandBuffer::InProcCommandBuffer( - Allocator* allocator, CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories) - : CommandBuffer(allocator, mode, command_categories) {} - -InProcCommandBuffer::~InProcCommandBuffer() { Reset(); } - -Status InProcCommandBuffer::Begin() { - IREE_TRACE_SCOPE0("InProcCommandBuffer::Begin"); - is_recording_ = true; - Reset(); - return OkStatus(); -} - -Status InProcCommandBuffer::End() { - IREE_TRACE_SCOPE0("InProcCommandBuffer::End"); - is_recording_ = false; - return OkStatus(); -} - -Status InProcCommandBuffer::ExecutionBarrier( - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span<const MemoryBarrier> memory_barriers, - absl::Span<const BufferBarrier> buffer_barriers) { - IREE_TRACE_SCOPE0("InProcCommandBuffer::ExecutionBarrier"); - ASSIGN_OR_RETURN(auto* cmd, AppendCmd<ExecutionBarrierCmd>()); - cmd->source_stage_mask = source_stage_mask; - cmd->target_stage_mask = target_stage_mask; - cmd->memory_barriers = AppendStructSpan(memory_barriers); - cmd->buffer_barriers = AppendStructSpan(buffer_barriers); - return OkStatus(); -} - -Status InProcCommandBuffer::SignalEvent( - Event* event, ExecutionStageBitfield source_stage_mask) { - IREE_TRACE_SCOPE0("InProcCommandBuffer::SignalEvent"); - ASSIGN_OR_RETURN(auto* cmd, AppendCmd<SignalEventCmd>()); - cmd->event = event; - cmd->source_stage_mask = source_stage_mask; - return OkStatus(); -} - -Status InProcCommandBuffer::ResetEvent( - Event* event, ExecutionStageBitfield source_stage_mask) { - IREE_TRACE_SCOPE0("InProcCommandBuffer::ResetEvent"); - ASSIGN_OR_RETURN(auto* cmd, AppendCmd<ResetEventCmd>()); - cmd->event = event; - cmd->source_stage_mask = source_stage_mask; - return OkStatus(); -} - -Status InProcCommandBuffer::WaitEvents( - absl::Span<Event*> events, ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span<const MemoryBarrier> memory_barriers, - absl::Span<const BufferBarrier> buffer_barriers) { - IREE_TRACE_SCOPE0("InProcCommandBuffer::WaitEvents"); - ASSIGN_OR_RETURN(auto* cmd, AppendCmd<WaitEventsCmd>()); - cmd->events = AppendStructSpan(events); - cmd->source_stage_mask = source_stage_mask; - cmd->target_stage_mask = target_stage_mask; - cmd->memory_barriers = AppendStructSpan(memory_barriers); - cmd->buffer_barriers = AppendStructSpan(buffer_barriers); - return OkStatus(); -} - -Status InProcCommandBuffer::FillBuffer(Buffer* target_buffer, - device_size_t target_offset, - device_size_t length, - const void* pattern, - size_t pattern_length) { - IREE_TRACE_SCOPE0("InProcCommandBuffer::FillBuffer"); - ASSIGN_OR_RETURN(auto* cmd, AppendCmd<FillBufferCmd>()); - cmd->target_buffer = target_buffer; - cmd->target_offset = target_offset; - cmd->length = length; - std::memcpy(cmd->pattern, pattern, pattern_length); - cmd->pattern_length = pattern_length; - return OkStatus(); -} - -Status InProcCommandBuffer::DiscardBuffer(Buffer* buffer) { - IREE_TRACE_SCOPE0("InProcCommandBuffer::DiscardBuffer"); - ASSIGN_OR_RETURN(auto* cmd, AppendCmd<DiscardBufferCmd>()); - cmd->buffer = buffer; - return OkStatus(); -} - -Status InProcCommandBuffer::UpdateBuffer(const void* source_buffer, - device_size_t source_offset, - Buffer* target_buffer, - device_size_t target_offset, - device_size_t length) { - IREE_TRACE_SCOPE0("InProcCommandBuffer::UpdateBuffer"); - ASSIGN_OR_RETURN(auto* cmd, AppendCmd<UpdateBufferCmd>()); - cmd->source_buffer = AppendCmdData(source_buffer, source_offset, length); - cmd->target_buffer = target_buffer; - cmd->target_offset = target_offset; - cmd->length = length; - return OkStatus(); -} - -Status InProcCommandBuffer::CopyBuffer(Buffer* source_buffer, - device_size_t source_offset, - Buffer* target_buffer, - device_size_t target_offset, - device_size_t length) { - IREE_TRACE_SCOPE0("InProcCommandBuffer::CopyBuffer"); - ASSIGN_OR_RETURN(auto* cmd, AppendCmd<CopyBufferCmd>()); - cmd->source_buffer = source_buffer; - cmd->source_offset = source_offset; - cmd->target_buffer = target_buffer; - cmd->target_offset = target_offset; - cmd->length = length; - return OkStatus(); -} - -Status InProcCommandBuffer::Dispatch(const DispatchRequest& dispatch_request) { - IREE_TRACE_SCOPE0("InProcCommandBuffer::Dispatch"); - ASSIGN_OR_RETURN(auto* cmd, AppendCmd<DispatchCmd>()); - cmd->request.executable = dispatch_request.executable; - cmd->request.entry_point = dispatch_request.entry_point; - cmd->request.workload = dispatch_request.workload; - cmd->request.workload_buffer = dispatch_request.workload_buffer; - cmd->request.bindings = AppendStructSpan(dispatch_request.bindings); - return OkStatus(); -} - -void InProcCommandBuffer::Reset() { - auto* cmd_list = ¤t_cmd_list_; - cmd_list->head = cmd_list->tail = nullptr; - cmd_list->arena.Reset(); -} - -InProcCommandBuffer::CmdHeader* InProcCommandBuffer::AppendCmdHeader( - CmdType type, size_t cmd_size) { - auto* cmd_list = ¤t_cmd_list_; - auto* cmd_header = reinterpret_cast<CmdHeader*>( - cmd_list->arena.AllocateBytes(sizeof(CmdHeader) + cmd_size)); - cmd_header->next = nullptr; - cmd_header->type = type; - if (!cmd_list->head) { - cmd_list->head = cmd_header; - } else if (cmd_list->tail) { - cmd_list->tail->next = cmd_header; - } - cmd_list->tail = cmd_header; - return cmd_header; -} - -void* InProcCommandBuffer::AppendCmdData(const void* source_buffer, - device_size_t source_offset, - device_size_t source_length) { - auto* cmd_list = ¤t_cmd_list_; - - uint8_t* allocated_bytes = cmd_list->arena.AllocateBytes(source_length); - std::memcpy(allocated_bytes, - static_cast<const uint8_t*>(source_buffer) + source_offset, - source_length); - return allocated_bytes; -} - -Status InProcCommandBuffer::Process(CommandBuffer* command_processor) const { - IREE_TRACE_SCOPE0("InProcCommandBuffer::Process"); - - RETURN_IF_ERROR(command_processor->Begin()); - - // Process each command in the order they were recorded. - auto* cmd_list = ¤t_cmd_list_; - for (CmdHeader* cmd_header = cmd_list->head; cmd_header != nullptr; - cmd_header = cmd_header->next) { - auto command_status = ProcessCmd(cmd_header, command_processor); - if (!command_status.ok()) { - LOG(ERROR) << "DeviceQueue failure while executing command; permanently " - "failing all future commands: " - << command_status; - } - } - - RETURN_IF_ERROR(command_processor->End()); - - return OkStatus(); -} - -Status InProcCommandBuffer::ProcessCmd(CmdHeader* cmd_header, - CommandBuffer* command_processor) const { - switch (cmd_header->type) { - case CmdType::kExecutionBarrier: { - auto* cmd = reinterpret_cast<ExecutionBarrierCmd*>(cmd_header + 1); - return command_processor->ExecutionBarrier( - cmd->source_stage_mask, cmd->target_stage_mask, cmd->memory_barriers, - cmd->buffer_barriers); - } - case CmdType::kSignalEvent: { - auto* cmd = reinterpret_cast<SignalEventCmd*>(cmd_header + 1); - return command_processor->SignalEvent(cmd->event, cmd->source_stage_mask); - } - case CmdType::kResetEvent: { - auto* cmd = reinterpret_cast<ResetEventCmd*>(cmd_header + 1); - return command_processor->ResetEvent(cmd->event, cmd->source_stage_mask); - } - case CmdType::kWaitEvents: { - auto* cmd = reinterpret_cast<WaitEventsCmd*>(cmd_header + 1); - return command_processor->WaitEvents( - cmd->events, cmd->source_stage_mask, cmd->target_stage_mask, - cmd->memory_barriers, cmd->buffer_barriers); - } - case CmdType::kFillBuffer: { - auto* cmd = reinterpret_cast<FillBufferCmd*>(cmd_header + 1); - return command_processor->FillBuffer(cmd->target_buffer, - cmd->target_offset, cmd->length, - cmd->pattern, cmd->pattern_length); - } - case CmdType::kDiscardBuffer: { - auto* cmd = reinterpret_cast<DiscardBufferCmd*>(cmd_header + 1); - return command_processor->DiscardBuffer(cmd->buffer); - } - case CmdType::kUpdateBuffer: { - auto* cmd = reinterpret_cast<UpdateBufferCmd*>(cmd_header + 1); - return command_processor->UpdateBuffer(cmd->source_buffer, 0, - cmd->target_buffer, - cmd->target_offset, cmd->length); - } - case CmdType::kCopyBuffer: { - auto* cmd = reinterpret_cast<CopyBufferCmd*>(cmd_header + 1); - return command_processor->CopyBuffer( - cmd->source_buffer, cmd->source_offset, cmd->target_buffer, - cmd->target_offset, cmd->length); - } - case CmdType::kDispatch: { - auto* cmd = reinterpret_cast<DispatchCmd*>(cmd_header + 1); - return command_processor->Dispatch(cmd->request); - } - default: - return DataLossErrorBuilder(IREE_LOC) - << "Unrecognized command type " - << static_cast<int>(cmd_header->type) << "; corrupt buffer?"; - } -} - -} // namespace hal -} // namespace iree
diff --git a/iree/hal/host/inproc_command_buffer.h b/iree/hal/host/inproc_command_buffer.h deleted file mode 100644 index a9d2bf6..0000000 --- a/iree/hal/host/inproc_command_buffer.h +++ /dev/null
@@ -1,241 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_HOST_INPROC_COMMAND_BUFFER_H_ -#define IREE_HAL_HOST_INPROC_COMMAND_BUFFER_H_ - -#include "iree/base/arena.h" -#include "iree/base/intrusive_list.h" -#include "iree/base/status.h" -#include "iree/hal/command_buffer.h" - -namespace iree { -namespace hal { - -// In-process command buffer with support for recording and playback. -// Commands are recorded into heap-allocated arenas with pointers to used -// resources (Buffer*, etc). To replay a command buffer against a real -// implementation use Process to call each command method as it was originally -// recorded. -// -// Thread-compatible (as with CommandBuffer itself). -class InProcCommandBuffer final : public CommandBuffer { - public: - InProcCommandBuffer(Allocator* allocator, CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories); - ~InProcCommandBuffer() override; - - bool is_recording() const override { return is_recording_; } - - Status Begin() override; - Status End() override; - - Status ExecutionBarrier( - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span<const MemoryBarrier> memory_barriers, - absl::Span<const BufferBarrier> buffer_barriers) override; - - Status SignalEvent(Event* event, - ExecutionStageBitfield source_stage_mask) override; - - Status ResetEvent(Event* event, - ExecutionStageBitfield source_stage_mask) override; - - Status WaitEvents(absl::Span<Event*> events, - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span<const MemoryBarrier> memory_barriers, - absl::Span<const BufferBarrier> buffer_barriers) override; - - Status FillBuffer(Buffer* target_buffer, device_size_t target_offset, - device_size_t length, const void* pattern, - size_t pattern_length) override; - - Status DiscardBuffer(Buffer* buffer) override; - - Status UpdateBuffer(const void* source_buffer, device_size_t source_offset, - Buffer* target_buffer, device_size_t target_offset, - device_size_t length) override; - - Status CopyBuffer(Buffer* source_buffer, device_size_t source_offset, - Buffer* target_buffer, device_size_t target_offset, - device_size_t length) override; - - Status Dispatch(const DispatchRequest& dispatch_request) override; - - // Processes all commands in the buffer using the given |command_processor|. - // The commands are issued in the order they were recorded. - Status Process(CommandBuffer* command_processor) const; - - private: - // Type of Cmd, used by CmdHeader to identify the command payload. - enum class CmdType { - kExecutionBarrier, - kSignalEvent, - kResetEvent, - kWaitEvents, - kFillBuffer, - kDiscardBuffer, - kUpdateBuffer, - kCopyBuffer, - kDispatch, - }; - - // Prefix for commands encoded into the CmdList. - // This is used to identify the type of a command as well as connect commands - // in the list sequence. Command data immediately follows the header in - // memory. - struct CmdHeader { - // Optional next command in the list. - CmdHeader* next; - // Type of the command. - CmdType type; - }; - - // A lightweight linked list of commands and an arena that stores them. - // CmdLists are designed to be reused so that the arena allocations are - // amortized across multiple uses. - // - // Note that this and the CmdHeader/Cmd types include raw pointers and as - // such are *not* portable across processes. It'd be possible, though, to - // extend this for cross-process use if a shared-memory Buffer was also - // implemented. For YAGNI we avoid that here. - struct CmdList : public IntrusiveLinkBase<void> { - static constexpr size_t kArenaBlockSize = 64 * 1024; - - Arena arena{kArenaBlockSize}; - CmdHeader* head = nullptr; - CmdHeader* tail = nullptr; - }; - - // Defines an execution barrier. - struct ExecutionBarrierCmd { - static constexpr CmdType kType = CmdType::kExecutionBarrier; - ExecutionStageBitfield source_stage_mask; - ExecutionStageBitfield target_stage_mask; - absl::Span<const MemoryBarrier> memory_barriers; - absl::Span<const BufferBarrier> buffer_barriers; - }; - - // Signals an event. - struct SignalEventCmd { - static constexpr CmdType kType = CmdType::kSignalEvent; - Event* event; - ExecutionStageBitfield source_stage_mask; - }; - - // Resets an event. - struct ResetEventCmd { - static constexpr CmdType kType = CmdType::kResetEvent; - Event* event; - ExecutionStageBitfield source_stage_mask; - }; - - // Waits for one or more events. - struct WaitEventsCmd { - static constexpr CmdType kType = CmdType::kWaitEvents; - absl::Span<Event*> events; - ExecutionStageBitfield source_stage_mask; - ExecutionStageBitfield target_stage_mask; - absl::Span<const MemoryBarrier> memory_barriers; - absl::Span<const BufferBarrier> buffer_barriers; - }; - - // Fills the target buffer with the given repeating value. - struct FillBufferCmd { - static constexpr CmdType kType = CmdType::kFillBuffer; - Buffer* target_buffer; - device_size_t target_offset; - device_size_t length; - uint8_t pattern[4]; - size_t pattern_length; - }; - - // Hints to the device queue that the given buffer will not be used again. - struct DiscardBufferCmd { - static constexpr CmdType kType = CmdType::kDiscardBuffer; - Buffer* buffer; - }; - - // Writes a range of the given target buffer from the embedded memory. - // The source buffer contents immediately follow the command in the arena. - struct UpdateBufferCmd { - static constexpr CmdType kType = CmdType::kUpdateBuffer; - const void* source_buffer; - Buffer* target_buffer; - device_size_t target_offset; - device_size_t length; - }; - - // Copies a range of one buffer to another. - struct CopyBufferCmd { - static constexpr CmdType kType = CmdType::kCopyBuffer; - Buffer* source_buffer; - device_size_t source_offset; - Buffer* target_buffer; - device_size_t target_offset; - device_size_t length; - }; - - // Dispatches an execution request. - struct DispatchCmd { - static constexpr CmdType kType = CmdType::kDispatch; - DispatchRequest request; - }; - - // Resets the command list. - void Reset(); - - // Allocates a command and appends it to the current command list. - // The caller must populate the fields in the returned pointer. - template <typename T> - StatusOr<T*> AppendCmd() { - return reinterpret_cast<T*>(AppendCmdHeader(T::kType, sizeof(T)) + 1); - } - - // Appends a command with the given |type| and payload |cmd_size| prefixed - // with a CmdHeader. Returns a pointer to the CmdHeader that is followed - // immediately by |cmd_size| zero bytes. - CmdHeader* AppendCmdHeader(CmdType type, size_t cmd_size); - - // Appends a byte buffer to the command buffer and returns a pointer to the - // copied data within the command buffer arena. - void* AppendCmdData(const void* source_buffer, device_size_t source_offset, - device_size_t source_length); - - // Appends a span of POD structs to the current CmdList and returns a span - // pointing into the CmdList arena. - template <typename T> - absl::Span<T> AppendStructSpan(absl::Span<T> value) { - static_assert(std::is_standard_layout<T>::value, - "Struct must be a POD type"); - void* data_ptr = AppendCmdData(value.data(), 0, value.size() * sizeof(T)); - return absl::MakeSpan(static_cast<T*>(data_ptr), value.size()); - } - - // Processes a single command. - Status ProcessCmd(CmdHeader* cmd_header, - CommandBuffer* command_processor) const; - - bool is_recording_ = false; - - // NOTE: not synchronized. Expected to be used from a single thread. - CmdList current_cmd_list_; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_HOST_INPROC_COMMAND_BUFFER_H_
diff --git a/iree/hal/interpreter/BUILD b/iree/hal/interpreter/BUILD deleted file mode 100644 index 204eb0a..0000000 --- a/iree/hal/interpreter/BUILD +++ /dev/null
@@ -1,190 +0,0 @@ -# HAL implementation running on the CPU using the IREE bytecode. - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "bytecode_cache", - srcs = ["bytecode_cache.cc"], - hdrs = ["bytecode_cache.h"], - deps = [ - ":bytecode_executable", - "//iree/base:source_location", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:allocator", - "//iree/hal:executable", - "//iree/hal:executable_cache", - "//iree/hal:executable_format", - "//iree/rt", - ], -) - -cc_library( - name = "bytecode_dispatch", - srcs = [ - "bytecode_dispatch.cc", - "bytecode_dispatch_conversion.h", - "bytecode_dispatch_util.cc", - "bytecode_dispatch_util.h", - ], - hdrs = ["bytecode_dispatch.h"], - deps = [ - ":bytecode_kernels", - "//iree/base:logging", - "//iree/base:memory", - "//iree/base:status", - "//iree/hal:allocator", - "//iree/hal:buffer_view", - "//iree/hal:heap_buffer", - "//iree/rt", - "//iree/schemas/bytecode:interpreter_bytecode_v0", - "//iree/vm:bytecode_module", - "//iree/vm:bytecode_reader", - "//iree/vm:bytecode_tables_interpreter", - "//iree/vm:bytecode_util", - "//iree/vm:opcode_info", - "//iree/vm:type", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "bytecode_executable", - srcs = ["bytecode_executable.cc"], - hdrs = ["bytecode_executable.h"], - deps = [ - ":interpreter_module", - "//iree/base:status", - "//iree/hal:allocator", - "//iree/hal:executable", - "//iree/hal:executable_spec", - "//iree/rt", - "//iree/vm:bytecode_tables_interpreter", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "bytecode_kernels", - hdrs = ["bytecode_kernels.h"], - textual_hdrs = [ - # TODO(benvanik): SIMD variants. - "bytecode_kernels_generic.h", - "bytecode_kernels_ruy.h", - ], - deps = [ - "//iree/base:shape", - "//iree/base:status", - "//iree/hal:buffer_view", - "@com_google_absl//absl/algorithm", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/types:span", - "@org_tensorflow//tensorflow/lite/experimental/ruy", - "@org_tensorflow//tensorflow/lite/experimental/ruy:context", - ], -) - -cc_test( - name = "bytecode_kernels_test", - srcs = ["bytecode_kernels_test.cc"], - deps = [ - ":bytecode_kernels", - "//iree/base:memory", - "//iree/base:status_matchers", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "interpreter_command_processor", - srcs = ["interpreter_command_processor.cc"], - hdrs = ["interpreter_command_processor.h"], - deps = [ - ":bytecode_executable", - "//iree/base:source_location", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:buffer_view", - "//iree/hal/host:host_local_command_processor", - "//iree/rt", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "interpreter_device", - srcs = ["interpreter_device.cc"], - hdrs = ["interpreter_device.h"], - deps = [ - ":bytecode_cache", - ":bytecode_kernels", - ":interpreter_command_processor", - "//iree/base:memory", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:command_buffer_validation", - "//iree/hal:command_queue", - "//iree/hal:device", - "//iree/hal:fence", - "//iree/hal/host:async_command_queue", - "//iree/hal/host:host_event", - "//iree/hal/host:host_local_allocator", - "//iree/hal/host:host_submission_queue", - "//iree/hal/host:inproc_command_buffer", - "//iree/rt", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "interpreter_driver", - srcs = ["interpreter_driver.cc"], - hdrs = ["interpreter_driver.h"], - deps = [ - ":interpreter_device", - "//iree/hal:device_info", - "//iree/hal:driver", - ], -) - -cc_library( - name = "interpreter_driver_module", - srcs = ["interpreter_driver_module.cc"], - deps = [ - ":interpreter_driver", - "//iree/base:init", - "//iree/base:status", - "//iree/hal:driver_registry", - ], - alwayslink = 1, -) - -cc_library( - name = "interpreter_module", - srcs = ["interpreter_module.cc"], - hdrs = ["interpreter_module.h"], - deps = [ - ":bytecode_dispatch", - ":bytecode_kernels", - "//iree/base:flatbuffer_util", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:allocator", - "//iree/hal:buffer_view", - "//iree/rt", - "//iree/vm:bytecode_module", - "//iree/vm:bytecode_tables_interpreter", - "@com_google_absl//absl/types:span", - ], -)
diff --git a/iree/hal/interpreter/bytecode_cache.cc b/iree/hal/interpreter/bytecode_cache.cc deleted file mode 100644 index b42db5a..0000000 --- a/iree/hal/interpreter/bytecode_cache.cc +++ /dev/null
@@ -1,55 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/interpreter/bytecode_cache.h" - -#include "iree/base/source_location.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/executable_format.h" -#include "iree/hal/interpreter/bytecode_executable.h" - -namespace iree { -namespace hal { - -BytecodeCache::BytecodeCache(ref_ptr<rt::Instance> instance, - hal::Allocator* allocator) - : instance_(std::move(instance)), allocator_(allocator) {} - -BytecodeCache::~BytecodeCache() = default; - -bool BytecodeCache::CanPrepareFormat(ExecutableFormat format) const { - return format == kExecutableFormatIreeBytecode; -} - -StatusOr<ref_ptr<Executable>> BytecodeCache::PrepareExecutable( - ExecutableCachingModeBitfield mode, const ExecutableSpec& spec) { - IREE_TRACE_SCOPE0("BytecodeCache::PrepareExecutable"); - if (!CanPrepareFormat(spec.format)) { - return UnimplementedErrorBuilder(IREE_LOC) - << "Unsupported format: " << spec.format; - } - - // Wrap the data (or copy it). - bool allow_aliasing_data = - AllBitsSet(mode, ExecutableCachingMode::kAliasProvidedData); - ASSIGN_OR_RETURN(auto executable, - BytecodeExecutable::Load(add_ref(instance_), allocator_, - spec, !allow_aliasing_data)); - - return executable; -} - -} // namespace hal -} // namespace iree
diff --git a/iree/hal/interpreter/bytecode_cache.h b/iree/hal/interpreter/bytecode_cache.h deleted file mode 100644 index 4da59ec..0000000 --- a/iree/hal/interpreter/bytecode_cache.h +++ /dev/null
@@ -1,44 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_INTERPRETER_BYTECODE_CACHE_H_ -#define IREE_HAL_INTERPRETER_BYTECODE_CACHE_H_ - -#include "iree/hal/allocator.h" -#include "iree/hal/executable.h" -#include "iree/hal/executable_cache.h" -#include "iree/rt/instance.h" - -namespace iree { -namespace hal { - -class BytecodeCache final : public ExecutableCache { - public: - BytecodeCache(ref_ptr<rt::Instance> instance, hal::Allocator* allocator); - ~BytecodeCache() override; - - bool CanPrepareFormat(ExecutableFormat format) const override; - - StatusOr<ref_ptr<Executable>> PrepareExecutable( - ExecutableCachingModeBitfield mode, const ExecutableSpec& spec) override; - - private: - ref_ptr<rt::Instance> instance_; - hal::Allocator* allocator_; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_INTERPRETER_BYTECODE_CACHE_H_
diff --git a/iree/hal/interpreter/bytecode_dispatch.cc b/iree/hal/interpreter/bytecode_dispatch.cc deleted file mode 100644 index f0e7040..0000000 --- a/iree/hal/interpreter/bytecode_dispatch.cc +++ /dev/null
@@ -1,850 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Implements a full bytecode dispatch system. -// Currently this is verbose and object oriented, but future revisions -// (once we have interesting benchmarks) will likely simplify and inline -// a lot of the checks to make things faster. Consider this to be as -// experimental an implementation as the entire rest of the project :) - -#include "iree/hal/interpreter/bytecode_dispatch.h" - -#include <algorithm> - -#include "absl/base/attributes.h" -#include "absl/container/inlined_vector.h" -#include "absl/types/span.h" -#include "iree/base/logging.h" -#include "iree/base/memory.h" -#include "iree/base/status.h" -#include "iree/hal/buffer_view.h" -#include "iree/hal/heap_buffer.h" -#include "iree/hal/interpreter/bytecode_dispatch_conversion.h" -#include "iree/hal/interpreter/bytecode_dispatch_util.h" -#include "iree/hal/interpreter/bytecode_kernels.h" -#include "iree/rt/function.h" -#include "iree/schemas/bytecode/interpreter_bytecode_v0.h" -#include "iree/vm/bytecode_module.h" -#include "iree/vm/bytecode_reader.h" -#include "iree/vm/bytecode_tables_interpreter.h" -#include "iree/vm/bytecode_util.h" -#include "iree/vm/opcode_info.h" - -namespace iree { -namespace hal { - -namespace { - -using ::iree::rt::Stack; -using ::iree::rt::StackFrame; -using ::iree::vm::BytecodeReader; - -} // namespace - -Status Dispatch(hal::Allocator* allocator, - kernels::RuntimeState* kernel_runtime_state, Stack* stack, - StackFrame* entry_stack_frame, - absl::Span<BufferView> entry_results) { - // Dispatch table mapping 1:1 with bytecode ops. - // Each entry is a label within this function that can be used for computed - // goto. You can find more information on computed goto here: - // https://eli.thegreenplace.net/2012/07/12/computed-goto-for-efficient-dispatch-tables - // - // Note that we ensure the table is 256 elements long exactly to make sure - // that unused opcodes are handled gracefully. - static const void* kDispatchTable[256] = { -#define DECLARE_DISPATCH(ordinal, name, ...) &&_dispatch_##name, -#define DECLARE_DISPATCH_RESERVED(ordinal, name, ...) &&_dispatch_unhandled, - IREE_INTERPRETER_OPCODE_LIST(DECLARE_DISPATCH, DECLARE_DISPATCH_RESERVED) -#undef DECLARE_DISPATCH -#undef DECLARE_DISPATCH_RESERVED - }; - - // Primary dispatch state. This is our 'native stack frame' and really just - // enough to make dereferencing common addresses (like the current offset) - // faster. You can think of this like CPU state (like PC). - // - // We hope that LLVM decides to keep these in registers (as they are touched - // for every instruction executed). The stack_frame will change as we call - // into different functions. - BytecodeReader reader(stack); - RETURN_IF_ERROR(reader.SwitchStackFrame(entry_stack_frame)); - -#define DISPATCH_NEXT() \ - { \ - uint8_t opcode = *reader.AdvanceOffset().ValueOrDie(); \ - DVLOG(1) \ - << "Interpreter dispatching op code: " \ - << GetOpcodeInfo(vm::interpreter_opcode_table(), opcode).mnemonic; \ - goto* kDispatchTable[opcode]; \ - } - -#define DISPATCH_CORE_OPCODE(opcode, body) \ - _dispatch_##opcode : {body} DISPATCH_NEXT() -#if defined(IREE_SUPPORT_F32) || defined(IREE_SUPPORT_F64) -#define DISPATCH_FLOAT_OPCODE(opcode, body) \ - _dispatch_##opcode : {body} DISPATCH_NEXT() -#else -#define DISPATCH_FLOAT_OPCODE(...) -#endif // IREE_SUPPORT_F32 || IREE_SUPPORT_F64 - - DISPATCH_NEXT(); - - DISPATCH_CORE_OPCODE(kConstant, { - ASSIGN_OR_RETURN(auto value, reader.ReadConstant()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - *dst_local = std::move(value); - }); - - DISPATCH_CORE_OPCODE(kCall, { - auto* old_stack_frame = stack->current_frame(); - ASSIGN_OR_RETURN(const auto& target_function, reader.ReadFunction()); - // TODO(benvanik): rework register storage interface. - ASSIGN_OR_RETURN( - const auto* function_def, - static_cast<const vm::BytecodeModule*>(target_function.module()) - ->GetFunctionDef(target_function.linkage(), - target_function.ordinal())); - ASSIGN_OR_RETURN(auto* new_stack_frame, stack->PushFrame(target_function)); - new_stack_frame->mutable_registers()->buffer_views.resize( - function_def->bytecode()->local_count()); - RETURN_IF_ERROR( - reader.CopyInputsAndSwitchStackFrame(old_stack_frame, new_stack_frame)); - DVLOG(1) << "Call; stack now: " << stack->DebugString(); - }); - - DISPATCH_CORE_OPCODE(kCallImport, { - return UnimplementedErrorBuilder(IREE_LOC) - << "Non-module imports not supported"; - }); - - DISPATCH_CORE_OPCODE(kCallIndirect, { - return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented call_indirect"; - }); - - DISPATCH_CORE_OPCODE(kReturn, { - auto* old_stack_frame = stack->current_frame(); - auto* new_stack_frame = stack->caller_frame(); - if (old_stack_frame == entry_stack_frame) { - // Returning from entry function. Marshal results from the return stmt. - ASSIGN_OR_RETURN(int32_t src_count, reader.ReadCount()); - for (int i = 0; i < src_count; ++i) { - ASSIGN_OR_RETURN( - auto* src_local, - reader.ReadLocal(old_stack_frame->mutable_registers())); - entry_results[i] = std::move(*src_local); - } - DVLOG(1) << "Returning to entry"; - return OkStatus(); - } else if (!new_stack_frame) { - return FailedPreconditionErrorBuilder(IREE_LOC) << "Stack underflow"; - } - RETURN_IF_ERROR(reader.CopyResultsAndSwitchStackFrame(old_stack_frame, - new_stack_frame)); - RETURN_IF_ERROR(stack->PopFrame()); - DVLOG(1) << "Return; stack now: " << stack->DebugString(); - }); - - DISPATCH_CORE_OPCODE(kBranch, { - ASSIGN_OR_RETURN(int32_t offset, reader.ReadBlockOffset()); - RETURN_IF_ERROR(reader.CopySlots()); - RETURN_IF_ERROR(reader.BranchToOffset(offset)); - }); - - DISPATCH_CORE_OPCODE(kCondBranch, { - // Evaluate condition first so we can do the copies as we read them for - // which side of the branch we take. - ASSIGN_OR_RETURN(auto* cond_local, reader.ReadLocal()); - bool cond_value = BufferViewIsTrue(*cond_local); - ASSIGN_OR_RETURN(int32_t true_offset, reader.ReadBlockOffset()); - if (cond_value) { - RETURN_IF_ERROR(reader.CopySlots()); - RETURN_IF_ERROR(reader.BranchToOffset(true_offset)); - } else { - ASSIGN_OR_RETURN(int32_t true_op_count, reader.ReadCount()); - RETURN_IF_ERROR(reader.SkipLocals(2 * true_op_count)); - ASSIGN_OR_RETURN(int32_t false_offset, reader.ReadBlockOffset()); - RETURN_IF_ERROR(reader.CopySlots()); - RETURN_IF_ERROR(reader.BranchToOffset(false_offset)); - } - }); - - DISPATCH_CORE_OPCODE(kCmpI, { - ASSIGN_OR_RETURN(uint8_t predicate, reader.ReadUint8_t()); - ASSIGN_OR_RETURN(auto* lhs_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* rhs_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - - switch (static_cast<CmpIPredicate>(predicate)) { - case CmpIPredicate::kEq: - RETURN_IF_ERROR(ApplyComparisonOpIS<kernels::CompareEQ>( - lhs_local, rhs_local, dst_local)); - break; - case CmpIPredicate::kNe: - RETURN_IF_ERROR(ApplyComparisonOpIS<kernels::CompareNE>( - lhs_local, rhs_local, dst_local)); - break; - case CmpIPredicate::kSlt: - RETURN_IF_ERROR(ApplyComparisonOpIS<kernels::CompareLT>( - lhs_local, rhs_local, dst_local)); - break; - case CmpIPredicate::kSle: - RETURN_IF_ERROR(ApplyComparisonOpIS<kernels::CompareLE>( - lhs_local, rhs_local, dst_local)); - break; - case CmpIPredicate::kSgt: - RETURN_IF_ERROR(ApplyComparisonOpIS<kernels::CompareGT>( - lhs_local, rhs_local, dst_local)); - break; - case CmpIPredicate::kSge: - RETURN_IF_ERROR(ApplyComparisonOpIS<kernels::CompareGE>( - lhs_local, rhs_local, dst_local)); - break; - case CmpIPredicate::kUlt: - RETURN_IF_ERROR(ApplyComparisonOpIU<kernels::CompareLT>( - lhs_local, rhs_local, dst_local)); - break; - case CmpIPredicate::kUle: - RETURN_IF_ERROR(ApplyComparisonOpIU<kernels::CompareLE>( - lhs_local, rhs_local, dst_local)); - break; - case CmpIPredicate::kUgt: - RETURN_IF_ERROR(ApplyComparisonOpIU<kernels::CompareGT>( - lhs_local, rhs_local, dst_local)); - break; - case CmpIPredicate::kUge: - RETURN_IF_ERROR(ApplyComparisonOpIU<kernels::CompareGE>( - lhs_local, rhs_local, dst_local)); - break; - } - }); - - DISPATCH_FLOAT_OPCODE(kCmpF, { - ASSIGN_OR_RETURN(uint8_t p, reader.ReadUint8_t()); - ASSIGN_OR_RETURN(auto* lhs_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* rhs_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - - auto predicate = static_cast<CmpFPredicate>(p); - switch (predicate) { - case CmpFPredicate::kOeq: - RETURN_IF_ERROR(ApplyComparisonOpF<kernels::CompareEQ>( - lhs_local, rhs_local, dst_local)); - break; - case CmpFPredicate::kUne: - RETURN_IF_ERROR(ApplyComparisonOpF<kernels::CompareNE>( - lhs_local, rhs_local, dst_local)); - break; - case CmpFPredicate::kOlt: - RETURN_IF_ERROR(ApplyComparisonOpF<kernels::CompareLT>( - lhs_local, rhs_local, dst_local)); - break; - case CmpFPredicate::kOle: - RETURN_IF_ERROR(ApplyComparisonOpF<kernels::CompareLE>( - lhs_local, rhs_local, dst_local)); - break; - case CmpFPredicate::kOgt: - RETURN_IF_ERROR(ApplyComparisonOpF<kernels::CompareGT>( - lhs_local, rhs_local, dst_local)); - break; - case CmpFPredicate::kOge: - RETURN_IF_ERROR(ApplyComparisonOpF<kernels::CompareGE>( - lhs_local, rhs_local, dst_local)); - break; - case CmpFPredicate::kFalse: - case CmpFPredicate::kOne: - case CmpFPredicate::kOrd: - case CmpFPredicate::kUeq: - case CmpFPredicate::kUgt: - case CmpFPredicate::kUge: - case CmpFPredicate::kUlt: - case CmpFPredicate::kUle: - case CmpFPredicate::kUno: - case CmpFPredicate::kTrue: - // TODO(b/132183250) support these if we ever need them. - return UnimplementedErrorBuilder(IREE_LOC) - << "Unsupported comparison predicate value " - << static_cast<int>(p) << " (" - << vm::PredicateToString(predicate) << ")"; - } - }); - - DISPATCH_CORE_OPCODE(kAllocStatic, { - return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented alloc_static"; - }); - - DISPATCH_CORE_OPCODE(kAllocStack, { - return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented alloc_stack"; - }); - - DISPATCH_CORE_OPCODE(kAllocStackInit, { - return UnimplementedErrorBuilder(IREE_LOC) - << "Unimplemented alloc_stack_init"; - }); - - DISPATCH_CORE_OPCODE(kAllocHeap, { - ASSIGN_OR_RETURN(auto heap_type, reader.ReadInt32()); - ASSIGN_OR_RETURN(auto type, reader.ReadType()); - size_t element_size = type.element_size(); - - // TODO(benvanik): more efficient reading and storage. - size_t element_count = 0; - ASSIGN_OR_RETURN(auto shape, reader.ReadShapePieces(&element_count)); - size_t allocation_size = element_size * element_count; - - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - dst_local->element_size = element_size; - dst_local->shape = shape; - - // TODO(benvanik): properly allocate with attributes from op. - CHECK_EQ(heap_type, 0); - ASSIGN_OR_RETURN( - dst_local->buffer, - allocator->Allocate(MemoryType::kHostLocal | MemoryType::kDeviceVisible, - BufferUsage::kAll, allocation_size)); - }); - - DISPATCH_CORE_OPCODE(kDiscard, { - // NOTE: if we were an encoder we would actually discard the buffer. - ASSIGN_OR_RETURN(auto* local, reader.ReadLocal()); - *local = {}; - }); - - DISPATCH_CORE_OPCODE(kRank, { - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - int32_t rank = src_local->shape.size(); - RETURN_IF_ERROR(dst_local->buffer->WriteData(0, &rank, sizeof(int32_t))); - }); - - DISPATCH_CORE_OPCODE(kDim, { - ASSIGN_OR_RETURN(int32_t axis, reader.ReadUint8_t()); - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(int32_t dim, src_local->shape.ResolveAxis(axis)); - RETURN_IF_ERROR(dst_local->buffer->WriteData(0, &dim, sizeof(int32_t))); - }); - - DISPATCH_CORE_OPCODE(kShape, { - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - RETURN_IF_ERROR(dst_local->buffer->WriteData( - 0, src_local->shape.subspan().data(), - src_local->shape.subspan().size() * sizeof(int32_t))); - }); - - DISPATCH_CORE_OPCODE(kLength, { - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - int32_t length = src_local->shape.element_count(); - RETURN_IF_ERROR(dst_local->buffer->WriteData(0, &length, sizeof(int32_t))); - }); - - DISPATCH_CORE_OPCODE(kDynamicSlice, { - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto indices, reader.ReadSlotElements<int32_t>()); - ASSIGN_OR_RETURN(auto lengths, reader.ReadSlotElements<int32_t>()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(*dst_local, src_local->Slice(indices, lengths)); - }); - - DISPATCH_CORE_OPCODE(kStaticSlice, { - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto indices, reader.ReadIndexList()); - ASSIGN_OR_RETURN(auto lengths, reader.ReadIndexList()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(*dst_local, src_local->Slice(indices, lengths)); - }); - - DISPATCH_CORE_OPCODE(kDynamicCopy, { - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto src_indices, reader.ReadSlotElements<int32_t>()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto dst_indices, reader.ReadSlotElements<int32_t>()); - ASSIGN_OR_RETURN(auto lengths, reader.ReadSlotElements<int32_t>()); - RETURN_IF_ERROR( - ApplyCopy(src_local, src_indices, dst_local, dst_indices, lengths)); - }); - - DISPATCH_CORE_OPCODE(kStaticCopy, { - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto src_indices, reader.ReadIndexList()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto dst_indices, reader.ReadIndexList()); - ASSIGN_OR_RETURN(auto lengths, reader.ReadIndexList()); - RETURN_IF_ERROR( - ApplyCopy(src_local, src_indices, dst_local, dst_indices, lengths)); - }); - - DISPATCH_CORE_OPCODE(kClone, { - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - dst_local->element_size = src_local->element_size; - dst_local->shape = src_local->shape; - dst_local->buffer = HeapBuffer::Allocate(src_local->buffer->usage(), - src_local->buffer->byte_length()); - RETURN_IF_ERROR(dst_local->buffer->CopyData(0, src_local->buffer.get())); - }); - - DISPATCH_CORE_OPCODE(kSplit, { - return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented split"; - }); - - DISPATCH_CORE_OPCODE(kAssign, { - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - *dst_local = *src_local; - }); - - DISPATCH_CORE_OPCODE(kCondAssign, { - ASSIGN_OR_RETURN(auto* cond_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* lhs_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* rhs_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - *dst_local = BufferViewIsTrue(*cond_local) ? *lhs_local : *rhs_local; - }); - - DISPATCH_CORE_OPCODE(kReshape, { - // TODO(benvanik): more logic required if strides differ. - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto shape_data, reader.ReadSlotElements<int32_t>()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - Shape new_shape = Shape{shape_data}; - if (src_local->shape.element_count() != new_shape.element_count()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "New element count " << new_shape.element_count() - << " != source element count " << src_local->shape.element_count(); - } - dst_local->shape = new_shape; - dst_local->buffer = add_ref(src_local->buffer); - dst_local->element_size = src_local->element_size; - }); - - DISPATCH_CORE_OPCODE(kSelect, { - ASSIGN_OR_RETURN(auto* cond_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* lhs_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* rhs_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto cond_buffer, cond_local->buffer->MapMemory<uint8_t>( - MemoryAccess::kRead)); - ASSIGN_OR_RETURN(auto lhs_buffer, lhs_local->buffer->MapMemory<uint8_t>( - MemoryAccess::kRead)); - ASSIGN_OR_RETURN(auto rhs_buffer, rhs_local->buffer->MapMemory<uint8_t>( - MemoryAccess::kRead)); - ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<uint8_t>( - MemoryAccess::kDiscardWrite)); - if (cond_local->element_size != 1) { - return InvalidArgumentErrorBuilder(IREE_LOC) << "Select cond must be i8"; - } else if (lhs_buffer.size() != rhs_buffer.size()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "LHS " << lhs_buffer.size() << "b != RHS " << rhs_buffer.size() - << "b; both arguments must match"; - } else if (lhs_buffer.size() != dst_buffer.size()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Dest " << dst_buffer.size() << "b != LHS/RHS " - << lhs_buffer.size() << "b; dest must match inputs"; - } - switch (lhs_local->element_size) { - case 1: - RETURN_IF_ERROR(kernels::Select::Execute<uint8_t>( - cond_buffer.contents(), lhs_buffer.contents(), - rhs_buffer.contents(), dst_buffer.mutable_contents())); - break; - case 2: - RETURN_IF_ERROR(kernels::Select::Execute<uint16_t>( - cond_buffer.contents(), - ReinterpretSpan<uint16_t>(lhs_buffer.contents()), - ReinterpretSpan<uint16_t>(rhs_buffer.contents()), - ReinterpretSpan<uint16_t>(dst_buffer.mutable_contents()))); - break; - case 4: - RETURN_IF_ERROR(kernels::Select::Execute<uint32_t>( - cond_buffer.contents(), - ReinterpretSpan<uint32_t>(lhs_buffer.contents()), - ReinterpretSpan<uint32_t>(rhs_buffer.contents()), - ReinterpretSpan<uint32_t>(dst_buffer.mutable_contents()))); - break; - case 8: - RETURN_IF_ERROR(kernels::Select::Execute<uint64_t>( - cond_buffer.contents(), - ReinterpretSpan<uint64_t>(lhs_buffer.contents()), - ReinterpretSpan<uint64_t>(rhs_buffer.contents()), - ReinterpretSpan<uint64_t>(dst_buffer.mutable_contents()))); - break; - default: - return UnimplementedErrorBuilder(IREE_LOC) - << "Unimplemented element size: " << lhs_local->element_size; - } - }); - - DISPATCH_CORE_OPCODE(kTranspose, { - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto perm_data, reader.ReadSlotElements<int32_t>()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - RETURN_IF_ERROR(ApplyUnaryOpIU<kernels::Transpose>( - src_local, dst_local, src_local->shape, - absl::MakeConstSpan(perm_data))); - }); - - DISPATCH_CORE_OPCODE(kReverse, { - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto perm_data, reader.ReadSlotElements<int32_t>()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - RETURN_IF_ERROR( - ApplyUnaryOpIU<kernels::Reverse>(src_local, dst_local, src_local->shape, - absl::MakeConstSpan(perm_data))); - }); - - DISPATCH_CORE_OPCODE(kPad, { - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* padding_value, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto edge_padding_low, reader.ReadSlotElements<int32_t>()); - ASSIGN_OR_RETURN(auto edge_padding_high, - reader.ReadSlotElements<int32_t>()); - ASSIGN_OR_RETURN(auto interior_padding, reader.ReadSlotElements<int32_t>()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - - RETURN_IF_ERROR(ApplyBinaryOpIU<kernels::Pad>( - src_local, padding_value, dst_local, src_local->shape, dst_local->shape, - absl::MakeConstSpan(edge_padding_low), - absl::MakeConstSpan(edge_padding_high), - absl::MakeConstSpan(interior_padding))); - }); - - DISPATCH_CORE_OPCODE(kBroadcast, { - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto shape_data, reader.ReadSlotElements<int32_t>()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - dst_local->shape = Shape{shape_data}; - RETURN_IF_ERROR(ApplyUnaryOpIU<kernels::Broadcast>(src_local, dst_local)); - }); - - DISPATCH_CORE_OPCODE(kTile, { - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto shape_data, reader.ReadSlotElements<int32_t>()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - dst_local->shape = Shape{shape_data}; - RETURN_IF_ERROR(ApplyUnaryOpIU<kernels::Tile>( - src_local, dst_local, src_local->shape, dst_local->shape)); - }); - - DISPATCH_CORE_OPCODE(kNot, { - RETURN_IF_ERROR(DispatchElementwiseUnaryOpIU<kernels::Not>(&reader)); - }); - DISPATCH_CORE_OPCODE(kAnd, { - RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::And>(&reader)); - }); - DISPATCH_CORE_OPCODE(kOr, { - RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Or>(&reader)); - }); - DISPATCH_CORE_OPCODE(kXor, { - RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Xor>(&reader)); - }); - DISPATCH_CORE_OPCODE(kShiftLeft, { - RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::ShiftLeft>(&reader)); - }); - DISPATCH_CORE_OPCODE(kShiftRightLogical, { - RETURN_IF_ERROR( - DispatchElementwiseBinaryOpIU<kernels::ShiftRight>(&reader)); - }); - DISPATCH_CORE_OPCODE(kShiftRightArithmetic, { - RETURN_IF_ERROR( - DispatchElementwiseBinaryOpIS<kernels::ShiftRight>(&reader)); - }); - - DISPATCH_CORE_OPCODE(kAddI, { - RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Add>(&reader)); - }); - DISPATCH_FLOAT_OPCODE(kAddF, { - RETURN_IF_ERROR(DispatchElementwiseBinaryOpF<kernels::Add>(&reader)); - }); - - DISPATCH_CORE_OPCODE(kSubI, { - RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Sub>(&reader)); - }); - DISPATCH_FLOAT_OPCODE(kSubF, { - RETURN_IF_ERROR(DispatchElementwiseBinaryOpF<kernels::Sub>(&reader)); - }); - - DISPATCH_CORE_OPCODE(kAbsI, { - RETURN_IF_ERROR(DispatchElementwiseUnaryOpIS<kernels::Abs>(&reader)); - }); - DISPATCH_FLOAT_OPCODE(kAbsF, { - RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Abs>(&reader)); - }); - - DISPATCH_CORE_OPCODE(kMulI, { - RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Mul>(&reader)); - }); - DISPATCH_FLOAT_OPCODE(kMulF, { - RETURN_IF_ERROR(DispatchElementwiseBinaryOpF<kernels::Mul>(&reader)); - }); - - DISPATCH_CORE_OPCODE(kDivIS, { - RETURN_IF_ERROR(DispatchElementwiseBinaryOpIS<kernels::Div>(&reader)); - }); - DISPATCH_CORE_OPCODE(kDivIU, { - RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Div>(&reader)); - }); - DISPATCH_FLOAT_OPCODE(kDivF, { - RETURN_IF_ERROR(DispatchElementwiseBinaryOpF<kernels::Div>(&reader)); - }); - - DISPATCH_CORE_OPCODE(kMulAddI, { - RETURN_IF_ERROR(DispatchElementwiseTernaryOpIU<kernels::MulAdd>(&reader)); - }); - DISPATCH_FLOAT_OPCODE(kMulAddF, { - RETURN_IF_ERROR(DispatchElementwiseTernaryOpF<kernels::MulAdd>(&reader)); - }); - DISPATCH_FLOAT_OPCODE(kExpF, { - RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Exp>(&reader)); - }); - DISPATCH_FLOAT_OPCODE(kLogF, { - RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Log>(&reader)); - }); - DISPATCH_FLOAT_OPCODE(kRsqrtF, { - RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Rsqrt>(&reader)); - }); - DISPATCH_FLOAT_OPCODE(kCosF, { - RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Cos>(&reader)); - }); - DISPATCH_FLOAT_OPCODE(kSinF, { - RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Sin>(&reader)); - }); - DISPATCH_FLOAT_OPCODE(kTanhF, { - RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Tanh>(&reader)); - }); - DISPATCH_FLOAT_OPCODE(kAtan2F, { - RETURN_IF_ERROR(DispatchElementwiseBinaryOpF<kernels::Atan2>(&reader)); - }); - - DISPATCH_CORE_OPCODE(kMinIS, { - RETURN_IF_ERROR(DispatchElementwiseBinaryOpIS<kernels::Min>(&reader)); - }); - DISPATCH_CORE_OPCODE(kMinIU, { - RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Min>(&reader)); - }); - DISPATCH_FLOAT_OPCODE(kMinF, { - RETURN_IF_ERROR(DispatchElementwiseBinaryOpF<kernels::Min>(&reader)); - }); - - DISPATCH_CORE_OPCODE(kMaxIS, { - RETURN_IF_ERROR(DispatchElementwiseBinaryOpIS<kernels::Max>(&reader)); - }); - DISPATCH_CORE_OPCODE(kMaxIU, { - RETURN_IF_ERROR(DispatchElementwiseBinaryOpIU<kernels::Max>(&reader)); - }); - DISPATCH_FLOAT_OPCODE(kMaxF, { - RETURN_IF_ERROR(DispatchElementwiseBinaryOpF<kernels::Max>(&reader)); - }); - - DISPATCH_CORE_OPCODE(kClampIS, { - RETURN_IF_ERROR(DispatchElementwiseTernaryOpIS<kernels::Clamp>(&reader)); - }); - DISPATCH_CORE_OPCODE(kClampIU, { - RETURN_IF_ERROR(DispatchElementwiseTernaryOpIS<kernels::Clamp>(&reader)); - }); - DISPATCH_FLOAT_OPCODE(kClampF, { - RETURN_IF_ERROR(DispatchElementwiseTernaryOpF<kernels::Clamp>(&reader)); - }); - - DISPATCH_FLOAT_OPCODE(kFloorF, { - RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Floor>(&reader)); - }); - DISPATCH_FLOAT_OPCODE(kCeilF, { - RETURN_IF_ERROR(DispatchElementwiseUnaryOpF<kernels::Ceil>(&reader)); - }); - - DISPATCH_CORE_OPCODE(kConvertSS, { - ASSIGN_OR_RETURN(auto src_type, reader.ReadType()); - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto dst_type, reader.ReadType()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - RETURN_IF_ERROR( - ApplyConvertSS::Apply(src_type, src_local, dst_type, dst_local)); - }); - DISPATCH_CORE_OPCODE(kConvertUU, { - ASSIGN_OR_RETURN(auto src_type, reader.ReadType()); - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto dst_type, reader.ReadType()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - RETURN_IF_ERROR( - ApplyConvertUU::Apply(src_type, src_local, dst_type, dst_local)); - }); - DISPATCH_CORE_OPCODE(kConvertSU, { - ASSIGN_OR_RETURN(auto src_type, reader.ReadType()); - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto dst_type, reader.ReadType()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - RETURN_IF_ERROR( - ApplyConvertSU::Apply(src_type, src_local, dst_type, dst_local)); - }); - DISPATCH_CORE_OPCODE(kConvertUS, { - ASSIGN_OR_RETURN(auto src_type, reader.ReadType()); - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto dst_type, reader.ReadType()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - RETURN_IF_ERROR( - ApplyConvertUS::Apply(src_type, src_local, dst_type, dst_local)); - }); - - DISPATCH_CORE_OPCODE(kMatMulI, { - ASSIGN_OR_RETURN(auto* lhs_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* rhs_local, reader.ReadLocal()); - // TODO(benvanik): add fused matmul-with-bias op in MLIR and lower to this. - BufferView* bias_local = nullptr; - ASSIGN_OR_RETURN(auto* multiplier_mantissa_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* multiplier_exponent_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - RETURN_IF_ERROR(ValidateMatMulOpI(lhs_local, rhs_local, bias_local, - multiplier_mantissa_local, - multiplier_exponent_local, dst_local)); - auto* mat_mul_state = kernel_runtime_state->mat_mul_state.get(); - // TODO(benvanik): define as a matrix of supported types to enable 8*8=16, - // accumulator options, and other precision modes. - switch (lhs_local->element_size) { - case 1: - RETURN_IF_ERROR(ApplyMatMulOpI<int8_t>( - mat_mul_state, lhs_local, rhs_local, bias_local, - multiplier_mantissa_local, multiplier_exponent_local, dst_local)); - break; - case 2: - RETURN_IF_ERROR(ApplyMatMulOpI<int16_t>( - mat_mul_state, lhs_local, rhs_local, bias_local, - multiplier_mantissa_local, multiplier_exponent_local, dst_local)); - break; - case 4: - RETURN_IF_ERROR(ApplyMatMulOpI<int32_t>( - mat_mul_state, lhs_local, rhs_local, bias_local, - multiplier_mantissa_local, multiplier_exponent_local, dst_local)); - break; - case 8: - RETURN_IF_ERROR(ApplyMatMulOpI<int64_t>( - mat_mul_state, lhs_local, rhs_local, bias_local, - multiplier_mantissa_local, multiplier_exponent_local, dst_local)); - break; - default: - return UnimplementedErrorBuilder(IREE_LOC) - << "Unimplemented element size: " << lhs_local->element_size; - } - }); - - DISPATCH_FLOAT_OPCODE(kMatMulF, { - ASSIGN_OR_RETURN(auto* lhs_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* rhs_local, reader.ReadLocal()); - BufferView* bias_local = nullptr; - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - RETURN_IF_ERROR( - ValidateMatMulOpF(lhs_local, rhs_local, bias_local, dst_local)); - auto* mat_mul_state = kernel_runtime_state->mat_mul_state.get(); - switch (lhs_local->element_size) { - case 4: - RETURN_IF_ERROR(ApplyMatMulOpF<float>( - mat_mul_state, lhs_local, rhs_local, bias_local, dst_local)); - break; - case 8: - RETURN_IF_ERROR(ApplyMatMulOpF<double>( - mat_mul_state, lhs_local, rhs_local, bias_local, dst_local)); - break; - default: - return UnimplementedErrorBuilder(IREE_LOC) - << "Unimplemented element size: " << lhs_local->element_size; - } - }); - - DISPATCH_CORE_OPCODE(kReduceSumI, { - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* init_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto dimension, reader.ReadInt32()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - // TODO(scotttodd): validate - RETURN_IF_ERROR(ApplyBinaryOpIS<kernels::ReduceSum>( - src_local, init_local, dst_local, dimension, src_local->shape, - dst_local->shape)); - }); - - DISPATCH_FLOAT_OPCODE(kReduceSumF, { - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* init_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto dimension, reader.ReadInt32()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - // TODO(scotttodd): validate - RETURN_IF_ERROR(ApplyBinaryOpF<kernels::ReduceSum>( - src_local, init_local, dst_local, dimension, src_local->shape, - dst_local->shape)); - }); - - DISPATCH_CORE_OPCODE(kReduceMinI, { - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* init_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto dimension, reader.ReadInt32()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - // TODO(scotttodd): validate - RETURN_IF_ERROR(ApplyBinaryOpIS<kernels::ReduceMin>( - src_local, init_local, dst_local, dimension, src_local->shape, - dst_local->shape)); - }); - - DISPATCH_FLOAT_OPCODE(kReduceMinF, { - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* init_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto dimension, reader.ReadInt32()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - // TODO(scotttodd): validate - RETURN_IF_ERROR(ApplyBinaryOpF<kernels::ReduceMin>( - src_local, init_local, dst_local, dimension, src_local->shape, - dst_local->shape)); - }); - - DISPATCH_CORE_OPCODE(kReduceMaxI, { - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* init_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto dimension, reader.ReadInt32()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - // TODO(scotttodd): validate - RETURN_IF_ERROR(ApplyBinaryOpIS<kernels::ReduceMax>( - src_local, init_local, dst_local, dimension, src_local->shape, - dst_local->shape)); - }); - - DISPATCH_FLOAT_OPCODE(kReduceMaxF, { - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* init_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto dimension, reader.ReadInt32()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - // TODO(scotttodd): validate - RETURN_IF_ERROR(ApplyBinaryOpF<kernels::ReduceMax>( - src_local, init_local, dst_local, dimension, src_local->shape, - dst_local->shape)); - }); - - DISPATCH_CORE_OPCODE(kTrace, { - return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented trace"; - }); - - DISPATCH_CORE_OPCODE(kBreak, { - return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented break"; - }); - - DISPATCH_CORE_OPCODE(kCondBreak, { - return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented cond_break"; - }); - -_dispatch_unhandled: - // TODO(benvanik): better tracing. - return UnimplementedErrorBuilder(IREE_LOC) << "Unknown dispatch opcode"; -} // NOLINT(readability/fn_size) - -} // namespace hal -} // namespace iree
diff --git a/iree/hal/interpreter/bytecode_dispatch.h b/iree/hal/interpreter/bytecode_dispatch.h deleted file mode 100644 index edd92d2..0000000 --- a/iree/hal/interpreter/bytecode_dispatch.h +++ /dev/null
@@ -1,35 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_H_ -#define IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_H_ - -#include "iree/base/status.h" -#include "iree/hal/allocator.h" -#include "iree/hal/interpreter/bytecode_kernels.h" -#include "iree/rt/stack.h" -#include "iree/rt/stack_frame.h" - -namespace iree { -namespace hal { - -Status Dispatch(hal::Allocator* allocator, - kernels::RuntimeState* kernel_runtime_state, rt::Stack* stack, - rt::StackFrame* entry_stack_frame, - absl::Span<BufferView> entry_results); - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_H_
diff --git a/iree/hal/interpreter/bytecode_dispatch_conversion.h b/iree/hal/interpreter/bytecode_dispatch_conversion.h deleted file mode 100644 index 8973250..0000000 --- a/iree/hal/interpreter/bytecode_dispatch_conversion.h +++ /dev/null
@@ -1,395 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Conversion helper tables. - -#ifndef IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_CONVERSION_H_ -#define IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_CONVERSION_H_ - -#include "iree/base/status.h" -#include "iree/hal/buffer_view.h" -#include "iree/hal/interpreter/bytecode_dispatch_util.h" -#include "iree/schemas/bytecode/interpreter_bytecode_v0.h" -#include "iree/vm/type.h" - -namespace iree { -namespace hal { - -template <typename KERNEL, bool src_signed, bool dst_signed, typename... ARGS> -struct ApplyConversionOp { - static Status Apply(const vm::Type& src_type, BufferView* src_local, - const vm::Type& dst_type, BufferView* dst_local, - ARGS... args) { - // Validate ranges so that we cannot go out of bounds on thunk table. - int src_type_index = src_type.type_index(); - int dst_type_index = dst_type.type_index(); - if (src_type_index < 0 || src_type_index >= kBuiltinTypeCount) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Conversion from invalid source builtin type " - << src_type_index; - } else if (dst_type_index < 0 || dst_type_index >= kBuiltinTypeCount) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Conversion to invalid dest builtin type " << dst_type_index; - } - - // All possible combinations of conversions. - using KernelFn = Status (*)(BufferView * src_local, BufferView * dst_local, - ARGS... args); - KernelFn fn = nullptr; - if (src_signed && dst_signed) { - // Signed -> signed. - static const KernelFn - kConversionTable[kBuiltinTypeCount * kBuiltinTypeCount] = { - // src_type = kI8: - /* kI8 */ Thunk<int8_t, int8_t>::Apply, - /* kI16 */ Thunk<int8_t, int16_t>::Apply, - /* kI32 */ Thunk<int8_t, int32_t>::Apply, - /* kI64 */ Thunk<int8_t, int64_t>::Apply, - /* kF16 */ nullptr, - /* kF32 */ Thunk<int8_t, float>::Apply, - /* kF64 */ Thunk<int8_t, double>::Apply, - - // src_type = kI16: - /* kI8 */ Thunk<int16_t, int8_t>::Apply, - /* kI16 */ Thunk<int16_t, int16_t>::Apply, - /* kI32 */ Thunk<int16_t, int32_t>::Apply, - /* kI64 */ Thunk<int16_t, int64_t>::Apply, - /* kF16 */ nullptr, - /* kF32 */ Thunk<int16_t, float>::Apply, - /* kF64 */ Thunk<int16_t, double>::Apply, - - // src_type = kI32: - /* kI8 */ Thunk<int32_t, int8_t>::Apply, - /* kI16 */ Thunk<int32_t, int16_t>::Apply, - /* kI32 */ Thunk<int32_t, int32_t>::Apply, - /* kI64 */ Thunk<int32_t, int64_t>::Apply, - /* kF16 */ nullptr, - /* kF32 */ Thunk<int32_t, float>::Apply, - /* kF64 */ Thunk<int32_t, double>::Apply, - - // src_type = kI64: - /* kI8 */ Thunk<int64_t, int8_t>::Apply, - /* kI16 */ Thunk<int64_t, int16_t>::Apply, - /* kI32 */ Thunk<int64_t, int32_t>::Apply, - /* kI64 */ Thunk<int64_t, int64_t>::Apply, - /* kF16 */ nullptr, - /* kF32 */ Thunk<int64_t, float>::Apply, - /* kF64 */ Thunk<int64_t, double>::Apply, - - // src_type = kF16: - /* kI8 */ nullptr, - /* kI16 */ nullptr, - /* kI32 */ nullptr, - /* kI64 */ nullptr, - /* kF16 */ Thunk<uint16_t, uint16_t>::Apply, - /* kF32 */ nullptr, - /* kF64 */ nullptr, - - // src_type = kF32: - /* kI8 */ Thunk<float, int8_t>::Apply, - /* kI16 */ Thunk<float, int16_t>::Apply, - /* kI32 */ Thunk<float, int32_t>::Apply, - /* kI64 */ Thunk<float, int64_t>::Apply, - /* kF16 */ nullptr, - /* kF32 */ Thunk<float, float>::Apply, - /* kF64 */ Thunk<float, double>::Apply, - - // src_type = kF64: - /* kI8 */ Thunk<double, int8_t>::Apply, - /* kI16 */ Thunk<double, int16_t>::Apply, - /* kI32 */ Thunk<double, int32_t>::Apply, - /* kI64 */ Thunk<double, int64_t>::Apply, - /* kF16 */ nullptr, - /* kF32 */ Thunk<double, float>::Apply, - /* kF64 */ Thunk<double, double>::Apply, - }; - fn = - kConversionTable[src_type_index * kBuiltinTypeCount + dst_type_index]; - } else if (src_signed && !dst_signed) { - // Signed -> unsigned. - static const KernelFn - kConversionTable[kBuiltinTypeCount * kBuiltinTypeCount] = { - // src_type = kI8: - /* kI8 */ Thunk<int8_t, uint8_t>::Apply, - /* kI16 */ Thunk<int8_t, uint16_t>::Apply, - /* kI32 */ Thunk<int8_t, uint32_t>::Apply, - /* kI64 */ Thunk<int8_t, uint64_t>::Apply, - /* kF16 */ nullptr, - /* kF32 */ nullptr, - /* kF64 */ nullptr, - - // src_type = kI16: - /* kI8 */ Thunk<int16_t, uint8_t>::Apply, - /* kI16 */ Thunk<int16_t, uint16_t>::Apply, - /* kI32 */ Thunk<int16_t, uint32_t>::Apply, - /* kI64 */ Thunk<int16_t, uint64_t>::Apply, - /* kF16 */ nullptr, - /* kF32 */ nullptr, - /* kF64 */ nullptr, - - // src_type = kI32: - /* kI8 */ Thunk<int32_t, uint8_t>::Apply, - /* kI16 */ Thunk<int32_t, uint16_t>::Apply, - /* kI32 */ Thunk<int32_t, uint32_t>::Apply, - /* kI64 */ Thunk<int32_t, uint64_t>::Apply, - /* kF16 */ nullptr, - /* kF32 */ nullptr, - /* kF64 */ nullptr, - - // src_type = kI64: - /* kI8 */ Thunk<int64_t, uint8_t>::Apply, - /* kI16 */ Thunk<int64_t, uint16_t>::Apply, - /* kI32 */ Thunk<int64_t, uint32_t>::Apply, - /* kI64 */ Thunk<int64_t, uint64_t>::Apply, - /* kF16 */ nullptr, - /* kF32 */ nullptr, - /* kF64 */ nullptr, - - // src_type = kF16: - /* kI8 */ nullptr, - /* kI16 */ nullptr, - /* kI32 */ nullptr, - /* kI64 */ nullptr, - /* kF16 */ nullptr, - /* kF32 */ nullptr, - /* kF64 */ nullptr, - - // src_type = kF32: - /* kI8 */ Thunk<float, uint8_t>::Apply, - /* kI16 */ Thunk<float, uint16_t>::Apply, - /* kI32 */ Thunk<float, uint32_t>::Apply, - /* kI64 */ Thunk<float, uint64_t>::Apply, - /* kF16 */ nullptr, - /* kF32 */ nullptr, - /* kF64 */ nullptr, - - // src_type = kF64: - /* kI8 */ Thunk<double, uint8_t>::Apply, - /* kI16 */ Thunk<double, uint16_t>::Apply, - /* kI32 */ Thunk<double, uint32_t>::Apply, - /* kI64 */ Thunk<double, uint64_t>::Apply, - /* kF16 */ nullptr, - /* kF32 */ nullptr, - /* kF64 */ nullptr, - }; - fn = - kConversionTable[src_type_index * kBuiltinTypeCount + dst_type_index]; - } else if (!src_signed && dst_signed) { - // Unsigned -> signed. - static const KernelFn - kConversionTable[kBuiltinTypeCount * kBuiltinTypeCount] = { - // src_type = kI8: - /* kI8 */ Thunk<uint8_t, int8_t>::Apply, - /* kI16 */ Thunk<uint8_t, int16_t>::Apply, - /* kI32 */ Thunk<uint8_t, int32_t>::Apply, - /* kI64 */ Thunk<uint8_t, int64_t>::Apply, - /* kF16 */ nullptr, - /* kF32 */ Thunk<uint8_t, float>::Apply, - /* kF64 */ Thunk<uint8_t, double>::Apply, - - // src_type = kI16: - /* kI8 */ Thunk<uint16_t, int8_t>::Apply, - /* kI16 */ Thunk<uint16_t, int16_t>::Apply, - /* kI32 */ Thunk<uint16_t, int32_t>::Apply, - /* kI64 */ Thunk<uint16_t, int64_t>::Apply, - /* kF16 */ nullptr, - /* kF32 */ Thunk<uint16_t, float>::Apply, - /* kF64 */ Thunk<uint16_t, double>::Apply, - - // src_type = kI32: - /* kI8 */ Thunk<uint32_t, int8_t>::Apply, - /* kI16 */ Thunk<uint32_t, int16_t>::Apply, - /* kI32 */ Thunk<uint32_t, int32_t>::Apply, - /* kI64 */ Thunk<uint32_t, int64_t>::Apply, - /* kF16 */ nullptr, - /* kF32 */ Thunk<uint32_t, float>::Apply, - /* kF64 */ Thunk<uint32_t, double>::Apply, - - // src_type = kI64: - /* kI8 */ Thunk<uint64_t, int8_t>::Apply, - /* kI16 */ Thunk<uint64_t, int16_t>::Apply, - /* kI32 */ Thunk<uint64_t, int32_t>::Apply, - /* kI64 */ Thunk<uint64_t, int64_t>::Apply, - /* kF16 */ nullptr, - /* kF32 */ Thunk<uint64_t, float>::Apply, - /* kF64 */ Thunk<uint64_t, double>::Apply, - - // src_type = kF16: - /* kI8 */ nullptr, - /* kI16 */ nullptr, - /* kI32 */ nullptr, - /* kI64 */ nullptr, - /* kF16 */ nullptr, - /* kF32 */ nullptr, - /* kF64 */ nullptr, - - // src_type = kF32: - /* kI8 */ nullptr, - /* kI16 */ nullptr, - /* kI32 */ nullptr, - /* kI64 */ nullptr, - /* kF16 */ nullptr, - /* kF32 */ nullptr, - /* kF64 */ nullptr, - - // src_type = kF64: - /* kI8 */ nullptr, - /* kI16 */ nullptr, - /* kI32 */ nullptr, - /* kI64 */ nullptr, - /* kF16 */ nullptr, - /* kF32 */ nullptr, - /* kF64 */ nullptr, - }; - fn = - kConversionTable[src_type_index * kBuiltinTypeCount + dst_type_index]; - } else if (!src_signed && !dst_signed) { - // Unsigned -> unsigned. - static const KernelFn - kConversionTable[kBuiltinTypeCount * kBuiltinTypeCount] = { - // src_type = kI8: - /* kI8 */ Thunk<uint8_t, uint8_t>::Apply, - /* kI16 */ Thunk<uint8_t, uint16_t>::Apply, - /* kI32 */ Thunk<uint8_t, uint32_t>::Apply, - /* kI64 */ Thunk<uint8_t, uint64_t>::Apply, - /* kF16 */ nullptr, - /* kF32 */ nullptr, - /* kF64 */ nullptr, - - // src_type = kI16: - /* kI8 */ Thunk<uint16_t, uint8_t>::Apply, - /* kI16 */ Thunk<uint16_t, uint16_t>::Apply, - /* kI32 */ Thunk<uint16_t, uint32_t>::Apply, - /* kI64 */ Thunk<uint16_t, uint64_t>::Apply, - /* kF16 */ nullptr, - /* kF32 */ nullptr, - /* kF64 */ nullptr, - - // src_type = kI32: - /* kI8 */ Thunk<uint32_t, uint8_t>::Apply, - /* kI16 */ Thunk<uint32_t, uint16_t>::Apply, - /* kI32 */ Thunk<uint32_t, uint32_t>::Apply, - /* kI64 */ Thunk<uint32_t, uint64_t>::Apply, - /* kF16 */ nullptr, - /* kF32 */ nullptr, - /* kF64 */ nullptr, - - // src_type = kI64: - /* kI8 */ Thunk<uint64_t, uint8_t>::Apply, - /* kI16 */ Thunk<uint64_t, uint16_t>::Apply, - /* kI32 */ Thunk<uint64_t, uint32_t>::Apply, - /* kI64 */ Thunk<uint64_t, uint64_t>::Apply, - /* kF16 */ nullptr, - /* kF32 */ nullptr, - /* kF64 */ nullptr, - - // src_type = kF16: - /* kI8 */ nullptr, - /* kI16 */ nullptr, - /* kI32 */ nullptr, - /* kI64 */ nullptr, - /* kF16 */ nullptr, - /* kF32 */ nullptr, - /* kF64 */ nullptr, - - // src_type = kF32: - /* kI8 */ nullptr, - /* kI16 */ nullptr, - /* kI32 */ nullptr, - /* kI64 */ nullptr, - /* kF16 */ nullptr, - /* kF32 */ nullptr, - /* kF64 */ nullptr, - - // src_type = kF64: - /* kI8 */ nullptr, - /* kI16 */ nullptr, - /* kI32 */ nullptr, - /* kI64 */ nullptr, - /* kF16 */ nullptr, - /* kF32 */ nullptr, - /* kF64 */ nullptr, - }; - fn = - kConversionTable[src_type_index * kBuiltinTypeCount + dst_type_index]; - } - if (!fn) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Unsupported conversion from " << src_type_index << " to " - << dst_type_index; - } - return fn(src_local, dst_local, args...); - } - - template <typename SRC, typename DST> - struct Thunk { - static Status Apply(BufferView* src_local, BufferView* dst_local, - ARGS... args) { - ASSIGN_OR_RETURN(auto src_buffer, - src_local->buffer->MapMemory<SRC>(MemoryAccess::kRead)); - ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<DST>( - MemoryAccess::kDiscardWrite)); - return KERNEL::Execute(src_buffer.contents(), - dst_buffer.mutable_contents(), args...); - } - }; - -// Disable F32/F64 conversions if they are not supported. -#if !defined(IREE_SUPPORT_F32) - template <typename DST> - struct Thunk<float, DST> { - static Status Apply(BufferView* src_local, BufferView* dst_local, - ARGS... args) { - return UnimplementedErrorBuilder(IREE_LOC) << "F32 not supported"; - } - }; - template <typename SRC> - struct Thunk<SRC, float> { - static Status Apply(BufferView* src_local, BufferView* dst_local, - ARGS... args) { - return UnimplementedErrorBuilder(IREE_LOC) << "F32 not supported"; - } - }; -#endif // !IREE_SUPPORT_F32 -#if !defined(IREE_SUPPORT_F64) - template <typename DST> - struct Thunk<double, DST> { - static Status Apply(BufferView* src_local, BufferView* dst_local, - ARGS... args) { - return UnimplementedErrorBuilder(IREE_LOC) << "F64 not supported"; - } - }; - template <typename SRC> - struct Thunk<SRC, double> { - static Status Apply(BufferView* src_local, BufferView* dst_local, - ARGS... args) { - return UnimplementedErrorBuilder(IREE_LOC) << "F64 not supported"; - } - }; -#endif // !IREE_SUPPORT_F64 -}; - -using ApplyConvertSS = ApplyConversionOp<kernels::Convert, /*src_signed=*/true, - /*dst_signed=*/true>; -using ApplyConvertUU = ApplyConversionOp<kernels::Convert, /*src_signed=*/false, - /*dst_signed=*/false>; -using ApplyConvertSU = ApplyConversionOp<kernels::Convert, /*src_signed=*/true, - /*dst_signed=*/false>; -using ApplyConvertUS = ApplyConversionOp<kernels::Convert, /*src_signed=*/false, - /*dst_signed=*/true>; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_CONVERSION_H_
diff --git a/iree/hal/interpreter/bytecode_dispatch_util.cc b/iree/hal/interpreter/bytecode_dispatch_util.cc deleted file mode 100644 index 40937e1..0000000 --- a/iree/hal/interpreter/bytecode_dispatch_util.cc +++ /dev/null
@@ -1,107 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/interpreter/bytecode_dispatch_util.h" - -namespace iree { -namespace hal { - -bool BufferViewIsTrue(const BufferView& buffer_view) { - if (buffer_view.element_size == 0 || !buffer_view.buffer || - buffer_view.byte_length() == 0) { - return false; - } - // TODO(benvanik): map more efficiently (based on element size?). - auto mapping = - buffer_view.buffer->MapMemory<uint8_t>(hal::MemoryAccess::kRead); - if (!mapping.ok()) { - return false; - } - for (uint8_t value : mapping.ValueOrDie().contents()) { - if (value) return true; - } - return false; -} - -Status ValidateElementwiseUnaryOp(BufferView* src_local, - BufferView* dst_local) { - // TODO(benvanik): validate shapes. - return OkStatus(); -} - -Status ValidateElementwiseBinaryOp(BufferView* lhs_local, BufferView* rhs_local, - BufferView* dst_local) { - // TODO(benvanik): validate shapes. - return OkStatus(); -} - -Status ValidateElementwiseTernaryOp(BufferView* a_local, BufferView* b_local, - BufferView* c_local, - BufferView* dst_local) { - // TODO(benvanik): validate shapes. - return OkStatus(); -} - -Status ValidateMatMulOpI(BufferView* lhs_local, BufferView* rhs_local, - BufferView* bias_local, - BufferView* multiplier_mantissa_local, - BufferView* multiplier_exponent_local, - BufferView* dst_local) { - // TODO(benvanik): validate shapes. - return OkStatus(); -} - -Status ValidateMatMulOpF(BufferView* lhs_local, BufferView* rhs_local, - BufferView* bias_local, BufferView* dst_local) { - // TODO(benvanik): validate shapes. - return OkStatus(); -} - -Status ApplyCopy(BufferView* src_local, absl::Span<const int32_t> src_indices, - BufferView* dst_local, absl::Span<const int32_t> dst_indices, - absl::Span<const int32_t> lengths) { - ASSIGN_OR_RETURN(auto src_buffer, - src_local->buffer->MapMemory<uint8_t>(MemoryAccess::kRead)); - // TODO(benvanik): discard if overwriting the entire buffer. - ASSIGN_OR_RETURN(auto dst_buffer, - dst_local->buffer->MapMemory<uint8_t>(MemoryAccess::kWrite)); - switch (src_local->element_size) { - case 1: - return kernels::Copy::Execute<1>(src_buffer.contents(), src_local->shape, - src_indices, - dst_buffer.mutable_contents(), - dst_local->shape, dst_indices, lengths); - case 2: - return kernels::Copy::Execute<2>(src_buffer.contents(), src_local->shape, - src_indices, - dst_buffer.mutable_contents(), - dst_local->shape, dst_indices, lengths); - case 4: - return kernels::Copy::Execute<4>(src_buffer.contents(), src_local->shape, - src_indices, - dst_buffer.mutable_contents(), - dst_local->shape, dst_indices, lengths); - case 8: - return kernels::Copy::Execute<8>(src_buffer.contents(), src_local->shape, - src_indices, - dst_buffer.mutable_contents(), - dst_local->shape, dst_indices, lengths); - default: - return UnimplementedErrorBuilder(IREE_LOC) - << "Unimplemented element size: " << src_local->element_size; - } -} - -} // namespace hal -} // namespace iree
diff --git a/iree/hal/interpreter/bytecode_dispatch_util.h b/iree/hal/interpreter/bytecode_dispatch_util.h deleted file mode 100644 index cb8d7d9..0000000 --- a/iree/hal/interpreter/bytecode_dispatch_util.h +++ /dev/null
@@ -1,513 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Utilities used by the bytecode_dispatch routines to aid in working with the -// bytecode stream and kernel dispatch. - -#ifndef IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_UTIL_H_ -#define IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_UTIL_H_ - -#include "absl/base/attributes.h" -#include "absl/container/inlined_vector.h" -#include "iree/base/status.h" -#include "iree/hal/buffer_view.h" -#include "iree/hal/heap_buffer.h" -#include "iree/hal/interpreter/bytecode_kernels.h" -#include "iree/rt/function.h" -#include "iree/rt/stack.h" -#include "iree/schemas/bytecode/interpreter_bytecode_v0.h" -#include "iree/vm/bytecode_reader.h" -#include "iree/vm/type.h" - -// TODO(benvanik): move to dedicated config file/build flags. -#define IREE_SUPPORT_F32 1 -#define IREE_SUPPORT_F64 1 - -namespace iree { -namespace hal { - -// Returns true if the contents of the BufferView are bitwise non-zero. -// Returns false if there is no buffer, the buffer is empty, or the contents are -// bitwise zero. -bool BufferViewIsTrue(const BufferView& buffer_view); - -Status ValidateElementwiseUnaryOp(BufferView* src_local, BufferView* dst_local); -Status ValidateElementwiseBinaryOp(BufferView* lhs_local, BufferView* rhs_local, - BufferView* dst_local); -Status ValidateElementwiseTernaryOp(BufferView* a_local, BufferView* b_local, - BufferView* c_local, BufferView* dst_local); -Status ValidateMatMulOpI(BufferView* lhs_local, BufferView* rhs_local, - BufferView* bias_local, - BufferView* multiplier_mantissa_local, - BufferView* multiplier_exponent_local, - BufferView* dst_local); -Status ValidateMatMulOpF(BufferView* lhs_local, BufferView* rhs_local, - BufferView* bias_local, BufferView* dst_local); - -template <typename KERNEL, typename T, typename... ARGS> -Status ApplyUnaryOp(BufferView* src_local, BufferView* dst_local, - ARGS... args) { - // TODO(benvanik): avoid mapping by changing buffer type? - ASSIGN_OR_RETURN(auto src_buffer, - src_local->buffer->MapMemory<T>(MemoryAccess::kRead)); - ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<T>( - MemoryAccess::kDiscardWrite)); - return KERNEL::Execute(src_buffer.contents(), dst_buffer.mutable_contents(), - args...); -} - -template <typename KERNEL, typename T, typename... ARGS> -Status ApplyBinaryOp(BufferView* lhs_local, BufferView* rhs_local, - BufferView* dst_local, ARGS... args) { - ASSIGN_OR_RETURN(auto lhs_buffer, - lhs_local->buffer->MapMemory<T>(MemoryAccess::kRead)); - ASSIGN_OR_RETURN(auto rhs_buffer, - rhs_local->buffer->MapMemory<T>(MemoryAccess::kRead)); - ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<T>( - MemoryAccess::kDiscardWrite)); - return KERNEL::Execute(lhs_buffer.contents(), rhs_buffer.contents(), - dst_buffer.mutable_contents(), args...); -} - -template <typename KERNEL, typename T, typename... ARGS> -Status ApplyTernaryOp(BufferView* a_local, BufferView* b_local, - BufferView* c_local, BufferView* dst_local, - ARGS... args) { - ASSIGN_OR_RETURN(auto a_buffer, - a_local->buffer->MapMemory<T>(MemoryAccess::kRead)); - ASSIGN_OR_RETURN(auto b_buffer, - b_local->buffer->MapMemory<T>(MemoryAccess::kRead)); - ASSIGN_OR_RETURN(auto c_buffer, - c_local->buffer->MapMemory<T>(MemoryAccess::kRead)); - ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<T>( - MemoryAccess::kDiscardWrite)); - return KERNEL::Execute(a_buffer.contents(), b_buffer.contents(), - c_buffer.contents(), dst_buffer.mutable_contents(), - args...); -} - -template <typename KERNEL, typename T> -Status ApplyComparisonOp(BufferView* lhs_local, BufferView* rhs_local, - BufferView* dst_local) { - ASSIGN_OR_RETURN(auto lhs_buffer, - lhs_local->buffer->MapMemory<T>(MemoryAccess::kRead)); - ASSIGN_OR_RETURN(auto rhs_buffer, - rhs_local->buffer->MapMemory<T>(MemoryAccess::kRead)); - ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<uint8_t>( - MemoryAccess::kDiscardWrite)); - return KERNEL::Execute(lhs_buffer.contents(), rhs_buffer.contents(), - dst_buffer.mutable_contents()); -} - -template <typename KERNEL, typename... ARGS> -Status ApplyUnaryOpIS(BufferView* src_local, BufferView* dst_local, - ARGS... args) { - switch (src_local->element_size) { - case 1: - return ApplyUnaryOp<KERNEL, int8_t>(src_local, dst_local, args...); - case 2: - return ApplyUnaryOp<KERNEL, int16_t>(src_local, dst_local, args...); - case 4: - return ApplyUnaryOp<KERNEL, int32_t>(src_local, dst_local, args...); - case 8: - return ApplyUnaryOp<KERNEL, int64_t>(src_local, dst_local, args...); - default: - return UnimplementedErrorBuilder(IREE_LOC) - << "Unimplemented element size: " << src_local->element_size; - } -} - -template <typename KERNEL, typename... ARGS> -Status ApplyUnaryOpIU(BufferView* src_local, BufferView* dst_local, - ARGS... args) { - switch (src_local->element_size) { - case 1: - return ApplyUnaryOp<KERNEL, uint8_t>(src_local, dst_local, args...); - case 2: - return ApplyUnaryOp<KERNEL, uint16_t>(src_local, dst_local, args...); - case 4: - return ApplyUnaryOp<KERNEL, uint32_t>(src_local, dst_local, args...); - case 8: - return ApplyUnaryOp<KERNEL, uint64_t>(src_local, dst_local, args...); - default: - return UnimplementedErrorBuilder(IREE_LOC) - << "Unimplemented element size: " << src_local->element_size; - } -} - -template <typename KERNEL, typename... ARGS> -Status ApplyUnaryOpF(BufferView* src_local, BufferView* dst_local, - ARGS... args) { - switch (src_local->element_size) { -#if defined(IREE_SUPPORT_F32) - case 4: - return ApplyUnaryOp<KERNEL, float>(src_local, dst_local, args...); -#endif // IREE_SUPPORT_F32 -#if defined(IREE_SUPPORT_F64) - case 8: - return ApplyUnaryOp<KERNEL, double>(src_local, dst_local, args...); -#endif // IREE_SUPPORT_F64 - default: - return UnimplementedErrorBuilder(IREE_LOC) - << "Unimplemented element size: " << src_local->element_size; - } -} - -template <typename KERNEL, typename... ARGS> -Status ApplyBinaryOpIS(BufferView* lhs_local, BufferView* rhs_local, - BufferView* dst_local, ARGS... args) { - switch (lhs_local->element_size) { - case 1: - return ApplyBinaryOp<KERNEL, int8_t>(lhs_local, rhs_local, dst_local, - args...); - case 2: - return ApplyBinaryOp<KERNEL, int16_t>(lhs_local, rhs_local, dst_local, - args...); - case 4: - return ApplyBinaryOp<KERNEL, int32_t>(lhs_local, rhs_local, dst_local, - args...); - case 8: - return ApplyBinaryOp<KERNEL, int64_t>(lhs_local, rhs_local, dst_local, - args...); - default: - return UnimplementedErrorBuilder(IREE_LOC) - << "Unimplemented element size: " << lhs_local->element_size; - } -} - -template <typename KERNEL, typename... ARGS> -Status ApplyBinaryOpIU(BufferView* lhs_local, BufferView* rhs_local, - BufferView* dst_local, ARGS... args) { - switch (lhs_local->element_size) { - case 1: - return ApplyBinaryOp<KERNEL, uint8_t>(lhs_local, rhs_local, dst_local, - args...); - case 2: - return ApplyBinaryOp<KERNEL, uint16_t>(lhs_local, rhs_local, dst_local, - args...); - case 4: - return ApplyBinaryOp<KERNEL, uint32_t>(lhs_local, rhs_local, dst_local, - args...); - case 8: - return ApplyBinaryOp<KERNEL, uint64_t>(lhs_local, rhs_local, dst_local, - args...); - default: - return UnimplementedErrorBuilder(IREE_LOC) - << "Unimplemented element size: " << lhs_local->element_size; - } -} - -template <typename KERNEL, typename... ARGS> -Status ApplyBinaryOpF(BufferView* lhs_local, BufferView* rhs_local, - BufferView* dst_local, ARGS... args) { - switch (lhs_local->element_size) { -#if defined(IREE_SUPPORT_F32) - case 4: - return ApplyBinaryOp<KERNEL, float>(lhs_local, rhs_local, dst_local, - args...); -#endif // IREE_SUPPORT_F32 -#if defined(IREE_SUPPORT_F64) - case 8: - return ApplyBinaryOp<KERNEL, double>(lhs_local, rhs_local, dst_local, - args...); -#endif // IREE_SUPPORT_F64 - default: - return UnimplementedErrorBuilder(IREE_LOC) - << "Unimplemented element size: " << lhs_local->element_size; - } -} - -template <typename KERNEL, typename... ARGS> -Status ApplyTernaryOpIS(BufferView* a_local, BufferView* b_local, - BufferView* c_local, BufferView* dst_local, - ARGS... args) { - switch (a_local->element_size) { - case 1: - return ApplyTernaryOp<KERNEL, int8_t>(a_local, b_local, c_local, - dst_local, args...); - case 2: - return ApplyTernaryOp<KERNEL, int16_t>(a_local, b_local, c_local, - dst_local, args...); - case 4: - return ApplyTernaryOp<KERNEL, int32_t>(a_local, b_local, c_local, - dst_local, args...); - case 8: - return ApplyTernaryOp<KERNEL, int64_t>(a_local, b_local, c_local, - dst_local, args...); - default: - return UnimplementedErrorBuilder(IREE_LOC) - << "Unimplemented element size: " << a_local->element_size; - } -} - -template <typename KERNEL, typename... ARGS> -Status ApplyTernaryOpIU(BufferView* a_local, BufferView* b_local, - BufferView* c_local, BufferView* dst_local, - ARGS... args) { - switch (a_local->element_size) { - case 1: - return ApplyTernaryOp<KERNEL, uint8_t>(a_local, b_local, c_local, - dst_local, args...); - case 2: - return ApplyTernaryOp<KERNEL, uint16_t>(a_local, b_local, c_local, - dst_local, args...); - case 4: - return ApplyTernaryOp<KERNEL, uint32_t>(a_local, b_local, c_local, - dst_local, args...); - case 8: - return ApplyTernaryOp<KERNEL, uint64_t>(a_local, b_local, c_local, - dst_local, args...); - default: - return UnimplementedErrorBuilder(IREE_LOC) - << "Unimplemented element size: " << a_local->element_size; - } -} - -template <typename KERNEL, typename... ARGS> -Status ApplyTernaryOpF(BufferView* a_local, BufferView* b_local, - BufferView* c_local, BufferView* dst_local, - ARGS... args) { - switch (a_local->element_size) { -#if defined(IREE_SUPPORT_F32) - case 4: - return ApplyTernaryOp<KERNEL, float>(a_local, b_local, c_local, dst_local, - args...); -#endif // IREE_SUPPORT_F32 -#if defined(IREE_SUPPORT_F64) - case 8: - return ApplyTernaryOp<KERNEL, double>(a_local, b_local, c_local, - dst_local, args...); -#endif // IREE_SUPPORT_F64 - default: - return UnimplementedErrorBuilder(IREE_LOC) - << "Unimplemented element size: " << a_local->element_size; - } -} - -template <typename KERNEL> -Status ApplyComparisonOpIS(BufferView* lhs_local, BufferView* rhs_local, - BufferView* dst_local) { - switch (lhs_local->element_size) { - case 1: - return ApplyComparisonOp<KERNEL, int8_t>(lhs_local, rhs_local, dst_local); - case 2: - return ApplyComparisonOp<KERNEL, int16_t>(lhs_local, rhs_local, - dst_local); - case 4: - return ApplyComparisonOp<KERNEL, int32_t>(lhs_local, rhs_local, - dst_local); - case 8: - return ApplyComparisonOp<KERNEL, int64_t>(lhs_local, rhs_local, - dst_local); - default: - return UnimplementedErrorBuilder(IREE_LOC) - << "Unimplemented element size: " << lhs_local->element_size; - } -} - -template <typename KERNEL> -Status ApplyComparisonOpIU(BufferView* lhs_local, BufferView* rhs_local, - BufferView* dst_local) { - switch (lhs_local->element_size) { - case 1: - return ApplyComparisonOp<KERNEL, uint8_t>(lhs_local, rhs_local, - dst_local); - case 2: - return ApplyComparisonOp<KERNEL, uint16_t>(lhs_local, rhs_local, - dst_local); - case 4: - return ApplyComparisonOp<KERNEL, uint32_t>(lhs_local, rhs_local, - dst_local); - case 8: - return ApplyComparisonOp<KERNEL, uint64_t>(lhs_local, rhs_local, - dst_local); - default: - return UnimplementedErrorBuilder(IREE_LOC) - << "Unimplemented element size: " << lhs_local->element_size; - } -} - -template <typename KERNEL> -Status ApplyComparisonOpF(BufferView* lhs_local, BufferView* rhs_local, - BufferView* dst_local) { - switch (lhs_local->element_size) { - case 4: - return ApplyComparisonOp<KERNEL, float>(lhs_local, rhs_local, dst_local); - case 8: - return ApplyComparisonOp<KERNEL, double>(lhs_local, rhs_local, dst_local); - default: - return UnimplementedErrorBuilder(IREE_LOC) - << "Unimplemented element size: " << lhs_local->element_size; - } -} - -template <typename T, typename ACC = int32_t> -Status ApplyMatMulOpI(kernels::MatMul::RuntimeState* runtime_state, - BufferView* lhs_local, BufferView* rhs_local, - BufferView* bias_local, - BufferView* multiplier_mantissa_local, - BufferView* multiplier_exponent_local, - BufferView* dst_local) { - kernels::MatMul::Buffers<T, ACC> buffers; - ASSIGN_OR_RETURN(auto lhs_buffer, - lhs_local->buffer->MapMemory<T>(MemoryAccess::kRead)); - buffers.lhs_buffer = lhs_buffer.contents(); - buffers.lhs_shape = lhs_local->shape; - ASSIGN_OR_RETURN(auto rhs_buffer, - rhs_local->buffer->MapMemory<T>(MemoryAccess::kRead)); - buffers.rhs_buffer = rhs_buffer.contents(); - buffers.rhs_shape = rhs_local->shape; - MappedMemory<ACC> bias_buffer; - if (bias_local && bias_local->buffer && !bias_local->shape.empty()) { - if (bias_local->element_size != sizeof(ACC)) { - return UnimplementedErrorBuilder(IREE_LOC) - << "Only " << sizeof(ACC) << "b biases are supported right now"; - } - ASSIGN_OR_RETURN(bias_buffer, - bias_local->buffer->MapMemory<ACC>(MemoryAccess::kRead)); - buffers.bias_buffer = bias_buffer.contents(); - } - ASSIGN_OR_RETURN( - auto multiplier_mantissa_buffer, - multiplier_mantissa_local->buffer->MapMemory<ACC>(MemoryAccess::kRead)); - buffers.multiplier_mantissa_buffer = multiplier_mantissa_buffer.contents(); - ASSIGN_OR_RETURN(auto multiplier_exponent_buffer, - multiplier_exponent_local->buffer->MapMemory<int32_t>( - MemoryAccess::kRead)); - buffers.multiplier_exponent_buffer = multiplier_exponent_buffer.contents(); - ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<T>( - MemoryAccess::kDiscardWrite)); - buffers.dst_buffer = dst_buffer.mutable_contents(); - buffers.dst_shape = dst_local->shape; - return kernels::MatMul::Execute(runtime_state, buffers); -} - -template <typename T> -Status ApplyMatMulOpF(kernels::MatMul::RuntimeState* runtime_state, - BufferView* lhs_local, BufferView* rhs_local, - BufferView* bias_local, BufferView* dst_local) { - kernels::MatMul::Buffers<T, T> buffers; - ASSIGN_OR_RETURN(auto lhs_buffer, - lhs_local->buffer->MapMemory<T>(MemoryAccess::kRead)); - buffers.lhs_buffer = lhs_buffer.contents(); - buffers.lhs_shape = lhs_local->shape; - ASSIGN_OR_RETURN(auto rhs_buffer, - rhs_local->buffer->MapMemory<T>(MemoryAccess::kRead)); - buffers.rhs_buffer = rhs_buffer.contents(); - buffers.rhs_shape = rhs_local->shape; - MappedMemory<T> bias_buffer; - if (bias_local && bias_local->buffer && !bias_local->shape.empty()) { - ASSIGN_OR_RETURN(bias_buffer, - bias_local->buffer->MapMemory<T>(MemoryAccess::kRead)); - buffers.bias_buffer = bias_buffer.contents(); - } - ASSIGN_OR_RETURN(auto dst_buffer, dst_local->buffer->MapMemory<T>( - MemoryAccess::kDiscardWrite)); - buffers.dst_buffer = dst_buffer.mutable_contents(); - buffers.dst_shape = dst_local->shape; - return kernels::MatMul::Execute(runtime_state, buffers); -} - -template <typename KERNEL> -Status DispatchElementwiseUnaryOpIS(vm::BytecodeReader* reader) { - ASSIGN_OR_RETURN(auto* src_local, reader->ReadLocal()); - ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal()); - RETURN_IF_ERROR(ValidateElementwiseUnaryOp(src_local, dst_local)); - return ApplyUnaryOpIS<KERNEL>(src_local, dst_local); -} - -template <typename KERNEL> -Status DispatchElementwiseUnaryOpIU(vm::BytecodeReader* reader) { - ASSIGN_OR_RETURN(auto* src_local, reader->ReadLocal()); - ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal()); - RETURN_IF_ERROR(ValidateElementwiseUnaryOp(src_local, dst_local)); - return ApplyUnaryOpIU<KERNEL>(src_local, dst_local); -} - -template <typename KERNEL> -Status DispatchElementwiseUnaryOpF(vm::BytecodeReader* reader) { - ASSIGN_OR_RETURN(auto* src_local, reader->ReadLocal()); - ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal()); - RETURN_IF_ERROR(ValidateElementwiseUnaryOp(src_local, dst_local)); - return ApplyUnaryOpF<KERNEL>(src_local, dst_local); -} - -template <typename KERNEL> -Status DispatchElementwiseBinaryOpIS(vm::BytecodeReader* reader) { - ASSIGN_OR_RETURN(auto* lhs_local, reader->ReadLocal()); - ASSIGN_OR_RETURN(auto* rhs_local, reader->ReadLocal()); - ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal()); - RETURN_IF_ERROR(ValidateElementwiseBinaryOp(lhs_local, rhs_local, dst_local)); - return ApplyBinaryOpIS<KERNEL>(lhs_local, rhs_local, dst_local); -} - -template <typename KERNEL> -Status DispatchElementwiseBinaryOpIU(vm::BytecodeReader* reader) { - ASSIGN_OR_RETURN(auto* lhs_local, reader->ReadLocal()); - ASSIGN_OR_RETURN(auto* rhs_local, reader->ReadLocal()); - ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal()); - RETURN_IF_ERROR(ValidateElementwiseBinaryOp(lhs_local, rhs_local, dst_local)); - return ApplyBinaryOpIU<KERNEL>(lhs_local, rhs_local, dst_local); -} - -template <typename KERNEL> -Status DispatchElementwiseBinaryOpF(vm::BytecodeReader* reader) { - ASSIGN_OR_RETURN(auto* lhs_local, reader->ReadLocal()); - ASSIGN_OR_RETURN(auto* rhs_local, reader->ReadLocal()); - ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal()); - RETURN_IF_ERROR(ValidateElementwiseBinaryOp(lhs_local, rhs_local, dst_local)); - return ApplyBinaryOpF<KERNEL>(lhs_local, rhs_local, dst_local); -} - -template <typename KERNEL> -Status DispatchElementwiseTernaryOpIS(vm::BytecodeReader* reader) { - ASSIGN_OR_RETURN(auto* a_local, reader->ReadLocal()); - ASSIGN_OR_RETURN(auto* b_local, reader->ReadLocal()); - ASSIGN_OR_RETURN(auto* c_local, reader->ReadLocal()); - ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal()); - RETURN_IF_ERROR( - ValidateElementwiseTernaryOp(a_local, b_local, c_local, dst_local)); - return ApplyTernaryOpIS<KERNEL>(a_local, b_local, c_local, dst_local); -} - -template <typename KERNEL> -Status DispatchElementwiseTernaryOpIU(vm::BytecodeReader* reader) { - ASSIGN_OR_RETURN(auto* a_local, reader->ReadLocal()); - ASSIGN_OR_RETURN(auto* b_local, reader->ReadLocal()); - ASSIGN_OR_RETURN(auto* c_local, reader->ReadLocal()); - ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal()); - RETURN_IF_ERROR( - ValidateElementwiseTernaryOp(a_local, b_local, c_local, dst_local)); - return ApplyTernaryOpIU<KERNEL>(a_local, b_local, c_local, dst_local); -} - -template <typename KERNEL> -Status DispatchElementwiseTernaryOpF(vm::BytecodeReader* reader) { - ASSIGN_OR_RETURN(auto* a_local, reader->ReadLocal()); - ASSIGN_OR_RETURN(auto* b_local, reader->ReadLocal()); - ASSIGN_OR_RETURN(auto* c_local, reader->ReadLocal()); - ASSIGN_OR_RETURN(auto* dst_local, reader->ReadLocal()); - RETURN_IF_ERROR( - ValidateElementwiseTernaryOp(a_local, b_local, c_local, dst_local)); - return ApplyTernaryOpF<KERNEL>(a_local, b_local, c_local, dst_local); -} - -Status ApplyCopy(BufferView* src_local, absl::Span<const int32_t> src_indices, - BufferView* dst_local, absl::Span<const int32_t> dst_indices, - absl::Span<const int32_t> lengths); - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_INTERPRETER_BYTECODE_DISPATCH_UTIL_H_
diff --git a/iree/hal/interpreter/bytecode_executable.cc b/iree/hal/interpreter/bytecode_executable.cc deleted file mode 100644 index c528263..0000000 --- a/iree/hal/interpreter/bytecode_executable.cc +++ /dev/null
@@ -1,64 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/interpreter/bytecode_executable.h" - -#include <iostream> - -#include "iree/hal/interpreter/interpreter_module.h" -#include "iree/rt/policy.h" - -namespace iree { -namespace hal { - -// static -StatusOr<ref_ptr<BytecodeExecutable>> BytecodeExecutable::Load( - ref_ptr<rt::Instance> instance, hal::Allocator* allocator, - ExecutableSpec spec, bool allow_aliasing_data) { - // Allocate the executable now. - // We do this here so that if we need to clone the data we are passing that - // to the VM loader instead of the data we may not have access to later. - auto executable = make_ref<BytecodeExecutable>(std::move(instance), allocator, - spec, allow_aliasing_data); - - // Create the executable module. - auto module_def = - ::flatbuffers::GetRoot<ModuleDef>(executable->executable_data().data()); - ASSIGN_OR_RETURN(auto module, - InterpreterModule::FromDef(allocator, *module_def)); - executable->module_ = add_ref(module); - RETURN_IF_ERROR(executable->context()->RegisterModule(std::move(module))); - - return executable; -} - -BytecodeExecutable::BytecodeExecutable(ref_ptr<rt::Instance> instance, - hal::Allocator* allocator, - ExecutableSpec spec, - bool allow_aliasing_data) - : spec_(spec), - context_( - make_ref<rt::Context>(std::move(instance), make_ref<rt::Policy>())) { - if (!allow_aliasing_data) { - // Clone data. - cloned_executable_data_ = {spec.executable_data.begin(), - spec.executable_data.end()}; - spec_.executable_data = absl::MakeConstSpan(cloned_executable_data_); - } -} - -BytecodeExecutable::~BytecodeExecutable() = default; - -} // namespace hal -} // namespace iree
diff --git a/iree/hal/interpreter/bytecode_executable.h b/iree/hal/interpreter/bytecode_executable.h deleted file mode 100644 index 6cd47de..0000000 --- a/iree/hal/interpreter/bytecode_executable.h +++ /dev/null
@@ -1,68 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_INTERPRETER_BYTECODE_EXECUTABLE_H_ -#define IREE_HAL_INTERPRETER_BYTECODE_EXECUTABLE_H_ - -#include <vector> - -#include "absl/types/span.h" -#include "iree/base/status.h" -#include "iree/hal/allocator.h" -#include "iree/hal/executable.h" -#include "iree/hal/executable_spec.h" -#include "iree/rt/context.h" -#include "iree/rt/instance.h" -#include "iree/rt/module.h" - -namespace iree { -namespace hal { - -class BytecodeExecutable final : public Executable { - public: - static StatusOr<ref_ptr<BytecodeExecutable>> Load( - ref_ptr<rt::Instance> instance, hal::Allocator* allocator, - ExecutableSpec spec, bool allow_aliasing_data); - - BytecodeExecutable(ref_ptr<rt::Instance> instance, hal::Allocator* allocator, - ExecutableSpec spec, bool allow_aliasing_data); - ~BytecodeExecutable() override; - - bool supports_debugging() const override { return false; } - - // Reference to the bytecode blob contents. - absl::Span<const uint8_t> executable_data() const { - return spec_.executable_data; - } - - // VM context with the executable registered. - const ref_ptr<rt::Context>& context() const { return context_; } - - // VM module representing the executable. - // Note that there may be more than one module in the Context and only this - // module can be used to lookup executable exports. - const ref_ptr<rt::Module>& module() const { return module_; } - - private: - ExecutableSpec spec_; - std::vector<uint8_t> cloned_executable_data_; - - ref_ptr<rt::Context> context_; - ref_ptr<rt::Module> module_; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_INTERPRETER_BYTECODE_EXECUTABLE_H_
diff --git a/iree/hal/interpreter/bytecode_kernels.h b/iree/hal/interpreter/bytecode_kernels.h deleted file mode 100644 index 8485b98..0000000 --- a/iree/hal/interpreter/bytecode_kernels.h +++ /dev/null
@@ -1,371 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Defines kernel functions and provides their implementation via one (or more) -// included files. -// -// Kernels should do the simplest possible operation. Buffer validation is -// handled by the dispatch logic and need not be checked. Kernels may optionally -// accept arguments beyond just the buffers, depending on the required state -// and attributes. -// -// Kernels may optionally have runtime state. This is state that is allocated -// once for the entire Runtime (and stored on RuntimeState) and shared across -// all fibers. This enables kernels that may require thread pools or device -// handles to be shared while kernels that require transient storage to be safe -// to use from multiple fibers concurrently. -// -// All kernels are templated to enable specialization of particular types or -// type combinations. By default the bytecode_kernels_generic.h will provide C++ -// semantics as reference and platform-specific versions can be implemented -// as needed. - -#ifndef IREE_HAL_INTERPRETER_BYTECODE_KERNELS_H_ -#define IREE_HAL_INTERPRETER_BYTECODE_KERNELS_H_ - -#include <cstdint> - -#include "absl/types/span.h" -#include "iree/base/shape.h" -#include "iree/base/status.h" - -namespace iree { -namespace hal { -namespace kernels { - -struct CompareEQ { - template <typename T> - static Status Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, - absl::Span<uint8_t> dst_buffer); -}; -struct CompareNE { - template <typename T> - static Status Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, - absl::Span<uint8_t> dst_buffer); -}; -struct CompareLT { - template <typename T> - static Status Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, - absl::Span<uint8_t> dst_buffer); -}; -struct CompareLE { - template <typename T> - static Status Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, - absl::Span<uint8_t> dst_buffer); -}; -struct CompareGT { - template <typename T> - static Status Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, - absl::Span<uint8_t> dst_buffer); -}; -struct CompareGE { - template <typename T> - static Status Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, - absl::Span<uint8_t> dst_buffer); -}; - -struct Copy { - template <int element_size> - static Status Execute(absl::Span<const uint8_t> src_buffer, - const Shape& src_shape, - absl::Span<const int32_t> src_indices, - absl::Span<uint8_t> dst_buffer, const Shape& dst_shape, - absl::Span<const int32_t> dst_indices, - absl::Span<const int32_t> lengths); -}; - -struct Select { - template <typename T> - static Status Execute(absl::Span<const uint8_t> cond_buffer, - absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, - absl::Span<T> dst_buffer); -}; - -struct Transpose { - template <typename T> - static Status Execute(absl::Span<const T> src_buffer, - absl::Span<T> dst_buffer, const Shape& src_shape, - absl::Span<const int32_t> perm); -}; - -struct Pad { - template <typename T> - static Status Execute(absl::Span<const T> src_buffer, - absl::Span<const T> padding_value, - absl::Span<T> dst_buffer, const Shape& src_shape, - const Shape& dst_shape, - absl::Span<const int32_t> edge_padding_low, - absl::Span<const int32_t> edge_padding_high, - absl::Span<const int32_t> interior_padding); -}; - -struct Reverse { - template <typename T> - static Status Execute(absl::Span<const T> src_buffer, - absl::Span<T> dst_buffer, const Shape& src_shape, - absl::Span<const int32_t> dimensions); -}; - -struct Broadcast { - template <typename T> - static Status Execute(absl::Span<const T> src_buffer, - absl::Span<T> dst_buffer); -}; - -struct Tile { - template <typename T> - static Status Execute(absl::Span<const T> src_buffer, - absl::Span<T> dst_buffer, const Shape& src_shape, - const Shape& dst_shape); -}; - -struct Not { - template <typename T> - static Status Execute(absl::Span<const T> src_buffer, - absl::Span<T> dst_buffer); -}; - -struct And { - template <typename T> - static Status Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, - absl::Span<T> dst_buffer); -}; - -struct Or { - template <typename T> - static Status Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, - absl::Span<T> dst_buffer); -}; - -struct Xor { - template <typename T> - static Status Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, - absl::Span<T> dst_buffer); -}; - -struct ShiftLeft { - template <typename T> - static Status Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, - absl::Span<T> dst_buffer); -}; - -struct ShiftRight { - template <typename T> - static Status Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, - absl::Span<T> dst_buffer); -}; - -struct Add { - template <typename T> - static Status Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, - absl::Span<T> dst_buffer); -}; - -struct Sub { - template <typename T> - static Status Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, - absl::Span<T> dst_buffer); -}; - -struct Abs { - template <typename T> - static Status Execute(absl::Span<const T> src_buffer, - absl::Span<T> dst_buffer); -}; - -struct Mul { - template <typename T> - static Status Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, - absl::Span<T> dst_buffer); -}; - -struct Div { - template <typename T> - static Status Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, - absl::Span<T> dst_buffer); -}; - -// a + (b * c) -struct MulAdd { - template <typename T> - static Status Execute(absl::Span<const T> a_buffer, - absl::Span<const T> b_buffer, - absl::Span<const T> c_buffer, absl::Span<T> dst_buffer); -}; - -struct Exp { - template <typename T> - static Status Execute(absl::Span<const T> src_buffer, - absl::Span<T> dst_buffer); -}; - -struct Log { - template <typename T> - static Status Execute(absl::Span<const T> src_buffer, - absl::Span<T> dst_buffer); -}; - -struct Rsqrt { - template <typename T> - static Status Execute(absl::Span<const T> src_buffer, - absl::Span<T> dst_buffer); -}; - -struct Cos { - template <typename T> - static Status Execute(absl::Span<const T> src_buffer, - absl::Span<T> dst_buffer); -}; - -struct Sin { - template <typename T> - static Status Execute(absl::Span<const T> src_buffer, - absl::Span<T> dst_buffer); -}; - -struct Tanh { - template <typename T> - static Status Execute(absl::Span<const T> src_buffer, - absl::Span<T> dst_buffer); -}; - -struct Atan2 { - template <typename T> - static Status Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, - absl::Span<T> dst_buffer); -}; - -struct Min { - template <typename T> - static Status Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, - absl::Span<T> dst_buffer); -}; - -struct Max { - template <typename T> - static Status Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, - absl::Span<T> dst_buffer); -}; - -struct Clamp { - template <typename T> - static Status Execute(absl::Span<const T> src_buffer, - absl::Span<const T> min_buffer, - absl::Span<const T> max_buffer, - absl::Span<T> dst_buffer); -}; - -struct Floor { - template <typename T> - static Status Execute(absl::Span<const T> src_buffer, - absl::Span<T> dst_buffer); -}; - -struct Ceil { - template <typename T> - static Status Execute(absl::Span<const T> src_buffer, - absl::Span<T> dst_buffer); -}; - -struct Convert { - template <typename SRC, typename DST> - static Status Execute(absl::Span<const SRC> src_buffer, - absl::Span<DST> dst_buffer); -}; - -struct MatMul { - struct RuntimeState; - - static std::unique_ptr<RuntimeState> CreateRuntimeState(); - - template <typename T, typename ACC> - struct Buffers { - Shape lhs_shape; - absl::Span<const T> lhs_buffer; - Shape rhs_shape; - absl::Span<const T> rhs_buffer; - Shape dst_shape; - absl::Span<T> dst_buffer; - - // Optional bias buffer. - absl::Span<const ACC> bias_buffer; - - // Fixed-point multiplier mantissa/exponent. May be a single value (for - // uniform quantization) or one element per row of the destination matrix - // for per-channel. - absl::Span<const ACC> multiplier_mantissa_buffer; - absl::Span<const int32_t> multiplier_exponent_buffer; - }; - - template <typename T, typename ACC> - static Status Execute(RuntimeState* runtime_state, - const Buffers<T, ACC>& buffers); -}; - -struct RuntimeState { - std::unique_ptr<MatMul::RuntimeState> mat_mul_state = - MatMul::CreateRuntimeState(); -}; - -struct ReduceSum { - template <typename T> - static Status Execute(absl::Span<const T> src_buffer, - absl::Span<const T> init_buffer, - absl::Span<T> dst_buffer, int32_t dimension, - const Shape& src_shape, const Shape& dst_shape); -}; - -struct ReduceMin { - template <typename T> - static Status Execute(absl::Span<const T> src_buffer, - absl::Span<const T> init_buffer, - absl::Span<T> dst_buffer, int32_t dimension, - const Shape& src_shape, const Shape& dst_shape); -}; - -struct ReduceMax { - template <typename T> - static Status Execute(absl::Span<const T> src_buffer, - absl::Span<const T> init_buffer, - absl::Span<T> dst_buffer, int32_t dimension, - const Shape& src_shape, const Shape& dst_shape); -}; - -} // namespace kernels -} // namespace hal -} // namespace iree - -#include "iree/hal/interpreter/bytecode_kernels_generic.h" // IWYU pragma: export -#include "iree/hal/interpreter/bytecode_kernels_ruy.h" // IWYU pragma: export - -#endif // IREE_HAL_INTERPRETER_BYTECODE_KERNELS_H_
diff --git a/iree/hal/interpreter/bytecode_kernels_generic.h b/iree/hal/interpreter/bytecode_kernels_generic.h deleted file mode 100644 index ddf69c7..0000000 --- a/iree/hal/interpreter/bytecode_kernels_generic.h +++ /dev/null
@@ -1,708 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_INTERPRETER_BYTECODE_KERNELS_GENERIC_H_ -#define IREE_HAL_INTERPRETER_BYTECODE_KERNELS_GENERIC_H_ - -#include "absl/container/flat_hash_set.h" -#include "absl/container/inlined_vector.h" -#include "absl/types/span.h" -#include "iree/base/status.h" - -namespace iree { -namespace hal { -namespace kernels { - -template <typename T> -Status CompareEQ::Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, - absl::Span<uint8_t> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = lhs_buffer[i] == rhs_buffer[i]; - } - return OkStatus(); -} - -template <typename T> -Status CompareNE::Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, - absl::Span<uint8_t> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = lhs_buffer[i] != rhs_buffer[i]; - } - return OkStatus(); -} - -template <typename T> -Status CompareLT::Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, - absl::Span<uint8_t> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = lhs_buffer[i] < rhs_buffer[i]; - } - return OkStatus(); -} - -template <typename T> -Status CompareLE::Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, - absl::Span<uint8_t> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = lhs_buffer[i] <= rhs_buffer[i]; - } - return OkStatus(); -} - -template <typename T> -Status CompareGT::Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, - absl::Span<uint8_t> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = lhs_buffer[i] > rhs_buffer[i]; - } - return OkStatus(); -} - -template <typename T> -Status CompareGE::Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, - absl::Span<uint8_t> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = lhs_buffer[i] >= rhs_buffer[i]; - } - return OkStatus(); -} - -namespace impl { -inline absl::InlinedVector<size_t, 6> ComputeCopyStrides(const Shape& shape, - size_t element_size) { - absl::InlinedVector<size_t, 6> strides(shape.size()); - strides.back() = element_size; - for (int i = shape.size() - 2; i >= 0; --i) { - strides[i] = strides[i + 1] * shape[i + 1]; - } - return strides; -} - -inline void CopyRegion(absl::Span<const uint8_t> src_buffer, - absl::Span<const size_t> src_strides, - absl::Span<const int32_t> src_indices, - absl::Span<uint8_t> dst_buffer, - absl::Span<const size_t> dst_strides, - absl::Span<const int32_t> dst_indices, - absl::Span<const int32_t> lengths) { - if (lengths.size() > 1) { - for (int i = 0; i < lengths[0]; ++i) { - size_t src_offset = src_strides[0] * (src_indices[0] + i); - size_t dst_offset = dst_strides[0] * (dst_indices[0] + i); - CopyRegion(src_buffer.subspan(src_offset), src_strides.subspan(1), - src_indices.subspan(1), dst_buffer.subspan(dst_offset), - dst_strides.subspan(1), dst_indices.subspan(1), - lengths.subspan(1)); - } - } else { - DCHECK_EQ(dst_strides.size(), 1); - DCHECK_EQ(src_strides.size(), 1); - DCHECK_EQ(src_indices.size(), 1); - DCHECK_EQ(dst_indices.size(), 1); - DCHECK_EQ(lengths.size(), 1); - auto src_offset = src_indices[0] * src_strides[0]; - auto dst_offset = dst_indices[0] * dst_strides[0]; - auto length = dst_strides[0] * lengths[0]; - std::memcpy(dst_buffer.data() + dst_offset, src_buffer.data() + src_offset, - length); - } -} -} // namespace impl - -// TODO(benvanik): replace with a real implementation once copy is defined. -// TODO(gcmn): More consistent/principled handling for scalars. -template <int element_size> -Status Copy::Execute(absl::Span<const uint8_t> src_buffer, - const Shape& src_shape, - absl::Span<const int32_t> src_indices, - absl::Span<uint8_t> dst_buffer, const Shape& dst_shape, - absl::Span<const int32_t> dst_indices, - absl::Span<const int32_t> lengths) { - DCHECK_EQ(src_indices.size(), lengths.size()); - DCHECK_EQ(dst_indices.size(), lengths.size()); - DCHECK_EQ(src_shape.size(), lengths.size()); - DCHECK_EQ(dst_shape.size(), lengths.size()); - if (lengths.empty()) { - std::memcpy(dst_buffer.data(), src_buffer.data(), element_size); - return OkStatus(); - } - - // TODO(gcmn) Maybe we can fast-path earlier if we detect contiguous memory - // across multiple rows. - auto src_strides = impl::ComputeCopyStrides(src_shape, element_size); - auto dst_strides = impl::ComputeCopyStrides(dst_shape, element_size); - DCHECK_EQ(src_strides.size(), lengths.size()); - DCHECK_EQ(dst_strides.size(), lengths.size()); - impl::CopyRegion(src_buffer, src_strides, src_indices, dst_buffer, - dst_strides, dst_indices, lengths); - return OkStatus(); -} - -template <typename T> -Status Select::Execute(absl::Span<const uint8_t> cond_buffer, - absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, - absl::Span<T> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = cond_buffer[i] ? lhs_buffer[i] : rhs_buffer[i]; - } - return OkStatus(); -} - -template <typename T> -Status Transpose::Execute(absl::Span<const T> src_buffer, - absl::Span<T> dst_buffer, const Shape& src_shape, - absl::Span<const int32_t> perm) { - // This implementation is .... not fast. - int rank = src_shape.size(); - absl::InlinedVector<int, 8> src_strides(rank); - absl::InlinedVector<int, 8> dst_strides(rank); - size_t src_stride = 1; - size_t dst_stride = 1; - for (int dim_i = rank - 1; dim_i >= 0; --dim_i) { - src_strides[dim_i] = src_stride; - dst_strides[dim_i] = dst_stride; - src_stride *= src_shape[dim_i]; - dst_stride *= src_shape[perm[dim_i]]; - } - for (size_t dst_i = 0; dst_i < dst_buffer.size(); ++dst_i) { - size_t src_i = 0; - size_t t = dst_i; - for (int dim_i = 0; dim_i < rank; ++dim_i) { - size_t ratio = t / dst_strides[dim_i]; - t -= ratio * dst_strides[dim_i]; - src_i += ratio * src_strides[perm[dim_i]]; - } - dst_buffer[dst_i] = src_buffer[src_i]; - } - return OkStatus(); -} - -namespace impl { -inline void IncrementShapeIndex(absl::Span<int32_t> indices, - const Shape& shape) { - for (int i = indices.size() - 1; i >= 0; --i) { - if (++indices[i] < shape[i]) return; - indices[i] = 0; - } -} - -inline bool IsPadding(absl::Span<const int32_t> indices, const Shape& shape, - absl::Span<const int32_t> edge_padding_low, - absl::Span<const int32_t> edge_padding_high, - absl::Span<const int32_t> interior_padding) { - for (int i = 0; i < indices.size(); ++i) { - auto index = indices[i]; - if (index < edge_padding_low[i] || - index >= shape[i] - edge_padding_high[i] || - (index - edge_padding_low[i]) % (interior_padding[i] + 1) != 0) { - return true; - } - } - - return false; -} -} // namespace impl - -template <typename T> -Status Pad::Execute(absl::Span<const T> src_buffer, - absl::Span<const T> padding_value_buffer, - absl::Span<T> dst_buffer, const Shape& src_shape, - const Shape& dst_shape, - absl::Span<const int32_t> edge_padding_low, - absl::Span<const int32_t> edge_padding_high, - absl::Span<const int32_t> interior_padding) { - // This implementation is not at all fast, as it iterates every index in the - // destination buffer individually. Potential improvements: - // 1. Fill the dst buffer with padded value initially. Only need to iterate - // through source buffer and can exit early. - // 2. Use striding to advance through larger swaths of the buffer with a - // memcpy from src and filling (or skipping) padded incides. Especially - // useful when e.g. entire rows are padded. - - // TODO(b/140836672) support negative padding - - if (padding_value_buffer.size() != 1) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Padding value buffer is larger than one element."; - } - auto padding_value = padding_value_buffer.front(); - - absl::InlinedVector<int, 8> dst_indices(src_shape.size(), 0); - - const T* src_ptr = src_buffer.begin(); - T* dst_ptr = dst_buffer.begin(); - while (dst_ptr != dst_buffer.end()) { - if (impl::IsPadding(dst_indices, dst_shape, edge_padding_low, - edge_padding_high, interior_padding)) { - *dst_ptr++ = padding_value; - } else { - DCHECK(src_ptr != src_buffer.end()); - *dst_ptr++ = *src_ptr++; - } - impl::IncrementShapeIndex(absl::MakeSpan(dst_indices), dst_shape); - } - - return OkStatus(); -} - -template <typename T> -Status Reverse::Execute(absl::Span<const T> src_buffer, - absl::Span<T> dst_buffer, const Shape& src_shape, - absl::Span<const int32_t> dimensions) { - // This implementation is not fast either - int rank = src_shape.size(); - absl::InlinedVector<int, 8> strides(rank); - size_t stride = 1; - for (int dim_i = rank - 1; dim_i >= 0; --dim_i) { - strides[dim_i] = stride; - stride *= src_shape[dim_i]; - } - absl::flat_hash_set<int32_t> dims_set(dimensions.begin(), dimensions.end()); - for (size_t dst_i = 0; dst_i < dst_buffer.size(); ++dst_i) { - size_t src_i = 0; - size_t t = dst_i; - for (int dim_i = 0; dim_i < rank; ++dim_i) { - size_t ratio = t / strides[dim_i]; - t -= ratio * strides[dim_i]; - bool do_reverse = dims_set.contains(dim_i); - src_i += (do_reverse ? (src_shape[dim_i] - 1 - ratio) : ratio) * - strides[dim_i]; - } - dst_buffer[dst_i] = src_buffer[src_i]; - } - return OkStatus(); -} - -template <typename T> -Status Broadcast::Execute(absl::Span<const T> src_buffer, - absl::Span<T> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = src_buffer[0]; - } - return OkStatus(); -} - -template <typename T> -Status Tile::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer, - const Shape& src_shape, const Shape& dst_shape) { - // This implementation is .... not fast. - int rank = dst_shape.size(); - absl::InlinedVector<int, 8> src_strides(rank); - absl::InlinedVector<int, 8> dst_strides(rank); - size_t src_stride = 1; - size_t dst_stride = 1; - for (int dim_i = rank - 1; dim_i >= 0; --dim_i) { - src_strides[dim_i] = src_stride; - dst_strides[dim_i] = dst_stride; - src_stride *= src_shape[dim_i]; - dst_stride *= dst_shape[dim_i]; - } - for (size_t dst_i = 0; dst_i < dst_buffer.size(); ++dst_i) { - size_t src_i = 0; - size_t t = dst_i; - for (int dim_i = 0; dim_i < rank; ++dim_i) { - src_i += t / dst_strides[dim_i] % src_shape[dim_i] * src_strides[dim_i]; - t %= dst_strides[dim_i]; - } - dst_buffer[dst_i] = src_buffer[src_i]; - } - return OkStatus(); -} - -template <typename T> -Status Not::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = ~src_buffer[i]; - } - return OkStatus(); -} - -template <typename T> -Status And::Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = lhs_buffer[i] & rhs_buffer[i]; - } - return OkStatus(); -} - -template <typename T> -Status Or::Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = lhs_buffer[i] | rhs_buffer[i]; - } - return OkStatus(); -} - -template <typename T> -Status Xor::Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = lhs_buffer[i] ^ rhs_buffer[i]; - } - return OkStatus(); -} - -template <typename T> -Status ShiftLeft::Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, - absl::Span<T> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = lhs_buffer[i] << rhs_buffer[i]; - } - return OkStatus(); -} - -template <typename T> -Status ShiftRight::Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, - absl::Span<T> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = lhs_buffer[i] >> rhs_buffer[i]; - } - return OkStatus(); -} - -template <typename T> -Status Add::Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = lhs_buffer[i] + rhs_buffer[i]; - } - return OkStatus(); -} - -template <typename T> -Status Sub::Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = lhs_buffer[i] - rhs_buffer[i]; - } - return OkStatus(); -} - -template <typename T> -Status Abs::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = std::abs(src_buffer[i]); - } - return OkStatus(); -} - -template <typename T> -Status Mul::Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = lhs_buffer[i] * rhs_buffer[i]; - } - return OkStatus(); -} - -template <typename T> -Status Div::Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = lhs_buffer[i] / rhs_buffer[i]; - } - return OkStatus(); -} - -template <typename T> -Status MulAdd::Execute(absl::Span<const T> a_buffer, - absl::Span<const T> b_buffer, - absl::Span<const T> c_buffer, absl::Span<T> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = a_buffer[i] + (b_buffer[i] * c_buffer[i]); - } - return OkStatus(); -} - -template <typename T> -Status Exp::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = std::exp(src_buffer[i]); - } - return OkStatus(); -} - -template <typename T> -Status Rsqrt::Execute(absl::Span<const T> src_buffer, - absl::Span<T> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = 1.0 / std::sqrt(src_buffer[i]); - } - return OkStatus(); -} - -template <typename T> -Status Log::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = std::log(src_buffer[i]); - } - return OkStatus(); -} - -template <typename T> -Status Cos::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = std::cos(src_buffer[i]); - } - return OkStatus(); -} - -template <typename T> -Status Sin::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = std::sin(src_buffer[i]); - } - return OkStatus(); -} - -template <typename T> -Status Tanh::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = std::tanh(src_buffer[i]); - } - return OkStatus(); -} - -template <typename T> -Status Atan2::Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, - absl::Span<T> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = std::atan2(lhs_buffer[i], rhs_buffer[i]); - } - return OkStatus(); -} - -template <typename T> -Status Min::Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = std::min(lhs_buffer[i], rhs_buffer[i]); - } - return OkStatus(); -} - -template <typename T> -Status Max::Execute(absl::Span<const T> lhs_buffer, - absl::Span<const T> rhs_buffer, absl::Span<T> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = std::max(lhs_buffer[i], rhs_buffer[i]); - } - return OkStatus(); -} - -template <typename T> -Status Clamp::Execute(absl::Span<const T> src_buffer, - absl::Span<const T> min_buffer, - absl::Span<const T> max_buffer, - absl::Span<T> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - T src = src_buffer[i]; - T min = min_buffer[i]; - T max = max_buffer[i]; - dst_buffer[i] = src <= min ? min : src >= max ? max : src; - } - return OkStatus(); -} - -template <typename T> -Status Floor::Execute(absl::Span<const T> src_buffer, - absl::Span<T> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = std::floor(src_buffer[i]); - } - return OkStatus(); -} - -template <typename T> -Status Ceil::Execute(absl::Span<const T> src_buffer, absl::Span<T> dst_buffer) { - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = std::ceil(src_buffer[i]); - } - return OkStatus(); -} - -template <typename SRC, typename DST> -Status Convert::Execute(absl::Span<const SRC> src_buffer, - absl::Span<DST> dst_buffer) { - DCHECK_EQ(src_buffer.size(), dst_buffer.size()); - for (size_t i = 0; i < dst_buffer.size(); ++i) { - dst_buffer[i] = static_cast<DST>(src_buffer[i]); - } - return OkStatus(); -} - -namespace impl { - -struct SumKernel { - template <typename T> - inline void operator()(T* value0, const T value1) { - *value0 += value1; - } -}; - -struct MinKernel { - template <typename T> - inline void operator()(T* value0, const T value1) { - *value0 = std::min(*value0, value1); - } -}; - -struct MaxKernel { - template <typename T> - inline void operator()(T* value0, const T value1) { - *value0 = std::max(*value0, value1); - } -}; - -template <typename T, typename KernelImpl> -inline void ReduceDimension(absl::Span<const T> src_buffer, - absl::Span<T> dst_buffer, const Shape& src_shape, - absl::Span<const int32_t> reduce_dims, - absl::Span<const int> dst_strides, int dim, - absl::Span<int> src_indices, size_t flat_src_i, - size_t src_stride) { - if (dim < 0) { - // Base case of the recursion - figure out which elements should be acted - // upon and apply the reduction kernel to them. - - // Derive destination indices from source indices. - // For example, - // reduce_dims: [1, 2] - // src_indices: [2, 1, 3, 0] - // ^ ^ - // | | - // |----- remove these dimensions - // dst_indices: [2, 0] - // - // TODO(scotttodd): Clean this up somehow, share across recursion levels? - size_t dst_size = src_shape.size() - reduce_dims.size(); - absl::InlinedVector<int, 8> dst_indices; - for (size_t i = 0; i < src_indices.size(); ++i) { - if (std::find(std::begin(reduce_dims), std::end(reduce_dims), i) == - std::end(reduce_dims)) { - dst_indices.push_back(src_indices[i]); - } - } - // Compute the flattened index into dst_buffer at [dst_indices]. - size_t dst_i = 0; - for (size_t i = 0; i < dst_indices.size(); ++i) { - dst_i += dst_indices[i] * dst_strides[dst_size - 1 - i]; - } - - // Flattened src and dst indices have been computed, invoke the kernel. - KernelImpl()(&dst_buffer[dst_i], src_buffer[flat_src_i]); - return; - } - - // Iterate through the current dimension in the source shape, recursing - // down one dimension at a time. - // - // This touches each element in the source buffer once, tracking complete - // dimensions within the shaped source buffer and using them to compute - // the corresponding indices (shaped and flattened) within the destination - // buffer. Each element in the destination buffer will be touched multiple - // times. - // - // Note that cache coherency isn't considered here, and some computations - // are redundant, so this could be optimized substantially. - for (size_t dim_i = 0; dim_i < src_shape[dim]; ++dim_i) { - src_indices[dim] = dim_i; - - // Recurse down to the next dimension (e.g. 2 -> 1 -> 0 -> base case) - // * Add the current stride to flat_src_i - // * Multiply src_stride by this dimension's shape - ReduceDimension<T, KernelImpl>(src_buffer, dst_buffer, src_shape, - reduce_dims, dst_strides, dim - 1, - src_indices, flat_src_i + dim_i * src_stride, - src_stride * src_shape[dim]); - } -} - -template <typename T, typename KernelImpl> -Status GenericReduce(absl::Span<const T> src_buffer, - absl::Span<const T> init_buffer, absl::Span<T> dst_buffer, - int32_t dimension, const Shape& src_shape, - const Shape& dst_shape) { - // Initialize using init_buffer, which is expected to be a scalar. - std::fill_n(dst_buffer.data(), dst_buffer.size(), init_buffer[0]); - - // Precompute destination strides. - int dst_rank = dst_shape.size(); - absl::InlinedVector<int, 8> dst_strides; - size_t dst_stride = 1; - for (int dim_i = dst_rank - 1; dim_i >= 0; --dim_i) { - dst_strides.push_back(dst_stride); - dst_stride *= dst_shape[dim_i]; - } - - // Call the helper (recursive) function, starting with: - // * source index [0, 0, ..., 0] - // * the innermost dimension (last in the shape) - // * flat_src_i of 0 (corresponds to [0, 0, ..., 0] above) - // * source stride 1 - absl::InlinedVector<int, 8> src_indices(src_shape.size(), 0); - ReduceDimension<T, KernelImpl>(src_buffer, dst_buffer, src_shape, {dimension}, - absl::MakeSpan(dst_strides), - src_shape.size() - 1, - absl::MakeSpan(src_indices), 0, 1); - - return OkStatus(); -} - -} // namespace impl - -template <typename T> -Status ReduceSum::Execute(absl::Span<const T> src_buffer, - absl::Span<const T> init_buffer, - absl::Span<T> dst_buffer, int32_t dimension, - const Shape& src_shape, const Shape& dst_shape) { - return impl::GenericReduce<T, impl::SumKernel>( - src_buffer, init_buffer, dst_buffer, dimension, src_shape, dst_shape); -} - -template <typename T> -Status ReduceMin::Execute(absl::Span<const T> src_buffer, - absl::Span<const T> init_buffer, - absl::Span<T> dst_buffer, int32_t dimension, - const Shape& src_shape, const Shape& dst_shape) { - return impl::GenericReduce<T, impl::MinKernel>( - src_buffer, init_buffer, dst_buffer, dimension, src_shape, dst_shape); -} - -template <typename T> -Status ReduceMax::Execute(absl::Span<const T> src_buffer, - absl::Span<const T> init_buffer, - absl::Span<T> dst_buffer, int32_t dimension, - const Shape& src_shape, const Shape& dst_shape) { - return impl::GenericReduce<T, impl::MaxKernel>( - src_buffer, init_buffer, dst_buffer, dimension, src_shape, dst_shape); -} - -} // namespace kernels -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_INTERPRETER_BYTECODE_KERNELS_GENERIC_H_
diff --git a/iree/hal/interpreter/bytecode_kernels_ruy.h b/iree/hal/interpreter/bytecode_kernels_ruy.h deleted file mode 100644 index 60bcb44..0000000 --- a/iree/hal/interpreter/bytecode_kernels_ruy.h +++ /dev/null
@@ -1,81 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_INTERPRETER_BYTECODE_KERNELS_RUY_H_ -#define IREE_HAL_INTERPRETER_BYTECODE_KERNELS_RUY_H_ - -#include "absl/base/thread_annotations.h" -#include "absl/memory/memory.h" -#include "iree/base/status.h" -#include "iree/hal/buffer_view.h" -#include "tensorflow/lite/experimental/ruy/context.h" -#include "tensorflow/lite/experimental/ruy/ruy.h" - -namespace iree { -namespace hal { -namespace kernels { - -// TODO(benvanik): something more clever for making this shareable. -// Maybe a factory fn based on the impl selected? -struct MatMul::RuntimeState { - // TODO(benvanik): share the thread pool but keep context per-fiber? - ruy::Context context; -}; - -inline std::unique_ptr<MatMul::RuntimeState> MatMul::CreateRuntimeState() { - return absl::make_unique<RuntimeState>(); -} - -template <typename T, typename ACC> -Status MatMul::Execute(RuntimeState* runtime_state, - const Buffers<T, ACC>& buffers) { - ruy::Matrix<T> lhs_matrix; - ruy::MakeSimpleLayout(buffers.lhs_shape[0], buffers.lhs_shape[1], - ruy::Order::kRowMajor, &lhs_matrix.layout); - lhs_matrix.data.set(buffers.lhs_buffer.data()); - - ruy::Matrix<T> rhs_matrix; - ruy::MakeSimpleLayout(buffers.rhs_shape[0], buffers.rhs_shape[1], - ruy::Order::kRowMajor, &rhs_matrix.layout); - rhs_matrix.data.set(buffers.rhs_buffer.data()); - - ruy::Matrix<T> dst_matrix; - ruy::MakeSimpleLayout(buffers.dst_shape[0], buffers.dst_shape[1], - ruy::Order::kRowMajor, &dst_matrix.layout); - dst_matrix.data.set(buffers.dst_buffer.data()); - - ruy::BasicSpec<ACC, T> spec; - spec.bias = buffers.bias_buffer.data(); - - if (buffers.multiplier_mantissa_buffer.size() == 1) { - spec.multiplier_fixedpoint = buffers.multiplier_mantissa_buffer[0]; - spec.multiplier_exponent = buffers.multiplier_exponent_buffer[0]; - } else { - spec.multiplier_fixedpoint_perchannel = - buffers.multiplier_mantissa_buffer.data(); - spec.multiplier_exponent_perchannel = - buffers.multiplier_exponent_buffer.data(); - } - - ruy::Mul<ruy::kAllPaths>(lhs_matrix, rhs_matrix, spec, - &runtime_state->context, &dst_matrix); - - return OkStatus(); -} - -} // namespace kernels -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_INTERPRETER_BYTECODE_KERNELS_RUY_H_
diff --git a/iree/hal/interpreter/bytecode_kernels_test.cc b/iree/hal/interpreter/bytecode_kernels_test.cc deleted file mode 100644 index db9a63b..0000000 --- a/iree/hal/interpreter/bytecode_kernels_test.cc +++ /dev/null
@@ -1,407 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/interpreter/bytecode_kernels.h" - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "iree/base/memory.h" -#include "iree/base/status_matchers.h" - -namespace iree { -namespace hal { -namespace kernels { - -namespace { - -constexpr float kEpsilon = 0.0001f; - -template <typename T> -std::vector<T> MakeIota(int size) { - std::vector<T> v(size); - std::iota(v.begin(), v.end(), static_cast<T>(1)); - return v; -} - -TEST(Copy, WholeBuffer) { - Shape src_shape = {2, 2}; - auto src_buffer = MakeIota<uint8_t>(4); - std::vector<int32_t> src_indices = {0, 0}; - Shape dst_shape = src_shape; - std::vector<uint8_t> dst_buffer(dst_shape.element_count()); - std::vector<int32_t> dst_indices = {0, 0}; - std::vector<int32_t> lengths = {2, 2}; - auto expected_dst = src_buffer; - - EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices, - absl::MakeSpan(dst_buffer), dst_shape, dst_indices, - lengths)); - EXPECT_EQ(dst_buffer, expected_dst); -} - -TEST(Copy, FirstRow) { - Shape src_shape = {3, 4}; - auto src_buffer = MakeIota<uint8_t>(12); - std::vector<int32_t> src_indices = {0, 0}; - Shape dst_shape = {1, 4}; - std::vector<uint8_t> dst_buffer(dst_shape.element_count()); - std::vector<int32_t> dst_indices = {0, 0}; - std::vector<int32_t> lengths = {1, 4}; - std::vector<uint8_t> expected_dst = {1, 2, 3, 4}; - - EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices, - absl::MakeSpan(dst_buffer), dst_shape, dst_indices, - lengths)); - EXPECT_EQ(dst_buffer, expected_dst); -} - -TEST(Copy, RowPart) { - Shape src_shape = {3, 4}; - auto src_buffer = MakeIota<uint8_t>(12); - std::vector<int32_t> src_indices = {1, 1}; - Shape dst_shape = {1, 2}; - std::vector<uint8_t> dst_buffer(dst_shape.element_count()); - std::vector<int32_t> dst_indices = {0, 0}; - std::vector<int32_t> lengths = {1, 2}; - std::vector<uint8_t> expected_dst = {6, 7}; - - EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices, - absl::MakeSpan(dst_buffer), dst_shape, dst_indices, - lengths)); - EXPECT_EQ(dst_buffer, expected_dst); -} - -TEST(Copy, MultiRow) { - Shape src_shape = {3, 4}; - auto src_buffer = MakeIota<uint8_t>(12); - std::vector<int32_t> src_indices = {1, 0}; - Shape dst_shape = {2, 4}; - std::vector<uint8_t> dst_buffer(dst_shape.element_count()); - std::vector<int32_t> dst_indices = {0, 0}; - std::vector<int32_t> lengths = {2, 4}; - std::vector<uint8_t> expected_dst = {5, 6, 7, 8, 9, 10, 11, 12}; - - EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices, - absl::MakeSpan(dst_buffer), dst_shape, dst_indices, - lengths)); - EXPECT_EQ(dst_buffer, expected_dst); -} - -TEST(Copy, NonContiguous) { - Shape src_shape = {3, 4}; - auto src_buffer = MakeIota<uint8_t>(12); - std::vector<int32_t> src_indices = {1, 1}; - Shape dst_shape = {2, 2}; - std::vector<uint8_t> dst_buffer(dst_shape.element_count()); - std::vector<int32_t> dst_indices = {0, 0}; - std::vector<int32_t> lengths = {2, 2}; - std::vector<uint8_t> expected_dst = {6, 7, 10, 11}; - - EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices, - absl::MakeSpan(dst_buffer), dst_shape, dst_indices, - lengths)); - EXPECT_EQ(dst_buffer, expected_dst); -} - -TEST(Copy, MultiByte) { - Shape src_shape = {3, 4}; - auto src_vals = MakeIota<int32_t>(12); - auto src_buffer = ReinterpretSpan<uint8_t>(absl::MakeSpan(src_vals)); - std::vector<int32_t> src_indices = {1, 1}; - Shape dst_shape = {2, 2}; - std::vector<uint8_t> dst_buffer(dst_shape.element_count() * sizeof(int32_t)); - std::vector<int32_t> dst_indices = {0, 0}; - std::vector<int32_t> lengths = {2, 2}; - std::vector<int32_t> expected_dst = {6, 7, 10, 11}; - - EXPECT_OK(Copy::Execute<4>(src_buffer, src_shape, src_indices, - absl::MakeSpan(dst_buffer), dst_shape, dst_indices, - lengths)); - - absl::Span<int32_t> dst_buffer_int32_t = - ReinterpretSpan<int32_t>(absl::MakeSpan(dst_buffer)); - - EXPECT_EQ(dst_buffer_int32_t, expected_dst); -} - -TEST(Copy, NotFullDst) { - Shape src_shape = {3, 4}; - auto src_buffer = MakeIota<uint8_t>(12); - std::vector<int32_t> src_indices = {0, 0}; - Shape dst_shape = {4, 3}; - std::vector<uint8_t> dst_buffer(12, 42); - std::vector<int32_t> dst_indices = {1, 1}; - std::vector<int32_t> lengths = {2, 2}; - // clang-format off - std::vector<uint8_t> expected_dst = {42, 42, 42, - 42, 1, 2, - 42, 5, 6, - 42, 42, 42}; - // clang-format on - - EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices, - absl::MakeSpan(dst_buffer), dst_shape, dst_indices, - lengths)); - EXPECT_EQ(dst_buffer, expected_dst); -} - -TEST(Copy, HighRank) { - Shape src_shape = {3, 3, 3, 3}; - auto src_buffer = MakeIota<uint8_t>(81); - std::vector<int32_t> src_indices = {1, 1, 1, 1}; - Shape dst_shape = {2, 2, 2, 2}; - std::vector<uint8_t> dst_buffer(dst_shape.element_count()); - std::vector<int32_t> dst_indices = {0, 0, 0, 0}; - std::vector<int32_t> lengths = {2, 2, 2, 2}; - std::vector<uint8_t> expected_dst = {41, 42, 44, 45, 50, 51, 53, 54, - 68, 69, 71, 72, 77, 78, 80, 81}; - - EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices, - absl::MakeSpan(dst_buffer), dst_shape, dst_indices, - lengths)); - EXPECT_EQ(dst_buffer, expected_dst); -} - -TEST(Copy, Scalar) { - Shape src_shape = {}; - std::vector<uint8_t> src_buffer = {42}; - std::vector<int32_t> src_indices = {}; - Shape dst_shape = {}; - std::vector<uint8_t> dst_buffer(dst_shape.element_count()); - std::vector<int32_t> dst_indices = {}; - std::vector<int32_t> lengths = {}; - std::vector<uint8_t> expected_dst = {42}; - - EXPECT_OK(Copy::Execute<1>(src_buffer, src_shape, src_indices, - absl::MakeSpan(dst_buffer), dst_shape, dst_indices, - lengths)); - EXPECT_EQ(dst_buffer, expected_dst); -} - -TEST(Copy, ScalarMultiByte) { - Shape src_shape = {}; - std::vector<int32_t> src_vals = {INT32_MAX}; - auto src_buffer = ReinterpretSpan<uint8_t>(absl::MakeSpan(src_vals)); - std::vector<int32_t> src_indices = {}; - Shape dst_shape = {}; - std::vector<uint8_t> dst_buffer(sizeof(int32_t)); - std::vector<int32_t> dst_indices = {}; - std::vector<int32_t> lengths = {}; - std::vector<int32_t> expected_dst = {INT32_MAX}; - - EXPECT_OK(Copy::Execute<4>(src_buffer, src_shape, src_indices, - absl::MakeSpan(dst_buffer), dst_shape, dst_indices, - lengths)); - - absl::Span<int32_t> dst_buffer_int32_t = - ReinterpretSpan<int32_t>(absl::MakeSpan(dst_buffer)); - - EXPECT_EQ(dst_buffer_int32_t, expected_dst); -} - -TEST(Pad, NoPadding) { - Shape src_shape = {2, 3}; - auto src_buffer = MakeIota<uint16_t>(src_shape.element_count()); - std::vector<uint16_t> pad_value_buffer = {0}; - std::vector<int32_t> edge_padding_low = {0, 0}; - std::vector<int32_t> edge_padding_high = {0, 0}; - std::vector<int32_t> interior_padding = {0, 0}; - Shape dst_shape = src_shape; - std::vector<uint16_t> dst_buffer(dst_shape.element_count(), UINT16_MAX); - auto expected_dst = src_buffer; - - EXPECT_OK(Pad::Execute<uint16_t>( - src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape, - dst_shape, edge_padding_low, edge_padding_high, interior_padding)); - EXPECT_EQ(dst_buffer, expected_dst); -} - -TEST(Pad, LowHighPadding) { - Shape src_shape = {2, 3}; - auto src_buffer = MakeIota<uint16_t>(src_shape.element_count()); - std::vector<uint16_t> pad_value_buffer = {0}; - std::vector<int32_t> edge_padding_low = {0, 1}; - std::vector<int32_t> edge_padding_high = {1, 2}; - std::vector<int32_t> interior_padding = {0, 0}; - Shape dst_shape = {3, 6}; - std::vector<uint16_t> dst_buffer(dst_shape.element_count(), UINT16_MAX); - // clang-format off - std::vector<uint16_t> expected_dst = {0, 1, 2, 3, 0, 0, - 0, 4, 5, 6, 0, 0, - 0, 0, 0, 0, 0, 0}; - // clang-format on - - EXPECT_OK(Pad::Execute<uint16_t>( - src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape, - dst_shape, edge_padding_low, edge_padding_high, interior_padding)); - EXPECT_EQ(dst_buffer, expected_dst); -} - -TEST(Pad, OnlyHighPadding) { - Shape src_shape = {2, 3}; - auto src_buffer = MakeIota<uint16_t>(src_shape.element_count()); - std::vector<uint16_t> pad_value_buffer = {0}; - std::vector<int32_t> edge_padding_low = {0, 0}; - std::vector<int32_t> edge_padding_high = {1, 3}; - std::vector<int32_t> interior_padding = {0, 0}; - Shape dst_shape = {3, 6}; - std::vector<uint16_t> dst_buffer(dst_shape.element_count(), UINT16_MAX); - // clang-format off - std::vector<uint16_t> expected_dst = {1, 2, 3, 0, 0, 0, - 4, 5, 6, 0, 0, 0, - 0, 0, 0, 0, 0, 0}; - // clang-format on - - EXPECT_OK(Pad::Execute<uint16_t>( - src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape, - dst_shape, edge_padding_low, edge_padding_high, interior_padding)); - EXPECT_EQ(dst_buffer, expected_dst); -} - -TEST(Pad, OnlyLowPadding) { - Shape src_shape = {2, 3}; - auto src_buffer = MakeIota<uint16_t>(src_shape.element_count()); - std::vector<uint16_t> pad_value_buffer = {0}; - std::vector<int32_t> edge_padding_low = {1, 3}; - std::vector<int32_t> edge_padding_high = {0, 0}; - std::vector<int32_t> interior_padding = {0, 0}; - Shape dst_shape = {3, 6}; - std::vector<uint16_t> dst_buffer(dst_shape.element_count(), UINT16_MAX); - // clang-format off - std::vector<uint16_t> expected_dst = {0, 0, 0, 0, 0, 0, - 0, 0, 0, 1, 2, 3, - 0, 0, 0, 4, 5, 6}; - // clang-format on - - EXPECT_OK(Pad::Execute<uint16_t>( - src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape, - dst_shape, edge_padding_low, edge_padding_high, interior_padding)); - EXPECT_EQ(dst_buffer, expected_dst); -} - -TEST(Pad, OnlyInteriorPadding) { - Shape src_shape = {2, 3}; - auto src_buffer = MakeIota<uint16_t>(src_shape.element_count()); - std::vector<uint16_t> pad_value_buffer = {0}; - std::vector<int32_t> edge_padding_low = {0, 0}; - std::vector<int32_t> edge_padding_high = {0, 0}; - std::vector<int32_t> interior_padding = {1, 1}; - Shape dst_shape = {3, 5}; - std::vector<uint16_t> dst_buffer(dst_shape.element_count(), UINT16_MAX); - // clang-format off - std::vector<uint16_t> expected_dst = {1, 0, 2, 0, 3, - 0, 0, 0, 0, 0, - 4, 0, 5, 0, 6}; - // clang-format on - - EXPECT_OK(Pad::Execute<uint16_t>( - src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape, - dst_shape, edge_padding_low, edge_padding_high, interior_padding)); - EXPECT_EQ(dst_buffer, expected_dst); -} - -TEST(Pad, AllPaddingTypes) { - Shape src_shape = {2, 3}; - auto src_buffer = MakeIota<uint16_t>(src_shape.element_count()); - std::vector<uint16_t> pad_value_buffer = {0}; - std::vector<int32_t> edge_padding_low = {1, 1}; - std::vector<int32_t> edge_padding_high = {1, 2}; - std::vector<int32_t> interior_padding = {1, 1}; - Shape dst_shape = {5, 8}; - std::vector<uint16_t> dst_buffer(dst_shape.element_count(), UINT16_MAX); - // clang-format off - std::vector<uint16_t> expected_dst = {0, 0, 0, 0, 0, 0, 0, 0, - 0, 1, 0, 2, 0, 3, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 4, 0, 5, 0, 6, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0}; - // clang-format on - - EXPECT_OK(Pad::Execute<uint16_t>( - src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape, - dst_shape, edge_padding_low, edge_padding_high, interior_padding)); - EXPECT_EQ(dst_buffer, expected_dst); -} - -TEST(Pad, HighRank) { - Shape src_shape = {2, 2, 2, 2}; - auto src_buffer = MakeIota<uint16_t>(src_shape.element_count()); - std::vector<uint16_t> pad_value_buffer = {0}; - std::vector<int32_t> edge_padding_low = {1, 0, 0, 0}; - std::vector<int32_t> edge_padding_high = {0, 1, 0, 0}; - std::vector<int32_t> interior_padding = {0, 0, 1, 0}; - Shape dst_shape = {3, 3, 3, 2}; - std::vector<uint16_t> dst_buffer(dst_shape.element_count(), UINT16_MAX); - // clang-format off - std::vector<uint16_t> expected_dst = { 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, - - 1, 2, 0, 0, 3, 4, - 5, 6, 0, 0, 7, 8, - 0, 0, 0, 0, 0, 0, - - 9, 10, 0, 0, 11, 12, - 13, 14, 0, 0, 15, 16, - 0, 0, 0, 0, 0, 0}; - // clang-format on - - ASSERT_EQ(dst_buffer.size(), expected_dst.size()); - - EXPECT_OK(Pad::Execute<uint16_t>( - src_buffer, pad_value_buffer, absl::MakeSpan(dst_buffer), src_shape, - dst_shape, edge_padding_low, edge_padding_high, interior_padding)); - EXPECT_EQ(dst_buffer, expected_dst); -} - -TEST(ReduceSum, Scalar) { - Shape src_shape = {5}; - int32_t dimension = 0; - Shape dst_shape = {1}; - std::vector<float> src_buffer = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; - std::vector<float> init_buffer = {0.0f}; - std::vector<float> dst_buffer(dst_shape.element_count(), 0.0f); - std::vector<float> expected_dst = {5.0f}; - - EXPECT_OK(ReduceSum::Execute<float>(src_buffer, init_buffer, - absl::MakeSpan(dst_buffer), dimension, - src_shape, dst_shape)); - - for (int i = 0; i < dst_buffer.size(); ++i) { - EXPECT_NEAR(expected_dst[i], dst_buffer[i], kEpsilon); - } -} - -TEST(ReduceMin, TwoDimensionsToOne) { - Shape src_shape = {3, 3}; - int32_t dimension = 0; - Shape dst_shape = {3}; - std::vector<float> src_buffer = MakeIota<float>(src_shape.element_count()); - std::vector<float> init_buffer = {std::numeric_limits<float>::max()}; - std::vector<float> dst_buffer(dst_shape.element_count(), 0.0f); - std::vector<float> expected_dst = {1.0f, 2.0f, 3.0f}; - - EXPECT_OK(ReduceMin::Execute<float>(src_buffer, init_buffer, - absl::MakeSpan(dst_buffer), dimension, - src_shape, dst_shape)); - - for (int i = 0; i < dst_buffer.size(); ++i) { - EXPECT_NEAR(expected_dst[i], dst_buffer[i], kEpsilon); - } -} - -} // namespace -} // namespace kernels -} // namespace hal -} // namespace iree
diff --git a/iree/hal/interpreter/interpreter_command_processor.cc b/iree/hal/interpreter/interpreter_command_processor.cc deleted file mode 100644 index 783809d..0000000 --- a/iree/hal/interpreter/interpreter_command_processor.cc +++ /dev/null
@@ -1,66 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/interpreter/interpreter_command_processor.h" - -#include "absl/container/inlined_vector.h" -#include "absl/types/span.h" -#include "iree/base/source_location.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/buffer_view.h" -#include "iree/hal/interpreter/bytecode_executable.h" -#include "iree/rt/stack.h" - -namespace iree { -namespace hal { - -InterpreterCommandProcessor::InterpreterCommandProcessor( - Allocator* allocator, CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories) - : HostLocalCommandProcessor(allocator, mode, command_categories) {} - -InterpreterCommandProcessor::~InterpreterCommandProcessor() = default; - -Status InterpreterCommandProcessor::Dispatch( - const DispatchRequest& dispatch_request) { - IREE_TRACE_SCOPE0("InterpreterCommandProcessor::Dispatch"); - - // Lookup the exported function. - auto* executable = - static_cast<BytecodeExecutable*>(dispatch_request.executable); - const auto& module = executable->module(); - ASSIGN_OR_RETURN(auto entry_function, module->LookupFunctionByOrdinal( - rt::Function::Linkage::kExport, - dispatch_request.entry_point)); - - rt::Stack stack(executable->context().get()); - - // TODO(benvanik): avoid this by directly referencing the bindings. - absl::InlinedVector<BufferView, 8> arguments; - arguments.reserve(dispatch_request.bindings.size()); - for (auto& binding : dispatch_request.bindings) { - arguments.push_back(BufferView{add_ref(binding.buffer), binding.shape, - binding.element_size}); - } - absl::InlinedVector<BufferView, 8> results; - - RETURN_IF_ERROR(executable->module()->Execute( - &stack, entry_function, std::move(arguments), &results)); - - return OkStatus(); -} - -} // namespace hal -} // namespace iree
diff --git a/iree/hal/interpreter/interpreter_command_processor.h b/iree/hal/interpreter/interpreter_command_processor.h deleted file mode 100644 index b3e474d..0000000 --- a/iree/hal/interpreter/interpreter_command_processor.h +++ /dev/null
@@ -1,36 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_INTERPRETER_INTERPRETER_COMMAND_PROCESSOR_H_ -#define IREE_HAL_INTERPRETER_INTERPRETER_COMMAND_PROCESSOR_H_ - -#include "iree/hal/host/host_local_command_processor.h" - -namespace iree { -namespace hal { - -class InterpreterCommandProcessor final : public HostLocalCommandProcessor { - public: - InterpreterCommandProcessor(Allocator* allocator, - CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories); - ~InterpreterCommandProcessor() override; - - Status Dispatch(const DispatchRequest& dispatch_request) override; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_INTERPRETER_INTERPRETER_COMMAND_PROCESSOR_H_
diff --git a/iree/hal/interpreter/interpreter_device.cc b/iree/hal/interpreter/interpreter_device.cc deleted file mode 100644 index 7a61cc3..0000000 --- a/iree/hal/interpreter/interpreter_device.cc +++ /dev/null
@@ -1,173 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/interpreter/interpreter_device.h" - -#include <utility> - -#include "absl/memory/memory.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/command_buffer_validation.h" -#include "iree/hal/command_queue.h" -#include "iree/hal/fence.h" -#include "iree/hal/host/async_command_queue.h" -#include "iree/hal/host/host_event.h" -#include "iree/hal/host/host_submission_queue.h" -#include "iree/hal/host/inproc_command_buffer.h" -#include "iree/hal/interpreter/bytecode_cache.h" -#include "iree/hal/interpreter/interpreter_command_processor.h" - -namespace iree { -namespace hal { - -namespace { - -// A CommandQueue that performs no synchronization (semaphores/fences) and just -// directly executes command buffers inline. -// -// This is meant to be wrapped by SyncCommandQueue or AsyncCommandQueue that -// themselves perform the synchronization/threading/etc. As such we ignore -// all semaphores in the provided batches under the assumption that if Submit is -// being called then all dependencies are valid. The wrapping queue is also -// responsible for signaling the fence as well as propagating errors in a way -// that is dependent on how it is performing its synchronization. -class UnsynchronizedCommandQueue final : public CommandQueue { - public: - UnsynchronizedCommandQueue(Allocator* allocator, std::string name, - CommandCategoryBitfield supported_categories) - : CommandQueue(std::move(name), supported_categories), - allocator_(allocator) {} - ~UnsynchronizedCommandQueue() override = default; - - Status Submit(absl::Span<const SubmissionBatch> batches, - FenceValue fence) override { - IREE_TRACE_SCOPE0("UnsynchronizedCommandQueue::Submit"); - DCHECK_EQ(nullptr, fence.first) - << "Fences must be handled by the wrapping queue"; - - // Process command buffers and propagate errors asynchronously through the - // fence. This ensures that even if we are running synchronously we still - // get consistent failure behavior with drivers that are purely async. - for (auto& batch : batches) { - DCHECK(batch.wait_semaphores.empty() && batch.signal_semaphores.empty()) - << "Semaphores must be handled by the wrapping queue"; - RETURN_IF_ERROR(ProcessCommandBuffers(batch.command_buffers)); - } - - // NOTE: fence is ignored here. - return OkStatus(); - } - - Status WaitIdle(absl::Time deadline) override { - // No-op. - return OkStatus(); - } - - private: - // Processes each command buffer in-turn with a fresh processor. - // This ensures we don't have any state that can carry across buffers. - Status ProcessCommandBuffers( - absl::Span<CommandBuffer* const> command_buffers) { - IREE_TRACE_SCOPE0("UnsynchronizedCommandQueue::ProcessCommandBuffers"); - for (auto* command_buffer : command_buffers) { - auto* inproc_command_buffer = - static_cast<InProcCommandBuffer*>(command_buffer->impl()); - InterpreterCommandProcessor command_processor( - allocator_, command_buffer->mode(), supported_categories()); - RETURN_IF_ERROR(inproc_command_buffer->Process(&command_processor)); - } - return OkStatus(); - } - - Allocator* const allocator_; -}; - -} // namespace - -InterpreterDevice::InterpreterDevice(DeviceInfo device_info) - : Device(std::move(device_info)), instance_(make_ref<rt::Instance>()) { - // We currently only expose a single command queue. - auto command_queue = absl::make_unique<UnsynchronizedCommandQueue>( - &allocator_, "cpu0", - CommandCategory::kTransfer | CommandCategory::kDispatch); - - // TODO(benvanik): allow injection of the wrapper type to support - // SyncCommandQueue without always linking in both. - auto async_command_queue = - absl::make_unique<AsyncCommandQueue>(std::move(command_queue)); - command_queues_.push_back(std::move(async_command_queue)); -} - -InterpreterDevice::~InterpreterDevice() = default; - -std::shared_ptr<ExecutableCache> InterpreterDevice::CreateExecutableCache() { - return std::make_shared<BytecodeCache>(add_ref(instance_), &allocator_); -} - -StatusOr<ref_ptr<CommandBuffer>> InterpreterDevice::CreateCommandBuffer( - CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories) { - // TODO(b/140026716): conditionally enable validation. - auto impl = - make_ref<InProcCommandBuffer>(&allocator_, mode, command_categories); - return WrapCommandBufferWithValidation(std::move(impl)); -} - -StatusOr<ref_ptr<Event>> InterpreterDevice::CreateEvent() { - return make_ref<HostEvent>(); -} - -StatusOr<ref_ptr<BinarySemaphore>> InterpreterDevice::CreateBinarySemaphore( - bool initial_value) { - IREE_TRACE_SCOPE0("InterpreterDevice::CreateBinarySemaphore"); - return make_ref<HostBinarySemaphore>(initial_value); -} - -StatusOr<ref_ptr<TimelineSemaphore>> InterpreterDevice::CreateTimelineSemaphore( - uint64_t initial_value) { - IREE_TRACE_SCOPE0("InterpreterDevice::CreateTimelineSemaphore"); - - // TODO(b/140141417): implement timeline semaphores. - return UnimplementedErrorBuilder(IREE_LOC) - << "Timeline semaphores not yet implemented"; -} - -StatusOr<ref_ptr<Fence>> InterpreterDevice::CreateFence( - uint64_t initial_value) { - IREE_TRACE_SCOPE0("InterpreterDevice::CreateFence"); - return make_ref<HostFence>(initial_value); -} - -Status InterpreterDevice::WaitAllFences(absl::Span<const FenceValue> fences, - absl::Time deadline) { - IREE_TRACE_SCOPE0("InterpreterDevice::WaitAllFences"); - return HostFence::WaitForFences(fences, /*wait_all=*/true, deadline); -} - -StatusOr<int> InterpreterDevice::WaitAnyFence( - absl::Span<const FenceValue> fences, absl::Time deadline) { - IREE_TRACE_SCOPE0("InterpreterDevice::WaitAnyFence"); - return HostFence::WaitForFences(fences, /*wait_all=*/false, deadline); -} - -Status InterpreterDevice::WaitIdle(absl::Time deadline) { - for (auto& command_queue : command_queues_) { - RETURN_IF_ERROR(command_queue->WaitIdle(deadline)); - } - return OkStatus(); -} - -} // namespace hal -} // namespace iree
diff --git a/iree/hal/interpreter/interpreter_device.h b/iree/hal/interpreter/interpreter_device.h deleted file mode 100644 index 5987ab9..0000000 --- a/iree/hal/interpreter/interpreter_device.h +++ /dev/null
@@ -1,79 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_INTERPRETER_INTERPRETER_DEVICE_H_ -#define IREE_HAL_INTERPRETER_INTERPRETER_DEVICE_H_ - -#include "absl/container/inlined_vector.h" -#include "absl/types/span.h" -#include "iree/base/memory.h" -#include "iree/hal/device.h" -#include "iree/hal/host/host_local_allocator.h" -#include "iree/hal/interpreter/bytecode_kernels.h" -#include "iree/rt/instance.h" - -namespace iree { -namespace hal { - -class InterpreterDevice final : public Device { - public: - explicit InterpreterDevice(DeviceInfo device_info); - ~InterpreterDevice() override; - - kernels::RuntimeState* kernel_runtime_state() { - return &kernel_runtime_state_; - } - - Allocator* allocator() const override { return &allocator_; } - - absl::Span<CommandQueue*> dispatch_queues() const override { - return RawPtrSpan(absl::MakeSpan(command_queues_)); - } - - absl::Span<CommandQueue*> transfer_queues() const override { - return RawPtrSpan(absl::MakeSpan(command_queues_)); - } - - std::shared_ptr<ExecutableCache> CreateExecutableCache() override; - - StatusOr<ref_ptr<CommandBuffer>> CreateCommandBuffer( - CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories) override; - - StatusOr<ref_ptr<Event>> CreateEvent() override; - - StatusOr<ref_ptr<BinarySemaphore>> CreateBinarySemaphore( - bool initial_value) override; - StatusOr<ref_ptr<TimelineSemaphore>> CreateTimelineSemaphore( - uint64_t initial_value) override; - - StatusOr<ref_ptr<Fence>> CreateFence(uint64_t initial_value) override; - Status WaitAllFences(absl::Span<const FenceValue> fences, - absl::Time deadline) override; - StatusOr<int> WaitAnyFence(absl::Span<const FenceValue> fences, - absl::Time deadline) override; - - Status WaitIdle(absl::Time deadline) override; - - private: - ref_ptr<rt::Instance> instance_; - kernels::RuntimeState kernel_runtime_state_; - mutable HostLocalAllocator allocator_; - mutable absl::InlinedVector<std::unique_ptr<CommandQueue>, 1> command_queues_; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_INTERPRETER_INTERPRETER_DEVICE_H_
diff --git a/iree/hal/interpreter/interpreter_driver.cc b/iree/hal/interpreter/interpreter_driver.cc deleted file mode 100644 index d9074ef..0000000 --- a/iree/hal/interpreter/interpreter_driver.cc +++ /dev/null
@@ -1,62 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/interpreter/interpreter_driver.h" - -#include <memory> - -#include "iree/hal/device_info.h" -#include "iree/hal/interpreter/interpreter_device.h" - -namespace iree { -namespace hal { - -namespace { - -DeviceInfo GetDefaultDeviceInfo() { - DeviceFeatureBitfield supported_features = DeviceFeature::kNone; - // TODO(benvanik): implement debugging/profiling features. - // supported_features |= DeviceFeature::kDebugging; - // supported_features |= DeviceFeature::kCoverage; - // supported_features |= DeviceFeature::kProfiling; - DeviceInfo device_info("interpreter", supported_features); - // TODO(benvanik): device info. - return device_info; -} - -} // namespace - -InterpreterDriver::InterpreterDriver() : Driver("interpreter") {} - -InterpreterDriver::~InterpreterDriver() = default; - -StatusOr<std::vector<DeviceInfo>> -InterpreterDriver::EnumerateAvailableDevices() { - std::vector<DeviceInfo> device_infos; - device_infos.push_back(GetDefaultDeviceInfo()); - return device_infos; -} - -StatusOr<std::shared_ptr<Device>> InterpreterDriver::CreateDefaultDevice() { - return CreateDevice(GetDefaultDeviceInfo()); -} - -StatusOr<std::shared_ptr<Device>> InterpreterDriver::CreateDevice( - const DeviceInfo& device_info) { - auto device = std::make_shared<InterpreterDevice>(device_info); - return device; -} - -} // namespace hal -} // namespace iree
diff --git a/iree/hal/interpreter/interpreter_driver.h b/iree/hal/interpreter/interpreter_driver.h deleted file mode 100644 index b217996..0000000 --- a/iree/hal/interpreter/interpreter_driver.h +++ /dev/null
@@ -1,39 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_INTERPRETER_INTERPRETER_DRIVER_H_ -#define IREE_HAL_INTERPRETER_INTERPRETER_DRIVER_H_ - -#include "iree/hal/driver.h" - -namespace iree { -namespace hal { - -class InterpreterDriver final : public Driver { - public: - InterpreterDriver(); - ~InterpreterDriver() override; - - StatusOr<std::vector<DeviceInfo>> EnumerateAvailableDevices() override; - - StatusOr<std::shared_ptr<Device>> CreateDefaultDevice() override; - - StatusOr<std::shared_ptr<Device>> CreateDevice( - const DeviceInfo& device_info) override; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_INTERPRETER_INTERPRETER_DRIVER_H_
diff --git a/iree/hal/interpreter/interpreter_driver_module.cc b/iree/hal/interpreter/interpreter_driver_module.cc deleted file mode 100644 index e9e78e7..0000000 --- a/iree/hal/interpreter/interpreter_driver_module.cc +++ /dev/null
@@ -1,39 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <memory> - -#include "iree/base/init.h" -#include "iree/base/status.h" -#include "iree/hal/driver_registry.h" -#include "iree/hal/interpreter/interpreter_driver.h" - -namespace iree { -namespace hal { -namespace { - -StatusOr<std::shared_ptr<Driver>> CreateInterpreterDriver() { - return std::make_shared<InterpreterDriver>(); -} - -} // namespace -} // namespace hal -} // namespace iree - -IREE_REGISTER_MODULE_INITIALIZER(iree_hal_interpreter_driver, { - QCHECK_OK(::iree::hal::DriverRegistry::shared_registry()->Register( - "interpreter", ::iree::hal::CreateInterpreterDriver)); -}); -IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(iree_hal, - iree_hal_interpreter_driver);
diff --git a/iree/hal/interpreter/interpreter_module.cc b/iree/hal/interpreter/interpreter_module.cc deleted file mode 100644 index 84e962c..0000000 --- a/iree/hal/interpreter/interpreter_module.cc +++ /dev/null
@@ -1,80 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/interpreter/interpreter_module.h" - -#include "iree/base/flatbuffer_util.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/interpreter/bytecode_dispatch.h" -#include "iree/vm/bytecode_tables_interpreter.h" - -namespace iree { -namespace hal { - -// static -StatusOr<ref_ptr<rt::Module>> InterpreterModule::FromDef( - hal::Allocator* allocator, const ModuleDef& module_def) { - ASSIGN_OR_RETURN(auto module_file, - vm::ModuleFile::Create(&module_def, []() {})); - if (module_file->root() == nullptr) { - return InvalidArgumentErrorBuilder(IREE_LOC) << "No root ModuleDef present"; - } - - auto module = - assign_ref(new InterpreterModule(allocator, std::move(module_file))); - - // TODO(benvanik): validate internals here? or make explicit? - - return {std::move(module)}; -} - -InterpreterModule::InterpreterModule( - hal::Allocator* allocator, std::unique_ptr<vm::ModuleFile> module_file) - : vm::BytecodeModule(std::move(module_file), - vm::interpreter_opcode_table()), - allocator_(allocator) {} - -Status InterpreterModule::Execute( - rt::Stack* stack, const rt::Function function, - absl::InlinedVector<hal::BufferView, 8> arguments, - absl::InlinedVector<hal::BufferView, 8>* results) const { - IREE_TRACE_SCOPE0("InterperterModule::Execute"); - - // Push stack frame for the function we are calling. - ASSIGN_OR_RETURN(auto* callee_stack_frame, stack->PushFrame(function)); - - // TODO(benvanik): rework register storage interface. - ASSIGN_OR_RETURN(const auto* function_def, - GetFunctionDef(function.linkage(), function.ordinal())); - auto* registers = callee_stack_frame->mutable_registers(); - registers->buffer_views.resize(function_def->bytecode()->local_count()); - - // Marshal input arguments. - for (int i = 0; i < arguments.size(); ++i) { - registers->buffer_views[i] = std::move(arguments[i]); - } - - // Run main dispatch loop until it exits (or errors). - RETURN_IF_ERROR(Dispatch(allocator_, &kernel_runtime_state_, stack, - callee_stack_frame, absl::MakeSpan(*results))); - - // Pop the callee frame to balance out the stack. - RETURN_IF_ERROR(stack->PopFrame()); - - return OkStatus(); -} - -} // namespace hal -} // namespace iree
diff --git a/iree/hal/interpreter/interpreter_module.h b/iree/hal/interpreter/interpreter_module.h deleted file mode 100644 index d7d98e3..0000000 --- a/iree/hal/interpreter/interpreter_module.h +++ /dev/null
@@ -1,55 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_INTERPRETER_INTERPRETER_MODULE_H_ -#define IREE_HAL_INTERPRETER_INTERPRETER_MODULE_H_ - -#include <memory> - -#include "absl/types/span.h" -#include "iree/base/status.h" -#include "iree/hal/allocator.h" -#include "iree/hal/buffer_view.h" -#include "iree/hal/interpreter/bytecode_kernels.h" -#include "iree/rt/function.h" -#include "iree/rt/module.h" -#include "iree/rt/stack.h" -#include "iree/vm/bytecode_module.h" -#include "iree/vm/bytecode_tables_interpreter.h" - -namespace iree { -namespace hal { - -class InterpreterModule final : public vm::BytecodeModule { - public: - static StatusOr<ref_ptr<rt::Module>> FromDef(hal::Allocator* allocator, - const ModuleDef& module_def); - - Status Execute( - rt::Stack* stack, const rt::Function function, - absl::InlinedVector<hal::BufferView, 8> arguments, - absl::InlinedVector<hal::BufferView, 8>* results) const override; - - private: - InterpreterModule(hal::Allocator* allocator, - std::unique_ptr<vm::ModuleFile> module_file); - - hal::Allocator* allocator_; - mutable kernels::RuntimeState kernel_runtime_state_; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_INTERPRETER_INTERPRETER_MODULE_H_
diff --git a/iree/hal/resource.h b/iree/hal/resource.h deleted file mode 100644 index 311a7ee..0000000 --- a/iree/hal/resource.h +++ /dev/null
@@ -1,33 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_RESOURCE_H_ -#define IREE_HAL_RESOURCE_H_ - -#include "iree/base/ref_ptr.h" - -namespace iree { -namespace hal { - -// Abstract resource type whose lifetime is managed by a ResourceSet. -// Used mostly just to get a virtual dtor, though we could add nicer logging. -class Resource : public RefObject<Resource> { - public: - virtual ~Resource() = default; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_RESOURCE_H_
diff --git a/iree/hal/semaphore.h b/iree/hal/semaphore.h deleted file mode 100644 index 74665d3..0000000 --- a/iree/hal/semaphore.h +++ /dev/null
@@ -1,61 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_SEMAPHORE_H_ -#define IREE_HAL_SEMAPHORE_H_ - -#include "absl/types/variant.h" -#include "iree/hal/resource.h" - -namespace iree { -namespace hal { - -// A synchronization primitive used to indicate submission dependencies. -// Semaphores are either of type binary (signaled or unsignaled) or timeline -// (uint64 payload with >= semantics). -class Semaphore : public Resource { - public: -}; - -// Binary semaphores have strict ordering requirements and must be carefully -// balanced. Each binary semaphore must only be waited on after a signal -// operation has been issued and each wait requires exactly one signal. They -// are commonly used only when interacting with external handles that may -// cross device or process boundaries. -class BinarySemaphore : public Semaphore { - public: -}; - -// Timeline semaphores act as a fence along a per-semaphore timeline where -// signaling is done by setting the payload to a monotonically increasing -// 64-bit integer and waiting is done by blocking until the payload is set -// greater-than or equal-to the specified value. Timeline semaphores may be -// waited on or signaled in any order and can be significantly more -// efficient due to system-level coalescing. -class TimelineSemaphore : public Semaphore { - public: - // TODO(benvanik): add value query support. - // TODO(benvanik): add host-side signal/wait. -}; - -// A reference to a strongly-typed semaphore and associated information. -// For TimelineSemaphores the provided payload is used to specify either the -// payload to wait for or new payload value. -using SemaphoreValue = - absl::variant<BinarySemaphore*, std::pair<TimelineSemaphore*, uint64_t>>; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_SEMAPHORE_H_
diff --git a/iree/hal/testing/BUILD b/iree/hal/testing/BUILD deleted file mode 100644 index ac80497..0000000 --- a/iree/hal/testing/BUILD +++ /dev/null
@@ -1,36 +0,0 @@ -# Test utilities for HAL-specific code. - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "mock_allocator", - testonly = True, - hdrs = ["mock_allocator.h"], - deps = [ - "//iree/hal:allocator", - "@com_google_googletest//:gtest", - ], -) - -cc_library( - name = "mock_command_buffer", - testonly = True, - hdrs = ["mock_command_buffer.h"], - deps = [ - "//iree/hal:command_buffer", - "@com_google_googletest//:gtest", - ], -) - -cc_library( - name = "mock_command_queue", - testonly = True, - hdrs = ["mock_command_queue.h"], - deps = [ - "//iree/hal:command_queue", - "@com_google_googletest//:gtest", - ], -)
diff --git a/iree/hal/testing/mock_allocator.h b/iree/hal/testing/mock_allocator.h deleted file mode 100644 index cc5548e..0000000 --- a/iree/hal/testing/mock_allocator.h +++ /dev/null
@@ -1,55 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_TESTING_MOCK_ALLOCATOR_H_ -#define IREE_HAL_TESTING_MOCK_ALLOCATOR_H_ - -#include "gmock/gmock.h" -#include "iree/hal/allocator.h" - -namespace iree { -namespace hal { -namespace testing { - -class MockAllocator : public ::testing::StrictMock<Allocator> { - public: - MockAllocator() : ::testing::StrictMock<Allocator>() {} - - MOCK_CONST_METHOD4(CanUseBufferLike, - bool(Allocator* source_allocator, - MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - BufferUsageBitfield intended_usage)); - - MOCK_CONST_METHOD3(CanAllocate, bool(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - size_t allocation_size)); - - MOCK_METHOD3(Allocate, - StatusOr<ref_ptr<Buffer>>(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - size_t allocation_size)); - - MOCK_METHOD5(WrapMutable, - StatusOr<ref_ptr<Buffer>>(MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, - BufferUsageBitfield buffer_usage, - void* data, size_t data_length)); -}; - -} // namespace testing -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_TESTING_MOCK_ALLOCATOR_H_
diff --git a/iree/hal/testing/mock_command_buffer.h b/iree/hal/testing/mock_command_buffer.h deleted file mode 100644 index 6c55499..0000000 --- a/iree/hal/testing/mock_command_buffer.h +++ /dev/null
@@ -1,80 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_TESTING_MOCK_COMMAND_BUFFER_H_ -#define IREE_HAL_TESTING_MOCK_COMMAND_BUFFER_H_ - -#include "gmock/gmock.h" -#include "iree/hal/command_buffer.h" - -namespace iree { -namespace hal { -namespace testing { - -class MockCommandBuffer : public ::testing::StrictMock<CommandBuffer> { - public: - MockCommandBuffer(Allocator* allocator, CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories) - : ::testing::StrictMock<CommandBuffer>(allocator, mode, - command_categories) {} - - bool is_recording() const override { return false; } - - MOCK_METHOD0(Begin, Status()); - MOCK_METHOD0(End, Status()); - - MOCK_METHOD4(ExecutionBarrier, - Status(ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span<const MemoryBarrier> memory_barriers, - absl::Span<const BufferBarrier> buffer_barriers)); - - MOCK_METHOD2(SignalEvent, - Status(Event* event, ExecutionStageBitfield source_stage_mask)); - - MOCK_METHOD2(ResetEvent, - Status(Event* event, ExecutionStageBitfield source_stage_mask)); - - MOCK_METHOD5(WaitEvents, - Status(absl::Span<Event*> events, - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span<const MemoryBarrier> memory_barriers, - absl::Span<const BufferBarrier> buffer_barriers)); - - MOCK_METHOD5(FillBuffer, - Status(Buffer* target_buffer, device_size_t target_offset, - device_size_t length, const void* pattern, - size_t pattern_length)); - - MOCK_METHOD1(DiscardBuffer, Status(Buffer* buffer)); - - MOCK_METHOD5(UpdateBuffer, - Status(const void* source_buffer, device_size_t source_offset, - Buffer* target_buffer, device_size_t target_offset, - device_size_t length)); - - MOCK_METHOD5(CopyBuffer, - Status(Buffer* source_buffer, device_size_t source_offset, - Buffer* target_buffer, device_size_t target_offset, - device_size_t length)); - - MOCK_METHOD1(Dispatch, Status(const DispatchRequest& dispatch_request)); -}; - -} // namespace testing -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_TESTING_MOCK_COMMAND_BUFFER_H_
diff --git a/iree/hal/testing/mock_command_queue.h b/iree/hal/testing/mock_command_queue.h deleted file mode 100644 index 1f5fe8b..0000000 --- a/iree/hal/testing/mock_command_queue.h +++ /dev/null
@@ -1,42 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_TESTING_MOCK_COMMAND_QUEUE_H_ -#define IREE_HAL_TESTING_MOCK_COMMAND_QUEUE_H_ - -#include "gmock/gmock.h" -#include "iree/hal/command_queue.h" - -namespace iree { -namespace hal { -namespace testing { - -class MockCommandQueue : public ::testing::StrictMock<CommandQueue> { - public: - MockCommandQueue(std::string name, - CommandCategoryBitfield supported_categories) - : ::testing::StrictMock<CommandQueue>(std::move(name), - supported_categories) {} - - MOCK_METHOD2(Submit, Status(absl::Span<const SubmissionBatch> batches, - FenceValue fence)); - - MOCK_METHOD1(WaitIdle, Status(absl::Time deadline)); -}; - -} // namespace testing -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_TESTING_MOCK_COMMAND_QUEUE_H_
diff --git a/iree/hal/vulkan/BUILD b/iree/hal/vulkan/BUILD deleted file mode 100644 index 56bcb98..0000000 --- a/iree/hal/vulkan/BUILD +++ /dev/null
@@ -1,380 +0,0 @@ -# HAL implementation using Vulkan and (likely) SPIR-V executables. - -load("//iree:build_defs.bzl", "PLATFORM_VULKAN_LOADER_COPTS", "PLATFORM_VULKAN_TEST_DEPS") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -# --define=IREE_VK=native to use the native Vulkan drivers (and real hardware). -config_setting( - name = "native_vk", - values = { - "define": "IREE_VK=native", - }, -) - -# --define=IREE_VK=swiftshader to use SwiftShader. -config_setting( - name = "swiftshader_vk", - values = { - "define": "IREE_VK=swiftshader", - }, -) - -cc_library( - name = "debug_reporter", - srcs = ["debug_reporter.cc"], - hdrs = ["debug_reporter.h"], - deps = [ - ":dynamic_symbols", - ":status_util", - "//iree/base:status", - "//iree/base:tracing", - "@vulkan_headers//:vulkan_headers_no_prototypes", - ], -) - -cc_library( - name = "descriptor_pool_cache", - srcs = ["descriptor_pool_cache.cc"], - hdrs = ["descriptor_pool_cache.h"], - deps = [ - ":dynamic_symbols", - ":handle_util", - ":status_util", - "//iree/base:ref_ptr", - "//iree/base:status", - "//iree/base:tracing", - "@com_google_absl//absl/container:inlined_vector", - ], -) - -cc_library( - name = "descriptor_set_arena", - srcs = ["descriptor_set_arena.cc"], - hdrs = ["descriptor_set_arena.h"], - deps = [ - ":descriptor_pool_cache", - ":pipeline_executable", - ":status_util", - ":vma_allocator", - "//iree/base:arena", - "//iree/base:math", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:command_buffer", - ], -) - -cc_library( - name = "direct_command_buffer", - srcs = ["direct_command_buffer.cc"], - hdrs = ["direct_command_buffer.h"], - deps = [ - ":descriptor_pool_cache", - ":descriptor_set_arena", - ":dynamic_symbols", - ":handle_util", - ":native_event", - ":pipeline_executable", - ":status_util", - ":vma_allocator", - "//iree/base:math", - "//iree/base:source_location", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:command_buffer", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/synchronization", - "@vulkan_headers//:vulkan_headers_no_prototypes", - ], -) - -cc_library( - name = "direct_command_queue", - srcs = ["direct_command_queue.cc"], - hdrs = ["direct_command_queue.h"], - deps = [ - ":direct_command_buffer", - ":dynamic_symbols", - ":handle_util", - ":legacy_fence", - ":native_binary_semaphore", - ":status_util", - "//iree/base:arena", - "//iree/base:memory", - "//iree/base:source_location", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:command_queue", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@vulkan_headers//:vulkan_headers_no_prototypes", - ], -) - -cc_library( - name = "dynamic_symbols", - srcs = ["dynamic_symbols.cc"], - hdrs = [ - "dynamic_symbol_tables.h", - "dynamic_symbols.h", - ], - copts = PLATFORM_VULKAN_LOADER_COPTS, - linkopts = [ - "-ldl", - ], - deps = [ - "//iree/base:file_path", - "//iree/base:ref_ptr", - "//iree/base:source_location", - "//iree/base:status", - "//iree/base:target_platform", - "//iree/base:tracing", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/memory", - "@vulkan_headers//:vulkan_headers_no_prototypes", - ], -) - -cc_test( - name = "dynamic_symbols_test", - srcs = ["dynamic_symbols_test.cc"], - deps = [ - ":status_util", - ":dynamic_symbols", - "//iree/base:status_matchers", - ] + PLATFORM_VULKAN_TEST_DEPS, -) - -cc_library( - name = "extensibility_util", - srcs = ["extensibility_util.cc"], - hdrs = ["extensibility_util.h"], - deps = [ - ":dynamic_symbols", - ":status_util", - "//iree/base:memory", - "//iree/base:status", - "//iree/base:tracing", - "@com_google_absl//absl/types:span", - "@vulkan_headers//:vulkan_headers_no_prototypes", - ], -) - -cc_library( - name = "handle_util", - hdrs = ["handle_util.h"], - deps = [ - ":dynamic_symbols", - ":extensibility_util", - "//iree/base:ref_ptr", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/utility", - "@vulkan_headers//:vulkan_headers_no_prototypes", - ], -) - -cc_library( - name = "legacy_fence", - srcs = ["legacy_fence.cc"], - hdrs = ["legacy_fence.h"], - deps = [ - ":handle_util", - ":status_util", - "//iree/base:intrusive_list", - "//iree/base:ref_ptr", - "//iree/base:source_location", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:fence", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@vulkan_headers//:vulkan_headers_no_prototypes", - ], -) - -cc_library( - name = "native_binary_semaphore", - srcs = ["native_binary_semaphore.cc"], - hdrs = ["native_binary_semaphore.h"], - deps = [ - ":handle_util", - "//iree/hal:semaphore", - "@vulkan_headers//:vulkan_headers_no_prototypes", - ], -) - -cc_library( - name = "native_event", - srcs = ["native_event.cc"], - hdrs = ["native_event.h"], - deps = [ - ":handle_util", - "//iree/hal:event", - "@vulkan_headers//:vulkan_headers_no_prototypes", - ], -) - -cc_library( - name = "pipeline_cache", - srcs = ["pipeline_cache.cc"], - hdrs = ["pipeline_cache.h"], - deps = [ - ":handle_util", - ":pipeline_executable", - ":status_util", - "//iree/base:source_location", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:executable", - "//iree/hal:executable_cache", - "//iree/hal:executable_format", - "//iree/schemas:spirv_executable_def_cc_fbs", - "@com_github_google_flatbuffers//:flatbuffers", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/synchronization", - "@vulkan_headers//:vulkan_headers_no_prototypes", - ], -) - -cc_library( - name = "pipeline_executable", - srcs = ["pipeline_executable.cc"], - hdrs = ["pipeline_executable.h"], - deps = [ - ":handle_util", - ":status_util", - "//iree/base:memory", - "//iree/base:source_location", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:executable", - "//iree/hal:executable_cache", - "//iree/hal:executable_spec", - "//iree/schemas:spirv_executable_def_cc_fbs", - "@com_google_absl//absl/container:inlined_vector", - "@vulkan_headers//:vulkan_headers_no_prototypes", - ], -) - -cc_library( - name = "status_util", - srcs = ["status_util.cc"], - hdrs = ["status_util.h"], - deps = [ - "//iree/base:status", - "@vulkan_headers//:vulkan_headers_no_prototypes", - ], -) - -cc_library( - name = "vma_allocator", - srcs = [ - "internal_vk_mem_alloc.cc", - "internal_vk_mem_alloc.h", - "vma_allocator.cc", - "vma_buffer.cc", - ], - hdrs = [ - "vma_allocator.h", - "vma_buffer.h", - ], - copts = [ - # Only needed in the implementation cc and not by external users. - "-DVMA_STATIC_VULKAN_FUNCTIONS=0", - ], - deps = [ - ":dynamic_symbols", - ":handle_util", - ":status_util", - "//iree/base:logging", - "//iree/base:source_location", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:allocator", - "//iree/hal:buffer", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/synchronization", - "@vulkan_headers//:vulkan_headers_no_prototypes", - "@vulkan_memory_allocator//:impl_header_only", - ], -) - -cc_library( - name = "vulkan_device", - srcs = ["vulkan_device.cc"], - hdrs = ["vulkan_device.h"], - deps = [ - ":descriptor_pool_cache", - ":direct_command_buffer", - ":direct_command_queue", - ":dynamic_symbols", - ":extensibility_util", - ":handle_util", - ":legacy_fence", - ":native_binary_semaphore", - ":native_event", - ":pipeline_cache", - ":status_util", - ":vma_allocator", - "//iree/base:memory", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:allocator", - "//iree/hal:command_buffer_validation", - "//iree/hal:command_queue", - "//iree/hal:device", - "//iree/hal:fence", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", - "@vulkan_headers//:vulkan_headers_no_prototypes", - ], -) - -cc_library( - name = "vulkan_driver", - srcs = ["vulkan_driver.cc"], - hdrs = ["vulkan_driver.h"], - deps = [ - ":debug_reporter", - ":dynamic_symbols", - ":extensibility_util", - ":status_util", - ":vulkan_device", - "//iree/base:memory", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:device_info", - "//iree/hal:driver", - "@com_google_absl//absl/container:inlined_vector", - "@vulkan_headers//:vulkan_headers_no_prototypes", - ], -) - -cc_library( - name = "vulkan_driver_module", - srcs = ["vulkan_driver_module.cc"], - deps = [ - ":dynamic_symbols", - ":vulkan_driver", - "//iree/base:init", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:driver_registry", - "@com_google_absl//absl/flags:flag", - ], - alwayslink = 1, -)
diff --git a/iree/hal/vulkan/debug_reporter.cc b/iree/hal/vulkan/debug_reporter.cc deleted file mode 100644 index 46ec68d..0000000 --- a/iree/hal/vulkan/debug_reporter.cc +++ /dev/null
@@ -1,160 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/debug_reporter.h" - -#include "iree/base/tracing.h" -#include "iree/hal/vulkan/status_util.h" - -namespace iree { -namespace hal { -namespace vulkan { - -namespace { - -// NOTE: |user_data| may be nullptr if we are being called during instance -// creation. Otherwise it is a pointer to the DebugReporter instance. - -// NOTE: this callback must be thread safe and must be careful not to reach too -// far outside of the call - it is called in-context from arbitrary threads with -// some amount of Vulkan state on the stack. Assume that creating or deleting -// Vulkan objects, issuing most Vulkan commands, etc are off-limits. - -VKAPI_ATTR VkBool32 VKAPI_CALL DebugUtilsMessageCallback( - VkDebugUtilsMessageSeverityFlagBitsEXT message_severity, - VkDebugUtilsMessageTypeFlagsEXT message_type, - const VkDebugUtilsMessengerCallbackDataEXT* callback_data, - void* user_data) { - // TODO(benvanik): better logging once we have switched logging APIs. - LOG(ERROR) << callback_data->pMessage; - - return VK_FALSE; // VK_TRUE is reserved for future use. -} - -VKAPI_ATTR VkBool32 VKAPI_CALL DebugReportCallback( - VkDebugReportFlagsEXT flags, VkDebugReportObjectTypeEXT object_type, - uint64_t object, size_t location, int32_t message_code, - const char* layer_prefix, const char* message, void* user_data) { - // TODO(benvanik): better logging once we have switched logging APIs. - LOG(ERROR) << message; - - return VK_FALSE; // VK_TRUE is reserved for future use. -} - -} // namespace - -// static -void DebugReporter::PopulateStaticCreateInfo( - VkDebugUtilsMessengerCreateInfoEXT* create_info) { - create_info->sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_MESSENGER_CREATE_INFO_EXT; - create_info->pNext = nullptr; - create_info->flags = 0; - - // TODO(benvanik): only enable the severities that logging has enabled. - create_info->messageSeverity = - VK_DEBUG_UTILS_MESSAGE_SEVERITY_VERBOSE_BIT_EXT | - VK_DEBUG_UTILS_MESSAGE_SEVERITY_INFO_BIT_EXT | - VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT | - VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT; - - // TODO(benvanik): allow filtering by category as a flag. - create_info->messageType = VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT | - VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT | - VK_DEBUG_UTILS_MESSAGE_TYPE_PERFORMANCE_BIT_EXT; - - create_info->pfnUserCallback = DebugUtilsMessageCallback; - create_info->pUserData = nullptr; -} - -// static -void DebugReporter::PopulateStaticCreateInfo( - VkDebugReportCallbackCreateInfoEXT* create_info) { - create_info->sType = VK_STRUCTURE_TYPE_DEBUG_REPORT_CALLBACK_CREATE_INFO_EXT; - create_info->pNext = nullptr; - create_info->flags = 0; - - // TODO(benvanik): only enable the severities that logging has enabled. - create_info->flags |= - VK_DEBUG_REPORT_INFORMATION_BIT_EXT | VK_DEBUG_REPORT_WARNING_BIT_EXT | - VK_DEBUG_REPORT_PERFORMANCE_WARNING_BIT_EXT | - VK_DEBUG_REPORT_ERROR_BIT_EXT | VK_DEBUG_REPORT_DEBUG_BIT_EXT; - - create_info->pfnCallback = DebugReportCallback; - create_info->pUserData = nullptr; -} - -// static -StatusOr<std::unique_ptr<DebugReporter>> -DebugReporter::CreateDebugUtilsMessenger( - VkInstance instance, const ref_ptr<DynamicSymbols>& syms, - const VkAllocationCallbacks* allocation_callbacks) { - IREE_TRACE_SCOPE0("DebugReporter::CreateDebugUtilsMessenger"); - - auto debug_reporter = - absl::WrapUnique(new DebugReporter(instance, syms, allocation_callbacks)); - - VkDebugUtilsMessengerCreateInfoEXT create_info; - PopulateStaticCreateInfo(&create_info); - create_info.pUserData = debug_reporter.get(); - - VK_RETURN_IF_ERROR(syms->vkCreateDebugUtilsMessengerEXT( - instance, &create_info, allocation_callbacks, - &debug_reporter->messenger_)); - - return debug_reporter; -} - -// static -StatusOr<std::unique_ptr<DebugReporter>> -DebugReporter::CreateDebugReportCallback( - VkInstance instance, const ref_ptr<DynamicSymbols>& syms, - const VkAllocationCallbacks* allocation_callbacks) { - IREE_TRACE_SCOPE0("DebugReporter::CreateDebugReportCallback"); - - auto debug_reporter = - absl::WrapUnique(new DebugReporter(instance, syms, allocation_callbacks)); - - VkDebugReportCallbackCreateInfoEXT create_info; - PopulateStaticCreateInfo(&create_info); - create_info.pUserData = debug_reporter.get(); - - VK_RETURN_IF_ERROR(syms->vkCreateDebugReportCallbackEXT( - instance, &create_info, allocation_callbacks, - &debug_reporter->callback_)); - - return debug_reporter; -} - -DebugReporter::DebugReporter(VkInstance instance, - const ref_ptr<DynamicSymbols>& syms, - const VkAllocationCallbacks* allocation_callbacks) - : instance_(instance), - syms_(add_ref(syms)), - allocation_callbacks_(allocation_callbacks) {} - -DebugReporter::~DebugReporter() { - IREE_TRACE_SCOPE0("DebugReporter::dtor"); - if (messenger_ != VK_NULL_HANDLE) { - syms_->vkDestroyDebugUtilsMessengerEXT(instance_, messenger_, - allocation_callbacks_); - } - if (callback_ != VK_NULL_HANDLE) { - syms_->vkDestroyDebugReportCallbackEXT(instance_, callback_, - allocation_callbacks_); - } -} - -} // namespace vulkan -} // namespace hal -} // namespace iree
diff --git a/iree/hal/vulkan/debug_reporter.h b/iree/hal/vulkan/debug_reporter.h deleted file mode 100644 index ffc594b..0000000 --- a/iree/hal/vulkan/debug_reporter.h +++ /dev/null
@@ -1,87 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VULKAN_DEBUG_REPORTER_H_ -#define IREE_HAL_VULKAN_DEBUG_REPORTER_H_ - -#include <vulkan/vulkan.h> - -#include "iree/base/status.h" -#include "iree/hal/vulkan/dynamic_symbols.h" - -namespace iree { -namespace hal { -namespace vulkan { - -// A debug reporter that works with the VK_EXT_debug_utils extension. -// One reporter should be created per VkInstance to receive callbacks from the -// API and route them to our logging systems. In general VK_EXT_debug_utils -// should be preferred if available as it provides a much cleaner interface and -// more plug-points than VK_EXT_debug_report. -// -// Since creating a reporter requires a VkInstance it's not possible to report -// on messages during instance creation. To work around this it's possible to -// pass a *CreateInfo struct to vkCreateInstance as part of the -// VkInstanceCreateInfo::pNext chain. The callback will only be used this way -// during the creation call after which users can create the real -// instance-specific reporter. -class DebugReporter final { - public: - // Populates |create_info| with an instance-agnostic callback. - // This can be used during instance creation by chaining the |create_info| to - // VkInstanceCreateInfo::pNext. - // - // Only use if VK_EXT_debug_utils is present. - static void PopulateStaticCreateInfo( - VkDebugUtilsMessengerCreateInfoEXT* create_info); - - // Populates |create_info| with an instance-agnostic callback. - // This can be used during instance creation by chaining the |create_info| to - // VkInstanceCreateInfo::pNext. - // - // Only use if VK_EXT_debug_report is present. - static void PopulateStaticCreateInfo( - VkDebugReportCallbackCreateInfoEXT* create_info); - - // Creates a debug messenger for the given Vulkan |instance| with - // VK_EXT_debug_utils enabled. - static StatusOr<std::unique_ptr<DebugReporter>> CreateDebugUtilsMessenger( - VkInstance instance, const ref_ptr<DynamicSymbols>& syms, - const VkAllocationCallbacks* allocation_callbacks); - - // Creates a debug report callback for the given Vulkan |instance| with - // VK_EXT_debug_report enabled. - static StatusOr<std::unique_ptr<DebugReporter>> CreateDebugReportCallback( - VkInstance instance, const ref_ptr<DynamicSymbols>& syms, - const VkAllocationCallbacks* allocation_callbacks); - - ~DebugReporter(); - - private: - DebugReporter(VkInstance instance, const ref_ptr<DynamicSymbols>& syms, - const VkAllocationCallbacks* allocation_callbacks); - - VkInstance instance_ = VK_NULL_HANDLE; - ref_ptr<DynamicSymbols> syms_; - const VkAllocationCallbacks* allocation_callbacks_ = nullptr; - - VkDebugUtilsMessengerEXT messenger_ = VK_NULL_HANDLE; - VkDebugReportCallbackEXT callback_ = VK_NULL_HANDLE; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_DEBUG_REPORTER_H_
diff --git a/iree/hal/vulkan/descriptor_pool_cache.cc b/iree/hal/vulkan/descriptor_pool_cache.cc deleted file mode 100644 index 3e1e4ef..0000000 --- a/iree/hal/vulkan/descriptor_pool_cache.cc +++ /dev/null
@@ -1,102 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/descriptor_pool_cache.h" - -#include <array> - -#include "iree/base/tracing.h" -#include "iree/hal/vulkan/status_util.h" - -namespace iree { -namespace hal { -namespace vulkan { - -namespace { - -// TODO(benvanik): be more conservative with descriptor set count or allow -// chaining in the command buffer when pools run out. -static constexpr int kMaxDescriptorSets = 4096; - -} // namespace - -DescriptorSetGroup::~DescriptorSetGroup() { - CHECK(descriptor_pools_.empty()) - << "DescriptorSetGroup must be reset explicitly"; -} - -Status DescriptorSetGroup::Reset() { - IREE_TRACE_SCOPE0("DescriptorSetGroup::Reset"); - - RETURN_IF_ERROR(descriptor_pool_cache_->ReleaseDescriptorPools( - absl::MakeSpan(descriptor_pools_))); - descriptor_pools_.clear(); - - return OkStatus(); -} - -DescriptorPoolCache::DescriptorPoolCache(ref_ptr<VkDeviceHandle> logical_device) - : logical_device_(std::move(logical_device)) {} - -StatusOr<DescriptorPool> DescriptorPoolCache::AcquireDescriptorPool( - VkDescriptorType descriptor_type, int max_descriptor_count) { - IREE_TRACE_SCOPE0("DescriptorPoolCache::AcquireDescriptorPool"); - - // TODO(benvanik): lookup in cache. - - VkDescriptorPoolCreateInfo create_info; - create_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO; - create_info.pNext = nullptr; - create_info.flags = 0; - create_info.maxSets = kMaxDescriptorSets; - std::array<VkDescriptorPoolSize, 1> pool_sizes; - pool_sizes[0].type = descriptor_type; - pool_sizes[0].descriptorCount = max_descriptor_count; - create_info.poolSizeCount = pool_sizes.size(); - create_info.pPoolSizes = pool_sizes.data(); - - DescriptorPool descriptor_pool; - descriptor_pool.descriptor_type = descriptor_type; - descriptor_pool.max_descriptor_count = max_descriptor_count; - descriptor_pool.handle = VK_NULL_HANDLE; - - VK_RETURN_IF_ERROR(syms().vkCreateDescriptorPool( - *logical_device_, &create_info, logical_device_->allocator(), - &descriptor_pool.handle)); - - return descriptor_pool; -} - -Status DescriptorPoolCache::ReleaseDescriptorPools( - absl::Span<DescriptorPool> descriptor_pools) { - IREE_TRACE_SCOPE0("DescriptorPoolCache::ReleaseDescriptorPools"); - - for (const auto& descriptor_pool : descriptor_pools) { - // Always reset immediately. We could do this on allocation instead however - // this leads to better errors when using the validation layers as we'll - // throw if there are in-flight command buffers using the sets in the pool. - VK_RETURN_IF_ERROR(syms().vkResetDescriptorPool(*logical_device_, - descriptor_pool.handle, 0)); - - // TODO(benvanik): release to cache. - syms().vkDestroyDescriptorPool(*logical_device_, descriptor_pool.handle, - logical_device_->allocator()); - } - - return OkStatus(); -} - -} // namespace vulkan -} // namespace hal -} // namespace iree
diff --git a/iree/hal/vulkan/descriptor_pool_cache.h b/iree/hal/vulkan/descriptor_pool_cache.h deleted file mode 100644 index bb4fa33..0000000 --- a/iree/hal/vulkan/descriptor_pool_cache.h +++ /dev/null
@@ -1,103 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VULKAN_DESCRIPTOR_POOL_CACHE_H_ -#define IREE_HAL_VULKAN_DESCRIPTOR_POOL_CACHE_H_ - -#include "absl/container/inlined_vector.h" -#include "iree/base/ref_ptr.h" -#include "iree/hal/vulkan/dynamic_symbols.h" -#include "iree/hal/vulkan/handle_util.h" - -namespace iree { -namespace hal { -namespace vulkan { - -class DescriptorPoolCache; - -// A descriptor pool with a single descriptor type of some number. -// We only support a single descriptor type for now as we only generate SPIR-V -// that uses a single type. -struct DescriptorPool { - // Type of the descriptor in the set. - VkDescriptorType descriptor_type = VK_DESCRIPTOR_TYPE_MAX_ENUM; - // Maximum number of descriptors of the given type per allocation. - int max_descriptor_count = 0; - // Pool handle. - VkDescriptorPool handle = VK_NULL_HANDLE; -}; - -// A group of descriptor sets allocated and released together. -// The group must be explicitly reset with Reset() prior to disposing. -class DescriptorSetGroup final { - public: - DescriptorSetGroup() = default; - DescriptorSetGroup(ref_ptr<DescriptorPoolCache> descriptor_pool_cache, - absl::InlinedVector<DescriptorPool, 8> descriptor_pools) - : descriptor_pool_cache_(std::move(descriptor_pool_cache)), - descriptor_pools_(std::move(descriptor_pools)) {} - DescriptorSetGroup(const DescriptorSetGroup&) = delete; - DescriptorSetGroup& operator=(const DescriptorSetGroup&) = delete; - DescriptorSetGroup(DescriptorSetGroup&& other) noexcept - : descriptor_pool_cache_(std::move(other.descriptor_pool_cache_)), - descriptor_pools_(std::move(other.descriptor_pools_)) {} - DescriptorSetGroup& operator=(DescriptorSetGroup&& other) { - std::swap(descriptor_pool_cache_, other.descriptor_pool_cache_); - std::swap(descriptor_pools_, other.descriptor_pools_); - return *this; - } - ~DescriptorSetGroup(); - - Status Reset(); - - private: - ref_ptr<DescriptorPoolCache> descriptor_pool_cache_; - absl::InlinedVector<DescriptorPool, 8> descriptor_pools_; -}; - -// A "cache" (or really, pool) of descriptor pools. These pools are allocated -// as needed to satisfy different descriptor size requirements and are given -// to command buffers during recording to write descriptor updates and bind -// resources. After the descriptors in the pool are no longer used (all -// command buffers using descriptor sets allocated from the pool have retired) -// the pool is returned here to be reused in the future. -class DescriptorPoolCache final : public RefObject<DescriptorPoolCache> { - public: - explicit DescriptorPoolCache(ref_ptr<VkDeviceHandle> logical_device); - - const ref_ptr<VkDeviceHandle>& logical_device() const { - return logical_device_; - } - const DynamicSymbols& syms() const { return *logical_device_->syms(); } - - // Acquires a new descriptor pool for use by the caller. - // The pool will have been reset and have all descriptor sets available. - // When all sets allocated from the pool are no longer in use it must be - // returned to the cache with ReleaseDescriptorPool. - StatusOr<DescriptorPool> AcquireDescriptorPool( - VkDescriptorType descriptor_type, int max_descriptor_count); - - // Releases descriptor pools back to the cache. The pools will be reset - // immediately and must no longer be in use by any in-flight command. - Status ReleaseDescriptorPools(absl::Span<DescriptorPool> descriptor_pools); - - private: - ref_ptr<VkDeviceHandle> logical_device_; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_DESCRIPTOR_POOL_CACHE_H_
diff --git a/iree/hal/vulkan/descriptor_set_arena.cc b/iree/hal/vulkan/descriptor_set_arena.cc deleted file mode 100644 index 636cfcd..0000000 --- a/iree/hal/vulkan/descriptor_set_arena.cc +++ /dev/null
@@ -1,204 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/descriptor_set_arena.h" - -#include "iree/base/math.h" -#include "iree/base/tracing.h" -#include "iree/hal/vulkan/status_util.h" -#include "iree/hal/vulkan/vma_buffer.h" - -namespace iree { -namespace hal { -namespace vulkan { - -namespace { - -StatusOr<VmaBuffer*> CastBuffer(Buffer* buffer) { - // TODO(benvanik): assert that the buffer is from the right allocator and - // that it is compatible with our target queue family. - return static_cast<VmaBuffer*>(buffer->allocated_buffer()); -} - -StatusOr<absl::Span<VkWriteDescriptorSet>> PopulateDescriptorSetWriteInfos( - const PipelineDescriptorSets& pipeline_descriptor_sets, - absl::Span<const BufferBinding> bindings, VkDescriptorSet dst_set, - Arena* arena) { - int required_descriptor_count = - pipeline_descriptor_sets.buffer_binding_set_map.size(); - - arena->Reset(); - auto buffer_infos = - arena->AllocateSpan<VkDescriptorBufferInfo>(required_descriptor_count); - auto write_infos = arena->AllocateSpan<VkWriteDescriptorSet>(bindings.size()); - - for (int i = 0; i < bindings.size(); ++i) { - const auto& binding = bindings[i]; - - auto& buffer_info = buffer_infos[i]; - ASSIGN_OR_RETURN(auto buffer, CastBuffer(binding.buffer)); - buffer_info.buffer = buffer->handle(); - // TODO(benvanik): properly subrange (add to BufferBinding). - buffer_info.offset = binding.buffer->byte_offset(); - buffer_info.range = binding.buffer->byte_length(); - - auto& write_info = write_infos[i]; - write_info.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET; - write_info.pNext = nullptr; - write_info.dstSet = dst_set; - write_info.dstBinding = pipeline_descriptor_sets.buffer_binding_set_map[i]; - write_info.dstArrayElement = 0; - write_info.descriptorCount = 1; - write_info.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; - write_info.pImageInfo = nullptr; - write_info.pBufferInfo = &buffer_info; - write_info.pTexelBufferView = nullptr; - } - - return write_infos; -} - -} // namespace - -DescriptorSetArena::DescriptorSetArena( - ref_ptr<DescriptorPoolCache> descriptor_pool_cache) - : logical_device_(add_ref(descriptor_pool_cache->logical_device())), - descriptor_pool_cache_(std::move(descriptor_pool_cache)) {} - -DescriptorSetArena::~DescriptorSetArena() { - if (!used_descriptor_pools_.empty()) { - descriptor_pool_cache_ - ->ReleaseDescriptorPools(absl::MakeSpan(used_descriptor_pools_)) - .IgnoreError(); - used_descriptor_pools_.clear(); - } -} - -Status DescriptorSetArena::BindDescriptorSet( - VkCommandBuffer command_buffer, PipelineExecutable* executable, - absl::Span<const BufferBinding> bindings) { - // Always prefer using push descriptors when available as we can avoid the - // additional API overhead of updating/resetting pools. - if (logical_device_->enabled_extensions().push_descriptors) { - return PushDescriptorSet(command_buffer, executable, bindings); - } - - IREE_TRACE_SCOPE0("DescriptorSetArena::BindDescriptorSet"); - - // Pick a bucket based on the number of descriptors required. - // NOTE: right now we are 1:1 with bindings. - int required_descriptor_count = bindings.size() * 1; - int max_descriptor_count = - std::max(8, RoundUpToNearestPow2(required_descriptor_count)); - int bucket = TrailingZeros(max_descriptor_count >> 3); - if (bucket >= descriptor_pool_buckets_.size()) { - return OutOfRangeErrorBuilder(IREE_LOC) - << "Too many descriptors required: " << required_descriptor_count - << " (max=" << (1 << (descriptor_pool_buckets_.size() + 3)) << ")"; - } - if (descriptor_pool_buckets_[bucket].handle == VK_NULL_HANDLE) { - // Acquire a pool for this max_descriptor_count bucket. - ASSIGN_OR_RETURN( - descriptor_pool_buckets_[bucket], - descriptor_pool_cache_->AcquireDescriptorPool( - VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, max_descriptor_count)); - used_descriptor_pools_.push_back(descriptor_pool_buckets_[bucket]); - } - auto& descriptor_pool = descriptor_pool_buckets_[bucket]; - - const auto& pipeline_descriptor_sets = executable->descriptor_sets(); - - VkDescriptorSetAllocateInfo allocate_info; - allocate_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO; - allocate_info.pNext = nullptr; - allocate_info.descriptorPool = descriptor_pool.handle; - allocate_info.descriptorSetCount = 1; - allocate_info.pSetLayouts = - &pipeline_descriptor_sets.buffer_binding_set_layout; - VkDescriptorSet descriptor_set = VK_NULL_HANDLE; - VkResult result = syms().vkAllocateDescriptorSets( - *logical_device_, &allocate_info, &descriptor_set); - if (result == VK_ERROR_OUT_OF_POOL_MEMORY) { - // Allocation failed because the pool is either out of descriptors or too - // fragmented. We'll just allocate another pool. - ASSIGN_OR_RETURN( - descriptor_pool_buckets_[bucket], - descriptor_pool_cache_->AcquireDescriptorPool( - VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, max_descriptor_count)); - used_descriptor_pools_.push_back(descriptor_pool_buckets_[bucket]); - } - - // Get a list of VkWriteDescriptorSet structs with all bound buffers. - ASSIGN_OR_RETURN(auto write_infos, PopulateDescriptorSetWriteInfos( - pipeline_descriptor_sets, bindings, - descriptor_set, &scratch_arena_)); - - // This is the reason why push descriptor sets are good. - // We can't batch these effectively as we don't know prior to recording what - // descriptor sets we will need and what buffers they will point to (without - // doing just as much work as actually recording the buffer to try to find - // out). - syms().vkUpdateDescriptorSets(*logical_device_, write_infos.size(), - write_infos.data(), 0, nullptr); - - // Bind the descriptor set. - syms().vkCmdBindDescriptorSets(command_buffer, VK_PIPELINE_BIND_POINT_COMPUTE, - executable->pipeline_layout(), - pipeline_descriptor_sets.buffer_binding_set, 1, - &descriptor_set, 0, nullptr); - - return OkStatus(); -} - -Status DescriptorSetArena::PushDescriptorSet( - VkCommandBuffer command_buffer, PipelineExecutable* executable, - absl::Span<const BufferBinding> bindings) { - IREE_TRACE_SCOPE0("DescriptorSetArena::PushDescriptorSet"); - - const auto& pipeline_descriptor_sets = executable->descriptor_sets(); - - // Get a list of VkWriteDescriptorSet structs with all bound buffers. - ASSIGN_OR_RETURN(auto write_infos, PopulateDescriptorSetWriteInfos( - pipeline_descriptor_sets, bindings, - VK_NULL_HANDLE, &scratch_arena_)); - - // Fast path using push descriptors. These are pooled internally by the - // command buffer and prevent the need for our own pooling mechanisms. - syms().vkCmdPushDescriptorSetKHR(command_buffer, - VK_PIPELINE_BIND_POINT_COMPUTE, - executable->pipeline_layout(), - pipeline_descriptor_sets.buffer_binding_set, - write_infos.size(), write_infos.data()); - - return OkStatus(); -} - -StatusOr<DescriptorSetGroup> DescriptorSetArena::Flush() { - IREE_TRACE_SCOPE0("DescriptorSetArena::Flush"); - - if (used_descriptor_pools_.empty()) { - // No resources to free. - return DescriptorSetGroup{}; - } - - for (auto& bucket : descriptor_pool_buckets_) { - bucket = {}; - } - return DescriptorSetGroup(add_ref(descriptor_pool_cache_), - std::move(used_descriptor_pools_)); -} - -} // namespace vulkan -} // namespace hal -} // namespace iree
diff --git a/iree/hal/vulkan/descriptor_set_arena.h b/iree/hal/vulkan/descriptor_set_arena.h deleted file mode 100644 index bbbe5bc..0000000 --- a/iree/hal/vulkan/descriptor_set_arena.h +++ /dev/null
@@ -1,76 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VULKAN_DESCRIPTOR_SET_ARENA_H_ -#define IREE_HAL_VULKAN_DESCRIPTOR_SET_ARENA_H_ - -#include <array> -#include <vector> - -#include "iree/base/arena.h" -#include "iree/base/status.h" -#include "iree/hal/command_buffer.h" -#include "iree/hal/vulkan/descriptor_pool_cache.h" -#include "iree/hal/vulkan/pipeline_executable.h" - -namespace iree { -namespace hal { -namespace vulkan { - -// A reusable arena for allocating descriptor sets and batching updates. -class DescriptorSetArena final { - public: - explicit DescriptorSetArena( - ref_ptr<DescriptorPoolCache> descriptor_pool_cache); - ~DescriptorSetArena(); - - // Allocates and binds a descriptor set from the arena. - // The command buffer will have the descriptor set containing |bindings| bound - // to it. - Status BindDescriptorSet(VkCommandBuffer command_buffer, - PipelineExecutable* executable, - absl::Span<const BufferBinding> bindings); - - // Flushes all pending writes to descriptor sets allocated from the arena and - // returns a group that - when dropped - will release the descriptor sets - // back to the pools they were allocated from. - StatusOr<DescriptorSetGroup> Flush(); - - private: - const DynamicSymbols& syms() const { return *logical_device_->syms(); } - - // Pushes the descriptor set to the command buffer, if supported. - Status PushDescriptorSet(VkCommandBuffer command_buffer, - PipelineExecutable* executable, - absl::Span<const BufferBinding> bindings); - - ref_ptr<VkDeviceHandle> logical_device_; - ref_ptr<DescriptorPoolCache> descriptor_pool_cache_; - - // Arena used for temporary binding information used during allocation. - Arena scratch_arena_; - - // A list of pools acquired on demand as different descriptor counts are - // needed. Allocation granularity is max_descriptor_count=[8, 16, 32, 64]. - std::array<DescriptorPool, 4> descriptor_pool_buckets_; - - // All pools that have been used during allocation. - absl::InlinedVector<DescriptorPool, 8> used_descriptor_pools_; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_DESCRIPTOR_SET_ARENA_H_
diff --git a/iree/hal/vulkan/direct_command_buffer.cc b/iree/hal/vulkan/direct_command_buffer.cc deleted file mode 100644 index 7ced0a2..0000000 --- a/iree/hal/vulkan/direct_command_buffer.cc +++ /dev/null
@@ -1,403 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/direct_command_buffer.h" - -#include "absl/base/attributes.h" -#include "absl/container/inlined_vector.h" -#include "absl/synchronization/mutex.h" -#include "iree/base/math.h" -#include "iree/base/source_location.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/vulkan/status_util.h" - -namespace iree { -namespace hal { -namespace vulkan { - -namespace { - -VkPipelineStageFlags ConvertPipelineStageFlags( - ExecutionStageBitfield stage_mask) { - VkPipelineStageFlags flags = 0; - flags |= AnyBitSet(stage_mask & ExecutionStage::kCommandIssue) - ? VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT - : 0; - flags |= AnyBitSet(stage_mask & ExecutionStage::kCommandProcess) - ? VK_PIPELINE_STAGE_DRAW_INDIRECT_BIT - : 0; - flags |= AnyBitSet(stage_mask & ExecutionStage::kDispatch) - ? VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT - : 0; - flags |= AnyBitSet(stage_mask & ExecutionStage::kTransfer) - ? VK_PIPELINE_STAGE_TRANSFER_BIT - : 0; - flags |= AnyBitSet(stage_mask & ExecutionStage::kCommandRetire) - ? VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT - : 0; - flags |= AnyBitSet(stage_mask & ExecutionStage::kHost) - ? VK_PIPELINE_STAGE_HOST_BIT - : 0; - return flags; -} - -VkAccessFlags ConvertAccessMask(AccessScopeBitfield access_mask) { - VkAccessFlags flags = 0; - flags |= AnyBitSet(access_mask & AccessScope::kIndirectCommandRead) - ? VK_ACCESS_INDIRECT_COMMAND_READ_BIT - : 0; - flags |= AnyBitSet(access_mask & AccessScope::kConstantRead) - ? VK_ACCESS_UNIFORM_READ_BIT - : 0; - flags |= AnyBitSet(access_mask & AccessScope::kDispatchRead) - ? VK_ACCESS_SHADER_READ_BIT - : 0; - flags |= AnyBitSet(access_mask & AccessScope::kDispatchWrite) - ? VK_ACCESS_SHADER_WRITE_BIT - : 0; - flags |= AnyBitSet(access_mask & AccessScope::kTransferRead) - ? VK_ACCESS_TRANSFER_READ_BIT - : 0; - flags |= AnyBitSet(access_mask & AccessScope::kTransferWrite) - ? VK_ACCESS_TRANSFER_WRITE_BIT - : 0; - flags |= AnyBitSet(access_mask & AccessScope::kHostRead) - ? VK_ACCESS_HOST_READ_BIT - : 0; - flags |= AnyBitSet(access_mask & AccessScope::kHostWrite) - ? VK_ACCESS_HOST_WRITE_BIT - : 0; - flags |= AnyBitSet(access_mask & AccessScope::kMemoryRead) - ? VK_ACCESS_MEMORY_READ_BIT - : 0; - flags |= AnyBitSet(access_mask & AccessScope::kMemoryWrite) - ? VK_ACCESS_MEMORY_WRITE_BIT - : 0; - return flags; -} - -// Splats a pattern value of 1, 2, or 4 bytes out to a 4 byte value. -uint32_t SplatPattern(const void* pattern, size_t pattern_length) { - switch (pattern_length) { - case 1: { - uint32_t pattern_value = *static_cast<const uint8_t*>(pattern); - return (pattern_value << 24) | (pattern_value << 16) | - (pattern_value << 8) | pattern_value; - } - case 2: { - uint32_t pattern_value = *static_cast<const uint16_t*>(pattern); - return (pattern_value << 16) | pattern_value; - } - case 4: { - uint32_t pattern_value = *static_cast<const uint32_t*>(pattern); - return pattern_value; - } - default: - return 0; // Already verified that this should not be possible. - } -} - -} // namespace - -DirectCommandBuffer::DirectCommandBuffer( - Allocator* allocator, CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories, - ref_ptr<DescriptorPoolCache> descriptor_pool_cache, - ref_ptr<VkCommandPoolHandle> command_pool, VkCommandBuffer command_buffer) - : CommandBuffer(allocator, mode, command_categories), - command_pool_(std::move(command_pool)), - command_buffer_(command_buffer), - descriptor_set_arena_(std::move(descriptor_pool_cache)) {} - -DirectCommandBuffer::~DirectCommandBuffer() { - IREE_TRACE_SCOPE0("DirectCommandBuffer::dtor"); - descriptor_set_group_.Reset().IgnoreError(); - absl::MutexLock lock(command_pool_->mutex()); - syms()->vkFreeCommandBuffers(*command_pool_->logical_device(), *command_pool_, - 1, &command_buffer_); -} - -StatusOr<NativeEvent*> DirectCommandBuffer::CastEvent(Event* event) const { - // TODO(benvanik): assert the event is valid. - return static_cast<NativeEvent*>(event); -} - -StatusOr<VmaBuffer*> DirectCommandBuffer::CastBuffer(Buffer* buffer) const { - // TODO(benvanik): assert that the buffer is from the right allocator and - // that it is compatible with our target queue family. - return static_cast<VmaBuffer*>(buffer->allocated_buffer()); -} - -Status DirectCommandBuffer::Begin() { - IREE_TRACE_SCOPE0("DirectCommandBuffer::Begin"); - - is_recording_ = true; - - // NOTE: we require that command buffers not be recorded while they are - // in-flight so this is safe. - RETURN_IF_ERROR(descriptor_set_group_.Reset()); - - VkCommandBufferBeginInfo begin_info; - begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; - begin_info.pNext = nullptr; - begin_info.flags = AllBitsSet(mode(), CommandBufferMode::kOneShot) - ? VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT - : 0; - begin_info.pInheritanceInfo = nullptr; - VK_RETURN_IF_ERROR( - syms()->vkBeginCommandBuffer(command_buffer_, &begin_info)); - - return OkStatus(); -} - -Status DirectCommandBuffer::End() { - IREE_TRACE_SCOPE0("DirectCommandBuffer::End"); - - VK_RETURN_IF_ERROR(syms()->vkEndCommandBuffer(command_buffer_)); - - // Flush all pending descriptor set writes (if any). - ASSIGN_OR_RETURN(descriptor_set_group_, descriptor_set_arena_.Flush()); - - is_recording_ = false; - - return OkStatus(); -} - -Status DirectCommandBuffer::ExecutionBarrier( - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span<const MemoryBarrier> memory_barriers, - absl::Span<const BufferBarrier> buffer_barriers) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::ExecutionBarrier"); - - absl::InlinedVector<VkMemoryBarrier, 8> memory_barrier_infos( - memory_barriers.size()); - for (int i = 0; i < memory_barriers.size(); ++i) { - const auto& memory_barrier = memory_barriers[i]; - auto& info = memory_barrier_infos[i]; - info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; - info.pNext = nullptr; - info.srcAccessMask = ConvertAccessMask(memory_barrier.source_scope); - info.dstAccessMask = ConvertAccessMask(memory_barrier.target_scope); - } - - absl::InlinedVector<VkBufferMemoryBarrier, 8> buffer_barrier_infos( - buffer_barriers.size()); - for (int i = 0; i < buffer_barriers.size(); ++i) { - const auto& buffer_barrier = buffer_barriers[i]; - auto& info = buffer_barrier_infos[i]; - info.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER; - info.pNext = nullptr; - info.srcAccessMask = ConvertAccessMask(buffer_barrier.source_scope); - info.dstAccessMask = ConvertAccessMask(buffer_barrier.target_scope); - info.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; - info.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; - ASSIGN_OR_RETURN(auto* device_buffer, CastBuffer(buffer_barrier.buffer)); - info.buffer = device_buffer->handle(); - info.offset = buffer_barrier.offset; - info.size = buffer_barrier.length; - } - - syms()->vkCmdPipelineBarrier( - command_buffer_, ConvertPipelineStageFlags(source_stage_mask), - ConvertPipelineStageFlags(target_stage_mask), /*dependencyFlags=*/0, - memory_barrier_infos.size(), memory_barrier_infos.data(), - buffer_barrier_infos.size(), buffer_barrier_infos.data(), 0, nullptr); - - return OkStatus(); -} - -Status DirectCommandBuffer::SignalEvent( - Event* event, ExecutionStageBitfield source_stage_mask) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::SignalEvent"); - ASSIGN_OR_RETURN(auto* device_event, CastEvent(event)); - syms()->vkCmdSetEvent(command_buffer_, device_event->handle(), - ConvertPipelineStageFlags(source_stage_mask)); - return OkStatus(); -} - -Status DirectCommandBuffer::ResetEvent( - Event* event, ExecutionStageBitfield source_stage_mask) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::ResetEvent"); - ASSIGN_OR_RETURN(auto* device_event, CastEvent(event)); - syms()->vkCmdResetEvent(command_buffer_, device_event->handle(), - ConvertPipelineStageFlags(source_stage_mask)); - return OkStatus(); -} - -Status DirectCommandBuffer::WaitEvents( - absl::Span<Event*> events, ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span<const MemoryBarrier> memory_barriers, - absl::Span<const BufferBarrier> buffer_barriers) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::WaitEvents"); - - absl::InlinedVector<VkEvent, 4> event_handles(events.size()); - for (int i = 0; i < events.size(); ++i) { - ASSIGN_OR_RETURN(auto* device_event, CastEvent(events[i])); - event_handles[i] = device_event->handle(); - } - - absl::InlinedVector<VkMemoryBarrier, 8> memory_barrier_infos( - memory_barriers.size()); - for (int i = 0; i < memory_barriers.size(); ++i) { - const auto& memory_barrier = memory_barriers[i]; - auto& info = memory_barrier_infos[i]; - info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; - info.pNext = nullptr; - info.srcAccessMask = ConvertAccessMask(memory_barrier.source_scope); - info.dstAccessMask = ConvertAccessMask(memory_barrier.target_scope); - } - - absl::InlinedVector<VkBufferMemoryBarrier, 8> buffer_barrier_infos( - buffer_barriers.size()); - for (int i = 0; i < buffer_barriers.size(); ++i) { - const auto& buffer_barrier = buffer_barriers[i]; - auto& info = buffer_barrier_infos[i]; - info.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER; - info.pNext = nullptr; - info.srcAccessMask = ConvertAccessMask(buffer_barrier.source_scope); - info.dstAccessMask = ConvertAccessMask(buffer_barrier.target_scope); - info.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; - info.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; - ASSIGN_OR_RETURN(auto* device_buffer, CastBuffer(buffer_barrier.buffer)); - info.buffer = device_buffer->handle(); - info.offset = buffer_barrier.offset; - info.size = buffer_barrier.length; - } - - syms()->vkCmdWaitEvents( - command_buffer_, event_handles.size(), event_handles.data(), - ConvertPipelineStageFlags(source_stage_mask), - ConvertPipelineStageFlags(target_stage_mask), memory_barrier_infos.size(), - memory_barrier_infos.data(), buffer_barrier_infos.size(), - buffer_barrier_infos.data(), 0, nullptr); - return OkStatus(); -} - -Status DirectCommandBuffer::FillBuffer(Buffer* target_buffer, - device_size_t target_offset, - device_size_t length, - const void* pattern, - size_t pattern_length) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::FillBuffer"); - ASSIGN_OR_RETURN(auto* target_device_buffer, CastBuffer(target_buffer)); - - // Note that fill only accepts 4-byte aligned values so we need to splat out - // our variable-length pattern. - target_offset += target_buffer->byte_offset(); - uint32_t dword_pattern = SplatPattern(pattern, pattern_length); - syms()->vkCmdFillBuffer(command_buffer_, target_device_buffer->handle(), - target_offset, length, dword_pattern); - - return OkStatus(); -} - -Status DirectCommandBuffer::DiscardBuffer(Buffer* buffer) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::DiscardBuffer"); - // NOTE: we could use this to prevent queue family transitions. - return OkStatus(); -} - -Status DirectCommandBuffer::UpdateBuffer(const void* source_buffer, - device_size_t source_offset, - Buffer* target_buffer, - device_size_t target_offset, - device_size_t length) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::UpdateBuffer"); - ASSIGN_OR_RETURN(auto* target_device_buffer, CastBuffer(target_buffer)); - - // Vulkan only allows updates of <= 65536 because you really, really, really - // shouldn't do large updates like this (as it wastes command buffer space and - // may be slower than just using write-through mapped memory). The - // recommendation in the spec for larger updates is to split the single update - // into multiple updates over the entire desired range. - const auto* source_buffer_ptr = static_cast<const uint8_t*>(source_buffer); - target_offset += target_buffer->byte_offset(); - while (length > 0) { - device_size_t chunk_length = - std::min(static_cast<device_size_t>(65536u), length); - syms()->vkCmdUpdateBuffer(command_buffer_, target_device_buffer->handle(), - target_offset, chunk_length, source_buffer_ptr); - source_buffer_ptr += chunk_length; - target_offset += chunk_length; - length -= chunk_length; - } - - return OkStatus(); -} - -Status DirectCommandBuffer::CopyBuffer(Buffer* source_buffer, - device_size_t source_offset, - Buffer* target_buffer, - device_size_t target_offset, - device_size_t length) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::CopyBuffer"); - ASSIGN_OR_RETURN(auto* source_device_buffer, CastBuffer(source_buffer)); - ASSIGN_OR_RETURN(auto* target_device_buffer, CastBuffer(target_buffer)); - - VkBufferCopy region; - region.srcOffset = source_buffer->byte_offset() + source_offset; - region.dstOffset = target_buffer->byte_offset() + target_offset; - region.size = length; - syms()->vkCmdCopyBuffer(command_buffer_, source_device_buffer->handle(), - target_device_buffer->handle(), 1, ®ion); - - return OkStatus(); -} - -Status DirectCommandBuffer::Dispatch(const DispatchRequest& dispatch_request) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::Dispatch"); - - // Get the compiled and linked pipeline for the specified entry point and - // bind it to the command buffer. - auto* executable = - static_cast<PipelineExecutable*>(dispatch_request.executable); - ASSIGN_OR_RETURN(VkPipeline pipeline, executable->GetPipelineForEntryPoint( - dispatch_request.entry_point)); - syms()->vkCmdBindPipeline(command_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, - pipeline); - - // Either allocate, update, and bind a descriptor set or use push descriptor - // sets to use the command buffer pool when supported. - RETURN_IF_ERROR(descriptor_set_arena_.BindDescriptorSet( - command_buffer_, executable, dispatch_request.bindings)); - - // TODO(benvanik): divide workload by caps and issue multiple dispatches. - // TODO(benvanik): track local workgroup/subgroup size and divide into groups. - if (dispatch_request.workload_buffer) { - return UnimplementedErrorBuilder(IREE_LOC) - << "Dynamic dispatches not yet implemented"; - } - uint32_t group_count_x = dispatch_request.workload[0]; - uint32_t group_count_y = dispatch_request.workload[1]; - uint32_t group_count_z = dispatch_request.workload[2]; - - // TODO(GH-67): pre-divide workload by tile size. - if (executable->is_matmul()) { - group_count_x = (group_count_x + 16 - 1) / 16; - group_count_y = (group_count_y + 16 - 1) / 16; - group_count_z = 1; - } - - syms()->vkCmdDispatch(command_buffer_, group_count_x, group_count_y, - group_count_z); - - return OkStatus(); -} - -} // namespace vulkan -} // namespace hal -} // namespace iree
diff --git a/iree/hal/vulkan/direct_command_buffer.h b/iree/hal/vulkan/direct_command_buffer.h deleted file mode 100644 index 989ea77..0000000 --- a/iree/hal/vulkan/direct_command_buffer.h +++ /dev/null
@@ -1,103 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VULKAN_DIRECT_COMMAND_BUFFER_H_ -#define IREE_HAL_VULKAN_DIRECT_COMMAND_BUFFER_H_ - -#include <vulkan/vulkan.h> - -#include "iree/hal/command_buffer.h" -#include "iree/hal/vulkan/descriptor_pool_cache.h" -#include "iree/hal/vulkan/descriptor_set_arena.h" -#include "iree/hal/vulkan/dynamic_symbols.h" -#include "iree/hal/vulkan/handle_util.h" -#include "iree/hal/vulkan/native_event.h" -#include "iree/hal/vulkan/pipeline_executable.h" -#include "iree/hal/vulkan/vma_buffer.h" - -namespace iree { -namespace hal { -namespace vulkan { - -// Command buffer implementation that directly maps to VkCommandBuffer. -// This records the commands on the calling thread without additional threading -// indirection. -class DirectCommandBuffer final : public CommandBuffer { - public: - DirectCommandBuffer(Allocator* allocator, CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories, - ref_ptr<DescriptorPoolCache> descriptor_pool_cache, - ref_ptr<VkCommandPoolHandle> command_pool, - VkCommandBuffer command_buffer); - ~DirectCommandBuffer() override; - - VkCommandBuffer handle() const { return command_buffer_; } - - bool is_recording() const override { return is_recording_; } - - Status Begin() override; - Status End() override; - - Status ExecutionBarrier( - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span<const MemoryBarrier> memory_barriers, - absl::Span<const BufferBarrier> buffer_barriers) override; - Status SignalEvent(Event* event, - ExecutionStageBitfield source_stage_mask) override; - Status ResetEvent(Event* event, - ExecutionStageBitfield source_stage_mask) override; - Status WaitEvents(absl::Span<Event*> events, - ExecutionStageBitfield source_stage_mask, - ExecutionStageBitfield target_stage_mask, - absl::Span<const MemoryBarrier> memory_barriers, - absl::Span<const BufferBarrier> buffer_barriers) override; - - Status FillBuffer(Buffer* target_buffer, device_size_t target_offset, - device_size_t length, const void* pattern, - size_t pattern_length) override; - Status DiscardBuffer(Buffer* buffer) override; - Status UpdateBuffer(const void* source_buffer, device_size_t source_offset, - Buffer* target_buffer, device_size_t target_offset, - device_size_t length) override; - Status CopyBuffer(Buffer* source_buffer, device_size_t source_offset, - Buffer* target_buffer, device_size_t target_offset, - device_size_t length) override; - - Status Dispatch(const DispatchRequest& dispatch_request) override; - - private: - const ref_ptr<DynamicSymbols>& syms() const { return command_pool_->syms(); } - - StatusOr<NativeEvent*> CastEvent(Event* event) const; - StatusOr<VmaBuffer*> CastBuffer(Buffer* buffer) const; - - bool is_recording_ = false; - ref_ptr<VkCommandPoolHandle> command_pool_; - VkCommandBuffer command_buffer_; - - // TODO(b/140026716): may grow large - should try to reclaim or reuse. - DescriptorSetArena descriptor_set_arena_; - - // The current descriptor set group in use by the command buffer, if any. - // This must remain valid until all in-flight submissions of the command - // buffer complete. - DescriptorSetGroup descriptor_set_group_; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_DIRECT_COMMAND_BUFFER_H_
diff --git a/iree/hal/vulkan/direct_command_queue.cc b/iree/hal/vulkan/direct_command_queue.cc deleted file mode 100644 index fb5cd59..0000000 --- a/iree/hal/vulkan/direct_command_queue.cc +++ /dev/null
@@ -1,201 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/direct_command_queue.h" - -#include <cstdint> - -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "iree/base/memory.h" -#include "iree/base/source_location.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/vulkan/direct_command_buffer.h" -#include "iree/hal/vulkan/legacy_fence.h" -#include "iree/hal/vulkan/native_binary_semaphore.h" -#include "iree/hal/vulkan/status_util.h" - -namespace iree { -namespace hal { -namespace vulkan { - -DirectCommandQueue::DirectCommandQueue( - std::string name, CommandCategoryBitfield supported_categories, - const ref_ptr<VkDeviceHandle>& logical_device, VkQueue queue) - : CommandQueue(std::move(name), supported_categories), - logical_device_(add_ref(logical_device)), - queue_(queue) {} - -DirectCommandQueue::~DirectCommandQueue() { - IREE_TRACE_SCOPE0("DirectCommandQueue::dtor"); - absl::MutexLock lock(&queue_mutex_); - syms()->vkQueueWaitIdle(queue_); -} - -Status DirectCommandQueue::TranslateBatchInfo(const SubmissionBatch& batch, - VkSubmitInfo* submit_info, - Arena* arena) { - // TODO(benvanik): see if we can go to finer-grained stages. - // For example, if this was just queue ownership transfers then we can use - // the pseudo-stage of VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT. - VkPipelineStageFlags dst_stage_mask = - VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT; - - auto wait_semaphore_handles = - arena->AllocateSpan<VkSemaphore>(batch.wait_semaphores.size()); - auto wait_dst_stage_masks = - arena->AllocateSpan<VkPipelineStageFlags>(batch.wait_semaphores.size()); - for (int i = 0; i < batch.wait_semaphores.size(); ++i) { - const auto& semaphore_value = batch.wait_semaphores[i]; - if (semaphore_value.index() == 0) { - const auto& binary_semaphore = - static_cast<NativeBinarySemaphore*>(absl::get<0>(semaphore_value)); - wait_semaphore_handles[i] = binary_semaphore->handle(); - } else { - // TODO(b/140141417): implement timeline semaphores. - return UnimplementedErrorBuilder(IREE_LOC) - << "Timeline semaphores not yet implemented"; - } - wait_dst_stage_masks[i] = dst_stage_mask; - } - - auto signal_semaphore_handles = - arena->AllocateSpan<VkSemaphore>(batch.signal_semaphores.size()); - for (int i = 0; i < batch.signal_semaphores.size(); ++i) { - const auto& semaphore_value = batch.signal_semaphores[i]; - if (semaphore_value.index() == 0) { - const auto& binary_semaphore = - static_cast<NativeBinarySemaphore*>(absl::get<0>(semaphore_value)); - signal_semaphore_handles[i] = binary_semaphore->handle(); - } else { - // TODO(b/140141417): implement timeline semaphores. - return UnimplementedErrorBuilder(IREE_LOC) - << "Timeline semaphores not yet implemented"; - } - } - - auto command_buffer_handles = - arena->AllocateSpan<VkCommandBuffer>(batch.command_buffers.size()); - for (int i = 0; i < batch.command_buffers.size(); ++i) { - const auto& command_buffer = batch.command_buffers[i]; - auto* direct_command_buffer = - static_cast<DirectCommandBuffer*>(command_buffer->impl()); - command_buffer_handles[i] = direct_command_buffer->handle(); - } - - submit_info->sType = VK_STRUCTURE_TYPE_SUBMIT_INFO; - submit_info->pNext = nullptr; - submit_info->waitSemaphoreCount = wait_semaphore_handles.size(); - submit_info->pWaitSemaphores = wait_semaphore_handles.data(); - submit_info->pWaitDstStageMask = wait_dst_stage_masks.data(); - submit_info->commandBufferCount = command_buffer_handles.size(); - submit_info->pCommandBuffers = command_buffer_handles.data(); - submit_info->signalSemaphoreCount = signal_semaphore_handles.size(); - submit_info->pSignalSemaphores = signal_semaphore_handles.data(); - - return OkStatus(); -} - -Status DirectCommandQueue::Submit(absl::Span<const SubmissionBatch> batches, - FenceValue fence) { - IREE_TRACE_SCOPE0("DirectCommandQueue::Submit"); - - // Map the submission batches to VkSubmitInfos. - // Note that we must keep all arrays referenced alive until submission - // completes and since there are a bunch of them we use an arena. - Arena arena(4 * 1024); - auto submit_infos = arena.AllocateSpan<VkSubmitInfo>(batches.size()); - for (int i = 0; i < batches.size(); ++i) { - RETURN_IF_ERROR(TranslateBatchInfo(batches[i], &submit_infos[i], &arena)); - } - - // TODO(b/140141417): implement timeline semaphore fences and switch here. - auto legacy_fence = reinterpret_cast<LegacyFence*>(fence.first); - ASSIGN_OR_RETURN(VkFence fence_handle, - legacy_fence->AcquireSignalFence(fence.second)); - - { - absl::MutexLock lock(&queue_mutex_); - VK_RETURN_IF_ERROR(syms()->vkQueueSubmit( - queue_, submit_infos.size(), submit_infos.data(), fence_handle)); - } - - return OkStatus(); -} - -Status DirectCommandQueue::WaitIdle(absl::Time deadline) { - if (deadline == absl::InfiniteFuture()) { - // Fast path for using vkQueueWaitIdle, which is usually cheaper (as it - // requires fewer calls into the driver). - IREE_TRACE_SCOPE0("DirectCommandQueue::WaitIdle#vkQueueWaitIdle"); - absl::MutexLock lock(&queue_mutex_); - VK_RETURN_IF_ERROR(syms()->vkQueueWaitIdle(queue_)); - return OkStatus(); - } - - IREE_TRACE_SCOPE0("DirectCommandQueue::WaitIdle#Fence"); - - // Create a new fence just for this wait. This keeps us thread-safe as the - // behavior of wait+reset is racey. - VkFenceCreateInfo create_info; - create_info.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO; - create_info.pNext = nullptr; - create_info.flags = 0; - VkFence fence = VK_NULL_HANDLE; - VK_RETURN_IF_ERROR(syms()->vkCreateFence( - *logical_device_, &create_info, logical_device_->allocator(), &fence)); - auto fence_cleanup = MakeCleanup([this, fence]() { - syms()->vkDestroyFence(*logical_device_, fence, - logical_device_->allocator()); - }); - - uint64_t timeout; - if (deadline == absl::InfinitePast()) { - // Do not wait. - timeout = 0; - } else if (deadline == absl::InfiniteFuture()) { - // Wait forever. - timeout = UINT64_MAX; - } else { - // Convert to relative time in nanoseconds. - // The implementation may not wait with this granularity (like, by 10000x). - absl::Time now = absl::Now(); - if (deadline < now) { - return DeadlineExceededErrorBuilder(IREE_LOC) << "Deadline in the past"; - } - timeout = static_cast<uint64_t>(absl::ToInt64Nanoseconds(deadline - now)); - } - - { - absl::MutexLock lock(&queue_mutex_); - VK_RETURN_IF_ERROR(syms()->vkQueueSubmit(queue_, 0, nullptr, fence)); - } - - VkResult result = - syms()->vkWaitForFences(*logical_device_, 1, &fence, VK_TRUE, timeout); - switch (result) { - case VK_SUCCESS: - return OkStatus(); - case VK_TIMEOUT: - return DeadlineExceededErrorBuilder(IREE_LOC) - << "Deadline exceeded waiting for idle"; - default: - return VkResultToStatus(result); - } -} - -} // namespace vulkan -} // namespace hal -} // namespace iree
diff --git a/iree/hal/vulkan/direct_command_queue.h b/iree/hal/vulkan/direct_command_queue.h deleted file mode 100644 index 99fd0dd..0000000 --- a/iree/hal/vulkan/direct_command_queue.h +++ /dev/null
@@ -1,69 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VULKAN_DIRECT_COMMAND_QUEUE_H_ -#define IREE_HAL_VULKAN_DIRECT_COMMAND_QUEUE_H_ - -#include <vulkan/vulkan.h> - -#include <cstdint> -#include <string> - -#include "absl/base/thread_annotations.h" -#include "absl/synchronization/mutex.h" -#include "absl/time/time.h" -#include "iree/base/arena.h" -#include "iree/base/status.h" -#include "iree/hal/command_queue.h" -#include "iree/hal/vulkan/dynamic_symbols.h" -#include "iree/hal/vulkan/handle_util.h" - -namespace iree { -namespace hal { -namespace vulkan { - -// Command queue implementation directly maps to VkQueue. -class DirectCommandQueue final : public CommandQueue { - public: - DirectCommandQueue(std::string name, - CommandCategoryBitfield supported_categories, - const ref_ptr<VkDeviceHandle>& logical_device, - VkQueue queue); - ~DirectCommandQueue() override; - - const ref_ptr<DynamicSymbols>& syms() const { - return logical_device_->syms(); - } - - Status Submit(absl::Span<const SubmissionBatch> batches, - FenceValue fence) override; - - Status WaitIdle(absl::Time deadline) override; - - private: - Status TranslateBatchInfo(const SubmissionBatch& batch, - VkSubmitInfo* submit_info, Arena* arena); - - ref_ptr<VkDeviceHandle> logical_device_; - - // VkQueue needs to be externally synchronized. - mutable absl::Mutex queue_mutex_; - VkQueue queue_ ABSL_GUARDED_BY(queue_mutex_); -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_DIRECT_COMMAND_QUEUE_H_
diff --git a/iree/hal/vulkan/dynamic_symbols.cc b/iree/hal/vulkan/dynamic_symbols.cc deleted file mode 100644 index 00380ec..0000000 --- a/iree/hal/vulkan/dynamic_symbols.cc +++ /dev/null
@@ -1,238 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/dynamic_symbols.h" - -#include <cstddef> -#include <cstdlib> - -#include "absl/base/attributes.h" -#include "absl/base/macros.h" -#include "absl/memory/memory.h" -#include "iree/base/file_path.h" -#include "iree/base/source_location.h" -#include "iree/base/status.h" -#include "iree/base/target_platform.h" -#include "iree/base/tracing.h" -#include "iree/hal/vulkan/dynamic_symbol_tables.h" - -#if defined(IREE_PLATFORM_WINDOWS) -#include <windows.h> -#else -#include <dlfcn.h> -#endif // IREE_PLATFORM_WINDOWS - -namespace iree { -namespace hal { -namespace vulkan { - -// Read-only table of function pointer information designed to be in .rdata. -// To reduce binary size this structure is packed (knowing that we won't have -// gigabytes of function pointers :). -struct FunctionPtrInfo { - // Name of the function (like 'vkSomeFunction'). - const char* function_name; - // 1 if the function pointer can be resolved via vkGetDeviceProcAddr. - uint32_t is_device : 1; - // 1 if the function is required and the loader should bail if not found. - uint32_t is_required : 1; - // TODO(benvanik): remove from table by manually walking sizeof(uintptr_t). - // An offset in bytes from the base of &syms to where the PFN_vkSomeFunction - // member is located. - uint32_t member_offset : 30; -} ABSL_ATTRIBUTE_PACKED; - -namespace { - -#define REQUIRED_PFN_FUNCTION_PTR(function_name, is_device) \ - {#function_name, is_device, 1, offsetof(DynamicSymbols, function_name)}, -#define OPTIONAL_PFN_FUNCTION_PTR(function_name, is_device) \ - {#function_name, is_device, 0, offsetof(DynamicSymbols, function_name)}, -#define EXCLUDED_PFN_FUNCTION_PTR(function_name, is_device) -#define INS_PFN_FUNCTION_PTR(requirement, function_name) \ - requirement##_PFN_FUNCTION_PTR(function_name, 0) -#define DEV_PFN_FUNCTION_PTR(requirement, function_name) \ - requirement##_PFN_FUNCTION_PTR(function_name, 1) - -// Defines the table of mandatory FunctionPtrInfos resolved prior to instance -// creation. These are safe to call with no instance parameter and should be -// exported by all loaders/ICDs. -static constexpr const FunctionPtrInfo kInstancelessFunctionPtrInfos[] = { - REQUIRED_PFN_FUNCTION_PTR(vkCreateInstance, false) // - REQUIRED_PFN_FUNCTION_PTR(vkEnumerateInstanceLayerProperties, false) // - REQUIRED_PFN_FUNCTION_PTR(vkEnumerateInstanceExtensionProperties, false) // -}; - -// Defines the table of FunctionPtrInfos for dynamic loading that must wait -// until an instance has been created to be resolved. -static constexpr const FunctionPtrInfo kDynamicFunctionPtrInfos[] = { - IREE_VULKAN_DYNAMIC_SYMBOL_TABLES(INS_PFN_FUNCTION_PTR, - DEV_PFN_FUNCTION_PTR)}; - -} // namespace - -// static -StatusOr<ref_ptr<DynamicSymbols>> DynamicSymbols::Create( - const GetProcAddrFn& get_proc_addr) { - IREE_TRACE_SCOPE0("DynamicSymbols::Create"); - - auto syms = make_ref<DynamicSymbols>(); - - // Resolve the method the shared object uses to resolve other functions. - // Some libraries will export all symbols while others will only export this - // single function. - syms->vkGetInstanceProcAddr = reinterpret_cast<PFN_vkGetInstanceProcAddr>( - get_proc_addr("vkGetInstanceProcAddr")); - if (!syms->vkGetInstanceProcAddr) { - return UnavailableErrorBuilder(IREE_LOC) - << "Required method vkGetInstanceProcAddr not " - "found in provided Vulkan library (did you pick the wrong file?)"; - } - - // Resolve the mandatory functions that we need to create instances. - // If the provided |get_proc_addr| cannot resolve these then it's not a loader - // or ICD we want to use, anyway. - for (int i = 0; i < ABSL_ARRAYSIZE(kInstancelessFunctionPtrInfos); ++i) { - const auto& function_ptr = kInstancelessFunctionPtrInfos[i]; - auto* member_ptr = reinterpret_cast<PFN_vkVoidFunction*>( - reinterpret_cast<uint8_t*>(syms.get()) + function_ptr.member_offset); - *member_ptr = - syms->vkGetInstanceProcAddr(VK_NULL_HANDLE, function_ptr.function_name); - if (*member_ptr == nullptr) { - return UnavailableErrorBuilder(IREE_LOC) - << "Mandatory Vulkan function " << function_ptr.function_name - << " not available; invalid loader/ICD?"; - } - } - - return syms; -} - -// static -StatusOr<ref_ptr<DynamicSymbols>> DynamicSymbols::CreateFromSystemLoader() { - IREE_TRACE_SCOPE0("DynamicSymbols::CreateFromSystemLoader"); - -#if defined(IREE_VK_ICD_FILENAMES) -#define IREE_STRINGIFY_(x) #x -#define IREE_STRING_(x) IREE_STRINGIFY_(x) - std::string vk_icd_filenames = IREE_STRING_(IREE_VK_ICD_FILENAMES); -#undef IREE_STRINGIFY_ -#undef IREE_STRING_ -#if defined(IREE_PLATFORM_WINDOWS) - // TODO(b/138220713): Set VK_ICD_FILENAMES on Windows -#else - ::setenv("VK_ICD_FILENAMES", vk_icd_filenames.c_str(), 0); -#endif // IREE_PLATFORM_WINDOWS -#else - // Leave VK_ICD_FILENAMES unchanged and rely on the system Vulkan loader to - // discover ICDs. -#endif // IREE_VK_ICD_FILENAMES - -// NOTE: we could factor this out into base, but this is the only place we use -// it right now so it's fine. -#if defined(IREE_PLATFORM_WINDOWS) - HMODULE library = ::LoadLibraryA("vulkan-1.dll"); - if (!library) { - return UnavailableErrorBuilder(IREE_LOC) - << "Unable to open vulkan-1.dll; driver not installed/on PATH"; - } - ASSIGN_OR_RETURN(auto syms, Create([library](const char* function_name) { - return reinterpret_cast<PFN_vkVoidFunction>( - ::GetProcAddress(library, function_name)); - })); - syms->close_fn_ = [library]() { - // TODO(benvanik): disable if we want to get profiling results. Sometimes - // closing the library can prevent proper symbolization on crashes or - // in sampling profilers. - ::FreeLibrary(library); - }; - return syms; -#else - void* library = ::dlopen("libvulkan.so.1", RTLD_LAZY | RTLD_LOCAL); - if (!library) { - return UnavailableErrorBuilder(IREE_LOC) - << "Unable to open libvulkan.so; driver not installed/on " - "LD_LIBRARY_PATH"; - } - ASSIGN_OR_RETURN(auto syms, Create([library](const char* function_name) { - return reinterpret_cast<PFN_vkVoidFunction>( - ::dlsym(library, function_name)); - })); - syms->close_fn_ = [library]() { - // TODO(benvanik): disable if we want to get profiling results. Sometimes - // closing the library can prevent proper symbolization on crashes or - // in sampling profilers. - ::dlclose(library); - }; - return syms; -#endif // IREE_PLATFORM_WINDOWS -} - -Status DynamicSymbols::LoadFromInstance(VkInstance instance) { - IREE_TRACE_SCOPE0("DynamicSymbols::LoadFromInstance"); - return LoadFromDevice(instance, VK_NULL_HANDLE); -} - -Status DynamicSymbols::LoadFromDevice(VkInstance instance, VkDevice device) { - IREE_TRACE_SCOPE0("DynamicSymbols::LoadFromDevice"); - - if (!instance) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Instance must have been created and a default instance proc " - "lookup function is required"; - } - - // Setup the lookup methods first. The rest of the syms uses these to - // resolve function pointers. - this->vkGetDeviceProcAddr = reinterpret_cast<PFN_vkGetDeviceProcAddr>( - this->vkGetInstanceProcAddr(instance, "vkGetDeviceProcAddr")); - if (!this->vkGetDeviceProcAddr) { - return UnavailableErrorBuilder(IREE_LOC) - << "Required Vulkan function vkGetDeviceProcAddr not available; " - "invalid driver handle?"; - } - - // Load the rest of the functions. - for (int i = 0; i < ABSL_ARRAYSIZE(kDynamicFunctionPtrInfos); ++i) { - const auto& function_ptr = kDynamicFunctionPtrInfos[i]; - auto* member_ptr = reinterpret_cast<PFN_vkVoidFunction*>( - reinterpret_cast<uint8_t*>(this) + function_ptr.member_offset); - if (function_ptr.is_device && device) { - *member_ptr = - this->vkGetDeviceProcAddr(device, function_ptr.function_name); - } else { - *member_ptr = - this->vkGetInstanceProcAddr(instance, function_ptr.function_name); - } - if (*member_ptr == nullptr && function_ptr.is_required) { - return UnavailableErrorBuilder(IREE_LOC) - << "Required Vulkan function " << function_ptr.function_name - << " not available"; - } - } - - return OkStatus(); -} - -DynamicSymbols::DynamicSymbols() = default; - -DynamicSymbols::~DynamicSymbols() { - if (close_fn_) { - close_fn_(); - } -} - -} // namespace vulkan -} // namespace hal -} // namespace iree
diff --git a/iree/hal/vulkan/dynamic_symbols.h b/iree/hal/vulkan/dynamic_symbols.h deleted file mode 100644 index 429c7f6..0000000 --- a/iree/hal/vulkan/dynamic_symbols.h +++ /dev/null
@@ -1,129 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VULKAN_DYNAMIC_SYMBOLS_H_ -#define IREE_HAL_VULKAN_DYNAMIC_SYMBOLS_H_ - -#include <vulkan/vulkan.h> - -#include <cstdint> -#include <functional> -#include <memory> - -#include "iree/base/ref_ptr.h" -#include "iree/base/status.h" -#include "iree/hal/vulkan/dynamic_symbol_tables.h" - -namespace iree { -namespace hal { -namespace vulkan { - -struct FunctionPtrInfo; - -// Dynamic Vulkan function loader for use with vulkan.hpp. -// This loader is a subset of the DispatchLoaderDynamic implementation that only -// loads functions we are interested in (a compute-specific subset) and avoids -// extensions we will never use. -// -// This exposes all Vulkan methods as function pointer members. Optional -// methods will be nullptr if not present. Excluded methods will be omitted. -// -// DynamicSymbols instances are designed to be passed to vulkan.hpp methods as -// the last argument, though they may also be called directly. -// **Always make sure to pass the loader to vulkan.hpp methods!** -// -// Loading is performed by walking a table of required and optional functions -// (defined in dynamic_symbol_tables.h) and populating the member function -// pointers exposed on this struct when available. For example, if the -// vkSomeFunction method is marked in the table as OPTIONAL the loader will -// attempt to lookup the function and if successful set the -// DynamicSymbols::vkSomeFunction pointer to the resolved address. If the -// function is not found then it will be set to nullptr so users can check for -// function availability. -// -// Documentation: -// https://github.com/KhronosGroup/Vulkan-Hpp#extensions--per-device-function-pointers -// -// Usage: -// ASSIGN_OR_RETURN(auto syms, DynamicSymbols::CreateFromSystemLoader()); -// VkInstance instance = VK_NULL_HANDLE; -// syms->vkCreateInstance(..., &instance); -// RETURN_IF_ERROR(syms->LoadFromInstance(instance)); -struct DynamicSymbols : public RefObject<DynamicSymbols> { - using GetProcAddrFn = - std::function<PFN_vkVoidFunction(const char* function_name)>; - - DynamicSymbols(); - ~DynamicSymbols(); - - // Creates the dynamic symbol table using the given |get_proc_addr| to resolve - // the vkCreateInstance function. - // - // After the instance is created the caller must use LoadFromInstance (or - // LoadFromDevice) to load the remaining symbols. - static StatusOr<ref_ptr<DynamicSymbols>> Create( - const GetProcAddrFn& get_proc_addr); - - // Loads all required and optional Vulkan functions from the Vulkan loader. - // This will look for a Vulkan loader on the system (like libvulkan.so) and - // dlsym the functions from that. - // - // The loaded function pointers will point to thunks in the ICD. This may - // enable additional debug checking and more readable stack traces (as - // errors come from within the ICD, where we have symbols). - static StatusOr<ref_ptr<DynamicSymbols>> CreateFromSystemLoader(); - - // Loads all required and optional Vulkan functions from the given instance. - // - // The loaded function pointers will point to thunks in the ICD. This may - // enable additional debug checking and more readable stack traces (as - // errors come from within the ICD, where we have symbols). - Status LoadFromInstance(VkInstance instance); - - // Loads all required and optional Vulkan functions from the given device, - // falling back to the instance when required. - // - // This attempts to directly query the methods from the device, bypassing any - // ICD or shim layers. These methods will generally have less overhead at - // runtime as they need not jump through the various trampolines. - Status LoadFromDevice(VkInstance instance, VkDevice device); - - // Define members for each function pointer. - // See dynamic_symbol_tables.h for the full list of methods. - // - // Each required and optional function in the loader tables will expand to - // the following member, such as for example 'vkSomeFunction': - // PFN_vkSomeFunction vkSomeFunction; -#define REQUIRED_PFN(function_name) PFN_##function_name function_name -#define OPTIONAL_PFN(function_name) PFN_##function_name function_name -#define EXCLUDED_PFN(function_name) -#define PFN_MEMBER(requirement, function_name) requirement##_PFN(function_name); - REQUIRED_PFN(vkGetInstanceProcAddr); - REQUIRED_PFN(vkGetDeviceProcAddr); - IREE_VULKAN_DYNAMIC_SYMBOL_TABLES(PFN_MEMBER, PFN_MEMBER); -#undef REQUIRED_PFN -#undef OPTIONAL_PFN -#undef EXCLUDED_PFN -#undef PFN_MEMBER - - private: - // Optional callback on loader destruction. - std::function<void()> close_fn_; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_DYNAMIC_SYMBOLS_H_
diff --git a/iree/hal/vulkan/dynamic_symbols_test.cc b/iree/hal/vulkan/dynamic_symbols_test.cc deleted file mode 100644 index 2576a44..0000000 --- a/iree/hal/vulkan/dynamic_symbols_test.cc +++ /dev/null
@@ -1,73 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/dynamic_symbols.h" - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "iree/base/status_matchers.h" -#include "iree/hal/vulkan/status_util.h" - -namespace iree { -namespace hal { -namespace vulkan { -namespace { - -VkApplicationInfo GetApplicationInfo() { - VkApplicationInfo app_info; - app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; - app_info.pNext = nullptr; - app_info.pApplicationName = "IREE-ML-TEST"; - app_info.applicationVersion = 0; - app_info.pEngineName = "IREE"; - app_info.engineVersion = 0; - app_info.apiVersion = VK_API_VERSION_1_0; - return app_info; -} - -VkInstanceCreateInfo GetInstanceCreateInfo(VkApplicationInfo* app_info) { - VkInstanceCreateInfo create_info; - create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; - create_info.pNext = nullptr; - create_info.flags = 0; - create_info.pApplicationInfo = app_info; - create_info.enabledLayerCount = 0; - create_info.ppEnabledLayerNames = nullptr; - create_info.enabledExtensionCount = 0; - create_info.ppEnabledExtensionNames = nullptr; - return create_info; -} - -TEST(DynamicSymbolsTest, CreateFromSystemLoader) { - auto status_or_syms = DynamicSymbols::CreateFromSystemLoader(); - ASSERT_OK(status_or_syms); - ref_ptr<DynamicSymbols> syms = std::move(status_or_syms.ValueOrDie()); - - // Create and destroy a VkInstance using the symbols. This is mainly testing - // that the symbols were loaded successfully and are actually able to be used. - VkApplicationInfo app_info = GetApplicationInfo(); - VkInstanceCreateInfo create_info = GetInstanceCreateInfo(&app_info); - VkInstance instance = VK_NULL_HANDLE; - VK_CHECK_OK( - syms->vkCreateInstance(&create_info, /*pAllocator=*/nullptr, &instance)); - - ASSERT_OK(syms->LoadFromInstance(instance)); - - syms->vkDestroyInstance(instance, /*pAllocator=*/nullptr); -} - -} // namespace -} // namespace vulkan -} // namespace hal -} // namespace iree
diff --git a/iree/hal/vulkan/extensibility_util.cc b/iree/hal/vulkan/extensibility_util.cc deleted file mode 100644 index ca1d2a7..0000000 --- a/iree/hal/vulkan/extensibility_util.cc +++ /dev/null
@@ -1,221 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/extensibility_util.h" - -#include "iree/base/memory.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/vulkan/status_util.h" - -namespace iree { -namespace hal { -namespace vulkan { - -namespace { - -StatusOr<std::vector<const char*>> MatchAvailableLayers( - absl::Span<const char* const> required_layers, - absl::Span<const char* const> optional_layers, - absl::Span<const VkLayerProperties> properties) { - IREE_TRACE_SCOPE0("MatchAvailableLayers"); - - std::vector<const char*> enabled_layers; - enabled_layers.reserve(required_layers.size() + optional_layers.size()); - - for (const char* layer_name : required_layers) { - bool found = false; - for (const auto& layer_properties : properties) { - if (std::strcmp(layer_name, layer_properties.layerName) == 0) { - VLOG(1) << "Enabling required layer: " << layer_name; - found = true; - enabled_layers.push_back(layer_name); - break; - } - } - if (!found) { - return UnavailableErrorBuilder(IREE_LOC) - << "Required layer " << layer_name << " not available"; - } - } - - for (const char* layer_name : optional_layers) { - bool found = false; - for (const auto& layer_properties : properties) { - if (std::strcmp(layer_name, layer_properties.layerName) == 0) { - VLOG(1) << "Enabling optional layer: " << layer_name; - found = true; - enabled_layers.push_back(layer_name); - break; - } - } - if (!found) { - VLOG(1) << "Optional layer " << layer_name << " not available"; - } - } - - return enabled_layers; -} - -StatusOr<std::vector<const char*>> MatchAvailableExtensions( - absl::Span<const char* const> required_extensions, - absl::Span<const char* const> optional_extensions, - absl::Span<const VkExtensionProperties> properties) { - IREE_TRACE_SCOPE0("MatchAvailableExtensions"); - - std::vector<const char*> enabled_extensions; - enabled_extensions.reserve(required_extensions.size() + - optional_extensions.size()); - - for (const char* extension_name : required_extensions) { - bool found = false; - for (const auto& extension_properties : properties) { - if (std::strcmp(extension_name, extension_properties.extensionName) == - 0) { - VLOG(1) << "Enabling required extension: " << extension_name; - found = true; - enabled_extensions.push_back(extension_name); - break; - } - } - if (!found) { - return UnavailableErrorBuilder(IREE_LOC) - << "Required extension " << extension_name << " not available"; - } - } - - for (const char* extension_name : optional_extensions) { - bool found = false; - for (const auto& extension_properties : properties) { - if (std::strcmp(extension_name, extension_properties.extensionName) == - 0) { - VLOG(1) << "Enabling optional extension: " << extension_name; - found = true; - enabled_extensions.push_back(extension_name); - break; - } - } - if (!found) { - VLOG(1) << "Optional extension " << extension_name << " not available"; - } - } - - return enabled_extensions; -} - -} // namespace - -StatusOr<std::vector<const char*>> MatchAvailableInstanceLayers( - const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms) { - uint32_t layer_property_count = 0; - VK_RETURN_IF_ERROR( - syms.vkEnumerateInstanceLayerProperties(&layer_property_count, nullptr)); - std::vector<VkLayerProperties> layer_properties(layer_property_count); - VK_RETURN_IF_ERROR(syms.vkEnumerateInstanceLayerProperties( - &layer_property_count, layer_properties.data())); - ASSIGN_OR_RETURN(auto enabled_layers, - MatchAvailableLayers(extensibility_spec.required_layers, - extensibility_spec.optional_layers, - layer_properties), - _ << "Unable to find all required instance layers"); - return enabled_layers; -} - -StatusOr<std::vector<const char*>> MatchAvailableInstanceExtensions( - const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms) { - uint32_t extension_property_count = 0; - // Warning: leak checks remain disabled if an error is returned. - IREE_DISABLE_LEAK_CHECKS(); - VK_RETURN_IF_ERROR(syms.vkEnumerateInstanceExtensionProperties( - nullptr, &extension_property_count, nullptr)); - std::vector<VkExtensionProperties> extension_properties( - extension_property_count); - VK_RETURN_IF_ERROR(syms.vkEnumerateInstanceExtensionProperties( - nullptr, &extension_property_count, extension_properties.data())); - ASSIGN_OR_RETURN( - auto enabled_extensions, - MatchAvailableExtensions(extensibility_spec.required_extensions, - extensibility_spec.optional_extensions, - extension_properties), - _ << "Unable to find all required instance extensions"); - IREE_ENABLE_LEAK_CHECKS(); - return enabled_extensions; -} - -StatusOr<std::vector<const char*>> MatchAvailableDeviceLayers( - VkPhysicalDevice physical_device, - const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms) { - uint32_t layer_property_count = 0; - VK_RETURN_IF_ERROR(syms.vkEnumerateDeviceLayerProperties( - physical_device, &layer_property_count, nullptr)); - std::vector<VkLayerProperties> layer_properties(layer_property_count); - VK_RETURN_IF_ERROR(syms.vkEnumerateDeviceLayerProperties( - physical_device, &layer_property_count, layer_properties.data())); - ASSIGN_OR_RETURN(auto enabled_layers, - MatchAvailableLayers(extensibility_spec.required_layers, - extensibility_spec.optional_layers, - layer_properties), - _ << "Unable to find all required device layers"); - return enabled_layers; -} - -StatusOr<std::vector<const char*>> MatchAvailableDeviceExtensions( - VkPhysicalDevice physical_device, - const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms) { - uint32_t extension_property_count = 0; - VK_RETURN_IF_ERROR(syms.vkEnumerateDeviceExtensionProperties( - physical_device, nullptr, &extension_property_count, nullptr)); - std::vector<VkExtensionProperties> extension_properties( - extension_property_count); - VK_RETURN_IF_ERROR(syms.vkEnumerateDeviceExtensionProperties( - physical_device, nullptr, &extension_property_count, - extension_properties.data())); - ASSIGN_OR_RETURN( - auto enabled_extensions, - MatchAvailableExtensions(extensibility_spec.required_extensions, - extensibility_spec.optional_extensions, - extension_properties), - _ << "Unable to find all required device extensions"); - return enabled_extensions; -} - -InstanceExtensions PopulateEnabledInstanceExtensions( - absl::Span<const char* const> extension_names) { - InstanceExtensions extensions = {0}; - for (const char* extension_name : extension_names) { - if (std::strcmp(extension_name, VK_EXT_DEBUG_REPORT_EXTENSION_NAME) == 0) { - extensions.debug_report = true; - } else if (std::strcmp(extension_name, VK_EXT_DEBUG_UTILS_EXTENSION_NAME) == - 0) { - extensions.debug_utils = true; - } - } - return extensions; -} - -DeviceExtensions PopulateEnabledDeviceExtensions( - absl::Span<const char* const> extension_names) { - DeviceExtensions extensions = {0}; - for (const char* extension_name : extension_names) { - if (std::strcmp(extension_name, VK_KHR_PUSH_DESCRIPTOR_EXTENSION_NAME) == - 0) { - extensions.push_descriptors = true; - } - } - return extensions; -} - -} // namespace vulkan -} // namespace hal -} // namespace iree
diff --git a/iree/hal/vulkan/extensibility_util.h b/iree/hal/vulkan/extensibility_util.h deleted file mode 100644 index 8f2b784..0000000 --- a/iree/hal/vulkan/extensibility_util.h +++ /dev/null
@@ -1,100 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Utilities for working with layers and extensions. - -#ifndef IREE_HAL_VULKAN_EXTENSIBILITY_UTIL_H_ -#define IREE_HAL_VULKAN_EXTENSIBILITY_UTIL_H_ - -#include <vulkan/vulkan.h> - -#include "absl/types/span.h" -#include "iree/base/status.h" -#include "iree/hal/vulkan/dynamic_symbols.h" - -namespace iree { -namespace hal { -namespace vulkan { - -// Describes required and optional extensibility points. -struct ExtensibilitySpec { - // A list of required and optional layers. - std::vector<const char*> required_layers; - std::vector<const char*> optional_layers; - - // A list of required and optional extensions. - // Prefer using the _EXTENSION_NAME macros to make tracking easier (such as - // 'VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME'). - std::vector<const char*> required_extensions; - std::vector<const char*> optional_extensions; -}; - -// Returns a list of layer names available for instances. -// Fails if any required_layers are unavailable. -StatusOr<std::vector<const char*>> MatchAvailableInstanceLayers( - const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms); - -// Returns a list of extension names available for instances. -// Fails if any required_extensions are unavailable. -StatusOr<std::vector<const char*>> MatchAvailableInstanceExtensions( - const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms); - -// Returns a list of layer names available for the given |physical_device|. -// Fails if any required_layers are unavailable. -StatusOr<std::vector<const char*>> MatchAvailableDeviceLayers( - VkPhysicalDevice physical_device, - const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms); - -// Returns a list of extension names available for the given |physical_device|. -// Fails if any required_extensions are unavailable. -StatusOr<std::vector<const char*>> MatchAvailableDeviceExtensions( - VkPhysicalDevice physical_device, - const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms); - -// Bits for enabled instance extensions. -// We must use this to query support instead of just detecting symbol names as -// ICDs will resolve the functions sometimes even if they don't support the -// extension (or we didn't ask for it to be enabled). -struct InstanceExtensions { - // VK_EXT_debug_report is enabled and a callback is regsitered. - // https://www.khronos.org/registry/vulkan/specs/1.1-extensions/html/chap44.html#VK_EXT_debug_report - bool debug_report : 1; - - // VK_EXT_debug_utils is enabled and a debug messenger is registered. - // https://www.khronos.org/registry/vulkan/specs/1.1-extensions/html/chap44.html#VK_EXT_debug_utils - bool debug_utils : 1; -}; - -// Returns a bitfield with all of the provided extension names. -InstanceExtensions PopulateEnabledInstanceExtensions( - absl::Span<const char* const> extension_names); - -// Bits for enabled device extensions. -// We must use this to query support instead of just detecting symbol names as -// ICDs will resolve the functions sometimes even if they don't support the -// extension (or we didn't ask for it to be enabled). -struct DeviceExtensions { - // VK_KHR_push_descriptor is enabled and vkCmdPushDescriptorSetKHR is valid. - bool push_descriptors : 1; -}; - -// Returns a bitfield with all of the provided extension names. -DeviceExtensions PopulateEnabledDeviceExtensions( - absl::Span<const char* const> extension_names); - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_EXTENSIBILITY_UTIL_H_
diff --git a/iree/hal/vulkan/handle_util.h b/iree/hal/vulkan/handle_util.h deleted file mode 100644 index 3173de4..0000000 --- a/iree/hal/vulkan/handle_util.h +++ /dev/null
@@ -1,136 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Helpers for wrapping Vulkan handles that don't require us to wrap every type. -// This keeps our compilation time reasonable (as the vulkancpp library is -// insane) while giving us nice safety around cleanup and ensuring we use -// dynamic symbols and consistent allocators. -// -// Do not add functionality beyond handle management to these types. Keep our -// Vulkan usage mostly functional and C-like to ensure minimal code size and -// readability. - -#ifndef IREE_HAL_VULKAN_HANDLE_UTIL_H_ -#define IREE_HAL_VULKAN_HANDLE_UTIL_H_ - -#include <vulkan/vulkan.h> - -#include "absl/synchronization/mutex.h" -#include "absl/utility/utility.h" -#include "iree/base/ref_ptr.h" -#include "iree/hal/vulkan/dynamic_symbols.h" -#include "iree/hal/vulkan/extensibility_util.h" - -namespace iree { -namespace hal { -namespace vulkan { - -class VkDeviceHandle : public RefObject<VkDeviceHandle> { - public: - VkDeviceHandle(const ref_ptr<DynamicSymbols>& syms, - DeviceExtensions enabled_extensions, - const VkAllocationCallbacks* allocator = nullptr) - : syms_(add_ref(syms)), - enabled_extensions_(enabled_extensions), - allocator_(allocator) {} - ~VkDeviceHandle() { reset(); } - - VkDeviceHandle(const VkDeviceHandle&) = delete; - VkDeviceHandle& operator=(const VkDeviceHandle&) = delete; - VkDeviceHandle(VkDeviceHandle&& other) noexcept - : value_(absl::exchange(other.value_, - static_cast<VkDevice>(VK_NULL_HANDLE))), - syms_(std::move(other.syms_)), - enabled_extensions_(other.enabled_extensions_), - allocator_(other.allocator_) {} - - void reset() { - if (value_ == VK_NULL_HANDLE) return; - syms_->vkDestroyDevice(value_, allocator_); - value_ = VK_NULL_HANDLE; - } - - VkDevice value() const noexcept { return value_; } - VkDevice* mutable_value() noexcept { return &value_; } - operator VkDevice() const noexcept { return value_; } - - const ref_ptr<DynamicSymbols>& syms() const noexcept { return syms_; } - const VkAllocationCallbacks* allocator() const noexcept { return allocator_; } - - const DeviceExtensions& enabled_extensions() const { - return enabled_extensions_; - } - - private: - VkDevice value_ = VK_NULL_HANDLE; - ref_ptr<DynamicSymbols> syms_; - DeviceExtensions enabled_extensions_; - const VkAllocationCallbacks* allocator_ = nullptr; -}; - -class VkCommandPoolHandle : public RefObject<VkCommandPoolHandle> { - public: - explicit VkCommandPoolHandle(const ref_ptr<VkDeviceHandle>& logical_device) - : logical_device_(add_ref(logical_device)) {} - ~VkCommandPoolHandle() { reset(); } - - VkCommandPoolHandle(const VkCommandPoolHandle&) = delete; - VkCommandPoolHandle& operator=(const VkCommandPoolHandle&) = delete; - VkCommandPoolHandle(VkCommandPoolHandle&& other) noexcept - : logical_device_(std::move(other.logical_device_)), - value_(absl::exchange(other.value_, - static_cast<VkCommandPool>(VK_NULL_HANDLE))) {} - VkCommandPoolHandle& operator=(VkCommandPoolHandle&& other) { - std::swap(logical_device_, other.logical_device_); - std::swap(value_, other.value_); - return *this; - } - - void reset() { - if (value_ == VK_NULL_HANDLE) return; - syms()->vkDestroyCommandPool(*logical_device_, value_, allocator()); - value_ = VK_NULL_HANDLE; - } - - VkCommandPool value() const noexcept { return value_; } - VkCommandPool* mutable_value() noexcept { return &value_; } - operator VkCommandPool() const noexcept { return value_; } - - const ref_ptr<VkDeviceHandle>& logical_device() const noexcept { - return logical_device_; - } - const ref_ptr<DynamicSymbols>& syms() const noexcept { - return logical_device_->syms(); - } - const VkAllocationCallbacks* allocator() const noexcept { - return logical_device_->allocator(); - } - - absl::Mutex* mutex() const { return &mutex_; } - - private: - ref_ptr<VkDeviceHandle> logical_device_; - VkCommandPool value_ = VK_NULL_HANDLE; - - // Vulkan command pools are not thread safe and require external - // synchronization. Since we allow arbitrary threads to allocate and - // deallocate the HAL command buffers we need to externally synchronize. - mutable absl::Mutex mutex_; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_HANDLE_UTIL_H_
diff --git a/iree/hal/vulkan/internal_vk_mem_alloc.cc b/iree/hal/vulkan/internal_vk_mem_alloc.cc deleted file mode 100644 index 0884a34..0000000 --- a/iree/hal/vulkan/internal_vk_mem_alloc.cc +++ /dev/null
@@ -1,68 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// This file configures VMA to use common Google/Abseil types in an effort to -// better integrate with applications compiled using other Google code. By using -// the same types that dependers are likely using we can often reduce binary -// size and ease debugging (such as by using absl::Mutex to get better tsan -// warnings). - -// Only compile if an external implementation has not been otherwise linked. -#if !defined(VULKAN_MEMORY_ALLOCATOR_EXTERNAL_IMPL) - -#include "absl/container/flat_hash_map.h" -#include "absl/synchronization/mutex.h" -#include "iree/base/logging.h" - -// Use std::vector instead of the VMA version. -#define VMA_USE_STL_VECTOR 1 - -// TODO(benvanik): figure out why std::list cannot be used. -// #define VMA_USE_STL_LIST 1 - -// Use absl::flat_hash_map instead of std::unordered_map. -#define VmaPair std::pair -#define VMA_MAP_TYPE(KeyT, ValueT) \ - absl::flat_hash_map<KeyT, ValueT, std::hash<KeyT>, std::equal_to<KeyT>, \ - VmaStlAllocator<std::pair<KeyT, ValueT> > > - -// Use CHECK for assertions. -#define VMA_ASSERT CHECK -#define VMA_HEAVY_ASSERT DCHECK - -// Use LOG for logging. -#ifndef NDEBUG -#define VMA_DEBUG_LOG(...) ABSL_RAW_LOG(INFO, __VA_ARGS__) -#else -#define VMA_DEBUG_LOG(...) -#endif // !NDEBUG - -// Use absl::Mutex for VMA_MUTEX. -#define VMA_MUTEX absl::Mutex -class AbslVmaRWMutex { - public: - void LockRead() ABSL_SHARED_LOCK_FUNCTION() { mutex_.ReaderLock(); } - void UnlockRead() ABSL_UNLOCK_FUNCTION() { mutex_.ReaderUnlock(); } - void LockWrite() ABSL_EXCLUSIVE_LOCK_FUNCTION() { mutex_.WriterLock(); } - void UnlockWrite() ABSL_UNLOCK_FUNCTION() { mutex_.WriterUnlock(); } - - private: - absl::Mutex mutex_; -}; -#define VMA_RW_MUTEX AbslVmaRWMutex - -#define VMA_IMPLEMENTATION -#include "vk_mem_alloc.h" - -#endif
diff --git a/iree/hal/vulkan/legacy_fence.cc b/iree/hal/vulkan/legacy_fence.cc deleted file mode 100644 index 481e855..0000000 --- a/iree/hal/vulkan/legacy_fence.cc +++ /dev/null
@@ -1,396 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/legacy_fence.h" - -#include <cstdint> - -#include "absl/container/inlined_vector.h" -#include "absl/synchronization/mutex.h" -#include "absl/time/time.h" -#include "iree/base/intrusive_list.h" -#include "iree/base/source_location.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/vulkan/status_util.h" - -namespace iree { -namespace hal { -namespace vulkan { - -namespace { - -// Inserts the given |fence_signal| into |list| in ascending order. -void InsertOutstandingFenceSignal(OutstandingFenceSignal* fence_signal, - IntrusiveList<OutstandingFenceSignal>* list) { - for (auto existing_signal : *list) { - if (existing_signal->value > fence_signal->value) { - list->insert(existing_signal, fence_signal); - return; - } - } - list->push_back(fence_signal); -} - -} // namespace - -// static -StatusOr<ref_ptr<LegacyFencePool>> LegacyFencePool::Create( - ref_ptr<VkDeviceHandle> logical_device) { - IREE_TRACE_SCOPE0("LegacyFencePool::Create"); - ref_ptr<LegacyFencePool> fence_pool( - new LegacyFencePool(std::move(logical_device))); - RETURN_IF_ERROR(fence_pool->PreallocateFences()); - return fence_pool; -} - -LegacyFencePool::LegacyFencePool(ref_ptr<VkDeviceHandle> logical_device) - : logical_device_(std::move(logical_device)) {} - -LegacyFencePool::~LegacyFencePool() { - IREE_TRACE_SCOPE0("LegacyFencePool::dtor"); - - absl::MutexLock lock(&mutex_); - for (auto& fence_signal : storage_) { - syms()->vkDestroyFence(*logical_device_, fence_signal.fence, - logical_device_->allocator()); - } - unused_fences_.clear(); - unresolved_fences_.clear(); -} - -Status LegacyFencePool::PreallocateFences() { - IREE_TRACE_SCOPE0("LegacyFencePool::PreallocateFences"); - - VkFenceCreateInfo create_info; - create_info.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO; - create_info.pNext = nullptr; - create_info.flags = 0; - - absl::MutexLock lock(&mutex_); - for (int i = 0; i < kMaxInFlightFenceCount; ++i) { - auto* fence_signal = &storage_[i]; - VK_RETURN_IF_ERROR(syms()->vkCreateFence(*logical_device_, &create_info, - logical_device_->allocator(), - &fence_signal->fence)); - unused_fences_.push_back(fence_signal); - } - - return OkStatus(); -} - -StatusOr<OutstandingFenceSignal*> LegacyFencePool::Acquire() { - IREE_TRACE_SCOPE0("LegacyFencePool::Acquire"); - - absl::MutexLock lock(&mutex_); - if (unused_fences_.empty()) { - return ResourceExhaustedErrorBuilder(IREE_LOC) - << "Fence pool out of unused fences"; - } - - auto* fence_signal = unused_fences_.front(); - unused_fences_.pop_front(); - return fence_signal; -} - -void LegacyFencePool::ReleaseResolved( - IntrusiveList<OutstandingFenceSignal>* fence_signals) { - IREE_TRACE_SCOPE0("LegacyFencePool::ReleaseResolved"); - - // Get a list of fences we need to reset. Note that not all fences may have - // been signaled and we can avoid resetting them. - absl::InlinedVector<VkFence, 8> handles; - handles.reserve(fence_signals->size()); - for (auto* fence_signal : *fence_signals) { - if (fence_signal->is_pending) { - handles.push_back(fence_signal->fence); - } - } - if (!handles.empty()) { - syms()->vkResetFences(*logical_device_, handles.size(), handles.data()); - } - - absl::MutexLock lock(&mutex_); - unused_fences_.merge_from(fence_signals); -} - -void LegacyFencePool::ReleaseUnresolved( - IntrusiveList<OutstandingFenceSignal>* fence_signals) { - IREE_TRACE_SCOPE0("LegacyFencePool::ReleaseUnresolved"); - - absl::MutexLock lock(&mutex_); - while (!fence_signals->empty()) { - auto* fence_signal = fence_signals->front(); - fence_signals->pop_front(); - if (fence_signal->is_pending) { - // Fence was submitted and may still have a pending signal on it. We can't - // reuse it until it has resolved. - // TODO(benvanik): fix these fences by reallocating? We aren't leaking - // here (technically) but we will exhaust the pool pretty quickly. - unresolved_fences_.push_back(fence_signal); - } else { - // Fence was never actually submitted so we can reuse it no problem. - unused_fences_.push_back(fence_signal); - } - } -} - -// static -Status LegacyFence::WaitForFences(VkDeviceHandle* logical_device, - absl::Span<const FenceValue> fences, - bool wait_all, absl::Time deadline) { - IREE_TRACE_SCOPE0("LegacyFence::WaitForFences"); - - // NOTE: we could pool this state too (probably right on the LegacyFencePool) - // or be smarter about using stack-allocated storage. The best idea is to use - // real timeline semaphores, though, so not much effort has been spent on - // optimizing this. - absl::InlinedVector<VkFence, 4> handles; - handles.reserve(fences.size()); - - // Loop over the fences and wait for any/all to signal. In wait_all mode we - // perform the bookkeeping to remove fences that have already been signaled so - // that we only wait on ones we need to (and possibly avoid making the vk call - // entirely!). - while (true) { - // Grab handles and acquire fences for all fences not yet at the requested - // timeline value. - for (const auto& fence_value : fences) { - auto* fence = reinterpret_cast<LegacyFence*>(fence_value.first); - // NOTE: this will return the sticky fence error if the fence has failed. - ASSIGN_OR_RETURN(VkFence handle, - fence->AcquireWaitFence(fence_value.second)); - if (handle != VK_NULL_HANDLE) { - // Fence is unresolved and we need to really wait for it. - handles.push_back(handle); - } - } - if (handles.empty()) { - // All fences resolved. - return OkStatus(); - } - - uint64_t timeout_nanos; - if (deadline == absl::InfiniteFuture()) { - timeout_nanos = UINT64_MAX; - } else if (deadline == absl::InfinitePast()) { - timeout_nanos = 0; - } else { - auto relative_nanos = absl::ToInt64Nanoseconds(deadline - absl::Now()); - timeout_nanos = relative_nanos < 0 ? 0 : relative_nanos; - } - - // Wait on the fences we still need. - // Note that waking does not actually indicate all fences were hit! We need - // to do another pass above on the next iteration to make sure that we don't - // need to wait again on another fence. - VK_RETURN_IF_ERROR(logical_device->syms()->vkWaitForFences( - *logical_device, handles.size(), handles.data(), wait_all, - timeout_nanos)); - handles.clear(); - } - - return OkStatus(); -} - -LegacyFence::LegacyFence(ref_ptr<LegacyFencePool> fence_pool, - uint64_t initial_value) - : fence_pool_(std::move(fence_pool)), value_(initial_value) {} - -LegacyFence::~LegacyFence() { - IREE_TRACE_SCOPE0("LegacyFence::dtor"); - CHECK_OK(TryResolveOutstandingFences(UINT64_MAX)); - absl::MutexLock lock(&mutex_); - CHECK(outstanding_signals_.empty()) - << "Destroying a fence without first waiting on outstanding signals"; -} - -Status LegacyFence::status() const { - if (value_.load() != UINT64_MAX) { - return OkStatus(); - } - absl::MutexLock lock(&mutex_); - return status_; -} - -StatusOr<uint64_t> LegacyFence::QueryValue() { - RETURN_IF_ERROR(TryResolveOutstandingFences(UINT64_MAX)); - return value_.load(); -} - -StatusOr<VkFence> LegacyFence::AcquireSignalFence(uint64_t value) { - absl::MutexLock lock(&mutex_); - - // It's an error to signal out of order (as that requires a lot more - // tracking and magic to get right). - if (value_.load() >= value) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Attempting to signal a timeline fence out of order; value=" - << value_ << ", new_value=" << value; - } - - // Scan to see if there's waiters for this value (or values before it). - // We may be able to reuse a previously allocated fence in the case that a - // user is waiting prior to actually submitting the signal operation. - OutstandingFenceSignal* signal_state = nullptr; - for (auto* fence_signal : outstanding_signals_) { - if (fence_signal->value == value) { - // Fence is going to be signaled at exactly the required value. - if (fence_signal->is_pending) { - // Already have signaled to this value - that's a paddlin'. - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Duplicate signal of timeline fence for value=" << value; - } - signal_state = fence_signal; - break; - } - } - if (!signal_state) { - // Allocate a signal state entry and a VkFence to submit with. - // TODO(benvanik): check for RESOURCE_EXHAUSTED and force a flush. - ASSIGN_OR_RETURN(signal_state, fence_pool_->Acquire()); - signal_state->value = value; - InsertOutstandingFenceSignal(signal_state, &outstanding_signals_); - } - - signal_state->is_pending = true; - return signal_state->fence; -} - -StatusOr<VkFence> LegacyFence::AcquireWaitFence(uint64_t value) { - // If we've already resolved then we want to avoid doing any kind of wait. - // Since the value is monotonically increasing we can do a lock-free peek - // here to see if we need to bother taking a full lock. - if (value_.load() >= value) { - return VK_NULL_HANDLE; - } - - absl::MutexLock lock(&mutex_); - - // Try to resolve any outstanding fence signals. - RETURN_IF_ERROR(TryResolveOutstandingFencesLocked(value)); - if (value_.load() >= value) { - return VK_NULL_HANDLE; - } - - // Try to find an existing fence we can reuse based on the required value. - OutstandingFenceSignal* signal_state = nullptr; - for (auto* fence_signal : outstanding_signals_) { - if (fence_signal->value >= value) { - // Fence is going to be signaled at or above the required value. - signal_state = fence_signal; - break; // |outstanding_signals_| is in sorted order. - } - } - if (!signal_state) { - // Allocate a signal state entry and a VkFence that we will need to signal - // in the future. We can't yet insert it into the queue but it will go in - // when the user tries to signal a value >= the required value. - // TODO(benvanik): check for RESOURCE_EXHAUSTED and force a flush. - ASSIGN_OR_RETURN(signal_state, fence_pool_->Acquire()); - signal_state->value = value; - InsertOutstandingFenceSignal(signal_state, &outstanding_signals_); - } - - return signal_state->fence; -} - -Status LegacyFence::TryResolveOutstandingFences(uint64_t upper_value) { - absl::MutexLock lock(&mutex_); - return TryResolveOutstandingFencesLocked(upper_value); -} - -Status LegacyFence::TryResolveOutstandingFencesLocked(uint64_t upper_value) { - // Fast-path for when we have no outstanding fences. - // NOTE: we hold the lock during the entire resolve process so that any waiter - // will only be woken once we have resolved to the furthest possible value. - if (outstanding_signals_.empty() || value_ > upper_value) { - return OkStatus(); - } - - IREE_TRACE_SCOPE0("LegacyFence::TryResolveOutstandingFences"); - - IntrusiveList<OutstandingFenceSignal> resolved_fences; - IntrusiveList<OutstandingFenceSignal> unresolved_fences; - VkDevice device = *fence_pool_->logical_device(); - const auto& syms = fence_pool_->syms(); - bool keep_resolving = true; - while (keep_resolving && !outstanding_signals_.empty()) { - auto* fence_signal = outstanding_signals_.front(); - if (fence_signal->value > upper_value) { - // Signal is for a value beyond our upper limit - early exit so that we - // don't spend time dealing with signals we don't yet care about. This can - // prevent live lock where one thread is signaling fences as fast/faster - // than another thread can consume them. - keep_resolving = false; - break; - } - VkResult fence_status = syms->vkGetFenceStatus(device, fence_signal->fence); - switch (fence_status) { - case VK_SUCCESS: { - // Fence has signaled meaning that we have reached this point in the - // timeline and can advance the value. - value_.store(fence_signal->value); - outstanding_signals_.erase(fence_signal); - resolved_fences.push_back(fence_signal); - - // Run backwards and resolve any non-pending fences as they will never - // be used. - for (auto* it = fence_signal; it != nullptr;) { - auto* prev_fence_signal = it; - it = outstanding_signals_.previous(it); - if (!prev_fence_signal->is_pending) { - outstanding_signals_.erase(prev_fence_signal); - unresolved_fences.push_back(prev_fence_signal); - } - } - break; - } - case VK_NOT_READY: - if (fence_signal->is_pending) { - // Fence has not yet been signaled. We stop here and wait for future - // attempts at resolution. - keep_resolving = false; - } - // Fence is not even pending yet - we may have skipped it. Keep - // resolving to see if there's a higher value we can use. - break; - default: - // Fence indicates an error (device lost, out of memory, etc). - // Propagate this back to our status (and thus any waiters). - // Since we only take the first error we find we skip all remaining - // fences. - status_ = VkResultToStatus(fence_status); - value_.store(UINT64_MAX); - outstanding_signals_.erase(fence_signal); - resolved_fences.push_back(fence_signal); - break; - } - } - - // Release resolved fences back to the pool. Note that we can only do this - // to fences we know have actually completed: unresolved fences after an error - // may still be in-flight and we don't want to reuse them. - fence_pool_->ReleaseResolved(&resolved_fences); - fence_pool_->ReleaseUnresolved(&unresolved_fences); - if (!status_.ok()) { - fence_pool_->ReleaseUnresolved(&outstanding_signals_); - } - - return status_; -} - -} // namespace vulkan -} // namespace hal -} // namespace iree
diff --git a/iree/hal/vulkan/legacy_fence.h b/iree/hal/vulkan/legacy_fence.h deleted file mode 100644 index 42f0362..0000000 --- a/iree/hal/vulkan/legacy_fence.h +++ /dev/null
@@ -1,200 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// TODO(b/140141417): share the pool (and possibly most of the fence impl) with -// the timeline semaphores fallback. - -#ifndef IREE_HAL_VULKAN_LEGACY_FENCE_H_ -#define IREE_HAL_VULKAN_LEGACY_FENCE_H_ - -#include <vulkan/vulkan.h> - -#include <array> -#include <atomic> - -#include "absl/base/thread_annotations.h" -#include "absl/synchronization/mutex.h" -#include "iree/base/intrusive_list.h" -#include "iree/base/ref_ptr.h" -#include "iree/base/status.h" -#include "iree/hal/fence.h" -#include "iree/hal/vulkan/handle_util.h" - -namespace iree { -namespace hal { -namespace vulkan { - -// An outstanding legacy fence signal for a particular timeline value. -// Each signal to a new value gets a new VkFence and these are stored in a -// LegacyFence to quickly scan and process signaled fences. -// -// Must be externally synchronized via the LegacyFence mutex. -struct OutstandingFenceSignal : public IntrusiveLinkBase<void> { - // Allocated fence that is passed to vkQueueSubmit/vkWaitForFences. - // Represents a point in the timeline of value. - VkFence fence = VK_NULL_HANDLE; - - // Value that the fence payload should be when the fence is signaled. - // Note that since fences may resolve out of order we still need to check that - // we are only ever advancing the timeline and not just setting this value. - uint64_t value = UINT64_MAX; - - // True when the fence has been submitted and is pending on the device. - bool is_pending = false; -}; - -// A pool of VkFences that can be used by LegacyFence to simulate individual -// payload value signaling. Note that we prefer a pool instead of a ringbuffer -// as we want to allow out-of-order completion. -class LegacyFencePool final : public RefObject<LegacyFencePool> { - public: - static constexpr int kMaxInFlightFenceCount = 64; - - // Allocates a new fence pool and all fences. - static StatusOr<ref_ptr<LegacyFencePool>> Create( - ref_ptr<VkDeviceHandle> logical_device); - - ~LegacyFencePool(); - - const ref_ptr<VkDeviceHandle>& logical_device() const { - return logical_device_; - } - const ref_ptr<DynamicSymbols>& syms() const { - return logical_device_->syms(); - } - - // Acquires a fence from the pool for use by the caller. - // The fence is guaranteed to not be in-flight and will have been reset to an - // unsignaled state. - // - // Returns RESOURCE_EXHAUSTED if the pool has no more available fences. - // Callers are expected to handle this by waiting on previous fences or for - // complete device idle. Yes, that's as bad as it sounds, and if we start - // seeing that we should bump up the max count. - StatusOr<OutstandingFenceSignal*> Acquire(); - - // Releases one or more fences back to the pool. - // The fences must either be signaled or not be in-flight. - void ReleaseResolved(IntrusiveList<OutstandingFenceSignal>* fence_signals); - - // Releases one or more unresolved fences back to the pool. - // These may be in any state and will be assumed as untouchable. - void ReleaseUnresolved(IntrusiveList<OutstandingFenceSignal>* fence_signals); - - private: - explicit LegacyFencePool(ref_ptr<VkDeviceHandle> logical_device); - - Status PreallocateFences() ABSL_LOCKS_EXCLUDED(mutex_); - - ref_ptr<VkDeviceHandle> logical_device_; - - absl::Mutex mutex_; - std::array<OutstandingFenceSignal, kMaxInFlightFenceCount> storage_ - ABSL_GUARDED_BY(mutex_); - IntrusiveList<OutstandingFenceSignal> unused_fences_ ABSL_GUARDED_BY(mutex_); - IntrusiveList<OutstandingFenceSignal> unresolved_fences_ - ABSL_GUARDED_BY(mutex_); -}; - -// A fence implemented using a pool of native VkFences. -// This is supported unconditionally on all versions of Vulkan. When timeline -// semaphores are available we prefer using those instead and this is only -// present as a fallback. We keep this implementation separate so that it can be -// compiled out when the target is known to have the extension. -// -// Simulation of timeline semaphore-based fences is done via a pool of native -// VkFences that each represent a single signaled value. This means that worst -// case we are using one fence per submit however that's no different than if -// we did anything else. Though we can't cancel previously-queued fences when -// increasing values are signaled we can be clever when querying and releasing -// by always walking in reverse relying on the monotonically increasing values. -// -// Valid usage patterns we need to handle: -// 1. fence signaled and waited on (common case) -// 2. fence waited on before beginning signaling -// 3. fence signaled and never waited on -// -// Case 1 is fairly straightforward: we acquire a VkFence, pass that to the -// queue submit, and then vkWaitForFences/query it for completion. -// -// Case 2 requires that we reserve a fence during the wait so that we can pass -// it to vkWaitForFences and track it such that we can reuse it during a future -// signal operation. Since we don't know during signaling if the specific value -// we waited on will ever have its own dedicated signal operation we need to be -// conservative and try to coalesce for correctness. This means that if a wait -// for a value of 1 is performed and we get a signal for a value of 2 we need to -// combine the two. If a signal for a value of 1 is later performed it then -// becomes a no-op. This could lead to some additional latency however that's a -// risk (or benefit!) of using timelines. Rule of thumb: don't do out of order -// signaling. -// -// Case 3 is like case 2 where we need to reserve a fence to wait on, however -// since we don't know if it will ever be signaled we need to take care to -// properly release the VkFence back to the pool for reuse: we don't want to -// return it while there are still waiters for its original event. For this -// reason we track the waiters on a given fence during their wait operation and -// if a fence is released with waiters active we put them in a special -// unresolved until the waiters continue on. -class LegacyFence final : public Fence { - public: - // Waits for one or more (or all) fences to reach or exceed the given values. - static Status WaitForFences(VkDeviceHandle* logical_device, - absl::Span<const FenceValue> fences, - bool wait_all, absl::Time deadline); - - LegacyFence(ref_ptr<LegacyFencePool> fence_pool, uint64_t initial_value); - ~LegacyFence() override; - - Status status() const override; - - StatusOr<uint64_t> QueryValue() override; - - // Acquires a new fence for signaling a specific value. - StatusOr<VkFence> AcquireSignalFence(uint64_t value); - - private: - // Acquires a new fence for waiting on a specific value. - // Returns VK_NULL_HANDLE if the fence already resolved and the sticky error - // if the fence is in an error state. - StatusOr<VkFence> AcquireWaitFence(uint64_t value); - - // Runs down the outstanding fences list and resolves to the latest signaled - // value. Will early exit if the value moves beyond |upper_value|. - Status TryResolveOutstandingFences(uint64_t upper_value) - ABSL_LOCKS_EXCLUDED(mutex_); - Status TryResolveOutstandingFencesLocked(uint64_t upper_value) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - ref_ptr<LegacyFencePool> fence_pool_; - - // The current highest value of the fence as verified during a wait or query. - // Kept outside of |mutex_| so that queries do not require a lock. - std::atomic<uint64_t> value_; - - mutable absl::Mutex mutex_; - - // Sticky status failure value set on first failure. - Status status_ ABSL_GUARDED_BY(mutex_); - - // Outstanding VkFences representing signal values. - // Expected to be sorted in ascending order by value. - IntrusiveList<OutstandingFenceSignal> outstanding_signals_ - ABSL_GUARDED_BY(mutex_); -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_LEGACY_FENCE_H_
diff --git a/iree/hal/vulkan/native_binary_semaphore.cc b/iree/hal/vulkan/native_binary_semaphore.cc deleted file mode 100644 index 5439a67..0000000 --- a/iree/hal/vulkan/native_binary_semaphore.cc +++ /dev/null
@@ -1,32 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/native_binary_semaphore.h" - -namespace iree { -namespace hal { -namespace vulkan { - -NativeBinarySemaphore::NativeBinarySemaphore( - ref_ptr<VkDeviceHandle> logical_device, VkSemaphore handle) - : logical_device_(std::move(logical_device)), handle_(handle) {} - -NativeBinarySemaphore::~NativeBinarySemaphore() { - logical_device_->syms()->vkDestroySemaphore(*logical_device_, handle_, - logical_device_->allocator()); -} - -} // namespace vulkan -} // namespace hal -} // namespace iree
diff --git a/iree/hal/vulkan/native_binary_semaphore.h b/iree/hal/vulkan/native_binary_semaphore.h deleted file mode 100644 index fc55ebb..0000000 --- a/iree/hal/vulkan/native_binary_semaphore.h +++ /dev/null
@@ -1,46 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VULKAN_NATIVE_BINARY_SEMAPHORE_H_ -#define IREE_HAL_VULKAN_NATIVE_BINARY_SEMAPHORE_H_ - -#include <vulkan/vulkan.h> - -#include "iree/hal/semaphore.h" -#include "iree/hal/vulkan/handle_util.h" - -namespace iree { -namespace hal { -namespace vulkan { - -// A binary semaphore implemented using the native VkSemaphore type. -// This is supported unconditionally on all versions of Vulkan. -class NativeBinarySemaphore final : public BinarySemaphore { - public: - NativeBinarySemaphore(ref_ptr<VkDeviceHandle> logical_device, - VkSemaphore handle); - ~NativeBinarySemaphore() override; - - VkSemaphore handle() const { return handle_; } - - private: - ref_ptr<VkDeviceHandle> logical_device_; - VkSemaphore handle_; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_NATIVE_BINARY_SEMAPHORE_H_
diff --git a/iree/hal/vulkan/native_event.cc b/iree/hal/vulkan/native_event.cc deleted file mode 100644 index 28dbc56..0000000 --- a/iree/hal/vulkan/native_event.cc +++ /dev/null
@@ -1,31 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/native_event.h" - -namespace iree { -namespace hal { -namespace vulkan { - -NativeEvent::NativeEvent(ref_ptr<VkDeviceHandle> logical_device, VkEvent handle) - : logical_device_(std::move(logical_device)), handle_(handle) {} - -NativeEvent::~NativeEvent() { - logical_device_->syms()->vkDestroyEvent(*logical_device_, handle_, - logical_device_->allocator()); -} - -} // namespace vulkan -} // namespace hal -} // namespace iree
diff --git a/iree/hal/vulkan/native_event.h b/iree/hal/vulkan/native_event.h deleted file mode 100644 index 691ef6d..0000000 --- a/iree/hal/vulkan/native_event.h +++ /dev/null
@@ -1,44 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VULKAN_NATIVE_EVENT_H_ -#define IREE_HAL_VULKAN_NATIVE_EVENT_H_ - -#include <vulkan/vulkan.h> - -#include "iree/hal/event.h" -#include "iree/hal/vulkan/handle_util.h" - -namespace iree { -namespace hal { -namespace vulkan { - -// An event implemented with the native VkEvent type. -class NativeEvent final : public Event { - public: - NativeEvent(ref_ptr<VkDeviceHandle> logical_device, VkEvent handle); - ~NativeEvent() override; - - VkEvent handle() const { return handle_; } - - private: - ref_ptr<VkDeviceHandle> logical_device_; - VkEvent handle_; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_NATIVE_EVENT_H_
diff --git a/iree/hal/vulkan/pipeline_cache.cc b/iree/hal/vulkan/pipeline_cache.cc deleted file mode 100644 index aba1917..0000000 --- a/iree/hal/vulkan/pipeline_cache.cc +++ /dev/null
@@ -1,235 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/pipeline_cache.h" - -#include "absl/synchronization/mutex.h" -#include "flatbuffers/flatbuffers.h" -#include "iree/base/source_location.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/executable_format.h" -#include "iree/hal/vulkan/status_util.h" -#include "iree/schemas/spirv_executable_def_generated.h" - -namespace iree { -namespace hal { -namespace vulkan { - -PipelineCache::PipelineCache(const ref_ptr<VkDeviceHandle>& logical_device) - : logical_device_(add_ref(logical_device)) {} - -PipelineCache::~PipelineCache() { - IREE_TRACE_SCOPE0("PipelineCache::dtor"); - ClearLayoutCaches(); -} - -bool PipelineCache::CanPrepareFormat(ExecutableFormat format) const { - return format == kExecutableFormatSpirV; -} - -StatusOr<ref_ptr<Executable>> PipelineCache::PrepareExecutable( - ExecutableCachingModeBitfield mode, const ExecutableSpec& spec) { - IREE_TRACE_SCOPE0("PipelineCache::PrepareExecutable"); - if (!CanPrepareFormat(spec.format)) { - return UnimplementedErrorBuilder(IREE_LOC) - << "Unsupported 4CC format: 0x" << std::hex << spec.format; - } - if (spec.executable_data.size() <= 4 || - !SpirVExecutableDefBufferHasIdentifier(spec.executable_data.data())) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Supplied executable data does not contain a SpirVExecutableDef"; - } - - // Get the SPIR-V executable def flatbuffer. - const auto& spirv_executable_def = - *::flatbuffers::GetRoot<SpirVExecutableDef>(spec.executable_data.data()); - - // Create (or reuse) a pipeline layout. - if (!spirv_executable_def.pipeline_layout()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Missing pipeline layout def"; - } - ASSIGN_OR_RETURN( - auto pipeline_layout_entry, - LookupOrInsertPipelineLayout(*spirv_executable_def.pipeline_layout())); - - // Create the executable (which may itself own many pipelines). - ASSIGN_OR_RETURN(auto executable, PipelineExecutable::Create( - logical_device_, - /*pipeline_cache=*/VK_NULL_HANDLE, - pipeline_layout_entry->pipeline_layout, - pipeline_layout_entry->descriptor_sets, - mode, spirv_executable_def)); - return executable; -} - -StatusOr<const PipelineCache::CachedPipelineLayout*> -PipelineCache::LookupOrInsertPipelineLayout( - const VkPipelineLayoutDef& pipeline_layout_def) { - IREE_TRACE_SCOPE0("PipelineCache::LookupOrInsertPipelineLayout"); - absl::MutexLock lock(&mutex_); - - // Build a list of the required descriptor set layouts and push constants. - // If we were being fast about this we would just hash the def and directly - // look up the pipeline layout. - PipelineDescriptorSets descriptor_sets; - descriptor_sets.buffer_binding_set = pipeline_layout_def.buffer_binding_set(); - descriptor_sets.buffer_binding_set_layout = VK_NULL_HANDLE; - absl::InlinedVector<VkDescriptorSetLayout, 4> descriptor_set_layouts; - if (pipeline_layout_def.descriptor_set_layouts()) { - const auto& layout_defs = *pipeline_layout_def.descriptor_set_layouts(); - descriptor_set_layouts.resize(layout_defs.size()); - for (int i = 0; i < descriptor_set_layouts.size(); ++i) { - if (!layout_defs[i]) { - return InvalidArgumentErrorBuilder(IREE_LOC) << "Missing layout def"; - } - ASSIGN_OR_RETURN(descriptor_set_layouts[i], - LookupOrInsertDescriptorSetLayout(*layout_defs[i])); - if (i == pipeline_layout_def.buffer_binding_set()) { - descriptor_sets.buffer_binding_set_layout = descriptor_set_layouts[i]; - descriptor_sets.buffer_binding_set_map.resize( - layout_defs[i]->bindings()->size()); - for (int j = 0; j < layout_defs[i]->bindings()->size(); ++j) { - descriptor_sets.buffer_binding_set_map[j] = - layout_defs[i]->bindings()->Get(j)->binding(); - } - } - } - } - - absl::InlinedVector<VkPushConstantRange, 1> push_constant_ranges; - if (pipeline_layout_def.push_constant_ranges()) { - const auto& range_defs = *pipeline_layout_def.push_constant_ranges(); - push_constant_ranges.resize(range_defs.size()); - for (int i = 0; i < push_constant_ranges.size(); ++i) { - if (!range_defs[i]) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Missing push constant range def"; - } - push_constant_ranges[i].stageFlags = range_defs[i]->stage_flags(); - push_constant_ranges[i].offset = range_defs[i]->offset(); - push_constant_ranges[i].size = range_defs[i]->size(); - } - } - - // Scan for an existing pipeline layout that matches the descriptor sets. - for (auto& entry : pipeline_layout_cache_) { - if (entry.descriptor_set_layouts.size() != descriptor_set_layouts.size() || - entry.push_constant_ranges.size() != push_constant_ranges.size()) { - continue; - } - if (std::memcmp( - descriptor_set_layouts.data(), entry.descriptor_set_layouts.data(), - descriptor_set_layouts.size() * sizeof(VkDescriptorSetLayout)) == - 0 && - std::memcmp( - push_constant_ranges.data(), entry.push_constant_ranges.data(), - push_constant_ranges.size() * sizeof(VkPushConstantRange)) == 0) { - return &entry; - } - } - - VkPipelineLayoutCreateInfo create_info; - create_info.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO; - create_info.pNext = nullptr; - create_info.flags = 0; - create_info.setLayoutCount = descriptor_set_layouts.size(); - create_info.pSetLayouts = descriptor_set_layouts.data(); - create_info.pushConstantRangeCount = push_constant_ranges.size(); - create_info.pPushConstantRanges = push_constant_ranges.data(); - - // Create and insert into the cache. - VkPipelineLayout pipeline_layout = VK_NULL_HANDLE; - VK_RETURN_IF_ERROR(syms()->vkCreatePipelineLayout( - *logical_device_, &create_info, logical_device_->allocator(), - &pipeline_layout)); - pipeline_layout_cache_.push_back({std::move(descriptor_set_layouts), - std::move(push_constant_ranges), - pipeline_layout, descriptor_sets}); - return &pipeline_layout_cache_.back(); -} - -StatusOr<VkDescriptorSetLayout> -PipelineCache::LookupOrInsertDescriptorSetLayout( - const VkDescriptorSetLayoutDef& descriptor_set_layout_def) { - // Build a list of bindings in the set. - // If we were being fast we would hash the bindings and directly lookup - // without doing this allocation. - absl::InlinedVector<VkDescriptorSetLayoutBinding, 4> bindings; - if (descriptor_set_layout_def.bindings()) { - const auto& binding_defs = *descriptor_set_layout_def.bindings(); - bindings.resize(binding_defs.size()); - for (int i = 0; i < binding_defs.size(); ++i) { - bindings[i].binding = binding_defs[i]->binding(); - bindings[i].descriptorType = - static_cast<VkDescriptorType>(binding_defs[i]->descriptor_type()); - bindings[i].descriptorCount = binding_defs[i]->descriptor_count(); - bindings[i].stageFlags = binding_defs[i]->stage_flags(); - bindings[i].pImmutableSamplers = nullptr; - } - } - - // Scan for an existing descriptor set layout that matches the bindings. - for (auto& entry : descriptor_set_layout_cache_) { - if (entry.bindings.size() != bindings.size()) continue; - if (std::memcmp(bindings.data(), entry.bindings.data(), - bindings.size() * sizeof(VkDescriptorSetLayoutBinding)) == - 0) { - return entry.descriptor_set_layout; - } - } - - VkDescriptorSetLayoutCreateInfo create_info; - create_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO; - create_info.pNext = nullptr; - create_info.flags = 0; - if (logical_device_->enabled_extensions().push_descriptors) { - // Note that we can *only* use push descriptor sets if we set this create - // flag. That's fine, though, as the command buffer recording logic always - // prefers the extension if available. - create_info.flags |= - VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT_KHR; - } - create_info.bindingCount = bindings.size(); - create_info.pBindings = bindings.data(); - - // Create and insert into the cache. - VkDescriptorSetLayout descriptor_set_layout = VK_NULL_HANDLE; - VK_RETURN_IF_ERROR(syms()->vkCreateDescriptorSetLayout( - *logical_device_, &create_info, logical_device_->allocator(), - &descriptor_set_layout)); - descriptor_set_layout_cache_.push_back( - {std::move(bindings), descriptor_set_layout}); - return descriptor_set_layout; -} - -void PipelineCache::ClearLayoutCaches() { - absl::MutexLock lock(&mutex_); - for (auto& entry : pipeline_layout_cache_) { - syms()->vkDestroyPipelineLayout(*logical_device_, entry.pipeline_layout, - logical_device_->allocator()); - } - pipeline_layout_cache_.clear(); - for (auto& entry : descriptor_set_layout_cache_) { - syms()->vkDestroyDescriptorSetLayout(*logical_device_, - entry.descriptor_set_layout, - logical_device_->allocator()); - } - descriptor_set_layout_cache_.clear(); -} - -} // namespace vulkan -} // namespace hal -} // namespace iree
diff --git a/iree/hal/vulkan/pipeline_cache.h b/iree/hal/vulkan/pipeline_cache.h deleted file mode 100644 index 1847f36..0000000 --- a/iree/hal/vulkan/pipeline_cache.h +++ /dev/null
@@ -1,85 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VULKAN_PIPELINE_CACHE_H_ -#define IREE_HAL_VULKAN_PIPELINE_CACHE_H_ - -#include <vulkan/vulkan.h> - -#include "absl/base/thread_annotations.h" -#include "absl/container/inlined_vector.h" -#include "absl/synchronization/mutex.h" -#include "iree/hal/executable.h" -#include "iree/hal/executable_cache.h" -#include "iree/hal/vulkan/handle_util.h" -#include "iree/hal/vulkan/pipeline_executable.h" -#include "iree/schemas/spirv_executable_def_generated.h" - -namespace iree { -namespace hal { -namespace vulkan { - -class PipelineCache final : public ExecutableCache { - public: - explicit PipelineCache(const ref_ptr<VkDeviceHandle>& logical_device); - ~PipelineCache() override; - - const ref_ptr<DynamicSymbols>& syms() const { - return logical_device_->syms(); - } - - bool CanPrepareFormat(ExecutableFormat format) const override; - - StatusOr<ref_ptr<Executable>> PrepareExecutable( - ExecutableCachingModeBitfield mode, const ExecutableSpec& spec) override; - - private: - struct CachedDescriptorSetLayout { - absl::InlinedVector<VkDescriptorSetLayoutBinding, 4> bindings; - VkDescriptorSetLayout descriptor_set_layout; - }; - struct CachedPipelineLayout { - absl::InlinedVector<VkDescriptorSetLayout, 4> descriptor_set_layouts; - absl::InlinedVector<VkPushConstantRange, 1> push_constant_ranges; - VkPipelineLayout pipeline_layout; - PipelineDescriptorSets descriptor_sets; - }; - - StatusOr<const CachedPipelineLayout*> LookupOrInsertPipelineLayout( - const VkPipelineLayoutDef& pipeline_layout_def) - ABSL_LOCKS_EXCLUDED(mutex_); - StatusOr<VkDescriptorSetLayout> LookupOrInsertDescriptorSetLayout( - const VkDescriptorSetLayoutDef& descriptor_set_layout_def) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - void ClearLayoutCaches() ABSL_LOCKS_EXCLUDED(mutex_); - - ref_ptr<VkDeviceHandle> logical_device_; - - // A "cache" of descriptor set and pipeline layouts for various values. - // We never evict and just do a simple linear scan on lookup. This is fine for - // now as we only support a single descriptor type and really we only need to - // check for binding count. As we go toward more general usage of descriptors - // (images/etc) we will likely want to change this to a real cache. - absl::Mutex mutex_; - std::vector<CachedDescriptorSetLayout> descriptor_set_layout_cache_ - ABSL_GUARDED_BY(mutex_); - std::vector<CachedPipelineLayout> pipeline_layout_cache_ - ABSL_GUARDED_BY(mutex_); -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_PIPELINE_CACHE_H_
diff --git a/iree/hal/vulkan/pipeline_executable.cc b/iree/hal/vulkan/pipeline_executable.cc deleted file mode 100644 index 9c1514e..0000000 --- a/iree/hal/vulkan/pipeline_executable.cc +++ /dev/null
@@ -1,191 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/pipeline_executable.h" - -#include "absl/container/inlined_vector.h" -#include "iree/base/memory.h" -#include "iree/base/source_location.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/vulkan/status_util.h" - -namespace iree { -namespace hal { -namespace vulkan { - -namespace { - -// Generates the baked specialization constant data based on the flatbuffer. -// We only support uint32_t right now so this is easy. -// Note that the returned vectors are referenced by pointers in |out_info| and -// must remain valid until the info is no longer in use. -std::pair<std::vector<VkSpecializationMapEntry>, std::vector<uint8_t>> -PopulateSpecializationInfo(const VkSpecializationInfoDef* info_def) { - int entry_count = - info_def && info_def->map_entries() ? info_def->map_entries()->size() : 0; - if (!entry_count) { - return {}; - } - - std::vector<VkSpecializationMapEntry> entries; - entries.reserve(entry_count); - std::vector<uint8_t> data; - data.resize(entry_count * sizeof(uint32_t)); - - uint32_t offset = 0; - for (const auto* entry_def : *info_def->map_entries()) { - if (!entry_def) continue; - entries.push_back({}); - auto& entry = entries.back(); - entry.constantID = entry_def->constant_id(); - entry.offset = offset; - entry.size = sizeof(uint32_t); - uint32_t value = entry_def->uint32_value(); - std::memcpy(data.data() + offset, &value, sizeof(value)); - offset += entry.size; - } - - return {std::move(entries), std::move(data)}; -} - -} // namespace - -// static -StatusOr<ref_ptr<PipelineExecutable>> PipelineExecutable::Create( - const ref_ptr<VkDeviceHandle>& logical_device, - VkPipelineCache pipeline_cache, VkPipelineLayout pipeline_layout, - PipelineDescriptorSets descriptor_sets, ExecutableCachingModeBitfield mode, - const SpirVExecutableDef& spirv_executable_def) { - IREE_TRACE_SCOPE0("PipelineExecutable::Create"); - const auto& syms = logical_device->syms(); - if (!spirv_executable_def.entry_points() || - spirv_executable_def.entry_points()->size() == 0) { - return InvalidArgumentErrorBuilder(IREE_LOC) << "No entry points defined"; - } - if (!spirv_executable_def.code()) { - return InvalidArgumentErrorBuilder(IREE_LOC) << "No SPIR-V code present"; - } - const auto& code = *spirv_executable_def.code(); - - // Create the shader module. - VkShaderModuleCreateInfo shader_module_create_info; - shader_module_create_info.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; - shader_module_create_info.pNext = nullptr; - shader_module_create_info.flags = 0; - shader_module_create_info.codeSize = code.size() * sizeof(uint32_t); - shader_module_create_info.pCode = code.data(); - VkShaderModule shader_module = VK_NULL_HANDLE; - VK_RETURN_IF_ERROR( - syms->vkCreateShaderModule(*logical_device, &shader_module_create_info, - logical_device->allocator(), &shader_module)); - - // We only need to keep this around during pipeline creation so ensure we - // always clean it up when we exit this function. - auto shader_module_cleanup = MakeCleanup([&logical_device, shader_module]() { - logical_device->syms()->vkDestroyShaderModule( - *logical_device, shader_module, logical_device->allocator()); - }); - - // Specialization info is currently constant against all entry points. - std::vector<VkSpecializationMapEntry> spec_entries; - std::vector<uint8_t> spec_data; - std::tie(spec_entries, spec_data) = - PopulateSpecializationInfo(spirv_executable_def.specialization_info()); - VkSpecializationInfo specialization_info; - specialization_info.mapEntryCount = spec_entries.size(); - specialization_info.pMapEntries = spec_entries.data(); - specialization_info.dataSize = spec_data.size(); - specialization_info.pData = spec_data.data(); - - // Create pipelines for each entry point. - const auto& entry_points = *spirv_executable_def.entry_points(); - absl::InlinedVector<VkComputePipelineCreateInfo, 1> pipeline_create_infos; - pipeline_create_infos.resize(entry_points.size()); - for (int entry_ordinal = 0; entry_ordinal < entry_points.size(); - ++entry_ordinal) { - auto& pipeline_create_info = pipeline_create_infos[entry_ordinal]; - pipeline_create_info.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; - pipeline_create_info.pNext = nullptr; - pipeline_create_info.flags = 0; - if (!AllBitsSet(mode, ExecutableCachingMode::kAllowOptimization)) { - pipeline_create_info.flags |= VK_PIPELINE_CREATE_DISABLE_OPTIMIZATION_BIT; - } - if (entry_ordinal == 0) { - pipeline_create_info.flags |= VK_PIPELINE_CREATE_ALLOW_DERIVATIVES_BIT; - } else { - pipeline_create_info.flags |= VK_PIPELINE_CREATE_DERIVATIVE_BIT; - } - pipeline_create_info.layout = pipeline_layout; - pipeline_create_info.basePipelineHandle = VK_NULL_HANDLE; - pipeline_create_info.basePipelineIndex = 0; - auto& stage_create_info = pipeline_create_info.stage; - stage_create_info.sType = - VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; - stage_create_info.pNext = nullptr; - stage_create_info.flags = 0; - stage_create_info.stage = VK_SHADER_STAGE_COMPUTE_BIT; - stage_create_info.module = shader_module; - stage_create_info.pName = entry_points[entry_ordinal]->c_str(); - stage_create_info.pSpecializationInfo = &specialization_info; - } - absl::InlinedVector<VkPipeline, 1> pipelines; - pipelines.resize(entry_points.size()); - - // Some ICDs appear to leak in here, out of our control. - // Warning: leak checks remain disabled if an error is returned. - IREE_DISABLE_LEAK_CHECKS(); - VK_RETURN_IF_ERROR(syms->vkCreateComputePipelines( - *logical_device, pipeline_cache, pipeline_create_infos.size(), - pipeline_create_infos.data(), logical_device->allocator(), - pipelines.data())); - IREE_ENABLE_LEAK_CHECKS(); - - auto executable = - make_ref<PipelineExecutable>(CtorKey{}, logical_device, pipeline_layout, - descriptor_sets, std::move(pipelines)); - executable->tag_ = - spirv_executable_def.tag() ? spirv_executable_def.tag()->str() : ""; - return executable; -} - -PipelineExecutable::PipelineExecutable( - CtorKey ctor_key, const ref_ptr<VkDeviceHandle>& logical_device, - VkPipelineLayout pipeline_layout, PipelineDescriptorSets descriptor_sets, - absl::InlinedVector<VkPipeline, 1> pipelines) - : logical_device_(add_ref(logical_device)), - pipeline_layout_(pipeline_layout), - descriptor_sets_(descriptor_sets), - pipelines_(std::move(pipelines)) {} - -PipelineExecutable::~PipelineExecutable() { - IREE_TRACE_SCOPE0("PipelineExecutable::dtor"); - for (auto pipeline : pipelines_) { - syms()->vkDestroyPipeline(*logical_device_, pipeline, - logical_device_->allocator()); - } - pipelines_.clear(); -} - -StatusOr<VkPipeline> PipelineExecutable::GetPipelineForEntryPoint( - int entry_ordinal) const { - if (entry_ordinal < 0 || entry_ordinal >= pipelines_.size()) { - return OutOfRangeErrorBuilder(IREE_LOC) << "Invalid entry point ordinal"; - } - return pipelines_[entry_ordinal]; -} - -} // namespace vulkan -} // namespace hal -} // namespace iree
diff --git a/iree/hal/vulkan/pipeline_executable.h b/iree/hal/vulkan/pipeline_executable.h deleted file mode 100644 index 1ca21aa..0000000 --- a/iree/hal/vulkan/pipeline_executable.h +++ /dev/null
@@ -1,91 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_H_ -#define IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_H_ - -#include <vulkan/vulkan.h> - -#include <vector> - -#include "absl/container/inlined_vector.h" -#include "iree/base/status.h" -#include "iree/hal/executable.h" -#include "iree/hal/executable_cache.h" -#include "iree/hal/executable_spec.h" -#include "iree/hal/vulkan/handle_util.h" -#include "iree/schemas/spirv_executable_def_generated.h" - -namespace iree { -namespace hal { -namespace vulkan { - -struct PipelineDescriptorSets { - uint32_t buffer_binding_set; - VkDescriptorSetLayout buffer_binding_set_layout; - absl::InlinedVector<uint32_t, 8> buffer_binding_set_map; -}; - -class PipelineExecutable final : public Executable { - public: - static StatusOr<ref_ptr<PipelineExecutable>> Create( - const ref_ptr<VkDeviceHandle>& logical_device, - VkPipelineCache pipeline_cache, VkPipelineLayout pipeline_layout, - PipelineDescriptorSets descriptor_sets, - ExecutableCachingModeBitfield mode, - const SpirVExecutableDef& spirv_executable_def); - - // Private constructor. - struct CtorKey { - private: - friend class PipelineExecutable; - CtorKey() = default; - }; - PipelineExecutable(CtorKey ctor_key, - const ref_ptr<VkDeviceHandle>& logical_device, - VkPipelineLayout pipeline_layout, - PipelineDescriptorSets descriptor_sets, - absl::InlinedVector<VkPipeline, 1> pipelines); - ~PipelineExecutable() override; - - const ref_ptr<DynamicSymbols>& syms() const { - return logical_device_->syms(); - } - - bool supports_debugging() const override { return false; } - - VkPipelineLayout pipeline_layout() const { return pipeline_layout_; } - const PipelineDescriptorSets& descriptor_sets() const { - return descriptor_sets_; - } - - bool is_matmul() const { return tag_ == "__matmul__"; } - - StatusOr<VkPipeline> GetPipelineForEntryPoint(int entry_ordinal) const; - - private: - ref_ptr<VkDeviceHandle> logical_device_; - VkPipelineLayout pipeline_layout_; - PipelineDescriptorSets descriptor_sets_; - std::string tag_; - - // One pipeline per entry point. - absl::InlinedVector<VkPipeline, 1> pipelines_; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_H_
diff --git a/iree/hal/vulkan/status_util.cc b/iree/hal/vulkan/status_util.cc deleted file mode 100644 index 4b08e21..0000000 --- a/iree/hal/vulkan/status_util.cc +++ /dev/null
@@ -1,231 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/status_util.h" - -#include "iree/base/status.h" - -namespace iree { -namespace hal { -namespace vulkan { - -Status VkResultToStatus(VkResult result) { - switch (result) { - // Success codes. - case VK_SUCCESS: - // Command successfully completed. - return OkStatus(); - case VK_NOT_READY: - // A fence or query has not yet completed. - return OkStatus(); - case VK_TIMEOUT: - // A wait operation has not completed in the specified time. - return OkStatus(); - case VK_EVENT_SET: - // An event is signaled. - return OkStatus(); - case VK_EVENT_RESET: - // An event is unsignaled. - return OkStatus(); - case VK_INCOMPLETE: - // A return array was too small for the result. - return OkStatus(); - case VK_SUBOPTIMAL_KHR: - // A swapchain no longer matches the surface properties exactly, but can - // still be used to present to the surface successfully. - return OkStatus(); - - // Error codes. - case VK_ERROR_OUT_OF_HOST_MEMORY: - // A host memory allocation has failed. - return ResourceExhaustedError("VK_ERROR_OUT_OF_HOST_MEMORY"); - case VK_ERROR_OUT_OF_DEVICE_MEMORY: - // A device memory allocation has failed. - return ResourceExhaustedError("VK_ERROR_OUT_OF_DEVICE_MEMORY"); - case VK_ERROR_INITIALIZATION_FAILED: - // Initialization of an object could not be completed for - // implementation-specific reasons. - return InternalError("VK_ERROR_INITIALIZATION_FAILED"); - case VK_ERROR_DEVICE_LOST: - // The logical or physical device has been lost. - // - // A logical device may become lost for a number of - // implementation-specific reasons, indicating that pending and future - // command execution may fail and cause resources and backing memory to - // become undefined. - // - // Typical reasons for device loss will include things like execution - // timing out (to prevent denial of service), power management events, - // platform resource management, or implementation errors. - // - // When this happens, certain commands will return - // VK_ERROR_DEVICE_LOST (see Error Codes for a list of such - // commands). After any such event, the logical device is considered lost. - // It is not possible to reset the logical device to a non-lost state, - // however the lost state is specific to a logical device (VkDevice), and - // the corresponding physical device (VkPhysicalDevice) may be otherwise - // unaffected. - // - // In some cases, the physical device may also be lost, and attempting to - // create a new logical device will fail, returning VK_ERROR_DEVICE_LOST. - // This is usually indicative of a problem with the underlying - // implementation, or its connection to the host. If the physical device - // has not been lost, and a new logical device is successfully created - // from that physical device, it must be in the non-lost state. - // - // Whilst logical device loss may be recoverable, in the case of physical - // device loss, it is unlikely that an application will be able to recover - // unless additional, unaffected physical devices exist on the system. The - // error is largely informational and intended only to inform the user - // that a platform issue has occurred, and should be investigated further. - // For example, underlying hardware may have developed a fault or become - // physically disconnected from the rest of the system. In many cases, - // physical device loss may cause other more serious issues such as the - // operating system crashing; in which case it may not be reported via the - // Vulkan API. - // - // Undefined behavior caused by an application error may cause a device to - // become lost. However, such undefined behavior may also cause - // unrecoverable damage to the process, and it is then not guaranteed that - // the API objects, including the VkPhysicalDevice or the VkInstance are - // still valid or that the error is recoverable. - // - // When a device is lost, its child objects are not implicitly destroyed - // and their handles are still valid. Those objects must still be - // destroyed before their parents or the device can be destroyed (see the - // Object Lifetime section). The host address space corresponding to - // device memory mapped using vkMapMemory is still valid, and host memory - // accesses to these mapped regions are still valid, but the contents are - // undefined. It is still legal to call any API command on the device and - // child objects. - // - // Once a device is lost, command execution may fail, and commands that - // return a VkResult may return VK_ERROR_DEVICE_LOST. - // Commands that do not allow run-time errors must still operate correctly - // for valid usage and, if applicable, return valid data. - // - // Commands that wait indefinitely for device execution (namely - // vkDeviceWaitIdle, vkQueueWaitIdle, vkWaitForFences with a maximum - // timeout, and vkGetQueryPoolResults with the VK_QUERY_RESULT_WAIT_BIT - // bit set in flags) must return in finite time even in the case - // of a lost device, and return either VK_SUCCESS or - // VK_ERROR_DEVICE_LOST. For any command that may return - // VK_ERROR_DEVICE_LOST, for the purpose of determining whether a - // command buffer is in the pending state, or whether resources are - // considered in-use by the device, a return value of - // VK_ERROR_DEVICE_LOST is equivalent to VK_SUCCESS. - return InternalError("VK_ERROR_DEVICE_LOST"); - case VK_ERROR_MEMORY_MAP_FAILED: - // Mapping of a memory object has failed. - return InternalError("VK_ERROR_MEMORY_MAP_FAILED"); - case VK_ERROR_LAYER_NOT_PRESENT: - // A requested layer is not present or could not be loaded. - return UnimplementedError("VK_ERROR_LAYER_NOT_PRESENT"); - case VK_ERROR_EXTENSION_NOT_PRESENT: - // A requested extension is not supported. - return UnimplementedError("VK_ERROR_EXTENSION_NOT_PRESENT"); - case VK_ERROR_FEATURE_NOT_PRESENT: - // A requested feature is not supported. - return UnimplementedError("VK_ERROR_FEATURE_NOT_PRESENT"); - case VK_ERROR_INCOMPATIBLE_DRIVER: - // The requested version of Vulkan is not supported by the driver or is - // otherwise incompatible for implementation-specific reasons. - return FailedPreconditionError("VK_ERROR_INCOMPATIBLE_DRIVER"); - case VK_ERROR_TOO_MANY_OBJECTS: - // Too many objects of the type have already been created. - return ResourceExhaustedError("VK_ERROR_TOO_MANY_OBJECTS"); - case VK_ERROR_FORMAT_NOT_SUPPORTED: - // A requested format is not supported on this device. - return UnimplementedError("VK_ERROR_FORMAT_NOT_SUPPORTED"); - case VK_ERROR_FRAGMENTED_POOL: - // A pool allocation has failed due to fragmentation of the pool’s memory. - // This must only be returned if no attempt to allocate host or device - // memory was made to accommodate the new allocation. - return ResourceExhaustedError("VK_ERROR_FRAGMENTED_POOL"); - case VK_ERROR_OUT_OF_POOL_MEMORY: - // A pool memory allocation has failed. This must only be returned if no - // attempt to allocate host or device memory was made to accommodate the - // new allocation. If the failure was definitely due to fragmentation of - // the pool, VK_ERROR_FRAGMENTED_POOL should be returned instead. - return ResourceExhaustedError("VK_ERROR_OUT_OF_POOL_MEMORY"); - case VK_ERROR_INVALID_EXTERNAL_HANDLE: - // An external handle is not a valid handle of the specified type. - return InvalidArgumentError("VK_ERROR_INVALID_EXTERNAL_HANDLE"); - case VK_ERROR_SURFACE_LOST_KHR: - // A surface is no longer available. - return UnavailableError("VK_ERROR_SURFACE_LOST_KHR"); - case VK_ERROR_NATIVE_WINDOW_IN_USE_KHR: - // The requested window is already in use by Vulkan or another API in a - // manner which prevents it from being used again. - return InvalidArgumentError("VK_ERROR_NATIVE_WINDOW_IN_USE_KHR"); - case VK_ERROR_OUT_OF_DATE_KHR: - // A surface has changed in such a way that it is no longer compatible - // with the swapchain, and further presentation requests using the - // swapchain will fail. Applications must query the new surface properties - // and recreate their swapchain if they wish to continue presenting to the - // surface. - return FailedPreconditionError("VK_ERROR_OUT_OF_DATE_KHR"); - case VK_ERROR_INCOMPATIBLE_DISPLAY_KHR: - // The display used by a swapchain does not use the same presentable image - // layout, or is incompatible in a way that prevents sharing an image. - return InvalidArgumentError("VK_ERROR_INCOMPATIBLE_DISPLAY_KHR"); - case VK_ERROR_VALIDATION_FAILED_EXT: - // Validation layer testing failed. It is not expected that an - // application would see this this error code during normal use of the - // validation layers. - return InvalidArgumentError("VK_ERROR_VALIDATION_FAILED_EXT"); - case VK_ERROR_INVALID_SHADER_NV: - // One or more shaders failed to compile or link. More details are - // reported back to the application when the validation layer is enabled - // using the extension VK_EXT_debug_report. - return InvalidArgumentError("VK_ERROR_INVALID_SHADER_NV"); - case VK_ERROR_INVALID_DRM_FORMAT_MODIFIER_PLANE_LAYOUT_EXT: - // When creating an image with - // VkImageDrmFormatModifierExplicitCreateInfoEXT, it is the application’s - // responsibility to satisfy all Valid Usage requirements. However, the - // implementation must validate that the provided pPlaneLayouts, when - // combined with the provided drmFormatModifier and other creation - // parameters in VkImageCreateInfo and its pNext chain, produce a valid - // image. (This validation is necessarily implementation-dependent and - // outside the scope of Vulkan, and therefore not described by Valid Usage - // requirements). If this validation fails, then vkCreateImage returns - // VK_ERROR_INVALID_DRM_FORMAT_MODIFIER_PLANE_LAYOUT_EXT. - return InvalidArgumentError( - "VK_ERROR_INVALID_DRM_FORMAT_MODIFIER_PLANE_LAYOUT_EXT"); - case VK_ERROR_FRAGMENTATION_EXT: - // A descriptor pool creation has failed due to fragmentation. - return ResourceExhaustedError("VK_ERROR_FRAGMENTATION_EXT"); - case VK_ERROR_NOT_PERMITTED_EXT: - // When creating a queue, the caller does not have sufficient privileges - // to request to acquire a priority above the default priority - // (VK_QUEUE_GLOBAL_PRIORITY_MEDIUM_EXT). - return PermissionDeniedError("VK_ERROR_NOT_PERMITTED_EXT"); - case VK_ERROR_INVALID_DEVICE_ADDRESS_EXT: - // A buffer creation failed because the requested address is not - // available. - return OutOfRangeError("VK_ERROR_INVALID_DEVICE_ADDRESS_EXT"); - case VK_ERROR_FULL_SCREEN_EXCLUSIVE_MODE_LOST_EXT: - // An operation on a swapchain created with - // VK_FULL_SCREEN_EXCLUSIVE_APPLICATION_CONTROLLED_EXT failed as it did - // not have exlusive full-screen access. This may occur due to - // implementation-dependent reasons, outside of the application’s control. - return UnavailableError("VK_ERROR_FULL_SCREEN_EXCLUSIVE_MODE_LOST_EXT"); - default: - return UnknownError(std::to_string(result)); - } -} - -} // namespace vulkan -} // namespace hal -} // namespace iree
diff --git a/iree/hal/vulkan/status_util.h b/iree/hal/vulkan/status_util.h deleted file mode 100644 index 85f9def..0000000 --- a/iree/hal/vulkan/status_util.h +++ /dev/null
@@ -1,87 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VULKAN_STATUS_UTIL_H_ -#define IREE_HAL_VULKAN_STATUS_UTIL_H_ - -#include <vulkan/vulkan.h> - -#include "iree/base/status.h" - -namespace iree { -namespace hal { -namespace vulkan { - -// RETURN_IF_ERROR but implicitly converts the VkResult return value to -// a Status. -// -// Usage: -// VK_RETURN_IF_ERROR(vkDoThing(...)); -#define VK_RETURN_IF_ERROR(expr) \ - RETURN_IF_ERROR(::iree::hal::vulkan::VkResultToStatus(expr)) - -// CHECK_OK but implicitly converts the VkResults return value to a -// Status and checks that it is OkStatus. -// -// Usage: -// VK_CHECK_OK(vkDoThing(...)); -#define VK_CHECK_OK(expr) CHECK_OK(::iree::hal::vulkan::VkResultToStatus(expr)) - -// Converts a VkResult to a Status object. -// -// Vulkan considers the following as "success codes" and users should ensure -// they first check the result prior to converting: -// -// - VK_SUCCESS -> OkStatus() -// - VK_NOT_READY -> OkStatus() -// - VK_TIMEOUT -> OkStatus() -// - VK_EVENT_SET -> OkStatus() -// - VK_EVENT_RESET -> OkStatus() -// - VK_INCOMPLETE -> OkStatus() -// - VK_SUBOPTIMAL_KHR -> OkStatus() -// -// The rest are considered as "error codes": -// -// - VK_ERROR_OUT_OF_HOST_MEMORY -> ResourceExhaustedError("VK...") -// - VK_ERROR_OUT_OF_DEVICE_MEMORY -> ResourceExhaustedError("VK...") -// - VK_ERROR_INITIALIZATION_FAILED -> InternalError("VK...") -// - VK_ERROR_DEVICE_LOST -> InternalError("VK...") -// - VK_ERROR_MEMORY_MAP_FAILED -> InternalError("VK...") -// - VK_ERROR_LAYER_NOT_PRESENT -> NotFoundError("VK...") -// - VK_ERROR_EXTENSION_NOT_PRESENT -> NotFoundError("VK...") -// - VK_ERROR_FEATURE_NOT_PRESENT -> NotFoundError("VK...") -// - VK_ERROR_INCOMPATIBLE_DRIVER -> FailedPreconditionError("VK...") -// - VK_ERROR_TOO_MANY_OBJECTS -> ResourceExhaustedError("VK...") -// - VK_ERROR_FORMAT_NOT_SUPPORTED -> UnimplementedError("VK...") -// - VK_ERROR_FRAGMENTED_POOL -> ResourceExhaustedError("VK...") -// - VK_ERROR_OUT_OF_POOL_MEMORY -> ResourceExhaustedError("VK...") -// - VK_ERROR_INVALID_EXTERNAL_HANDLE -> InvalidArgumentError("VK...") -// - VK_ERROR_SURFACE_LOST_KHR -> InternalError("VK...") -// - VK_ERROR_NATIVE_WINDOW_IN_USE_KHR -> InternalError("VK...") -// - VK_ERROR_OUT_OF_DATE_KHR -> InternalError("VK...") -// - VK_ERROR_INCOMPATIBLE_DISPLAY_KHR -> InternalError("VK...") -// - VK_ERROR_VALIDATION_FAILED_EXT -> InternalError("VK...") -// - VK_ERROR_INVALID_SHADER_NV -> InternalError("VK...") -// - VK_ERROR_INVALID_DRM_FORMAT_MODIFIER_PLANE_LAYOUT_EXT -> InternalError -// - VK_ERROR_FRAGMENTATION_EXT -> ResourceExhaustedError("VK...") -// - VK_ERROR_NOT_PERMITTED_EXT -> PermissionDeniedError("VK...") -// - VK_ERROR_INVALID_DEVICE_ADDRESS_EXT -> OutOfRangeError("VK...") -// - VK_ERROR_FULL_SCREEN_EXCLUSIVE_MODE_LOST_EXT -> InternalError("VK...") -Status VkResultToStatus(VkResult result); - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_STATUS_UTIL_H_
diff --git a/iree/hal/vulkan/vma_allocator.cc b/iree/hal/vulkan/vma_allocator.cc deleted file mode 100644 index 181d6e1..0000000 --- a/iree/hal/vulkan/vma_allocator.cc +++ /dev/null
@@ -1,248 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/vma_allocator.h" - -#include "absl/flags/flag.h" -#include "absl/memory/memory.h" -#include "iree/base/source_location.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/buffer.h" -#include "iree/hal/vulkan/status_util.h" -#include "iree/hal/vulkan/vma_buffer.h" - -#if VMA_RECORDING_ENABLED -ABSL_FLAG(std::string, vma_recording_file, "", - "File path to write a CSV containing the VMA recording."); -ABSL_FLAG(bool, vma_recording_flush_after_call, false, - "Flush the VMA recording file after every call (useful if " - "crashing/not exiting cleanly)."); -#endif // VMA_RECORDING_ENABLED - -namespace iree { -namespace hal { -namespace vulkan { - -// static -StatusOr<std::unique_ptr<VmaAllocator>> VmaAllocator::Create( - VkPhysicalDevice physical_device, - const ref_ptr<VkDeviceHandle>& logical_device) { - IREE_TRACE_SCOPE0("VmaAllocator::Create"); - - const auto& syms = logical_device->syms(); - VmaVulkanFunctions vulkan_fns; - vulkan_fns.vkGetPhysicalDeviceProperties = - syms->vkGetPhysicalDeviceProperties; - vulkan_fns.vkGetPhysicalDeviceMemoryProperties = - syms->vkGetPhysicalDeviceMemoryProperties; - vulkan_fns.vkAllocateMemory = syms->vkAllocateMemory; - vulkan_fns.vkFreeMemory = syms->vkFreeMemory; - vulkan_fns.vkMapMemory = syms->vkMapMemory; - vulkan_fns.vkUnmapMemory = syms->vkUnmapMemory; - vulkan_fns.vkFlushMappedMemoryRanges = syms->vkFlushMappedMemoryRanges; - vulkan_fns.vkInvalidateMappedMemoryRanges = - syms->vkInvalidateMappedMemoryRanges; - vulkan_fns.vkBindBufferMemory = syms->vkBindBufferMemory; - vulkan_fns.vkBindImageMemory = syms->vkBindImageMemory; - vulkan_fns.vkGetBufferMemoryRequirements = - syms->vkGetBufferMemoryRequirements; - vulkan_fns.vkGetImageMemoryRequirements = syms->vkGetImageMemoryRequirements; - vulkan_fns.vkCreateBuffer = syms->vkCreateBuffer; - vulkan_fns.vkDestroyBuffer = syms->vkDestroyBuffer; - vulkan_fns.vkCreateImage = syms->vkCreateImage; - vulkan_fns.vkDestroyImage = syms->vkDestroyImage; - vulkan_fns.vkCmdCopyBuffer = syms->vkCmdCopyBuffer; - - VmaRecordSettings record_settings; -#if VMA_RECORDING_ENABLED - record_settings.flags = absl::GetFlag(FLAGS_vma_recording_flush_after_call) - ? VMA_RECORD_FLUSH_AFTER_CALL_BIT - : 0; - record_settings.pFilePath = absl::GetFlag(FLAGS_vma_recording_file).c_str(); -#else - record_settings.flags = 0; - record_settings.pFilePath = nullptr; -#endif // VMA_RECORDING_ENABLED - - VmaAllocatorCreateInfo create_info; - create_info.flags = 0; - create_info.physicalDevice = physical_device; - create_info.device = *logical_device; - create_info.preferredLargeHeapBlockSize = 64 * 1024 * 1024; - create_info.pAllocationCallbacks = logical_device->allocator(); - create_info.pDeviceMemoryCallbacks = nullptr; - create_info.frameInUseCount = 0; - create_info.pHeapSizeLimit = nullptr; - create_info.pVulkanFunctions = &vulkan_fns; - create_info.pRecordSettings = &record_settings; - ::VmaAllocator vma = VK_NULL_HANDLE; - VK_RETURN_IF_ERROR(vmaCreateAllocator(&create_info, &vma)); - - auto allocator = - absl::WrapUnique(new VmaAllocator(physical_device, logical_device, vma)); - // TODO(benvanik): query memory properties/types. - return allocator; -} - -VmaAllocator::VmaAllocator(VkPhysicalDevice physical_device, - const ref_ptr<VkDeviceHandle>& logical_device, - ::VmaAllocator vma) - : physical_device_(physical_device), - logical_device_(add_ref(logical_device)), - vma_(vma) {} - -VmaAllocator::~VmaAllocator() { - IREE_TRACE_SCOPE0("VmaAllocator::dtor"); - vmaDestroyAllocator(vma_); -} - -bool VmaAllocator::CanUseBufferLike(Allocator* source_allocator, - MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - BufferUsageBitfield intended_usage) const { - // TODO(benvanik): ensure there is a memory type that can satisfy the request. - return source_allocator == this; -} - -bool VmaAllocator::CanAllocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - size_t allocation_size) const { - // TODO(benvnik): ensure there is a memory type that can satisfy the request. - return true; -} - -Status VmaAllocator::MakeCompatible(MemoryTypeBitfield* memory_type, - BufferUsageBitfield* buffer_usage) const { - // TODO(benvanik): mutate to match supported memory types. - return OkStatus(); -} - -StatusOr<ref_ptr<VmaBuffer>> VmaAllocator::AllocateInternal( - MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - MemoryAccessBitfield allowed_access, size_t allocation_size, - VmaAllocationCreateFlags flags) { - IREE_TRACE_SCOPE0("VmaAllocator::AllocateInternal"); - - VkBufferCreateInfo buffer_create_info; - buffer_create_info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; - buffer_create_info.pNext = nullptr; - buffer_create_info.flags = 0; - buffer_create_info.size = allocation_size; - buffer_create_info.usage = 0; - if (AllBitsSet(buffer_usage, BufferUsage::kTransfer)) { - buffer_create_info.usage |= VK_BUFFER_USAGE_TRANSFER_SRC_BIT; - buffer_create_info.usage |= VK_BUFFER_USAGE_TRANSFER_DST_BIT; - } - if (AllBitsSet(buffer_usage, BufferUsage::kDispatch)) { - buffer_create_info.usage |= VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT; - buffer_create_info.usage |= VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; - buffer_create_info.usage |= VK_BUFFER_USAGE_INDIRECT_BUFFER_BIT; - } - buffer_create_info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; - buffer_create_info.queueFamilyIndexCount = 0; - buffer_create_info.pQueueFamilyIndices = nullptr; - - VmaAllocationCreateInfo allocation_create_info; - allocation_create_info.flags = flags; - allocation_create_info.usage = VMA_MEMORY_USAGE_UNKNOWN; - allocation_create_info.requiredFlags = 0; - allocation_create_info.preferredFlags = 0; - allocation_create_info.memoryTypeBits = 0; // Automatic selection. - allocation_create_info.pool = VK_NULL_HANDLE; - allocation_create_info.pUserData = nullptr; - if (AllBitsSet(memory_type, MemoryType::kDeviceLocal)) { - if (AllBitsSet(memory_type, MemoryType::kHostVisible)) { - // Device-local, host-visible. - allocation_create_info.usage = VMA_MEMORY_USAGE_CPU_TO_GPU; - allocation_create_info.preferredFlags |= - VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; - } else { - // Device-local only. - allocation_create_info.usage = VMA_MEMORY_USAGE_GPU_ONLY; - allocation_create_info.requiredFlags |= - VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; - } - } else { - if (AllBitsSet(memory_type, MemoryType::kDeviceVisible)) { - // Host-local, device-visible. - allocation_create_info.usage = VMA_MEMORY_USAGE_GPU_TO_CPU; - } else { - // Host-local only. - allocation_create_info.usage = VMA_MEMORY_USAGE_CPU_ONLY; - } - } - if (AllBitsSet(memory_type, MemoryType::kHostCached)) { - allocation_create_info.requiredFlags |= VK_MEMORY_PROPERTY_HOST_CACHED_BIT; - } - if (AllBitsSet(memory_type, MemoryType::kHostCoherent)) { - allocation_create_info.requiredFlags |= - VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; - } - if (AllBitsSet(memory_type, MemoryType::kTransient)) { - allocation_create_info.preferredFlags |= - VK_MEMORY_PROPERTY_LAZILY_ALLOCATED_BIT; - } - if (AllBitsSet(buffer_usage, BufferUsage::kMapping)) { - allocation_create_info.requiredFlags |= VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; - } - - VkBuffer buffer = VK_NULL_HANDLE; - VmaAllocation allocation = VK_NULL_HANDLE; - VmaAllocationInfo allocation_info; - VK_RETURN_IF_ERROR(vmaCreateBuffer(vma_, &buffer_create_info, - &allocation_create_info, &buffer, - &allocation, &allocation_info)); - - return make_ref<VmaBuffer>(this, memory_type, allowed_access, buffer_usage, - allocation_size, 0, allocation_size, buffer, - allocation, allocation_info); -} - -StatusOr<ref_ptr<Buffer>> VmaAllocator::Allocate( - MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - size_t allocation_size) { - IREE_TRACE_SCOPE0("VmaAllocator::Allocate"); - return AllocateInternal(memory_type, buffer_usage, MemoryAccess::kAll, - allocation_size, /*flags=*/0); -} - -StatusOr<ref_ptr<Buffer>> VmaAllocator::AllocateConstant( - BufferUsageBitfield buffer_usage, ref_ptr<Buffer> source_buffer) { - IREE_TRACE_SCOPE0("VmaAllocator::AllocateConstant"); - // TODO(benvanik): import memory to avoid the copy. - ASSIGN_OR_RETURN( - auto buffer, - AllocateInternal(MemoryType::kDeviceLocal | MemoryType::kHostVisible, - buffer_usage, - MemoryAccess::kRead | MemoryAccess::kDiscardWrite, - source_buffer->byte_length(), - /*flags=*/0)); - RETURN_IF_ERROR(buffer->CopyData(0, source_buffer.get(), 0, kWholeBuffer)); - buffer->set_allowed_access(MemoryAccess::kRead); - return buffer; -} - -StatusOr<ref_ptr<Buffer>> VmaAllocator::WrapMutable( - MemoryTypeBitfield memory_type, MemoryAccessBitfield allowed_access, - BufferUsageBitfield buffer_usage, void* data, size_t data_length) { - IREE_TRACE_SCOPE0("VmaAllocator::WrapMutable"); - // TODO(benvanik): import memory. - return UnimplementedErrorBuilder(IREE_LOC) - << "Wrapping host memory is not yet implemented"; -} - -} // namespace vulkan -} // namespace hal -} // namespace iree
diff --git a/iree/hal/vulkan/vma_allocator.h b/iree/hal/vulkan/vma_allocator.h deleted file mode 100644 index 2c0054c..0000000 --- a/iree/hal/vulkan/vma_allocator.h +++ /dev/null
@@ -1,110 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VULKAN_VMA_ALLOCATOR_H_ -#define IREE_HAL_VULKAN_VMA_ALLOCATOR_H_ - -#include <vulkan/vulkan.h> - -#include <memory> - -#include "iree/base/status.h" -#include "iree/hal/allocator.h" -#include "iree/hal/vulkan/dynamic_symbols.h" -#include "iree/hal/vulkan/handle_util.h" -#include "iree/hal/vulkan/internal_vk_mem_alloc.h" - -namespace iree { -namespace hal { -namespace vulkan { - -class VmaBuffer; - -// A HAL allocator using the Vulkan Memory Allocator (VMA) to manage memory. -// VMA (//third_party/vulkan_memory_allocator) provides dlmalloc-like behavior -// with suballocations made with various policies (best fit, first fit, etc). -// This reduces the number of allocations we need from the Vulkan implementation -// (which can sometimes be limited to as little as 4096 total allowed) and -// manages higher level allocation semantics like slab allocation and -// defragmentation. -// -// VMA is internally synchronized and the functionality exposed on the HAL -// interface is thread-safe. -// -// More information: -// https://github.com/GPUOpen-LibrariesAndSDKs/VulkanMemoryAllocator -// https://gpuopen-librariesandsdks.github.io/VulkanMemoryAllocator/html/ -class VmaAllocator final : public Allocator { - public: - static StatusOr<std::unique_ptr<VmaAllocator>> Create( - VkPhysicalDevice physical_device, - const ref_ptr<VkDeviceHandle>& logical_device); - - ~VmaAllocator() override; - - const ref_ptr<DynamicSymbols>& syms() const { - return logical_device_->syms(); - } - - ::VmaAllocator vma() const { return vma_; } - - bool CanUseBufferLike(Allocator* source_allocator, - MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - BufferUsageBitfield intended_usage) const override; - - bool CanAllocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - size_t allocation_size) const override; - - Status MakeCompatible(MemoryTypeBitfield* memory_type, - BufferUsageBitfield* buffer_usage) const override; - - StatusOr<ref_ptr<Buffer>> Allocate(MemoryTypeBitfield memory_type, - BufferUsageBitfield buffer_usage, - size_t allocation_size) override; - - StatusOr<ref_ptr<Buffer>> AllocateConstant( - BufferUsageBitfield buffer_usage, ref_ptr<Buffer> source_buffer) override; - - StatusOr<ref_ptr<Buffer>> WrapMutable(MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, - BufferUsageBitfield buffer_usage, - void* data, - size_t data_length) override; - - private: - VmaAllocator(VkPhysicalDevice physical_device, - const ref_ptr<VkDeviceHandle>& logical_device, - ::VmaAllocator vma); - - StatusOr<ref_ptr<VmaBuffer>> AllocateInternal( - MemoryTypeBitfield memory_type, BufferUsageBitfield buffer_usage, - MemoryAccessBitfield allowed_access, size_t allocation_size, - VmaAllocationCreateFlags flags); - - VkPhysicalDevice physical_device_; - ref_ptr<VkDeviceHandle> logical_device_; - - // Internally synchronized. We could externally synchronize if we thought it - // was worth it, however I'm not sure we'd be able to do much better with the - // current Allocator API. - ::VmaAllocator vma_; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_VMA_ALLOCATOR_H_
diff --git a/iree/hal/vulkan/vma_buffer.cc b/iree/hal/vulkan/vma_buffer.cc deleted file mode 100644 index f7d3ed0..0000000 --- a/iree/hal/vulkan/vma_buffer.cc +++ /dev/null
@@ -1,163 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/vma_buffer.h" - -#include "iree/base/source_location.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/vulkan/status_util.h" -#include "iree/hal/vulkan/vma_allocator.h" - -namespace iree { -namespace hal { -namespace vulkan { - -VmaBuffer::VmaBuffer(VmaAllocator* allocator, MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, - BufferUsageBitfield usage, device_size_t allocation_size, - device_size_t byte_offset, device_size_t byte_length, - VkBuffer buffer, VmaAllocation allocation, - VmaAllocationInfo allocation_info) - : Buffer(allocator, memory_type, allowed_access, usage, allocation_size, - byte_offset, byte_length), - vma_(allocator->vma()), - buffer_(buffer), - allocation_(allocation), - allocation_info_(allocation_info) { - // TODO(benvanik): set debug name instead and use the - // VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT flag. - vmaSetAllocationUserData(vma_, allocation_, this); -} - -VmaBuffer::~VmaBuffer() { - IREE_TRACE_SCOPE0("VmaBuffer::dtor"); - vmaDestroyBuffer(vma_, buffer_, allocation_); -} - -Status VmaBuffer::FillImpl(device_size_t byte_offset, device_size_t byte_length, - const void* pattern, device_size_t pattern_length) { - ASSIGN_OR_RETURN(auto mapping, MapMemory<uint8_t>(MemoryAccess::kDiscardWrite, - byte_offset, byte_length)); - void* data_ptr = static_cast<void*>(mapping.mutable_data()); - switch (pattern_length) { - case 1: { - uint8_t* data = static_cast<uint8_t*>(data_ptr); - uint8_t value_bits = *static_cast<const uint8_t*>(pattern); - std::fill_n(data + byte_offset, byte_length, value_bits); - break; - } - case 2: { - uint16_t* data = static_cast<uint16_t*>(data_ptr); - uint16_t value_bits = *static_cast<const uint16_t*>(pattern); - std::fill_n(data + byte_offset / sizeof(uint16_t), - byte_length / sizeof(uint16_t), value_bits); - break; - } - case 4: { - uint32_t* data = static_cast<uint32_t*>(data_ptr); - uint32_t value_bits = *static_cast<const uint32_t*>(pattern); - std::fill_n(data + byte_offset / sizeof(uint32_t), - byte_length / sizeof(uint32_t), value_bits); - break; - } - default: - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Unsupported scalar data size: " << pattern_length; - } - return OkStatus(); -} - -Status VmaBuffer::ReadDataImpl(device_size_t source_offset, void* data, - device_size_t data_length) { - ASSIGN_OR_RETURN( - auto mapping, - MapMemory<uint8_t>(MemoryAccess::kRead, source_offset, data_length)); - std::memcpy(data, mapping.data(), mapping.byte_length()); - return OkStatus(); -} - -Status VmaBuffer::WriteDataImpl(device_size_t target_offset, const void* data, - device_size_t data_length) { - ASSIGN_OR_RETURN(auto mapping, - MapMemory<uint8_t>(MemoryAccess::kDiscardWrite, - target_offset, data_length)); - std::memcpy(mapping.mutable_data(), data, mapping.byte_length()); - return OkStatus(); -} - -Status VmaBuffer::CopyDataImpl(device_size_t target_offset, - Buffer* source_buffer, - device_size_t source_offset, - device_size_t data_length) { - // This is pretty terrible. Let's not do this. - // TODO(benvanik): a way for allocators to indicate transfer compat. - ASSIGN_OR_RETURN(auto source_mapping, - source_buffer->MapMemory<uint8_t>( - MemoryAccess::kRead, source_offset, data_length)); - CHECK_EQ(data_length, source_mapping.size()); - ASSIGN_OR_RETURN(auto target_mapping, - MapMemory<uint8_t>(MemoryAccess::kDiscardWrite, - target_offset, data_length)); - CHECK_EQ(data_length, target_mapping.size()); - std::memcpy(target_mapping.mutable_data() + target_offset, - source_mapping.data(), data_length); - return OkStatus(); -} - -Status VmaBuffer::MapMemoryImpl(MappingMode mapping_mode, - MemoryAccessBitfield memory_access, - device_size_t local_byte_offset, - device_size_t local_byte_length, - void** out_data) { - uint8_t* data_ptr = nullptr; - VK_RETURN_IF_ERROR( - vmaMapMemory(vma_, allocation_, reinterpret_cast<void**>(&data_ptr))); - *out_data = data_ptr + local_byte_offset; - - // If we mapped for discard scribble over the bytes. This is not a mandated - // behavior but it will make debugging issues easier. Alternatively for - // heap buffers we could reallocate them such that ASAN yells, but that - // would only work if the entire buffer was discarded. -#ifndef NDEBUG - if (AnyBitSet(memory_access & MemoryAccess::kDiscard)) { - std::memset(data_ptr + local_byte_offset, 0xCD, local_byte_length); - } -#endif // !NDEBUG - - return OkStatus(); -} - -Status VmaBuffer::UnmapMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length, void* data) { - vmaUnmapMemory(vma_, allocation_); - return OkStatus(); -} - -Status VmaBuffer::InvalidateMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) { - vmaInvalidateAllocation(vma_, allocation_, local_byte_offset, - local_byte_length); - return OkStatus(); -} - -Status VmaBuffer::FlushMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) { - vmaFlushAllocation(vma_, allocation_, local_byte_offset, local_byte_length); - return OkStatus(); -} - -} // namespace vulkan -} // namespace hal -} // namespace iree
diff --git a/iree/hal/vulkan/vma_buffer.h b/iree/hal/vulkan/vma_buffer.h deleted file mode 100644 index b768f71..0000000 --- a/iree/hal/vulkan/vma_buffer.h +++ /dev/null
@@ -1,79 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VULKAN_VMA_BUFFER_H_ -#define IREE_HAL_VULKAN_VMA_BUFFER_H_ - -#include <vulkan/vulkan.h> - -#include "iree/hal/buffer.h" -#include "vk_mem_alloc.h" - -namespace iree { -namespace hal { -namespace vulkan { - -class VmaAllocator; - -// A buffer implementation representing an allocation made from within a pool of -// a Vulkan Memory Allocator instance. See VmaAllocator for more information. -class VmaBuffer final : public Buffer { - public: - VmaBuffer(VmaAllocator* allocator, MemoryTypeBitfield memory_type, - MemoryAccessBitfield allowed_access, BufferUsageBitfield usage, - device_size_t allocation_size, device_size_t byte_offset, - device_size_t byte_length, VkBuffer buffer, - VmaAllocation allocation, VmaAllocationInfo allocation_info); - ~VmaBuffer() override; - - VkBuffer handle() const { return buffer_; } - VmaAllocation allocation() const { return allocation_; } - const VmaAllocationInfo& allocation_info() const { return allocation_info_; } - - // Exposed so that VmaAllocator can reset access after initial mapping. - using Buffer::set_allowed_access; - - private: - Status FillImpl(device_size_t byte_offset, device_size_t byte_length, - const void* pattern, device_size_t pattern_length) override; - Status ReadDataImpl(device_size_t source_offset, void* data, - device_size_t data_length) override; - Status WriteDataImpl(device_size_t target_offset, const void* data, - device_size_t data_length) override; - Status CopyDataImpl(device_size_t target_offset, Buffer* source_buffer, - device_size_t source_offset, - device_size_t data_length) override; - Status MapMemoryImpl(MappingMode mapping_mode, - MemoryAccessBitfield memory_access, - device_size_t local_byte_offset, - device_size_t local_byte_length, - void** out_data) override; - Status UnmapMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length, void* data) override; - Status InvalidateMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) override; - Status FlushMappedMemoryImpl(device_size_t local_byte_offset, - device_size_t local_byte_length) override; - - ::VmaAllocator vma_; - VkBuffer buffer_; - VmaAllocation allocation_; - VmaAllocationInfo allocation_info_; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_VMA_BUFFER_H_
diff --git a/iree/hal/vulkan/vulkan_device.cc b/iree/hal/vulkan/vulkan_device.cc deleted file mode 100644 index 56f3483..0000000 --- a/iree/hal/vulkan/vulkan_device.cc +++ /dev/null
@@ -1,500 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/vulkan_device.h" - -#include <functional> -#include <utility> - -#include "absl/container/inlined_vector.h" -#include "absl/memory/memory.h" -#include "absl/strings/str_cat.h" -#include "absl/synchronization/mutex.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/command_buffer_validation.h" -#include "iree/hal/command_queue.h" -#include "iree/hal/fence.h" -#include "iree/hal/vulkan/direct_command_buffer.h" -#include "iree/hal/vulkan/direct_command_queue.h" -#include "iree/hal/vulkan/dynamic_symbols.h" -#include "iree/hal/vulkan/extensibility_util.h" -#include "iree/hal/vulkan/legacy_fence.h" -#include "iree/hal/vulkan/native_binary_semaphore.h" -#include "iree/hal/vulkan/native_event.h" -#include "iree/hal/vulkan/pipeline_cache.h" -#include "iree/hal/vulkan/status_util.h" -#include "iree/hal/vulkan/vma_allocator.h" - -namespace iree { -namespace hal { -namespace vulkan { - -namespace { - -constexpr uint32_t kInvalidQueueFamilyIndex = -1; - -struct QueueFamilyInfo { - uint32_t dispatch_index = kInvalidQueueFamilyIndex; - uint32_t dispatch_queue_count = 0; - uint32_t transfer_index = kInvalidQueueFamilyIndex; - uint32_t transfer_queue_count = 0; -}; - -// Finds the first queue in the listing (which is usually the driver-preferred) -// that has all of the |required_queue_flags| and none of the -// |excluded_queue_flags|. -// Returns kInvalidQueueFamilyIndex if no matching queue is found. -uint32_t FindFirstQueueFamilyWithFlags( - absl::Span<const VkQueueFamilyProperties> queue_family_properties, - uint32_t required_queue_flags, uint32_t excluded_queue_flags) { - for (int queue_family_index = 0; - queue_family_index < queue_family_properties.size(); - ++queue_family_index) { - const auto& properties = queue_family_properties[queue_family_index]; - if ((properties.queueFlags & required_queue_flags) == - required_queue_flags && - (properties.queueFlags & excluded_queue_flags) == 0) { - return queue_family_index; - } - } - return kInvalidQueueFamilyIndex; -} - -// Selects queue family indices for compute and transfer queues. -// Note that both queue families may be the same if there is only one family -// available. -StatusOr<QueueFamilyInfo> SelectQueueFamilies( - VkPhysicalDevice physical_device, const ref_ptr<DynamicSymbols>& syms) { - // Enumerate queue families available on the device. - uint32_t queue_family_count = 0; - syms->vkGetPhysicalDeviceQueueFamilyProperties(physical_device, - &queue_family_count, nullptr); - absl::InlinedVector<VkQueueFamilyProperties, 4> queue_family_properties( - queue_family_count); - syms->vkGetPhysicalDeviceQueueFamilyProperties( - physical_device, &queue_family_count, queue_family_properties.data()); - - QueueFamilyInfo queue_family_info; - - // Try to find a dedicated compute queue (no graphics caps). - // Some may support both transfer and compute. If that fails then fallback to - // any queue that supports compute. - queue_family_info.dispatch_index = FindFirstQueueFamilyWithFlags( - queue_family_properties, VK_QUEUE_COMPUTE_BIT, VK_QUEUE_GRAPHICS_BIT); - if (queue_family_info.dispatch_index == kInvalidQueueFamilyIndex) { - queue_family_info.dispatch_index = FindFirstQueueFamilyWithFlags( - queue_family_properties, VK_QUEUE_COMPUTE_BIT, 0); - } - if (queue_family_info.dispatch_index == kInvalidQueueFamilyIndex) { - return NotFoundErrorBuilder(IREE_LOC) - << "Unable to find any queue family support compute operations"; - } - queue_family_info.dispatch_queue_count = - queue_family_properties[queue_family_info.dispatch_index].queueCount; - - // Try to find a dedicated transfer queue (no compute or graphics caps). - // Not all devices have one, and some have only a queue family for everything - // and possibly a queue family just for compute/etc. If that fails then - // fallback to any queue that supports transfer. Finally, if /that/ fails then - // we just won't create a transfer queue and instead use the compute queue for - // all operations. - queue_family_info.transfer_index = FindFirstQueueFamilyWithFlags( - queue_family_properties, VK_QUEUE_TRANSFER_BIT, - VK_QUEUE_COMPUTE_BIT | VK_QUEUE_GRAPHICS_BIT); - if (queue_family_info.transfer_index == kInvalidQueueFamilyIndex) { - queue_family_info.transfer_index = FindFirstQueueFamilyWithFlags( - queue_family_properties, VK_QUEUE_TRANSFER_BIT, VK_QUEUE_GRAPHICS_BIT); - } - if (queue_family_info.transfer_index == kInvalidQueueFamilyIndex) { - queue_family_info.transfer_index = FindFirstQueueFamilyWithFlags( - queue_family_properties, VK_QUEUE_TRANSFER_BIT, 0); - } - if (queue_family_info.transfer_index != kInvalidQueueFamilyIndex) { - queue_family_info.transfer_queue_count = - queue_family_properties[queue_family_info.transfer_index].queueCount; - } - - // Ensure that we don't share the dispatch queues with transfer queues if that - // would put us over the queue count. - if (queue_family_info.dispatch_index == queue_family_info.transfer_index) { - queue_family_info.transfer_queue_count = std::min( - queue_family_properties[queue_family_info.dispatch_index].queueCount - - queue_family_info.dispatch_queue_count, - queue_family_info.transfer_queue_count); - } - - return queue_family_info; -} - -// Creates a transient command pool for the given queue family. -// Command buffers allocated from the pool must only be issued on queues -// belonging to the specified family. -StatusOr<ref_ptr<VkCommandPoolHandle>> CreateTransientCommandPool( - const ref_ptr<VkDeviceHandle>& logical_device, - uint32_t queue_family_index) { - VkCommandPoolCreateInfo create_info; - create_info.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; - create_info.pNext = nullptr; - create_info.flags = VK_COMMAND_POOL_CREATE_TRANSIENT_BIT; - create_info.queueFamilyIndex = queue_family_index; - - auto command_pool = make_ref<VkCommandPoolHandle>(logical_device); - VK_RETURN_IF_ERROR(logical_device->syms()->vkCreateCommandPool( - *logical_device, &create_info, logical_device->allocator(), - command_pool->mutable_value())); - return command_pool; -} - -} // namespace - -// static -StatusOr<std::shared_ptr<VulkanDevice>> VulkanDevice::Create( - const DeviceInfo& device_info, VkPhysicalDevice physical_device, - const ExtensibilitySpec& extensibility_spec, - const ref_ptr<DynamicSymbols>& syms) { - IREE_TRACE_SCOPE0("VulkanDevice::Create"); - - // Find the layers and extensions we need (or want) that are also available - // on the device. This will fail when required ones are not present. - ASSIGN_OR_RETURN( - auto enabled_layer_names, - MatchAvailableDeviceLayers(physical_device, extensibility_spec, *syms)); - ASSIGN_OR_RETURN(auto enabled_extension_names, - MatchAvailableDeviceExtensions(physical_device, - extensibility_spec, *syms)); - auto enabled_device_extensions = - PopulateEnabledDeviceExtensions(enabled_extension_names); - - // Find queue families we will expose as HAL queues. - ASSIGN_OR_RETURN(auto queue_family_info, - SelectQueueFamilies(physical_device, syms)); - - // Limit the number of queues we create (for now). - // We may want to allow this to grow, but each queue adds overhead and we need - // to measure to make sure we can effectively use them all. - queue_family_info.dispatch_queue_count = - std::min(2u, queue_family_info.dispatch_queue_count); - queue_family_info.transfer_queue_count = - std::min(1u, queue_family_info.transfer_queue_count); - bool has_dedicated_transfer_queues = - queue_family_info.transfer_queue_count > 0; - - // Setup the queue info we'll be using. - // Each queue here (created from within a family) will map to a HAL queue. - // - // Note that we need to handle the case where we have transfer queues that are - // of the same queue family as the dispatch queues: Vulkan requires that all - // queues created from the same family are done in the same - // VkDeviceQueueCreateInfo struct. - DVLOG(1) << "Creating " << queue_family_info.dispatch_queue_count - << " dispatch queue(s) in queue family " - << queue_family_info.dispatch_index; - absl::InlinedVector<VkDeviceQueueCreateInfo, 2> queue_create_info; - absl::InlinedVector<float, 4> dispatch_queue_priorities; - absl::InlinedVector<float, 4> transfer_queue_priorities; - queue_create_info.push_back({}); - auto& dispatch_queue_info = queue_create_info.back(); - dispatch_queue_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; - dispatch_queue_info.pNext = nullptr; - dispatch_queue_info.flags = 0; - dispatch_queue_info.queueFamilyIndex = queue_family_info.dispatch_index; - dispatch_queue_info.queueCount = queue_family_info.dispatch_queue_count; - if (has_dedicated_transfer_queues) { - if (queue_family_info.dispatch_index == queue_family_info.transfer_index) { - DVLOG(1) << "Creating " << queue_family_info.transfer_queue_count - << " dedicated transfer queue(s) in shared queue family " - << queue_family_info.transfer_index; - dispatch_queue_info.queueCount += queue_family_info.transfer_queue_count; - } else { - DVLOG(1) << "Creating " << queue_family_info.transfer_queue_count - << " dedicated transfer queue(s) in independent queue family " - << queue_family_info.transfer_index; - queue_create_info.push_back({}); - auto& transfer_queue_info = queue_create_info.back(); - transfer_queue_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; - transfer_queue_info.pNext = nullptr; - transfer_queue_info.queueFamilyIndex = queue_family_info.transfer_index; - transfer_queue_info.queueCount = queue_family_info.transfer_queue_count; - transfer_queue_info.flags = 0; - transfer_queue_priorities.resize(transfer_queue_info.queueCount); - transfer_queue_info.pQueuePriorities = transfer_queue_priorities.data(); - } - } - dispatch_queue_priorities.resize(dispatch_queue_info.queueCount); - dispatch_queue_info.pQueuePriorities = dispatch_queue_priorities.data(); - - // TODO(benvanik): specify features with VkPhysicalDeviceFeatures. - - // Create device and its queues. - VkDeviceCreateInfo device_create_info = {}; - device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO; - device_create_info.pNext = nullptr; - device_create_info.enabledLayerCount = enabled_layer_names.size(); - device_create_info.ppEnabledLayerNames = enabled_layer_names.data(); - device_create_info.enabledExtensionCount = enabled_extension_names.size(); - device_create_info.ppEnabledExtensionNames = enabled_extension_names.data(); - device_create_info.queueCreateInfoCount = queue_create_info.size(); - device_create_info.pQueueCreateInfos = queue_create_info.data(); - device_create_info.pEnabledFeatures = nullptr; - auto logical_device = make_ref<VkDeviceHandle>( - syms, enabled_device_extensions, /*allocator=*/nullptr); - VK_RETURN_IF_ERROR(syms->vkCreateDevice(physical_device, &device_create_info, - logical_device->allocator(), - logical_device->mutable_value())); - - // Create the device memory allocator. - // TODO(benvanik): allow other types to be plugged in. - ASSIGN_OR_RETURN(auto allocator, - VmaAllocator::Create(physical_device, logical_device)); - - // Create command pools for each queue family. If we don't have a transfer - // queue then we'll ignore that one and just use the dispatch pool. - // If we wanted to expose the pools through the HAL to allow the VM to more - // effectively manage them (pool per fiber, etc) we could, however I doubt the - // overhead of locking the pool will be even a blip. - ASSIGN_OR_RETURN(auto dispatch_command_pool, - CreateTransientCommandPool( - logical_device, queue_family_info.dispatch_index)); - ref_ptr<VkCommandPoolHandle> transfer_command_pool; - if (has_dedicated_transfer_queues) { - ASSIGN_OR_RETURN(transfer_command_pool, - CreateTransientCommandPool( - logical_device, queue_family_info.transfer_index)); - } - - // Get the queues and create the HAL wrappers. - absl::InlinedVector<std::unique_ptr<CommandQueue>, 4> command_queues; - for (uint32_t i = 0; i < queue_family_info.dispatch_queue_count; ++i) { - VkQueue queue = VK_NULL_HANDLE; - syms->vkGetDeviceQueue(*logical_device, queue_family_info.dispatch_index, i, - &queue); - std::string queue_name = absl::StrCat(device_info.name(), ":d", i); - command_queues.push_back(absl::make_unique<DirectCommandQueue>( - std::move(queue_name), - CommandCategory::kDispatch | CommandCategory::kTransfer, logical_device, - queue)); - } - if (has_dedicated_transfer_queues) { - uint32_t base_queue_index = 0; - if (queue_family_info.dispatch_index == queue_family_info.transfer_index) { - // Sharing a family, so transfer queues follow compute queues. - base_queue_index = queue_family_info.dispatch_index; - } - for (uint32_t i = 0; i < queue_family_info.transfer_queue_count; ++i) { - VkQueue queue = VK_NULL_HANDLE; - syms->vkGetDeviceQueue(*logical_device, queue_family_info.transfer_index, - base_queue_index + i, &queue); - std::string queue_name = absl::StrCat(device_info.name(), ":t", i); - command_queues.push_back(absl::make_unique<DirectCommandQueue>( - std::move(queue_name), CommandCategory::kTransfer, logical_device, - queue)); - } - } - - // TODO(b/140141417): implement timeline semaphore fences and switch here. - ASSIGN_OR_RETURN(auto legacy_fence_pool, - LegacyFencePool::Create(add_ref(logical_device))); - - return std::make_shared<VulkanDevice>( - CtorKey{}, device_info, physical_device, std::move(logical_device), - std::move(allocator), std::move(command_queues), - std::move(dispatch_command_pool), std::move(transfer_command_pool), - std::move(legacy_fence_pool)); -} - -VulkanDevice::VulkanDevice( - CtorKey ctor_key, const DeviceInfo& device_info, - VkPhysicalDevice physical_device, ref_ptr<VkDeviceHandle> logical_device, - std::unique_ptr<Allocator> allocator, - absl::InlinedVector<std::unique_ptr<CommandQueue>, 4> command_queues, - ref_ptr<VkCommandPoolHandle> dispatch_command_pool, - ref_ptr<VkCommandPoolHandle> transfer_command_pool, - ref_ptr<LegacyFencePool> legacy_fence_pool) - : Device(device_info), - physical_device_(physical_device), - logical_device_(std::move(logical_device)), - allocator_(std::move(allocator)), - command_queues_(std::move(command_queues)), - descriptor_pool_cache_( - make_ref<DescriptorPoolCache>(add_ref(logical_device_))), - dispatch_command_pool_(std::move(dispatch_command_pool)), - transfer_command_pool_(std::move(transfer_command_pool)), - legacy_fence_pool_(std::move(legacy_fence_pool)) { - // Populate the queue lists based on queue capabilities. - for (auto& command_queue : command_queues_) { - if (command_queue->can_dispatch()) { - dispatch_queues_.push_back(command_queue.get()); - if (transfer_command_pool_ == VK_NULL_HANDLE) { - transfer_queues_.push_back(command_queue.get()); - } - } else { - transfer_queues_.push_back(command_queue.get()); - } - } -} - -VulkanDevice::~VulkanDevice() { - IREE_TRACE_SCOPE0("VulkanDevice::dtor"); - - // Drop all command queues. These may wait until idle. - command_queues_.clear(); - dispatch_queues_.clear(); - transfer_queues_.clear(); - - // Drop command pools now that we know there are no more outstanding command - // buffers. - dispatch_command_pool_.reset(); - transfer_command_pool_.reset(); - - // Now that no commands are outstanding we can release all descriptor sets. - descriptor_pool_cache_.reset(); - - // Finally, destroy the device. - logical_device_.reset(); -} - -std::shared_ptr<ExecutableCache> VulkanDevice::CreateExecutableCache() { - IREE_TRACE_SCOPE0("VulkanDevice::CreateExecutableCache"); - return std::make_shared<PipelineCache>(logical_device_); -} - -StatusOr<ref_ptr<CommandBuffer>> VulkanDevice::CreateCommandBuffer( - CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories) { - IREE_TRACE_SCOPE0("VulkanDevice::CreateCommandBuffer"); - - // Select the command pool to used based on the types of commands used. - // Note that we may not have a dedicated transfer command pool if there are no - // dedicated transfer queues. - ref_ptr<VkCommandPoolHandle> command_pool; - if (transfer_command_pool_ && - !AllBitsSet(command_categories, CommandCategory::kDispatch)) { - command_pool = add_ref(transfer_command_pool_); - } else { - command_pool = add_ref(dispatch_command_pool_); - } - - VkCommandBufferAllocateInfo allocate_info; - allocate_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; - allocate_info.pNext = nullptr; - allocate_info.commandPool = *command_pool; - allocate_info.commandBufferCount = 1; - allocate_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; - - VkCommandBuffer command_buffer = VK_NULL_HANDLE; - { - absl::MutexLock lock(command_pool->mutex()); - VK_RETURN_IF_ERROR(syms()->vkAllocateCommandBuffers( - *logical_device_, &allocate_info, &command_buffer)); - } - - // TODO(b/140026716): conditionally enable validation. - auto impl = make_ref<DirectCommandBuffer>( - allocator(), mode, command_categories, add_ref(descriptor_pool_cache_), - add_ref(command_pool), command_buffer); - return WrapCommandBufferWithValidation(std::move(impl)); -} - -StatusOr<ref_ptr<Event>> VulkanDevice::CreateEvent() { - IREE_TRACE_SCOPE0("VulkanDevice::CreateEvent"); - - // TODO(b/138729892): pool events. - VkEventCreateInfo create_info; - create_info.sType = VK_STRUCTURE_TYPE_EVENT_CREATE_INFO; - create_info.pNext = nullptr; - create_info.flags = 0; - VkEvent event_handle = VK_NULL_HANDLE; - VK_RETURN_IF_ERROR(syms()->vkCreateEvent(*logical_device_, &create_info, - logical_device_->allocator(), - &event_handle)); - - return make_ref<NativeEvent>(add_ref(logical_device_), event_handle); -} - -StatusOr<ref_ptr<BinarySemaphore>> VulkanDevice::CreateBinarySemaphore( - bool initial_value) { - IREE_TRACE_SCOPE0("VulkanDevice::CreateBinarySemaphore"); - - VkSemaphoreCreateInfo create_info; - create_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_CREATE_INFO; - create_info.pNext = nullptr; - create_info.flags = initial_value ? VK_FENCE_CREATE_SIGNALED_BIT : 0; - VkSemaphore semaphore_handle = VK_NULL_HANDLE; - VK_RETURN_IF_ERROR(syms()->vkCreateSemaphore(*logical_device_, &create_info, - logical_device_->allocator(), - &semaphore_handle)); - - return make_ref<NativeBinarySemaphore>(add_ref(logical_device_), - semaphore_handle); -} - -StatusOr<ref_ptr<TimelineSemaphore>> VulkanDevice::CreateTimelineSemaphore( - uint64_t initial_value) { - IREE_TRACE_SCOPE0("VulkanDevice::CreateTimelineSemaphore"); - - // TODO(b/140141417): implement timeline semaphores. - return UnimplementedErrorBuilder(IREE_LOC) - << "Timeline semaphores not yet implemented"; -} - -StatusOr<ref_ptr<Fence>> VulkanDevice::CreateFence(uint64_t initial_value) { - IREE_TRACE_SCOPE0("VulkanDevice::CreateFence"); - - // TODO(b/140141417): implement timeline semaphore fences and switch here. - // NOTE: we'll want some magic factory so that we can cleanly compile out the - // legacy implementation and pool. - - return make_ref<LegacyFence>(add_ref(legacy_fence_pool_), initial_value); -} - -Status VulkanDevice::WaitAllFences(absl::Span<const FenceValue> fences, - absl::Time deadline) { - IREE_TRACE_SCOPE0("VulkanDevice::WaitAllFences"); - - // TODO(b/140141417): implement timeline semaphore fences and switch here. - - return LegacyFence::WaitForFences(logical_device_.get(), fences, - /*wait_all=*/true, deadline); -} - -StatusOr<int> VulkanDevice::WaitAnyFence(absl::Span<const FenceValue> fences, - absl::Time deadline) { - IREE_TRACE_SCOPE0("VulkanDevice::WaitAnyFence"); - - // TODO(b/140141417): implement timeline semaphore fences and switch here. - - return LegacyFence::WaitForFences(logical_device_.get(), fences, - /*wait_all=*/false, deadline); -} - -Status VulkanDevice::WaitIdle(absl::Time deadline) { - if (deadline == absl::InfiniteFuture()) { - // Fast path for using vkDeviceWaitIdle, which is usually cheaper (as it - // requires fewer calls into the driver). - IREE_TRACE_SCOPE0("VulkanDevice::WaitIdle#vkDeviceWaitIdle"); - VK_RETURN_IF_ERROR(syms()->vkDeviceWaitIdle(*logical_device_)); - return OkStatus(); - } - - IREE_TRACE_SCOPE0("VulkanDevice::WaitIdle#Fences"); - for (auto& command_queue : command_queues_) { - RETURN_IF_ERROR(command_queue->WaitIdle(deadline)); - } - return OkStatus(); -} - -} // namespace vulkan -} // namespace hal -} // namespace iree
diff --git a/iree/hal/vulkan/vulkan_device.h b/iree/hal/vulkan/vulkan_device.h deleted file mode 100644 index f4557c1..0000000 --- a/iree/hal/vulkan/vulkan_device.h +++ /dev/null
@@ -1,120 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VULKAN_VULKAN_DEVICE_H_ -#define IREE_HAL_VULKAN_VULKAN_DEVICE_H_ - -#include <vulkan/vulkan.h> - -#include <functional> -#include <memory> - -#include "absl/container/inlined_vector.h" -#include "absl/types/span.h" -#include "iree/base/memory.h" -#include "iree/hal/allocator.h" -#include "iree/hal/device.h" -#include "iree/hal/vulkan/descriptor_pool_cache.h" -#include "iree/hal/vulkan/dynamic_symbols.h" -#include "iree/hal/vulkan/extensibility_util.h" -#include "iree/hal/vulkan/handle_util.h" -#include "iree/hal/vulkan/legacy_fence.h" - -namespace iree { -namespace hal { -namespace vulkan { - -class VulkanDevice final : public Device { - public: - static StatusOr<std::shared_ptr<VulkanDevice>> Create( - const DeviceInfo& device_info, VkPhysicalDevice physical_device, - const ExtensibilitySpec& extensibility_spec, - const ref_ptr<DynamicSymbols>& syms); - - // Private constructor. - struct CtorKey { - private: - friend class VulkanDevice; - CtorKey() = default; - }; - VulkanDevice( - CtorKey ctor_key, const DeviceInfo& device_info, - VkPhysicalDevice physical_device, ref_ptr<VkDeviceHandle> logical_device, - std::unique_ptr<Allocator> allocator, - absl::InlinedVector<std::unique_ptr<CommandQueue>, 4> command_queues, - ref_ptr<VkCommandPoolHandle> dispatch_command_pool, - ref_ptr<VkCommandPoolHandle> transfer_command_pool, - ref_ptr<LegacyFencePool> legacy_fence_pool); - ~VulkanDevice() override; - - const ref_ptr<DynamicSymbols>& syms() const { - return logical_device_->syms(); - } - - Allocator* allocator() const override { return allocator_.get(); } - - absl::Span<CommandQueue*> dispatch_queues() const override { - return absl::MakeSpan(dispatch_queues_); - } - - absl::Span<CommandQueue*> transfer_queues() const override { - return absl::MakeSpan(transfer_queues_); - } - - std::shared_ptr<ExecutableCache> CreateExecutableCache() override; - - StatusOr<ref_ptr<CommandBuffer>> CreateCommandBuffer( - CommandBufferModeBitfield mode, - CommandCategoryBitfield command_categories) override; - - StatusOr<ref_ptr<Event>> CreateEvent() override; - - StatusOr<ref_ptr<BinarySemaphore>> CreateBinarySemaphore( - bool initial_value) override; - StatusOr<ref_ptr<TimelineSemaphore>> CreateTimelineSemaphore( - uint64_t initial_value) override; - - StatusOr<ref_ptr<Fence>> CreateFence(uint64_t initial_value) override; - Status WaitAllFences(absl::Span<const FenceValue> fences, - absl::Time deadline) override; - StatusOr<int> WaitAnyFence(absl::Span<const FenceValue> fences, - absl::Time deadline) override; - - Status WaitIdle(absl::Time deadline) override; - - private: - VkPhysicalDevice physical_device_; - ref_ptr<VkDeviceHandle> logical_device_; - - std::unique_ptr<Allocator> allocator_; - - mutable absl::InlinedVector<std::unique_ptr<CommandQueue>, 4> command_queues_; - mutable absl::InlinedVector<CommandQueue*, 4> dispatch_queues_; - mutable absl::InlinedVector<CommandQueue*, 4> transfer_queues_; - - ref_ptr<DescriptorPoolCache> descriptor_pool_cache_; - - ref_ptr<VkCommandPoolHandle> dispatch_command_pool_; - ref_ptr<VkCommandPoolHandle> transfer_command_pool_; - - // TODO(b/140141417): implement timeline semaphore fences and conditionally - // compile the legacy fence pool out. - ref_ptr<LegacyFencePool> legacy_fence_pool_; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_VULKAN_DEVICE_H_
diff --git a/iree/hal/vulkan/vulkan_driver.cc b/iree/hal/vulkan/vulkan_driver.cc deleted file mode 100644 index a5238d0..0000000 --- a/iree/hal/vulkan/vulkan_driver.cc +++ /dev/null
@@ -1,229 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/vulkan_driver.h" - -#include <memory> - -#include "absl/container/inlined_vector.h" -#include "iree/base/memory.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/device_info.h" -#include "iree/hal/vulkan/extensibility_util.h" -#include "iree/hal/vulkan/status_util.h" -#include "iree/hal/vulkan/vulkan_device.h" - -namespace iree { -namespace hal { -namespace vulkan { - -namespace { - -// Returns a VkApplicationInfo struct populated with the default app info. -// We may allow hosting applications to override this via weak-linkage if it's -// useful, otherwise this is enough to create the application. -VkApplicationInfo GetDefaultApplicationInfo() { - VkApplicationInfo info; - info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; - info.pNext = nullptr; - info.pApplicationName = "IREE-ML"; - info.applicationVersion = 0; - info.pEngineName = "IREE"; - info.engineVersion = 0; - info.apiVersion = VK_API_VERSION_1_0; - return info; -} - -// Populates device information from the given Vulkan physical device handle. -StatusOr<DeviceInfo> PopulateDeviceInfo(VkPhysicalDevice physical_device, - const ref_ptr<DynamicSymbols>& syms) { - VkPhysicalDeviceFeatures physical_device_features; - syms->vkGetPhysicalDeviceFeatures(physical_device, &physical_device_features); - // TODO(benvanik): check and optionally require these features: - // - physical_device_features.robustBufferAccess - // - physical_device_features.shaderInt16 - // - physical_device_features.shaderInt64 - // - physical_device_features.shaderFloat64 - - VkPhysicalDeviceProperties physical_device_properties; - syms->vkGetPhysicalDeviceProperties(physical_device, - &physical_device_properties); - // TODO(benvanik): check and optionally require reasonable limits. - - // TODO(benvanik): more clever/sanitized device naming. - std::string name = std::string(physical_device_properties.deviceName); - - DeviceFeatureBitfield supported_features = DeviceFeature::kNone; - // TODO(benvanik): implement debugging/profiling features. - // TODO(benvanik): use props to determine if we have timing info. - // supported_features |= DeviceFeature::kDebugging; - // supported_features |= DeviceFeature::kCoverage; - // supported_features |= DeviceFeature::kProfiling; - return DeviceInfo(std::move(name), supported_features, physical_device); -} - -} // namespace - -// static -StatusOr<std::shared_ptr<VulkanDriver>> VulkanDriver::Create( - Options options, ref_ptr<DynamicSymbols> syms) { - IREE_TRACE_SCOPE0("VulkanDriver::Create"); - - // Find the layers and extensions we need (or want) that are also available - // on the instance. This will fail when required ones are not present. - ASSIGN_OR_RETURN( - auto enabled_layer_names, - MatchAvailableInstanceLayers(options.instance_extensibility, *syms)); - ASSIGN_OR_RETURN( - auto enabled_extension_names, - MatchAvailableInstanceExtensions(options.instance_extensibility, *syms)); - auto instance_extensions = - PopulateEnabledInstanceExtensions(enabled_extension_names); - - // Create the instance this driver will use for all requests. - VkApplicationInfo app_info = GetDefaultApplicationInfo(); - app_info.apiVersion = options.api_version; - VkInstanceCreateInfo create_info; - create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; - create_info.pNext = nullptr; - create_info.flags = 0; - create_info.pApplicationInfo = &app_info; - create_info.enabledLayerCount = enabled_layer_names.size(); - create_info.ppEnabledLayerNames = enabled_layer_names.data(); - create_info.enabledExtensionCount = enabled_extension_names.size(); - create_info.ppEnabledExtensionNames = enabled_extension_names.data(); - - // If we have the debug_utils extension then we can chain a one-shot messenger - // callback that we can use to log out the instance creation errors. Once we - // have the real instance we can then register a real messenger. - union { - VkDebugUtilsMessengerCreateInfoEXT debug_utils_create_info; - VkDebugReportCallbackCreateInfoEXT debug_report_create_info; - }; - if (instance_extensions.debug_utils) { - create_info.pNext = &debug_utils_create_info; - DebugReporter::PopulateStaticCreateInfo(&debug_utils_create_info); - } else if (instance_extensions.debug_report) { - create_info.pNext = &debug_report_create_info; - DebugReporter::PopulateStaticCreateInfo(&debug_report_create_info); - } - - // Some ICDs appear to leak in here, out of our control. - // Warning: leak checks remain disabled if an error is returned. - IREE_DISABLE_LEAK_CHECKS(); - VkInstance instance = VK_NULL_HANDLE; - VK_RETURN_IF_ERROR( - syms->vkCreateInstance(&create_info, /*pAllocator=*/nullptr, &instance)) - << "Unable to create Vulkan instance"; - IREE_ENABLE_LEAK_CHECKS(); - - // TODO(benvanik): enable validation layers if needed. - - // Now that the instance has been created we can fetch all of the instance - // symbols. - RETURN_IF_ERROR(syms->LoadFromInstance(instance)); - - // The real debug messenger (not just the static one used above) can now be - // created as we've loaded all the required symbols. - // TODO(benvanik): strip in release builds. - std::unique_ptr<DebugReporter> debug_reporter; - if (instance_extensions.debug_utils) { - ASSIGN_OR_RETURN(debug_reporter, DebugReporter::CreateDebugUtilsMessenger( - instance, syms, - /*allocation_callbacks=*/nullptr)); - } else if (instance_extensions.debug_report) { - ASSIGN_OR_RETURN(debug_reporter, - DebugReporter::CreateDebugReportCallback( - instance, syms, /*allocation_callbacks=*/nullptr)); - } - - return std::make_shared<VulkanDriver>( - CtorKey{}, std::move(syms), instance, std::move(debug_reporter), - std::move(options.device_extensibility)); -} - -VulkanDriver::VulkanDriver(CtorKey ctor_key, ref_ptr<DynamicSymbols> syms, - VkInstance instance, - std::unique_ptr<DebugReporter> debug_reporter, - ExtensibilitySpec device_extensibility_spec) - : Driver("vulkan"), - syms_(std::move(syms)), - instance_(instance), - debug_reporter_(std::move(debug_reporter)), - device_extensibility_spec_(std::move(device_extensibility_spec)) {} - -VulkanDriver::~VulkanDriver() { - IREE_TRACE_SCOPE0("VulkanDriver::dtor"); - debug_reporter_.reset(); - syms()->vkDestroyInstance(instance_, /*pAllocator=*/nullptr); -} - -StatusOr<std::vector<DeviceInfo>> VulkanDriver::EnumerateAvailableDevices() { - IREE_TRACE_SCOPE0("VulkanDriver::EnumerateAvailableDevices"); - - // Query all available devices (at this moment, note that this may change!). - uint32_t physical_device_count = 0; - VK_RETURN_IF_ERROR(syms()->vkEnumeratePhysicalDevices( - instance_, &physical_device_count, nullptr)); - absl::InlinedVector<VkPhysicalDevice, 2> physical_devices( - physical_device_count); - VK_RETURN_IF_ERROR(syms()->vkEnumeratePhysicalDevices( - instance_, &physical_device_count, physical_devices.data())); - - // Convert to our HAL structure. - std::vector<DeviceInfo> device_infos; - device_infos.reserve(physical_device_count); - for (auto physical_device : physical_devices) { - // TODO(benvanik): if we fail should we just ignore the device in the list? - ASSIGN_OR_RETURN(auto device_info, - PopulateDeviceInfo(physical_device, syms())); - device_infos.push_back(std::move(device_info)); - } - return device_infos; -} - -StatusOr<std::shared_ptr<Device>> VulkanDriver::CreateDefaultDevice() { - IREE_TRACE_SCOPE0("VulkanDriver::CreateDefaultDevice"); - - // Query available devices. - ASSIGN_OR_RETURN(auto available_devices, EnumerateAvailableDevices()); - if (available_devices.empty()) { - return NotFoundErrorBuilder(IREE_LOC) << "No devices are available"; - } - - // Just create the first one we find. - return CreateDevice(available_devices.front()); -} - -StatusOr<std::shared_ptr<Device>> VulkanDriver::CreateDevice( - const DeviceInfo& device_info) { - IREE_TRACE_SCOPE0("VulkanDriver::CreateDevice"); - - auto physical_device = - static_cast<VkPhysicalDevice>(device_info.driver_handle()); - - // Attempt to create the device. - // This may fail if the device was enumerated but is in exclusive use, - // disabled by the system, or permission is denied. - ASSIGN_OR_RETURN(auto device, - VulkanDevice::Create(device_info, physical_device, - device_extensibility_spec_, syms())); - - return device; -} - -} // namespace vulkan -} // namespace hal -} // namespace iree
diff --git a/iree/hal/vulkan/vulkan_driver.h b/iree/hal/vulkan/vulkan_driver.h deleted file mode 100644 index d2383db..0000000 --- a/iree/hal/vulkan/vulkan_driver.h +++ /dev/null
@@ -1,84 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VULKAN_VULKAN_DRIVER_H_ -#define IREE_HAL_VULKAN_VULKAN_DRIVER_H_ - -#include <vulkan/vulkan.h> - -#include <memory> -#include <vector> - -#include "iree/hal/driver.h" -#include "iree/hal/vulkan/debug_reporter.h" -#include "iree/hal/vulkan/dynamic_symbols.h" -#include "iree/hal/vulkan/extensibility_util.h" - -namespace iree { -namespace hal { -namespace vulkan { - -class VulkanDriver final : public Driver { - public: - struct Options { - // Vulkan version that will be requested. - // Driver creation will fail if the required version is not available. - uint32_t api_version = VK_API_VERSION_1_0; - - // Extensibility descriptions for instances and devices. - // Device descriptions will be used for all devices created by the driver. - ExtensibilitySpec instance_extensibility; - ExtensibilitySpec device_extensibility; - }; - - static StatusOr<std::shared_ptr<VulkanDriver>> Create( - Options options, ref_ptr<DynamicSymbols> syms); - - // TODO(benvanik): method to wrap an existing instance/device (interop). - - // Private constructor. - struct CtorKey { - private: - friend class VulkanDriver; - CtorKey() = default; - }; - VulkanDriver(CtorKey ctor_key, ref_ptr<DynamicSymbols> syms, - VkInstance instance, - std::unique_ptr<DebugReporter> debug_reporter, - ExtensibilitySpec device_extensibility_spec); - ~VulkanDriver() override; - - const ref_ptr<DynamicSymbols>& syms() const { return syms_; } - - VkInstance instance() const { return instance_; } - - StatusOr<std::vector<DeviceInfo>> EnumerateAvailableDevices() override; - - StatusOr<std::shared_ptr<Device>> CreateDefaultDevice() override; - - StatusOr<std::shared_ptr<Device>> CreateDevice( - const DeviceInfo& device_info) override; - - private: - ref_ptr<DynamicSymbols> syms_; - VkInstance instance_; - std::unique_ptr<DebugReporter> debug_reporter_; - ExtensibilitySpec device_extensibility_spec_; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_VULKAN_DRIVER_H_
diff --git a/iree/hal/vulkan/vulkan_driver_module.cc b/iree/hal/vulkan/vulkan_driver_module.cc deleted file mode 100644 index 4d3d745..0000000 --- a/iree/hal/vulkan/vulkan_driver_module.cc +++ /dev/null
@@ -1,102 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <memory> - -#include "absl/flags/flag.h" -#include "iree/base/init.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/driver_registry.h" -#include "iree/hal/vulkan/dynamic_symbols.h" -#include "iree/hal/vulkan/vulkan_driver.h" - -ABSL_FLAG(bool, vulkan_validation_layers, true, - "Enables standard Vulkan validation layers."); -ABSL_FLAG(bool, vulkan_debug_utils, true, - "Enables VK_EXT_debug_utils, records markers, and logs errors."); -ABSL_FLAG(bool, vulkan_debug_report, false, - "Enables VK_EXT_debug_report and logs errors."); -ABSL_FLAG(bool, vulkan_push_descriptors, true, - "Enables use of vkCmdPushDescriptorSetKHR, if available."); - -namespace iree { -namespace hal { -namespace vulkan { -namespace { - -StatusOr<std::shared_ptr<Driver>> CreateVulkanDriver() { - IREE_TRACE_SCOPE0("CreateVulkanDriver"); - - // Load the Vulkan library. This will fail if the library cannot be found or - // does not have the expected functions. - ASSIGN_OR_RETURN(auto syms, DynamicSymbols::CreateFromSystemLoader()); - - // Setup driver options from flags. We do this here as we want to enable other - // consumers that may not be using modules/command line flags to be able to - // set their options however they want. - VulkanDriver::Options options; - - // TODO: validation layers have bugs when using VK_EXT_debug_report, so if the - // user requested that we force them off with a warning. Prefer using - // VK_EXT_debug_utils when available. - if (absl::GetFlag(FLAGS_vulkan_debug_report) && - absl::GetFlag(FLAGS_vulkan_validation_layers)) { - LOG(WARNING) << "VK_EXT_debug_report has issues with modern validation " - "layers; disabling validation"; - absl::SetFlag(&FLAGS_vulkan_validation_layers, false); - } - - // REQUIRED: these are required extensions that must be present for IREE to - // work (such as those relied upon by SPIR-V kernels, etc). - options.device_extensibility.required_extensions.push_back( - VK_KHR_STORAGE_BUFFER_STORAGE_CLASS_EXTENSION_NAME); - - if (absl::GetFlag(FLAGS_vulkan_validation_layers)) { - options.instance_extensibility.optional_layers.push_back( - "VK_LAYER_LUNARG_standard_validation"); - } - - if (absl::GetFlag(FLAGS_vulkan_debug_report)) { - options.instance_extensibility.optional_extensions.push_back( - VK_EXT_DEBUG_REPORT_EXTENSION_NAME); - } - if (absl::GetFlag(FLAGS_vulkan_debug_utils)) { - options.instance_extensibility.optional_extensions.push_back( - VK_EXT_DEBUG_UTILS_EXTENSION_NAME); - } - - if (absl::GetFlag(FLAGS_vulkan_push_descriptors)) { - options.instance_extensibility.optional_extensions.push_back( - VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME); - options.device_extensibility.optional_extensions.push_back( - VK_KHR_PUSH_DESCRIPTOR_EXTENSION_NAME); - } - - // Create the driver and VkInstance. - ASSIGN_OR_RETURN(auto driver, VulkanDriver::Create(options, std::move(syms))); - - return driver; -} - -} // namespace -} // namespace vulkan -} // namespace hal -} // namespace iree - -IREE_REGISTER_MODULE_INITIALIZER(iree_hal_vulkan_driver, { - QCHECK_OK(::iree::hal::DriverRegistry::shared_registry()->Register( - "vulkan", ::iree::hal::vulkan::CreateVulkanDriver)); -}); -IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(iree_hal, iree_hal_vulkan_driver);
diff --git a/iree/rt/BUILD b/iree/rt/BUILD deleted file mode 100644 index 33aea24..0000000 --- a/iree/rt/BUILD +++ /dev/null
@@ -1,83 +0,0 @@ -# Runtime API for interacting with IREE modules and invoking functions. - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "api", - srcs = ["api.cc"], - hdrs = ["api.h"], - visibility = ["//visibility:public"], - deps = [ - ":api_hdrs", - ":rt", - "//iree/base:api", - "//iree/base:api_util", - "//iree/base:tracing", - "//iree/hal:api", - "//iree/hal:buffer_view", - "//iree/hal:driver_registry", - "//iree/rt/debug:debug_server_interface", - "@com_google_absl//absl/time", - ], -) - -cc_library( - name = "api_hdrs", - hdrs = ["api.h"], - deps = [ - "//iree/base:api_hdrs", - "//iree/hal:api_hdrs", - ], -) - -cc_library( - name = "rt", - srcs = [ - "context.cc", - "function.cc", - "instance.cc", - "invocation.cc", - "module_printer.cc", - "source_location.cc", - "stack.cc", - "stack_frame.cc", - "stack_trace.cc", - ], - hdrs = [ - "context.h", - "disassembler.h", - "function.h", - "function_signature.h", - "instance.h", - "invocation.h", - "module.h", - "module_printer.h", - "module_signature.h", - "policy.h", - "source_location.h", - "source_resolver.h", - "stack.h", - "stack_frame.h", - "stack_trace.h", - ], - deps = [ - "//iree/base:bitfield", - "//iree/base:intrusive_list", - "//iree/base:ref_ptr", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:buffer_view", - "//iree/hal:device_manager", - "//iree/rt/debug:debug_server_interface", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:span", - ], -)
diff --git a/iree/rt/api.cc b/iree/rt/api.cc deleted file mode 100644 index 04cbca4..0000000 --- a/iree/rt/api.cc +++ /dev/null
@@ -1,788 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/rt/api.h" - -#include "absl/time/time.h" -#include "iree/base/api.h" -#include "iree/base/api_util.h" -#include "iree/base/tracing.h" -#include "iree/hal/api.h" -#include "iree/hal/api_detail.h" -#include "iree/hal/buffer_view.h" -#include "iree/hal/driver_registry.h" -#include "iree/rt/context.h" -#include "iree/rt/debug/debug_server.h" -#include "iree/rt/function.h" -#include "iree/rt/instance.h" -#include "iree/rt/invocation.h" -#include "iree/rt/module.h" -#include "iree/rt/policy.h" - -namespace iree { -namespace rt { - -//===----------------------------------------------------------------------===// -// iree::rt::Instance -//===----------------------------------------------------------------------===// - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_instance_create( - iree_allocator_t allocator, iree_rt_instance_t** out_instance) { - IREE_TRACE_SCOPE0("iree_rt_instance_create"); - - if (!out_instance) { - return IREE_STATUS_INVALID_ARGUMENT; - } - *out_instance = nullptr; - - auto instance = make_ref<Instance>(); - *out_instance = reinterpret_cast<iree_rt_instance_t*>(instance.release()); - - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_instance_retain(iree_rt_instance_t* instance) { - IREE_TRACE_SCOPE0("iree_rt_instance_retain"); - auto* handle = reinterpret_cast<Instance*>(instance); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - handle->AddReference(); - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_instance_release(iree_rt_instance_t* instance) { - IREE_TRACE_SCOPE0("iree_rt_instance_release"); - auto* handle = reinterpret_cast<Instance*>(instance); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - handle->ReleaseReference(); - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_instance_register_driver_ex( - iree_rt_instance_t* instance, iree_string_view_t driver_name) { - IREE_TRACE_SCOPE0("iree_rt_instance_register_driver_ex"); - auto* handle = reinterpret_cast<Instance*>(instance); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - - IREE_API_ASSIGN_OR_RETURN( - auto driver, hal::DriverRegistry::shared_registry()->Create( - absl::string_view{driver_name.data, driver_name.size})); - IREE_API_ASSIGN_OR_RETURN(auto available_devices, - driver->EnumerateAvailableDevices()); - for (const auto& device_info : available_devices) { - LOG(INFO) << " Device: " << device_info.name(); - } - LOG(INFO) << "Creating default device..."; - IREE_API_ASSIGN_OR_RETURN(auto device, driver->CreateDefaultDevice()); - IREE_API_RETURN_IF_ERROR(handle->device_manager()->RegisterDevice(device)); - LOG(INFO) << "Successfully created device '" << device->info().name() << "'"; - - return IREE_STATUS_OK; -} - -//===----------------------------------------------------------------------===// -// iree::rt::Module -//===----------------------------------------------------------------------===// - -namespace { - -class ExternalModule final : public Module { - public: - ExternalModule(iree_rt_external_module_t impl, iree_allocator_t allocator) - : impl_(impl), allocator_(allocator) { - IREE_TRACE_SCOPE0("ExternalModule::ctor"); - } - - ~ExternalModule() override { - IREE_TRACE_SCOPE0("ExternalModule::dtor"); - impl_.destroy(impl_.self); - std::memset(&impl_, 0, sizeof(impl_)); - } - - absl::string_view name() const override { - auto result = impl_.name(impl_.self); - return absl::string_view{result.data, result.size}; - } - - const ModuleSignature signature() const override { - auto signature = impl_.signature(impl_.self); - return ModuleSignature{ - signature.import_function_count, - signature.export_function_count, - signature.internal_function_count, - signature.state_slot_count, - }; - } - - SourceResolver* source_resolver() const override { return nullptr; } - - Disassembler* disassembler() const override { return nullptr; } - - std::string DebugStringShort() const override { return std::string(name()); } - - StatusOr<const Function> LookupFunctionByOrdinal( - Function::Linkage linkage, int32_t ordinal) const override { - IREE_TRACE_SCOPE0("ExternalModule::LookupFunctionByOrdinal"); - iree_rt_function_t function; - auto status = impl_.lookup_function_by_ordinal( - impl_.self, static_cast<iree_rt_function_linkage_t>(linkage), ordinal, - &function); - if (status != IREE_STATUS_OK) { - return FromApiStatus(status, IREE_LOC); - } - return Function{reinterpret_cast<Module*>(function.module), - static_cast<Function::Linkage>(function.linkage), - function.ordinal}; - } - - StatusOr<const Function> LookupFunctionByName( - Function::Linkage linkage, absl::string_view name) const override { - IREE_TRACE_SCOPE0("ExternalModule::LookupFunctionByName"); - iree_rt_function_t function; - auto status = impl_.lookup_function_by_name( - impl_.self, static_cast<iree_rt_function_linkage_t>(linkage), - iree_string_view_t{name.data(), name.size()}, &function); - if (status != IREE_STATUS_OK) { - return FromApiStatus(status, IREE_LOC); - } - return Function{reinterpret_cast<Module*>(function.module), - static_cast<Function::Linkage>(function.linkage), - function.ordinal}; - } - - StatusOr<absl::string_view> GetFunctionName(Function::Linkage linkage, - int32_t ordinal) const override { - IREE_TRACE_SCOPE0("ExternalModule::GetFunctionName"); - iree_string_view_t name; - auto status = impl_.get_function_name( - impl_.self, static_cast<iree_rt_function_linkage_t>(linkage), ordinal, - &name); - RETURN_IF_ERROR(FromApiStatus(status, IREE_LOC)); - return absl::string_view{name.data, name.size}; - } - - StatusOr<const FunctionSignature> GetFunctionSignature( - Function::Linkage linkage, int32_t ordinal) const override { - IREE_TRACE_SCOPE0("ExternalModule::GetFunctionSignature"); - iree_rt_function_signature_t signature; - auto status = impl_.get_function_signature( - impl_.self, static_cast<iree_rt_function_linkage_t>(linkage), ordinal, - &signature); - if (status != IREE_STATUS_OK) { - return FromApiStatus(status, IREE_LOC); - } - return FunctionSignature{signature.argument_count, signature.result_count}; - } - - Status Execute( - Stack* stack, const Function function, - absl::InlinedVector<hal::BufferView, 8> arguments, - absl::InlinedVector<hal::BufferView, 8>* results) const override { - // TODO(benvanik): fn ptr callback to external code. Waiting on fibers. - return UnimplementedErrorBuilder(IREE_LOC) - << "External calls not yet implemented"; - } - - private: - iree_rt_external_module_t impl_; - iree_allocator_t allocator_; -}; - -} // namespace - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_module_create_external( - iree_rt_external_module_t impl, iree_allocator_t allocator, - iree_rt_module_t** out_module) { - IREE_TRACE_SCOPE0("iree_rt_module_create_external"); - - if (!out_module) { - return IREE_STATUS_INVALID_ARGUMENT; - } - *out_module = nullptr; - - auto module = make_ref<ExternalModule>(impl, allocator); - *out_module = reinterpret_cast<iree_rt_module_t*>(module.release()); - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_module_retain(iree_rt_module_t* module) { - IREE_TRACE_SCOPE0("iree_rt_module_retain"); - auto* handle = reinterpret_cast<Module*>(module); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - handle->AddReference(); - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_module_release(iree_rt_module_t* module) { - IREE_TRACE_SCOPE0("iree_rt_module_release"); - auto* handle = reinterpret_cast<Module*>(module); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - handle->ReleaseReference(); - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_string_view_t IREE_API_CALL -iree_rt_module_name(const iree_rt_module_t* module) { - IREE_TRACE_SCOPE0("iree_rt_module_name"); - const auto* handle = reinterpret_cast<const Module*>(module); - CHECK(handle) << "NULL module handle"; - return iree_string_view_t{handle->name().data(), handle->name().size()}; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_module_lookup_function_by_ordinal(iree_rt_module_t* module, - iree_rt_function_linkage_t linkage, - int32_t ordinal, - iree_rt_function_t* out_function) { - IREE_TRACE_SCOPE0("iree_rt_module_lookup_function_by_ordinal"); - - if (!out_function) { - return IREE_STATUS_INVALID_ARGUMENT; - } - std::memset(out_function, 0, sizeof(*out_function)); - - auto* handle = reinterpret_cast<Module*>(module); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - - auto function_or = handle->LookupFunctionByOrdinal( - static_cast<Function::Linkage>(linkage), ordinal); - if (!function_or.ok()) { - // Map this invalid argument to not found, per the API spec. - if (IsInvalidArgument(function_or.status())) { - return IREE_STATUS_NOT_FOUND; - } - return ToApiStatus(std::move(function_or).status()); - } - auto function = *function_or; - - out_function->module = module; - out_function->linkage = - static_cast<iree_rt_function_linkage_t>(function.linkage()); - out_function->ordinal = function.ordinal(); - - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_module_lookup_function_by_name(iree_rt_module_t* module, - iree_rt_function_linkage_t linkage, - iree_string_view_t name, - iree_rt_function_t* out_function) { - IREE_TRACE_SCOPE0("iree_rt_module_lookup_function_by_name"); - - if (!out_function) { - return IREE_STATUS_INVALID_ARGUMENT; - } - std::memset(out_function, 0, sizeof(*out_function)); - - auto* handle = reinterpret_cast<Module*>(module); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - - IREE_API_ASSIGN_OR_RETURN( - auto function, - handle->LookupFunctionByName(static_cast<Function::Linkage>(linkage), - absl::string_view{name.data, name.size})); - - out_function->linkage = - static_cast<iree_rt_function_linkage_t>(function.linkage()); - out_function->module = module; - out_function->linkage = linkage; - out_function->ordinal = function.ordinal(); - - return IREE_STATUS_OK; -} - -//===----------------------------------------------------------------------===// -// iree::rt::Function -//===----------------------------------------------------------------------===// - -IREE_API_EXPORT iree_string_view_t IREE_API_CALL -iree_rt_function_name(const iree_rt_function_t* function) { - IREE_TRACE_SCOPE0("iree_rt_function_name"); - CHECK(function && function->module) << "NULL function handle"; - auto* module = reinterpret_cast<Module*>(function->module); - auto name_or = module->GetFunctionName( - static_cast<Function::Linkage>(function->linkage), function->ordinal); - if (!name_or.ok()) return {}; - auto name = name_or.ValueOrDie(); - return iree_string_view_t{name.data(), name.size()}; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_function_signature(const iree_rt_function_t* function, - iree_rt_function_signature_t* out_signature) { - IREE_TRACE_SCOPE0("iree_rt_function_signature"); - - if (!out_signature) { - return IREE_STATUS_INVALID_ARGUMENT; - } - std::memset(out_signature, 0, sizeof(*out_signature)); - - if (!function || !function->module) { - return IREE_STATUS_INVALID_ARGUMENT; - } - - auto* module = reinterpret_cast<Module*>(function->module); - IREE_API_ASSIGN_OR_RETURN( - auto signature, module->GetFunctionSignature( - static_cast<Function::Linkage>(function->linkage), - function->ordinal)); - out_signature->argument_count = signature.argument_count(); - out_signature->result_count = signature.result_count(); - return IREE_STATUS_OK; -} - -//===----------------------------------------------------------------------===// -// iree::rt::Policy -//===----------------------------------------------------------------------===// - -IREE_API_EXPORT iree_status_t iree_rt_policy_create( - iree_allocator_t allocator, iree_rt_policy_t** out_policy) { - IREE_TRACE_SCOPE0("iree_rt_policy_create"); - - if (!out_policy) { - return IREE_STATUS_INVALID_ARGUMENT; - } - *out_policy = nullptr; - - auto policy = make_ref<Policy>(); - - *out_policy = reinterpret_cast<iree_rt_policy_t*>(policy.release()); - - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_policy_retain(iree_rt_policy_t* policy) { - IREE_TRACE_SCOPE0("iree_rt_policy_retain"); - auto* handle = reinterpret_cast<Policy*>(policy); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - handle->AddReference(); - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_policy_release(iree_rt_policy_t* policy) { - IREE_TRACE_SCOPE0("iree_rt_policy_release"); - auto* handle = reinterpret_cast<Policy*>(policy); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - handle->ReleaseReference(); - return IREE_STATUS_OK; -} - -//===----------------------------------------------------------------------===// -// iree::rt::Context -//===----------------------------------------------------------------------===// - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_create( - iree_rt_instance_t* instance, iree_rt_policy_t* policy, - iree_allocator_t allocator, iree_rt_context_t** out_context) { - IREE_TRACE_SCOPE0("iree_rt_context_create"); - - if (!out_context) { - return IREE_STATUS_INVALID_ARGUMENT; - } - *out_context = nullptr; - - if (!instance || !policy) { - return IREE_STATUS_INVALID_ARGUMENT; - } - - auto context = - make_ref<Context>(add_ref(reinterpret_cast<Instance*>(instance)), - add_ref(reinterpret_cast<Policy*>(policy))); - - *out_context = reinterpret_cast<iree_rt_context_t*>(context.release()); - - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_context_retain(iree_rt_context_t* context) { - IREE_TRACE_SCOPE0("iree_rt_context_retain"); - auto* handle = reinterpret_cast<Context*>(context); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - handle->AddReference(); - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_context_release(iree_rt_context_t* context) { - IREE_TRACE_SCOPE0("iree_rt_context_release"); - auto* handle = reinterpret_cast<Context*>(context); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - handle->ReleaseReference(); - return IREE_STATUS_OK; -} - -IREE_API_EXPORT int32_t IREE_API_CALL -iree_rt_context_id(const iree_rt_context_t* context) { - IREE_TRACE_SCOPE0("iree_rt_context_id"); - const auto* handle = reinterpret_cast<const Context*>(context); - CHECK(handle) << "NULL context handle"; - return handle->id(); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_register_modules( - iree_rt_context_t* context, iree_rt_module_t** modules, - iree_host_size_t module_count) { - IREE_TRACE_SCOPE0("iree_rt_context_register_modules"); - auto* handle = reinterpret_cast<Context*>(context); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - if (module_count && !modules) { - return IREE_STATUS_INVALID_ARGUMENT; - } - for (size_t i = 0; i < module_count; ++i) { - auto* module = reinterpret_cast<Module*>(modules[i]); - if (!module) { - return IREE_STATUS_INVALID_ARGUMENT; - } - IREE_API_RETURN_IF_ERROR(handle->RegisterModule(add_ref(module))); - } - - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_rt_module_t* IREE_API_CALL -iree_rt_context_lookup_module_by_name(const iree_rt_context_t* context, - iree_string_view_t module_name) { - IREE_TRACE_SCOPE0("iree_rt_context_lookup_module_by_name"); - const auto* handle = reinterpret_cast<const Context*>(context); - CHECK(handle) << "NULL context handle"; - auto module_or = handle->LookupModuleByName( - absl::string_view{module_name.data, module_name.size}); - if (!module_or.ok()) { - return nullptr; - } - return reinterpret_cast<iree_rt_module_t*>(module_or.ValueOrDie()); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_resolve_function( - const iree_rt_context_t* context, iree_string_view_t full_name, - iree_rt_function_t* out_function) { - IREE_TRACE_SCOPE0("iree_rt_context_resolve_function"); - - if (!out_function) { - return IREE_STATUS_INVALID_ARGUMENT; - } - std::memset(out_function, 0, sizeof(*out_function)); - - const auto* handle = reinterpret_cast<const Context*>(context); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - - auto full_name_view = absl::string_view{full_name.data, full_name.size}; - size_t last_dot = full_name_view.rfind('.'); - if (last_dot == absl::string_view::npos) { - return IREE_STATUS_INVALID_ARGUMENT; - } - auto module_name = full_name_view.substr(0, last_dot); - auto function_name = full_name_view.substr(last_dot + 1); - - iree_rt_module_t* module = iree_rt_context_lookup_module_by_name( - context, iree_string_view_t{module_name.data(), module_name.size()}); - if (!module) { - return IREE_STATUS_NOT_FOUND; - } - - return iree_rt_module_lookup_function_by_name( - module, IREE_RT_FUNCTION_LINKAGE_EXPORT, - iree_string_view_t{function_name.data(), function_name.size()}, - out_function); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_context_allocate_device_visible_buffer( - iree_rt_context_t* context, iree_hal_buffer_usage_t buffer_usage, - iree_host_size_t allocation_size, iree_allocator_t allocator, - iree_hal_buffer_t** out_buffer) { - IREE_TRACE_SCOPE0("iree_rt_context_allocate_device_visible_buffer"); - - if (!out_buffer) { - return IREE_STATUS_INVALID_ARGUMENT; - } - std::memset(out_buffer, 0, sizeof(*out_buffer)); - - const auto* handle = reinterpret_cast<const Context*>(context); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } else if (!allocation_size) { - return IREE_STATUS_INVALID_ARGUMENT; - } - - // TODO(benvanik): reroute to context based on current policy. - auto* device_manager = handle->instance()->device_manager(); - IREE_API_ASSIGN_OR_RETURN(auto device_placement, - device_manager->ResolvePlacement({})); - IREE_API_ASSIGN_OR_RETURN(auto buffer, - device_manager->AllocateDeviceVisibleBuffer( - static_cast<hal::BufferUsage>(buffer_usage), - allocation_size, {device_placement})); - - *out_buffer = reinterpret_cast<iree_hal_buffer_t*>(buffer.release()); - - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_context_allocate_device_local_buffer( - iree_rt_context_t* context, iree_hal_buffer_usage_t buffer_usage, - iree_host_size_t allocation_size, iree_allocator_t allocator, - iree_hal_buffer_t** out_buffer) { - IREE_TRACE_SCOPE0("iree_rt_context_allocate_device_local_buffer"); - - if (!out_buffer) { - return IREE_STATUS_INVALID_ARGUMENT; - } - std::memset(out_buffer, 0, sizeof(*out_buffer)); - - const auto* handle = reinterpret_cast<const Context*>(context); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } else if (!allocation_size) { - return IREE_STATUS_INVALID_ARGUMENT; - } - - // TODO(benvanik): reroute to context based on current policy. - auto* device_manager = handle->instance()->device_manager(); - IREE_API_ASSIGN_OR_RETURN(auto device_placement, - device_manager->ResolvePlacement({})); - IREE_API_ASSIGN_OR_RETURN(auto buffer, - device_manager->AllocateDeviceLocalBuffer( - static_cast<hal::BufferUsage>(buffer_usage), - allocation_size, {device_placement})); - - *out_buffer = reinterpret_cast<iree_hal_buffer_t*>(buffer.release()); - - return IREE_STATUS_OK; -} - -//===----------------------------------------------------------------------===// -// iree::rt::Invocation -//===----------------------------------------------------------------------===// - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_create( - iree_rt_context_t* context, iree_rt_function_t* function, - iree_rt_policy_t* policy, - const iree_rt_invocation_dependencies_t* dependencies, - iree_hal_buffer_view_t** arguments, iree_host_size_t argument_count, - iree_hal_buffer_view_t** results, iree_host_size_t result_count, - iree_allocator_t allocator, iree_rt_invocation_t** out_invocation) { - IREE_TRACE_SCOPE0("iree_rt_invocation_create"); - - if (!out_invocation) { - return IREE_STATUS_INVALID_ARGUMENT; - } - *out_invocation = nullptr; - - if (!context || !function || !function->module) { - return IREE_STATUS_INVALID_ARGUMENT; - } else if (dependencies && - (dependencies->invocation_count && !dependencies->invocations)) { - return IREE_STATUS_INVALID_ARGUMENT; - } else if ((argument_count && !arguments) || (result_count && !results)) { - return IREE_STATUS_INVALID_ARGUMENT; - } - - // TODO(benvanik): unwrap without needing to retain here. - absl::InlinedVector<ref_ptr<Invocation>, 4> dependent_invocations; - if (dependencies) { - dependent_invocations.resize(dependencies->invocation_count); - for (int i = 0; i < dependencies->invocation_count; ++i) { - dependent_invocations[i] = - add_ref(reinterpret_cast<Invocation*>(dependencies->invocations[i])); - } - } - - // TODO(benvanik): unwrap without needing to retain here. - absl::InlinedVector<hal::BufferView, 8> argument_views(argument_count); - for (int i = 0; i < argument_count; ++i) { - const auto* api_buffer_view = - reinterpret_cast<const hal::iree_hal_buffer_view*>(arguments[i]); - if (!api_buffer_view) { - return IREE_STATUS_INVALID_ARGUMENT; - } - argument_views[i] = hal::BufferView{add_ref(api_buffer_view->impl.buffer), - api_buffer_view->impl.shape, - api_buffer_view->impl.element_size}; - } - - // TODO(benvanik): unwrap without needing to retain here. - absl::InlinedVector<hal::BufferView, 8> result_views(result_count); - for (int i = 0; i < result_count; ++i) { - const auto* api_buffer_view = - reinterpret_cast<const hal::iree_hal_buffer_view*>(results[i]); - if (api_buffer_view) { - result_views[i] = hal::BufferView{add_ref(api_buffer_view->impl.buffer), - api_buffer_view->impl.shape, - api_buffer_view->impl.element_size}; - } - } - - IREE_API_ASSIGN_OR_RETURN( - auto invocation, - Invocation::Create( - add_ref(reinterpret_cast<Context*>(context)), - Function{reinterpret_cast<Module*>(function->module), - static_cast<Function::Linkage>(function->linkage), - function->ordinal}, - add_ref(reinterpret_cast<Policy*>(policy)), - std::move(dependent_invocations), std::move(argument_views), - result_views.empty() - ? absl::optional<absl::InlinedVector<hal::BufferView, 8>>( - absl::nullopt) - : std::move(result_views))); - - *out_invocation = - reinterpret_cast<iree_rt_invocation_t*>(invocation.release()); - - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_invocation_retain(iree_rt_invocation_t* invocation) { - IREE_TRACE_SCOPE0("iree_rt_invocation_retain"); - auto* handle = reinterpret_cast<Invocation*>(invocation); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - handle->AddReference(); - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_invocation_release(iree_rt_invocation_t* invocation) { - IREE_TRACE_SCOPE0("iree_rt_invocation_release"); - auto* handle = reinterpret_cast<Invocation*>(invocation); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - handle->ReleaseReference(); - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_invocation_query_status(iree_rt_invocation_t* invocation) { - IREE_TRACE_SCOPE0("iree_rt_invocation_query_status"); - auto* handle = reinterpret_cast<Invocation*>(invocation); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - IREE_API_RETURN_IF_ERROR(handle->QueryStatus()); - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_consume_results( - iree_rt_invocation_t* invocation, iree_host_size_t result_capacity, - iree_allocator_t allocator, iree_hal_buffer_view_t** out_results, - iree_host_size_t* out_result_count) { - IREE_TRACE_SCOPE0("iree_rt_invocation_consume_results"); - - if (!out_result_count) { - return IREE_STATUS_INVALID_ARGUMENT; - } - *out_result_count = 0; - if (!out_results) { - std::memset(out_results, 0, - sizeof(iree_hal_buffer_view_t*) * result_capacity); - } - - auto* handle = reinterpret_cast<Invocation*>(invocation); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - - const auto& function = handle->function(); - int32_t result_count = function.signature().result_count(); - *out_result_count = result_count; - if (!out_results) { - return IREE_STATUS_OK; - } else if (result_capacity < result_count) { - return IREE_STATUS_OUT_OF_RANGE; - } - - IREE_API_ASSIGN_OR_RETURN(auto results, handle->ConsumeResults()); - iree_status_t status = IREE_STATUS_OK; - int i = 0; - for (i = 0; i < results.size(); ++i) { - iree_shape_t shape; - status = ToApiShape(results[i].shape, &shape); - if (status != IREE_STATUS_OK) break; - status = iree_hal_buffer_view_create( - reinterpret_cast<iree_hal_buffer_t*>(results[i].buffer.get()), shape, - results[i].element_size, allocator, &out_results[i]); - if (status != IREE_STATUS_OK) break; - } - if (status != IREE_STATUS_OK) { - // Release already-retained buffer views on failure. - for (int j = 0; j < i; ++j) { - iree_hal_buffer_view_release(out_results[j]); - out_results[j] = nullptr; - } - } - return status; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_await( - iree_rt_invocation_t* invocation, iree_time_t deadline) { - IREE_TRACE_SCOPE0("iree_rt_invocation_await"); - auto* handle = reinterpret_cast<Invocation*>(invocation); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - IREE_API_RETURN_IF_ERROR(handle->Await(ToAbslTime(deadline))); - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_invocation_abort(iree_rt_invocation_t* invocation) { - IREE_TRACE_SCOPE0("iree_rt_invocation_abort"); - auto* handle = reinterpret_cast<Invocation*>(invocation); - if (!handle) { - return IREE_STATUS_INVALID_ARGUMENT; - } - IREE_API_RETURN_IF_ERROR(handle->Abort()); - return IREE_STATUS_OK; -} - -} // namespace rt -} // namespace iree
diff --git a/iree/rt/api.h b/iree/rt/api.h deleted file mode 100644 index 37e66c3..0000000 --- a/iree/rt/api.h +++ /dev/null
@@ -1,400 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// See iree/base/api.h for documentation on the API conventions used. - -#ifndef IREE_RT_API_H_ -#define IREE_RT_API_H_ - -#include <stdint.h> - -#include "iree/base/api.h" -#include "iree/hal/api.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -//===----------------------------------------------------------------------===// -// Types and Enums -//===----------------------------------------------------------------------===// - -typedef struct iree_rt_instance iree_rt_instance_t; -typedef struct iree_rt_context iree_rt_context_t; -typedef struct iree_rt_policy iree_rt_policy_t; -typedef struct iree_rt_module iree_rt_module_t; -typedef struct iree_rt_invocation iree_rt_invocation_t; - -// Describes the type of a function reference. -typedef enum { - // Function is internal to the module and may not be reflectable. - IREE_RT_FUNCTION_LINKAGE_INTERNAL = 0, - // Function is an import from another module. - IREE_RT_FUNCTION_LINKAGE_IMPORT = 1, - // Function is an export from the module. - IREE_RT_FUNCTION_LINKAGE_EXPORT = 2, -} iree_rt_function_linkage_t; - -// A function reference that can be used with the iree_rt_function_* methods. -// These should be treated as opaque and the accessor functions should be used -// instead. -typedef struct { - // Module the function is contained within. - iree_rt_module_t* module; - // Linkage of the function. Note that IREE_RT_FUNCTION_LINKAGE_INTERNAL - // functions may be missing reflection information. - iree_rt_function_linkage_t linkage; - // Ordinal within the module in the linkage scope. - int32_t ordinal; -} iree_rt_function_t; - -// Describes the expected calling convention and arguments/results of a -// function. -typedef struct { - // Total number of arguments to the function. - int32_t argument_count; - // Total number of results from the function. - int32_t result_count; -} iree_rt_function_signature_t; - -// Describes the imports, exports, and capabilities of a module. -typedef struct { - // Total number of imported functions. - int32_t import_function_count; - // Total number of exported functions. - int32_t export_function_count; - // Total number of internal functions, if debugging info is present and they - // can be queried. - int32_t internal_function_count; - // Total number of state block resource slots consumed. - int32_t state_slot_count; -} iree_rt_module_signature_t; - -// Dependency information used to order invocations. -typedef struct { - // Prior invocations that must complete before the new invocation begins. - iree_rt_invocation_t** invocations; - iree_host_size_t invocation_count; - - // TODO(benvanik): wait semaphores/importing. -} iree_rt_invocation_dependencies_t; - -// Defines an external module that can be used to reflect and execute functions. -// Modules must be thread-safe as lookups and executions may occur in any order -// from any thread. -// -// Modules will have their resolve_imports function called upon registration -// with a context and may use the provided resolver to find imported functions. -typedef struct { - // User-defined pointer passed to all functions. - void* self; - // Destroys |self| when all references to the module have been released. - iree_status_t(IREE_API_PTR* destroy)(void* self); - // Returns the name of the module (used during resolution). - iree_string_view_t(IREE_API_PTR* name)(void* self); - // Sets |out_module_signature| to the reflected signature of the module. - iree_rt_module_signature_t(IREE_API_PTR* signature)(void* self); - // Sets |out_function| to a resolved function by ordinal, if found. - iree_status_t(IREE_API_PTR* lookup_function_by_ordinal)( - void* self, iree_rt_function_linkage_t linkage, int32_t ordinal, - iree_rt_function_t* out_function); - // Sets |out_function| to a resolved function by name, if found. - iree_status_t(IREE_API_PTR* lookup_function_by_name)( - void* self, iree_rt_function_linkage_t linkage, iree_string_view_t name, - iree_rt_function_t* out_function); - // Sets |out_name| to the name of the function with the given ordinal, if - // found. - iree_status_t(IREE_API_PTR* get_function_name)( - void* self, iree_rt_function_linkage_t linkage, int32_t ordinal, - iree_string_view_t* out_name); - // Sets |out_signature| to the reflected signature of the given - // function, if found. - iree_status_t(IREE_API_PTR* get_function_signature)( - void* self, iree_rt_function_linkage_t linkage, int32_t ordinal, - iree_rt_function_signature_t* out_signature); -} iree_rt_external_module_t; - -//===----------------------------------------------------------------------===// -// iree::rt::Instance -//===----------------------------------------------------------------------===// - -#ifndef IREE_API_NO_PROTOTYPES - -// Creates a new instance. This should be shared with all contexts in an -// application to ensure that resources are tracked properly and threads are -// managed correctly. -// |out_instance| must be released by the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_instance_create( - iree_allocator_t allocator, iree_rt_instance_t** out_instance); - -// Retains the given |instance| for the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_instance_retain(iree_rt_instance_t* instance); - -// Releases the given |instance| from the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_instance_release(iree_rt_instance_t* instance); - -// TEMPORARY: until policies and placement are performed this can be used to -// explicitly create and register drivers by name. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_instance_register_driver_ex( - iree_rt_instance_t* instance, iree_string_view_t driver_name); - -#endif // IREE_API_NO_PROTOTYPES - -//===----------------------------------------------------------------------===// -// iree::rt::Module -//===----------------------------------------------------------------------===// - -#ifndef IREE_API_NO_PROTOTYPES - -// Creates a module with an external backing implementation. -// The provided |external_module| definition will be used to query the module -// state as needed. No caching occurs within the implementation to allow calls -// to return different values per-invocation. -// -// |out_module| must be released by the caller. -// iree_rt_external_module_t::destroy is called when the last reference to the -// iree_rt_module_t is released. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_module_create_external( - iree_rt_external_module_t impl, iree_allocator_t allocator, - iree_rt_module_t** out_module); - -// Retains the given |module| for the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_module_retain(iree_rt_module_t* module); - -// Releases the given |module| from the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_module_release(iree_rt_module_t* module); - -// Returns the name of the module. -IREE_API_EXPORT iree_string_view_t IREE_API_CALL -iree_rt_module_name(const iree_rt_module_t* module); - -// Sets |out_function| to a function with |ordinal| in the given linkage or -// returns IREE_STATUS_NOT_FOUND. The function reference is valid for the -// lifetime of |module|. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_module_lookup_function_by_ordinal(iree_rt_module_t* module, - iree_rt_function_linkage_t linkage, - int32_t ordinal, - iree_rt_function_t* out_function); - -// Sets |out_function| to a function with |name| in the given linkage or returns -// IREE_STATUS_NOT_FOUND. The function reference is valid for the lifetime of -// |module|. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_module_lookup_function_by_name(iree_rt_module_t* module, - iree_rt_function_linkage_t linkage, - iree_string_view_t name, - iree_rt_function_t* out_function); - -#endif // IREE_API_NO_PROTOTYPES - -//===----------------------------------------------------------------------===// -// iree::rt::Function -//===----------------------------------------------------------------------===// - -#ifndef IREE_API_NO_PROTOTYPES - -// Returns the name of the function as exported from the module. -IREE_API_EXPORT iree_string_view_t IREE_API_CALL -iree_rt_function_name(const iree_rt_function_t* function); - -// Sets |out_function_signature| to the reflected signature of the function. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_function_signature(const iree_rt_function_t* function, - iree_rt_function_signature_t* out_signature); - -#endif // IREE_API_NO_PROTOTYPES - -//===----------------------------------------------------------------------===// -// iree::rt::Policy -//===----------------------------------------------------------------------===// - -#ifndef IREE_API_NO_PROTOTYPES - -// TODO(benvanik): define policies. For now they are no-ops. -IREE_API_EXPORT iree_status_t iree_rt_policy_create( - iree_allocator_t allocator, iree_rt_policy_t** out_policy); - -// Retains the given |policy| for the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_policy_retain(iree_rt_policy_t* policy); - -// Releases the given |policy| from the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_policy_release(iree_rt_policy_t* policy); - -#endif // IREE_API_NO_PROTOTYPES - -//===----------------------------------------------------------------------===// -// iree::rt::Context -//===----------------------------------------------------------------------===// - -#ifndef IREE_API_NO_PROTOTYPES - -// Creates a new context that uses the given |instance| for device management. -// |out_context| must be released by the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_create( - iree_rt_instance_t* instance, iree_rt_policy_t* policy, - iree_allocator_t allocator, iree_rt_context_t** out_context); - -// Retains the given |context| for the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_context_retain(iree_rt_context_t* context); - -// Releases the given |context| from the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_context_release(iree_rt_context_t* context); - -// Returns a process-unique ID for the |context|. -IREE_API_EXPORT int32_t IREE_API_CALL -iree_rt_context_id(const iree_rt_context_t* context); - -// Registers a list of modules with the context and resolves imports. -// The modules will be retained by the context until destruction. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_register_modules( - iree_rt_context_t* context, iree_rt_module_t** modules, - iree_host_size_t module_count); - -// Returns a reference to the module registered with the given name or nullptr -// if not found. The caller must retain the returned module if they want to -// continue using it. -IREE_API_EXPORT iree_rt_module_t* IREE_API_CALL -iree_rt_context_lookup_module_by_name(const iree_rt_context_t* context, - iree_string_view_t module_name); - -// Sets |out_function| to to an exported function with the fully-qualified name -// of |full_name| or returns IREE_STATUS_NOT_FOUND. The function reference is -// valid for the lifetime of |context|. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_resolve_function( - const iree_rt_context_t* context, iree_string_view_t full_name, - iree_rt_function_t* out_function); - -// Allocates a host-local buffer that is optimal for use on the host but is -// usable by the given |device_placements| (at a possible performance -// penalty). The buffer can be used for staging uploads to device-local -// buffers and is useful for times when the buffer will be used more on the -// host than the device. If a buffer never needs to be used with a device -// prefer instead HeapBuffer::Allocate. -// -// Fails if it is not possible to allocate and satisfy all placements for the -// requested |buffer_usage|. -// |out_buffer| must be released by the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_context_allocate_device_visible_buffer( - iree_rt_context_t* context, iree_hal_buffer_usage_t buffer_usage, - iree_host_size_t allocation_size, iree_allocator_t allocator, - iree_hal_buffer_t** out_buffer); - -// Allocates a device-local buffer that is optimal for use with the given -// |device_placements|. The buffer will not be host-visible and can only be -// used from compatible device queues. -// -// Fails if it is not possible to allocate and satisfy all placements for the -// requested |buffer_usage|. -// |out_buffer| must be released by the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_context_allocate_device_local_buffer( - iree_rt_context_t* context, iree_hal_buffer_usage_t buffer_usage, - iree_host_size_t allocation_size, iree_allocator_t allocator, - iree_hal_buffer_t** out_buffer); - -#endif // IREE_API_NO_PROTOTYPES - -//===----------------------------------------------------------------------===// -// iree::rt::Invocation -//===----------------------------------------------------------------------===// - -#ifndef IREE_API_NO_PROTOTYPES - -// Creates a new invocation tracking object for invoking the given |function| -// from |context|. |arguments| will be retained until the invocation is made. -// If |dependencies| are provided then the invocation will wait until they are -// resolved before executing. If a |policy| is provided it will override the -// context-level policy. -// -// Optionally |results| may be provided with preallocated buffers that will -// receive the outputs of the invocation. Invocation will fail if they do not -// match expected sizes. -// -// Note that it's possible for the invocation to complete prior to the return of -// this function. Any errors that occur will be set on the invocation and -// callers should query its state prior to assuming it is in-flight. -// -// |out_invocation| must be released by the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_create( - iree_rt_context_t* context, iree_rt_function_t* function, - iree_rt_policy_t* policy, - const iree_rt_invocation_dependencies_t* dependencies, - iree_hal_buffer_view_t** arguments, iree_host_size_t argument_count, - iree_hal_buffer_view_t** results, iree_host_size_t result_count, - iree_allocator_t allocator, iree_rt_invocation_t** out_invocation); - -// Retains the given |invocation| for the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_invocation_retain(iree_rt_invocation_t* invocation); - -// Releases the given |invocation| from the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_invocation_release(iree_rt_invocation_t* invocation); - -// Queries the completion status of the invocation. -// Returns one of the following: -// IREE_STATUS_OK: the invocation completed successfully. -// IREE_STATUS_UNAVAILABLE: the invocation has not yet completed. -// IREE_STATUS_CANCELLED: the invocation was cancelled internally. -// IREE_STATUS_ABORTED: the invocation was aborted. -// IREE_STATUS_*: an error occurred during invocation. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_invocation_query_status(iree_rt_invocation_t* invocation); - -// Populates |out_results| to the values of the results. -// |result_capacity| defines the number of elements available in |out_results| -// and |out_result_count| will be set with the actual number of results -// available. If |result_capacity| is too small IREE_STATUS_OUT_OF_RANGE will be -// returned wtih the required capacity in |out_result_count|. To only query the -// required capacity |out_results| may be passed as nullptr. -// -// Ownership of returned results will be transferred to the caller and they must -// be released if no longer needed. -// -// Returns errors as with iree_rt_invocation_query_status, for example in the -// case of not-yet-completed or aborted invocations. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_consume_results( - iree_rt_invocation_t* invocation, iree_host_size_t result_capacity, - iree_allocator_t allocator, iree_hal_buffer_view_t** out_results, - iree_host_size_t* out_result_count); - -// Blocks the caller until the invocation completes (successfully or otherwise). -// -// Returns IREE_STATUS_DEADLINE_EXCEEDED if |deadline| elapses before the -// invocation completes and otherwise returns iree_rt_invocation_query_status. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_await( - iree_rt_invocation_t* invocation, iree_time_t deadline); - -// Attempts to abort the invocation if it is in-flight. -// A no-op if the invocation has already completed. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_rt_invocation_abort(iree_rt_invocation_t* invocation); - -#endif // IREE_API_NO_PROTOTYPES - -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus - -#endif // IREE_RT_API_H_
diff --git a/iree/rt/context.cc b/iree/rt/context.cc deleted file mode 100644 index 6ae7a0a..0000000 --- a/iree/rt/context.cc +++ /dev/null
@@ -1,167 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/rt/context.h" - -#include <atomic> - -#include "absl/strings/str_cat.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/rt/debug/debug_server.h" -#include "iree/rt/instance.h" -#include "iree/rt/invocation.h" - -namespace iree { -namespace rt { - -namespace { - -int32_t NextUniqueContextId() { - static std::atomic<int32_t> next_id = {0}; - return ++next_id; -} - -} // namespace - -Context::Context(ref_ptr<Instance> instance, ref_ptr<Policy> policy) - : id_(NextUniqueContextId()), - instance_(std::move(instance)), - policy_(std::move(policy)) { - IREE_TRACE_SCOPE("Context::ctor", int32_t)(id_); - instance_->RegisterContext(this); -} - -Context::~Context() { - IREE_TRACE_SCOPE("Context::dtor", int32_t)(id_); - instance_->UnregisterContext(this); -} - -std::string Context::DebugStringShort() const { - return absl::StrCat("context_", id_); -} - -Status Context::RegisterModule(ref_ptr<Module> module) { - IREE_TRACE_SCOPE0("Context::RegisterModule"); - - // Ensure no conflicts in naming - we don't support shadowing. - for (const auto& existing_module : modules_) { - if (existing_module->name() == module->name()) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Module '" << module->name() - << "' has already been registered in the context"; - } - } - - // Try resolving prior to actually registering; if we can't resolve an import - // then we want to fail the entire registration. - ASSIGN_OR_RETURN(auto import_table, ResolveImports(module.get())); - - auto* debug_server = instance_->debug_server(); - if (debug_server) { - CHECK_OK(debug_server->RegisterContextModule(this, module.get())); - } - - modules_.push_back(std::move(module)); - module_import_tables_.push_back(std::move(import_table)); - return OkStatus(); -} - -StatusOr<ModuleImportTable> Context::ResolveImports(Module* module) { - IREE_TRACE_SCOPE0("Context::ResolveImports"); - - int32_t import_count = module->signature().import_function_count(); - ModuleImportTable import_table; - import_table.first = module; - import_table.second.resize(import_count); - - for (int32_t i = 0; i < import_count; ++i) { - ASSIGN_OR_RETURN(auto import_function_name, - module->GetFunctionName(Function::Linkage::kImport, i)); - ASSIGN_OR_RETURN(import_table.second[i], - ResolveFunction(import_function_name)); - } - - return import_table; -} - -StatusOr<Module*> Context::LookupModuleByName( - absl::string_view module_name) const { - for (const auto& module : modules_) { - if (module->name() == module_name) { - return module.get(); - } - } - return NotFoundErrorBuilder(IREE_LOC) - << "No module with the name '" << module_name - << "' has been registered"; -} - -StatusOr<const Function> Context::ResolveFunction( - absl::string_view full_name) const { - size_t last_dot = full_name.rfind('.'); - if (last_dot == absl::string_view::npos) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "'" << full_name - << "' is not fully qualified (expected 'module.function')"; - } - auto module_name = full_name.substr(0, last_dot); - auto function_name = full_name.substr(last_dot + 1); - ASSIGN_OR_RETURN(auto* module, LookupModuleByName(module_name)); - return module->LookupFunctionByName(Function::Linkage::kExport, - function_name); -} - -StatusOr<const Function> Context::ResolveImport(const Module* module, - int32_t ordinal) const { - for (const auto& import_table_ref : module_import_tables_) { - if (import_table_ref.first == module) { - const auto& import_table = import_table_ref.second; - if (ordinal >= import_table.size()) { - return NotFoundErrorBuilder(IREE_LOC) - << "Import ordinal " << ordinal - << " out of bounds of import table (" << import_table.size() - << ")"; - } - return import_table[ordinal]; - } - } - return NotFoundErrorBuilder(IREE_LOC) - << "Import ordinal " << ordinal << " not found"; -} - -void Context::RegisterInvocation(Invocation* invocation) { - { - absl::MutexLock lock(&invocations_mutex_); - invocations_.push_back(invocation); - } - auto* debug_server = instance_->debug_server(); - if (debug_server) { - CHECK_OK(debug_server->RegisterInvocation(invocation)); - } -} - -void Context::UnregisterInvocation(Invocation* invocation) { - auto* debug_server = instance_->debug_server(); - if (debug_server) { - CHECK_OK(debug_server->UnregisterInvocation(invocation)); - } - { - absl::MutexLock lock(&invocations_mutex_); - invocations_.erase(invocation); - } -} - -} // namespace rt -} // namespace iree
diff --git a/iree/rt/context.h b/iree/rt/context.h deleted file mode 100644 index d67b79a..0000000 --- a/iree/rt/context.h +++ /dev/null
@@ -1,119 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_RT_CONTEXT_H_ -#define IREE_RT_CONTEXT_H_ - -#include <ostream> - -#include "absl/base/thread_annotations.h" -#include "absl/container/inlined_vector.h" -#include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/optional.h" -#include "iree/base/intrusive_list.h" -#include "iree/base/ref_ptr.h" -#include "iree/base/status.h" -#include "iree/hal/buffer_view.h" -#include "iree/rt/invocation.h" -#include "iree/rt/module.h" -#include "iree/rt/policy.h" - -namespace iree { -namespace rt { - -class Instance; - -using ModuleImportTable = std::pair<Module*, std::vector<Function>>; - -// An isolated execution context. -// Effectively a sandbox where modules can be loaded and run with restricted -// visibility and where they can maintain state. -// -// Modules have imports resolved automatically when registered by searching -// existing modules registered within the context and load order is used for -// resolution. For example, target-specific modules should be loaded prior to -// generic modules that may import functions defined there and if a function is -// not available in the target-specific modules the fallback provided by the -// generic module will be used. -// -// Thread-compatible and must be externally synchronized. -class Context final : public RefObject<Context> { - public: - Context(ref_ptr<Instance> instance, ref_ptr<Policy> policy); - ~Context(); - - // A process-unique ID for the context. - int32_t id() const { return id_; } - - // Instance this context uses for shared resources. - const ref_ptr<Instance>& instance() const { return instance_; } - - // A short human-readable name for the context. - std::string DebugStringShort() const; - - // A list of modules registered with the context. - absl::Span<const ref_ptr<Module>> modules() const { - return absl::MakeConstSpan(modules_); - } - - // Registers a new module with the context. - // Imports from the module will be resolved using the existing modules in the - // context. The module will be retained by the context until destruction. - Status RegisterModule(ref_ptr<Module> module); - - // Looks up a module by name. - StatusOr<Module*> LookupModuleByName(absl::string_view module_name) const; - - // Resolves an exported function by fully-qualified name. The function - // reference is valid for the lifetime of the context. - StatusOr<const Function> ResolveFunction(absl::string_view full_name) const; - - // Resolves an imported function by import ordinal. The function reference is - // valid for the lifetime of the context. - StatusOr<const Function> ResolveImport(const Module* module, - int32_t ordinal) const; - - private: - // Resolves imports for the given module. - StatusOr<ModuleImportTable> ResolveImports(Module* module); - - friend class Invocation; - void RegisterInvocation(Invocation* invocation); - void UnregisterInvocation(Invocation* invocation); - - int32_t id_; - ref_ptr<Instance> instance_; - ref_ptr<Policy> policy_; - - absl::InlinedVector<ref_ptr<Module>, 4> modules_; - absl::InlinedVector<ModuleImportTable, 4> module_import_tables_; - - absl::Mutex invocations_mutex_; - IntrusiveList<Invocation, offsetof(Invocation, context_list_link_)> - invocations_ ABSL_GUARDED_BY(invocations_mutex_); - - friend class Instance; - IntrusiveListLink instance_list_link_; -}; - -inline std::ostream& operator<<(std::ostream& stream, const Context& context) { - stream << context.DebugStringShort(); - return stream; -} - -} // namespace rt -} // namespace iree - -#endif // IREE_RT_CONTEXT_H_
diff --git a/iree/rt/debug/BUILD b/iree/rt/debug/BUILD deleted file mode 100644 index 8c4a939..0000000 --- a/iree/rt/debug/BUILD +++ /dev/null
@@ -1,192 +0,0 @@ -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -# TODO(benvanik): re-enable debugger after refactoring. -# cc_library( -# name = "debug_client", -# srcs = ["debug_client.cc"], -# hdrs = ["debug_client.h"], -# deps = [ -# ":debug_client_interface", -# ":debug_client_tcp", # build-cleaner: keep -# "@com_google_absl//absl/container:flat_hash_map", -# "@com_google_absl//absl/strings", -# "@com_google_absl//absl/types:optional", -# "@com_google_absl//absl/types:span", -# "//iree/base:source_location", -# "//iree/base:status", -# "//iree/schemas", -# ], -# ) -# -# cc_library( -# name = "debug_client_interface", -# hdrs = ["debug_client.h"], -# deps = [ -# "@com_google_absl//absl/container:flat_hash_map", -# "@com_google_absl//absl/strings", -# "@com_google_absl//absl/types:optional", -# "@com_google_absl//absl/types:span", -# "//iree/base:status", -# "//iree/schemas", -# ], -# ) -# -# cc_library( -# name = "debug_client_tcp", -# srcs = ["debug_client_tcp.cc"], -# deps = [ -# ":debug_client_interface", -# ":debug_tcp_util", -# "@com_google_absl//absl/container:flat_hash_map", -# "@com_google_absl//absl/memory", -# "@com_google_absl//absl/strings", -# "@com_google_absl//absl/types:span", -# "@com_github_google_flatbuffers//:flatbuffers", -# "//iree/base:flatbuffer_util", -# "//iree/base:status", -# "//iree/rt", -# "//iree/schemas", -# ], -# ) -# -# cc_library( -# name = "debug_server", -# hdrs = ["debug_server.h"], -# deps = [ -# ":debug_server_interface", -# "//third_party/flatbuffers:flatbuffers", -# "//iree/schemas", -# "//iree/base:status", -# ] + select({ -# "//iree:debug": [":debug_server_tcp"], -# "//conditions:default": [":debug_server_disabled"], -# }), -# ) - -cc_library( - name = "debug_server", - hdrs = ["debug_server.h"], - deps = [ - ":debug_server_disabled", - ":debug_server_interface", - "//iree/base:status", - ], -) - -cc_library( - name = "debug_server_interface", - hdrs = ["debug_server.h"], - deps = ["//iree/base:status"], -) - -cc_library( - name = "debug_server_disabled", - srcs = ["debug_server_disabled.cc"], - deps = [ - ":debug_server_interface", - "@com_google_absl//absl/memory", - ], -) - -# TODO(benvanik): re-enable debugger after refactoring. -# cc_library( -# name = "debug_server_tcp", -# srcs = ["debug_server_tcp.cc"], -# deps = [ -# ":debug_server_interface", -# ":debug_service", -# ":debug_tcp_util", -# "@com_google_absl//absl/base:core_headers", -# "@com_google_absl//absl/memory", -# "@com_google_absl//absl/synchronization", -# "@com_github_google_flatbuffers//:flatbuffers", -# "//iree/base:status", -# "//iree/schemas", -# ], -# ) - -cc_library( - name = "debug_server_flags", - srcs = ["debug_server_flags.cc"], - hdrs = ["debug_server_flags.h"], - deps = [ - ":debug_server", - "//iree/base:memory", - "//iree/base:status", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/strings", - ], -) - -# TODO(benvanik): re-enable debugger after refactoring. -# cc_library( -# name = "debug_server_flags", -# srcs = ["debug_server_flags.cc"], -# hdrs = ["debug_server_flags.h"], -# copts = select({ -# "//iree:debug": [ -# "-DIREE_DEBUG_EMBEDDED_APP_PRESENT=1", -# ], -# "//conditions:default": [], -# }), -# deps = [ -# ":debug_server", -# "@com_google_absl//absl/flags:flag", -# "@com_google_absl//absl/strings", -# "//iree/base:memory", -# "//iree/base:status", -# ] + select({ -# "//iree:debug": [ -# "//iree/tools/debugger:debug_app_embedded", -# "//third_party/GL/native:EGL", # build-cleaner: keep -# "//third_party/GL/native:GLESv2", # build-cleaner: keep -# ], -# "//conditions:default": [], -# }), -# ) -# -# cc_library( -# name = "debug_service", -# srcs = ["debug_service.cc"], -# hdrs = ["debug_service.h"], -# deps = [ -# ":debug_session", -# "@com_google_absl//absl/base:core_headers", -# "@com_google_absl//absl/strings", -# "@com_google_absl//absl/synchronization", -# "@com_github_google_flatbuffers//:flatbuffers", -# "//iree/base:flatbuffer_util", -# "//iree/base:source_location", -# "//iree/base:status", -# "//iree/rt", -# "//iree/schemas", -# "//iree/schemas:reflection_data", -# ], -# ) -# -# cc_library( -# name = "debug_session", -# srcs = ["debug_session.cc"], -# hdrs = ["debug_session.h"], -# deps = [ -# "@com_google_absl//absl/base:core_headers", -# "@com_google_absl//absl/synchronization", -# "//iree/base:source_location", -# "//iree/base:status", -# "//iree/rt", -# "//iree/schemas", -# ], -# ) -# -# cc_library( -# name = "debug_tcp_util", -# hdrs = ["debug_tcp_util.h"], -# deps = [ -# "@com_github_google_flatbuffers//:flatbuffers", -# "//iree/base:status", -# "//iree/schemas", -# ], -# )
diff --git a/iree/rt/debug/debug_adapter.h b/iree/rt/debug/debug_adapter.h deleted file mode 100644 index ab1d495..0000000 --- a/iree/rt/debug/debug_adapter.h +++ /dev/null
@@ -1,106 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_RT_DEBUG_ADAPTER_H_ -#define IREE_RT_DEBUG_ADAPTER_H_ - -#include <functional> - -#include "iree/base/status.h" -#include "iree/rt/invocation.h" - -namespace iree { -namespace rt { -namespace debug { - -struct StepTarget { - // TODO(benvanik): step target info (matching RPC message). - // module / function / offset - // relative to current: once, out, return, etc -}; - -// TODO(benvanik): move to fiber base. -// Interface for debugging invocations. -// This is only accessible in debug builds where such features are compiled in. -class DebugAdapter { - public: - // Called when an invocation completes suspending (in response to a Suspend or - // Step request). The |suspend_status| will indicate if the suspension was - // successful. - using SuspendCallback = std::function<void(Status suspend_status)>; - - // Returns true if the invocation is suspended. - // This only returns true if the invocation has been requested to suspend with - // Suspend and the runtime has acked the suspend. Once suspended (and until - // resumed) invocation state will not change and may be observed from any - // thread. - // - // Safe to call from any thread. - bool IsSuspended(Invocation* invocation); - - // Suspends the invocation at the next possible chance. - // - // Fibers have a suspension depth and each call to Suspend must be matched - // with a call to Resume. Fibers will only resume excution when all prior - // Suspend calls have their matching Resume called. - // - // Optionally callers may provide a |suspend_callback| that will be called - // from a random thread when the invocation is suspended (or fails to - // suspend). - // - // Safe to call from any thread. - // Returns StatusCode::kUnavailable if debugging is not supported. - Status Suspend(ref_ptr<Invocation> invocation, - SuspendCallback suspend_callback = nullptr); - - // Resumes the invocation if it is suspended (or cancels a pending suspend). - // This may wake threads if they are currently waiting on the invocation to - // execute. - // - // Safe to call from any thread. - // Returns StatusCode::kUnavailable if debugging is not supported. - Status Resume(Invocation* invocation); - - // Steps invocation execution. - // This will attempt to resume the invocation and will complete - // asynchronously. Upon returning the invocation should be assumed resumed and - // callers must query is_suspended to wait until the invocation suspends - // again. Optionally callers may provide a |suspend_callback| that will be - // called from a random thread when the invocation is suspended (or fails to - // suspend). - // - // Safe to call from any thread while the invocation is suspended. - // Returns StatusCode::kUnavailable if debugging is not supported and - // StatusCode::kFailedPrecondition if the invocation is not suspended. - Status Step(ref_ptr<Invocation> invocation, StepTarget step_target, - SuspendCallback suspend_callback = nullptr); - - // Returns a call stack that can be used to query and manipulate the - // invocation state. The behaviors supported depend on the stack frames and - // the backend support and may be conditionally enabled via compile-time or - // run-time flags. - // - // Safe to call from any thread while the invocation is suspended. - // Returns StatusCode::kUnavailable if debugging is not supported and - // StatusCode::kFailedPrecondition if the invocation is not suspended. - // The returned stack will be invalidated when the invocation is stepped or - // resumed. - StatusOr<Stack> CaptureStack(Invocation* invocation); -}; - -} // namespace debug -} // namespace rt -} // namespace iree - -#endif // IREE_RT_DEBUG_ADAPTER_H_
diff --git a/iree/rt/debug/debug_client.cc b/iree/rt/debug/debug_client.cc deleted file mode 100644 index 66d42a7..0000000 --- a/iree/rt/debug/debug_client.cc +++ /dev/null
@@ -1,64 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/rt/debug/debug_client.h" - -#include "iree/base/source_location.h" -#include "iree/base/status.h" - -namespace iree { -namespace rt { -namespace debug { - -Status DebugClient::GetFunction( - std::string module_name, std::string function_name, - std::function<void(StatusOr<RemoteFunction*> function)> callback) { - return ResolveFunction( - module_name, function_name, - [this, module_name, callback](StatusOr<int> function_ordinal) { - if (!function_ordinal.ok()) { - callback(function_ordinal.status()); - return; - } - auto status = - GetFunction(module_name, function_ordinal.ValueOrDie(), callback); - if (!status.ok()) { - callback(std::move(status)); - } - }); -} - -Status DebugClient::StepInvocationOver(const RemoteInvocation& invocation, - std::function<void()> callback) { - // TODO(benvanik): implement bytecode stepping search. - // int bytecode_offset = 0; - // return StepInvocationToOffset(invocation, bytecode_offset, - // std::move(callback)); - return UnimplementedErrorBuilder(IREE_LOC) - << "StepInvocationOver not yet implemented"; -} - -Status DebugClient::StepInvocationOut(const RemoteInvocation& invocation, - std::function<void()> callback) { - // TODO(benvanik): implement bytecode stepping search. - // int bytecode_offset = 0; - // return StepInvocationToOffset(invocation, bytecode_offset, - // std::move(callback)); - return UnimplementedErrorBuilder(IREE_LOC) - << "StepInvocationOut not yet implemented"; -} - -} // namespace debug -} // namespace rt -} // namespace iree
diff --git a/iree/rt/debug/debug_client.h b/iree/rt/debug/debug_client.h deleted file mode 100644 index 4031f5e..0000000 --- a/iree/rt/debug/debug_client.h +++ /dev/null
@@ -1,286 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_RT_DEBUG_DEBUG_CLIENT_H_ -#define IREE_RT_DEBUG_DEBUG_CLIENT_H_ - -#include <functional> -#include <memory> - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" -#include "absl/types/span.h" -#include "iree/base/status.h" -#include "iree/schemas/debug_service_generated.h" - -namespace iree { -namespace rt { -namespace debug { - -// Remote breakpoint currently active on the server. -class RemoteBreakpoint { - public: - enum class Type { - kBytecodeFunction = 0, - kNativeFunction = 1, - }; - - virtual ~RemoteBreakpoint() = default; - - int id() const { return id_; } - Type type() const { return type_; } - - virtual const std::string& module_name() const = 0; - virtual const std::string& function_name() const = 0; - virtual int function_ordinal() const = 0; - virtual int bytecode_offset() const = 0; - - protected: - explicit RemoteBreakpoint(int id, Type type) : id_(id), type_(type) {} - - private: - int id_; - Type type_; -}; - -class RemoteModule; - -class RemoteFunction { - public: - virtual ~RemoteFunction() = default; - - RemoteModule* module() const { return module_; } - int ordinal() const { return function_ordinal_; } - virtual const std::string& name() const = 0; - - virtual const FunctionDef& def() = 0; - - virtual bool is_loaded() const = 0; - virtual bool CheckLoadedOrRequest() = 0; - - using LoadCallback = std::function<void(StatusOr<RemoteFunction*>)>; - virtual void WhenLoaded(LoadCallback callback) = 0; - - virtual const BytecodeDef* bytecode() = 0; - - protected: - RemoteFunction(RemoteModule* module, int function_ordinal) - : module_(module), function_ordinal_(function_ordinal) {} - - RemoteModule* module_; - int function_ordinal_; -}; - -class RemoteModule { - public: - virtual ~RemoteModule() = default; - - int context_id() const { return context_id_; } - const std::string& name() const { return name_; } - - virtual const ModuleDef& def() = 0; - - virtual bool is_loaded() const = 0; - virtual bool CheckLoadedOrRequest() = 0; - - using LoadCallback = std::function<void(StatusOr<RemoteModule*>)>; - virtual void WhenLoaded(LoadCallback callback) = 0; - - virtual absl::Span<RemoteFunction*> functions() = 0; - - protected: - RemoteModule(int context_id, std::string name) - : context_id_(context_id), name_(std::move(name)) {} - - private: - int context_id_; - std::string name_; -}; - -class RemoteContext { - public: - virtual ~RemoteContext() = default; - - int id() const { return id_; } - - virtual absl::Span<RemoteModule* const> modules() const = 0; - - protected: - explicit RemoteContext(int id) : id_(id) {} - - private: - int id_; -}; - -class RemoteInvocation { - public: - virtual ~RemoteInvocation() = default; - - int id() const { return id_; } - const std::string& name() const { return name_; } - - virtual const rpc::InvocationDefT& def() const = 0; - - protected: - explicit RemoteInvocation(int id) - : id_(id), name_(absl::StrCat("Invocation ", id)) {} - - private: - int id_; - std::string name_; -}; - -// Debugger RPC server client. -// Statefully tracks a DebugServer to provide common client operations and -// memoized queries. -// -// Thread-compatible. Do not use the client from multiple threads concurrently. -// All remote updates of local state are performed by the Poll function. See -// Poll for more details. -class DebugClient { - public: - // Debug event listener interface. - // Event methods will be called from within Poll calls (so on that thread). - // - // When the server posts an event it will mark the client as unready and - // suspend execution of all invocations until MakeReady is used to indicate - // that the client is ready for the server to resume. Each event needs a - // matching MakeReady ack. - // - // Listeners can defer acking if they need to perform additional queries or - // state changes to the server or wait for user interaction. Multiple events - // may come in while unready if there was a series of events pending on the - // server. - class Listener { - public: - virtual ~Listener() = default; - - // Signals that a context has been registered on the server. - virtual Status OnContextRegistered(const RemoteContext& context) = 0; - virtual Status OnContextUnregistered(const RemoteContext& context) = 0; - - // Signals that a module has been loaded into a context on the server. - virtual Status OnModuleLoaded(const RemoteContext& context, - const RemoteModule& module) = 0; - - // Signals that a invocation has been registered on the server. - virtual Status OnInvocationRegistered( - const RemoteInvocation& invocation) = 0; - virtual Status OnInvocationUnregistered( - const RemoteInvocation& invocation) = 0; - - // Signals that a breakpoint has been hit by a invocation on the server. - virtual Status OnBreakpointHit(const RemoteBreakpoint& breakpoint, - const RemoteInvocation& invocation) = 0; - }; - - // Connects to a remote debug service at the provided IP:port. - // The specified |listener| will receive async event notifications. - static StatusOr<std::unique_ptr<DebugClient>> Connect( - absl::string_view service_address, Listener* listener); - - virtual ~DebugClient() = default; - - // Returns true if the client is connected to a service. - // virtual bool is_connected() const = 0; - - // A list of all contexts registered with the server. - virtual absl::Span<RemoteContext* const> contexts() const = 0; - - // A list of all invocations registered with the server. - virtual absl::Span<RemoteInvocation* const> invocations() const = 0; - - // A list of all breakpoints registered with the server. - virtual absl::Span<RemoteBreakpoint* const> breakpoints() const = 0; - - // Resolves a function to a module ordinal. - // This will occur asynchronously and the |callback| will be issued on the - // polling thread. - virtual Status ResolveFunction( - std::string module_name, std::string function_name, - std::function<void(StatusOr<int> function_ordinal)> callback) = 0; - - // Gets a function body instance. - // The provided |callback| will be issued on the polling thread when the - // function is available. - virtual Status GetFunction( - std::string module_name, int function_ordinal, - std::function<void(StatusOr<RemoteFunction*> function)> callback) = 0; - Status GetFunction( - std::string module_name, std::string function_name, - std::function<void(StatusOr<RemoteFunction*> function)> callback); - - // Adds a breakpoint for the given module:function:offset. - // The breakpoint will apply to all contexts with the module loaded. - virtual Status AddFunctionBreakpoint( - std::string module_name, std::string function_name, int offset, - std::function<void(const RemoteBreakpoint& breakpoint)> callback = - nullptr) = 0; - - // Removes a breakpoint from the server. - virtual Status RemoveBreakpoint(const RemoteBreakpoint& breakpoint) = 0; - - // Notifies the server that the debug session is ready to continue. - // This must be called once on connection to and in acknowledgement to any - // events posted by the server (read: any call to the Listener::On* methods). - virtual Status MakeReady() = 0; - - // Suspends all invocations running on the server. - virtual Status SuspendAllInvocations() = 0; - - // Resumes all invocations running on the server. - virtual Status ResumeAllInvocations() = 0; - - // Suspends a list of invocations running on the server. Invocations not in - // the provided list will not be suspended, such as new invocations created - // while the request is pending. - virtual Status SuspendInvocations( - absl::Span<RemoteInvocation*> invocations) = 0; - - // Resumes a list of invocations running on the server. - virtual Status ResumeInvocations( - absl::Span<RemoteInvocation*> invocations) = 0; - - // Steps a invocation one bytecode operation. - virtual Status StepInvocation(const RemoteInvocation& invocation, - std::function<void()> callback) = 0; - // Steps a invocation over one bytecode operation, not stopping until it - // completes. - Status StepInvocationOver(const RemoteInvocation& invocation, - std::function<void()> callback); - // Steps a invocation out of the current block. - Status StepInvocationOut(const RemoteInvocation& invocation, - std::function<void()> callback); - // Steps a invocation to a specific bytecode offset within the current - // function. - virtual Status StepInvocationToOffset(const RemoteInvocation& invocation, - int bytecode_offset, - std::function<void()> callback) = 0; - - // TODO(benvanik): profiling modes. - - // Polls for the current state of the debug service and processes incoming - // responses. Must be called as frequently as the UI is desired to update. - // Returns CancelledError when the service is being shutdown/disconnected. - // - // Events on the Listener will be called from within this method. - virtual Status Poll() = 0; -}; - -} // namespace debug -} // namespace rt -} // namespace iree - -#endif // IREE_RT_DEBUG_DEBUG_CLIENT_H_
diff --git a/iree/rt/debug/debug_client_tcp.cc b/iree/rt/debug/debug_client_tcp.cc deleted file mode 100644 index 7b00410..0000000 --- a/iree/rt/debug/debug_client_tcp.cc +++ /dev/null
@@ -1,1127 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <netdb.h> -#include <netinet/in.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <unistd.h> - -#include <algorithm> -#include <cstring> -#include <queue> - -#include "absl/container/flat_hash_map.h" -#include "absl/memory/memory.h" -#include "absl/strings/ascii.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_split.h" -#include "absl/types/span.h" -#include "flatbuffers/base.h" -#include "flatbuffers/flatbuffers.h" -#include "iree/base/flatbuffer_util.h" -#include "iree/base/status.h" -#include "iree/rt/debug/debug_client.h" -#include "iree/rt/debug/debug_tcp_util.h" -#include "iree/rt/module.h" -#include "iree/schemas/debug_service_generated.h" -#include "iree/schemas/module_def_generated.h" - -namespace iree { -namespace rt { -namespace debug { -namespace { - -using ::flatbuffers::FlatBufferBuilder; - -// Parses a host:port address, with support for the RFC 3986 IPv6 [host]:port -// format. Returns a pair of (hostname, port), with port being 0 if none was -// specified. -// -// Parses: -// foo (port 0) / foo:123 -// 1.2.3.4 (port 0) / 1.2.3.4:123 -// [foo] (port 0) / [foo]:123 -// [::1] (port 0) / [::1]:123 -StatusOr<std::pair<std::string, int>> ParseAddress(absl::string_view address) { - address = absl::StripAsciiWhitespace(address); - absl::string_view hostname; - absl::string_view port_str; - size_t bracket_loc = address.find_last_of(']'); - if (bracket_loc != std::string::npos) { - // Has at least a ]. Let's assume it's mostly right. - if (address.find('[') != 0) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Mismatched brackets in address: " << address; - } - hostname = address.substr(1, bracket_loc - 1); - port_str = address.substr(bracket_loc + 1); - if (port_str.find(':') == 0) { - port_str.remove_prefix(1); - } - } else { - size_t colon_loc = address.find_last_of(':'); - if (colon_loc != std::string::npos) { - hostname = address.substr(0, colon_loc); - port_str = address.substr(colon_loc + 1); - } else { - hostname = address; - port_str = ""; - } - } - int port = 0; - if (!port_str.empty() && !absl::SimpleAtoi(port_str, &port)) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Unable to parse port '" << port_str << "' from " << address; - } - return std::make_pair(std::string(hostname), port); -} - -class TcpDebugClient final : public DebugClient { - public: - class TcpRemoteBreakpoint : public RemoteBreakpoint { - public: - TcpRemoteBreakpoint(int id, Type type, TcpDebugClient* client) - : RemoteBreakpoint(id, type) {} - - const std::string& module_name() const override { return def_.module_name; } - const std::string& function_name() const override { - return def_.function_name; - } - int function_ordinal() const override { return def_.function_ordinal; } - int bytecode_offset() const override { return def_.bytecode_offset; } - - Status MergeFrom(const rpc::BreakpointDef& breakpoint_def) { - breakpoint_def.UnPackTo(&def_); - return OkStatus(); - } - - private: - rpc::BreakpointDefT def_; - }; - - class TcpRemoteFunction final : public RemoteFunction { - public: - TcpRemoteFunction(RemoteModule* module, int function_ordinal, - const FunctionDef* function_def, TcpDebugClient* client) - : RemoteFunction(module, function_ordinal), - def_(function_def), - client_(client) { - name_ = def_->name() ? std::string(WrapString(def_->name())) : ""; - } - - const std::string& name() const override { return name_; } - - const FunctionDef& def() override { return *def_; } - - bool is_loaded() const override { - return contents_.flatbuffers_buffer.size() > 0; - } - - bool CheckLoadedOrRequest() override { - if (!is_loaded()) { - DemandContents(); - } - return is_loaded(); - } - - void WhenLoaded(LoadCallback callback) override { - if (is_loaded()) { - callback(this); - return; - } - load_callbacks_.push_back(std::move(callback)); - } - - const BytecodeDef* bytecode() override { - CHECK(is_loaded()); - return contents_.bytecode_def; - } - - private: - void DemandContents() { - if (!has_requested_contents_) { - VLOG(2) << "Client " << client_->fd() << ": GetFunction(" - << module()->context_id() << ", " << module()->name() << ", " - << ordinal() << ")"; - FlatBufferBuilder fbb; - rpc::GetFunctionRequestT request; - request.session_id = client_->session_id(); - request.context_id = module()->context_id(); - request.module_name = module()->name(); - request.function_ordinal = ordinal(); - auto status = - client_->IssueRequest<rpc::GetFunctionRequest, - rpc::ResponseUnion::GetFunctionResponse>( - rpc::GetFunctionRequest::Pack(fbb, &request), std::move(fbb), - [this](Status status, - const rpc::Response& response_union) -> Status { - if (!status.ok()) return status; - const auto& response = - *response_union.message_as_GetFunctionResponse(); - VLOG(2) << "Client " << client_->fd() << ": GetFunction(" - << module()->context_id() << ", " << module()->name() - << ", " << ordinal() << ") = ..."; - RETURN_IF_ERROR(MergeFrom(response)); - for (auto& callback : load_callbacks_) { - callback(this); - } - load_callbacks_.clear(); - return OkStatus(); - }); - if (!status.ok()) { - LOG(ERROR) << "Failed to request module: " << status; - return; - } - has_requested_contents_ = true; - } - } - - Status MergeFrom(const rpc::GetFunctionResponse& response) { - // Clone and retain the contents. - // TODO(benvanik): find a way to steal to avoid the reserialization. - BytecodeDefT bytecode_def_storage; - response.bytecode()->UnPackTo(&bytecode_def_storage); - ::flatbuffers::FlatBufferBuilder fbb; - fbb.Finish(response.bytecode()->Pack(fbb, &bytecode_def_storage)); - contents_.flatbuffers_buffer = fbb.Release(); - contents_.bytecode_def = ::flatbuffers::GetRoot<BytecodeDef>( - contents_.flatbuffers_buffer.data()); - return OkStatus(); - } - - const FunctionDef* def_; - TcpDebugClient* client_; - std::string name_; - bool has_requested_contents_ = false; - std::vector<LoadCallback> load_callbacks_; - struct { - ::flatbuffers::DetachedBuffer flatbuffers_buffer; - const BytecodeDef* bytecode_def = nullptr; - } contents_; - }; - - class TcpRemoteModule final : public RemoteModule { - public: - TcpRemoteModule(int context_id, std::string module_name, - TcpDebugClient* client) - : RemoteModule(context_id, std::move(module_name)), client_(client) {} - - const ModuleDef& def() override { - CHECK(is_loaded()); - return *module_file_->root(); - } - - bool is_loaded() const override { return module_file_ != nullptr; } - - bool CheckLoadedOrRequest() override { - if (!is_loaded()) { - DemandModuleDef(); - } - return is_loaded(); - } - - void WhenLoaded(LoadCallback callback) override { - if (is_loaded()) { - callback(this); - return; - } - load_callbacks_.push_back(std::move(callback)); - } - - absl::Span<RemoteFunction*> functions() override { - auto* module_def = DemandModuleDef(); - if (!module_def) return {}; - return {reinterpret_cast<RemoteFunction**>(functions_.data()), - functions_.size()}; - } - - private: - const ModuleDef* DemandModuleDef() { - if (module_file_) { - return module_file_->root(); - } - if (!has_requested_module_def_) { - VLOG(2) << "Client " << client_->fd() << ": GetModule(" << context_id() - << ", " << name() << ")"; - FlatBufferBuilder fbb; - rpc::GetModuleRequestT request; - request.session_id = client_->session_id(); - request.context_id = context_id(); - request.module_name = name(); - auto status = - client_->IssueRequest<rpc::GetModuleRequest, - rpc::ResponseUnion::GetModuleResponse>( - rpc::GetModuleRequest::Pack(fbb, &request), std::move(fbb), - [this](Status status, - const rpc::Response& response_union) -> Status { - if (!status.ok()) return status; - const auto& response = - *response_union.message_as_GetModuleResponse(); - VLOG(2) << "Client " << client_->fd() << ": GetModule(" - << context_id() << ", " << name() << ") = ..."; - RETURN_IF_ERROR(MergeFrom(response)); - for (auto& callback : load_callbacks_) { - callback(this); - } - load_callbacks_.clear(); - return OkStatus(); - }); - if (!status.ok()) { - LOG(ERROR) << "Failed to request module: " << status; - return nullptr; - } - has_requested_module_def_ = true; - } - return nullptr; - } - - Status MergeFrom(const rpc::GetModuleResponse& response) { - // Clone and retain the module. - // TODO(benvanik): find a way to steal to avoid the reserialization. - ModuleDefT module_def_storage; - response.module_()->UnPackTo(&module_def_storage); - FlatBufferBuilder fbb; - auto module_offs = response.module_()->Pack(fbb, &module_def_storage); - FinishModuleDefBuffer(fbb, module_offs); - ASSIGN_OR_RETURN(auto module_file, - ModuleFile::CreateWithBackingBuffer(fbb.Release())); - - const auto& module_def = module_file->root(); - const auto& function_table = *module_def->function_table(); - functions_.reserve(function_table.functions()->size()); - for (int i = 0; i < function_table.functions()->size(); ++i) { - const auto* function_def = function_table.functions()->Get(i); - functions_.push_back(absl::make_unique<TcpRemoteFunction>( - this, i, function_def, client_)); - } - - module_file_ = std::move(module_file); - return OkStatus(); - } - - TcpDebugClient* client_; - bool has_requested_module_def_ = false; - std::vector<LoadCallback> load_callbacks_; - std::unique_ptr<ModuleFile> module_file_; - std::vector<std::unique_ptr<RemoteFunction>> functions_; - }; - - class TcpRemoteContext final : public RemoteContext { - public: - TcpRemoteContext(int context_id, TcpDebugClient* client) - : RemoteContext(context_id), client_(client) {} - - absl::Span<RemoteModule* const> modules() const override { - return absl::MakeConstSpan(modules_); - } - - Status AddModule(std::unique_ptr<TcpRemoteModule> module) { - modules_.push_back(module.get()); - module_map_.insert({module->name(), std::move(module)}); - return OkStatus(); - } - - Status MergeFrom(const rpc::ContextDef& context_def) { return OkStatus(); } - - private: - TcpDebugClient* client_; - std::vector<RemoteModule*> modules_; - absl::flat_hash_map<std::string, std::unique_ptr<TcpRemoteModule>> - module_map_; - }; - - class TcpRemoteInvocation final : public RemoteInvocation { - public: - TcpRemoteInvocation(int invocation_id, TcpDebugClient* client) - : RemoteInvocation(invocation_id), client_(client) {} - - const rpc::InvocationDefT& def() const override { return def_; } - - Status MergeFrom(const rpc::InvocationDef& invocation_def) { - invocation_def.UnPackTo(&def_); - return OkStatus(); - } - - private: - TcpDebugClient* client_; - rpc::InvocationDefT def_; - }; - - static StatusOr<std::unique_ptr<TcpDebugClient>> Create(int fd, - Listener* listener) { - VLOG(2) << "Client " << fd << ": Setting up socket options..."; - // Disable Nagel's algorithm to ensure we have low latency. - RETURN_IF_ERROR(tcp::ToggleSocketNagelsAlgorithm(fd, false)); - // Enable keepalive assuming the client is local and this high freq is ok. - RETURN_IF_ERROR(tcp::ToggleSocketLocalKeepalive(fd, true)); - // Linger around for a bit to flush all data. - RETURN_IF_ERROR(tcp::ToggleSocketLinger(fd, true)); - // Disable blocking as we are poll based. - RETURN_IF_ERROR(tcp::ToggleSocketBlocking(fd, false)); - - auto client = absl::make_unique<TcpDebugClient>(fd, listener); - RETURN_IF_ERROR(client->Refresh()); - return client; - } - - TcpDebugClient(int fd, Listener* listener) : fd_(fd), listener_(listener) {} - - ~TcpDebugClient() override { - VLOG(2) << "Client " << fd_ << ": Shutting down session socket..."; - ::shutdown(fd_, SHUT_WR); - VLOG(2) << "Client " << fd_ << ": Closing session socket..."; - ::close(fd_); - VLOG(2) << "Client " << fd_ << ": Closed session socket!"; - fd_ = -1; - } - - int fd() const { return fd_; } - int session_id() const { return session_id_; } - - absl::Span<RemoteContext* const> contexts() const override { - return absl::MakeConstSpan(contexts_); - } - - absl::Span<RemoteInvocation* const> invocations() const override { - return absl::MakeConstSpan(invocations_); - } - - absl::Span<RemoteBreakpoint* const> breakpoints() const override { - return absl::MakeConstSpan(breakpoints_); - } - - // Writes the given typed request message to the given fd by wrapping it in - // a size-prefixed rpc::Request union. - // - // Example: - // FlatBufferBuilder fbb; - // rpc::SuspendInvocationRequestBuilder request(fbb); - // RETURN_IF_ERROR(WriteRequest(fd_, request.Finish(), std::move(fbb))); - template <typename T> - Status WriteRequest(int fd, ::flatbuffers::Offset<T> request_offs, - FlatBufferBuilder fbb) { - rpc::RequestBuilder request_builder(fbb); - request_builder.add_message_type(rpc::RequestUnionTraits<T>::enum_value); - request_builder.add_message(request_offs.Union()); - fbb.FinishSizePrefixed(request_builder.Finish()); - auto write_status = tcp::WriteBuffer(fd, fbb.Release()); - if (shutdown_pending_ && IsUnavailable(write_status)) { - return OkStatus(); - } - return write_status; - } - - Status ResolveFunction( - std::string module_name, std::string function_name, - std::function<void(StatusOr<int> function_ordinal)> callback) override { - VLOG(2) << "Client " << fd_ << ": ResolveFunction(" << module_name << ", " - << function_name << ")"; - FlatBufferBuilder fbb; - rpc::ResolveFunctionRequestT request; - request.session_id = session_id_; - request.module_name = module_name; - request.function_name = function_name; - return IssueRequest<rpc::ResolveFunctionRequest, - rpc::ResponseUnion::ResolveFunctionResponse>( - rpc::ResolveFunctionRequest::Pack(fbb, &request), std::move(fbb), - [this, module_name, function_name, callback]( - Status status, const rpc::Response& response_union) -> Status { - if (status.ok()) { - const auto& response = - *response_union.message_as_ResolveFunctionResponse(); - VLOG(2) << "Client " << fd_ << ": ResolveFunction(" << module_name - << ", " << function_name - << ") = " << response.function_ordinal(); - callback(response.function_ordinal()); - } else { - callback(std::move(status)); - } - return OkStatus(); - }); - } - - Status GetFunction(std::string module_name, int function_ordinal, - std::function<void(StatusOr<RemoteFunction*> function)> - callback) override { - // See if we have the module already. If not, we'll fetch it first. - RemoteModule* target_module = nullptr; - for (auto* context : contexts_) { - for (auto* module : context->modules()) { - if (module->name() == module_name) { - target_module = module; - break; - } - } - if (target_module) break; - } - if (!target_module) { - // TODO(benvanik): fetch contexts first. - return UnimplementedErrorBuilder(IREE_LOC) - << "Demand fetch contexts not yet implemented"; - } - // Found at least one module with the right name. - if (target_module->is_loaded()) { - callback(target_module->functions()[function_ordinal]); - return OkStatus(); - } else { - // Wait until the module completes loading. - target_module->WhenLoaded( - [callback, function_ordinal](StatusOr<RemoteModule*> module_or) { - if (!module_or.ok()) { - callback(module_or.status()); - return; - } - callback(module_or.ValueOrDie()->functions()[function_ordinal]); - }); - return OkStatus(); - } - } - - Status AddFunctionBreakpoint( - std::string module_name, std::string function_name, int offset, - std::function<void(const RemoteBreakpoint& breakpoint)> callback) - override { - VLOG(2) << "Client " << fd_ << ": AddFunctionBreakpoint(" << module_name - << ", " << function_name << ", " << offset << ")"; - FlatBufferBuilder fbb; - - auto breakpoint = absl::make_unique<rpc::BreakpointDefT>(); - breakpoint->module_name = module_name; - breakpoint->function_name = function_name; - breakpoint->function_ordinal = -1; - breakpoint->bytecode_offset = offset; - rpc::AddBreakpointRequestT request; - request.session_id = session_id_; - request.breakpoint = std::move(breakpoint); - return IssueRequest<rpc::AddBreakpointRequest, - rpc::ResponseUnion::AddBreakpointResponse>( - rpc::AddBreakpointRequest::Pack(fbb, &request), std::move(fbb), - [this, callback](Status status, - const rpc::Response& response_union) -> Status { - if (!status.ok()) return status; - const auto& response = - *response_union.message_as_AddBreakpointResponse(); - RETURN_IF_ERROR(RegisterBreakpoint(*response.breakpoint())); - if (callback) { - ASSIGN_OR_RETURN( - auto breakpoint, - GetBreakpoint(response.breakpoint()->breakpoint_id())); - callback(*breakpoint); - } - return OkStatus(); - }); - } - - Status RemoveBreakpoint(const RemoteBreakpoint& breakpoint) override { - VLOG(2) << "Client " << fd_ << ": RemoveBreakpoint(" << breakpoint.id() - << ")"; - int breakpoint_id = breakpoint.id(); - ASSIGN_OR_RETURN(auto* breakpoint_ptr, GetBreakpoint(breakpoint_id)); - RETURN_IF_ERROR(UnregisterBreakpoint(breakpoint_ptr)); - FlatBufferBuilder fbb; - rpc::RemoveBreakpointRequestBuilder request(fbb); - request.add_session_id(session_id_); - request.add_breakpoint_id(breakpoint_id); - return IssueRequest<rpc::RemoveBreakpointRequest, - rpc::ResponseUnion::RemoveBreakpointResponse>( - request.Finish(), std::move(fbb), - [](Status status, const rpc::Response& response_union) -> Status { - if (!status.ok()) return status; - // No non-error status. - return OkStatus(); - }); - } - - Status MakeReady() override { - FlatBufferBuilder fbb; - rpc::MakeReadyRequestBuilder request(fbb); - request.add_session_id(session_id_); - return IssueRequest<rpc::MakeReadyRequest, - rpc::ResponseUnion::MakeReadyResponse>( - request.Finish(), std::move(fbb), - [](Status status, const rpc::Response& response_union) { - return status; - }); - } - - Status SuspendAllInvocations() override { - VLOG(2) << "Client " << fd_ << ": SuspendAllInvocations()"; - FlatBufferBuilder fbb; - rpc::SuspendInvocationsRequestBuilder request(fbb); - request.add_session_id(session_id_); - return IssueRequest<rpc::SuspendInvocationsRequest, - rpc::ResponseUnion::SuspendInvocationsResponse>( - request.Finish(), std::move(fbb), - [this](Status status, const rpc::Response& response_union) -> Status { - if (!status.ok()) return status; - return RefreshInvocations(); - }); - } - - Status ResumeAllInvocations() override { - VLOG(2) << "Client " << fd_ << ": ResumeAllInvocations()"; - FlatBufferBuilder fbb; - rpc::ResumeInvocationsRequestBuilder request(fbb); - request.add_session_id(session_id_); - return IssueRequest<rpc::ResumeInvocationsRequest, - rpc::ResponseUnion::ResumeInvocationsResponse>( - request.Finish(), std::move(fbb), - [this](Status status, const rpc::Response& response_union) -> Status { - if (!status.ok()) return status; - return RefreshInvocations(); - }); - } - - Status SuspendInvocations( - absl::Span<RemoteInvocation*> invocations) override { - VLOG(2) << "Client " << fd_ << ": SuspendInvocations(...)"; - FlatBufferBuilder fbb; - auto invocation_ids_offs = fbb.CreateVector<int32_t>( - invocations.size(), - [&invocations](size_t i) { return invocations[i]->id(); }); - rpc::SuspendInvocationsRequestBuilder request(fbb); - request.add_session_id(session_id_); - request.add_invocation_ids(invocation_ids_offs); - return IssueRequest<rpc::SuspendInvocationsRequest, - rpc::ResponseUnion::SuspendInvocationsResponse>( - request.Finish(), std::move(fbb), - [this](Status status, const rpc::Response& response_union) -> Status { - if (!status.ok()) return status; - return RefreshInvocations(); - }); - } - - Status ResumeInvocations(absl::Span<RemoteInvocation*> invocations) override { - VLOG(2) << "Client " << fd_ << ": ResumeInvocations(...)"; - FlatBufferBuilder fbb; - auto invocation_ids_offs = fbb.CreateVector<int32_t>( - invocations.size(), - [&invocations](size_t i) { return invocations[i]->id(); }); - rpc::ResumeInvocationsRequestBuilder request(fbb); - request.add_session_id(session_id_); - request.add_invocation_ids(invocation_ids_offs); - return IssueRequest<rpc::ResumeInvocationsRequest, - rpc::ResponseUnion::ResumeInvocationsResponse>( - request.Finish(), std::move(fbb), - [this](Status status, const rpc::Response& response_union) -> Status { - if (!status.ok()) return status; - return RefreshInvocations(); - }); - } - - Status StepInvocation(const RemoteInvocation& invocation, - std::function<void()> callback) override { - int step_id = next_step_id_++; - VLOG(2) << "Client " << fd_ << ": StepInvocation(" << invocation.id() - << ") as step_id=" << step_id; - rpc::StepInvocationRequestT step_request; - step_request.step_id = step_id; - step_request.invocation_id = invocation.id(); - step_request.step_mode = rpc::StepMode::STEP_ONCE; - return StepInvocation(&step_request, std::move(callback)); - } - - Status StepInvocationToOffset(const RemoteInvocation& invocation, - int bytecode_offset, - std::function<void()> callback) override { - int step_id = next_step_id_++; - VLOG(2) << "Client " << fd_ << ": StepInvocationToOffset(" - << invocation.id() << ", " << bytecode_offset - << ") as step_id=" << step_id; - rpc::StepInvocationRequestT step_request; - step_request.step_id = step_id; - step_request.invocation_id = invocation.id(); - step_request.step_mode = rpc::StepMode::STEP_TO_OFFSET; - step_request.bytecode_offset = bytecode_offset; - return StepInvocation(&step_request, std::move(callback)); - } - - Status Poll() override { - while (true) { - // If nothing awaiting then return immediately. - if (!tcp::CanReadBuffer(fd_)) { - break; - } - - // Read the pending response and dispatch. - auto packet_buffer_or = tcp::ReadBuffer<rpc::ServicePacket>(fd_); - if (!packet_buffer_or.ok()) { - if (shutdown_pending_ && IsUnavailable(packet_buffer_or.status())) { - // This is a graceful close. - return CancelledErrorBuilder(IREE_LOC) << "Service shutdown"; - } - return packet_buffer_or.status(); - } - const auto& packet = packet_buffer_or.ValueOrDie().GetRoot(); - if (packet.response()) { - RETURN_IF_ERROR(DispatchResponse(*packet.response())); - } - if (packet.event()) { - RETURN_IF_ERROR(DispatchEvent(packet)); - } - } - return OkStatus(); - } - - using ResponseCallback = - std::function<Status(Status status, const rpc::Response& response)>; - - template <typename T, rpc::ResponseUnion response_type> - Status IssueRequest(::flatbuffers::Offset<T> request_offs, - FlatBufferBuilder fbb, ResponseCallback callback) { - RETURN_IF_ERROR(WriteRequest(fd_, request_offs, std::move(fbb))); - pending_responses_.push({response_type, std::move(callback)}); - return OkStatus(); - } - - private: - Status Refresh() { - RETURN_IF_ERROR(RefreshContexts()); - RETURN_IF_ERROR(RefreshInvocations()); - RETURN_IF_ERROR(RefreshBreakpoints()); - return OkStatus(); - } - - Status RefreshContexts() { - VLOG(2) << "Request contexts refresh..."; - FlatBufferBuilder fbb; - rpc::ListContextsRequestBuilder request(fbb); - request.add_session_id(session_id_); - return IssueRequest<rpc::ListContextsRequest, - rpc::ResponseUnion::ListContextsResponse>( - request.Finish(), std::move(fbb), - [this](Status status, const rpc::Response& response_union) -> Status { - if (!status.ok()) return status; - VLOG(2) << "Refreshing contexts..."; - const auto& response = - *response_union.message_as_ListContextsResponse(); - for (auto* context_def : *response.contexts()) { - auto context_or = GetContext(context_def->context_id()); - if (!context_or.ok()) { - // Not found; add new. - RETURN_IF_ERROR(RegisterContext(context_def->context_id())); - context_or = GetContext(context_def->context_id()); - } - RETURN_IF_ERROR(context_or.status()); - RETURN_IF_ERROR(context_or.ValueOrDie()->MergeFrom(*context_def)); - } - VLOG(2) << "Refreshed contexts!"; - return OkStatus(); - }); - } - - Status RefreshInvocations() { - VLOG(2) << "Request invocation states refresh..."; - FlatBufferBuilder fbb; - rpc::ListInvocationsRequestBuilder request(fbb); - request.add_session_id(session_id_); - return IssueRequest<rpc::ListInvocationsRequest, - rpc::ResponseUnion::ListInvocationsResponse>( - request.Finish(), std::move(fbb), - [this](Status status, const rpc::Response& response_union) -> Status { - if (!status.ok()) return status; - VLOG(2) << "Refreshing invocation states..."; - const auto& response = - *response_union.message_as_ListInvocationsResponse(); - for (auto* invocation_def : *response.invocations()) { - auto invocation_or = GetInvocation(invocation_def->invocation_id()); - if (!invocation_or.ok()) { - // Not found; add new. - RETURN_IF_ERROR( - RegisterInvocation(invocation_def->invocation_id())); - invocation_or = GetInvocation(invocation_def->invocation_id()); - } - RETURN_IF_ERROR(invocation_or.status()); - RETURN_IF_ERROR( - invocation_or.ValueOrDie()->MergeFrom(*invocation_def)); - } - // TODO(benvanik): handle removals/deaths. - VLOG(2) << "Refreshed invocation states!"; - return OkStatus(); - }); - } - - Status RefreshBreakpoints() { - VLOG(2) << "Requesting breakpoint refresh..."; - FlatBufferBuilder fbb; - rpc::ListBreakpointsRequestBuilder request(fbb); - request.add_session_id(session_id_); - return IssueRequest<rpc::ListBreakpointsRequest, - rpc::ResponseUnion::ListBreakpointsResponse>( - request.Finish(), std::move(fbb), - [this](Status status, const rpc::Response& response_union) -> Status { - if (!status.ok()) return status; - VLOG(2) << "Refreshing breakpoints..."; - const auto& response = - *response_union.message_as_ListBreakpointsResponse(); - for (auto* breakpoint_def : *response.breakpoints()) { - auto breakpoint_or = GetBreakpoint(breakpoint_def->breakpoint_id()); - if (!breakpoint_or.ok()) { - // Not found; add new. - RETURN_IF_ERROR(RegisterBreakpoint(*breakpoint_def)); - breakpoint_or = GetBreakpoint(breakpoint_def->breakpoint_id()); - } - RETURN_IF_ERROR(breakpoint_or.status()); - RETURN_IF_ERROR( - breakpoint_or.ValueOrDie()->MergeFrom(*breakpoint_def)); - } - // TODO(benvanik): handle removals/deaths. - VLOG(2) << "Refreshed breakpoints!"; - return OkStatus(); - }); - } - - Status DispatchResponse(const rpc::Response& response) { - if (pending_responses_.empty()) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Response received but no request is pending"; - } - auto type_callback = std::move(pending_responses_.front()); - pending_responses_.pop(); - - if (response.status()) { - const auto& status = *response.status(); - Status client_status = - StatusBuilder(static_cast<StatusCode>(status.code()), IREE_LOC) - << "Server request failed: " << WrapString(status.message()); - return type_callback.second(std::move(client_status), response); - } - - if (!response.message()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Response contains no message body"; - } - - if (response.message_type() != type_callback.first) { - return DataLossErrorBuilder(IREE_LOC) - << "Out of order response (mismatch pending)"; - } - return type_callback.second(OkStatus(), response); - } - - Status DispatchEvent(const rpc::ServicePacket& packet) { - switch (packet.event_type()) { -#define DISPATCH_EVENT(event_name) \ - case rpc::EventUnion::event_name##Event: { \ - VLOG(2) << "EVENT: " << #event_name; \ - return On##event_name(*packet.event_as_##event_name##Event()); \ - } - DISPATCH_EVENT(ServiceShutdown); - DISPATCH_EVENT(ContextRegistered); - DISPATCH_EVENT(ContextUnregistered); - DISPATCH_EVENT(ModuleLoaded); - DISPATCH_EVENT(InvocationRegistered); - DISPATCH_EVENT(InvocationUnregistered); - DISPATCH_EVENT(BreakpointResolved); - DISPATCH_EVENT(BreakpointHit); - DISPATCH_EVENT(StepCompleted); - default: - return UnimplementedErrorBuilder(IREE_LOC) - << "Unimplemented debug service event: " - << static_cast<int>(packet.event_type()); - } - } - - StatusOr<TcpRemoteContext*> GetContext(int context_id) { - auto it = context_map_.find(context_id); - if (it == context_map_.end()) { - return NotFoundErrorBuilder(IREE_LOC) << "Context was never registered"; - } - return it->second.get(); - } - - Status OnServiceShutdown(const rpc::ServiceShutdownEvent& event) { - LOG(INFO) << "Service is shutting down; setting pending shutdown flag"; - shutdown_pending_ = true; - return OkStatus(); - } - - Status RegisterContext(int context_id) { - auto context = absl::make_unique<TcpRemoteContext>(context_id, this); - VLOG(2) << "RegisterContext(" << context_id << ")"; - auto context_ptr = context.get(); - context_map_.insert({context_id, std::move(context)}); - contexts_.push_back(context_ptr); - return listener_->OnContextRegistered(*context_ptr); - } - - Status OnContextRegistered(const rpc::ContextRegisteredEvent& event) { - VLOG(2) << "OnContextRegistered(" << event.context_id() << ")"; - auto it = context_map_.find(event.context_id()); - if (it != context_map_.end()) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Context already registered"; - } - return RegisterContext(event.context_id()); - } - - Status OnContextUnregistered(const rpc::ContextUnregisteredEvent& event) { - VLOG(2) << "OnContextUnregistered(" << event.context_id() << ")"; - auto it = context_map_.find(event.context_id()); - if (it == context_map_.end()) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Context was never registered"; - } - auto context = std::move(it->second); - context_map_.erase(it); - auto list_it = std::find(contexts_.begin(), contexts_.end(), context.get()); - contexts_.erase(list_it); - return listener_->OnContextUnregistered(*context); - } - - Status OnModuleLoaded(const rpc::ModuleLoadedEvent& event) { - VLOG(2) << "OnModuleLoaded(" << event.context_id() << ", " - << WrapString(event.module_name()) << ")"; - ASSIGN_OR_RETURN(auto* context, GetContext(event.context_id())); - auto module_name = WrapString(event.module_name()); - auto module = absl::make_unique<TcpRemoteModule>( - event.context_id(), std::string(module_name), this); - auto* module_ptr = module.get(); - RETURN_IF_ERROR(context->AddModule(std::move(module))); - return listener_->OnModuleLoaded(*context, *module_ptr); - } - - StatusOr<TcpRemoteInvocation*> GetInvocation(int invocation_id) { - auto it = invocation_map_.find(invocation_id); - if (it == invocation_map_.end()) { - return NotFoundErrorBuilder(IREE_LOC) - << "Invocation was never registered"; - } - return it->second.get(); - } - - Status RegisterInvocation(int invocation_id) { - VLOG(2) << "RegisterInvocation(" << invocation_id << ")"; - auto invocation = - absl::make_unique<TcpRemoteInvocation>(invocation_id, this); - auto invocation_ptr = invocation.get(); - invocation_map_.insert({invocation_id, std::move(invocation)}); - invocations_.push_back(invocation_ptr); - RETURN_IF_ERROR(RefreshInvocations()); - return listener_->OnInvocationRegistered(*invocation_ptr); - } - - Status OnInvocationRegistered(const rpc::InvocationRegisteredEvent& event) { - VLOG(2) << "OnInvocationRegistered(" << event.invocation_id() << ")"; - auto it = invocation_map_.find(event.invocation_id()); - if (it != invocation_map_.end()) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Invocation already registered"; - } - return RegisterInvocation(event.invocation_id()); - } - - Status OnInvocationUnregistered( - const rpc::InvocationUnregisteredEvent& event) { - VLOG(2) << "OnInvocationUnregistered(" << event.invocation_id() << ")"; - auto it = invocation_map_.find(event.invocation_id()); - if (it == invocation_map_.end()) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Invocation was never registered"; - } - auto invocation = std::move(it->second); - invocation_map_.erase(it); - auto list_it = - std::find(invocations_.begin(), invocations_.end(), invocation.get()); - invocations_.erase(list_it); - return listener_->OnInvocationUnregistered(*invocation); - } - - StatusOr<TcpRemoteBreakpoint*> GetBreakpoint(int breakpoint_id) { - auto it = breakpoint_map_.find(breakpoint_id); - if (it == breakpoint_map_.end()) { - return NotFoundErrorBuilder(IREE_LOC) - << "Breakpoint " << breakpoint_id << " was never registered"; - } - return it->second.get(); - } - - Status RegisterBreakpoint(const rpc::BreakpointDef& breakpoint_def) { - auto it = breakpoint_map_.find(breakpoint_def.breakpoint_id()); - if (it != breakpoint_map_.end()) { - VLOG(2) << "RegisterBreakpoint(" << breakpoint_def.breakpoint_id() - << ") (update)"; - return it->second->MergeFrom(breakpoint_def); - } - - VLOG(2) << "RegisterBreakpoint(" << breakpoint_def.breakpoint_id() << ")"; - auto breakpoint = absl::make_unique<TcpRemoteBreakpoint>( - breakpoint_def.breakpoint_id(), - static_cast<RemoteBreakpoint::Type>(breakpoint_def.breakpoint_type()), - this); - RETURN_IF_ERROR(breakpoint->MergeFrom(breakpoint_def)); - breakpoints_.push_back(breakpoint.get()); - breakpoint_map_.insert({breakpoint->id(), std::move(breakpoint)}); - return OkStatus(); - } - - Status UnregisterBreakpoint(RemoteBreakpoint* breakpoint) { - VLOG(2) << "UnregisterBreakpoint(" << breakpoint->id() << ")"; - auto it = breakpoint_map_.find(breakpoint->id()); - if (it == breakpoint_map_.end()) { - return NotFoundErrorBuilder(IREE_LOC) - << "Breakpoint was never registered"; - } - breakpoint_map_.erase(it); - auto list_it = - std::find(breakpoints_.begin(), breakpoints_.end(), breakpoint); - breakpoints_.erase(list_it); - return OkStatus(); - } - - Status OnBreakpointResolved(const rpc::BreakpointResolvedEvent& event) { - VLOG(2) << "OnBreakpointResolved(" << event.breakpoint()->breakpoint_id() - << ")"; - auto it = breakpoint_map_.find(event.breakpoint()->breakpoint_id()); - if (it == breakpoint_map_.end()) { - RETURN_IF_ERROR(RegisterBreakpoint(*event.breakpoint())); - } else { - RETURN_IF_ERROR(it->second->MergeFrom(*event.breakpoint())); - } - return OkStatus(); - } - - Status OnBreakpointHit(const rpc::BreakpointHitEvent& event) { - VLOG(2) << "OnBreakpointHit(" << event.breakpoint_id() << ")"; - ASSIGN_OR_RETURN(auto* breakpoint, GetBreakpoint(event.breakpoint_id())); - auto* invocation_def = event.invocation(); - auto invocation_or = GetInvocation(invocation_def->invocation_id()); - if (!invocation_or.ok()) { - // Not found; add new. - RETURN_IF_ERROR(RegisterInvocation(invocation_def->invocation_id())); - invocation_or = GetInvocation(invocation_def->invocation_id()); - } - RETURN_IF_ERROR(invocation_or.status()); - RETURN_IF_ERROR(invocation_or.ValueOrDie()->MergeFrom(*invocation_def)); - return listener_->OnBreakpointHit(*breakpoint, *invocation_or.ValueOrDie()); - } - - Status StepInvocation(rpc::StepInvocationRequestT* step_request, - std::function<void()> callback) { - FlatBufferBuilder fbb; - auto status = IssueRequest<rpc::StepInvocationRequest, - rpc::ResponseUnion::StepInvocationResponse>( - rpc::StepInvocationRequest::Pack(fbb, step_request), std::move(fbb), - [](Status status, const rpc::Response& response_union) -> Status { - return status; - }); - RETURN_IF_ERROR(status); - pending_step_callbacks_[step_request->step_id] = std::move(callback); - return OkStatus(); - } - - Status OnStepCompleted(const rpc::StepCompletedEvent& event) { - VLOG(2) << "OnStepCompleted(" << event.step_id() << ")"; - - // Update all invocation states that are contained. - // This may only be a subset of relevant states. - for (auto* invocation_def : *event.invocations()) { - ASSIGN_OR_RETURN(auto invocation, - GetInvocation(invocation_def->invocation_id())); - RETURN_IF_ERROR(invocation->MergeFrom(*invocation_def)); - } - - // Dispatch step callback. Note that it may have been cancelled and that's - // ok. We'll just make ready to resume execution. - auto it = pending_step_callbacks_.find(event.step_id()); - if (it != pending_step_callbacks_.end()) { - it->second(); - pending_step_callbacks_.erase(it); - } else { - LOG(WARNING) << "Step " << event.step_id() - << " not found; was cancelled?"; - RETURN_IF_ERROR(MakeReady()); - } - return OkStatus(); - } - - int session_id_ = 123; - - int fd_ = -1; - Listener* listener_; - bool shutdown_pending_ = false; - std::queue<std::pair<rpc::ResponseUnion, ResponseCallback>> - pending_responses_; - - std::vector<RemoteContext*> contexts_; - absl::flat_hash_map<int, std::unique_ptr<TcpRemoteContext>> context_map_; - std::vector<RemoteInvocation*> invocations_; - absl::flat_hash_map<int, std::unique_ptr<TcpRemoteInvocation>> - invocation_map_; - std::vector<RemoteBreakpoint*> breakpoints_; - absl::flat_hash_map<int, std::unique_ptr<TcpRemoteBreakpoint>> - breakpoint_map_; - - int next_step_id_ = 1; - absl::flat_hash_map<int, std::function<void()>> pending_step_callbacks_; -}; - -} // namespace - -// static -StatusOr<std::unique_ptr<DebugClient>> DebugClient::Connect( - absl::string_view service_address, Listener* listener) { - // Parse address into hostname and port. - ASSIGN_OR_RETURN(auto hostname_port, ParseAddress(service_address)); - if (hostname_port.second == 0) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "No port specified in service address; port must match the " - "server: " - << service_address; - } - - // Attempt to resolve the address. - // Note that if we only wanted local debugging we could remove the dep on - // getaddrinfo/having a valid DNS setup. - addrinfo hints = {0}; - hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; - addrinfo* resolved_address = nullptr; - auto port_str = std::to_string(hostname_port.second); - int getaddrinfo_ret = ::getaddrinfo( - hostname_port.first.c_str(), port_str.c_str(), &hints, &resolved_address); - if (getaddrinfo_ret != 0) { - return UnavailableErrorBuilder(IREE_LOC) - << "Unable to resolve debug service address for " << service_address - << ": (" << getaddrinfo_ret << ") " - << ::gai_strerror(getaddrinfo_ret); - } - - // Attempt to connect with each address returned from the query. - int fd = -1; - for (addrinfo* rp = resolved_address; rp != nullptr; rp = rp->ai_next) { - fd = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); - if (fd == -1) continue; - if (::connect(fd, rp->ai_addr, rp->ai_addrlen) == 0) { - break; // Success! - } - ::close(fd); - fd = -1; - } - ::freeaddrinfo(resolved_address); - if (fd == -1) { - return UnavailableErrorBuilder(IREE_LOC) - << "Unable to connect to " << service_address << " on any address: (" - << errno << ") " << ::strerror(errno); - } - - LOG(INFO) << "Connected to debug service at " << service_address; - - return TcpDebugClient::Create(fd, listener); -} - -} // namespace debug -} // namespace rt -} // namespace iree
diff --git a/iree/rt/debug/debug_server.h b/iree/rt/debug/debug_server.h deleted file mode 100644 index c10e0d4..0000000 --- a/iree/rt/debug/debug_server.h +++ /dev/null
@@ -1,88 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_RT_DEBUG_DEBUG_SERVER_H_ -#define IREE_RT_DEBUG_DEBUG_SERVER_H_ - -#include "iree/base/status.h" - -namespace iree { -namespace rt { -class Context; -class Instance; -class Invocation; -class Module; -} // namespace rt -} // namespace iree - -namespace iree { -namespace rt { -namespace debug { - -// Runtime debugging server. -// Enabled only when compiled in (by defining IREE_DEBUG=1), this provides an -// RPC server that allows debuggers to attach, query, and manipulate contexts. -// This interface is used by various parts of the runtime such as dispatch to -// query the current debug state and signal events. -// -// Thread-safe. Contexts may be registered and unregistered from any thread. -class DebugServer { - public: - // Creates a new debug service listening on the provided |port|. - // Even when disabled the device can still be created however it will not - // perform any actual operations and act as if the debugger is not attached. - static StatusOr<std::unique_ptr<DebugServer>> Create(int listen_port); - - // TODO(benvanik): ensure this gets optimized out when disabled. - // Seems to be the case: https://gcc.godbolt.org/z/0zf-L4 - virtual ~DebugServer() = default; - - // Attaches a callback that will be made when the debug server is shutting - // down. This can be used to keep resources alive that require the debugger. - // The callback will be made from a random thread. - virtual void AtExit(std::function<void()> callback) = 0; - - // Blocks the caller until a client session connects and resumes all fibers. - // Returns AbortedError if a session connects/is connected but disconnects - // during the wait. - virtual Status WaitUntilSessionReady() = 0; - - protected: - friend class ::iree::rt::Instance; - - // Registers a context with the debug service. - // Ownership remains with the caller and UnregisterContext must be called - // prior to the context being destroyed. - virtual Status RegisterContext(Context* context) = 0; - virtual Status UnregisterContext(Context* context) = 0; - - friend class ::iree::rt::Context; - - // Registers a new module linked into an existing Context. - virtual Status RegisterContextModule(Context* context, Module* module) = 0; - - friend class ::iree::rt::Invocation; - - // Registers an invocation with the debug service. - // Ownership remains with the caller and UnregisterInvocation must be called - // prior to the fiber state being destroyed. - virtual Status RegisterInvocation(Invocation* invocation) = 0; - virtual Status UnregisterInvocation(Invocation* invocation) = 0; -}; - -} // namespace debug -} // namespace rt -} // namespace iree - -#endif // IREE_RT_DEBUG_DEBUG_SERVER_H_
diff --git a/iree/rt/debug/debug_server_disabled.cc b/iree/rt/debug/debug_server_disabled.cc deleted file mode 100644 index e9b9bd8..0000000 --- a/iree/rt/debug/debug_server_disabled.cc +++ /dev/null
@@ -1,28 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/rt/debug/debug_server.h" - -namespace iree { -namespace rt { -namespace debug { - -// static -StatusOr<std::unique_ptr<DebugServer>> DebugServer::Create(int listen_port) { - return {nullptr}; -} - -} // namespace debug -} // namespace rt -} // namespace iree
diff --git a/iree/rt/debug/debug_server_flags.cc b/iree/rt/debug/debug_server_flags.cc deleted file mode 100644 index a7814c0..0000000 --- a/iree/rt/debug/debug_server_flags.cc +++ /dev/null
@@ -1,80 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/rt/debug/debug_server_flags.h" - -#include "absl/flags/flag.h" -#include "absl/strings/str_cat.h" -#include "iree/base/memory.h" -#include "iree/base/status.h" - -#if defined(IREE_DEBUG_EMBEDDED_APP_PRESENT) -#include "iree/tools/debugger/debug_app_embedded.h" -#endif // IREE_DEBUG_EMBEDDED_APP_PRESENT - -ABSL_FLAG(int32_t, iree_debug_service_port, 6000, - "TCP port to listen for debug service connections."); -ABSL_FLAG(bool, iree_wait_for_debugger, false, - "Waits until a debugger connects to continue startup."); -ABSL_FLAG(bool, iree_attach_debugger, false, "Attaches a debugger at startup."); - -namespace iree { -namespace rt { -namespace debug { - -StatusOr<std::unique_ptr<DebugServer>> CreateDebugServerFromFlags() { - // Create the server based on whatever version is compiled in. - // Note that this will return nullptr if no server is available. - ASSIGN_OR_RETURN( - auto debug_server, - DebugServer::Create(absl::GetFlag(FLAGS_iree_debug_service_port))); - if (!debug_server) { - return nullptr; - } - -#if defined(IREE_DEBUG_EMBEDDED_APP_PRESENT) - // If the embedded debug UI is present then we can launch that now. - std::unique_ptr<EmbeddedDebugger> debugger; - if (absl::GetFlag(FLAGS_iree_attach_debugger)) { - LOG(INFO) << "Attaching debugger at startup..."; - ASSIGN_OR_RETURN( - debugger, - AttachDebugger(absl::StrCat( - "localhost:", absl::GetFlag(FLAGS_iree_debug_service_port)))); - RETURN_IF_ERROR(debug_server->WaitUntilSessionReady()); - LOG(INFO) << "Debugger attached"; - // TODO(benvanik): C++14 to avoid this. - auto debugger_baton = IreeMoveToLambda(debugger); - debug_server->AtExit([debugger_baton]() { debugger_baton.value.reset(); }); - } -#else - if (absl::GetFlag(FLAGS_iree_attach_debugger)) { - LOG(WARNING) << "--iree_attach_debugger specified but no embedded debugger " - "is present. Build with --define=IREE_DEBUG=1."; - } -#endif // IREE_DEBUG_EMBEDDED_APP_PRESENT - - // Wait for a debugger to connect. - if (absl::GetFlag(FLAGS_iree_wait_for_debugger)) { - LOG(INFO) << "Waiting for a debugger to connect..."; - RETURN_IF_ERROR(debug_server->WaitUntilSessionReady()); - LOG(INFO) << "Debugger ready, resuming..."; - } - - return std::move(debug_server); -} - -} // namespace debug -} // namespace rt -} // namespace iree
diff --git a/iree/rt/debug/debug_server_flags.h b/iree/rt/debug/debug_server_flags.h deleted file mode 100644 index 21ed287..0000000 --- a/iree/rt/debug/debug_server_flags.h +++ /dev/null
@@ -1,33 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_RT_DEBUG_DEBUG_SERVER_FLAGS_H_ -#define IREE_RT_DEBUG_DEBUG_SERVER_FLAGS_H_ - -#include "iree/base/status.h" -#include "iree/rt/debug/debug_server.h" - -namespace iree { -namespace rt { -namespace debug { - -// Creates a debug server based on the current --iree_* debug flags. -// Returns nullptr if no server is compiled in or the flags are not set. -StatusOr<std::unique_ptr<DebugServer>> CreateDebugServerFromFlags(); - -} // namespace debug -} // namespace rt -} // namespace iree - -#endif // IREE_RT_DEBUG_DEBUG_SERVER_FLAGS_H_
diff --git a/iree/rt/debug/debug_server_tcp.cc b/iree/rt/debug/debug_server_tcp.cc deleted file mode 100644 index 14240e7..0000000 --- a/iree/rt/debug/debug_server_tcp.cc +++ /dev/null
@@ -1,459 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <netinet/in.h> -#include <netinet/tcp.h> -#include <sys/socket.h> -#include <sys/types.h> -#include <unistd.h> - -#include <algorithm> -#include <cerrno> -#include <exception> -#include <thread> // NOLINT - -#include "absl/base/thread_annotations.h" -#include "absl/memory/memory.h" -#include "absl/synchronization/mutex.h" -#include "flatbuffers/flatbuffers.h" -#include "iree/base/status.h" -#include "iree/rt/debug/debug_server.h" -#include "iree/rt/debug/debug_service.h" -#include "iree/rt/debug/debug_tcp_util.h" -#include "iree/schemas/debug_service_generated.h" - -namespace iree { -namespace rt { -namespace debug { -namespace { - -// Writes the given typed response message to the given fd by wrapping it in -// a size-prefixed rpc::Request union. -// -// Example: -// ::flatbuffers::FlatBufferBuilder fbb; -// rpc::SuspendInvocationResponseBuilder response(fbb); -// RETURN_IF_ERROR(WriteResponse(fd_, response.Finish(), std::move(fbb))); -template <typename T> -Status WriteResponse(int fd, ::flatbuffers::Offset<T> message_offs, - ::flatbuffers::FlatBufferBuilder fbb) { - rpc::ResponseBuilder response_builder(fbb); - response_builder.add_message_type(rpc::ResponseUnionTraits<T>::enum_value); - response_builder.add_message(message_offs.Union()); - auto response_offs = response_builder.Finish(); - rpc::ServicePacketBuilder packet_builder(fbb); - packet_builder.add_response(response_offs); - fbb.FinishSizePrefixed(packet_builder.Finish()); - return tcp::WriteBuffer(fd, fbb.Release()); -} - -class TcpDebugSession : public DebugSession { - public: - using ClosedCallback = - std::function<void(TcpDebugSession* session, Status status)>; - - static StatusOr<std::unique_ptr<TcpDebugSession>> Accept( - DebugService* debug_service, int client_fd, - ClosedCallback closed_callback) { - VLOG(2) << "Client " << client_fd << ": Setting up socket options..."; - // Disable Nagel's algorithm to ensure we have low latency. - RETURN_IF_ERROR(tcp::ToggleSocketNagelsAlgorithm(client_fd, false)); - // Enable keepalive assuming the client is local and this high freq is ok. - RETURN_IF_ERROR(tcp::ToggleSocketLocalKeepalive(client_fd, true)); - // Linger around for a bit to flush all data. - RETURN_IF_ERROR(tcp::ToggleSocketLinger(client_fd, true)); - - return absl::make_unique<TcpDebugSession>(debug_service, client_fd, - std::move(closed_callback)); - } - - TcpDebugSession(DebugService* debug_service, int client_fd, - ClosedCallback closed_callback) - : debug_service_(debug_service), - client_fd_(client_fd), - closed_callback_(std::move(closed_callback)) { - CHECK_OK(debug_service_->RegisterDebugSession(this)); - session_thread_ = std::thread([this]() { SessionThread(); }); - } - - ~TcpDebugSession() override { - CHECK_OK(debug_service_->UnregisterDebugSession(this)); - VLOG(2) << "Client " << client_fd_ << ": Shutting down session socket..."; - ::shutdown(client_fd_, SHUT_RD); - if (session_thread_.joinable() && - session_thread_.get_id() != std::this_thread::get_id()) { - VLOG(2) << "Client " << client_fd_ << ": Joining socket thread..."; - session_thread_.join(); - VLOG(2) << "Client " << client_fd_ << ": Joined socket thread!"; - } else { - VLOG(2) << "Client " << client_fd_ << ": Detaching socket thread..."; - session_thread_.detach(); - } - VLOG(2) << "Client " << client_fd_ << ": Closing session socket..."; - ::close(client_fd_); - VLOG(2) << "Client " << client_fd_ << ": Closed session socket!"; - client_fd_ = -1; - } - - Status OnServiceShutdown() { - VLOG(2) << "Client " << client_fd_ << ": Post OnServiceShutdown()"; - ::flatbuffers::FlatBufferBuilder fbb; - rpc::ServiceShutdownEventBuilder event(fbb); - return PostEvent(event.Finish(), std::move(fbb)); - } - - Status OnContextRegistered(Context* context) override { - VLOG(2) << "Client " << client_fd_ << ": Post OnContextRegistered(" - << context->id() << ")"; - ::flatbuffers::FlatBufferBuilder fbb; - rpc::ContextRegisteredEventBuilder event(fbb); - event.add_context_id(context->id()); - return PostEvent(event.Finish(), std::move(fbb)); - } - Status OnContextUnregistered(Context* context) override { - VLOG(2) << "Client " << client_fd_ << ": Post OnContextUnregistered(" - << context->id() << ")"; - ::flatbuffers::FlatBufferBuilder fbb; - rpc::ContextUnregisteredEventBuilder event(fbb); - event.add_context_id(context->id()); - return PostEvent(event.Finish(), std::move(fbb)); - } - - Status OnModuleLoaded(Context* context, Module* module) override { - VLOG(2) << "Client " << client_fd_ << ": Post OnModuleLoaded(" - << context->id() << ", " << module->name() << ")"; - ::flatbuffers::FlatBufferBuilder fbb; - auto module_name_offs = - fbb.CreateString(module->name().data(), module->name().size()); - rpc::ModuleLoadedEventBuilder event(fbb); - event.add_context_id(context->id()); - event.add_module_name(module_name_offs); - return PostEvent(event.Finish(), std::move(fbb)); - } - - Status OnInvocationRegistered(Invocation* invocation) override { - VLOG(2) << "Client " << client_fd_ << ": Post OnInvocationRegistered(" - << invocation->id() << ")"; - ::flatbuffers::FlatBufferBuilder fbb; - rpc::InvocationRegisteredEventBuilder event(fbb); - event.add_invocation_id(invocation->id()); - return PostEvent(event.Finish(), std::move(fbb)); - } - Status OnInvocationUnregistered(Invocation* invocation) override { - VLOG(2) << "Client " << client_fd_ << ": Post OnInvocationUnregistered(" - << invocation->id() << ")"; - ::flatbuffers::FlatBufferBuilder fbb; - rpc::InvocationUnregisteredEventBuilder event(fbb); - event.add_invocation_id(invocation->id()); - return PostEvent(event.Finish(), std::move(fbb)); - } - - Status OnBreakpointResolved(const rpc::BreakpointDefT& breakpoint, - Context* context) override { - VLOG(2) << "Client " << client_fd_ << ": Post OnBreakpointResolved(" - << breakpoint.breakpoint_id << ", " << context->id() << ", " - << breakpoint.function_ordinal << ")"; - rpc::BreakpointResolvedEventT event; - event.breakpoint = absl::make_unique<rpc::BreakpointDefT>(); - *event.breakpoint = breakpoint; - event.context_id = context->id(); - ::flatbuffers::FlatBufferBuilder fbb; - return PostEvent(rpc::BreakpointResolvedEvent::Pack(fbb, &event), - std::move(fbb)); - } - - Status OnBreakpointHit(int breakpoint_id, - const Invocation& invocation) override { - VLOG(2) << "Client " << client_fd_ << ": Post OnBreakpointHit(" - << breakpoint_id << ", " << invocation.id() << ")"; - ::flatbuffers::FlatBufferBuilder fbb; - ASSIGN_OR_RETURN(auto invocation_offs, - debug_service_->SerializeInvocation(invocation, &fbb)); - rpc::BreakpointHitEventBuilder event(fbb); - event.add_breakpoint_id(breakpoint_id); - event.add_invocation(invocation_offs); - return PostEvent(event.Finish(), std::move(fbb)); - } - - private: - void SessionThread() { - VLOG(2) << "Client " << client_fd_ << ": Thread entry"; - Status session_status = OkStatus(); - while (session_status.ok()) { - auto buffer_or = tcp::ReadBuffer<rpc::Request>(client_fd_); - if (!buffer_or.ok()) { - if (IsCancelled(buffer_or.status())) { - // Graceful shutdown. - VLOG(2) << "Client " << client_fd_ << ": Graceful shutdown requested"; - break; - } - // Error reading. - session_status = std::move(buffer_or).status(); - LOG(ERROR) << "Client " << client_fd_ - << ": Error reading request buffer: " << session_status; - break; - } - auto request_buffer = std::move(buffer_or).ValueOrDie(); - session_status = DispatchRequest(request_buffer.GetRoot()); - if (!session_status.ok()) { - LOG(ERROR) << "Client " << client_fd_ - << ": Error dispatching request: " << session_status; - break; - } - } - VLOG(2) << "Client " << client_fd_ << ": Thread exit"; - AbortSession(session_status); - } - - void AbortSession(Status status) { - if (status.ok()) { - VLOG(2) << "Debug client disconnected"; - } else { - LOG(ERROR) << "Debug session aborted; " << status; - ::flatbuffers::FlatBufferBuilder fbb; - auto message_offs = - fbb.CreateString(status.message().data(), status.message().size()); - rpc::StatusBuilder status_builder(fbb); - status_builder.add_code(static_cast<int>(status.code())); - status_builder.add_message(message_offs); - auto status_offs = status_builder.Finish(); - rpc::ResponseBuilder response(fbb); - response.add_status(status_offs); - fbb.FinishSizePrefixed(response.Finish()); - tcp::WriteBuffer(client_fd_, fbb.Release()).IgnoreError(); - } - closed_callback_(this, std::move(status)); - } - - template <typename T> - Status PostEvent(::flatbuffers::Offset<T> event_offs, - ::flatbuffers::FlatBufferBuilder fbb) { - rpc::ServicePacketBuilder packet_builder(fbb); - packet_builder.add_event_type(rpc::EventUnionTraits<T>::enum_value); - packet_builder.add_event(event_offs.Union()); - fbb.FinishSizePrefixed(packet_builder.Finish()); - return tcp::WriteBuffer(client_fd_, fbb.Release()); - } - - Status DispatchRequest(const rpc::Request& request) { - ::flatbuffers::FlatBufferBuilder fbb; - switch (request.message_type()) { -#define DISPATCH_REQUEST(method_name) \ - case rpc::RequestUnion::method_name##Request: { \ - VLOG(2) << "Client " << client_fd_ \ - << ": DispatchRequest(" #method_name ")..."; \ - ASSIGN_OR_RETURN(auto response_offs, \ - debug_service_->method_name( \ - *request.message_as_##method_name##Request(), &fbb)); \ - return WriteResponse(client_fd_, response_offs, std::move(fbb)); \ - } - DISPATCH_REQUEST(MakeReady); - DISPATCH_REQUEST(GetStatus); - DISPATCH_REQUEST(ListContexts); - DISPATCH_REQUEST(GetModule); - DISPATCH_REQUEST(GetFunction); - DISPATCH_REQUEST(ListInvocations); - DISPATCH_REQUEST(SuspendInvocations); - DISPATCH_REQUEST(ResumeInvocations); - DISPATCH_REQUEST(StepInvocation); - DISPATCH_REQUEST(GetInvocationLocal); - DISPATCH_REQUEST(SetInvocationLocal); - DISPATCH_REQUEST(ListBreakpoints); - DISPATCH_REQUEST(AddBreakpoint); - DISPATCH_REQUEST(RemoveBreakpoint); - DISPATCH_REQUEST(StartProfiling); - DISPATCH_REQUEST(StopProfiling); - default: - return UnimplementedErrorBuilder(IREE_LOC) - << "Unimplemented debug service request: " - << static_cast<int>(request.message_type()); - } - } - - DebugService* debug_service_; - int client_fd_; - ClosedCallback closed_callback_; - std::thread session_thread_; -}; - -class TcpDebugServer final : public DebugServer { - public: - static StatusOr<std::unique_ptr<TcpDebugServer>> Listen(int port) { - // We support both IPv4 and IPv6 by using the IN6ADDR_ANY. This requires - // that we setup the socket as INET6 and enable reuse (so the same port can - // be bound for both IPv4 and IPv6). - int listen_fd = ::socket(AF_INET6, SOCK_STREAM, 0); - RETURN_IF_ERROR(tcp::ToggleSocketAddressReuse(listen_fd, true)); - - struct sockaddr_in6 socket_addr = {0}; - socket_addr.sin6_family = AF_INET6; - socket_addr.sin6_port = htons(port); - socket_addr.sin6_addr = in6addr_any; - if (::bind(listen_fd, reinterpret_cast<struct sockaddr*>(&socket_addr), - sizeof(socket_addr)) < 0) { - return AlreadyExistsErrorBuilder(IREE_LOC) - << "Unable to bind socket to port " << port << ": (" << errno - << ") " << ::strerror(errno); - } - if (::listen(listen_fd, 1)) { - ::close(listen_fd); - return AlreadyExistsErrorBuilder(IREE_LOC) - << "Unable to listen on port " << port << ": (" << errno << ") " - << ::strerror(errno); - } - return absl::make_unique<TcpDebugServer>(listen_fd); - } - - TcpDebugServer(int listen_fd) : listen_fd_(listen_fd) { - server_thread_ = std::thread([this]() { ListenThread(); }); - } - - ~TcpDebugServer() ABSL_LOCKS_EXCLUDED(mutex_) override { - absl::ReleasableMutexLock lock(&mutex_); - LOG(INFO) << "Shutting down debug server..."; - - // Notify all sessions. - for (auto& session : sessions_) { - session->OnServiceShutdown().IgnoreError(); - } - - // Shut down listen socket first so that we can't accept new connections. - VLOG(2) << "Shutting down listen socket..."; - ::shutdown(listen_fd_, SHUT_RDWR); - if (server_thread_.joinable()) { - VLOG(2) << "Joining listen thread..."; - server_thread_.join(); - VLOG(2) << "Joined listen thread!"; - } - VLOG(2) << "Closing listen socket..."; - ::close(listen_fd_); - listen_fd_ = -1; - VLOG(2) << "Closed listen socket!"; - - // Kill all active sessions. Note that we must do this outside of our lock. - std::vector<std::unique_ptr<TcpDebugSession>> sessions = - std::move(sessions_); - std::vector<std::function<void()>> at_exit_callbacks = - std::move(at_exit_callbacks_); - lock.Release(); - VLOG(2) << "Clearing live sessions..."; - sessions.clear(); - VLOG(2) << "Calling AtExit callbacks..."; - for (auto& callback : at_exit_callbacks) { - callback(); - } - LOG(INFO) << "Debug server shutdown!"; - } - - DebugService* debug_service() { return &debug_service_; } - - Status AcceptNewSession(int client_fd) { - LOG(INFO) << "Accepting new client session as " << client_fd; - ASSIGN_OR_RETURN(auto session, - TcpDebugSession::Accept( - &debug_service_, client_fd, - [this](TcpDebugSession* session, Status status) { - absl::MutexLock lock(&mutex_); - for (auto it = sessions_.begin(); - it != sessions_.end(); ++it) { - if (it->get() == session) { - sessions_.erase(it); - break; - } - } - return OkStatus(); - })); - - absl::MutexLock lock(&mutex_); - sessions_.push_back(std::move(session)); - return OkStatus(); - } - - void AtExit(std::function<void()> callback) override { - absl::MutexLock lock(&mutex_); - at_exit_callbacks_.push_back(std::move(callback)); - } - - Status WaitUntilSessionReady() override { - return debug_service_.WaitUntilAllSessionsReady(); - } - - protected: - Status RegisterContext(Context* context) override { - return debug_service_.RegisterContext(context); - } - Status UnregisterContext(Context* context) override { - return debug_service_.UnregisterContext(context); - } - Status RegisterContextModule(Context* context, Module* module) override { - return debug_service_.RegisterContextModule(context, module); - } - Status RegisterInvocation(Invocation* invocation) override { - return debug_service_.RegisterInvocation(invocation); - } - Status UnregisterInvocation(Invocation* invocation) override { - return debug_service_.UnregisterInvocation(invocation); - } - - private: - void ListenThread() { - VLOG(2) << "Listen thread entry"; - while (true) { - struct sockaddr_in accept_socket_addr; - socklen_t accept_socket_addr_length = sizeof(accept_socket_addr); - int accepted_fd = ::accept( - listen_fd_, reinterpret_cast<struct sockaddr*>(&accept_socket_addr), - &accept_socket_addr_length); - if (accepted_fd < 0) { - if (errno == EINVAL) { - // Shutting down gracefully. - break; - } - // We may be able to recover from some of these cases, but... shrug. - LOG(FATAL) << "Failed to accept client socket: (" << errno << ") " - << ::strerror(errno); - break; - } - auto accept_status = AcceptNewSession(accepted_fd); - if (!accept_status.ok()) { - LOG(ERROR) << "Failed to accept incoming debug client: " - << accept_status; - } - } - VLOG(2) << "Listen thread exit"; - } - - int listen_fd_; - std::thread server_thread_; - - absl::Mutex mutex_; - std::vector<std::unique_ptr<TcpDebugSession>> sessions_ - ABSL_GUARDED_BY(mutex_); - std::vector<std::function<void()>> at_exit_callbacks_ ABSL_GUARDED_BY(mutex_); - - DebugService debug_service_; -}; - -} // namespace - -// static -StatusOr<std::unique_ptr<DebugServer>> DebugServer::Create(int listen_port) { - ASSIGN_OR_RETURN(auto debug_server, TcpDebugServer::Listen(listen_port)); - LOG(INFO) << "Debug server listening on localhost:" << listen_port; - return debug_server; -} - -} // namespace debug -} // namespace rt -} // namespace iree
diff --git a/iree/rt/debug/debug_service.cc b/iree/rt/debug/debug_service.cc deleted file mode 100644 index 9c1f7ac..0000000 --- a/iree/rt/debug/debug_service.cc +++ /dev/null
@@ -1,850 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/rt/debug/debug_service.h" - -#include <algorithm> -#include <memory> - -#include "absl/strings/str_join.h" -#include "absl/synchronization/mutex.h" -#include "flatbuffers/flatbuffers.h" -#include "flatbuffers/reflection.h" -#include "iree/base/flatbuffer_util.h" -#include "iree/base/source_location.h" -#include "iree/base/status.h" -#include "iree/rt/instance.h" -#include "iree/schemas/debug_service_generated.h" -#include "iree/schemas/reflection_data.h" - -namespace iree { -namespace rt { -namespace debug { -namespace { - -using ::flatbuffers::FlatBufferBuilder; -using ::flatbuffers::Offset; -using ::iree::hal::BufferView; - -int32_t NextUniqueBreakpointId() { - static std::atomic<int32_t> next_id = 0; - return ++next_id; -} - -// Gets an embedded flatbuffers reflection schema. -const ::reflection::Schema& GetSchema(const char* schema_name) { - for (const auto* file_toc = schemas::reflection_data_create(); - file_toc != nullptr; ++file_toc) { - if (std::strcmp(file_toc->name, schema_name) == 0) { - return *::reflection::GetSchema(file_toc->data); - } - } - LOG(FATAL) << "FlatBuffer schema '" << schema_name - << "' not found in binary; ensure it is in :reflection_data"; -} - -// Recursively copies a flatbuffer table, returning the root offset in |fbb|. -template <typename T> -StatusOr<Offset<T>> DeepCopyTable(const char* schema_name, const T& table_def, - FlatBufferBuilder* fbb) { - const auto* root_table = - reinterpret_cast<const ::flatbuffers::Table*>(std::addressof(table_def)); - const auto& schema = GetSchema(schema_name); - return {::flatbuffers::CopyTable(*fbb, schema, *schema.root_table(), - *root_table, - /*use_string_pooling=*/false) - .o}; -} - -// Serializes a buffer_view value, optionally including the entire buffer -// contents. -StatusOr<Offset<rpc::BufferViewDef>> SerializeBufferView( - const BufferView& buffer_view, bool include_buffer_contents, - FlatBufferBuilder* fbb) { - auto shape_offs = fbb->CreateVector(buffer_view.shape.subspan().data(), - buffer_view.shape.subspan().size()); - rpc::BufferViewDefBuilder value(*fbb); - value.add_is_valid(buffer_view.buffer != nullptr); - value.add_shape(shape_offs); - value.add_element_size(buffer_view.element_size); - if (include_buffer_contents) { - // TODO(benvanik): add buffer data. - } - return value.Finish(); -} - -// Serializes a stack frame. -StatusOr<Offset<rpc::StackFrameDef>> SerializeStackFrame( - const StackFrame& stack_frame, FlatBufferBuilder* fbb) { - ASSIGN_OR_RETURN(int function_ordinal, - stack_frame.module().function_table().LookupFunctionOrdinal( - stack_frame.function())); - auto module_name_offs = fbb->CreateString(stack_frame.module().name().data(), - stack_frame.module().name().size()); - std::vector<Offset<rpc::BufferViewDef>> local_offs_list; - for (const auto& local : stack_frame.locals()) { - ASSIGN_OR_RETURN( - auto local_offs, - SerializeBufferView(local, /*include_buffer_contents=*/false, fbb)); - local_offs_list.push_back(local_offs); - } - auto locals_offs = fbb->CreateVector(local_offs_list); - rpc::StackFrameDefBuilder sfb(*fbb); - sfb.add_module_name(module_name_offs); - sfb.add_function_ordinal(function_ordinal); - sfb.add_offset(stack_frame.offset()); - sfb.add_locals(locals_offs); - return sfb.Finish(); -} - -// Resolves a local from a invocation:frame:local_index to a BufferView. -StatusOr<BufferView*> ResolveInvocationLocal(Invocation* invocation, - int frame_index, int local_index) { - auto frames = invocation->mutable_stack()->mutable_frames(); - if (frame_index < 0 || frame_index > frames.size()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Frame index " << frame_index << " out of bounds (" - << frames.size() << ")"; - } - auto locals = frames[frame_index].mutable_locals(); - if (local_index < 0 || local_index > locals.size()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Local index " << local_index << " out of bounds (" - << locals.size() << ")"; - } - return &locals[local_index]; -} - -// Suspends a set of invocations and blocks until all have been suspended (or -// one or more fails to suspend). This works only when the caller is *not* one -// of the threads executing a invocation in |invocations| (this normally -// shouldn't happen, but may if we support eval()-like semantics). -Status SuspendInvocationsAndWait(absl::Span<Invocation*> invocations) { - absl::Mutex suspend_mutex; - Status one_suspend_status = OkStatus(); - std::list<int> pending_suspend_ids; - for (auto* invocation : invocations) { - pending_suspend_ids.push_back(invocation->id()); - } - for (auto* invocation : invocations) { - auto suspend_callback = [&, invocation](Status suspend_status) { - absl::MutexLock lock(&suspend_mutex); - auto it = std::find(pending_suspend_ids.begin(), - pending_suspend_ids.end(), invocation->id()); - CHECK(it != pending_suspend_ids.end()); - pending_suspend_ids.erase(it); - if (!suspend_status.ok()) { - one_suspend_status = std::move(suspend_status); - } - }; - RETURN_IF_ERROR(invocation->Suspend(suspend_callback)); - } - suspend_mutex.LockWhen(absl::Condition( - +[](std::list<int>* pending_suspend_ids) { - return pending_suspend_ids->empty(); - }, - &pending_suspend_ids)); - suspend_mutex.Unlock(); - return one_suspend_status; -} - -} // namespace - -Status DebugService::SuspendAllInvocations() { - VLOG(2) << "SuspendAllInvocations"; - for (auto* invocation : invocations_) { - RETURN_IF_ERROR(invocation->Suspend()); - } - return OkStatus(); -} - -Status DebugService::ResumeAllInvocations() { - VLOG(2) << "ResumeAllInvocations"; - for (auto* invocation : invocations_) { - RETURN_IF_ERROR(invocation->Resume()); - } - return OkStatus(); -} - -Status DebugService::RegisterContext(Context* context) { - absl::MutexLock lock(&mutex_); - VLOG(2) << "RegisterContext(" << context->id() << ")"; - RETURN_IF_ERROR(SuspendAllInvocations()); - RETURN_IF_ERROR(UnreadyAllSessions()); - contexts_.push_back(context); - for (auto* session : sessions_) { - RETURN_IF_ERROR(session->OnContextRegistered(context)); - } - RETURN_IF_ERROR(ResumeAllInvocations()); - return OkStatus(); -} - -Status DebugService::UnregisterContext(Context* context) { - absl::MutexLock lock(&mutex_); - VLOG(2) << "UnregisterContext(" << context->id() << ")"; - auto it = std::find(contexts_.begin(), contexts_.end(), context); - if (it == contexts_.end()) { - return NotFoundErrorBuilder(IREE_LOC) << "Context not registered"; - } - RETURN_IF_ERROR(SuspendAllInvocations()); - RETURN_IF_ERROR(UnreadyAllSessions()); - for (auto* session : sessions_) { - RETURN_IF_ERROR(session->OnContextUnregistered(context)); - } - contexts_.erase(it); - RETURN_IF_ERROR(ResumeAllInvocations()); - return OkStatus(); -} - -StatusOr<Context*> DebugService::GetContext(int context_id) const { - for (auto* context : contexts_) { - if (context->id() == context_id) { - return context; - } - } - return NotFoundErrorBuilder(IREE_LOC) - << "Context with ID " << context_id - << " not registered with the debug service"; -} - -Status DebugService::RegisterContextModule(Context* context, Module* module) { - absl::MutexLock lock(&mutex_); - VLOG(2) << "RegisterContextModule(" << context->id() << ", " << module->name() - << ")"; - RETURN_IF_ERROR(SuspendAllInvocations()); - RETURN_IF_ERROR(UnreadyAllSessions()); - RETURN_IF_ERROR(RegisterModuleBreakpoints(context, module)); - for (auto* session : sessions_) { - RETURN_IF_ERROR(session->OnModuleLoaded(context, module)); - } - RETURN_IF_ERROR(ResumeAllInvocations()); - return OkStatus(); -} - -StatusOr<Module*> DebugService::GetModule(int context_id, - absl::string_view module_name) const { - ASSIGN_OR_RETURN(auto* context, GetContext(context_id)); - for (const auto& module : context->modules()) { - if (module->name() == module_name) { - return module.get(); - } - } - return NotFoundErrorBuilder(IREE_LOC) - << "Module '" << module_name << "' not found on context " - << context_id; -} - -Status DebugService::RegisterInvocation(Invocation* invocation) { - absl::MutexLock lock(&mutex_); - VLOG(2) << "RegisterInvocation(" << invocation->id() << ")"; - RETURN_IF_ERROR(SuspendAllInvocations()); - RETURN_IF_ERROR(UnreadyAllSessions()); - invocations_.push_back(invocation); - if (sessions_unready_) { - // Suspend immediately as a debugger is not yet read. - RETURN_IF_ERROR(invocation->Suspend()); - } - for (auto* session : sessions_) { - RETURN_IF_ERROR(session->OnInvocationRegistered(invocation)); - } - RETURN_IF_ERROR(ResumeAllInvocations()); - return OkStatus(); -} - -Status DebugService::UnregisterInvocation(Invocation* invocation) { - absl::MutexLock lock(&mutex_); - VLOG(2) << "UnregisterInvocation(" << invocation->id() << ")"; - auto it = std::find(invocations_.begin(), invocations_.end(), invocation); - if (it == invocations_.end()) { - return NotFoundErrorBuilder(IREE_LOC) << "Invocation state not registered"; - } - RETURN_IF_ERROR(SuspendAllInvocations()); - RETURN_IF_ERROR(UnreadyAllSessions()); - for (auto* session : sessions_) { - RETURN_IF_ERROR(session->OnInvocationUnregistered(invocation)); - } - invocations_.erase(it); - RETURN_IF_ERROR(ResumeAllInvocations()); - return OkStatus(); -} - -StatusOr<Invocation*> DebugService::GetInvocation(int invocation_id) const { - for (auto* invocation : invocations_) { - if (invocation->id() == invocation_id) { - return invocation; - } - } - return NotFoundErrorBuilder(IREE_LOC) - << "Invocation state with ID " << invocation_id - << " not registered with the debug service"; -} - -StatusOr<Offset<rpc::InvocationDef>> DebugService::SerializeInvocation( - const Invocation& invocation, FlatBufferBuilder* fbb) { - std::vector<Offset<rpc::StackFrameDef>> frame_offs_list; - for (const auto& frame : invocation.stack().frames()) { - ASSIGN_OR_RETURN(auto frame_offs, SerializeStackFrame(frame, fbb)); - frame_offs_list.push_back(frame_offs); - } - auto frames_offs = fbb->CreateVector(frame_offs_list); - rpc::InvocationDefBuilder fsb(*fbb); - fsb.add_invocation_id(invocation.id()); - fsb.add_frames(frames_offs); - return fsb.Finish(); -} - -Status DebugService::RegisterDebugSession(DebugSession* session) { - absl::MutexLock lock(&mutex_); - VLOG(2) << "RegisterDebugSession(" << session->id() << ")"; - sessions_.push_back(session); - if (session->is_ready()) { - ++sessions_ready_; - } else { - // Immediately suspend all invocations until the session readies up (or - // disconnects). - ++sessions_unready_; - RETURN_IF_ERROR(SuspendAllInvocations()); - } - return OkStatus(); -} - -Status DebugService::UnregisterDebugSession(DebugSession* session) { - absl::MutexLock lock(&mutex_); - VLOG(2) << "UnregisterDebugSession(" << session->id() << ")"; - auto it = std::find(sessions_.begin(), sessions_.end(), session); - if (it == sessions_.end()) { - return NotFoundErrorBuilder(IREE_LOC) << "Session not registered"; - } - sessions_.erase(it); - if (session->is_ready()) { - --sessions_ready_; - } else { - // If the session never readied up then we still have all invocations - // suspended waiting for it. We should resume so that we don't block - // forever. - --sessions_unready_; - RETURN_IF_ERROR(ResumeAllInvocations()); - } - return OkStatus(); -} - -Status DebugService::WaitUntilAllSessionsReady() { - VLOG(1) << "Waiting until all sessions are ready..."; - struct CondState { - DebugService* service; - bool had_sessions; - bool consider_aborted; - } cond_state; - { - absl::MutexLock lock(&mutex_); - cond_state.service = this; - cond_state.had_sessions = !sessions_.empty(); - cond_state.consider_aborted = false; - } - mutex_.LockWhen(absl::Condition( - +[](CondState* cond_state) { - cond_state->service->mutex_.AssertHeld(); - if (cond_state->service->sessions_ready_ > 0) { - // One or more sessions are ready. - return true; - } - if (cond_state->service->sessions_unready_ > 0) { - // One or more sessions are connected but not yet ready. - cond_state->had_sessions = true; - return false; - } - if (cond_state->had_sessions && - cond_state->service->sessions_.empty()) { - // We had sessions but now we don't, consider this an error and bail. - // This can happen when a session connects but never readies up. - cond_state->consider_aborted = true; - return true; - } - return false; - }, - &cond_state)); - mutex_.Unlock(); - if (cond_state.consider_aborted) { - return AbortedErrorBuilder(IREE_LOC) - << "At least one session connected but never readied up"; - } - VLOG(1) << "Sessions ready, resuming"; - return OkStatus(); -} - -StatusOr<Offset<rpc::MakeReadyResponse>> DebugService::MakeReady( - const rpc::MakeReadyRequest& request, FlatBufferBuilder* fbb) { - absl::MutexLock lock(&mutex_); - VLOG(1) << "RPC: MakeReady()"; - // TODO(benvanik): support more than one session. - CHECK_LE(sessions_.size(), 1) << "Only one session is currently supported"; - if (!sessions_.empty()) { - RETURN_IF_ERROR(sessions_[0]->OnReady()); - } - sessions_ready_ = 0; - sessions_unready_ = 0; - for (auto* session : sessions_) { - sessions_ready_ += session->is_ready() ? 1 : 0; - sessions_unready_ += session->is_ready() ? 0 : 1; - } - rpc::MakeReadyResponseBuilder response(*fbb); - return response.Finish(); -} - -Status DebugService::UnreadyAllSessions() { - for (auto* session : sessions_) { - RETURN_IF_ERROR(session->OnUnready()); - } - sessions_ready_ = 0; - sessions_unready_ = sessions_.size(); - return OkStatus(); -} - -StatusOr<Offset<rpc::GetStatusResponse>> DebugService::GetStatus( - const rpc::GetStatusRequest& request, FlatBufferBuilder* fbb) { - absl::MutexLock lock(&mutex_); - VLOG(1) << "RPC: GetStatus()"; - rpc::GetStatusResponseBuilder response(*fbb); - response.add_protocol(0); - return response.Finish(); -} - -StatusOr<Offset<rpc::ListContextsResponse>> DebugService::ListContexts( - const rpc::ListContextsRequest& request, FlatBufferBuilder* fbb) { - absl::MutexLock lock(&mutex_); - VLOG(1) << "RPC: ListContexts()"; - std::vector<Offset<rpc::ContextDef>> context_offs; - for (auto* context : contexts_) { - std::vector<Offset<rpc::NativeFunctionDef>> native_function_offs_list; - for (const auto& pair : context->native_functions()) { - auto name_offs = fbb->CreateString(pair.first); - rpc::NativeFunctionDefBuilder native_function(*fbb); - native_function.add_name(name_offs); - native_function_offs_list.push_back(native_function.Finish()); - } - auto native_functions_offs = fbb->CreateVector(native_function_offs_list); - - std::vector<std::string> module_names; - for (const auto& module : context->modules()) { - module_names.push_back(std::string(module->name())); - } - auto module_names_offs = fbb->CreateVectorOfStrings(module_names); - - rpc::ContextDefBuilder context_def(*fbb); - context_def.add_context_id(context->id()); - context_def.add_native_functions(native_functions_offs); - context_def.add_module_names(module_names_offs); - context_offs.push_back(context_def.Finish()); - } - - auto contexts_offs = fbb->CreateVector(context_offs); - rpc::ListContextsResponseBuilder response(*fbb); - response.add_contexts(contexts_offs); - return response.Finish(); -} - -StatusOr<Offset<rpc::GetModuleResponse>> DebugService::GetModule( - const rpc::GetModuleRequest& request, FlatBufferBuilder* fbb) { - absl::MutexLock lock(&mutex_); - VLOG(1) << "RPC: GetModule(" << request.context_id() << ", " - << WrapString(request.module_name()) << ")"; - ASSIGN_OR_RETURN(auto* module, GetModule(request.context_id(), - WrapString(request.module_name()))); - // TODO(benvanik): find a way to do this without possibly duping all memory. - // I suspect that when we make constants poolable then there's only one - // place to kill and there may be magic we could use to do that during a - // reflection pass. - ModuleDefT module_t; - module->def().UnPackTo(&module_t); - for (auto& function : module_t.function_table->functions) { - function->bytecode->contents.clear(); - } - auto trimmed_module_offs = ModuleDef::Pack(*fbb, &module_t); - rpc::GetModuleResponseBuilder response(*fbb); - response.add_module_(trimmed_module_offs); - return response.Finish(); -} - -StatusOr<Offset<rpc::GetFunctionResponse>> DebugService::GetFunction( - const rpc::GetFunctionRequest& request, FlatBufferBuilder* fbb) { - absl::MutexLock lock(&mutex_); - VLOG(1) << "RPC: GetFunction(" << WrapString(request.module_name()) << ", " - << request.function_ordinal() << ")"; - ASSIGN_OR_RETURN(auto* module, GetModule(request.context_id(), - WrapString(request.module_name()))); - ASSIGN_OR_RETURN(auto& function, module->function_table().LookupFunction( - request.function_ordinal())); - Offset<BytecodeDef> bytecode_offs; - if (function.def().bytecode()) { - ASSIGN_OR_RETURN( - bytecode_offs, - DeepCopyTable("bytecode_def.bfbs", *function.def().bytecode(), fbb)); - } - rpc::GetFunctionResponseBuilder response(*fbb); - response.add_bytecode(bytecode_offs); - return response.Finish(); -} - -StatusOr<Offset<rpc::ResolveFunctionResponse>> DebugService::ResolveFunction( - const rpc::ResolveFunctionRequest& request, FlatBufferBuilder* fbb) { - absl::MutexLock lock(&mutex_); - VLOG(1) << "RPC: ResolveFunction(" << WrapString(request.module_name()) - << ", " << WrapString(request.function_name()) << ")"; - std::vector<int32_t> context_ids; - auto context_ids_offs = fbb->CreateVector(context_ids); - int function_ordinal = -1; - for (auto* context : contexts_) { - for (const auto& module : context->modules()) { - if (module->name() == WrapString(request.module_name())) { - ASSIGN_OR_RETURN(function_ordinal, - module->function_table().LookupFunctionOrdinalByName( - WrapString(request.function_name()))); - context_ids.push_back(context->id()); - break; - } - } - } - rpc::ResolveFunctionResponseBuilder response(*fbb); - response.add_context_ids(context_ids_offs); - response.add_function_ordinal(function_ordinal); - return response.Finish(); -} - -StatusOr<Offset<rpc::ListInvocationsResponse>> DebugService::ListInvocations( - const rpc::ListInvocationsRequest& request, FlatBufferBuilder* fbb) { - absl::MutexLock lock(&mutex_); - VLOG(1) << "RPC: ListInvocations()"; - std::vector<Offset<rpc::InvocationDef>> invocation_offsets; - for (auto* invocation : invocations_) { - ASSIGN_OR_RETURN(auto invocation_offs, - SerializeInvocation(*invocation, fbb)); - invocation_offsets.push_back(invocation_offs); - } - auto invocations_offs = fbb->CreateVector(invocation_offsets); - rpc::ListInvocationsResponseBuilder response(*fbb); - response.add_invocations(invocations_offs); - return response.Finish(); -} - -StatusOr<Offset<rpc::SuspendInvocationsResponse>> -DebugService::SuspendInvocations(const rpc::SuspendInvocationsRequest& request, - FlatBufferBuilder* fbb) { - absl::MutexLock lock(&mutex_); - VLOG(1) << "RPC: SuspendInvocations(invocation_ids=[" - << (request.invocation_ids() - ? absl::StrJoin(*request.invocation_ids(), ", ") - : "") - << "])"; - std::vector<Offset<rpc::InvocationDef>> invocation_offsets; - if (request.invocation_ids() && request.invocation_ids()->size() > 0) { - // Suspending a list of invocations. - std::vector<Invocation*> invocations_to_suspend; - for (int invocation_id : *request.invocation_ids()) { - ASSIGN_OR_RETURN(auto* invocation, GetInvocation(invocation_id)); - invocations_to_suspend.push_back(invocation); - } - RETURN_IF_ERROR( - SuspendInvocationsAndWait(absl::MakeSpan(invocations_to_suspend))); - for (auto* invocation : invocations_to_suspend) { - ASSIGN_OR_RETURN(auto invocation_offs, - SerializeInvocation(*invocation, fbb)); - invocation_offsets.push_back(invocation_offs); - } - } else { - // Suspending all invocations. - RETURN_IF_ERROR(SuspendAllInvocations()); - for (auto* invocation : invocations_) { - ASSIGN_OR_RETURN(auto invocation_offs, - SerializeInvocation(*invocation, fbb)); - invocation_offsets.push_back(invocation_offs); - } - } - auto invocations_offs = fbb->CreateVector(invocation_offsets); - rpc::SuspendInvocationsResponseBuilder response(*fbb); - response.add_invocations(invocations_offs); - return response.Finish(); -} - -StatusOr<Offset<rpc::ResumeInvocationsResponse>> -DebugService::ResumeInvocations(const rpc::ResumeInvocationsRequest& request, - FlatBufferBuilder* fbb) { - VLOG(1) << "RPC: ResumeInvocations(invocation_ids=[" - << (request.invocation_ids() - ? absl::StrJoin(*request.invocation_ids(), ", ") - : "") - << "])"; - absl::MutexLock lock(&mutex_); - if (request.invocation_ids() && request.invocation_ids()->size() > 0) { - // Resuming a list of invocations. - for (int invocation_id : *request.invocation_ids()) { - ASSIGN_OR_RETURN(auto* invocation, GetInvocation(invocation_id)); - RETURN_IF_ERROR(invocation->Resume()); - } - } else { - // Resuming all invocations. - RETURN_IF_ERROR(ResumeAllInvocations()); - } - rpc::ResumeInvocationsResponseBuilder response(*fbb); - return response.Finish(); -} - -StatusOr<Offset<rpc::StepInvocationResponse>> DebugService::StepInvocation( - const rpc::StepInvocationRequest& request, FlatBufferBuilder* fbb) { - absl::MutexLock lock(&mutex_); - VLOG(1) << "RPC: StepInvocation(" << request.invocation_id() << ")"; - ASSIGN_OR_RETURN(auto* invocation, GetInvocation(request.invocation_id())); - Invocation::StepTarget step_target; - // TODO(benvanik): step settings. - RETURN_IF_ERROR(invocation->Step(step_target)); - rpc::StepInvocationResponseBuilder response(*fbb); - return response.Finish(); -} - -StatusOr<Offset<rpc::GetInvocationLocalResponse>> -DebugService::GetInvocationLocal(const rpc::GetInvocationLocalRequest& request, - FlatBufferBuilder* fbb) { - absl::MutexLock lock(&mutex_); - VLOG(1) << "RPC: GetInvocationLocal(" << request.invocation_id() << ", " - << request.frame_index() << ", " << request.local_index() << ")"; - ASSIGN_OR_RETURN(auto* invocation, GetInvocation(request.invocation_id())); - ASSIGN_OR_RETURN(auto* local, - ResolveInvocationLocal(invocation, request.frame_index(), - request.local_index())); - - ASSIGN_OR_RETURN( - auto value_offs, - SerializeBufferView(*local, /*include_buffer_contents=*/true, fbb)); - rpc::GetInvocationLocalResponseBuilder response(*fbb); - response.add_value(value_offs); - return response.Finish(); -} - -StatusOr<Offset<rpc::SetInvocationLocalResponse>> -DebugService::SetInvocationLocal(const rpc::SetInvocationLocalRequest& request, - FlatBufferBuilder* fbb) { - absl::MutexLock lock(&mutex_); - VLOG(1) << "RPC: SetInvocationLocal(" << request.invocation_id() << ", " - << request.frame_index() << ", " << request.local_index() << ")"; - ASSIGN_OR_RETURN(auto* invocation, GetInvocation(request.invocation_id())); - ASSIGN_OR_RETURN(auto* local, - ResolveInvocationLocal(invocation, request.frame_index(), - request.local_index())); - - if (!request.value()) { - local->shape.clear(); - local->element_size = 0; - local->buffer.reset(); - } else { - const auto& value = *request.value(); - local->shape.clear(); - if (value.shape()) { - for (int dim : *value.shape()) { - local->shape.push_back(dim); - } - } - local->element_size = value.element_size(); - // TODO(benvanik): copy buffer data. - } - - ASSIGN_OR_RETURN( - auto value_offs, - SerializeBufferView(*local, /*include_buffer_contents=*/true, fbb)); - rpc::SetInvocationLocalResponseBuilder response(*fbb); - response.add_value(value_offs); - return response.Finish(); -} - -StatusOr<Offset<rpc::ListBreakpointsResponse>> DebugService::ListBreakpoints( - const rpc::ListBreakpointsRequest& request, FlatBufferBuilder* fbb) { - absl::MutexLock lock(&mutex_); - VLOG(1) << "RPC: ListBreakpoints()"; - std::vector<Offset<rpc::BreakpointDef>> breakpoint_offs; - for (const auto& breakpoint : breakpoints_) { - breakpoint_offs.push_back(rpc::BreakpointDef::Pack(*fbb, &breakpoint)); - } - auto breakpoints_offs = fbb->CreateVector(breakpoint_offs); - rpc::ListBreakpointsResponseBuilder response(*fbb); - response.add_breakpoints(breakpoints_offs); - return response.Finish(); -} - -StatusOr<Offset<rpc::AddBreakpointResponse>> DebugService::AddBreakpoint( - const rpc::AddBreakpointRequest& request, FlatBufferBuilder* fbb) { - if (!request.breakpoint()) { - return InvalidArgumentErrorBuilder(IREE_LOC) << "No breakpoint specified"; - } - absl::MutexLock lock(&mutex_); - int breakpoint_id = NextUniqueBreakpointId(); - VLOG(1) << "RPC: AddBreakpoint(" << breakpoint_id << ")"; - - RETURN_IF_ERROR(SuspendAllInvocations()); - - rpc::BreakpointDefT breakpoint; - request.breakpoint()->UnPackTo(&breakpoint); - breakpoint.breakpoint_id = breakpoint_id; - switch (breakpoint.breakpoint_type) { - case rpc::BreakpointType::BYTECODE_FUNCTION: - case rpc::BreakpointType::NATIVE_FUNCTION: - for (auto* context : contexts_) { - auto module_or = context->LookupModule(breakpoint.module_name); - if (!module_or.ok()) continue; - auto* module = module_or.ValueOrDie(); - RETURN_IF_ERROR( - RegisterFunctionBreakpoint(context, module, &breakpoint)); - } - break; - default: - return UnimplementedErrorBuilder(IREE_LOC) << "Unhandled breakpoint type"; - } - breakpoints_.push_back(std::move(breakpoint)); - - RETURN_IF_ERROR(ResumeAllInvocations()); - - auto breakpoint_offs = rpc::BreakpointDef::Pack(*fbb, &breakpoints_.back()); - rpc::AddBreakpointResponseBuilder response(*fbb); - response.add_breakpoint(breakpoint_offs); - return response.Finish(); -} - -StatusOr<Offset<rpc::RemoveBreakpointResponse>> DebugService::RemoveBreakpoint( - const rpc::RemoveBreakpointRequest& request, FlatBufferBuilder* fbb) { - absl::MutexLock lock(&mutex_); - VLOG(1) << "RPC: RemoveBreakpoint(" << request.breakpoint_id() << ")"; - RETURN_IF_ERROR(SuspendAllInvocations()); - - bool found = false; - for (auto it = breakpoints_.begin(); it != breakpoints_.end(); ++it) { - if (it->breakpoint_id == request.breakpoint_id()) { - auto& breakpoint = *it; - found = true; - switch (breakpoint.breakpoint_type) { - case rpc::BreakpointType::BYTECODE_FUNCTION: - case rpc::BreakpointType::NATIVE_FUNCTION: - RETURN_IF_ERROR(UnregisterFunctionBreakpoint(breakpoint)); - break; - default: - return UnimplementedErrorBuilder(IREE_LOC) - << "Unhandled breakpoint type"; - } - breakpoints_.erase(it); - break; - } - } - - RETURN_IF_ERROR(ResumeAllInvocations()); - if (!found) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Breakpoint ID " << request.breakpoint_id() << " not found"; - } - - rpc::RemoveBreakpointResponseBuilder response(*fbb); - return response.Finish(); -} - -Status DebugService::RegisterModuleBreakpoints(Context* context, - Module* module) { - for (auto& breakpoint : breakpoints_) { - switch (breakpoint.breakpoint_type) { - case rpc::BreakpointType::BYTECODE_FUNCTION: - if (breakpoint.module_name == module->name()) { - RETURN_IF_ERROR( - RegisterFunctionBreakpoint(context, module, &breakpoint)); - } - break; - default: - // Not relevant to modules. - break; - } - } - return OkStatus(); -} - -Status DebugService::RegisterFunctionBreakpoint( - Context* context, Module* module, rpc::BreakpointDefT* breakpoint) { - if (!breakpoint->function_name.empty()) { - ASSIGN_OR_RETURN(breakpoint->function_ordinal, - module->function_table().LookupFunctionOrdinalByName( - breakpoint->function_name)); - } - RETURN_IF_ERROR(module->mutable_function_table()->RegisterBreakpoint( - breakpoint->function_ordinal, breakpoint->bytecode_offset, - std::bind(&DebugService::OnFunctionBreakpointHit, this, - breakpoint->breakpoint_id, std::placeholders::_1))); - for (auto* session : sessions_) { - RETURN_IF_ERROR(session->OnBreakpointResolved(*breakpoint, context)); - } - return OkStatus(); -} - -Status DebugService::UnregisterFunctionBreakpoint( - const rpc::BreakpointDefT& breakpoint) { - for (auto* context : contexts_) { - auto module_or = context->LookupModule(breakpoint.module_name); - if (!module_or.ok()) continue; - auto* module = module_or.ValueOrDie(); - RETURN_IF_ERROR(module->mutable_function_table()->UnregisterBreakpoint( - breakpoint.function_ordinal, breakpoint.bytecode_offset)); - } - return OkStatus(); -} - -Status DebugService::OnFunctionBreakpointHit(int breakpoint_id, - const Invocation& invocation) { - absl::ReleasableMutexLock lock(&mutex_); - LOG(INFO) << "Breakpoint hit: " << breakpoint_id; - RETURN_IF_ERROR(UnreadyAllSessions()); - for (auto* session : sessions_) { - RETURN_IF_ERROR(session->OnBreakpointHit(breakpoint_id, invocation)); - } - lock.Release(); - - // TODO(benvanik): on-demand attach if desired? - - // Wait until all clients are ready. - auto wait_status = WaitUntilAllSessionsReady(); - if (IsAborted(wait_status)) { - // This means we lost all sessions. Just continue. - VLOG(1) << "No sessions active; ignoring breakpoint and continuing"; - return OkStatus(); - } - return wait_status; -} - -StatusOr<Offset<rpc::StartProfilingResponse>> DebugService::StartProfiling( - const rpc::StartProfilingRequest& request, FlatBufferBuilder* fbb) { - absl::MutexLock lock(&mutex_); - VLOG(1) << "RPC: StartProfiling()"; - // TODO(benvanik): implement profiling. - // ASSIGN_OR_RETURN(auto* context, GetContext(request.context_id())); - // rpc::StartProfilingResponseBuilder response(*fbb); - // return response.Finish(); - return UnimplementedErrorBuilder(IREE_LOC) - << "StartProfiling not yet implemented"; -} - -StatusOr<Offset<rpc::StopProfilingResponse>> DebugService::StopProfiling( - const rpc::StopProfilingRequest& request, FlatBufferBuilder* fbb) { - absl::MutexLock lock(&mutex_); - VLOG(1) << "RPC: StopProfiling()"; - // TODO(benvanik): implement profiling. - // ASSIGN_OR_RETURN(auto* context, GetContext(request.context_id())); - // rpc::StopProfilingResponseBuilder response(*fbb); - // return response.Finish(); - return UnimplementedErrorBuilder(IREE_LOC) - << "StopProfiling not yet implemented"; -} - -} // namespace debug -} // namespace rt -} // namespace iree
diff --git a/iree/rt/debug/debug_service.h b/iree/rt/debug/debug_service.h deleted file mode 100644 index 47f17ab..0000000 --- a/iree/rt/debug/debug_service.h +++ /dev/null
@@ -1,175 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_RT_DEBUG_DEBUG_SERVICE_H_ -#define IREE_RT_DEBUG_DEBUG_SERVICE_H_ - -#include <vector> - -#include "absl/base/thread_annotations.h" -#include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" -#include "flatbuffers/flatbuffers.h" -#include "iree/base/status.h" -#include "iree/rt/context.h" -#include "iree/rt/debug/debug_session.h" -#include "iree/schemas/debug_service_generated.h" - -namespace iree { -namespace rt { -namespace debug { - -// Debugging service used to implement the DebugService RPC methods in a -// transport-independent way. Specific DebugServer implementations can compose -// with a DebugService to avoid needing to maintain state themselves. Multiple -// DebugServer instances could share the same DebugService instance to ensure -// all clients - regardless of transport - share the same state. -// -// Thread-safe. -class DebugService { - public: - // Registers a context with the debug service. - // Ownership remains with the caller and UnregisterContext must be called - // prior to the context being destroyed. - Status RegisterContext(Context* context); - Status UnregisterContext(Context* context); - - // Registers a new module linked into an existing Context. - Status RegisterContextModule(Context* context, Module* module); - - // Registers a invocation state with the debug service. - // Ownership remains with the caller and UnregisterInvocation must be called - // prior to the invocation state being destroyed. - Status RegisterInvocation(Invocation* invocation); - Status UnregisterInvocation(Invocation* invocation); - - // Registers a debug session with the service. - Status RegisterDebugSession(DebugSession* session); - Status UnregisterDebugSession(DebugSession* session); - - // Blocks the caller until all sessions are ready. - // Returns AbortedError if a session connects/is already connected but - // disconnects during the wait. - Status WaitUntilAllSessionsReady(); - - StatusOr<::flatbuffers::Offset<rpc::MakeReadyResponse>> MakeReady( - const rpc::MakeReadyRequest& request, - ::flatbuffers::FlatBufferBuilder* fbb); - - StatusOr<::flatbuffers::Offset<rpc::GetStatusResponse>> GetStatus( - const rpc::GetStatusRequest& request, - ::flatbuffers::FlatBufferBuilder* fbb); - - StatusOr<::flatbuffers::Offset<rpc::ListContextsResponse>> ListContexts( - const rpc::ListContextsRequest& request, - ::flatbuffers::FlatBufferBuilder* fbb); - - StatusOr<::flatbuffers::Offset<rpc::GetModuleResponse>> GetModule( - const rpc::GetModuleRequest& request, - ::flatbuffers::FlatBufferBuilder* fbb); - StatusOr<::flatbuffers::Offset<rpc::GetFunctionResponse>> GetFunction( - const rpc::GetFunctionRequest& request, - ::flatbuffers::FlatBufferBuilder* fbb); - StatusOr<::flatbuffers::Offset<rpc::ResolveFunctionResponse>> ResolveFunction( - const rpc::ResolveFunctionRequest& request, - ::flatbuffers::FlatBufferBuilder* fbb); - - StatusOr<::flatbuffers::Offset<rpc::ListInvocationsResponse>> ListInvocations( - const rpc::ListInvocationsRequest& request, - ::flatbuffers::FlatBufferBuilder* fbb); - StatusOr<::flatbuffers::Offset<rpc::SuspendInvocationsResponse>> - SuspendInvocations(const rpc::SuspendInvocationsRequest& request, - ::flatbuffers::FlatBufferBuilder* fbb); - StatusOr<::flatbuffers::Offset<rpc::ResumeInvocationsResponse>> - ResumeInvocations(const rpc::ResumeInvocationsRequest& request, - ::flatbuffers::FlatBufferBuilder* fbb); - StatusOr<::flatbuffers::Offset<rpc::StepInvocationResponse>> StepInvocation( - const rpc::StepInvocationRequest& request, - ::flatbuffers::FlatBufferBuilder* fbb); - StatusOr<::flatbuffers::Offset<rpc::GetInvocationLocalResponse>> - GetInvocationLocal(const rpc::GetInvocationLocalRequest& request, - ::flatbuffers::FlatBufferBuilder* fbb); - StatusOr<::flatbuffers::Offset<rpc::SetInvocationLocalResponse>> - SetInvocationLocal(const rpc::SetInvocationLocalRequest& request, - ::flatbuffers::FlatBufferBuilder* fbb); - - StatusOr<::flatbuffers::Offset<rpc::ListBreakpointsResponse>> ListBreakpoints( - const rpc::ListBreakpointsRequest& request, - ::flatbuffers::FlatBufferBuilder* fbb); - StatusOr<::flatbuffers::Offset<rpc::AddBreakpointResponse>> AddBreakpoint( - const rpc::AddBreakpointRequest& request, - ::flatbuffers::FlatBufferBuilder* fbb); - StatusOr<::flatbuffers::Offset<rpc::RemoveBreakpointResponse>> - RemoveBreakpoint(const rpc::RemoveBreakpointRequest& request, - ::flatbuffers::FlatBufferBuilder* fbb); - - StatusOr<::flatbuffers::Offset<rpc::StartProfilingResponse>> StartProfiling( - const rpc::StartProfilingRequest& request, - ::flatbuffers::FlatBufferBuilder* fbb); - StatusOr<::flatbuffers::Offset<rpc::StopProfilingResponse>> StopProfiling( - const rpc::StopProfilingRequest& request, - ::flatbuffers::FlatBufferBuilder* fbb); - - // Serializes an invocation and its stack frames. - StatusOr<::flatbuffers::Offset<rpc::InvocationDef>> SerializeInvocation( - const Invocation& invocation, ::flatbuffers::FlatBufferBuilder* fbb); - - private: - StatusOr<Context*> GetContext(int context_id) const - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - StatusOr<Module*> GetModule(int context_id, - absl::string_view module_name) const - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - StatusOr<Invocation*> GetInvocation(int invocation_id) const - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Suspends all invocations on all contexts. Returns only once all invocations - // have been suspended successfully. Fails if any invocation fails to suspend. - Status SuspendAllInvocations() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Resumes all invocations on all contexts (the inverse of - // SuspendAllInvocations). Returns immediately. - Status ResumeAllInvocations() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Marks all sessions as unready. - Status UnreadyAllSessions() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Attempts to re-register all breakpoints for a module. - Status RegisterModuleBreakpoints(Context* context, Module* module) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - Status RegisterFunctionBreakpoint(Context* context, Module* module, - rpc::BreakpointDefT* breakpoint) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - Status UnregisterFunctionBreakpoint(const rpc::BreakpointDefT& breakpoint) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - // Signals that the given breakpoint was hit by the specified invocation. - // Called without the debug lock held. - Status OnFunctionBreakpointHit(int breakpoint_id, - const Invocation& invocation); - - absl::Mutex mutex_; - std::vector<Context*> contexts_ ABSL_GUARDED_BY(mutex_); - std::vector<Invocation*> invocations_ ABSL_GUARDED_BY(mutex_); - std::vector<DebugSession*> sessions_ ABSL_GUARDED_BY(mutex_); - int sessions_unready_ ABSL_GUARDED_BY(mutex_) = 0; - int sessions_ready_ ABSL_GUARDED_BY(mutex_) = 0; - - std::vector<rpc::BreakpointDefT> breakpoints_ ABSL_GUARDED_BY(mutex_); -}; - -} // namespace debug -} // namespace rt -} // namespace iree - -#endif // IREE_RT_DEBUG_DEBUG_SERVICE_H_
diff --git a/iree/rt/debug/debug_session.cc b/iree/rt/debug/debug_session.cc deleted file mode 100644 index 83a28da..0000000 --- a/iree/rt/debug/debug_session.cc +++ /dev/null
@@ -1,49 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/rt/debug/debug_session.h" - -#include "iree/base/source_location.h" -#include "iree/base/status.h" - -namespace iree { -namespace rt { -namespace debug { - -bool DebugSession::is_ready() const { - absl::MutexLock lock(&mutex_); - return ready_ == 0; -} - -Status DebugSession::OnReady() { - absl::MutexLock lock(&mutex_); - if (ready_ > 0) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Session has already readied up"; - } - ++ready_; - VLOG(2) << "Session " << id() << ": ++ready = " << ready_; - return OkStatus(); -} - -Status DebugSession::OnUnready() { - absl::MutexLock lock(&mutex_); - --ready_; - VLOG(2) << "Session " << id() << ": --ready = " << ready_; - return OkStatus(); -} - -} // namespace debug -} // namespace rt -} // namespace iree
diff --git a/iree/rt/debug/debug_session.h b/iree/rt/debug/debug_session.h deleted file mode 100644 index f6cd87c..0000000 --- a/iree/rt/debug/debug_session.h +++ /dev/null
@@ -1,93 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_RT_DEBUG_DEBUG_SESSION_H_ -#define IREE_RT_DEBUG_DEBUG_SESSION_H_ - -#include "absl/base/thread_annotations.h" -#include "absl/synchronization/mutex.h" -#include "iree/base/status.h" -#include "iree/rt/context.h" -#include "iree/rt/invocation.h" -#include "iree/rt/module.h" -#include "iree/schemas/debug_service_generated.h" - -namespace iree { -namespace rt { -namespace debug { - -// An active debugging session maintained by the DebugService. -// Each connected client gets a session and transport-specific implementations -// use the event methods to receive signals from the service. -// -// All methods are called only while the debug lock is held and may be called -// from any thread. -class DebugSession { - public: - virtual ~DebugSession() = default; - - // Session ID used in all RPCs related to this session. - // This can be used for attributing RPCs to the originating session when - // multiple sessions may be active at a time/over the same transport. - int id() const { return session_id_; } - - // Returns true if the session has issued a MakeReady request and is ok if - // execution resumes. - bool is_ready() const; - - // Signals that the session has readied up and is now active. - // Called with the global debug lock held. - virtual Status OnReady(); - - // Signals that the session has gone unready (from an event/etc) and the - // service is now awaiting it to ready up. - // Called with the global debug lock held. - virtual Status OnUnready(); - - // Signals that a context has been registered. - // Called with the global debug lock held. - virtual Status OnContextRegistered(Context* context) = 0; - virtual Status OnContextUnregistered(Context* context) = 0; - - // Signals that a module has been loaded in a context. - // Called with the global debug lock held. - virtual Status OnModuleLoaded(Context* context, Module* module) = 0; - - // Signals that a invocation has been registered. - // Called with the global debug lock held. - virtual Status OnInvocationRegistered(Invocation* invocation) = 0; - virtual Status OnInvocationUnregistered(Invocation* invocation) = 0; - - // Signals that a breakpoint has been resolved to a particular function in a - // context. - // Called with the global debug lock held. - virtual Status OnBreakpointResolved(const rpc::BreakpointDefT& breakpoint, - Context* context) = 0; - - // Signals that the given breakpoint has been hit during execution. - // Called with the global debug lock held. - virtual Status OnBreakpointHit(int breakpoint_id, - const Invocation& invocation) = 0; - - private: - mutable absl::Mutex mutex_; - int session_id_ = 0; - int ready_ ABSL_GUARDED_BY(mutex_) = -1; -}; - -} // namespace debug -} // namespace rt -} // namespace iree - -#endif // IREE_RT_DEBUG_DEBUG_SESSION_H_
diff --git a/iree/rt/debug/debug_tcp_util.h b/iree/rt/debug/debug_tcp_util.h deleted file mode 100644 index d9c33bf..0000000 --- a/iree/rt/debug/debug_tcp_util.h +++ /dev/null
@@ -1,217 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Utilities for working with TCP sockets. -// These are (mostly) portable to systems implementing BSD sockets. - -#ifndef IREE_RT_DEBUG_DEBUG_TCP_UTIL_H_ -#define IREE_RT_DEBUG_DEBUG_TCP_UTIL_H_ - -#include <fcntl.h> -#include <netinet/in.h> -#include <netinet/tcp.h> -#include <sys/socket.h> -#include <sys/types.h> - -#include <cstddef> - -#include "flatbuffers/base.h" -#include "flatbuffers/flatbuffers.h" -#include "iree/base/status.h" -#include "iree/schemas/debug_service_generated.h" - -namespace iree { -namespace rt { -namespace debug { -namespace tcp { - -// Toggles address reuse on a socket. Call prior to binding. -// This is useful if a socket is sitting in close_wait from a previous process -// while a new one is trying to bind to it. -inline Status ToggleSocketAddressReuse(int fd, bool is_enabled) { - int toggle = is_enabled ? 1 : 0; - ::setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &toggle, sizeof(toggle)); - return OkStatus(); -} - -// Toggles the linger option on a socket. -// Enabling linger will ensure all data on the socket is sent (if it can be -// sent within N sec) prior to closing. Disabling linger will cause the socket -// to close gracefully. -inline Status ToggleSocketLinger(int fd, bool is_enabled) { - struct linger linger; - linger.l_onoff = is_enabled ? 1 : 0; - linger.l_linger = 1; - ::setsockopt(fd, SOL_SOCKET, SO_LINGER, &linger, sizeof(linger)); - return OkStatus(); -} - -// Toggles Nagel's algorithm on a socket. -// Enabled by default, sockets have ~250ms delay for small packets. Disabling -// the algorithm will make socket flushes actually send data. -inline Status ToggleSocketNagelsAlgorithm(int fd, bool is_enabled) { - int toggle = is_enabled ? 1 : 0; - ::setsockopt(fd, SOL_TCP, TCP_NODELAY, &toggle, sizeof(toggle)); - return OkStatus(); -} - -// Toggles TCP keepalive on a socket. -// Assumes that the remote side is on the local machine/network and that we can -// spam it with packets. -// -// NOTE: we may want to adjust this when real debuggers are attached (to prevent -// dropping our own connections). Need to figure out how to reliably detect -// debug suspends vs. actual death. -inline Status ToggleSocketLocalKeepalive(int fd, bool is_enabled) { - // Toggle keepalive. - int keepalive_enable = is_enabled ? 1 : 0; - ::setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, &keepalive_enable, - sizeof(keepalive_enable)); - // Begin sending keepalive probes after N sec. - int keepalive_idle_delay = 3; - ::setsockopt(fd, SOL_TCP, TCP_KEEPIDLE, &keepalive_idle_delay, - sizeof(keepalive_idle_delay)); - // Try one probe and bail (faster detection). - int keepalive_retry_count = 1; - ::setsockopt(fd, SOL_TCP, TCP_KEEPINTVL, &keepalive_retry_count, - sizeof(keepalive_retry_count)); - // Send keepalives every N sec. - int keepalive_interval = 1; - ::setsockopt(fd, SOL_TCP, TCP_KEEPINTVL, &keepalive_interval, - sizeof(keepalive_interval)); - return OkStatus(); -} - -// Toggles the blocking state of a socket. -// If a socket has been set to non-blocking methods like read and write will -// return EWOULDBLOCK if they would have blocked on the specific operation. -inline Status ToggleSocketBlocking(int fd, bool is_blocking) { - if (is_blocking) { - ::fcntl(fd, F_SETFL, ::fcntl(fd, F_GETFL) & ~O_NONBLOCK); - } else { - ::fcntl(fd, F_SETFL, ::fcntl(fd, F_GETFL) | O_NONBLOCK); - } - return OkStatus(); -} - -// RAII wrapper for messages containing flatbuffer roots of type T. -template <typename T> -struct MessageBuffer { - public: - explicit MessageBuffer(std::vector<uint8_t> buffer) - : buffer_(std::move(buffer)) {} - MessageBuffer(const MessageBuffer&) = delete; - MessageBuffer& operator=(const MessageBuffer&) = delete; - MessageBuffer(MessageBuffer&&) = default; - MessageBuffer& operator=(MessageBuffer&&) = default; - - const T& GetRoot() const { - return *::flatbuffers::GetRoot<T>(buffer_.data()); - } - - private: - std::vector<uint8_t> buffer_; -}; - -// Reads a size prefix value from the given fd. -// If |poll_only| is true then the size prefix is not consumed from the stream -// and the call will return 0 if there is no size prefix available. -// Returns CancelledError if a (probably) graceful close is detected. -inline StatusOr<size_t> ReadSizePrefix(int fd, bool poll_only) { - ::flatbuffers::uoffset_t size_prefix = 0; - int read_bytes = ::recv(fd, &size_prefix, sizeof(size_prefix), - poll_only ? (MSG_PEEK | MSG_DONTWAIT) : 0); - if (read_bytes == 0) { - // Remote side disconnected. - return CancelledErrorBuilder(IREE_LOC) << "Graceful remote close"; - } else if (read_bytes < 0) { - if (errno == ECONNRESET) { - return CancelledErrorBuilder(IREE_LOC) << "Ungraceful remote close"; - } - return DataLossErrorBuilder(IREE_LOC) - << "Failed to read size prefix from socket: (" << errno << ") " - << ::strerror(errno); - } else if (read_bytes != sizeof(size_prefix)) { - if (poll_only) { - // No data available. - return 0; - } else { - return DataLossErrorBuilder(IREE_LOC) - << "Failed to read full size prefix (got " << read_bytes << "b of " - << sizeof(size_prefix) << "b expected)"; - } - } - return size_prefix; -} - -// Returns true if ReadBuffer will (likely) not block when called. -// Returns CancelledError if a (probably) graceful close is detected. -inline StatusOr<bool> CanReadBuffer(int fd) { - ASSIGN_OR_RETURN(size_t size_prefix, ReadSizePrefix(fd, /*poll_only=*/true)); - return size_prefix != 0; -} - -// Reads a size-prefixed message from the given fd. -// This will block until the entire message contents are available. -// Returns a buffer reference that will deallocate the buffer automatically or -// CancelledError if a (probably) graceful close is detected. -template <typename T> -StatusOr<MessageBuffer<T>> ReadBuffer(int fd) { - // Read the size prefix (written as a uoffset_t by the Write* methods). - ASSIGN_OR_RETURN(size_t size_prefix, ReadSizePrefix(fd, /*poll_only=*/false)); - - // Allocate the buffer for the entire message. - // We'll use the BufferRef to free() it when it's no longer required. - std::vector<uint8_t> buffer(size_prefix); - - // Read the entire message contents. - int full_read_bytes = ::recv(fd, buffer.data(), buffer.size(), 0); - if (full_read_bytes < 0) { - return DataLossErrorBuilder(IREE_LOC) - << "Failed to read full message contents from socket: (" << errno - << ") " << ::strerror(errno); - } else if (full_read_bytes != buffer.size()) { - return DataLossErrorBuilder(IREE_LOC) - << "Failed to read full message contents (got " << full_read_bytes - << "b of " << buffer.size() << "b expected)"; - } - - // Verify the contents. Not strictly required (as we won't ever ship this to - // prod), but useful in ensuring our socket code isn't corrupting things. - ::flatbuffers::Verifier verifier(buffer.data(), buffer.size()); - if (!verifier.VerifyBuffer<T>()) { - return DataLossErrorBuilder(IREE_LOC) - << "Verification of input buffer of type " << typeid(T).name() - << " (" << buffer.size() << "b) failed"; - } - - // Wrap the buffer to get some RAII goodness. - return MessageBuffer<T>(std::move(buffer)); -} - -// Writes a buffer to the given fd. -inline Status WriteBuffer(int fd, ::flatbuffers::DetachedBuffer buffer) { - if (::send(fd, buffer.data(), buffer.size(), 0) < 0) { - return UnavailableErrorBuilder(IREE_LOC) - << "Write failed: (" << errno << ") " << ::strerror(errno); - } - return OkStatus(); -} - -} // namespace tcp -} // namespace debug -} // namespace rt -} // namespace iree - -#endif // IREE_RT_DEBUG_DEBUG_TCP_UTIL_H_
diff --git a/iree/rt/disassembler.h b/iree/rt/disassembler.h deleted file mode 100644 index 52d91ee..0000000 --- a/iree/rt/disassembler.h +++ /dev/null
@@ -1,64 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_RT_DISASSEMBLER_H_ -#define IREE_RT_DISASSEMBLER_H_ - -#include <cstdint> -#include <ostream> -#include <vector> - -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" -#include "iree/base/status.h" -#include "iree/rt/function.h" -#include "iree/rt/source_location.h" - -namespace iree { -namespace rt { - -// A single disassembled instruction. -struct Instruction { - // Offset of the instruction within the function. - // The meaning of this is backend-dependent. - SourceOffset offset; - - // The first line of |long_text|. - absl::string_view short_text; - - // Human-readable text of the instruction. May contain multiple lines. - std::string long_text; -}; - -// Disassembles functions into instructions. -// -// Thread-safe. -class Disassembler { - public: - virtual ~Disassembler() = default; - - // Disassembles one or more instructions within the given function based on - // source offsets. - virtual StatusOr<std::vector<Instruction>> DisassembleInstructions( - const Function& function, SourceOffset offset, - int32_t instruction_count = INT32_MAX) const = 0; - - protected: - Disassembler() = default; -}; - -} // namespace rt -} // namespace iree - -#endif // IREE_RT_DISASSEMBLER_H_
diff --git a/iree/rt/function.cc b/iree/rt/function.cc deleted file mode 100644 index c32fbce..0000000 --- a/iree/rt/function.cc +++ /dev/null
@@ -1,33 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/rt/function.h" - -#include "iree/rt/module.h" - -namespace iree { -namespace rt { - -absl::string_view Function::name() const { - auto result_or = module_->GetFunctionName(linkage_, ordinal_); - return result_or.ok() ? result_or.ValueOrDie() : absl::string_view(); -} - -const FunctionSignature Function::signature() const { - auto result_or = module_->GetFunctionSignature(linkage_, ordinal_); - return result_or.ok() ? result_or.ValueOrDie() : FunctionSignature(); -} - -} // namespace rt -} // namespace iree
diff --git a/iree/rt/function.h b/iree/rt/function.h deleted file mode 100644 index 9acc0e6..0000000 --- a/iree/rt/function.h +++ /dev/null
@@ -1,84 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_RT_FUNCTION_H_ -#define IREE_RT_FUNCTION_H_ - -#include <cstdint> -#include <ostream> - -#include "absl/strings/string_view.h" -#include "iree/rt/function_signature.h" - -namespace iree { -namespace rt { - -class Module; - -// Reference to a function within a module. -// Functions are either visible or hidden from the module interface and may be -// of one Linkage type. Imports and exports are always visible (as they are -// required for dynamic linking) however functions with internal linkage may be -// hidden in optimized builds to reduce the amount of reflection metadata -// required. -class Function final { - public: - enum class Linkage { - // Function is internal to the module and may not be reflectable. - kInternal = 0, - // Function is an import from another module. - kImport = 1, - // Function is an export from the module. - kExport = 2, - }; - - Function() = default; - Function(const Module* module, Linkage linkage, int32_t ordinal) - : module_(module), linkage_(linkage), ordinal_(ordinal) {} - - // Module the function is contained within. - const Module* module() const { return module_; } - - // Linkage of the function. Note that Linkage::kInternal functions may be - // missing reflection information. - Linkage linkage() const { return linkage_; } - - // Ordinal within the module in the linkage scope. - int32_t ordinal() const { return ordinal_; } - - // Returns the original name of the function. - // Internal functions may return empty if debugging info has been stripped. - absl::string_view name() const; - - // Returns the signature of the function. - // Always present for imports and exports but may be empty for internal - // functions if debugging info has been stripped. - const FunctionSignature signature() const; - - private: - const Module* module_ = nullptr; - Linkage linkage_ = Linkage::kInternal; - int32_t ordinal_ = -1; -}; - -inline std::ostream& operator<<(std::ostream& stream, - const Function& function) { - stream << '@' << function.name() << '#' << function.ordinal(); - return stream; -} - -} // namespace rt -} // namespace iree - -#endif // IREE_RT_FUNCTION_H_
diff --git a/iree/rt/instance.cc b/iree/rt/instance.cc deleted file mode 100644 index d793bc7..0000000 --- a/iree/rt/instance.cc +++ /dev/null
@@ -1,51 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/rt/instance.h" - -#include "iree/base/tracing.h" -#include "iree/rt/debug/debug_server.h" - -namespace iree { -namespace rt { - -Instance::Instance(std::unique_ptr<debug::DebugServer> debug_server) - : debug_server_(std::move(debug_server)) { - IREE_TRACE_SCOPE0("Instance::ctor"); -} - -Instance::~Instance() { IREE_TRACE_SCOPE0("Instance::dtor"); } - -void Instance::RegisterContext(Context* context) { - { - absl::MutexLock lock(&contexts_mutex_); - contexts_.push_back(context); - } - if (debug_server_) { - CHECK_OK(debug_server_->RegisterContext(context)); - } -} - -void Instance::UnregisterContext(Context* context) { - if (debug_server_) { - CHECK_OK(debug_server_->UnregisterContext(context)); - } - { - absl::MutexLock lock(&contexts_mutex_); - contexts_.erase(context); - } -} - -} // namespace rt -} // namespace iree
diff --git a/iree/rt/instance.h b/iree/rt/instance.h deleted file mode 100644 index 61fd544..0000000 --- a/iree/rt/instance.h +++ /dev/null
@@ -1,70 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_RT_INSTANCE_H_ -#define IREE_RT_INSTANCE_H_ - -#include "absl/base/thread_annotations.h" -#include "absl/synchronization/mutex.h" -#include "iree/base/intrusive_list.h" -#include "iree/base/ref_ptr.h" -#include "iree/hal/device_manager.h" -#include "iree/rt/context.h" -#include "iree/rt/debug/debug_server.h" - -namespace iree { -namespace rt { - -// Shared runtime instance responsible for routing Context events, enumerating -// and creating hardware device interfaces, and managing thread pools. -// -// A single runtime instance can service multiple contexts and hosting -// applications should try to reuse instances as much as possible. This ensures -// that resource allocation across contexts is handled and extraneous device -// interaction is avoided. For devices that may have exclusive access -// restrictions it is mandatory to share instances, so plan accordingly. -// -// Thread-safe. -class Instance final : public RefObject<Instance> { - public: - // Creates an instance with an optional attached |debug_server|. - Instance() : Instance(nullptr) {} - explicit Instance(std::unique_ptr<debug::DebugServer> debug_server); - ~Instance(); - Instance(const Instance&) = delete; - Instance& operator=(const Instance&) = delete; - - // Optional debug server that has access to contexts in this instance. - debug::DebugServer* debug_server() const { return debug_server_.get(); } - - // Device manager used to enumerate available devices. - hal::DeviceManager* device_manager() const { return &device_manager_; } - - private: - friend class Context; - void RegisterContext(Context* context); - void UnregisterContext(Context* context); - - std::unique_ptr<debug::DebugServer> debug_server_; - mutable hal::DeviceManager device_manager_; - - absl::Mutex contexts_mutex_; - IntrusiveList<Context, offsetof(Context, instance_list_link_)> contexts_ - ABSL_GUARDED_BY(contexts_mutex_); -}; - -} // namespace rt -} // namespace iree - -#endif // IREE_RT_INSTANCE_H_
diff --git a/iree/rt/invocation.cc b/iree/rt/invocation.cc deleted file mode 100644 index 738c65b..0000000 --- a/iree/rt/invocation.cc +++ /dev/null
@@ -1,185 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/rt/invocation.h" - -#include <atomic> -#include <iterator> - -#include "absl/strings/str_cat.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/rt/context.h" - -namespace iree { -namespace rt { - -namespace { - -int32_t NextUniqueInvocationId() { - static std::atomic<int32_t> next_id = {0}; - return ++next_id; -} - -} // namespace - -// static -StatusOr<ref_ptr<Invocation>> Invocation::Create( - ref_ptr<Context> context, const Function function, ref_ptr<Policy> policy, - absl::InlinedVector<ref_ptr<Invocation>, 4> dependencies, - absl::InlinedVector<hal::BufferView, 8> arguments, - absl::optional<absl::InlinedVector<hal::BufferView, 8>> results) { - IREE_TRACE_SCOPE0("Invocation::Create"); - - const auto& signature = function.signature(); - if (arguments.size() != signature.argument_count()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Argument count mismatch; expected " << signature.argument_count() - << " but received " << arguments.size(); - } else if (results.has_value() && - results.value().size() != signature.result_count()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Result count mismatch; expected " << signature.result_count() - << " but received " << results.value().size(); - } - - absl::InlinedVector<hal::BufferView, 8> results_value; - if (results.has_value()) { - results_value = std::move(results.value()); - } else { - results_value.resize(signature.result_count()); - } - - auto invocation = assign_ref( - new Invocation(std::move(context), function, std::move(policy))); - - // TODO(benvanik): grab execution state, insert deps, etc. - if (!dependencies.empty()) { - return UnimplementedErrorBuilder(IREE_LOC) - << "Dependencies are not yet implemented"; - } - - // TODO(benvanik): fiber scheduling and such. - auto execute_status = function.module()->Execute( - &invocation->stack_, function, std::move(arguments), &results_value); - if (execute_status.ok()) { - invocation->CompleteSuccess(std::move(results_value)); - } else { - invocation->CompleteFailure(std::move(execute_status), nullptr); - } - - return invocation; -} - -// static -StatusOr<ref_ptr<Invocation>> Invocation::Create( - ref_ptr<Context> context, const Function function, ref_ptr<Policy> policy, - absl::Span<const ref_ptr<Invocation>> dependencies, - absl::Span<const hal::BufferView> arguments) { - absl::InlinedVector<ref_ptr<Invocation>, 4> dependency_list; - dependency_list.reserve(dependencies.size()); - for (auto& dependency : dependencies) { - dependency_list.push_back(add_ref(dependency)); - } - absl::InlinedVector<hal::BufferView, 8> argument_list; - argument_list.reserve(arguments.size()); - for (auto& buffer_view : arguments) { - argument_list.push_back(buffer_view); - } - return Invocation::Create(std::move(context), function, std::move(policy), - std::move(dependency_list), - std::move(argument_list)); -} - -Invocation::Invocation(ref_ptr<Context> context, const Function function, - ref_ptr<Policy> policy) - : id_(NextUniqueInvocationId()), - context_(std::move(context)), - function_(function), - policy_(std::move(policy)), - stack_(context_.get()) { - IREE_TRACE_SCOPE0("Invocation::ctor"); - context_->RegisterInvocation(this); -} - -Invocation::~Invocation() { - IREE_TRACE_SCOPE0("Invocation::dtor"); - context_->UnregisterInvocation(this); -} - -std::string Invocation::DebugStringShort() const { - return absl::StrCat("invocation_", id_); -} - -std::string Invocation::DebugString() const { return DebugStringShort(); } - -Status Invocation::QueryStatus() { - IREE_TRACE_SCOPE0("Invocation::QueryStatus"); - absl::MutexLock lock(&status_mutex_); - return completion_status_; -} - -StatusOr<absl::InlinedVector<hal::BufferView, 8>> Invocation::ConsumeResults() { - IREE_TRACE_SCOPE0("Invocation::ConsumeResults"); - absl::MutexLock lock(&status_mutex_); - if (!completion_status_.ok()) { - return completion_status_; - } - return std::move(results_); -} - -Status Invocation::Await(absl::Time deadline) { - IREE_TRACE_SCOPE0("Invocation::Await"); - absl::MutexLock lock(&status_mutex_); - // TODO(benvanik): implement async invocation behavior. - return completion_status_; -} - -Status Invocation::Abort() { - IREE_TRACE_SCOPE0("Invocation::Abort"); - // TODO(benvanik): implement async invocation behavior. - return UnimplementedErrorBuilder(IREE_LOC) - << "Async invocations not yet implemented"; -} - -void Invocation::CompleteSuccess( - absl::InlinedVector<hal::BufferView, 8> results) { - IREE_TRACE_SCOPE0("Invocation::CompleteSuccess"); - absl::MutexLock lock(&status_mutex_); - if (IsAborted(completion_status_)) { - // Ignore as the invocation was already aborted prior to completion. - return; - } - DCHECK(IsUnavailable(completion_status_)); - completion_status_ = OkStatus(); - failure_stack_trace_.reset(); - results_ = std::move(results); -} - -void Invocation::CompleteFailure( - Status completion_status, std::unique_ptr<StackTrace> failure_stack_trace) { - IREE_TRACE_SCOPE0("Invocation::CompleteFailure"); - absl::MutexLock lock(&status_mutex_); - if (IsAborted(completion_status_)) { - // Ignore as the invocation was already aborted prior to completion. - return; - } - DCHECK(IsUnavailable(completion_status_)); - completion_status_ = std::move(completion_status); - failure_stack_trace_ = std::move(failure_stack_trace); - results_.clear(); -} - -} // namespace rt -} // namespace iree
diff --git a/iree/rt/invocation.h b/iree/rt/invocation.h deleted file mode 100644 index 600fd8a..0000000 --- a/iree/rt/invocation.h +++ /dev/null
@@ -1,157 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_RT_INVOCATION_H_ -#define IREE_RT_INVOCATION_H_ - -#include <ostream> -#include <string> - -#include "absl/base/thread_annotations.h" -#include "absl/time/time.h" -#include "absl/types/span.h" -#include "iree/base/intrusive_list.h" -#include "iree/base/ref_ptr.h" -#include "iree/base/status.h" -#include "iree/hal/buffer_view.h" -#include "iree/rt/function.h" -#include "iree/rt/policy.h" -#include "iree/rt/stack.h" -#include "iree/rt/stack_trace.h" - -namespace iree { -namespace rt { - -class Context; - -// An asynchronous invocation of a function. -// Holds the invocation state and allows querying and waiting on completion. -// Invocations are conceptually fibers and may suspend and resume execution -// several times before completing. -// -// Thread-safe. -class Invocation final : public RefObject<Invocation> { - public: - // TODO(benvanik): define error propagation semantics across dependencies. - // TODO(benvanik): support more dependency types (semaphores, etc). - // Creates a new invocation tracking object for invoking the given |function| - // from |context|. |arguments| will be retained until the invocation is made. - // If |dependencies| are provided then the invocation will wait until they are - // resolved before executing. If a |policy| is provided it will override the - // context-level policy. - // - // Optionally |results| may be provided with preallocated buffers that will - // receive the outputs of the invocation. Invocation will fail if they do not - // match expected sizes. - // - // Note that it's possible for the invocation to complete prior to the return - // of this function. Any errors that occur will be set on the invocation and - // callers should query its state prior to assuming it is in-flight. - static StatusOr<ref_ptr<Invocation>> Create( - ref_ptr<Context> context, const Function function, ref_ptr<Policy> policy, - absl::InlinedVector<ref_ptr<Invocation>, 4> dependencies, - absl::InlinedVector<hal::BufferView, 8> arguments, - absl::optional<absl::InlinedVector<hal::BufferView, 8>> results = - absl::nullopt); - static StatusOr<ref_ptr<Invocation>> Create( - ref_ptr<Context> context, const Function function, ref_ptr<Policy> policy, - absl::Span<const ref_ptr<Invocation>> dependencies, - absl::Span<const hal::BufferView> arguments); - - ~Invocation(); - - // A process-unique ID for the invocation. - int32_t id() const { return id_; } - - // Context this invocation is running within. - const ref_ptr<Context>& context() const { return context_; } - - // Function being invoked. - const Function& function() const { return function_; } - - // A single-line human-readable debug string for the invocation. - std::string DebugStringShort() const; - - // A long-form debug string with stack trace (if available). - std::string DebugString() const; - - // Queries the completion status of the invocation. - // Returns one of the following: - // StatusCode::kOk: the invocation completed successfully. - // StatusCode::kUnavailable: the invocation has not yet completed. - // StatusCode::kCancelled: the invocation was cancelled internally. - // StatusCode::kAborted: the invocation was aborted. - // StatusCode::*: an error occurred during invocation. - Status QueryStatus(); - - // Returns ownership of the results of the operation to the caller. - // If the invocation failed then the result will be returned as if Query had - // been called. - StatusOr<absl::InlinedVector<hal::BufferView, 8>> ConsumeResults(); - - // Blocks the caller until the invocation completes (successfully or - // otherwise). - // - // Returns StatusCode::kDeadlineExceeded if |deadline| elapses before the - // invocation completes and otherwise returns the result of Query(). - Status Await(absl::Time deadline); - - // Attempts to abort the invocation if it is in-flight. - // A no-op if the invocation has already completed. - Status Abort(); - - // TODO(benvanik): export a hal::TimelineSemaphore. - - private: - friend class Context; - - Invocation(ref_ptr<Context> context, const Function function, - ref_ptr<Policy> policy); - - // Completes the invocation with a successful result. - void CompleteSuccess(absl::InlinedVector<hal::BufferView, 8> results); - - // Completes the invocation with a failure, including an optional stack trace. - void CompleteFailure(Status completion_status, - std::unique_ptr<StackTrace> failure_stack_trace); - - int32_t id_; - ref_ptr<Context> context_; - const Function function_; - ref_ptr<Policy> policy_; - - Stack stack_; - - absl::Mutex status_mutex_; - Status completion_status_ ABSL_GUARDED_BY(status_mutex_) = - UnavailableErrorBuilder(IREE_LOC); - std::unique_ptr<StackTrace> failure_stack_trace_ - ABSL_GUARDED_BY(status_mutex_); - absl::InlinedVector<hal::BufferView, 8> results_ - ABSL_GUARDED_BY(status_mutex_); - - friend class Context; - IntrusiveListLink context_list_link_; -}; - -inline std::ostream& operator<<(std::ostream& stream, - const Invocation& invocation) { - stream << invocation.DebugStringShort(); - return stream; -} - -} // namespace rt -} // namespace iree - -#endif // IREE_RT_INVOCATION_H_
diff --git a/iree/rt/module.h b/iree/rt/module.h deleted file mode 100644 index f91a3a5..0000000 --- a/iree/rt/module.h +++ /dev/null
@@ -1,109 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_RT_MODULE_H_ -#define IREE_RT_MODULE_H_ - -#include <ostream> - -#include "absl/container/inlined_vector.h" -#include "absl/strings/string_view.h" -#include "iree/base/ref_ptr.h" -#include "iree/base/status.h" -#include "iree/hal/buffer_view.h" -#include "iree/rt/function.h" -#include "iree/rt/module_signature.h" - -namespace iree { -namespace rt { - -class Disassembler; -class SourceResolver; -class Stack; - -// Abstract compiled module interface for resolving functions. -// -// Modules are (generally) stateless, immutable, and may exist in multiple -// contexts at the same time. -class Module : public RefObject<Module> { - public: - virtual ~Module() = default; - - // Name of the module used to resolve fully-qualified references. - // The lifetime of the returned reference is not guaranteed beyond the current - // calling scope and callers must clone it if they want to retain it. - virtual absl::string_view name() const = 0; - - // A description of the module imports, exports, and other metadata. - virtual const ModuleSignature signature() const = 0; - - // Returns a resolver capable of resolving functions to source and performing - // basic debugging logic (such as offset calculation). - // May be nullptr if debugging info has been stripped. - virtual SourceResolver* source_resolver() const = 0; - - // Returns a disassembler that can be used to disassemble functions in the - // module. May be nullptr if debugging info has been stripped or disassembly - // has been disabled as a compile option. - virtual Disassembler* disassembler() const = 0; - - // A short human-readable string that matches the compiler formatting. - virtual std::string DebugStringShort() const = 0; - - // Looks up a visible function by ordinal. - // Internal functions may not be found if debugging info has been stripped. - virtual StatusOr<const Function> LookupFunctionByOrdinal( - Function::Linkage linkage, int32_t ordinal) const = 0; - - // Looks up a visible function by name. - // Internal functions may not be found if debugging info has been stripped. - virtual StatusOr<const Function> LookupFunctionByName( - Function::Linkage linkage, absl::string_view name) const = 0; - - // Returns the name of the visible function as a string reference. - // - // May return empty for functions with internal linkage if debugging info has - // been stripped. - // - // The lifetime of the returned reference is not guaranteed beyond the current - // calling scope and callers must clone it if they want to retain it. - virtual StatusOr<absl::string_view> GetFunctionName( - Function::Linkage linkage, int32_t ordinal) const = 0; - - // Returns the full function signature for the given |ordinal|. - // - // May return empty for functions with internal linkage if the debugging info - // has been stripped. - virtual StatusOr<const FunctionSignature> GetFunctionSignature( - Function::Linkage linkage, int32_t ordinal) const = 0; - - // Temporary until scheduler is built. - virtual Status Execute( - Stack* stack, const Function function, - absl::InlinedVector<hal::BufferView, 8> arguments, - absl::InlinedVector<hal::BufferView, 8>* results) const = 0; - - protected: - Module() = default; -}; - -inline std::ostream& operator<<(std::ostream& stream, const Module& module) { - stream << module.DebugStringShort(); - return stream; -} - -} // namespace rt -} // namespace iree - -#endif // IREE_RT_MODULE_H_
diff --git a/iree/rt/module_printer.cc b/iree/rt/module_printer.cc deleted file mode 100644 index dd0267e..0000000 --- a/iree/rt/module_printer.cc +++ /dev/null
@@ -1,65 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/rt/module_printer.h" - -#include <iomanip> - -#include "iree/rt/disassembler.h" -#include "iree/rt/source_resolver.h" - -namespace iree { -namespace rt { - -Status PrintModuleToStream(const Module& module, PrintModuleFlagBitfield flags, - std::ostream* stream) { - *stream << "Imports:\n"; - for (int i = 0; i < module.signature().import_function_count(); ++i) { - ASSIGN_OR_RETURN(auto function, module.LookupFunctionByOrdinal( - Function::Linkage::kImport, i)); - *stream << " " << i << ": " << function << "\n"; - } - *stream << "Exports:\n"; - for (int i = 0; i < module.signature().export_function_count(); ++i) { - ASSIGN_OR_RETURN(auto function, module.LookupFunctionByOrdinal( - Function::Linkage::kExport, i)); - *stream << " " << i << ": " << function << "\n"; - } - if (module.signature().internal_function_count()) { - *stream << "Internal:\n"; - auto* disassembler = module.disassembler(); - for (int i = 0; i < module.signature().internal_function_count(); ++i) { - ASSIGN_OR_RETURN(auto function, module.LookupFunctionByOrdinal( - Function::Linkage::kInternal, i)); - *stream << " " << i << ": " << function << "\n"; - if (disassembler && AllBitsSet(flags, PrintModuleFlag::kDisassemble)) { - auto instructions_or = - disassembler->DisassembleInstructions(function, 0); - if (IsUnavailable(instructions_or.status())) continue; - for (const auto& instruction : instructions_or.ValueOrDie()) { - *stream << " " << std::setw(6) << instruction.offset << ": " - << instruction.long_text << "\n"; - } - } - } - } - return OkStatus(); -} - -Status PrintModuleToStream(const Module& module, std::ostream* stream) { - return PrintModuleToStream(module, PrintModuleFlag::kNone, stream); -} - -} // namespace rt -} // namespace iree
diff --git a/iree/rt/module_printer.h b/iree/rt/module_printer.h deleted file mode 100644 index 945dbbd..0000000 --- a/iree/rt/module_printer.h +++ /dev/null
@@ -1,42 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_RT_MODULE_PRINTER_H_ -#define IREE_RT_MODULE_PRINTER_H_ - -#include <ostream> - -#include "iree/base/bitfield.h" -#include "iree/base/status.h" -#include "iree/rt/module.h" - -namespace iree { -namespace rt { - -enum class PrintModuleFlag { - kNone = 0, - kDisassemble = 1 << 0, -}; -IREE_BITFIELD(PrintModuleFlag); -using PrintModuleFlagBitfield = PrintModuleFlag; - -// Prints all functions within the module to the given |stream|. -Status PrintModuleToStream(const Module& module, std::ostream* stream); -Status PrintModuleToStream(const Module& module, PrintModuleFlagBitfield flags, - std::ostream* stream); - -} // namespace rt -} // namespace iree - -#endif // IREE_RT_MODULE_PRINTER_H_
diff --git a/iree/rt/policy.h b/iree/rt/policy.h deleted file mode 100644 index 8bd72fa..0000000 --- a/iree/rt/policy.h +++ /dev/null
@@ -1,43 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_RT_POLICY_H_ -#define IREE_RT_POLICY_H_ - -#include "iree/base/ref_ptr.h" - -namespace iree { -namespace rt { - -// Defines how invocation scheduling is to be performed. -// The policy instance is used by the scheduler to determine when submissions -// should be flushed to target queues. -// -// Thread-safe; the policy may be evaluated from arbitrary threads after an -// invocation has began processing. -class Policy : public RefObject<Policy> { - public: - virtual ~Policy() = default; - - // TODO(benvanik): constraints: - // - max memory usage - // - max delay - // - max in-flight items/etc - // - allowed device types -}; - -} // namespace rt -} // namespace iree - -#endif // IREE_RT_POLICY_H_
diff --git a/iree/rt/source_location.cc b/iree/rt/source_location.cc deleted file mode 100644 index 835905b..0000000 --- a/iree/rt/source_location.cc +++ /dev/null
@@ -1,32 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/rt/source_location.h" - -#include <sstream> - -#include "iree/rt/source_resolver.h" - -namespace iree { -namespace rt { - -std::string SourceLocation::DebugStringShort() const { - if (is_unknown()) return "(unknown)"; - std::ostringstream stream; - resolver_->PrintSourceLocation(resolver_args_, &stream); - return stream.str(); -} - -} // namespace rt -} // namespace iree
diff --git a/iree/rt/source_resolver.h b/iree/rt/source_resolver.h deleted file mode 100644 index 3ceb1b9..0000000 --- a/iree/rt/source_resolver.h +++ /dev/null
@@ -1,67 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_RT_SOURCE_RESOLVER_H_ -#define IREE_RT_SOURCE_RESOLVER_H_ - -#include <cstdint> -#include <ostream> -#include <vector> - -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" -#include "iree/base/status.h" -#include "iree/rt/function.h" -#include "iree/rt/source_location.h" - -namespace iree { -namespace rt { - -// Resolves offsets within functions to SourceLocations and provides source -// language services. -// -// Thread-safe. -class SourceResolver { - public: - virtual ~SourceResolver() = default; - - // Resolves a function-relative offset to a source location. - // Not all offsets within a function may have source mapping information. - virtual absl::optional<SourceLocation> ResolveFunctionOffset( - const Function& function, SourceOffset offset) = 0; - - // Converts a source location to a human-readable string, commonly in a single - // line denoting an original source file location (such as path:line:col). - virtual void PrintSourceLocation(SourceResolverArgs resolver_args, - std::ostream* stream) const = 0; - - // TODO(benvanik): query local variable names. - - // TODO(benvanik): step target calculation (relative mapping). - // TODO(benvanik): step target based on SourceLocation delta. - - // TODO(benvanik): expression evaluator? (setting variables) - - protected: - friend class SourceLocation; - - SourceResolver() = default; - - // TODO(benvanik): get line mapping information. -}; - -} // namespace rt -} // namespace iree - -#endif // IREE_RT_SOURCE_RESOLVER_H_
diff --git a/iree/rt/stack.cc b/iree/rt/stack.cc deleted file mode 100644 index f4f300f..0000000 --- a/iree/rt/stack.cc +++ /dev/null
@@ -1,60 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/rt/stack.h" - -#include <iterator> - -#include "absl/strings/str_join.h" -#include "iree/base/status.h" - -namespace iree { -namespace rt { - -constexpr int Stack::kMaxStackDepth; - -Stack::Stack(Context* context) : context_(context) {} - -Stack::~Stack() = default; - -StatusOr<StackFrame*> Stack::PushFrame(Function function) { - if (stack_depth_ + 1 > kMaxStackDepth) { - return InternalErrorBuilder(IREE_LOC) - << "Max stack depth of " << kMaxStackDepth << " exceeded"; - } - frames_[stack_depth_++] = StackFrame(function); - - // TODO(benvanik): WTF scope enter. - - return current_frame(); -} - -Status Stack::PopFrame() { - if (stack_depth_ == 0) { - return InternalErrorBuilder(IREE_LOC) << "Unbalanced stack pop"; - } - - // TODO(benvanik): WTF scope leave. - - --stack_depth_; - frames_[stack_depth_] = {}; - return OkStatus(); -} - -std::string Stack::DebugString() const { - return absl::StrJoin(frames(), "\n", StackFrameFormatter()); -} - -} // namespace rt -} // namespace iree
diff --git a/iree/rt/stack.h b/iree/rt/stack.h deleted file mode 100644 index a9547fe..0000000 --- a/iree/rt/stack.h +++ /dev/null
@@ -1,85 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_RT_STACK_H_ -#define IREE_RT_STACK_H_ - -#include <functional> - -#include "absl/types/span.h" -#include "iree/base/status.h" -#include "iree/rt/stack_frame.h" - -namespace iree { -namespace rt { - -class Context; - -// A runtime call stack for managing stack frames. -// The frames within a stack may be from different backends and may provide -// varying levels of information based on capabilities. -// -// Thread-compatible. Do not attempt to investigate a stack while another thread -// may be mutating it! -class Stack final { - public: - static constexpr int kMaxStackDepth = 32; - - explicit Stack(Context* context); - Stack(const Stack&) = delete; - Stack& operator=(const Stack&) = delete; - ~Stack(); - - // Context defining the module and global workspaces. - Context* context() const { return context_; } - - // All stack frames within the stack. - absl::Span<StackFrame> frames() { - return absl::MakeSpan(frames_).subspan(0, stack_depth_); - } - absl::Span<const StackFrame> frames() const { - return absl::MakeConstSpan(frames_).subspan(0, stack_depth_); - } - - // The current stack frame. - StackFrame* current_frame() { - return stack_depth_ > 0 ? &frames_[stack_depth_ - 1] : nullptr; - } - - // The stack frame of the caller of the current function. - StackFrame* caller_frame() { - return stack_depth_ > 1 ? &frames_[stack_depth_ - 2] : nullptr; - } - - StatusOr<StackFrame*> PushFrame(Function function); - Status PopFrame(); - - // Returns a full stack frame listing in human-readable form. - std::string DebugString() const; - - private: - Context* context_ = nullptr; - std::array<StackFrame, kMaxStackDepth> frames_; - int stack_depth_ = 0; -}; - -inline std::ostream& operator<<(std::ostream& stream, const Stack& stack) { - stream << stack.DebugString(); - return stream; -} - -} // namespace rt -} // namespace iree - -#endif // IREE_RT_STACK_H_
diff --git a/iree/rt/stack_frame.cc b/iree/rt/stack_frame.cc deleted file mode 100644 index 12ea5c8..0000000 --- a/iree/rt/stack_frame.cc +++ /dev/null
@@ -1,34 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/rt/stack_frame.h" - -#include "absl/strings/str_cat.h" -#include "iree/rt/source_resolver.h" - -namespace iree { -namespace rt { - -absl::optional<SourceLocation> StackFrame::source_location() const { - auto* source_resolver = function_.module()->source_resolver(); - if (!source_resolver) return absl::nullopt; - return source_resolver->ResolveFunctionOffset(function_, offset_); -} - -std::string StackFrame::DebugStringShort() const { - return absl::StrCat(module().name(), ":", function().name(), "@", offset()); -} - -} // namespace rt -} // namespace iree
diff --git a/iree/rt/stack_frame.h b/iree/rt/stack_frame.h deleted file mode 100644 index bb77ef4..0000000 --- a/iree/rt/stack_frame.h +++ /dev/null
@@ -1,106 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_RT_STACK_FRAME_H_ -#define IREE_RT_STACK_FRAME_H_ - -#include <ostream> - -#include "absl/types/span.h" -#include "iree/rt/function.h" -#include "iree/rt/module.h" -#include "iree/rt/source_location.h" - -namespace iree { -namespace rt { - -// TODO(benvanik): allocate in-place from an arena. -// Register table used within a stack frame. -struct Registers { - std::vector<hal::BufferView> buffer_views; -}; - -// A single frame on the call stack containing current execution state and -// register values. -// -// As different backends may support different features this interface exposes -// only the things we want to view in our debugger/stack dumps. This allows us -// to ignore the actual implementation (bytecode VM, compiled C code, etc) so -// long as it can respond to queries for register values. This has the benefit -// of keeping the actual frame very lightweight as we are not storing the values -// but instead just routing to the real storage via indirection. If the debugger -// is not attached and no errors are hit then no additional bookkeeping is done. -// -// Thread-compatible, as is the owning Stack/StackTrace. -class StackFrame final { - public: - StackFrame() = default; - explicit StackFrame(Function function) : function_(function) {} - StackFrame(Function function, SourceOffset offset, Registers registers) - : function_(function), - offset_(offset), - registers_(std::move(registers)) {} - StackFrame(const StackFrame&) = delete; - StackFrame& operator=(const StackFrame&) = delete; - StackFrame(StackFrame&&) = default; - StackFrame& operator=(StackFrame&&) = default; - - // Module that owns the function this stack frame represents. - const Module& module() const { return *function_.module(); } - - // Function the stack frame represents. - const Function& function() const { return function_; } - - // Current virtual offset within the function. - // The exact meaning of the offset is backend dependent and callers should - // treat them as opaque and must use the SourceResolver to compute new - // offsets (such as 'next offset'). - SourceOffset offset() const { return offset_; } - SourceOffset* mutable_offset() { return &offset_; } - - // Returns a source location, if available, for the current offset within the - // target function. - absl::optional<SourceLocation> source_location() const; - - // Registers used within the stack frame. - // Storage is implementation-defined and is valid only for the lifetime of the - // frame. - const Registers& registers() const { return registers_; } - Registers* mutable_registers() { return ®isters_; } - - // A short human-readable string for the frame; a single line. - std::string DebugStringShort() const; - - private: - Function function_; - SourceOffset offset_ = 0; - Registers registers_; -}; - -struct StackFrameFormatter { - void operator()(std::string* out, const StackFrame& stack_frame) const { - out->append(stack_frame.DebugStringShort()); - } -}; - -inline std::ostream& operator<<(std::ostream& stream, - const StackFrame& stack_frame) { - stream << stack_frame.DebugStringShort(); - return stream; -} - -} // namespace rt -} // namespace iree - -#endif // IREE_RT_STACK_FRAME_H_
diff --git a/iree/rt/stack_trace.cc b/iree/rt/stack_trace.cc deleted file mode 100644 index ace8803..0000000 --- a/iree/rt/stack_trace.cc +++ /dev/null
@@ -1,28 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/rt/stack_trace.h" - -#include "absl/strings/str_join.h" -#include "iree/base/status.h" - -namespace iree { -namespace rt { - -std::string StackTrace::DebugString() const { - return absl::StrJoin(frames_, "\n", StackFrameFormatter()); -} - -} // namespace rt -} // namespace iree
diff --git a/iree/rt/stack_trace.h b/iree/rt/stack_trace.h deleted file mode 100644 index beca466..0000000 --- a/iree/rt/stack_trace.h +++ /dev/null
@@ -1,65 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_RT_STACK_TRACE_H_ -#define IREE_RT_STACK_TRACE_H_ - -#include <ostream> -#include <vector> - -#include "absl/types/span.h" -#include "iree/base/status.h" -#include "iree/rt/stack_frame.h" - -namespace iree { -namespace rt { - -// A snapshot of a stack at a point in time. -// The frames within a stack may be from different backends and may provide -// varying levels of information based on capabilities. -// -// Depending on the capture options the trace may contain references to register -// values (such as buffers) from the time of capture. If the buffers were -// modified after the capture was taken those results will be reflected! -class StackTrace final { - public: - StackTrace() = default; - explicit StackTrace(std::vector<StackFrame> frames) - : frames_(std::move(frames)) {} - StackTrace(const StackTrace&) = delete; - StackTrace& operator=(const StackTrace&) = delete; - ~StackTrace() = default; - - // All stack frames within the stack. - absl::Span<const StackFrame> frames() const { - return absl::MakeConstSpan(frames_); - } - - // Returns a full stack frame listing in human-readable form. - std::string DebugString() const; - - private: - std::vector<StackFrame> frames_; -}; - -inline std::ostream& operator<<(std::ostream& stream, - const StackTrace& stack_trace) { - stream << stack_trace.DebugString(); - return stream; -} - -} // namespace rt -} // namespace iree - -#endif // IREE_RT_STACK_TRACE_H_
diff --git a/iree/samples/hal/BUILD b/iree/samples/hal/BUILD deleted file mode 100644 index 8fa22bf..0000000 --- a/iree/samples/hal/BUILD +++ /dev/null
@@ -1,48 +0,0 @@ -# Samples demonstrating use of the HAL API. -# These do not rely on higher layers of the system (such as the VM or runtime). - -load("//iree:build_defs.bzl", "PLATFORM_VULKAN_TEST_DEPS") -load("//iree/tools:compilation.bzl", "iree_bytecode_module") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -iree_bytecode_module( - name = "simple_compute_test_module", - srcs = ["simple_compute_test.mlir"], - cc_namespace = "iree::hal::samples", -) - -cc_test( - name = "simple_compute_test", - srcs = ["simple_compute_test.cc"], - data = [ - # When building with --config=asan you must specify the following - # envvar when using Vulkan + a local Nvidia GPU: - # LSAN_OPTIONS=suppressions=third_party/iree/tools/sanitizer_suppressions.txt - "//iree/tools:sanitizer_suppressions.txt", - ], - deps = [ - ":simple_compute_test_module_cc", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - "//iree/base:flatbuffer_util", - "//iree/base:status_matchers", - "//iree/hal:command_buffer", - "//iree/hal:command_queue", - "//iree/hal:driver_registry", - "//iree/schemas", - "//iree/base:status", - - # These are the drivers we support running with and can produce - # executables for from the source MLIR. - "//iree/hal/interpreter:interpreter_driver_module", # build-cleaner: keep - "//iree/hal/vulkan:vulkan_driver_module", # build-cleaner: keep - - # TODO(b/142004903): enable when Dawn HAL implementation is functional - # "//iree/hal/dawn:dawn_driver_module", # build-cleaner: keep - ] + PLATFORM_VULKAN_TEST_DEPS, -)
diff --git a/iree/samples/hal/simple_compute_test.cc b/iree/samples/hal/simple_compute_test.cc deleted file mode 100644 index 24d7414..0000000 --- a/iree/samples/hal/simple_compute_test.cc +++ /dev/null
@@ -1,221 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// A simple backend-agnostic compute test for the HAL API. -// This will load an IREE module containing one or more executables and attempt -// to run them against all registered driver backends. -// -// The input file, simple_compute_test.mlir, is as generic as possible to ensure -// we don't need too many variants. This means that it does not use any FFI -// imports requiring runtime support, uses floats exclusively (as that's assumed -// available everywhere), etc. -// -// The `iree_bytecode_module` build rule is used to translate the MLIR to the -// module flatbuffer. Additional target support can be defined there. - -#include "absl/container/inlined_vector.h" -#include "absl/strings/str_replace.h" -#include "absl/time/time.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "iree/base/flatbuffer_util.h" -#include "iree/base/status.h" -#include "iree/base/status_matchers.h" -#include "iree/hal/command_buffer.h" -#include "iree/hal/command_queue.h" -#include "iree/hal/driver_registry.h" -#include "iree/samples/hal/simple_compute_test_module.h" -#include "iree/schemas/module_def_generated.h" - -namespace iree { -namespace hal { -namespace samples { -namespace { - -using ModuleFile = FlatBufferFile<ModuleDef>; - -struct TestParams { - // HAL driver to use for the test. - std::string driver_name; - // Ordinal within the module to execute. - int executable_ordinal; - // Name of the executable (just for prettier logging). - std::string executable_name; -}; - -std::ostream& operator<<(std::ostream& os, const TestParams& params) { - return os << absl::StrReplaceAll(params.driver_name, {{":", "_"}}) << "_ex" - << params.executable_ordinal << "_" << params.executable_name; -} - -// Loads the precompiled module file (from simple_compute_test.mlir). -std::unique_ptr<ModuleFile> LoadModuleFile() { - const auto* file_toc = simple_compute_test_module_create(); - return ModuleFile::WrapBuffer( - ModuleDefIdentifier(), - absl::MakeSpan(reinterpret_cast<const uint8_t*>(file_toc->data), - file_toc->size)) - .ValueOrDie(); -} - -// Builds a list of tests to run for each [driver x available executable]. -std::vector<TestParams> GetAvailableDriverTestParams() { - auto module_file = LoadModuleFile(); - auto& executable_table = *module_file->root()->executable_table(); - std::vector<TestParams> all_test_params; - for (const auto& driver_name : - DriverRegistry::shared_registry()->EnumerateAvailableDrivers()) { - int executable_ordinal = 0; - for (const auto* multi_arch_executable_def : - *executable_table.multi_arch_executables()) { - TestParams test_params; - test_params.driver_name = driver_name; - test_params.executable_ordinal = executable_ordinal--; - test_params.executable_name = - std::string(WrapString(multi_arch_executable_def->name())); - all_test_params.push_back(std::move(test_params)); - } - } - return all_test_params; -} - -class SimpleComputeTest : public ::testing::Test, - public ::testing::WithParamInterface<TestParams> { - protected: - virtual void SetUp() { module_file_ = LoadModuleFile(); } - - std::unique_ptr<ModuleFile> module_file_; -}; - -TEST_P(SimpleComputeTest, RunOnce) { - const auto& test_params = GetParam(); - - // Create driver for this test (based on params) and then get a default - // device. - LOG(INFO) << "Creating driver '" << test_params.driver_name << "'..."; - auto driver_or = - DriverRegistry::shared_registry()->Create(test_params.driver_name); - if (IsUnavailable(driver_or.status())) { - LOG(WARNING) << "Skipping test as driver is unavailable: " - << driver_or.status(); - GTEST_SKIP(); - return; - } - ASSERT_OK_AND_ASSIGN(auto driver, driver_or); - ASSERT_OK_AND_ASSIGN(auto available_devices, - driver->EnumerateAvailableDevices()); - for (const auto& device_info : available_devices) { - LOG(INFO) << " Device: " << device_info.name(); - } - LOG(INFO) << "Creating default device..."; - ASSERT_OK_AND_ASSIGN(auto device, driver->CreateDefaultDevice()); - LOG(INFO) << "Successfully created device '" << device->info().name() << "'"; - - // Attempt to compile the appropriate executable. This may fail if there's no - // executable available in the input file that the driver can load. - auto executable_cache = device->CreateExecutableCache(); - auto& executable_table = *module_file_->root()->executable_table(); - auto multi_arch_executable_def = - executable_table.multi_arch_executables()->Get( - test_params.executable_ordinal); - ref_ptr<Executable> executable; - for (auto executable_def : *multi_arch_executable_def->executables()) { - if (!executable_cache->CanPrepareFormat(executable_def->format())) { - continue; - } - ExecutableSpec spec; - spec.format = executable_def->format(); - spec.executable_data = *executable_def->contents(); - ASSERT_OK_AND_ASSIGN(executable, - executable_cache->PrepareExecutable( - ExecutableCachingMode::kDefault, spec)); - break; - } - ASSERT_NE(executable, nullptr) - << "No executable found that has a supported format for driver " - << test_params.driver_name; - - // Create I/O buffers. - ASSERT_OK_AND_ASSIGN(auto arg0_buffer, - device->allocator()->Allocate( - MemoryType::kHostLocal | MemoryType::kDeviceVisible, - BufferUsage::kAll, 4 * sizeof(float))); - ASSERT_OK_AND_ASSIGN(auto arg1_buffer, - device->allocator()->Allocate( - MemoryType::kHostLocal | MemoryType::kDeviceVisible, - BufferUsage::kAll, 4 * sizeof(float))); - ASSERT_OK_AND_ASSIGN(auto ret0_buffer, - device->allocator()->Allocate( - MemoryType::kHostLocal | MemoryType::kDeviceVisible, - BufferUsage::kAll, 4 * sizeof(float))); - - // Populate initial values for 4 * 2 = 8. - // We scribble into the result buffer so that it's easy to ensure it's - // overwritten. - ASSERT_OK(arg0_buffer->Fill32(4.0f)); - ASSERT_OK(arg1_buffer->Fill32(2.0f)); - ASSERT_OK(ret0_buffer->Fill32(99999.0f)); - - // Record the command buffer that dispatches the executable. - ASSERT_OK_AND_ASSIGN( - auto cmd, device->CreateCommandBuffer( - CommandBufferMode::kOneShot, - CommandCategory::kTransfer | CommandCategory::kDispatch)); - ASSERT_OK(cmd->Begin()); - DispatchRequest dispatch_request; - dispatch_request.executable = executable.get(); - dispatch_request.entry_point = 0; - dispatch_request.workload[0] = 4; - dispatch_request.workload[1] = 1; - dispatch_request.workload[2] = 1; - BufferBinding bindings[3]; - bindings[0].buffer = arg0_buffer.get(); - bindings[0].access = MemoryAccess::kRead; - bindings[0].element_size = sizeof(float); - bindings[0].shape = {4}; - bindings[1].buffer = arg1_buffer.get(); - bindings[1].access = MemoryAccess::kRead; - bindings[1].element_size = sizeof(float); - bindings[1].shape = {4}; - bindings[2].buffer = ret0_buffer.get(); - bindings[2].access = MemoryAccess::kDiscardWrite; - bindings[2].element_size = sizeof(float); - bindings[2].shape = {4}; - dispatch_request.bindings = bindings; - ASSERT_OK(cmd->Dispatch(dispatch_request)); - ASSERT_OK(cmd->End()); - - // Schedule and wait for completion. - ASSERT_FALSE(device->dispatch_queues().empty()); - CommandQueue* queue = device->dispatch_queues().front(); - ASSERT_OK_AND_ASSIGN(auto fence, device->CreateFence(0u)); - ASSERT_OK( - queue->Submit(SubmissionBatch{{}, {cmd.get()}, {}}, {fence.get(), 1u})); - ASSERT_OK(device->WaitAllFences({{fence.get(), 1u}}, absl::InfiniteFuture())); - - // Read back the results. - ASSERT_OK_AND_ASSIGN(auto ret0_mapping, - ret0_buffer->MapMemory<float>(MemoryAccess::kRead)); - EXPECT_THAT(ret0_mapping.contents(), - ::testing::ElementsAreArray({8.0f, 8.0f, 8.0f, 8.0f})); -} - -INSTANTIATE_TEST_SUITE_P(AllDrivers, SimpleComputeTest, - ::testing::ValuesIn(GetAvailableDriverTestParams()), - ::testing::PrintToStringParamName()); - -} // namespace -} // namespace samples -} // namespace hal -} // namespace iree
diff --git a/iree/samples/rt/BUILD b/iree/samples/rt/BUILD deleted file mode 100644 index e880a5d..0000000 --- a/iree/samples/rt/BUILD +++ /dev/null
@@ -1,72 +0,0 @@ -# Samples demonstrating use of the RT API. - -load("//iree/tools:compilation.bzl", "iree_bytecode_module") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -iree_bytecode_module( - name = "simple_module_test_bytecode_module", - srcs = ["simple_module_test.mlir"], - cc_namespace = "iree::rt::samples", -) - -cc_test( - name = "bytecode_module_test", - srcs = ["bytecode_module_test.cc"], - data = [ - # When building with --config=asan you must specify the following - # envvar when using Vulkan + a local Nvidia GPU: - # LSAN_OPTIONS=suppressions=third_party/iree/tools/sanitizer_suppressions.txt - "//iree/tools:sanitizer_suppressions.txt", - ], - deps = [ - ":simple_module_test_bytecode_module_cc", - "@com_google_googletest//:gtest_main", - "@com_google_absl//absl/strings", - "//iree/base:flatbuffer_util", - "//iree/base:status", - "//iree/base:status_matchers", - "//iree/hal:buffer_view", - "//iree/hal:driver_registry", - "//iree/rt", - "//iree/schemas", - "//iree/vm:bytecode_module", - "//iree/vm:sequencer_module", - - # These are the drivers we support running with and can produce - # executables for from the source MLIR. - "//iree/hal/interpreter:interpreter_driver_module", # build-cleaner: keep - # TODO(benvanik): include SPIR-V. - # "//iree/hal/vulkan:vulkan_driver_module", # build-cleaner: keep - ], -) - -cc_test( - name = "bytecode_module_api_test", - srcs = ["bytecode_module_api_test.cc"], - data = [ - # When building with --config=asan you must specify the following - # envvar when using Vulkan + a local Nvidia GPU: - # LSAN_OPTIONS=suppressions=third_party/iree/tools/sanitizer_suppressions.txt - "//iree/tools:sanitizer_suppressions.txt", - ], - deps = [ - ":simple_module_test_bytecode_module_cc", - "@com_google_googletest//:gtest_main", - "@com_google_absl//absl/strings", - "//iree/base:api", - "//iree/hal:api", - "//iree/hal:driver_registry", - "//iree/rt:api", - "//iree/vm:api", - - # These are the drivers we support running with and can produce - # executables for from the source MLIR. - "//iree/hal/interpreter:interpreter_driver_module", # build-cleaner: keep - # TODO(benvanik): include SPIR-V. - # "//iree/hal/vulkan:vulkan_driver_module", # build-cleaner: keep - ], -)
diff --git a/iree/samples/rt/bytecode_module_api_test.cc b/iree/samples/rt/bytecode_module_api_test.cc deleted file mode 100644 index 385a4ae..0000000 --- a/iree/samples/rt/bytecode_module_api_test.cc +++ /dev/null
@@ -1,188 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "absl/strings/str_replace.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -// C API: -#include "iree/base/api.h" -#include "iree/hal/api.h" -#include "iree/rt/api.h" -#include "iree/vm/api.h" - -// Only temporary, used for test device registration: -#include "iree/hal/driver_registry.h" - -// Compiled module embedded here to avoid file IO: -#include "iree/samples/rt/simple_module_test_bytecode_module.h" - -namespace iree { -namespace rt { -namespace samples { -namespace { - -#define ASSERT_API_OK(expr) ASSERT_EQ(IREE_STATUS_OK, (expr)) - -struct TestParams { - // HAL driver to use for the test. - std::string driver_name; -}; - -std::ostream& operator<<(std::ostream& os, const TestParams& params) { - return os << absl::StrReplaceAll(params.driver_name, {{":", "_"}}); -} - -// Builds a list of tests to run based on the linked in driver modules. -std::vector<TestParams> GetAvailableDriverTestParams() { - std::vector<TestParams> all_test_params; - for (const auto& driver_name : - hal::DriverRegistry::shared_registry()->EnumerateAvailableDrivers()) { - TestParams test_params; - test_params.driver_name = driver_name; - all_test_params.push_back(std::move(test_params)); - } - return all_test_params; -} - -class BytecodeModuleApiTest : public ::testing::Test, - public ::testing::WithParamInterface<TestParams> { - protected: -}; - -TEST_P(BytecodeModuleApiTest, RunOnce) { - iree_rt_instance_t* instance = nullptr; - ASSERT_API_OK(iree_rt_instance_create(IREE_ALLOCATOR_DEFAULT, &instance)); - - // TEMPORARY: until policies and placement are performed with manually - // register drivers via a magic function. - const auto& driver_name = GetParam().driver_name; - LOG(INFO) << "Creating driver '" << driver_name << "'..."; - ASSERT_API_OK(iree_rt_instance_register_driver_ex( - instance, iree_string_view_t{driver_name.data(), driver_name.size()})); - - // Allocate a context that will hold the module state across invocations. - iree_rt_policy_t* dummy_policy = nullptr; - ASSERT_API_OK(iree_rt_policy_create(IREE_ALLOCATOR_DEFAULT, &dummy_policy)); - iree_rt_context_t* context = nullptr; - ASSERT_API_OK(iree_rt_context_create(instance, dummy_policy, - IREE_ALLOCATOR_DEFAULT, &context)); - iree_rt_policy_release(dummy_policy); - - // Load bytecode module from the embedded data. - LOG(INFO) << "Loading simple_module_test.mlir..."; - const auto* module_file_toc = simple_module_test_bytecode_module_create(); - iree_rt_module_t* bytecode_module = nullptr; - ASSERT_API_OK(iree_vm_bytecode_module_create_from_buffer( - iree_const_byte_span_t{ - reinterpret_cast<const uint8_t*>(module_file_toc->data), - module_file_toc->size}, - nullptr, nullptr, IREE_ALLOCATOR_DEFAULT, &bytecode_module)); - - // Register modules that we want to be able to use in the context. - std::vector<iree_rt_module_t*> modules; - modules.push_back(bytecode_module); - ASSERT_API_OK( - iree_rt_context_register_modules(context, &modules[0], modules.size())); - iree_rt_module_release(bytecode_module); - LOG(INFO) << "Module loaded and context is ready for use"; - - // Lookup the entry point function. - iree_rt_function_t main_function; - const char kMainFunctionName[] = "module.simple_mul"; - ASSERT_API_OK(iree_rt_context_resolve_function( - context, - iree_string_view_t{kMainFunctionName, sizeof(kMainFunctionName) - 1}, - &main_function)); - - // Allocate buffers that can be mapped on the CPU and that can also be used - // on the device. Not all devices support this, but the ones we have now do. - LOG(INFO) << "Creating I/O buffers..."; - constexpr int kElementCount = 4; - iree_hal_buffer_t* arg0_buffer = nullptr; - iree_hal_buffer_t* arg1_buffer = nullptr; - ASSERT_API_OK(iree_rt_context_allocate_device_visible_buffer( - context, IREE_HAL_BUFFER_USAGE_ALL, sizeof(float) * kElementCount, - IREE_ALLOCATOR_DEFAULT, &arg0_buffer)); - ASSERT_API_OK(iree_rt_context_allocate_device_visible_buffer( - context, IREE_HAL_BUFFER_USAGE_ALL, sizeof(float) * kElementCount, - IREE_ALLOCATOR_DEFAULT, &arg1_buffer)); - - // Populate initial values for 4 * 2 = 8. - float kFloat4 = 4.0f; - float kFloat2 = 2.0f; - ASSERT_API_OK(iree_hal_buffer_fill(arg0_buffer, 0, IREE_WHOLE_BUFFER, - &kFloat4, sizeof(float))); - ASSERT_API_OK(iree_hal_buffer_fill(arg1_buffer, 0, IREE_WHOLE_BUFFER, - &kFloat2, sizeof(float))); - - // Wrap buffers in buffer views to provide shape information. - std::array<iree_hal_buffer_view_t*, 2> arg_buffer_views; - ASSERT_API_OK(iree_hal_buffer_view_create( - arg0_buffer, iree_shape_t{1, {kElementCount}}, sizeof(float), - IREE_ALLOCATOR_DEFAULT, &arg_buffer_views[0])); - ASSERT_API_OK(iree_hal_buffer_view_create( - arg1_buffer, iree_shape_t{1, {kElementCount}}, sizeof(float), - IREE_ALLOCATOR_DEFAULT, &arg_buffer_views[1])); - iree_hal_buffer_release(arg0_buffer); - iree_hal_buffer_release(arg1_buffer); - - // Call into the @simple_mul function. - LOG(INFO) << "Calling @simple_mul..."; - iree_rt_invocation_t* invocation = nullptr; - ASSERT_API_OK(iree_rt_invocation_create( - context, &main_function, nullptr, nullptr, arg_buffer_views.data(), 2, - nullptr, 0, IREE_ALLOCATOR_DEFAULT, &invocation)); - ASSERT_API_OK(iree_hal_buffer_view_release(arg_buffer_views[0])); - ASSERT_API_OK(iree_hal_buffer_view_release(arg_buffer_views[1])); - ASSERT_API_OK( - iree_rt_invocation_await(invocation, IREE_TIME_INFINITE_FUTURE)); - - // Get the result buffers from the invocation. - LOG(INFO) << "Retreiving results..."; - std::array<iree_hal_buffer_view_t*, 2> result_buffer_views; - iree_host_size_t result_count; - ASSERT_API_OK(iree_rt_invocation_consume_results( - invocation, result_buffer_views.size(), IREE_ALLOCATOR_DEFAULT, - result_buffer_views.data(), &result_count)); - iree_rt_invocation_release(invocation); - - // Read back the results and ensure we got the right values. - LOG(INFO) << "Reading back results..."; - iree_hal_buffer_t* result_buffer = - iree_hal_buffer_view_buffer(result_buffer_views[0]); - iree_hal_mapped_memory_t mapped_memory; - ASSERT_API_OK(iree_hal_buffer_map(result_buffer, IREE_HAL_MEMORY_ACCESS_READ, - 0, IREE_WHOLE_BUFFER, &mapped_memory)); - ASSERT_THAT(absl::Span<const float>( - reinterpret_cast<const float*>(mapped_memory.contents.data), - mapped_memory.contents.data_length / sizeof(float)), - ::testing::ElementsAreArray({8.0f, 8.0f, 8.0f, 8.0f})); - ASSERT_API_OK(iree_hal_buffer_unmap(result_buffer, &mapped_memory)); - LOG(INFO) << "Results match!"; - - iree_hal_buffer_view_release(result_buffer_views[0]); - - iree_rt_context_release(context); - iree_rt_instance_release(instance); -} - -INSTANTIATE_TEST_SUITE_P(AllDrivers, BytecodeModuleApiTest, - ::testing::ValuesIn(GetAvailableDriverTestParams()), - ::testing::PrintToStringParamName()); - -} // namespace -} // namespace samples -} // namespace rt -} // namespace iree
diff --git a/iree/samples/rt/bytecode_module_test.cc b/iree/samples/rt/bytecode_module_test.cc deleted file mode 100644 index 9671fd9..0000000 --- a/iree/samples/rt/bytecode_module_test.cc +++ /dev/null
@@ -1,170 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// A simple sample demonstrating simple synchronous module loading and VM use. -// This will load an IREE module containing a @simple_mul method that performs -// an element-wise multiplication. It will invoke @simple_mul in the VM, once -// for each available HAL driver linked into the binary. -// -// The synchronous invocation method (Context::Invoke) used here waits until all -// asynchronous HAL work completes before returning. It's still possible get -// overlapped execution by invoking methods from other threads with their own -// FiberState, though it's best to use the asynchronous API instead. -// -// The `iree_module` build rule is used to translate the MLIR to the module -// flatbuffer. Additional HAL backend target support can be defined there. - -#include "iree/vm/bytecode_module.h" - -#include "absl/strings/str_replace.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "iree/base/flatbuffer_util.h" -#include "iree/base/status.h" -#include "iree/base/status_matchers.h" -#include "iree/hal/buffer_view.h" -#include "iree/hal/driver_registry.h" -#include "iree/rt/context.h" -#include "iree/rt/instance.h" -#include "iree/samples/rt/simple_module_test_bytecode_module.h" -#include "iree/schemas/module_def_generated.h" -#include "iree/vm/sequencer_module.h" - -namespace iree { -namespace rt { -namespace samples { -namespace { - -using ::iree::hal::BufferView; -using ::iree::vm::ModuleFile; - -struct TestParams { - // HAL driver to use for the test. - std::string driver_name; -}; - -std::ostream& operator<<(std::ostream& os, const TestParams& params) { - return os << absl::StrReplaceAll(params.driver_name, {{":", "_"}}); -} - -// Builds a list of tests to run based on the linked in driver modules. -std::vector<TestParams> GetAvailableDriverTestParams() { - std::vector<TestParams> all_test_params; - for (const auto& driver_name : - hal::DriverRegistry::shared_registry()->EnumerateAvailableDrivers()) { - TestParams test_params; - test_params.driver_name = driver_name; - all_test_params.push_back(std::move(test_params)); - } - return all_test_params; -} - -class BytecodeModuleTest : public ::testing::Test, - public ::testing::WithParamInterface<TestParams> { - protected: -}; - -TEST_P(BytecodeModuleTest, RunOnce) { - auto instance = make_ref<Instance>(); - - // Create driver for this test (based on params) and then get a default - // device. - const auto& test_params = GetParam(); - LOG(INFO) << "Creating driver '" << test_params.driver_name << "'..."; - auto driver_or = - hal::DriverRegistry::shared_registry()->Create(test_params.driver_name); - if (IsUnavailable(driver_or.status())) { - LOG(WARNING) << "Skipping test as driver is unavailable: " - << driver_or.status(); - GTEST_SKIP(); - return; - } - ASSERT_OK_AND_ASSIGN(auto driver, driver_or); - ASSERT_OK_AND_ASSIGN(auto available_devices, - driver->EnumerateAvailableDevices()); - for (const auto& device_info : available_devices) { - LOG(INFO) << " Device: " << device_info.name(); - } - LOG(INFO) << "Creating default device..."; - ASSERT_OK_AND_ASSIGN(auto device, driver->CreateDefaultDevice()); - ASSERT_OK(instance->device_manager()->RegisterDevice(device)); - LOG(INFO) << "Successfully created device '" << device->info().name() << "'"; - - // Make a new context and load the precompiled module file (from - // simple_module_test.mlir) into it. - LOG(INFO) << "Loading simple_module_test.mlir..."; - auto policy = make_ref<Policy>(); - Context context(add_ref(instance), add_ref(policy)); - const auto* module_file_toc = simple_module_test_bytecode_module_create(); - ASSERT_OK_AND_ASSIGN(auto module_file, - vm::ModuleFile::WrapBuffer( - ModuleDefIdentifier(), - absl::MakeSpan(reinterpret_cast<const uint8_t*>( - module_file_toc->data), - module_file_toc->size))); - ASSERT_OK_AND_ASSIGN(auto main_module, - vm::SequencerModule::FromFile(std::move(module_file))); - ASSERT_OK(context.RegisterModule(std::move(main_module))); - LOG(INFO) << "Module loaded and context is ready for use"; - - // Allocate buffers that can be mapped on the CPU and that can also be used - // on the device. Not all devices support this, but the ones we have now do. - LOG(INFO) << "Creating I/O buffers..."; - constexpr int kElementCount = 4; - ASSERT_OK_AND_ASSIGN( - auto arg0_buffer, - instance->device_manager()->AllocateDeviceVisibleBuffer( - hal::BufferUsage::kAll, sizeof(float) * kElementCount, {{device}})); - ASSERT_OK_AND_ASSIGN( - auto arg1_buffer, - instance->device_manager()->AllocateDeviceVisibleBuffer( - hal::BufferUsage::kAll, sizeof(float) * kElementCount, {{device}})); - - // Populate initial values for 4 * 2 = 8. - ASSERT_OK(arg0_buffer->Fill32(4.0f)); - ASSERT_OK(arg1_buffer->Fill32(2.0f)); - - // Call into the @simple_mul function. - LOG(INFO) << "Calling @simple_mul..."; - absl::InlinedVector<BufferView, 8> args{ - BufferView{add_ref(arg0_buffer), {kElementCount}, sizeof(float)}, - BufferView{add_ref(arg1_buffer), {kElementCount}, sizeof(float)}, - }; - ASSERT_OK_AND_ASSIGN(auto simple_mul, - context.ResolveFunction("module.simple_mul")); - ASSERT_OK_AND_ASSIGN(auto invocation, - Invocation::Create(add_ref(&context), simple_mul, - nullptr, {}, std::move(args))); - ASSERT_OK(invocation->Await(absl::InfiniteFuture())); - ASSERT_OK_AND_ASSIGN(auto results, invocation->ConsumeResults()); - - // Read back the results and ensure we got the right values. - LOG(INFO) << "Reading back results..."; - auto& ret_buffer_view = results[0]; - ASSERT_OK_AND_ASSIGN( - auto ret_mapping, - ret_buffer_view.buffer->MapMemory<float>(hal::MemoryAccess::kRead)); - ASSERT_THAT(ret_mapping.contents(), - ::testing::ElementsAreArray({8.0f, 8.0f, 8.0f, 8.0f})); - LOG(INFO) << "Results match!"; -} - -INSTANTIATE_TEST_SUITE_P(AllDrivers, BytecodeModuleTest, - ::testing::ValuesIn(GetAvailableDriverTestParams()), - ::testing::PrintToStringParamName()); - -} // namespace -} // namespace samples -} // namespace rt -} // namespace iree
diff --git a/iree/schemas/BUILD b/iree/schemas/BUILD deleted file mode 100644 index ebe0bd6..0000000 --- a/iree/schemas/BUILD +++ /dev/null
@@ -1,252 +0,0 @@ -load("//iree:build_defs.bzl", "FLATBUFFER_SUPPORTS_REFLECTIONS", "iree_build_test", "iree_flatbuffer_cc_library") -load("//build_tools/embed_data:build_defs.bzl", "cc_embed_data") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -FLATC_ARGS = [ - # Preserve workspace-relative include paths in generated code. - "--keep-prefix", - # Use C++11 'enum class' for enums. - "--scoped-enums", - # Include reflection tables used for dumping debug representations. - "--reflect-names", - # Generate FooT types for unpack/pack support. Note that this should only - # be used in tooling as the code size/runtime overhead is non-trivial. - "--gen-object-api", -] - -iree_flatbuffer_cc_library( - name = "archive_def_cc_fbs", - srcs = ["archive_def.fbs"], - flatc_args = FLATC_ARGS, - includes = [ - ":bytecode_def_cc_fbs_includes", - ":device_def_cc_fbs_includes", - ":device_group_def_cc_fbs_includes", - ":device_table_def_cc_fbs_includes", - ":executable_def_cc_fbs_includes", - ":executable_table_def_cc_fbs_includes", - ":function_def_cc_fbs_includes", - ":function_table_def_cc_fbs_includes", - ":module_def_cc_fbs_includes", - ":source_map_def_cc_fbs_includes", - ":type_def_cc_fbs_includes", - ], -) - -iree_flatbuffer_cc_library( - name = "bytecode_def_cc_fbs", - srcs = ["bytecode_def.fbs"], - flatc_args = FLATC_ARGS, - includes = [ - ], -) - -iree_flatbuffer_cc_library( - name = "debug_service_cc_fbs", - srcs = ["debug_service.fbs"], - flatc_args = FLATC_ARGS, - includes = [ - ":bytecode_def_cc_fbs_includes", - ":device_def_cc_fbs_includes", - ":device_group_def_cc_fbs_includes", - ":device_table_def_cc_fbs_includes", - ":executable_def_cc_fbs_includes", - ":executable_table_def_cc_fbs_includes", - ":function_def_cc_fbs_includes", - ":function_table_def_cc_fbs_includes", - ":module_def_cc_fbs_includes", - ":source_map_def_cc_fbs_includes", - ":type_def_cc_fbs_includes", - ], -) - -iree_flatbuffer_cc_library( - name = "device_def_cc_fbs", - srcs = ["device_def.fbs"], - flatc_args = FLATC_ARGS, - includes = [ - ], -) - -iree_flatbuffer_cc_library( - name = "device_group_def_cc_fbs", - srcs = ["device_group_def.fbs"], - flatc_args = FLATC_ARGS, - includes = [ - ], -) - -iree_flatbuffer_cc_library( - name = "device_table_def_cc_fbs", - srcs = ["device_table_def.fbs"], - flatc_args = FLATC_ARGS, - includes = [ - ":device_def_cc_fbs_includes", - ":device_group_def_cc_fbs_includes", - ], -) - -iree_flatbuffer_cc_library( - name = "executable_def_cc_fbs", - srcs = ["executable_def.fbs"], - flatc_args = FLATC_ARGS, - includes = [ - ], -) - -iree_flatbuffer_cc_library( - name = "executable_table_def_cc_fbs", - srcs = ["executable_table_def.fbs"], - flatc_args = FLATC_ARGS, - includes = [ - ":executable_def_cc_fbs_includes", - ], -) - -iree_flatbuffer_cc_library( - name = "function_def_cc_fbs", - srcs = ["function_def.fbs"], - flatc_args = FLATC_ARGS, - includes = [ - ":bytecode_def_cc_fbs_includes", - ":type_def_cc_fbs_includes", - ], -) - -iree_flatbuffer_cc_library( - name = "function_table_def_cc_fbs", - srcs = ["function_table_def.fbs"], - flatc_args = FLATC_ARGS, - includes = [ - ":bytecode_def_cc_fbs_includes", - ":function_def_cc_fbs_includes", - ":type_def_cc_fbs_includes", - ], -) - -iree_flatbuffer_cc_library( - name = "module_def_cc_fbs", - srcs = ["module_def.fbs"], - flatc_args = FLATC_ARGS, - includes = [ - ":bytecode_def_cc_fbs_includes", - ":device_def_cc_fbs_includes", - ":device_group_def_cc_fbs_includes", - ":device_table_def_cc_fbs_includes", - ":executable_def_cc_fbs_includes", - ":executable_table_def_cc_fbs_includes", - ":function_def_cc_fbs_includes", - ":function_table_def_cc_fbs_includes", - ":source_map_def_cc_fbs_includes", - ":type_def_cc_fbs_includes", - ], -) - -iree_flatbuffer_cc_library( - name = "source_map_def_cc_fbs", - srcs = ["source_map_def.fbs"], - flatc_args = FLATC_ARGS, - includes = [ - ], -) - -iree_flatbuffer_cc_library( - name = "spirv_executable_def_cc_fbs", - srcs = ["spirv_executable_def.fbs"], - flatc_args = FLATC_ARGS, - includes = [ - ], -) - -iree_flatbuffer_cc_library( - name = "type_def_cc_fbs", - srcs = ["type_def.fbs"], - flatc_args = FLATC_ARGS, - includes = [ - ], -) - -iree_build_test( - name = "schema_build_test", - targets = [ - ":archive_def_cc_fbs", - ":bytecode_def_cc_fbs", - ":debug_service_cc_fbs", - ":device_def_cc_fbs", - ":device_group_def_cc_fbs", - ":device_table_def_cc_fbs", - ":executable_def_cc_fbs", - ":executable_table_def_cc_fbs", - ":function_def_cc_fbs", - ":function_table_def_cc_fbs", - ":module_def_cc_fbs", - ":source_map_def_cc_fbs", - ":spirv_executable_def_cc_fbs", - ":type_def_cc_fbs", - ], -) - -cc_library( - name = "schemas", - hdrs = [ - ":archive_def_generated.h", - ":bytecode_def_generated.h", - ":debug_service_generated.h", - ":device_def_generated.h", - ":device_group_def_generated.h", - ":device_table_def_generated.h", - ":executable_def_generated.h", - ":executable_table_def_generated.h", - ":function_def_generated.h", - ":function_table_def_generated.h", - ":module_def_generated.h", - ":source_map_def_generated.h", - ":type_def_generated.h", - ], - deps = [ - ":archive_def_cc_fbs", - ":bytecode_def_cc_fbs", - ":debug_service_cc_fbs", - ":device_def_cc_fbs", - ":device_group_def_cc_fbs", - ":device_table_def_cc_fbs", - ":executable_def_cc_fbs", - ":executable_table_def_cc_fbs", - ":function_def_cc_fbs", - ":function_table_def_cc_fbs", - ":module_def_cc_fbs", - ":source_map_def_cc_fbs", - ":spirv_executable_def_cc_fbs", - ":type_def_cc_fbs", - "@com_github_google_flatbuffers//:flatbuffers", - ], -) - -REFLECTION_SRCS = [] if not FLATBUFFER_SUPPORTS_REFLECTIONS else [ - "archive_def.bfbs", - "bytecode_def.bfbs", - "debug_service.bfbs", - "executable_def.bfbs", - "executable_table_def.bfbs", - "function_def.bfbs", - "function_table_def.bfbs", - "module_def.bfbs", - "source_map_def.bfbs", - "spirv_executable_def.bfbs", - "type_def.bfbs", - "device_def.bfbs", - "device_group_def.bfbs", - "device_table_def.bfbs", -] - -cc_embed_data( - name = "reflection_data", - srcs = REFLECTION_SRCS, - cc_file_output = "reflection_data.cc", - cpp_namespace = "iree::schemas", - h_file_output = "reflection_data.h", -)
diff --git a/iree/schemas/archive_def.fbs b/iree/schemas/archive_def.fbs deleted file mode 100644 index 2e96683..0000000 --- a/iree/schemas/archive_def.fbs +++ /dev/null
@@ -1,14 +0,0 @@ -include "iree/schemas/module_def.fbs"; - -namespace iree; - -// 'Executable ARChive'. -file_identifier "EARC"; -file_extension "earc"; - -table ArchiveDef { - name:string; - modules:[ModuleDef]; -} - -root_type ArchiveDef;
diff --git a/iree/schemas/bytecode/BUILD b/iree/schemas/bytecode/BUILD deleted file mode 100644 index 11f5541..0000000 --- a/iree/schemas/bytecode/BUILD +++ /dev/null
@@ -1,29 +0,0 @@ -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "bytecode_v0", - hdrs = ["bytecode_v0.h"], - deps = [ - "//iree/base:bitfield", - "@com_google_absl//absl/base:core_headers", - ], -) - -cc_library( - name = "interpreter_bytecode_v0", - hdrs = ["interpreter_bytecode_v0.h"], - deps = [ - ":bytecode_v0", - ], -) - -cc_library( - name = "sequencer_bytecode_v0", - hdrs = ["sequencer_bytecode_v0.h"], - deps = [ - ":bytecode_v0", - ], -)
diff --git a/iree/schemas/bytecode/bytecode_v0.h b/iree/schemas/bytecode/bytecode_v0.h deleted file mode 100644 index 46f56b8..0000000 --- a/iree/schemas/bytecode/bytecode_v0.h +++ /dev/null
@@ -1,143 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Opcode table for the V0 binary format. -// Additions are fine but changing the behavior or order of any opcodes will -// break pasring of existing files. -// -// Opcodes have been selected on frequency of use, general applicability, and -// relative stability. Experimental ops should be implemented via the FFI fisrt -// before graduating into the core set. Ops that may only be present on certain -// targets should also be kept as imports via the FFI. -// -// Opcodes may be specified for particular types (int32_t), categories of types -// (all floating-point types), or implicit types (output matches input). Saving -// opcode space by sharing a single opcode for multiple types is preferred -// except where hot operations are performed (for example, comparison used in -// loop iteratosr). - -#ifndef IREE_SCHEMAS_BYTECODE_BYTECODE_V0_H_ -#define IREE_SCHEMAS_BYTECODE_BYTECODE_V0_H_ - -#include <cstdint> - -#include "iree/base/bitfield.h" - -namespace iree { - -#define IREE_CONSTANT_ENCODING_LIST(ENC) \ - ENC(0x00, kDense, "dense") \ - ENC(0x01, kSplat, "splat") - -#define IREE_TYPE_LIST(TYP) \ - TYP(0x00, kI8, "i8", 1) \ - TYP(0x01, kI16, "i16", 2) \ - TYP(0x02, kI32, "i32", 4) \ - TYP(0x03, kI64, "i64", 8) \ - TYP(0x04, kF16, "f16", 2) \ - TYP(0x05, kF32, "f32", 4) \ - TYP(0x06, kF64, "f64", 8) \ - TYP(0x80, kDevice, "device", 0) \ - TYP(0x81, kCommandBuffer, "command_buffer", 0) \ - TYP(0x82, kEvent, "event", 0) \ - TYP(0x83, kSemaphore, "semaphore", 0) \ - TYP(0x84, kFence, "fence", 0) \ - TYP(0xFF, kOpaque, "opaque", 0) - -#define IREE_CMPI_PREDICATE_LIST(PRED) \ - PRED(0, kEq, "eq") \ - PRED(1, kNe, "ne") \ - PRED(2, kSlt, "slt") \ - PRED(3, kSle, "sle") \ - PRED(4, kSgt, "sgt") \ - PRED(5, kSge, "sge") \ - PRED(6, kUlt, "ult") \ - PRED(7, kUle, "ule") \ - PRED(8, kUgt, "ugt") \ - PRED(9, kUge, "uge") - -#define IREE_CMPF_PREDICATE_LIST(PRED) \ - PRED(0, kFalse, "false") \ - PRED(1, kOeq, "oeq") \ - PRED(2, kOgt, "ogt") \ - PRED(3, kOge, "oge") \ - PRED(4, kOlt, "olt") \ - PRED(5, kOle, "ole") \ - PRED(6, kOne, "one") \ - PRED(7, kOrd, "ord") \ - PRED(8, kUeq, "ueq") \ - PRED(9, kUgt, "ugt") \ - PRED(10, kUge, "uge") \ - PRED(11, kUlt, "ult") \ - PRED(12, kUle, "ule") \ - PRED(13, kUne, "une") \ - PRED(14, kUno, "uno") \ - PRED(15, kTrue, "true") - -// NOTE: FF is a to-be-defined flag value for encoding/decoding. -#define FLAG(V) ::iree::OpcodeFlag::V - -#define RSV(opcode, RESERVED_OPC) \ - RESERVED_OPC(opcode, kReserved##opcode, "rsv." #opcode, FLAG(kDefault), "", \ - FF) - -#define DECLARE_ENUM(ordinal, enum_name, ...) enum_name = ordinal, - -enum class ConstantEncoding : uint8_t { - IREE_CONSTANT_ENCODING_LIST(DECLARE_ENUM) -}; - -enum class BuiltinType : uint8_t { IREE_TYPE_LIST(DECLARE_ENUM) }; - -enum class CmpIPredicate : uint8_t { IREE_CMPI_PREDICATE_LIST(DECLARE_ENUM) }; - -enum class CmpFPredicate : uint8_t { IREE_CMPF_PREDICATE_LIST(DECLARE_ENUM) }; - -#undef DECLARE_ENUM - -static constexpr uint8_t kBuiltinTypeCount = - static_cast<uint8_t>(BuiltinType::kF64) + 1; - -enum class OpcodeFlag : uint8_t { - kDefault = 0, -}; -IREE_BITFIELD(OpcodeFlag); -using OpcodeFlagBitfield = OpcodeFlag; - -enum class OperandEncoding : char { - kNone = '\0', - kInputSlot = 's', - kVariadicInputSlots = 'S', - kOutputSlot = 'o', - kVariadicOutputSlots = 'O', - kResultSlot = 'r', - kVariadicResultSlots = 'R', - kVariadicTransferSlots = 'T', - kConstant = 'c', - kFunctionOrdinal = 'f', - kImportOrdinal = 'F', - kDispatchOrdinal = 'd', - kBlockOffset = 'b', - kTypeIndex = 't', - kIndex = 'i', - kIndexList = 'I', - kCmpIPredicate = 'p', - kCmpFPredicate = 'P', -}; -IREE_BITFIELD(OperandEncoding); -using OperandEncodingBitfield = OperandEncoding; - -} // namespace iree - -#endif // IREE_SCHEMAS_BYTECODE_BYTECODE_V0_H_
diff --git a/iree/schemas/bytecode/interpreter_bytecode_v0.h b/iree/schemas/bytecode/interpreter_bytecode_v0.h deleted file mode 100644 index 05a0c56..0000000 --- a/iree/schemas/bytecode/interpreter_bytecode_v0.h +++ /dev/null
@@ -1,325 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Opcode table for the V0 binary format. -// Additions are fine but changing the behavior or order of any opcodes will -// break parsing of existing files. -// -// Opcodes have been selected on frequency of use, general applicability, and -// relative stability. Experimental ops should be implemented via the Foreign -// Function Interface (FFI) first before graduating into the core set. Ops that -// may only be present on certain targets should also be kept as imports via the -// FFI. -// -// Opcodes may be specified for particular types (int32_t), categories of types -// (all floating-point types), or implicit types (output matches input). Saving -// opcode space by sharing a single opcode for multiple types is preferred -// except where hot operations are performed (for example, comparison used in -// loop iterators). - -#ifndef IREE_SCHEMAS_BYTECODE_INTERPRETER_BYTECODE_V0_H_ -#define IREE_SCHEMAS_BYTECODE_INTERPRETER_BYTECODE_V0_H_ - -#include "iree/schemas/bytecode/bytecode_v0.h" - -namespace iree { - -#define IREE_INTERPRETER_OPCODE_LIST(OPC, RESERVED_OPC) \ - OPC(0x00, kConstant, "constant", FLAG(kDefault), "cr", FF) \ - \ - OPC(0x01, kCall, "call", FLAG(kDefault), "fSR", FF) \ - OPC(0x02, kCallImport, "call_import", FLAG(kDefault), "FSR", FF) \ - OPC(0x03, kCallIndirect, "call_indirect", FLAG(kDefault), "tsSR", FF) \ - OPC(0x04, kReturn, "return", FLAG(kDefault), "S", FF) \ - OPC(0x05, kBranch, "br", FLAG(kDefault), "bT", FF) \ - OPC(0x06, kCondBranch, "cond_br", FLAG(kDefault), "sbTbT", FF) \ - OPC(0x07, kCmpI, "cmp_i", FLAG(kDefault), "psso", FF) \ - OPC(0x08, kCmpF, "cmp_f", FLAG(kDefault), "Psso", FF) \ - \ - RSV(0x09, RESERVED_OPC) \ - RSV(0x0A, RESERVED_OPC) \ - RSV(0x0B, RESERVED_OPC) \ - RSV(0x0C, RESERVED_OPC) \ - RSV(0x0D, RESERVED_OPC) \ - RSV(0x0E, RESERVED_OPC) \ - RSV(0x0F, RESERVED_OPC) \ - RSV(0x10, RESERVED_OPC) \ - RSV(0x11, RESERVED_OPC) \ - RSV(0x12, RESERVED_OPC) \ - RSV(0x13, RESERVED_OPC) \ - RSV(0x14, RESERVED_OPC) \ - RSV(0x15, RESERVED_OPC) \ - RSV(0x16, RESERVED_OPC) \ - RSV(0x17, RESERVED_OPC) \ - RSV(0x18, RESERVED_OPC) \ - RSV(0x19, RESERVED_OPC) \ - RSV(0x1A, RESERVED_OPC) \ - RSV(0x1B, RESERVED_OPC) \ - RSV(0x1C, RESERVED_OPC) \ - RSV(0x1D, RESERVED_OPC) \ - RSV(0x1E, RESERVED_OPC) \ - RSV(0x1F, RESERVED_OPC) \ - \ - OPC(0x20, kAllocStatic, "alloc_static", FLAG(kDefault), "Icr", FF) \ - OPC(0x21, kAllocStack, "alloc_stack", FLAG(kDefault), "itISr", FF) \ - OPC(0x22, kAllocStackInit, "alloc_stack_init", FLAG(kDefault), "tIScr", FF) \ - OPC(0x23, kAllocHeap, "alloc_heap", FLAG(kDefault), "itISr", FF) \ - OPC(0x24, kDiscard, "discard", FLAG(kDefault), "s", FF) \ - \ - RSV(0x25, RESERVED_OPC) \ - RSV(0x26, RESERVED_OPC) \ - RSV(0x27, RESERVED_OPC) \ - RSV(0x28, RESERVED_OPC) \ - RSV(0x29, RESERVED_OPC) \ - RSV(0x2A, RESERVED_OPC) \ - RSV(0x2B, RESERVED_OPC) \ - RSV(0x2C, RESERVED_OPC) \ - RSV(0x2D, RESERVED_OPC) \ - RSV(0x2E, RESERVED_OPC) \ - RSV(0x2F, RESERVED_OPC) \ - \ - OPC(0x30, kRank, "rank", FLAG(kDefault), "so", FF) \ - OPC(0x31, kDim, "dim", FLAG(kDefault), "iso", FF) \ - OPC(0x32, kShape, "shape", FLAG(kDefault), "so", FF) \ - OPC(0x33, kLength, "length", FLAG(kDefault), "so", FF) \ - OPC(0x34, kDynamicSlice, "dynamic_slice", FLAG(kDefault), "sssr", FF) \ - OPC(0x35, kStaticSlice, "static_slice", FLAG(kDefault), "sIIr", FF) \ - OPC(0x36, kDynamicCopy, "dynamic_copy", FLAG(kDefault), "ssoss", FF) \ - OPC(0x37, kStaticCopy, "static_copy", FLAG(kDefault), "sIoII", FF) \ - OPC(0x38, kClone, "clone", FLAG(kDefault), "sr", FF) \ - RSV(0x39, RESERVED_OPC) \ - OPC(0x3A, kSplit, "split", FLAG(kDefault), "isR", FF) \ - OPC(0x3B, kAssign, "assign", FLAG(kDefault), "sr", FF) \ - OPC(0x3C, kCondAssign, "cond_assign", FLAG(kDefault), "sssr", FF) \ - OPC(0x3D, kReshape, "reshape", FLAG(kDefault), "ssr", FF) \ - OPC(0x3E, kSelect, "select", FLAG(kDefault), "ssso", FF) \ - OPC(0x3F, kTranspose, "transpose", FLAG(kDefault), "sso", FF) \ - OPC(0x40, kBroadcast, "broadcast", FLAG(kDefault), "sso", FF) \ - OPC(0x41, kTile, "tile", FLAG(kDefault), "sso", FF) \ - OPC(0x42, kReverse, "reverse", FLAG(kDefault), "sso", FF) \ - OPC(0x43, kPad, "pad", FLAG(kDefault), "ssssso", FF) \ - \ - RSV(0x44, RESERVED_OPC) \ - RSV(0x45, RESERVED_OPC) \ - RSV(0x46, RESERVED_OPC) \ - RSV(0x47, RESERVED_OPC) \ - RSV(0x48, RESERVED_OPC) \ - RSV(0x49, RESERVED_OPC) \ - RSV(0x4A, RESERVED_OPC) \ - RSV(0x4B, RESERVED_OPC) \ - RSV(0x4C, RESERVED_OPC) \ - RSV(0x4D, RESERVED_OPC) \ - RSV(0x4E, RESERVED_OPC) \ - RSV(0x4F, RESERVED_OPC) \ - \ - OPC(0x50, kNot, "not", FLAG(kDefault), "so", FF) \ - OPC(0x51, kAnd, "and", FLAG(kDefault), "sso", FF) \ - OPC(0x52, kOr, "or", FLAG(kDefault), "sso", FF) \ - OPC(0x53, kXor, "xor", FLAG(kDefault), "sso", FF) \ - OPC(0x54, kShiftLeft, "sll", FLAG(kDefault), "sso", FF) \ - OPC(0x55, kShiftRightLogical, "srl", FLAG(kDefault), "sso", FF) \ - OPC(0x56, kShiftRightArithmetic, "sra", FLAG(kDefault), "sso", FF) \ - \ - RSV(0x57, RESERVED_OPC) \ - RSV(0x58, RESERVED_OPC) \ - RSV(0x59, RESERVED_OPC) \ - RSV(0x5A, RESERVED_OPC) \ - RSV(0x5B, RESERVED_OPC) \ - RSV(0x5C, RESERVED_OPC) \ - RSV(0x5D, RESERVED_OPC) \ - RSV(0x5E, RESERVED_OPC) \ - RSV(0x5F, RESERVED_OPC) \ - RSV(0x60, RESERVED_OPC) \ - RSV(0x61, RESERVED_OPC) \ - RSV(0x62, RESERVED_OPC) \ - RSV(0x63, RESERVED_OPC) \ - RSV(0x64, RESERVED_OPC) \ - RSV(0x65, RESERVED_OPC) \ - RSV(0x66, RESERVED_OPC) \ - RSV(0x67, RESERVED_OPC) \ - RSV(0x68, RESERVED_OPC) \ - RSV(0x69, RESERVED_OPC) \ - RSV(0x6A, RESERVED_OPC) \ - RSV(0x6B, RESERVED_OPC) \ - RSV(0x6C, RESERVED_OPC) \ - RSV(0x6D, RESERVED_OPC) \ - RSV(0x6E, RESERVED_OPC) \ - RSV(0x6F, RESERVED_OPC) \ - \ - /* TODO(benvanik): remove ones we don't need/can emulate */ \ - OPC(0x70, kAddI, "add_i", FLAG(kDefault), "sso", FF) \ - OPC(0x71, kAddF, "add_f", FLAG(kDefault), "sso", FF) \ - OPC(0x72, kSubI, "sub_i", FLAG(kDefault), "sso", FF) \ - OPC(0x73, kSubF, "sub_f", FLAG(kDefault), "sso", FF) \ - OPC(0x74, kAbsI, "abs_i", FLAG(kDefault), "so", FF) \ - OPC(0x75, kAbsF, "abs_f", FLAG(kDefault), "so", FF) \ - OPC(0x76, kMulI, "mul_i", FLAG(kDefault), "sso", FF) \ - OPC(0x77, kMulF, "mul_f", FLAG(kDefault), "sso", FF) \ - OPC(0x78, kDivIS, "div_i_s", FLAG(kDefault), "sso", FF) \ - OPC(0x79, kDivIU, "div_i_u", FLAG(kDefault), "sso", FF) \ - OPC(0x7A, kDivF, "div_f", FLAG(kDefault), "sso", FF) \ - OPC(0x7B, kMulAddI, "madd_i", FLAG(kDefault), "ssso", FF) \ - OPC(0x7C, kMulAddF, "madd_f", FLAG(kDefault), "ssso", FF) \ - OPC(0x7D, kCosF, "cos_f", FLAG(kDefault), "so", FF) \ - OPC(0x7E, kSinF, "sin_f", FLAG(kDefault), "so", FF) \ - OPC(0x7F, kTanhF, "tanh_f", FLAG(kDefault), "so", FF) \ - OPC(0x80, kAtan2F, "atan2_f", FLAG(kDefault), "sso", FF) \ - OPC(0x81, kExpF, "exp_f", FLAG(kDefault), "so", FF) \ - OPC(0x82, kLogF, "log_f", FLAG(kDefault), "so", FF) \ - OPC(0x83, kRsqrtF, "rsqrt_f", FLAG(kDefault), "so", FF) \ - \ - RSV(0x84, RESERVED_OPC) \ - RSV(0x85, RESERVED_OPC) \ - RSV(0x86, RESERVED_OPC) \ - RSV(0x87, RESERVED_OPC) \ - RSV(0x88, RESERVED_OPC) \ - RSV(0x89, RESERVED_OPC) \ - RSV(0x8A, RESERVED_OPC) \ - RSV(0x8B, RESERVED_OPC) \ - RSV(0x8C, RESERVED_OPC) \ - RSV(0x8D, RESERVED_OPC) \ - RSV(0x8E, RESERVED_OPC) \ - RSV(0x8F, RESERVED_OPC) \ - \ - OPC(0x90, kMinIS, "min_i_s", FLAG(kDefault), "sso", FF) \ - OPC(0x91, kMinIU, "min_i_u", FLAG(kDefault), "sso", FF) \ - OPC(0x92, kMinF, "min_f", FLAG(kDefault), "sso", FF) \ - OPC(0x93, kMaxIS, "max_i_s", FLAG(kDefault), "sso", FF) \ - OPC(0x94, kMaxIU, "max_i_u", FLAG(kDefault), "sso", FF) \ - OPC(0x95, kMaxF, "max_f", FLAG(kDefault), "sso", FF) \ - OPC(0x96, kClampIS, "clamp_i_s", FLAG(kDefault), "ssso", FF) \ - OPC(0x97, kClampIU, "clamp_i_u", FLAG(kDefault), "ssso", FF) \ - OPC(0x98, kClampF, "clamp_f", FLAG(kDefault), "ssso", FF) \ - OPC(0x99, kFloorF, "floor_f", FLAG(kDefault), "so", FF) \ - OPC(0x9A, kCeilF, "ceil_f", FLAG(kDefault), "so", FF) \ - \ - OPC(0x9B, kConvertSS, "convert_s_s", FLAG(kDefault), "tsto", FF) \ - OPC(0x9C, kConvertUU, "convert_u_u", FLAG(kDefault), "tsto", FF) \ - OPC(0x9D, kConvertSU, "convert_s_u", FLAG(kDefault), "tsto", FF) \ - OPC(0x9E, kConvertUS, "convert_u_s", FLAG(kDefault), "tsto", FF) \ - \ - RSV(0x9F, RESERVED_OPC) \ - \ - /* TODO(benvanik): reduction/sum/etc */ \ - /* TODO(benvanik): sort */ \ - \ - OPC(0xA0, kMatMulI, "matmul_i", FLAG(kDefault), "sssso", FF) \ - OPC(0xA1, kMatMulF, "matmul_f", FLAG(kDefault), "sso", FF) \ - /* TODO(benvanik): convolution */ \ - \ - OPC(0xA2, kReduceSumI, "reduce_sum_i", FLAG(kDefault), "ssio", FF) \ - OPC(0xA3, kReduceSumF, "reduce_sum_f", FLAG(kDefault), "ssio", FF) \ - OPC(0xA4, kReduceMinI, "reduce_min_i", FLAG(kDefault), "ssio", FF) \ - OPC(0xA5, kReduceMinF, "reduce_min_f", FLAG(kDefault), "ssio", FF) \ - OPC(0xA6, kReduceMaxI, "reduce_max_i", FLAG(kDefault), "ssio", FF) \ - OPC(0xA7, kReduceMaxF, "reduce_max_f", FLAG(kDefault), "ssio", FF) \ - RSV(0xA8, RESERVED_OPC) \ - RSV(0xA9, RESERVED_OPC) \ - RSV(0xAA, RESERVED_OPC) \ - RSV(0xAB, RESERVED_OPC) \ - RSV(0xAC, RESERVED_OPC) \ - RSV(0xAD, RESERVED_OPC) \ - RSV(0xAE, RESERVED_OPC) \ - RSV(0xAF, RESERVED_OPC) \ - RSV(0xB0, RESERVED_OPC) \ - RSV(0xB1, RESERVED_OPC) \ - RSV(0xB2, RESERVED_OPC) \ - RSV(0xB3, RESERVED_OPC) \ - RSV(0xB4, RESERVED_OPC) \ - RSV(0xB5, RESERVED_OPC) \ - RSV(0xB6, RESERVED_OPC) \ - RSV(0xB7, RESERVED_OPC) \ - RSV(0xB8, RESERVED_OPC) \ - RSV(0xB9, RESERVED_OPC) \ - RSV(0xBA, RESERVED_OPC) \ - RSV(0xBB, RESERVED_OPC) \ - RSV(0xBC, RESERVED_OPC) \ - RSV(0xBD, RESERVED_OPC) \ - RSV(0xBE, RESERVED_OPC) \ - RSV(0xBF, RESERVED_OPC) \ - RSV(0xC0, RESERVED_OPC) \ - RSV(0xC1, RESERVED_OPC) \ - RSV(0xC2, RESERVED_OPC) \ - RSV(0xC3, RESERVED_OPC) \ - RSV(0xC4, RESERVED_OPC) \ - RSV(0xC5, RESERVED_OPC) \ - RSV(0xC6, RESERVED_OPC) \ - RSV(0xC7, RESERVED_OPC) \ - RSV(0xC8, RESERVED_OPC) \ - RSV(0xC9, RESERVED_OPC) \ - RSV(0xCA, RESERVED_OPC) \ - RSV(0xCB, RESERVED_OPC) \ - RSV(0xCC, RESERVED_OPC) \ - RSV(0xCD, RESERVED_OPC) \ - RSV(0xCE, RESERVED_OPC) \ - RSV(0xCF, RESERVED_OPC) \ - RSV(0xD0, RESERVED_OPC) \ - RSV(0xD1, RESERVED_OPC) \ - RSV(0xD2, RESERVED_OPC) \ - RSV(0xD3, RESERVED_OPC) \ - RSV(0xD4, RESERVED_OPC) \ - RSV(0xD5, RESERVED_OPC) \ - RSV(0xD6, RESERVED_OPC) \ - RSV(0xD7, RESERVED_OPC) \ - RSV(0xD8, RESERVED_OPC) \ - RSV(0xD9, RESERVED_OPC) \ - RSV(0xDA, RESERVED_OPC) \ - RSV(0xDB, RESERVED_OPC) \ - RSV(0xDC, RESERVED_OPC) \ - RSV(0xDD, RESERVED_OPC) \ - RSV(0xDE, RESERVED_OPC) \ - RSV(0xDF, RESERVED_OPC) \ - RSV(0xE0, RESERVED_OPC) \ - RSV(0xE1, RESERVED_OPC) \ - RSV(0xE2, RESERVED_OPC) \ - RSV(0xE3, RESERVED_OPC) \ - RSV(0xE4, RESERVED_OPC) \ - RSV(0xE5, RESERVED_OPC) \ - RSV(0xE6, RESERVED_OPC) \ - RSV(0xE7, RESERVED_OPC) \ - RSV(0xE8, RESERVED_OPC) \ - RSV(0xE9, RESERVED_OPC) \ - RSV(0xEA, RESERVED_OPC) \ - RSV(0xEB, RESERVED_OPC) \ - RSV(0xEC, RESERVED_OPC) \ - RSV(0xED, RESERVED_OPC) \ - RSV(0xEE, RESERVED_OPC) \ - RSV(0xEF, RESERVED_OPC) \ - RSV(0xF0, RESERVED_OPC) \ - RSV(0xF1, RESERVED_OPC) \ - RSV(0xF2, RESERVED_OPC) \ - RSV(0xF3, RESERVED_OPC) \ - RSV(0xF4, RESERVED_OPC) \ - RSV(0xF5, RESERVED_OPC) \ - RSV(0xF6, RESERVED_OPC) \ - RSV(0xF7, RESERVED_OPC) \ - RSV(0xF8, RESERVED_OPC) \ - RSV(0xF9, RESERVED_OPC) \ - RSV(0xFA, RESERVED_OPC) \ - RSV(0xFB, RESERVED_OPC) \ - RSV(0xFC, RESERVED_OPC) \ - \ - OPC(0xFD, kTrace, "trace", FLAG(kDefault), "s", FF) \ - OPC(0xFE, kCondBreak, "cond_break", FLAG(kDefault), "s", FF) \ - OPC(0xFF, kBreak, "break", FLAG(kDefault), "", FF) - -#define DECLARE_ENUM(ordinal, enum_name, ...) enum_name = ordinal, -enum class InterpreterOpcode : uint8_t { - IREE_INTERPRETER_OPCODE_LIST(DECLARE_ENUM, DECLARE_ENUM) -}; -#undef DECLARE_ENUM - -} // namespace iree - -#endif // IREE_SCHEMAS_BYTECODE_INTERPRETER_BYTECODE_V0_H_
diff --git a/iree/schemas/bytecode/sequencer_bytecode_v0.h b/iree/schemas/bytecode/sequencer_bytecode_v0.h deleted file mode 100644 index 7cba224..0000000 --- a/iree/schemas/bytecode/sequencer_bytecode_v0.h +++ /dev/null
@@ -1,313 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Opcode table for the V0 binary format. -// Additions are fine but changing the behavior or order of any opcodes will -// break parsing of existing files. -// -// Opcodes have been selected on frequency of use, general applicability, and -// relative stability. Experimental ops should be implemented via the Foreign -// Function Interface (FFI) first before graduating into the core set. Ops that -// may only be present on certain targets should also be kept as imports via the -// FFI. -// -// Opcodes may be specified for particular types (int32_t), categories of types -// (all floating-point types), or implicit types (output matches input). Saving -// opcode space by sharing a single opcode for multiple types is preferred -// except where hot operations are performed (for example, comparison used in -// loop iterators). - -#ifndef IREE_SCHEMAS_BYTECODE_SEQUENCER_BYTECODE_V0_H_ -#define IREE_SCHEMAS_BYTECODE_SEQUENCER_BYTECODE_V0_H_ - -#include "iree/schemas/bytecode/bytecode_v0.h" - -namespace iree { - -#define IREE_SEQUENCER_OPCODE_LIST(OPC, RESERVED_OPC) \ - OPC(0x00, kConstant, "constant", FLAG(kDefault), "cr", FF) \ - \ - OPC(0x01, kCall, "call", FLAG(kDefault), "fSR", FF) \ - OPC(0x02, kCallImport, "call_import", FLAG(kDefault), "FSR", FF) \ - OPC(0x03, kCallIndirect, "call_indirect", FLAG(kDefault), "tsSR", FF) \ - OPC(0x04, kReturn, "return", FLAG(kDefault), "S", FF) \ - OPC(0x05, kBranch, "br", FLAG(kDefault), "bT", FF) \ - OPC(0x06, kCondBranch, "cond_br", FLAG(kDefault), "sbTbT", FF) \ - \ - RSV(0x07, RESERVED_OPC) \ - RSV(0x08, RESERVED_OPC) \ - RSV(0x09, RESERVED_OPC) \ - RSV(0x0A, RESERVED_OPC) \ - RSV(0x0B, RESERVED_OPC) \ - RSV(0x0C, RESERVED_OPC) \ - RSV(0x0D, RESERVED_OPC) \ - RSV(0x0E, RESERVED_OPC) \ - RSV(0x0F, RESERVED_OPC) \ - \ - OPC(0x10, kDynamicDispatch, "dynamic_dispatch", FLAG(kDefault), "dsSOR", FF) \ - OPC(0x11, kStaticDispatch, "static_dispatch", FLAG(kDefault), "diiiSOR", FF) \ - \ - RSV(0x12, RESERVED_OPC) \ - RSV(0x13, RESERVED_OPC) \ - RSV(0x14, RESERVED_OPC) \ - RSV(0x15, RESERVED_OPC) \ - RSV(0x16, RESERVED_OPC) \ - RSV(0x17, RESERVED_OPC) \ - RSV(0x18, RESERVED_OPC) \ - RSV(0x19, RESERVED_OPC) \ - RSV(0x1A, RESERVED_OPC) \ - RSV(0x1B, RESERVED_OPC) \ - RSV(0x1C, RESERVED_OPC) \ - RSV(0x1D, RESERVED_OPC) \ - RSV(0x1E, RESERVED_OPC) \ - RSV(0x1F, RESERVED_OPC) \ - \ - OPC(0x20, kAllocStatic, "alloc_static", FLAG(kDefault), "Icr", FF) \ - OPC(0x21, kAllocStack, "alloc_stack", FLAG(kDefault), "itISr", FF) \ - OPC(0x22, kAllocStackInit, "alloc_stack_init", FLAG(kDefault), "tIScr", FF) \ - OPC(0x23, kAllocHeap, "alloc_heap", FLAG(kDefault), "itISr", FF) \ - OPC(0x24, kDiscard, "discard", FLAG(kDefault), "s", FF) \ - \ - RSV(0x25, RESERVED_OPC) \ - RSV(0x26, RESERVED_OPC) \ - RSV(0x27, RESERVED_OPC) \ - RSV(0x28, RESERVED_OPC) \ - RSV(0x29, RESERVED_OPC) \ - RSV(0x2A, RESERVED_OPC) \ - RSV(0x2B, RESERVED_OPC) \ - RSV(0x2C, RESERVED_OPC) \ - RSV(0x2D, RESERVED_OPC) \ - RSV(0x2E, RESERVED_OPC) \ - RSV(0x2F, RESERVED_OPC) \ - RSV(0x30, RESERVED_OPC) \ - \ - OPC(0x31, kComputeRange, "compute_range", FLAG(kDefault), "sissoo", FF) \ - OPC(0x32, kShape, "shape", FLAG(kDefault), "so", FF) \ - OPC(0x33, kLength, "length", FLAG(kDefault), "so", FF) \ - OPC(0x34, kDynamicSlice, "dynamic_slice", FLAG(kDefault), "ssstsr", FF) \ - OPC(0x35, kStaticSlice, "static_slice", FLAG(kDefault), "siitIr", FF) \ - OPC(0x36, kDynamicCopy, "dynamic_copy", FLAG(kDefault), "ssoss", FF) \ - OPC(0x37, kStaticCopy, "static_copy", FLAG(kDefault), "sioii", FF) \ - OPC(0x38, kDynamicFill, "dynamic_fill", FLAG(kDefault), "soss", FF) \ - OPC(0x39, kStaticFill, "static_fill", FLAG(kDefault), "ioii", FF) \ - OPC(0x3A, kClone, "clone", FLAG(kDefault), "sr", FF) \ - OPC(0x3B, kAssign, "assign", FLAG(kDefault), "sr", FF) \ - OPC(0x3C, kCondAssign, "cond_assign", FLAG(kDefault), "sssr", FF) \ - OPC(0x3D, kReshape, "reshape", FLAG(kDefault), "ssr", FF) \ - \ - RSV(0x3E, RESERVED_OPC) \ - RSV(0x3F, RESERVED_OPC) \ - RSV(0x40, RESERVED_OPC) \ - RSV(0x41, RESERVED_OPC) \ - RSV(0x42, RESERVED_OPC) \ - RSV(0x43, RESERVED_OPC) \ - RSV(0x44, RESERVED_OPC) \ - RSV(0x45, RESERVED_OPC) \ - RSV(0x46, RESERVED_OPC) \ - RSV(0x47, RESERVED_OPC) \ - RSV(0x48, RESERVED_OPC) \ - RSV(0x49, RESERVED_OPC) \ - RSV(0x4A, RESERVED_OPC) \ - RSV(0x4B, RESERVED_OPC) \ - RSV(0x4C, RESERVED_OPC) \ - RSV(0x4D, RESERVED_OPC) \ - RSV(0x4E, RESERVED_OPC) \ - RSV(0x4F, RESERVED_OPC) \ - RSV(0x50, RESERVED_OPC) \ - RSV(0x51, RESERVED_OPC) \ - RSV(0x52, RESERVED_OPC) \ - RSV(0x53, RESERVED_OPC) \ - RSV(0x54, RESERVED_OPC) \ - RSV(0x55, RESERVED_OPC) \ - RSV(0x56, RESERVED_OPC) \ - RSV(0x57, RESERVED_OPC) \ - RSV(0x58, RESERVED_OPC) \ - RSV(0x59, RESERVED_OPC) \ - RSV(0x5A, RESERVED_OPC) \ - RSV(0x5B, RESERVED_OPC) \ - RSV(0x5C, RESERVED_OPC) \ - RSV(0x5D, RESERVED_OPC) \ - RSV(0x5E, RESERVED_OPC) \ - RSV(0x5F, RESERVED_OPC) \ - RSV(0x60, RESERVED_OPC) \ - RSV(0x61, RESERVED_OPC) \ - RSV(0x62, RESERVED_OPC) \ - RSV(0x63, RESERVED_OPC) \ - RSV(0x64, RESERVED_OPC) \ - RSV(0x65, RESERVED_OPC) \ - RSV(0x66, RESERVED_OPC) \ - RSV(0x67, RESERVED_OPC) \ - RSV(0x68, RESERVED_OPC) \ - RSV(0x69, RESERVED_OPC) \ - RSV(0x6A, RESERVED_OPC) \ - RSV(0x6B, RESERVED_OPC) \ - RSV(0x6C, RESERVED_OPC) \ - RSV(0x6D, RESERVED_OPC) \ - RSV(0x6E, RESERVED_OPC) \ - RSV(0x6F, RESERVED_OPC) \ - RSV(0x70, RESERVED_OPC) \ - RSV(0x71, RESERVED_OPC) \ - RSV(0x72, RESERVED_OPC) \ - RSV(0x73, RESERVED_OPC) \ - RSV(0x74, RESERVED_OPC) \ - RSV(0x75, RESERVED_OPC) \ - RSV(0x76, RESERVED_OPC) \ - RSV(0x77, RESERVED_OPC) \ - RSV(0x78, RESERVED_OPC) \ - RSV(0x79, RESERVED_OPC) \ - RSV(0x7A, RESERVED_OPC) \ - RSV(0x7B, RESERVED_OPC) \ - RSV(0x7C, RESERVED_OPC) \ - RSV(0x7D, RESERVED_OPC) \ - RSV(0x7E, RESERVED_OPC) \ - RSV(0x7F, RESERVED_OPC) \ - RSV(0x80, RESERVED_OPC) \ - RSV(0x81, RESERVED_OPC) \ - RSV(0x82, RESERVED_OPC) \ - RSV(0x83, RESERVED_OPC) \ - RSV(0x84, RESERVED_OPC) \ - RSV(0x85, RESERVED_OPC) \ - RSV(0x86, RESERVED_OPC) \ - RSV(0x87, RESERVED_OPC) \ - RSV(0x88, RESERVED_OPC) \ - RSV(0x89, RESERVED_OPC) \ - RSV(0x8A, RESERVED_OPC) \ - RSV(0x8B, RESERVED_OPC) \ - RSV(0x8C, RESERVED_OPC) \ - RSV(0x8D, RESERVED_OPC) \ - RSV(0x8E, RESERVED_OPC) \ - RSV(0x8F, RESERVED_OPC) \ - RSV(0x90, RESERVED_OPC) \ - RSV(0x91, RESERVED_OPC) \ - RSV(0x92, RESERVED_OPC) \ - RSV(0x93, RESERVED_OPC) \ - RSV(0x94, RESERVED_OPC) \ - RSV(0x95, RESERVED_OPC) \ - RSV(0x96, RESERVED_OPC) \ - RSV(0x97, RESERVED_OPC) \ - RSV(0x98, RESERVED_OPC) \ - RSV(0x99, RESERVED_OPC) \ - RSV(0x9A, RESERVED_OPC) \ - RSV(0x9B, RESERVED_OPC) \ - RSV(0x9C, RESERVED_OPC) \ - RSV(0x9D, RESERVED_OPC) \ - RSV(0x9E, RESERVED_OPC) \ - RSV(0x9F, RESERVED_OPC) \ - RSV(0xA0, RESERVED_OPC) \ - RSV(0xA1, RESERVED_OPC) \ - RSV(0xA2, RESERVED_OPC) \ - RSV(0xA3, RESERVED_OPC) \ - RSV(0xA4, RESERVED_OPC) \ - RSV(0xA5, RESERVED_OPC) \ - RSV(0xA6, RESERVED_OPC) \ - RSV(0xA7, RESERVED_OPC) \ - RSV(0xA8, RESERVED_OPC) \ - RSV(0xA9, RESERVED_OPC) \ - RSV(0xAA, RESERVED_OPC) \ - RSV(0xAB, RESERVED_OPC) \ - RSV(0xAC, RESERVED_OPC) \ - RSV(0xAD, RESERVED_OPC) \ - RSV(0xAE, RESERVED_OPC) \ - RSV(0xAF, RESERVED_OPC) \ - RSV(0xB0, RESERVED_OPC) \ - RSV(0xB1, RESERVED_OPC) \ - RSV(0xB2, RESERVED_OPC) \ - RSV(0xB3, RESERVED_OPC) \ - RSV(0xB4, RESERVED_OPC) \ - RSV(0xB5, RESERVED_OPC) \ - RSV(0xB6, RESERVED_OPC) \ - RSV(0xB7, RESERVED_OPC) \ - RSV(0xB8, RESERVED_OPC) \ - RSV(0xB9, RESERVED_OPC) \ - RSV(0xBA, RESERVED_OPC) \ - RSV(0xBB, RESERVED_OPC) \ - RSV(0xBC, RESERVED_OPC) \ - RSV(0xBD, RESERVED_OPC) \ - RSV(0xBE, RESERVED_OPC) \ - RSV(0xBF, RESERVED_OPC) \ - RSV(0xC0, RESERVED_OPC) \ - RSV(0xC1, RESERVED_OPC) \ - RSV(0xC2, RESERVED_OPC) \ - RSV(0xC3, RESERVED_OPC) \ - RSV(0xC4, RESERVED_OPC) \ - RSV(0xC5, RESERVED_OPC) \ - RSV(0xC6, RESERVED_OPC) \ - RSV(0xC7, RESERVED_OPC) \ - RSV(0xC8, RESERVED_OPC) \ - RSV(0xC9, RESERVED_OPC) \ - RSV(0xCA, RESERVED_OPC) \ - RSV(0xCB, RESERVED_OPC) \ - RSV(0xCC, RESERVED_OPC) \ - RSV(0xCD, RESERVED_OPC) \ - RSV(0xCE, RESERVED_OPC) \ - RSV(0xCF, RESERVED_OPC) \ - RSV(0xD0, RESERVED_OPC) \ - RSV(0xD1, RESERVED_OPC) \ - RSV(0xD2, RESERVED_OPC) \ - RSV(0xD3, RESERVED_OPC) \ - RSV(0xD4, RESERVED_OPC) \ - RSV(0xD5, RESERVED_OPC) \ - RSV(0xD6, RESERVED_OPC) \ - RSV(0xD7, RESERVED_OPC) \ - RSV(0xD8, RESERVED_OPC) \ - RSV(0xD9, RESERVED_OPC) \ - RSV(0xDA, RESERVED_OPC) \ - RSV(0xDB, RESERVED_OPC) \ - RSV(0xDC, RESERVED_OPC) \ - RSV(0xDD, RESERVED_OPC) \ - RSV(0xDE, RESERVED_OPC) \ - RSV(0xDF, RESERVED_OPC) \ - RSV(0xE0, RESERVED_OPC) \ - RSV(0xE1, RESERVED_OPC) \ - RSV(0xE2, RESERVED_OPC) \ - RSV(0xE3, RESERVED_OPC) \ - RSV(0xE4, RESERVED_OPC) \ - RSV(0xE5, RESERVED_OPC) \ - RSV(0xE6, RESERVED_OPC) \ - RSV(0xE7, RESERVED_OPC) \ - RSV(0xE8, RESERVED_OPC) \ - RSV(0xE9, RESERVED_OPC) \ - RSV(0xEA, RESERVED_OPC) \ - RSV(0xEB, RESERVED_OPC) \ - RSV(0xEC, RESERVED_OPC) \ - RSV(0xED, RESERVED_OPC) \ - RSV(0xEE, RESERVED_OPC) \ - RSV(0xEF, RESERVED_OPC) \ - RSV(0xF0, RESERVED_OPC) \ - RSV(0xF1, RESERVED_OPC) \ - RSV(0xF2, RESERVED_OPC) \ - RSV(0xF3, RESERVED_OPC) \ - RSV(0xF4, RESERVED_OPC) \ - RSV(0xF5, RESERVED_OPC) \ - RSV(0xF6, RESERVED_OPC) \ - RSV(0xF7, RESERVED_OPC) \ - RSV(0xF8, RESERVED_OPC) \ - RSV(0xF9, RESERVED_OPC) \ - RSV(0xFA, RESERVED_OPC) \ - RSV(0xFB, RESERVED_OPC) \ - RSV(0xFC, RESERVED_OPC) \ - \ - OPC(0xFD, kTrace, "trace", FLAG(kDefault), "s", FF) \ - OPC(0xFE, kCondBreak, "cond_break", FLAG(kDefault), "s", FF) \ - OPC(0xFF, kBreak, "break", FLAG(kDefault), "", FF) - -#define DECLARE_ENUM(ordinal, enum_name, ...) enum_name = ordinal, -enum class SequencerOpcode : uint8_t { - IREE_SEQUENCER_OPCODE_LIST(DECLARE_ENUM, DECLARE_ENUM) -}; -#undef DECLARE_ENUM - -} // namespace iree - -#endif // IREE_SCHEMAS_BYTECODE_SEQUENCER_BYTECODE_V0_H_
diff --git a/iree/schemas/debug_service.fbs b/iree/schemas/debug_service.fbs deleted file mode 100644 index 36e84b5..0000000 --- a/iree/schemas/debug_service.fbs +++ /dev/null
@@ -1,347 +0,0 @@ -include "iree/schemas/function_def.fbs"; -include "iree/schemas/module_def.fbs"; - -namespace iree.rt.debug.rpc; - -table Status { - code:int; - message:string; -} - -table CreateSessionRequest { -} -table CreateSessionResponse { - session_id:int; -} - -table MakeReadyRequest { - session_id:int; -} -table MakeReadyResponse { -} - -table GetStatusRequest { - session_id:int; - // TODO(benvanik): caps debugger supports? version expected? -} -table GetStatusResponse { - protocol:int; - // TODO(benvanik): run state. - // TODO(benvanik): profiling state. -} - -table NativeFunctionDef { - name:string; - // TODO(benvanik): more information about the fns (stack trace of registrant?) -} - -table ContextDef { - context_id:int; - native_functions:[NativeFunctionDef]; - module_names:[string]; -} - -table ListContextsRequest { - session_id:int; -} - -table ListContextsResponse { - contexts:[ContextDef]; -} - -table GetModuleRequest { - session_id:int; - context_id:int; - module_name:string; -} -table GetModuleResponse { - module:ModuleDef; -} - -table GetFunctionRequest { - session_id:int; - context_id:int; - module_name:string; - function_ordinal:int; -} -table GetFunctionResponse { - bytecode:BytecodeDef; - // TODO(benvanik): import info (linked module, etc). -} - -table ResolveFunctionRequest { - session_id:int; - module_name:string; - function_name:string; -} -table ResolveFunctionResponse { - context_ids:[int]; - function_ordinal:int; -} - -table BufferViewDef { - is_valid:bool; - shape:[int]; - element_size:int; - // TODO(benvanik): buffer attrs (type, access, usage). - // TODO(benvanik): buffer size/allocated_size. - // TODO(benvanik): buffer data (if accessible). -} - -table StackFrameDef { - module_name:string; - function_ordinal:int; - offset:int; - locals:[BufferViewDef]; -} - -table InvocationDef { - invocation_id:int; - frames:[StackFrameDef]; -} - -table ListInvocationsRequest { - session_id:int; -} -table ListInvocationsResponse { - invocations:[InvocationDef]; -} - -table SuspendInvocationsRequest { - session_id:int; - invocation_ids:[int]; -} -table SuspendInvocationsResponse { - invocations:[InvocationDef]; -} - -table ResumeInvocationsRequest { - session_id:int; - invocation_ids:[int]; -} -table ResumeInvocationsResponse { -} - -enum StepMode : uint8 { - STEP_ONCE = 0, - STEP_TO_OFFSET = 1, -} - -table StepInvocationRequest { - session_id:int; - step_id:int; - invocation_id:int; - step_mode:StepMode; - bytecode_offset:int; -} -table StepInvocationResponse {} - -table GetInvocationLocalRequest { - session_id:int; - invocation_id:int; - frame_index:int; - local_index:int; -} -table GetInvocationLocalResponse { - value:BufferViewDef; -} - -table SetInvocationLocalRequest { - session_id:int; - invocation_id:int; - frame_index:int; - local_index:int; - value:BufferViewDef; -} -table SetInvocationLocalResponse { - value:BufferViewDef; -} - -enum BreakpointType : uint8 { - BYTECODE_FUNCTION = 0, - NATIVE_FUNCTION = 1, -} - -table BreakpointDef { - breakpoint_id:int; - breakpoint_type:BreakpointType; - - module_name:string; - function_name:string; - function_ordinal:int; - bytecode_offset:int; -} - -table ListBreakpointsRequest { - session_id:int; -} -table ListBreakpointsResponse { - breakpoints:[BreakpointDef]; -} - -table AddBreakpointRequest { - session_id:int; - breakpoint:BreakpointDef; -} -table AddBreakpointResponse { - breakpoint:BreakpointDef; -} - -table RemoveBreakpointRequest { - session_id:int; - breakpoint_id:int; -} -table RemoveBreakpointResponse { -} - -table StartProfilingRequest { - session_id:int; - context_id:int; - // TODO(benvanik): profiling mode. - // mode: sampling_timing, instrumented_coverage, instrumented_log, - // invoke_log -} -table StartProfilingResponse { - // TODO(benvanik): current/new mode. -} - -table StopProfilingRequest { - session_id:int; - context_id:int; -} -table StopProfilingResponse { - // TODO(benvanik): profiling data. -} - -// TODO(benvanik): streaming profiling data query. - -table ServiceShutdownEvent { -} - -table ContextRegisteredEvent { - context_id:int; -} -table ContextUnregisteredEvent { - context_id:int; -} - -table ModuleLoadedEvent { - context_id:int; - module_name:string; -} - -table InvocationRegisteredEvent { - invocation_id:int; -} -table InvocationUnregisteredEvent { - invocation_id:int; -} - -table BreakpointResolvedEvent { - breakpoint:BreakpointDef; - context_id:int; -} - -table BreakpointHitEvent { - breakpoint_id:int; - invocation:InvocationDef; -} - -table StepCompletedEvent { - step_id:int; - invocations:[InvocationDef]; -} - -union RequestUnion { - CreateSessionRequest, - MakeReadyRequest, - GetStatusRequest, - ListContextsRequest, - GetModuleRequest, - GetFunctionRequest, - ResolveFunctionRequest, - ListInvocationsRequest, - SuspendInvocationsRequest, - ResumeInvocationsRequest, - StepInvocationRequest, - GetInvocationLocalRequest, - SetInvocationLocalRequest, - ListBreakpointsRequest, - AddBreakpointRequest, - RemoveBreakpointRequest, - StartProfilingRequest, - StopProfilingRequest, -} - -union ResponseUnion { - CreateSessionResponse, - MakeReadyResponse, - GetStatusResponse, - ListContextsResponse, - GetModuleResponse, - GetFunctionResponse, - ResolveFunctionResponse, - ListInvocationsResponse, - SuspendInvocationsResponse, - ResumeInvocationsResponse, - StepInvocationResponse, - GetInvocationLocalResponse, - SetInvocationLocalResponse, - ListBreakpointsResponse, - AddBreakpointResponse, - RemoveBreakpointResponse, - StartProfilingResponse, - StopProfilingResponse, -} - -union EventUnion { - ServiceShutdownEvent, - ContextRegisteredEvent, - ContextUnregisteredEvent, - ModuleLoadedEvent, - InvocationRegisteredEvent, - InvocationUnregisteredEvent, - BreakpointResolvedEvent, - BreakpointHitEvent, - StepCompletedEvent, -} - -table Request { - message:RequestUnion; -} - -table Response { - status:Status; - message:ResponseUnion; -} - -table ServicePacket { - response:Response; - event:EventUnion; -} - -// NOTE: we aren't using this yet as the FlatBuffers gRPC code is... suspect. -rpc_service DebugServiceRpc { - MakeReady(MakeReadyRequest):MakeReadyResponse; - - GetStatus(GetStatusRequest):GetStatusResponse; - - ListContexts(ListContextsRequest):ListContextsResponse; - GetModule(GetModuleRequest):GetModuleResponse; - GetFunction(GetFunctionRequest):GetFunctionResponse; - ResolveFunction(ResolveFunctionRequest):ResolveFunctionResponse; - - ListInvocations(ListInvocationsRequest):ListInvocationsResponse; - SuspendInvocations(SuspendInvocationsRequest):SuspendInvocationsResponse; - ResumeInvocations(ResumeInvocationsRequest):ResumeInvocationsResponse; - StepInvocation(StepInvocationRequest):StepInvocationResponse; - GetInvocationLocal(GetInvocationLocalRequest):GetInvocationLocalResponse; - SetInvocationLocal(SetInvocationLocalRequest):SetInvocationLocalResponse; - - ListBreakpoints(ListBreakpointsRequest):ListBreakpointsResponse; - AddBreakpoint(AddBreakpointRequest):AddBreakpointResponse; - RemoveBreakpoint(RemoveBreakpointRequest):RemoveBreakpointResponse; - - StartProfiling(StartProfilingRequest):StartProfilingResponse; - StopProfiling(StopProfilingRequest):StopProfilingResponse; -}
diff --git a/iree/schemas/device_table_def.fbs b/iree/schemas/device_table_def.fbs deleted file mode 100644 index 9b00e29..0000000 --- a/iree/schemas/device_table_def.fbs +++ /dev/null
@@ -1,13 +0,0 @@ -include "iree/schemas/device_def.fbs"; -include "iree/schemas/device_group_def.fbs"; - -namespace iree; - -// A table of devices used for runtime device resolution and referencing. -table DeviceTableDef { - // One or more virtual devices referenced by ordinal in the sequencer ops. - devices:[DeviceDef]; - - // Zero or more device groups that specify which devices must be compatible. - device_groups:[DeviceGroupDef]; -}
diff --git a/iree/schemas/executable_table_def.fbs b/iree/schemas/executable_table_def.fbs deleted file mode 100644 index 25174ad..0000000 --- a/iree/schemas/executable_table_def.fbs +++ /dev/null
@@ -1,28 +0,0 @@ -include "iree/schemas/executable_def.fbs"; - -namespace iree; - -// A fat executable containing multiple format variants for the same logical -// entry points. -table MultiArchExecutableDef { - // Friendly name of the executable used for diagnostics. - name:string; - - // Number of available entry points. - // This is used for bytecode verification even when the executable is not - // fully loaded into a device. All executables must have the same entry - // points. - entry_point_count:uint; - - // A set of executables of various formats and supported feature sets. - // The runtime will select the appropriate executable based on the dispatch - // requirements. - executables:[ExecutableDef]; -} - -// A table of executables used for runtime dispatch lookup. -table ExecutableTableDef { - // One or more top level executables referenced by sequencer dispatch ops. - // Ordinal is referenced by dispatch ops to index into the table. - multi_arch_executables:[MultiArchExecutableDef]; -}
diff --git a/iree/schemas/function_def.fbs b/iree/schemas/function_def.fbs deleted file mode 100644 index 9547604..0000000 --- a/iree/schemas/function_def.fbs +++ /dev/null
@@ -1,18 +0,0 @@ -include "iree/schemas/bytecode_def.fbs"; -include "iree/schemas/type_def.fbs"; - -namespace iree; - -table FunctionAttributeDef { - key:string; - value:string; -} - -table FunctionDef { - name:string; - type:FunctionTypeDef; - - attrs:[FunctionAttributeDef]; - - bytecode:BytecodeDef; -}
diff --git a/iree/schemas/function_table_def.fbs b/iree/schemas/function_table_def.fbs deleted file mode 100644 index 2065df7..0000000 --- a/iree/schemas/function_table_def.fbs +++ /dev/null
@@ -1,9 +0,0 @@ -include "iree/schemas/function_def.fbs"; - -namespace iree; - -table FunctionTableDef { - functions:[FunctionDef]; - imports:[int]; - exports:[int]; -}
diff --git a/iree/schemas/module_def.fbs b/iree/schemas/module_def.fbs deleted file mode 100644 index 70bcbe9..0000000 --- a/iree/schemas/module_def.fbs +++ /dev/null
@@ -1,20 +0,0 @@ -include "iree/schemas/executable_table_def.fbs"; -include "iree/schemas/device_table_def.fbs"; -include "iree/schemas/function_table_def.fbs"; -include "iree/schemas/source_map_def.fbs"; - -namespace iree; - -// 'Executable MODule'. -file_identifier "EMOD"; -file_extension "emod"; - -table ModuleDef { - name:string; - device_table:DeviceTableDef; - function_table:FunctionTableDef; - executable_table:ExecutableTableDef; - source_map:SourceMapDef; -} - -root_type ModuleDef;
diff --git a/iree/tools/BUILD b/iree/tools/BUILD deleted file mode 100644 index cca9c86..0000000 --- a/iree/tools/BUILD +++ /dev/null
@@ -1,124 +0,0 @@ -# Misc tools used to optimize, translate, and evaluate IREE. -# Most of these are not designed to run on-device. - -load("//iree:build_defs.bzl", "PLATFORM_VULKAN_DEPS") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -exports_files([ - "run_lit.sh", - "sanitizer_suppressions.txt", -]) - -cc_binary( - name = "iree-opt", - deps = [ - "//iree/compiler/Transforms", - "//iree/compiler/Transforms/Interpreter", - "//iree/compiler/Transforms/Sequencer", - "//iree/compiler/Translation/SPIRV", - "@llvm//:support", - "@local_config_mlir//:AffineDialectRegistration", - "@local_config_mlir//:MlirOptLib", - "@local_config_mlir//:MlirOptMain", - "@local_config_mlir//:StandardDialectRegistration", - "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo", - "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_dialect_registration", - ], -) - -cc_binary( - name = "iree-run-mlir", - srcs = ["run_mlir_main.cc"], - deps = [ - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/strings", - "//iree/base:source_location", - "//iree/rt", - "//iree/vm:sequencer_module", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Parser", - "@local_config_mlir//:Support", - "//iree/base:init", - "//iree/base:status", - "//iree/compiler/Translation/Sequencer", - "//iree/compiler/Translation/Interpreter", - "//iree/compiler/Translation/SPIRV", - "//iree/hal:buffer_view_string_util", - "//iree/hal:driver_registry", - "//iree/schemas", - "//iree/rt/debug:debug_server_flags", - ] + PLATFORM_VULKAN_DEPS + [ - "//iree/hal/interpreter:interpreter_driver_module", - # TODO(b/142004903): enable when Dawn HAL implementation is functional - # "//iree/hal/dawn:dawn_driver_module", - "//iree/hal/vulkan:vulkan_driver_module", - ], -) - -cc_binary( - name = "iree-translate", - srcs = ["iree_translate_main.cc"], - deps = [ - "//iree/compiler/Translation/Interpreter", - "//iree/compiler/Translation/SPIRV", - "//iree/compiler/Translation/Sequencer", - "@llvm//:support", - "@local_config_mlir//:AffineDialectRegistration", - "@local_config_mlir//:IR", - "@local_config_mlir//:Pass", - "@local_config_mlir//:StandardDialectRegistration", - "@local_config_mlir//:Support", - "@local_config_mlir//:TranslateClParser", - "@local_config_mlir//:Translation", - "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_dialect_registration", - ], -) - -cc_binary( - name = "run_module", - srcs = ["run_module_main.cc"], - deps = [ - "//iree/base:file_io", - "//iree/base:file_path", - "//iree/base:init", - "//iree/base:source_location", - "//iree/base:status", - "//iree/hal:buffer_view_string_util", - "//iree/hal:driver_registry", - "//iree/hal/interpreter:interpreter_driver_module", - "//iree/rt", - "//iree/rt/debug:debug_server_flags", - "//iree/schemas", - "//iree/vm:sequencer_module", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/strings", - ], -) - -cc_binary( - name = "benchmark_module", - testonly = 1, - srcs = ["benchmark_module.cc"], - deps = [ - "//iree/base:file_io", - "//iree/base:file_path", - "//iree/base:init", - "//iree/base:source_location", - "//iree/base:status", - "//iree/hal:buffer_view_string_util", - "//iree/hal:driver_registry", - "//iree/hal/interpreter:interpreter_driver_module", - "//iree/rt", - "//iree/rt/debug:debug_server_flags", - "//iree/schemas", - "//iree/vm:sequencer_module", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/strings", - "@com_google_benchmark//:benchmark", - ], -)
diff --git a/iree/tools/benchmark_module.cc b/iree/tools/benchmark_module.cc deleted file mode 100644 index 1b62a52..0000000 --- a/iree/tools/benchmark_module.cc +++ /dev/null
@@ -1,157 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <iostream> -#include <vector> - -#include "absl/flags/flag.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_replace.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "benchmark/benchmark.h" -#include "iree/base/file_io.h" -#include "iree/base/file_path.h" -#include "iree/base/init.h" -#include "iree/base/source_location.h" -#include "iree/base/status.h" -#include "iree/hal/buffer_view_string_util.h" -#include "iree/hal/driver_registry.h" -#include "iree/rt/context.h" -#include "iree/rt/debug/debug_server_flags.h" -#include "iree/rt/instance.h" -#include "iree/rt/module_printer.h" -#include "iree/schemas/module_def_generated.h" -#include "iree/vm/sequencer_module.h" - -ABSL_FLAG(std::string, main_module, "", "Main module with entry point."); -ABSL_FLAG(std::string, main_function, "", - "Function within the main module to execute."); - -ABSL_FLAG(std::string, input_values, "", "Input shapes and optional values."); -ABSL_FLAG(std::string, input_file, "", - "Input shapes and optional values serialized in a file."); - -namespace iree { -namespace { - -// Parses a list of input shapes and values from a string of newline-separated -// inputs. Expects the contents to have one value per line with each value -// listed as -// [shape]xtype=[value] -// Example: -// 4x4xi8=0,1,2,3 -StatusOr<std::vector<hal::BufferView>> ParseInputsFromFlags( - hal::Allocator* allocator) { - std::string file_contents; - if (!absl::GetFlag(FLAGS_input_values).empty()) { - file_contents = - absl::StrReplaceAll(absl::GetFlag(FLAGS_input_values), {{"\\n", "\n"}}); - } else if (!absl::GetFlag(FLAGS_input_file).empty()) { - ASSIGN_OR_RETURN(file_contents, - file_io::GetFileContents(absl::GetFlag(FLAGS_input_file))); - } - std::vector<hal::BufferView> inputs; - for (const auto& line : - absl::StrSplit(file_contents, '\n', absl::SkipWhitespace())) { - ASSIGN_OR_RETURN(auto input, - hal::ParseBufferViewFromString(line, allocator)); - inputs.push_back(input); - } - return inputs; -} - -Status Run(benchmark::State& state) { - ASSIGN_OR_RETURN(auto debug_server, rt::debug::CreateDebugServerFromFlags()); - auto instance = make_ref<rt::Instance>(std::move(debug_server)); - ASSIGN_OR_RETURN(auto driver, hal::DriverRegistry::shared_registry()->Create( - "interpreter")); - ASSIGN_OR_RETURN(auto device, driver->CreateDefaultDevice()); - RETURN_IF_ERROR(instance->device_manager()->RegisterDevice(device)); - auto policy = make_ref<rt::Policy>(); - auto context = make_ref<rt::Context>(add_ref(instance), std::move(policy)); - - // Load main module. - ASSIGN_OR_RETURN( - auto main_module_file, - vm::ModuleFile::LoadFile(ModuleDefIdentifier(), - absl::GetFlag(FLAGS_main_module)), - _ << "while loading module file " << absl::GetFlag(FLAGS_main_module)); - ASSIGN_OR_RETURN(auto main_module, - vm::SequencerModule::FromFile(std::move(main_module_file))); - - // Register the main module with the context. - // We could add additional modules (specializations, shared libraries, etc). - // ModuleFiles are stateless so we could have the same module_file used by - // multiple contexts simultaneously. - RETURN_IF_ERROR(context->RegisterModule(add_ref(main_module))); - - rt::Function main_function; - if (!absl::GetFlag(FLAGS_main_function).empty()) { - // User-specified main function. - ASSIGN_OR_RETURN(main_function, main_module->LookupFunctionByName( - rt::Function::Linkage::kExport, - absl::GetFlag(FLAGS_main_function))); - } else { - // No main function specified; to prevent non-deterministic behavior we - // require one unless there's exactly one exported function in the module. - if (main_module->signature().export_function_count() == 1) { - ASSIGN_OR_RETURN(main_function, main_module->LookupFunctionByOrdinal( - rt::Function::Linkage::kExport, 0)); - } else { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "--main_function= must be specified to disambiguate the " - "function to run"; - } - } - - // Call into the main function. - ASSIGN_OR_RETURN(auto arguments, ParseInputsFromFlags(device->allocator())); - - for (auto _ : state) { - ASSIGN_OR_RETURN(auto invocation, - rt::Invocation::Create(add_ref(context), main_function, - make_ref<rt::Policy>(), {}, - absl::MakeConstSpan(arguments))); - RETURN_IF_ERROR(invocation->Await(absl::InfiniteFuture())); - } - - return OkStatus(); -} - -void BM_RunModule(benchmark::State& state) { - // Delegate to a status-returning function so we can use the status macros. - CHECK_OK(Run(state)); -} - -// By default only the main thread is included in CPU time. Include all the -// threads instead. To make single and multi-threaded benchmarks more -// comparable, use the wall time to determine how many iterations to run. -// See https://github.com/google/benchmark#cpu-timers, -BENCHMARK(BM_RunModule)->MeasureProcessCPUTime()->UseRealTime(); - -} // namespace - -extern "C" int main(int argc, char** argv) { - // The benchmark library uses a different mechanism for its flags. This - // consumes any arguments it understands from argv. It must come before - // InitializeEnvironment to avoid failures on unknown flags. - ::benchmark::Initialize(&argc, argv); - InitializeEnvironment(&argc, &argv); - size_t run_benchmark_count = ::benchmark::RunSpecifiedBenchmarks(); - CHECK_GT(run_benchmark_count, 0) << "No benchmarks were run"; - return 0; -} - -} // namespace iree
diff --git a/iree/tools/compilation.bzl b/iree/tools/compilation.bzl deleted file mode 100644 index b232099..0000000 --- a/iree/tools/compilation.bzl +++ /dev/null
@@ -1,43 +0,0 @@ -"""Rules for compiling IREE executables, modules, and archives.""" - -load("//build_tools/embed_data:build_defs.bzl", "cc_embed_data") - -# TODO(benvanik): port to a full starlark rule, document, etc. -def iree_bytecode_module( - name, - srcs, - cc_namespace = None, - visibility = None): - native.genrule( - name = name, - srcs = srcs, - outs = [ - "%s.emod" % (name), - ], - cmd = " && ".join([ - " ".join([ - "$(location //iree/tools:iree-translate)", - "-mlir-to-iree-module", - "-o $(location %s.emod)" % (name), - ] + ["$(locations %s)" % (src) for src in srcs]), - ]), - tools = [ - "//iree/tools:iree-translate", - ], - message = "Compiling IREE module %s..." % (name), - output_to_bindir = 1, - ) - - # Embed the module for use in C++. This avoids the need for file IO in - # tests and samples that would otherwise complicate execution/porting. - if cc_namespace: - cc_embed_data( - name = "%s_cc" % (name), - identifier = name, - srcs = ["%s.emod" % (name)], - cc_file_output = "%s.cc" % (name), - h_file_output = "%s.h" % (name), - cpp_namespace = cc_namespace, - visibility = visibility, - flatten = True, - )
diff --git a/iree/tools/debugger/BUILD b/iree/tools/debugger/BUILD deleted file mode 100644 index e3b90a4..0000000 --- a/iree/tools/debugger/BUILD +++ /dev/null
@@ -1,173 +0,0 @@ -# IREE Debugger UIs. -# -# The main debugger UI can be used in standalone mode connected to a remote -# host (via :debugger) or can be directly embedded into the IREE runtime to -# allow for attaching (--iree_attach_debugger). -# -# By default the IREE runtime does not compile in debug support. To link it in -# pass --define=IREE_DEBUG=1 to bazel builds of the runtime. - -# TODO(benvanik): re-enable debugger after refactoring. -# load("//third_party/emscripten:split_transition_defs.bzl", "auto_wasm_binary") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -# TODO(benvanik): re-enable debugger after refactoring. -# alias( -# name = "debugger", -# actual = select({ -# "//tools/cc_target_os:emscripten": ":debug_app_emscripten_files", -# "//conditions:default": ":debug_app_native", -# }), -# ) -# -# cc_library( -# name = "debug_app_library", -# srcs = ["debug_app.cc"], -# hdrs = ["debug_app.h"], -# deps = [ -# "//third_party/GL:GLES2_headers", -# "//third_party/SDL2", -# "@com_google_absl//absl/flags:flag", -# "@com_google_absl//absl/memory", -# "@com_google_absl//absl/strings", -# "@com_google_absl//absl/types:optional", -# "//third_party/dear_imgui", -# "//third_party/dear_imgui:imgui_sdl_opengl3", -# "//iree/base:memory", -# "//iree/base:source_location", -# "//iree/base:status", -# "//iree/rt/debug:debug_client", -# "//iree/schemas", -# ], -# ) -# -# # NOTE: users must also link in a GL implementation, like: -# # "//third_party/GL/native:GLESv2", # build-cleaner: keep -# cc_library( -# name = "debug_app_embedded", -# srcs = ["debug_app_embedded.cc"], -# hdrs = ["debug_app_embedded.h"], -# deps = [ -# ":debug_app_library", -# "//third_party/SDL2", -# "@com_google_absl//absl/base:core_headers", -# "@com_google_absl//absl/memory", -# "@com_google_absl//absl/strings", -# "@com_google_absl//absl/synchronization", -# "//third_party/dear_imgui", -# "//iree/base:memory", -# "//iree/base:status", -# ], -# ) -# -# EMSCRIPTEN_LINKOPTS_COMMON = [ -# # Error at compile time on unresolved symbols. -# "-s ERROR_ON_UNDEFINED_SYMBOLS=1", -# -# # Required by SDL. -# "-s EXTRA_EXPORTED_RUNTIME_METHODS=Pointer_stringify", -# -# # TODO(benvanik): tweak to enable support when needed. -# "-s ALLOW_MEMORY_GROWTH=1", -# # "-s WASM_MEM_MAX=268435456", # 256MB -# # "-s TOTAL_MEMORY=268435456", # 256MB -# ] -# -# EMSCRIPTEN_LINKOPTS_DBG = [ -# # Show WASM stack trace in Chrome debugger. -# "-g2", -# "-s DEMANGLE_SUPPORT=1", -# -# # Enable verbose assertions. -# "-s ASSERTIONS=2", -# "-s SAFE_HEAP=1", -# "-s STACK_OVERFLOW_CHECK=2", -# ] -# -# EMSCRIPTEN_LINKOPTS_OPT = [] -# -# cc_binary( -# name = "debug_app_emscripten", -# srcs = ["debug_app_main_emscripten.cc"], -# linkopts = EMSCRIPTEN_LINKOPTS_COMMON + select({ -# "//tools/compilation_mode:dbg": EMSCRIPTEN_LINKOPTS_DBG, -# "//tools/compilation_mode:opt": EMSCRIPTEN_LINKOPTS_OPT, -# "//conditions:default": EMSCRIPTEN_LINKOPTS_OPT, -# }), -# tags = [ -# "manual", -# "notap", # TODO(b/137088911): Build/test on TAP -# "wasm", -# ], -# deps = [ -# ":debug_app_library", -# "//third_party/SDL2", -# "@com_google_absl//absl/memory", -# "//third_party/dear_imgui", -# "//third_party/dear_imgui:imgui_sdl_opengl3", -# "//iree/base:init", -# "//iree/base:source_location", -# "//iree/base:status", -# ], -# ) -# -# auto_wasm_binary( -# name = "debug_app_emscripten_binary", -# cc_target = ":debug_app_emscripten", -# tags = ["manual"], -# ) -# -# Fileset( -# name = "debug_app_emscripten_files", -# out = "wasm_files", -# entries = [ -# FilesetEntry( -# files = [":debug_app_emscripten_binary"], -# strip_prefix = "debug_app_emscripten_binary", -# destdir = "wasm", -# ), -# FilesetEntry( -# files = ["debug_app.html"], -# destdir = "wasm", -# ), -# ], -# tags = ["manual"], -# ) -# -# cc_binary( -# name = "debug_app_native", -# srcs = ["debug_app_main_native.cc"], -# deps = [ -# ":debug_app_embedded", -# "//third_party/GL/native:EGL", # build-cleaner: keep -# "//third_party/GL/native:GLESv2", # build-cleaner: keep -# "//iree/base:init", -# "//iree/base:status", -# ], -# ) -# -# cc_binary( -# name = "debug_cli", -# srcs = ["debug_cli_main.cc"], -# deps = [ -# ":debug_prompt", -# "@com_google_absl//absl/flags:flag", -# "//iree/base:init", -# "//iree/base:status", -# ], -# ) -# -# cc_library( -# name = "debug_prompt", -# srcs = ["debug_prompt.cc"], -# hdrs = ["debug_prompt.h"], -# deps = [ -# "@com_google_absl//absl/strings", -# "//iree/base:status", -# "//iree/rt/debug:debug_client", -# ], -# )
diff --git a/iree/tools/debugger/debug_app.cc b/iree/tools/debugger/debug_app.cc deleted file mode 100644 index 3e63f53..0000000 --- a/iree/tools/debugger/debug_app.cc +++ /dev/null
@@ -1,1422 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/tools/debugger/debug_app.h" - -#include <GLES2/gl2.h> - -#include <algorithm> -#include <cstdio> - -#include "absl/flags/flag.h" -#include "absl/memory/memory.h" -#include "absl/strings/str_join.h" -#include "absl/strings/str_split.h" -#include "absl/types/optional.h" -#include "iree/base/memory.h" -#include "iree/base/source_location.h" -#include "iree/base/status.h" -#include "iree/rt/debug/debug_client.h" -#include "iree/schemas/debug_service_generated.h" -#include "iree/vm/bytecode_module.h" -#include "iree/vm/bytecode_tables_sequencer.h" -#include "third_party/dear_imgui/imgui.h" -#include "third_party/dear_imgui/imgui_internal.h" - -namespace iree { -namespace rt { -namespace debug { -namespace { - -void PushButtonHue(float hue) { - ImGui::PushStyleColor(ImGuiCol_Button, - (ImVec4)ImColor::HSV(hue / 7.0f, 0.6f, 0.6f)); - ImGui::PushStyleColor(ImGuiCol_ButtonHovered, - (ImVec4)ImColor::HSV(hue / 7.0f, 0.7f, 0.7f)); - ImGui::PushStyleColor(ImGuiCol_ButtonActive, - (ImVec4)ImColor::HSV(hue / 7.0f, 0.8f, 0.8f)); -} - -void PushButtonColor(const ImVec4& color) { - ImGui::PushStyleColor(ImGuiCol_Button, color); - ImGui::PushStyleColor(ImGuiCol_ButtonHovered, color); - ImGui::PushStyleColor(ImGuiCol_ButtonActive, color); -} - -void PopButtonStyle() { ImGui::PopStyleColor(3); } - -bool AreBreakpointsEqual(const RemoteBreakpoint& breakpoint, - const DebugApp::UserBreakpoint& user_breakpoint) { - if (user_breakpoint.active_breakpoint == &breakpoint) { - return true; - } else if (user_breakpoint.type != breakpoint.type()) { - return false; - } - switch (breakpoint.type()) { - case RemoteBreakpoint::Type::kBytecodeFunction: - if (user_breakpoint.function_ordinal != -1 && - user_breakpoint.function_ordinal != breakpoint.function_ordinal()) { - return false; - } - return breakpoint.module_name() == user_breakpoint.module_name && - breakpoint.function_name() == user_breakpoint.function_name && - breakpoint.bytecode_offset() == user_breakpoint.bytecode_offset; - case RemoteBreakpoint::Type::kNativeFunction: - return breakpoint.function_name() == user_breakpoint.native_function; - default: - return false; - } -} - -} // namespace - -// static -void DebugApp::PumpMainLoopThunk(void* arg) { - auto status = reinterpret_cast<DebugApp*>(arg)->PumpMainLoop(); - if (IsCancelled(status)) { - return; - } else if (!status.ok()) { - CHECK_OK(status); - } -} - -DebugApp::DebugApp(SDL_Window* window, SDL_GLContext gl_context, - const char* glsl_version) - : window_(window), gl_context_(gl_context) { - VLOG(1) << "DebugApp initializing..."; - IMGUI_CHECKVERSION(); - ImGui::CreateContext(); - ImGuiIO& io = ImGui::GetIO(); - io.ConfigFlags |= ImGuiConfigFlags_NavEnableKeyboard; - io.ConfigFlags |= ImGuiConfigFlags_DockingEnable; - - // TODO(benvanik): ini file for settings. - io.IniFilename = nullptr; - // ImGui::LoadIniSettingsFromMemory() - // ImGui::SaveIniSettingsToMemory() - - // TODO(benvanik): theming. - ImGui::StyleColorsDark(); - - // Setup Platform/Renderer bindings - ImGui_ImplSDL2_InitForOpenGL(window_, gl_context_); - ImGui_ImplOpenGL3_Init(glsl_version); - SDL_GL_MakeCurrent(nullptr, nullptr); - VLOG(1) << "DebugApp initialized"; -} - -DebugApp::~DebugApp() { - VLOG(1) << "DebugApp shutting down..."; - ImGui_ImplOpenGL3_Shutdown(); - ImGui_ImplSDL2_Shutdown(); - ImGui::DestroyContext(); - - SDL_GL_DeleteContext(gl_context_); - SDL_GL_MakeCurrent(nullptr, nullptr); - SDL_DestroyWindow(window_); - SDL_Quit(); - VLOG(1) << "DebugApp shut down (SDL_Quit)"; -} - -Status DebugApp::Connect(absl::string_view service_address) { - VLOG(1) << "Connecting to debug service at " << service_address << "..."; - ASSIGN_OR_RETURN(debug_client_, DebugClient::Connect(service_address, this)); - - // TODO(benvanik): load breakpoints from file. - UserBreakpoint user_breakpoint; - user_breakpoint.module_name = "module"; - user_breakpoint.function_name = "main"; - user_breakpoint.bytecode_offset = 0; - user_breakpoint.wants_enabled = true; - user_breakpoint_list_.push_back(std::move(user_breakpoint)); - RETURN_IF_ERROR(RefreshActiveBreakpoints()); - - // Set paused so that we need to resume to continue execution. - is_paused_ = true; - return OkStatus(); -} - -Status DebugApp::Disconnect() { - VLOG(1) << "Disconnecting from debug service"; - debug_client_.reset(); - return OkStatus(); -} - -bool DebugApp::is_paused() const { - if (!debug_client_) { - return false; - } - if (!hit_breakpoints_.empty()) { - return true; // One or more breakpoints hit. - } - return is_paused_ || !is_stepping_; -} - -RemoteInvocation* DebugApp::GetSelectedInvocation() const { - if (!debug_client_ || !selected_invocation_id_.has_value()) { - return nullptr; - } - for (auto* invocation : debug_client_->invocations()) { - if (invocation->id() == selected_invocation_id_.value()) { - return invocation; - } - } - return nullptr; -} - -Status DebugApp::RefreshActiveBreakpoints() { - // Set all breakpoints to disabled. We'll re-enable them as we find them - // below. - for (auto& user_breakpoint : user_breakpoint_list_) { - user_breakpoint.active_breakpoint = nullptr; - } - - // If not connected then no breakpoints are active. - if (!debug_client_) { - return OkStatus(); - } - - // Reconcile the user breakpoint list with the breakpoints available on the - // server. - for (auto* breakpoint : debug_client_->breakpoints()) { - auto it = - std::find_if(user_breakpoint_list_.begin(), user_breakpoint_list_.end(), - [breakpoint](const UserBreakpoint& user_breakpoint) { - return AreBreakpointsEqual(*breakpoint, user_breakpoint); - }); - if (it == user_breakpoint_list_.end()) { - // Breakpoint not found - add to user list. - UserBreakpoint user_breakpoint; - user_breakpoint.type = breakpoint->type(); - user_breakpoint.active_breakpoint = breakpoint; - user_breakpoint.module_name = breakpoint->module_name(); - user_breakpoint.function_name = breakpoint->function_name(); - user_breakpoint.function_ordinal = breakpoint->function_ordinal(); - user_breakpoint.bytecode_offset = breakpoint->bytecode_offset(); - user_breakpoint_list_.push_back(std::move(user_breakpoint)); - } else { - // Breakpoint found - set the active pointer. - UserBreakpoint& user_breakpoint = *it; - user_breakpoint.active_breakpoint = breakpoint; - user_breakpoint.is_enabling = false; - user_breakpoint.module_name = breakpoint->module_name(); - user_breakpoint.function_name = breakpoint->function_name(); - user_breakpoint.function_ordinal = breakpoint->function_ordinal(); - user_breakpoint.bytecode_offset = breakpoint->bytecode_offset(); - } - } - - // Ensure any breakpoint the user wants enabled is active/otherwise. - for (auto& user_breakpoint : user_breakpoint_list_) { - if (user_breakpoint.wants_enabled && !user_breakpoint.is_enabling && - !user_breakpoint.active_breakpoint) { - // Add breakpoint on server. - switch (user_breakpoint.type) { - case RemoteBreakpoint::Type::kBytecodeFunction: - RETURN_IF_ERROR(debug_client_->AddFunctionBreakpoint( - user_breakpoint.module_name, user_breakpoint.function_name, - user_breakpoint.bytecode_offset, - [&user_breakpoint](const RemoteBreakpoint& breakpoint) { - user_breakpoint.function_ordinal = - breakpoint.function_ordinal(); - })); - break; - case RemoteBreakpoint::Type::kNativeFunction: - // TODO(benvanik): native breakpoint support. - return UnimplementedErrorBuilder(IREE_LOC) - << "Native function breakpoints are TODO"; - default: - return UnimplementedErrorBuilder(IREE_LOC) - << "Unimplemented breakpoint type"; - } - user_breakpoint.is_enabling = true; - } else if (!user_breakpoint.wants_enabled && - user_breakpoint.active_breakpoint) { - // Remove breakpoint from server. - RETURN_IF_ERROR( - debug_client_->RemoveBreakpoint(*user_breakpoint.active_breakpoint)); - - user_breakpoint.active_breakpoint = nullptr; - } - } - - return OkStatus(); -} - -bool DebugApp::IsStoppedAtBreakpoint( - const UserBreakpoint& user_breakpoint) const { - return std::find(hit_breakpoints_.begin(), hit_breakpoints_.end(), - user_breakpoint.active_breakpoint) != hit_breakpoints_.end(); -} - -int DebugApp::FindMatchingUserBreakpointIndex(absl::string_view module_name, - int function_ordinal, - int offset) { - for (int i = 0; i < user_breakpoint_list_.size(); ++i) { - auto& user_breakpoint = user_breakpoint_list_[i]; - if (user_breakpoint.module_name == module_name && - user_breakpoint.function_ordinal == function_ordinal && - user_breakpoint.bytecode_offset == offset) { - return i; - } - } - return -1; -} - -int DebugApp::FindMatchingUserBreakpointIndex(absl::string_view module_name, - absl::string_view function_name, - int offset) { - for (int i = 0; i < user_breakpoint_list_.size(); ++i) { - auto& user_breakpoint = user_breakpoint_list_[i]; - if (user_breakpoint.module_name == module_name && - user_breakpoint.function_name == function_name && - user_breakpoint.bytecode_offset == offset) { - return i; - } - } - return -1; -} - -Status DebugApp::ResumeFromBreakpoint(UserBreakpoint* user_breakpoint) { - if (!user_breakpoint->active_breakpoint) { - return FailedPreconditionErrorBuilder(IREE_LOC) << "Breakpoint not active"; - } - VLOG(1) << "Resuming from breakpoint " - << user_breakpoint->active_breakpoint->id() << "..."; - auto it = std::find(hit_breakpoints_.begin(), hit_breakpoints_.end(), - user_breakpoint->active_breakpoint); - if (it == hit_breakpoints_.end()) { - return NotFoundErrorBuilder(IREE_LOC) << "Breakpoint not found"; - } - hit_breakpoints_.erase(it); - return debug_client_->MakeReady(); -} - -Status DebugApp::OnContextRegistered(const RemoteContext& context) { - // Ack event. - return debug_client_->MakeReady(); -} - -Status DebugApp::OnContextUnregistered(const RemoteContext& context) { - // Close documents that may reference modules in the context. - std::vector<CodeViewDocument*> closing_documents; - for (auto& document : documents_) { - auto* module = document->function->module(); - if (module->context_id() != context.id()) { - // Document is not from this context so it's fine. - continue; - } - - // See if any other live context still has the module loaded. We can change - // the document over to that. - RemoteModule* replacement_module = nullptr; - for (auto* context : debug_client_->contexts()) { - for (auto* other_module : context->modules()) { - if (other_module->name() == module->name()) { - replacement_module = other_module; - break; - } - } - if (replacement_module) break; - } - if (replacement_module && replacement_module->is_loaded()) { - // Replace document module reference. - int function_ordinal = document->function->ordinal(); - auto functions = replacement_module->functions(); - if (function_ordinal < functions.size()) { - document->function = functions[function_ordinal]; - } else { - document->function = nullptr; - } - } else { - document->function = nullptr; - } - - if (!document->function) { - // Close the document if we don't have a valid function for it. - VLOG(1) - << "Closing document " << document->title - << " because the last context using the module is being unregistered"; - closing_documents.push_back(document.get()); - } - } - for (auto* document : closing_documents) { - auto it = std::find_if( - documents_.begin(), documents_.end(), - [document](const std::unique_ptr<CodeViewDocument>& open_document) { - return document == open_document.get(); - }); - documents_.erase(it); - } - - // Ack event. - return debug_client_->MakeReady(); -} - -Status DebugApp::OnModuleLoaded(const RemoteContext& context, - const RemoteModule& module) { - // Ack event. - return debug_client_->MakeReady(); -} - -Status DebugApp::OnInvocationRegistered(const RemoteInvocation& invocation) { - if (!selected_invocation_id_.has_value()) { - selected_invocation_id_ = invocation.id(); - selected_stack_frame_index_ = {}; - } - - // Ack event. - return debug_client_->MakeReady(); -} - -Status DebugApp::OnInvocationUnregistered(const RemoteInvocation& invocation) { - if (selected_invocation_id_.has_value() && - selected_invocation_id_.value() == invocation.id()) { - selected_invocation_id_ = {}; - selected_stack_frame_index_ = {}; - } - - // Ack event. - return debug_client_->MakeReady(); -} - -Status DebugApp::OnBreakpointHit(const RemoteBreakpoint& breakpoint, - const RemoteInvocation& invocation) { - // Keep track of where we are stopped. - hit_breakpoints_.push_back(&breakpoint); - return NavigateToCodeView(invocation, -1, NavigationMode::kMatchDocument); -} - -Status DebugApp::PumpMainLoop() { - ImGuiIO& io = ImGui::GetIO(); - - if (debug_client_) { - RETURN_IF_ERROR(debug_client_->Poll()); - } - RETURN_IF_ERROR(RefreshActiveBreakpoints()); - - SDL_GL_MakeCurrent(window_, gl_context_); - - SDL_Event event; - while (SDL_PollEvent(&event)) { - ImGui_ImplSDL2_ProcessEvent(&event); - if (event.type == SDL_QUIT) { - return CancelledErrorBuilder(IREE_LOC) << "Quit hotkey"; - } else if (event.type == SDL_WINDOWEVENT && - event.window.event == SDL_WINDOWEVENT_CLOSE && - event.window.windowID == SDL_GetWindowID(window_)) { - return CancelledErrorBuilder(IREE_LOC) << "Window closed"; - } - } - ImGui_ImplOpenGL3_NewFrame(); - ImGui_ImplSDL2_NewFrame(window_); - ImGui::NewFrame(); - - auto draw_status = DrawUI(); - if (!draw_status.ok()) { - // TODO(benvanik): show on screen? Probably all messed up. - LOG(ERROR) << draw_status; - } - - // Blit the entire ImGui UI. - ImGui::Render(); - SDL_GL_MakeCurrent(window_, gl_context_); - glViewport(0, 0, (int)io.DisplaySize.x, (int)io.DisplaySize.y); - glClearColor(0.45f, 0.55f, 0.60f, 1.0f); - glClear(GL_COLOR_BUFFER_BIT); - // Workaround for terrible bad SDL/graphics driver leaks. - IREE_DISABLE_LEAK_CHECKS(); - ImGui_ImplOpenGL3_RenderDrawData(ImGui::GetDrawData()); - IREE_ENABLE_LEAK_CHECKS(); - - // Render additional viewport windows (desktop only). - if (io.ConfigFlags & ImGuiConfigFlags_ViewportsEnable) { - SDL_Window* backup_current_window = SDL_GL_GetCurrentWindow(); - SDL_GLContext backup_current_context = SDL_GL_GetCurrentContext(); - ImGui::UpdatePlatformWindows(); - ImGui::RenderPlatformWindowsDefault(); - SDL_GL_MakeCurrent(backup_current_window, backup_current_context); - } - - SDL_GL_SwapWindow(window_); - return OkStatus(); -} - -Status DebugApp::LayoutInitialDockSpace() { - dockspace_id_ = ImGui::GetID("MainDockSpace"); - if (ImGui::DockBuilderGetNode(dockspace_id_)) { - // Already configured. - return OkStatus(); - } - ImGui::DockBuilderAddNode(dockspace_id_, ImGuiDockNodeFlags_DockSpace); - - dock_content_id_ = dockspace_id_; - dock_top_id_ = ImGui::DockBuilderSplitNode(dock_content_id_, ImGuiDir_Up, - 0.05f, nullptr, &dock_content_id_); - dock_left_id_ = ImGui::DockBuilderSplitNode( - dock_content_id_, ImGuiDir_Left, 0.20f, nullptr, &dock_content_id_); - dock_bottom_id_ = ImGui::DockBuilderSplitNode( - dock_content_id_, ImGuiDir_Down, 0.20f, nullptr, &dock_content_id_); - dock_right_id_ = ImGui::DockBuilderSplitNode( - dock_content_id_, ImGuiDir_Right, 0.20f, nullptr, &dock_content_id_); - dock_bottom_left_id_ = ImGui::DockBuilderSplitNode( - dock_bottom_id_, ImGuiDir_Left, 0.50f, nullptr, &dock_bottom_right_id_); - - ImGui::DockBuilderDockWindow("Toolbar", dock_top_id_); - auto* dock_top_node = ImGui::DockBuilderGetNode(dock_top_id_); - dock_top_node->LocalFlags = ImGuiDockNodeFlags_NoSplit | - ImGuiDockNodeFlags_NoResize | - ImGuiDockNodeFlags_AutoHideTabBar; - - ImGui::DockBuilderDockWindow("Modules", dock_left_id_); - ImGui::DockBuilderDockWindow("Locals", dock_bottom_left_id_); - ImGui::DockBuilderDockWindow("Invocations", dock_bottom_right_id_); - ImGui::DockBuilderDockWindow("Breakpoints", dock_bottom_right_id_); - - ImGui::DockBuilderFinish(dockspace_id_); - return OkStatus(); -} - -Status DebugApp::DrawUI() { - ImGuiWindowFlags window_flags = - ImGuiWindowFlags_MenuBar | ImGuiWindowFlags_NoDocking; - window_flags |= ImGuiWindowFlags_NoTitleBar | ImGuiWindowFlags_NoCollapse | - ImGuiWindowFlags_NoResize | ImGuiWindowFlags_NoMove | - ImGuiWindowFlags_NoNavFocus; - - ImGuiViewport* viewport = ImGui::GetMainViewport(); - ImGui::SetNextWindowPos(viewport->Pos); - ImGui::SetNextWindowSize(viewport->Size); - ImGui::SetNextWindowViewport(viewport->ID); - ImGui::PushStyleVar(ImGuiStyleVar_WindowRounding, 0.0f); - ImGui::PushStyleVar(ImGuiStyleVar_WindowBorderSize, 0.0f); - ImGui::PushStyleVar(ImGuiStyleVar_WindowPadding, ImVec2(0.0f, 0.0f)); - ImGui::Begin("IREEDebugRoot", nullptr, window_flags); - ImGui::PopStyleVar(3); - - RETURN_IF_ERROR(LayoutInitialDockSpace()); - ImGui::DockSpace(dockspace_id_, ImVec2(0.0f, 0.0f), ImGuiDockNodeFlags_None); - - RETURN_IF_ERROR(DrawMainMenu()); - RETURN_IF_ERROR(DrawToolbar()); - - ImGui::PushStyleVar(ImGuiStyleVar_WindowPadding, ImVec2(2, 2)); - RETURN_IF_ERROR(DrawBreakpointListPanel()); - RETURN_IF_ERROR(DrawModuleListPanel()); - RETURN_IF_ERROR(DrawLocalListPanel()); - RETURN_IF_ERROR(DrawInvocationListPanel()); - ImGui::PopStyleVar(); - - RETURN_IF_ERROR(DrawCodeViewPanels()); - - ImGui::End(); - return OkStatus(); -} - -Status DebugApp::DrawMainMenu() { - if (!ImGui::BeginMenuBar()) return OkStatus(); - - // TODO(benvanik): main menu. - if (ImGui::BeginMenu("File")) { - ImGui::EndMenu(); - } - - ImGui::EndMenuBar(); - return OkStatus(); -} - -Status DebugApp::DrawToolbar() { - // TODO(benvanik): figure out how to make this not grow. - ImGui::Begin("Toolbar", nullptr, - ImGuiWindowFlags_NoResize | ImGuiWindowFlags_NoTitleBar | - ImGuiWindowFlags_NoMove | ImGuiWindowFlags_NoCollapse | - ImGuiWindowFlags_NoScrollbar); - ImGui::BeginGroup(); - -#if !defined(IMGUI_DISABLE_DEMO_WINDOWS) - static bool show_demo_window = false; - if (ImGui::Button("Demo")) { - show_demo_window = !show_demo_window; - } - if (show_demo_window) { - ImGui::SetNextWindowDockID(dock_content_id_); - ImGui::ShowDemoWindow(&show_demo_window); - } -#endif // !IMGUI_DISABLE_DEMO_WINDOWS - - ImGui::SameLine(); - if (!debug_client_) { - if (ImGui::Button("Connect")) { - // TODO(benvanik): connection dialog and/or autoconnect. - } - } else { - if (ImGui::Button("Disconnect")) { - debug_client_.reset(); - } - } - - ImGui::SameLine(); - if (debug_client_) { - ImGui::Text("<status>"); - } else { - ImGui::TextDisabled("disconnected"); - } - - ImGui::SameLine(); - ImGui::Spacing(); - ImGui::SameLine(); - ImGui::Spacing(); - - ImGui::SameLine(); - ImGui::BeginGroup(); - ImGui::Text("Invocation: "); - ImGui::SameLine(); - ImGui::SetNextItemWidth(300); - auto* selected_invocation = GetSelectedInvocation(); - const std::string& active_invocation_name = - selected_invocation ? selected_invocation->name() : ""; - if (ImGui::BeginCombo("##active_invocation", active_invocation_name.c_str(), - ImGuiComboFlags_PopupAlignLeft)) { - if (debug_client_) { - for (auto* invocation : debug_client_->invocations()) { - ImGui::PushID(invocation->id()); - bool is_selected = invocation == selected_invocation; - if (ImGui::Selectable(invocation->name().c_str(), is_selected)) { - RETURN_IF_ERROR(NavigateToCodeView(*invocation, -1, - NavigationMode::kMatchDocument)); - } - if (is_selected) { - ImGui::SetItemDefaultFocus(); - } - ImGui::PopID(); - } - } - ImGui::EndCombo(); - } - ImGui::EndGroup(); - - ImGui::SameLine(); - ImGui::BeginGroup(); - static const float kPauseButtonHue = 0.0f; - static const float kResumeButtonHue = 2.0f; - static const float kStepButtonHue = 1.0f; - if (debug_client_ && !is_paused()) { - PushButtonHue(kPauseButtonHue); - if (ImGui::Button("Pause")) { - RETURN_IF_ERROR(debug_client_->SuspendAllInvocations()); - } - PopButtonStyle(); - } else if (debug_client_ && is_paused()) { - ImGui::PushStyleColor(ImGuiCol_Button, 0xFF666666); - ImGui::PushStyleColor(ImGuiCol_Text, 0xFFAAAAAA); - ImGui::ButtonEx("Pause", {}, ImGuiButtonFlags_Disabled); - ImGui::PopStyleColor(2); - } - if (debug_client_ && is_paused()) { - ImGui::SameLine(); - PushButtonHue(kResumeButtonHue); - if (ImGui::Button("Resume")) { - if (is_paused_) { - is_paused_ = false; - RETURN_IF_ERROR(debug_client_->MakeReady()); - } - while (!hit_breakpoints_.empty()) { - hit_breakpoints_.pop_back(); - RETURN_IF_ERROR(debug_client_->MakeReady()); - } - } - PopButtonStyle(); - } else { - ImGui::PushStyleColor(ImGuiCol_Button, 0xFF666666); - ImGui::PushStyleColor(ImGuiCol_Text, 0xFFAAAAAA); - ImGui::SameLine(); - ImGui::ButtonEx("Resume", {}, ImGuiButtonFlags_Disabled); - ImGui::PopStyleColor(2); - } - - if (debug_client_ && is_paused() && selected_invocation) { - ImGui::SameLine(); - PushButtonHue(kStepButtonHue); - if (ImGui::Button("Step Into")) { - RETURN_IF_ERROR( - debug_client_->StepInvocation(*selected_invocation, [this]() { - is_paused_ = true; - is_stepping_ = false; - })); - is_stepping_ = true; - } - PopButtonStyle(); - ImGui::SameLine(); - if (ImGui::Button("Step Over")) { - RETURN_IF_ERROR( - debug_client_->StepInvocationOver(*selected_invocation, [this]() { - is_paused_ = true; - is_stepping_ = false; - })); - is_stepping_ = true; - } - ImGui::SameLine(); - if (ImGui::Button("Step Out")) { - RETURN_IF_ERROR( - debug_client_->StepInvocationOut(*selected_invocation, [this]() { - is_paused_ = true; - is_stepping_ = false; - })); - is_stepping_ = true; - } - if (ImGui::BeginPopup("Step to...")) { - // TODO(benvanik): step to Invoke exit, next FFI call, etc - ImGui::MenuItem("(stuff)"); - ImGui::EndPopup(); - } - ImGui::SameLine(); - if (ImGui::Button("Step to...")) { - ImGui::OpenPopup("Step to..."); - } - } else { - ImGui::PushStyleColor(ImGuiCol_Button, 0xFF666666); - ImGui::PushStyleColor(ImGuiCol_Text, 0xFFAAAAAA); - ImGui::SameLine(); - ImGui::ButtonEx("Step Into", {}, ImGuiButtonFlags_Disabled); - ImGui::SameLine(); - ImGui::ButtonEx("Step Over", {}, ImGuiButtonFlags_Disabled); - ImGui::SameLine(); - ImGui::ButtonEx("Step Out", {}, ImGuiButtonFlags_Disabled); - ImGui::SameLine(); - ImGui::ButtonEx("Step to...", {}, ImGuiButtonFlags_Disabled); - ImGui::PopStyleColor(2); - } - ImGui::EndGroup(); - - ImGui::EndGroup(); - ImGui::End(); - return OkStatus(); -} - -Status DebugApp::DrawBreakpointListPanel() { - static bool is_panel_visible = true; - if (!ImGui::Begin("Breakpoints", &is_panel_visible, ImGuiWindowFlags_None)) { - ImGui::End(); - return OkStatus(); - } - - ImGui::PushStyleVar(ImGuiStyleVar_WindowPadding, ImVec2(8, 8)); - absl::optional<RemoteBreakpoint::Type> add_breakpoint_type; - if (ImGui::BeginPopup("+ Function")) { - if (ImGui::MenuItem("Bytecode Function")) { - add_breakpoint_type = RemoteBreakpoint::Type::kBytecodeFunction; - } - if (ImGui::MenuItem("Native Function")) { - add_breakpoint_type = RemoteBreakpoint::Type::kNativeFunction; - } - ImGui::EndPopup(); - } - ImGui::PopStyleVar(); - if (ImGui::Button("+ Function")) { - ImGui::OpenPopup("+ Function"); - } - RETURN_IF_ERROR(DrawAddBreakpointDialogs(add_breakpoint_type)); - - ImGui::SameLine(); - if (ImGui::Button("Remove All")) { - // TODO(benvanik): removal all is broken - need removebreakpoints or a - // 'want_removal' flag so that RefreshActiveBreakpoints handles things. - // Right now if you have 2 breakpoints and hit remove all the second will - // come back during the next refresh (as the server hasn't removed it yet). - for (auto& user_breakpoint : user_breakpoint_list_) { - if (user_breakpoint.active_breakpoint) { - RETURN_IF_ERROR(debug_client_->RemoveBreakpoint( - *user_breakpoint.active_breakpoint)); - user_breakpoint.active_breakpoint = nullptr; - } - } - user_breakpoint_list_.clear(); - } - ImGui::Separator(); - - ImGui::BeginChild("BreakpointList", ImVec2(-1, -1), false, - ImGuiWindowFlags_AlwaysVerticalScrollbar); - std::vector<UserBreakpoint*> dead_breakpoints; - for (auto& user_breakpoint : user_breakpoint_list_) { - ASSIGN_OR_RETURN(bool should_keep, DrawBreakpoint(&user_breakpoint)); - if (!should_keep) { - dead_breakpoints.push_back(&user_breakpoint); - } - } - for (auto* user_breakpoint : dead_breakpoints) { - for (auto it = user_breakpoint_list_.begin(); - it != user_breakpoint_list_.end(); ++it) { - if (&*it == user_breakpoint) { - if (user_breakpoint->active_breakpoint) { - RETURN_IF_ERROR(debug_client_->RemoveBreakpoint( - *user_breakpoint->active_breakpoint)); - } - user_breakpoint_list_.erase(it); - break; - } - } - } - ImGui::EndChild(); - - ImGui::End(); - return OkStatus(); -} - -StatusOr<bool> DebugApp::DrawBreakpoint(UserBreakpoint* user_breakpoint) { - std::string breakpoint_name; - switch (user_breakpoint->type) { - case RemoteBreakpoint::Type::kBytecodeFunction: - breakpoint_name = - absl::StrCat("[bytecode] ", user_breakpoint->module_name, ":", - user_breakpoint->function_name, ":", - user_breakpoint->bytecode_offset); - if (user_breakpoint->function_ordinal != -1) { - absl::StrAppend(&breakpoint_name, " @", - user_breakpoint->function_ordinal); - } - break; - case RemoteBreakpoint::Type::kNativeFunction: - breakpoint_name = - absl::StrCat("[native ] ", user_breakpoint->native_function); - break; - } - ImGui::BeginGroup(); - bool is_closing = true; - bool is_expanded = ImGui::CollapsingHeader( - ("##" + breakpoint_name).c_str(), &is_closing, - ImGuiTreeNodeFlags_Framed | ImGuiTreeNodeFlags_NoTreePushOnOpen | - ImGuiTreeNodeFlags_NoAutoOpenOnLog | ImGuiTreeNodeFlags_OpenOnArrow | - ImGuiTreeNodeFlags_OpenOnDoubleClick); - ImGui::SameLine(); - ImGui::Checkbox(breakpoint_name.c_str(), &user_breakpoint->wants_enabled); - ImGui::EndGroup(); - if (!is_expanded) { - return is_closing; - } - ImGui::PushID(breakpoint_name.c_str()); - - ImGui::Text("(breakpoint stats/etc)"); - - ImGui::PopID(); - return is_closing; -} - -Status DebugApp::DrawAddBreakpointDialogs( - absl::optional<RemoteBreakpoint::Type> add_breakpoint_type) { - if (add_breakpoint_type.has_value()) { - switch (add_breakpoint_type.value()) { - case RemoteBreakpoint::Type::kBytecodeFunction: - ImGui::OpenPopup("Add Bytecode Function Breakpoint"); - break; - case RemoteBreakpoint::Type::kNativeFunction: - ImGui::OpenPopup("Add Native Function Breakpoint"); - break; - } - } - ImGui::PushStyleVar(ImGuiStyleVar_WindowPadding, ImVec2(8, 8)); - RETURN_IF_ERROR(DrawAddBytecodeFunctionBreakpointDialog()); - RETURN_IF_ERROR(DrawAddNativeFunctionBreakpointDialog()); - ImGui::PopStyleVar(); - return OkStatus(); -} - -Status DebugApp::DrawAddBytecodeFunctionBreakpointDialog() { - ImGui::SetNextWindowSize(ImVec2(400, 400), ImGuiCond_FirstUseEver); - bool close_popup = true; - if (!ImGui::BeginPopupModal("Add Bytecode Function Breakpoint", &close_popup, - ImGuiWindowFlags_None)) { - return OkStatus(); - } - ImGui::BeginGroup(); - ImGui::BeginChild("##data_entry", - ImVec2(0, -ImGui::GetFrameHeightWithSpacing())); - - ImGui::TextWrapped( - "Adds a breakpoint set on the entry of the function (offset=0)."); - ImGui::Separator(); - - // TODO(benvanik): fancy list, filtering, etc. - - static char module_name[256] = {0}; - ImGui::InputText("Module", module_name, sizeof(module_name)); - ImGui::SetItemDefaultFocus(); - - static char function_name[256] = {0}; - ImGui::InputText("Function", function_name, sizeof(function_name)); - - ImGui::EndChild(); - ImGui::Separator(); - - if (ImGui::Button("Add")) { - int offset = 0; - if (FindMatchingUserBreakpointIndex(module_name, function_name, offset) == - -1) { - UserBreakpoint user_breakpoint; - user_breakpoint.type = RemoteBreakpoint::Type::kBytecodeFunction; - user_breakpoint.module_name = module_name; - user_breakpoint.function_name = function_name; - user_breakpoint.bytecode_offset = offset; - user_breakpoint.wants_enabled = true; - user_breakpoint_list_.push_back(std::move(user_breakpoint)); - } - ImGui::CloseCurrentPopup(); - } - ImGui::SameLine(); - if (ImGui::Button("Cancel")) { - ImGui::CloseCurrentPopup(); - } - - ImGui::EndGroup(); - ImGui::EndPopup(); - return OkStatus(); -} - -Status DebugApp::DrawAddNativeFunctionBreakpointDialog() { - ImGui::SetNextWindowSize(ImVec2(400, 400), ImGuiCond_FirstUseEver); - bool close_popup = true; - if (!ImGui::BeginPopupModal("Add Native Function Breakpoint", &close_popup, - ImGuiWindowFlags_None)) { - return OkStatus(); - } - ImGui::BeginGroup(); - ImGui::BeginChild("##data_entry", - ImVec2(0, -ImGui::GetFrameHeightWithSpacing())); - - ImGui::TextWrapped( - "Adds a breakpoint set on any call to the given FFI imported " - "function."); - ImGui::Separator(); - - static char function_name[256] = {0}; - ImGui::InputText("Function", function_name, sizeof(function_name)); - ImGui::SetItemDefaultFocus(); - - ImGui::EndChild(); - ImGui::Separator(); - - if (ImGui::Button("Add")) { - UserBreakpoint user_breakpoint; - user_breakpoint.type = RemoteBreakpoint::Type::kNativeFunction; - user_breakpoint.native_function = function_name; - user_breakpoint.wants_enabled = true; - user_breakpoint_list_.push_back(std::move(user_breakpoint)); - ImGui::CloseCurrentPopup(); - } - ImGui::SameLine(); - if (ImGui::Button("Cancel")) { - ImGui::CloseCurrentPopup(); - } - - ImGui::EndGroup(); - ImGui::EndPopup(); - return OkStatus(); -} - -Status DebugApp::DrawModuleListPanel() { - static bool is_panel_visible = true; - if (!ImGui::Begin("Modules", &is_panel_visible, ImGuiWindowFlags_None)) { - ImGui::End(); - return OkStatus(); - } else if (!debug_client_) { - ImGui::TextDisabled("disconnected"); - ImGui::End(); - return OkStatus(); - } - ImGui::PushStyleVar(ImGuiStyleVar_WindowPadding, ImVec2(4, 4)); - - ImGui::BeginGroup(); - ImGui::SetNextItemWidth(ImGui::GetContentRegionAvailWidth()); - static char function_name_filter_text[256] = {0}; - ImGui::InputTextWithHint( - "##function_name_filter", "Filter functions", function_name_filter_text, - sizeof(function_name_filter_text), ImGuiInputTextFlags_AutoSelectAll); - ImGuiTextFilter function_name_filter(function_name_filter_text); - ImGui::EndGroup(); - - ImGui::Separator(); - - ImGui::BeginGroup(); - ImGui::BeginChild("##context_list", ImVec2(0, -ImGui::GetFrameHeight())); - for (auto* context : debug_client_->contexts()) { - RETURN_IF_ERROR(DrawContext(*context, function_name_filter)); - } - ImGui::EndChild(); - ImGui::EndGroup(); - - ImGui::PopStyleVar(); - ImGui::End(); - return OkStatus(); -} - -Status DebugApp::DrawContext(const RemoteContext& context, - const ImGuiTextFilter& filter) { - std::string context_name = absl::StrCat("Context ", context.id()); - if (!ImGui::CollapsingHeader(context_name.c_str(), nullptr, - ImGuiTreeNodeFlags_DefaultOpen | - ImGuiTreeNodeFlags_Framed | - ImGuiTreeNodeFlags_NoTreePushOnOpen | - ImGuiTreeNodeFlags_NoAutoOpenOnLog | - ImGuiTreeNodeFlags_OpenOnArrow | - ImGuiTreeNodeFlags_OpenOnDoubleClick)) { - return OkStatus(); - } - ImGui::PushID(context.id()); - for (auto* module : context.modules()) { - RETURN_IF_ERROR(DrawModule(module, filter)); - } - ImGui::PopID(); - return OkStatus(); -} - -Status DebugApp::DrawModule(RemoteModule* module, - const ImGuiTextFilter& filter) { - ImGui::PushID(module->name().c_str()); - if (ImGui::TreeNodeEx(module->name().c_str(), - ImGuiTreeNodeFlags_Framed | - ImGuiTreeNodeFlags_DefaultOpen | - ImGuiTreeNodeFlags_OpenOnDoubleClick | - ImGuiTreeNodeFlags_OpenOnArrow)) { - if (module->CheckLoadedOrRequest()) { - for (auto* function : module->functions()) { - char function_name[128]; - if (function->name().empty()) { - std::snprintf(function_name, sizeof(function_name), "@%d", - function->ordinal()); - } else { - std::snprintf(function_name, sizeof(function_name), "@%d %s", - function->ordinal(), function->name().c_str()); - } - if (filter.IsActive() && !filter.PassFilter(function_name)) { - continue; - } - ImGui::PushID(function->ordinal()); - bool is_selected = false; - if (ImGui::Selectable("##selectable", &is_selected, - ImGuiSelectableFlags_AllowDoubleClick | - ImGuiSelectableFlags_DrawFillAvailWidth)) { - if (is_selected) { - RETURN_IF_ERROR(NavigateToCodeView(module->name(), - function->ordinal(), 0, - NavigationMode::kMatchDocument)); - } - } - ImGui::SameLine(); - // TODO(benvanik): detect if breakpoint active at offset 0. - ImGui::BulletText("%s", function_name); - ImGui::PopID(); - } - } else { - ImGui::TextDisabled("Loading..."); - } - ImGui::TreePop(); - } - ImGui::PopID(); - return OkStatus(); -} - -Status DebugApp::DrawLocalListPanel() { - static bool is_panel_visible = true; - if (!ImGui::Begin("Locals", &is_panel_visible, ImGuiWindowFlags_None)) { - ImGui::End(); - return OkStatus(); - } else if (!debug_client_) { - ImGui::TextDisabled("disconnected"); - ImGui::End(); - return OkStatus(); - } - auto* invocation = GetSelectedInvocation(); - if (!invocation) { - ImGui::TextDisabled("select a invocation to view locals"); - ImGui::End(); - return OkStatus(); - } else if (invocation->def().frames.empty()) { - ImGui::TextDisabled("(invocation has no frames)"); - ImGui::End(); - return OkStatus(); - } - int stack_frame_index = selected_stack_frame_index_.value_or(-1); - if (stack_frame_index == -1) { - stack_frame_index = invocation->def().frames.size() - 1; - } - auto& stack_frame = invocation->def().frames[stack_frame_index]; - - // TODO(benvanik): toggle for IREE VM locals vs. source locals. - for (int i = 0; i < stack_frame->locals.size(); ++i) { - auto& local = stack_frame->locals[i]; - RETURN_IF_ERROR(DrawLocal(invocation, stack_frame_index, i, *local)); - } - - ImGui::End(); - return OkStatus(); -} - -Status DebugApp::DrawLocal(RemoteInvocation* invocation, int stack_frame_index, - int local_index, const rpc::BufferViewDefT& local) { - // TODO(benvanik): columns and such in fancy table. - ImGui::Text("l%d", local_index); - ImGui::SameLine(50); - if (local.is_valid) { - auto shape_str = - absl::StrCat(absl::StrJoin(local.shape, "x"), "x", local.element_size); - ImGui::Text("%s", shape_str.c_str()); - } else { - ImGui::TextDisabled("∅"); - } - // TODO(benvanik): editing options (change shape, change contents, upload). - // TODO(benvanik): save/download/log options. - return OkStatus(); -} - -Status DebugApp::DrawInvocationListPanel() { - static bool is_panel_visible = true; - if (!ImGui::Begin("Invocations", &is_panel_visible, ImGuiWindowFlags_None)) { - ImGui::End(); - return OkStatus(); - } else if (!debug_client_) { - ImGui::TextDisabled("disconnected"); - ImGui::End(); - return OkStatus(); - } - for (auto* invocation : debug_client_->invocations()) { - RETURN_IF_ERROR(DrawInvocation(*invocation)); - } - ImGui::End(); - return OkStatus(); -} - -Status DebugApp::DrawInvocation(const RemoteInvocation& invocation) { - // TODO(benvanik): expand if any breakpoints are stopped in invocation. - if (selected_invocation_id_.has_value() && - selected_invocation_id_.value() == invocation.id()) { - ImGui::SetNextTreeNodeOpen(true); - } - if (!ImGui::CollapsingHeader(invocation.name().c_str())) { - return OkStatus(); - } - ImGui::PushID(invocation.id()); - - for (int i = 0; i < invocation.def().frames.size(); ++i) { - const auto& stack_frame = invocation.def().frames[i]; - ImGui::PushID(i); - // TODO(benvanik): highlight frames with breakpoints in them. - bool is_selected = selected_invocation_id_.has_value() && - selected_invocation_id_.value() == invocation.id() && - selected_stack_frame_index_.has_value() && - selected_stack_frame_index_.value() == i; - if (ImGui::Selectable("##selectable", &is_selected, - ImGuiSelectableFlags_AllowDoubleClick | - ImGuiSelectableFlags_DrawFillAvailWidth)) { - // TODO(benvanik): detect when clicking but already selected. - if (is_selected) { - RETURN_IF_ERROR( - NavigateToCodeView(invocation, i, NavigationMode::kMatchDocument)); - } - } - ImGui::SameLine(); - ImGui::Bullet(); - ImGui::SameLine(); - // TODO(benvanik): better naming/etc (resolve function). - ImGui::Text("%s:%d:%d", stack_frame->module_name.c_str(), - stack_frame->function_ordinal, stack_frame->offset); - - ImGui::PopID(); - } - - ImGui::PopID(); - return OkStatus(); -} - -DebugApp::CodeViewDocument* DebugApp::FindMatchingDocument( - absl::string_view module_name, int function_ordinal) { - for (auto& document : documents_) { - if (document->function->module()->name() == module_name && - document->function->ordinal() == function_ordinal) { - return document.get(); - } - } - return nullptr; -} - -Status DebugApp::NavigateToCodeView(absl::string_view module_name, - int function_ordinal, int offset, - NavigationMode navigation_mode) { - if (!debug_client_) { - return UnavailableErrorBuilder(IREE_LOC) << "No connection established"; - } - VLOG(1) << "NavigateToCodeView(" << module_name << ", " << function_ordinal - << ", " << offset << ")"; - CodeViewDocument* existing_document = nullptr; - switch (navigation_mode) { - case NavigationMode::kNewDocument: - // Fall through and create below. - break; - case NavigationMode::kCurrentDocument: - // Not yet done - treat as a new document. - break; - case NavigationMode::kMatchDocument: - existing_document = FindMatchingDocument(module_name, function_ordinal); - break; - } - if (existing_document) { - ImGui::SetWindowFocus(existing_document->title.c_str()); - return OkStatus(); - } - - // TODO(benvanik): make this common code. - RETURN_IF_ERROR(debug_client_->GetFunction( - std::string(module_name), function_ordinal, - [this, offset](StatusOr<RemoteFunction*> function_or) { - if (!function_or.ok()) { - // TODO(benvanik): error dialog. - CHECK_OK(function_or.status()); - } - auto* function = function_or.ValueOrDie(); - auto document = absl::make_unique<CodeViewDocument>(); - document->title = - absl::StrCat(function->module()->name(), ":", function->name()); - document->function = function; - document->focus_offset = offset; - ImGui::SetWindowFocus(document->title.c_str()); - documents_.push_back(std::move(document)); - })); - return OkStatus(); -} - -Status DebugApp::NavigateToCodeView(absl::string_view module_name, - absl::string_view function_name, int offset, - NavigationMode navigation_mode) { - if (!debug_client_) { - return UnavailableErrorBuilder(IREE_LOC) << "No connection established"; - } - return debug_client_->ResolveFunction( - std::string(module_name), std::string(function_name), - [this, navigation_mode, module_name, - offset](StatusOr<int> function_ordinal) { - CHECK_OK(function_ordinal.status()); - CHECK_OK(NavigateToCodeView(module_name, function_ordinal.ValueOrDie(), - offset, navigation_mode)); - }); -} - -Status DebugApp::NavigateToCodeView(const RemoteInvocation& invocation, - int stack_frame_index, - NavigationMode navigation_mode) { - if (!debug_client_) { - return UnavailableErrorBuilder(IREE_LOC) << "No connection established"; - } - const auto& stack_frame = stack_frame_index == -1 - ? *invocation.def().frames.back() - : *invocation.def().frames[stack_frame_index]; - selected_invocation_id_ = invocation.id(); - selected_stack_frame_index_ = stack_frame_index; - return NavigateToCodeView(stack_frame.module_name, - stack_frame.function_ordinal, stack_frame.offset, - NavigationMode::kMatchDocument); -} - -Status DebugApp::NavigateToCodeView(const UserBreakpoint& user_breakpoint, - NavigationMode navigation_mode) { - if (!debug_client_) { - return UnavailableErrorBuilder(IREE_LOC) << "No connection established"; - } - switch (user_breakpoint.type) { - case RemoteBreakpoint::Type::kBytecodeFunction: - if (user_breakpoint.function_ordinal != -1) { - return NavigateToCodeView( - user_breakpoint.module_name, user_breakpoint.function_ordinal, - user_breakpoint.bytecode_offset, navigation_mode); - } else { - return NavigateToCodeView( - user_breakpoint.module_name, user_breakpoint.function_name, - user_breakpoint.bytecode_offset, navigation_mode); - } - case RemoteBreakpoint::Type::kNativeFunction: - return UnimplementedErrorBuilder(IREE_LOC) - << "Navigation to non-bytecode functions unimplemented"; - } -} - -Status DebugApp::DrawCodeViewPanels() { - // If we've disconnected then we need to clear bodies. - // TODO(benvanik): allow documents to persist by caching all required info. - if (!debug_client_) { - documents_.clear(); - return OkStatus(); - } - - std::vector<CodeViewDocument*> closing_documents; - for (auto& document : documents_) { - ASSIGN_OR_RETURN(bool is_open, DrawCodeViewDocument(document.get())); - if (!is_open) { - closing_documents.push_back(document.get()); - } - } - for (auto* closing_document : closing_documents) { - auto it = std::find_if( - documents_.begin(), documents_.end(), - [closing_document](const std::unique_ptr<CodeViewDocument>& document) { - return document.get() == closing_document; - }); - documents_.erase(it); - } - return OkStatus(); -} - -StatusOr<bool> DebugApp::DrawCodeViewDocument(CodeViewDocument* document) { - ImGui::SetNextWindowDockID(dockspace_id_, ImGuiCond_FirstUseEver); - ImGui::PushStyleVar(ImGuiStyleVar_WindowPadding, ImVec2(0, 0)); - bool is_open = true; - bool is_visible = - ImGui::Begin(document->title.c_str(), &is_open, ImGuiWindowFlags_None); - if (!is_open || !is_visible) { - ImGui::End(); - ImGui::PopStyleVar(); - return is_open; - } - ImGui::PopStyleVar(); - - auto* remote_module = document->function->module(); - auto* remote_function = document->function; - if (remote_module->CheckLoadedOrRequest() && - remote_function->CheckLoadedOrRequest()) { - // TODO(benvanik): draw function signature. - if (remote_function->bytecode()) { - RETURN_IF_ERROR(DrawBytecodeCodeView(document)); - } else { - // TODO(benvanik): display native registration info. - ImGui::TextDisabled("(native)"); - } - } else { - ImGui::TextDisabled("loading..."); - } - - ImGui::End(); - return true; -} - -Status DebugApp::PrepareBytecodeCodeView(CodeViewDocument* document) { - auto* remote_module = document->function->module(); - auto* remote_function = document->function; - - document->bytecode_info.lines = remote_function->name(); - - return OkStatus(); -} - -Status DebugApp::DrawBytecodeCodeView(CodeViewDocument* document) { - // Ensure we have cached our line information. - RETURN_IF_ERROR(PrepareBytecodeCodeView(document)); - - auto* remote_module = document->function->module(); - auto* remote_function = document->function; - - ImGui::BeginGroup(); - ImGui::BeginChild("##bytecode_view", ImVec2(0, 0), false, - ImGuiWindowFlags_AlwaysVerticalScrollbar); - ImGui::PushStyleVar(ImGuiStyleVar_ItemSpacing, ImVec2(0, 0)); - - // TODO(benvanik): cache breakpoints for this function for faster lookup. - - auto& bytecode_info = document->bytecode_info; - ImGuiListClipper clipper(bytecode_info.lines.size(), - ImGui::GetTextLineHeightWithSpacing()); - while (clipper.Step()) { - for (int i = clipper.DisplayStart; i < clipper.DisplayEnd; ++i) { - ImGui::PushID(i); - - // TODO(benvanik): lookup line info. - int bytecode_offset = 0; - int breakpoint_index = FindMatchingUserBreakpointIndex( - remote_module->name(), remote_function->ordinal(), bytecode_offset); - bool has_breakpoint = breakpoint_index != -1; - bool active_on_any_invocation = false; - bool active_on_selected_invocation = false; - - ImGui::Dummy(ImVec2(4, 0)); - - // Gutter breakpoint button. - ImGui::SameLine(); - if (has_breakpoint) { - PushButtonHue(0.0f); // Red - if (ImGui::Button(" ##toggle_breakpoint")) { - CHECK_GE(breakpoint_index, 0); - auto& user_breakpoint = user_breakpoint_list_[breakpoint_index]; - if (user_breakpoint.active_breakpoint) { - RETURN_IF_ERROR(debug_client_->RemoveBreakpoint( - *user_breakpoint.active_breakpoint)); - } - user_breakpoint_list_.erase(user_breakpoint_list_.begin() + - breakpoint_index); - } - PopButtonStyle(); - if (ImGui::IsItemHovered()) { - ImGui::SetTooltip("Remove the breakpoint at this offset."); - } - } else { - PushButtonColor(ImGui::GetStyleColorVec4(ImGuiCol_ChildBg)); - if (ImGui::Button(" ##toggle_breakpoint")) { - UserBreakpoint user_breakpoint; - user_breakpoint.type = RemoteBreakpoint::Type::kBytecodeFunction; - user_breakpoint.module_name = remote_module->name(); - user_breakpoint.function_name = remote_function->name(); - user_breakpoint.bytecode_offset = bytecode_offset; - user_breakpoint.wants_enabled = true; - user_breakpoint_list_.push_back(std::move(user_breakpoint)); - } - PopButtonStyle(); - if (ImGui::IsItemHovered()) { - ImGui::SetTooltip("Add a breakpoint at this offset."); - } - } - - // Active execution chevron (shows when active or any invocation is - // executing this region). - ImGui::SameLine(); - if (active_on_selected_invocation) { - // The selected invocation is active here. - ImGui::TextColored(ImGui::GetStyleColorVec4(ImGuiCol_SeparatorActive), - " > "); - } else if (active_on_any_invocation) { - // At least one other invocation is active here. - ImGui::TextColored(ImGui::GetStyleColorVec4(ImGuiCol_Separator), " > "); - } else { - // Not active. - ImGui::Text(" "); - } - - // Line contents. - ImGui::SameLine(); - ImGui::Text("%s", bytecode_info.lines[i].c_str()); - - if (document->focus_offset.has_value() && - bytecode_offset == document->focus_offset.value()) { - document->bytecode_offset = document->focus_offset.value(); - document->focus_offset = {}; - ImGui::SetScrollHereY(); - } - - ImGui::PopID(); - } - } - - ImGui::PopStyleVar(); - ImGui::EndChild(); - ImGui::EndGroup(); - - return OkStatus(); -} - -} // namespace debug -} // namespace rt -} // namespace iree
diff --git a/iree/tools/debugger/debug_app.h b/iree/tools/debugger/debug_app.h deleted file mode 100644 index e970f5f..0000000 --- a/iree/tools/debugger/debug_app.h +++ /dev/null
@@ -1,200 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_TOOLS_DEBUGGER_DEBUG_APP_H_ -#define IREE_TOOLS_DEBUGGER_DEBUG_APP_H_ - -#include <SDL.h> - -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" -#include "iree/base/status.h" -#include "iree/rt/debug/debug_client.h" - -// NOTE: order matters here, imgui must come first: -#include "third_party/dear_imgui/imgui.h" -// NOTE: must follow imgui.h: -#include "third_party/dear_imgui/examples/imgui_impl_opengl3.h" -#include "third_party/dear_imgui/examples/imgui_impl_sdl.h" - -namespace iree { -namespace rt { -namespace debug { - -// Debug client app UI. -// Uses a DebugClient to communicate with a remote DebugServer and ImGui to -// display a nifty UI. -// -// See the ImGui site for more info: https://github.com/ocornut/imgui -// The most useful thing is the imgui_demo.cpp file that contains example usage -// of most features. -class DebugApp : private DebugClient::Listener { - public: - struct UserBreakpoint { - RemoteBreakpoint::Type type = RemoteBreakpoint::Type::kBytecodeFunction; - const RemoteBreakpoint* active_breakpoint = nullptr; - bool wants_enabled = true; - bool is_enabling = false; - // TODO(benvanik): reuse BreakpointDef here? - std::string module_name; - std::string function_name; - int function_ordinal = -1; - int bytecode_offset = 0; - std::string native_function; - }; - - static void PumpMainLoopThunk(void* arg); - - DebugApp(SDL_Window* window, SDL_GLContext gl_context, - const char* glsl_version); - ~DebugApp(); - - // Connects to the service at the specified address. - Status Connect(absl::string_view service_address); - // Disconnects from the currently connected service, if any. - Status Disconnect(); - - // Returns true if the remote service is paused at our request. - bool is_paused() const; - - // Pumps the main UI loop once. - // This polls the DebugClient, SDL input, and renders the UI. - // It should be called as frequently as possible to ensure snappy UI updates. - // Returns CancelledError if the app is being closed by the user. - Status PumpMainLoop(); - - // Defines how NavigationToCodeView methods behave. - enum class NavigationMode { - // The target will be opened in a new document tab. - kNewDocument, - // The target will be opened in the current document tab, replacing the - // current contents. - kCurrentDocument, - // The target will be opened in a document tab that mostly matches (like - // the same function in a module at a different offset), otherwise a new - // document will be opened. - kMatchDocument, - }; - - // Navigates to a particular function offset based on resolution of the given - // arguments. Navigation may happen asynchronously if targets need to be - // resolved or contents fetched. - Status NavigateToCodeView(absl::string_view module_name, int function_ordinal, - int offset, NavigationMode navigation_mode); - Status NavigateToCodeView(absl::string_view module_name, - absl::string_view function_name, int offset, - NavigationMode navigation_mode); - Status NavigateToCodeView(const RemoteInvocation& invocation, - int stack_frame_index, - NavigationMode navigation_mode); - Status NavigateToCodeView(const UserBreakpoint& user_breakpoint, - NavigationMode navigation_mode); - - private: - struct CodeViewDocument { - // Document display title (and ID). - std::string title; - // Function (and offset within the function) being displayed. - RemoteFunction* function = nullptr; - int bytecode_offset = 0; - // Set to a bytecode offset to have the document focus there. - absl::optional<int> focus_offset; - // Cached info for bytecode display. - struct { - std::vector<std::string> lines; - } bytecode_info; - }; - - CodeViewDocument* FindMatchingDocument(absl::string_view module_name, - int function_ordinal); - RemoteInvocation* GetSelectedInvocation() const; - - Status RefreshActiveBreakpoints(); - bool IsStoppedAtBreakpoint(const UserBreakpoint& user_breakpoint) const; - int FindMatchingUserBreakpointIndex(absl::string_view module_name, - int function_ordinal, int offset); - int FindMatchingUserBreakpointIndex(absl::string_view module_name, - absl::string_view function_name, - int offset); - Status ResumeFromBreakpoint(UserBreakpoint* user_breakpoint); - - Status OnContextRegistered(const RemoteContext& context) override; - Status OnContextUnregistered(const RemoteContext& context) override; - Status OnModuleLoaded(const RemoteContext& context, - const RemoteModule& module) override; - Status OnInvocationRegistered(const RemoteInvocation& invocation) override; - Status OnInvocationUnregistered(const RemoteInvocation& invocation) override; - Status OnBreakpointHit(const RemoteBreakpoint& breakpoint, - const RemoteInvocation& invocation) override; - - Status LayoutInitialDockSpace(); - - Status DrawUI(); - Status DrawMainMenu(); - Status DrawToolbar(); - - Status DrawBreakpointListPanel(); - StatusOr<bool> DrawBreakpoint(UserBreakpoint* user_breakpoint); - Status DrawAddBreakpointDialogs( - absl::optional<RemoteBreakpoint::Type> add_breakpoint_type); - Status DrawAddBytecodeFunctionBreakpointDialog(); - Status DrawAddNativeFunctionBreakpointDialog(); - - Status DrawModuleListPanel(); - Status DrawContext(const RemoteContext& context, - const ImGuiTextFilter& filter); - Status DrawModule(RemoteModule* module, const ImGuiTextFilter& filter); - - Status DrawLocalListPanel(); - Status DrawLocal(RemoteInvocation* invocation, int stack_frame_index, - int local_index, const rpc::BufferViewDefT& local); - - Status DrawInvocationListPanel(); - Status DrawInvocation(const RemoteInvocation& invocation); - - Status DrawCodeViewPanels(); - StatusOr<bool> DrawCodeViewDocument(CodeViewDocument* document); - Status PrepareBytecodeCodeView(CodeViewDocument* document); - Status DrawBytecodeCodeView(CodeViewDocument* document); - - SDL_Window* window_ = nullptr; - SDL_GLContext gl_context_ = nullptr; - - ImGuiID dockspace_id_; - ImGuiID dock_top_id_; - ImGuiID dock_left_id_; - ImGuiID dock_bottom_id_; - ImGuiID dock_bottom_left_id_; - ImGuiID dock_bottom_right_id_; - ImGuiID dock_right_id_; - ImGuiID dock_content_id_; - - std::unique_ptr<DebugClient> debug_client_; - std::vector<UserBreakpoint> user_breakpoint_list_; - - bool is_paused_ = false; - std::vector<const RemoteBreakpoint*> hit_breakpoints_; - bool is_stepping_ = false; - - absl::optional<int> selected_invocation_id_; - absl::optional<int> selected_stack_frame_index_; - - std::vector<std::unique_ptr<CodeViewDocument>> documents_; -}; - -} // namespace debug -} // namespace rt -} // namespace iree - -#endif // IREE_TOOLS_DEBUGGER_DEBUG_APP_H_
diff --git a/iree/tools/debugger/debug_app_embedded.cc b/iree/tools/debugger/debug_app_embedded.cc deleted file mode 100644 index f44692c..0000000 --- a/iree/tools/debugger/debug_app_embedded.cc +++ /dev/null
@@ -1,153 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/tools/debugger/debug_app_embedded.h" - -#include <SDL.h> - -#include <thread> // NOLINT - -#include "absl/base/thread_annotations.h" -#include "absl/memory/memory.h" -#include "absl/synchronization/mutex.h" -#include "iree/base/memory.h" -#include "iree/base/status.h" -#include "iree/tools/debugger/debug_app.h" -#include "third_party/SDL2/include/SDL_thread.h" - -namespace iree { -namespace rt { -namespace debug { - -class InProcessEmbeddedDebugger : public EmbeddedDebugger { - public: - explicit InProcessEmbeddedDebugger(std::unique_ptr<DebugApp> app) - : app_(std::move(app)) { - thread_ = - SDL_CreateThread(&ThreadMainThunk, "InProcessEmbeddedDebugger", this); - } - - ~InProcessEmbeddedDebugger() override { - VLOG(1) << "Setting shutdown flag and waiting on thread..."; - shutdown_flag_ = true; - int status = 0; - SDL_WaitThread(thread_, &status); - VLOG(1) << "Thread shutdown, killing app..."; - app_.reset(); - } - - Status AwaitClose() override { - await_mutex_.LockWhen(absl::Condition( - +[](bool* is_shutdown) { return *is_shutdown; }, &is_shutdown_)); - auto status = std::move(shutdown_status_); - await_mutex_.Unlock(); - return status; - } - - private: - static int ThreadMainThunk(void* arg) { - return reinterpret_cast<InProcessEmbeddedDebugger*>(arg)->ThreadMain(); - } - - int ThreadMain() { - VLOG(1) << "Thread entry"; - while (!shutdown_flag_) { - auto status = app_->PumpMainLoop(); - if (IsCancelled(status)) { - shutdown_flag_ = true; - break; - } else if (!shutdown_flag_ && !status.ok()) { - absl::MutexLock lock(&await_mutex_); - shutdown_status_ = std::move(status); - // TODO(benvanik): don't check unless no one is watching. - CHECK_OK(shutdown_status_); - } - } - app_.reset(); - { - absl::MutexLock lock(&await_mutex_); - is_shutdown_ = true; - } - VLOG(1) << "Thread exit"; - return 0; - } - - std::unique_ptr<DebugApp> app_; - SDL_Thread* thread_; - std::atomic<bool> shutdown_flag_ = {false}; - absl::Mutex await_mutex_; - bool is_shutdown_ ABSL_GUARDED_BY(await_mutex_) = false; - Status shutdown_status_ ABSL_GUARDED_BY(await_mutex_); -}; - -StatusOr<std::unique_ptr<EmbeddedDebugger>> LaunchDebugger() { - return AttachDebugger(""); -} - -StatusOr<std::unique_ptr<EmbeddedDebugger>> AttachDebugger( - absl::string_view service_address) { - LOG(INFO) << "Launching embedded debugger; service=" << service_address; - // Workaround for terrible bad SDL/graphics driver leaks. - IREE_DISABLE_LEAK_CHECKS(); - - if (SDL_Init(SDL_INIT_VIDEO | SDL_INIT_TIMER) != 0) { - return InternalErrorBuilder(IREE_LOC) - << "Unable to init SDL: " << SDL_GetError(); - } - -#if __APPLE__ - // GL 3.2 Core + GLSL 150 - const char* glsl_version = "#version 150"; - SDL_GL_SetAttribute( - SDL_GL_CONTEXT_FLAGS, - SDL_GL_CONTEXT_FORWARD_COMPATIBLE_FLAG); // Always required on Mac - SDL_GL_SetAttribute(SDL_GL_CONTEXT_PROFILE_MASK, SDL_GL_CONTEXT_PROFILE_CORE); - SDL_GL_SetAttribute(SDL_GL_CONTEXT_MAJOR_VERSION, 3); - SDL_GL_SetAttribute(SDL_GL_CONTEXT_MINOR_VERSION, 2); -#else - // GL 3.0 + GLSL 130 - const char* glsl_version = "#version 130"; - SDL_GL_SetAttribute(SDL_GL_CONTEXT_FLAGS, 0); - SDL_GL_SetAttribute(SDL_GL_CONTEXT_PROFILE_MASK, SDL_GL_CONTEXT_PROFILE_CORE); - SDL_GL_SetAttribute(SDL_GL_CONTEXT_MAJOR_VERSION, 3); - SDL_GL_SetAttribute(SDL_GL_CONTEXT_MINOR_VERSION, 0); -#endif - - SDL_GL_SetAttribute(SDL_GL_DOUBLEBUFFER, 1); - SDL_GL_SetAttribute(SDL_GL_DEPTH_SIZE, 24); - SDL_GL_SetAttribute(SDL_GL_STENCIL_SIZE, 8); - SDL_DisplayMode current; - SDL_GetCurrentDisplayMode(0, ¤t); - SDL_WindowFlags window_flags = (SDL_WindowFlags)( - SDL_WINDOW_OPENGL | SDL_WINDOW_RESIZABLE | SDL_WINDOW_ALLOW_HIGHDPI); - SDL_Window* window = - SDL_CreateWindow("IREE Debugger (embedded)", SDL_WINDOWPOS_CENTERED, - SDL_WINDOWPOS_CENTERED, 1280, 720, window_flags); - SDL_GLContext gl_context = SDL_GL_CreateContext(window); - SDL_GL_MakeCurrent(nullptr, nullptr); - - IREE_ENABLE_LEAK_CHECKS(); - - auto app = absl::make_unique<DebugApp>(window, gl_context, glsl_version); - if (!service_address.empty()) { - RETURN_IF_ERROR(app->Connect(service_address)); - } - - auto handle = absl::make_unique<InProcessEmbeddedDebugger>(std::move(app)); - return handle; -} - -} // namespace debug -} // namespace rt -} // namespace iree
diff --git a/iree/tools/debugger/debug_app_embedded.h b/iree/tools/debugger/debug_app_embedded.h deleted file mode 100644 index 8e7c15f..0000000 --- a/iree/tools/debugger/debug_app_embedded.h +++ /dev/null
@@ -1,52 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_TOOLS_DEBUGGER_DEBUG_APP_EMBEDDED_H_ -#define IREE_TOOLS_DEBUGGER_DEBUG_APP_EMBEDDED_H_ - -#include <memory> - -#include "absl/strings/string_view.h" -#include "iree/base/status.h" - -namespace iree { -namespace rt { -namespace debug { - -// RAII handle for keeping the debugger alive. -// When the instance is destroyed the debugger app will be closed. -class EmbeddedDebugger { - public: - virtual ~EmbeddedDebugger() = default; - - // Blocks the caller until the debugger is closed by the user. - virtual Status AwaitClose() = 0; -}; - -// Launches the debugger app. -// Returns a handle that can be used to wait for the debugger to close or -// force it to close. -StatusOr<std::unique_ptr<EmbeddedDebugger>> LaunchDebugger(); - -// Launches the debugger app and attaches to the given server address. -// Returns a handle that can be used to wait for the debugger to close or -// force it to close. -StatusOr<std::unique_ptr<EmbeddedDebugger>> AttachDebugger( - absl::string_view service_address); - -} // namespace debug -} // namespace rt -} // namespace iree - -#endif // IREE_TOOLS_DEBUGGER_DEBUG_APP_EMBEDDED_H_
diff --git a/iree/tools/debugger/debug_app_main_emscripten.cc b/iree/tools/debugger/debug_app_main_emscripten.cc deleted file mode 100644 index 1df8d98..0000000 --- a/iree/tools/debugger/debug_app_main_emscripten.cc +++ /dev/null
@@ -1,69 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Emscripten debug_app entry point. -// Though we are using SDL here we need to do some emscripten-specific magic to -// handle the different main looping mode (as we can't block in main() like on -// other platforms) as well as support some emscripten-specific features for -// file upload/download/etc. - -#include <SDL.h> -#include <emscripten.h> - -#include "iree/base/init.h" -#include "iree/tools/debugger/debug_app.h" - -namespace iree { -namespace rt { -namespace debug { - -extern "C" int main(int argc, char** argv) { - InitializeEnvironment(&argc, &argv); - - if (SDL_Init(SDL_INIT_VIDEO) != 0) { - printf("Error: %s\n", SDL_GetError()); - return -1; - } - - const char* glsl_version = "#version 100"; - SDL_GL_SetAttribute(SDL_GL_CONTEXT_FLAGS, 0); - SDL_GL_SetAttribute(SDL_GL_CONTEXT_PROFILE_MASK, SDL_GL_CONTEXT_PROFILE_ES); - SDL_GL_SetAttribute(SDL_GL_CONTEXT_MAJOR_VERSION, 2); - SDL_GL_SetAttribute(SDL_GL_CONTEXT_MINOR_VERSION, 0); - - SDL_GL_SetAttribute(SDL_GL_DOUBLEBUFFER, 1); - SDL_GL_SetAttribute(SDL_GL_DEPTH_SIZE, 24); - SDL_GL_SetAttribute(SDL_GL_STENCIL_SIZE, 8); - SDL_DisplayMode current; - SDL_GetCurrentDisplayMode(0, ¤t); - SDL_WindowFlags window_flags = (SDL_WindowFlags)( - SDL_WINDOW_OPENGL | SDL_WINDOW_RESIZABLE | SDL_WINDOW_ALLOW_HIGHDPI); - SDL_Window* window = - SDL_CreateWindow("IREE Debugger", SDL_WINDOWPOS_CENTERED, - SDL_WINDOWPOS_CENTERED, 1280, 720, window_flags); - SDL_GLContext gl_context = SDL_GL_CreateContext(window); - if (!gl_context) { - printf("Failed to initialize WebGL context!\n"); - return 1; - } - - auto app = absl::make_unique<DebugApp>(window, gl_context, glsl_version); - ::emscripten_set_main_loop_arg(DebugApp::PumpMainLoopThunk, app.release(), 0, - false); - return 0; -} - -} // namespace debug -} // namespace rt -} // namespace iree
diff --git a/iree/tools/debugger/debug_app_main_native.cc b/iree/tools/debugger/debug_app_main_native.cc deleted file mode 100644 index d364d14..0000000 --- a/iree/tools/debugger/debug_app_main_native.cc +++ /dev/null
@@ -1,45 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Native (linux/etc) debug_app entry point. -// This should work on any platform with pthreads and SDL support. - -#include "iree/base/init.h" -#include "iree/base/status.h" -#include "iree/tools/debugger/debug_app_embedded.h" - -namespace iree { -namespace rt { -namespace debug { - -Status Run() { - ASSIGN_OR_RETURN(auto handle, LaunchDebugger()); - RETURN_IF_ERROR(handle->AwaitClose()); - handle.reset(); - return OkStatus(); -} - -extern "C" int main(int argc, char** argv) { - InitializeEnvironment(&argc, &argv); - auto status = Run(); - if (!status.ok()) { - LOG(ERROR) << "Debugger error: " << status; - return 1; - } - return 0; -} - -} // namespace debug -} // namespace rt -} // namespace iree
diff --git a/iree/tools/debugger/debug_cli_main.cc b/iree/tools/debugger/debug_cli_main.cc deleted file mode 100644 index 4970a70..0000000 --- a/iree/tools/debugger/debug_cli_main.cc +++ /dev/null
@@ -1,40 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "absl/flags/flag.h" -#include "iree/base/init.h" -#include "iree/base/status.h" -#include "iree/tools/debugger/debug_prompt.h" - -ABSL_FLAG(std::string, debug_service_uri, "0.0.0.0:6000", - "IP/port of debug service to connect to."); - -namespace iree { -namespace rt { -namespace debug { - -Status Run() { - // TODO(benvanik): retry until connected? would allow auto-build reconnects. - return AttachDebugPrompt(absl::GetFlag(FLAGS_debug_service_uri)); -} - -extern "C" int main(int argc, char** argv) { - InitializeEnvironment(&argc, &argv); - CHECK_OK(Run()); - return 0; -} - -} // namespace debug -} // namespace rt -} // namespace iree
diff --git a/iree/tools/debugger/debug_prompt.cc b/iree/tools/debugger/debug_prompt.cc deleted file mode 100644 index aceda80..0000000 --- a/iree/tools/debugger/debug_prompt.cc +++ /dev/null
@@ -1,90 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/tools/debugger/debug_prompt.h" - -#include "iree/base/status.h" -#include "iree/rt/debug/debug_client.h" - -namespace iree { -namespace rt { -namespace debug { -namespace { - -class DebugPrompt : private DebugClient::Listener { - public: - Status Connect(absl::string_view debug_service_uri) { - // Connect to the debug service. - ASSIGN_OR_RETURN(debug_client_, - DebugClient::Connect(debug_service_uri, this)); - return OkStatus(); - } - - Status Run() { - // Query commands, transmit requests, and dispatch responses. - while (true) { - RETURN_IF_ERROR(debug_client_->Poll()); - - // TODO(benvanik): ask for a command. - // TODO(benvanik): display stuff. - } - } - - private: - Status OnContextRegistered(const RemoteContext& context) override { - // Ack. - return debug_client_->MakeReady(); - } - - Status OnContextUnregistered(const RemoteContext& context) override { - // Ack. - return debug_client_->MakeReady(); - } - - Status OnModuleLoaded(const RemoteContext& context, - const RemoteModule& module) override { - // Ack. - return debug_client_->MakeReady(); - } - - Status OnInvocationRegistered(const RemoteInvocation& invocation) override { - // Ack. - return debug_client_->MakeReady(); - } - - Status OnInvocationUnregistered(const RemoteInvocation& invocation) override { - // Ack. - return debug_client_->MakeReady(); - } - - Status OnBreakpointHit(const RemoteBreakpoint& breakpoint, - const RemoteInvocation& invocation) override { - // Ack. - return debug_client_->MakeReady(); - } - - std::unique_ptr<DebugClient> debug_client_; -}; - -} // namespace - -Status AttachDebugPrompt(absl::string_view debug_service_uri) { - DebugPrompt debug_prompt; - RETURN_IF_ERROR(debug_prompt.Connect(debug_service_uri)); - return debug_prompt.Run(); -} - -} // namespace debug -} // namespace rt -} // namespace iree
diff --git a/iree/tools/debugger/debug_prompt.h b/iree/tools/debugger/debug_prompt.h deleted file mode 100644 index 2c9bf7a..0000000 --- a/iree/tools/debugger/debug_prompt.h +++ /dev/null
@@ -1,35 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_TOOLS_DEBUGGER_DEBUG_PROMPT_H_ -#define IREE_TOOLS_DEBUGGER_DEBUG_PROMPT_H_ - -#include "absl/strings/string_view.h" -#include "iree/base/status.h" - -namespace iree { -namespace rt { -namespace debug { - -// TODO(benvanik): take stdin/stdout as arguments. -// Attaches a debug prompt reading stdin for commands and printing results to -// stdout. The calling thread will block until the debugger is exited or the -// debug service closes. -Status AttachDebugPrompt(absl::string_view debug_service_uri); - -} // namespace debug -} // namespace rt -} // namespace iree - -#endif // IREE_TOOLS_DEBUGGER_DEBUG_PROMPT_H_
diff --git a/iree/tools/run_mlir_main.cc b/iree/tools/run_mlir_main.cc deleted file mode 100644 index 3b7b3f9..0000000 --- a/iree/tools/run_mlir_main.cc +++ /dev/null
@@ -1,360 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// IREE source.mlir -> execution output test runner. -// This is meant to be called from LIT for FileCheck tests, and tries to match -// the interface of mlir-opt (featuring -split-input-file, etc) so it's easier -// to work with there. If you want a more generalized runner for standalone -// precompiled IREE modules use //third_party/iree/tools:run_module. -// -// By default all exported functions in the module will be run in order. -// All input values, provided via -input-values, will be passed to the -// functions (this means all input signatures must match). Results from the -// executed functions will be printed to stdout for checking. -// Use -output_types to set the function output data types, which like args will -// be used for all functions executed. -// -// Example input: -// // RUN: iree-run %s | FileCheck %s -// // CHECK-LABEL: @foo -// // CHECK: 1xf32: 2 -// func @foo() -> memref<f32> attributes {iree.module.export} { -// %0 = "iree.constant"() {value: dense<tensor<f32>, 2.0>} : () -> memref<f32> -// return %0 : memref<f32> -// } - -#include <iostream> - -#include "absl/flags/flag.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_replace.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "iree/base/init.h" -#include "iree/base/source_location.h" -#include "iree/base/status.h" -#include "iree/compiler/Translation/Sequencer/SequencerModuleTranslation.h" -#include "iree/hal/buffer_view_string_util.h" -#include "iree/hal/driver_registry.h" -#include "iree/rt/context.h" -#include "iree/rt/debug/debug_server_flags.h" -#include "iree/rt/instance.h" -#include "iree/rt/invocation.h" -#include "iree/rt/module.h" -#include "iree/rt/module_printer.h" -#include "iree/schemas/module_def_generated.h" -#include "iree/vm/sequencer_module.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/SourceMgr.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Function.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Module.h" -#include "mlir/Parser.h" -#include "mlir/Support/FileUtilities.h" - -ABSL_FLAG(bool, split_input_file, true, - "Split the input file into multiple modules."); - -ABSL_FLAG(std::string, target_backends, "", - "Comma-separated list of target backends to translate executables " - "into. Omit to translate using all linked-in backend translators."); -ABSL_FLAG( - bool, export_all, true, - "Automatically add the iree.module.export attribute to all functions."); - -ABSL_FLAG(std::string, input_values, "", "Input shapes and optional values."); -ABSL_FLAG(std::string, output_types, "", - "Output data types (comma delimited list of b/i/u/f for " - "binary/signed int/unsigned int/float)."); - -// TODO(benvanik): is there a more canonical flag we can use? -ABSL_FLAG(bool, print_mlir, true, "Prints MLIR IR during translation."); - -ABSL_FLAG(bool, print_bytecode, false, - "Prints IREE bytecode after translation."); - -ABSL_FLAG(bool, run, true, - "Option to run the file. Setting it to false just compiles it."); - -namespace iree { -namespace { - -using ::iree::hal::BufferView; -using ::iree::rt::Function; -using ::iree::rt::Module; - -// Returns a driver name capable of handling input from the given backend. -std::string BackendToDriverName(std::string backend) { - size_t dash = backend.find('-'); - if (dash == std::string::npos) { - return backend; - } else { - return backend.substr(0, dash); - } -} - -// Prepares a module for evaluation by running MLIR import and IREE translation. -StatusOr<ref_ptr<Module>> PrepareModule( - std::string target_backend, - std::unique_ptr<llvm::MemoryBuffer> file_buffer) { - mlir::MLIRContext context; - - // Parse input MLIR module. - llvm::SourceMgr source_mgr; - source_mgr.AddNewSourceBuffer(std::move(file_buffer), llvm::SMLoc()); - mlir::OwningModuleRef mlir_module = - mlir::parseSourceFile(source_mgr, &context); - - if (absl::GetFlag(FLAGS_export_all)) { - for (auto function : mlir_module->getOps<mlir::FuncOp>()) { - function.setAttr("iree.module.export", mlir::UnitAttr::get(&context)); - } - } - - // Translate from MLIR to IREE bytecode. - mlir::iree_compiler::ModuleTranslationOptions options; - options.print_mlir = absl::GetFlag(FLAGS_print_mlir); - options.target_backends = {target_backend}; - auto iree_module_bytes = - mlir::iree_compiler::translateMlirToIreeSequencerModule(mlir_module.get(), - options); - if (iree_module_bytes.empty()) { - return iree::InternalErrorBuilder(IREE_LOC) - << "Error translating MLIR to an IREE sequencer module"; - } - - if (absl::GetFlag(FLAGS_print_mlir)) { - mlir_module->dump(); - } - - // Wrap module in a file handle. - ASSIGN_OR_RETURN(auto iree_module_file, - vm::ModuleFile::FromBuffer(ModuleDefIdentifier(), - std::move(iree_module_bytes))); - return vm::SequencerModule::FromFile(std::move(iree_module_file)); -} - -// Parses a list of input shapes and values from a string of newline-separated -// inputs. Expects the contents to have one value per line with each value -// listed as -// [shape]xtype=[value] -// Example: -// 4x4xi8=0,1,2,3 -StatusOr<std::vector<BufferView>> ParseInputsFromFlags( - hal::Allocator *allocator) { - std::string file_contents = - absl::StrReplaceAll(absl::GetFlag(FLAGS_input_values), {{"\\n", "\n"}}); - std::vector<BufferView> inputs; - std::vector<std::string> lines = absl::StrSplit( - file_contents, absl::ByAnyChar("\n;"), absl::SkipWhitespace()); - for (const auto &line : lines) { - ASSIGN_OR_RETURN(auto input, - hal::ParseBufferViewFromString(line, allocator)); - inputs.push_back(input); - } - return inputs; -} - -// Outputs all results from the function to stdout in IREE BufferView format. -Status OutputFunctionResults(const Function &function, - absl::Span<BufferView> results) { - std::vector<std::string> output_types = - absl::StrSplit(absl::GetFlag(FLAGS_output_types), absl::ByAnyChar(", "), - absl::SkipWhitespace()); - if (!output_types.empty() && output_types.size() != results.size()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "--output_types= specified but has " << output_types.size() - << " types when the function returns " << results.size(); - } - - for (int i = 0; i < results.size(); ++i) { - const auto &result = results[i]; - auto print_mode = hal::BufferViewPrintMode::kFloatingPoint; - if (!output_types.empty()) { - ASSIGN_OR_RETURN(print_mode, - hal::ParseBufferViewPrintMode(output_types[i])); - } - ASSIGN_OR_RETURN(auto result_str, - hal::PrintBufferViewToString(result, print_mode, 1024)); - LOG(INFO) << "result[" << i << "]: " << result.buffer->DebugString(); - std::cout << result_str << "\n"; - } - - return OkStatus(); -} - -// Evaluates a single function in its own fiber, printing the results to stdout. -Status EvaluateFunction(const ref_ptr<rt::Context> &context, - hal::Allocator *allocator, const Function &function) { - std::cout << "EXEC @" << function.name() << std::endl; - - // Create invocation that will perform the execution. - ASSIGN_OR_RETURN(auto arguments, ParseInputsFromFlags(allocator)); - ASSIGN_OR_RETURN( - auto invocation, - rt::Invocation::Create(add_ref(context), function, make_ref<rt::Policy>(), - {}, absl::MakeConstSpan(arguments))); - - // Wait until invocation completes. - RETURN_IF_ERROR(invocation->Await(absl::InfiniteFuture())); - - // Print outputs. - ASSIGN_OR_RETURN(auto results, invocation->ConsumeResults()); - RETURN_IF_ERROR(OutputFunctionResults(function, absl::MakeSpan(results))); - - return OkStatus(); -} - -// Evaluates all exported functions within given module. -Status EvaluateFunctions(absl::string_view target_backend, - ref_ptr<Module> module) { - // Create the context we'll use for this (ensuring that we can't interfere - // with other running evaluations, such as when in a multithreaded test - // runner). - ASSIGN_OR_RETURN(auto debug_server, rt::debug::CreateDebugServerFromFlags()); - auto instance = make_ref<rt::Instance>(std::move(debug_server)); - ASSIGN_OR_RETURN(auto driver, hal::DriverRegistry::shared_registry()->Create( - target_backend)); - ASSIGN_OR_RETURN(auto device, driver->CreateDefaultDevice()); - RETURN_IF_ERROR(instance->device_manager()->RegisterDevice(device)); - - if (absl::GetFlag(FLAGS_print_bytecode)) { - RETURN_IF_ERROR(rt::PrintModuleToStream( - *module, rt::PrintModuleFlag::kDisassemble, &std::cout)); - } - - // Evaluate all exported functions. - auto policy = make_ref<rt::Policy>(); - auto run_function = [&](int ordinal) -> Status { - // Setup a new context for this invocation. - auto context = make_ref<rt::Context>(add_ref(instance), add_ref(policy)); - RETURN_IF_ERROR(context->RegisterModule(add_ref(module))); - - // Invoke the function and print results. - ASSIGN_OR_RETURN(auto function, - module->LookupFunctionByOrdinal( - rt::Function::Linkage::kExport, ordinal)); - RETURN_IF_ERROR(EvaluateFunction(context, device->allocator(), function)); - return OkStatus(); - }; - - Status evaluate_status = OkStatus(); - for (int i = 0; i < module->signature().export_function_count(); ++i) { - evaluate_status = run_function(i); - if (!evaluate_status.ok()) { - break; - } - } - - RETURN_IF_ERROR(instance->device_manager()->UnregisterDevice(device.get())); - device.reset(); - driver.reset(); - - return evaluate_status; -} - -// Translates and runs a single LLVM file buffer. -Status EvaluateFile(std::unique_ptr<llvm::MemoryBuffer> file_buffer) { - std::vector<std::string> target_backends; - if (absl::GetFlag(FLAGS_target_backends).empty()) { - target_backends = - hal::DriverRegistry::shared_registry()->EnumerateAvailableDrivers(); - } else { - // We need to map specific backends names to drivers (like 'vulkan-spirv' to - // the driver 'vulkan'). - target_backends = absl::StrSplit(absl::GetFlag(FLAGS_target_backends), ','); - } - - for (auto target_backend : target_backends) { - // Prepare the module for execution and evaluate it. - auto cloned_file_buffer = llvm::MemoryBuffer::getMemBufferCopy( - file_buffer->getBuffer(), file_buffer->getBufferIdentifier()); - ASSIGN_OR_RETURN(auto module, PrepareModule(target_backend + '*', - std::move(cloned_file_buffer))); - if (!absl::GetFlag(FLAGS_run)) { - continue; - } - RETURN_IF_ERROR(EvaluateFunctions(BackendToDriverName(target_backend), - std::move(module))); - } - - return OkStatus(); -} - -// Runs the given .mlir file based on the current flags. -Status RunFile(std::string mlir_filename) { - // Load input file/from stdin. - std::string error_message; - auto file = mlir::openInputFile(mlir_filename, &error_message); - if (!file) { - return NotFoundErrorBuilder(IREE_LOC) - << "Unable to open input file " << mlir_filename << ": " - << error_message; - } - - if (!absl::GetFlag(FLAGS_split_input_file)) { - // Use entire buffer as a single module. - return EvaluateFile(std::move(file)); - } - - // Split the buffer into separate modules and evaluate independently. - // This matches the -split-input-file arg to mlir-opt. - const char kSplitMarker[] = "// -----\n"; - auto *full_buffer = file.get(); - llvm::SmallVector<llvm::StringRef, 8> source_buffers; - full_buffer->getBuffer().split(source_buffers, kSplitMarker); - - // Add the original buffer to the source manager. - llvm::SourceMgr fileSourceMgr; - fileSourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc()); - - // Process each chunk in turn. Only return the first error (but log all). - Status any_failure; - for (auto &sub_source_buffer : source_buffers) { - auto split_loc = llvm::SMLoc::getFromPointer(sub_source_buffer.data()); - unsigned split_line = fileSourceMgr.getLineAndColumn(split_loc).first; - auto sub_buffer = llvm::MemoryBuffer::getMemBufferCopy( - sub_source_buffer, full_buffer->getBufferIdentifier() + - llvm::Twine(" split at line #") + - llvm::Twine(split_line)); - auto sub_failure = EvaluateFile(std::move(sub_buffer)); - if (!sub_failure.ok()) { - LOG(ERROR) << sub_failure; - if (any_failure.ok()) { - any_failure = std::move(sub_failure); - } - } - } - - return any_failure; -} - -} // namespace - -extern "C" int main(int argc, char **argv) { - InitializeEnvironment(&argc, &argv); - if (argc < 2) { - LOG(ERROR) << "Must supply an input .mlir file."; - return 1; - } - auto status = RunFile(argv[1]); - if (!status.ok()) { - std::cerr << "ERROR running file (" << argv[1] << "): " << status << "\n"; - return 1; - } - return 0; -} - -} // namespace iree
diff --git a/iree/tools/run_module_main.cc b/iree/tools/run_module_main.cc deleted file mode 100644 index be634ee..0000000 --- a/iree/tools/run_module_main.cc +++ /dev/null
@@ -1,183 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <iostream> -#include <vector> - -#include "absl/flags/flag.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_replace.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "iree/base/file_io.h" -#include "iree/base/file_path.h" -#include "iree/base/init.h" -#include "iree/base/source_location.h" -#include "iree/base/status.h" -#include "iree/hal/buffer_view_string_util.h" -#include "iree/hal/driver_registry.h" -#include "iree/rt/context.h" -#include "iree/rt/debug/debug_server_flags.h" -#include "iree/rt/instance.h" -#include "iree/rt/module_printer.h" -#include "iree/schemas/module_def_generated.h" -#include "iree/vm/sequencer_module.h" - -ABSL_FLAG(std::string, main_module, "", "Main module with entry point."); -ABSL_FLAG(std::string, main_function, "", - "Function within the main module to execute."); - -ABSL_FLAG(bool, print_disassembly, true, - "Prints bytecode disassembly for the module."); - -ABSL_FLAG(std::string, input_values, "", "Input shapes and optional values."); -ABSL_FLAG(std::string, input_file, "", - "Input shapes and optional values serialized in a file."); - -ABSL_FLAG(std::string, output_types, "", - "Output data types (comma delimited list of b/i/u/f for " - "binary/signed int/unsigned int/float)."); - -namespace iree { -namespace { - -// Parses a list of input shapes and values from a string of newline-separated -// inputs. Expects the contents to have one value per line with each value -// listed as -// [shape]xtype=[value] -// Example: -// 4x4xi8=0,1,2,3 -StatusOr<std::vector<hal::BufferView>> ParseInputsFromFlags( - hal::Allocator* allocator) { - std::string file_contents; - if (!absl::GetFlag(FLAGS_input_values).empty()) { - file_contents = - absl::StrReplaceAll(absl::GetFlag(FLAGS_input_values), {{"\\n", "\n"}}); - } else if (!absl::GetFlag(FLAGS_input_file).empty()) { - ASSIGN_OR_RETURN(file_contents, - file_io::GetFileContents(absl::GetFlag(FLAGS_input_file))); - } - std::vector<hal::BufferView> inputs; - for (const auto& line : - absl::StrSplit(file_contents, '\n', absl::SkipWhitespace())) { - ASSIGN_OR_RETURN(auto input, - hal::ParseBufferViewFromString(line, allocator)); - inputs.push_back(input); - } - return inputs; -} - -} // namespace - -Status Run() { - ASSIGN_OR_RETURN(auto debug_server, rt::debug::CreateDebugServerFromFlags()); - auto instance = make_ref<rt::Instance>(std::move(debug_server)); - ASSIGN_OR_RETURN(auto driver, hal::DriverRegistry::shared_registry()->Create( - "interpreter")); - ASSIGN_OR_RETURN(auto device, driver->CreateDefaultDevice()); - RETURN_IF_ERROR(instance->device_manager()->RegisterDevice(device)); - auto policy = make_ref<rt::Policy>(); - auto context = make_ref<rt::Context>(add_ref(instance), std::move(policy)); - - // Load main module. - ASSIGN_OR_RETURN( - auto main_module_file, - vm::ModuleFile::LoadFile(ModuleDefIdentifier(), - absl::GetFlag(FLAGS_main_module)), - _ << "while loading module file " << absl::GetFlag(FLAGS_main_module)); - ASSIGN_OR_RETURN(auto main_module, - vm::SequencerModule::FromFile(std::move(main_module_file))); - - // Register the main module with the context. - // We could add additional modules (specializations, shared libraries, etc). - // ModuleFioles are stateless so we could have the same module_file used by - // multiple contexts simultaneously. - RETURN_IF_ERROR(context->RegisterModule(add_ref(main_module))); - - // Dump the registered modules. - rt::PrintModuleFlagBitfield print_flags = rt::PrintModuleFlag::kNone; - if (absl::GetFlag(FLAGS_print_disassembly)) { - print_flags |= rt::PrintModuleFlag::kDisassemble; - } - for (const auto& module : context->modules()) { - RETURN_IF_ERROR(PrintModuleToStream(*module, print_flags, &std::cout)); - } - - rt::Function main_function; - if (!absl::GetFlag(FLAGS_main_function).empty()) { - // User-specified main function. - ASSIGN_OR_RETURN(main_function, main_module->LookupFunctionByName( - rt::Function::Linkage::kExport, - absl::GetFlag(FLAGS_main_function))); - } else { - // No main function specified; to prevent non-deterministic behavior we - // require one unless there's exactly one exported function in the module. - if (main_module->signature().export_function_count() == 1) { - ASSIGN_OR_RETURN(main_function, main_module->LookupFunctionByOrdinal( - rt::Function::Linkage::kExport, 0)); - } else { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "--main_function= must be specified to disambiguate the " - "function to run"; - } - } - - // Call into the main function. - ASSIGN_OR_RETURN(auto arguments, ParseInputsFromFlags(device->allocator())); - ASSIGN_OR_RETURN(auto invocation, - rt::Invocation::Create(add_ref(context), main_function, - make_ref<rt::Policy>(), {}, - absl::MakeConstSpan(arguments))); - - // Wait until invocation completes. - RETURN_IF_ERROR(invocation->Await(absl::InfiniteFuture())); - ASSIGN_OR_RETURN(auto results, invocation->ConsumeResults()); - - // Dump all results to stdout. - std::vector<std::string> output_types = - absl::StrSplit(absl::GetFlag(FLAGS_output_types), absl::ByAnyChar(", "), - absl::SkipWhitespace()); - if (!output_types.empty() && output_types.size() != results.size()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "--output_types= specified but has " << output_types.size() - << " types when the function returns " << results.size(); - } - for (int i = 0; i < results.size(); ++i) { - const auto& result = results[i]; - auto print_mode = hal::BufferViewPrintMode::kFloatingPoint; - if (!output_types.empty()) { - ASSIGN_OR_RETURN(print_mode, - hal::ParseBufferViewPrintMode(output_types[i])); - } - ASSIGN_OR_RETURN(auto result_str, - PrintBufferViewToString(result, print_mode, 1024)); - const auto& buffer = result.buffer; - if (!buffer) { - return InternalErrorBuilder(IREE_LOC) - << "result[" << i << "] unexpectedly has no buffer"; - } - LOG(INFO) << "result[" << i << "]: " << buffer->DebugString(); - std::cout << result_str << "\n"; - } - - return OkStatus(); -} - -extern "C" int main(int argc, char** argv) { - InitializeEnvironment(&argc, &argv); - CHECK_OK(Run()); - return 0; -} - -} // namespace iree
diff --git a/iree/tools/web/BUILD b/iree/tools/web/BUILD deleted file mode 100644 index f55ef9d..0000000 --- a/iree/tools/web/BUILD +++ /dev/null
@@ -1,89 +0,0 @@ -# IREE web tools. - -load("//third_party/emscripten:split_transition_defs.bzl", "auto_wasm_binary") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -EMSCRIPTEN_LINKOPTS_COMMON = [ - # Error at compile time on unresolved symbols. - "-s ERROR_ON_UNDEFINED_SYMBOLS=1", - - # Note: If pthreads and memory growth are enabled, WASM_MEM_MAX must be set. - # Also, USE_PTHREADS + ALLOW_MEMORY_GROWTH may run non-wasm code slowly. - # "-s ALLOW_MEMORY_GROWTH=1", - # "-s WASM_MEM_MAX=268435456", # 256MB - # "-s TOTAL_MEMORY=268435456", # 256MB - - # Request a prepopulated pool of web workers for pthreads to use. - # Without this, threads may not start until the javascript thread yields. - # See considerations at https://emscripten.org/docs/porting/pthreads.html. - "-s PTHREAD_POOL_SIZE=1", -] - -EMSCRIPTEN_LINKOPTS_DBG = [ - # Show WASM stack trace in Chrome debugger. - "-g2", - "-s DEMANGLE_SUPPORT=1", - - # Enable verbose assertions. - "-s ASSERTIONS=2", - "-s SAFE_HEAP=1", - "-s STACK_OVERFLOW_CHECK=2", -] - -EMSCRIPTEN_LINKOPTS_OPT = [] - -# To use run_module_emscripten: -# > bazel build third_party/iree/tools/web:run_module_emscripten_files - -cc_binary( - name = "run_module_emscripten", - srcs = ["run_module_emscripten.cc"], - linkopts = EMSCRIPTEN_LINKOPTS_COMMON + select({ - "//tools/compilation_mode:dbg": EMSCRIPTEN_LINKOPTS_DBG, - "//tools/compilation_mode:opt": EMSCRIPTEN_LINKOPTS_OPT, - "//conditions:default": EMSCRIPTEN_LINKOPTS_OPT, - }), - tags = [ - "manual", - "notap", # TODO(b/137088911): Build/test on TAP - "wasm", - ], - deps = [ - "//iree/base:init", - "//iree/base:status", - "//iree/hal:buffer_view_string_util", - "//iree/hal:driver_registry", - "//iree/hal/interpreter:interpreter_driver_module", - "//iree/rt", - "//iree/vm:sequencer_module", - "//third_party/emscripten:embind", - ], -) - -auto_wasm_binary( - name = "run_module_emscripten_binary", - cc_target = ":run_module_emscripten", - tags = ["manual"], - threads = "emscripten", -) - -Fileset( - name = "run_module_emscripten_files", - out = "wasm_files", - entries = [ - FilesetEntry( - files = [":run_module_emscripten_binary"], - strip_prefix = "run_module_emscripten_binary", - destdir = "wasm", - ), - FilesetEntry( - files = ["run_module.html"], - destdir = "wasm", - ), - ], - tags = ["manual"], -)
diff --git a/iree/tools/web/run_module_emscripten.cc b/iree/tools/web/run_module_emscripten.cc deleted file mode 100644 index 49d88a7..0000000 --- a/iree/tools/web/run_module_emscripten.cc +++ /dev/null
@@ -1,140 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <emscripten.h> -#include <emscripten/bind.h> - -#include <vector> - -#include "absl/strings/str_replace.h" -#include "absl/strings/str_split.h" -#include "absl/strings/string_view.h" -#include "iree/base/flatbuffer_util.h" -#include "iree/base/init.h" -#include "iree/base/status.h" -#include "iree/hal/buffer_view.h" -#include "iree/hal/buffer_view_string_util.h" -#include "iree/hal/driver_registry.h" -#include "iree/rt/context.h" -#include "iree/rt/instance.h" -#include "iree/schemas/module_def_generated.h" -#include "iree/vm/sequencer_module.h" - -namespace iree { - -// Parses a list of input shapes and values from a string of newline-separated -// inputs. Expects the contents to have one value per line with each value -// listed as -// [shape]xtype=[value] -// Example: -// 4x4xi8=0,1,2,3 -StatusOr<std::vector<hal::BufferView>> ParseInputs( - absl::string_view inputs_string, hal::Allocator* allocator) { - std::string input_lines = absl::StrReplaceAll(inputs_string, {{"\\n", "\n"}}); - std::vector<hal::BufferView> input_buffer_views; - for (const auto& input_line : - absl::StrSplit(input_lines, '\n', absl::SkipWhitespace())) { - ASSIGN_OR_RETURN(auto input_buffer_view, - hal::ParseBufferViewFromString(input_line, allocator)); - input_buffer_views.push_back(input_buffer_view); - } - return input_buffer_views; -} - -// Runs an IREE module with the provided inputs and returns its outputs. -StatusOr<std::string> RunIreeModule(std::string module_file_data, - absl::string_view inputs_string) { - auto instance = make_ref<rt::Instance>(); - - // Create driver and device. - ASSIGN_OR_RETURN(auto driver, hal::DriverRegistry::shared_registry()->Create( - "interpreter")); - ASSIGN_OR_RETURN(auto device, driver->CreateDefaultDevice()); - RETURN_IF_ERROR(instance->device_manager()->RegisterDevice(device)); - - auto policy = make_ref<rt::Policy>(); - auto context = make_ref<rt::Context>(add_ref(instance), std::move(policy)); - - // Load main module FlatBuffer. - ASSIGN_OR_RETURN(auto main_module_file, - FlatBufferFile<ModuleDef>::FromString(ModuleDefIdentifier(), - module_file_data)); - ASSIGN_OR_RETURN(auto main_module, - vm::SequencerModule::FromFile(std::move(main_module_file))); - - // Register the main module with the context. - RETURN_IF_ERROR(context->RegisterModule(add_ref(main_module))); - - // Setup arguments and storage for results. - // TODO(scotttodd): Receive main function name from JS. - ASSIGN_OR_RETURN(auto main_function, - main_module->LookupFunctionByName( - rt::Function::Linkage::kExport, "main")); - - ASSIGN_OR_RETURN(auto arguments, - ParseInputs(inputs_string, device->allocator())); - - // Call into the main function. - ASSIGN_OR_RETURN(auto invocation, - rt::Invocation::Create(add_ref(context), main_function, - make_ref<rt::Policy>(), {}, - absl::MakeConstSpan(arguments))); - - // Wait until invocation completes. - // TODO(scotttodd): make this an async callback. - RETURN_IF_ERROR(invocation->Await(absl::InfiniteFuture())); - ASSIGN_OR_RETURN(auto results, invocation->ConsumeResults()); - - // Dump all results to stdout. - // TODO(scotttodd): Receive output types / print mode from JS. - // TODO(scotttodd): Return list of outputs instead of just the first (proto?) - for (int i = 0; i < results.size(); ++i) { - const auto& result = results[i]; - auto print_mode = hal::BufferViewPrintMode::kFloatingPoint; - ASSIGN_OR_RETURN(auto result_str, - PrintBufferViewToString(result, print_mode, 1024)); - const auto& buffer = result.buffer; - if (!buffer) { - return InternalErrorBuilder(IREE_LOC) - << "result[" << i << "] unexpectedly has no buffer"; - } - - return result_str; - } - - return InternalErrorBuilder(IREE_LOC) << "Received no results"; -} - -std::string RunIreeModuleEntry(std::string module_file_data, - std::string inputs_string) { - // TODO(scotttodd): optimize, minimize copies - // https://groups.google.com/d/msg/emscripten-discuss/CMfYljLWMvY/Di52WB2QAgAJ - auto result_or = RunIreeModule(std::move(module_file_data), inputs_string); - if (!result_or.ok()) { - return "Error: " + result_or.status().ToString(); - } else { - return result_or.ValueOrDie(); - } -} - -EMSCRIPTEN_BINDINGS(iree) { - emscripten::function("runIreeModule", &RunIreeModuleEntry); -} - -extern "C" int main(int argc, char** argv) { - InitializeEnvironment(&argc, &argv); - return 0; -} - -} // namespace iree
diff --git a/iree/vm/BUILD b/iree/vm/BUILD deleted file mode 100644 index b197f32..0000000 --- a/iree/vm/BUILD +++ /dev/null
@@ -1,200 +0,0 @@ -# Bytecode VM used by the IREE sequencer and interpreter. - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "api", - srcs = ["api.cc"], - hdrs = ["api.h"], - visibility = ["//visibility:public"], - deps = [ - ":api_hdrs", - ":sequencer_module", - "//iree/base:api", - "//iree/base:api_util", - "//iree/base:flatbuffer_util", - "//iree/base:tracing", - "//iree/rt:api", - ], -) - -cc_library( - name = "api_hdrs", - hdrs = ["api.h"], - deps = [ - "//iree/base:api_hdrs", - "//iree/rt:api_hdrs", - ], -) - -cc_library( - name = "bytecode_module", - srcs = [ - "bytecode_disassembler.cc", - "bytecode_module.cc", - ], - hdrs = [ - "bytecode_disassembler.h", - "bytecode_module.h", - ], - deps = [ - ":bytecode_util", - ":opcode_info", - ":source_map_resolver", - ":type", - "//iree/base:flatbuffer_util", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:buffer_view", - "//iree/rt", - "//iree/schemas", - "//iree/schemas/bytecode:bytecode_v0", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "bytecode_reader", - srcs = ["bytecode_reader.cc"], - hdrs = ["bytecode_reader.h"], - deps = [ - ":bytecode_module", - ":type", - "//iree/base:shape", - "//iree/base:status", - "//iree/hal:buffer_view", - "//iree/hal:heap_buffer", - "//iree/rt", - "//iree/schemas/bytecode:bytecode_v0", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:inlined_vector", - ], -) - -cc_library( - name = "bytecode_tables_interpreter", - srcs = ["bytecode_tables_interpreter.cc"], - hdrs = ["bytecode_tables_interpreter.h"], - deps = [ - ":opcode_info", - "//iree/schemas/bytecode:interpreter_bytecode_v0", - ], -) - -cc_library( - name = "bytecode_tables_sequencer", - srcs = ["bytecode_tables_sequencer.cc"], - hdrs = ["bytecode_tables_sequencer.h"], - deps = [ - ":opcode_info", - "//iree/schemas/bytecode:sequencer_bytecode_v0", - ], -) - -cc_library( - name = "bytecode_util", - srcs = ["bytecode_util.cc"], - hdrs = ["bytecode_util.h"], - deps = [ - "//iree/schemas/bytecode:bytecode_v0", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "bytecode_validator", - srcs = ["bytecode_validator.cc"], - hdrs = ["bytecode_validator.h"], - deps = [ - ":bytecode_module", - "//iree/base:status", - "//iree/schemas", - ], -) - -cc_library( - name = "opcode_info", - hdrs = ["opcode_info.h"], - deps = [ - "//iree/schemas/bytecode:bytecode_v0", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "sequencer_dispatch", - srcs = ["sequencer_dispatch.cc"], - hdrs = ["sequencer_dispatch.h"], - deps = [ - ":bytecode_module", - ":bytecode_reader", - ":bytecode_tables_sequencer", - ":bytecode_util", - ":opcode_info", - "//iree/base:logging", - "//iree/base:memory", - "//iree/base:status", - "//iree/hal:buffer_view", - "//iree/hal:command_queue", - "//iree/hal:device", - "//iree/hal:device_placement", - "//iree/hal:heap_buffer", - "//iree/rt", - "//iree/schemas/bytecode:sequencer_bytecode_v0", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( - name = "sequencer_module", - srcs = ["sequencer_module.cc"], - hdrs = ["sequencer_module.h"], - deps = [ - ":bytecode_module", - ":bytecode_tables_sequencer", - ":sequencer_dispatch", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal:buffer_view", - "//iree/rt", - "@com_google_absl//absl/memory", - ], -) - -cc_library( - name = "source_map_resolver", - srcs = ["source_map_resolver.cc"], - hdrs = ["source_map_resolver.h"], - deps = [ - "//iree/base:flatbuffer_util", - "//iree/base:status", - "//iree/rt", - "//iree/schemas", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", - ], -) - -cc_library( - name = "type", - srcs = ["type.cc"], - hdrs = ["type.h"], - deps = [ - "//iree/base:status", - "//iree/schemas", - "//iree/schemas/bytecode:bytecode_v0", - ], -)
diff --git a/iree/vm/api.cc b/iree/vm/api.cc deleted file mode 100644 index 4a6cf06..0000000 --- a/iree/vm/api.cc +++ /dev/null
@@ -1,99 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/vm/api.h" - -#include "iree/base/api.h" -#include "iree/base/api_util.h" -#include "iree/base/flatbuffer_util.h" -#include "iree/base/tracing.h" -#include "iree/vm/sequencer_module.h" - -namespace iree { -namespace vm { - -//===----------------------------------------------------------------------===// -// iree::vm::BytecodeModule -//===----------------------------------------------------------------------===// - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_vm_bytecode_module_create_from_buffer( - iree_const_byte_span_t buffer_data, - void (*buffer_free_fn)(void* self, iree_byte_span_t buffer_data), - void* buffer_free_self, iree_allocator_t allocator, - iree_rt_module_t** out_module) { - IREE_TRACE_SCOPE0("iree_vm_bytecode_module_create_from_buffer"); - - if (!out_module) { - return IREE_STATUS_INVALID_ARGUMENT; - } - *out_module = nullptr; - - if (!buffer_data.data || !buffer_data.data_length) { - return IREE_STATUS_INVALID_ARGUMENT; - } - - IREE_API_ASSIGN_OR_RETURN( - auto module_file, - FlatBufferFile<ModuleDef>::FromBuffer( - ModuleDefIdentifier(), {buffer_data.data, buffer_data.data_length}, - [buffer_free_fn, buffer_free_self, buffer_data]() { - if (buffer_free_fn != nullptr) { - buffer_free_fn(buffer_free_self, - {const_cast<uint8_t*>(buffer_data.data), - buffer_data.data_length}); - } - })); - - IREE_API_ASSIGN_OR_RETURN(auto module, - SequencerModule::FromFile(std::move(module_file))); - - *out_module = reinterpret_cast<iree_rt_module_t*>(module.release()); - - return IREE_STATUS_OK; -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_vm_bytecode_module_create_from_file_mapping( - iree_file_mapping_t* file_mapping, iree_allocator_t allocator, - iree_rt_module_t** out_module) { - IREE_TRACE_SCOPE0("iree_vm_bytecode_module_create_from_file_mapping"); - - if (!out_module) { - return IREE_STATUS_INVALID_ARGUMENT; - } - *out_module = nullptr; - - if (!file_mapping) { - return IREE_STATUS_INVALID_ARGUMENT; - } - - auto buffer_data = iree_file_mapping_data(file_mapping); - IREE_API_ASSIGN_OR_RETURN( - auto module_file, - FlatBufferFile<ModuleDef>::FromBuffer( - ModuleDefIdentifier(), {buffer_data.data, buffer_data.data_length}, - [file_mapping]() { iree_file_mapping_release(file_mapping); })); - iree_file_mapping_retain(file_mapping); - - IREE_API_ASSIGN_OR_RETURN(auto module, - SequencerModule::FromFile(std::move(module_file))); - - *out_module = reinterpret_cast<iree_rt_module_t*>(module.release()); - - return IREE_STATUS_OK; -} - -} // namespace vm -} // namespace iree
diff --git a/iree/vm/api.h b/iree/vm/api.h deleted file mode 100644 index f724309..0000000 --- a/iree/vm/api.h +++ /dev/null
@@ -1,60 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// See iree/base/api.h for documentation on the API conventions used. - -#ifndef IREE_VM_API_H_ -#define IREE_VM_API_H_ - -#include <stdint.h> - -#include "iree/base/api.h" -#include "iree/rt/api.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -//===----------------------------------------------------------------------===// -// iree::vm::BytecodeModule -//===----------------------------------------------------------------------===// - -#ifndef IREE_API_NO_PROTOTYPES - -// Creates a VM module from an in-memory ModuleDef FlatBuffer. -// The provided |buffer_free_fn| will be called when the module is destroyed -// and only if this creation function succeeds. If ownership remains with the -// caller then pass nullptr for |buffer_free_fn|. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_vm_bytecode_module_create_from_buffer( - iree_const_byte_span_t buffer_data, - void (*buffer_free_fn)(void* self, iree_byte_span_t buffer_data), - void* buffer_free_self, iree_allocator_t allocator, - iree_rt_module_t** out_module); - -// Creates a VM module from a mapped ModuleDef FlatBuffer. -// The provided |file_mapping| will be retained for the life of the module and -// the contents will be accessed by reference. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_vm_bytecode_module_create_from_file_mapping( - iree_file_mapping_t* file_mapping, iree_allocator_t allocator, - iree_rt_module_t** out_module); - -#endif // IREE_API_NO_PROTOTYPES - -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus - -#endif // IREE_VM_API_H_
diff --git a/iree/vm/bytecode_disassembler.cc b/iree/vm/bytecode_disassembler.cc deleted file mode 100644 index 4868792..0000000 --- a/iree/vm/bytecode_disassembler.cc +++ /dev/null
@@ -1,482 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/vm/bytecode_disassembler.h" - -#include <iomanip> -#include <sstream> - -#include "absl/base/macros.h" -#include "absl/container/inlined_vector.h" -#include "absl/strings/str_join.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "iree/base/status.h" -#include "iree/schemas/bytecode/bytecode_v0.h" -#include "iree/schemas/source_map_def_generated.h" -#include "iree/vm/bytecode_module.h" -#include "iree/vm/bytecode_util.h" -#include "iree/vm/type.h" - -namespace iree { -namespace vm { - -namespace { - -using ::iree::rt::SourceOffset; - -template <typename T> -StatusOr<T> ReadValue(absl::Span<const uint8_t> data, SourceOffset* offset) { - if (*offset + sizeof(T) > data.size()) { - return OutOfRangeErrorBuilder(IREE_LOC) << "Bytecode data underrun"; - } - auto value = *reinterpret_cast<const T*>(&data[*offset]); - *offset = *offset + sizeof(T); - return value; -} - -StatusOr<const Type> ReadType(absl::Span<const uint8_t> data, - SourceOffset* offset) { - ASSIGN_OR_RETURN(uint8_t type_index, ReadValue<uint8_t>(data, offset)); - return Type::FromTypeIndex(type_index); -} - -StatusOr<uint8_t> ReadCount(absl::Span<const uint8_t> data, - SourceOffset* offset) { - return ReadValue<uint8_t>(data, offset); -} - -StatusOr<uint16_t> ReadValueSlot(absl::Span<const uint8_t> data, - SourceOffset* offset) { - return ReadValue<uint16_t>(data, offset); -} - -absl::string_view ConstantEncodingToString(ConstantEncoding encoding) { - switch (encoding) { -#define GET_NAME(ordinal, enum_name, str, ...) \ - case ConstantEncoding::enum_name: \ - return str; - IREE_CONSTANT_ENCODING_LIST(GET_NAME) -#undef GET_NAME - default: - return "unknown"; - } -} - -template <typename T> -std::string TypedDataToString(absl::Span<const uint8_t> bytes) { - auto typed_data = absl::Span<const T>{ - reinterpret_cast<const T*>(bytes.data()), bytes.size() / sizeof(T)}; - return absl::StrJoin(typed_data, ","); -} - -std::string ConstantToString(const Type& type, - absl::Span<const uint8_t> bytes) { - if (!type.is_builtin()) { - return absl::StrJoin(bytes, ","); - } - switch (type.builtin_type()) { - case BuiltinType::kI8: - return TypedDataToString<uint8_t>(bytes); - case BuiltinType::kI16: - return TypedDataToString<uint16_t>(bytes); - case BuiltinType::kI32: - return TypedDataToString<uint32_t>(bytes); - case BuiltinType::kI64: - return TypedDataToString<uint64_t>(bytes); - case BuiltinType::kF16: - return TypedDataToString<uint16_t>(bytes); - case BuiltinType::kF32: - return TypedDataToString<float>(bytes); - case BuiltinType::kF64: - return TypedDataToString<double>(bytes); - default: - return "<unsupported>"; - } -} - -} // namespace - -StatusOr<std::vector<rt::Instruction>> -BytecodeDisassembler::DisassembleInstructions(const rt::Function& function, - SourceOffset offset, - int32_t instruction_count) const { - std::vector<rt::Instruction> instructions; - - ASSIGN_OR_RETURN( - auto* function_def, - static_cast<const BytecodeModule*>(function.module()) - ->GetFunctionDef(function.linkage(), function.ordinal())); - auto* bytecode_def = function_def->bytecode(); - if (!bytecode_def) { - return UnavailableErrorBuilder(IREE_LOC) << "Function contains no body"; - } - auto data = absl::MakeSpan( - reinterpret_cast<const uint8_t*>(bytecode_def->contents()->data()), - bytecode_def->contents()->size()); - - // TODO(benvanik): scan and find all branch offsets to insert labels - - while (offset < data.length() && instructions.size() < instruction_count) { - instructions.push_back({}); - auto& instruction = instructions.back(); - instruction.offset = offset; - - uint8_t opcode = data[offset++]; - const auto& opcode_info = GetOpcodeInfo(opcode_table_, opcode); - if (!opcode_info.mnemonic) { - return UnimplementedErrorBuilder(IREE_LOC) - << "Unhandled opcode " << opcode << " at offset " << (offset - 1); - } - int payload_offset = offset; - - std::ostringstream stream; - - // Print out return values, if any. - int base_result_index = 0; - int printed_result_count = 0; - for (int i = base_result_index; i < ABSL_ARRAYSIZE(opcode_info.operands); - ++i) { - if (opcode_info.operands[i] == OperandEncoding::kNone) break; - if (printed_result_count > 0) { - stream << ", "; - } - switch (opcode_info.operands[i]) { - default: - case OperandEncoding::kNone: - return UnimplementedErrorBuilder(IREE_LOC) - << "Unhandled op encoding " - << static_cast<int>(opcode_info.operands[i]) << " at offset " - << (offset - 1); - case OperandEncoding::kInputSlot: - case OperandEncoding::kOutputSlot: { - // Printing handled below. - offset += sizeof(uint16_t); - break; - } - case OperandEncoding::kVariadicInputSlots: - case OperandEncoding::kVariadicOutputSlots: { - // Printing handled below. - ASSIGN_OR_RETURN(int count, ReadCount(data, &offset)); - offset += count * sizeof(uint16_t); - break; - } - case OperandEncoding::kResultSlot: { - ++printed_result_count; - ASSIGN_OR_RETURN(uint16_t slot_ordinal, ReadValueSlot(data, &offset)); - stream << "%" << slot_ordinal; - break; - } - case OperandEncoding::kVariadicResultSlots: { - ++printed_result_count; - stream << "["; - ASSIGN_OR_RETURN(int count, ReadCount(data, &offset)); - for (int j = 0; j < count; ++j) { - ASSIGN_OR_RETURN(uint16_t slot_ordinal, - ReadValueSlot(data, &offset)); - if (j > 0) stream << ", "; - stream << "%" << slot_ordinal; - } - stream << "]"; - break; - } - case OperandEncoding::kVariadicTransferSlots: { - // Printing handled below. - ASSIGN_OR_RETURN(int count, ReadCount(data, &offset)); - offset += count * 2 * sizeof(uint16_t); - break; - } - case OperandEncoding::kConstant: { - // Printing handled below. - ASSIGN_OR_RETURN(auto type, ReadType(data, &offset)); - ASSIGN_OR_RETURN(int rank, ReadCount(data, &offset)); - int element_count = 1; - for (int j = 0; j < rank; ++j) { - ASSIGN_OR_RETURN(int dim, ReadValue<int32_t>(data, &offset)); - element_count *= dim; - } - offset += sizeof(ConstantEncoding); - offset += element_count * type.element_size(); - break; - } - case OperandEncoding::kFunctionOrdinal: { - // Printing handled below. - offset += sizeof(uint32_t); - break; - } - case OperandEncoding::kDispatchOrdinal: { - // Printing handled below. - offset += sizeof(uint32_t) + sizeof(uint16_t); - break; - } - case OperandEncoding::kBlockOffset: { - // Printing handled below. - offset += sizeof(uint32_t); - break; - } - case OperandEncoding::kTypeIndex: { - // Printing handled below. - offset += sizeof(uint8_t); - break; - } - case OperandEncoding::kIndex: { - // Printing handled below. - offset += sizeof(int32_t); - break; - } - case OperandEncoding::kIndexList: { - // Printing handled below. - ASSIGN_OR_RETURN(int count, ReadCount(data, &offset)); - offset += count * sizeof(int32_t); - break; - } - case OperandEncoding::kCmpIPredicate: - case OperandEncoding::kCmpFPredicate: { - // Printing handled below. - offset += sizeof(uint8_t); - break; - } - } - } - if (printed_result_count > 0) { - stream << " = "; - } - offset = payload_offset; - - stream << opcode_info.mnemonic; - - // Print out operands. - int base_operand_index = 0; - int printed_operand_count = 0; - for (int i = base_operand_index; i < ABSL_ARRAYSIZE(opcode_info.operands); - ++i) { - if (opcode_info.operands[i] == OperandEncoding::kNone) break; - if (opcode_info.operands[i] != OperandEncoding::kResultSlot && - opcode_info.operands[i] != OperandEncoding::kVariadicResultSlots) { - if (i == base_operand_index) { - stream << " "; - } else if (printed_operand_count > 0) { - stream << ", "; - } - } - switch (opcode_info.operands[i]) { - default: - case OperandEncoding::kNone: - return UnimplementedErrorBuilder(IREE_LOC) - << "Unhandled op encoding " - << static_cast<int>(opcode_info.operands[i]) << " at offset " - << (offset - 1); - case OperandEncoding::kInputSlot: { - ++printed_operand_count; - ASSIGN_OR_RETURN(uint16_t slot_ordinal, ReadValueSlot(data, &offset)); - stream << "%" << slot_ordinal; - break; - } - case OperandEncoding::kVariadicInputSlots: { - ++printed_operand_count; - stream << "["; - ASSIGN_OR_RETURN(int count, ReadCount(data, &offset)); - for (int j = 0; j < count; ++j) { - ASSIGN_OR_RETURN(uint16_t slot_ordinal, - ReadValueSlot(data, &offset)); - if (j > 0) stream << ", "; - stream << "%" << slot_ordinal; - } - stream << "]"; - break; - } - case OperandEncoding::kOutputSlot: { - ++printed_operand_count; - ASSIGN_OR_RETURN(uint16_t slot_ordinal, ReadValueSlot(data, &offset)); - stream << "&" - << "%" << slot_ordinal; - break; - } - case OperandEncoding::kVariadicOutputSlots: { - ++printed_operand_count; - stream << "["; - ASSIGN_OR_RETURN(int count, ReadCount(data, &offset)); - for (int j = 0; j < count; ++j) { - ASSIGN_OR_RETURN(uint16_t slot_ordinal, - ReadValueSlot(data, &offset)); - if (j > 0) stream << ", "; - stream << "&" - << "%" << slot_ordinal; - } - stream << "]"; - break; - } - case OperandEncoding::kResultSlot: { - // Printing handled above. - offset += sizeof(uint16_t); - break; - } - case OperandEncoding::kVariadicResultSlots: { - // Printing handled above. - ASSIGN_OR_RETURN(int count, ReadCount(data, &offset)); - offset += count * sizeof(uint16_t); - break; - } - case OperandEncoding::kVariadicTransferSlots: { - ++printed_operand_count; - stream << "["; - ASSIGN_OR_RETURN(int count, ReadCount(data, &offset)); - for (int j = 0; j < count; ++j) { - ASSIGN_OR_RETURN(uint16_t src_slot_ordinal, - ReadValueSlot(data, &offset)); - ASSIGN_OR_RETURN(uint16_t dst_slot_ordinal, - ReadValueSlot(data, &offset)); - if (j > 0) stream << ", "; - stream << "%" << src_slot_ordinal << "=>%" << dst_slot_ordinal; - } - stream << "]"; - break; - } - case OperandEncoding::kConstant: { - ++printed_operand_count; - ASSIGN_OR_RETURN(auto type, ReadType(data, &offset)); - ASSIGN_OR_RETURN(int rank, ReadCount(data, &offset)); - absl::InlinedVector<int32_t, 4> shape(rank); - int element_count = 1; - for (int j = 0; j < rank; ++j) { - ASSIGN_OR_RETURN(int dim, ReadValue<int32_t>(data, &offset)); - shape[j] = dim; - element_count *= dim; - } - ASSIGN_OR_RETURN(auto encoding, - ReadValue<ConstantEncoding>(data, &offset)); - stream << ConstantEncodingToString(encoding); - int serialized_element_count = 1; - switch (encoding) { - case ConstantEncoding::kDense: - serialized_element_count = element_count; - break; - case ConstantEncoding::kSplat: - serialized_element_count = 1; - break; - default: - return UnimplementedErrorBuilder(IREE_LOC) - << "Unimplemented constant encoding " - << static_cast<int>(encoding); - } - stream << " buffer_view<"; - if (!shape.empty()) { - stream << absl::StrJoin(shape, "x") << "x"; - } - stream << type << ">{"; - size_t element_size = type.element_size(); - auto bytes = data.subspan( - offset, std::min(serialized_element_count, 1024) * element_size); - stream << ConstantToString(type, bytes); - if (serialized_element_count > 1024) stream << "..."; - offset += serialized_element_count * element_size; - stream << "}"; - break; - } - case OperandEncoding::kFunctionOrdinal: { - ++printed_operand_count; - ASSIGN_OR_RETURN(auto function_ordinal, - ReadValue<uint32_t>(data, &offset)); - ASSIGN_OR_RETURN( - auto target_function, - function.module()->LookupFunctionByOrdinal( - rt::Function::Linkage::kInternal, function_ordinal)); - stream << "@" << function_ordinal << " " << target_function.name(); - break; - } - case OperandEncoding::kDispatchOrdinal: { - ++printed_operand_count; - ASSIGN_OR_RETURN(auto dispatch_ordinal, - ReadValue<uint32_t>(data, &offset)); - ASSIGN_OR_RETURN(auto export_ordinal, - ReadValue<uint16_t>(data, &offset)); - // TODO(benvanik): lookup in executable table. - stream << "@" << dispatch_ordinal << ":" << export_ordinal; - break; - } - case OperandEncoding::kImportOrdinal: { - ++printed_operand_count; - ASSIGN_OR_RETURN(auto import_ordinal, - ReadValue<uint32_t>(data, &offset)); - ASSIGN_OR_RETURN(auto target_function, - function.module()->LookupFunctionByOrdinal( - rt::Function::Linkage::kImport, import_ordinal)); - stream << "@i" << import_ordinal << " " << target_function.name(); - break; - } - case OperandEncoding::kBlockOffset: { - ++printed_operand_count; - ASSIGN_OR_RETURN(uint32_t block_offset, - ReadValue<uint32_t>(data, &offset)); - stream << ":" << block_offset; - break; - } - case OperandEncoding::kTypeIndex: { - ++printed_operand_count; - ASSIGN_OR_RETURN(auto type, ReadType(data, &offset)); - stream << type; - break; - } - case OperandEncoding::kIndex: { - ++printed_operand_count; - ASSIGN_OR_RETURN(auto index, ReadValue<int32_t>(data, &offset)); - stream << "#" << index; - break; - } - case OperandEncoding::kIndexList: { - ++printed_operand_count; - stream << "{"; - ASSIGN_OR_RETURN(int count, ReadCount(data, &offset)); - for (int j = 0; j < count; ++j) { - ASSIGN_OR_RETURN(auto dim, ReadValue<int32_t>(data, &offset)); - if (j > 0) stream << ","; - stream << dim; - } - stream << "}"; - break; - } - case OperandEncoding::kCmpIPredicate: { - ++printed_operand_count; - ASSIGN_OR_RETURN(auto predicate_value, - ReadValue<uint8_t>(data, &offset)); - stream << "<" - << PredicateToString( - static_cast<CmpIPredicate>(predicate_value)) - << ">"; - break; - } - case OperandEncoding::kCmpFPredicate: { - ++printed_operand_count; - ASSIGN_OR_RETURN(auto predicate_value, - ReadValue<uint8_t>(data, &offset)); - stream << "<" - << PredicateToString( - static_cast<CmpFPredicate>(predicate_value)) - << ">"; - break; - } - } - } - - stream << "\n"; - - instruction.long_text = stream.str(); - instruction.short_text = instruction.long_text; - } - - return instructions; -} - -} // namespace vm -} // namespace iree
diff --git a/iree/vm/bytecode_disassembler.h b/iree/vm/bytecode_disassembler.h deleted file mode 100644 index 633e290..0000000 --- a/iree/vm/bytecode_disassembler.h +++ /dev/null
@@ -1,46 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_VM_BYTECODE_DISASSEMBLER_H_ -#define IREE_VM_BYTECODE_DISASSEMBLER_H_ - -#include <ostream> - -#include "iree/base/status.h" -#include "iree/rt/disassembler.h" -#include "iree/schemas/bytecode_def_generated.h" -#include "iree/schemas/source_map_def_generated.h" -#include "iree/vm/opcode_info.h" - -namespace iree { -namespace vm { - -// Disassembles bytecode with a specific op set to text. -class BytecodeDisassembler final : public rt::Disassembler { - public: - explicit BytecodeDisassembler(OpcodeTable opcode_table) - : opcode_table_(opcode_table) {} - - StatusOr<std::vector<rt::Instruction>> DisassembleInstructions( - const rt::Function& function, rt::SourceOffset offset, - int32_t instruction_count) const override; - - private: - OpcodeTable opcode_table_; -}; - -} // namespace vm -} // namespace iree - -#endif // IREE_VM_BYTECODE_DISASSEMBLER_H_
diff --git a/iree/vm/bytecode_module.cc b/iree/vm/bytecode_module.cc deleted file mode 100644 index 7812cb6..0000000 --- a/iree/vm/bytecode_module.cc +++ /dev/null
@@ -1,310 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/vm/bytecode_module.h" - -#include "absl/memory/memory.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/buffer_view.h" -#include "iree/vm/bytecode_disassembler.h" - -namespace iree { -namespace vm { - -namespace { - -using ::iree::hal::BufferView; -using ::iree::rt::Function; -using ::iree::rt::FunctionSignature; -using ::iree::rt::Module; -using ::iree::rt::ModuleSignature; - -Status ValidateElementSize(int element_bit_width, - const ElementTypeDef& expected_element_type) { - switch (expected_element_type.type_union_type()) { - case ElementTypeDefUnion::FloatTypeDef: { - auto expected_bit_width = - expected_element_type.type_union_as_FloatTypeDef()->width(); - if (element_bit_width != expected_bit_width) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Has element bit width " << element_bit_width - << " but expected " << expected_bit_width; - } - return OkStatus(); - } - case ElementTypeDefUnion::IntegerTypeDef: { - auto expected_bit_width = - expected_element_type.type_union_as_IntegerTypeDef()->width(); - if (element_bit_width != expected_bit_width) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Has element bit width " << element_bit_width - << " but expected " << expected_bit_width; - } - return OkStatus(); - } - case ElementTypeDefUnion::UnknownTypeDef: - case ElementTypeDefUnion::NONE: { - } - } - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Defined type has unsupported element type " - << EnumNameElementTypeDefUnion( - expected_element_type.type_union_type()); -} - -Status ValidateTypeStructure(const FunctionTypeDef& type_def) { - // Ensure all fields are populated. - return OkStatus(); -} - -Status ValidateFunctionTableStructure( - const FunctionTableDef& function_table_def) { - if (!function_table_def.functions()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Function table is missing the function listing"; - } - - // All functions must contain a valid type. - const auto& functions = *function_table_def.functions(); - for (int i = 0; i < functions.size(); ++i) { - const auto* function = functions[i]; - if (!function) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Function ordinal " << i << " is missing its contents"; - } - if (!function->type()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Function ordinal " << i << " is missing its type"; - } - RETURN_IF_ERROR(ValidateTypeStructure(*function->type())); - } - - // Imports must also have a name (that we can use to resolve it). - if (function_table_def.imports()) { - const auto& imports = *function_table_def.imports(); - for (int i = 0; i < imports.size(); ++i) { - int function_index = imports[i]; - if (!functions[function_index]->name()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Import ordinal " << i << " is missing its contents"; - } - } - } - - // Exports must also have a name (that others will use to look it up). - if (function_table_def.exports()) { - const auto& exports = *function_table_def.exports(); - for (int i = 0; i < exports.size(); ++i) { - int function_index = exports[i]; - if (!functions[function_index]->name()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Export ordinal " << i << " is missing its contents"; - } - } - } - - return OkStatus(); -} - -Status ValidateExecutableTableStructure( - const ExecutableTableDef& executable_table_def) { - if (!executable_table_def.multi_arch_executables()) { - // May have sequencer only fns. Fine to not have dispatchable executables. - return OkStatus(); - } - - // All fat executables need at least one device-specific executable. - const auto& multi_arch_executables = - *executable_table_def.multi_arch_executables(); - for (int i = 0; i < multi_arch_executables.size(); ++i) { - const auto* multi_arch_executable = multi_arch_executables[i]; - if (!multi_arch_executable || !multi_arch_executable->executables() || - multi_arch_executable->executables()->size() == 0) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Multi-arch executable ordinal " << i - << " is missing its contents"; - } - } - - return OkStatus(); -} - -} // namespace - -// static -Status BytecodeModule::ValidateStructure(const ModuleDef& module_def) { - IREE_TRACE_SCOPE0("BytecodeModule::ValidateStructure"); - - // Must have a function table. - if (module_def.function_table()) { - RETURN_IF_ERROR( - ValidateFunctionTableStructure(*module_def.function_table())); - } else { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "ModuleDef is missing a function table"; - } - - // Must have an executable table. - if (module_def.executable_table()) { - RETURN_IF_ERROR( - ValidateExecutableTableStructure(*module_def.executable_table())); - } else { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "ModuleDef is missing an executable table"; - } - - return OkStatus(); -} - -BytecodeModule::BytecodeModule(std::unique_ptr<ModuleFile> module_file, - OpcodeTable opcode_table) - : module_file_(std::move(module_file)), - module_def_(*module_file_->root()), - source_resolver_(SourceMapResolver::FromModule(module_def_)), - disassembler_(absl::make_unique<BytecodeDisassembler>(opcode_table)) {} - -BytecodeModule::~BytecodeModule() = default; - -const ModuleSignature BytecodeModule::signature() const { - return ModuleSignature(function_table_def().imports()->size(), - function_table_def().exports()->size(), - function_table_def().functions()->size(), 0); -} - -std::string BytecodeModule::DebugStringShort() const { - return std::string(name()); -} - -StatusOr<int32_t> BytecodeModule::MapFunctionOrdinal(Function::Linkage linkage, - int32_t ordinal) const { - const auto& function_table = function_table_def(); - switch (linkage) { - case Function::Linkage::kImport: - if (ordinal < 0 || ordinal >= function_table.imports()->size()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Import ordinal " << ordinal - << " is outside the valid range [0, " - << function_table.imports()->size() << ")"; - } - ordinal = function_table.imports()->Get(ordinal); - break; - case Function::Linkage::kExport: - if (ordinal < 0 || ordinal >= function_table.exports()->size()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Export ordinal " << ordinal - << " is outside the valid range [0, " - << function_table.exports()->size() << ")"; - } - ordinal = function_table.exports()->Get(ordinal); - break; - default: - break; - } - if (ordinal < 0 || ordinal >= function_table.functions()->size()) { - return OutOfRangeErrorBuilder(IREE_LOC) - << "Function ordinal " << ordinal - << " is outside the valid range [0, " - << function_table.functions()->size() << ")"; - } - return ordinal; -} - -StatusOr<const Function> BytecodeModule::LookupFunctionByOrdinal( - Function::Linkage linkage, int32_t ordinal) const { - ASSIGN_OR_RETURN(ordinal, MapFunctionOrdinal(linkage, ordinal)); - return Function(this, Function::Linkage::kInternal, ordinal); -} - -StatusOr<const Function> BytecodeModule::LookupFunctionByName( - Function::Linkage linkage, absl::string_view name) const { - const auto& functions = *function_table_def().functions(); - for (int i = 0; i < functions.size(); ++i) { - const auto* function_def = functions.Get(i); - if (WrapString(function_def->name()) == name) { - return LookupFunctionByOrdinal(Function::Linkage::kInternal, i); - } - } - return NotFoundErrorBuilder(IREE_LOC) - << "Function '" << name - << "' not found in function table (or names have been stripped)"; -} - -StatusOr<absl::string_view> BytecodeModule::GetFunctionName( - Function::Linkage linkage, int32_t ordinal) const { - ASSIGN_OR_RETURN(ordinal, MapFunctionOrdinal(linkage, ordinal)); - const auto* function_def = function_table_def().functions()->Get(ordinal); - return WrapString(function_def->name()); -} - -StatusOr<const FunctionSignature> BytecodeModule::GetFunctionSignature( - Function::Linkage linkage, int32_t ordinal) const { - ASSIGN_OR_RETURN(ordinal, MapFunctionOrdinal(linkage, ordinal)); - const auto* function_def = function_table_def().functions()->Get(ordinal); - const auto* type_def = function_def->type(); - return FunctionSignature( - type_def->inputs() ? type_def->inputs()->size() : 0, - type_def->results() ? type_def->results()->size() : 0); -} - -StatusOr<const FunctionDef*> BytecodeModule::GetFunctionDef( - rt::Function::Linkage linkage, int32_t ordinal) const { - ASSIGN_OR_RETURN(ordinal, MapFunctionOrdinal(linkage, ordinal)); - const auto& function_defs = *function_table_def().functions(); - if (ordinal >= function_defs.size()) { - return OutOfRangeErrorBuilder(IREE_LOC) - << "Internal function ordinal " << ordinal - << " out of range of table (" << function_defs.size() << ")"; - } - return function_defs.Get(ordinal); -} - -StatusOr<const MultiArchExecutableDef*> -BytecodeModule::LookupMultiArchExecutable(int executable_ordinal) const { - if (executable_ordinal < 0 || - executable_ordinal >= - executable_table_def().multi_arch_executables()->size()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Invalid multi-arch executable ordinal " << executable_ordinal; - } - return executable_table_def().multi_arch_executables()->Get( - executable_ordinal); -} - -// static -Status BytecodeModule::ValidateArgType(const BufferView& arg, - const MemRefTypeDef& expected_type) { - RETURN_IF_ERROR( - ValidateElementSize(arg.element_size * 8, *expected_type.element_type())); - - auto expected_shape = expected_type.shape(); - if (arg.shape.size() != expected_shape->size()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Argument should have rank " << expected_shape->size() - << " but has rank " << arg.shape.size(); - } - for (int i = 0; i < expected_shape->size(); ++i) { - auto dim_size = arg.shape[i]; - auto expected_dim_size = expected_shape->Get(i); - if (dim_size != expected_dim_size) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Argument dimension " << i << " should have size " - << expected_dim_size << " but has size " << dim_size; - } - } - return OkStatus(); -} - -} // namespace vm -} // namespace iree
diff --git a/iree/vm/bytecode_module.h b/iree/vm/bytecode_module.h deleted file mode 100644 index 36d034e..0000000 --- a/iree/vm/bytecode_module.h +++ /dev/null
@@ -1,103 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_VM_BYTECODE_MODULE_H_ -#define IREE_VM_BYTECODE_MODULE_H_ - -#include <memory> - -#include "iree/base/flatbuffer_util.h" -#include "iree/rt/function.h" -#include "iree/rt/module.h" -#include "iree/schemas/executable_table_def_generated.h" -#include "iree/schemas/function_table_def_generated.h" -#include "iree/schemas/module_def_generated.h" -#include "iree/vm/opcode_info.h" -#include "iree/vm/source_map_resolver.h" - -namespace iree { -namespace vm { - -using ModuleFile = FlatBufferFile<ModuleDef>; - -// A loaded bytecode module backed by a FlatBuffer. -class BytecodeModule : public rt::Module { - public: - static Status ValidateStructure(const ModuleDef& module_def); - - ~BytecodeModule() override; - - const ModuleDef& def() const { return module_def_; } - const FunctionTableDef& function_table_def() const { - return *module_def_.function_table(); - } - const ExecutableTableDef& executable_table_def() const { - return *module_def_.executable_table(); - } - - absl::string_view name() const override { - return WrapString(module_def_.name()); - } - - const rt::ModuleSignature signature() const override; - - rt::SourceResolver* source_resolver() const override { - return &source_resolver_; - } - - rt::Disassembler* disassembler() const override { - return disassembler_.get(); - } - - std::string DebugStringShort() const override; - - StatusOr<const rt::Function> LookupFunctionByOrdinal( - rt::Function::Linkage linkage, int32_t ordinal) const override; - - StatusOr<const rt::Function> LookupFunctionByName( - rt::Function::Linkage linkage, absl::string_view name) const override; - - StatusOr<absl::string_view> GetFunctionName(rt::Function::Linkage linkage, - int32_t ordinal) const override; - - StatusOr<const rt::FunctionSignature> GetFunctionSignature( - rt::Function::Linkage linkage, int32_t ordinal) const override; - - StatusOr<const FunctionDef*> GetFunctionDef(rt::Function::Linkage linkage, - int32_t ordinal) const; - - StatusOr<const MultiArchExecutableDef*> LookupMultiArchExecutable( - int executable_ordinal) const; - - protected: - BytecodeModule(std::unique_ptr<ModuleFile> module_file, - OpcodeTable opcode_table); - - static Status ValidateArgType(const hal::BufferView& arg, - const MemRefTypeDef& expected_type); - - private: - StatusOr<int32_t> MapFunctionOrdinal(rt::Function::Linkage linkage, - int32_t ordinal) const; - - std::unique_ptr<ModuleFile> module_file_; - const ModuleDef& module_def_; - mutable SourceMapResolver source_resolver_; - mutable std::unique_ptr<rt::Disassembler> disassembler_; -}; - -} // namespace vm -} // namespace iree - -#endif // IREE_VM_BYTECODE_MODULE_H_
diff --git a/iree/vm/bytecode_reader.cc b/iree/vm/bytecode_reader.cc deleted file mode 100644 index 65e69f3..0000000 --- a/iree/vm/bytecode_reader.cc +++ /dev/null
@@ -1,289 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/vm/bytecode_reader.h" - -#include "iree/base/shape.h" -#include "iree/base/status.h" -#include "iree/hal/heap_buffer.h" -#include "iree/vm/bytecode_module.h" - -namespace iree { -namespace vm { - -namespace { - -using ::iree::hal::BufferView; -using ::iree::rt::StackFrame; - -} // namespace - -StatusOr<const uint8_t*> BytecodeReader::AdvanceOffset() { - *stack_frame_->mutable_offset() = offset(); - // TODO(benvanik): make a flag and/or remove. - DVLOG(1) << "dispatch(" << stack_frame_->function().name() << "@" << offset() - << "): " << int(*bytecode_pc_); - for (int i = 0; i < registers_->buffer_views.size(); ++i) { - DVLOG(1) << "local[" << i << "] " - << registers_->buffer_views[i].DebugStringShort(); - } - return bytecode_pc_++; -} - -Status BytecodeReader::SkipLocals(int count) { - size_t stride = sizeof(uint16_t) * count; - if (bytecode_pc_ + stride >= bytecode_limit_) { - return OutOfRangeErrorBuilder(IREE_LOC) << "Bytecode underflow"; - } - bytecode_pc_ += stride; - return OkStatus(); -} - -Status BytecodeReader::ReadShape(Shape* out_shape) { - ASSIGN_OR_RETURN(auto shape_dims, ReadIndexList()); - *out_shape = Shape(shape_dims); - return OkStatus(); -} - -StatusOr<Shape> BytecodeReader::ReadShapePieces() { - // TODO(benvanik): rewrite to be faster (multiple offsets to walk both lists). - ASSIGN_OR_RETURN(auto shape_dims, ReadIndexList()); - if (shape_dims.size() >= kMaxRank) { - return UnimplementedErrorBuilder(IREE_LOC) - << "Shapes limited to rank " << kMaxRank << " right now"; - } - int expected_dynamic_dims = 0; - for (int i = 0; i < shape_dims.size(); ++i) { - if (shape_dims[i] == -1) { - ++expected_dynamic_dims; - } - } - - Shape shape(shape_dims); - ASSIGN_OR_RETURN(int dynamic_dims, ReadCount()); - if (dynamic_dims != expected_dynamic_dims) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Expected " << expected_dynamic_dims << " dynamic dims but only " - << dynamic_dims << " provided"; - } else if (dynamic_dims) { - for (int i = 0; i < shape_dims.size(); ++i) { - if (shape_dims[i] != -1) { - continue; - } - // TODO(benvanik): kill this embarrassment. - ASSIGN_OR_RETURN(auto dims_piece, ReadSlotElements<int32_t>()); - if (dims_piece.size() != 1) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Dims piece has rank " << dims_piece.size() << "; must be 1"; - } - shape[i] = dims_piece[0]; - } - } - return shape; -} - -StatusOr<Shape> BytecodeReader::ReadShapePieces(size_t* out_element_count) { - ASSIGN_OR_RETURN(auto shape, ReadShapePieces()); - *out_element_count = shape.element_count(); - return shape; -} - -StatusOr<absl::Span<const int32_t>> BytecodeReader::ReadIndexList() { - ASSIGN_OR_RETURN(int count, ReadCount()); - int stride = count * sizeof(int32_t); - if (bytecode_pc_ + stride >= bytecode_limit_) { - return OutOfRangeErrorBuilder(IREE_LOC) << "Bytecode underflow"; - } - auto list = absl::Span<const int32_t>( - reinterpret_cast<const int32_t*>(bytecode_pc_), count); - bytecode_pc_ += stride; - return list; -} - -Status BytecodeReader::SwitchStackFrame(StackFrame* new_stack_frame) { - // Flush old state. - auto* old_stack_frame = stack_frame_; - if (old_stack_frame) { - *old_stack_frame->mutable_offset() = offset(); - } - - // Switch the frame. The FiberState holds the full stack, this is just the - // current one for easy access. - stack_frame_ = new_stack_frame; - - // Setup state pointers for faster dereferencing. - const auto& function = new_stack_frame->function(); - ASSIGN_OR_RETURN( - const auto* function_def, - static_cast<const BytecodeModule*>(function.module()) - ->GetFunctionDef(function.linkage(), function.ordinal())); - const auto& bytecode = *function_def->bytecode(); - bytecode_base_ = bytecode.contents()->Data(); - bytecode_limit_ = bytecode_base_ + bytecode.contents()->size(); - bytecode_pc_ = bytecode_base_ + new_stack_frame->offset(); - registers_ = new_stack_frame->mutable_registers(); - return OkStatus(); -} - -Status BytecodeReader::CopyInputsAndSwitchStackFrame( - StackFrame* src_stack_frame, StackFrame* dst_stack_frame) { - ASSIGN_OR_RETURN(size_t src_count, ReadCount()); - auto& dst_buffer_views = dst_stack_frame->mutable_registers()->buffer_views; - for (int i = 0; i < std::min(src_count, dst_buffer_views.size()); ++i) { - ASSIGN_OR_RETURN(auto* src_local, - ReadLocal(src_stack_frame->mutable_registers())); - dst_buffer_views[i] = *src_local; - } - return SwitchStackFrame(dst_stack_frame); -} - -Status BytecodeReader::CopyResultsAndSwitchStackFrame( - StackFrame* src_stack_frame, StackFrame* dst_stack_frame) { - ASSIGN_OR_RETURN(int32_t src_count, ReadCount()); - // TODO(benvanik): avoid vector. - absl::InlinedVector<BufferView*, 8> src_locals(src_count); - for (int i = 0; i < src_count; ++i) { - ASSIGN_OR_RETURN(src_locals[i], - ReadLocal(src_stack_frame->mutable_registers())); - } - RETURN_IF_ERROR(SwitchStackFrame(dst_stack_frame)); - ASSIGN_OR_RETURN(int32_t dst_count, ReadCount()); - if (src_count != dst_count) { - return OutOfRangeErrorBuilder(IREE_LOC) - << "Src and dst value counts differ: " << src_count << " vs " - << dst_count; - } - for (int i = 0; i < dst_count; ++i) { - ASSIGN_OR_RETURN(auto* dst_local, - ReadLocal(dst_stack_frame->mutable_registers())); - *dst_local = *src_locals[i]; - } - return OkStatus(); -} - -Status BytecodeReader::CopySlots() { - ASSIGN_OR_RETURN(int32_t count, ReadCount()); - for (int i = 0; i < count; ++i) { - ASSIGN_OR_RETURN(auto* src_local, - ReadLocal(stack_frame_->mutable_registers())); - ASSIGN_OR_RETURN(auto* dst_local, - ReadLocal(stack_frame_->mutable_registers())); - *dst_local = *src_local; - } - return OkStatus(); -} - -Status BytecodeReader::BranchToOffset(int32_t offset) { - const uint8_t* new_bytecode_pc = bytecode_base_ + offset; - if (new_bytecode_pc < bytecode_base_ || new_bytecode_pc > bytecode_limit_) { - return OutOfRangeErrorBuilder(IREE_LOC) - << "Branch target " << offset - << " is out of bounds of the function bytecode (" - << static_cast<size_t>(bytecode_limit_ - bytecode_base_) - << "b total)"; - } - bytecode_pc_ = new_bytecode_pc; - return OkStatus(); -} - -StatusOr<BufferView> BytecodeReader::ReadConstant() { - BufferView buffer_view; - - // Element type defines the buffer_view size (but we don't really care about - // the data format). - ASSIGN_OR_RETURN(auto element_type, ReadType()); - buffer_view.element_size = element_type.element_size(); - - // Parse shape - constants always define a full shape. - RETURN_IF_ERROR(ReadShape(&buffer_view.shape)); - - // Read encoding to determine how the constant data is stored in the file. - ASSIGN_OR_RETURN(auto encoding, ReadValue<ConstantEncoding>()); - - // Get buffer for the constant data. - switch (encoding) { - case ConstantEncoding::kDense: { - // Validate we have all constant data present. - device_size_t serialized_length = buffer_view.byte_length(); - if (bytecode_pc_ + serialized_length >= bytecode_limit_) { - return OutOfRangeErrorBuilder(IREE_LOC) - << "Constant data out of bounds"; - } - - buffer_view.buffer = hal::HeapBuffer::Wrap( - hal::MemoryType::kHostLocal, hal::BufferUsage::kAll, bytecode_pc_, - serialized_length); - bytecode_pc_ += serialized_length; - break; - } - case ConstantEncoding::kSplat: { - // Validate we have at least one element worth of data in the buffer. - if (bytecode_pc_ + buffer_view.element_size >= bytecode_limit_) { - return OutOfRangeErrorBuilder(IREE_LOC) - << "Constant data out of bounds"; - } - - // TODO(benvanik): replace with fancy constant pool and such. - // NOTE: this is not much different than if a alloc_heap+broadcast pair - // had been in the IR. - buffer_view.buffer = hal::HeapBuffer::Allocate( - hal::MemoryType::kHostLocal, hal::BufferUsage::kAll, - buffer_view.byte_length()); - switch (buffer_view.element_size) { - case 1: { - uint8_t value = *reinterpret_cast<const uint8_t*>(bytecode_pc_); - RETURN_IF_ERROR(buffer_view.buffer->Fill8(value)); - break; - } - case 2: { - uint16_t value = *reinterpret_cast<const uint16_t*>(bytecode_pc_); - RETURN_IF_ERROR(buffer_view.buffer->Fill16(value)); - break; - } - case 4: { - uint32_t value = *reinterpret_cast<const uint32_t*>(bytecode_pc_); - RETURN_IF_ERROR(buffer_view.buffer->Fill32(value)); - break; - } - case 8: { - // TODO(benvanik): add Fill64. - uint64_t value = *reinterpret_cast<const uint64_t*>(bytecode_pc_); - ASSIGN_OR_RETURN(auto mapping, - buffer_view.buffer->MapMemory<uint64_t>( - hal::MemoryAccess::kDiscardWrite)); - auto mapped_data = mapping.mutable_contents(); - for (int i = 0; i < mapping.size(); ++i) { - mapped_data[i] = value; - } - break; - } - default: - return UnimplementedErrorBuilder(IREE_LOC) - << "Unimplemented splat element stride " - << buffer_view.element_size; - } - bytecode_pc_ += buffer_view.element_size; - break; - } - default: - return UnimplementedErrorBuilder(IREE_LOC) - << "Unimplemented constant encoding " - << static_cast<int>(encoding); - } - - return buffer_view; -} - -} // namespace vm -} // namespace iree
diff --git a/iree/vm/bytecode_reader.h b/iree/vm/bytecode_reader.h deleted file mode 100644 index f2b85fd..0000000 --- a/iree/vm/bytecode_reader.h +++ /dev/null
@@ -1,169 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_VM_BYTECODE_READER_H_ -#define IREE_VM_BYTECODE_READER_H_ - -#include "absl/base/attributes.h" -#include "absl/container/inlined_vector.h" -#include "iree/base/status.h" -#include "iree/hal/buffer_view.h" -#include "iree/rt/context.h" -#include "iree/rt/stack.h" -#include "iree/rt/stack_frame.h" -#include "iree/schemas/bytecode/bytecode_v0.h" -#include "iree/vm/type.h" - -namespace iree { -namespace vm { - -class BytecodeReader { - public: - explicit BytecodeReader(rt::Stack* stack) : stack_(stack) {} - - int offset() const { return static_cast<int>(bytecode_pc_ - bytecode_base_); } - - StatusOr<const uint8_t*> AdvanceOffset(); - - Status SwitchStackFrame(rt::StackFrame* new_stack_frame); - Status BranchToOffset(int32_t offset); - - Status CopyInputsAndSwitchStackFrame(rt::StackFrame* src_stack_frame, - rt::StackFrame* dst_stack_frame); - Status CopyResultsAndSwitchStackFrame(rt::StackFrame* src_stack_frame, - rt::StackFrame* dst_stack_frame); - Status CopySlots(); - - StatusOr<hal::BufferView> ReadConstant(); - - ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<int> ReadCount() { - return ReadValue<uint8_t>(); - } - - ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<const Type> ReadType() { - ASSIGN_OR_RETURN(uint8_t type_index, ReadValue<uint8_t>()); - return Type::FromTypeIndex(type_index); - } - - ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<const rt::Function> ReadFunction() { - ASSIGN_OR_RETURN(auto value, ReadValue<uint32_t>()); - const auto& module = stack_frame_->module(); - return module.LookupFunctionByOrdinal(rt::Function::Linkage::kInternal, - value); - } - - ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<const rt::Function> - ReadImportFunction() { - ASSIGN_OR_RETURN(auto value, ReadValue<uint32_t>()); - const auto& module = stack_frame_->module(); - return stack_->context()->ResolveImport(&module, value); - } - - ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<hal::BufferView*> ReadLocal( - rt::Registers* registers) { - ASSIGN_OR_RETURN(auto value, ReadValue<uint16_t>()); - if (value > registers->buffer_views.size()) { - return OutOfRangeErrorBuilder(IREE_LOC) - << "Out of bounds local access " << value << " of " - << registers->buffer_views.size(); - } - return ®isters->buffer_views[value]; - } - - ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<hal::BufferView*> ReadLocal() { - return ReadLocal(registers_); - } - - Status SkipLocals(int count); - - ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<uint8_t> ReadUint8_t() { - return ReadValue<uint8_t>(); - } - - ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<uint16_t> ReadUint16_t() { - return ReadValue<uint16_t>(); - } - - ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<int32_t> ReadInt32() { - return ReadValue<int32_t>(); - } - - ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<uint32_t> ReadBlockOffset() { - return ReadValue<uint32_t>(); - } - - template <typename T, size_t N = 8> - ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<absl::InlinedVector<T, N>> - ReadSlotElements() { - ASSIGN_OR_RETURN(auto* local, ReadLocal(registers_)); - absl::InlinedVector<T, N> result(local->shape.element_count()); - if (sizeof(T) == local->element_size) { - // Fast(ish) path: requested element size matches the actual element size. - RETURN_IF_ERROR( - local->buffer->ReadData(0, result.data(), result.size() * sizeof(T))); - } else { - // Slow path: need to convert the data. - switch (local->element_size) { - case 4: { - ASSIGN_OR_RETURN(auto mapping, local->buffer->MapMemory<int32_t>( - hal::MemoryAccess::kRead)); - for (size_t i = 0; i < result.size(); ++i) { - result[i] = static_cast<T>(mapping[i]); - } - break; - } - case 8: { - ASSIGN_OR_RETURN(auto mapping, local->buffer->MapMemory<int64_t>( - hal::MemoryAccess::kRead)); - for (size_t i = 0; i < result.size(); ++i) { - result[i] = static_cast<T>(mapping[i]); - } - break; - } - default: - return UnimplementedErrorBuilder(IREE_LOC) - << "Unsupported local element size: " << local->element_size; - } - } - return result; - } - - Status ReadShape(Shape* out_shape); - - StatusOr<Shape> ReadShapePieces(); - StatusOr<Shape> ReadShapePieces(size_t* out_element_count); - - StatusOr<absl::Span<const int32_t>> ReadIndexList(); - - private: - template <typename T> - ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<T> ReadValue() { - // TODO(benvanik): validate bounds. - T value = *reinterpret_cast<const T*>(bytecode_pc_); - bytecode_pc_ += sizeof(T); - return value; - } - - rt::Stack* stack_ = nullptr; - rt::StackFrame* stack_frame_ = nullptr; - const uint8_t* bytecode_base_ = nullptr; - const uint8_t* bytecode_limit_ = nullptr; - const uint8_t* bytecode_pc_ = nullptr; - rt::Registers* registers_ = nullptr; -}; - -} // namespace vm -} // namespace iree - -#endif // IREE_VM_BYTECODE_READER_H_
diff --git a/iree/vm/bytecode_tables_interpreter.cc b/iree/vm/bytecode_tables_interpreter.cc deleted file mode 100644 index 23661c5..0000000 --- a/iree/vm/bytecode_tables_interpreter.cc +++ /dev/null
@@ -1,44 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/vm/bytecode_tables_interpreter.h" - -#include "iree/schemas/bytecode/interpreter_bytecode_v0.h" - -namespace iree { -namespace vm { - -namespace { - -// Info table mapping 1:1 with bytecode ops. -// -// Note that we ensure the table is 256 elements long exactly to make sure -// that unused opcodes are handled gracefully. -static const OpcodeInfo kInfoTable[256] = { -#define DECLARE_INFO(ordinal, enum_value, name, flags, operand_encodings, ...) \ - OpcodeInfo{ \ - name, \ - flags, \ - {operand_encodings}, \ - }, - IREE_INTERPRETER_OPCODE_LIST(DECLARE_INFO, DECLARE_INFO) -#undef DECLARE_INFO -}; - -} // namespace - -OpcodeTable interpreter_opcode_table() { return kInfoTable; } - -} // namespace vm -} // namespace iree
diff --git a/iree/vm/bytecode_tables_interpreter.h b/iree/vm/bytecode_tables_interpreter.h deleted file mode 100644 index ef53902..0000000 --- a/iree/vm/bytecode_tables_interpreter.h +++ /dev/null
@@ -1,28 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_VM_BYTECODE_TABLES_INTERPRETER_H_ -#define IREE_VM_BYTECODE_TABLES_INTERPRETER_H_ - -#include "iree/vm/opcode_info.h" - -namespace iree { -namespace vm { - -OpcodeTable interpreter_opcode_table(); - -} // namespace vm -} // namespace iree - -#endif // IREE_VM_BYTECODE_TABLES_INTERPRETER_H_
diff --git a/iree/vm/bytecode_tables_sequencer.cc b/iree/vm/bytecode_tables_sequencer.cc deleted file mode 100644 index 6a64335..0000000 --- a/iree/vm/bytecode_tables_sequencer.cc +++ /dev/null
@@ -1,44 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/vm/bytecode_tables_sequencer.h" - -#include "iree/schemas/bytecode/sequencer_bytecode_v0.h" - -namespace iree { -namespace vm { - -namespace { - -// Info table mapping 1:1 with bytecode ops. -// -// Note that we ensure the table is 256 elements long exactly to make sure -// that unused opcodes are handled gracefully. -static const OpcodeInfo kInfoTable[256] = { -#define DECLARE_INFO(ordinal, enum_value, name, flags, operand_encodings, ...) \ - OpcodeInfo{ \ - name, \ - flags, \ - {operand_encodings}, \ - }, - IREE_SEQUENCER_OPCODE_LIST(DECLARE_INFO, DECLARE_INFO) -#undef DECLARE_INFO -}; - -} // namespace - -OpcodeTable sequencer_opcode_table() { return kInfoTable; } - -} // namespace vm -} // namespace iree
diff --git a/iree/vm/bytecode_tables_sequencer.h b/iree/vm/bytecode_tables_sequencer.h deleted file mode 100644 index 7e690c5..0000000 --- a/iree/vm/bytecode_tables_sequencer.h +++ /dev/null
@@ -1,28 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_VM_BYTECODE_TABLES_SEQUENCER_H_ -#define IREE_VM_BYTECODE_TABLES_SEQUENCER_H_ - -#include "iree/vm/opcode_info.h" - -namespace iree { -namespace vm { - -OpcodeTable sequencer_opcode_table(); - -} // namespace vm -} // namespace iree - -#endif // IREE_VM_BYTECODE_TABLES_SEQUENCER_H_
diff --git a/iree/vm/bytecode_util.cc b/iree/vm/bytecode_util.cc deleted file mode 100644 index 9307773..0000000 --- a/iree/vm/bytecode_util.cc +++ /dev/null
@@ -1,43 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/vm/bytecode_util.h" - -namespace iree { -namespace vm { - -absl::string_view PredicateToString(CmpIPredicate p) { -#define PRED(index, name, str, ...) \ - case CmpIPredicate::name: \ - return str; - switch (p) { - IREE_CMPI_PREDICATE_LIST(PRED) -#undef PRED - } - return "<unknown>"; -} - -absl::string_view PredicateToString(CmpFPredicate p) { -#define PRED(index, name, str, ...) \ - case CmpFPredicate::name: \ - return str; - switch (p) { - IREE_CMPF_PREDICATE_LIST(PRED) -#undef PRED - } - return "<unknown>"; -} - -} // namespace vm -} // namespace iree
diff --git a/iree/vm/bytecode_util.h b/iree/vm/bytecode_util.h deleted file mode 100644 index c663570..0000000 --- a/iree/vm/bytecode_util.h +++ /dev/null
@@ -1,31 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_VM_BYTECODE_UTIL_H_ -#define IREE_VM_BYTECODE_UTIL_H_ - -#include "absl/strings/string_view.h" -#include "iree/schemas/bytecode/bytecode_v0.h" - -namespace iree { -namespace vm { - -absl::string_view PredicateToString(CmpIPredicate predicate); - -absl::string_view PredicateToString(CmpFPredicate predicate); - -} // namespace vm -} // namespace iree - -#endif // IREE_VM_BYTECODE_UTIL_H_
diff --git a/iree/vm/bytecode_validator.cc b/iree/vm/bytecode_validator.cc deleted file mode 100644 index 968b193..0000000 --- a/iree/vm/bytecode_validator.cc +++ /dev/null
@@ -1,28 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/vm/bytecode_validator.h" - -namespace iree { -namespace vm { - -// static -Status BytecodeValidator::Validate(const BytecodeModule& module, - const BytecodeDef& bytecode_def) { - // TODO(benvanik): validate bytecode. - return OkStatus(); -} - -} // namespace vm -} // namespace iree
diff --git a/iree/vm/bytecode_validator.h b/iree/vm/bytecode_validator.h deleted file mode 100644 index 429c754..0000000 --- a/iree/vm/bytecode_validator.h +++ /dev/null
@@ -1,37 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_VM_BYTECODE_VALIDATOR_H_ -#define IREE_VM_BYTECODE_VALIDATOR_H_ - -#include "iree/base/status.h" -#include "iree/schemas/bytecode_def_generated.h" -#include "iree/vm/bytecode_module.h" - -namespace iree { -namespace vm { - -// Validates bytecode such that success indicates the bytecode does not -// reference undefined types, functions, or required imports and all imports can -// be resolved with matching signatures. -class BytecodeValidator { - public: - static Status Validate(const BytecodeModule& module, - const BytecodeDef& bytecode_def); -}; - -} // namespace vm -} // namespace iree - -#endif // IREE_VM_BYTECODE_VALIDATOR_H_
diff --git a/iree/vm/opcode_info.h b/iree/vm/opcode_info.h deleted file mode 100644 index 751a6e0..0000000 --- a/iree/vm/opcode_info.h +++ /dev/null
@@ -1,45 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_VM_OPCODE_INFO_H_ -#define IREE_VM_OPCODE_INFO_H_ - -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" -#include "absl/types/span.h" -#include "iree/schemas/bytecode/bytecode_v0.h" - -namespace iree { -namespace vm { - -struct OpcodeInfo { - const char* mnemonic; - OpcodeFlagBitfield flag; - union { - const char operands_value[8]; - const OperandEncoding operands[8]; - }; -}; - -using OpcodeTable = absl::Span<const OpcodeInfo>; - -template <typename T> -inline const OpcodeInfo& GetOpcodeInfo(OpcodeTable opcode_table, T opcode) { - return opcode_table[static_cast<uint8_t>(opcode)]; -} - -} // namespace vm -} // namespace iree - -#endif // IREE_VM_OPCODE_INFO_H_
diff --git a/iree/vm/sequencer_dispatch.cc b/iree/vm/sequencer_dispatch.cc deleted file mode 100644 index ce5a408..0000000 --- a/iree/vm/sequencer_dispatch.cc +++ /dev/null
@@ -1,561 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Implements a full bytecode dispatch system for sequencer ops. -// TODO(benvanik): rework to be async against CommandBuffers. - -#include "iree/vm/sequencer_dispatch.h" - -#include <algorithm> - -#include "absl/base/attributes.h" -#include "absl/container/inlined_vector.h" -#include "absl/strings/str_join.h" -#include "absl/time/time.h" -#include "absl/types/span.h" -#include "iree/base/logging.h" -#include "iree/base/memory.h" -#include "iree/base/status.h" -#include "iree/hal/buffer_view.h" -#include "iree/hal/command_queue.h" -#include "iree/hal/device.h" -#include "iree/hal/heap_buffer.h" -#include "iree/schemas/bytecode/sequencer_bytecode_v0.h" -#include "iree/vm/bytecode_module.h" -#include "iree/vm/bytecode_reader.h" -#include "iree/vm/bytecode_tables_sequencer.h" -#include "iree/vm/bytecode_util.h" -#include "iree/vm/opcode_info.h" - -namespace iree { -namespace vm { - -namespace { - -using ::iree::hal::Buffer; -using ::iree::hal::BufferView; - -// TODO(benvanik): remove (this should happen via predication). -bool BufferViewIsTrue(const BufferView& buffer_view) { - if (buffer_view.element_size == 0 || !buffer_view.buffer || - buffer_view.byte_length() == 0) { - return false; - } - // TODO(benvanik): map more efficiently (based on element size?). - auto mapping = - buffer_view.buffer->MapMemory<uint8_t>(hal::MemoryAccess::kRead); - if (!mapping.ok()) { - return false; - } - for (uint8_t value : mapping.ValueOrDie().contents()) { - if (value) return true; - } - return false; -} - -// TODO(benvanik): insert fence callbacks and wait on fence. -Status CallExternalFunction(rt::Stack* stack, const rt::Function& function) { - // Marshal inputs and outputs. - const auto* stack_frame = stack->current_frame(); - auto buffer_views = absl::MakeSpan(stack_frame->registers().buffer_views); - absl::InlinedVector<hal::BufferView, 8> arguments( - buffer_views.begin(), - buffer_views.begin() + function.signature().argument_count()); - absl::InlinedVector<hal::BufferView, 8> results( - buffer_views.begin() + arguments.size(), buffer_views.end()); - return function.module()->Execute(stack, function, std::move(arguments), - &results); -} - -// Pretty prints an array, e.g. [1, 2, 3, 4] -inline std::string PrettyPrint(absl::Span<const int32_t> arr) { - return "[" + absl::StrJoin(arr, ",") + "]"; -} - -// Calculates the byte offset into a buffer corresponding to the indices in the -// given shape. -StatusOr<device_size_t> CalculateOffset(absl::Span<const int32_t> indices, - Shape shape, uint8_t element_size) { - if (shape.empty() || indices.size() > shape.size()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Indices " << PrettyPrint(indices) << " out of bounds of shape " - << PrettyPrint(shape.subspan()); - } - device_size_t offset = 0; - for (int i = 0; i < indices.size(); ++i) { - if (indices[i] >= shape[i]) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Indices[" << i << "]=" << indices[i] - << " out of bounds of shape " << PrettyPrint(shape.subspan()); - } - device_size_t axis_offset = indices[i]; - for (int j = i + 1; j < shape.size(); ++j) { - axis_offset *= shape[j]; - } - offset += axis_offset; - } - offset *= element_size; - return offset; -} - -} // namespace - -Status DispatchSequence(const hal::DevicePlacement& placement, rt::Stack* stack, - rt::StackFrame* entry_stack_frame, - absl::Span<BufferView> entry_results) { - // Dispatch table mapping 1:1 with bytecode ops. - // Each entry is a label within this function that can be used for computed - // goto. You can find more information on computed goto here: - // https://eli.thegreenplace.net/2012/07/12/computed-goto-for-efficient-dispatch-tables - // - // Note that we ensure the table is 256 elements long exactly to make sure - // that unused opcodes are handled gracefully. - static const void* kDispatchTable[256] = { -#define DECLARE_DISPATCH(ordinal, name, ...) &&_dispatch_##name, -#define DECLARE_DISPATCH_RESERVED(ordinal, name, ...) &&_dispatch_unhandled, - IREE_SEQUENCER_OPCODE_LIST(DECLARE_DISPATCH, DECLARE_DISPATCH_RESERVED) -#undef DECLARE_DISPATCH -#undef DECLARE_DISPATCH_RESERVED - }; - - // Primary dispatch state. This is our 'native stack frame' and really just - // enough to make dereferencing common addresses (like the current offset) - // faster. You can think of this like CPU state (like PC). - // - // We hope that LLVM decides to keep these in registers (as they are touched - // for every instruction executed). The stack_frame will change as we call - // into different functions. - BytecodeReader reader(stack); - RETURN_IF_ERROR(reader.SwitchStackFrame(entry_stack_frame)); - -#define DISPATCH_NEXT() \ - { \ - uint8_t opcode = *reader.AdvanceOffset().ValueOrDie(); \ - DVLOG(1) << "Sequencer dispatching op code: " \ - << GetOpcodeInfo(sequencer_opcode_table(), opcode).mnemonic; \ - goto* kDispatchTable[opcode]; \ - } - -#define DISPATCH_CORE_OPCODE(opcode, body) \ - _dispatch_##opcode : {body} DISPATCH_NEXT() - - DISPATCH_NEXT(); - - DISPATCH_CORE_OPCODE(kConstant, { - ASSIGN_OR_RETURN(auto value, reader.ReadConstant()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - // TODO(b/139121143): until we have full command buffers we need to do this. - ASSIGN_OR_RETURN(value.buffer, - placement.device->allocator()->AllocateConstant( - hal::BufferUsage::kConstant | hal::BufferUsage::kAll, - std::move(value.buffer))); - *dst_local = std::move(value); - }); - - DISPATCH_CORE_OPCODE(kCall, { - auto* old_stack_frame = stack->current_frame(); - ASSIGN_OR_RETURN(const auto& target_function, reader.ReadFunction()); - // TODO(benvanik): rework register storage interface. - ASSIGN_OR_RETURN( - const auto* function_def, - static_cast<const BytecodeModule*>(target_function.module()) - ->GetFunctionDef(target_function.linkage(), - target_function.ordinal())); - ASSIGN_OR_RETURN(auto* new_stack_frame, stack->PushFrame(target_function)); - new_stack_frame->mutable_registers()->buffer_views.resize( - function_def->bytecode()->local_count()); - RETURN_IF_ERROR( - reader.CopyInputsAndSwitchStackFrame(old_stack_frame, new_stack_frame)); - DVLOG(1) << "Call; stack now: " << stack->DebugString(); - }); - - DISPATCH_CORE_OPCODE(kCallImport, { - auto* old_stack_frame = stack->current_frame(); - ASSIGN_OR_RETURN(const auto& target_function, reader.ReadImportFunction()); - ASSIGN_OR_RETURN(auto* new_stack_frame, stack->PushFrame(target_function)); - // TODO(benvanik): rework register storage interface. - const auto& signature = target_function.signature(); - new_stack_frame->mutable_registers()->buffer_views.resize( - signature.argument_count() + signature.result_count()); - RETURN_IF_ERROR( - reader.CopyInputsAndSwitchStackFrame(old_stack_frame, new_stack_frame)); - DVLOG(1) << "Call native import; stack now: " << stack->DebugString(); - RETURN_IF_ERROR(CallExternalFunction(stack, target_function)); - RETURN_IF_ERROR(reader.CopyResultsAndSwitchStackFrame(old_stack_frame, - new_stack_frame)); - RETURN_IF_ERROR(stack->PopFrame()); - DVLOG(1) << "Return from native; stack now: " << stack->DebugString(); - }); - - DISPATCH_CORE_OPCODE(kCallIndirect, { - return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented call_indirect"; - }); - - DISPATCH_CORE_OPCODE(kReturn, { - auto* old_stack_frame = stack->current_frame(); - auto* new_stack_frame = stack->caller_frame(); - if (old_stack_frame == entry_stack_frame) { - // Returning from entry function. Marshal results from the return stmt. - ASSIGN_OR_RETURN(int32_t src_count, reader.ReadCount()); - for (int i = 0; i < src_count; ++i) { - ASSIGN_OR_RETURN( - auto* src_local, - reader.ReadLocal(old_stack_frame->mutable_registers())); - entry_results[i] = std::move(*src_local); - } - DVLOG(1) << "Returning to entry"; - return OkStatus(); - } else if (!new_stack_frame) { - return FailedPreconditionErrorBuilder(IREE_LOC) << "Stack underflow"; - } - RETURN_IF_ERROR(reader.CopyResultsAndSwitchStackFrame(old_stack_frame, - new_stack_frame)); - RETURN_IF_ERROR(stack->PopFrame()); - DVLOG(1) << "Return; stack now: " << stack->DebugString(); - }); - - DISPATCH_CORE_OPCODE(kBranch, { - ASSIGN_OR_RETURN(int32_t offset, reader.ReadBlockOffset()); - RETURN_IF_ERROR(reader.CopySlots()); - RETURN_IF_ERROR(reader.BranchToOffset(offset)); - }); - - DISPATCH_CORE_OPCODE(kCondBranch, { - // Evaluate condition first so we can do the copies as we read them for - // which side of the branch we take. - ASSIGN_OR_RETURN(auto* cond_local, reader.ReadLocal()); - bool cond_value = BufferViewIsTrue(*cond_local); - ASSIGN_OR_RETURN(int32_t true_offset, reader.ReadBlockOffset()); - - if (cond_value) { - RETURN_IF_ERROR(reader.CopySlots()); - RETURN_IF_ERROR(reader.BranchToOffset(true_offset)); - } else { - ASSIGN_OR_RETURN(int32_t true_op_count, reader.ReadCount()); - RETURN_IF_ERROR(reader.SkipLocals(2 * true_op_count)); - ASSIGN_OR_RETURN(int32_t false_offset, reader.ReadBlockOffset()); - - RETURN_IF_ERROR(reader.CopySlots()); - RETURN_IF_ERROR(reader.BranchToOffset(false_offset)); - } - }); - - DISPATCH_CORE_OPCODE(kDynamicDispatch, { - return UnimplementedErrorBuilder(IREE_LOC) - << "Unimplemented dynamic_dispatch"; - }); - - DISPATCH_CORE_OPCODE(kStaticDispatch, { - // TODO(benvanik): the real sequencer :) - ASSIGN_OR_RETURN(auto dispatch_ordinal, reader.ReadInt32()); - ASSIGN_OR_RETURN(auto export_ordinal, reader.ReadUint16_t()); - ASSIGN_OR_RETURN( - const auto* multi_arch_executable_def, - static_cast<const BytecodeModule&>(stack->current_frame()->module()) - .LookupMultiArchExecutable(dispatch_ordinal)); - if (export_ordinal >= multi_arch_executable_def->entry_point_count()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Invalid executable export ordinal " << export_ordinal; - } - auto* executable_def = multi_arch_executable_def->executables()->Get(0); - hal::ExecutableSpec executable_spec; - executable_spec.format = executable_def->format(); - executable_spec.executable_data = absl::Span<const uint8_t>( - executable_def->contents()->data(), executable_def->contents()->size()); - auto executable_cache = placement.device->CreateExecutableCache(); - ref_ptr<hal::Executable> executable; - for (auto* executable_def : *multi_arch_executable_def->executables()) { - if (!executable_cache->CanPrepareFormat(executable_def->format())) { - continue; - } - hal::ExecutableSpec executable_spec; - executable_spec.format = executable_def->format(); - executable_spec.executable_data = - absl::Span<const uint8_t>(executable_def->contents()->data(), - executable_def->contents()->size()); - ASSIGN_OR_RETURN(executable, - executable_cache->PrepareExecutable( - hal::ExecutableCachingMode::kDefault | - hal::ExecutableCachingMode::kAliasProvidedData, - executable_spec), - _.LogError()); - break; - } - if (!executable) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "No executable found for the current driver"; - } - - ASSIGN_OR_RETURN(int workload_x, reader.ReadInt32()); - ASSIGN_OR_RETURN(int workload_y, reader.ReadInt32()); - ASSIGN_OR_RETURN(int workload_z, reader.ReadInt32()); - - std::vector<hal::BufferBinding> bindings; - ASSIGN_OR_RETURN(int input_count, reader.ReadCount()); - for (int i = 0; i < input_count; ++i) { - ASSIGN_OR_RETURN(auto* input_local, reader.ReadLocal()); - bindings.push_back(hal::BufferBinding( - input_local->buffer->allowed_access() & hal::MemoryAccess::kAll, - *input_local)); - } - ASSIGN_OR_RETURN(int output_count, reader.ReadCount()); - for (int i = 0; i < output_count; ++i) { - ASSIGN_OR_RETURN(auto* output_local, reader.ReadLocal()); - bindings.push_back( - hal::BufferBinding(hal::MemoryAccess::kWrite, *output_local)); - } - ASSIGN_OR_RETURN(int result_count, reader.ReadCount()); - CHECK_EQ(0, result_count) << "Results not yet implemented"; - - ASSIGN_OR_RETURN( - auto cmd, - placement.device->CreateCommandBuffer( - hal::CommandBufferMode::kOneShot, - hal::CommandCategory::kTransfer | hal::CommandCategory::kDispatch), - _.LogError()); - RETURN_IF_ERROR(cmd->Begin()); - hal::DispatchRequest dispatch_request; - dispatch_request.executable = executable.get(); - dispatch_request.entry_point = export_ordinal; - dispatch_request.workload[0] = workload_x; - dispatch_request.workload[1] = workload_y; - dispatch_request.workload[2] = workload_z; - dispatch_request.bindings = bindings; - RETURN_IF_ERROR(cmd->Dispatch(dispatch_request)); - RETURN_IF_ERROR(cmd->End()); - auto* cmd_ptr = cmd.get(); - - auto* queue = placement.device->dispatch_queues().front(); - hal::SubmissionBatch batch; - batch.command_buffers = absl::MakeConstSpan(&cmd_ptr, 1); - ASSIGN_OR_RETURN(auto fence, placement.device->CreateFence(0u)); - RETURN_IF_ERROR(queue->Submit(batch, {fence.get(), 1u})); - RETURN_IF_ERROR(placement.device->WaitAllFences({{fence.get(), 1u}}, - absl::InfiniteFuture())); - }); - - DISPATCH_CORE_OPCODE(kAllocStatic, { - return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented alloc_static"; - }); - - DISPATCH_CORE_OPCODE(kAllocStack, { - return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented alloc_stack"; - }); - - DISPATCH_CORE_OPCODE(kAllocStackInit, { - return UnimplementedErrorBuilder(IREE_LOC) - << "Unimplemented alloc_stack_init"; - }); - - DISPATCH_CORE_OPCODE(kAllocHeap, { - ASSIGN_OR_RETURN(auto heap_type, reader.ReadInt32()); - ASSIGN_OR_RETURN(auto type, reader.ReadType()); - size_t element_size = type.element_size(); - - // TODO(benvanik): more efficient reading and storage. - size_t element_count = 0; - ASSIGN_OR_RETURN(auto shape, reader.ReadShapePieces(&element_count)); - size_t allocation_size = element_size * element_count; - - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - dst_local->element_size = element_size; - dst_local->shape = shape; - - // TODO(benvanik): pick an allocator and use that instead. - CHECK_EQ(heap_type, 0); - auto* allocator = placement.device->allocator(); - ASSIGN_OR_RETURN( - dst_local->buffer, - allocator->Allocate( - hal::MemoryType::kHostLocal | hal::MemoryType::kDeviceVisible, - hal::BufferUsage::kAll, allocation_size)); - }); - - DISPATCH_CORE_OPCODE(kDiscard, { - // NOTE: if we were an encoder we would actually discard the buffer. - ASSIGN_OR_RETURN(auto* local, reader.ReadLocal()); - *local = {}; - }); - - DISPATCH_CORE_OPCODE(kComputeRange, { - ASSIGN_OR_RETURN(auto shape_data, reader.ReadSlotElements<int32_t>()); - ASSIGN_OR_RETURN(auto element_size, reader.ReadUint8_t()); - ASSIGN_OR_RETURN(auto indices, reader.ReadSlotElements<int32_t>()); - ASSIGN_OR_RETURN(auto lengths, reader.ReadSlotElements<int32_t>()); - ASSIGN_OR_RETURN(auto* dst_offset_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* dst_length_local, reader.ReadLocal()); - - Shape shape(shape_data); - ASSIGN_OR_RETURN(device_size_t dst_offset, - CalculateOffset(indices, shape, element_size)); - RETURN_IF_ERROR( - dst_offset_local->buffer->WriteData(0, &dst_offset, sizeof(int32_t))); - - // A buffer range can only be computed for contiguous memory. To ensure that - // this only requests such, we validate that the offset in the buffer - // between the start and end indices is the same as the requested size. - device_size_t dst_length = element_size; - for (int i = 0; i < lengths.size(); ++i) { - dst_length *= lengths[i]; - indices[i] += lengths[i] - 1; - } - ASSIGN_OR_RETURN(auto end_offset, - CalculateOffset(indices, shape, element_size)); - auto offset_based_length = end_offset - dst_offset + element_size; - if (dst_length != offset_based_length) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Cannot compute range for non-contiguous region of memory;" - << " shape: " << PrettyPrint(shape.subspan()) - << " indices: " << PrettyPrint(indices) - << " lengths: " << PrettyPrint(lengths); - } - RETURN_IF_ERROR( - dst_length_local->buffer->WriteData(0, &dst_length, sizeof(int32_t))); - }); - - DISPATCH_CORE_OPCODE(kShape, { - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - RETURN_IF_ERROR(dst_local->buffer->WriteData( - 0, src_local->shape.subspan().data(), - src_local->shape.subspan().size() * sizeof(int32_t))); - }); - - DISPATCH_CORE_OPCODE(kLength, { - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - int32_t length = src_local->shape.element_count(); - RETURN_IF_ERROR(dst_local->buffer->WriteData(0, &length, sizeof(int32_t))); - }); - - DISPATCH_CORE_OPCODE(kDynamicSlice, { - // TODO(b/139299169): implement indirect copies to avoid CPU readback. - return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented dynamic_slice"; - }); - - DISPATCH_CORE_OPCODE(kStaticSlice, { - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto offset, reader.ReadInt32()); - ASSIGN_OR_RETURN(auto length, reader.ReadInt32()); - ASSIGN_OR_RETURN(auto type, reader.ReadType()); - ASSIGN_OR_RETURN(auto shape_data, reader.ReadSlotElements<int32_t>()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - Shape new_shape = Shape{shape_data}; - if (new_shape.element_count() * type.element_size() != length) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "New element count " << new_shape.element_count() - << " != length slice " << length; - } - ASSIGN_OR_RETURN(dst_local->buffer, - Buffer::Subspan(src_local->buffer, offset, length)); - dst_local->shape = new_shape; - dst_local->element_size = type.element_size(); - }); - - DISPATCH_CORE_OPCODE(kDynamicCopy, { - // TODO(b/139299169): implement indirect copies to avoid CPU readback. - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto src_offset_span, reader.ReadSlotElements<int32_t>()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto dst_offset_span, reader.ReadSlotElements<int32_t>()); - ASSIGN_OR_RETURN(auto length_span, reader.ReadSlotElements<int32_t>()); - RETURN_IF_ERROR(dst_local->buffer->CopyData( - dst_offset_span.front(), src_local->buffer.get(), - src_offset_span.front(), length_span.front())); - }); - - DISPATCH_CORE_OPCODE(kStaticCopy, { - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto src_offset, reader.ReadInt32()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto dst_offset, reader.ReadInt32()); - ASSIGN_OR_RETURN(auto length, reader.ReadInt32()); - RETURN_IF_ERROR(dst_local->buffer->CopyData( - dst_offset, src_local->buffer.get(), src_offset, length)); - }); - - DISPATCH_CORE_OPCODE(kDynamicFill, { - // TODO(b/139299169): implement indirect fills to avoid CPU readback. - return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented dynamic_fill"; - }); - - DISPATCH_CORE_OPCODE(kStaticFill, { - ASSIGN_OR_RETURN(auto value, reader.ReadInt32()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto dst_offset, reader.ReadInt32()); - ASSIGN_OR_RETURN(auto length, reader.ReadInt32()); - RETURN_IF_ERROR(dst_local->buffer->Fill32(dst_offset, length, value)); - }); - - DISPATCH_CORE_OPCODE(kClone, { - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - dst_local->element_size = src_local->element_size; - dst_local->shape = src_local->shape; - ASSIGN_OR_RETURN(dst_local->buffer, placement.device->allocator()->Allocate( - src_local->buffer->memory_type(), - src_local->buffer->usage(), - src_local->buffer->byte_length())); - RETURN_IF_ERROR(dst_local->buffer->CopyData(0, src_local->buffer.get())); - }); - - DISPATCH_CORE_OPCODE(kAssign, { - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - *dst_local = *src_local; - }); - - DISPATCH_CORE_OPCODE(kCondAssign, { - ASSIGN_OR_RETURN(auto* cond_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* lhs_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* rhs_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - *dst_local = BufferViewIsTrue(*cond_local) ? *lhs_local : *rhs_local; - }); - - DISPATCH_CORE_OPCODE(kReshape, { - // TODO(benvanik): more logic required if strides differ. - ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); - ASSIGN_OR_RETURN(auto shape_data, reader.ReadSlotElements<int32_t>()); - ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); - Shape new_shape = Shape{shape_data}; - if (src_local->shape.element_count() != new_shape.element_count()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "New element count " << new_shape.element_count() - << " != source element count " << src_local->shape.element_count(); - } - dst_local->shape = new_shape; - dst_local->buffer = add_ref(src_local->buffer); - dst_local->element_size = src_local->element_size; - }); - - DISPATCH_CORE_OPCODE(kTrace, { - return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented trace"; - }); - - DISPATCH_CORE_OPCODE(kBreak, { - return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented break"; - }); - - DISPATCH_CORE_OPCODE(kCondBreak, { - return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented cond_break"; - }); - -_dispatch_unhandled: - // TODO(benvanik): better tracing. - return UnimplementedErrorBuilder(IREE_LOC) << "Unknown dispatch opcode"; -} - -} // namespace vm -} // namespace iree
diff --git a/iree/vm/sequencer_dispatch.h b/iree/vm/sequencer_dispatch.h deleted file mode 100644 index 0251c17..0000000 --- a/iree/vm/sequencer_dispatch.h +++ /dev/null
@@ -1,35 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_VM_SEQUENCER_DISPATCH_H_ -#define IREE_VM_SEQUENCER_DISPATCH_H_ - -#include "iree/base/status.h" -#include "iree/hal/buffer_view.h" -#include "iree/hal/device_placement.h" -#include "iree/rt/stack.h" -#include "iree/rt/stack_frame.h" - -namespace iree { -namespace vm { - -// TODO(benvanik): API that supports yielding. -Status DispatchSequence(const hal::DevicePlacement& placement, rt::Stack* stack, - rt::StackFrame* entry_stack_frame, - absl::Span<hal::BufferView> entry_results); - -} // namespace vm -} // namespace iree - -#endif // IREE_VM_SEQUENCER_DISPATCH_H_
diff --git a/iree/vm/sequencer_module.cc b/iree/vm/sequencer_module.cc deleted file mode 100644 index c3cf58f..0000000 --- a/iree/vm/sequencer_module.cc +++ /dev/null
@@ -1,112 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/vm/sequencer_module.h" - -#include "absl/memory/memory.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/buffer_view.h" -#include "iree/rt/context.h" -#include "iree/rt/instance.h" -#include "iree/vm/bytecode_tables_sequencer.h" -#include "iree/vm/sequencer_dispatch.h" - -namespace iree { -namespace vm { - -namespace { - -using ::iree::hal::BufferView; -using ::iree::rt::Function; -using ::iree::rt::Module; - -} // namespace - -// static -StatusOr<ref_ptr<rt::Module>> SequencerModule::FromDef( - const ModuleDef& module_def) { - ASSIGN_OR_RETURN(auto module_file, ModuleFile::Create(&module_def, []() {})); - return FromFile(std::move(module_file)); -} - -// static -StatusOr<ref_ptr<rt::Module>> SequencerModule::FromFile( - std::unique_ptr<ModuleFile> module_file) { - if (module_file->root() == nullptr) { - return InvalidArgumentErrorBuilder(IREE_LOC) << "No root ModuleDef present"; - } - const auto& module_def = *module_file->root(); - - // Validates the structure of the module (but not bytecode). - // This ensures we don't have flatbuffer vectors will null entries, etc. - RETURN_IF_ERROR(BytecodeModule::ValidateStructure(module_def)); - - auto module = assign_ref(new SequencerModule(std::move(module_file))); - - // TODO(benvanik): validate internals here? or make explicit? - - return {std::move(module)}; -} - -SequencerModule::SequencerModule(std::unique_ptr<ModuleFile> module_file) - : BytecodeModule(std::move(module_file), sequencer_opcode_table()) {} - -SequencerModule::~SequencerModule() = default; - -Status SequencerModule::Execute( - rt::Stack* stack, const Function function, - absl::InlinedVector<hal::BufferView, 8> arguments, - absl::InlinedVector<hal::BufferView, 8>* results) const { - IREE_TRACE_SCOPE0("SequencerModule::Execute"); - - // Push stack frame for the function we are calling. - ASSIGN_OR_RETURN(auto* callee_stack_frame, stack->PushFrame(function)); - - // TODO(benvanik): rework register storage interface. - ASSIGN_OR_RETURN(const auto* function_def, - GetFunctionDef(function.linkage(), function.ordinal())); - auto* registers = callee_stack_frame->mutable_registers(); - registers->buffer_views.resize(function_def->bytecode()->local_count()); - - // Marshal input arguments. - for (int i = 0; i < arguments.size(); ++i) { - auto arg = arguments[i]; - auto expected_arg_type = function_def->type()->inputs()->Get(i); - RETURN_IF_ERROR(BytecodeModule::ValidateArgType( - arg, *expected_arg_type->type_union_as_MemRefTypeDef())) - << "Function " << function.name() << " argument " << i; - registers->buffer_views[i] = std::move(arg); - } - - // TODO(benvanik): change to: - // get command queue (any command queue) - // make command buffer - // record dispatch - // submit - // wait on fence - ASSIGN_OR_RETURN( - auto placement, - stack->context()->instance()->device_manager()->ResolvePlacement({})); - RETURN_IF_ERROR(DispatchSequence(placement, stack, callee_stack_frame, - absl::MakeSpan(*results))); - - // Pop the callee frame to balance out the stack. - RETURN_IF_ERROR(stack->PopFrame()); - - return OkStatus(); -} - -} // namespace vm -} // namespace iree
diff --git a/iree/vm/sequencer_module.h b/iree/vm/sequencer_module.h deleted file mode 100644 index b9bb176..0000000 --- a/iree/vm/sequencer_module.h +++ /dev/null
@@ -1,46 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_VM_SEQUENCER_MODULE_H_ -#define IREE_VM_SEQUENCER_MODULE_H_ - -#include <memory> - -#include "iree/vm/bytecode_module.h" - -namespace iree { -namespace vm { - -// A module using the sequencer bytecode ops. -class SequencerModule final : public BytecodeModule { - public: - static StatusOr<ref_ptr<rt::Module>> FromDef(const ModuleDef& module_def); - static StatusOr<ref_ptr<rt::Module>> FromFile( - std::unique_ptr<ModuleFile> module_file); - - ~SequencerModule() override; - - Status Execute( - rt::Stack* stack, const rt::Function function, - absl::InlinedVector<hal::BufferView, 8> arguments, - absl::InlinedVector<hal::BufferView, 8>* results) const override; - - private: - explicit SequencerModule(std::unique_ptr<ModuleFile> module_file); -}; - -} // namespace vm -} // namespace iree - -#endif // IREE_VM_SEQUENCER_MODULE_H_
diff --git a/iree/vm/source_map_resolver.cc b/iree/vm/source_map_resolver.cc deleted file mode 100644 index 96025e4..0000000 --- a/iree/vm/source_map_resolver.cc +++ /dev/null
@@ -1,194 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/vm/source_map_resolver.h" - -#include "iree/base/flatbuffer_util.h" -#include "iree/base/status.h" -#include "iree/schemas/source_map_def_generated.h" - -namespace iree { -namespace vm { - -namespace { - -Status PrintLocation(const SourceMapResolver& source_map, - const FunctionSourceMapDef& function_source_map, - const LocationDef& location, std::ostream* stream); - -Status PrintFileLocation(const SourceMapResolver& source_map, - const FunctionSourceMapDef& function_source_map, - const FileLocationDef& location, - std::ostream* stream) { - ASSIGN_OR_RETURN(auto filename, - source_map.GetUniqueString(location.filename())); - *stream << filename << ":" << location.line() << ":" << location.column(); - return OkStatus(); -} - -Status PrintNameLocation(const SourceMapResolver& source_map, - const FunctionSourceMapDef& function_source_map, - const NameLocationDef& location, - std::ostream* stream) { - ASSIGN_OR_RETURN(auto name, source_map.GetUniqueString(location.name())); - *stream << "\"" << name << "\""; - return OkStatus(); -} - -Status PrintCallSiteLocation(const SourceMapResolver& source_map, - const FunctionSourceMapDef& function_source_map, - const CallSiteLocationDef& location, - std::ostream* stream) { - *stream << "(callsites todo)"; - return OkStatus(); -} - -Status PrintFusedLocation(const SourceMapResolver& source_map, - const FunctionSourceMapDef& function_source_map, - const FusedLocationDef& location, - std::ostream* stream) { - *stream << "fused["; - if (location.locations()) { - for (int i = 0; i < location.locations()->size(); ++i) { - if (i > 0) *stream << ", "; - int location_ordinal = location.locations()->Get(i); - const auto& child_location = - *function_source_map.location_table()->Get(location_ordinal); - RETURN_IF_ERROR(PrintLocation(source_map, function_source_map, - child_location, stream)); - } - } - *stream << "]"; - return OkStatus(); -} - -Status PrintLocation(const SourceMapResolver& source_map, - const FunctionSourceMapDef& function_source_map, - const LocationDef& location, std::ostream* stream) { - switch (location.location_union_type()) { - case LocationDefUnion::FileLocationDef: - return PrintFileLocation(source_map, function_source_map, - *location.location_union_as_FileLocationDef(), - stream); - case LocationDefUnion::NameLocationDef: - return PrintNameLocation(source_map, function_source_map, - *location.location_union_as_NameLocationDef(), - stream); - case LocationDefUnion::CallSiteLocationDef: - return PrintCallSiteLocation( - source_map, function_source_map, - *location.location_union_as_CallSiteLocationDef(), stream); - case LocationDefUnion::FusedLocationDef: - return PrintFusedLocation(source_map, function_source_map, - *location.location_union_as_FusedLocationDef(), - stream); - default: - return UnimplementedErrorBuilder(IREE_LOC) - << "Unhandled location type " - << static_cast<int>(location.location_union_type()); - } -} - -} // namespace - -// static -SourceMapResolver SourceMapResolver::FromModule(const ModuleDef& module_def) { - if (module_def.source_map()) { - return SourceMapResolver{*module_def.source_map()}; - } - return {}; -} - -StatusOr<absl::string_view> SourceMapResolver::GetUniqueString( - int string_index) const { - if (empty()) { - return NotFoundErrorBuilder(IREE_LOC) << "No source map present"; - } - const auto* string_table = source_map_def_->string_table(); - if (string_table && string_table->size() > string_index) { - return WrapString(string_table->Get(string_index)); - } - return NotFoundErrorBuilder(IREE_LOC) - << "String index " << string_index << " not present in string table"; -} - -StatusOr<const FunctionSourceMapDef*> SourceMapResolver::GetFunctionSourceMap( - int function_ordinal) const { - if (empty()) { - return NotFoundErrorBuilder(IREE_LOC) << "No source map present"; - } - const auto* function_table = source_map_def_->function_table(); - if (function_table && function_table->size() > function_ordinal) { - const auto* function_source_map = function_table->Get(function_ordinal); - if (function_source_map && function_source_map->location_table() && - function_source_map->bytecode_map()) { - return function_source_map; - } - } - return NotFoundErrorBuilder(IREE_LOC) - << "Function ordinal " << function_ordinal - << " source map not present in function table"; -} - -absl::optional<rt::SourceLocation> SourceMapResolver::ResolveFunctionOffset( - const rt::Function& function, rt::SourceOffset offset) { - if (empty()) return absl::nullopt; - auto function_source_map_or = GetFunctionSourceMap(function.ordinal()); - if (!function_source_map_or.ok()) { - return absl::nullopt; - } - const auto* function_source_map = function_source_map_or.ValueOrDie(); - const auto* bytecode_map = function_source_map->bytecode_map(); - if (!bytecode_map) return absl::nullopt; - - // TODO(benvanik): allow fuzzy offset matching/table sparsity. - int location_ordinal = -1; - for (const auto* map_loc : *bytecode_map) { - if (map_loc->offset() == offset) { - location_ordinal = map_loc->location(); - break; - } - } - if (location_ordinal == -1) { - return absl::nullopt; - } - - return rt::SourceLocation(this, - { - reinterpret_cast<uint64_t>(function_source_map), - static_cast<uint64_t>(location_ordinal), - }); -} - -void SourceMapResolver::PrintSourceLocation( - rt::SourceResolverArgs resolver_args, std::ostream* stream) const { - if (empty()) { - *stream << "<unknown>"; - return; - } - - auto* function_source_map = - reinterpret_cast<FunctionSourceMapDef*>(resolver_args[0]); - int location_ordinal = static_cast<int>(resolver_args[1]); - - const auto& location = - *function_source_map->location_table()->Get(location_ordinal); - auto status = PrintLocation(*this, *function_source_map, location, stream); - if (!status.ok()) { - *stream << status; - } -} - -} // namespace vm -} // namespace iree
diff --git a/iree/vm/source_map_resolver.h b/iree/vm/source_map_resolver.h deleted file mode 100644 index 5c8f7c2..0000000 --- a/iree/vm/source_map_resolver.h +++ /dev/null
@@ -1,57 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_VM_SOURCE_MAP_RESOLVER_H_ -#define IREE_VM_SOURCE_MAP_RESOLVER_H_ - -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" -#include "iree/base/status.h" -#include "iree/rt/source_resolver.h" -#include "iree/schemas/module_def_generated.h" -#include "iree/schemas/source_map_def_generated.h" - -namespace iree { -namespace vm { - -class SourceMapResolver final : public rt::SourceResolver { - public: - static SourceMapResolver FromModule(const ModuleDef& module_def); - - SourceMapResolver() = default; - explicit SourceMapResolver(const SourceMapDef& source_map_def) - : source_map_def_(&source_map_def) {} - - bool empty() const { return source_map_def_ == nullptr; } - const SourceMapDef* def() const { return source_map_def_; } - - StatusOr<absl::string_view> GetUniqueString(int string_index) const; - - StatusOr<const FunctionSourceMapDef*> GetFunctionSourceMap( - int function_ordinal) const; - - absl::optional<rt::SourceLocation> ResolveFunctionOffset( - const rt::Function& function, rt::SourceOffset offset) override; - - void PrintSourceLocation(rt::SourceResolverArgs resolver_args, - std::ostream* stream) const override; - - private: - const SourceMapDef* source_map_def_ = nullptr; -}; - -} // namespace vm -} // namespace iree - -#endif // IREE_VM_SOURCE_MAP_RESOLVER_H_
diff --git a/iree/vm/type.cc b/iree/vm/type.cc deleted file mode 100644 index 5b2c372..0000000 --- a/iree/vm/type.cc +++ /dev/null
@@ -1,64 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/vm/type.h" - -#include "iree/base/status.h" - -namespace iree { -namespace vm { - -// static -StatusOr<const Type> Type::FromTypeIndex(uint8_t type_index) { - // Currently we only support the builtin types. - if (type_index == static_cast<uint8_t>(BuiltinType::kOpaque)) { - return Type(type_index); - } else if (type_index < kBuiltinTypeCount) { - return Type(type_index); - } - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Type index " << static_cast<int>(type_index) << " not supported"; -} - -// static -const Type Type::FromBuiltin(BuiltinType type) { - return Type(static_cast<uint8_t>(type)); -} - -std::string Type::DebugString() const { - switch (type_index_) { -#define TYPE_NAME(index, name, str, size) \ - case index: \ - return str; - IREE_TYPE_LIST(TYPE_NAME) -#undef TYPE_NAME - default: - return "<invalid>"; - } -} - -size_t Type::element_size() const { - switch (type_index_) { -#define TYPE_SIZE(index, name, str, size) \ - case index: \ - return size; - IREE_TYPE_LIST(TYPE_SIZE) -#undef TYPE_SIZE - default: - return 0; - } -} - -} // namespace vm -} // namespace iree
diff --git a/iree/vm/type.h b/iree/vm/type.h deleted file mode 100644 index b8ab9bd..0000000 --- a/iree/vm/type.h +++ /dev/null
@@ -1,65 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_VM_TYPE_H_ -#define IREE_VM_TYPE_H_ - -#include "iree/base/status.h" -#include "iree/schemas/bytecode/bytecode_v0.h" -#include "iree/schemas/type_def_generated.h" - -namespace iree { -namespace vm { - -class Type { - public: - static StatusOr<const Type> FromTypeIndex(uint8_t type_index); - static const Type FromBuiltin(BuiltinType type); - - std::string DebugString() const; - - uint8_t type_index() const { return type_index_; } - - bool is_opaque() const { - return type_index_ == static_cast<uint8_t>(BuiltinType::kOpaque); - } - bool is_builtin() const { return !is_opaque(); } - BuiltinType builtin_type() const { - DCHECK(is_builtin()); - return static_cast<BuiltinType>(type_index_); - } - - size_t element_size() const; - - private: - explicit Type(uint8_t type_index) : type_index_(type_index) {} - - uint8_t type_index_; -}; - -inline bool operator==(const Type& a, const Type& b) { - return a.type_index() == b.type_index(); -} - -inline bool operator!=(const Type& a, const Type& b) { return !(a == b); } - -inline std::ostream& operator<<(std::ostream& stream, const Type& type) { - stream << type.DebugString(); - return stream; -} - -} // namespace vm -} // namespace iree - -#endif // IREE_VM_TYPE_H_
diff --git a/rt/BUILD b/rt/BUILD new file mode 100644 index 0000000..e06308c --- /dev/null +++ b/rt/BUILD
@@ -0,0 +1,83 @@ +# Runtime API for interacting with IREE modules and invoking functions. + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "api", + srcs = ["api.cc"], + hdrs = ["api.h"], + visibility = ["//visibility:public"], + deps = [ + ":api_hdrs", + ":rt", + "///base:api", + "///base:api_util", + "///base:tracing", + "///hal:api", + "///hal:buffer_view", + "///hal:driver_registry", + "///rt/debug:debug_server_interface", + "@com_google_absl//absl/time", + ], +) + +cc_library( + name = "api_hdrs", + hdrs = ["api.h"], + deps = [ + "///base:api_hdrs", + "///hal:api_hdrs", + ], +) + +cc_library( + name = "rt", + srcs = [ + "context.cc", + "function.cc", + "instance.cc", + "invocation.cc", + "module_printer.cc", + "source_location.cc", + "stack.cc", + "stack_frame.cc", + "stack_trace.cc", + ], + hdrs = [ + "context.h", + "disassembler.h", + "function.h", + "function_signature.h", + "instance.h", + "invocation.h", + "module.h", + "module_printer.h", + "module_signature.h", + "policy.h", + "source_location.h", + "source_resolver.h", + "stack.h", + "stack_frame.h", + "stack_trace.h", + ], + deps = [ + "///base:bitfield", + "///base:intrusive_list", + "///base:ref_ptr", + "///base:status", + "///base:tracing", + "///hal:buffer_view", + "///hal:device_manager", + "///rt/debug:debug_server_interface", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +)
diff --git a/iree/rt/CMakeLists.txt b/rt/CMakeLists.txt similarity index 100% rename from iree/rt/CMakeLists.txt rename to rt/CMakeLists.txt
diff --git a/rt/api.cc b/rt/api.cc new file mode 100644 index 0000000..c99f733 --- /dev/null +++ b/rt/api.cc
@@ -0,0 +1,788 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rt/api.h" + +#include "absl/time/time.h" +#include "base/api.h" +#include "base/api_util.h" +#include "base/tracing.h" +#include "hal/api.h" +#include "hal/api_detail.h" +#include "hal/buffer_view.h" +#include "hal/driver_registry.h" +#include "rt/context.h" +#include "rt/debug/debug_server.h" +#include "rt/function.h" +#include "rt/instance.h" +#include "rt/invocation.h" +#include "rt/module.h" +#include "rt/policy.h" + +namespace iree { +namespace rt { + +//===----------------------------------------------------------------------===// +// iree::rt::Instance +//===----------------------------------------------------------------------===// + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_instance_create( + iree_allocator_t allocator, iree_rt_instance_t** out_instance) { + IREE_TRACE_SCOPE0("iree_rt_instance_create"); + + if (!out_instance) { + return IREE_STATUS_INVALID_ARGUMENT; + } + *out_instance = nullptr; + + auto instance = make_ref<Instance>(); + *out_instance = reinterpret_cast<iree_rt_instance_t*>(instance.release()); + + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_instance_retain(iree_rt_instance_t* instance) { + IREE_TRACE_SCOPE0("iree_rt_instance_retain"); + auto* handle = reinterpret_cast<Instance*>(instance); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + handle->AddReference(); + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_instance_release(iree_rt_instance_t* instance) { + IREE_TRACE_SCOPE0("iree_rt_instance_release"); + auto* handle = reinterpret_cast<Instance*>(instance); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + handle->ReleaseReference(); + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_instance_register_driver_ex( + iree_rt_instance_t* instance, iree_string_view_t driver_name) { + IREE_TRACE_SCOPE0("iree_rt_instance_register_driver_ex"); + auto* handle = reinterpret_cast<Instance*>(instance); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + + IREE_API_ASSIGN_OR_RETURN( + auto driver, hal::DriverRegistry::shared_registry()->Create( + absl::string_view{driver_name.data, driver_name.size})); + IREE_API_ASSIGN_OR_RETURN(auto available_devices, + driver->EnumerateAvailableDevices()); + for (const auto& device_info : available_devices) { + LOG(INFO) << " Device: " << device_info.name(); + } + LOG(INFO) << "Creating default device..."; + IREE_API_ASSIGN_OR_RETURN(auto device, driver->CreateDefaultDevice()); + IREE_API_RETURN_IF_ERROR(handle->device_manager()->RegisterDevice(device)); + LOG(INFO) << "Successfully created device '" << device->info().name() << "'"; + + return IREE_STATUS_OK; +} + +//===----------------------------------------------------------------------===// +// iree::rt::Module +//===----------------------------------------------------------------------===// + +namespace { + +class ExternalModule final : public Module { + public: + ExternalModule(iree_rt_external_module_t impl, iree_allocator_t allocator) + : impl_(impl), allocator_(allocator) { + IREE_TRACE_SCOPE0("ExternalModule::ctor"); + } + + ~ExternalModule() override { + IREE_TRACE_SCOPE0("ExternalModule::dtor"); + impl_.destroy(impl_.self); + std::memset(&impl_, 0, sizeof(impl_)); + } + + absl::string_view name() const override { + auto result = impl_.name(impl_.self); + return absl::string_view{result.data, result.size}; + } + + const ModuleSignature signature() const override { + auto signature = impl_.signature(impl_.self); + return ModuleSignature{ + signature.import_function_count, + signature.export_function_count, + signature.internal_function_count, + signature.state_slot_count, + }; + } + + SourceResolver* source_resolver() const override { return nullptr; } + + Disassembler* disassembler() const override { return nullptr; } + + std::string DebugStringShort() const override { return std::string(name()); } + + StatusOr<const Function> LookupFunctionByOrdinal( + Function::Linkage linkage, int32_t ordinal) const override { + IREE_TRACE_SCOPE0("ExternalModule::LookupFunctionByOrdinal"); + iree_rt_function_t function; + auto status = impl_.lookup_function_by_ordinal( + impl_.self, static_cast<iree_rt_function_linkage_t>(linkage), ordinal, + &function); + if (status != IREE_STATUS_OK) { + return FromApiStatus(status, IREE_LOC); + } + return Function{reinterpret_cast<Module*>(function.module), + static_cast<Function::Linkage>(function.linkage), + function.ordinal}; + } + + StatusOr<const Function> LookupFunctionByName( + Function::Linkage linkage, absl::string_view name) const override { + IREE_TRACE_SCOPE0("ExternalModule::LookupFunctionByName"); + iree_rt_function_t function; + auto status = impl_.lookup_function_by_name( + impl_.self, static_cast<iree_rt_function_linkage_t>(linkage), + iree_string_view_t{name.data(), name.size()}, &function); + if (status != IREE_STATUS_OK) { + return FromApiStatus(status, IREE_LOC); + } + return Function{reinterpret_cast<Module*>(function.module), + static_cast<Function::Linkage>(function.linkage), + function.ordinal}; + } + + StatusOr<absl::string_view> GetFunctionName(Function::Linkage linkage, + int32_t ordinal) const override { + IREE_TRACE_SCOPE0("ExternalModule::GetFunctionName"); + iree_string_view_t name; + auto status = impl_.get_function_name( + impl_.self, static_cast<iree_rt_function_linkage_t>(linkage), ordinal, + &name); + RETURN_IF_ERROR(FromApiStatus(status, IREE_LOC)); + return absl::string_view{name.data, name.size}; + } + + StatusOr<const FunctionSignature> GetFunctionSignature( + Function::Linkage linkage, int32_t ordinal) const override { + IREE_TRACE_SCOPE0("ExternalModule::GetFunctionSignature"); + iree_rt_function_signature_t signature; + auto status = impl_.get_function_signature( + impl_.self, static_cast<iree_rt_function_linkage_t>(linkage), ordinal, + &signature); + if (status != IREE_STATUS_OK) { + return FromApiStatus(status, IREE_LOC); + } + return FunctionSignature{signature.argument_count, signature.result_count}; + } + + Status Execute( + Stack* stack, const Function function, + absl::InlinedVector<hal::BufferView, 8> arguments, + absl::InlinedVector<hal::BufferView, 8>* results) const override { + // TODO(benvanik): fn ptr callback to external code. Waiting on fibers. + return UnimplementedErrorBuilder(IREE_LOC) + << "External calls not yet implemented"; + } + + private: + iree_rt_external_module_t impl_; + iree_allocator_t allocator_; +}; + +} // namespace + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_module_create_external( + iree_rt_external_module_t impl, iree_allocator_t allocator, + iree_rt_module_t** out_module) { + IREE_TRACE_SCOPE0("iree_rt_module_create_external"); + + if (!out_module) { + return IREE_STATUS_INVALID_ARGUMENT; + } + *out_module = nullptr; + + auto module = make_ref<ExternalModule>(impl, allocator); + *out_module = reinterpret_cast<iree_rt_module_t*>(module.release()); + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_module_retain(iree_rt_module_t* module) { + IREE_TRACE_SCOPE0("iree_rt_module_retain"); + auto* handle = reinterpret_cast<Module*>(module); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + handle->AddReference(); + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_module_release(iree_rt_module_t* module) { + IREE_TRACE_SCOPE0("iree_rt_module_release"); + auto* handle = reinterpret_cast<Module*>(module); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + handle->ReleaseReference(); + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_string_view_t IREE_API_CALL +iree_rt_module_name(const iree_rt_module_t* module) { + IREE_TRACE_SCOPE0("iree_rt_module_name"); + const auto* handle = reinterpret_cast<const Module*>(module); + CHECK(handle) << "NULL module handle"; + return iree_string_view_t{handle->name().data(), handle->name().size()}; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_module_lookup_function_by_ordinal(iree_rt_module_t* module, + iree_rt_function_linkage_t linkage, + int32_t ordinal, + iree_rt_function_t* out_function) { + IREE_TRACE_SCOPE0("iree_rt_module_lookup_function_by_ordinal"); + + if (!out_function) { + return IREE_STATUS_INVALID_ARGUMENT; + } + std::memset(out_function, 0, sizeof(*out_function)); + + auto* handle = reinterpret_cast<Module*>(module); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + + auto function_or = handle->LookupFunctionByOrdinal( + static_cast<Function::Linkage>(linkage), ordinal); + if (!function_or.ok()) { + // Map this invalid argument to not found, per the API spec. + if (IsInvalidArgument(function_or.status())) { + return IREE_STATUS_NOT_FOUND; + } + return ToApiStatus(std::move(function_or).status()); + } + auto function = *function_or; + + out_function->module = module; + out_function->linkage = + static_cast<iree_rt_function_linkage_t>(function.linkage()); + out_function->ordinal = function.ordinal(); + + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_module_lookup_function_by_name(iree_rt_module_t* module, + iree_rt_function_linkage_t linkage, + iree_string_view_t name, + iree_rt_function_t* out_function) { + IREE_TRACE_SCOPE0("iree_rt_module_lookup_function_by_name"); + + if (!out_function) { + return IREE_STATUS_INVALID_ARGUMENT; + } + std::memset(out_function, 0, sizeof(*out_function)); + + auto* handle = reinterpret_cast<Module*>(module); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + + IREE_API_ASSIGN_OR_RETURN( + auto function, + handle->LookupFunctionByName(static_cast<Function::Linkage>(linkage), + absl::string_view{name.data, name.size})); + + out_function->linkage = + static_cast<iree_rt_function_linkage_t>(function.linkage()); + out_function->module = module; + out_function->linkage = linkage; + out_function->ordinal = function.ordinal(); + + return IREE_STATUS_OK; +} + +//===----------------------------------------------------------------------===// +// iree::rt::Function +//===----------------------------------------------------------------------===// + +IREE_API_EXPORT iree_string_view_t IREE_API_CALL +iree_rt_function_name(const iree_rt_function_t* function) { + IREE_TRACE_SCOPE0("iree_rt_function_name"); + CHECK(function && function->module) << "NULL function handle"; + auto* module = reinterpret_cast<Module*>(function->module); + auto name_or = module->GetFunctionName( + static_cast<Function::Linkage>(function->linkage), function->ordinal); + if (!name_or.ok()) return {}; + auto name = name_or.ValueOrDie(); + return iree_string_view_t{name.data(), name.size()}; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_function_signature(const iree_rt_function_t* function, + iree_rt_function_signature_t* out_signature) { + IREE_TRACE_SCOPE0("iree_rt_function_signature"); + + if (!out_signature) { + return IREE_STATUS_INVALID_ARGUMENT; + } + std::memset(out_signature, 0, sizeof(*out_signature)); + + if (!function || !function->module) { + return IREE_STATUS_INVALID_ARGUMENT; + } + + auto* module = reinterpret_cast<Module*>(function->module); + IREE_API_ASSIGN_OR_RETURN( + auto signature, module->GetFunctionSignature( + static_cast<Function::Linkage>(function->linkage), + function->ordinal)); + out_signature->argument_count = signature.argument_count(); + out_signature->result_count = signature.result_count(); + return IREE_STATUS_OK; +} + +//===----------------------------------------------------------------------===// +// iree::rt::Policy +//===----------------------------------------------------------------------===// + +IREE_API_EXPORT iree_status_t iree_rt_policy_create( + iree_allocator_t allocator, iree_rt_policy_t** out_policy) { + IREE_TRACE_SCOPE0("iree_rt_policy_create"); + + if (!out_policy) { + return IREE_STATUS_INVALID_ARGUMENT; + } + *out_policy = nullptr; + + auto policy = make_ref<Policy>(); + + *out_policy = reinterpret_cast<iree_rt_policy_t*>(policy.release()); + + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_policy_retain(iree_rt_policy_t* policy) { + IREE_TRACE_SCOPE0("iree_rt_policy_retain"); + auto* handle = reinterpret_cast<Policy*>(policy); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + handle->AddReference(); + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_policy_release(iree_rt_policy_t* policy) { + IREE_TRACE_SCOPE0("iree_rt_policy_release"); + auto* handle = reinterpret_cast<Policy*>(policy); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + handle->ReleaseReference(); + return IREE_STATUS_OK; +} + +//===----------------------------------------------------------------------===// +// iree::rt::Context +//===----------------------------------------------------------------------===// + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_create( + iree_rt_instance_t* instance, iree_rt_policy_t* policy, + iree_allocator_t allocator, iree_rt_context_t** out_context) { + IREE_TRACE_SCOPE0("iree_rt_context_create"); + + if (!out_context) { + return IREE_STATUS_INVALID_ARGUMENT; + } + *out_context = nullptr; + + if (!instance || !policy) { + return IREE_STATUS_INVALID_ARGUMENT; + } + + auto context = + make_ref<Context>(add_ref(reinterpret_cast<Instance*>(instance)), + add_ref(reinterpret_cast<Policy*>(policy))); + + *out_context = reinterpret_cast<iree_rt_context_t*>(context.release()); + + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_context_retain(iree_rt_context_t* context) { + IREE_TRACE_SCOPE0("iree_rt_context_retain"); + auto* handle = reinterpret_cast<Context*>(context); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + handle->AddReference(); + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_context_release(iree_rt_context_t* context) { + IREE_TRACE_SCOPE0("iree_rt_context_release"); + auto* handle = reinterpret_cast<Context*>(context); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + handle->ReleaseReference(); + return IREE_STATUS_OK; +} + +IREE_API_EXPORT int32_t IREE_API_CALL +iree_rt_context_id(const iree_rt_context_t* context) { + IREE_TRACE_SCOPE0("iree_rt_context_id"); + const auto* handle = reinterpret_cast<const Context*>(context); + CHECK(handle) << "NULL context handle"; + return handle->id(); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_register_modules( + iree_rt_context_t* context, iree_rt_module_t** modules, + iree_host_size_t module_count) { + IREE_TRACE_SCOPE0("iree_rt_context_register_modules"); + auto* handle = reinterpret_cast<Context*>(context); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + if (module_count && !modules) { + return IREE_STATUS_INVALID_ARGUMENT; + } + for (size_t i = 0; i < module_count; ++i) { + auto* module = reinterpret_cast<Module*>(modules[i]); + if (!module) { + return IREE_STATUS_INVALID_ARGUMENT; + } + IREE_API_RETURN_IF_ERROR(handle->RegisterModule(add_ref(module))); + } + + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_rt_module_t* IREE_API_CALL +iree_rt_context_lookup_module_by_name(const iree_rt_context_t* context, + iree_string_view_t module_name) { + IREE_TRACE_SCOPE0("iree_rt_context_lookup_module_by_name"); + const auto* handle = reinterpret_cast<const Context*>(context); + CHECK(handle) << "NULL context handle"; + auto module_or = handle->LookupModuleByName( + absl::string_view{module_name.data, module_name.size}); + if (!module_or.ok()) { + return nullptr; + } + return reinterpret_cast<iree_rt_module_t*>(module_or.ValueOrDie()); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_resolve_function( + const iree_rt_context_t* context, iree_string_view_t full_name, + iree_rt_function_t* out_function) { + IREE_TRACE_SCOPE0("iree_rt_context_resolve_function"); + + if (!out_function) { + return IREE_STATUS_INVALID_ARGUMENT; + } + std::memset(out_function, 0, sizeof(*out_function)); + + const auto* handle = reinterpret_cast<const Context*>(context); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + + auto full_name_view = absl::string_view{full_name.data, full_name.size}; + size_t last_dot = full_name_view.rfind('.'); + if (last_dot == absl::string_view::npos) { + return IREE_STATUS_INVALID_ARGUMENT; + } + auto module_name = full_name_view.substr(0, last_dot); + auto function_name = full_name_view.substr(last_dot + 1); + + iree_rt_module_t* module = iree_rt_context_lookup_module_by_name( + context, iree_string_view_t{module_name.data(), module_name.size()}); + if (!module) { + return IREE_STATUS_NOT_FOUND; + } + + return iree_rt_module_lookup_function_by_name( + module, IREE_RT_FUNCTION_LINKAGE_EXPORT, + iree_string_view_t{function_name.data(), function_name.size()}, + out_function); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_context_allocate_device_visible_buffer( + iree_rt_context_t* context, iree_hal_buffer_usage_t buffer_usage, + iree_host_size_t allocation_size, iree_allocator_t allocator, + iree_hal_buffer_t** out_buffer) { + IREE_TRACE_SCOPE0("iree_rt_context_allocate_device_visible_buffer"); + + if (!out_buffer) { + return IREE_STATUS_INVALID_ARGUMENT; + } + std::memset(out_buffer, 0, sizeof(*out_buffer)); + + const auto* handle = reinterpret_cast<const Context*>(context); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } else if (!allocation_size) { + return IREE_STATUS_INVALID_ARGUMENT; + } + + // TODO(benvanik): reroute to context based on current policy. + auto* device_manager = handle->instance()->device_manager(); + IREE_API_ASSIGN_OR_RETURN(auto device_placement, + device_manager->ResolvePlacement({})); + IREE_API_ASSIGN_OR_RETURN(auto buffer, + device_manager->AllocateDeviceVisibleBuffer( + static_cast<hal::BufferUsage>(buffer_usage), + allocation_size, {device_placement})); + + *out_buffer = reinterpret_cast<iree_hal_buffer_t*>(buffer.release()); + + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_context_allocate_device_local_buffer( + iree_rt_context_t* context, iree_hal_buffer_usage_t buffer_usage, + iree_host_size_t allocation_size, iree_allocator_t allocator, + iree_hal_buffer_t** out_buffer) { + IREE_TRACE_SCOPE0("iree_rt_context_allocate_device_local_buffer"); + + if (!out_buffer) { + return IREE_STATUS_INVALID_ARGUMENT; + } + std::memset(out_buffer, 0, sizeof(*out_buffer)); + + const auto* handle = reinterpret_cast<const Context*>(context); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } else if (!allocation_size) { + return IREE_STATUS_INVALID_ARGUMENT; + } + + // TODO(benvanik): reroute to context based on current policy. + auto* device_manager = handle->instance()->device_manager(); + IREE_API_ASSIGN_OR_RETURN(auto device_placement, + device_manager->ResolvePlacement({})); + IREE_API_ASSIGN_OR_RETURN(auto buffer, + device_manager->AllocateDeviceLocalBuffer( + static_cast<hal::BufferUsage>(buffer_usage), + allocation_size, {device_placement})); + + *out_buffer = reinterpret_cast<iree_hal_buffer_t*>(buffer.release()); + + return IREE_STATUS_OK; +} + +//===----------------------------------------------------------------------===// +// iree::rt::Invocation +//===----------------------------------------------------------------------===// + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_create( + iree_rt_context_t* context, iree_rt_function_t* function, + iree_rt_policy_t* policy, + const iree_rt_invocation_dependencies_t* dependencies, + iree_hal_buffer_view_t** arguments, iree_host_size_t argument_count, + iree_hal_buffer_view_t** results, iree_host_size_t result_count, + iree_allocator_t allocator, iree_rt_invocation_t** out_invocation) { + IREE_TRACE_SCOPE0("iree_rt_invocation_create"); + + if (!out_invocation) { + return IREE_STATUS_INVALID_ARGUMENT; + } + *out_invocation = nullptr; + + if (!context || !function || !function->module) { + return IREE_STATUS_INVALID_ARGUMENT; + } else if (dependencies && + (dependencies->invocation_count && !dependencies->invocations)) { + return IREE_STATUS_INVALID_ARGUMENT; + } else if ((argument_count && !arguments) || (result_count && !results)) { + return IREE_STATUS_INVALID_ARGUMENT; + } + + // TODO(benvanik): unwrap without needing to retain here. + absl::InlinedVector<ref_ptr<Invocation>, 4> dependent_invocations; + if (dependencies) { + dependent_invocations.resize(dependencies->invocation_count); + for (int i = 0; i < dependencies->invocation_count; ++i) { + dependent_invocations[i] = + add_ref(reinterpret_cast<Invocation*>(dependencies->invocations[i])); + } + } + + // TODO(benvanik): unwrap without needing to retain here. + absl::InlinedVector<hal::BufferView, 8> argument_views(argument_count); + for (int i = 0; i < argument_count; ++i) { + const auto* api_buffer_view = + reinterpret_cast<const hal::iree_hal_buffer_view*>(arguments[i]); + if (!api_buffer_view) { + return IREE_STATUS_INVALID_ARGUMENT; + } + argument_views[i] = hal::BufferView{add_ref(api_buffer_view->impl.buffer), + api_buffer_view->impl.shape, + api_buffer_view->impl.element_size}; + } + + // TODO(benvanik): unwrap without needing to retain here. + absl::InlinedVector<hal::BufferView, 8> result_views(result_count); + for (int i = 0; i < result_count; ++i) { + const auto* api_buffer_view = + reinterpret_cast<const hal::iree_hal_buffer_view*>(results[i]); + if (api_buffer_view) { + result_views[i] = hal::BufferView{add_ref(api_buffer_view->impl.buffer), + api_buffer_view->impl.shape, + api_buffer_view->impl.element_size}; + } + } + + IREE_API_ASSIGN_OR_RETURN( + auto invocation, + Invocation::Create( + add_ref(reinterpret_cast<Context*>(context)), + Function{reinterpret_cast<Module*>(function->module), + static_cast<Function::Linkage>(function->linkage), + function->ordinal}, + add_ref(reinterpret_cast<Policy*>(policy)), + std::move(dependent_invocations), std::move(argument_views), + result_views.empty() + ? absl::optional<absl::InlinedVector<hal::BufferView, 8>>( + absl::nullopt) + : std::move(result_views))); + + *out_invocation = + reinterpret_cast<iree_rt_invocation_t*>(invocation.release()); + + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_invocation_retain(iree_rt_invocation_t* invocation) { + IREE_TRACE_SCOPE0("iree_rt_invocation_retain"); + auto* handle = reinterpret_cast<Invocation*>(invocation); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + handle->AddReference(); + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_invocation_release(iree_rt_invocation_t* invocation) { + IREE_TRACE_SCOPE0("iree_rt_invocation_release"); + auto* handle = reinterpret_cast<Invocation*>(invocation); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + handle->ReleaseReference(); + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_invocation_query_status(iree_rt_invocation_t* invocation) { + IREE_TRACE_SCOPE0("iree_rt_invocation_query_status"); + auto* handle = reinterpret_cast<Invocation*>(invocation); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + IREE_API_RETURN_IF_ERROR(handle->QueryStatus()); + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_consume_results( + iree_rt_invocation_t* invocation, iree_host_size_t result_capacity, + iree_allocator_t allocator, iree_hal_buffer_view_t** out_results, + iree_host_size_t* out_result_count) { + IREE_TRACE_SCOPE0("iree_rt_invocation_consume_results"); + + if (!out_result_count) { + return IREE_STATUS_INVALID_ARGUMENT; + } + *out_result_count = 0; + if (!out_results) { + std::memset(out_results, 0, + sizeof(iree_hal_buffer_view_t*) * result_capacity); + } + + auto* handle = reinterpret_cast<Invocation*>(invocation); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + + const auto& function = handle->function(); + int32_t result_count = function.signature().result_count(); + *out_result_count = result_count; + if (!out_results) { + return IREE_STATUS_OK; + } else if (result_capacity < result_count) { + return IREE_STATUS_OUT_OF_RANGE; + } + + IREE_API_ASSIGN_OR_RETURN(auto results, handle->ConsumeResults()); + iree_status_t status = IREE_STATUS_OK; + int i = 0; + for (i = 0; i < results.size(); ++i) { + iree_shape_t shape; + status = ToApiShape(results[i].shape, &shape); + if (status != IREE_STATUS_OK) break; + status = iree_hal_buffer_view_create( + reinterpret_cast<iree_hal_buffer_t*>(results[i].buffer.get()), shape, + results[i].element_size, allocator, &out_results[i]); + if (status != IREE_STATUS_OK) break; + } + if (status != IREE_STATUS_OK) { + // Release already-retained buffer views on failure. + for (int j = 0; j < i; ++j) { + iree_hal_buffer_view_release(out_results[j]); + out_results[j] = nullptr; + } + } + return status; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_await( + iree_rt_invocation_t* invocation, iree_time_t deadline) { + IREE_TRACE_SCOPE0("iree_rt_invocation_await"); + auto* handle = reinterpret_cast<Invocation*>(invocation); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + IREE_API_RETURN_IF_ERROR(handle->Await(ToAbslTime(deadline))); + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_invocation_abort(iree_rt_invocation_t* invocation) { + IREE_TRACE_SCOPE0("iree_rt_invocation_abort"); + auto* handle = reinterpret_cast<Invocation*>(invocation); + if (!handle) { + return IREE_STATUS_INVALID_ARGUMENT; + } + IREE_API_RETURN_IF_ERROR(handle->Abort()); + return IREE_STATUS_OK; +} + +} // namespace rt +} // namespace iree
diff --git a/rt/api.h b/rt/api.h new file mode 100644 index 0000000..5bc57c4 --- /dev/null +++ b/rt/api.h
@@ -0,0 +1,400 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// See iree/base/api.h for documentation on the API conventions used. + +#ifndef IREE_RT_API_H_ +#define IREE_RT_API_H_ + +#include <stdint.h> + +#include "base/api.h" +#include "hal/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +//===----------------------------------------------------------------------===// +// Types and Enums +//===----------------------------------------------------------------------===// + +typedef struct iree_rt_instance iree_rt_instance_t; +typedef struct iree_rt_context iree_rt_context_t; +typedef struct iree_rt_policy iree_rt_policy_t; +typedef struct iree_rt_module iree_rt_module_t; +typedef struct iree_rt_invocation iree_rt_invocation_t; + +// Describes the type of a function reference. +typedef enum { + // Function is internal to the module and may not be reflectable. + IREE_RT_FUNCTION_LINKAGE_INTERNAL = 0, + // Function is an import from another module. + IREE_RT_FUNCTION_LINKAGE_IMPORT = 1, + // Function is an export from the module. + IREE_RT_FUNCTION_LINKAGE_EXPORT = 2, +} iree_rt_function_linkage_t; + +// A function reference that can be used with the iree_rt_function_* methods. +// These should be treated as opaque and the accessor functions should be used +// instead. +typedef struct { + // Module the function is contained within. + iree_rt_module_t* module; + // Linkage of the function. Note that IREE_RT_FUNCTION_LINKAGE_INTERNAL + // functions may be missing reflection information. + iree_rt_function_linkage_t linkage; + // Ordinal within the module in the linkage scope. + int32_t ordinal; +} iree_rt_function_t; + +// Describes the expected calling convention and arguments/results of a +// function. +typedef struct { + // Total number of arguments to the function. + int32_t argument_count; + // Total number of results from the function. + int32_t result_count; +} iree_rt_function_signature_t; + +// Describes the imports, exports, and capabilities of a module. +typedef struct { + // Total number of imported functions. + int32_t import_function_count; + // Total number of exported functions. + int32_t export_function_count; + // Total number of internal functions, if debugging info is present and they + // can be queried. + int32_t internal_function_count; + // Total number of state block resource slots consumed. + int32_t state_slot_count; +} iree_rt_module_signature_t; + +// Dependency information used to order invocations. +typedef struct { + // Prior invocations that must complete before the new invocation begins. + iree_rt_invocation_t** invocations; + iree_host_size_t invocation_count; + + // TODO(benvanik): wait semaphores/importing. +} iree_rt_invocation_dependencies_t; + +// Defines an external module that can be used to reflect and execute functions. +// Modules must be thread-safe as lookups and executions may occur in any order +// from any thread. +// +// Modules will have their resolve_imports function called upon registration +// with a context and may use the provided resolver to find imported functions. +typedef struct { + // User-defined pointer passed to all functions. + void* self; + // Destroys |self| when all references to the module have been released. + iree_status_t(IREE_API_PTR* destroy)(void* self); + // Returns the name of the module (used during resolution). + iree_string_view_t(IREE_API_PTR* name)(void* self); + // Sets |out_module_signature| to the reflected signature of the module. + iree_rt_module_signature_t(IREE_API_PTR* signature)(void* self); + // Sets |out_function| to a resolved function by ordinal, if found. + iree_status_t(IREE_API_PTR* lookup_function_by_ordinal)( + void* self, iree_rt_function_linkage_t linkage, int32_t ordinal, + iree_rt_function_t* out_function); + // Sets |out_function| to a resolved function by name, if found. + iree_status_t(IREE_API_PTR* lookup_function_by_name)( + void* self, iree_rt_function_linkage_t linkage, iree_string_view_t name, + iree_rt_function_t* out_function); + // Sets |out_name| to the name of the function with the given ordinal, if + // found. + iree_status_t(IREE_API_PTR* get_function_name)( + void* self, iree_rt_function_linkage_t linkage, int32_t ordinal, + iree_string_view_t* out_name); + // Sets |out_signature| to the reflected signature of the given + // function, if found. + iree_status_t(IREE_API_PTR* get_function_signature)( + void* self, iree_rt_function_linkage_t linkage, int32_t ordinal, + iree_rt_function_signature_t* out_signature); +} iree_rt_external_module_t; + +//===----------------------------------------------------------------------===// +// iree::rt::Instance +//===----------------------------------------------------------------------===// + +#ifndef IREE_API_NO_PROTOTYPES + +// Creates a new instance. This should be shared with all contexts in an +// application to ensure that resources are tracked properly and threads are +// managed correctly. +// |out_instance| must be released by the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_instance_create( + iree_allocator_t allocator, iree_rt_instance_t** out_instance); + +// Retains the given |instance| for the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_instance_retain(iree_rt_instance_t* instance); + +// Releases the given |instance| from the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_instance_release(iree_rt_instance_t* instance); + +// TEMPORARY: until policies and placement are performed this can be used to +// explicitly create and register drivers by name. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_instance_register_driver_ex( + iree_rt_instance_t* instance, iree_string_view_t driver_name); + +#endif // IREE_API_NO_PROTOTYPES + +//===----------------------------------------------------------------------===// +// iree::rt::Module +//===----------------------------------------------------------------------===// + +#ifndef IREE_API_NO_PROTOTYPES + +// Creates a module with an external backing implementation. +// The provided |external_module| definition will be used to query the module +// state as needed. No caching occurs within the implementation to allow calls +// to return different values per-invocation. +// +// |out_module| must be released by the caller. +// iree_rt_external_module_t::destroy is called when the last reference to the +// iree_rt_module_t is released. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_module_create_external( + iree_rt_external_module_t impl, iree_allocator_t allocator, + iree_rt_module_t** out_module); + +// Retains the given |module| for the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_module_retain(iree_rt_module_t* module); + +// Releases the given |module| from the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_module_release(iree_rt_module_t* module); + +// Returns the name of the module. +IREE_API_EXPORT iree_string_view_t IREE_API_CALL +iree_rt_module_name(const iree_rt_module_t* module); + +// Sets |out_function| to a function with |ordinal| in the given linkage or +// returns IREE_STATUS_NOT_FOUND. The function reference is valid for the +// lifetime of |module|. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_module_lookup_function_by_ordinal(iree_rt_module_t* module, + iree_rt_function_linkage_t linkage, + int32_t ordinal, + iree_rt_function_t* out_function); + +// Sets |out_function| to a function with |name| in the given linkage or returns +// IREE_STATUS_NOT_FOUND. The function reference is valid for the lifetime of +// |module|. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_module_lookup_function_by_name(iree_rt_module_t* module, + iree_rt_function_linkage_t linkage, + iree_string_view_t name, + iree_rt_function_t* out_function); + +#endif // IREE_API_NO_PROTOTYPES + +//===----------------------------------------------------------------------===// +// iree::rt::Function +//===----------------------------------------------------------------------===// + +#ifndef IREE_API_NO_PROTOTYPES + +// Returns the name of the function as exported from the module. +IREE_API_EXPORT iree_string_view_t IREE_API_CALL +iree_rt_function_name(const iree_rt_function_t* function); + +// Sets |out_function_signature| to the reflected signature of the function. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_function_signature(const iree_rt_function_t* function, + iree_rt_function_signature_t* out_signature); + +#endif // IREE_API_NO_PROTOTYPES + +//===----------------------------------------------------------------------===// +// iree::rt::Policy +//===----------------------------------------------------------------------===// + +#ifndef IREE_API_NO_PROTOTYPES + +// TODO(benvanik): define policies. For now they are no-ops. +IREE_API_EXPORT iree_status_t iree_rt_policy_create( + iree_allocator_t allocator, iree_rt_policy_t** out_policy); + +// Retains the given |policy| for the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_policy_retain(iree_rt_policy_t* policy); + +// Releases the given |policy| from the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_policy_release(iree_rt_policy_t* policy); + +#endif // IREE_API_NO_PROTOTYPES + +//===----------------------------------------------------------------------===// +// iree::rt::Context +//===----------------------------------------------------------------------===// + +#ifndef IREE_API_NO_PROTOTYPES + +// Creates a new context that uses the given |instance| for device management. +// |out_context| must be released by the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_create( + iree_rt_instance_t* instance, iree_rt_policy_t* policy, + iree_allocator_t allocator, iree_rt_context_t** out_context); + +// Retains the given |context| for the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_context_retain(iree_rt_context_t* context); + +// Releases the given |context| from the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_context_release(iree_rt_context_t* context); + +// Returns a process-unique ID for the |context|. +IREE_API_EXPORT int32_t IREE_API_CALL +iree_rt_context_id(const iree_rt_context_t* context); + +// Registers a list of modules with the context and resolves imports. +// The modules will be retained by the context until destruction. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_register_modules( + iree_rt_context_t* context, iree_rt_module_t** modules, + iree_host_size_t module_count); + +// Returns a reference to the module registered with the given name or nullptr +// if not found. The caller must retain the returned module if they want to +// continue using it. +IREE_API_EXPORT iree_rt_module_t* IREE_API_CALL +iree_rt_context_lookup_module_by_name(const iree_rt_context_t* context, + iree_string_view_t module_name); + +// Sets |out_function| to to an exported function with the fully-qualified name +// of |full_name| or returns IREE_STATUS_NOT_FOUND. The function reference is +// valid for the lifetime of |context|. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_context_resolve_function( + const iree_rt_context_t* context, iree_string_view_t full_name, + iree_rt_function_t* out_function); + +// Allocates a host-local buffer that is optimal for use on the host but is +// usable by the given |device_placements| (at a possible performance +// penalty). The buffer can be used for staging uploads to device-local +// buffers and is useful for times when the buffer will be used more on the +// host than the device. If a buffer never needs to be used with a device +// prefer instead HeapBuffer::Allocate. +// +// Fails if it is not possible to allocate and satisfy all placements for the +// requested |buffer_usage|. +// |out_buffer| must be released by the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_context_allocate_device_visible_buffer( + iree_rt_context_t* context, iree_hal_buffer_usage_t buffer_usage, + iree_host_size_t allocation_size, iree_allocator_t allocator, + iree_hal_buffer_t** out_buffer); + +// Allocates a device-local buffer that is optimal for use with the given +// |device_placements|. The buffer will not be host-visible and can only be +// used from compatible device queues. +// +// Fails if it is not possible to allocate and satisfy all placements for the +// requested |buffer_usage|. +// |out_buffer| must be released by the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_context_allocate_device_local_buffer( + iree_rt_context_t* context, iree_hal_buffer_usage_t buffer_usage, + iree_host_size_t allocation_size, iree_allocator_t allocator, + iree_hal_buffer_t** out_buffer); + +#endif // IREE_API_NO_PROTOTYPES + +//===----------------------------------------------------------------------===// +// iree::rt::Invocation +//===----------------------------------------------------------------------===// + +#ifndef IREE_API_NO_PROTOTYPES + +// Creates a new invocation tracking object for invoking the given |function| +// from |context|. |arguments| will be retained until the invocation is made. +// If |dependencies| are provided then the invocation will wait until they are +// resolved before executing. If a |policy| is provided it will override the +// context-level policy. +// +// Optionally |results| may be provided with preallocated buffers that will +// receive the outputs of the invocation. Invocation will fail if they do not +// match expected sizes. +// +// Note that it's possible for the invocation to complete prior to the return of +// this function. Any errors that occur will be set on the invocation and +// callers should query its state prior to assuming it is in-flight. +// +// |out_invocation| must be released by the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_create( + iree_rt_context_t* context, iree_rt_function_t* function, + iree_rt_policy_t* policy, + const iree_rt_invocation_dependencies_t* dependencies, + iree_hal_buffer_view_t** arguments, iree_host_size_t argument_count, + iree_hal_buffer_view_t** results, iree_host_size_t result_count, + iree_allocator_t allocator, iree_rt_invocation_t** out_invocation); + +// Retains the given |invocation| for the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_invocation_retain(iree_rt_invocation_t* invocation); + +// Releases the given |invocation| from the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_invocation_release(iree_rt_invocation_t* invocation); + +// Queries the completion status of the invocation. +// Returns one of the following: +// IREE_STATUS_OK: the invocation completed successfully. +// IREE_STATUS_UNAVAILABLE: the invocation has not yet completed. +// IREE_STATUS_CANCELLED: the invocation was cancelled internally. +// IREE_STATUS_ABORTED: the invocation was aborted. +// IREE_STATUS_*: an error occurred during invocation. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_invocation_query_status(iree_rt_invocation_t* invocation); + +// Populates |out_results| to the values of the results. +// |result_capacity| defines the number of elements available in |out_results| +// and |out_result_count| will be set with the actual number of results +// available. If |result_capacity| is too small IREE_STATUS_OUT_OF_RANGE will be +// returned wtih the required capacity in |out_result_count|. To only query the +// required capacity |out_results| may be passed as nullptr. +// +// Ownership of returned results will be transferred to the caller and they must +// be released if no longer needed. +// +// Returns errors as with iree_rt_invocation_query_status, for example in the +// case of not-yet-completed or aborted invocations. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_consume_results( + iree_rt_invocation_t* invocation, iree_host_size_t result_capacity, + iree_allocator_t allocator, iree_hal_buffer_view_t** out_results, + iree_host_size_t* out_result_count); + +// Blocks the caller until the invocation completes (successfully or otherwise). +// +// Returns IREE_STATUS_DEADLINE_EXCEEDED if |deadline| elapses before the +// invocation completes and otherwise returns iree_rt_invocation_query_status. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_rt_invocation_await( + iree_rt_invocation_t* invocation, iree_time_t deadline); + +// Attempts to abort the invocation if it is in-flight. +// A no-op if the invocation has already completed. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_rt_invocation_abort(iree_rt_invocation_t* invocation); + +#endif // IREE_API_NO_PROTOTYPES + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_RT_API_H_
diff --git a/rt/context.cc b/rt/context.cc new file mode 100644 index 0000000..051deca --- /dev/null +++ b/rt/context.cc
@@ -0,0 +1,167 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rt/context.h" + +#include <atomic> + +#include "absl/strings/str_cat.h" +#include "base/status.h" +#include "base/tracing.h" +#include "rt/debug/debug_server.h" +#include "rt/instance.h" +#include "rt/invocation.h" + +namespace iree { +namespace rt { + +namespace { + +int32_t NextUniqueContextId() { + static std::atomic<int32_t> next_id = {0}; + return ++next_id; +} + +} // namespace + +Context::Context(ref_ptr<Instance> instance, ref_ptr<Policy> policy) + : id_(NextUniqueContextId()), + instance_(std::move(instance)), + policy_(std::move(policy)) { + IREE_TRACE_SCOPE("Context::ctor", int32_t)(id_); + instance_->RegisterContext(this); +} + +Context::~Context() { + IREE_TRACE_SCOPE("Context::dtor", int32_t)(id_); + instance_->UnregisterContext(this); +} + +std::string Context::DebugStringShort() const { + return absl::StrCat("context_", id_); +} + +Status Context::RegisterModule(ref_ptr<Module> module) { + IREE_TRACE_SCOPE0("Context::RegisterModule"); + + // Ensure no conflicts in naming - we don't support shadowing. + for (const auto& existing_module : modules_) { + if (existing_module->name() == module->name()) { + return FailedPreconditionErrorBuilder(IREE_LOC) + << "Module '" << module->name() + << "' has already been registered in the context"; + } + } + + // Try resolving prior to actually registering; if we can't resolve an import + // then we want to fail the entire registration. + ASSIGN_OR_RETURN(auto import_table, ResolveImports(module.get())); + + auto* debug_server = instance_->debug_server(); + if (debug_server) { + CHECK_OK(debug_server->RegisterContextModule(this, module.get())); + } + + modules_.push_back(std::move(module)); + module_import_tables_.push_back(std::move(import_table)); + return OkStatus(); +} + +StatusOr<ModuleImportTable> Context::ResolveImports(Module* module) { + IREE_TRACE_SCOPE0("Context::ResolveImports"); + + int32_t import_count = module->signature().import_function_count(); + ModuleImportTable import_table; + import_table.first = module; + import_table.second.resize(import_count); + + for (int32_t i = 0; i < import_count; ++i) { + ASSIGN_OR_RETURN(auto import_function_name, + module->GetFunctionName(Function::Linkage::kImport, i)); + ASSIGN_OR_RETURN(import_table.second[i], + ResolveFunction(import_function_name)); + } + + return import_table; +} + +StatusOr<Module*> Context::LookupModuleByName( + absl::string_view module_name) const { + for (const auto& module : modules_) { + if (module->name() == module_name) { + return module.get(); + } + } + return NotFoundErrorBuilder(IREE_LOC) + << "No module with the name '" << module_name + << "' has been registered"; +} + +StatusOr<const Function> Context::ResolveFunction( + absl::string_view full_name) const { + size_t last_dot = full_name.rfind('.'); + if (last_dot == absl::string_view::npos) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "'" << full_name + << "' is not fully qualified (expected 'module.function')"; + } + auto module_name = full_name.substr(0, last_dot); + auto function_name = full_name.substr(last_dot + 1); + ASSIGN_OR_RETURN(auto* module, LookupModuleByName(module_name)); + return module->LookupFunctionByName(Function::Linkage::kExport, + function_name); +} + +StatusOr<const Function> Context::ResolveImport(const Module* module, + int32_t ordinal) const { + for (const auto& import_table_ref : module_import_tables_) { + if (import_table_ref.first == module) { + const auto& import_table = import_table_ref.second; + if (ordinal >= import_table.size()) { + return NotFoundErrorBuilder(IREE_LOC) + << "Import ordinal " << ordinal + << " out of bounds of import table (" << import_table.size() + << ")"; + } + return import_table[ordinal]; + } + } + return NotFoundErrorBuilder(IREE_LOC) + << "Import ordinal " << ordinal << " not found"; +} + +void Context::RegisterInvocation(Invocation* invocation) { + { + absl::MutexLock lock(&invocations_mutex_); + invocations_.push_back(invocation); + } + auto* debug_server = instance_->debug_server(); + if (debug_server) { + CHECK_OK(debug_server->RegisterInvocation(invocation)); + } +} + +void Context::UnregisterInvocation(Invocation* invocation) { + auto* debug_server = instance_->debug_server(); + if (debug_server) { + CHECK_OK(debug_server->UnregisterInvocation(invocation)); + } + { + absl::MutexLock lock(&invocations_mutex_); + invocations_.erase(invocation); + } +} + +} // namespace rt +} // namespace iree
diff --git a/rt/context.h b/rt/context.h new file mode 100644 index 0000000..ecda38b --- /dev/null +++ b/rt/context.h
@@ -0,0 +1,119 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_RT_CONTEXT_H_ +#define IREE_RT_CONTEXT_H_ + +#include <ostream> + +#include "absl/base/thread_annotations.h" +#include "absl/container/inlined_vector.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/optional.h" +#include "base/intrusive_list.h" +#include "base/ref_ptr.h" +#include "base/status.h" +#include "hal/buffer_view.h" +#include "rt/invocation.h" +#include "rt/module.h" +#include "rt/policy.h" + +namespace iree { +namespace rt { + +class Instance; + +using ModuleImportTable = std::pair<Module*, std::vector<Function>>; + +// An isolated execution context. +// Effectively a sandbox where modules can be loaded and run with restricted +// visibility and where they can maintain state. +// +// Modules have imports resolved automatically when registered by searching +// existing modules registered within the context and load order is used for +// resolution. For example, target-specific modules should be loaded prior to +// generic modules that may import functions defined there and if a function is +// not available in the target-specific modules the fallback provided by the +// generic module will be used. +// +// Thread-compatible and must be externally synchronized. +class Context final : public RefObject<Context> { + public: + Context(ref_ptr<Instance> instance, ref_ptr<Policy> policy); + ~Context(); + + // A process-unique ID for the context. + int32_t id() const { return id_; } + + // Instance this context uses for shared resources. + const ref_ptr<Instance>& instance() const { return instance_; } + + // A short human-readable name for the context. + std::string DebugStringShort() const; + + // A list of modules registered with the context. + absl::Span<const ref_ptr<Module>> modules() const { + return absl::MakeConstSpan(modules_); + } + + // Registers a new module with the context. + // Imports from the module will be resolved using the existing modules in the + // context. The module will be retained by the context until destruction. + Status RegisterModule(ref_ptr<Module> module); + + // Looks up a module by name. + StatusOr<Module*> LookupModuleByName(absl::string_view module_name) const; + + // Resolves an exported function by fully-qualified name. The function + // reference is valid for the lifetime of the context. + StatusOr<const Function> ResolveFunction(absl::string_view full_name) const; + + // Resolves an imported function by import ordinal. The function reference is + // valid for the lifetime of the context. + StatusOr<const Function> ResolveImport(const Module* module, + int32_t ordinal) const; + + private: + // Resolves imports for the given module. + StatusOr<ModuleImportTable> ResolveImports(Module* module); + + friend class Invocation; + void RegisterInvocation(Invocation* invocation); + void UnregisterInvocation(Invocation* invocation); + + int32_t id_; + ref_ptr<Instance> instance_; + ref_ptr<Policy> policy_; + + absl::InlinedVector<ref_ptr<Module>, 4> modules_; + absl::InlinedVector<ModuleImportTable, 4> module_import_tables_; + + absl::Mutex invocations_mutex_; + IntrusiveList<Invocation, offsetof(Invocation, context_list_link_)> + invocations_ ABSL_GUARDED_BY(invocations_mutex_); + + friend class Instance; + IntrusiveListLink instance_list_link_; +}; + +inline std::ostream& operator<<(std::ostream& stream, const Context& context) { + stream << context.DebugStringShort(); + return stream; +} + +} // namespace rt +} // namespace iree + +#endif // IREE_RT_CONTEXT_H_
diff --git a/rt/debug/BUILD b/rt/debug/BUILD new file mode 100644 index 0000000..8f12b5c --- /dev/null +++ b/rt/debug/BUILD
@@ -0,0 +1,192 @@ +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +# TODO(benvanik): re-enable debugger after refactoring. +# cc_library( +# name = "debug_client", +# srcs = ["debug_client.cc"], +# hdrs = ["debug_client.h"], +# deps = [ +# ":debug_client_interface", +# ":debug_client_tcp", # build-cleaner: keep +# "@com_google_absl//absl/container:flat_hash_map", +# "@com_google_absl//absl/strings", +# "@com_google_absl//absl/types:optional", +# "@com_google_absl//absl/types:span", +# "///base:source_location", +# "///base:status", +# "///schemas", +# ], +# ) +# +# cc_library( +# name = "debug_client_interface", +# hdrs = ["debug_client.h"], +# deps = [ +# "@com_google_absl//absl/container:flat_hash_map", +# "@com_google_absl//absl/strings", +# "@com_google_absl//absl/types:optional", +# "@com_google_absl//absl/types:span", +# "///base:status", +# "///schemas", +# ], +# ) +# +# cc_library( +# name = "debug_client_tcp", +# srcs = ["debug_client_tcp.cc"], +# deps = [ +# ":debug_client_interface", +# ":debug_tcp_util", +# "@com_google_absl//absl/container:flat_hash_map", +# "@com_google_absl//absl/memory", +# "@com_google_absl//absl/strings", +# "@com_google_absl//absl/types:span", +# "@com_github_google_flatbuffers//:flatbuffers", +# "///base:flatbuffer_util", +# "///base:status", +# "///rt", +# "///schemas", +# ], +# ) +# +# cc_library( +# name = "debug_server", +# hdrs = ["debug_server.h"], +# deps = [ +# ":debug_server_interface", +# "//third_party/flatbuffers:flatbuffers", +# "///schemas", +# "///base:status", +# ] + select({ +# "//:debug": [":debug_server_tcp"], +# "//conditions:default": [":debug_server_disabled"], +# }), +# ) + +cc_library( + name = "debug_server", + hdrs = ["debug_server.h"], + deps = [ + ":debug_server_disabled", + ":debug_server_interface", + "///base:status", + ], +) + +cc_library( + name = "debug_server_interface", + hdrs = ["debug_server.h"], + deps = ["///base:status"], +) + +cc_library( + name = "debug_server_disabled", + srcs = ["debug_server_disabled.cc"], + deps = [ + ":debug_server_interface", + "@com_google_absl//absl/memory", + ], +) + +# TODO(benvanik): re-enable debugger after refactoring. +# cc_library( +# name = "debug_server_tcp", +# srcs = ["debug_server_tcp.cc"], +# deps = [ +# ":debug_server_interface", +# ":debug_service", +# ":debug_tcp_util", +# "@com_google_absl//absl/base:core_headers", +# "@com_google_absl//absl/memory", +# "@com_google_absl//absl/synchronization", +# "@com_github_google_flatbuffers//:flatbuffers", +# "///base:status", +# "///schemas", +# ], +# ) + +cc_library( + name = "debug_server_flags", + srcs = ["debug_server_flags.cc"], + hdrs = ["debug_server_flags.h"], + deps = [ + ":debug_server", + "///base:memory", + "///base:status", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/strings", + ], +) + +# TODO(benvanik): re-enable debugger after refactoring. +# cc_library( +# name = "debug_server_flags", +# srcs = ["debug_server_flags.cc"], +# hdrs = ["debug_server_flags.h"], +# copts = select({ +# "//:debug": [ +# "-DIREE_DEBUG_EMBEDDED_APP_PRESENT=1", +# ], +# "//conditions:default": [], +# }), +# deps = [ +# ":debug_server", +# "@com_google_absl//absl/flags:flag", +# "@com_google_absl//absl/strings", +# "///base:memory", +# "///base:status", +# ] + select({ +# "//:debug": [ +# "///tools/debugger:debug_app_embedded", +# "//third_party/GL/native:EGL", # build-cleaner: keep +# "//third_party/GL/native:GLESv2", # build-cleaner: keep +# ], +# "//conditions:default": [], +# }), +# ) +# +# cc_library( +# name = "debug_service", +# srcs = ["debug_service.cc"], +# hdrs = ["debug_service.h"], +# deps = [ +# ":debug_session", +# "@com_google_absl//absl/base:core_headers", +# "@com_google_absl//absl/strings", +# "@com_google_absl//absl/synchronization", +# "@com_github_google_flatbuffers//:flatbuffers", +# "///base:flatbuffer_util", +# "///base:source_location", +# "///base:status", +# "///rt", +# "///schemas", +# "///schemas:reflection_data", +# ], +# ) +# +# cc_library( +# name = "debug_session", +# srcs = ["debug_session.cc"], +# hdrs = ["debug_session.h"], +# deps = [ +# "@com_google_absl//absl/base:core_headers", +# "@com_google_absl//absl/synchronization", +# "///base:source_location", +# "///base:status", +# "///rt", +# "///schemas", +# ], +# ) +# +# cc_library( +# name = "debug_tcp_util", +# hdrs = ["debug_tcp_util.h"], +# deps = [ +# "@com_github_google_flatbuffers//:flatbuffers", +# "///base:status", +# "///schemas", +# ], +# )
diff --git a/iree/rt/debug/CMakeLists.txt b/rt/debug/CMakeLists.txt similarity index 100% rename from iree/rt/debug/CMakeLists.txt rename to rt/debug/CMakeLists.txt
diff --git a/rt/debug/debug_adapter.h b/rt/debug/debug_adapter.h new file mode 100644 index 0000000..f9c149f --- /dev/null +++ b/rt/debug/debug_adapter.h
@@ -0,0 +1,106 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_RT_DEBUG_ADAPTER_H_ +#define IREE_RT_DEBUG_ADAPTER_H_ + +#include <functional> + +#include "base/status.h" +#include "rt/invocation.h" + +namespace iree { +namespace rt { +namespace debug { + +struct StepTarget { + // TODO(benvanik): step target info (matching RPC message). + // module / function / offset + // relative to current: once, out, return, etc +}; + +// TODO(benvanik): move to fiber base. +// Interface for debugging invocations. +// This is only accessible in debug builds where such features are compiled in. +class DebugAdapter { + public: + // Called when an invocation completes suspending (in response to a Suspend or + // Step request). The |suspend_status| will indicate if the suspension was + // successful. + using SuspendCallback = std::function<void(Status suspend_status)>; + + // Returns true if the invocation is suspended. + // This only returns true if the invocation has been requested to suspend with + // Suspend and the runtime has acked the suspend. Once suspended (and until + // resumed) invocation state will not change and may be observed from any + // thread. + // + // Safe to call from any thread. + bool IsSuspended(Invocation* invocation); + + // Suspends the invocation at the next possible chance. + // + // Fibers have a suspension depth and each call to Suspend must be matched + // with a call to Resume. Fibers will only resume excution when all prior + // Suspend calls have their matching Resume called. + // + // Optionally callers may provide a |suspend_callback| that will be called + // from a random thread when the invocation is suspended (or fails to + // suspend). + // + // Safe to call from any thread. + // Returns StatusCode::kUnavailable if debugging is not supported. + Status Suspend(ref_ptr<Invocation> invocation, + SuspendCallback suspend_callback = nullptr); + + // Resumes the invocation if it is suspended (or cancels a pending suspend). + // This may wake threads if they are currently waiting on the invocation to + // execute. + // + // Safe to call from any thread. + // Returns StatusCode::kUnavailable if debugging is not supported. + Status Resume(Invocation* invocation); + + // Steps invocation execution. + // This will attempt to resume the invocation and will complete + // asynchronously. Upon returning the invocation should be assumed resumed and + // callers must query is_suspended to wait until the invocation suspends + // again. Optionally callers may provide a |suspend_callback| that will be + // called from a random thread when the invocation is suspended (or fails to + // suspend). + // + // Safe to call from any thread while the invocation is suspended. + // Returns StatusCode::kUnavailable if debugging is not supported and + // StatusCode::kFailedPrecondition if the invocation is not suspended. + Status Step(ref_ptr<Invocation> invocation, StepTarget step_target, + SuspendCallback suspend_callback = nullptr); + + // Returns a call stack that can be used to query and manipulate the + // invocation state. The behaviors supported depend on the stack frames and + // the backend support and may be conditionally enabled via compile-time or + // run-time flags. + // + // Safe to call from any thread while the invocation is suspended. + // Returns StatusCode::kUnavailable if debugging is not supported and + // StatusCode::kFailedPrecondition if the invocation is not suspended. + // The returned stack will be invalidated when the invocation is stepped or + // resumed. + StatusOr<Stack> CaptureStack(Invocation* invocation); +}; + +} // namespace debug +} // namespace rt +} // namespace iree + +#endif // IREE_RT_DEBUG_ADAPTER_H_
diff --git a/rt/debug/debug_client.cc b/rt/debug/debug_client.cc new file mode 100644 index 0000000..8ee2cbf --- /dev/null +++ b/rt/debug/debug_client.cc
@@ -0,0 +1,64 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rt/debug/debug_client.h" + +#include "base/source_location.h" +#include "base/status.h" + +namespace iree { +namespace rt { +namespace debug { + +Status DebugClient::GetFunction( + std::string module_name, std::string function_name, + std::function<void(StatusOr<RemoteFunction*> function)> callback) { + return ResolveFunction( + module_name, function_name, + [this, module_name, callback](StatusOr<int> function_ordinal) { + if (!function_ordinal.ok()) { + callback(function_ordinal.status()); + return; + } + auto status = + GetFunction(module_name, function_ordinal.ValueOrDie(), callback); + if (!status.ok()) { + callback(std::move(status)); + } + }); +} + +Status DebugClient::StepInvocationOver(const RemoteInvocation& invocation, + std::function<void()> callback) { + // TODO(benvanik): implement bytecode stepping search. + // int bytecode_offset = 0; + // return StepInvocationToOffset(invocation, bytecode_offset, + // std::move(callback)); + return UnimplementedErrorBuilder(IREE_LOC) + << "StepInvocationOver not yet implemented"; +} + +Status DebugClient::StepInvocationOut(const RemoteInvocation& invocation, + std::function<void()> callback) { + // TODO(benvanik): implement bytecode stepping search. + // int bytecode_offset = 0; + // return StepInvocationToOffset(invocation, bytecode_offset, + // std::move(callback)); + return UnimplementedErrorBuilder(IREE_LOC) + << "StepInvocationOut not yet implemented"; +} + +} // namespace debug +} // namespace rt +} // namespace iree
diff --git a/rt/debug/debug_client.h b/rt/debug/debug_client.h new file mode 100644 index 0000000..ac074b8 --- /dev/null +++ b/rt/debug/debug_client.h
@@ -0,0 +1,286 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_RT_DEBUG_DEBUG_CLIENT_H_ +#define IREE_RT_DEBUG_DEBUG_CLIENT_H_ + +#include <functional> +#include <memory> + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "base/status.h" +#include "schemas/debug_service_generated.h" + +namespace iree { +namespace rt { +namespace debug { + +// Remote breakpoint currently active on the server. +class RemoteBreakpoint { + public: + enum class Type { + kBytecodeFunction = 0, + kNativeFunction = 1, + }; + + virtual ~RemoteBreakpoint() = default; + + int id() const { return id_; } + Type type() const { return type_; } + + virtual const std::string& module_name() const = 0; + virtual const std::string& function_name() const = 0; + virtual int function_ordinal() const = 0; + virtual int bytecode_offset() const = 0; + + protected: + explicit RemoteBreakpoint(int id, Type type) : id_(id), type_(type) {} + + private: + int id_; + Type type_; +}; + +class RemoteModule; + +class RemoteFunction { + public: + virtual ~RemoteFunction() = default; + + RemoteModule* module() const { return module_; } + int ordinal() const { return function_ordinal_; } + virtual const std::string& name() const = 0; + + virtual const FunctionDef& def() = 0; + + virtual bool is_loaded() const = 0; + virtual bool CheckLoadedOrRequest() = 0; + + using LoadCallback = std::function<void(StatusOr<RemoteFunction*>)>; + virtual void WhenLoaded(LoadCallback callback) = 0; + + virtual const BytecodeDef* bytecode() = 0; + + protected: + RemoteFunction(RemoteModule* module, int function_ordinal) + : module_(module), function_ordinal_(function_ordinal) {} + + RemoteModule* module_; + int function_ordinal_; +}; + +class RemoteModule { + public: + virtual ~RemoteModule() = default; + + int context_id() const { return context_id_; } + const std::string& name() const { return name_; } + + virtual const ModuleDef& def() = 0; + + virtual bool is_loaded() const = 0; + virtual bool CheckLoadedOrRequest() = 0; + + using LoadCallback = std::function<void(StatusOr<RemoteModule*>)>; + virtual void WhenLoaded(LoadCallback callback) = 0; + + virtual absl::Span<RemoteFunction*> functions() = 0; + + protected: + RemoteModule(int context_id, std::string name) + : context_id_(context_id), name_(std::move(name)) {} + + private: + int context_id_; + std::string name_; +}; + +class RemoteContext { + public: + virtual ~RemoteContext() = default; + + int id() const { return id_; } + + virtual absl::Span<RemoteModule* const> modules() const = 0; + + protected: + explicit RemoteContext(int id) : id_(id) {} + + private: + int id_; +}; + +class RemoteInvocation { + public: + virtual ~RemoteInvocation() = default; + + int id() const { return id_; } + const std::string& name() const { return name_; } + + virtual const rpc::InvocationDefT& def() const = 0; + + protected: + explicit RemoteInvocation(int id) + : id_(id), name_(absl::StrCat("Invocation ", id)) {} + + private: + int id_; + std::string name_; +}; + +// Debugger RPC server client. +// Statefully tracks a DebugServer to provide common client operations and +// memoized queries. +// +// Thread-compatible. Do not use the client from multiple threads concurrently. +// All remote updates of local state are performed by the Poll function. See +// Poll for more details. +class DebugClient { + public: + // Debug event listener interface. + // Event methods will be called from within Poll calls (so on that thread). + // + // When the server posts an event it will mark the client as unready and + // suspend execution of all invocations until MakeReady is used to indicate + // that the client is ready for the server to resume. Each event needs a + // matching MakeReady ack. + // + // Listeners can defer acking if they need to perform additional queries or + // state changes to the server or wait for user interaction. Multiple events + // may come in while unready if there was a series of events pending on the + // server. + class Listener { + public: + virtual ~Listener() = default; + + // Signals that a context has been registered on the server. + virtual Status OnContextRegistered(const RemoteContext& context) = 0; + virtual Status OnContextUnregistered(const RemoteContext& context) = 0; + + // Signals that a module has been loaded into a context on the server. + virtual Status OnModuleLoaded(const RemoteContext& context, + const RemoteModule& module) = 0; + + // Signals that a invocation has been registered on the server. + virtual Status OnInvocationRegistered( + const RemoteInvocation& invocation) = 0; + virtual Status OnInvocationUnregistered( + const RemoteInvocation& invocation) = 0; + + // Signals that a breakpoint has been hit by a invocation on the server. + virtual Status OnBreakpointHit(const RemoteBreakpoint& breakpoint, + const RemoteInvocation& invocation) = 0; + }; + + // Connects to a remote debug service at the provided IP:port. + // The specified |listener| will receive async event notifications. + static StatusOr<std::unique_ptr<DebugClient>> Connect( + absl::string_view service_address, Listener* listener); + + virtual ~DebugClient() = default; + + // Returns true if the client is connected to a service. + // virtual bool is_connected() const = 0; + + // A list of all contexts registered with the server. + virtual absl::Span<RemoteContext* const> contexts() const = 0; + + // A list of all invocations registered with the server. + virtual absl::Span<RemoteInvocation* const> invocations() const = 0; + + // A list of all breakpoints registered with the server. + virtual absl::Span<RemoteBreakpoint* const> breakpoints() const = 0; + + // Resolves a function to a module ordinal. + // This will occur asynchronously and the |callback| will be issued on the + // polling thread. + virtual Status ResolveFunction( + std::string module_name, std::string function_name, + std::function<void(StatusOr<int> function_ordinal)> callback) = 0; + + // Gets a function body instance. + // The provided |callback| will be issued on the polling thread when the + // function is available. + virtual Status GetFunction( + std::string module_name, int function_ordinal, + std::function<void(StatusOr<RemoteFunction*> function)> callback) = 0; + Status GetFunction( + std::string module_name, std::string function_name, + std::function<void(StatusOr<RemoteFunction*> function)> callback); + + // Adds a breakpoint for the given module:function:offset. + // The breakpoint will apply to all contexts with the module loaded. + virtual Status AddFunctionBreakpoint( + std::string module_name, std::string function_name, int offset, + std::function<void(const RemoteBreakpoint& breakpoint)> callback = + nullptr) = 0; + + // Removes a breakpoint from the server. + virtual Status RemoveBreakpoint(const RemoteBreakpoint& breakpoint) = 0; + + // Notifies the server that the debug session is ready to continue. + // This must be called once on connection to and in acknowledgement to any + // events posted by the server (read: any call to the Listener::On* methods). + virtual Status MakeReady() = 0; + + // Suspends all invocations running on the server. + virtual Status SuspendAllInvocations() = 0; + + // Resumes all invocations running on the server. + virtual Status ResumeAllInvocations() = 0; + + // Suspends a list of invocations running on the server. Invocations not in + // the provided list will not be suspended, such as new invocations created + // while the request is pending. + virtual Status SuspendInvocations( + absl::Span<RemoteInvocation*> invocations) = 0; + + // Resumes a list of invocations running on the server. + virtual Status ResumeInvocations( + absl::Span<RemoteInvocation*> invocations) = 0; + + // Steps a invocation one bytecode operation. + virtual Status StepInvocation(const RemoteInvocation& invocation, + std::function<void()> callback) = 0; + // Steps a invocation over one bytecode operation, not stopping until it + // completes. + Status StepInvocationOver(const RemoteInvocation& invocation, + std::function<void()> callback); + // Steps a invocation out of the current block. + Status StepInvocationOut(const RemoteInvocation& invocation, + std::function<void()> callback); + // Steps a invocation to a specific bytecode offset within the current + // function. + virtual Status StepInvocationToOffset(const RemoteInvocation& invocation, + int bytecode_offset, + std::function<void()> callback) = 0; + + // TODO(benvanik): profiling modes. + + // Polls for the current state of the debug service and processes incoming + // responses. Must be called as frequently as the UI is desired to update. + // Returns CancelledError when the service is being shutdown/disconnected. + // + // Events on the Listener will be called from within this method. + virtual Status Poll() = 0; +}; + +} // namespace debug +} // namespace rt +} // namespace iree + +#endif // IREE_RT_DEBUG_DEBUG_CLIENT_H_
diff --git a/rt/debug/debug_client_tcp.cc b/rt/debug/debug_client_tcp.cc new file mode 100644 index 0000000..14f90f7 --- /dev/null +++ b/rt/debug/debug_client_tcp.cc
@@ -0,0 +1,1127 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <netdb.h> +#include <netinet/in.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <unistd.h> + +#include <algorithm> +#include <cstring> +#include <queue> + +#include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" +#include "absl/strings/ascii.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" +#include "absl/types/span.h" +#include "base/flatbuffer_util.h" +#include "base/status.h" +#include "flatbuffers/base.h" +#include "flatbuffers/flatbuffers.h" +#include "rt/debug/debug_client.h" +#include "rt/debug/debug_tcp_util.h" +#include "rt/module.h" +#include "schemas/debug_service_generated.h" +#include "schemas/module_def_generated.h" + +namespace iree { +namespace rt { +namespace debug { +namespace { + +using ::flatbuffers::FlatBufferBuilder; + +// Parses a host:port address, with support for the RFC 3986 IPv6 [host]:port +// format. Returns a pair of (hostname, port), with port being 0 if none was +// specified. +// +// Parses: +// foo (port 0) / foo:123 +// 1.2.3.4 (port 0) / 1.2.3.4:123 +// [foo] (port 0) / [foo]:123 +// [::1] (port 0) / [::1]:123 +StatusOr<std::pair<std::string, int>> ParseAddress(absl::string_view address) { + address = absl::StripAsciiWhitespace(address); + absl::string_view hostname; + absl::string_view port_str; + size_t bracket_loc = address.find_last_of(']'); + if (bracket_loc != std::string::npos) { + // Has at least a ]. Let's assume it's mostly right. + if (address.find('[') != 0) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Mismatched brackets in address: " << address; + } + hostname = address.substr(1, bracket_loc - 1); + port_str = address.substr(bracket_loc + 1); + if (port_str.find(':') == 0) { + port_str.remove_prefix(1); + } + } else { + size_t colon_loc = address.find_last_of(':'); + if (colon_loc != std::string::npos) { + hostname = address.substr(0, colon_loc); + port_str = address.substr(colon_loc + 1); + } else { + hostname = address; + port_str = ""; + } + } + int port = 0; + if (!port_str.empty() && !absl::SimpleAtoi(port_str, &port)) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Unable to parse port '" << port_str << "' from " << address; + } + return std::make_pair(std::string(hostname), port); +} + +class TcpDebugClient final : public DebugClient { + public: + class TcpRemoteBreakpoint : public RemoteBreakpoint { + public: + TcpRemoteBreakpoint(int id, Type type, TcpDebugClient* client) + : RemoteBreakpoint(id, type) {} + + const std::string& module_name() const override { return def_.module_name; } + const std::string& function_name() const override { + return def_.function_name; + } + int function_ordinal() const override { return def_.function_ordinal; } + int bytecode_offset() const override { return def_.bytecode_offset; } + + Status MergeFrom(const rpc::BreakpointDef& breakpoint_def) { + breakpoint_def.UnPackTo(&def_); + return OkStatus(); + } + + private: + rpc::BreakpointDefT def_; + }; + + class TcpRemoteFunction final : public RemoteFunction { + public: + TcpRemoteFunction(RemoteModule* module, int function_ordinal, + const FunctionDef* function_def, TcpDebugClient* client) + : RemoteFunction(module, function_ordinal), + def_(function_def), + client_(client) { + name_ = def_->name() ? std::string(WrapString(def_->name())) : ""; + } + + const std::string& name() const override { return name_; } + + const FunctionDef& def() override { return *def_; } + + bool is_loaded() const override { + return contents_.flatbuffers_buffer.size() > 0; + } + + bool CheckLoadedOrRequest() override { + if (!is_loaded()) { + DemandContents(); + } + return is_loaded(); + } + + void WhenLoaded(LoadCallback callback) override { + if (is_loaded()) { + callback(this); + return; + } + load_callbacks_.push_back(std::move(callback)); + } + + const BytecodeDef* bytecode() override { + CHECK(is_loaded()); + return contents_.bytecode_def; + } + + private: + void DemandContents() { + if (!has_requested_contents_) { + VLOG(2) << "Client " << client_->fd() << ": GetFunction(" + << module()->context_id() << ", " << module()->name() << ", " + << ordinal() << ")"; + FlatBufferBuilder fbb; + rpc::GetFunctionRequestT request; + request.session_id = client_->session_id(); + request.context_id = module()->context_id(); + request.module_name = module()->name(); + request.function_ordinal = ordinal(); + auto status = + client_->IssueRequest<rpc::GetFunctionRequest, + rpc::ResponseUnion::GetFunctionResponse>( + rpc::GetFunctionRequest::Pack(fbb, &request), std::move(fbb), + [this](Status status, + const rpc::Response& response_union) -> Status { + if (!status.ok()) return status; + const auto& response = + *response_union.message_as_GetFunctionResponse(); + VLOG(2) << "Client " << client_->fd() << ": GetFunction(" + << module()->context_id() << ", " << module()->name() + << ", " << ordinal() << ") = ..."; + RETURN_IF_ERROR(MergeFrom(response)); + for (auto& callback : load_callbacks_) { + callback(this); + } + load_callbacks_.clear(); + return OkStatus(); + }); + if (!status.ok()) { + LOG(ERROR) << "Failed to request module: " << status; + return; + } + has_requested_contents_ = true; + } + } + + Status MergeFrom(const rpc::GetFunctionResponse& response) { + // Clone and retain the contents. + // TODO(benvanik): find a way to steal to avoid the reserialization. + BytecodeDefT bytecode_def_storage; + response.bytecode()->UnPackTo(&bytecode_def_storage); + ::flatbuffers::FlatBufferBuilder fbb; + fbb.Finish(response.bytecode()->Pack(fbb, &bytecode_def_storage)); + contents_.flatbuffers_buffer = fbb.Release(); + contents_.bytecode_def = ::flatbuffers::GetRoot<BytecodeDef>( + contents_.flatbuffers_buffer.data()); + return OkStatus(); + } + + const FunctionDef* def_; + TcpDebugClient* client_; + std::string name_; + bool has_requested_contents_ = false; + std::vector<LoadCallback> load_callbacks_; + struct { + ::flatbuffers::DetachedBuffer flatbuffers_buffer; + const BytecodeDef* bytecode_def = nullptr; + } contents_; + }; + + class TcpRemoteModule final : public RemoteModule { + public: + TcpRemoteModule(int context_id, std::string module_name, + TcpDebugClient* client) + : RemoteModule(context_id, std::move(module_name)), client_(client) {} + + const ModuleDef& def() override { + CHECK(is_loaded()); + return *module_file_->root(); + } + + bool is_loaded() const override { return module_file_ != nullptr; } + + bool CheckLoadedOrRequest() override { + if (!is_loaded()) { + DemandModuleDef(); + } + return is_loaded(); + } + + void WhenLoaded(LoadCallback callback) override { + if (is_loaded()) { + callback(this); + return; + } + load_callbacks_.push_back(std::move(callback)); + } + + absl::Span<RemoteFunction*> functions() override { + auto* module_def = DemandModuleDef(); + if (!module_def) return {}; + return {reinterpret_cast<RemoteFunction**>(functions_.data()), + functions_.size()}; + } + + private: + const ModuleDef* DemandModuleDef() { + if (module_file_) { + return module_file_->root(); + } + if (!has_requested_module_def_) { + VLOG(2) << "Client " << client_->fd() << ": GetModule(" << context_id() + << ", " << name() << ")"; + FlatBufferBuilder fbb; + rpc::GetModuleRequestT request; + request.session_id = client_->session_id(); + request.context_id = context_id(); + request.module_name = name(); + auto status = + client_->IssueRequest<rpc::GetModuleRequest, + rpc::ResponseUnion::GetModuleResponse>( + rpc::GetModuleRequest::Pack(fbb, &request), std::move(fbb), + [this](Status status, + const rpc::Response& response_union) -> Status { + if (!status.ok()) return status; + const auto& response = + *response_union.message_as_GetModuleResponse(); + VLOG(2) << "Client " << client_->fd() << ": GetModule(" + << context_id() << ", " << name() << ") = ..."; + RETURN_IF_ERROR(MergeFrom(response)); + for (auto& callback : load_callbacks_) { + callback(this); + } + load_callbacks_.clear(); + return OkStatus(); + }); + if (!status.ok()) { + LOG(ERROR) << "Failed to request module: " << status; + return nullptr; + } + has_requested_module_def_ = true; + } + return nullptr; + } + + Status MergeFrom(const rpc::GetModuleResponse& response) { + // Clone and retain the module. + // TODO(benvanik): find a way to steal to avoid the reserialization. + ModuleDefT module_def_storage; + response.module_()->UnPackTo(&module_def_storage); + FlatBufferBuilder fbb; + auto module_offs = response.module_()->Pack(fbb, &module_def_storage); + FinishModuleDefBuffer(fbb, module_offs); + ASSIGN_OR_RETURN(auto module_file, + ModuleFile::CreateWithBackingBuffer(fbb.Release())); + + const auto& module_def = module_file->root(); + const auto& function_table = *module_def->function_table(); + functions_.reserve(function_table.functions()->size()); + for (int i = 0; i < function_table.functions()->size(); ++i) { + const auto* function_def = function_table.functions()->Get(i); + functions_.push_back(absl::make_unique<TcpRemoteFunction>( + this, i, function_def, client_)); + } + + module_file_ = std::move(module_file); + return OkStatus(); + } + + TcpDebugClient* client_; + bool has_requested_module_def_ = false; + std::vector<LoadCallback> load_callbacks_; + std::unique_ptr<ModuleFile> module_file_; + std::vector<std::unique_ptr<RemoteFunction>> functions_; + }; + + class TcpRemoteContext final : public RemoteContext { + public: + TcpRemoteContext(int context_id, TcpDebugClient* client) + : RemoteContext(context_id), client_(client) {} + + absl::Span<RemoteModule* const> modules() const override { + return absl::MakeConstSpan(modules_); + } + + Status AddModule(std::unique_ptr<TcpRemoteModule> module) { + modules_.push_back(module.get()); + module_map_.insert({module->name(), std::move(module)}); + return OkStatus(); + } + + Status MergeFrom(const rpc::ContextDef& context_def) { return OkStatus(); } + + private: + TcpDebugClient* client_; + std::vector<RemoteModule*> modules_; + absl::flat_hash_map<std::string, std::unique_ptr<TcpRemoteModule>> + module_map_; + }; + + class TcpRemoteInvocation final : public RemoteInvocation { + public: + TcpRemoteInvocation(int invocation_id, TcpDebugClient* client) + : RemoteInvocation(invocation_id), client_(client) {} + + const rpc::InvocationDefT& def() const override { return def_; } + + Status MergeFrom(const rpc::InvocationDef& invocation_def) { + invocation_def.UnPackTo(&def_); + return OkStatus(); + } + + private: + TcpDebugClient* client_; + rpc::InvocationDefT def_; + }; + + static StatusOr<std::unique_ptr<TcpDebugClient>> Create(int fd, + Listener* listener) { + VLOG(2) << "Client " << fd << ": Setting up socket options..."; + // Disable Nagel's algorithm to ensure we have low latency. + RETURN_IF_ERROR(tcp::ToggleSocketNagelsAlgorithm(fd, false)); + // Enable keepalive assuming the client is local and this high freq is ok. + RETURN_IF_ERROR(tcp::ToggleSocketLocalKeepalive(fd, true)); + // Linger around for a bit to flush all data. + RETURN_IF_ERROR(tcp::ToggleSocketLinger(fd, true)); + // Disable blocking as we are poll based. + RETURN_IF_ERROR(tcp::ToggleSocketBlocking(fd, false)); + + auto client = absl::make_unique<TcpDebugClient>(fd, listener); + RETURN_IF_ERROR(client->Refresh()); + return client; + } + + TcpDebugClient(int fd, Listener* listener) : fd_(fd), listener_(listener) {} + + ~TcpDebugClient() override { + VLOG(2) << "Client " << fd_ << ": Shutting down session socket..."; + ::shutdown(fd_, SHUT_WR); + VLOG(2) << "Client " << fd_ << ": Closing session socket..."; + ::close(fd_); + VLOG(2) << "Client " << fd_ << ": Closed session socket!"; + fd_ = -1; + } + + int fd() const { return fd_; } + int session_id() const { return session_id_; } + + absl::Span<RemoteContext* const> contexts() const override { + return absl::MakeConstSpan(contexts_); + } + + absl::Span<RemoteInvocation* const> invocations() const override { + return absl::MakeConstSpan(invocations_); + } + + absl::Span<RemoteBreakpoint* const> breakpoints() const override { + return absl::MakeConstSpan(breakpoints_); + } + + // Writes the given typed request message to the given fd by wrapping it in + // a size-prefixed rpc::Request union. + // + // Example: + // FlatBufferBuilder fbb; + // rpc::SuspendInvocationRequestBuilder request(fbb); + // RETURN_IF_ERROR(WriteRequest(fd_, request.Finish(), std::move(fbb))); + template <typename T> + Status WriteRequest(int fd, ::flatbuffers::Offset<T> request_offs, + FlatBufferBuilder fbb) { + rpc::RequestBuilder request_builder(fbb); + request_builder.add_message_type(rpc::RequestUnionTraits<T>::enum_value); + request_builder.add_message(request_offs.Union()); + fbb.FinishSizePrefixed(request_builder.Finish()); + auto write_status = tcp::WriteBuffer(fd, fbb.Release()); + if (shutdown_pending_ && IsUnavailable(write_status)) { + return OkStatus(); + } + return write_status; + } + + Status ResolveFunction( + std::string module_name, std::string function_name, + std::function<void(StatusOr<int> function_ordinal)> callback) override { + VLOG(2) << "Client " << fd_ << ": ResolveFunction(" << module_name << ", " + << function_name << ")"; + FlatBufferBuilder fbb; + rpc::ResolveFunctionRequestT request; + request.session_id = session_id_; + request.module_name = module_name; + request.function_name = function_name; + return IssueRequest<rpc::ResolveFunctionRequest, + rpc::ResponseUnion::ResolveFunctionResponse>( + rpc::ResolveFunctionRequest::Pack(fbb, &request), std::move(fbb), + [this, module_name, function_name, callback]( + Status status, const rpc::Response& response_union) -> Status { + if (status.ok()) { + const auto& response = + *response_union.message_as_ResolveFunctionResponse(); + VLOG(2) << "Client " << fd_ << ": ResolveFunction(" << module_name + << ", " << function_name + << ") = " << response.function_ordinal(); + callback(response.function_ordinal()); + } else { + callback(std::move(status)); + } + return OkStatus(); + }); + } + + Status GetFunction(std::string module_name, int function_ordinal, + std::function<void(StatusOr<RemoteFunction*> function)> + callback) override { + // See if we have the module already. If not, we'll fetch it first. + RemoteModule* target_module = nullptr; + for (auto* context : contexts_) { + for (auto* module : context->modules()) { + if (module->name() == module_name) { + target_module = module; + break; + } + } + if (target_module) break; + } + if (!target_module) { + // TODO(benvanik): fetch contexts first. + return UnimplementedErrorBuilder(IREE_LOC) + << "Demand fetch contexts not yet implemented"; + } + // Found at least one module with the right name. + if (target_module->is_loaded()) { + callback(target_module->functions()[function_ordinal]); + return OkStatus(); + } else { + // Wait until the module completes loading. + target_module->WhenLoaded( + [callback, function_ordinal](StatusOr<RemoteModule*> module_or) { + if (!module_or.ok()) { + callback(module_or.status()); + return; + } + callback(module_or.ValueOrDie()->functions()[function_ordinal]); + }); + return OkStatus(); + } + } + + Status AddFunctionBreakpoint( + std::string module_name, std::string function_name, int offset, + std::function<void(const RemoteBreakpoint& breakpoint)> callback) + override { + VLOG(2) << "Client " << fd_ << ": AddFunctionBreakpoint(" << module_name + << ", " << function_name << ", " << offset << ")"; + FlatBufferBuilder fbb; + + auto breakpoint = absl::make_unique<rpc::BreakpointDefT>(); + breakpoint->module_name = module_name; + breakpoint->function_name = function_name; + breakpoint->function_ordinal = -1; + breakpoint->bytecode_offset = offset; + rpc::AddBreakpointRequestT request; + request.session_id = session_id_; + request.breakpoint = std::move(breakpoint); + return IssueRequest<rpc::AddBreakpointRequest, + rpc::ResponseUnion::AddBreakpointResponse>( + rpc::AddBreakpointRequest::Pack(fbb, &request), std::move(fbb), + [this, callback](Status status, + const rpc::Response& response_union) -> Status { + if (!status.ok()) return status; + const auto& response = + *response_union.message_as_AddBreakpointResponse(); + RETURN_IF_ERROR(RegisterBreakpoint(*response.breakpoint())); + if (callback) { + ASSIGN_OR_RETURN( + auto breakpoint, + GetBreakpoint(response.breakpoint()->breakpoint_id())); + callback(*breakpoint); + } + return OkStatus(); + }); + } + + Status RemoveBreakpoint(const RemoteBreakpoint& breakpoint) override { + VLOG(2) << "Client " << fd_ << ": RemoveBreakpoint(" << breakpoint.id() + << ")"; + int breakpoint_id = breakpoint.id(); + ASSIGN_OR_RETURN(auto* breakpoint_ptr, GetBreakpoint(breakpoint_id)); + RETURN_IF_ERROR(UnregisterBreakpoint(breakpoint_ptr)); + FlatBufferBuilder fbb; + rpc::RemoveBreakpointRequestBuilder request(fbb); + request.add_session_id(session_id_); + request.add_breakpoint_id(breakpoint_id); + return IssueRequest<rpc::RemoveBreakpointRequest, + rpc::ResponseUnion::RemoveBreakpointResponse>( + request.Finish(), std::move(fbb), + [](Status status, const rpc::Response& response_union) -> Status { + if (!status.ok()) return status; + // No non-error status. + return OkStatus(); + }); + } + + Status MakeReady() override { + FlatBufferBuilder fbb; + rpc::MakeReadyRequestBuilder request(fbb); + request.add_session_id(session_id_); + return IssueRequest<rpc::MakeReadyRequest, + rpc::ResponseUnion::MakeReadyResponse>( + request.Finish(), std::move(fbb), + [](Status status, const rpc::Response& response_union) { + return status; + }); + } + + Status SuspendAllInvocations() override { + VLOG(2) << "Client " << fd_ << ": SuspendAllInvocations()"; + FlatBufferBuilder fbb; + rpc::SuspendInvocationsRequestBuilder request(fbb); + request.add_session_id(session_id_); + return IssueRequest<rpc::SuspendInvocationsRequest, + rpc::ResponseUnion::SuspendInvocationsResponse>( + request.Finish(), std::move(fbb), + [this](Status status, const rpc::Response& response_union) -> Status { + if (!status.ok()) return status; + return RefreshInvocations(); + }); + } + + Status ResumeAllInvocations() override { + VLOG(2) << "Client " << fd_ << ": ResumeAllInvocations()"; + FlatBufferBuilder fbb; + rpc::ResumeInvocationsRequestBuilder request(fbb); + request.add_session_id(session_id_); + return IssueRequest<rpc::ResumeInvocationsRequest, + rpc::ResponseUnion::ResumeInvocationsResponse>( + request.Finish(), std::move(fbb), + [this](Status status, const rpc::Response& response_union) -> Status { + if (!status.ok()) return status; + return RefreshInvocations(); + }); + } + + Status SuspendInvocations( + absl::Span<RemoteInvocation*> invocations) override { + VLOG(2) << "Client " << fd_ << ": SuspendInvocations(...)"; + FlatBufferBuilder fbb; + auto invocation_ids_offs = fbb.CreateVector<int32_t>( + invocations.size(), + [&invocations](size_t i) { return invocations[i]->id(); }); + rpc::SuspendInvocationsRequestBuilder request(fbb); + request.add_session_id(session_id_); + request.add_invocation_ids(invocation_ids_offs); + return IssueRequest<rpc::SuspendInvocationsRequest, + rpc::ResponseUnion::SuspendInvocationsResponse>( + request.Finish(), std::move(fbb), + [this](Status status, const rpc::Response& response_union) -> Status { + if (!status.ok()) return status; + return RefreshInvocations(); + }); + } + + Status ResumeInvocations(absl::Span<RemoteInvocation*> invocations) override { + VLOG(2) << "Client " << fd_ << ": ResumeInvocations(...)"; + FlatBufferBuilder fbb; + auto invocation_ids_offs = fbb.CreateVector<int32_t>( + invocations.size(), + [&invocations](size_t i) { return invocations[i]->id(); }); + rpc::ResumeInvocationsRequestBuilder request(fbb); + request.add_session_id(session_id_); + request.add_invocation_ids(invocation_ids_offs); + return IssueRequest<rpc::ResumeInvocationsRequest, + rpc::ResponseUnion::ResumeInvocationsResponse>( + request.Finish(), std::move(fbb), + [this](Status status, const rpc::Response& response_union) -> Status { + if (!status.ok()) return status; + return RefreshInvocations(); + }); + } + + Status StepInvocation(const RemoteInvocation& invocation, + std::function<void()> callback) override { + int step_id = next_step_id_++; + VLOG(2) << "Client " << fd_ << ": StepInvocation(" << invocation.id() + << ") as step_id=" << step_id; + rpc::StepInvocationRequestT step_request; + step_request.step_id = step_id; + step_request.invocation_id = invocation.id(); + step_request.step_mode = rpc::StepMode::STEP_ONCE; + return StepInvocation(&step_request, std::move(callback)); + } + + Status StepInvocationToOffset(const RemoteInvocation& invocation, + int bytecode_offset, + std::function<void()> callback) override { + int step_id = next_step_id_++; + VLOG(2) << "Client " << fd_ << ": StepInvocationToOffset(" + << invocation.id() << ", " << bytecode_offset + << ") as step_id=" << step_id; + rpc::StepInvocationRequestT step_request; + step_request.step_id = step_id; + step_request.invocation_id = invocation.id(); + step_request.step_mode = rpc::StepMode::STEP_TO_OFFSET; + step_request.bytecode_offset = bytecode_offset; + return StepInvocation(&step_request, std::move(callback)); + } + + Status Poll() override { + while (true) { + // If nothing awaiting then return immediately. + if (!tcp::CanReadBuffer(fd_)) { + break; + } + + // Read the pending response and dispatch. + auto packet_buffer_or = tcp::ReadBuffer<rpc::ServicePacket>(fd_); + if (!packet_buffer_or.ok()) { + if (shutdown_pending_ && IsUnavailable(packet_buffer_or.status())) { + // This is a graceful close. + return CancelledErrorBuilder(IREE_LOC) << "Service shutdown"; + } + return packet_buffer_or.status(); + } + const auto& packet = packet_buffer_or.ValueOrDie().GetRoot(); + if (packet.response()) { + RETURN_IF_ERROR(DispatchResponse(*packet.response())); + } + if (packet.event()) { + RETURN_IF_ERROR(DispatchEvent(packet)); + } + } + return OkStatus(); + } + + using ResponseCallback = + std::function<Status(Status status, const rpc::Response& response)>; + + template <typename T, rpc::ResponseUnion response_type> + Status IssueRequest(::flatbuffers::Offset<T> request_offs, + FlatBufferBuilder fbb, ResponseCallback callback) { + RETURN_IF_ERROR(WriteRequest(fd_, request_offs, std::move(fbb))); + pending_responses_.push({response_type, std::move(callback)}); + return OkStatus(); + } + + private: + Status Refresh() { + RETURN_IF_ERROR(RefreshContexts()); + RETURN_IF_ERROR(RefreshInvocations()); + RETURN_IF_ERROR(RefreshBreakpoints()); + return OkStatus(); + } + + Status RefreshContexts() { + VLOG(2) << "Request contexts refresh..."; + FlatBufferBuilder fbb; + rpc::ListContextsRequestBuilder request(fbb); + request.add_session_id(session_id_); + return IssueRequest<rpc::ListContextsRequest, + rpc::ResponseUnion::ListContextsResponse>( + request.Finish(), std::move(fbb), + [this](Status status, const rpc::Response& response_union) -> Status { + if (!status.ok()) return status; + VLOG(2) << "Refreshing contexts..."; + const auto& response = + *response_union.message_as_ListContextsResponse(); + for (auto* context_def : *response.contexts()) { + auto context_or = GetContext(context_def->context_id()); + if (!context_or.ok()) { + // Not found; add new. + RETURN_IF_ERROR(RegisterContext(context_def->context_id())); + context_or = GetContext(context_def->context_id()); + } + RETURN_IF_ERROR(context_or.status()); + RETURN_IF_ERROR(context_or.ValueOrDie()->MergeFrom(*context_def)); + } + VLOG(2) << "Refreshed contexts!"; + return OkStatus(); + }); + } + + Status RefreshInvocations() { + VLOG(2) << "Request invocation states refresh..."; + FlatBufferBuilder fbb; + rpc::ListInvocationsRequestBuilder request(fbb); + request.add_session_id(session_id_); + return IssueRequest<rpc::ListInvocationsRequest, + rpc::ResponseUnion::ListInvocationsResponse>( + request.Finish(), std::move(fbb), + [this](Status status, const rpc::Response& response_union) -> Status { + if (!status.ok()) return status; + VLOG(2) << "Refreshing invocation states..."; + const auto& response = + *response_union.message_as_ListInvocationsResponse(); + for (auto* invocation_def : *response.invocations()) { + auto invocation_or = GetInvocation(invocation_def->invocation_id()); + if (!invocation_or.ok()) { + // Not found; add new. + RETURN_IF_ERROR( + RegisterInvocation(invocation_def->invocation_id())); + invocation_or = GetInvocation(invocation_def->invocation_id()); + } + RETURN_IF_ERROR(invocation_or.status()); + RETURN_IF_ERROR( + invocation_or.ValueOrDie()->MergeFrom(*invocation_def)); + } + // TODO(benvanik): handle removals/deaths. + VLOG(2) << "Refreshed invocation states!"; + return OkStatus(); + }); + } + + Status RefreshBreakpoints() { + VLOG(2) << "Requesting breakpoint refresh..."; + FlatBufferBuilder fbb; + rpc::ListBreakpointsRequestBuilder request(fbb); + request.add_session_id(session_id_); + return IssueRequest<rpc::ListBreakpointsRequest, + rpc::ResponseUnion::ListBreakpointsResponse>( + request.Finish(), std::move(fbb), + [this](Status status, const rpc::Response& response_union) -> Status { + if (!status.ok()) return status; + VLOG(2) << "Refreshing breakpoints..."; + const auto& response = + *response_union.message_as_ListBreakpointsResponse(); + for (auto* breakpoint_def : *response.breakpoints()) { + auto breakpoint_or = GetBreakpoint(breakpoint_def->breakpoint_id()); + if (!breakpoint_or.ok()) { + // Not found; add new. + RETURN_IF_ERROR(RegisterBreakpoint(*breakpoint_def)); + breakpoint_or = GetBreakpoint(breakpoint_def->breakpoint_id()); + } + RETURN_IF_ERROR(breakpoint_or.status()); + RETURN_IF_ERROR( + breakpoint_or.ValueOrDie()->MergeFrom(*breakpoint_def)); + } + // TODO(benvanik): handle removals/deaths. + VLOG(2) << "Refreshed breakpoints!"; + return OkStatus(); + }); + } + + Status DispatchResponse(const rpc::Response& response) { + if (pending_responses_.empty()) { + return FailedPreconditionErrorBuilder(IREE_LOC) + << "Response received but no request is pending"; + } + auto type_callback = std::move(pending_responses_.front()); + pending_responses_.pop(); + + if (response.status()) { + const auto& status = *response.status(); + Status client_status = + StatusBuilder(static_cast<StatusCode>(status.code()), IREE_LOC) + << "Server request failed: " << WrapString(status.message()); + return type_callback.second(std::move(client_status), response); + } + + if (!response.message()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Response contains no message body"; + } + + if (response.message_type() != type_callback.first) { + return DataLossErrorBuilder(IREE_LOC) + << "Out of order response (mismatch pending)"; + } + return type_callback.second(OkStatus(), response); + } + + Status DispatchEvent(const rpc::ServicePacket& packet) { + switch (packet.event_type()) { +#define DISPATCH_EVENT(event_name) \ + case rpc::EventUnion::event_name##Event: { \ + VLOG(2) << "EVENT: " << #event_name; \ + return On##event_name(*packet.event_as_##event_name##Event()); \ + } + DISPATCH_EVENT(ServiceShutdown); + DISPATCH_EVENT(ContextRegistered); + DISPATCH_EVENT(ContextUnregistered); + DISPATCH_EVENT(ModuleLoaded); + DISPATCH_EVENT(InvocationRegistered); + DISPATCH_EVENT(InvocationUnregistered); + DISPATCH_EVENT(BreakpointResolved); + DISPATCH_EVENT(BreakpointHit); + DISPATCH_EVENT(StepCompleted); + default: + return UnimplementedErrorBuilder(IREE_LOC) + << "Unimplemented debug service event: " + << static_cast<int>(packet.event_type()); + } + } + + StatusOr<TcpRemoteContext*> GetContext(int context_id) { + auto it = context_map_.find(context_id); + if (it == context_map_.end()) { + return NotFoundErrorBuilder(IREE_LOC) << "Context was never registered"; + } + return it->second.get(); + } + + Status OnServiceShutdown(const rpc::ServiceShutdownEvent& event) { + LOG(INFO) << "Service is shutting down; setting pending shutdown flag"; + shutdown_pending_ = true; + return OkStatus(); + } + + Status RegisterContext(int context_id) { + auto context = absl::make_unique<TcpRemoteContext>(context_id, this); + VLOG(2) << "RegisterContext(" << context_id << ")"; + auto context_ptr = context.get(); + context_map_.insert({context_id, std::move(context)}); + contexts_.push_back(context_ptr); + return listener_->OnContextRegistered(*context_ptr); + } + + Status OnContextRegistered(const rpc::ContextRegisteredEvent& event) { + VLOG(2) << "OnContextRegistered(" << event.context_id() << ")"; + auto it = context_map_.find(event.context_id()); + if (it != context_map_.end()) { + return FailedPreconditionErrorBuilder(IREE_LOC) + << "Context already registered"; + } + return RegisterContext(event.context_id()); + } + + Status OnContextUnregistered(const rpc::ContextUnregisteredEvent& event) { + VLOG(2) << "OnContextUnregistered(" << event.context_id() << ")"; + auto it = context_map_.find(event.context_id()); + if (it == context_map_.end()) { + return FailedPreconditionErrorBuilder(IREE_LOC) + << "Context was never registered"; + } + auto context = std::move(it->second); + context_map_.erase(it); + auto list_it = std::find(contexts_.begin(), contexts_.end(), context.get()); + contexts_.erase(list_it); + return listener_->OnContextUnregistered(*context); + } + + Status OnModuleLoaded(const rpc::ModuleLoadedEvent& event) { + VLOG(2) << "OnModuleLoaded(" << event.context_id() << ", " + << WrapString(event.module_name()) << ")"; + ASSIGN_OR_RETURN(auto* context, GetContext(event.context_id())); + auto module_name = WrapString(event.module_name()); + auto module = absl::make_unique<TcpRemoteModule>( + event.context_id(), std::string(module_name), this); + auto* module_ptr = module.get(); + RETURN_IF_ERROR(context->AddModule(std::move(module))); + return listener_->OnModuleLoaded(*context, *module_ptr); + } + + StatusOr<TcpRemoteInvocation*> GetInvocation(int invocation_id) { + auto it = invocation_map_.find(invocation_id); + if (it == invocation_map_.end()) { + return NotFoundErrorBuilder(IREE_LOC) + << "Invocation was never registered"; + } + return it->second.get(); + } + + Status RegisterInvocation(int invocation_id) { + VLOG(2) << "RegisterInvocation(" << invocation_id << ")"; + auto invocation = + absl::make_unique<TcpRemoteInvocation>(invocation_id, this); + auto invocation_ptr = invocation.get(); + invocation_map_.insert({invocation_id, std::move(invocation)}); + invocations_.push_back(invocation_ptr); + RETURN_IF_ERROR(RefreshInvocations()); + return listener_->OnInvocationRegistered(*invocation_ptr); + } + + Status OnInvocationRegistered(const rpc::InvocationRegisteredEvent& event) { + VLOG(2) << "OnInvocationRegistered(" << event.invocation_id() << ")"; + auto it = invocation_map_.find(event.invocation_id()); + if (it != invocation_map_.end()) { + return FailedPreconditionErrorBuilder(IREE_LOC) + << "Invocation already registered"; + } + return RegisterInvocation(event.invocation_id()); + } + + Status OnInvocationUnregistered( + const rpc::InvocationUnregisteredEvent& event) { + VLOG(2) << "OnInvocationUnregistered(" << event.invocation_id() << ")"; + auto it = invocation_map_.find(event.invocation_id()); + if (it == invocation_map_.end()) { + return FailedPreconditionErrorBuilder(IREE_LOC) + << "Invocation was never registered"; + } + auto invocation = std::move(it->second); + invocation_map_.erase(it); + auto list_it = + std::find(invocations_.begin(), invocations_.end(), invocation.get()); + invocations_.erase(list_it); + return listener_->OnInvocationUnregistered(*invocation); + } + + StatusOr<TcpRemoteBreakpoint*> GetBreakpoint(int breakpoint_id) { + auto it = breakpoint_map_.find(breakpoint_id); + if (it == breakpoint_map_.end()) { + return NotFoundErrorBuilder(IREE_LOC) + << "Breakpoint " << breakpoint_id << " was never registered"; + } + return it->second.get(); + } + + Status RegisterBreakpoint(const rpc::BreakpointDef& breakpoint_def) { + auto it = breakpoint_map_.find(breakpoint_def.breakpoint_id()); + if (it != breakpoint_map_.end()) { + VLOG(2) << "RegisterBreakpoint(" << breakpoint_def.breakpoint_id() + << ") (update)"; + return it->second->MergeFrom(breakpoint_def); + } + + VLOG(2) << "RegisterBreakpoint(" << breakpoint_def.breakpoint_id() << ")"; + auto breakpoint = absl::make_unique<TcpRemoteBreakpoint>( + breakpoint_def.breakpoint_id(), + static_cast<RemoteBreakpoint::Type>(breakpoint_def.breakpoint_type()), + this); + RETURN_IF_ERROR(breakpoint->MergeFrom(breakpoint_def)); + breakpoints_.push_back(breakpoint.get()); + breakpoint_map_.insert({breakpoint->id(), std::move(breakpoint)}); + return OkStatus(); + } + + Status UnregisterBreakpoint(RemoteBreakpoint* breakpoint) { + VLOG(2) << "UnregisterBreakpoint(" << breakpoint->id() << ")"; + auto it = breakpoint_map_.find(breakpoint->id()); + if (it == breakpoint_map_.end()) { + return NotFoundErrorBuilder(IREE_LOC) + << "Breakpoint was never registered"; + } + breakpoint_map_.erase(it); + auto list_it = + std::find(breakpoints_.begin(), breakpoints_.end(), breakpoint); + breakpoints_.erase(list_it); + return OkStatus(); + } + + Status OnBreakpointResolved(const rpc::BreakpointResolvedEvent& event) { + VLOG(2) << "OnBreakpointResolved(" << event.breakpoint()->breakpoint_id() + << ")"; + auto it = breakpoint_map_.find(event.breakpoint()->breakpoint_id()); + if (it == breakpoint_map_.end()) { + RETURN_IF_ERROR(RegisterBreakpoint(*event.breakpoint())); + } else { + RETURN_IF_ERROR(it->second->MergeFrom(*event.breakpoint())); + } + return OkStatus(); + } + + Status OnBreakpointHit(const rpc::BreakpointHitEvent& event) { + VLOG(2) << "OnBreakpointHit(" << event.breakpoint_id() << ")"; + ASSIGN_OR_RETURN(auto* breakpoint, GetBreakpoint(event.breakpoint_id())); + auto* invocation_def = event.invocation(); + auto invocation_or = GetInvocation(invocation_def->invocation_id()); + if (!invocation_or.ok()) { + // Not found; add new. + RETURN_IF_ERROR(RegisterInvocation(invocation_def->invocation_id())); + invocation_or = GetInvocation(invocation_def->invocation_id()); + } + RETURN_IF_ERROR(invocation_or.status()); + RETURN_IF_ERROR(invocation_or.ValueOrDie()->MergeFrom(*invocation_def)); + return listener_->OnBreakpointHit(*breakpoint, *invocation_or.ValueOrDie()); + } + + Status StepInvocation(rpc::StepInvocationRequestT* step_request, + std::function<void()> callback) { + FlatBufferBuilder fbb; + auto status = IssueRequest<rpc::StepInvocationRequest, + rpc::ResponseUnion::StepInvocationResponse>( + rpc::StepInvocationRequest::Pack(fbb, step_request), std::move(fbb), + [](Status status, const rpc::Response& response_union) -> Status { + return status; + }); + RETURN_IF_ERROR(status); + pending_step_callbacks_[step_request->step_id] = std::move(callback); + return OkStatus(); + } + + Status OnStepCompleted(const rpc::StepCompletedEvent& event) { + VLOG(2) << "OnStepCompleted(" << event.step_id() << ")"; + + // Update all invocation states that are contained. + // This may only be a subset of relevant states. + for (auto* invocation_def : *event.invocations()) { + ASSIGN_OR_RETURN(auto invocation, + GetInvocation(invocation_def->invocation_id())); + RETURN_IF_ERROR(invocation->MergeFrom(*invocation_def)); + } + + // Dispatch step callback. Note that it may have been cancelled and that's + // ok. We'll just make ready to resume execution. + auto it = pending_step_callbacks_.find(event.step_id()); + if (it != pending_step_callbacks_.end()) { + it->second(); + pending_step_callbacks_.erase(it); + } else { + LOG(WARNING) << "Step " << event.step_id() + << " not found; was cancelled?"; + RETURN_IF_ERROR(MakeReady()); + } + return OkStatus(); + } + + int session_id_ = 123; + + int fd_ = -1; + Listener* listener_; + bool shutdown_pending_ = false; + std::queue<std::pair<rpc::ResponseUnion, ResponseCallback>> + pending_responses_; + + std::vector<RemoteContext*> contexts_; + absl::flat_hash_map<int, std::unique_ptr<TcpRemoteContext>> context_map_; + std::vector<RemoteInvocation*> invocations_; + absl::flat_hash_map<int, std::unique_ptr<TcpRemoteInvocation>> + invocation_map_; + std::vector<RemoteBreakpoint*> breakpoints_; + absl::flat_hash_map<int, std::unique_ptr<TcpRemoteBreakpoint>> + breakpoint_map_; + + int next_step_id_ = 1; + absl::flat_hash_map<int, std::function<void()>> pending_step_callbacks_; +}; + +} // namespace + +// static +StatusOr<std::unique_ptr<DebugClient>> DebugClient::Connect( + absl::string_view service_address, Listener* listener) { + // Parse address into hostname and port. + ASSIGN_OR_RETURN(auto hostname_port, ParseAddress(service_address)); + if (hostname_port.second == 0) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "No port specified in service address; port must match the " + "server: " + << service_address; + } + + // Attempt to resolve the address. + // Note that if we only wanted local debugging we could remove the dep on + // getaddrinfo/having a valid DNS setup. + addrinfo hints = {0}; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + addrinfo* resolved_address = nullptr; + auto port_str = std::to_string(hostname_port.second); + int getaddrinfo_ret = ::getaddrinfo( + hostname_port.first.c_str(), port_str.c_str(), &hints, &resolved_address); + if (getaddrinfo_ret != 0) { + return UnavailableErrorBuilder(IREE_LOC) + << "Unable to resolve debug service address for " << service_address + << ": (" << getaddrinfo_ret << ") " + << ::gai_strerror(getaddrinfo_ret); + } + + // Attempt to connect with each address returned from the query. + int fd = -1; + for (addrinfo* rp = resolved_address; rp != nullptr; rp = rp->ai_next) { + fd = ::socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + if (fd == -1) continue; + if (::connect(fd, rp->ai_addr, rp->ai_addrlen) == 0) { + break; // Success! + } + ::close(fd); + fd = -1; + } + ::freeaddrinfo(resolved_address); + if (fd == -1) { + return UnavailableErrorBuilder(IREE_LOC) + << "Unable to connect to " << service_address << " on any address: (" + << errno << ") " << ::strerror(errno); + } + + LOG(INFO) << "Connected to debug service at " << service_address; + + return TcpDebugClient::Create(fd, listener); +} + +} // namespace debug +} // namespace rt +} // namespace iree
diff --git a/rt/debug/debug_server.h b/rt/debug/debug_server.h new file mode 100644 index 0000000..6fd416d --- /dev/null +++ b/rt/debug/debug_server.h
@@ -0,0 +1,88 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_RT_DEBUG_DEBUG_SERVER_H_ +#define IREE_RT_DEBUG_DEBUG_SERVER_H_ + +#include "base/status.h" + +namespace iree { +namespace rt { +class Context; +class Instance; +class Invocation; +class Module; +} // namespace rt +} // namespace iree + +namespace iree { +namespace rt { +namespace debug { + +// Runtime debugging server. +// Enabled only when compiled in (by defining IREE_DEBUG=1), this provides an +// RPC server that allows debuggers to attach, query, and manipulate contexts. +// This interface is used by various parts of the runtime such as dispatch to +// query the current debug state and signal events. +// +// Thread-safe. Contexts may be registered and unregistered from any thread. +class DebugServer { + public: + // Creates a new debug service listening on the provided |port|. + // Even when disabled the device can still be created however it will not + // perform any actual operations and act as if the debugger is not attached. + static StatusOr<std::unique_ptr<DebugServer>> Create(int listen_port); + + // TODO(benvanik): ensure this gets optimized out when disabled. + // Seems to be the case: https://gcc.godbolt.org/z/0zf-L4 + virtual ~DebugServer() = default; + + // Attaches a callback that will be made when the debug server is shutting + // down. This can be used to keep resources alive that require the debugger. + // The callback will be made from a random thread. + virtual void AtExit(std::function<void()> callback) = 0; + + // Blocks the caller until a client session connects and resumes all fibers. + // Returns AbortedError if a session connects/is connected but disconnects + // during the wait. + virtual Status WaitUntilSessionReady() = 0; + + protected: + friend class ::iree::rt::Instance; + + // Registers a context with the debug service. + // Ownership remains with the caller and UnregisterContext must be called + // prior to the context being destroyed. + virtual Status RegisterContext(Context* context) = 0; + virtual Status UnregisterContext(Context* context) = 0; + + friend class ::iree::rt::Context; + + // Registers a new module linked into an existing Context. + virtual Status RegisterContextModule(Context* context, Module* module) = 0; + + friend class ::iree::rt::Invocation; + + // Registers an invocation with the debug service. + // Ownership remains with the caller and UnregisterInvocation must be called + // prior to the fiber state being destroyed. + virtual Status RegisterInvocation(Invocation* invocation) = 0; + virtual Status UnregisterInvocation(Invocation* invocation) = 0; +}; + +} // namespace debug +} // namespace rt +} // namespace iree + +#endif // IREE_RT_DEBUG_DEBUG_SERVER_H_
diff --git a/rt/debug/debug_server_disabled.cc b/rt/debug/debug_server_disabled.cc new file mode 100644 index 0000000..28306a9 --- /dev/null +++ b/rt/debug/debug_server_disabled.cc
@@ -0,0 +1,28 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rt/debug/debug_server.h" + +namespace iree { +namespace rt { +namespace debug { + +// static +StatusOr<std::unique_ptr<DebugServer>> DebugServer::Create(int listen_port) { + return {nullptr}; +} + +} // namespace debug +} // namespace rt +} // namespace iree
diff --git a/rt/debug/debug_server_flags.cc b/rt/debug/debug_server_flags.cc new file mode 100644 index 0000000..4455bf7 --- /dev/null +++ b/rt/debug/debug_server_flags.cc
@@ -0,0 +1,80 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rt/debug/debug_server_flags.h" + +#include "absl/flags/flag.h" +#include "absl/strings/str_cat.h" +#include "base/memory.h" +#include "base/status.h" + +#if defined(IREE_DEBUG_EMBEDDED_APP_PRESENT) +#include "tools/debugger/debug_app_embedded.h" +#endif // IREE_DEBUG_EMBEDDED_APP_PRESENT + +ABSL_FLAG(int32_t, iree_debug_service_port, 6000, + "TCP port to listen for debug service connections."); +ABSL_FLAG(bool, iree_wait_for_debugger, false, + "Waits until a debugger connects to continue startup."); +ABSL_FLAG(bool, iree_attach_debugger, false, "Attaches a debugger at startup."); + +namespace iree { +namespace rt { +namespace debug { + +StatusOr<std::unique_ptr<DebugServer>> CreateDebugServerFromFlags() { + // Create the server based on whatever version is compiled in. + // Note that this will return nullptr if no server is available. + ASSIGN_OR_RETURN( + auto debug_server, + DebugServer::Create(absl::GetFlag(FLAGS_iree_debug_service_port))); + if (!debug_server) { + return nullptr; + } + +#if defined(IREE_DEBUG_EMBEDDED_APP_PRESENT) + // If the embedded debug UI is present then we can launch that now. + std::unique_ptr<EmbeddedDebugger> debugger; + if (absl::GetFlag(FLAGS_iree_attach_debugger)) { + LOG(INFO) << "Attaching debugger at startup..."; + ASSIGN_OR_RETURN( + debugger, + AttachDebugger(absl::StrCat( + "localhost:", absl::GetFlag(FLAGS_iree_debug_service_port)))); + RETURN_IF_ERROR(debug_server->WaitUntilSessionReady()); + LOG(INFO) << "Debugger attached"; + // TODO(benvanik): C++14 to avoid this. + auto debugger_baton = IreeMoveToLambda(debugger); + debug_server->AtExit([debugger_baton]() { debugger_baton.value.reset(); }); + } +#else + if (absl::GetFlag(FLAGS_iree_attach_debugger)) { + LOG(WARNING) << "--iree_attach_debugger specified but no embedded debugger " + "is present. Build with --define=IREE_DEBUG=1."; + } +#endif // IREE_DEBUG_EMBEDDED_APP_PRESENT + + // Wait for a debugger to connect. + if (absl::GetFlag(FLAGS_iree_wait_for_debugger)) { + LOG(INFO) << "Waiting for a debugger to connect..."; + RETURN_IF_ERROR(debug_server->WaitUntilSessionReady()); + LOG(INFO) << "Debugger ready, resuming..."; + } + + return std::move(debug_server); +} + +} // namespace debug +} // namespace rt +} // namespace iree
diff --git a/rt/debug/debug_server_flags.h b/rt/debug/debug_server_flags.h new file mode 100644 index 0000000..13dc1b2 --- /dev/null +++ b/rt/debug/debug_server_flags.h
@@ -0,0 +1,33 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_RT_DEBUG_DEBUG_SERVER_FLAGS_H_ +#define IREE_RT_DEBUG_DEBUG_SERVER_FLAGS_H_ + +#include "base/status.h" +#include "rt/debug/debug_server.h" + +namespace iree { +namespace rt { +namespace debug { + +// Creates a debug server based on the current --iree_* debug flags. +// Returns nullptr if no server is compiled in or the flags are not set. +StatusOr<std::unique_ptr<DebugServer>> CreateDebugServerFromFlags(); + +} // namespace debug +} // namespace rt +} // namespace iree + +#endif // IREE_RT_DEBUG_DEBUG_SERVER_FLAGS_H_
diff --git a/rt/debug/debug_server_tcp.cc b/rt/debug/debug_server_tcp.cc new file mode 100644 index 0000000..a4bf868 --- /dev/null +++ b/rt/debug/debug_server_tcp.cc
@@ -0,0 +1,459 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <netinet/in.h> +#include <netinet/tcp.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <unistd.h> + +#include <algorithm> +#include <cerrno> +#include <exception> +#include <thread> // NOLINT + +#include "absl/base/thread_annotations.h" +#include "absl/memory/memory.h" +#include "absl/synchronization/mutex.h" +#include "base/status.h" +#include "flatbuffers/flatbuffers.h" +#include "rt/debug/debug_server.h" +#include "rt/debug/debug_service.h" +#include "rt/debug/debug_tcp_util.h" +#include "schemas/debug_service_generated.h" + +namespace iree { +namespace rt { +namespace debug { +namespace { + +// Writes the given typed response message to the given fd by wrapping it in +// a size-prefixed rpc::Request union. +// +// Example: +// ::flatbuffers::FlatBufferBuilder fbb; +// rpc::SuspendInvocationResponseBuilder response(fbb); +// RETURN_IF_ERROR(WriteResponse(fd_, response.Finish(), std::move(fbb))); +template <typename T> +Status WriteResponse(int fd, ::flatbuffers::Offset<T> message_offs, + ::flatbuffers::FlatBufferBuilder fbb) { + rpc::ResponseBuilder response_builder(fbb); + response_builder.add_message_type(rpc::ResponseUnionTraits<T>::enum_value); + response_builder.add_message(message_offs.Union()); + auto response_offs = response_builder.Finish(); + rpc::ServicePacketBuilder packet_builder(fbb); + packet_builder.add_response(response_offs); + fbb.FinishSizePrefixed(packet_builder.Finish()); + return tcp::WriteBuffer(fd, fbb.Release()); +} + +class TcpDebugSession : public DebugSession { + public: + using ClosedCallback = + std::function<void(TcpDebugSession* session, Status status)>; + + static StatusOr<std::unique_ptr<TcpDebugSession>> Accept( + DebugService* debug_service, int client_fd, + ClosedCallback closed_callback) { + VLOG(2) << "Client " << client_fd << ": Setting up socket options..."; + // Disable Nagel's algorithm to ensure we have low latency. + RETURN_IF_ERROR(tcp::ToggleSocketNagelsAlgorithm(client_fd, false)); + // Enable keepalive assuming the client is local and this high freq is ok. + RETURN_IF_ERROR(tcp::ToggleSocketLocalKeepalive(client_fd, true)); + // Linger around for a bit to flush all data. + RETURN_IF_ERROR(tcp::ToggleSocketLinger(client_fd, true)); + + return absl::make_unique<TcpDebugSession>(debug_service, client_fd, + std::move(closed_callback)); + } + + TcpDebugSession(DebugService* debug_service, int client_fd, + ClosedCallback closed_callback) + : debug_service_(debug_service), + client_fd_(client_fd), + closed_callback_(std::move(closed_callback)) { + CHECK_OK(debug_service_->RegisterDebugSession(this)); + session_thread_ = std::thread([this]() { SessionThread(); }); + } + + ~TcpDebugSession() override { + CHECK_OK(debug_service_->UnregisterDebugSession(this)); + VLOG(2) << "Client " << client_fd_ << ": Shutting down session socket..."; + ::shutdown(client_fd_, SHUT_RD); + if (session_thread_.joinable() && + session_thread_.get_id() != std::this_thread::get_id()) { + VLOG(2) << "Client " << client_fd_ << ": Joining socket thread..."; + session_thread_.join(); + VLOG(2) << "Client " << client_fd_ << ": Joined socket thread!"; + } else { + VLOG(2) << "Client " << client_fd_ << ": Detaching socket thread..."; + session_thread_.detach(); + } + VLOG(2) << "Client " << client_fd_ << ": Closing session socket..."; + ::close(client_fd_); + VLOG(2) << "Client " << client_fd_ << ": Closed session socket!"; + client_fd_ = -1; + } + + Status OnServiceShutdown() { + VLOG(2) << "Client " << client_fd_ << ": Post OnServiceShutdown()"; + ::flatbuffers::FlatBufferBuilder fbb; + rpc::ServiceShutdownEventBuilder event(fbb); + return PostEvent(event.Finish(), std::move(fbb)); + } + + Status OnContextRegistered(Context* context) override { + VLOG(2) << "Client " << client_fd_ << ": Post OnContextRegistered(" + << context->id() << ")"; + ::flatbuffers::FlatBufferBuilder fbb; + rpc::ContextRegisteredEventBuilder event(fbb); + event.add_context_id(context->id()); + return PostEvent(event.Finish(), std::move(fbb)); + } + Status OnContextUnregistered(Context* context) override { + VLOG(2) << "Client " << client_fd_ << ": Post OnContextUnregistered(" + << context->id() << ")"; + ::flatbuffers::FlatBufferBuilder fbb; + rpc::ContextUnregisteredEventBuilder event(fbb); + event.add_context_id(context->id()); + return PostEvent(event.Finish(), std::move(fbb)); + } + + Status OnModuleLoaded(Context* context, Module* module) override { + VLOG(2) << "Client " << client_fd_ << ": Post OnModuleLoaded(" + << context->id() << ", " << module->name() << ")"; + ::flatbuffers::FlatBufferBuilder fbb; + auto module_name_offs = + fbb.CreateString(module->name().data(), module->name().size()); + rpc::ModuleLoadedEventBuilder event(fbb); + event.add_context_id(context->id()); + event.add_module_name(module_name_offs); + return PostEvent(event.Finish(), std::move(fbb)); + } + + Status OnInvocationRegistered(Invocation* invocation) override { + VLOG(2) << "Client " << client_fd_ << ": Post OnInvocationRegistered(" + << invocation->id() << ")"; + ::flatbuffers::FlatBufferBuilder fbb; + rpc::InvocationRegisteredEventBuilder event(fbb); + event.add_invocation_id(invocation->id()); + return PostEvent(event.Finish(), std::move(fbb)); + } + Status OnInvocationUnregistered(Invocation* invocation) override { + VLOG(2) << "Client " << client_fd_ << ": Post OnInvocationUnregistered(" + << invocation->id() << ")"; + ::flatbuffers::FlatBufferBuilder fbb; + rpc::InvocationUnregisteredEventBuilder event(fbb); + event.add_invocation_id(invocation->id()); + return PostEvent(event.Finish(), std::move(fbb)); + } + + Status OnBreakpointResolved(const rpc::BreakpointDefT& breakpoint, + Context* context) override { + VLOG(2) << "Client " << client_fd_ << ": Post OnBreakpointResolved(" + << breakpoint.breakpoint_id << ", " << context->id() << ", " + << breakpoint.function_ordinal << ")"; + rpc::BreakpointResolvedEventT event; + event.breakpoint = absl::make_unique<rpc::BreakpointDefT>(); + *event.breakpoint = breakpoint; + event.context_id = context->id(); + ::flatbuffers::FlatBufferBuilder fbb; + return PostEvent(rpc::BreakpointResolvedEvent::Pack(fbb, &event), + std::move(fbb)); + } + + Status OnBreakpointHit(int breakpoint_id, + const Invocation& invocation) override { + VLOG(2) << "Client " << client_fd_ << ": Post OnBreakpointHit(" + << breakpoint_id << ", " << invocation.id() << ")"; + ::flatbuffers::FlatBufferBuilder fbb; + ASSIGN_OR_RETURN(auto invocation_offs, + debug_service_->SerializeInvocation(invocation, &fbb)); + rpc::BreakpointHitEventBuilder event(fbb); + event.add_breakpoint_id(breakpoint_id); + event.add_invocation(invocation_offs); + return PostEvent(event.Finish(), std::move(fbb)); + } + + private: + void SessionThread() { + VLOG(2) << "Client " << client_fd_ << ": Thread entry"; + Status session_status = OkStatus(); + while (session_status.ok()) { + auto buffer_or = tcp::ReadBuffer<rpc::Request>(client_fd_); + if (!buffer_or.ok()) { + if (IsCancelled(buffer_or.status())) { + // Graceful shutdown. + VLOG(2) << "Client " << client_fd_ << ": Graceful shutdown requested"; + break; + } + // Error reading. + session_status = std::move(buffer_or).status(); + LOG(ERROR) << "Client " << client_fd_ + << ": Error reading request buffer: " << session_status; + break; + } + auto request_buffer = std::move(buffer_or).ValueOrDie(); + session_status = DispatchRequest(request_buffer.GetRoot()); + if (!session_status.ok()) { + LOG(ERROR) << "Client " << client_fd_ + << ": Error dispatching request: " << session_status; + break; + } + } + VLOG(2) << "Client " << client_fd_ << ": Thread exit"; + AbortSession(session_status); + } + + void AbortSession(Status status) { + if (status.ok()) { + VLOG(2) << "Debug client disconnected"; + } else { + LOG(ERROR) << "Debug session aborted; " << status; + ::flatbuffers::FlatBufferBuilder fbb; + auto message_offs = + fbb.CreateString(status.message().data(), status.message().size()); + rpc::StatusBuilder status_builder(fbb); + status_builder.add_code(static_cast<int>(status.code())); + status_builder.add_message(message_offs); + auto status_offs = status_builder.Finish(); + rpc::ResponseBuilder response(fbb); + response.add_status(status_offs); + fbb.FinishSizePrefixed(response.Finish()); + tcp::WriteBuffer(client_fd_, fbb.Release()).IgnoreError(); + } + closed_callback_(this, std::move(status)); + } + + template <typename T> + Status PostEvent(::flatbuffers::Offset<T> event_offs, + ::flatbuffers::FlatBufferBuilder fbb) { + rpc::ServicePacketBuilder packet_builder(fbb); + packet_builder.add_event_type(rpc::EventUnionTraits<T>::enum_value); + packet_builder.add_event(event_offs.Union()); + fbb.FinishSizePrefixed(packet_builder.Finish()); + return tcp::WriteBuffer(client_fd_, fbb.Release()); + } + + Status DispatchRequest(const rpc::Request& request) { + ::flatbuffers::FlatBufferBuilder fbb; + switch (request.message_type()) { +#define DISPATCH_REQUEST(method_name) \ + case rpc::RequestUnion::method_name##Request: { \ + VLOG(2) << "Client " << client_fd_ \ + << ": DispatchRequest(" #method_name ")..."; \ + ASSIGN_OR_RETURN(auto response_offs, \ + debug_service_->method_name( \ + *request.message_as_##method_name##Request(), &fbb)); \ + return WriteResponse(client_fd_, response_offs, std::move(fbb)); \ + } + DISPATCH_REQUEST(MakeReady); + DISPATCH_REQUEST(GetStatus); + DISPATCH_REQUEST(ListContexts); + DISPATCH_REQUEST(GetModule); + DISPATCH_REQUEST(GetFunction); + DISPATCH_REQUEST(ListInvocations); + DISPATCH_REQUEST(SuspendInvocations); + DISPATCH_REQUEST(ResumeInvocations); + DISPATCH_REQUEST(StepInvocation); + DISPATCH_REQUEST(GetInvocationLocal); + DISPATCH_REQUEST(SetInvocationLocal); + DISPATCH_REQUEST(ListBreakpoints); + DISPATCH_REQUEST(AddBreakpoint); + DISPATCH_REQUEST(RemoveBreakpoint); + DISPATCH_REQUEST(StartProfiling); + DISPATCH_REQUEST(StopProfiling); + default: + return UnimplementedErrorBuilder(IREE_LOC) + << "Unimplemented debug service request: " + << static_cast<int>(request.message_type()); + } + } + + DebugService* debug_service_; + int client_fd_; + ClosedCallback closed_callback_; + std::thread session_thread_; +}; + +class TcpDebugServer final : public DebugServer { + public: + static StatusOr<std::unique_ptr<TcpDebugServer>> Listen(int port) { + // We support both IPv4 and IPv6 by using the IN6ADDR_ANY. This requires + // that we setup the socket as INET6 and enable reuse (so the same port can + // be bound for both IPv4 and IPv6). + int listen_fd = ::socket(AF_INET6, SOCK_STREAM, 0); + RETURN_IF_ERROR(tcp::ToggleSocketAddressReuse(listen_fd, true)); + + struct sockaddr_in6 socket_addr = {0}; + socket_addr.sin6_family = AF_INET6; + socket_addr.sin6_port = htons(port); + socket_addr.sin6_addr = in6addr_any; + if (::bind(listen_fd, reinterpret_cast<struct sockaddr*>(&socket_addr), + sizeof(socket_addr)) < 0) { + return AlreadyExistsErrorBuilder(IREE_LOC) + << "Unable to bind socket to port " << port << ": (" << errno + << ") " << ::strerror(errno); + } + if (::listen(listen_fd, 1)) { + ::close(listen_fd); + return AlreadyExistsErrorBuilder(IREE_LOC) + << "Unable to listen on port " << port << ": (" << errno << ") " + << ::strerror(errno); + } + return absl::make_unique<TcpDebugServer>(listen_fd); + } + + TcpDebugServer(int listen_fd) : listen_fd_(listen_fd) { + server_thread_ = std::thread([this]() { ListenThread(); }); + } + + ~TcpDebugServer() ABSL_LOCKS_EXCLUDED(mutex_) override { + absl::ReleasableMutexLock lock(&mutex_); + LOG(INFO) << "Shutting down debug server..."; + + // Notify all sessions. + for (auto& session : sessions_) { + session->OnServiceShutdown().IgnoreError(); + } + + // Shut down listen socket first so that we can't accept new connections. + VLOG(2) << "Shutting down listen socket..."; + ::shutdown(listen_fd_, SHUT_RDWR); + if (server_thread_.joinable()) { + VLOG(2) << "Joining listen thread..."; + server_thread_.join(); + VLOG(2) << "Joined listen thread!"; + } + VLOG(2) << "Closing listen socket..."; + ::close(listen_fd_); + listen_fd_ = -1; + VLOG(2) << "Closed listen socket!"; + + // Kill all active sessions. Note that we must do this outside of our lock. + std::vector<std::unique_ptr<TcpDebugSession>> sessions = + std::move(sessions_); + std::vector<std::function<void()>> at_exit_callbacks = + std::move(at_exit_callbacks_); + lock.Release(); + VLOG(2) << "Clearing live sessions..."; + sessions.clear(); + VLOG(2) << "Calling AtExit callbacks..."; + for (auto& callback : at_exit_callbacks) { + callback(); + } + LOG(INFO) << "Debug server shutdown!"; + } + + DebugService* debug_service() { return &debug_service_; } + + Status AcceptNewSession(int client_fd) { + LOG(INFO) << "Accepting new client session as " << client_fd; + ASSIGN_OR_RETURN(auto session, + TcpDebugSession::Accept( + &debug_service_, client_fd, + [this](TcpDebugSession* session, Status status) { + absl::MutexLock lock(&mutex_); + for (auto it = sessions_.begin(); + it != sessions_.end(); ++it) { + if (it->get() == session) { + sessions_.erase(it); + break; + } + } + return OkStatus(); + })); + + absl::MutexLock lock(&mutex_); + sessions_.push_back(std::move(session)); + return OkStatus(); + } + + void AtExit(std::function<void()> callback) override { + absl::MutexLock lock(&mutex_); + at_exit_callbacks_.push_back(std::move(callback)); + } + + Status WaitUntilSessionReady() override { + return debug_service_.WaitUntilAllSessionsReady(); + } + + protected: + Status RegisterContext(Context* context) override { + return debug_service_.RegisterContext(context); + } + Status UnregisterContext(Context* context) override { + return debug_service_.UnregisterContext(context); + } + Status RegisterContextModule(Context* context, Module* module) override { + return debug_service_.RegisterContextModule(context, module); + } + Status RegisterInvocation(Invocation* invocation) override { + return debug_service_.RegisterInvocation(invocation); + } + Status UnregisterInvocation(Invocation* invocation) override { + return debug_service_.UnregisterInvocation(invocation); + } + + private: + void ListenThread() { + VLOG(2) << "Listen thread entry"; + while (true) { + struct sockaddr_in accept_socket_addr; + socklen_t accept_socket_addr_length = sizeof(accept_socket_addr); + int accepted_fd = ::accept( + listen_fd_, reinterpret_cast<struct sockaddr*>(&accept_socket_addr), + &accept_socket_addr_length); + if (accepted_fd < 0) { + if (errno == EINVAL) { + // Shutting down gracefully. + break; + } + // We may be able to recover from some of these cases, but... shrug. + LOG(FATAL) << "Failed to accept client socket: (" << errno << ") " + << ::strerror(errno); + break; + } + auto accept_status = AcceptNewSession(accepted_fd); + if (!accept_status.ok()) { + LOG(ERROR) << "Failed to accept incoming debug client: " + << accept_status; + } + } + VLOG(2) << "Listen thread exit"; + } + + int listen_fd_; + std::thread server_thread_; + + absl::Mutex mutex_; + std::vector<std::unique_ptr<TcpDebugSession>> sessions_ + ABSL_GUARDED_BY(mutex_); + std::vector<std::function<void()>> at_exit_callbacks_ ABSL_GUARDED_BY(mutex_); + + DebugService debug_service_; +}; + +} // namespace + +// static +StatusOr<std::unique_ptr<DebugServer>> DebugServer::Create(int listen_port) { + ASSIGN_OR_RETURN(auto debug_server, TcpDebugServer::Listen(listen_port)); + LOG(INFO) << "Debug server listening on localhost:" << listen_port; + return debug_server; +} + +} // namespace debug +} // namespace rt +} // namespace iree
diff --git a/rt/debug/debug_service.cc b/rt/debug/debug_service.cc new file mode 100644 index 0000000..5a21ded --- /dev/null +++ b/rt/debug/debug_service.cc
@@ -0,0 +1,850 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rt/debug/debug_service.h" + +#include <algorithm> +#include <memory> + +#include "absl/strings/str_join.h" +#include "absl/synchronization/mutex.h" +#include "base/flatbuffer_util.h" +#include "base/source_location.h" +#include "base/status.h" +#include "flatbuffers/flatbuffers.h" +#include "flatbuffers/reflection.h" +#include "rt/instance.h" +#include "schemas/debug_service_generated.h" +#include "schemas/reflection_data.h" + +namespace iree { +namespace rt { +namespace debug { +namespace { + +using ::flatbuffers::FlatBufferBuilder; +using ::flatbuffers::Offset; +using ::iree::hal::BufferView; + +int32_t NextUniqueBreakpointId() { + static std::atomic<int32_t> next_id = 0; + return ++next_id; +} + +// Gets an embedded flatbuffers reflection schema. +const ::reflection::Schema& GetSchema(const char* schema_name) { + for (const auto* file_toc = schemas::reflection_data_create(); + file_toc != nullptr; ++file_toc) { + if (std::strcmp(file_toc->name, schema_name) == 0) { + return *::reflection::GetSchema(file_toc->data); + } + } + LOG(FATAL) << "FlatBuffer schema '" << schema_name + << "' not found in binary; ensure it is in :reflection_data"; +} + +// Recursively copies a flatbuffer table, returning the root offset in |fbb|. +template <typename T> +StatusOr<Offset<T>> DeepCopyTable(const char* schema_name, const T& table_def, + FlatBufferBuilder* fbb) { + const auto* root_table = + reinterpret_cast<const ::flatbuffers::Table*>(std::addressof(table_def)); + const auto& schema = GetSchema(schema_name); + return {::flatbuffers::CopyTable(*fbb, schema, *schema.root_table(), + *root_table, + /*use_string_pooling=*/false) + .o}; +} + +// Serializes a buffer_view value, optionally including the entire buffer +// contents. +StatusOr<Offset<rpc::BufferViewDef>> SerializeBufferView( + const BufferView& buffer_view, bool include_buffer_contents, + FlatBufferBuilder* fbb) { + auto shape_offs = fbb->CreateVector(buffer_view.shape.subspan().data(), + buffer_view.shape.subspan().size()); + rpc::BufferViewDefBuilder value(*fbb); + value.add_is_valid(buffer_view.buffer != nullptr); + value.add_shape(shape_offs); + value.add_element_size(buffer_view.element_size); + if (include_buffer_contents) { + // TODO(benvanik): add buffer data. + } + return value.Finish(); +} + +// Serializes a stack frame. +StatusOr<Offset<rpc::StackFrameDef>> SerializeStackFrame( + const StackFrame& stack_frame, FlatBufferBuilder* fbb) { + ASSIGN_OR_RETURN(int function_ordinal, + stack_frame.module().function_table().LookupFunctionOrdinal( + stack_frame.function())); + auto module_name_offs = fbb->CreateString(stack_frame.module().name().data(), + stack_frame.module().name().size()); + std::vector<Offset<rpc::BufferViewDef>> local_offs_list; + for (const auto& local : stack_frame.locals()) { + ASSIGN_OR_RETURN( + auto local_offs, + SerializeBufferView(local, /*include_buffer_contents=*/false, fbb)); + local_offs_list.push_back(local_offs); + } + auto locals_offs = fbb->CreateVector(local_offs_list); + rpc::StackFrameDefBuilder sfb(*fbb); + sfb.add_module_name(module_name_offs); + sfb.add_function_ordinal(function_ordinal); + sfb.add_offset(stack_frame.offset()); + sfb.add_locals(locals_offs); + return sfb.Finish(); +} + +// Resolves a local from a invocation:frame:local_index to a BufferView. +StatusOr<BufferView*> ResolveInvocationLocal(Invocation* invocation, + int frame_index, int local_index) { + auto frames = invocation->mutable_stack()->mutable_frames(); + if (frame_index < 0 || frame_index > frames.size()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Frame index " << frame_index << " out of bounds (" + << frames.size() << ")"; + } + auto locals = frames[frame_index].mutable_locals(); + if (local_index < 0 || local_index > locals.size()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Local index " << local_index << " out of bounds (" + << locals.size() << ")"; + } + return &locals[local_index]; +} + +// Suspends a set of invocations and blocks until all have been suspended (or +// one or more fails to suspend). This works only when the caller is *not* one +// of the threads executing a invocation in |invocations| (this normally +// shouldn't happen, but may if we support eval()-like semantics). +Status SuspendInvocationsAndWait(absl::Span<Invocation*> invocations) { + absl::Mutex suspend_mutex; + Status one_suspend_status = OkStatus(); + std::list<int> pending_suspend_ids; + for (auto* invocation : invocations) { + pending_suspend_ids.push_back(invocation->id()); + } + for (auto* invocation : invocations) { + auto suspend_callback = [&, invocation](Status suspend_status) { + absl::MutexLock lock(&suspend_mutex); + auto it = std::find(pending_suspend_ids.begin(), + pending_suspend_ids.end(), invocation->id()); + CHECK(it != pending_suspend_ids.end()); + pending_suspend_ids.erase(it); + if (!suspend_status.ok()) { + one_suspend_status = std::move(suspend_status); + } + }; + RETURN_IF_ERROR(invocation->Suspend(suspend_callback)); + } + suspend_mutex.LockWhen(absl::Condition( + +[](std::list<int>* pending_suspend_ids) { + return pending_suspend_ids->empty(); + }, + &pending_suspend_ids)); + suspend_mutex.Unlock(); + return one_suspend_status; +} + +} // namespace + +Status DebugService::SuspendAllInvocations() { + VLOG(2) << "SuspendAllInvocations"; + for (auto* invocation : invocations_) { + RETURN_IF_ERROR(invocation->Suspend()); + } + return OkStatus(); +} + +Status DebugService::ResumeAllInvocations() { + VLOG(2) << "ResumeAllInvocations"; + for (auto* invocation : invocations_) { + RETURN_IF_ERROR(invocation->Resume()); + } + return OkStatus(); +} + +Status DebugService::RegisterContext(Context* context) { + absl::MutexLock lock(&mutex_); + VLOG(2) << "RegisterContext(" << context->id() << ")"; + RETURN_IF_ERROR(SuspendAllInvocations()); + RETURN_IF_ERROR(UnreadyAllSessions()); + contexts_.push_back(context); + for (auto* session : sessions_) { + RETURN_IF_ERROR(session->OnContextRegistered(context)); + } + RETURN_IF_ERROR(ResumeAllInvocations()); + return OkStatus(); +} + +Status DebugService::UnregisterContext(Context* context) { + absl::MutexLock lock(&mutex_); + VLOG(2) << "UnregisterContext(" << context->id() << ")"; + auto it = std::find(contexts_.begin(), contexts_.end(), context); + if (it == contexts_.end()) { + return NotFoundErrorBuilder(IREE_LOC) << "Context not registered"; + } + RETURN_IF_ERROR(SuspendAllInvocations()); + RETURN_IF_ERROR(UnreadyAllSessions()); + for (auto* session : sessions_) { + RETURN_IF_ERROR(session->OnContextUnregistered(context)); + } + contexts_.erase(it); + RETURN_IF_ERROR(ResumeAllInvocations()); + return OkStatus(); +} + +StatusOr<Context*> DebugService::GetContext(int context_id) const { + for (auto* context : contexts_) { + if (context->id() == context_id) { + return context; + } + } + return NotFoundErrorBuilder(IREE_LOC) + << "Context with ID " << context_id + << " not registered with the debug service"; +} + +Status DebugService::RegisterContextModule(Context* context, Module* module) { + absl::MutexLock lock(&mutex_); + VLOG(2) << "RegisterContextModule(" << context->id() << ", " << module->name() + << ")"; + RETURN_IF_ERROR(SuspendAllInvocations()); + RETURN_IF_ERROR(UnreadyAllSessions()); + RETURN_IF_ERROR(RegisterModuleBreakpoints(context, module)); + for (auto* session : sessions_) { + RETURN_IF_ERROR(session->OnModuleLoaded(context, module)); + } + RETURN_IF_ERROR(ResumeAllInvocations()); + return OkStatus(); +} + +StatusOr<Module*> DebugService::GetModule(int context_id, + absl::string_view module_name) const { + ASSIGN_OR_RETURN(auto* context, GetContext(context_id)); + for (const auto& module : context->modules()) { + if (module->name() == module_name) { + return module.get(); + } + } + return NotFoundErrorBuilder(IREE_LOC) + << "Module '" << module_name << "' not found on context " + << context_id; +} + +Status DebugService::RegisterInvocation(Invocation* invocation) { + absl::MutexLock lock(&mutex_); + VLOG(2) << "RegisterInvocation(" << invocation->id() << ")"; + RETURN_IF_ERROR(SuspendAllInvocations()); + RETURN_IF_ERROR(UnreadyAllSessions()); + invocations_.push_back(invocation); + if (sessions_unready_) { + // Suspend immediately as a debugger is not yet read. + RETURN_IF_ERROR(invocation->Suspend()); + } + for (auto* session : sessions_) { + RETURN_IF_ERROR(session->OnInvocationRegistered(invocation)); + } + RETURN_IF_ERROR(ResumeAllInvocations()); + return OkStatus(); +} + +Status DebugService::UnregisterInvocation(Invocation* invocation) { + absl::MutexLock lock(&mutex_); + VLOG(2) << "UnregisterInvocation(" << invocation->id() << ")"; + auto it = std::find(invocations_.begin(), invocations_.end(), invocation); + if (it == invocations_.end()) { + return NotFoundErrorBuilder(IREE_LOC) << "Invocation state not registered"; + } + RETURN_IF_ERROR(SuspendAllInvocations()); + RETURN_IF_ERROR(UnreadyAllSessions()); + for (auto* session : sessions_) { + RETURN_IF_ERROR(session->OnInvocationUnregistered(invocation)); + } + invocations_.erase(it); + RETURN_IF_ERROR(ResumeAllInvocations()); + return OkStatus(); +} + +StatusOr<Invocation*> DebugService::GetInvocation(int invocation_id) const { + for (auto* invocation : invocations_) { + if (invocation->id() == invocation_id) { + return invocation; + } + } + return NotFoundErrorBuilder(IREE_LOC) + << "Invocation state with ID " << invocation_id + << " not registered with the debug service"; +} + +StatusOr<Offset<rpc::InvocationDef>> DebugService::SerializeInvocation( + const Invocation& invocation, FlatBufferBuilder* fbb) { + std::vector<Offset<rpc::StackFrameDef>> frame_offs_list; + for (const auto& frame : invocation.stack().frames()) { + ASSIGN_OR_RETURN(auto frame_offs, SerializeStackFrame(frame, fbb)); + frame_offs_list.push_back(frame_offs); + } + auto frames_offs = fbb->CreateVector(frame_offs_list); + rpc::InvocationDefBuilder fsb(*fbb); + fsb.add_invocation_id(invocation.id()); + fsb.add_frames(frames_offs); + return fsb.Finish(); +} + +Status DebugService::RegisterDebugSession(DebugSession* session) { + absl::MutexLock lock(&mutex_); + VLOG(2) << "RegisterDebugSession(" << session->id() << ")"; + sessions_.push_back(session); + if (session->is_ready()) { + ++sessions_ready_; + } else { + // Immediately suspend all invocations until the session readies up (or + // disconnects). + ++sessions_unready_; + RETURN_IF_ERROR(SuspendAllInvocations()); + } + return OkStatus(); +} + +Status DebugService::UnregisterDebugSession(DebugSession* session) { + absl::MutexLock lock(&mutex_); + VLOG(2) << "UnregisterDebugSession(" << session->id() << ")"; + auto it = std::find(sessions_.begin(), sessions_.end(), session); + if (it == sessions_.end()) { + return NotFoundErrorBuilder(IREE_LOC) << "Session not registered"; + } + sessions_.erase(it); + if (session->is_ready()) { + --sessions_ready_; + } else { + // If the session never readied up then we still have all invocations + // suspended waiting for it. We should resume so that we don't block + // forever. + --sessions_unready_; + RETURN_IF_ERROR(ResumeAllInvocations()); + } + return OkStatus(); +} + +Status DebugService::WaitUntilAllSessionsReady() { + VLOG(1) << "Waiting until all sessions are ready..."; + struct CondState { + DebugService* service; + bool had_sessions; + bool consider_aborted; + } cond_state; + { + absl::MutexLock lock(&mutex_); + cond_state.service = this; + cond_state.had_sessions = !sessions_.empty(); + cond_state.consider_aborted = false; + } + mutex_.LockWhen(absl::Condition( + +[](CondState* cond_state) { + cond_state->service->mutex_.AssertHeld(); + if (cond_state->service->sessions_ready_ > 0) { + // One or more sessions are ready. + return true; + } + if (cond_state->service->sessions_unready_ > 0) { + // One or more sessions are connected but not yet ready. + cond_state->had_sessions = true; + return false; + } + if (cond_state->had_sessions && + cond_state->service->sessions_.empty()) { + // We had sessions but now we don't, consider this an error and bail. + // This can happen when a session connects but never readies up. + cond_state->consider_aborted = true; + return true; + } + return false; + }, + &cond_state)); + mutex_.Unlock(); + if (cond_state.consider_aborted) { + return AbortedErrorBuilder(IREE_LOC) + << "At least one session connected but never readied up"; + } + VLOG(1) << "Sessions ready, resuming"; + return OkStatus(); +} + +StatusOr<Offset<rpc::MakeReadyResponse>> DebugService::MakeReady( + const rpc::MakeReadyRequest& request, FlatBufferBuilder* fbb) { + absl::MutexLock lock(&mutex_); + VLOG(1) << "RPC: MakeReady()"; + // TODO(benvanik): support more than one session. + CHECK_LE(sessions_.size(), 1) << "Only one session is currently supported"; + if (!sessions_.empty()) { + RETURN_IF_ERROR(sessions_[0]->OnReady()); + } + sessions_ready_ = 0; + sessions_unready_ = 0; + for (auto* session : sessions_) { + sessions_ready_ += session->is_ready() ? 1 : 0; + sessions_unready_ += session->is_ready() ? 0 : 1; + } + rpc::MakeReadyResponseBuilder response(*fbb); + return response.Finish(); +} + +Status DebugService::UnreadyAllSessions() { + for (auto* session : sessions_) { + RETURN_IF_ERROR(session->OnUnready()); + } + sessions_ready_ = 0; + sessions_unready_ = sessions_.size(); + return OkStatus(); +} + +StatusOr<Offset<rpc::GetStatusResponse>> DebugService::GetStatus( + const rpc::GetStatusRequest& request, FlatBufferBuilder* fbb) { + absl::MutexLock lock(&mutex_); + VLOG(1) << "RPC: GetStatus()"; + rpc::GetStatusResponseBuilder response(*fbb); + response.add_protocol(0); + return response.Finish(); +} + +StatusOr<Offset<rpc::ListContextsResponse>> DebugService::ListContexts( + const rpc::ListContextsRequest& request, FlatBufferBuilder* fbb) { + absl::MutexLock lock(&mutex_); + VLOG(1) << "RPC: ListContexts()"; + std::vector<Offset<rpc::ContextDef>> context_offs; + for (auto* context : contexts_) { + std::vector<Offset<rpc::NativeFunctionDef>> native_function_offs_list; + for (const auto& pair : context->native_functions()) { + auto name_offs = fbb->CreateString(pair.first); + rpc::NativeFunctionDefBuilder native_function(*fbb); + native_function.add_name(name_offs); + native_function_offs_list.push_back(native_function.Finish()); + } + auto native_functions_offs = fbb->CreateVector(native_function_offs_list); + + std::vector<std::string> module_names; + for (const auto& module : context->modules()) { + module_names.push_back(std::string(module->name())); + } + auto module_names_offs = fbb->CreateVectorOfStrings(module_names); + + rpc::ContextDefBuilder context_def(*fbb); + context_def.add_context_id(context->id()); + context_def.add_native_functions(native_functions_offs); + context_def.add_module_names(module_names_offs); + context_offs.push_back(context_def.Finish()); + } + + auto contexts_offs = fbb->CreateVector(context_offs); + rpc::ListContextsResponseBuilder response(*fbb); + response.add_contexts(contexts_offs); + return response.Finish(); +} + +StatusOr<Offset<rpc::GetModuleResponse>> DebugService::GetModule( + const rpc::GetModuleRequest& request, FlatBufferBuilder* fbb) { + absl::MutexLock lock(&mutex_); + VLOG(1) << "RPC: GetModule(" << request.context_id() << ", " + << WrapString(request.module_name()) << ")"; + ASSIGN_OR_RETURN(auto* module, GetModule(request.context_id(), + WrapString(request.module_name()))); + // TODO(benvanik): find a way to do this without possibly duping all memory. + // I suspect that when we make constants poolable then there's only one + // place to kill and there may be magic we could use to do that during a + // reflection pass. + ModuleDefT module_t; + module->def().UnPackTo(&module_t); + for (auto& function : module_t.function_table->functions) { + function->bytecode->contents.clear(); + } + auto trimmed_module_offs = ModuleDef::Pack(*fbb, &module_t); + rpc::GetModuleResponseBuilder response(*fbb); + response.add_module_(trimmed_module_offs); + return response.Finish(); +} + +StatusOr<Offset<rpc::GetFunctionResponse>> DebugService::GetFunction( + const rpc::GetFunctionRequest& request, FlatBufferBuilder* fbb) { + absl::MutexLock lock(&mutex_); + VLOG(1) << "RPC: GetFunction(" << WrapString(request.module_name()) << ", " + << request.function_ordinal() << ")"; + ASSIGN_OR_RETURN(auto* module, GetModule(request.context_id(), + WrapString(request.module_name()))); + ASSIGN_OR_RETURN(auto& function, module->function_table().LookupFunction( + request.function_ordinal())); + Offset<BytecodeDef> bytecode_offs; + if (function.def().bytecode()) { + ASSIGN_OR_RETURN( + bytecode_offs, + DeepCopyTable("bytecode_def.bfbs", *function.def().bytecode(), fbb)); + } + rpc::GetFunctionResponseBuilder response(*fbb); + response.add_bytecode(bytecode_offs); + return response.Finish(); +} + +StatusOr<Offset<rpc::ResolveFunctionResponse>> DebugService::ResolveFunction( + const rpc::ResolveFunctionRequest& request, FlatBufferBuilder* fbb) { + absl::MutexLock lock(&mutex_); + VLOG(1) << "RPC: ResolveFunction(" << WrapString(request.module_name()) + << ", " << WrapString(request.function_name()) << ")"; + std::vector<int32_t> context_ids; + auto context_ids_offs = fbb->CreateVector(context_ids); + int function_ordinal = -1; + for (auto* context : contexts_) { + for (const auto& module : context->modules()) { + if (module->name() == WrapString(request.module_name())) { + ASSIGN_OR_RETURN(function_ordinal, + module->function_table().LookupFunctionOrdinalByName( + WrapString(request.function_name()))); + context_ids.push_back(context->id()); + break; + } + } + } + rpc::ResolveFunctionResponseBuilder response(*fbb); + response.add_context_ids(context_ids_offs); + response.add_function_ordinal(function_ordinal); + return response.Finish(); +} + +StatusOr<Offset<rpc::ListInvocationsResponse>> DebugService::ListInvocations( + const rpc::ListInvocationsRequest& request, FlatBufferBuilder* fbb) { + absl::MutexLock lock(&mutex_); + VLOG(1) << "RPC: ListInvocations()"; + std::vector<Offset<rpc::InvocationDef>> invocation_offsets; + for (auto* invocation : invocations_) { + ASSIGN_OR_RETURN(auto invocation_offs, + SerializeInvocation(*invocation, fbb)); + invocation_offsets.push_back(invocation_offs); + } + auto invocations_offs = fbb->CreateVector(invocation_offsets); + rpc::ListInvocationsResponseBuilder response(*fbb); + response.add_invocations(invocations_offs); + return response.Finish(); +} + +StatusOr<Offset<rpc::SuspendInvocationsResponse>> +DebugService::SuspendInvocations(const rpc::SuspendInvocationsRequest& request, + FlatBufferBuilder* fbb) { + absl::MutexLock lock(&mutex_); + VLOG(1) << "RPC: SuspendInvocations(invocation_ids=[" + << (request.invocation_ids() + ? absl::StrJoin(*request.invocation_ids(), ", ") + : "") + << "])"; + std::vector<Offset<rpc::InvocationDef>> invocation_offsets; + if (request.invocation_ids() && request.invocation_ids()->size() > 0) { + // Suspending a list of invocations. + std::vector<Invocation*> invocations_to_suspend; + for (int invocation_id : *request.invocation_ids()) { + ASSIGN_OR_RETURN(auto* invocation, GetInvocation(invocation_id)); + invocations_to_suspend.push_back(invocation); + } + RETURN_IF_ERROR( + SuspendInvocationsAndWait(absl::MakeSpan(invocations_to_suspend))); + for (auto* invocation : invocations_to_suspend) { + ASSIGN_OR_RETURN(auto invocation_offs, + SerializeInvocation(*invocation, fbb)); + invocation_offsets.push_back(invocation_offs); + } + } else { + // Suspending all invocations. + RETURN_IF_ERROR(SuspendAllInvocations()); + for (auto* invocation : invocations_) { + ASSIGN_OR_RETURN(auto invocation_offs, + SerializeInvocation(*invocation, fbb)); + invocation_offsets.push_back(invocation_offs); + } + } + auto invocations_offs = fbb->CreateVector(invocation_offsets); + rpc::SuspendInvocationsResponseBuilder response(*fbb); + response.add_invocations(invocations_offs); + return response.Finish(); +} + +StatusOr<Offset<rpc::ResumeInvocationsResponse>> +DebugService::ResumeInvocations(const rpc::ResumeInvocationsRequest& request, + FlatBufferBuilder* fbb) { + VLOG(1) << "RPC: ResumeInvocations(invocation_ids=[" + << (request.invocation_ids() + ? absl::StrJoin(*request.invocation_ids(), ", ") + : "") + << "])"; + absl::MutexLock lock(&mutex_); + if (request.invocation_ids() && request.invocation_ids()->size() > 0) { + // Resuming a list of invocations. + for (int invocation_id : *request.invocation_ids()) { + ASSIGN_OR_RETURN(auto* invocation, GetInvocation(invocation_id)); + RETURN_IF_ERROR(invocation->Resume()); + } + } else { + // Resuming all invocations. + RETURN_IF_ERROR(ResumeAllInvocations()); + } + rpc::ResumeInvocationsResponseBuilder response(*fbb); + return response.Finish(); +} + +StatusOr<Offset<rpc::StepInvocationResponse>> DebugService::StepInvocation( + const rpc::StepInvocationRequest& request, FlatBufferBuilder* fbb) { + absl::MutexLock lock(&mutex_); + VLOG(1) << "RPC: StepInvocation(" << request.invocation_id() << ")"; + ASSIGN_OR_RETURN(auto* invocation, GetInvocation(request.invocation_id())); + Invocation::StepTarget step_target; + // TODO(benvanik): step settings. + RETURN_IF_ERROR(invocation->Step(step_target)); + rpc::StepInvocationResponseBuilder response(*fbb); + return response.Finish(); +} + +StatusOr<Offset<rpc::GetInvocationLocalResponse>> +DebugService::GetInvocationLocal(const rpc::GetInvocationLocalRequest& request, + FlatBufferBuilder* fbb) { + absl::MutexLock lock(&mutex_); + VLOG(1) << "RPC: GetInvocationLocal(" << request.invocation_id() << ", " + << request.frame_index() << ", " << request.local_index() << ")"; + ASSIGN_OR_RETURN(auto* invocation, GetInvocation(request.invocation_id())); + ASSIGN_OR_RETURN(auto* local, + ResolveInvocationLocal(invocation, request.frame_index(), + request.local_index())); + + ASSIGN_OR_RETURN( + auto value_offs, + SerializeBufferView(*local, /*include_buffer_contents=*/true, fbb)); + rpc::GetInvocationLocalResponseBuilder response(*fbb); + response.add_value(value_offs); + return response.Finish(); +} + +StatusOr<Offset<rpc::SetInvocationLocalResponse>> +DebugService::SetInvocationLocal(const rpc::SetInvocationLocalRequest& request, + FlatBufferBuilder* fbb) { + absl::MutexLock lock(&mutex_); + VLOG(1) << "RPC: SetInvocationLocal(" << request.invocation_id() << ", " + << request.frame_index() << ", " << request.local_index() << ")"; + ASSIGN_OR_RETURN(auto* invocation, GetInvocation(request.invocation_id())); + ASSIGN_OR_RETURN(auto* local, + ResolveInvocationLocal(invocation, request.frame_index(), + request.local_index())); + + if (!request.value()) { + local->shape.clear(); + local->element_size = 0; + local->buffer.reset(); + } else { + const auto& value = *request.value(); + local->shape.clear(); + if (value.shape()) { + for (int dim : *value.shape()) { + local->shape.push_back(dim); + } + } + local->element_size = value.element_size(); + // TODO(benvanik): copy buffer data. + } + + ASSIGN_OR_RETURN( + auto value_offs, + SerializeBufferView(*local, /*include_buffer_contents=*/true, fbb)); + rpc::SetInvocationLocalResponseBuilder response(*fbb); + response.add_value(value_offs); + return response.Finish(); +} + +StatusOr<Offset<rpc::ListBreakpointsResponse>> DebugService::ListBreakpoints( + const rpc::ListBreakpointsRequest& request, FlatBufferBuilder* fbb) { + absl::MutexLock lock(&mutex_); + VLOG(1) << "RPC: ListBreakpoints()"; + std::vector<Offset<rpc::BreakpointDef>> breakpoint_offs; + for (const auto& breakpoint : breakpoints_) { + breakpoint_offs.push_back(rpc::BreakpointDef::Pack(*fbb, &breakpoint)); + } + auto breakpoints_offs = fbb->CreateVector(breakpoint_offs); + rpc::ListBreakpointsResponseBuilder response(*fbb); + response.add_breakpoints(breakpoints_offs); + return response.Finish(); +} + +StatusOr<Offset<rpc::AddBreakpointResponse>> DebugService::AddBreakpoint( + const rpc::AddBreakpointRequest& request, FlatBufferBuilder* fbb) { + if (!request.breakpoint()) { + return InvalidArgumentErrorBuilder(IREE_LOC) << "No breakpoint specified"; + } + absl::MutexLock lock(&mutex_); + int breakpoint_id = NextUniqueBreakpointId(); + VLOG(1) << "RPC: AddBreakpoint(" << breakpoint_id << ")"; + + RETURN_IF_ERROR(SuspendAllInvocations()); + + rpc::BreakpointDefT breakpoint; + request.breakpoint()->UnPackTo(&breakpoint); + breakpoint.breakpoint_id = breakpoint_id; + switch (breakpoint.breakpoint_type) { + case rpc::BreakpointType::BYTECODE_FUNCTION: + case rpc::BreakpointType::NATIVE_FUNCTION: + for (auto* context : contexts_) { + auto module_or = context->LookupModule(breakpoint.module_name); + if (!module_or.ok()) continue; + auto* module = module_or.ValueOrDie(); + RETURN_IF_ERROR( + RegisterFunctionBreakpoint(context, module, &breakpoint)); + } + break; + default: + return UnimplementedErrorBuilder(IREE_LOC) << "Unhandled breakpoint type"; + } + breakpoints_.push_back(std::move(breakpoint)); + + RETURN_IF_ERROR(ResumeAllInvocations()); + + auto breakpoint_offs = rpc::BreakpointDef::Pack(*fbb, &breakpoints_.back()); + rpc::AddBreakpointResponseBuilder response(*fbb); + response.add_breakpoint(breakpoint_offs); + return response.Finish(); +} + +StatusOr<Offset<rpc::RemoveBreakpointResponse>> DebugService::RemoveBreakpoint( + const rpc::RemoveBreakpointRequest& request, FlatBufferBuilder* fbb) { + absl::MutexLock lock(&mutex_); + VLOG(1) << "RPC: RemoveBreakpoint(" << request.breakpoint_id() << ")"; + RETURN_IF_ERROR(SuspendAllInvocations()); + + bool found = false; + for (auto it = breakpoints_.begin(); it != breakpoints_.end(); ++it) { + if (it->breakpoint_id == request.breakpoint_id()) { + auto& breakpoint = *it; + found = true; + switch (breakpoint.breakpoint_type) { + case rpc::BreakpointType::BYTECODE_FUNCTION: + case rpc::BreakpointType::NATIVE_FUNCTION: + RETURN_IF_ERROR(UnregisterFunctionBreakpoint(breakpoint)); + break; + default: + return UnimplementedErrorBuilder(IREE_LOC) + << "Unhandled breakpoint type"; + } + breakpoints_.erase(it); + break; + } + } + + RETURN_IF_ERROR(ResumeAllInvocations()); + if (!found) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Breakpoint ID " << request.breakpoint_id() << " not found"; + } + + rpc::RemoveBreakpointResponseBuilder response(*fbb); + return response.Finish(); +} + +Status DebugService::RegisterModuleBreakpoints(Context* context, + Module* module) { + for (auto& breakpoint : breakpoints_) { + switch (breakpoint.breakpoint_type) { + case rpc::BreakpointType::BYTECODE_FUNCTION: + if (breakpoint.module_name == module->name()) { + RETURN_IF_ERROR( + RegisterFunctionBreakpoint(context, module, &breakpoint)); + } + break; + default: + // Not relevant to modules. + break; + } + } + return OkStatus(); +} + +Status DebugService::RegisterFunctionBreakpoint( + Context* context, Module* module, rpc::BreakpointDefT* breakpoint) { + if (!breakpoint->function_name.empty()) { + ASSIGN_OR_RETURN(breakpoint->function_ordinal, + module->function_table().LookupFunctionOrdinalByName( + breakpoint->function_name)); + } + RETURN_IF_ERROR(module->mutable_function_table()->RegisterBreakpoint( + breakpoint->function_ordinal, breakpoint->bytecode_offset, + std::bind(&DebugService::OnFunctionBreakpointHit, this, + breakpoint->breakpoint_id, std::placeholders::_1))); + for (auto* session : sessions_) { + RETURN_IF_ERROR(session->OnBreakpointResolved(*breakpoint, context)); + } + return OkStatus(); +} + +Status DebugService::UnregisterFunctionBreakpoint( + const rpc::BreakpointDefT& breakpoint) { + for (auto* context : contexts_) { + auto module_or = context->LookupModule(breakpoint.module_name); + if (!module_or.ok()) continue; + auto* module = module_or.ValueOrDie(); + RETURN_IF_ERROR(module->mutable_function_table()->UnregisterBreakpoint( + breakpoint.function_ordinal, breakpoint.bytecode_offset)); + } + return OkStatus(); +} + +Status DebugService::OnFunctionBreakpointHit(int breakpoint_id, + const Invocation& invocation) { + absl::ReleasableMutexLock lock(&mutex_); + LOG(INFO) << "Breakpoint hit: " << breakpoint_id; + RETURN_IF_ERROR(UnreadyAllSessions()); + for (auto* session : sessions_) { + RETURN_IF_ERROR(session->OnBreakpointHit(breakpoint_id, invocation)); + } + lock.Release(); + + // TODO(benvanik): on-demand attach if desired? + + // Wait until all clients are ready. + auto wait_status = WaitUntilAllSessionsReady(); + if (IsAborted(wait_status)) { + // This means we lost all sessions. Just continue. + VLOG(1) << "No sessions active; ignoring breakpoint and continuing"; + return OkStatus(); + } + return wait_status; +} + +StatusOr<Offset<rpc::StartProfilingResponse>> DebugService::StartProfiling( + const rpc::StartProfilingRequest& request, FlatBufferBuilder* fbb) { + absl::MutexLock lock(&mutex_); + VLOG(1) << "RPC: StartProfiling()"; + // TODO(benvanik): implement profiling. + // ASSIGN_OR_RETURN(auto* context, GetContext(request.context_id())); + // rpc::StartProfilingResponseBuilder response(*fbb); + // return response.Finish(); + return UnimplementedErrorBuilder(IREE_LOC) + << "StartProfiling not yet implemented"; +} + +StatusOr<Offset<rpc::StopProfilingResponse>> DebugService::StopProfiling( + const rpc::StopProfilingRequest& request, FlatBufferBuilder* fbb) { + absl::MutexLock lock(&mutex_); + VLOG(1) << "RPC: StopProfiling()"; + // TODO(benvanik): implement profiling. + // ASSIGN_OR_RETURN(auto* context, GetContext(request.context_id())); + // rpc::StopProfilingResponseBuilder response(*fbb); + // return response.Finish(); + return UnimplementedErrorBuilder(IREE_LOC) + << "StopProfiling not yet implemented"; +} + +} // namespace debug +} // namespace rt +} // namespace iree
diff --git a/rt/debug/debug_service.h b/rt/debug/debug_service.h new file mode 100644 index 0000000..54a4744 --- /dev/null +++ b/rt/debug/debug_service.h
@@ -0,0 +1,175 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_RT_DEBUG_DEBUG_SERVICE_H_ +#define IREE_RT_DEBUG_DEBUG_SERVICE_H_ + +#include <vector> + +#include "absl/base/thread_annotations.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "base/status.h" +#include "flatbuffers/flatbuffers.h" +#include "rt/context.h" +#include "rt/debug/debug_session.h" +#include "schemas/debug_service_generated.h" + +namespace iree { +namespace rt { +namespace debug { + +// Debugging service used to implement the DebugService RPC methods in a +// transport-independent way. Specific DebugServer implementations can compose +// with a DebugService to avoid needing to maintain state themselves. Multiple +// DebugServer instances could share the same DebugService instance to ensure +// all clients - regardless of transport - share the same state. +// +// Thread-safe. +class DebugService { + public: + // Registers a context with the debug service. + // Ownership remains with the caller and UnregisterContext must be called + // prior to the context being destroyed. + Status RegisterContext(Context* context); + Status UnregisterContext(Context* context); + + // Registers a new module linked into an existing Context. + Status RegisterContextModule(Context* context, Module* module); + + // Registers a invocation state with the debug service. + // Ownership remains with the caller and UnregisterInvocation must be called + // prior to the invocation state being destroyed. + Status RegisterInvocation(Invocation* invocation); + Status UnregisterInvocation(Invocation* invocation); + + // Registers a debug session with the service. + Status RegisterDebugSession(DebugSession* session); + Status UnregisterDebugSession(DebugSession* session); + + // Blocks the caller until all sessions are ready. + // Returns AbortedError if a session connects/is already connected but + // disconnects during the wait. + Status WaitUntilAllSessionsReady(); + + StatusOr<::flatbuffers::Offset<rpc::MakeReadyResponse>> MakeReady( + const rpc::MakeReadyRequest& request, + ::flatbuffers::FlatBufferBuilder* fbb); + + StatusOr<::flatbuffers::Offset<rpc::GetStatusResponse>> GetStatus( + const rpc::GetStatusRequest& request, + ::flatbuffers::FlatBufferBuilder* fbb); + + StatusOr<::flatbuffers::Offset<rpc::ListContextsResponse>> ListContexts( + const rpc::ListContextsRequest& request, + ::flatbuffers::FlatBufferBuilder* fbb); + + StatusOr<::flatbuffers::Offset<rpc::GetModuleResponse>> GetModule( + const rpc::GetModuleRequest& request, + ::flatbuffers::FlatBufferBuilder* fbb); + StatusOr<::flatbuffers::Offset<rpc::GetFunctionResponse>> GetFunction( + const rpc::GetFunctionRequest& request, + ::flatbuffers::FlatBufferBuilder* fbb); + StatusOr<::flatbuffers::Offset<rpc::ResolveFunctionResponse>> ResolveFunction( + const rpc::ResolveFunctionRequest& request, + ::flatbuffers::FlatBufferBuilder* fbb); + + StatusOr<::flatbuffers::Offset<rpc::ListInvocationsResponse>> ListInvocations( + const rpc::ListInvocationsRequest& request, + ::flatbuffers::FlatBufferBuilder* fbb); + StatusOr<::flatbuffers::Offset<rpc::SuspendInvocationsResponse>> + SuspendInvocations(const rpc::SuspendInvocationsRequest& request, + ::flatbuffers::FlatBufferBuilder* fbb); + StatusOr<::flatbuffers::Offset<rpc::ResumeInvocationsResponse>> + ResumeInvocations(const rpc::ResumeInvocationsRequest& request, + ::flatbuffers::FlatBufferBuilder* fbb); + StatusOr<::flatbuffers::Offset<rpc::StepInvocationResponse>> StepInvocation( + const rpc::StepInvocationRequest& request, + ::flatbuffers::FlatBufferBuilder* fbb); + StatusOr<::flatbuffers::Offset<rpc::GetInvocationLocalResponse>> + GetInvocationLocal(const rpc::GetInvocationLocalRequest& request, + ::flatbuffers::FlatBufferBuilder* fbb); + StatusOr<::flatbuffers::Offset<rpc::SetInvocationLocalResponse>> + SetInvocationLocal(const rpc::SetInvocationLocalRequest& request, + ::flatbuffers::FlatBufferBuilder* fbb); + + StatusOr<::flatbuffers::Offset<rpc::ListBreakpointsResponse>> ListBreakpoints( + const rpc::ListBreakpointsRequest& request, + ::flatbuffers::FlatBufferBuilder* fbb); + StatusOr<::flatbuffers::Offset<rpc::AddBreakpointResponse>> AddBreakpoint( + const rpc::AddBreakpointRequest& request, + ::flatbuffers::FlatBufferBuilder* fbb); + StatusOr<::flatbuffers::Offset<rpc::RemoveBreakpointResponse>> + RemoveBreakpoint(const rpc::RemoveBreakpointRequest& request, + ::flatbuffers::FlatBufferBuilder* fbb); + + StatusOr<::flatbuffers::Offset<rpc::StartProfilingResponse>> StartProfiling( + const rpc::StartProfilingRequest& request, + ::flatbuffers::FlatBufferBuilder* fbb); + StatusOr<::flatbuffers::Offset<rpc::StopProfilingResponse>> StopProfiling( + const rpc::StopProfilingRequest& request, + ::flatbuffers::FlatBufferBuilder* fbb); + + // Serializes an invocation and its stack frames. + StatusOr<::flatbuffers::Offset<rpc::InvocationDef>> SerializeInvocation( + const Invocation& invocation, ::flatbuffers::FlatBufferBuilder* fbb); + + private: + StatusOr<Context*> GetContext(int context_id) const + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + StatusOr<Module*> GetModule(int context_id, + absl::string_view module_name) const + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + StatusOr<Invocation*> GetInvocation(int invocation_id) const + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Suspends all invocations on all contexts. Returns only once all invocations + // have been suspended successfully. Fails if any invocation fails to suspend. + Status SuspendAllInvocations() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Resumes all invocations on all contexts (the inverse of + // SuspendAllInvocations). Returns immediately. + Status ResumeAllInvocations() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Marks all sessions as unready. + Status UnreadyAllSessions() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Attempts to re-register all breakpoints for a module. + Status RegisterModuleBreakpoints(Context* context, Module* module) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + Status RegisterFunctionBreakpoint(Context* context, Module* module, + rpc::BreakpointDefT* breakpoint) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + Status UnregisterFunctionBreakpoint(const rpc::BreakpointDefT& breakpoint) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + // Signals that the given breakpoint was hit by the specified invocation. + // Called without the debug lock held. + Status OnFunctionBreakpointHit(int breakpoint_id, + const Invocation& invocation); + + absl::Mutex mutex_; + std::vector<Context*> contexts_ ABSL_GUARDED_BY(mutex_); + std::vector<Invocation*> invocations_ ABSL_GUARDED_BY(mutex_); + std::vector<DebugSession*> sessions_ ABSL_GUARDED_BY(mutex_); + int sessions_unready_ ABSL_GUARDED_BY(mutex_) = 0; + int sessions_ready_ ABSL_GUARDED_BY(mutex_) = 0; + + std::vector<rpc::BreakpointDefT> breakpoints_ ABSL_GUARDED_BY(mutex_); +}; + +} // namespace debug +} // namespace rt +} // namespace iree + +#endif // IREE_RT_DEBUG_DEBUG_SERVICE_H_
diff --git a/rt/debug/debug_session.cc b/rt/debug/debug_session.cc new file mode 100644 index 0000000..43d8197 --- /dev/null +++ b/rt/debug/debug_session.cc
@@ -0,0 +1,49 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rt/debug/debug_session.h" + +#include "base/source_location.h" +#include "base/status.h" + +namespace iree { +namespace rt { +namespace debug { + +bool DebugSession::is_ready() const { + absl::MutexLock lock(&mutex_); + return ready_ == 0; +} + +Status DebugSession::OnReady() { + absl::MutexLock lock(&mutex_); + if (ready_ > 0) { + return FailedPreconditionErrorBuilder(IREE_LOC) + << "Session has already readied up"; + } + ++ready_; + VLOG(2) << "Session " << id() << ": ++ready = " << ready_; + return OkStatus(); +} + +Status DebugSession::OnUnready() { + absl::MutexLock lock(&mutex_); + --ready_; + VLOG(2) << "Session " << id() << ": --ready = " << ready_; + return OkStatus(); +} + +} // namespace debug +} // namespace rt +} // namespace iree
diff --git a/rt/debug/debug_session.h b/rt/debug/debug_session.h new file mode 100644 index 0000000..3cdf903 --- /dev/null +++ b/rt/debug/debug_session.h
@@ -0,0 +1,93 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_RT_DEBUG_DEBUG_SESSION_H_ +#define IREE_RT_DEBUG_DEBUG_SESSION_H_ + +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" +#include "base/status.h" +#include "rt/context.h" +#include "rt/invocation.h" +#include "rt/module.h" +#include "schemas/debug_service_generated.h" + +namespace iree { +namespace rt { +namespace debug { + +// An active debugging session maintained by the DebugService. +// Each connected client gets a session and transport-specific implementations +// use the event methods to receive signals from the service. +// +// All methods are called only while the debug lock is held and may be called +// from any thread. +class DebugSession { + public: + virtual ~DebugSession() = default; + + // Session ID used in all RPCs related to this session. + // This can be used for attributing RPCs to the originating session when + // multiple sessions may be active at a time/over the same transport. + int id() const { return session_id_; } + + // Returns true if the session has issued a MakeReady request and is ok if + // execution resumes. + bool is_ready() const; + + // Signals that the session has readied up and is now active. + // Called with the global debug lock held. + virtual Status OnReady(); + + // Signals that the session has gone unready (from an event/etc) and the + // service is now awaiting it to ready up. + // Called with the global debug lock held. + virtual Status OnUnready(); + + // Signals that a context has been registered. + // Called with the global debug lock held. + virtual Status OnContextRegistered(Context* context) = 0; + virtual Status OnContextUnregistered(Context* context) = 0; + + // Signals that a module has been loaded in a context. + // Called with the global debug lock held. + virtual Status OnModuleLoaded(Context* context, Module* module) = 0; + + // Signals that a invocation has been registered. + // Called with the global debug lock held. + virtual Status OnInvocationRegistered(Invocation* invocation) = 0; + virtual Status OnInvocationUnregistered(Invocation* invocation) = 0; + + // Signals that a breakpoint has been resolved to a particular function in a + // context. + // Called with the global debug lock held. + virtual Status OnBreakpointResolved(const rpc::BreakpointDefT& breakpoint, + Context* context) = 0; + + // Signals that the given breakpoint has been hit during execution. + // Called with the global debug lock held. + virtual Status OnBreakpointHit(int breakpoint_id, + const Invocation& invocation) = 0; + + private: + mutable absl::Mutex mutex_; + int session_id_ = 0; + int ready_ ABSL_GUARDED_BY(mutex_) = -1; +}; + +} // namespace debug +} // namespace rt +} // namespace iree + +#endif // IREE_RT_DEBUG_DEBUG_SESSION_H_
diff --git a/rt/debug/debug_tcp_util.h b/rt/debug/debug_tcp_util.h new file mode 100644 index 0000000..67b6e31 --- /dev/null +++ b/rt/debug/debug_tcp_util.h
@@ -0,0 +1,217 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Utilities for working with TCP sockets. +// These are (mostly) portable to systems implementing BSD sockets. + +#ifndef IREE_RT_DEBUG_DEBUG_TCP_UTIL_H_ +#define IREE_RT_DEBUG_DEBUG_TCP_UTIL_H_ + +#include <fcntl.h> +#include <netinet/in.h> +#include <netinet/tcp.h> +#include <sys/socket.h> +#include <sys/types.h> + +#include <cstddef> + +#include "base/status.h" +#include "flatbuffers/base.h" +#include "flatbuffers/flatbuffers.h" +#include "schemas/debug_service_generated.h" + +namespace iree { +namespace rt { +namespace debug { +namespace tcp { + +// Toggles address reuse on a socket. Call prior to binding. +// This is useful if a socket is sitting in close_wait from a previous process +// while a new one is trying to bind to it. +inline Status ToggleSocketAddressReuse(int fd, bool is_enabled) { + int toggle = is_enabled ? 1 : 0; + ::setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &toggle, sizeof(toggle)); + return OkStatus(); +} + +// Toggles the linger option on a socket. +// Enabling linger will ensure all data on the socket is sent (if it can be +// sent within N sec) prior to closing. Disabling linger will cause the socket +// to close gracefully. +inline Status ToggleSocketLinger(int fd, bool is_enabled) { + struct linger linger; + linger.l_onoff = is_enabled ? 1 : 0; + linger.l_linger = 1; + ::setsockopt(fd, SOL_SOCKET, SO_LINGER, &linger, sizeof(linger)); + return OkStatus(); +} + +// Toggles Nagel's algorithm on a socket. +// Enabled by default, sockets have ~250ms delay for small packets. Disabling +// the algorithm will make socket flushes actually send data. +inline Status ToggleSocketNagelsAlgorithm(int fd, bool is_enabled) { + int toggle = is_enabled ? 1 : 0; + ::setsockopt(fd, SOL_TCP, TCP_NODELAY, &toggle, sizeof(toggle)); + return OkStatus(); +} + +// Toggles TCP keepalive on a socket. +// Assumes that the remote side is on the local machine/network and that we can +// spam it with packets. +// +// NOTE: we may want to adjust this when real debuggers are attached (to prevent +// dropping our own connections). Need to figure out how to reliably detect +// debug suspends vs. actual death. +inline Status ToggleSocketLocalKeepalive(int fd, bool is_enabled) { + // Toggle keepalive. + int keepalive_enable = is_enabled ? 1 : 0; + ::setsockopt(fd, SOL_SOCKET, SO_KEEPALIVE, &keepalive_enable, + sizeof(keepalive_enable)); + // Begin sending keepalive probes after N sec. + int keepalive_idle_delay = 3; + ::setsockopt(fd, SOL_TCP, TCP_KEEPIDLE, &keepalive_idle_delay, + sizeof(keepalive_idle_delay)); + // Try one probe and bail (faster detection). + int keepalive_retry_count = 1; + ::setsockopt(fd, SOL_TCP, TCP_KEEPINTVL, &keepalive_retry_count, + sizeof(keepalive_retry_count)); + // Send keepalives every N sec. + int keepalive_interval = 1; + ::setsockopt(fd, SOL_TCP, TCP_KEEPINTVL, &keepalive_interval, + sizeof(keepalive_interval)); + return OkStatus(); +} + +// Toggles the blocking state of a socket. +// If a socket has been set to non-blocking methods like read and write will +// return EWOULDBLOCK if they would have blocked on the specific operation. +inline Status ToggleSocketBlocking(int fd, bool is_blocking) { + if (is_blocking) { + ::fcntl(fd, F_SETFL, ::fcntl(fd, F_GETFL) & ~O_NONBLOCK); + } else { + ::fcntl(fd, F_SETFL, ::fcntl(fd, F_GETFL) | O_NONBLOCK); + } + return OkStatus(); +} + +// RAII wrapper for messages containing flatbuffer roots of type T. +template <typename T> +struct MessageBuffer { + public: + explicit MessageBuffer(std::vector<uint8_t> buffer) + : buffer_(std::move(buffer)) {} + MessageBuffer(const MessageBuffer&) = delete; + MessageBuffer& operator=(const MessageBuffer&) = delete; + MessageBuffer(MessageBuffer&&) = default; + MessageBuffer& operator=(MessageBuffer&&) = default; + + const T& GetRoot() const { + return *::flatbuffers::GetRoot<T>(buffer_.data()); + } + + private: + std::vector<uint8_t> buffer_; +}; + +// Reads a size prefix value from the given fd. +// If |poll_only| is true then the size prefix is not consumed from the stream +// and the call will return 0 if there is no size prefix available. +// Returns CancelledError if a (probably) graceful close is detected. +inline StatusOr<size_t> ReadSizePrefix(int fd, bool poll_only) { + ::flatbuffers::uoffset_t size_prefix = 0; + int read_bytes = ::recv(fd, &size_prefix, sizeof(size_prefix), + poll_only ? (MSG_PEEK | MSG_DONTWAIT) : 0); + if (read_bytes == 0) { + // Remote side disconnected. + return CancelledErrorBuilder(IREE_LOC) << "Graceful remote close"; + } else if (read_bytes < 0) { + if (errno == ECONNRESET) { + return CancelledErrorBuilder(IREE_LOC) << "Ungraceful remote close"; + } + return DataLossErrorBuilder(IREE_LOC) + << "Failed to read size prefix from socket: (" << errno << ") " + << ::strerror(errno); + } else if (read_bytes != sizeof(size_prefix)) { + if (poll_only) { + // No data available. + return 0; + } else { + return DataLossErrorBuilder(IREE_LOC) + << "Failed to read full size prefix (got " << read_bytes << "b of " + << sizeof(size_prefix) << "b expected)"; + } + } + return size_prefix; +} + +// Returns true if ReadBuffer will (likely) not block when called. +// Returns CancelledError if a (probably) graceful close is detected. +inline StatusOr<bool> CanReadBuffer(int fd) { + ASSIGN_OR_RETURN(size_t size_prefix, ReadSizePrefix(fd, /*poll_only=*/true)); + return size_prefix != 0; +} + +// Reads a size-prefixed message from the given fd. +// This will block until the entire message contents are available. +// Returns a buffer reference that will deallocate the buffer automatically or +// CancelledError if a (probably) graceful close is detected. +template <typename T> +StatusOr<MessageBuffer<T>> ReadBuffer(int fd) { + // Read the size prefix (written as a uoffset_t by the Write* methods). + ASSIGN_OR_RETURN(size_t size_prefix, ReadSizePrefix(fd, /*poll_only=*/false)); + + // Allocate the buffer for the entire message. + // We'll use the BufferRef to free() it when it's no longer required. + std::vector<uint8_t> buffer(size_prefix); + + // Read the entire message contents. + int full_read_bytes = ::recv(fd, buffer.data(), buffer.size(), 0); + if (full_read_bytes < 0) { + return DataLossErrorBuilder(IREE_LOC) + << "Failed to read full message contents from socket: (" << errno + << ") " << ::strerror(errno); + } else if (full_read_bytes != buffer.size()) { + return DataLossErrorBuilder(IREE_LOC) + << "Failed to read full message contents (got " << full_read_bytes + << "b of " << buffer.size() << "b expected)"; + } + + // Verify the contents. Not strictly required (as we won't ever ship this to + // prod), but useful in ensuring our socket code isn't corrupting things. + ::flatbuffers::Verifier verifier(buffer.data(), buffer.size()); + if (!verifier.VerifyBuffer<T>()) { + return DataLossErrorBuilder(IREE_LOC) + << "Verification of input buffer of type " << typeid(T).name() + << " (" << buffer.size() << "b) failed"; + } + + // Wrap the buffer to get some RAII goodness. + return MessageBuffer<T>(std::move(buffer)); +} + +// Writes a buffer to the given fd. +inline Status WriteBuffer(int fd, ::flatbuffers::DetachedBuffer buffer) { + if (::send(fd, buffer.data(), buffer.size(), 0) < 0) { + return UnavailableErrorBuilder(IREE_LOC) + << "Write failed: (" << errno << ") " << ::strerror(errno); + } + return OkStatus(); +} + +} // namespace tcp +} // namespace debug +} // namespace rt +} // namespace iree + +#endif // IREE_RT_DEBUG_DEBUG_TCP_UTIL_H_
diff --git a/rt/disassembler.h b/rt/disassembler.h new file mode 100644 index 0000000..cb3e0de --- /dev/null +++ b/rt/disassembler.h
@@ -0,0 +1,64 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_RT_DISASSEMBLER_H_ +#define IREE_RT_DISASSEMBLER_H_ + +#include <cstdint> +#include <ostream> +#include <vector> + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "base/status.h" +#include "rt/function.h" +#include "rt/source_location.h" + +namespace iree { +namespace rt { + +// A single disassembled instruction. +struct Instruction { + // Offset of the instruction within the function. + // The meaning of this is backend-dependent. + SourceOffset offset; + + // The first line of |long_text|. + absl::string_view short_text; + + // Human-readable text of the instruction. May contain multiple lines. + std::string long_text; +}; + +// Disassembles functions into instructions. +// +// Thread-safe. +class Disassembler { + public: + virtual ~Disassembler() = default; + + // Disassembles one or more instructions within the given function based on + // source offsets. + virtual StatusOr<std::vector<Instruction>> DisassembleInstructions( + const Function& function, SourceOffset offset, + int32_t instruction_count = INT32_MAX) const = 0; + + protected: + Disassembler() = default; +}; + +} // namespace rt +} // namespace iree + +#endif // IREE_RT_DISASSEMBLER_H_
diff --git a/rt/function.cc b/rt/function.cc new file mode 100644 index 0000000..b3ec339 --- /dev/null +++ b/rt/function.cc
@@ -0,0 +1,33 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rt/function.h" + +#include "rt/module.h" + +namespace iree { +namespace rt { + +absl::string_view Function::name() const { + auto result_or = module_->GetFunctionName(linkage_, ordinal_); + return result_or.ok() ? result_or.ValueOrDie() : absl::string_view(); +} + +const FunctionSignature Function::signature() const { + auto result_or = module_->GetFunctionSignature(linkage_, ordinal_); + return result_or.ok() ? result_or.ValueOrDie() : FunctionSignature(); +} + +} // namespace rt +} // namespace iree
diff --git a/rt/function.h b/rt/function.h new file mode 100644 index 0000000..9c3a2ca --- /dev/null +++ b/rt/function.h
@@ -0,0 +1,84 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_RT_FUNCTION_H_ +#define IREE_RT_FUNCTION_H_ + +#include <cstdint> +#include <ostream> + +#include "absl/strings/string_view.h" +#include "rt/function_signature.h" + +namespace iree { +namespace rt { + +class Module; + +// Reference to a function within a module. +// Functions are either visible or hidden from the module interface and may be +// of one Linkage type. Imports and exports are always visible (as they are +// required for dynamic linking) however functions with internal linkage may be +// hidden in optimized builds to reduce the amount of reflection metadata +// required. +class Function final { + public: + enum class Linkage { + // Function is internal to the module and may not be reflectable. + kInternal = 0, + // Function is an import from another module. + kImport = 1, + // Function is an export from the module. + kExport = 2, + }; + + Function() = default; + Function(const Module* module, Linkage linkage, int32_t ordinal) + : module_(module), linkage_(linkage), ordinal_(ordinal) {} + + // Module the function is contained within. + const Module* module() const { return module_; } + + // Linkage of the function. Note that Linkage::kInternal functions may be + // missing reflection information. + Linkage linkage() const { return linkage_; } + + // Ordinal within the module in the linkage scope. + int32_t ordinal() const { return ordinal_; } + + // Returns the original name of the function. + // Internal functions may return empty if debugging info has been stripped. + absl::string_view name() const; + + // Returns the signature of the function. + // Always present for imports and exports but may be empty for internal + // functions if debugging info has been stripped. + const FunctionSignature signature() const; + + private: + const Module* module_ = nullptr; + Linkage linkage_ = Linkage::kInternal; + int32_t ordinal_ = -1; +}; + +inline std::ostream& operator<<(std::ostream& stream, + const Function& function) { + stream << '@' << function.name() << '#' << function.ordinal(); + return stream; +} + +} // namespace rt +} // namespace iree + +#endif // IREE_RT_FUNCTION_H_
diff --git a/iree/rt/function_signature.h b/rt/function_signature.h similarity index 100% rename from iree/rt/function_signature.h rename to rt/function_signature.h
diff --git a/rt/instance.cc b/rt/instance.cc new file mode 100644 index 0000000..3ad8b35 --- /dev/null +++ b/rt/instance.cc
@@ -0,0 +1,51 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rt/instance.h" + +#include "base/tracing.h" +#include "rt/debug/debug_server.h" + +namespace iree { +namespace rt { + +Instance::Instance(std::unique_ptr<debug::DebugServer> debug_server) + : debug_server_(std::move(debug_server)) { + IREE_TRACE_SCOPE0("Instance::ctor"); +} + +Instance::~Instance() { IREE_TRACE_SCOPE0("Instance::dtor"); } + +void Instance::RegisterContext(Context* context) { + { + absl::MutexLock lock(&contexts_mutex_); + contexts_.push_back(context); + } + if (debug_server_) { + CHECK_OK(debug_server_->RegisterContext(context)); + } +} + +void Instance::UnregisterContext(Context* context) { + if (debug_server_) { + CHECK_OK(debug_server_->UnregisterContext(context)); + } + { + absl::MutexLock lock(&contexts_mutex_); + contexts_.erase(context); + } +} + +} // namespace rt +} // namespace iree
diff --git a/rt/instance.h b/rt/instance.h new file mode 100644 index 0000000..6578ca4 --- /dev/null +++ b/rt/instance.h
@@ -0,0 +1,70 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_RT_INSTANCE_H_ +#define IREE_RT_INSTANCE_H_ + +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" +#include "base/intrusive_list.h" +#include "base/ref_ptr.h" +#include "hal/device_manager.h" +#include "rt/context.h" +#include "rt/debug/debug_server.h" + +namespace iree { +namespace rt { + +// Shared runtime instance responsible for routing Context events, enumerating +// and creating hardware device interfaces, and managing thread pools. +// +// A single runtime instance can service multiple contexts and hosting +// applications should try to reuse instances as much as possible. This ensures +// that resource allocation across contexts is handled and extraneous device +// interaction is avoided. For devices that may have exclusive access +// restrictions it is mandatory to share instances, so plan accordingly. +// +// Thread-safe. +class Instance final : public RefObject<Instance> { + public: + // Creates an instance with an optional attached |debug_server|. + Instance() : Instance(nullptr) {} + explicit Instance(std::unique_ptr<debug::DebugServer> debug_server); + ~Instance(); + Instance(const Instance&) = delete; + Instance& operator=(const Instance&) = delete; + + // Optional debug server that has access to contexts in this instance. + debug::DebugServer* debug_server() const { return debug_server_.get(); } + + // Device manager used to enumerate available devices. + hal::DeviceManager* device_manager() const { return &device_manager_; } + + private: + friend class Context; + void RegisterContext(Context* context); + void UnregisterContext(Context* context); + + std::unique_ptr<debug::DebugServer> debug_server_; + mutable hal::DeviceManager device_manager_; + + absl::Mutex contexts_mutex_; + IntrusiveList<Context, offsetof(Context, instance_list_link_)> contexts_ + ABSL_GUARDED_BY(contexts_mutex_); +}; + +} // namespace rt +} // namespace iree + +#endif // IREE_RT_INSTANCE_H_
diff --git a/rt/invocation.cc b/rt/invocation.cc new file mode 100644 index 0000000..004c7bf --- /dev/null +++ b/rt/invocation.cc
@@ -0,0 +1,185 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rt/invocation.h" + +#include <atomic> +#include <iterator> + +#include "absl/strings/str_cat.h" +#include "base/status.h" +#include "base/tracing.h" +#include "rt/context.h" + +namespace iree { +namespace rt { + +namespace { + +int32_t NextUniqueInvocationId() { + static std::atomic<int32_t> next_id = {0}; + return ++next_id; +} + +} // namespace + +// static +StatusOr<ref_ptr<Invocation>> Invocation::Create( + ref_ptr<Context> context, const Function function, ref_ptr<Policy> policy, + absl::InlinedVector<ref_ptr<Invocation>, 4> dependencies, + absl::InlinedVector<hal::BufferView, 8> arguments, + absl::optional<absl::InlinedVector<hal::BufferView, 8>> results) { + IREE_TRACE_SCOPE0("Invocation::Create"); + + const auto& signature = function.signature(); + if (arguments.size() != signature.argument_count()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Argument count mismatch; expected " << signature.argument_count() + << " but received " << arguments.size(); + } else if (results.has_value() && + results.value().size() != signature.result_count()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Result count mismatch; expected " << signature.result_count() + << " but received " << results.value().size(); + } + + absl::InlinedVector<hal::BufferView, 8> results_value; + if (results.has_value()) { + results_value = std::move(results.value()); + } else { + results_value.resize(signature.result_count()); + } + + auto invocation = assign_ref( + new Invocation(std::move(context), function, std::move(policy))); + + // TODO(benvanik): grab execution state, insert deps, etc. + if (!dependencies.empty()) { + return UnimplementedErrorBuilder(IREE_LOC) + << "Dependencies are not yet implemented"; + } + + // TODO(benvanik): fiber scheduling and such. + auto execute_status = function.module()->Execute( + &invocation->stack_, function, std::move(arguments), &results_value); + if (execute_status.ok()) { + invocation->CompleteSuccess(std::move(results_value)); + } else { + invocation->CompleteFailure(std::move(execute_status), nullptr); + } + + return invocation; +} + +// static +StatusOr<ref_ptr<Invocation>> Invocation::Create( + ref_ptr<Context> context, const Function function, ref_ptr<Policy> policy, + absl::Span<const ref_ptr<Invocation>> dependencies, + absl::Span<const hal::BufferView> arguments) { + absl::InlinedVector<ref_ptr<Invocation>, 4> dependency_list; + dependency_list.reserve(dependencies.size()); + for (auto& dependency : dependencies) { + dependency_list.push_back(add_ref(dependency)); + } + absl::InlinedVector<hal::BufferView, 8> argument_list; + argument_list.reserve(arguments.size()); + for (auto& buffer_view : arguments) { + argument_list.push_back(buffer_view); + } + return Invocation::Create(std::move(context), function, std::move(policy), + std::move(dependency_list), + std::move(argument_list)); +} + +Invocation::Invocation(ref_ptr<Context> context, const Function function, + ref_ptr<Policy> policy) + : id_(NextUniqueInvocationId()), + context_(std::move(context)), + function_(function), + policy_(std::move(policy)), + stack_(context_.get()) { + IREE_TRACE_SCOPE0("Invocation::ctor"); + context_->RegisterInvocation(this); +} + +Invocation::~Invocation() { + IREE_TRACE_SCOPE0("Invocation::dtor"); + context_->UnregisterInvocation(this); +} + +std::string Invocation::DebugStringShort() const { + return absl::StrCat("invocation_", id_); +} + +std::string Invocation::DebugString() const { return DebugStringShort(); } + +Status Invocation::QueryStatus() { + IREE_TRACE_SCOPE0("Invocation::QueryStatus"); + absl::MutexLock lock(&status_mutex_); + return completion_status_; +} + +StatusOr<absl::InlinedVector<hal::BufferView, 8>> Invocation::ConsumeResults() { + IREE_TRACE_SCOPE0("Invocation::ConsumeResults"); + absl::MutexLock lock(&status_mutex_); + if (!completion_status_.ok()) { + return completion_status_; + } + return std::move(results_); +} + +Status Invocation::Await(absl::Time deadline) { + IREE_TRACE_SCOPE0("Invocation::Await"); + absl::MutexLock lock(&status_mutex_); + // TODO(benvanik): implement async invocation behavior. + return completion_status_; +} + +Status Invocation::Abort() { + IREE_TRACE_SCOPE0("Invocation::Abort"); + // TODO(benvanik): implement async invocation behavior. + return UnimplementedErrorBuilder(IREE_LOC) + << "Async invocations not yet implemented"; +} + +void Invocation::CompleteSuccess( + absl::InlinedVector<hal::BufferView, 8> results) { + IREE_TRACE_SCOPE0("Invocation::CompleteSuccess"); + absl::MutexLock lock(&status_mutex_); + if (IsAborted(completion_status_)) { + // Ignore as the invocation was already aborted prior to completion. + return; + } + DCHECK(IsUnavailable(completion_status_)); + completion_status_ = OkStatus(); + failure_stack_trace_.reset(); + results_ = std::move(results); +} + +void Invocation::CompleteFailure( + Status completion_status, std::unique_ptr<StackTrace> failure_stack_trace) { + IREE_TRACE_SCOPE0("Invocation::CompleteFailure"); + absl::MutexLock lock(&status_mutex_); + if (IsAborted(completion_status_)) { + // Ignore as the invocation was already aborted prior to completion. + return; + } + DCHECK(IsUnavailable(completion_status_)); + completion_status_ = std::move(completion_status); + failure_stack_trace_ = std::move(failure_stack_trace); + results_.clear(); +} + +} // namespace rt +} // namespace iree
diff --git a/rt/invocation.h b/rt/invocation.h new file mode 100644 index 0000000..3226ea3 --- /dev/null +++ b/rt/invocation.h
@@ -0,0 +1,157 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_RT_INVOCATION_H_ +#define IREE_RT_INVOCATION_H_ + +#include <ostream> +#include <string> + +#include "absl/base/thread_annotations.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "base/intrusive_list.h" +#include "base/ref_ptr.h" +#include "base/status.h" +#include "hal/buffer_view.h" +#include "rt/function.h" +#include "rt/policy.h" +#include "rt/stack.h" +#include "rt/stack_trace.h" + +namespace iree { +namespace rt { + +class Context; + +// An asynchronous invocation of a function. +// Holds the invocation state and allows querying and waiting on completion. +// Invocations are conceptually fibers and may suspend and resume execution +// several times before completing. +// +// Thread-safe. +class Invocation final : public RefObject<Invocation> { + public: + // TODO(benvanik): define error propagation semantics across dependencies. + // TODO(benvanik): support more dependency types (semaphores, etc). + // Creates a new invocation tracking object for invoking the given |function| + // from |context|. |arguments| will be retained until the invocation is made. + // If |dependencies| are provided then the invocation will wait until they are + // resolved before executing. If a |policy| is provided it will override the + // context-level policy. + // + // Optionally |results| may be provided with preallocated buffers that will + // receive the outputs of the invocation. Invocation will fail if they do not + // match expected sizes. + // + // Note that it's possible for the invocation to complete prior to the return + // of this function. Any errors that occur will be set on the invocation and + // callers should query its state prior to assuming it is in-flight. + static StatusOr<ref_ptr<Invocation>> Create( + ref_ptr<Context> context, const Function function, ref_ptr<Policy> policy, + absl::InlinedVector<ref_ptr<Invocation>, 4> dependencies, + absl::InlinedVector<hal::BufferView, 8> arguments, + absl::optional<absl::InlinedVector<hal::BufferView, 8>> results = + absl::nullopt); + static StatusOr<ref_ptr<Invocation>> Create( + ref_ptr<Context> context, const Function function, ref_ptr<Policy> policy, + absl::Span<const ref_ptr<Invocation>> dependencies, + absl::Span<const hal::BufferView> arguments); + + ~Invocation(); + + // A process-unique ID for the invocation. + int32_t id() const { return id_; } + + // Context this invocation is running within. + const ref_ptr<Context>& context() const { return context_; } + + // Function being invoked. + const Function& function() const { return function_; } + + // A single-line human-readable debug string for the invocation. + std::string DebugStringShort() const; + + // A long-form debug string with stack trace (if available). + std::string DebugString() const; + + // Queries the completion status of the invocation. + // Returns one of the following: + // StatusCode::kOk: the invocation completed successfully. + // StatusCode::kUnavailable: the invocation has not yet completed. + // StatusCode::kCancelled: the invocation was cancelled internally. + // StatusCode::kAborted: the invocation was aborted. + // StatusCode::*: an error occurred during invocation. + Status QueryStatus(); + + // Returns ownership of the results of the operation to the caller. + // If the invocation failed then the result will be returned as if Query had + // been called. + StatusOr<absl::InlinedVector<hal::BufferView, 8>> ConsumeResults(); + + // Blocks the caller until the invocation completes (successfully or + // otherwise). + // + // Returns StatusCode::kDeadlineExceeded if |deadline| elapses before the + // invocation completes and otherwise returns the result of Query(). + Status Await(absl::Time deadline); + + // Attempts to abort the invocation if it is in-flight. + // A no-op if the invocation has already completed. + Status Abort(); + + // TODO(benvanik): export a hal::TimelineSemaphore. + + private: + friend class Context; + + Invocation(ref_ptr<Context> context, const Function function, + ref_ptr<Policy> policy); + + // Completes the invocation with a successful result. + void CompleteSuccess(absl::InlinedVector<hal::BufferView, 8> results); + + // Completes the invocation with a failure, including an optional stack trace. + void CompleteFailure(Status completion_status, + std::unique_ptr<StackTrace> failure_stack_trace); + + int32_t id_; + ref_ptr<Context> context_; + const Function function_; + ref_ptr<Policy> policy_; + + Stack stack_; + + absl::Mutex status_mutex_; + Status completion_status_ ABSL_GUARDED_BY(status_mutex_) = + UnavailableErrorBuilder(IREE_LOC); + std::unique_ptr<StackTrace> failure_stack_trace_ + ABSL_GUARDED_BY(status_mutex_); + absl::InlinedVector<hal::BufferView, 8> results_ + ABSL_GUARDED_BY(status_mutex_); + + friend class Context; + IntrusiveListLink context_list_link_; +}; + +inline std::ostream& operator<<(std::ostream& stream, + const Invocation& invocation) { + stream << invocation.DebugStringShort(); + return stream; +} + +} // namespace rt +} // namespace iree + +#endif // IREE_RT_INVOCATION_H_
diff --git a/rt/module.h b/rt/module.h new file mode 100644 index 0000000..777d454 --- /dev/null +++ b/rt/module.h
@@ -0,0 +1,109 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_RT_MODULE_H_ +#define IREE_RT_MODULE_H_ + +#include <ostream> + +#include "absl/container/inlined_vector.h" +#include "absl/strings/string_view.h" +#include "base/ref_ptr.h" +#include "base/status.h" +#include "hal/buffer_view.h" +#include "rt/function.h" +#include "rt/module_signature.h" + +namespace iree { +namespace rt { + +class Disassembler; +class SourceResolver; +class Stack; + +// Abstract compiled module interface for resolving functions. +// +// Modules are (generally) stateless, immutable, and may exist in multiple +// contexts at the same time. +class Module : public RefObject<Module> { + public: + virtual ~Module() = default; + + // Name of the module used to resolve fully-qualified references. + // The lifetime of the returned reference is not guaranteed beyond the current + // calling scope and callers must clone it if they want to retain it. + virtual absl::string_view name() const = 0; + + // A description of the module imports, exports, and other metadata. + virtual const ModuleSignature signature() const = 0; + + // Returns a resolver capable of resolving functions to source and performing + // basic debugging logic (such as offset calculation). + // May be nullptr if debugging info has been stripped. + virtual SourceResolver* source_resolver() const = 0; + + // Returns a disassembler that can be used to disassemble functions in the + // module. May be nullptr if debugging info has been stripped or disassembly + // has been disabled as a compile option. + virtual Disassembler* disassembler() const = 0; + + // A short human-readable string that matches the compiler formatting. + virtual std::string DebugStringShort() const = 0; + + // Looks up a visible function by ordinal. + // Internal functions may not be found if debugging info has been stripped. + virtual StatusOr<const Function> LookupFunctionByOrdinal( + Function::Linkage linkage, int32_t ordinal) const = 0; + + // Looks up a visible function by name. + // Internal functions may not be found if debugging info has been stripped. + virtual StatusOr<const Function> LookupFunctionByName( + Function::Linkage linkage, absl::string_view name) const = 0; + + // Returns the name of the visible function as a string reference. + // + // May return empty for functions with internal linkage if debugging info has + // been stripped. + // + // The lifetime of the returned reference is not guaranteed beyond the current + // calling scope and callers must clone it if they want to retain it. + virtual StatusOr<absl::string_view> GetFunctionName( + Function::Linkage linkage, int32_t ordinal) const = 0; + + // Returns the full function signature for the given |ordinal|. + // + // May return empty for functions with internal linkage if the debugging info + // has been stripped. + virtual StatusOr<const FunctionSignature> GetFunctionSignature( + Function::Linkage linkage, int32_t ordinal) const = 0; + + // Temporary until scheduler is built. + virtual Status Execute( + Stack* stack, const Function function, + absl::InlinedVector<hal::BufferView, 8> arguments, + absl::InlinedVector<hal::BufferView, 8>* results) const = 0; + + protected: + Module() = default; +}; + +inline std::ostream& operator<<(std::ostream& stream, const Module& module) { + stream << module.DebugStringShort(); + return stream; +} + +} // namespace rt +} // namespace iree + +#endif // IREE_RT_MODULE_H_
diff --git a/rt/module_printer.cc b/rt/module_printer.cc new file mode 100644 index 0000000..a5ff3e6 --- /dev/null +++ b/rt/module_printer.cc
@@ -0,0 +1,65 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rt/module_printer.h" + +#include <iomanip> + +#include "rt/disassembler.h" +#include "rt/source_resolver.h" + +namespace iree { +namespace rt { + +Status PrintModuleToStream(const Module& module, PrintModuleFlagBitfield flags, + std::ostream* stream) { + *stream << "Imports:\n"; + for (int i = 0; i < module.signature().import_function_count(); ++i) { + ASSIGN_OR_RETURN(auto function, module.LookupFunctionByOrdinal( + Function::Linkage::kImport, i)); + *stream << " " << i << ": " << function << "\n"; + } + *stream << "Exports:\n"; + for (int i = 0; i < module.signature().export_function_count(); ++i) { + ASSIGN_OR_RETURN(auto function, module.LookupFunctionByOrdinal( + Function::Linkage::kExport, i)); + *stream << " " << i << ": " << function << "\n"; + } + if (module.signature().internal_function_count()) { + *stream << "Internal:\n"; + auto* disassembler = module.disassembler(); + for (int i = 0; i < module.signature().internal_function_count(); ++i) { + ASSIGN_OR_RETURN(auto function, module.LookupFunctionByOrdinal( + Function::Linkage::kInternal, i)); + *stream << " " << i << ": " << function << "\n"; + if (disassembler && AllBitsSet(flags, PrintModuleFlag::kDisassemble)) { + auto instructions_or = + disassembler->DisassembleInstructions(function, 0); + if (IsUnavailable(instructions_or.status())) continue; + for (const auto& instruction : instructions_or.ValueOrDie()) { + *stream << " " << std::setw(6) << instruction.offset << ": " + << instruction.long_text << "\n"; + } + } + } + } + return OkStatus(); +} + +Status PrintModuleToStream(const Module& module, std::ostream* stream) { + return PrintModuleToStream(module, PrintModuleFlag::kNone, stream); +} + +} // namespace rt +} // namespace iree
diff --git a/rt/module_printer.h b/rt/module_printer.h new file mode 100644 index 0000000..48abcff --- /dev/null +++ b/rt/module_printer.h
@@ -0,0 +1,42 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_RT_MODULE_PRINTER_H_ +#define IREE_RT_MODULE_PRINTER_H_ + +#include <ostream> + +#include "base/bitfield.h" +#include "base/status.h" +#include "rt/module.h" + +namespace iree { +namespace rt { + +enum class PrintModuleFlag { + kNone = 0, + kDisassemble = 1 << 0, +}; +IREE_BITFIELD(PrintModuleFlag); +using PrintModuleFlagBitfield = PrintModuleFlag; + +// Prints all functions within the module to the given |stream|. +Status PrintModuleToStream(const Module& module, std::ostream* stream); +Status PrintModuleToStream(const Module& module, PrintModuleFlagBitfield flags, + std::ostream* stream); + +} // namespace rt +} // namespace iree + +#endif // IREE_RT_MODULE_PRINTER_H_
diff --git a/iree/rt/module_signature.h b/rt/module_signature.h similarity index 100% rename from iree/rt/module_signature.h rename to rt/module_signature.h
diff --git a/rt/policy.h b/rt/policy.h new file mode 100644 index 0000000..0c3c8f9 --- /dev/null +++ b/rt/policy.h
@@ -0,0 +1,43 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_RT_POLICY_H_ +#define IREE_RT_POLICY_H_ + +#include "base/ref_ptr.h" + +namespace iree { +namespace rt { + +// Defines how invocation scheduling is to be performed. +// The policy instance is used by the scheduler to determine when submissions +// should be flushed to target queues. +// +// Thread-safe; the policy may be evaluated from arbitrary threads after an +// invocation has began processing. +class Policy : public RefObject<Policy> { + public: + virtual ~Policy() = default; + + // TODO(benvanik): constraints: + // - max memory usage + // - max delay + // - max in-flight items/etc + // - allowed device types +}; + +} // namespace rt +} // namespace iree + +#endif // IREE_RT_POLICY_H_
diff --git a/rt/source_location.cc b/rt/source_location.cc new file mode 100644 index 0000000..0cf035b --- /dev/null +++ b/rt/source_location.cc
@@ -0,0 +1,32 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rt/source_location.h" + +#include <sstream> + +#include "rt/source_resolver.h" + +namespace iree { +namespace rt { + +std::string SourceLocation::DebugStringShort() const { + if (is_unknown()) return "(unknown)"; + std::ostringstream stream; + resolver_->PrintSourceLocation(resolver_args_, &stream); + return stream.str(); +} + +} // namespace rt +} // namespace iree
diff --git a/iree/rt/source_location.h b/rt/source_location.h similarity index 100% rename from iree/rt/source_location.h rename to rt/source_location.h
diff --git a/rt/source_resolver.h b/rt/source_resolver.h new file mode 100644 index 0000000..3b27aaf --- /dev/null +++ b/rt/source_resolver.h
@@ -0,0 +1,67 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_RT_SOURCE_RESOLVER_H_ +#define IREE_RT_SOURCE_RESOLVER_H_ + +#include <cstdint> +#include <ostream> +#include <vector> + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "base/status.h" +#include "rt/function.h" +#include "rt/source_location.h" + +namespace iree { +namespace rt { + +// Resolves offsets within functions to SourceLocations and provides source +// language services. +// +// Thread-safe. +class SourceResolver { + public: + virtual ~SourceResolver() = default; + + // Resolves a function-relative offset to a source location. + // Not all offsets within a function may have source mapping information. + virtual absl::optional<SourceLocation> ResolveFunctionOffset( + const Function& function, SourceOffset offset) = 0; + + // Converts a source location to a human-readable string, commonly in a single + // line denoting an original source file location (such as path:line:col). + virtual void PrintSourceLocation(SourceResolverArgs resolver_args, + std::ostream* stream) const = 0; + + // TODO(benvanik): query local variable names. + + // TODO(benvanik): step target calculation (relative mapping). + // TODO(benvanik): step target based on SourceLocation delta. + + // TODO(benvanik): expression evaluator? (setting variables) + + protected: + friend class SourceLocation; + + SourceResolver() = default; + + // TODO(benvanik): get line mapping information. +}; + +} // namespace rt +} // namespace iree + +#endif // IREE_RT_SOURCE_RESOLVER_H_
diff --git a/rt/stack.cc b/rt/stack.cc new file mode 100644 index 0000000..6e0cd13 --- /dev/null +++ b/rt/stack.cc
@@ -0,0 +1,60 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rt/stack.h" + +#include <iterator> + +#include "absl/strings/str_join.h" +#include "base/status.h" + +namespace iree { +namespace rt { + +constexpr int Stack::kMaxStackDepth; + +Stack::Stack(Context* context) : context_(context) {} + +Stack::~Stack() = default; + +StatusOr<StackFrame*> Stack::PushFrame(Function function) { + if (stack_depth_ + 1 > kMaxStackDepth) { + return InternalErrorBuilder(IREE_LOC) + << "Max stack depth of " << kMaxStackDepth << " exceeded"; + } + frames_[stack_depth_++] = StackFrame(function); + + // TODO(benvanik): WTF scope enter. + + return current_frame(); +} + +Status Stack::PopFrame() { + if (stack_depth_ == 0) { + return InternalErrorBuilder(IREE_LOC) << "Unbalanced stack pop"; + } + + // TODO(benvanik): WTF scope leave. + + --stack_depth_; + frames_[stack_depth_] = {}; + return OkStatus(); +} + +std::string Stack::DebugString() const { + return absl::StrJoin(frames(), "\n", StackFrameFormatter()); +} + +} // namespace rt +} // namespace iree
diff --git a/rt/stack.h b/rt/stack.h new file mode 100644 index 0000000..34a2c8d --- /dev/null +++ b/rt/stack.h
@@ -0,0 +1,85 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_RT_STACK_H_ +#define IREE_RT_STACK_H_ + +#include <functional> + +#include "absl/types/span.h" +#include "base/status.h" +#include "rt/stack_frame.h" + +namespace iree { +namespace rt { + +class Context; + +// A runtime call stack for managing stack frames. +// The frames within a stack may be from different backends and may provide +// varying levels of information based on capabilities. +// +// Thread-compatible. Do not attempt to investigate a stack while another thread +// may be mutating it! +class Stack final { + public: + static constexpr int kMaxStackDepth = 32; + + explicit Stack(Context* context); + Stack(const Stack&) = delete; + Stack& operator=(const Stack&) = delete; + ~Stack(); + + // Context defining the module and global workspaces. + Context* context() const { return context_; } + + // All stack frames within the stack. + absl::Span<StackFrame> frames() { + return absl::MakeSpan(frames_).subspan(0, stack_depth_); + } + absl::Span<const StackFrame> frames() const { + return absl::MakeConstSpan(frames_).subspan(0, stack_depth_); + } + + // The current stack frame. + StackFrame* current_frame() { + return stack_depth_ > 0 ? &frames_[stack_depth_ - 1] : nullptr; + } + + // The stack frame of the caller of the current function. + StackFrame* caller_frame() { + return stack_depth_ > 1 ? &frames_[stack_depth_ - 2] : nullptr; + } + + StatusOr<StackFrame*> PushFrame(Function function); + Status PopFrame(); + + // Returns a full stack frame listing in human-readable form. + std::string DebugString() const; + + private: + Context* context_ = nullptr; + std::array<StackFrame, kMaxStackDepth> frames_; + int stack_depth_ = 0; +}; + +inline std::ostream& operator<<(std::ostream& stream, const Stack& stack) { + stream << stack.DebugString(); + return stream; +} + +} // namespace rt +} // namespace iree + +#endif // IREE_RT_STACK_H_
diff --git a/rt/stack_frame.cc b/rt/stack_frame.cc new file mode 100644 index 0000000..6faf9e7 --- /dev/null +++ b/rt/stack_frame.cc
@@ -0,0 +1,34 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rt/stack_frame.h" + +#include "absl/strings/str_cat.h" +#include "rt/source_resolver.h" + +namespace iree { +namespace rt { + +absl::optional<SourceLocation> StackFrame::source_location() const { + auto* source_resolver = function_.module()->source_resolver(); + if (!source_resolver) return absl::nullopt; + return source_resolver->ResolveFunctionOffset(function_, offset_); +} + +std::string StackFrame::DebugStringShort() const { + return absl::StrCat(module().name(), ":", function().name(), "@", offset()); +} + +} // namespace rt +} // namespace iree
diff --git a/rt/stack_frame.h b/rt/stack_frame.h new file mode 100644 index 0000000..f2014b8 --- /dev/null +++ b/rt/stack_frame.h
@@ -0,0 +1,106 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_RT_STACK_FRAME_H_ +#define IREE_RT_STACK_FRAME_H_ + +#include <ostream> + +#include "absl/types/span.h" +#include "rt/function.h" +#include "rt/module.h" +#include "rt/source_location.h" + +namespace iree { +namespace rt { + +// TODO(benvanik): allocate in-place from an arena. +// Register table used within a stack frame. +struct Registers { + std::vector<hal::BufferView> buffer_views; +}; + +// A single frame on the call stack containing current execution state and +// register values. +// +// As different backends may support different features this interface exposes +// only the things we want to view in our debugger/stack dumps. This allows us +// to ignore the actual implementation (bytecode VM, compiled C code, etc) so +// long as it can respond to queries for register values. This has the benefit +// of keeping the actual frame very lightweight as we are not storing the values +// but instead just routing to the real storage via indirection. If the debugger +// is not attached and no errors are hit then no additional bookkeeping is done. +// +// Thread-compatible, as is the owning Stack/StackTrace. +class StackFrame final { + public: + StackFrame() = default; + explicit StackFrame(Function function) : function_(function) {} + StackFrame(Function function, SourceOffset offset, Registers registers) + : function_(function), + offset_(offset), + registers_(std::move(registers)) {} + StackFrame(const StackFrame&) = delete; + StackFrame& operator=(const StackFrame&) = delete; + StackFrame(StackFrame&&) = default; + StackFrame& operator=(StackFrame&&) = default; + + // Module that owns the function this stack frame represents. + const Module& module() const { return *function_.module(); } + + // Function the stack frame represents. + const Function& function() const { return function_; } + + // Current virtual offset within the function. + // The exact meaning of the offset is backend dependent and callers should + // treat them as opaque and must use the SourceResolver to compute new + // offsets (such as 'next offset'). + SourceOffset offset() const { return offset_; } + SourceOffset* mutable_offset() { return &offset_; } + + // Returns a source location, if available, for the current offset within the + // target function. + absl::optional<SourceLocation> source_location() const; + + // Registers used within the stack frame. + // Storage is implementation-defined and is valid only for the lifetime of the + // frame. + const Registers& registers() const { return registers_; } + Registers* mutable_registers() { return ®isters_; } + + // A short human-readable string for the frame; a single line. + std::string DebugStringShort() const; + + private: + Function function_; + SourceOffset offset_ = 0; + Registers registers_; +}; + +struct StackFrameFormatter { + void operator()(std::string* out, const StackFrame& stack_frame) const { + out->append(stack_frame.DebugStringShort()); + } +}; + +inline std::ostream& operator<<(std::ostream& stream, + const StackFrame& stack_frame) { + stream << stack_frame.DebugStringShort(); + return stream; +} + +} // namespace rt +} // namespace iree + +#endif // IREE_RT_STACK_FRAME_H_
diff --git a/rt/stack_trace.cc b/rt/stack_trace.cc new file mode 100644 index 0000000..a3ad17f --- /dev/null +++ b/rt/stack_trace.cc
@@ -0,0 +1,28 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rt/stack_trace.h" + +#include "absl/strings/str_join.h" +#include "base/status.h" + +namespace iree { +namespace rt { + +std::string StackTrace::DebugString() const { + return absl::StrJoin(frames_, "\n", StackFrameFormatter()); +} + +} // namespace rt +} // namespace iree
diff --git a/rt/stack_trace.h b/rt/stack_trace.h new file mode 100644 index 0000000..56d0d47 --- /dev/null +++ b/rt/stack_trace.h
@@ -0,0 +1,65 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_RT_STACK_TRACE_H_ +#define IREE_RT_STACK_TRACE_H_ + +#include <ostream> +#include <vector> + +#include "absl/types/span.h" +#include "base/status.h" +#include "rt/stack_frame.h" + +namespace iree { +namespace rt { + +// A snapshot of a stack at a point in time. +// The frames within a stack may be from different backends and may provide +// varying levels of information based on capabilities. +// +// Depending on the capture options the trace may contain references to register +// values (such as buffers) from the time of capture. If the buffers were +// modified after the capture was taken those results will be reflected! +class StackTrace final { + public: + StackTrace() = default; + explicit StackTrace(std::vector<StackFrame> frames) + : frames_(std::move(frames)) {} + StackTrace(const StackTrace&) = delete; + StackTrace& operator=(const StackTrace&) = delete; + ~StackTrace() = default; + + // All stack frames within the stack. + absl::Span<const StackFrame> frames() const { + return absl::MakeConstSpan(frames_); + } + + // Returns a full stack frame listing in human-readable form. + std::string DebugString() const; + + private: + std::vector<StackFrame> frames_; +}; + +inline std::ostream& operator<<(std::ostream& stream, + const StackTrace& stack_trace) { + stream << stack_trace.DebugString(); + return stream; +} + +} // namespace rt +} // namespace iree + +#endif // IREE_RT_STACK_TRACE_H_
diff --git a/iree/samples/CMakeLists.txt b/samples/CMakeLists.txt similarity index 100% rename from iree/samples/CMakeLists.txt rename to samples/CMakeLists.txt
diff --git a/samples/hal/BUILD b/samples/hal/BUILD new file mode 100644 index 0000000..0976831 --- /dev/null +++ b/samples/hal/BUILD
@@ -0,0 +1,48 @@ +# Samples demonstrating use of the HAL API. +# These do not rely on higher layers of the system (such as the VM or runtime). + +load("//:build_defs.google.bzl", "PLATFORM_VULKAN_TEST_DEPS") +load("///tools:compilation.bzl", "iree_bytecode_module") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +iree_bytecode_module( + name = "simple_compute_test_module", + srcs = ["simple_compute_test.mlir"], + cc_namespace = "iree::hal::samples", +) + +cc_test( + name = "simple_compute_test", + srcs = ["simple_compute_test.cc"], + data = [ + # When building with --config=asan you must specify the following + # envvar when using Vulkan + a local Nvidia GPU: + # LSAN_OPTIONS=suppressions=third_party/iree/tools/sanitizer_suppressions.txt + "///tools:sanitizer_suppressions.txt", + ], + deps = [ + ":simple_compute_test_module_cc", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "///base:flatbuffer_util", + "///base:status_matchers", + "///hal:command_buffer", + "///hal:command_queue", + "///hal:driver_registry", + "///schemas", + "///base:status", + + # These are the drivers we support running with and can produce + # executables for from the source MLIR. + "///hal/interpreter:interpreter_driver_module", # build-cleaner: keep + "///hal/vulkan:vulkan_driver_module", # build-cleaner: keep + + # TODO(b/142004903): enable when Dawn HAL implementation is functional + # "///hal/dawn:dawn_driver_module", # build-cleaner: keep + ] + PLATFORM_VULKAN_TEST_DEPS, +)
diff --git a/iree/samples/hal/CMakeLists.txt b/samples/hal/CMakeLists.txt similarity index 100% rename from iree/samples/hal/CMakeLists.txt rename to samples/hal/CMakeLists.txt
diff --git a/samples/hal/simple_compute_test.cc b/samples/hal/simple_compute_test.cc new file mode 100644 index 0000000..1c34905 --- /dev/null +++ b/samples/hal/simple_compute_test.cc
@@ -0,0 +1,221 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// A simple backend-agnostic compute test for the HAL API. +// This will load an IREE module containing one or more executables and attempt +// to run them against all registered driver backends. +// +// The input file, simple_compute_test.mlir, is as generic as possible to ensure +// we don't need too many variants. This means that it does not use any FFI +// imports requiring runtime support, uses floats exclusively (as that's assumed +// available everywhere), etc. +// +// The `iree_bytecode_module` build rule is used to translate the MLIR to the +// module flatbuffer. Additional target support can be defined there. + +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_replace.h" +#include "absl/time/time.h" +#include "base/flatbuffer_util.h" +#include "base/status.h" +#include "base/status_matchers.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "hal/command_buffer.h" +#include "hal/command_queue.h" +#include "hal/driver_registry.h" +#include "samples/hal/simple_compute_test_module.h" +#include "schemas/module_def_generated.h" + +namespace iree { +namespace hal { +namespace samples { +namespace { + +using ModuleFile = FlatBufferFile<ModuleDef>; + +struct TestParams { + // HAL driver to use for the test. + std::string driver_name; + // Ordinal within the module to execute. + int executable_ordinal; + // Name of the executable (just for prettier logging). + std::string executable_name; +}; + +std::ostream& operator<<(std::ostream& os, const TestParams& params) { + return os << absl::StrReplaceAll(params.driver_name, {{":", "_"}}) << "_ex" + << params.executable_ordinal << "_" << params.executable_name; +} + +// Loads the precompiled module file (from simple_compute_test.mlir). +std::unique_ptr<ModuleFile> LoadModuleFile() { + const auto* file_toc = simple_compute_test_module_create(); + return ModuleFile::WrapBuffer( + ModuleDefIdentifier(), + absl::MakeSpan(reinterpret_cast<const uint8_t*>(file_toc->data), + file_toc->size)) + .ValueOrDie(); +} + +// Builds a list of tests to run for each [driver x available executable]. +std::vector<TestParams> GetAvailableDriverTestParams() { + auto module_file = LoadModuleFile(); + auto& executable_table = *module_file->root()->executable_table(); + std::vector<TestParams> all_test_params; + for (const auto& driver_name : + DriverRegistry::shared_registry()->EnumerateAvailableDrivers()) { + int executable_ordinal = 0; + for (const auto* multi_arch_executable_def : + *executable_table.multi_arch_executables()) { + TestParams test_params; + test_params.driver_name = driver_name; + test_params.executable_ordinal = executable_ordinal--; + test_params.executable_name = + std::string(WrapString(multi_arch_executable_def->name())); + all_test_params.push_back(std::move(test_params)); + } + } + return all_test_params; +} + +class SimpleComputeTest : public ::testing::Test, + public ::testing::WithParamInterface<TestParams> { + protected: + virtual void SetUp() { module_file_ = LoadModuleFile(); } + + std::unique_ptr<ModuleFile> module_file_; +}; + +TEST_P(SimpleComputeTest, RunOnce) { + const auto& test_params = GetParam(); + + // Create driver for this test (based on params) and then get a default + // device. + LOG(INFO) << "Creating driver '" << test_params.driver_name << "'..."; + auto driver_or = + DriverRegistry::shared_registry()->Create(test_params.driver_name); + if (IsUnavailable(driver_or.status())) { + LOG(WARNING) << "Skipping test as driver is unavailable: " + << driver_or.status(); + GTEST_SKIP(); + return; + } + ASSERT_OK_AND_ASSIGN(auto driver, driver_or); + ASSERT_OK_AND_ASSIGN(auto available_devices, + driver->EnumerateAvailableDevices()); + for (const auto& device_info : available_devices) { + LOG(INFO) << " Device: " << device_info.name(); + } + LOG(INFO) << "Creating default device..."; + ASSERT_OK_AND_ASSIGN(auto device, driver->CreateDefaultDevice()); + LOG(INFO) << "Successfully created device '" << device->info().name() << "'"; + + // Attempt to compile the appropriate executable. This may fail if there's no + // executable available in the input file that the driver can load. + auto executable_cache = device->CreateExecutableCache(); + auto& executable_table = *module_file_->root()->executable_table(); + auto multi_arch_executable_def = + executable_table.multi_arch_executables()->Get( + test_params.executable_ordinal); + ref_ptr<Executable> executable; + for (auto executable_def : *multi_arch_executable_def->executables()) { + if (!executable_cache->CanPrepareFormat(executable_def->format())) { + continue; + } + ExecutableSpec spec; + spec.format = executable_def->format(); + spec.executable_data = *executable_def->contents(); + ASSERT_OK_AND_ASSIGN(executable, + executable_cache->PrepareExecutable( + ExecutableCachingMode::kDefault, spec)); + break; + } + ASSERT_NE(executable, nullptr) + << "No executable found that has a supported format for driver " + << test_params.driver_name; + + // Create I/O buffers. + ASSERT_OK_AND_ASSIGN(auto arg0_buffer, + device->allocator()->Allocate( + MemoryType::kHostLocal | MemoryType::kDeviceVisible, + BufferUsage::kAll, 4 * sizeof(float))); + ASSERT_OK_AND_ASSIGN(auto arg1_buffer, + device->allocator()->Allocate( + MemoryType::kHostLocal | MemoryType::kDeviceVisible, + BufferUsage::kAll, 4 * sizeof(float))); + ASSERT_OK_AND_ASSIGN(auto ret0_buffer, + device->allocator()->Allocate( + MemoryType::kHostLocal | MemoryType::kDeviceVisible, + BufferUsage::kAll, 4 * sizeof(float))); + + // Populate initial values for 4 * 2 = 8. + // We scribble into the result buffer so that it's easy to ensure it's + // overwritten. + ASSERT_OK(arg0_buffer->Fill32(4.0f)); + ASSERT_OK(arg1_buffer->Fill32(2.0f)); + ASSERT_OK(ret0_buffer->Fill32(99999.0f)); + + // Record the command buffer that dispatches the executable. + ASSERT_OK_AND_ASSIGN( + auto cmd, device->CreateCommandBuffer( + CommandBufferMode::kOneShot, + CommandCategory::kTransfer | CommandCategory::kDispatch)); + ASSERT_OK(cmd->Begin()); + DispatchRequest dispatch_request; + dispatch_request.executable = executable.get(); + dispatch_request.entry_point = 0; + dispatch_request.workload[0] = 4; + dispatch_request.workload[1] = 1; + dispatch_request.workload[2] = 1; + BufferBinding bindings[3]; + bindings[0].buffer = arg0_buffer.get(); + bindings[0].access = MemoryAccess::kRead; + bindings[0].element_size = sizeof(float); + bindings[0].shape = {4}; + bindings[1].buffer = arg1_buffer.get(); + bindings[1].access = MemoryAccess::kRead; + bindings[1].element_size = sizeof(float); + bindings[1].shape = {4}; + bindings[2].buffer = ret0_buffer.get(); + bindings[2].access = MemoryAccess::kDiscardWrite; + bindings[2].element_size = sizeof(float); + bindings[2].shape = {4}; + dispatch_request.bindings = bindings; + ASSERT_OK(cmd->Dispatch(dispatch_request)); + ASSERT_OK(cmd->End()); + + // Schedule and wait for completion. + ASSERT_FALSE(device->dispatch_queues().empty()); + CommandQueue* queue = device->dispatch_queues().front(); + ASSERT_OK_AND_ASSIGN(auto fence, device->CreateFence(0u)); + ASSERT_OK( + queue->Submit(SubmissionBatch{{}, {cmd.get()}, {}}, {fence.get(), 1u})); + ASSERT_OK(device->WaitAllFences({{fence.get(), 1u}}, absl::InfiniteFuture())); + + // Read back the results. + ASSERT_OK_AND_ASSIGN(auto ret0_mapping, + ret0_buffer->MapMemory<float>(MemoryAccess::kRead)); + EXPECT_THAT(ret0_mapping.contents(), + ::testing::ElementsAreArray({8.0f, 8.0f, 8.0f, 8.0f})); +} + +INSTANTIATE_TEST_SUITE_P(AllDrivers, SimpleComputeTest, + ::testing::ValuesIn(GetAvailableDriverTestParams()), + ::testing::PrintToStringParamName()); + +} // namespace +} // namespace samples +} // namespace hal +} // namespace iree
diff --git a/iree/samples/hal/simple_compute_test.mlir b/samples/hal/simple_compute_test.mlir similarity index 100% rename from iree/samples/hal/simple_compute_test.mlir rename to samples/hal/simple_compute_test.mlir
diff --git a/samples/rt/BUILD b/samples/rt/BUILD new file mode 100644 index 0000000..72b7343 --- /dev/null +++ b/samples/rt/BUILD
@@ -0,0 +1,72 @@ +# Samples demonstrating use of the RT API. + +load("///tools:compilation.bzl", "iree_bytecode_module") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +iree_bytecode_module( + name = "simple_module_test_bytecode_module", + srcs = ["simple_module_test.mlir"], + cc_namespace = "iree::rt::samples", +) + +cc_test( + name = "bytecode_module_test", + srcs = ["bytecode_module_test.cc"], + data = [ + # When building with --config=asan you must specify the following + # envvar when using Vulkan + a local Nvidia GPU: + # LSAN_OPTIONS=suppressions=third_party/iree/tools/sanitizer_suppressions.txt + "///tools:sanitizer_suppressions.txt", + ], + deps = [ + ":simple_module_test_bytecode_module_cc", + "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/strings", + "///base:flatbuffer_util", + "///base:status", + "///base:status_matchers", + "///hal:buffer_view", + "///hal:driver_registry", + "///rt", + "///schemas", + "///vm:bytecode_module", + "///vm:sequencer_module", + + # These are the drivers we support running with and can produce + # executables for from the source MLIR. + "///hal/interpreter:interpreter_driver_module", # build-cleaner: keep + # TODO(benvanik): include SPIR-V. + # "///hal/vulkan:vulkan_driver_module", # build-cleaner: keep + ], +) + +cc_test( + name = "bytecode_module_api_test", + srcs = ["bytecode_module_api_test.cc"], + data = [ + # When building with --config=asan you must specify the following + # envvar when using Vulkan + a local Nvidia GPU: + # LSAN_OPTIONS=suppressions=third_party/iree/tools/sanitizer_suppressions.txt + "///tools:sanitizer_suppressions.txt", + ], + deps = [ + ":simple_module_test_bytecode_module_cc", + "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/strings", + "///base:api", + "///hal:api", + "///hal:driver_registry", + "///rt:api", + "///vm:api", + + # These are the drivers we support running with and can produce + # executables for from the source MLIR. + "///hal/interpreter:interpreter_driver_module", # build-cleaner: keep + # TODO(benvanik): include SPIR-V. + # "///hal/vulkan:vulkan_driver_module", # build-cleaner: keep + ], +)
diff --git a/iree/samples/rt/CMakeLists.txt b/samples/rt/CMakeLists.txt similarity index 100% rename from iree/samples/rt/CMakeLists.txt rename to samples/rt/CMakeLists.txt
diff --git a/samples/rt/bytecode_module_api_test.cc b/samples/rt/bytecode_module_api_test.cc new file mode 100644 index 0000000..fbb980b --- /dev/null +++ b/samples/rt/bytecode_module_api_test.cc
@@ -0,0 +1,188 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/strings/str_replace.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +// C API: +#include "base/api.h" +#include "hal/api.h" +#include "rt/api.h" +#include "vm/api.h" + +// Only temporary, used for test device registration: +#include "hal/driver_registry.h" + +// Compiled module embedded here to avoid file IO: +#include "samples/rt/simple_module_test_bytecode_module.h" + +namespace iree { +namespace rt { +namespace samples { +namespace { + +#define ASSERT_API_OK(expr) ASSERT_EQ(IREE_STATUS_OK, (expr)) + +struct TestParams { + // HAL driver to use for the test. + std::string driver_name; +}; + +std::ostream& operator<<(std::ostream& os, const TestParams& params) { + return os << absl::StrReplaceAll(params.driver_name, {{":", "_"}}); +} + +// Builds a list of tests to run based on the linked in driver modules. +std::vector<TestParams> GetAvailableDriverTestParams() { + std::vector<TestParams> all_test_params; + for (const auto& driver_name : + hal::DriverRegistry::shared_registry()->EnumerateAvailableDrivers()) { + TestParams test_params; + test_params.driver_name = driver_name; + all_test_params.push_back(std::move(test_params)); + } + return all_test_params; +} + +class BytecodeModuleApiTest : public ::testing::Test, + public ::testing::WithParamInterface<TestParams> { + protected: +}; + +TEST_P(BytecodeModuleApiTest, RunOnce) { + iree_rt_instance_t* instance = nullptr; + ASSERT_API_OK(iree_rt_instance_create(IREE_ALLOCATOR_DEFAULT, &instance)); + + // TEMPORARY: until policies and placement are performed with manually + // register drivers via a magic function. + const auto& driver_name = GetParam().driver_name; + LOG(INFO) << "Creating driver '" << driver_name << "'..."; + ASSERT_API_OK(iree_rt_instance_register_driver_ex( + instance, iree_string_view_t{driver_name.data(), driver_name.size()})); + + // Allocate a context that will hold the module state across invocations. + iree_rt_policy_t* dummy_policy = nullptr; + ASSERT_API_OK(iree_rt_policy_create(IREE_ALLOCATOR_DEFAULT, &dummy_policy)); + iree_rt_context_t* context = nullptr; + ASSERT_API_OK(iree_rt_context_create(instance, dummy_policy, + IREE_ALLOCATOR_DEFAULT, &context)); + iree_rt_policy_release(dummy_policy); + + // Load bytecode module from the embedded data. + LOG(INFO) << "Loading simple_module_test.mlir..."; + const auto* module_file_toc = simple_module_test_bytecode_module_create(); + iree_rt_module_t* bytecode_module = nullptr; + ASSERT_API_OK(iree_vm_bytecode_module_create_from_buffer( + iree_const_byte_span_t{ + reinterpret_cast<const uint8_t*>(module_file_toc->data), + module_file_toc->size}, + nullptr, nullptr, IREE_ALLOCATOR_DEFAULT, &bytecode_module)); + + // Register modules that we want to be able to use in the context. + std::vector<iree_rt_module_t*> modules; + modules.push_back(bytecode_module); + ASSERT_API_OK( + iree_rt_context_register_modules(context, &modules[0], modules.size())); + iree_rt_module_release(bytecode_module); + LOG(INFO) << "Module loaded and context is ready for use"; + + // Lookup the entry point function. + iree_rt_function_t main_function; + const char kMainFunctionName[] = "module.simple_mul"; + ASSERT_API_OK(iree_rt_context_resolve_function( + context, + iree_string_view_t{kMainFunctionName, sizeof(kMainFunctionName) - 1}, + &main_function)); + + // Allocate buffers that can be mapped on the CPU and that can also be used + // on the device. Not all devices support this, but the ones we have now do. + LOG(INFO) << "Creating I/O buffers..."; + constexpr int kElementCount = 4; + iree_hal_buffer_t* arg0_buffer = nullptr; + iree_hal_buffer_t* arg1_buffer = nullptr; + ASSERT_API_OK(iree_rt_context_allocate_device_visible_buffer( + context, IREE_HAL_BUFFER_USAGE_ALL, sizeof(float) * kElementCount, + IREE_ALLOCATOR_DEFAULT, &arg0_buffer)); + ASSERT_API_OK(iree_rt_context_allocate_device_visible_buffer( + context, IREE_HAL_BUFFER_USAGE_ALL, sizeof(float) * kElementCount, + IREE_ALLOCATOR_DEFAULT, &arg1_buffer)); + + // Populate initial values for 4 * 2 = 8. + float kFloat4 = 4.0f; + float kFloat2 = 2.0f; + ASSERT_API_OK(iree_hal_buffer_fill(arg0_buffer, 0, IREE_WHOLE_BUFFER, + &kFloat4, sizeof(float))); + ASSERT_API_OK(iree_hal_buffer_fill(arg1_buffer, 0, IREE_WHOLE_BUFFER, + &kFloat2, sizeof(float))); + + // Wrap buffers in buffer views to provide shape information. + std::array<iree_hal_buffer_view_t*, 2> arg_buffer_views; + ASSERT_API_OK(iree_hal_buffer_view_create( + arg0_buffer, iree_shape_t{1, {kElementCount}}, sizeof(float), + IREE_ALLOCATOR_DEFAULT, &arg_buffer_views[0])); + ASSERT_API_OK(iree_hal_buffer_view_create( + arg1_buffer, iree_shape_t{1, {kElementCount}}, sizeof(float), + IREE_ALLOCATOR_DEFAULT, &arg_buffer_views[1])); + iree_hal_buffer_release(arg0_buffer); + iree_hal_buffer_release(arg1_buffer); + + // Call into the @simple_mul function. + LOG(INFO) << "Calling @simple_mul..."; + iree_rt_invocation_t* invocation = nullptr; + ASSERT_API_OK(iree_rt_invocation_create( + context, &main_function, nullptr, nullptr, arg_buffer_views.data(), 2, + nullptr, 0, IREE_ALLOCATOR_DEFAULT, &invocation)); + ASSERT_API_OK(iree_hal_buffer_view_release(arg_buffer_views[0])); + ASSERT_API_OK(iree_hal_buffer_view_release(arg_buffer_views[1])); + ASSERT_API_OK( + iree_rt_invocation_await(invocation, IREE_TIME_INFINITE_FUTURE)); + + // Get the result buffers from the invocation. + LOG(INFO) << "Retreiving results..."; + std::array<iree_hal_buffer_view_t*, 2> result_buffer_views; + iree_host_size_t result_count; + ASSERT_API_OK(iree_rt_invocation_consume_results( + invocation, result_buffer_views.size(), IREE_ALLOCATOR_DEFAULT, + result_buffer_views.data(), &result_count)); + iree_rt_invocation_release(invocation); + + // Read back the results and ensure we got the right values. + LOG(INFO) << "Reading back results..."; + iree_hal_buffer_t* result_buffer = + iree_hal_buffer_view_buffer(result_buffer_views[0]); + iree_hal_mapped_memory_t mapped_memory; + ASSERT_API_OK(iree_hal_buffer_map(result_buffer, IREE_HAL_MEMORY_ACCESS_READ, + 0, IREE_WHOLE_BUFFER, &mapped_memory)); + ASSERT_THAT(absl::Span<const float>( + reinterpret_cast<const float*>(mapped_memory.contents.data), + mapped_memory.contents.data_length / sizeof(float)), + ::testing::ElementsAreArray({8.0f, 8.0f, 8.0f, 8.0f})); + ASSERT_API_OK(iree_hal_buffer_unmap(result_buffer, &mapped_memory)); + LOG(INFO) << "Results match!"; + + iree_hal_buffer_view_release(result_buffer_views[0]); + + iree_rt_context_release(context); + iree_rt_instance_release(instance); +} + +INSTANTIATE_TEST_SUITE_P(AllDrivers, BytecodeModuleApiTest, + ::testing::ValuesIn(GetAvailableDriverTestParams()), + ::testing::PrintToStringParamName()); + +} // namespace +} // namespace samples +} // namespace rt +} // namespace iree
diff --git a/samples/rt/bytecode_module_test.cc b/samples/rt/bytecode_module_test.cc new file mode 100644 index 0000000..67c2dc9 --- /dev/null +++ b/samples/rt/bytecode_module_test.cc
@@ -0,0 +1,170 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// A simple sample demonstrating simple synchronous module loading and VM use. +// This will load an IREE module containing a @simple_mul method that performs +// an element-wise multiplication. It will invoke @simple_mul in the VM, once +// for each available HAL driver linked into the binary. +// +// The synchronous invocation method (Context::Invoke) used here waits until all +// asynchronous HAL work completes before returning. It's still possible get +// overlapped execution by invoking methods from other threads with their own +// FiberState, though it's best to use the asynchronous API instead. +// +// The `iree_module` build rule is used to translate the MLIR to the module +// flatbuffer. Additional HAL backend target support can be defined there. + +#include "vm/bytecode_module.h" + +#include "absl/strings/str_replace.h" +#include "base/flatbuffer_util.h" +#include "base/status.h" +#include "base/status_matchers.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "hal/buffer_view.h" +#include "hal/driver_registry.h" +#include "rt/context.h" +#include "rt/instance.h" +#include "samples/rt/simple_module_test_bytecode_module.h" +#include "schemas/module_def_generated.h" +#include "vm/sequencer_module.h" + +namespace iree { +namespace rt { +namespace samples { +namespace { + +using ::iree::hal::BufferView; +using ::iree::vm::ModuleFile; + +struct TestParams { + // HAL driver to use for the test. + std::string driver_name; +}; + +std::ostream& operator<<(std::ostream& os, const TestParams& params) { + return os << absl::StrReplaceAll(params.driver_name, {{":", "_"}}); +} + +// Builds a list of tests to run based on the linked in driver modules. +std::vector<TestParams> GetAvailableDriverTestParams() { + std::vector<TestParams> all_test_params; + for (const auto& driver_name : + hal::DriverRegistry::shared_registry()->EnumerateAvailableDrivers()) { + TestParams test_params; + test_params.driver_name = driver_name; + all_test_params.push_back(std::move(test_params)); + } + return all_test_params; +} + +class BytecodeModuleTest : public ::testing::Test, + public ::testing::WithParamInterface<TestParams> { + protected: +}; + +TEST_P(BytecodeModuleTest, RunOnce) { + auto instance = make_ref<Instance>(); + + // Create driver for this test (based on params) and then get a default + // device. + const auto& test_params = GetParam(); + LOG(INFO) << "Creating driver '" << test_params.driver_name << "'..."; + auto driver_or = + hal::DriverRegistry::shared_registry()->Create(test_params.driver_name); + if (IsUnavailable(driver_or.status())) { + LOG(WARNING) << "Skipping test as driver is unavailable: " + << driver_or.status(); + GTEST_SKIP(); + return; + } + ASSERT_OK_AND_ASSIGN(auto driver, driver_or); + ASSERT_OK_AND_ASSIGN(auto available_devices, + driver->EnumerateAvailableDevices()); + for (const auto& device_info : available_devices) { + LOG(INFO) << " Device: " << device_info.name(); + } + LOG(INFO) << "Creating default device..."; + ASSERT_OK_AND_ASSIGN(auto device, driver->CreateDefaultDevice()); + ASSERT_OK(instance->device_manager()->RegisterDevice(device)); + LOG(INFO) << "Successfully created device '" << device->info().name() << "'"; + + // Make a new context and load the precompiled module file (from + // simple_module_test.mlir) into it. + LOG(INFO) << "Loading simple_module_test.mlir..."; + auto policy = make_ref<Policy>(); + Context context(add_ref(instance), add_ref(policy)); + const auto* module_file_toc = simple_module_test_bytecode_module_create(); + ASSERT_OK_AND_ASSIGN(auto module_file, + vm::ModuleFile::WrapBuffer( + ModuleDefIdentifier(), + absl::MakeSpan(reinterpret_cast<const uint8_t*>( + module_file_toc->data), + module_file_toc->size))); + ASSERT_OK_AND_ASSIGN(auto main_module, + vm::SequencerModule::FromFile(std::move(module_file))); + ASSERT_OK(context.RegisterModule(std::move(main_module))); + LOG(INFO) << "Module loaded and context is ready for use"; + + // Allocate buffers that can be mapped on the CPU and that can also be used + // on the device. Not all devices support this, but the ones we have now do. + LOG(INFO) << "Creating I/O buffers..."; + constexpr int kElementCount = 4; + ASSERT_OK_AND_ASSIGN( + auto arg0_buffer, + instance->device_manager()->AllocateDeviceVisibleBuffer( + hal::BufferUsage::kAll, sizeof(float) * kElementCount, {{device}})); + ASSERT_OK_AND_ASSIGN( + auto arg1_buffer, + instance->device_manager()->AllocateDeviceVisibleBuffer( + hal::BufferUsage::kAll, sizeof(float) * kElementCount, {{device}})); + + // Populate initial values for 4 * 2 = 8. + ASSERT_OK(arg0_buffer->Fill32(4.0f)); + ASSERT_OK(arg1_buffer->Fill32(2.0f)); + + // Call into the @simple_mul function. + LOG(INFO) << "Calling @simple_mul..."; + absl::InlinedVector<BufferView, 8> args{ + BufferView{add_ref(arg0_buffer), {kElementCount}, sizeof(float)}, + BufferView{add_ref(arg1_buffer), {kElementCount}, sizeof(float)}, + }; + ASSERT_OK_AND_ASSIGN(auto simple_mul, + context.ResolveFunction("module.simple_mul")); + ASSERT_OK_AND_ASSIGN(auto invocation, + Invocation::Create(add_ref(&context), simple_mul, + nullptr, {}, std::move(args))); + ASSERT_OK(invocation->Await(absl::InfiniteFuture())); + ASSERT_OK_AND_ASSIGN(auto results, invocation->ConsumeResults()); + + // Read back the results and ensure we got the right values. + LOG(INFO) << "Reading back results..."; + auto& ret_buffer_view = results[0]; + ASSERT_OK_AND_ASSIGN( + auto ret_mapping, + ret_buffer_view.buffer->MapMemory<float>(hal::MemoryAccess::kRead)); + ASSERT_THAT(ret_mapping.contents(), + ::testing::ElementsAreArray({8.0f, 8.0f, 8.0f, 8.0f})); + LOG(INFO) << "Results match!"; +} + +INSTANTIATE_TEST_SUITE_P(AllDrivers, BytecodeModuleTest, + ::testing::ValuesIn(GetAvailableDriverTestParams()), + ::testing::PrintToStringParamName()); + +} // namespace +} // namespace samples +} // namespace rt +} // namespace iree
diff --git a/iree/samples/rt/simple_module_test.mlir b/samples/rt/simple_module_test.mlir similarity index 100% rename from iree/samples/rt/simple_module_test.mlir rename to samples/rt/simple_module_test.mlir
diff --git a/schemas/BUILD b/schemas/BUILD new file mode 100644 index 0000000..cbcbfc0 --- /dev/null +++ b/schemas/BUILD
@@ -0,0 +1,252 @@ +load("//:build_defs.google.bzl", "FLATBUFFER_SUPPORTS_REFLECTIONS", "iree_build_test", "iree_flatbuffer_cc_library") +load("///build_tools/embed_data:build_defs.bzl", "cc_embed_data") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +FLATC_ARGS = [ + # Preserve workspace-relative include paths in generated code. + "--keep-prefix", + # Use C++11 'enum class' for enums. + "--scoped-enums", + # Include reflection tables used for dumping debug representations. + "--reflect-names", + # Generate FooT types for unpack/pack support. Note that this should only + # be used in tooling as the code size/runtime overhead is non-trivial. + "--gen-object-api", +] + +iree_flatbuffer_cc_library( + name = "archive_def_cc_fbs", + srcs = ["archive_def.fbs"], + flatc_args = FLATC_ARGS, + includes = [ + ":bytecode_def_cc_fbs_includes", + ":device_def_cc_fbs_includes", + ":device_group_def_cc_fbs_includes", + ":device_table_def_cc_fbs_includes", + ":executable_def_cc_fbs_includes", + ":executable_table_def_cc_fbs_includes", + ":function_def_cc_fbs_includes", + ":function_table_def_cc_fbs_includes", + ":module_def_cc_fbs_includes", + ":source_map_def_cc_fbs_includes", + ":type_def_cc_fbs_includes", + ], +) + +iree_flatbuffer_cc_library( + name = "bytecode_def_cc_fbs", + srcs = ["bytecode_def.fbs"], + flatc_args = FLATC_ARGS, + includes = [ + ], +) + +iree_flatbuffer_cc_library( + name = "debug_service_cc_fbs", + srcs = ["debug_service.fbs"], + flatc_args = FLATC_ARGS, + includes = [ + ":bytecode_def_cc_fbs_includes", + ":device_def_cc_fbs_includes", + ":device_group_def_cc_fbs_includes", + ":device_table_def_cc_fbs_includes", + ":executable_def_cc_fbs_includes", + ":executable_table_def_cc_fbs_includes", + ":function_def_cc_fbs_includes", + ":function_table_def_cc_fbs_includes", + ":module_def_cc_fbs_includes", + ":source_map_def_cc_fbs_includes", + ":type_def_cc_fbs_includes", + ], +) + +iree_flatbuffer_cc_library( + name = "device_def_cc_fbs", + srcs = ["device_def.fbs"], + flatc_args = FLATC_ARGS, + includes = [ + ], +) + +iree_flatbuffer_cc_library( + name = "device_group_def_cc_fbs", + srcs = ["device_group_def.fbs"], + flatc_args = FLATC_ARGS, + includes = [ + ], +) + +iree_flatbuffer_cc_library( + name = "device_table_def_cc_fbs", + srcs = ["device_table_def.fbs"], + flatc_args = FLATC_ARGS, + includes = [ + ":device_def_cc_fbs_includes", + ":device_group_def_cc_fbs_includes", + ], +) + +iree_flatbuffer_cc_library( + name = "executable_def_cc_fbs", + srcs = ["executable_def.fbs"], + flatc_args = FLATC_ARGS, + includes = [ + ], +) + +iree_flatbuffer_cc_library( + name = "executable_table_def_cc_fbs", + srcs = ["executable_table_def.fbs"], + flatc_args = FLATC_ARGS, + includes = [ + ":executable_def_cc_fbs_includes", + ], +) + +iree_flatbuffer_cc_library( + name = "function_def_cc_fbs", + srcs = ["function_def.fbs"], + flatc_args = FLATC_ARGS, + includes = [ + ":bytecode_def_cc_fbs_includes", + ":type_def_cc_fbs_includes", + ], +) + +iree_flatbuffer_cc_library( + name = "function_table_def_cc_fbs", + srcs = ["function_table_def.fbs"], + flatc_args = FLATC_ARGS, + includes = [ + ":bytecode_def_cc_fbs_includes", + ":function_def_cc_fbs_includes", + ":type_def_cc_fbs_includes", + ], +) + +iree_flatbuffer_cc_library( + name = "module_def_cc_fbs", + srcs = ["module_def.fbs"], + flatc_args = FLATC_ARGS, + includes = [ + ":bytecode_def_cc_fbs_includes", + ":device_def_cc_fbs_includes", + ":device_group_def_cc_fbs_includes", + ":device_table_def_cc_fbs_includes", + ":executable_def_cc_fbs_includes", + ":executable_table_def_cc_fbs_includes", + ":function_def_cc_fbs_includes", + ":function_table_def_cc_fbs_includes", + ":source_map_def_cc_fbs_includes", + ":type_def_cc_fbs_includes", + ], +) + +iree_flatbuffer_cc_library( + name = "source_map_def_cc_fbs", + srcs = ["source_map_def.fbs"], + flatc_args = FLATC_ARGS, + includes = [ + ], +) + +iree_flatbuffer_cc_library( + name = "spirv_executable_def_cc_fbs", + srcs = ["spirv_executable_def.fbs"], + flatc_args = FLATC_ARGS, + includes = [ + ], +) + +iree_flatbuffer_cc_library( + name = "type_def_cc_fbs", + srcs = ["type_def.fbs"], + flatc_args = FLATC_ARGS, + includes = [ + ], +) + +iree_build_test( + name = "schema_build_test", + targets = [ + ":archive_def_cc_fbs", + ":bytecode_def_cc_fbs", + ":debug_service_cc_fbs", + ":device_def_cc_fbs", + ":device_group_def_cc_fbs", + ":device_table_def_cc_fbs", + ":executable_def_cc_fbs", + ":executable_table_def_cc_fbs", + ":function_def_cc_fbs", + ":function_table_def_cc_fbs", + ":module_def_cc_fbs", + ":source_map_def_cc_fbs", + ":spirv_executable_def_cc_fbs", + ":type_def_cc_fbs", + ], +) + +cc_library( + name = "schemas", + hdrs = [ + ":archive_def_generated.h", + ":bytecode_def_generated.h", + ":debug_service_generated.h", + ":device_def_generated.h", + ":device_group_def_generated.h", + ":device_table_def_generated.h", + ":executable_def_generated.h", + ":executable_table_def_generated.h", + ":function_def_generated.h", + ":function_table_def_generated.h", + ":module_def_generated.h", + ":source_map_def_generated.h", + ":type_def_generated.h", + ], + deps = [ + ":archive_def_cc_fbs", + ":bytecode_def_cc_fbs", + ":debug_service_cc_fbs", + ":device_def_cc_fbs", + ":device_group_def_cc_fbs", + ":device_table_def_cc_fbs", + ":executable_def_cc_fbs", + ":executable_table_def_cc_fbs", + ":function_def_cc_fbs", + ":function_table_def_cc_fbs", + ":module_def_cc_fbs", + ":source_map_def_cc_fbs", + ":spirv_executable_def_cc_fbs", + ":type_def_cc_fbs", + "@com_github_google_flatbuffers//:flatbuffers", + ], +) + +REFLECTION_SRCS = [] if not FLATBUFFER_SUPPORTS_REFLECTIONS else [ + "archive_def.bfbs", + "bytecode_def.bfbs", + "debug_service.bfbs", + "executable_def.bfbs", + "executable_table_def.bfbs", + "function_def.bfbs", + "function_table_def.bfbs", + "module_def.bfbs", + "source_map_def.bfbs", + "spirv_executable_def.bfbs", + "type_def.bfbs", + "device_def.bfbs", + "device_group_def.bfbs", + "device_table_def.bfbs", +] + +cc_embed_data( + name = "reflection_data", + srcs = REFLECTION_SRCS, + cc_file_output = "reflection_data.cc", + cpp_namespace = "iree::schemas", + h_file_output = "reflection_data.h", +)
diff --git a/iree/schemas/CMakeLists.txt b/schemas/CMakeLists.txt similarity index 100% rename from iree/schemas/CMakeLists.txt rename to schemas/CMakeLists.txt
diff --git a/schemas/archive_def.fbs b/schemas/archive_def.fbs new file mode 100644 index 0000000..114415c --- /dev/null +++ b/schemas/archive_def.fbs
@@ -0,0 +1,14 @@ +include "schemas/module_def.fbs"; + +namespace iree; + +// 'Executable ARChive'. +file_identifier "EARC"; +file_extension "earc"; + +table ArchiveDef { + name:string; + modules:[ModuleDef]; +} + +root_type ArchiveDef;
diff --git a/schemas/bytecode/BUILD b/schemas/bytecode/BUILD new file mode 100644 index 0000000..7b7dcdd --- /dev/null +++ b/schemas/bytecode/BUILD
@@ -0,0 +1,29 @@ +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "bytecode_v0", + hdrs = ["bytecode_v0.h"], + deps = [ + "///base:bitfield", + "@com_google_absl//absl/base:core_headers", + ], +) + +cc_library( + name = "interpreter_bytecode_v0", + hdrs = ["interpreter_bytecode_v0.h"], + deps = [ + ":bytecode_v0", + ], +) + +cc_library( + name = "sequencer_bytecode_v0", + hdrs = ["sequencer_bytecode_v0.h"], + deps = [ + ":bytecode_v0", + ], +)
diff --git a/iree/schemas/bytecode/CMakeLists.txt b/schemas/bytecode/CMakeLists.txt similarity index 100% rename from iree/schemas/bytecode/CMakeLists.txt rename to schemas/bytecode/CMakeLists.txt
diff --git a/schemas/bytecode/bytecode_v0.h b/schemas/bytecode/bytecode_v0.h new file mode 100644 index 0000000..c9e6398 --- /dev/null +++ b/schemas/bytecode/bytecode_v0.h
@@ -0,0 +1,143 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Opcode table for the V0 binary format. +// Additions are fine but changing the behavior or order of any opcodes will +// break pasring of existing files. +// +// Opcodes have been selected on frequency of use, general applicability, and +// relative stability. Experimental ops should be implemented via the FFI fisrt +// before graduating into the core set. Ops that may only be present on certain +// targets should also be kept as imports via the FFI. +// +// Opcodes may be specified for particular types (int32_t), categories of types +// (all floating-point types), or implicit types (output matches input). Saving +// opcode space by sharing a single opcode for multiple types is preferred +// except where hot operations are performed (for example, comparison used in +// loop iteratosr). + +#ifndef IREE_SCHEMAS_BYTECODE_BYTECODE_V0_H_ +#define IREE_SCHEMAS_BYTECODE_BYTECODE_V0_H_ + +#include <cstdint> + +#include "base/bitfield.h" + +namespace iree { + +#define IREE_CONSTANT_ENCODING_LIST(ENC) \ + ENC(0x00, kDense, "dense") \ + ENC(0x01, kSplat, "splat") + +#define IREE_TYPE_LIST(TYP) \ + TYP(0x00, kI8, "i8", 1) \ + TYP(0x01, kI16, "i16", 2) \ + TYP(0x02, kI32, "i32", 4) \ + TYP(0x03, kI64, "i64", 8) \ + TYP(0x04, kF16, "f16", 2) \ + TYP(0x05, kF32, "f32", 4) \ + TYP(0x06, kF64, "f64", 8) \ + TYP(0x80, kDevice, "device", 0) \ + TYP(0x81, kCommandBuffer, "command_buffer", 0) \ + TYP(0x82, kEvent, "event", 0) \ + TYP(0x83, kSemaphore, "semaphore", 0) \ + TYP(0x84, kFence, "fence", 0) \ + TYP(0xFF, kOpaque, "opaque", 0) + +#define IREE_CMPI_PREDICATE_LIST(PRED) \ + PRED(0, kEq, "eq") \ + PRED(1, kNe, "ne") \ + PRED(2, kSlt, "slt") \ + PRED(3, kSle, "sle") \ + PRED(4, kSgt, "sgt") \ + PRED(5, kSge, "sge") \ + PRED(6, kUlt, "ult") \ + PRED(7, kUle, "ule") \ + PRED(8, kUgt, "ugt") \ + PRED(9, kUge, "uge") + +#define IREE_CMPF_PREDICATE_LIST(PRED) \ + PRED(0, kFalse, "false") \ + PRED(1, kOeq, "oeq") \ + PRED(2, kOgt, "ogt") \ + PRED(3, kOge, "oge") \ + PRED(4, kOlt, "olt") \ + PRED(5, kOle, "ole") \ + PRED(6, kOne, "one") \ + PRED(7, kOrd, "ord") \ + PRED(8, kUeq, "ueq") \ + PRED(9, kUgt, "ugt") \ + PRED(10, kUge, "uge") \ + PRED(11, kUlt, "ult") \ + PRED(12, kUle, "ule") \ + PRED(13, kUne, "une") \ + PRED(14, kUno, "uno") \ + PRED(15, kTrue, "true") + +// NOTE: FF is a to-be-defined flag value for encoding/decoding. +#define FLAG(V) ::iree::OpcodeFlag::V + +#define RSV(opcode, RESERVED_OPC) \ + RESERVED_OPC(opcode, kReserved##opcode, "rsv." #opcode, FLAG(kDefault), "", \ + FF) + +#define DECLARE_ENUM(ordinal, enum_name, ...) enum_name = ordinal, + +enum class ConstantEncoding : uint8_t { + IREE_CONSTANT_ENCODING_LIST(DECLARE_ENUM) +}; + +enum class BuiltinType : uint8_t { IREE_TYPE_LIST(DECLARE_ENUM) }; + +enum class CmpIPredicate : uint8_t { IREE_CMPI_PREDICATE_LIST(DECLARE_ENUM) }; + +enum class CmpFPredicate : uint8_t { IREE_CMPF_PREDICATE_LIST(DECLARE_ENUM) }; + +#undef DECLARE_ENUM + +static constexpr uint8_t kBuiltinTypeCount = + static_cast<uint8_t>(BuiltinType::kF64) + 1; + +enum class OpcodeFlag : uint8_t { + kDefault = 0, +}; +IREE_BITFIELD(OpcodeFlag); +using OpcodeFlagBitfield = OpcodeFlag; + +enum class OperandEncoding : char { + kNone = '\0', + kInputSlot = 's', + kVariadicInputSlots = 'S', + kOutputSlot = 'o', + kVariadicOutputSlots = 'O', + kResultSlot = 'r', + kVariadicResultSlots = 'R', + kVariadicTransferSlots = 'T', + kConstant = 'c', + kFunctionOrdinal = 'f', + kImportOrdinal = 'F', + kDispatchOrdinal = 'd', + kBlockOffset = 'b', + kTypeIndex = 't', + kIndex = 'i', + kIndexList = 'I', + kCmpIPredicate = 'p', + kCmpFPredicate = 'P', +}; +IREE_BITFIELD(OperandEncoding); +using OperandEncodingBitfield = OperandEncoding; + +} // namespace iree + +#endif // IREE_SCHEMAS_BYTECODE_BYTECODE_V0_H_
diff --git a/schemas/bytecode/interpreter_bytecode_v0.h b/schemas/bytecode/interpreter_bytecode_v0.h new file mode 100644 index 0000000..9d21d32 --- /dev/null +++ b/schemas/bytecode/interpreter_bytecode_v0.h
@@ -0,0 +1,325 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Opcode table for the V0 binary format. +// Additions are fine but changing the behavior or order of any opcodes will +// break parsing of existing files. +// +// Opcodes have been selected on frequency of use, general applicability, and +// relative stability. Experimental ops should be implemented via the Foreign +// Function Interface (FFI) first before graduating into the core set. Ops that +// may only be present on certain targets should also be kept as imports via the +// FFI. +// +// Opcodes may be specified for particular types (int32_t), categories of types +// (all floating-point types), or implicit types (output matches input). Saving +// opcode space by sharing a single opcode for multiple types is preferred +// except where hot operations are performed (for example, comparison used in +// loop iterators). + +#ifndef IREE_SCHEMAS_BYTECODE_INTERPRETER_BYTECODE_V0_H_ +#define IREE_SCHEMAS_BYTECODE_INTERPRETER_BYTECODE_V0_H_ + +#include "schemas/bytecode/bytecode_v0.h" + +namespace iree { + +#define IREE_INTERPRETER_OPCODE_LIST(OPC, RESERVED_OPC) \ + OPC(0x00, kConstant, "constant", FLAG(kDefault), "cr", FF) \ + \ + OPC(0x01, kCall, "call", FLAG(kDefault), "fSR", FF) \ + OPC(0x02, kCallImport, "call_import", FLAG(kDefault), "FSR", FF) \ + OPC(0x03, kCallIndirect, "call_indirect", FLAG(kDefault), "tsSR", FF) \ + OPC(0x04, kReturn, "return", FLAG(kDefault), "S", FF) \ + OPC(0x05, kBranch, "br", FLAG(kDefault), "bT", FF) \ + OPC(0x06, kCondBranch, "cond_br", FLAG(kDefault), "sbTbT", FF) \ + OPC(0x07, kCmpI, "cmp_i", FLAG(kDefault), "psso", FF) \ + OPC(0x08, kCmpF, "cmp_f", FLAG(kDefault), "Psso", FF) \ + \ + RSV(0x09, RESERVED_OPC) \ + RSV(0x0A, RESERVED_OPC) \ + RSV(0x0B, RESERVED_OPC) \ + RSV(0x0C, RESERVED_OPC) \ + RSV(0x0D, RESERVED_OPC) \ + RSV(0x0E, RESERVED_OPC) \ + RSV(0x0F, RESERVED_OPC) \ + RSV(0x10, RESERVED_OPC) \ + RSV(0x11, RESERVED_OPC) \ + RSV(0x12, RESERVED_OPC) \ + RSV(0x13, RESERVED_OPC) \ + RSV(0x14, RESERVED_OPC) \ + RSV(0x15, RESERVED_OPC) \ + RSV(0x16, RESERVED_OPC) \ + RSV(0x17, RESERVED_OPC) \ + RSV(0x18, RESERVED_OPC) \ + RSV(0x19, RESERVED_OPC) \ + RSV(0x1A, RESERVED_OPC) \ + RSV(0x1B, RESERVED_OPC) \ + RSV(0x1C, RESERVED_OPC) \ + RSV(0x1D, RESERVED_OPC) \ + RSV(0x1E, RESERVED_OPC) \ + RSV(0x1F, RESERVED_OPC) \ + \ + OPC(0x20, kAllocStatic, "alloc_static", FLAG(kDefault), "Icr", FF) \ + OPC(0x21, kAllocStack, "alloc_stack", FLAG(kDefault), "itISr", FF) \ + OPC(0x22, kAllocStackInit, "alloc_stack_init", FLAG(kDefault), "tIScr", FF) \ + OPC(0x23, kAllocHeap, "alloc_heap", FLAG(kDefault), "itISr", FF) \ + OPC(0x24, kDiscard, "discard", FLAG(kDefault), "s", FF) \ + \ + RSV(0x25, RESERVED_OPC) \ + RSV(0x26, RESERVED_OPC) \ + RSV(0x27, RESERVED_OPC) \ + RSV(0x28, RESERVED_OPC) \ + RSV(0x29, RESERVED_OPC) \ + RSV(0x2A, RESERVED_OPC) \ + RSV(0x2B, RESERVED_OPC) \ + RSV(0x2C, RESERVED_OPC) \ + RSV(0x2D, RESERVED_OPC) \ + RSV(0x2E, RESERVED_OPC) \ + RSV(0x2F, RESERVED_OPC) \ + \ + OPC(0x30, kRank, "rank", FLAG(kDefault), "so", FF) \ + OPC(0x31, kDim, "dim", FLAG(kDefault), "iso", FF) \ + OPC(0x32, kShape, "shape", FLAG(kDefault), "so", FF) \ + OPC(0x33, kLength, "length", FLAG(kDefault), "so", FF) \ + OPC(0x34, kDynamicSlice, "dynamic_slice", FLAG(kDefault), "sssr", FF) \ + OPC(0x35, kStaticSlice, "static_slice", FLAG(kDefault), "sIIr", FF) \ + OPC(0x36, kDynamicCopy, "dynamic_copy", FLAG(kDefault), "ssoss", FF) \ + OPC(0x37, kStaticCopy, "static_copy", FLAG(kDefault), "sIoII", FF) \ + OPC(0x38, kClone, "clone", FLAG(kDefault), "sr", FF) \ + RSV(0x39, RESERVED_OPC) \ + OPC(0x3A, kSplit, "split", FLAG(kDefault), "isR", FF) \ + OPC(0x3B, kAssign, "assign", FLAG(kDefault), "sr", FF) \ + OPC(0x3C, kCondAssign, "cond_assign", FLAG(kDefault), "sssr", FF) \ + OPC(0x3D, kReshape, "reshape", FLAG(kDefault), "ssr", FF) \ + OPC(0x3E, kSelect, "select", FLAG(kDefault), "ssso", FF) \ + OPC(0x3F, kTranspose, "transpose", FLAG(kDefault), "sso", FF) \ + OPC(0x40, kBroadcast, "broadcast", FLAG(kDefault), "sso", FF) \ + OPC(0x41, kTile, "tile", FLAG(kDefault), "sso", FF) \ + OPC(0x42, kReverse, "reverse", FLAG(kDefault), "sso", FF) \ + OPC(0x43, kPad, "pad", FLAG(kDefault), "ssssso", FF) \ + \ + RSV(0x44, RESERVED_OPC) \ + RSV(0x45, RESERVED_OPC) \ + RSV(0x46, RESERVED_OPC) \ + RSV(0x47, RESERVED_OPC) \ + RSV(0x48, RESERVED_OPC) \ + RSV(0x49, RESERVED_OPC) \ + RSV(0x4A, RESERVED_OPC) \ + RSV(0x4B, RESERVED_OPC) \ + RSV(0x4C, RESERVED_OPC) \ + RSV(0x4D, RESERVED_OPC) \ + RSV(0x4E, RESERVED_OPC) \ + RSV(0x4F, RESERVED_OPC) \ + \ + OPC(0x50, kNot, "not", FLAG(kDefault), "so", FF) \ + OPC(0x51, kAnd, "and", FLAG(kDefault), "sso", FF) \ + OPC(0x52, kOr, "or", FLAG(kDefault), "sso", FF) \ + OPC(0x53, kXor, "xor", FLAG(kDefault), "sso", FF) \ + OPC(0x54, kShiftLeft, "sll", FLAG(kDefault), "sso", FF) \ + OPC(0x55, kShiftRightLogical, "srl", FLAG(kDefault), "sso", FF) \ + OPC(0x56, kShiftRightArithmetic, "sra", FLAG(kDefault), "sso", FF) \ + \ + RSV(0x57, RESERVED_OPC) \ + RSV(0x58, RESERVED_OPC) \ + RSV(0x59, RESERVED_OPC) \ + RSV(0x5A, RESERVED_OPC) \ + RSV(0x5B, RESERVED_OPC) \ + RSV(0x5C, RESERVED_OPC) \ + RSV(0x5D, RESERVED_OPC) \ + RSV(0x5E, RESERVED_OPC) \ + RSV(0x5F, RESERVED_OPC) \ + RSV(0x60, RESERVED_OPC) \ + RSV(0x61, RESERVED_OPC) \ + RSV(0x62, RESERVED_OPC) \ + RSV(0x63, RESERVED_OPC) \ + RSV(0x64, RESERVED_OPC) \ + RSV(0x65, RESERVED_OPC) \ + RSV(0x66, RESERVED_OPC) \ + RSV(0x67, RESERVED_OPC) \ + RSV(0x68, RESERVED_OPC) \ + RSV(0x69, RESERVED_OPC) \ + RSV(0x6A, RESERVED_OPC) \ + RSV(0x6B, RESERVED_OPC) \ + RSV(0x6C, RESERVED_OPC) \ + RSV(0x6D, RESERVED_OPC) \ + RSV(0x6E, RESERVED_OPC) \ + RSV(0x6F, RESERVED_OPC) \ + \ + /* TODO(benvanik): remove ones we don't need/can emulate */ \ + OPC(0x70, kAddI, "add_i", FLAG(kDefault), "sso", FF) \ + OPC(0x71, kAddF, "add_f", FLAG(kDefault), "sso", FF) \ + OPC(0x72, kSubI, "sub_i", FLAG(kDefault), "sso", FF) \ + OPC(0x73, kSubF, "sub_f", FLAG(kDefault), "sso", FF) \ + OPC(0x74, kAbsI, "abs_i", FLAG(kDefault), "so", FF) \ + OPC(0x75, kAbsF, "abs_f", FLAG(kDefault), "so", FF) \ + OPC(0x76, kMulI, "mul_i", FLAG(kDefault), "sso", FF) \ + OPC(0x77, kMulF, "mul_f", FLAG(kDefault), "sso", FF) \ + OPC(0x78, kDivIS, "div_i_s", FLAG(kDefault), "sso", FF) \ + OPC(0x79, kDivIU, "div_i_u", FLAG(kDefault), "sso", FF) \ + OPC(0x7A, kDivF, "div_f", FLAG(kDefault), "sso", FF) \ + OPC(0x7B, kMulAddI, "madd_i", FLAG(kDefault), "ssso", FF) \ + OPC(0x7C, kMulAddF, "madd_f", FLAG(kDefault), "ssso", FF) \ + OPC(0x7D, kCosF, "cos_f", FLAG(kDefault), "so", FF) \ + OPC(0x7E, kSinF, "sin_f", FLAG(kDefault), "so", FF) \ + OPC(0x7F, kTanhF, "tanh_f", FLAG(kDefault), "so", FF) \ + OPC(0x80, kAtan2F, "atan2_f", FLAG(kDefault), "sso", FF) \ + OPC(0x81, kExpF, "exp_f", FLAG(kDefault), "so", FF) \ + OPC(0x82, kLogF, "log_f", FLAG(kDefault), "so", FF) \ + OPC(0x83, kRsqrtF, "rsqrt_f", FLAG(kDefault), "so", FF) \ + \ + RSV(0x84, RESERVED_OPC) \ + RSV(0x85, RESERVED_OPC) \ + RSV(0x86, RESERVED_OPC) \ + RSV(0x87, RESERVED_OPC) \ + RSV(0x88, RESERVED_OPC) \ + RSV(0x89, RESERVED_OPC) \ + RSV(0x8A, RESERVED_OPC) \ + RSV(0x8B, RESERVED_OPC) \ + RSV(0x8C, RESERVED_OPC) \ + RSV(0x8D, RESERVED_OPC) \ + RSV(0x8E, RESERVED_OPC) \ + RSV(0x8F, RESERVED_OPC) \ + \ + OPC(0x90, kMinIS, "min_i_s", FLAG(kDefault), "sso", FF) \ + OPC(0x91, kMinIU, "min_i_u", FLAG(kDefault), "sso", FF) \ + OPC(0x92, kMinF, "min_f", FLAG(kDefault), "sso", FF) \ + OPC(0x93, kMaxIS, "max_i_s", FLAG(kDefault), "sso", FF) \ + OPC(0x94, kMaxIU, "max_i_u", FLAG(kDefault), "sso", FF) \ + OPC(0x95, kMaxF, "max_f", FLAG(kDefault), "sso", FF) \ + OPC(0x96, kClampIS, "clamp_i_s", FLAG(kDefault), "ssso", FF) \ + OPC(0x97, kClampIU, "clamp_i_u", FLAG(kDefault), "ssso", FF) \ + OPC(0x98, kClampF, "clamp_f", FLAG(kDefault), "ssso", FF) \ + OPC(0x99, kFloorF, "floor_f", FLAG(kDefault), "so", FF) \ + OPC(0x9A, kCeilF, "ceil_f", FLAG(kDefault), "so", FF) \ + \ + OPC(0x9B, kConvertSS, "convert_s_s", FLAG(kDefault), "tsto", FF) \ + OPC(0x9C, kConvertUU, "convert_u_u", FLAG(kDefault), "tsto", FF) \ + OPC(0x9D, kConvertSU, "convert_s_u", FLAG(kDefault), "tsto", FF) \ + OPC(0x9E, kConvertUS, "convert_u_s", FLAG(kDefault), "tsto", FF) \ + \ + RSV(0x9F, RESERVED_OPC) \ + \ + /* TODO(benvanik): reduction/sum/etc */ \ + /* TODO(benvanik): sort */ \ + \ + OPC(0xA0, kMatMulI, "matmul_i", FLAG(kDefault), "sssso", FF) \ + OPC(0xA1, kMatMulF, "matmul_f", FLAG(kDefault), "sso", FF) \ + /* TODO(benvanik): convolution */ \ + \ + OPC(0xA2, kReduceSumI, "reduce_sum_i", FLAG(kDefault), "ssio", FF) \ + OPC(0xA3, kReduceSumF, "reduce_sum_f", FLAG(kDefault), "ssio", FF) \ + OPC(0xA4, kReduceMinI, "reduce_min_i", FLAG(kDefault), "ssio", FF) \ + OPC(0xA5, kReduceMinF, "reduce_min_f", FLAG(kDefault), "ssio", FF) \ + OPC(0xA6, kReduceMaxI, "reduce_max_i", FLAG(kDefault), "ssio", FF) \ + OPC(0xA7, kReduceMaxF, "reduce_max_f", FLAG(kDefault), "ssio", FF) \ + RSV(0xA8, RESERVED_OPC) \ + RSV(0xA9, RESERVED_OPC) \ + RSV(0xAA, RESERVED_OPC) \ + RSV(0xAB, RESERVED_OPC) \ + RSV(0xAC, RESERVED_OPC) \ + RSV(0xAD, RESERVED_OPC) \ + RSV(0xAE, RESERVED_OPC) \ + RSV(0xAF, RESERVED_OPC) \ + RSV(0xB0, RESERVED_OPC) \ + RSV(0xB1, RESERVED_OPC) \ + RSV(0xB2, RESERVED_OPC) \ + RSV(0xB3, RESERVED_OPC) \ + RSV(0xB4, RESERVED_OPC) \ + RSV(0xB5, RESERVED_OPC) \ + RSV(0xB6, RESERVED_OPC) \ + RSV(0xB7, RESERVED_OPC) \ + RSV(0xB8, RESERVED_OPC) \ + RSV(0xB9, RESERVED_OPC) \ + RSV(0xBA, RESERVED_OPC) \ + RSV(0xBB, RESERVED_OPC) \ + RSV(0xBC, RESERVED_OPC) \ + RSV(0xBD, RESERVED_OPC) \ + RSV(0xBE, RESERVED_OPC) \ + RSV(0xBF, RESERVED_OPC) \ + RSV(0xC0, RESERVED_OPC) \ + RSV(0xC1, RESERVED_OPC) \ + RSV(0xC2, RESERVED_OPC) \ + RSV(0xC3, RESERVED_OPC) \ + RSV(0xC4, RESERVED_OPC) \ + RSV(0xC5, RESERVED_OPC) \ + RSV(0xC6, RESERVED_OPC) \ + RSV(0xC7, RESERVED_OPC) \ + RSV(0xC8, RESERVED_OPC) \ + RSV(0xC9, RESERVED_OPC) \ + RSV(0xCA, RESERVED_OPC) \ + RSV(0xCB, RESERVED_OPC) \ + RSV(0xCC, RESERVED_OPC) \ + RSV(0xCD, RESERVED_OPC) \ + RSV(0xCE, RESERVED_OPC) \ + RSV(0xCF, RESERVED_OPC) \ + RSV(0xD0, RESERVED_OPC) \ + RSV(0xD1, RESERVED_OPC) \ + RSV(0xD2, RESERVED_OPC) \ + RSV(0xD3, RESERVED_OPC) \ + RSV(0xD4, RESERVED_OPC) \ + RSV(0xD5, RESERVED_OPC) \ + RSV(0xD6, RESERVED_OPC) \ + RSV(0xD7, RESERVED_OPC) \ + RSV(0xD8, RESERVED_OPC) \ + RSV(0xD9, RESERVED_OPC) \ + RSV(0xDA, RESERVED_OPC) \ + RSV(0xDB, RESERVED_OPC) \ + RSV(0xDC, RESERVED_OPC) \ + RSV(0xDD, RESERVED_OPC) \ + RSV(0xDE, RESERVED_OPC) \ + RSV(0xDF, RESERVED_OPC) \ + RSV(0xE0, RESERVED_OPC) \ + RSV(0xE1, RESERVED_OPC) \ + RSV(0xE2, RESERVED_OPC) \ + RSV(0xE3, RESERVED_OPC) \ + RSV(0xE4, RESERVED_OPC) \ + RSV(0xE5, RESERVED_OPC) \ + RSV(0xE6, RESERVED_OPC) \ + RSV(0xE7, RESERVED_OPC) \ + RSV(0xE8, RESERVED_OPC) \ + RSV(0xE9, RESERVED_OPC) \ + RSV(0xEA, RESERVED_OPC) \ + RSV(0xEB, RESERVED_OPC) \ + RSV(0xEC, RESERVED_OPC) \ + RSV(0xED, RESERVED_OPC) \ + RSV(0xEE, RESERVED_OPC) \ + RSV(0xEF, RESERVED_OPC) \ + RSV(0xF0, RESERVED_OPC) \ + RSV(0xF1, RESERVED_OPC) \ + RSV(0xF2, RESERVED_OPC) \ + RSV(0xF3, RESERVED_OPC) \ + RSV(0xF4, RESERVED_OPC) \ + RSV(0xF5, RESERVED_OPC) \ + RSV(0xF6, RESERVED_OPC) \ + RSV(0xF7, RESERVED_OPC) \ + RSV(0xF8, RESERVED_OPC) \ + RSV(0xF9, RESERVED_OPC) \ + RSV(0xFA, RESERVED_OPC) \ + RSV(0xFB, RESERVED_OPC) \ + RSV(0xFC, RESERVED_OPC) \ + \ + OPC(0xFD, kTrace, "trace", FLAG(kDefault), "s", FF) \ + OPC(0xFE, kCondBreak, "cond_break", FLAG(kDefault), "s", FF) \ + OPC(0xFF, kBreak, "break", FLAG(kDefault), "", FF) + +#define DECLARE_ENUM(ordinal, enum_name, ...) enum_name = ordinal, +enum class InterpreterOpcode : uint8_t { + IREE_INTERPRETER_OPCODE_LIST(DECLARE_ENUM, DECLARE_ENUM) +}; +#undef DECLARE_ENUM + +} // namespace iree + +#endif // IREE_SCHEMAS_BYTECODE_INTERPRETER_BYTECODE_V0_H_
diff --git a/schemas/bytecode/sequencer_bytecode_v0.h b/schemas/bytecode/sequencer_bytecode_v0.h new file mode 100644 index 0000000..a199011 --- /dev/null +++ b/schemas/bytecode/sequencer_bytecode_v0.h
@@ -0,0 +1,313 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Opcode table for the V0 binary format. +// Additions are fine but changing the behavior or order of any opcodes will +// break parsing of existing files. +// +// Opcodes have been selected on frequency of use, general applicability, and +// relative stability. Experimental ops should be implemented via the Foreign +// Function Interface (FFI) first before graduating into the core set. Ops that +// may only be present on certain targets should also be kept as imports via the +// FFI. +// +// Opcodes may be specified for particular types (int32_t), categories of types +// (all floating-point types), or implicit types (output matches input). Saving +// opcode space by sharing a single opcode for multiple types is preferred +// except where hot operations are performed (for example, comparison used in +// loop iterators). + +#ifndef IREE_SCHEMAS_BYTECODE_SEQUENCER_BYTECODE_V0_H_ +#define IREE_SCHEMAS_BYTECODE_SEQUENCER_BYTECODE_V0_H_ + +#include "schemas/bytecode/bytecode_v0.h" + +namespace iree { + +#define IREE_SEQUENCER_OPCODE_LIST(OPC, RESERVED_OPC) \ + OPC(0x00, kConstant, "constant", FLAG(kDefault), "cr", FF) \ + \ + OPC(0x01, kCall, "call", FLAG(kDefault), "fSR", FF) \ + OPC(0x02, kCallImport, "call_import", FLAG(kDefault), "FSR", FF) \ + OPC(0x03, kCallIndirect, "call_indirect", FLAG(kDefault), "tsSR", FF) \ + OPC(0x04, kReturn, "return", FLAG(kDefault), "S", FF) \ + OPC(0x05, kBranch, "br", FLAG(kDefault), "bT", FF) \ + OPC(0x06, kCondBranch, "cond_br", FLAG(kDefault), "sbTbT", FF) \ + \ + RSV(0x07, RESERVED_OPC) \ + RSV(0x08, RESERVED_OPC) \ + RSV(0x09, RESERVED_OPC) \ + RSV(0x0A, RESERVED_OPC) \ + RSV(0x0B, RESERVED_OPC) \ + RSV(0x0C, RESERVED_OPC) \ + RSV(0x0D, RESERVED_OPC) \ + RSV(0x0E, RESERVED_OPC) \ + RSV(0x0F, RESERVED_OPC) \ + \ + OPC(0x10, kDynamicDispatch, "dynamic_dispatch", FLAG(kDefault), "dsSOR", FF) \ + OPC(0x11, kStaticDispatch, "static_dispatch", FLAG(kDefault), "diiiSOR", FF) \ + \ + RSV(0x12, RESERVED_OPC) \ + RSV(0x13, RESERVED_OPC) \ + RSV(0x14, RESERVED_OPC) \ + RSV(0x15, RESERVED_OPC) \ + RSV(0x16, RESERVED_OPC) \ + RSV(0x17, RESERVED_OPC) \ + RSV(0x18, RESERVED_OPC) \ + RSV(0x19, RESERVED_OPC) \ + RSV(0x1A, RESERVED_OPC) \ + RSV(0x1B, RESERVED_OPC) \ + RSV(0x1C, RESERVED_OPC) \ + RSV(0x1D, RESERVED_OPC) \ + RSV(0x1E, RESERVED_OPC) \ + RSV(0x1F, RESERVED_OPC) \ + \ + OPC(0x20, kAllocStatic, "alloc_static", FLAG(kDefault), "Icr", FF) \ + OPC(0x21, kAllocStack, "alloc_stack", FLAG(kDefault), "itISr", FF) \ + OPC(0x22, kAllocStackInit, "alloc_stack_init", FLAG(kDefault), "tIScr", FF) \ + OPC(0x23, kAllocHeap, "alloc_heap", FLAG(kDefault), "itISr", FF) \ + OPC(0x24, kDiscard, "discard", FLAG(kDefault), "s", FF) \ + \ + RSV(0x25, RESERVED_OPC) \ + RSV(0x26, RESERVED_OPC) \ + RSV(0x27, RESERVED_OPC) \ + RSV(0x28, RESERVED_OPC) \ + RSV(0x29, RESERVED_OPC) \ + RSV(0x2A, RESERVED_OPC) \ + RSV(0x2B, RESERVED_OPC) \ + RSV(0x2C, RESERVED_OPC) \ + RSV(0x2D, RESERVED_OPC) \ + RSV(0x2E, RESERVED_OPC) \ + RSV(0x2F, RESERVED_OPC) \ + RSV(0x30, RESERVED_OPC) \ + \ + OPC(0x31, kComputeRange, "compute_range", FLAG(kDefault), "sissoo", FF) \ + OPC(0x32, kShape, "shape", FLAG(kDefault), "so", FF) \ + OPC(0x33, kLength, "length", FLAG(kDefault), "so", FF) \ + OPC(0x34, kDynamicSlice, "dynamic_slice", FLAG(kDefault), "ssstsr", FF) \ + OPC(0x35, kStaticSlice, "static_slice", FLAG(kDefault), "siitIr", FF) \ + OPC(0x36, kDynamicCopy, "dynamic_copy", FLAG(kDefault), "ssoss", FF) \ + OPC(0x37, kStaticCopy, "static_copy", FLAG(kDefault), "sioii", FF) \ + OPC(0x38, kDynamicFill, "dynamic_fill", FLAG(kDefault), "soss", FF) \ + OPC(0x39, kStaticFill, "static_fill", FLAG(kDefault), "ioii", FF) \ + OPC(0x3A, kClone, "clone", FLAG(kDefault), "sr", FF) \ + OPC(0x3B, kAssign, "assign", FLAG(kDefault), "sr", FF) \ + OPC(0x3C, kCondAssign, "cond_assign", FLAG(kDefault), "sssr", FF) \ + OPC(0x3D, kReshape, "reshape", FLAG(kDefault), "ssr", FF) \ + \ + RSV(0x3E, RESERVED_OPC) \ + RSV(0x3F, RESERVED_OPC) \ + RSV(0x40, RESERVED_OPC) \ + RSV(0x41, RESERVED_OPC) \ + RSV(0x42, RESERVED_OPC) \ + RSV(0x43, RESERVED_OPC) \ + RSV(0x44, RESERVED_OPC) \ + RSV(0x45, RESERVED_OPC) \ + RSV(0x46, RESERVED_OPC) \ + RSV(0x47, RESERVED_OPC) \ + RSV(0x48, RESERVED_OPC) \ + RSV(0x49, RESERVED_OPC) \ + RSV(0x4A, RESERVED_OPC) \ + RSV(0x4B, RESERVED_OPC) \ + RSV(0x4C, RESERVED_OPC) \ + RSV(0x4D, RESERVED_OPC) \ + RSV(0x4E, RESERVED_OPC) \ + RSV(0x4F, RESERVED_OPC) \ + RSV(0x50, RESERVED_OPC) \ + RSV(0x51, RESERVED_OPC) \ + RSV(0x52, RESERVED_OPC) \ + RSV(0x53, RESERVED_OPC) \ + RSV(0x54, RESERVED_OPC) \ + RSV(0x55, RESERVED_OPC) \ + RSV(0x56, RESERVED_OPC) \ + RSV(0x57, RESERVED_OPC) \ + RSV(0x58, RESERVED_OPC) \ + RSV(0x59, RESERVED_OPC) \ + RSV(0x5A, RESERVED_OPC) \ + RSV(0x5B, RESERVED_OPC) \ + RSV(0x5C, RESERVED_OPC) \ + RSV(0x5D, RESERVED_OPC) \ + RSV(0x5E, RESERVED_OPC) \ + RSV(0x5F, RESERVED_OPC) \ + RSV(0x60, RESERVED_OPC) \ + RSV(0x61, RESERVED_OPC) \ + RSV(0x62, RESERVED_OPC) \ + RSV(0x63, RESERVED_OPC) \ + RSV(0x64, RESERVED_OPC) \ + RSV(0x65, RESERVED_OPC) \ + RSV(0x66, RESERVED_OPC) \ + RSV(0x67, RESERVED_OPC) \ + RSV(0x68, RESERVED_OPC) \ + RSV(0x69, RESERVED_OPC) \ + RSV(0x6A, RESERVED_OPC) \ + RSV(0x6B, RESERVED_OPC) \ + RSV(0x6C, RESERVED_OPC) \ + RSV(0x6D, RESERVED_OPC) \ + RSV(0x6E, RESERVED_OPC) \ + RSV(0x6F, RESERVED_OPC) \ + RSV(0x70, RESERVED_OPC) \ + RSV(0x71, RESERVED_OPC) \ + RSV(0x72, RESERVED_OPC) \ + RSV(0x73, RESERVED_OPC) \ + RSV(0x74, RESERVED_OPC) \ + RSV(0x75, RESERVED_OPC) \ + RSV(0x76, RESERVED_OPC) \ + RSV(0x77, RESERVED_OPC) \ + RSV(0x78, RESERVED_OPC) \ + RSV(0x79, RESERVED_OPC) \ + RSV(0x7A, RESERVED_OPC) \ + RSV(0x7B, RESERVED_OPC) \ + RSV(0x7C, RESERVED_OPC) \ + RSV(0x7D, RESERVED_OPC) \ + RSV(0x7E, RESERVED_OPC) \ + RSV(0x7F, RESERVED_OPC) \ + RSV(0x80, RESERVED_OPC) \ + RSV(0x81, RESERVED_OPC) \ + RSV(0x82, RESERVED_OPC) \ + RSV(0x83, RESERVED_OPC) \ + RSV(0x84, RESERVED_OPC) \ + RSV(0x85, RESERVED_OPC) \ + RSV(0x86, RESERVED_OPC) \ + RSV(0x87, RESERVED_OPC) \ + RSV(0x88, RESERVED_OPC) \ + RSV(0x89, RESERVED_OPC) \ + RSV(0x8A, RESERVED_OPC) \ + RSV(0x8B, RESERVED_OPC) \ + RSV(0x8C, RESERVED_OPC) \ + RSV(0x8D, RESERVED_OPC) \ + RSV(0x8E, RESERVED_OPC) \ + RSV(0x8F, RESERVED_OPC) \ + RSV(0x90, RESERVED_OPC) \ + RSV(0x91, RESERVED_OPC) \ + RSV(0x92, RESERVED_OPC) \ + RSV(0x93, RESERVED_OPC) \ + RSV(0x94, RESERVED_OPC) \ + RSV(0x95, RESERVED_OPC) \ + RSV(0x96, RESERVED_OPC) \ + RSV(0x97, RESERVED_OPC) \ + RSV(0x98, RESERVED_OPC) \ + RSV(0x99, RESERVED_OPC) \ + RSV(0x9A, RESERVED_OPC) \ + RSV(0x9B, RESERVED_OPC) \ + RSV(0x9C, RESERVED_OPC) \ + RSV(0x9D, RESERVED_OPC) \ + RSV(0x9E, RESERVED_OPC) \ + RSV(0x9F, RESERVED_OPC) \ + RSV(0xA0, RESERVED_OPC) \ + RSV(0xA1, RESERVED_OPC) \ + RSV(0xA2, RESERVED_OPC) \ + RSV(0xA3, RESERVED_OPC) \ + RSV(0xA4, RESERVED_OPC) \ + RSV(0xA5, RESERVED_OPC) \ + RSV(0xA6, RESERVED_OPC) \ + RSV(0xA7, RESERVED_OPC) \ + RSV(0xA8, RESERVED_OPC) \ + RSV(0xA9, RESERVED_OPC) \ + RSV(0xAA, RESERVED_OPC) \ + RSV(0xAB, RESERVED_OPC) \ + RSV(0xAC, RESERVED_OPC) \ + RSV(0xAD, RESERVED_OPC) \ + RSV(0xAE, RESERVED_OPC) \ + RSV(0xAF, RESERVED_OPC) \ + RSV(0xB0, RESERVED_OPC) \ + RSV(0xB1, RESERVED_OPC) \ + RSV(0xB2, RESERVED_OPC) \ + RSV(0xB3, RESERVED_OPC) \ + RSV(0xB4, RESERVED_OPC) \ + RSV(0xB5, RESERVED_OPC) \ + RSV(0xB6, RESERVED_OPC) \ + RSV(0xB7, RESERVED_OPC) \ + RSV(0xB8, RESERVED_OPC) \ + RSV(0xB9, RESERVED_OPC) \ + RSV(0xBA, RESERVED_OPC) \ + RSV(0xBB, RESERVED_OPC) \ + RSV(0xBC, RESERVED_OPC) \ + RSV(0xBD, RESERVED_OPC) \ + RSV(0xBE, RESERVED_OPC) \ + RSV(0xBF, RESERVED_OPC) \ + RSV(0xC0, RESERVED_OPC) \ + RSV(0xC1, RESERVED_OPC) \ + RSV(0xC2, RESERVED_OPC) \ + RSV(0xC3, RESERVED_OPC) \ + RSV(0xC4, RESERVED_OPC) \ + RSV(0xC5, RESERVED_OPC) \ + RSV(0xC6, RESERVED_OPC) \ + RSV(0xC7, RESERVED_OPC) \ + RSV(0xC8, RESERVED_OPC) \ + RSV(0xC9, RESERVED_OPC) \ + RSV(0xCA, RESERVED_OPC) \ + RSV(0xCB, RESERVED_OPC) \ + RSV(0xCC, RESERVED_OPC) \ + RSV(0xCD, RESERVED_OPC) \ + RSV(0xCE, RESERVED_OPC) \ + RSV(0xCF, RESERVED_OPC) \ + RSV(0xD0, RESERVED_OPC) \ + RSV(0xD1, RESERVED_OPC) \ + RSV(0xD2, RESERVED_OPC) \ + RSV(0xD3, RESERVED_OPC) \ + RSV(0xD4, RESERVED_OPC) \ + RSV(0xD5, RESERVED_OPC) \ + RSV(0xD6, RESERVED_OPC) \ + RSV(0xD7, RESERVED_OPC) \ + RSV(0xD8, RESERVED_OPC) \ + RSV(0xD9, RESERVED_OPC) \ + RSV(0xDA, RESERVED_OPC) \ + RSV(0xDB, RESERVED_OPC) \ + RSV(0xDC, RESERVED_OPC) \ + RSV(0xDD, RESERVED_OPC) \ + RSV(0xDE, RESERVED_OPC) \ + RSV(0xDF, RESERVED_OPC) \ + RSV(0xE0, RESERVED_OPC) \ + RSV(0xE1, RESERVED_OPC) \ + RSV(0xE2, RESERVED_OPC) \ + RSV(0xE3, RESERVED_OPC) \ + RSV(0xE4, RESERVED_OPC) \ + RSV(0xE5, RESERVED_OPC) \ + RSV(0xE6, RESERVED_OPC) \ + RSV(0xE7, RESERVED_OPC) \ + RSV(0xE8, RESERVED_OPC) \ + RSV(0xE9, RESERVED_OPC) \ + RSV(0xEA, RESERVED_OPC) \ + RSV(0xEB, RESERVED_OPC) \ + RSV(0xEC, RESERVED_OPC) \ + RSV(0xED, RESERVED_OPC) \ + RSV(0xEE, RESERVED_OPC) \ + RSV(0xEF, RESERVED_OPC) \ + RSV(0xF0, RESERVED_OPC) \ + RSV(0xF1, RESERVED_OPC) \ + RSV(0xF2, RESERVED_OPC) \ + RSV(0xF3, RESERVED_OPC) \ + RSV(0xF4, RESERVED_OPC) \ + RSV(0xF5, RESERVED_OPC) \ + RSV(0xF6, RESERVED_OPC) \ + RSV(0xF7, RESERVED_OPC) \ + RSV(0xF8, RESERVED_OPC) \ + RSV(0xF9, RESERVED_OPC) \ + RSV(0xFA, RESERVED_OPC) \ + RSV(0xFB, RESERVED_OPC) \ + RSV(0xFC, RESERVED_OPC) \ + \ + OPC(0xFD, kTrace, "trace", FLAG(kDefault), "s", FF) \ + OPC(0xFE, kCondBreak, "cond_break", FLAG(kDefault), "s", FF) \ + OPC(0xFF, kBreak, "break", FLAG(kDefault), "", FF) + +#define DECLARE_ENUM(ordinal, enum_name, ...) enum_name = ordinal, +enum class SequencerOpcode : uint8_t { + IREE_SEQUENCER_OPCODE_LIST(DECLARE_ENUM, DECLARE_ENUM) +}; +#undef DECLARE_ENUM + +} // namespace iree + +#endif // IREE_SCHEMAS_BYTECODE_SEQUENCER_BYTECODE_V0_H_
diff --git a/iree/schemas/bytecode_def.fbs b/schemas/bytecode_def.fbs similarity index 100% rename from iree/schemas/bytecode_def.fbs rename to schemas/bytecode_def.fbs
diff --git a/schemas/debug_service.fbs b/schemas/debug_service.fbs new file mode 100644 index 0000000..0f31294 --- /dev/null +++ b/schemas/debug_service.fbs
@@ -0,0 +1,347 @@ +include "schemas/function_def.fbs"; +include "schemas/module_def.fbs"; + +namespace iree.rt.debug.rpc; + +table Status { + code:int; + message:string; +} + +table CreateSessionRequest { +} +table CreateSessionResponse { + session_id:int; +} + +table MakeReadyRequest { + session_id:int; +} +table MakeReadyResponse { +} + +table GetStatusRequest { + session_id:int; + // TODO(benvanik): caps debugger supports? version expected? +} +table GetStatusResponse { + protocol:int; + // TODO(benvanik): run state. + // TODO(benvanik): profiling state. +} + +table NativeFunctionDef { + name:string; + // TODO(benvanik): more information about the fns (stack trace of registrant?) +} + +table ContextDef { + context_id:int; + native_functions:[NativeFunctionDef]; + module_names:[string]; +} + +table ListContextsRequest { + session_id:int; +} + +table ListContextsResponse { + contexts:[ContextDef]; +} + +table GetModuleRequest { + session_id:int; + context_id:int; + module_name:string; +} +table GetModuleResponse { + module:ModuleDef; +} + +table GetFunctionRequest { + session_id:int; + context_id:int; + module_name:string; + function_ordinal:int; +} +table GetFunctionResponse { + bytecode:BytecodeDef; + // TODO(benvanik): import info (linked module, etc). +} + +table ResolveFunctionRequest { + session_id:int; + module_name:string; + function_name:string; +} +table ResolveFunctionResponse { + context_ids:[int]; + function_ordinal:int; +} + +table BufferViewDef { + is_valid:bool; + shape:[int]; + element_size:int; + // TODO(benvanik): buffer attrs (type, access, usage). + // TODO(benvanik): buffer size/allocated_size. + // TODO(benvanik): buffer data (if accessible). +} + +table StackFrameDef { + module_name:string; + function_ordinal:int; + offset:int; + locals:[BufferViewDef]; +} + +table InvocationDef { + invocation_id:int; + frames:[StackFrameDef]; +} + +table ListInvocationsRequest { + session_id:int; +} +table ListInvocationsResponse { + invocations:[InvocationDef]; +} + +table SuspendInvocationsRequest { + session_id:int; + invocation_ids:[int]; +} +table SuspendInvocationsResponse { + invocations:[InvocationDef]; +} + +table ResumeInvocationsRequest { + session_id:int; + invocation_ids:[int]; +} +table ResumeInvocationsResponse { +} + +enum StepMode : uint8 { + STEP_ONCE = 0, + STEP_TO_OFFSET = 1, +} + +table StepInvocationRequest { + session_id:int; + step_id:int; + invocation_id:int; + step_mode:StepMode; + bytecode_offset:int; +} +table StepInvocationResponse {} + +table GetInvocationLocalRequest { + session_id:int; + invocation_id:int; + frame_index:int; + local_index:int; +} +table GetInvocationLocalResponse { + value:BufferViewDef; +} + +table SetInvocationLocalRequest { + session_id:int; + invocation_id:int; + frame_index:int; + local_index:int; + value:BufferViewDef; +} +table SetInvocationLocalResponse { + value:BufferViewDef; +} + +enum BreakpointType : uint8 { + BYTECODE_FUNCTION = 0, + NATIVE_FUNCTION = 1, +} + +table BreakpointDef { + breakpoint_id:int; + breakpoint_type:BreakpointType; + + module_name:string; + function_name:string; + function_ordinal:int; + bytecode_offset:int; +} + +table ListBreakpointsRequest { + session_id:int; +} +table ListBreakpointsResponse { + breakpoints:[BreakpointDef]; +} + +table AddBreakpointRequest { + session_id:int; + breakpoint:BreakpointDef; +} +table AddBreakpointResponse { + breakpoint:BreakpointDef; +} + +table RemoveBreakpointRequest { + session_id:int; + breakpoint_id:int; +} +table RemoveBreakpointResponse { +} + +table StartProfilingRequest { + session_id:int; + context_id:int; + // TODO(benvanik): profiling mode. + // mode: sampling_timing, instrumented_coverage, instrumented_log, + // invoke_log +} +table StartProfilingResponse { + // TODO(benvanik): current/new mode. +} + +table StopProfilingRequest { + session_id:int; + context_id:int; +} +table StopProfilingResponse { + // TODO(benvanik): profiling data. +} + +// TODO(benvanik): streaming profiling data query. + +table ServiceShutdownEvent { +} + +table ContextRegisteredEvent { + context_id:int; +} +table ContextUnregisteredEvent { + context_id:int; +} + +table ModuleLoadedEvent { + context_id:int; + module_name:string; +} + +table InvocationRegisteredEvent { + invocation_id:int; +} +table InvocationUnregisteredEvent { + invocation_id:int; +} + +table BreakpointResolvedEvent { + breakpoint:BreakpointDef; + context_id:int; +} + +table BreakpointHitEvent { + breakpoint_id:int; + invocation:InvocationDef; +} + +table StepCompletedEvent { + step_id:int; + invocations:[InvocationDef]; +} + +union RequestUnion { + CreateSessionRequest, + MakeReadyRequest, + GetStatusRequest, + ListContextsRequest, + GetModuleRequest, + GetFunctionRequest, + ResolveFunctionRequest, + ListInvocationsRequest, + SuspendInvocationsRequest, + ResumeInvocationsRequest, + StepInvocationRequest, + GetInvocationLocalRequest, + SetInvocationLocalRequest, + ListBreakpointsRequest, + AddBreakpointRequest, + RemoveBreakpointRequest, + StartProfilingRequest, + StopProfilingRequest, +} + +union ResponseUnion { + CreateSessionResponse, + MakeReadyResponse, + GetStatusResponse, + ListContextsResponse, + GetModuleResponse, + GetFunctionResponse, + ResolveFunctionResponse, + ListInvocationsResponse, + SuspendInvocationsResponse, + ResumeInvocationsResponse, + StepInvocationResponse, + GetInvocationLocalResponse, + SetInvocationLocalResponse, + ListBreakpointsResponse, + AddBreakpointResponse, + RemoveBreakpointResponse, + StartProfilingResponse, + StopProfilingResponse, +} + +union EventUnion { + ServiceShutdownEvent, + ContextRegisteredEvent, + ContextUnregisteredEvent, + ModuleLoadedEvent, + InvocationRegisteredEvent, + InvocationUnregisteredEvent, + BreakpointResolvedEvent, + BreakpointHitEvent, + StepCompletedEvent, +} + +table Request { + message:RequestUnion; +} + +table Response { + status:Status; + message:ResponseUnion; +} + +table ServicePacket { + response:Response; + event:EventUnion; +} + +// NOTE: we aren't using this yet as the FlatBuffers gRPC code is... suspect. +rpc_service DebugServiceRpc { + MakeReady(MakeReadyRequest):MakeReadyResponse; + + GetStatus(GetStatusRequest):GetStatusResponse; + + ListContexts(ListContextsRequest):ListContextsResponse; + GetModule(GetModuleRequest):GetModuleResponse; + GetFunction(GetFunctionRequest):GetFunctionResponse; + ResolveFunction(ResolveFunctionRequest):ResolveFunctionResponse; + + ListInvocations(ListInvocationsRequest):ListInvocationsResponse; + SuspendInvocations(SuspendInvocationsRequest):SuspendInvocationsResponse; + ResumeInvocations(ResumeInvocationsRequest):ResumeInvocationsResponse; + StepInvocation(StepInvocationRequest):StepInvocationResponse; + GetInvocationLocal(GetInvocationLocalRequest):GetInvocationLocalResponse; + SetInvocationLocal(SetInvocationLocalRequest):SetInvocationLocalResponse; + + ListBreakpoints(ListBreakpointsRequest):ListBreakpointsResponse; + AddBreakpoint(AddBreakpointRequest):AddBreakpointResponse; + RemoveBreakpoint(RemoveBreakpointRequest):RemoveBreakpointResponse; + + StartProfiling(StartProfilingRequest):StartProfilingResponse; + StopProfiling(StopProfilingRequest):StopProfilingResponse; +}
diff --git a/iree/schemas/device_def.fbs b/schemas/device_def.fbs similarity index 100% rename from iree/schemas/device_def.fbs rename to schemas/device_def.fbs
diff --git a/iree/schemas/device_group_def.fbs b/schemas/device_group_def.fbs similarity index 100% rename from iree/schemas/device_group_def.fbs rename to schemas/device_group_def.fbs
diff --git a/schemas/device_table_def.fbs b/schemas/device_table_def.fbs new file mode 100644 index 0000000..597e126 --- /dev/null +++ b/schemas/device_table_def.fbs
@@ -0,0 +1,13 @@ +include "schemas/device_def.fbs"; +include "schemas/device_group_def.fbs"; + +namespace iree; + +// A table of devices used for runtime device resolution and referencing. +table DeviceTableDef { + // One or more virtual devices referenced by ordinal in the sequencer ops. + devices:[DeviceDef]; + + // Zero or more device groups that specify which devices must be compatible. + device_groups:[DeviceGroupDef]; +}
diff --git a/iree/schemas/executable_def.fbs b/schemas/executable_def.fbs similarity index 100% rename from iree/schemas/executable_def.fbs rename to schemas/executable_def.fbs
diff --git a/schemas/executable_table_def.fbs b/schemas/executable_table_def.fbs new file mode 100644 index 0000000..6c55f76 --- /dev/null +++ b/schemas/executable_table_def.fbs
@@ -0,0 +1,28 @@ +include "schemas/executable_def.fbs"; + +namespace iree; + +// A fat executable containing multiple format variants for the same logical +// entry points. +table MultiArchExecutableDef { + // Friendly name of the executable used for diagnostics. + name:string; + + // Number of available entry points. + // This is used for bytecode verification even when the executable is not + // fully loaded into a device. All executables must have the same entry + // points. + entry_point_count:uint; + + // A set of executables of various formats and supported feature sets. + // The runtime will select the appropriate executable based on the dispatch + // requirements. + executables:[ExecutableDef]; +} + +// A table of executables used for runtime dispatch lookup. +table ExecutableTableDef { + // One or more top level executables referenced by sequencer dispatch ops. + // Ordinal is referenced by dispatch ops to index into the table. + multi_arch_executables:[MultiArchExecutableDef]; +}
diff --git a/schemas/function_def.fbs b/schemas/function_def.fbs new file mode 100644 index 0000000..b6eed64 --- /dev/null +++ b/schemas/function_def.fbs
@@ -0,0 +1,18 @@ +include "schemas/bytecode_def.fbs"; +include "schemas/type_def.fbs"; + +namespace iree; + +table FunctionAttributeDef { + key:string; + value:string; +} + +table FunctionDef { + name:string; + type:FunctionTypeDef; + + attrs:[FunctionAttributeDef]; + + bytecode:BytecodeDef; +}
diff --git a/schemas/function_table_def.fbs b/schemas/function_table_def.fbs new file mode 100644 index 0000000..3ad9c0a --- /dev/null +++ b/schemas/function_table_def.fbs
@@ -0,0 +1,9 @@ +include "schemas/function_def.fbs"; + +namespace iree; + +table FunctionTableDef { + functions:[FunctionDef]; + imports:[int]; + exports:[int]; +}
diff --git a/schemas/module_def.fbs b/schemas/module_def.fbs new file mode 100644 index 0000000..cebcbcb --- /dev/null +++ b/schemas/module_def.fbs
@@ -0,0 +1,20 @@ +include "schemas/executable_table_def.fbs"; +include "schemas/device_table_def.fbs"; +include "schemas/function_table_def.fbs"; +include "schemas/source_map_def.fbs"; + +namespace iree; + +// 'Executable MODule'. +file_identifier "EMOD"; +file_extension "emod"; + +table ModuleDef { + name:string; + device_table:DeviceTableDef; + function_table:FunctionTableDef; + executable_table:ExecutableTableDef; + source_map:SourceMapDef; +} + +root_type ModuleDef;
diff --git a/iree/schemas/source_map_def.fbs b/schemas/source_map_def.fbs similarity index 100% rename from iree/schemas/source_map_def.fbs rename to schemas/source_map_def.fbs
diff --git a/iree/schemas/spirv_executable_def.fbs b/schemas/spirv_executable_def.fbs similarity index 100% rename from iree/schemas/spirv_executable_def.fbs rename to schemas/spirv_executable_def.fbs
diff --git a/iree/schemas/type_def.fbs b/schemas/type_def.fbs similarity index 100% rename from iree/schemas/type_def.fbs rename to schemas/type_def.fbs
diff --git a/test/e2e/BUILD b/test/e2e/BUILD index bba2d68..c83e6fc 100644 --- a/test/e2e/BUILD +++ b/test/e2e/BUILD
@@ -1,6 +1,6 @@ # Tests for end-to-end IREE support. -load("//iree:build_defs.bzl", "iree_glob_lit_tests", "iree_setup_lit_package") +load("//:build_defs.google.bzl", "iree_glob_lit_tests", "iree_setup_lit_package") package( default_visibility = ["//visibility:public"], @@ -9,7 +9,7 @@ iree_setup_lit_package( data = [ - "//iree/tools:iree-run-mlir", + "///tools:iree-run-mlir", ], )
diff --git a/test/e2e/xla/BUILD b/test/e2e/xla/BUILD index f3eda03..0066c08 100644 --- a/test/e2e/xla/BUILD +++ b/test/e2e/xla/BUILD
@@ -1,6 +1,6 @@ # Tests for end-to-end IREE support starting from the XLA HLO dialect. -load("//iree:build_defs.bzl", "iree_glob_lit_tests", "iree_setup_lit_package") +load("//:build_defs.google.bzl", "iree_glob_lit_tests", "iree_setup_lit_package") package( default_visibility = ["//visibility:public"], @@ -9,7 +9,7 @@ iree_setup_lit_package( data = [ - "//iree/tools:iree-run-mlir", + "///tools:iree-run-mlir", ], )
diff --git a/tools/BUILD b/tools/BUILD new file mode 100644 index 0000000..c35fa96 --- /dev/null +++ b/tools/BUILD
@@ -0,0 +1,124 @@ +# Misc tools used to optimize, translate, and evaluate IREE. +# Most of these are not designed to run on-device. + +load("//:build_defs.google.bzl", "PLATFORM_VULKAN_DEPS") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +exports_files([ + "run_lit.sh", + "sanitizer_suppressions.txt", +]) + +cc_binary( + name = "iree-opt", + deps = [ + "///compiler/Transforms", + "///compiler/Transforms/Interpreter", + "///compiler/Transforms/Sequencer", + "///compiler/Translation/SPIRV", + "@llvm//:support", + "@local_config_mlir//:AffineDialectRegistration", + "@local_config_mlir//:MlirOptLib", + "@local_config_mlir//:MlirOptMain", + "@local_config_mlir//:StandardDialectRegistration", + "@org_tensorflow//tensorflow/compiler/mlir/xla:hlo", + "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_dialect_registration", + ], +) + +cc_binary( + name = "iree-run-mlir", + srcs = ["run_mlir_main.cc"], + deps = [ + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/strings", + "///base:source_location", + "///rt", + "///vm:sequencer_module", + "@llvm//:support", + "@local_config_mlir//:IR", + "@local_config_mlir//:Parser", + "@local_config_mlir//:Support", + "///base:init", + "///base:status", + "///compiler/Translation/Sequencer", + "///compiler/Translation/Interpreter", + "///compiler/Translation/SPIRV", + "///hal:buffer_view_string_util", + "///hal:driver_registry", + "///schemas", + "///rt/debug:debug_server_flags", + ] + PLATFORM_VULKAN_DEPS + [ + "///hal/interpreter:interpreter_driver_module", + # TODO(b/142004903): enable when Dawn HAL implementation is functional + # "///hal/dawn:dawn_driver_module", + "///hal/vulkan:vulkan_driver_module", + ], +) + +cc_binary( + name = "iree-translate", + srcs = ["iree_translate_main.cc"], + deps = [ + "///compiler/Translation/Interpreter", + "///compiler/Translation/SPIRV", + "///compiler/Translation/Sequencer", + "@llvm//:support", + "@local_config_mlir//:AffineDialectRegistration", + "@local_config_mlir//:IR", + "@local_config_mlir//:Pass", + "@local_config_mlir//:StandardDialectRegistration", + "@local_config_mlir//:Support", + "@local_config_mlir//:TranslateClParser", + "@local_config_mlir//:Translation", + "@org_tensorflow//tensorflow/compiler/mlir/xla:xla_dialect_registration", + ], +) + +cc_binary( + name = "run_module", + srcs = ["run_module_main.cc"], + deps = [ + "///base:file_io", + "///base:file_path", + "///base:init", + "///base:source_location", + "///base:status", + "///hal:buffer_view_string_util", + "///hal:driver_registry", + "///hal/interpreter:interpreter_driver_module", + "///rt", + "///rt/debug:debug_server_flags", + "///schemas", + "///vm:sequencer_module", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/strings", + ], +) + +cc_binary( + name = "benchmark_module", + testonly = 1, + srcs = ["benchmark_module.cc"], + deps = [ + "///base:file_io", + "///base:file_path", + "///base:init", + "///base:source_location", + "///base:status", + "///hal:buffer_view_string_util", + "///hal:driver_registry", + "///hal/interpreter:interpreter_driver_module", + "///rt", + "///rt/debug:debug_server_flags", + "///schemas", + "///vm:sequencer_module", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/strings", + "@com_google_benchmark//:benchmark", + ], +)
diff --git a/iree/tools/CMakeLists.txt b/tools/CMakeLists.txt similarity index 100% rename from iree/tools/CMakeLists.txt rename to tools/CMakeLists.txt
diff --git a/tools/benchmark_module.cc b/tools/benchmark_module.cc new file mode 100644 index 0000000..b5f45b9 --- /dev/null +++ b/tools/benchmark_module.cc
@@ -0,0 +1,157 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <iostream> +#include <vector> + +#include "absl/flags/flag.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "base/file_io.h" +#include "base/file_path.h" +#include "base/init.h" +#include "base/source_location.h" +#include "base/status.h" +#include "benchmark/benchmark.h" +#include "hal/buffer_view_string_util.h" +#include "hal/driver_registry.h" +#include "rt/context.h" +#include "rt/debug/debug_server_flags.h" +#include "rt/instance.h" +#include "rt/module_printer.h" +#include "schemas/module_def_generated.h" +#include "vm/sequencer_module.h" + +ABSL_FLAG(std::string, main_module, "", "Main module with entry point."); +ABSL_FLAG(std::string, main_function, "", + "Function within the main module to execute."); + +ABSL_FLAG(std::string, input_values, "", "Input shapes and optional values."); +ABSL_FLAG(std::string, input_file, "", + "Input shapes and optional values serialized in a file."); + +namespace iree { +namespace { + +// Parses a list of input shapes and values from a string of newline-separated +// inputs. Expects the contents to have one value per line with each value +// listed as +// [shape]xtype=[value] +// Example: +// 4x4xi8=0,1,2,3 +StatusOr<std::vector<hal::BufferView>> ParseInputsFromFlags( + hal::Allocator* allocator) { + std::string file_contents; + if (!absl::GetFlag(FLAGS_input_values).empty()) { + file_contents = + absl::StrReplaceAll(absl::GetFlag(FLAGS_input_values), {{"\\n", "\n"}}); + } else if (!absl::GetFlag(FLAGS_input_file).empty()) { + ASSIGN_OR_RETURN(file_contents, + file_io::GetFileContents(absl::GetFlag(FLAGS_input_file))); + } + std::vector<hal::BufferView> inputs; + for (const auto& line : + absl::StrSplit(file_contents, '\n', absl::SkipWhitespace())) { + ASSIGN_OR_RETURN(auto input, + hal::ParseBufferViewFromString(line, allocator)); + inputs.push_back(input); + } + return inputs; +} + +Status Run(benchmark::State& state) { + ASSIGN_OR_RETURN(auto debug_server, rt::debug::CreateDebugServerFromFlags()); + auto instance = make_ref<rt::Instance>(std::move(debug_server)); + ASSIGN_OR_RETURN(auto driver, hal::DriverRegistry::shared_registry()->Create( + "interpreter")); + ASSIGN_OR_RETURN(auto device, driver->CreateDefaultDevice()); + RETURN_IF_ERROR(instance->device_manager()->RegisterDevice(device)); + auto policy = make_ref<rt::Policy>(); + auto context = make_ref<rt::Context>(add_ref(instance), std::move(policy)); + + // Load main module. + ASSIGN_OR_RETURN( + auto main_module_file, + vm::ModuleFile::LoadFile(ModuleDefIdentifier(), + absl::GetFlag(FLAGS_main_module)), + _ << "while loading module file " << absl::GetFlag(FLAGS_main_module)); + ASSIGN_OR_RETURN(auto main_module, + vm::SequencerModule::FromFile(std::move(main_module_file))); + + // Register the main module with the context. + // We could add additional modules (specializations, shared libraries, etc). + // ModuleFiles are stateless so we could have the same module_file used by + // multiple contexts simultaneously. + RETURN_IF_ERROR(context->RegisterModule(add_ref(main_module))); + + rt::Function main_function; + if (!absl::GetFlag(FLAGS_main_function).empty()) { + // User-specified main function. + ASSIGN_OR_RETURN(main_function, main_module->LookupFunctionByName( + rt::Function::Linkage::kExport, + absl::GetFlag(FLAGS_main_function))); + } else { + // No main function specified; to prevent non-deterministic behavior we + // require one unless there's exactly one exported function in the module. + if (main_module->signature().export_function_count() == 1) { + ASSIGN_OR_RETURN(main_function, main_module->LookupFunctionByOrdinal( + rt::Function::Linkage::kExport, 0)); + } else { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "--main_function= must be specified to disambiguate the " + "function to run"; + } + } + + // Call into the main function. + ASSIGN_OR_RETURN(auto arguments, ParseInputsFromFlags(device->allocator())); + + for (auto _ : state) { + ASSIGN_OR_RETURN(auto invocation, + rt::Invocation::Create(add_ref(context), main_function, + make_ref<rt::Policy>(), {}, + absl::MakeConstSpan(arguments))); + RETURN_IF_ERROR(invocation->Await(absl::InfiniteFuture())); + } + + return OkStatus(); +} + +void BM_RunModule(benchmark::State& state) { + // Delegate to a status-returning function so we can use the status macros. + CHECK_OK(Run(state)); +} + +// By default only the main thread is included in CPU time. Include all the +// threads instead. To make single and multi-threaded benchmarks more +// comparable, use the wall time to determine how many iterations to run. +// See https://github.com/google/benchmark#cpu-timers, +BENCHMARK(BM_RunModule)->MeasureProcessCPUTime()->UseRealTime(); + +} // namespace + +extern "C" int main(int argc, char** argv) { + // The benchmark library uses a different mechanism for its flags. This + // consumes any arguments it understands from argv. It must come before + // InitializeEnvironment to avoid failures on unknown flags. + ::benchmark::Initialize(&argc, argv); + InitializeEnvironment(&argc, &argv); + size_t run_benchmark_count = ::benchmark::RunSpecifiedBenchmarks(); + CHECK_GT(run_benchmark_count, 0) << "No benchmarks were run"; + return 0; +} + +} // namespace iree
diff --git a/tools/compilation.bzl b/tools/compilation.bzl new file mode 100644 index 0000000..65b2f6b --- /dev/null +++ b/tools/compilation.bzl
@@ -0,0 +1,43 @@ +"""Rules for compiling IREE executables, modules, and archives.""" + +load("///build_tools/embed_data:build_defs.bzl", "cc_embed_data") + +# TODO(benvanik): port to a full starlark rule, document, etc. +def iree_bytecode_module( + name, + srcs, + cc_namespace = None, + visibility = None): + native.genrule( + name = name, + srcs = srcs, + outs = [ + "%s.emod" % (name), + ], + cmd = " && ".join([ + " ".join([ + "$(location ///tools:iree-translate)", + "-mlir-to-iree-module", + "-o $(location %s.emod)" % (name), + ] + ["$(locations %s)" % (src) for src in srcs]), + ]), + tools = [ + "///tools:iree-translate", + ], + message = "Compiling IREE module %s..." % (name), + output_to_bindir = 1, + ) + + # Embed the module for use in C++. This avoids the need for file IO in + # tests and samples that would otherwise complicate execution/porting. + if cc_namespace: + cc_embed_data( + name = "%s_cc" % (name), + identifier = name, + srcs = ["%s.emod" % (name)], + cc_file_output = "%s.cc" % (name), + h_file_output = "%s.h" % (name), + cpp_namespace = cc_namespace, + visibility = visibility, + flatten = True, + )
diff --git a/tools/debugger/BUILD b/tools/debugger/BUILD new file mode 100644 index 0000000..1146b8e --- /dev/null +++ b/tools/debugger/BUILD
@@ -0,0 +1,173 @@ +# IREE Debugger UIs. +# +# The main debugger UI can be used in standalone mode connected to a remote +# host (via :debugger) or can be directly embedded into the IREE runtime to +# allow for attaching (--iree_attach_debugger). +# +# By default the IREE runtime does not compile in debug support. To link it in +# pass --define=IREE_DEBUG=1 to bazel builds of the runtime. + +# TODO(benvanik): re-enable debugger after refactoring. +# load("//third_party/emscripten:split_transition_defs.bzl", "auto_wasm_binary") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +# TODO(benvanik): re-enable debugger after refactoring. +# alias( +# name = "debugger", +# actual = select({ +# "//tools/cc_target_os:emscripten": ":debug_app_emscripten_files", +# "//conditions:default": ":debug_app_native", +# }), +# ) +# +# cc_library( +# name = "debug_app_library", +# srcs = ["debug_app.cc"], +# hdrs = ["debug_app.h"], +# deps = [ +# "//third_party/GL:GLES2_headers", +# "//third_party/SDL2", +# "@com_google_absl//absl/flags:flag", +# "@com_google_absl//absl/memory", +# "@com_google_absl//absl/strings", +# "@com_google_absl//absl/types:optional", +# "//third_party/dear_imgui", +# "//third_party/dear_imgui:imgui_sdl_opengl3", +# "///base:memory", +# "///base:source_location", +# "///base:status", +# "///rt/debug:debug_client", +# "///schemas", +# ], +# ) +# +# # NOTE: users must also link in a GL implementation, like: +# # "//third_party/GL/native:GLESv2", # build-cleaner: keep +# cc_library( +# name = "debug_app_embedded", +# srcs = ["debug_app_embedded.cc"], +# hdrs = ["debug_app_embedded.h"], +# deps = [ +# ":debug_app_library", +# "//third_party/SDL2", +# "@com_google_absl//absl/base:core_headers", +# "@com_google_absl//absl/memory", +# "@com_google_absl//absl/strings", +# "@com_google_absl//absl/synchronization", +# "//third_party/dear_imgui", +# "///base:memory", +# "///base:status", +# ], +# ) +# +# EMSCRIPTEN_LINKOPTS_COMMON = [ +# # Error at compile time on unresolved symbols. +# "-s ERROR_ON_UNDEFINED_SYMBOLS=1", +# +# # Required by SDL. +# "-s EXTRA_EXPORTED_RUNTIME_METHODS=Pointer_stringify", +# +# # TODO(benvanik): tweak to enable support when needed. +# "-s ALLOW_MEMORY_GROWTH=1", +# # "-s WASM_MEM_MAX=268435456", # 256MB +# # "-s TOTAL_MEMORY=268435456", # 256MB +# ] +# +# EMSCRIPTEN_LINKOPTS_DBG = [ +# # Show WASM stack trace in Chrome debugger. +# "-g2", +# "-s DEMANGLE_SUPPORT=1", +# +# # Enable verbose assertions. +# "-s ASSERTIONS=2", +# "-s SAFE_HEAP=1", +# "-s STACK_OVERFLOW_CHECK=2", +# ] +# +# EMSCRIPTEN_LINKOPTS_OPT = [] +# +# cc_binary( +# name = "debug_app_emscripten", +# srcs = ["debug_app_main_emscripten.cc"], +# linkopts = EMSCRIPTEN_LINKOPTS_COMMON + select({ +# "//tools/compilation_mode:dbg": EMSCRIPTEN_LINKOPTS_DBG, +# "//tools/compilation_mode:opt": EMSCRIPTEN_LINKOPTS_OPT, +# "//conditions:default": EMSCRIPTEN_LINKOPTS_OPT, +# }), +# tags = [ +# "manual", +# "notap", # TODO(b/137088911): Build/test on TAP +# "wasm", +# ], +# deps = [ +# ":debug_app_library", +# "//third_party/SDL2", +# "@com_google_absl//absl/memory", +# "//third_party/dear_imgui", +# "//third_party/dear_imgui:imgui_sdl_opengl3", +# "///base:init", +# "///base:source_location", +# "///base:status", +# ], +# ) +# +# auto_wasm_binary( +# name = "debug_app_emscripten_binary", +# cc_target = ":debug_app_emscripten", +# tags = ["manual"], +# ) +# +# Fileset( +# name = "debug_app_emscripten_files", +# out = "wasm_files", +# entries = [ +# FilesetEntry( +# files = [":debug_app_emscripten_binary"], +# strip_prefix = "debug_app_emscripten_binary", +# destdir = "wasm", +# ), +# FilesetEntry( +# files = ["debug_app.html"], +# destdir = "wasm", +# ), +# ], +# tags = ["manual"], +# ) +# +# cc_binary( +# name = "debug_app_native", +# srcs = ["debug_app_main_native.cc"], +# deps = [ +# ":debug_app_embedded", +# "//third_party/GL/native:EGL", # build-cleaner: keep +# "//third_party/GL/native:GLESv2", # build-cleaner: keep +# "///base:init", +# "///base:status", +# ], +# ) +# +# cc_binary( +# name = "debug_cli", +# srcs = ["debug_cli_main.cc"], +# deps = [ +# ":debug_prompt", +# "@com_google_absl//absl/flags:flag", +# "///base:init", +# "///base:status", +# ], +# ) +# +# cc_library( +# name = "debug_prompt", +# srcs = ["debug_prompt.cc"], +# hdrs = ["debug_prompt.h"], +# deps = [ +# "@com_google_absl//absl/strings", +# "///base:status", +# "///rt/debug:debug_client", +# ], +# )
diff --git a/tools/debugger/debug_app.cc b/tools/debugger/debug_app.cc new file mode 100644 index 0000000..5136b04 --- /dev/null +++ b/tools/debugger/debug_app.cc
@@ -0,0 +1,1422 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/debugger/debug_app.h" + +#include <GLES2/gl2.h> + +#include <algorithm> +#include <cstdio> + +#include "absl/flags/flag.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/types/optional.h" +#include "base/memory.h" +#include "base/source_location.h" +#include "base/status.h" +#include "rt/debug/debug_client.h" +#include "schemas/debug_service_generated.h" +#include "third_party/dear_imgui/imgui.h" +#include "third_party/dear_imgui/imgui_internal.h" +#include "vm/bytecode_module.h" +#include "vm/bytecode_tables_sequencer.h" + +namespace iree { +namespace rt { +namespace debug { +namespace { + +void PushButtonHue(float hue) { + ImGui::PushStyleColor(ImGuiCol_Button, + (ImVec4)ImColor::HSV(hue / 7.0f, 0.6f, 0.6f)); + ImGui::PushStyleColor(ImGuiCol_ButtonHovered, + (ImVec4)ImColor::HSV(hue / 7.0f, 0.7f, 0.7f)); + ImGui::PushStyleColor(ImGuiCol_ButtonActive, + (ImVec4)ImColor::HSV(hue / 7.0f, 0.8f, 0.8f)); +} + +void PushButtonColor(const ImVec4& color) { + ImGui::PushStyleColor(ImGuiCol_Button, color); + ImGui::PushStyleColor(ImGuiCol_ButtonHovered, color); + ImGui::PushStyleColor(ImGuiCol_ButtonActive, color); +} + +void PopButtonStyle() { ImGui::PopStyleColor(3); } + +bool AreBreakpointsEqual(const RemoteBreakpoint& breakpoint, + const DebugApp::UserBreakpoint& user_breakpoint) { + if (user_breakpoint.active_breakpoint == &breakpoint) { + return true; + } else if (user_breakpoint.type != breakpoint.type()) { + return false; + } + switch (breakpoint.type()) { + case RemoteBreakpoint::Type::kBytecodeFunction: + if (user_breakpoint.function_ordinal != -1 && + user_breakpoint.function_ordinal != breakpoint.function_ordinal()) { + return false; + } + return breakpoint.module_name() == user_breakpoint.module_name && + breakpoint.function_name() == user_breakpoint.function_name && + breakpoint.bytecode_offset() == user_breakpoint.bytecode_offset; + case RemoteBreakpoint::Type::kNativeFunction: + return breakpoint.function_name() == user_breakpoint.native_function; + default: + return false; + } +} + +} // namespace + +// static +void DebugApp::PumpMainLoopThunk(void* arg) { + auto status = reinterpret_cast<DebugApp*>(arg)->PumpMainLoop(); + if (IsCancelled(status)) { + return; + } else if (!status.ok()) { + CHECK_OK(status); + } +} + +DebugApp::DebugApp(SDL_Window* window, SDL_GLContext gl_context, + const char* glsl_version) + : window_(window), gl_context_(gl_context) { + VLOG(1) << "DebugApp initializing..."; + IMGUI_CHECKVERSION(); + ImGui::CreateContext(); + ImGuiIO& io = ImGui::GetIO(); + io.ConfigFlags |= ImGuiConfigFlags_NavEnableKeyboard; + io.ConfigFlags |= ImGuiConfigFlags_DockingEnable; + + // TODO(benvanik): ini file for settings. + io.IniFilename = nullptr; + // ImGui::LoadIniSettingsFromMemory() + // ImGui::SaveIniSettingsToMemory() + + // TODO(benvanik): theming. + ImGui::StyleColorsDark(); + + // Setup Platform/Renderer bindings + ImGui_ImplSDL2_InitForOpenGL(window_, gl_context_); + ImGui_ImplOpenGL3_Init(glsl_version); + SDL_GL_MakeCurrent(nullptr, nullptr); + VLOG(1) << "DebugApp initialized"; +} + +DebugApp::~DebugApp() { + VLOG(1) << "DebugApp shutting down..."; + ImGui_ImplOpenGL3_Shutdown(); + ImGui_ImplSDL2_Shutdown(); + ImGui::DestroyContext(); + + SDL_GL_DeleteContext(gl_context_); + SDL_GL_MakeCurrent(nullptr, nullptr); + SDL_DestroyWindow(window_); + SDL_Quit(); + VLOG(1) << "DebugApp shut down (SDL_Quit)"; +} + +Status DebugApp::Connect(absl::string_view service_address) { + VLOG(1) << "Connecting to debug service at " << service_address << "..."; + ASSIGN_OR_RETURN(debug_client_, DebugClient::Connect(service_address, this)); + + // TODO(benvanik): load breakpoints from file. + UserBreakpoint user_breakpoint; + user_breakpoint.module_name = "module"; + user_breakpoint.function_name = "main"; + user_breakpoint.bytecode_offset = 0; + user_breakpoint.wants_enabled = true; + user_breakpoint_list_.push_back(std::move(user_breakpoint)); + RETURN_IF_ERROR(RefreshActiveBreakpoints()); + + // Set paused so that we need to resume to continue execution. + is_paused_ = true; + return OkStatus(); +} + +Status DebugApp::Disconnect() { + VLOG(1) << "Disconnecting from debug service"; + debug_client_.reset(); + return OkStatus(); +} + +bool DebugApp::is_paused() const { + if (!debug_client_) { + return false; + } + if (!hit_breakpoints_.empty()) { + return true; // One or more breakpoints hit. + } + return is_paused_ || !is_stepping_; +} + +RemoteInvocation* DebugApp::GetSelectedInvocation() const { + if (!debug_client_ || !selected_invocation_id_.has_value()) { + return nullptr; + } + for (auto* invocation : debug_client_->invocations()) { + if (invocation->id() == selected_invocation_id_.value()) { + return invocation; + } + } + return nullptr; +} + +Status DebugApp::RefreshActiveBreakpoints() { + // Set all breakpoints to disabled. We'll re-enable them as we find them + // below. + for (auto& user_breakpoint : user_breakpoint_list_) { + user_breakpoint.active_breakpoint = nullptr; + } + + // If not connected then no breakpoints are active. + if (!debug_client_) { + return OkStatus(); + } + + // Reconcile the user breakpoint list with the breakpoints available on the + // server. + for (auto* breakpoint : debug_client_->breakpoints()) { + auto it = + std::find_if(user_breakpoint_list_.begin(), user_breakpoint_list_.end(), + [breakpoint](const UserBreakpoint& user_breakpoint) { + return AreBreakpointsEqual(*breakpoint, user_breakpoint); + }); + if (it == user_breakpoint_list_.end()) { + // Breakpoint not found - add to user list. + UserBreakpoint user_breakpoint; + user_breakpoint.type = breakpoint->type(); + user_breakpoint.active_breakpoint = breakpoint; + user_breakpoint.module_name = breakpoint->module_name(); + user_breakpoint.function_name = breakpoint->function_name(); + user_breakpoint.function_ordinal = breakpoint->function_ordinal(); + user_breakpoint.bytecode_offset = breakpoint->bytecode_offset(); + user_breakpoint_list_.push_back(std::move(user_breakpoint)); + } else { + // Breakpoint found - set the active pointer. + UserBreakpoint& user_breakpoint = *it; + user_breakpoint.active_breakpoint = breakpoint; + user_breakpoint.is_enabling = false; + user_breakpoint.module_name = breakpoint->module_name(); + user_breakpoint.function_name = breakpoint->function_name(); + user_breakpoint.function_ordinal = breakpoint->function_ordinal(); + user_breakpoint.bytecode_offset = breakpoint->bytecode_offset(); + } + } + + // Ensure any breakpoint the user wants enabled is active/otherwise. + for (auto& user_breakpoint : user_breakpoint_list_) { + if (user_breakpoint.wants_enabled && !user_breakpoint.is_enabling && + !user_breakpoint.active_breakpoint) { + // Add breakpoint on server. + switch (user_breakpoint.type) { + case RemoteBreakpoint::Type::kBytecodeFunction: + RETURN_IF_ERROR(debug_client_->AddFunctionBreakpoint( + user_breakpoint.module_name, user_breakpoint.function_name, + user_breakpoint.bytecode_offset, + [&user_breakpoint](const RemoteBreakpoint& breakpoint) { + user_breakpoint.function_ordinal = + breakpoint.function_ordinal(); + })); + break; + case RemoteBreakpoint::Type::kNativeFunction: + // TODO(benvanik): native breakpoint support. + return UnimplementedErrorBuilder(IREE_LOC) + << "Native function breakpoints are TODO"; + default: + return UnimplementedErrorBuilder(IREE_LOC) + << "Unimplemented breakpoint type"; + } + user_breakpoint.is_enabling = true; + } else if (!user_breakpoint.wants_enabled && + user_breakpoint.active_breakpoint) { + // Remove breakpoint from server. + RETURN_IF_ERROR( + debug_client_->RemoveBreakpoint(*user_breakpoint.active_breakpoint)); + + user_breakpoint.active_breakpoint = nullptr; + } + } + + return OkStatus(); +} + +bool DebugApp::IsStoppedAtBreakpoint( + const UserBreakpoint& user_breakpoint) const { + return std::find(hit_breakpoints_.begin(), hit_breakpoints_.end(), + user_breakpoint.active_breakpoint) != hit_breakpoints_.end(); +} + +int DebugApp::FindMatchingUserBreakpointIndex(absl::string_view module_name, + int function_ordinal, + int offset) { + for (int i = 0; i < user_breakpoint_list_.size(); ++i) { + auto& user_breakpoint = user_breakpoint_list_[i]; + if (user_breakpoint.module_name == module_name && + user_breakpoint.function_ordinal == function_ordinal && + user_breakpoint.bytecode_offset == offset) { + return i; + } + } + return -1; +} + +int DebugApp::FindMatchingUserBreakpointIndex(absl::string_view module_name, + absl::string_view function_name, + int offset) { + for (int i = 0; i < user_breakpoint_list_.size(); ++i) { + auto& user_breakpoint = user_breakpoint_list_[i]; + if (user_breakpoint.module_name == module_name && + user_breakpoint.function_name == function_name && + user_breakpoint.bytecode_offset == offset) { + return i; + } + } + return -1; +} + +Status DebugApp::ResumeFromBreakpoint(UserBreakpoint* user_breakpoint) { + if (!user_breakpoint->active_breakpoint) { + return FailedPreconditionErrorBuilder(IREE_LOC) << "Breakpoint not active"; + } + VLOG(1) << "Resuming from breakpoint " + << user_breakpoint->active_breakpoint->id() << "..."; + auto it = std::find(hit_breakpoints_.begin(), hit_breakpoints_.end(), + user_breakpoint->active_breakpoint); + if (it == hit_breakpoints_.end()) { + return NotFoundErrorBuilder(IREE_LOC) << "Breakpoint not found"; + } + hit_breakpoints_.erase(it); + return debug_client_->MakeReady(); +} + +Status DebugApp::OnContextRegistered(const RemoteContext& context) { + // Ack event. + return debug_client_->MakeReady(); +} + +Status DebugApp::OnContextUnregistered(const RemoteContext& context) { + // Close documents that may reference modules in the context. + std::vector<CodeViewDocument*> closing_documents; + for (auto& document : documents_) { + auto* module = document->function->module(); + if (module->context_id() != context.id()) { + // Document is not from this context so it's fine. + continue; + } + + // See if any other live context still has the module loaded. We can change + // the document over to that. + RemoteModule* replacement_module = nullptr; + for (auto* context : debug_client_->contexts()) { + for (auto* other_module : context->modules()) { + if (other_module->name() == module->name()) { + replacement_module = other_module; + break; + } + } + if (replacement_module) break; + } + if (replacement_module && replacement_module->is_loaded()) { + // Replace document module reference. + int function_ordinal = document->function->ordinal(); + auto functions = replacement_module->functions(); + if (function_ordinal < functions.size()) { + document->function = functions[function_ordinal]; + } else { + document->function = nullptr; + } + } else { + document->function = nullptr; + } + + if (!document->function) { + // Close the document if we don't have a valid function for it. + VLOG(1) + << "Closing document " << document->title + << " because the last context using the module is being unregistered"; + closing_documents.push_back(document.get()); + } + } + for (auto* document : closing_documents) { + auto it = std::find_if( + documents_.begin(), documents_.end(), + [document](const std::unique_ptr<CodeViewDocument>& open_document) { + return document == open_document.get(); + }); + documents_.erase(it); + } + + // Ack event. + return debug_client_->MakeReady(); +} + +Status DebugApp::OnModuleLoaded(const RemoteContext& context, + const RemoteModule& module) { + // Ack event. + return debug_client_->MakeReady(); +} + +Status DebugApp::OnInvocationRegistered(const RemoteInvocation& invocation) { + if (!selected_invocation_id_.has_value()) { + selected_invocation_id_ = invocation.id(); + selected_stack_frame_index_ = {}; + } + + // Ack event. + return debug_client_->MakeReady(); +} + +Status DebugApp::OnInvocationUnregistered(const RemoteInvocation& invocation) { + if (selected_invocation_id_.has_value() && + selected_invocation_id_.value() == invocation.id()) { + selected_invocation_id_ = {}; + selected_stack_frame_index_ = {}; + } + + // Ack event. + return debug_client_->MakeReady(); +} + +Status DebugApp::OnBreakpointHit(const RemoteBreakpoint& breakpoint, + const RemoteInvocation& invocation) { + // Keep track of where we are stopped. + hit_breakpoints_.push_back(&breakpoint); + return NavigateToCodeView(invocation, -1, NavigationMode::kMatchDocument); +} + +Status DebugApp::PumpMainLoop() { + ImGuiIO& io = ImGui::GetIO(); + + if (debug_client_) { + RETURN_IF_ERROR(debug_client_->Poll()); + } + RETURN_IF_ERROR(RefreshActiveBreakpoints()); + + SDL_GL_MakeCurrent(window_, gl_context_); + + SDL_Event event; + while (SDL_PollEvent(&event)) { + ImGui_ImplSDL2_ProcessEvent(&event); + if (event.type == SDL_QUIT) { + return CancelledErrorBuilder(IREE_LOC) << "Quit hotkey"; + } else if (event.type == SDL_WINDOWEVENT && + event.window.event == SDL_WINDOWEVENT_CLOSE && + event.window.windowID == SDL_GetWindowID(window_)) { + return CancelledErrorBuilder(IREE_LOC) << "Window closed"; + } + } + ImGui_ImplOpenGL3_NewFrame(); + ImGui_ImplSDL2_NewFrame(window_); + ImGui::NewFrame(); + + auto draw_status = DrawUI(); + if (!draw_status.ok()) { + // TODO(benvanik): show on screen? Probably all messed up. + LOG(ERROR) << draw_status; + } + + // Blit the entire ImGui UI. + ImGui::Render(); + SDL_GL_MakeCurrent(window_, gl_context_); + glViewport(0, 0, (int)io.DisplaySize.x, (int)io.DisplaySize.y); + glClearColor(0.45f, 0.55f, 0.60f, 1.0f); + glClear(GL_COLOR_BUFFER_BIT); + // Workaround for terrible bad SDL/graphics driver leaks. + IREE_DISABLE_LEAK_CHECKS(); + ImGui_ImplOpenGL3_RenderDrawData(ImGui::GetDrawData()); + IREE_ENABLE_LEAK_CHECKS(); + + // Render additional viewport windows (desktop only). + if (io.ConfigFlags & ImGuiConfigFlags_ViewportsEnable) { + SDL_Window* backup_current_window = SDL_GL_GetCurrentWindow(); + SDL_GLContext backup_current_context = SDL_GL_GetCurrentContext(); + ImGui::UpdatePlatformWindows(); + ImGui::RenderPlatformWindowsDefault(); + SDL_GL_MakeCurrent(backup_current_window, backup_current_context); + } + + SDL_GL_SwapWindow(window_); + return OkStatus(); +} + +Status DebugApp::LayoutInitialDockSpace() { + dockspace_id_ = ImGui::GetID("MainDockSpace"); + if (ImGui::DockBuilderGetNode(dockspace_id_)) { + // Already configured. + return OkStatus(); + } + ImGui::DockBuilderAddNode(dockspace_id_, ImGuiDockNodeFlags_DockSpace); + + dock_content_id_ = dockspace_id_; + dock_top_id_ = ImGui::DockBuilderSplitNode(dock_content_id_, ImGuiDir_Up, + 0.05f, nullptr, &dock_content_id_); + dock_left_id_ = ImGui::DockBuilderSplitNode( + dock_content_id_, ImGuiDir_Left, 0.20f, nullptr, &dock_content_id_); + dock_bottom_id_ = ImGui::DockBuilderSplitNode( + dock_content_id_, ImGuiDir_Down, 0.20f, nullptr, &dock_content_id_); + dock_right_id_ = ImGui::DockBuilderSplitNode( + dock_content_id_, ImGuiDir_Right, 0.20f, nullptr, &dock_content_id_); + dock_bottom_left_id_ = ImGui::DockBuilderSplitNode( + dock_bottom_id_, ImGuiDir_Left, 0.50f, nullptr, &dock_bottom_right_id_); + + ImGui::DockBuilderDockWindow("Toolbar", dock_top_id_); + auto* dock_top_node = ImGui::DockBuilderGetNode(dock_top_id_); + dock_top_node->LocalFlags = ImGuiDockNodeFlags_NoSplit | + ImGuiDockNodeFlags_NoResize | + ImGuiDockNodeFlags_AutoHideTabBar; + + ImGui::DockBuilderDockWindow("Modules", dock_left_id_); + ImGui::DockBuilderDockWindow("Locals", dock_bottom_left_id_); + ImGui::DockBuilderDockWindow("Invocations", dock_bottom_right_id_); + ImGui::DockBuilderDockWindow("Breakpoints", dock_bottom_right_id_); + + ImGui::DockBuilderFinish(dockspace_id_); + return OkStatus(); +} + +Status DebugApp::DrawUI() { + ImGuiWindowFlags window_flags = + ImGuiWindowFlags_MenuBar | ImGuiWindowFlags_NoDocking; + window_flags |= ImGuiWindowFlags_NoTitleBar | ImGuiWindowFlags_NoCollapse | + ImGuiWindowFlags_NoResize | ImGuiWindowFlags_NoMove | + ImGuiWindowFlags_NoNavFocus; + + ImGuiViewport* viewport = ImGui::GetMainViewport(); + ImGui::SetNextWindowPos(viewport->Pos); + ImGui::SetNextWindowSize(viewport->Size); + ImGui::SetNextWindowViewport(viewport->ID); + ImGui::PushStyleVar(ImGuiStyleVar_WindowRounding, 0.0f); + ImGui::PushStyleVar(ImGuiStyleVar_WindowBorderSize, 0.0f); + ImGui::PushStyleVar(ImGuiStyleVar_WindowPadding, ImVec2(0.0f, 0.0f)); + ImGui::Begin("IREEDebugRoot", nullptr, window_flags); + ImGui::PopStyleVar(3); + + RETURN_IF_ERROR(LayoutInitialDockSpace()); + ImGui::DockSpace(dockspace_id_, ImVec2(0.0f, 0.0f), ImGuiDockNodeFlags_None); + + RETURN_IF_ERROR(DrawMainMenu()); + RETURN_IF_ERROR(DrawToolbar()); + + ImGui::PushStyleVar(ImGuiStyleVar_WindowPadding, ImVec2(2, 2)); + RETURN_IF_ERROR(DrawBreakpointListPanel()); + RETURN_IF_ERROR(DrawModuleListPanel()); + RETURN_IF_ERROR(DrawLocalListPanel()); + RETURN_IF_ERROR(DrawInvocationListPanel()); + ImGui::PopStyleVar(); + + RETURN_IF_ERROR(DrawCodeViewPanels()); + + ImGui::End(); + return OkStatus(); +} + +Status DebugApp::DrawMainMenu() { + if (!ImGui::BeginMenuBar()) return OkStatus(); + + // TODO(benvanik): main menu. + if (ImGui::BeginMenu("File")) { + ImGui::EndMenu(); + } + + ImGui::EndMenuBar(); + return OkStatus(); +} + +Status DebugApp::DrawToolbar() { + // TODO(benvanik): figure out how to make this not grow. + ImGui::Begin("Toolbar", nullptr, + ImGuiWindowFlags_NoResize | ImGuiWindowFlags_NoTitleBar | + ImGuiWindowFlags_NoMove | ImGuiWindowFlags_NoCollapse | + ImGuiWindowFlags_NoScrollbar); + ImGui::BeginGroup(); + +#if !defined(IMGUI_DISABLE_DEMO_WINDOWS) + static bool show_demo_window = false; + if (ImGui::Button("Demo")) { + show_demo_window = !show_demo_window; + } + if (show_demo_window) { + ImGui::SetNextWindowDockID(dock_content_id_); + ImGui::ShowDemoWindow(&show_demo_window); + } +#endif // !IMGUI_DISABLE_DEMO_WINDOWS + + ImGui::SameLine(); + if (!debug_client_) { + if (ImGui::Button("Connect")) { + // TODO(benvanik): connection dialog and/or autoconnect. + } + } else { + if (ImGui::Button("Disconnect")) { + debug_client_.reset(); + } + } + + ImGui::SameLine(); + if (debug_client_) { + ImGui::Text("<status>"); + } else { + ImGui::TextDisabled("disconnected"); + } + + ImGui::SameLine(); + ImGui::Spacing(); + ImGui::SameLine(); + ImGui::Spacing(); + + ImGui::SameLine(); + ImGui::BeginGroup(); + ImGui::Text("Invocation: "); + ImGui::SameLine(); + ImGui::SetNextItemWidth(300); + auto* selected_invocation = GetSelectedInvocation(); + const std::string& active_invocation_name = + selected_invocation ? selected_invocation->name() : ""; + if (ImGui::BeginCombo("##active_invocation", active_invocation_name.c_str(), + ImGuiComboFlags_PopupAlignLeft)) { + if (debug_client_) { + for (auto* invocation : debug_client_->invocations()) { + ImGui::PushID(invocation->id()); + bool is_selected = invocation == selected_invocation; + if (ImGui::Selectable(invocation->name().c_str(), is_selected)) { + RETURN_IF_ERROR(NavigateToCodeView(*invocation, -1, + NavigationMode::kMatchDocument)); + } + if (is_selected) { + ImGui::SetItemDefaultFocus(); + } + ImGui::PopID(); + } + } + ImGui::EndCombo(); + } + ImGui::EndGroup(); + + ImGui::SameLine(); + ImGui::BeginGroup(); + static const float kPauseButtonHue = 0.0f; + static const float kResumeButtonHue = 2.0f; + static const float kStepButtonHue = 1.0f; + if (debug_client_ && !is_paused()) { + PushButtonHue(kPauseButtonHue); + if (ImGui::Button("Pause")) { + RETURN_IF_ERROR(debug_client_->SuspendAllInvocations()); + } + PopButtonStyle(); + } else if (debug_client_ && is_paused()) { + ImGui::PushStyleColor(ImGuiCol_Button, 0xFF666666); + ImGui::PushStyleColor(ImGuiCol_Text, 0xFFAAAAAA); + ImGui::ButtonEx("Pause", {}, ImGuiButtonFlags_Disabled); + ImGui::PopStyleColor(2); + } + if (debug_client_ && is_paused()) { + ImGui::SameLine(); + PushButtonHue(kResumeButtonHue); + if (ImGui::Button("Resume")) { + if (is_paused_) { + is_paused_ = false; + RETURN_IF_ERROR(debug_client_->MakeReady()); + } + while (!hit_breakpoints_.empty()) { + hit_breakpoints_.pop_back(); + RETURN_IF_ERROR(debug_client_->MakeReady()); + } + } + PopButtonStyle(); + } else { + ImGui::PushStyleColor(ImGuiCol_Button, 0xFF666666); + ImGui::PushStyleColor(ImGuiCol_Text, 0xFFAAAAAA); + ImGui::SameLine(); + ImGui::ButtonEx("Resume", {}, ImGuiButtonFlags_Disabled); + ImGui::PopStyleColor(2); + } + + if (debug_client_ && is_paused() && selected_invocation) { + ImGui::SameLine(); + PushButtonHue(kStepButtonHue); + if (ImGui::Button("Step Into")) { + RETURN_IF_ERROR( + debug_client_->StepInvocation(*selected_invocation, [this]() { + is_paused_ = true; + is_stepping_ = false; + })); + is_stepping_ = true; + } + PopButtonStyle(); + ImGui::SameLine(); + if (ImGui::Button("Step Over")) { + RETURN_IF_ERROR( + debug_client_->StepInvocationOver(*selected_invocation, [this]() { + is_paused_ = true; + is_stepping_ = false; + })); + is_stepping_ = true; + } + ImGui::SameLine(); + if (ImGui::Button("Step Out")) { + RETURN_IF_ERROR( + debug_client_->StepInvocationOut(*selected_invocation, [this]() { + is_paused_ = true; + is_stepping_ = false; + })); + is_stepping_ = true; + } + if (ImGui::BeginPopup("Step to...")) { + // TODO(benvanik): step to Invoke exit, next FFI call, etc + ImGui::MenuItem("(stuff)"); + ImGui::EndPopup(); + } + ImGui::SameLine(); + if (ImGui::Button("Step to...")) { + ImGui::OpenPopup("Step to..."); + } + } else { + ImGui::PushStyleColor(ImGuiCol_Button, 0xFF666666); + ImGui::PushStyleColor(ImGuiCol_Text, 0xFFAAAAAA); + ImGui::SameLine(); + ImGui::ButtonEx("Step Into", {}, ImGuiButtonFlags_Disabled); + ImGui::SameLine(); + ImGui::ButtonEx("Step Over", {}, ImGuiButtonFlags_Disabled); + ImGui::SameLine(); + ImGui::ButtonEx("Step Out", {}, ImGuiButtonFlags_Disabled); + ImGui::SameLine(); + ImGui::ButtonEx("Step to...", {}, ImGuiButtonFlags_Disabled); + ImGui::PopStyleColor(2); + } + ImGui::EndGroup(); + + ImGui::EndGroup(); + ImGui::End(); + return OkStatus(); +} + +Status DebugApp::DrawBreakpointListPanel() { + static bool is_panel_visible = true; + if (!ImGui::Begin("Breakpoints", &is_panel_visible, ImGuiWindowFlags_None)) { + ImGui::End(); + return OkStatus(); + } + + ImGui::PushStyleVar(ImGuiStyleVar_WindowPadding, ImVec2(8, 8)); + absl::optional<RemoteBreakpoint::Type> add_breakpoint_type; + if (ImGui::BeginPopup("+ Function")) { + if (ImGui::MenuItem("Bytecode Function")) { + add_breakpoint_type = RemoteBreakpoint::Type::kBytecodeFunction; + } + if (ImGui::MenuItem("Native Function")) { + add_breakpoint_type = RemoteBreakpoint::Type::kNativeFunction; + } + ImGui::EndPopup(); + } + ImGui::PopStyleVar(); + if (ImGui::Button("+ Function")) { + ImGui::OpenPopup("+ Function"); + } + RETURN_IF_ERROR(DrawAddBreakpointDialogs(add_breakpoint_type)); + + ImGui::SameLine(); + if (ImGui::Button("Remove All")) { + // TODO(benvanik): removal all is broken - need removebreakpoints or a + // 'want_removal' flag so that RefreshActiveBreakpoints handles things. + // Right now if you have 2 breakpoints and hit remove all the second will + // come back during the next refresh (as the server hasn't removed it yet). + for (auto& user_breakpoint : user_breakpoint_list_) { + if (user_breakpoint.active_breakpoint) { + RETURN_IF_ERROR(debug_client_->RemoveBreakpoint( + *user_breakpoint.active_breakpoint)); + user_breakpoint.active_breakpoint = nullptr; + } + } + user_breakpoint_list_.clear(); + } + ImGui::Separator(); + + ImGui::BeginChild("BreakpointList", ImVec2(-1, -1), false, + ImGuiWindowFlags_AlwaysVerticalScrollbar); + std::vector<UserBreakpoint*> dead_breakpoints; + for (auto& user_breakpoint : user_breakpoint_list_) { + ASSIGN_OR_RETURN(bool should_keep, DrawBreakpoint(&user_breakpoint)); + if (!should_keep) { + dead_breakpoints.push_back(&user_breakpoint); + } + } + for (auto* user_breakpoint : dead_breakpoints) { + for (auto it = user_breakpoint_list_.begin(); + it != user_breakpoint_list_.end(); ++it) { + if (&*it == user_breakpoint) { + if (user_breakpoint->active_breakpoint) { + RETURN_IF_ERROR(debug_client_->RemoveBreakpoint( + *user_breakpoint->active_breakpoint)); + } + user_breakpoint_list_.erase(it); + break; + } + } + } + ImGui::EndChild(); + + ImGui::End(); + return OkStatus(); +} + +StatusOr<bool> DebugApp::DrawBreakpoint(UserBreakpoint* user_breakpoint) { + std::string breakpoint_name; + switch (user_breakpoint->type) { + case RemoteBreakpoint::Type::kBytecodeFunction: + breakpoint_name = + absl::StrCat("[bytecode] ", user_breakpoint->module_name, ":", + user_breakpoint->function_name, ":", + user_breakpoint->bytecode_offset); + if (user_breakpoint->function_ordinal != -1) { + absl::StrAppend(&breakpoint_name, " @", + user_breakpoint->function_ordinal); + } + break; + case RemoteBreakpoint::Type::kNativeFunction: + breakpoint_name = + absl::StrCat("[native ] ", user_breakpoint->native_function); + break; + } + ImGui::BeginGroup(); + bool is_closing = true; + bool is_expanded = ImGui::CollapsingHeader( + ("##" + breakpoint_name).c_str(), &is_closing, + ImGuiTreeNodeFlags_Framed | ImGuiTreeNodeFlags_NoTreePushOnOpen | + ImGuiTreeNodeFlags_NoAutoOpenOnLog | ImGuiTreeNodeFlags_OpenOnArrow | + ImGuiTreeNodeFlags_OpenOnDoubleClick); + ImGui::SameLine(); + ImGui::Checkbox(breakpoint_name.c_str(), &user_breakpoint->wants_enabled); + ImGui::EndGroup(); + if (!is_expanded) { + return is_closing; + } + ImGui::PushID(breakpoint_name.c_str()); + + ImGui::Text("(breakpoint stats/etc)"); + + ImGui::PopID(); + return is_closing; +} + +Status DebugApp::DrawAddBreakpointDialogs( + absl::optional<RemoteBreakpoint::Type> add_breakpoint_type) { + if (add_breakpoint_type.has_value()) { + switch (add_breakpoint_type.value()) { + case RemoteBreakpoint::Type::kBytecodeFunction: + ImGui::OpenPopup("Add Bytecode Function Breakpoint"); + break; + case RemoteBreakpoint::Type::kNativeFunction: + ImGui::OpenPopup("Add Native Function Breakpoint"); + break; + } + } + ImGui::PushStyleVar(ImGuiStyleVar_WindowPadding, ImVec2(8, 8)); + RETURN_IF_ERROR(DrawAddBytecodeFunctionBreakpointDialog()); + RETURN_IF_ERROR(DrawAddNativeFunctionBreakpointDialog()); + ImGui::PopStyleVar(); + return OkStatus(); +} + +Status DebugApp::DrawAddBytecodeFunctionBreakpointDialog() { + ImGui::SetNextWindowSize(ImVec2(400, 400), ImGuiCond_FirstUseEver); + bool close_popup = true; + if (!ImGui::BeginPopupModal("Add Bytecode Function Breakpoint", &close_popup, + ImGuiWindowFlags_None)) { + return OkStatus(); + } + ImGui::BeginGroup(); + ImGui::BeginChild("##data_entry", + ImVec2(0, -ImGui::GetFrameHeightWithSpacing())); + + ImGui::TextWrapped( + "Adds a breakpoint set on the entry of the function (offset=0)."); + ImGui::Separator(); + + // TODO(benvanik): fancy list, filtering, etc. + + static char module_name[256] = {0}; + ImGui::InputText("Module", module_name, sizeof(module_name)); + ImGui::SetItemDefaultFocus(); + + static char function_name[256] = {0}; + ImGui::InputText("Function", function_name, sizeof(function_name)); + + ImGui::EndChild(); + ImGui::Separator(); + + if (ImGui::Button("Add")) { + int offset = 0; + if (FindMatchingUserBreakpointIndex(module_name, function_name, offset) == + -1) { + UserBreakpoint user_breakpoint; + user_breakpoint.type = RemoteBreakpoint::Type::kBytecodeFunction; + user_breakpoint.module_name = module_name; + user_breakpoint.function_name = function_name; + user_breakpoint.bytecode_offset = offset; + user_breakpoint.wants_enabled = true; + user_breakpoint_list_.push_back(std::move(user_breakpoint)); + } + ImGui::CloseCurrentPopup(); + } + ImGui::SameLine(); + if (ImGui::Button("Cancel")) { + ImGui::CloseCurrentPopup(); + } + + ImGui::EndGroup(); + ImGui::EndPopup(); + return OkStatus(); +} + +Status DebugApp::DrawAddNativeFunctionBreakpointDialog() { + ImGui::SetNextWindowSize(ImVec2(400, 400), ImGuiCond_FirstUseEver); + bool close_popup = true; + if (!ImGui::BeginPopupModal("Add Native Function Breakpoint", &close_popup, + ImGuiWindowFlags_None)) { + return OkStatus(); + } + ImGui::BeginGroup(); + ImGui::BeginChild("##data_entry", + ImVec2(0, -ImGui::GetFrameHeightWithSpacing())); + + ImGui::TextWrapped( + "Adds a breakpoint set on any call to the given FFI imported " + "function."); + ImGui::Separator(); + + static char function_name[256] = {0}; + ImGui::InputText("Function", function_name, sizeof(function_name)); + ImGui::SetItemDefaultFocus(); + + ImGui::EndChild(); + ImGui::Separator(); + + if (ImGui::Button("Add")) { + UserBreakpoint user_breakpoint; + user_breakpoint.type = RemoteBreakpoint::Type::kNativeFunction; + user_breakpoint.native_function = function_name; + user_breakpoint.wants_enabled = true; + user_breakpoint_list_.push_back(std::move(user_breakpoint)); + ImGui::CloseCurrentPopup(); + } + ImGui::SameLine(); + if (ImGui::Button("Cancel")) { + ImGui::CloseCurrentPopup(); + } + + ImGui::EndGroup(); + ImGui::EndPopup(); + return OkStatus(); +} + +Status DebugApp::DrawModuleListPanel() { + static bool is_panel_visible = true; + if (!ImGui::Begin("Modules", &is_panel_visible, ImGuiWindowFlags_None)) { + ImGui::End(); + return OkStatus(); + } else if (!debug_client_) { + ImGui::TextDisabled("disconnected"); + ImGui::End(); + return OkStatus(); + } + ImGui::PushStyleVar(ImGuiStyleVar_WindowPadding, ImVec2(4, 4)); + + ImGui::BeginGroup(); + ImGui::SetNextItemWidth(ImGui::GetContentRegionAvailWidth()); + static char function_name_filter_text[256] = {0}; + ImGui::InputTextWithHint( + "##function_name_filter", "Filter functions", function_name_filter_text, + sizeof(function_name_filter_text), ImGuiInputTextFlags_AutoSelectAll); + ImGuiTextFilter function_name_filter(function_name_filter_text); + ImGui::EndGroup(); + + ImGui::Separator(); + + ImGui::BeginGroup(); + ImGui::BeginChild("##context_list", ImVec2(0, -ImGui::GetFrameHeight())); + for (auto* context : debug_client_->contexts()) { + RETURN_IF_ERROR(DrawContext(*context, function_name_filter)); + } + ImGui::EndChild(); + ImGui::EndGroup(); + + ImGui::PopStyleVar(); + ImGui::End(); + return OkStatus(); +} + +Status DebugApp::DrawContext(const RemoteContext& context, + const ImGuiTextFilter& filter) { + std::string context_name = absl::StrCat("Context ", context.id()); + if (!ImGui::CollapsingHeader(context_name.c_str(), nullptr, + ImGuiTreeNodeFlags_DefaultOpen | + ImGuiTreeNodeFlags_Framed | + ImGuiTreeNodeFlags_NoTreePushOnOpen | + ImGuiTreeNodeFlags_NoAutoOpenOnLog | + ImGuiTreeNodeFlags_OpenOnArrow | + ImGuiTreeNodeFlags_OpenOnDoubleClick)) { + return OkStatus(); + } + ImGui::PushID(context.id()); + for (auto* module : context.modules()) { + RETURN_IF_ERROR(DrawModule(module, filter)); + } + ImGui::PopID(); + return OkStatus(); +} + +Status DebugApp::DrawModule(RemoteModule* module, + const ImGuiTextFilter& filter) { + ImGui::PushID(module->name().c_str()); + if (ImGui::TreeNodeEx(module->name().c_str(), + ImGuiTreeNodeFlags_Framed | + ImGuiTreeNodeFlags_DefaultOpen | + ImGuiTreeNodeFlags_OpenOnDoubleClick | + ImGuiTreeNodeFlags_OpenOnArrow)) { + if (module->CheckLoadedOrRequest()) { + for (auto* function : module->functions()) { + char function_name[128]; + if (function->name().empty()) { + std::snprintf(function_name, sizeof(function_name), "@%d", + function->ordinal()); + } else { + std::snprintf(function_name, sizeof(function_name), "@%d %s", + function->ordinal(), function->name().c_str()); + } + if (filter.IsActive() && !filter.PassFilter(function_name)) { + continue; + } + ImGui::PushID(function->ordinal()); + bool is_selected = false; + if (ImGui::Selectable("##selectable", &is_selected, + ImGuiSelectableFlags_AllowDoubleClick | + ImGuiSelectableFlags_DrawFillAvailWidth)) { + if (is_selected) { + RETURN_IF_ERROR(NavigateToCodeView(module->name(), + function->ordinal(), 0, + NavigationMode::kMatchDocument)); + } + } + ImGui::SameLine(); + // TODO(benvanik): detect if breakpoint active at offset 0. + ImGui::BulletText("%s", function_name); + ImGui::PopID(); + } + } else { + ImGui::TextDisabled("Loading..."); + } + ImGui::TreePop(); + } + ImGui::PopID(); + return OkStatus(); +} + +Status DebugApp::DrawLocalListPanel() { + static bool is_panel_visible = true; + if (!ImGui::Begin("Locals", &is_panel_visible, ImGuiWindowFlags_None)) { + ImGui::End(); + return OkStatus(); + } else if (!debug_client_) { + ImGui::TextDisabled("disconnected"); + ImGui::End(); + return OkStatus(); + } + auto* invocation = GetSelectedInvocation(); + if (!invocation) { + ImGui::TextDisabled("select a invocation to view locals"); + ImGui::End(); + return OkStatus(); + } else if (invocation->def().frames.empty()) { + ImGui::TextDisabled("(invocation has no frames)"); + ImGui::End(); + return OkStatus(); + } + int stack_frame_index = selected_stack_frame_index_.value_or(-1); + if (stack_frame_index == -1) { + stack_frame_index = invocation->def().frames.size() - 1; + } + auto& stack_frame = invocation->def().frames[stack_frame_index]; + + // TODO(benvanik): toggle for IREE VM locals vs. source locals. + for (int i = 0; i < stack_frame->locals.size(); ++i) { + auto& local = stack_frame->locals[i]; + RETURN_IF_ERROR(DrawLocal(invocation, stack_frame_index, i, *local)); + } + + ImGui::End(); + return OkStatus(); +} + +Status DebugApp::DrawLocal(RemoteInvocation* invocation, int stack_frame_index, + int local_index, const rpc::BufferViewDefT& local) { + // TODO(benvanik): columns and such in fancy table. + ImGui::Text("l%d", local_index); + ImGui::SameLine(50); + if (local.is_valid) { + auto shape_str = + absl::StrCat(absl::StrJoin(local.shape, "x"), "x", local.element_size); + ImGui::Text("%s", shape_str.c_str()); + } else { + ImGui::TextDisabled("∅"); + } + // TODO(benvanik): editing options (change shape, change contents, upload). + // TODO(benvanik): save/download/log options. + return OkStatus(); +} + +Status DebugApp::DrawInvocationListPanel() { + static bool is_panel_visible = true; + if (!ImGui::Begin("Invocations", &is_panel_visible, ImGuiWindowFlags_None)) { + ImGui::End(); + return OkStatus(); + } else if (!debug_client_) { + ImGui::TextDisabled("disconnected"); + ImGui::End(); + return OkStatus(); + } + for (auto* invocation : debug_client_->invocations()) { + RETURN_IF_ERROR(DrawInvocation(*invocation)); + } + ImGui::End(); + return OkStatus(); +} + +Status DebugApp::DrawInvocation(const RemoteInvocation& invocation) { + // TODO(benvanik): expand if any breakpoints are stopped in invocation. + if (selected_invocation_id_.has_value() && + selected_invocation_id_.value() == invocation.id()) { + ImGui::SetNextTreeNodeOpen(true); + } + if (!ImGui::CollapsingHeader(invocation.name().c_str())) { + return OkStatus(); + } + ImGui::PushID(invocation.id()); + + for (int i = 0; i < invocation.def().frames.size(); ++i) { + const auto& stack_frame = invocation.def().frames[i]; + ImGui::PushID(i); + // TODO(benvanik): highlight frames with breakpoints in them. + bool is_selected = selected_invocation_id_.has_value() && + selected_invocation_id_.value() == invocation.id() && + selected_stack_frame_index_.has_value() && + selected_stack_frame_index_.value() == i; + if (ImGui::Selectable("##selectable", &is_selected, + ImGuiSelectableFlags_AllowDoubleClick | + ImGuiSelectableFlags_DrawFillAvailWidth)) { + // TODO(benvanik): detect when clicking but already selected. + if (is_selected) { + RETURN_IF_ERROR( + NavigateToCodeView(invocation, i, NavigationMode::kMatchDocument)); + } + } + ImGui::SameLine(); + ImGui::Bullet(); + ImGui::SameLine(); + // TODO(benvanik): better naming/etc (resolve function). + ImGui::Text("%s:%d:%d", stack_frame->module_name.c_str(), + stack_frame->function_ordinal, stack_frame->offset); + + ImGui::PopID(); + } + + ImGui::PopID(); + return OkStatus(); +} + +DebugApp::CodeViewDocument* DebugApp::FindMatchingDocument( + absl::string_view module_name, int function_ordinal) { + for (auto& document : documents_) { + if (document->function->module()->name() == module_name && + document->function->ordinal() == function_ordinal) { + return document.get(); + } + } + return nullptr; +} + +Status DebugApp::NavigateToCodeView(absl::string_view module_name, + int function_ordinal, int offset, + NavigationMode navigation_mode) { + if (!debug_client_) { + return UnavailableErrorBuilder(IREE_LOC) << "No connection established"; + } + VLOG(1) << "NavigateToCodeView(" << module_name << ", " << function_ordinal + << ", " << offset << ")"; + CodeViewDocument* existing_document = nullptr; + switch (navigation_mode) { + case NavigationMode::kNewDocument: + // Fall through and create below. + break; + case NavigationMode::kCurrentDocument: + // Not yet done - treat as a new document. + break; + case NavigationMode::kMatchDocument: + existing_document = FindMatchingDocument(module_name, function_ordinal); + break; + } + if (existing_document) { + ImGui::SetWindowFocus(existing_document->title.c_str()); + return OkStatus(); + } + + // TODO(benvanik): make this common code. + RETURN_IF_ERROR(debug_client_->GetFunction( + std::string(module_name), function_ordinal, + [this, offset](StatusOr<RemoteFunction*> function_or) { + if (!function_or.ok()) { + // TODO(benvanik): error dialog. + CHECK_OK(function_or.status()); + } + auto* function = function_or.ValueOrDie(); + auto document = absl::make_unique<CodeViewDocument>(); + document->title = + absl::StrCat(function->module()->name(), ":", function->name()); + document->function = function; + document->focus_offset = offset; + ImGui::SetWindowFocus(document->title.c_str()); + documents_.push_back(std::move(document)); + })); + return OkStatus(); +} + +Status DebugApp::NavigateToCodeView(absl::string_view module_name, + absl::string_view function_name, int offset, + NavigationMode navigation_mode) { + if (!debug_client_) { + return UnavailableErrorBuilder(IREE_LOC) << "No connection established"; + } + return debug_client_->ResolveFunction( + std::string(module_name), std::string(function_name), + [this, navigation_mode, module_name, + offset](StatusOr<int> function_ordinal) { + CHECK_OK(function_ordinal.status()); + CHECK_OK(NavigateToCodeView(module_name, function_ordinal.ValueOrDie(), + offset, navigation_mode)); + }); +} + +Status DebugApp::NavigateToCodeView(const RemoteInvocation& invocation, + int stack_frame_index, + NavigationMode navigation_mode) { + if (!debug_client_) { + return UnavailableErrorBuilder(IREE_LOC) << "No connection established"; + } + const auto& stack_frame = stack_frame_index == -1 + ? *invocation.def().frames.back() + : *invocation.def().frames[stack_frame_index]; + selected_invocation_id_ = invocation.id(); + selected_stack_frame_index_ = stack_frame_index; + return NavigateToCodeView(stack_frame.module_name, + stack_frame.function_ordinal, stack_frame.offset, + NavigationMode::kMatchDocument); +} + +Status DebugApp::NavigateToCodeView(const UserBreakpoint& user_breakpoint, + NavigationMode navigation_mode) { + if (!debug_client_) { + return UnavailableErrorBuilder(IREE_LOC) << "No connection established"; + } + switch (user_breakpoint.type) { + case RemoteBreakpoint::Type::kBytecodeFunction: + if (user_breakpoint.function_ordinal != -1) { + return NavigateToCodeView( + user_breakpoint.module_name, user_breakpoint.function_ordinal, + user_breakpoint.bytecode_offset, navigation_mode); + } else { + return NavigateToCodeView( + user_breakpoint.module_name, user_breakpoint.function_name, + user_breakpoint.bytecode_offset, navigation_mode); + } + case RemoteBreakpoint::Type::kNativeFunction: + return UnimplementedErrorBuilder(IREE_LOC) + << "Navigation to non-bytecode functions unimplemented"; + } +} + +Status DebugApp::DrawCodeViewPanels() { + // If we've disconnected then we need to clear bodies. + // TODO(benvanik): allow documents to persist by caching all required info. + if (!debug_client_) { + documents_.clear(); + return OkStatus(); + } + + std::vector<CodeViewDocument*> closing_documents; + for (auto& document : documents_) { + ASSIGN_OR_RETURN(bool is_open, DrawCodeViewDocument(document.get())); + if (!is_open) { + closing_documents.push_back(document.get()); + } + } + for (auto* closing_document : closing_documents) { + auto it = std::find_if( + documents_.begin(), documents_.end(), + [closing_document](const std::unique_ptr<CodeViewDocument>& document) { + return document.get() == closing_document; + }); + documents_.erase(it); + } + return OkStatus(); +} + +StatusOr<bool> DebugApp::DrawCodeViewDocument(CodeViewDocument* document) { + ImGui::SetNextWindowDockID(dockspace_id_, ImGuiCond_FirstUseEver); + ImGui::PushStyleVar(ImGuiStyleVar_WindowPadding, ImVec2(0, 0)); + bool is_open = true; + bool is_visible = + ImGui::Begin(document->title.c_str(), &is_open, ImGuiWindowFlags_None); + if (!is_open || !is_visible) { + ImGui::End(); + ImGui::PopStyleVar(); + return is_open; + } + ImGui::PopStyleVar(); + + auto* remote_module = document->function->module(); + auto* remote_function = document->function; + if (remote_module->CheckLoadedOrRequest() && + remote_function->CheckLoadedOrRequest()) { + // TODO(benvanik): draw function signature. + if (remote_function->bytecode()) { + RETURN_IF_ERROR(DrawBytecodeCodeView(document)); + } else { + // TODO(benvanik): display native registration info. + ImGui::TextDisabled("(native)"); + } + } else { + ImGui::TextDisabled("loading..."); + } + + ImGui::End(); + return true; +} + +Status DebugApp::PrepareBytecodeCodeView(CodeViewDocument* document) { + auto* remote_module = document->function->module(); + auto* remote_function = document->function; + + document->bytecode_info.lines = remote_function->name(); + + return OkStatus(); +} + +Status DebugApp::DrawBytecodeCodeView(CodeViewDocument* document) { + // Ensure we have cached our line information. + RETURN_IF_ERROR(PrepareBytecodeCodeView(document)); + + auto* remote_module = document->function->module(); + auto* remote_function = document->function; + + ImGui::BeginGroup(); + ImGui::BeginChild("##bytecode_view", ImVec2(0, 0), false, + ImGuiWindowFlags_AlwaysVerticalScrollbar); + ImGui::PushStyleVar(ImGuiStyleVar_ItemSpacing, ImVec2(0, 0)); + + // TODO(benvanik): cache breakpoints for this function for faster lookup. + + auto& bytecode_info = document->bytecode_info; + ImGuiListClipper clipper(bytecode_info.lines.size(), + ImGui::GetTextLineHeightWithSpacing()); + while (clipper.Step()) { + for (int i = clipper.DisplayStart; i < clipper.DisplayEnd; ++i) { + ImGui::PushID(i); + + // TODO(benvanik): lookup line info. + int bytecode_offset = 0; + int breakpoint_index = FindMatchingUserBreakpointIndex( + remote_module->name(), remote_function->ordinal(), bytecode_offset); + bool has_breakpoint = breakpoint_index != -1; + bool active_on_any_invocation = false; + bool active_on_selected_invocation = false; + + ImGui::Dummy(ImVec2(4, 0)); + + // Gutter breakpoint button. + ImGui::SameLine(); + if (has_breakpoint) { + PushButtonHue(0.0f); // Red + if (ImGui::Button(" ##toggle_breakpoint")) { + CHECK_GE(breakpoint_index, 0); + auto& user_breakpoint = user_breakpoint_list_[breakpoint_index]; + if (user_breakpoint.active_breakpoint) { + RETURN_IF_ERROR(debug_client_->RemoveBreakpoint( + *user_breakpoint.active_breakpoint)); + } + user_breakpoint_list_.erase(user_breakpoint_list_.begin() + + breakpoint_index); + } + PopButtonStyle(); + if (ImGui::IsItemHovered()) { + ImGui::SetTooltip("Remove the breakpoint at this offset."); + } + } else { + PushButtonColor(ImGui::GetStyleColorVec4(ImGuiCol_ChildBg)); + if (ImGui::Button(" ##toggle_breakpoint")) { + UserBreakpoint user_breakpoint; + user_breakpoint.type = RemoteBreakpoint::Type::kBytecodeFunction; + user_breakpoint.module_name = remote_module->name(); + user_breakpoint.function_name = remote_function->name(); + user_breakpoint.bytecode_offset = bytecode_offset; + user_breakpoint.wants_enabled = true; + user_breakpoint_list_.push_back(std::move(user_breakpoint)); + } + PopButtonStyle(); + if (ImGui::IsItemHovered()) { + ImGui::SetTooltip("Add a breakpoint at this offset."); + } + } + + // Active execution chevron (shows when active or any invocation is + // executing this region). + ImGui::SameLine(); + if (active_on_selected_invocation) { + // The selected invocation is active here. + ImGui::TextColored(ImGui::GetStyleColorVec4(ImGuiCol_SeparatorActive), + " > "); + } else if (active_on_any_invocation) { + // At least one other invocation is active here. + ImGui::TextColored(ImGui::GetStyleColorVec4(ImGuiCol_Separator), " > "); + } else { + // Not active. + ImGui::Text(" "); + } + + // Line contents. + ImGui::SameLine(); + ImGui::Text("%s", bytecode_info.lines[i].c_str()); + + if (document->focus_offset.has_value() && + bytecode_offset == document->focus_offset.value()) { + document->bytecode_offset = document->focus_offset.value(); + document->focus_offset = {}; + ImGui::SetScrollHereY(); + } + + ImGui::PopID(); + } + } + + ImGui::PopStyleVar(); + ImGui::EndChild(); + ImGui::EndGroup(); + + return OkStatus(); +} + +} // namespace debug +} // namespace rt +} // namespace iree
diff --git a/tools/debugger/debug_app.h b/tools/debugger/debug_app.h new file mode 100644 index 0000000..a7954cf --- /dev/null +++ b/tools/debugger/debug_app.h
@@ -0,0 +1,200 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_TOOLS_DEBUGGER_DEBUG_APP_H_ +#define IREE_TOOLS_DEBUGGER_DEBUG_APP_H_ + +#include <SDL.h> + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "base/status.h" +#include "rt/debug/debug_client.h" + +// NOTE: order matters here, imgui must come first: +#include "third_party/dear_imgui/imgui.h" +// NOTE: must follow imgui.h: +#include "third_party/dear_imgui/examples/imgui_impl_opengl3.h" +#include "third_party/dear_imgui/examples/imgui_impl_sdl.h" + +namespace iree { +namespace rt { +namespace debug { + +// Debug client app UI. +// Uses a DebugClient to communicate with a remote DebugServer and ImGui to +// display a nifty UI. +// +// See the ImGui site for more info: https://github.com/ocornut/imgui +// The most useful thing is the imgui_demo.cpp file that contains example usage +// of most features. +class DebugApp : private DebugClient::Listener { + public: + struct UserBreakpoint { + RemoteBreakpoint::Type type = RemoteBreakpoint::Type::kBytecodeFunction; + const RemoteBreakpoint* active_breakpoint = nullptr; + bool wants_enabled = true; + bool is_enabling = false; + // TODO(benvanik): reuse BreakpointDef here? + std::string module_name; + std::string function_name; + int function_ordinal = -1; + int bytecode_offset = 0; + std::string native_function; + }; + + static void PumpMainLoopThunk(void* arg); + + DebugApp(SDL_Window* window, SDL_GLContext gl_context, + const char* glsl_version); + ~DebugApp(); + + // Connects to the service at the specified address. + Status Connect(absl::string_view service_address); + // Disconnects from the currently connected service, if any. + Status Disconnect(); + + // Returns true if the remote service is paused at our request. + bool is_paused() const; + + // Pumps the main UI loop once. + // This polls the DebugClient, SDL input, and renders the UI. + // It should be called as frequently as possible to ensure snappy UI updates. + // Returns CancelledError if the app is being closed by the user. + Status PumpMainLoop(); + + // Defines how NavigationToCodeView methods behave. + enum class NavigationMode { + // The target will be opened in a new document tab. + kNewDocument, + // The target will be opened in the current document tab, replacing the + // current contents. + kCurrentDocument, + // The target will be opened in a document tab that mostly matches (like + // the same function in a module at a different offset), otherwise a new + // document will be opened. + kMatchDocument, + }; + + // Navigates to a particular function offset based on resolution of the given + // arguments. Navigation may happen asynchronously if targets need to be + // resolved or contents fetched. + Status NavigateToCodeView(absl::string_view module_name, int function_ordinal, + int offset, NavigationMode navigation_mode); + Status NavigateToCodeView(absl::string_view module_name, + absl::string_view function_name, int offset, + NavigationMode navigation_mode); + Status NavigateToCodeView(const RemoteInvocation& invocation, + int stack_frame_index, + NavigationMode navigation_mode); + Status NavigateToCodeView(const UserBreakpoint& user_breakpoint, + NavigationMode navigation_mode); + + private: + struct CodeViewDocument { + // Document display title (and ID). + std::string title; + // Function (and offset within the function) being displayed. + RemoteFunction* function = nullptr; + int bytecode_offset = 0; + // Set to a bytecode offset to have the document focus there. + absl::optional<int> focus_offset; + // Cached info for bytecode display. + struct { + std::vector<std::string> lines; + } bytecode_info; + }; + + CodeViewDocument* FindMatchingDocument(absl::string_view module_name, + int function_ordinal); + RemoteInvocation* GetSelectedInvocation() const; + + Status RefreshActiveBreakpoints(); + bool IsStoppedAtBreakpoint(const UserBreakpoint& user_breakpoint) const; + int FindMatchingUserBreakpointIndex(absl::string_view module_name, + int function_ordinal, int offset); + int FindMatchingUserBreakpointIndex(absl::string_view module_name, + absl::string_view function_name, + int offset); + Status ResumeFromBreakpoint(UserBreakpoint* user_breakpoint); + + Status OnContextRegistered(const RemoteContext& context) override; + Status OnContextUnregistered(const RemoteContext& context) override; + Status OnModuleLoaded(const RemoteContext& context, + const RemoteModule& module) override; + Status OnInvocationRegistered(const RemoteInvocation& invocation) override; + Status OnInvocationUnregistered(const RemoteInvocation& invocation) override; + Status OnBreakpointHit(const RemoteBreakpoint& breakpoint, + const RemoteInvocation& invocation) override; + + Status LayoutInitialDockSpace(); + + Status DrawUI(); + Status DrawMainMenu(); + Status DrawToolbar(); + + Status DrawBreakpointListPanel(); + StatusOr<bool> DrawBreakpoint(UserBreakpoint* user_breakpoint); + Status DrawAddBreakpointDialogs( + absl::optional<RemoteBreakpoint::Type> add_breakpoint_type); + Status DrawAddBytecodeFunctionBreakpointDialog(); + Status DrawAddNativeFunctionBreakpointDialog(); + + Status DrawModuleListPanel(); + Status DrawContext(const RemoteContext& context, + const ImGuiTextFilter& filter); + Status DrawModule(RemoteModule* module, const ImGuiTextFilter& filter); + + Status DrawLocalListPanel(); + Status DrawLocal(RemoteInvocation* invocation, int stack_frame_index, + int local_index, const rpc::BufferViewDefT& local); + + Status DrawInvocationListPanel(); + Status DrawInvocation(const RemoteInvocation& invocation); + + Status DrawCodeViewPanels(); + StatusOr<bool> DrawCodeViewDocument(CodeViewDocument* document); + Status PrepareBytecodeCodeView(CodeViewDocument* document); + Status DrawBytecodeCodeView(CodeViewDocument* document); + + SDL_Window* window_ = nullptr; + SDL_GLContext gl_context_ = nullptr; + + ImGuiID dockspace_id_; + ImGuiID dock_top_id_; + ImGuiID dock_left_id_; + ImGuiID dock_bottom_id_; + ImGuiID dock_bottom_left_id_; + ImGuiID dock_bottom_right_id_; + ImGuiID dock_right_id_; + ImGuiID dock_content_id_; + + std::unique_ptr<DebugClient> debug_client_; + std::vector<UserBreakpoint> user_breakpoint_list_; + + bool is_paused_ = false; + std::vector<const RemoteBreakpoint*> hit_breakpoints_; + bool is_stepping_ = false; + + absl::optional<int> selected_invocation_id_; + absl::optional<int> selected_stack_frame_index_; + + std::vector<std::unique_ptr<CodeViewDocument>> documents_; +}; + +} // namespace debug +} // namespace rt +} // namespace iree + +#endif // IREE_TOOLS_DEBUGGER_DEBUG_APP_H_
diff --git a/iree/tools/debugger/debug_app.html b/tools/debugger/debug_app.html similarity index 100% rename from iree/tools/debugger/debug_app.html rename to tools/debugger/debug_app.html
diff --git a/tools/debugger/debug_app_embedded.cc b/tools/debugger/debug_app_embedded.cc new file mode 100644 index 0000000..ba70792 --- /dev/null +++ b/tools/debugger/debug_app_embedded.cc
@@ -0,0 +1,153 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/debugger/debug_app_embedded.h" + +#include <SDL.h> + +#include <thread> // NOLINT + +#include "absl/base/thread_annotations.h" +#include "absl/memory/memory.h" +#include "absl/synchronization/mutex.h" +#include "base/memory.h" +#include "base/status.h" +#include "third_party/SDL2/include/SDL_thread.h" +#include "tools/debugger/debug_app.h" + +namespace iree { +namespace rt { +namespace debug { + +class InProcessEmbeddedDebugger : public EmbeddedDebugger { + public: + explicit InProcessEmbeddedDebugger(std::unique_ptr<DebugApp> app) + : app_(std::move(app)) { + thread_ = + SDL_CreateThread(&ThreadMainThunk, "InProcessEmbeddedDebugger", this); + } + + ~InProcessEmbeddedDebugger() override { + VLOG(1) << "Setting shutdown flag and waiting on thread..."; + shutdown_flag_ = true; + int status = 0; + SDL_WaitThread(thread_, &status); + VLOG(1) << "Thread shutdown, killing app..."; + app_.reset(); + } + + Status AwaitClose() override { + await_mutex_.LockWhen(absl::Condition( + +[](bool* is_shutdown) { return *is_shutdown; }, &is_shutdown_)); + auto status = std::move(shutdown_status_); + await_mutex_.Unlock(); + return status; + } + + private: + static int ThreadMainThunk(void* arg) { + return reinterpret_cast<InProcessEmbeddedDebugger*>(arg)->ThreadMain(); + } + + int ThreadMain() { + VLOG(1) << "Thread entry"; + while (!shutdown_flag_) { + auto status = app_->PumpMainLoop(); + if (IsCancelled(status)) { + shutdown_flag_ = true; + break; + } else if (!shutdown_flag_ && !status.ok()) { + absl::MutexLock lock(&await_mutex_); + shutdown_status_ = std::move(status); + // TODO(benvanik): don't check unless no one is watching. + CHECK_OK(shutdown_status_); + } + } + app_.reset(); + { + absl::MutexLock lock(&await_mutex_); + is_shutdown_ = true; + } + VLOG(1) << "Thread exit"; + return 0; + } + + std::unique_ptr<DebugApp> app_; + SDL_Thread* thread_; + std::atomic<bool> shutdown_flag_ = {false}; + absl::Mutex await_mutex_; + bool is_shutdown_ ABSL_GUARDED_BY(await_mutex_) = false; + Status shutdown_status_ ABSL_GUARDED_BY(await_mutex_); +}; + +StatusOr<std::unique_ptr<EmbeddedDebugger>> LaunchDebugger() { + return AttachDebugger(""); +} + +StatusOr<std::unique_ptr<EmbeddedDebugger>> AttachDebugger( + absl::string_view service_address) { + LOG(INFO) << "Launching embedded debugger; service=" << service_address; + // Workaround for terrible bad SDL/graphics driver leaks. + IREE_DISABLE_LEAK_CHECKS(); + + if (SDL_Init(SDL_INIT_VIDEO | SDL_INIT_TIMER) != 0) { + return InternalErrorBuilder(IREE_LOC) + << "Unable to init SDL: " << SDL_GetError(); + } + +#if __APPLE__ + // GL 3.2 Core + GLSL 150 + const char* glsl_version = "#version 150"; + SDL_GL_SetAttribute( + SDL_GL_CONTEXT_FLAGS, + SDL_GL_CONTEXT_FORWARD_COMPATIBLE_FLAG); // Always required on Mac + SDL_GL_SetAttribute(SDL_GL_CONTEXT_PROFILE_MASK, SDL_GL_CONTEXT_PROFILE_CORE); + SDL_GL_SetAttribute(SDL_GL_CONTEXT_MAJOR_VERSION, 3); + SDL_GL_SetAttribute(SDL_GL_CONTEXT_MINOR_VERSION, 2); +#else + // GL 3.0 + GLSL 130 + const char* glsl_version = "#version 130"; + SDL_GL_SetAttribute(SDL_GL_CONTEXT_FLAGS, 0); + SDL_GL_SetAttribute(SDL_GL_CONTEXT_PROFILE_MASK, SDL_GL_CONTEXT_PROFILE_CORE); + SDL_GL_SetAttribute(SDL_GL_CONTEXT_MAJOR_VERSION, 3); + SDL_GL_SetAttribute(SDL_GL_CONTEXT_MINOR_VERSION, 0); +#endif + + SDL_GL_SetAttribute(SDL_GL_DOUBLEBUFFER, 1); + SDL_GL_SetAttribute(SDL_GL_DEPTH_SIZE, 24); + SDL_GL_SetAttribute(SDL_GL_STENCIL_SIZE, 8); + SDL_DisplayMode current; + SDL_GetCurrentDisplayMode(0, ¤t); + SDL_WindowFlags window_flags = (SDL_WindowFlags)( + SDL_WINDOW_OPENGL | SDL_WINDOW_RESIZABLE | SDL_WINDOW_ALLOW_HIGHDPI); + SDL_Window* window = + SDL_CreateWindow("IREE Debugger (embedded)", SDL_WINDOWPOS_CENTERED, + SDL_WINDOWPOS_CENTERED, 1280, 720, window_flags); + SDL_GLContext gl_context = SDL_GL_CreateContext(window); + SDL_GL_MakeCurrent(nullptr, nullptr); + + IREE_ENABLE_LEAK_CHECKS(); + + auto app = absl::make_unique<DebugApp>(window, gl_context, glsl_version); + if (!service_address.empty()) { + RETURN_IF_ERROR(app->Connect(service_address)); + } + + auto handle = absl::make_unique<InProcessEmbeddedDebugger>(std::move(app)); + return handle; +} + +} // namespace debug +} // namespace rt +} // namespace iree
diff --git a/tools/debugger/debug_app_embedded.h b/tools/debugger/debug_app_embedded.h new file mode 100644 index 0000000..e597ad2 --- /dev/null +++ b/tools/debugger/debug_app_embedded.h
@@ -0,0 +1,52 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_TOOLS_DEBUGGER_DEBUG_APP_EMBEDDED_H_ +#define IREE_TOOLS_DEBUGGER_DEBUG_APP_EMBEDDED_H_ + +#include <memory> + +#include "absl/strings/string_view.h" +#include "base/status.h" + +namespace iree { +namespace rt { +namespace debug { + +// RAII handle for keeping the debugger alive. +// When the instance is destroyed the debugger app will be closed. +class EmbeddedDebugger { + public: + virtual ~EmbeddedDebugger() = default; + + // Blocks the caller until the debugger is closed by the user. + virtual Status AwaitClose() = 0; +}; + +// Launches the debugger app. +// Returns a handle that can be used to wait for the debugger to close or +// force it to close. +StatusOr<std::unique_ptr<EmbeddedDebugger>> LaunchDebugger(); + +// Launches the debugger app and attaches to the given server address. +// Returns a handle that can be used to wait for the debugger to close or +// force it to close. +StatusOr<std::unique_ptr<EmbeddedDebugger>> AttachDebugger( + absl::string_view service_address); + +} // namespace debug +} // namespace rt +} // namespace iree + +#endif // IREE_TOOLS_DEBUGGER_DEBUG_APP_EMBEDDED_H_
diff --git a/tools/debugger/debug_app_main_emscripten.cc b/tools/debugger/debug_app_main_emscripten.cc new file mode 100644 index 0000000..2746681 --- /dev/null +++ b/tools/debugger/debug_app_main_emscripten.cc
@@ -0,0 +1,69 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Emscripten debug_app entry point. +// Though we are using SDL here we need to do some emscripten-specific magic to +// handle the different main looping mode (as we can't block in main() like on +// other platforms) as well as support some emscripten-specific features for +// file upload/download/etc. + +#include <SDL.h> +#include <emscripten.h> + +#include "base/init.h" +#include "tools/debugger/debug_app.h" + +namespace iree { +namespace rt { +namespace debug { + +extern "C" int main(int argc, char** argv) { + InitializeEnvironment(&argc, &argv); + + if (SDL_Init(SDL_INIT_VIDEO) != 0) { + printf("Error: %s\n", SDL_GetError()); + return -1; + } + + const char* glsl_version = "#version 100"; + SDL_GL_SetAttribute(SDL_GL_CONTEXT_FLAGS, 0); + SDL_GL_SetAttribute(SDL_GL_CONTEXT_PROFILE_MASK, SDL_GL_CONTEXT_PROFILE_ES); + SDL_GL_SetAttribute(SDL_GL_CONTEXT_MAJOR_VERSION, 2); + SDL_GL_SetAttribute(SDL_GL_CONTEXT_MINOR_VERSION, 0); + + SDL_GL_SetAttribute(SDL_GL_DOUBLEBUFFER, 1); + SDL_GL_SetAttribute(SDL_GL_DEPTH_SIZE, 24); + SDL_GL_SetAttribute(SDL_GL_STENCIL_SIZE, 8); + SDL_DisplayMode current; + SDL_GetCurrentDisplayMode(0, ¤t); + SDL_WindowFlags window_flags = (SDL_WindowFlags)( + SDL_WINDOW_OPENGL | SDL_WINDOW_RESIZABLE | SDL_WINDOW_ALLOW_HIGHDPI); + SDL_Window* window = + SDL_CreateWindow("IREE Debugger", SDL_WINDOWPOS_CENTERED, + SDL_WINDOWPOS_CENTERED, 1280, 720, window_flags); + SDL_GLContext gl_context = SDL_GL_CreateContext(window); + if (!gl_context) { + printf("Failed to initialize WebGL context!\n"); + return 1; + } + + auto app = absl::make_unique<DebugApp>(window, gl_context, glsl_version); + ::emscripten_set_main_loop_arg(DebugApp::PumpMainLoopThunk, app.release(), 0, + false); + return 0; +} + +} // namespace debug +} // namespace rt +} // namespace iree
diff --git a/tools/debugger/debug_app_main_native.cc b/tools/debugger/debug_app_main_native.cc new file mode 100644 index 0000000..1753e9f --- /dev/null +++ b/tools/debugger/debug_app_main_native.cc
@@ -0,0 +1,45 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Native (linux/etc) debug_app entry point. +// This should work on any platform with pthreads and SDL support. + +#include "base/init.h" +#include "base/status.h" +#include "tools/debugger/debug_app_embedded.h" + +namespace iree { +namespace rt { +namespace debug { + +Status Run() { + ASSIGN_OR_RETURN(auto handle, LaunchDebugger()); + RETURN_IF_ERROR(handle->AwaitClose()); + handle.reset(); + return OkStatus(); +} + +extern "C" int main(int argc, char** argv) { + InitializeEnvironment(&argc, &argv); + auto status = Run(); + if (!status.ok()) { + LOG(ERROR) << "Debugger error: " << status; + return 1; + } + return 0; +} + +} // namespace debug +} // namespace rt +} // namespace iree
diff --git a/tools/debugger/debug_cli_main.cc b/tools/debugger/debug_cli_main.cc new file mode 100644 index 0000000..4ff33e4 --- /dev/null +++ b/tools/debugger/debug_cli_main.cc
@@ -0,0 +1,40 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/flags/flag.h" +#include "base/init.h" +#include "base/status.h" +#include "tools/debugger/debug_prompt.h" + +ABSL_FLAG(std::string, debug_service_uri, "0.0.0.0:6000", + "IP/port of debug service to connect to."); + +namespace iree { +namespace rt { +namespace debug { + +Status Run() { + // TODO(benvanik): retry until connected? would allow auto-build reconnects. + return AttachDebugPrompt(absl::GetFlag(FLAGS_debug_service_uri)); +} + +extern "C" int main(int argc, char** argv) { + InitializeEnvironment(&argc, &argv); + CHECK_OK(Run()); + return 0; +} + +} // namespace debug +} // namespace rt +} // namespace iree
diff --git a/tools/debugger/debug_prompt.cc b/tools/debugger/debug_prompt.cc new file mode 100644 index 0000000..729d432 --- /dev/null +++ b/tools/debugger/debug_prompt.cc
@@ -0,0 +1,90 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/debugger/debug_prompt.h" + +#include "base/status.h" +#include "rt/debug/debug_client.h" + +namespace iree { +namespace rt { +namespace debug { +namespace { + +class DebugPrompt : private DebugClient::Listener { + public: + Status Connect(absl::string_view debug_service_uri) { + // Connect to the debug service. + ASSIGN_OR_RETURN(debug_client_, + DebugClient::Connect(debug_service_uri, this)); + return OkStatus(); + } + + Status Run() { + // Query commands, transmit requests, and dispatch responses. + while (true) { + RETURN_IF_ERROR(debug_client_->Poll()); + + // TODO(benvanik): ask for a command. + // TODO(benvanik): display stuff. + } + } + + private: + Status OnContextRegistered(const RemoteContext& context) override { + // Ack. + return debug_client_->MakeReady(); + } + + Status OnContextUnregistered(const RemoteContext& context) override { + // Ack. + return debug_client_->MakeReady(); + } + + Status OnModuleLoaded(const RemoteContext& context, + const RemoteModule& module) override { + // Ack. + return debug_client_->MakeReady(); + } + + Status OnInvocationRegistered(const RemoteInvocation& invocation) override { + // Ack. + return debug_client_->MakeReady(); + } + + Status OnInvocationUnregistered(const RemoteInvocation& invocation) override { + // Ack. + return debug_client_->MakeReady(); + } + + Status OnBreakpointHit(const RemoteBreakpoint& breakpoint, + const RemoteInvocation& invocation) override { + // Ack. + return debug_client_->MakeReady(); + } + + std::unique_ptr<DebugClient> debug_client_; +}; + +} // namespace + +Status AttachDebugPrompt(absl::string_view debug_service_uri) { + DebugPrompt debug_prompt; + RETURN_IF_ERROR(debug_prompt.Connect(debug_service_uri)); + return debug_prompt.Run(); +} + +} // namespace debug +} // namespace rt +} // namespace iree
diff --git a/tools/debugger/debug_prompt.h b/tools/debugger/debug_prompt.h new file mode 100644 index 0000000..1eb2183 --- /dev/null +++ b/tools/debugger/debug_prompt.h
@@ -0,0 +1,35 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_TOOLS_DEBUGGER_DEBUG_PROMPT_H_ +#define IREE_TOOLS_DEBUGGER_DEBUG_PROMPT_H_ + +#include "absl/strings/string_view.h" +#include "base/status.h" + +namespace iree { +namespace rt { +namespace debug { + +// TODO(benvanik): take stdin/stdout as arguments. +// Attaches a debug prompt reading stdin for commands and printing results to +// stdout. The calling thread will block until the debugger is exited or the +// debug service closes. +Status AttachDebugPrompt(absl::string_view debug_service_uri); + +} // namespace debug +} // namespace rt +} // namespace iree + +#endif // IREE_TOOLS_DEBUGGER_DEBUG_PROMPT_H_
diff --git a/iree/tools/iree_translate_main.cc b/tools/iree_translate_main.cc similarity index 100% rename from iree/tools/iree_translate_main.cc rename to tools/iree_translate_main.cc
diff --git a/iree/tools/run_lit.sh b/tools/run_lit.oss.sh similarity index 100% rename from iree/tools/run_lit.sh rename to tools/run_lit.oss.sh
diff --git a/tools/run_mlir_main.cc b/tools/run_mlir_main.cc new file mode 100644 index 0000000..5c832fb --- /dev/null +++ b/tools/run_mlir_main.cc
@@ -0,0 +1,360 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IREE source.mlir -> execution output test runner. +// This is meant to be called from LIT for FileCheck tests, and tries to match +// the interface of mlir-opt (featuring -split-input-file, etc) so it's easier +// to work with there. If you want a more generalized runner for standalone +// precompiled IREE modules use //third_party/iree/tools:run_module. +// +// By default all exported functions in the module will be run in order. +// All input values, provided via -input-values, will be passed to the +// functions (this means all input signatures must match). Results from the +// executed functions will be printed to stdout for checking. +// Use -output_types to set the function output data types, which like args will +// be used for all functions executed. +// +// Example input: +// // RUN: iree-run %s | FileCheck %s +// // CHECK-LABEL: @foo +// // CHECK: 1xf32: 2 +// func @foo() -> memref<f32> attributes {iree.module.export} { +// %0 = "iree.constant"() {value: dense<tensor<f32>, 2.0>} : () -> memref<f32> +// return %0 : memref<f32> +// } + +#include <iostream> + +#include "absl/flags/flag.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "base/init.h" +#include "base/source_location.h" +#include "base/status.h" +#include "compiler/Translation/Sequencer/SequencerModuleTranslation.h" +#include "hal/buffer_view_string_util.h" +#include "hal/driver_registry.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/SourceMgr.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/Parser.h" +#include "mlir/Support/FileUtilities.h" +#include "rt/context.h" +#include "rt/debug/debug_server_flags.h" +#include "rt/instance.h" +#include "rt/invocation.h" +#include "rt/module.h" +#include "rt/module_printer.h" +#include "schemas/module_def_generated.h" +#include "vm/sequencer_module.h" + +ABSL_FLAG(bool, split_input_file, true, + "Split the input file into multiple modules."); + +ABSL_FLAG(std::string, target_backends, "", + "Comma-separated list of target backends to translate executables " + "into. Omit to translate using all linked-in backend translators."); +ABSL_FLAG( + bool, export_all, true, + "Automatically add the iree.module.export attribute to all functions."); + +ABSL_FLAG(std::string, input_values, "", "Input shapes and optional values."); +ABSL_FLAG(std::string, output_types, "", + "Output data types (comma delimited list of b/i/u/f for " + "binary/signed int/unsigned int/float)."); + +// TODO(benvanik): is there a more canonical flag we can use? +ABSL_FLAG(bool, print_mlir, true, "Prints MLIR IR during translation."); + +ABSL_FLAG(bool, print_bytecode, false, + "Prints IREE bytecode after translation."); + +ABSL_FLAG(bool, run, true, + "Option to run the file. Setting it to false just compiles it."); + +namespace iree { +namespace { + +using ::iree::hal::BufferView; +using ::iree::rt::Function; +using ::iree::rt::Module; + +// Returns a driver name capable of handling input from the given backend. +std::string BackendToDriverName(std::string backend) { + size_t dash = backend.find('-'); + if (dash == std::string::npos) { + return backend; + } else { + return backend.substr(0, dash); + } +} + +// Prepares a module for evaluation by running MLIR import and IREE translation. +StatusOr<ref_ptr<Module>> PrepareModule( + std::string target_backend, + std::unique_ptr<llvm::MemoryBuffer> file_buffer) { + mlir::MLIRContext context; + + // Parse input MLIR module. + llvm::SourceMgr source_mgr; + source_mgr.AddNewSourceBuffer(std::move(file_buffer), llvm::SMLoc()); + mlir::OwningModuleRef mlir_module = + mlir::parseSourceFile(source_mgr, &context); + + if (absl::GetFlag(FLAGS_export_all)) { + for (auto function : mlir_module->getOps<mlir::FuncOp>()) { + function.setAttr("iree.module.export", mlir::UnitAttr::get(&context)); + } + } + + // Translate from MLIR to IREE bytecode. + mlir::iree_compiler::ModuleTranslationOptions options; + options.print_mlir = absl::GetFlag(FLAGS_print_mlir); + options.target_backends = {target_backend}; + auto iree_module_bytes = + mlir::iree_compiler::translateMlirToIreeSequencerModule(mlir_module.get(), + options); + if (iree_module_bytes.empty()) { + return iree::InternalErrorBuilder(IREE_LOC) + << "Error translating MLIR to an IREE sequencer module"; + } + + if (absl::GetFlag(FLAGS_print_mlir)) { + mlir_module->dump(); + } + + // Wrap module in a file handle. + ASSIGN_OR_RETURN(auto iree_module_file, + vm::ModuleFile::FromBuffer(ModuleDefIdentifier(), + std::move(iree_module_bytes))); + return vm::SequencerModule::FromFile(std::move(iree_module_file)); +} + +// Parses a list of input shapes and values from a string of newline-separated +// inputs. Expects the contents to have one value per line with each value +// listed as +// [shape]xtype=[value] +// Example: +// 4x4xi8=0,1,2,3 +StatusOr<std::vector<BufferView>> ParseInputsFromFlags( + hal::Allocator *allocator) { + std::string file_contents = + absl::StrReplaceAll(absl::GetFlag(FLAGS_input_values), {{"\\n", "\n"}}); + std::vector<BufferView> inputs; + std::vector<std::string> lines = absl::StrSplit( + file_contents, absl::ByAnyChar("\n;"), absl::SkipWhitespace()); + for (const auto &line : lines) { + ASSIGN_OR_RETURN(auto input, + hal::ParseBufferViewFromString(line, allocator)); + inputs.push_back(input); + } + return inputs; +} + +// Outputs all results from the function to stdout in IREE BufferView format. +Status OutputFunctionResults(const Function &function, + absl::Span<BufferView> results) { + std::vector<std::string> output_types = + absl::StrSplit(absl::GetFlag(FLAGS_output_types), absl::ByAnyChar(", "), + absl::SkipWhitespace()); + if (!output_types.empty() && output_types.size() != results.size()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "--output_types= specified but has " << output_types.size() + << " types when the function returns " << results.size(); + } + + for (int i = 0; i < results.size(); ++i) { + const auto &result = results[i]; + auto print_mode = hal::BufferViewPrintMode::kFloatingPoint; + if (!output_types.empty()) { + ASSIGN_OR_RETURN(print_mode, + hal::ParseBufferViewPrintMode(output_types[i])); + } + ASSIGN_OR_RETURN(auto result_str, + hal::PrintBufferViewToString(result, print_mode, 1024)); + LOG(INFO) << "result[" << i << "]: " << result.buffer->DebugString(); + std::cout << result_str << "\n"; + } + + return OkStatus(); +} + +// Evaluates a single function in its own fiber, printing the results to stdout. +Status EvaluateFunction(const ref_ptr<rt::Context> &context, + hal::Allocator *allocator, const Function &function) { + std::cout << "EXEC @" << function.name() << std::endl; + + // Create invocation that will perform the execution. + ASSIGN_OR_RETURN(auto arguments, ParseInputsFromFlags(allocator)); + ASSIGN_OR_RETURN( + auto invocation, + rt::Invocation::Create(add_ref(context), function, make_ref<rt::Policy>(), + {}, absl::MakeConstSpan(arguments))); + + // Wait until invocation completes. + RETURN_IF_ERROR(invocation->Await(absl::InfiniteFuture())); + + // Print outputs. + ASSIGN_OR_RETURN(auto results, invocation->ConsumeResults()); + RETURN_IF_ERROR(OutputFunctionResults(function, absl::MakeSpan(results))); + + return OkStatus(); +} + +// Evaluates all exported functions within given module. +Status EvaluateFunctions(absl::string_view target_backend, + ref_ptr<Module> module) { + // Create the context we'll use for this (ensuring that we can't interfere + // with other running evaluations, such as when in a multithreaded test + // runner). + ASSIGN_OR_RETURN(auto debug_server, rt::debug::CreateDebugServerFromFlags()); + auto instance = make_ref<rt::Instance>(std::move(debug_server)); + ASSIGN_OR_RETURN(auto driver, hal::DriverRegistry::shared_registry()->Create( + target_backend)); + ASSIGN_OR_RETURN(auto device, driver->CreateDefaultDevice()); + RETURN_IF_ERROR(instance->device_manager()->RegisterDevice(device)); + + if (absl::GetFlag(FLAGS_print_bytecode)) { + RETURN_IF_ERROR(rt::PrintModuleToStream( + *module, rt::PrintModuleFlag::kDisassemble, &std::cout)); + } + + // Evaluate all exported functions. + auto policy = make_ref<rt::Policy>(); + auto run_function = [&](int ordinal) -> Status { + // Setup a new context for this invocation. + auto context = make_ref<rt::Context>(add_ref(instance), add_ref(policy)); + RETURN_IF_ERROR(context->RegisterModule(add_ref(module))); + + // Invoke the function and print results. + ASSIGN_OR_RETURN(auto function, + module->LookupFunctionByOrdinal( + rt::Function::Linkage::kExport, ordinal)); + RETURN_IF_ERROR(EvaluateFunction(context, device->allocator(), function)); + return OkStatus(); + }; + + Status evaluate_status = OkStatus(); + for (int i = 0; i < module->signature().export_function_count(); ++i) { + evaluate_status = run_function(i); + if (!evaluate_status.ok()) { + break; + } + } + + RETURN_IF_ERROR(instance->device_manager()->UnregisterDevice(device.get())); + device.reset(); + driver.reset(); + + return evaluate_status; +} + +// Translates and runs a single LLVM file buffer. +Status EvaluateFile(std::unique_ptr<llvm::MemoryBuffer> file_buffer) { + std::vector<std::string> target_backends; + if (absl::GetFlag(FLAGS_target_backends).empty()) { + target_backends = + hal::DriverRegistry::shared_registry()->EnumerateAvailableDrivers(); + } else { + // We need to map specific backends names to drivers (like 'vulkan-spirv' to + // the driver 'vulkan'). + target_backends = absl::StrSplit(absl::GetFlag(FLAGS_target_backends), ','); + } + + for (auto target_backend : target_backends) { + // Prepare the module for execution and evaluate it. + auto cloned_file_buffer = llvm::MemoryBuffer::getMemBufferCopy( + file_buffer->getBuffer(), file_buffer->getBufferIdentifier()); + ASSIGN_OR_RETURN(auto module, PrepareModule(target_backend + '*', + std::move(cloned_file_buffer))); + if (!absl::GetFlag(FLAGS_run)) { + continue; + } + RETURN_IF_ERROR(EvaluateFunctions(BackendToDriverName(target_backend), + std::move(module))); + } + + return OkStatus(); +} + +// Runs the given .mlir file based on the current flags. +Status RunFile(std::string mlir_filename) { + // Load input file/from stdin. + std::string error_message; + auto file = mlir::openInputFile(mlir_filename, &error_message); + if (!file) { + return NotFoundErrorBuilder(IREE_LOC) + << "Unable to open input file " << mlir_filename << ": " + << error_message; + } + + if (!absl::GetFlag(FLAGS_split_input_file)) { + // Use entire buffer as a single module. + return EvaluateFile(std::move(file)); + } + + // Split the buffer into separate modules and evaluate independently. + // This matches the -split-input-file arg to mlir-opt. + const char kSplitMarker[] = "// -----\n"; + auto *full_buffer = file.get(); + llvm::SmallVector<llvm::StringRef, 8> source_buffers; + full_buffer->getBuffer().split(source_buffers, kSplitMarker); + + // Add the original buffer to the source manager. + llvm::SourceMgr fileSourceMgr; + fileSourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc()); + + // Process each chunk in turn. Only return the first error (but log all). + Status any_failure; + for (auto &sub_source_buffer : source_buffers) { + auto split_loc = llvm::SMLoc::getFromPointer(sub_source_buffer.data()); + unsigned split_line = fileSourceMgr.getLineAndColumn(split_loc).first; + auto sub_buffer = llvm::MemoryBuffer::getMemBufferCopy( + sub_source_buffer, full_buffer->getBufferIdentifier() + + llvm::Twine(" split at line #") + + llvm::Twine(split_line)); + auto sub_failure = EvaluateFile(std::move(sub_buffer)); + if (!sub_failure.ok()) { + LOG(ERROR) << sub_failure; + if (any_failure.ok()) { + any_failure = std::move(sub_failure); + } + } + } + + return any_failure; +} + +} // namespace + +extern "C" int main(int argc, char **argv) { + InitializeEnvironment(&argc, &argv); + if (argc < 2) { + LOG(ERROR) << "Must supply an input .mlir file."; + return 1; + } + auto status = RunFile(argv[1]); + if (!status.ok()) { + std::cerr << "ERROR running file (" << argv[1] << "): " << status << "\n"; + return 1; + } + return 0; +} + +} // namespace iree
diff --git a/tools/run_module_main.cc b/tools/run_module_main.cc new file mode 100644 index 0000000..93b1553 --- /dev/null +++ b/tools/run_module_main.cc
@@ -0,0 +1,183 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <iostream> +#include <vector> + +#include "absl/flags/flag.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "base/file_io.h" +#include "base/file_path.h" +#include "base/init.h" +#include "base/source_location.h" +#include "base/status.h" +#include "hal/buffer_view_string_util.h" +#include "hal/driver_registry.h" +#include "rt/context.h" +#include "rt/debug/debug_server_flags.h" +#include "rt/instance.h" +#include "rt/module_printer.h" +#include "schemas/module_def_generated.h" +#include "vm/sequencer_module.h" + +ABSL_FLAG(std::string, main_module, "", "Main module with entry point."); +ABSL_FLAG(std::string, main_function, "", + "Function within the main module to execute."); + +ABSL_FLAG(bool, print_disassembly, true, + "Prints bytecode disassembly for the module."); + +ABSL_FLAG(std::string, input_values, "", "Input shapes and optional values."); +ABSL_FLAG(std::string, input_file, "", + "Input shapes and optional values serialized in a file."); + +ABSL_FLAG(std::string, output_types, "", + "Output data types (comma delimited list of b/i/u/f for " + "binary/signed int/unsigned int/float)."); + +namespace iree { +namespace { + +// Parses a list of input shapes and values from a string of newline-separated +// inputs. Expects the contents to have one value per line with each value +// listed as +// [shape]xtype=[value] +// Example: +// 4x4xi8=0,1,2,3 +StatusOr<std::vector<hal::BufferView>> ParseInputsFromFlags( + hal::Allocator* allocator) { + std::string file_contents; + if (!absl::GetFlag(FLAGS_input_values).empty()) { + file_contents = + absl::StrReplaceAll(absl::GetFlag(FLAGS_input_values), {{"\\n", "\n"}}); + } else if (!absl::GetFlag(FLAGS_input_file).empty()) { + ASSIGN_OR_RETURN(file_contents, + file_io::GetFileContents(absl::GetFlag(FLAGS_input_file))); + } + std::vector<hal::BufferView> inputs; + for (const auto& line : + absl::StrSplit(file_contents, '\n', absl::SkipWhitespace())) { + ASSIGN_OR_RETURN(auto input, + hal::ParseBufferViewFromString(line, allocator)); + inputs.push_back(input); + } + return inputs; +} + +} // namespace + +Status Run() { + ASSIGN_OR_RETURN(auto debug_server, rt::debug::CreateDebugServerFromFlags()); + auto instance = make_ref<rt::Instance>(std::move(debug_server)); + ASSIGN_OR_RETURN(auto driver, hal::DriverRegistry::shared_registry()->Create( + "interpreter")); + ASSIGN_OR_RETURN(auto device, driver->CreateDefaultDevice()); + RETURN_IF_ERROR(instance->device_manager()->RegisterDevice(device)); + auto policy = make_ref<rt::Policy>(); + auto context = make_ref<rt::Context>(add_ref(instance), std::move(policy)); + + // Load main module. + ASSIGN_OR_RETURN( + auto main_module_file, + vm::ModuleFile::LoadFile(ModuleDefIdentifier(), + absl::GetFlag(FLAGS_main_module)), + _ << "while loading module file " << absl::GetFlag(FLAGS_main_module)); + ASSIGN_OR_RETURN(auto main_module, + vm::SequencerModule::FromFile(std::move(main_module_file))); + + // Register the main module with the context. + // We could add additional modules (specializations, shared libraries, etc). + // ModuleFioles are stateless so we could have the same module_file used by + // multiple contexts simultaneously. + RETURN_IF_ERROR(context->RegisterModule(add_ref(main_module))); + + // Dump the registered modules. + rt::PrintModuleFlagBitfield print_flags = rt::PrintModuleFlag::kNone; + if (absl::GetFlag(FLAGS_print_disassembly)) { + print_flags |= rt::PrintModuleFlag::kDisassemble; + } + for (const auto& module : context->modules()) { + RETURN_IF_ERROR(PrintModuleToStream(*module, print_flags, &std::cout)); + } + + rt::Function main_function; + if (!absl::GetFlag(FLAGS_main_function).empty()) { + // User-specified main function. + ASSIGN_OR_RETURN(main_function, main_module->LookupFunctionByName( + rt::Function::Linkage::kExport, + absl::GetFlag(FLAGS_main_function))); + } else { + // No main function specified; to prevent non-deterministic behavior we + // require one unless there's exactly one exported function in the module. + if (main_module->signature().export_function_count() == 1) { + ASSIGN_OR_RETURN(main_function, main_module->LookupFunctionByOrdinal( + rt::Function::Linkage::kExport, 0)); + } else { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "--main_function= must be specified to disambiguate the " + "function to run"; + } + } + + // Call into the main function. + ASSIGN_OR_RETURN(auto arguments, ParseInputsFromFlags(device->allocator())); + ASSIGN_OR_RETURN(auto invocation, + rt::Invocation::Create(add_ref(context), main_function, + make_ref<rt::Policy>(), {}, + absl::MakeConstSpan(arguments))); + + // Wait until invocation completes. + RETURN_IF_ERROR(invocation->Await(absl::InfiniteFuture())); + ASSIGN_OR_RETURN(auto results, invocation->ConsumeResults()); + + // Dump all results to stdout. + std::vector<std::string> output_types = + absl::StrSplit(absl::GetFlag(FLAGS_output_types), absl::ByAnyChar(", "), + absl::SkipWhitespace()); + if (!output_types.empty() && output_types.size() != results.size()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "--output_types= specified but has " << output_types.size() + << " types when the function returns " << results.size(); + } + for (int i = 0; i < results.size(); ++i) { + const auto& result = results[i]; + auto print_mode = hal::BufferViewPrintMode::kFloatingPoint; + if (!output_types.empty()) { + ASSIGN_OR_RETURN(print_mode, + hal::ParseBufferViewPrintMode(output_types[i])); + } + ASSIGN_OR_RETURN(auto result_str, + PrintBufferViewToString(result, print_mode, 1024)); + const auto& buffer = result.buffer; + if (!buffer) { + return InternalErrorBuilder(IREE_LOC) + << "result[" << i << "] unexpectedly has no buffer"; + } + LOG(INFO) << "result[" << i << "]: " << buffer->DebugString(); + std::cout << result_str << "\n"; + } + + return OkStatus(); +} + +extern "C" int main(int argc, char** argv) { + InitializeEnvironment(&argc, &argv); + CHECK_OK(Run()); + return 0; +} + +} // namespace iree
diff --git a/iree/tools/sanitizer_suppressions.txt b/tools/sanitizer_suppressions.txt similarity index 100% rename from iree/tools/sanitizer_suppressions.txt rename to tools/sanitizer_suppressions.txt
diff --git a/tools/web/BUILD b/tools/web/BUILD new file mode 100644 index 0000000..c24b4b2 --- /dev/null +++ b/tools/web/BUILD
@@ -0,0 +1,89 @@ +# IREE web tools. + +load("//third_party/emscripten:split_transition_defs.bzl", "auto_wasm_binary") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +EMSCRIPTEN_LINKOPTS_COMMON = [ + # Error at compile time on unresolved symbols. + "-s ERROR_ON_UNDEFINED_SYMBOLS=1", + + # Note: If pthreads and memory growth are enabled, WASM_MEM_MAX must be set. + # Also, USE_PTHREADS + ALLOW_MEMORY_GROWTH may run non-wasm code slowly. + # "-s ALLOW_MEMORY_GROWTH=1", + # "-s WASM_MEM_MAX=268435456", # 256MB + # "-s TOTAL_MEMORY=268435456", # 256MB + + # Request a prepopulated pool of web workers for pthreads to use. + # Without this, threads may not start until the javascript thread yields. + # See considerations at https://emscripten.org/docs/porting/pthreads.html. + "-s PTHREAD_POOL_SIZE=1", +] + +EMSCRIPTEN_LINKOPTS_DBG = [ + # Show WASM stack trace in Chrome debugger. + "-g2", + "-s DEMANGLE_SUPPORT=1", + + # Enable verbose assertions. + "-s ASSERTIONS=2", + "-s SAFE_HEAP=1", + "-s STACK_OVERFLOW_CHECK=2", +] + +EMSCRIPTEN_LINKOPTS_OPT = [] + +# To use run_module_emscripten: +# > bazel build third_party/iree/tools/web:run_module_emscripten_files + +cc_binary( + name = "run_module_emscripten", + srcs = ["run_module_emscripten.cc"], + linkopts = EMSCRIPTEN_LINKOPTS_COMMON + select({ + "//tools/compilation_mode:dbg": EMSCRIPTEN_LINKOPTS_DBG, + "//tools/compilation_mode:opt": EMSCRIPTEN_LINKOPTS_OPT, + "//conditions:default": EMSCRIPTEN_LINKOPTS_OPT, + }), + tags = [ + "manual", + "notap", # TODO(b/137088911): Build/test on TAP + "wasm", + ], + deps = [ + "///base:init", + "///base:status", + "///hal:buffer_view_string_util", + "///hal:driver_registry", + "///hal/interpreter:interpreter_driver_module", + "///rt", + "///vm:sequencer_module", + "//third_party/emscripten:embind", + ], +) + +auto_wasm_binary( + name = "run_module_emscripten_binary", + cc_target = ":run_module_emscripten", + tags = ["manual"], + threads = "emscripten", +) + +Fileset( + name = "run_module_emscripten_files", + out = "wasm_files", + entries = [ + FilesetEntry( + files = [":run_module_emscripten_binary"], + strip_prefix = "run_module_emscripten_binary", + destdir = "wasm", + ), + FilesetEntry( + files = ["run_module.html"], + destdir = "wasm", + ), + ], + tags = ["manual"], +)
diff --git a/iree/tools/web/run_module.html b/tools/web/run_module.html similarity index 100% rename from iree/tools/web/run_module.html rename to tools/web/run_module.html
diff --git a/tools/web/run_module_emscripten.cc b/tools/web/run_module_emscripten.cc new file mode 100644 index 0000000..d433f1f --- /dev/null +++ b/tools/web/run_module_emscripten.cc
@@ -0,0 +1,140 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <emscripten.h> +#include <emscripten/bind.h> + +#include <vector> + +#include "absl/strings/str_replace.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "base/flatbuffer_util.h" +#include "base/init.h" +#include "base/status.h" +#include "hal/buffer_view.h" +#include "hal/buffer_view_string_util.h" +#include "hal/driver_registry.h" +#include "rt/context.h" +#include "rt/instance.h" +#include "schemas/module_def_generated.h" +#include "vm/sequencer_module.h" + +namespace iree { + +// Parses a list of input shapes and values from a string of newline-separated +// inputs. Expects the contents to have one value per line with each value +// listed as +// [shape]xtype=[value] +// Example: +// 4x4xi8=0,1,2,3 +StatusOr<std::vector<hal::BufferView>> ParseInputs( + absl::string_view inputs_string, hal::Allocator* allocator) { + std::string input_lines = absl::StrReplaceAll(inputs_string, {{"\\n", "\n"}}); + std::vector<hal::BufferView> input_buffer_views; + for (const auto& input_line : + absl::StrSplit(input_lines, '\n', absl::SkipWhitespace())) { + ASSIGN_OR_RETURN(auto input_buffer_view, + hal::ParseBufferViewFromString(input_line, allocator)); + input_buffer_views.push_back(input_buffer_view); + } + return input_buffer_views; +} + +// Runs an IREE module with the provided inputs and returns its outputs. +StatusOr<std::string> RunIreeModule(std::string module_file_data, + absl::string_view inputs_string) { + auto instance = make_ref<rt::Instance>(); + + // Create driver and device. + ASSIGN_OR_RETURN(auto driver, hal::DriverRegistry::shared_registry()->Create( + "interpreter")); + ASSIGN_OR_RETURN(auto device, driver->CreateDefaultDevice()); + RETURN_IF_ERROR(instance->device_manager()->RegisterDevice(device)); + + auto policy = make_ref<rt::Policy>(); + auto context = make_ref<rt::Context>(add_ref(instance), std::move(policy)); + + // Load main module FlatBuffer. + ASSIGN_OR_RETURN(auto main_module_file, + FlatBufferFile<ModuleDef>::FromString(ModuleDefIdentifier(), + module_file_data)); + ASSIGN_OR_RETURN(auto main_module, + vm::SequencerModule::FromFile(std::move(main_module_file))); + + // Register the main module with the context. + RETURN_IF_ERROR(context->RegisterModule(add_ref(main_module))); + + // Setup arguments and storage for results. + // TODO(scotttodd): Receive main function name from JS. + ASSIGN_OR_RETURN(auto main_function, + main_module->LookupFunctionByName( + rt::Function::Linkage::kExport, "main")); + + ASSIGN_OR_RETURN(auto arguments, + ParseInputs(inputs_string, device->allocator())); + + // Call into the main function. + ASSIGN_OR_RETURN(auto invocation, + rt::Invocation::Create(add_ref(context), main_function, + make_ref<rt::Policy>(), {}, + absl::MakeConstSpan(arguments))); + + // Wait until invocation completes. + // TODO(scotttodd): make this an async callback. + RETURN_IF_ERROR(invocation->Await(absl::InfiniteFuture())); + ASSIGN_OR_RETURN(auto results, invocation->ConsumeResults()); + + // Dump all results to stdout. + // TODO(scotttodd): Receive output types / print mode from JS. + // TODO(scotttodd): Return list of outputs instead of just the first (proto?) + for (int i = 0; i < results.size(); ++i) { + const auto& result = results[i]; + auto print_mode = hal::BufferViewPrintMode::kFloatingPoint; + ASSIGN_OR_RETURN(auto result_str, + PrintBufferViewToString(result, print_mode, 1024)); + const auto& buffer = result.buffer; + if (!buffer) { + return InternalErrorBuilder(IREE_LOC) + << "result[" << i << "] unexpectedly has no buffer"; + } + + return result_str; + } + + return InternalErrorBuilder(IREE_LOC) << "Received no results"; +} + +std::string RunIreeModuleEntry(std::string module_file_data, + std::string inputs_string) { + // TODO(scotttodd): optimize, minimize copies + // https://groups.google.com/d/msg/emscripten-discuss/CMfYljLWMvY/Di52WB2QAgAJ + auto result_or = RunIreeModule(std::move(module_file_data), inputs_string); + if (!result_or.ok()) { + return "Error: " + result_or.status().ToString(); + } else { + return result_or.ValueOrDie(); + } +} + +EMSCRIPTEN_BINDINGS(iree) { + emscripten::function("runIreeModule", &RunIreeModuleEntry); +} + +extern "C" int main(int argc, char** argv) { + InitializeEnvironment(&argc, &argv); + return 0; +} + +} // namespace iree
diff --git a/vm/BUILD b/vm/BUILD new file mode 100644 index 0000000..c7fe7a2 --- /dev/null +++ b/vm/BUILD
@@ -0,0 +1,200 @@ +# Bytecode VM used by the IREE sequencer and interpreter. + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "api", + srcs = ["api.cc"], + hdrs = ["api.h"], + visibility = ["//visibility:public"], + deps = [ + ":api_hdrs", + ":sequencer_module", + "///base:api", + "///base:api_util", + "///base:flatbuffer_util", + "///base:tracing", + "///rt:api", + ], +) + +cc_library( + name = "api_hdrs", + hdrs = ["api.h"], + deps = [ + "///base:api_hdrs", + "///rt:api_hdrs", + ], +) + +cc_library( + name = "bytecode_module", + srcs = [ + "bytecode_disassembler.cc", + "bytecode_module.cc", + ], + hdrs = [ + "bytecode_disassembler.h", + "bytecode_module.h", + ], + deps = [ + ":bytecode_util", + ":opcode_info", + ":source_map_resolver", + ":type", + "///base:flatbuffer_util", + "///base:status", + "///base:tracing", + "///hal:buffer_view", + "///rt", + "///schemas", + "///schemas/bytecode:bytecode_v0", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "bytecode_reader", + srcs = ["bytecode_reader.cc"], + hdrs = ["bytecode_reader.h"], + deps = [ + ":bytecode_module", + ":type", + "///base:shape", + "///base:status", + "///hal:buffer_view", + "///hal:heap_buffer", + "///rt", + "///schemas/bytecode:bytecode_v0", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + ], +) + +cc_library( + name = "bytecode_tables_interpreter", + srcs = ["bytecode_tables_interpreter.cc"], + hdrs = ["bytecode_tables_interpreter.h"], + deps = [ + ":opcode_info", + "///schemas/bytecode:interpreter_bytecode_v0", + ], +) + +cc_library( + name = "bytecode_tables_sequencer", + srcs = ["bytecode_tables_sequencer.cc"], + hdrs = ["bytecode_tables_sequencer.h"], + deps = [ + ":opcode_info", + "///schemas/bytecode:sequencer_bytecode_v0", + ], +) + +cc_library( + name = "bytecode_util", + srcs = ["bytecode_util.cc"], + hdrs = ["bytecode_util.h"], + deps = [ + "///schemas/bytecode:bytecode_v0", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "bytecode_validator", + srcs = ["bytecode_validator.cc"], + hdrs = ["bytecode_validator.h"], + deps = [ + ":bytecode_module", + "///base:status", + "///schemas", + ], +) + +cc_library( + name = "opcode_info", + hdrs = ["opcode_info.h"], + deps = [ + "///schemas/bytecode:bytecode_v0", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "sequencer_dispatch", + srcs = ["sequencer_dispatch.cc"], + hdrs = ["sequencer_dispatch.h"], + deps = [ + ":bytecode_module", + ":bytecode_reader", + ":bytecode_tables_sequencer", + ":bytecode_util", + ":opcode_info", + "///base:logging", + "///base:memory", + "///base:status", + "///hal:buffer_view", + "///hal:command_queue", + "///hal:device", + "///hal:device_placement", + "///hal:heap_buffer", + "///rt", + "///schemas/bytecode:sequencer_bytecode_v0", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "sequencer_module", + srcs = ["sequencer_module.cc"], + hdrs = ["sequencer_module.h"], + deps = [ + ":bytecode_module", + ":bytecode_tables_sequencer", + ":sequencer_dispatch", + "///base:status", + "///base:tracing", + "///hal:buffer_view", + "///rt", + "@com_google_absl//absl/memory", + ], +) + +cc_library( + name = "source_map_resolver", + srcs = ["source_map_resolver.cc"], + hdrs = ["source_map_resolver.h"], + deps = [ + "///base:flatbuffer_util", + "///base:status", + "///rt", + "///schemas", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + +cc_library( + name = "type", + srcs = ["type.cc"], + hdrs = ["type.h"], + deps = [ + "///base:status", + "///schemas", + "///schemas/bytecode:bytecode_v0", + ], +)
diff --git a/iree/vm/CMakeLists.txt b/vm/CMakeLists.txt similarity index 100% rename from iree/vm/CMakeLists.txt rename to vm/CMakeLists.txt
diff --git a/vm/api.cc b/vm/api.cc new file mode 100644 index 0000000..08bd6e0 --- /dev/null +++ b/vm/api.cc
@@ -0,0 +1,99 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "vm/api.h" + +#include "base/api.h" +#include "base/api_util.h" +#include "base/flatbuffer_util.h" +#include "base/tracing.h" +#include "vm/sequencer_module.h" + +namespace iree { +namespace vm { + +//===----------------------------------------------------------------------===// +// iree::vm::BytecodeModule +//===----------------------------------------------------------------------===// + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_vm_bytecode_module_create_from_buffer( + iree_const_byte_span_t buffer_data, + void (*buffer_free_fn)(void* self, iree_byte_span_t buffer_data), + void* buffer_free_self, iree_allocator_t allocator, + iree_rt_module_t** out_module) { + IREE_TRACE_SCOPE0("iree_vm_bytecode_module_create_from_buffer"); + + if (!out_module) { + return IREE_STATUS_INVALID_ARGUMENT; + } + *out_module = nullptr; + + if (!buffer_data.data || !buffer_data.data_length) { + return IREE_STATUS_INVALID_ARGUMENT; + } + + IREE_API_ASSIGN_OR_RETURN( + auto module_file, + FlatBufferFile<ModuleDef>::FromBuffer( + ModuleDefIdentifier(), {buffer_data.data, buffer_data.data_length}, + [buffer_free_fn, buffer_free_self, buffer_data]() { + if (buffer_free_fn != nullptr) { + buffer_free_fn(buffer_free_self, + {const_cast<uint8_t*>(buffer_data.data), + buffer_data.data_length}); + } + })); + + IREE_API_ASSIGN_OR_RETURN(auto module, + SequencerModule::FromFile(std::move(module_file))); + + *out_module = reinterpret_cast<iree_rt_module_t*>(module.release()); + + return IREE_STATUS_OK; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_vm_bytecode_module_create_from_file_mapping( + iree_file_mapping_t* file_mapping, iree_allocator_t allocator, + iree_rt_module_t** out_module) { + IREE_TRACE_SCOPE0("iree_vm_bytecode_module_create_from_file_mapping"); + + if (!out_module) { + return IREE_STATUS_INVALID_ARGUMENT; + } + *out_module = nullptr; + + if (!file_mapping) { + return IREE_STATUS_INVALID_ARGUMENT; + } + + auto buffer_data = iree_file_mapping_data(file_mapping); + IREE_API_ASSIGN_OR_RETURN( + auto module_file, + FlatBufferFile<ModuleDef>::FromBuffer( + ModuleDefIdentifier(), {buffer_data.data, buffer_data.data_length}, + [file_mapping]() { iree_file_mapping_release(file_mapping); })); + iree_file_mapping_retain(file_mapping); + + IREE_API_ASSIGN_OR_RETURN(auto module, + SequencerModule::FromFile(std::move(module_file))); + + *out_module = reinterpret_cast<iree_rt_module_t*>(module.release()); + + return IREE_STATUS_OK; +} + +} // namespace vm +} // namespace iree
diff --git a/vm/api.h b/vm/api.h new file mode 100644 index 0000000..adaf774 --- /dev/null +++ b/vm/api.h
@@ -0,0 +1,60 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// See iree/base/api.h for documentation on the API conventions used. + +#ifndef IREE_VM_API_H_ +#define IREE_VM_API_H_ + +#include <stdint.h> + +#include "base/api.h" +#include "rt/api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +//===----------------------------------------------------------------------===// +// iree::vm::BytecodeModule +//===----------------------------------------------------------------------===// + +#ifndef IREE_API_NO_PROTOTYPES + +// Creates a VM module from an in-memory ModuleDef FlatBuffer. +// The provided |buffer_free_fn| will be called when the module is destroyed +// and only if this creation function succeeds. If ownership remains with the +// caller then pass nullptr for |buffer_free_fn|. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_vm_bytecode_module_create_from_buffer( + iree_const_byte_span_t buffer_data, + void (*buffer_free_fn)(void* self, iree_byte_span_t buffer_data), + void* buffer_free_self, iree_allocator_t allocator, + iree_rt_module_t** out_module); + +// Creates a VM module from a mapped ModuleDef FlatBuffer. +// The provided |file_mapping| will be retained for the life of the module and +// the contents will be accessed by reference. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_vm_bytecode_module_create_from_file_mapping( + iree_file_mapping_t* file_mapping, iree_allocator_t allocator, + iree_rt_module_t** out_module); + +#endif // IREE_API_NO_PROTOTYPES + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_VM_API_H_
diff --git a/vm/bytecode_disassembler.cc b/vm/bytecode_disassembler.cc new file mode 100644 index 0000000..532abc7 --- /dev/null +++ b/vm/bytecode_disassembler.cc
@@ -0,0 +1,482 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "vm/bytecode_disassembler.h" + +#include <iomanip> +#include <sstream> + +#include "absl/base/macros.h" +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "base/status.h" +#include "schemas/bytecode/bytecode_v0.h" +#include "schemas/source_map_def_generated.h" +#include "vm/bytecode_module.h" +#include "vm/bytecode_util.h" +#include "vm/type.h" + +namespace iree { +namespace vm { + +namespace { + +using ::iree::rt::SourceOffset; + +template <typename T> +StatusOr<T> ReadValue(absl::Span<const uint8_t> data, SourceOffset* offset) { + if (*offset + sizeof(T) > data.size()) { + return OutOfRangeErrorBuilder(IREE_LOC) << "Bytecode data underrun"; + } + auto value = *reinterpret_cast<const T*>(&data[*offset]); + *offset = *offset + sizeof(T); + return value; +} + +StatusOr<const Type> ReadType(absl::Span<const uint8_t> data, + SourceOffset* offset) { + ASSIGN_OR_RETURN(uint8_t type_index, ReadValue<uint8_t>(data, offset)); + return Type::FromTypeIndex(type_index); +} + +StatusOr<uint8_t> ReadCount(absl::Span<const uint8_t> data, + SourceOffset* offset) { + return ReadValue<uint8_t>(data, offset); +} + +StatusOr<uint16_t> ReadValueSlot(absl::Span<const uint8_t> data, + SourceOffset* offset) { + return ReadValue<uint16_t>(data, offset); +} + +absl::string_view ConstantEncodingToString(ConstantEncoding encoding) { + switch (encoding) { +#define GET_NAME(ordinal, enum_name, str, ...) \ + case ConstantEncoding::enum_name: \ + return str; + IREE_CONSTANT_ENCODING_LIST(GET_NAME) +#undef GET_NAME + default: + return "unknown"; + } +} + +template <typename T> +std::string TypedDataToString(absl::Span<const uint8_t> bytes) { + auto typed_data = absl::Span<const T>{ + reinterpret_cast<const T*>(bytes.data()), bytes.size() / sizeof(T)}; + return absl::StrJoin(typed_data, ","); +} + +std::string ConstantToString(const Type& type, + absl::Span<const uint8_t> bytes) { + if (!type.is_builtin()) { + return absl::StrJoin(bytes, ","); + } + switch (type.builtin_type()) { + case BuiltinType::kI8: + return TypedDataToString<uint8_t>(bytes); + case BuiltinType::kI16: + return TypedDataToString<uint16_t>(bytes); + case BuiltinType::kI32: + return TypedDataToString<uint32_t>(bytes); + case BuiltinType::kI64: + return TypedDataToString<uint64_t>(bytes); + case BuiltinType::kF16: + return TypedDataToString<uint16_t>(bytes); + case BuiltinType::kF32: + return TypedDataToString<float>(bytes); + case BuiltinType::kF64: + return TypedDataToString<double>(bytes); + default: + return "<unsupported>"; + } +} + +} // namespace + +StatusOr<std::vector<rt::Instruction>> +BytecodeDisassembler::DisassembleInstructions(const rt::Function& function, + SourceOffset offset, + int32_t instruction_count) const { + std::vector<rt::Instruction> instructions; + + ASSIGN_OR_RETURN( + auto* function_def, + static_cast<const BytecodeModule*>(function.module()) + ->GetFunctionDef(function.linkage(), function.ordinal())); + auto* bytecode_def = function_def->bytecode(); + if (!bytecode_def) { + return UnavailableErrorBuilder(IREE_LOC) << "Function contains no body"; + } + auto data = absl::MakeSpan( + reinterpret_cast<const uint8_t*>(bytecode_def->contents()->data()), + bytecode_def->contents()->size()); + + // TODO(benvanik): scan and find all branch offsets to insert labels + + while (offset < data.length() && instructions.size() < instruction_count) { + instructions.push_back({}); + auto& instruction = instructions.back(); + instruction.offset = offset; + + uint8_t opcode = data[offset++]; + const auto& opcode_info = GetOpcodeInfo(opcode_table_, opcode); + if (!opcode_info.mnemonic) { + return UnimplementedErrorBuilder(IREE_LOC) + << "Unhandled opcode " << opcode << " at offset " << (offset - 1); + } + int payload_offset = offset; + + std::ostringstream stream; + + // Print out return values, if any. + int base_result_index = 0; + int printed_result_count = 0; + for (int i = base_result_index; i < ABSL_ARRAYSIZE(opcode_info.operands); + ++i) { + if (opcode_info.operands[i] == OperandEncoding::kNone) break; + if (printed_result_count > 0) { + stream << ", "; + } + switch (opcode_info.operands[i]) { + default: + case OperandEncoding::kNone: + return UnimplementedErrorBuilder(IREE_LOC) + << "Unhandled op encoding " + << static_cast<int>(opcode_info.operands[i]) << " at offset " + << (offset - 1); + case OperandEncoding::kInputSlot: + case OperandEncoding::kOutputSlot: { + // Printing handled below. + offset += sizeof(uint16_t); + break; + } + case OperandEncoding::kVariadicInputSlots: + case OperandEncoding::kVariadicOutputSlots: { + // Printing handled below. + ASSIGN_OR_RETURN(int count, ReadCount(data, &offset)); + offset += count * sizeof(uint16_t); + break; + } + case OperandEncoding::kResultSlot: { + ++printed_result_count; + ASSIGN_OR_RETURN(uint16_t slot_ordinal, ReadValueSlot(data, &offset)); + stream << "%" << slot_ordinal; + break; + } + case OperandEncoding::kVariadicResultSlots: { + ++printed_result_count; + stream << "["; + ASSIGN_OR_RETURN(int count, ReadCount(data, &offset)); + for (int j = 0; j < count; ++j) { + ASSIGN_OR_RETURN(uint16_t slot_ordinal, + ReadValueSlot(data, &offset)); + if (j > 0) stream << ", "; + stream << "%" << slot_ordinal; + } + stream << "]"; + break; + } + case OperandEncoding::kVariadicTransferSlots: { + // Printing handled below. + ASSIGN_OR_RETURN(int count, ReadCount(data, &offset)); + offset += count * 2 * sizeof(uint16_t); + break; + } + case OperandEncoding::kConstant: { + // Printing handled below. + ASSIGN_OR_RETURN(auto type, ReadType(data, &offset)); + ASSIGN_OR_RETURN(int rank, ReadCount(data, &offset)); + int element_count = 1; + for (int j = 0; j < rank; ++j) { + ASSIGN_OR_RETURN(int dim, ReadValue<int32_t>(data, &offset)); + element_count *= dim; + } + offset += sizeof(ConstantEncoding); + offset += element_count * type.element_size(); + break; + } + case OperandEncoding::kFunctionOrdinal: { + // Printing handled below. + offset += sizeof(uint32_t); + break; + } + case OperandEncoding::kDispatchOrdinal: { + // Printing handled below. + offset += sizeof(uint32_t) + sizeof(uint16_t); + break; + } + case OperandEncoding::kBlockOffset: { + // Printing handled below. + offset += sizeof(uint32_t); + break; + } + case OperandEncoding::kTypeIndex: { + // Printing handled below. + offset += sizeof(uint8_t); + break; + } + case OperandEncoding::kIndex: { + // Printing handled below. + offset += sizeof(int32_t); + break; + } + case OperandEncoding::kIndexList: { + // Printing handled below. + ASSIGN_OR_RETURN(int count, ReadCount(data, &offset)); + offset += count * sizeof(int32_t); + break; + } + case OperandEncoding::kCmpIPredicate: + case OperandEncoding::kCmpFPredicate: { + // Printing handled below. + offset += sizeof(uint8_t); + break; + } + } + } + if (printed_result_count > 0) { + stream << " = "; + } + offset = payload_offset; + + stream << opcode_info.mnemonic; + + // Print out operands. + int base_operand_index = 0; + int printed_operand_count = 0; + for (int i = base_operand_index; i < ABSL_ARRAYSIZE(opcode_info.operands); + ++i) { + if (opcode_info.operands[i] == OperandEncoding::kNone) break; + if (opcode_info.operands[i] != OperandEncoding::kResultSlot && + opcode_info.operands[i] != OperandEncoding::kVariadicResultSlots) { + if (i == base_operand_index) { + stream << " "; + } else if (printed_operand_count > 0) { + stream << ", "; + } + } + switch (opcode_info.operands[i]) { + default: + case OperandEncoding::kNone: + return UnimplementedErrorBuilder(IREE_LOC) + << "Unhandled op encoding " + << static_cast<int>(opcode_info.operands[i]) << " at offset " + << (offset - 1); + case OperandEncoding::kInputSlot: { + ++printed_operand_count; + ASSIGN_OR_RETURN(uint16_t slot_ordinal, ReadValueSlot(data, &offset)); + stream << "%" << slot_ordinal; + break; + } + case OperandEncoding::kVariadicInputSlots: { + ++printed_operand_count; + stream << "["; + ASSIGN_OR_RETURN(int count, ReadCount(data, &offset)); + for (int j = 0; j < count; ++j) { + ASSIGN_OR_RETURN(uint16_t slot_ordinal, + ReadValueSlot(data, &offset)); + if (j > 0) stream << ", "; + stream << "%" << slot_ordinal; + } + stream << "]"; + break; + } + case OperandEncoding::kOutputSlot: { + ++printed_operand_count; + ASSIGN_OR_RETURN(uint16_t slot_ordinal, ReadValueSlot(data, &offset)); + stream << "&" + << "%" << slot_ordinal; + break; + } + case OperandEncoding::kVariadicOutputSlots: { + ++printed_operand_count; + stream << "["; + ASSIGN_OR_RETURN(int count, ReadCount(data, &offset)); + for (int j = 0; j < count; ++j) { + ASSIGN_OR_RETURN(uint16_t slot_ordinal, + ReadValueSlot(data, &offset)); + if (j > 0) stream << ", "; + stream << "&" + << "%" << slot_ordinal; + } + stream << "]"; + break; + } + case OperandEncoding::kResultSlot: { + // Printing handled above. + offset += sizeof(uint16_t); + break; + } + case OperandEncoding::kVariadicResultSlots: { + // Printing handled above. + ASSIGN_OR_RETURN(int count, ReadCount(data, &offset)); + offset += count * sizeof(uint16_t); + break; + } + case OperandEncoding::kVariadicTransferSlots: { + ++printed_operand_count; + stream << "["; + ASSIGN_OR_RETURN(int count, ReadCount(data, &offset)); + for (int j = 0; j < count; ++j) { + ASSIGN_OR_RETURN(uint16_t src_slot_ordinal, + ReadValueSlot(data, &offset)); + ASSIGN_OR_RETURN(uint16_t dst_slot_ordinal, + ReadValueSlot(data, &offset)); + if (j > 0) stream << ", "; + stream << "%" << src_slot_ordinal << "=>%" << dst_slot_ordinal; + } + stream << "]"; + break; + } + case OperandEncoding::kConstant: { + ++printed_operand_count; + ASSIGN_OR_RETURN(auto type, ReadType(data, &offset)); + ASSIGN_OR_RETURN(int rank, ReadCount(data, &offset)); + absl::InlinedVector<int32_t, 4> shape(rank); + int element_count = 1; + for (int j = 0; j < rank; ++j) { + ASSIGN_OR_RETURN(int dim, ReadValue<int32_t>(data, &offset)); + shape[j] = dim; + element_count *= dim; + } + ASSIGN_OR_RETURN(auto encoding, + ReadValue<ConstantEncoding>(data, &offset)); + stream << ConstantEncodingToString(encoding); + int serialized_element_count = 1; + switch (encoding) { + case ConstantEncoding::kDense: + serialized_element_count = element_count; + break; + case ConstantEncoding::kSplat: + serialized_element_count = 1; + break; + default: + return UnimplementedErrorBuilder(IREE_LOC) + << "Unimplemented constant encoding " + << static_cast<int>(encoding); + } + stream << " buffer_view<"; + if (!shape.empty()) { + stream << absl::StrJoin(shape, "x") << "x"; + } + stream << type << ">{"; + size_t element_size = type.element_size(); + auto bytes = data.subspan( + offset, std::min(serialized_element_count, 1024) * element_size); + stream << ConstantToString(type, bytes); + if (serialized_element_count > 1024) stream << "..."; + offset += serialized_element_count * element_size; + stream << "}"; + break; + } + case OperandEncoding::kFunctionOrdinal: { + ++printed_operand_count; + ASSIGN_OR_RETURN(auto function_ordinal, + ReadValue<uint32_t>(data, &offset)); + ASSIGN_OR_RETURN( + auto target_function, + function.module()->LookupFunctionByOrdinal( + rt::Function::Linkage::kInternal, function_ordinal)); + stream << "@" << function_ordinal << " " << target_function.name(); + break; + } + case OperandEncoding::kDispatchOrdinal: { + ++printed_operand_count; + ASSIGN_OR_RETURN(auto dispatch_ordinal, + ReadValue<uint32_t>(data, &offset)); + ASSIGN_OR_RETURN(auto export_ordinal, + ReadValue<uint16_t>(data, &offset)); + // TODO(benvanik): lookup in executable table. + stream << "@" << dispatch_ordinal << ":" << export_ordinal; + break; + } + case OperandEncoding::kImportOrdinal: { + ++printed_operand_count; + ASSIGN_OR_RETURN(auto import_ordinal, + ReadValue<uint32_t>(data, &offset)); + ASSIGN_OR_RETURN(auto target_function, + function.module()->LookupFunctionByOrdinal( + rt::Function::Linkage::kImport, import_ordinal)); + stream << "@i" << import_ordinal << " " << target_function.name(); + break; + } + case OperandEncoding::kBlockOffset: { + ++printed_operand_count; + ASSIGN_OR_RETURN(uint32_t block_offset, + ReadValue<uint32_t>(data, &offset)); + stream << ":" << block_offset; + break; + } + case OperandEncoding::kTypeIndex: { + ++printed_operand_count; + ASSIGN_OR_RETURN(auto type, ReadType(data, &offset)); + stream << type; + break; + } + case OperandEncoding::kIndex: { + ++printed_operand_count; + ASSIGN_OR_RETURN(auto index, ReadValue<int32_t>(data, &offset)); + stream << "#" << index; + break; + } + case OperandEncoding::kIndexList: { + ++printed_operand_count; + stream << "{"; + ASSIGN_OR_RETURN(int count, ReadCount(data, &offset)); + for (int j = 0; j < count; ++j) { + ASSIGN_OR_RETURN(auto dim, ReadValue<int32_t>(data, &offset)); + if (j > 0) stream << ","; + stream << dim; + } + stream << "}"; + break; + } + case OperandEncoding::kCmpIPredicate: { + ++printed_operand_count; + ASSIGN_OR_RETURN(auto predicate_value, + ReadValue<uint8_t>(data, &offset)); + stream << "<" + << PredicateToString( + static_cast<CmpIPredicate>(predicate_value)) + << ">"; + break; + } + case OperandEncoding::kCmpFPredicate: { + ++printed_operand_count; + ASSIGN_OR_RETURN(auto predicate_value, + ReadValue<uint8_t>(data, &offset)); + stream << "<" + << PredicateToString( + static_cast<CmpFPredicate>(predicate_value)) + << ">"; + break; + } + } + } + + stream << "\n"; + + instruction.long_text = stream.str(); + instruction.short_text = instruction.long_text; + } + + return instructions; +} + +} // namespace vm +} // namespace iree
diff --git a/vm/bytecode_disassembler.h b/vm/bytecode_disassembler.h new file mode 100644 index 0000000..4639777 --- /dev/null +++ b/vm/bytecode_disassembler.h
@@ -0,0 +1,46 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_VM_BYTECODE_DISASSEMBLER_H_ +#define IREE_VM_BYTECODE_DISASSEMBLER_H_ + +#include <ostream> + +#include "base/status.h" +#include "rt/disassembler.h" +#include "schemas/bytecode_def_generated.h" +#include "schemas/source_map_def_generated.h" +#include "vm/opcode_info.h" + +namespace iree { +namespace vm { + +// Disassembles bytecode with a specific op set to text. +class BytecodeDisassembler final : public rt::Disassembler { + public: + explicit BytecodeDisassembler(OpcodeTable opcode_table) + : opcode_table_(opcode_table) {} + + StatusOr<std::vector<rt::Instruction>> DisassembleInstructions( + const rt::Function& function, rt::SourceOffset offset, + int32_t instruction_count) const override; + + private: + OpcodeTable opcode_table_; +}; + +} // namespace vm +} // namespace iree + +#endif // IREE_VM_BYTECODE_DISASSEMBLER_H_
diff --git a/vm/bytecode_module.cc b/vm/bytecode_module.cc new file mode 100644 index 0000000..9ba5e84 --- /dev/null +++ b/vm/bytecode_module.cc
@@ -0,0 +1,310 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "vm/bytecode_module.h" + +#include "absl/memory/memory.h" +#include "base/status.h" +#include "base/tracing.h" +#include "hal/buffer_view.h" +#include "vm/bytecode_disassembler.h" + +namespace iree { +namespace vm { + +namespace { + +using ::iree::hal::BufferView; +using ::iree::rt::Function; +using ::iree::rt::FunctionSignature; +using ::iree::rt::Module; +using ::iree::rt::ModuleSignature; + +Status ValidateElementSize(int element_bit_width, + const ElementTypeDef& expected_element_type) { + switch (expected_element_type.type_union_type()) { + case ElementTypeDefUnion::FloatTypeDef: { + auto expected_bit_width = + expected_element_type.type_union_as_FloatTypeDef()->width(); + if (element_bit_width != expected_bit_width) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Has element bit width " << element_bit_width + << " but expected " << expected_bit_width; + } + return OkStatus(); + } + case ElementTypeDefUnion::IntegerTypeDef: { + auto expected_bit_width = + expected_element_type.type_union_as_IntegerTypeDef()->width(); + if (element_bit_width != expected_bit_width) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Has element bit width " << element_bit_width + << " but expected " << expected_bit_width; + } + return OkStatus(); + } + case ElementTypeDefUnion::UnknownTypeDef: + case ElementTypeDefUnion::NONE: { + } + } + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Defined type has unsupported element type " + << EnumNameElementTypeDefUnion( + expected_element_type.type_union_type()); +} + +Status ValidateTypeStructure(const FunctionTypeDef& type_def) { + // Ensure all fields are populated. + return OkStatus(); +} + +Status ValidateFunctionTableStructure( + const FunctionTableDef& function_table_def) { + if (!function_table_def.functions()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Function table is missing the function listing"; + } + + // All functions must contain a valid type. + const auto& functions = *function_table_def.functions(); + for (int i = 0; i < functions.size(); ++i) { + const auto* function = functions[i]; + if (!function) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Function ordinal " << i << " is missing its contents"; + } + if (!function->type()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Function ordinal " << i << " is missing its type"; + } + RETURN_IF_ERROR(ValidateTypeStructure(*function->type())); + } + + // Imports must also have a name (that we can use to resolve it). + if (function_table_def.imports()) { + const auto& imports = *function_table_def.imports(); + for (int i = 0; i < imports.size(); ++i) { + int function_index = imports[i]; + if (!functions[function_index]->name()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Import ordinal " << i << " is missing its contents"; + } + } + } + + // Exports must also have a name (that others will use to look it up). + if (function_table_def.exports()) { + const auto& exports = *function_table_def.exports(); + for (int i = 0; i < exports.size(); ++i) { + int function_index = exports[i]; + if (!functions[function_index]->name()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Export ordinal " << i << " is missing its contents"; + } + } + } + + return OkStatus(); +} + +Status ValidateExecutableTableStructure( + const ExecutableTableDef& executable_table_def) { + if (!executable_table_def.multi_arch_executables()) { + // May have sequencer only fns. Fine to not have dispatchable executables. + return OkStatus(); + } + + // All fat executables need at least one device-specific executable. + const auto& multi_arch_executables = + *executable_table_def.multi_arch_executables(); + for (int i = 0; i < multi_arch_executables.size(); ++i) { + const auto* multi_arch_executable = multi_arch_executables[i]; + if (!multi_arch_executable || !multi_arch_executable->executables() || + multi_arch_executable->executables()->size() == 0) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Multi-arch executable ordinal " << i + << " is missing its contents"; + } + } + + return OkStatus(); +} + +} // namespace + +// static +Status BytecodeModule::ValidateStructure(const ModuleDef& module_def) { + IREE_TRACE_SCOPE0("BytecodeModule::ValidateStructure"); + + // Must have a function table. + if (module_def.function_table()) { + RETURN_IF_ERROR( + ValidateFunctionTableStructure(*module_def.function_table())); + } else { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "ModuleDef is missing a function table"; + } + + // Must have an executable table. + if (module_def.executable_table()) { + RETURN_IF_ERROR( + ValidateExecutableTableStructure(*module_def.executable_table())); + } else { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "ModuleDef is missing an executable table"; + } + + return OkStatus(); +} + +BytecodeModule::BytecodeModule(std::unique_ptr<ModuleFile> module_file, + OpcodeTable opcode_table) + : module_file_(std::move(module_file)), + module_def_(*module_file_->root()), + source_resolver_(SourceMapResolver::FromModule(module_def_)), + disassembler_(absl::make_unique<BytecodeDisassembler>(opcode_table)) {} + +BytecodeModule::~BytecodeModule() = default; + +const ModuleSignature BytecodeModule::signature() const { + return ModuleSignature(function_table_def().imports()->size(), + function_table_def().exports()->size(), + function_table_def().functions()->size(), 0); +} + +std::string BytecodeModule::DebugStringShort() const { + return std::string(name()); +} + +StatusOr<int32_t> BytecodeModule::MapFunctionOrdinal(Function::Linkage linkage, + int32_t ordinal) const { + const auto& function_table = function_table_def(); + switch (linkage) { + case Function::Linkage::kImport: + if (ordinal < 0 || ordinal >= function_table.imports()->size()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Import ordinal " << ordinal + << " is outside the valid range [0, " + << function_table.imports()->size() << ")"; + } + ordinal = function_table.imports()->Get(ordinal); + break; + case Function::Linkage::kExport: + if (ordinal < 0 || ordinal >= function_table.exports()->size()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Export ordinal " << ordinal + << " is outside the valid range [0, " + << function_table.exports()->size() << ")"; + } + ordinal = function_table.exports()->Get(ordinal); + break; + default: + break; + } + if (ordinal < 0 || ordinal >= function_table.functions()->size()) { + return OutOfRangeErrorBuilder(IREE_LOC) + << "Function ordinal " << ordinal + << " is outside the valid range [0, " + << function_table.functions()->size() << ")"; + } + return ordinal; +} + +StatusOr<const Function> BytecodeModule::LookupFunctionByOrdinal( + Function::Linkage linkage, int32_t ordinal) const { + ASSIGN_OR_RETURN(ordinal, MapFunctionOrdinal(linkage, ordinal)); + return Function(this, Function::Linkage::kInternal, ordinal); +} + +StatusOr<const Function> BytecodeModule::LookupFunctionByName( + Function::Linkage linkage, absl::string_view name) const { + const auto& functions = *function_table_def().functions(); + for (int i = 0; i < functions.size(); ++i) { + const auto* function_def = functions.Get(i); + if (WrapString(function_def->name()) == name) { + return LookupFunctionByOrdinal(Function::Linkage::kInternal, i); + } + } + return NotFoundErrorBuilder(IREE_LOC) + << "Function '" << name + << "' not found in function table (or names have been stripped)"; +} + +StatusOr<absl::string_view> BytecodeModule::GetFunctionName( + Function::Linkage linkage, int32_t ordinal) const { + ASSIGN_OR_RETURN(ordinal, MapFunctionOrdinal(linkage, ordinal)); + const auto* function_def = function_table_def().functions()->Get(ordinal); + return WrapString(function_def->name()); +} + +StatusOr<const FunctionSignature> BytecodeModule::GetFunctionSignature( + Function::Linkage linkage, int32_t ordinal) const { + ASSIGN_OR_RETURN(ordinal, MapFunctionOrdinal(linkage, ordinal)); + const auto* function_def = function_table_def().functions()->Get(ordinal); + const auto* type_def = function_def->type(); + return FunctionSignature( + type_def->inputs() ? type_def->inputs()->size() : 0, + type_def->results() ? type_def->results()->size() : 0); +} + +StatusOr<const FunctionDef*> BytecodeModule::GetFunctionDef( + rt::Function::Linkage linkage, int32_t ordinal) const { + ASSIGN_OR_RETURN(ordinal, MapFunctionOrdinal(linkage, ordinal)); + const auto& function_defs = *function_table_def().functions(); + if (ordinal >= function_defs.size()) { + return OutOfRangeErrorBuilder(IREE_LOC) + << "Internal function ordinal " << ordinal + << " out of range of table (" << function_defs.size() << ")"; + } + return function_defs.Get(ordinal); +} + +StatusOr<const MultiArchExecutableDef*> +BytecodeModule::LookupMultiArchExecutable(int executable_ordinal) const { + if (executable_ordinal < 0 || + executable_ordinal >= + executable_table_def().multi_arch_executables()->size()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Invalid multi-arch executable ordinal " << executable_ordinal; + } + return executable_table_def().multi_arch_executables()->Get( + executable_ordinal); +} + +// static +Status BytecodeModule::ValidateArgType(const BufferView& arg, + const MemRefTypeDef& expected_type) { + RETURN_IF_ERROR( + ValidateElementSize(arg.element_size * 8, *expected_type.element_type())); + + auto expected_shape = expected_type.shape(); + if (arg.shape.size() != expected_shape->size()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Argument should have rank " << expected_shape->size() + << " but has rank " << arg.shape.size(); + } + for (int i = 0; i < expected_shape->size(); ++i) { + auto dim_size = arg.shape[i]; + auto expected_dim_size = expected_shape->Get(i); + if (dim_size != expected_dim_size) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Argument dimension " << i << " should have size " + << expected_dim_size << " but has size " << dim_size; + } + } + return OkStatus(); +} + +} // namespace vm +} // namespace iree
diff --git a/vm/bytecode_module.h b/vm/bytecode_module.h new file mode 100644 index 0000000..b2df6d9 --- /dev/null +++ b/vm/bytecode_module.h
@@ -0,0 +1,103 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_VM_BYTECODE_MODULE_H_ +#define IREE_VM_BYTECODE_MODULE_H_ + +#include <memory> + +#include "base/flatbuffer_util.h" +#include "rt/function.h" +#include "rt/module.h" +#include "schemas/executable_table_def_generated.h" +#include "schemas/function_table_def_generated.h" +#include "schemas/module_def_generated.h" +#include "vm/opcode_info.h" +#include "vm/source_map_resolver.h" + +namespace iree { +namespace vm { + +using ModuleFile = FlatBufferFile<ModuleDef>; + +// A loaded bytecode module backed by a FlatBuffer. +class BytecodeModule : public rt::Module { + public: + static Status ValidateStructure(const ModuleDef& module_def); + + ~BytecodeModule() override; + + const ModuleDef& def() const { return module_def_; } + const FunctionTableDef& function_table_def() const { + return *module_def_.function_table(); + } + const ExecutableTableDef& executable_table_def() const { + return *module_def_.executable_table(); + } + + absl::string_view name() const override { + return WrapString(module_def_.name()); + } + + const rt::ModuleSignature signature() const override; + + rt::SourceResolver* source_resolver() const override { + return &source_resolver_; + } + + rt::Disassembler* disassembler() const override { + return disassembler_.get(); + } + + std::string DebugStringShort() const override; + + StatusOr<const rt::Function> LookupFunctionByOrdinal( + rt::Function::Linkage linkage, int32_t ordinal) const override; + + StatusOr<const rt::Function> LookupFunctionByName( + rt::Function::Linkage linkage, absl::string_view name) const override; + + StatusOr<absl::string_view> GetFunctionName(rt::Function::Linkage linkage, + int32_t ordinal) const override; + + StatusOr<const rt::FunctionSignature> GetFunctionSignature( + rt::Function::Linkage linkage, int32_t ordinal) const override; + + StatusOr<const FunctionDef*> GetFunctionDef(rt::Function::Linkage linkage, + int32_t ordinal) const; + + StatusOr<const MultiArchExecutableDef*> LookupMultiArchExecutable( + int executable_ordinal) const; + + protected: + BytecodeModule(std::unique_ptr<ModuleFile> module_file, + OpcodeTable opcode_table); + + static Status ValidateArgType(const hal::BufferView& arg, + const MemRefTypeDef& expected_type); + + private: + StatusOr<int32_t> MapFunctionOrdinal(rt::Function::Linkage linkage, + int32_t ordinal) const; + + std::unique_ptr<ModuleFile> module_file_; + const ModuleDef& module_def_; + mutable SourceMapResolver source_resolver_; + mutable std::unique_ptr<rt::Disassembler> disassembler_; +}; + +} // namespace vm +} // namespace iree + +#endif // IREE_VM_BYTECODE_MODULE_H_
diff --git a/vm/bytecode_reader.cc b/vm/bytecode_reader.cc new file mode 100644 index 0000000..7730857 --- /dev/null +++ b/vm/bytecode_reader.cc
@@ -0,0 +1,289 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "vm/bytecode_reader.h" + +#include "base/shape.h" +#include "base/status.h" +#include "hal/heap_buffer.h" +#include "vm/bytecode_module.h" + +namespace iree { +namespace vm { + +namespace { + +using ::iree::hal::BufferView; +using ::iree::rt::StackFrame; + +} // namespace + +StatusOr<const uint8_t*> BytecodeReader::AdvanceOffset() { + *stack_frame_->mutable_offset() = offset(); + // TODO(benvanik): make a flag and/or remove. + DVLOG(1) << "dispatch(" << stack_frame_->function().name() << "@" << offset() + << "): " << int(*bytecode_pc_); + for (int i = 0; i < registers_->buffer_views.size(); ++i) { + DVLOG(1) << "local[" << i << "] " + << registers_->buffer_views[i].DebugStringShort(); + } + return bytecode_pc_++; +} + +Status BytecodeReader::SkipLocals(int count) { + size_t stride = sizeof(uint16_t) * count; + if (bytecode_pc_ + stride >= bytecode_limit_) { + return OutOfRangeErrorBuilder(IREE_LOC) << "Bytecode underflow"; + } + bytecode_pc_ += stride; + return OkStatus(); +} + +Status BytecodeReader::ReadShape(Shape* out_shape) { + ASSIGN_OR_RETURN(auto shape_dims, ReadIndexList()); + *out_shape = Shape(shape_dims); + return OkStatus(); +} + +StatusOr<Shape> BytecodeReader::ReadShapePieces() { + // TODO(benvanik): rewrite to be faster (multiple offsets to walk both lists). + ASSIGN_OR_RETURN(auto shape_dims, ReadIndexList()); + if (shape_dims.size() >= kMaxRank) { + return UnimplementedErrorBuilder(IREE_LOC) + << "Shapes limited to rank " << kMaxRank << " right now"; + } + int expected_dynamic_dims = 0; + for (int i = 0; i < shape_dims.size(); ++i) { + if (shape_dims[i] == -1) { + ++expected_dynamic_dims; + } + } + + Shape shape(shape_dims); + ASSIGN_OR_RETURN(int dynamic_dims, ReadCount()); + if (dynamic_dims != expected_dynamic_dims) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Expected " << expected_dynamic_dims << " dynamic dims but only " + << dynamic_dims << " provided"; + } else if (dynamic_dims) { + for (int i = 0; i < shape_dims.size(); ++i) { + if (shape_dims[i] != -1) { + continue; + } + // TODO(benvanik): kill this embarrassment. + ASSIGN_OR_RETURN(auto dims_piece, ReadSlotElements<int32_t>()); + if (dims_piece.size() != 1) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Dims piece has rank " << dims_piece.size() << "; must be 1"; + } + shape[i] = dims_piece[0]; + } + } + return shape; +} + +StatusOr<Shape> BytecodeReader::ReadShapePieces(size_t* out_element_count) { + ASSIGN_OR_RETURN(auto shape, ReadShapePieces()); + *out_element_count = shape.element_count(); + return shape; +} + +StatusOr<absl::Span<const int32_t>> BytecodeReader::ReadIndexList() { + ASSIGN_OR_RETURN(int count, ReadCount()); + int stride = count * sizeof(int32_t); + if (bytecode_pc_ + stride >= bytecode_limit_) { + return OutOfRangeErrorBuilder(IREE_LOC) << "Bytecode underflow"; + } + auto list = absl::Span<const int32_t>( + reinterpret_cast<const int32_t*>(bytecode_pc_), count); + bytecode_pc_ += stride; + return list; +} + +Status BytecodeReader::SwitchStackFrame(StackFrame* new_stack_frame) { + // Flush old state. + auto* old_stack_frame = stack_frame_; + if (old_stack_frame) { + *old_stack_frame->mutable_offset() = offset(); + } + + // Switch the frame. The FiberState holds the full stack, this is just the + // current one for easy access. + stack_frame_ = new_stack_frame; + + // Setup state pointers for faster dereferencing. + const auto& function = new_stack_frame->function(); + ASSIGN_OR_RETURN( + const auto* function_def, + static_cast<const BytecodeModule*>(function.module()) + ->GetFunctionDef(function.linkage(), function.ordinal())); + const auto& bytecode = *function_def->bytecode(); + bytecode_base_ = bytecode.contents()->Data(); + bytecode_limit_ = bytecode_base_ + bytecode.contents()->size(); + bytecode_pc_ = bytecode_base_ + new_stack_frame->offset(); + registers_ = new_stack_frame->mutable_registers(); + return OkStatus(); +} + +Status BytecodeReader::CopyInputsAndSwitchStackFrame( + StackFrame* src_stack_frame, StackFrame* dst_stack_frame) { + ASSIGN_OR_RETURN(size_t src_count, ReadCount()); + auto& dst_buffer_views = dst_stack_frame->mutable_registers()->buffer_views; + for (int i = 0; i < std::min(src_count, dst_buffer_views.size()); ++i) { + ASSIGN_OR_RETURN(auto* src_local, + ReadLocal(src_stack_frame->mutable_registers())); + dst_buffer_views[i] = *src_local; + } + return SwitchStackFrame(dst_stack_frame); +} + +Status BytecodeReader::CopyResultsAndSwitchStackFrame( + StackFrame* src_stack_frame, StackFrame* dst_stack_frame) { + ASSIGN_OR_RETURN(int32_t src_count, ReadCount()); + // TODO(benvanik): avoid vector. + absl::InlinedVector<BufferView*, 8> src_locals(src_count); + for (int i = 0; i < src_count; ++i) { + ASSIGN_OR_RETURN(src_locals[i], + ReadLocal(src_stack_frame->mutable_registers())); + } + RETURN_IF_ERROR(SwitchStackFrame(dst_stack_frame)); + ASSIGN_OR_RETURN(int32_t dst_count, ReadCount()); + if (src_count != dst_count) { + return OutOfRangeErrorBuilder(IREE_LOC) + << "Src and dst value counts differ: " << src_count << " vs " + << dst_count; + } + for (int i = 0; i < dst_count; ++i) { + ASSIGN_OR_RETURN(auto* dst_local, + ReadLocal(dst_stack_frame->mutable_registers())); + *dst_local = *src_locals[i]; + } + return OkStatus(); +} + +Status BytecodeReader::CopySlots() { + ASSIGN_OR_RETURN(int32_t count, ReadCount()); + for (int i = 0; i < count; ++i) { + ASSIGN_OR_RETURN(auto* src_local, + ReadLocal(stack_frame_->mutable_registers())); + ASSIGN_OR_RETURN(auto* dst_local, + ReadLocal(stack_frame_->mutable_registers())); + *dst_local = *src_local; + } + return OkStatus(); +} + +Status BytecodeReader::BranchToOffset(int32_t offset) { + const uint8_t* new_bytecode_pc = bytecode_base_ + offset; + if (new_bytecode_pc < bytecode_base_ || new_bytecode_pc > bytecode_limit_) { + return OutOfRangeErrorBuilder(IREE_LOC) + << "Branch target " << offset + << " is out of bounds of the function bytecode (" + << static_cast<size_t>(bytecode_limit_ - bytecode_base_) + << "b total)"; + } + bytecode_pc_ = new_bytecode_pc; + return OkStatus(); +} + +StatusOr<BufferView> BytecodeReader::ReadConstant() { + BufferView buffer_view; + + // Element type defines the buffer_view size (but we don't really care about + // the data format). + ASSIGN_OR_RETURN(auto element_type, ReadType()); + buffer_view.element_size = element_type.element_size(); + + // Parse shape - constants always define a full shape. + RETURN_IF_ERROR(ReadShape(&buffer_view.shape)); + + // Read encoding to determine how the constant data is stored in the file. + ASSIGN_OR_RETURN(auto encoding, ReadValue<ConstantEncoding>()); + + // Get buffer for the constant data. + switch (encoding) { + case ConstantEncoding::kDense: { + // Validate we have all constant data present. + device_size_t serialized_length = buffer_view.byte_length(); + if (bytecode_pc_ + serialized_length >= bytecode_limit_) { + return OutOfRangeErrorBuilder(IREE_LOC) + << "Constant data out of bounds"; + } + + buffer_view.buffer = hal::HeapBuffer::Wrap( + hal::MemoryType::kHostLocal, hal::BufferUsage::kAll, bytecode_pc_, + serialized_length); + bytecode_pc_ += serialized_length; + break; + } + case ConstantEncoding::kSplat: { + // Validate we have at least one element worth of data in the buffer. + if (bytecode_pc_ + buffer_view.element_size >= bytecode_limit_) { + return OutOfRangeErrorBuilder(IREE_LOC) + << "Constant data out of bounds"; + } + + // TODO(benvanik): replace with fancy constant pool and such. + // NOTE: this is not much different than if a alloc_heap+broadcast pair + // had been in the IR. + buffer_view.buffer = hal::HeapBuffer::Allocate( + hal::MemoryType::kHostLocal, hal::BufferUsage::kAll, + buffer_view.byte_length()); + switch (buffer_view.element_size) { + case 1: { + uint8_t value = *reinterpret_cast<const uint8_t*>(bytecode_pc_); + RETURN_IF_ERROR(buffer_view.buffer->Fill8(value)); + break; + } + case 2: { + uint16_t value = *reinterpret_cast<const uint16_t*>(bytecode_pc_); + RETURN_IF_ERROR(buffer_view.buffer->Fill16(value)); + break; + } + case 4: { + uint32_t value = *reinterpret_cast<const uint32_t*>(bytecode_pc_); + RETURN_IF_ERROR(buffer_view.buffer->Fill32(value)); + break; + } + case 8: { + // TODO(benvanik): add Fill64. + uint64_t value = *reinterpret_cast<const uint64_t*>(bytecode_pc_); + ASSIGN_OR_RETURN(auto mapping, + buffer_view.buffer->MapMemory<uint64_t>( + hal::MemoryAccess::kDiscardWrite)); + auto mapped_data = mapping.mutable_contents(); + for (int i = 0; i < mapping.size(); ++i) { + mapped_data[i] = value; + } + break; + } + default: + return UnimplementedErrorBuilder(IREE_LOC) + << "Unimplemented splat element stride " + << buffer_view.element_size; + } + bytecode_pc_ += buffer_view.element_size; + break; + } + default: + return UnimplementedErrorBuilder(IREE_LOC) + << "Unimplemented constant encoding " + << static_cast<int>(encoding); + } + + return buffer_view; +} + +} // namespace vm +} // namespace iree
diff --git a/vm/bytecode_reader.h b/vm/bytecode_reader.h new file mode 100644 index 0000000..499d679 --- /dev/null +++ b/vm/bytecode_reader.h
@@ -0,0 +1,169 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_VM_BYTECODE_READER_H_ +#define IREE_VM_BYTECODE_READER_H_ + +#include "absl/base/attributes.h" +#include "absl/container/inlined_vector.h" +#include "base/status.h" +#include "hal/buffer_view.h" +#include "rt/context.h" +#include "rt/stack.h" +#include "rt/stack_frame.h" +#include "schemas/bytecode/bytecode_v0.h" +#include "vm/type.h" + +namespace iree { +namespace vm { + +class BytecodeReader { + public: + explicit BytecodeReader(rt::Stack* stack) : stack_(stack) {} + + int offset() const { return static_cast<int>(bytecode_pc_ - bytecode_base_); } + + StatusOr<const uint8_t*> AdvanceOffset(); + + Status SwitchStackFrame(rt::StackFrame* new_stack_frame); + Status BranchToOffset(int32_t offset); + + Status CopyInputsAndSwitchStackFrame(rt::StackFrame* src_stack_frame, + rt::StackFrame* dst_stack_frame); + Status CopyResultsAndSwitchStackFrame(rt::StackFrame* src_stack_frame, + rt::StackFrame* dst_stack_frame); + Status CopySlots(); + + StatusOr<hal::BufferView> ReadConstant(); + + ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<int> ReadCount() { + return ReadValue<uint8_t>(); + } + + ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<const Type> ReadType() { + ASSIGN_OR_RETURN(uint8_t type_index, ReadValue<uint8_t>()); + return Type::FromTypeIndex(type_index); + } + + ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<const rt::Function> ReadFunction() { + ASSIGN_OR_RETURN(auto value, ReadValue<uint32_t>()); + const auto& module = stack_frame_->module(); + return module.LookupFunctionByOrdinal(rt::Function::Linkage::kInternal, + value); + } + + ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<const rt::Function> + ReadImportFunction() { + ASSIGN_OR_RETURN(auto value, ReadValue<uint32_t>()); + const auto& module = stack_frame_->module(); + return stack_->context()->ResolveImport(&module, value); + } + + ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<hal::BufferView*> ReadLocal( + rt::Registers* registers) { + ASSIGN_OR_RETURN(auto value, ReadValue<uint16_t>()); + if (value > registers->buffer_views.size()) { + return OutOfRangeErrorBuilder(IREE_LOC) + << "Out of bounds local access " << value << " of " + << registers->buffer_views.size(); + } + return ®isters->buffer_views[value]; + } + + ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<hal::BufferView*> ReadLocal() { + return ReadLocal(registers_); + } + + Status SkipLocals(int count); + + ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<uint8_t> ReadUint8_t() { + return ReadValue<uint8_t>(); + } + + ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<uint16_t> ReadUint16_t() { + return ReadValue<uint16_t>(); + } + + ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<int32_t> ReadInt32() { + return ReadValue<int32_t>(); + } + + ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<uint32_t> ReadBlockOffset() { + return ReadValue<uint32_t>(); + } + + template <typename T, size_t N = 8> + ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<absl::InlinedVector<T, N>> + ReadSlotElements() { + ASSIGN_OR_RETURN(auto* local, ReadLocal(registers_)); + absl::InlinedVector<T, N> result(local->shape.element_count()); + if (sizeof(T) == local->element_size) { + // Fast(ish) path: requested element size matches the actual element size. + RETURN_IF_ERROR( + local->buffer->ReadData(0, result.data(), result.size() * sizeof(T))); + } else { + // Slow path: need to convert the data. + switch (local->element_size) { + case 4: { + ASSIGN_OR_RETURN(auto mapping, local->buffer->MapMemory<int32_t>( + hal::MemoryAccess::kRead)); + for (size_t i = 0; i < result.size(); ++i) { + result[i] = static_cast<T>(mapping[i]); + } + break; + } + case 8: { + ASSIGN_OR_RETURN(auto mapping, local->buffer->MapMemory<int64_t>( + hal::MemoryAccess::kRead)); + for (size_t i = 0; i < result.size(); ++i) { + result[i] = static_cast<T>(mapping[i]); + } + break; + } + default: + return UnimplementedErrorBuilder(IREE_LOC) + << "Unsupported local element size: " << local->element_size; + } + } + return result; + } + + Status ReadShape(Shape* out_shape); + + StatusOr<Shape> ReadShapePieces(); + StatusOr<Shape> ReadShapePieces(size_t* out_element_count); + + StatusOr<absl::Span<const int32_t>> ReadIndexList(); + + private: + template <typename T> + ABSL_ATTRIBUTE_ALWAYS_INLINE StatusOr<T> ReadValue() { + // TODO(benvanik): validate bounds. + T value = *reinterpret_cast<const T*>(bytecode_pc_); + bytecode_pc_ += sizeof(T); + return value; + } + + rt::Stack* stack_ = nullptr; + rt::StackFrame* stack_frame_ = nullptr; + const uint8_t* bytecode_base_ = nullptr; + const uint8_t* bytecode_limit_ = nullptr; + const uint8_t* bytecode_pc_ = nullptr; + rt::Registers* registers_ = nullptr; +}; + +} // namespace vm +} // namespace iree + +#endif // IREE_VM_BYTECODE_READER_H_
diff --git a/vm/bytecode_tables_interpreter.cc b/vm/bytecode_tables_interpreter.cc new file mode 100644 index 0000000..5744bf9 --- /dev/null +++ b/vm/bytecode_tables_interpreter.cc
@@ -0,0 +1,44 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "vm/bytecode_tables_interpreter.h" + +#include "schemas/bytecode/interpreter_bytecode_v0.h" + +namespace iree { +namespace vm { + +namespace { + +// Info table mapping 1:1 with bytecode ops. +// +// Note that we ensure the table is 256 elements long exactly to make sure +// that unused opcodes are handled gracefully. +static const OpcodeInfo kInfoTable[256] = { +#define DECLARE_INFO(ordinal, enum_value, name, flags, operand_encodings, ...) \ + OpcodeInfo{ \ + name, \ + flags, \ + {operand_encodings}, \ + }, + IREE_INTERPRETER_OPCODE_LIST(DECLARE_INFO, DECLARE_INFO) +#undef DECLARE_INFO +}; + +} // namespace + +OpcodeTable interpreter_opcode_table() { return kInfoTable; } + +} // namespace vm +} // namespace iree
diff --git a/vm/bytecode_tables_interpreter.h b/vm/bytecode_tables_interpreter.h new file mode 100644 index 0000000..6407e84 --- /dev/null +++ b/vm/bytecode_tables_interpreter.h
@@ -0,0 +1,28 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_VM_BYTECODE_TABLES_INTERPRETER_H_ +#define IREE_VM_BYTECODE_TABLES_INTERPRETER_H_ + +#include "vm/opcode_info.h" + +namespace iree { +namespace vm { + +OpcodeTable interpreter_opcode_table(); + +} // namespace vm +} // namespace iree + +#endif // IREE_VM_BYTECODE_TABLES_INTERPRETER_H_
diff --git a/vm/bytecode_tables_sequencer.cc b/vm/bytecode_tables_sequencer.cc new file mode 100644 index 0000000..02880dc --- /dev/null +++ b/vm/bytecode_tables_sequencer.cc
@@ -0,0 +1,44 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "vm/bytecode_tables_sequencer.h" + +#include "schemas/bytecode/sequencer_bytecode_v0.h" + +namespace iree { +namespace vm { + +namespace { + +// Info table mapping 1:1 with bytecode ops. +// +// Note that we ensure the table is 256 elements long exactly to make sure +// that unused opcodes are handled gracefully. +static const OpcodeInfo kInfoTable[256] = { +#define DECLARE_INFO(ordinal, enum_value, name, flags, operand_encodings, ...) \ + OpcodeInfo{ \ + name, \ + flags, \ + {operand_encodings}, \ + }, + IREE_SEQUENCER_OPCODE_LIST(DECLARE_INFO, DECLARE_INFO) +#undef DECLARE_INFO +}; + +} // namespace + +OpcodeTable sequencer_opcode_table() { return kInfoTable; } + +} // namespace vm +} // namespace iree
diff --git a/vm/bytecode_tables_sequencer.h b/vm/bytecode_tables_sequencer.h new file mode 100644 index 0000000..0ead8f6 --- /dev/null +++ b/vm/bytecode_tables_sequencer.h
@@ -0,0 +1,28 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_VM_BYTECODE_TABLES_SEQUENCER_H_ +#define IREE_VM_BYTECODE_TABLES_SEQUENCER_H_ + +#include "vm/opcode_info.h" + +namespace iree { +namespace vm { + +OpcodeTable sequencer_opcode_table(); + +} // namespace vm +} // namespace iree + +#endif // IREE_VM_BYTECODE_TABLES_SEQUENCER_H_
diff --git a/vm/bytecode_util.cc b/vm/bytecode_util.cc new file mode 100644 index 0000000..48deb32 --- /dev/null +++ b/vm/bytecode_util.cc
@@ -0,0 +1,43 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "vm/bytecode_util.h" + +namespace iree { +namespace vm { + +absl::string_view PredicateToString(CmpIPredicate p) { +#define PRED(index, name, str, ...) \ + case CmpIPredicate::name: \ + return str; + switch (p) { + IREE_CMPI_PREDICATE_LIST(PRED) +#undef PRED + } + return "<unknown>"; +} + +absl::string_view PredicateToString(CmpFPredicate p) { +#define PRED(index, name, str, ...) \ + case CmpFPredicate::name: \ + return str; + switch (p) { + IREE_CMPF_PREDICATE_LIST(PRED) +#undef PRED + } + return "<unknown>"; +} + +} // namespace vm +} // namespace iree
diff --git a/vm/bytecode_util.h b/vm/bytecode_util.h new file mode 100644 index 0000000..560ab36 --- /dev/null +++ b/vm/bytecode_util.h
@@ -0,0 +1,31 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_VM_BYTECODE_UTIL_H_ +#define IREE_VM_BYTECODE_UTIL_H_ + +#include "absl/strings/string_view.h" +#include "schemas/bytecode/bytecode_v0.h" + +namespace iree { +namespace vm { + +absl::string_view PredicateToString(CmpIPredicate predicate); + +absl::string_view PredicateToString(CmpFPredicate predicate); + +} // namespace vm +} // namespace iree + +#endif // IREE_VM_BYTECODE_UTIL_H_
diff --git a/vm/bytecode_validator.cc b/vm/bytecode_validator.cc new file mode 100644 index 0000000..54e8a82 --- /dev/null +++ b/vm/bytecode_validator.cc
@@ -0,0 +1,28 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "vm/bytecode_validator.h" + +namespace iree { +namespace vm { + +// static +Status BytecodeValidator::Validate(const BytecodeModule& module, + const BytecodeDef& bytecode_def) { + // TODO(benvanik): validate bytecode. + return OkStatus(); +} + +} // namespace vm +} // namespace iree
diff --git a/vm/bytecode_validator.h b/vm/bytecode_validator.h new file mode 100644 index 0000000..1ebb580 --- /dev/null +++ b/vm/bytecode_validator.h
@@ -0,0 +1,37 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_VM_BYTECODE_VALIDATOR_H_ +#define IREE_VM_BYTECODE_VALIDATOR_H_ + +#include "base/status.h" +#include "schemas/bytecode_def_generated.h" +#include "vm/bytecode_module.h" + +namespace iree { +namespace vm { + +// Validates bytecode such that success indicates the bytecode does not +// reference undefined types, functions, or required imports and all imports can +// be resolved with matching signatures. +class BytecodeValidator { + public: + static Status Validate(const BytecodeModule& module, + const BytecodeDef& bytecode_def); +}; + +} // namespace vm +} // namespace iree + +#endif // IREE_VM_BYTECODE_VALIDATOR_H_
diff --git a/vm/opcode_info.h b/vm/opcode_info.h new file mode 100644 index 0000000..1b61366 --- /dev/null +++ b/vm/opcode_info.h
@@ -0,0 +1,45 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_VM_OPCODE_INFO_H_ +#define IREE_VM_OPCODE_INFO_H_ + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "schemas/bytecode/bytecode_v0.h" + +namespace iree { +namespace vm { + +struct OpcodeInfo { + const char* mnemonic; + OpcodeFlagBitfield flag; + union { + const char operands_value[8]; + const OperandEncoding operands[8]; + }; +}; + +using OpcodeTable = absl::Span<const OpcodeInfo>; + +template <typename T> +inline const OpcodeInfo& GetOpcodeInfo(OpcodeTable opcode_table, T opcode) { + return opcode_table[static_cast<uint8_t>(opcode)]; +} + +} // namespace vm +} // namespace iree + +#endif // IREE_VM_OPCODE_INFO_H_
diff --git a/vm/sequencer_dispatch.cc b/vm/sequencer_dispatch.cc new file mode 100644 index 0000000..d8e8fa0 --- /dev/null +++ b/vm/sequencer_dispatch.cc
@@ -0,0 +1,561 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Implements a full bytecode dispatch system for sequencer ops. +// TODO(benvanik): rework to be async against CommandBuffers. + +#include "vm/sequencer_dispatch.h" + +#include <algorithm> + +#include "absl/base/attributes.h" +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_join.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "base/logging.h" +#include "base/memory.h" +#include "base/status.h" +#include "hal/buffer_view.h" +#include "hal/command_queue.h" +#include "hal/device.h" +#include "hal/heap_buffer.h" +#include "schemas/bytecode/sequencer_bytecode_v0.h" +#include "vm/bytecode_module.h" +#include "vm/bytecode_reader.h" +#include "vm/bytecode_tables_sequencer.h" +#include "vm/bytecode_util.h" +#include "vm/opcode_info.h" + +namespace iree { +namespace vm { + +namespace { + +using ::iree::hal::Buffer; +using ::iree::hal::BufferView; + +// TODO(benvanik): remove (this should happen via predication). +bool BufferViewIsTrue(const BufferView& buffer_view) { + if (buffer_view.element_size == 0 || !buffer_view.buffer || + buffer_view.byte_length() == 0) { + return false; + } + // TODO(benvanik): map more efficiently (based on element size?). + auto mapping = + buffer_view.buffer->MapMemory<uint8_t>(hal::MemoryAccess::kRead); + if (!mapping.ok()) { + return false; + } + for (uint8_t value : mapping.ValueOrDie().contents()) { + if (value) return true; + } + return false; +} + +// TODO(benvanik): insert fence callbacks and wait on fence. +Status CallExternalFunction(rt::Stack* stack, const rt::Function& function) { + // Marshal inputs and outputs. + const auto* stack_frame = stack->current_frame(); + auto buffer_views = absl::MakeSpan(stack_frame->registers().buffer_views); + absl::InlinedVector<hal::BufferView, 8> arguments( + buffer_views.begin(), + buffer_views.begin() + function.signature().argument_count()); + absl::InlinedVector<hal::BufferView, 8> results( + buffer_views.begin() + arguments.size(), buffer_views.end()); + return function.module()->Execute(stack, function, std::move(arguments), + &results); +} + +// Pretty prints an array, e.g. [1, 2, 3, 4] +inline std::string PrettyPrint(absl::Span<const int32_t> arr) { + return "[" + absl::StrJoin(arr, ",") + "]"; +} + +// Calculates the byte offset into a buffer corresponding to the indices in the +// given shape. +StatusOr<device_size_t> CalculateOffset(absl::Span<const int32_t> indices, + Shape shape, uint8_t element_size) { + if (shape.empty() || indices.size() > shape.size()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Indices " << PrettyPrint(indices) << " out of bounds of shape " + << PrettyPrint(shape.subspan()); + } + device_size_t offset = 0; + for (int i = 0; i < indices.size(); ++i) { + if (indices[i] >= shape[i]) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Indices[" << i << "]=" << indices[i] + << " out of bounds of shape " << PrettyPrint(shape.subspan()); + } + device_size_t axis_offset = indices[i]; + for (int j = i + 1; j < shape.size(); ++j) { + axis_offset *= shape[j]; + } + offset += axis_offset; + } + offset *= element_size; + return offset; +} + +} // namespace + +Status DispatchSequence(const hal::DevicePlacement& placement, rt::Stack* stack, + rt::StackFrame* entry_stack_frame, + absl::Span<BufferView> entry_results) { + // Dispatch table mapping 1:1 with bytecode ops. + // Each entry is a label within this function that can be used for computed + // goto. You can find more information on computed goto here: + // https://eli.thegreenplace.net/2012/07/12/computed-goto-for-efficient-dispatch-tables + // + // Note that we ensure the table is 256 elements long exactly to make sure + // that unused opcodes are handled gracefully. + static const void* kDispatchTable[256] = { +#define DECLARE_DISPATCH(ordinal, name, ...) &&_dispatch_##name, +#define DECLARE_DISPATCH_RESERVED(ordinal, name, ...) &&_dispatch_unhandled, + IREE_SEQUENCER_OPCODE_LIST(DECLARE_DISPATCH, DECLARE_DISPATCH_RESERVED) +#undef DECLARE_DISPATCH +#undef DECLARE_DISPATCH_RESERVED + }; + + // Primary dispatch state. This is our 'native stack frame' and really just + // enough to make dereferencing common addresses (like the current offset) + // faster. You can think of this like CPU state (like PC). + // + // We hope that LLVM decides to keep these in registers (as they are touched + // for every instruction executed). The stack_frame will change as we call + // into different functions. + BytecodeReader reader(stack); + RETURN_IF_ERROR(reader.SwitchStackFrame(entry_stack_frame)); + +#define DISPATCH_NEXT() \ + { \ + uint8_t opcode = *reader.AdvanceOffset().ValueOrDie(); \ + DVLOG(1) << "Sequencer dispatching op code: " \ + << GetOpcodeInfo(sequencer_opcode_table(), opcode).mnemonic; \ + goto* kDispatchTable[opcode]; \ + } + +#define DISPATCH_CORE_OPCODE(opcode, body) \ + _dispatch_##opcode : {body} DISPATCH_NEXT() + + DISPATCH_NEXT(); + + DISPATCH_CORE_OPCODE(kConstant, { + ASSIGN_OR_RETURN(auto value, reader.ReadConstant()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + // TODO(b/139121143): until we have full command buffers we need to do this. + ASSIGN_OR_RETURN(value.buffer, + placement.device->allocator()->AllocateConstant( + hal::BufferUsage::kConstant | hal::BufferUsage::kAll, + std::move(value.buffer))); + *dst_local = std::move(value); + }); + + DISPATCH_CORE_OPCODE(kCall, { + auto* old_stack_frame = stack->current_frame(); + ASSIGN_OR_RETURN(const auto& target_function, reader.ReadFunction()); + // TODO(benvanik): rework register storage interface. + ASSIGN_OR_RETURN( + const auto* function_def, + static_cast<const BytecodeModule*>(target_function.module()) + ->GetFunctionDef(target_function.linkage(), + target_function.ordinal())); + ASSIGN_OR_RETURN(auto* new_stack_frame, stack->PushFrame(target_function)); + new_stack_frame->mutable_registers()->buffer_views.resize( + function_def->bytecode()->local_count()); + RETURN_IF_ERROR( + reader.CopyInputsAndSwitchStackFrame(old_stack_frame, new_stack_frame)); + DVLOG(1) << "Call; stack now: " << stack->DebugString(); + }); + + DISPATCH_CORE_OPCODE(kCallImport, { + auto* old_stack_frame = stack->current_frame(); + ASSIGN_OR_RETURN(const auto& target_function, reader.ReadImportFunction()); + ASSIGN_OR_RETURN(auto* new_stack_frame, stack->PushFrame(target_function)); + // TODO(benvanik): rework register storage interface. + const auto& signature = target_function.signature(); + new_stack_frame->mutable_registers()->buffer_views.resize( + signature.argument_count() + signature.result_count()); + RETURN_IF_ERROR( + reader.CopyInputsAndSwitchStackFrame(old_stack_frame, new_stack_frame)); + DVLOG(1) << "Call native import; stack now: " << stack->DebugString(); + RETURN_IF_ERROR(CallExternalFunction(stack, target_function)); + RETURN_IF_ERROR(reader.CopyResultsAndSwitchStackFrame(old_stack_frame, + new_stack_frame)); + RETURN_IF_ERROR(stack->PopFrame()); + DVLOG(1) << "Return from native; stack now: " << stack->DebugString(); + }); + + DISPATCH_CORE_OPCODE(kCallIndirect, { + return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented call_indirect"; + }); + + DISPATCH_CORE_OPCODE(kReturn, { + auto* old_stack_frame = stack->current_frame(); + auto* new_stack_frame = stack->caller_frame(); + if (old_stack_frame == entry_stack_frame) { + // Returning from entry function. Marshal results from the return stmt. + ASSIGN_OR_RETURN(int32_t src_count, reader.ReadCount()); + for (int i = 0; i < src_count; ++i) { + ASSIGN_OR_RETURN( + auto* src_local, + reader.ReadLocal(old_stack_frame->mutable_registers())); + entry_results[i] = std::move(*src_local); + } + DVLOG(1) << "Returning to entry"; + return OkStatus(); + } else if (!new_stack_frame) { + return FailedPreconditionErrorBuilder(IREE_LOC) << "Stack underflow"; + } + RETURN_IF_ERROR(reader.CopyResultsAndSwitchStackFrame(old_stack_frame, + new_stack_frame)); + RETURN_IF_ERROR(stack->PopFrame()); + DVLOG(1) << "Return; stack now: " << stack->DebugString(); + }); + + DISPATCH_CORE_OPCODE(kBranch, { + ASSIGN_OR_RETURN(int32_t offset, reader.ReadBlockOffset()); + RETURN_IF_ERROR(reader.CopySlots()); + RETURN_IF_ERROR(reader.BranchToOffset(offset)); + }); + + DISPATCH_CORE_OPCODE(kCondBranch, { + // Evaluate condition first so we can do the copies as we read them for + // which side of the branch we take. + ASSIGN_OR_RETURN(auto* cond_local, reader.ReadLocal()); + bool cond_value = BufferViewIsTrue(*cond_local); + ASSIGN_OR_RETURN(int32_t true_offset, reader.ReadBlockOffset()); + + if (cond_value) { + RETURN_IF_ERROR(reader.CopySlots()); + RETURN_IF_ERROR(reader.BranchToOffset(true_offset)); + } else { + ASSIGN_OR_RETURN(int32_t true_op_count, reader.ReadCount()); + RETURN_IF_ERROR(reader.SkipLocals(2 * true_op_count)); + ASSIGN_OR_RETURN(int32_t false_offset, reader.ReadBlockOffset()); + + RETURN_IF_ERROR(reader.CopySlots()); + RETURN_IF_ERROR(reader.BranchToOffset(false_offset)); + } + }); + + DISPATCH_CORE_OPCODE(kDynamicDispatch, { + return UnimplementedErrorBuilder(IREE_LOC) + << "Unimplemented dynamic_dispatch"; + }); + + DISPATCH_CORE_OPCODE(kStaticDispatch, { + // TODO(benvanik): the real sequencer :) + ASSIGN_OR_RETURN(auto dispatch_ordinal, reader.ReadInt32()); + ASSIGN_OR_RETURN(auto export_ordinal, reader.ReadUint16_t()); + ASSIGN_OR_RETURN( + const auto* multi_arch_executable_def, + static_cast<const BytecodeModule&>(stack->current_frame()->module()) + .LookupMultiArchExecutable(dispatch_ordinal)); + if (export_ordinal >= multi_arch_executable_def->entry_point_count()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Invalid executable export ordinal " << export_ordinal; + } + auto* executable_def = multi_arch_executable_def->executables()->Get(0); + hal::ExecutableSpec executable_spec; + executable_spec.format = executable_def->format(); + executable_spec.executable_data = absl::Span<const uint8_t>( + executable_def->contents()->data(), executable_def->contents()->size()); + auto executable_cache = placement.device->CreateExecutableCache(); + ref_ptr<hal::Executable> executable; + for (auto* executable_def : *multi_arch_executable_def->executables()) { + if (!executable_cache->CanPrepareFormat(executable_def->format())) { + continue; + } + hal::ExecutableSpec executable_spec; + executable_spec.format = executable_def->format(); + executable_spec.executable_data = + absl::Span<const uint8_t>(executable_def->contents()->data(), + executable_def->contents()->size()); + ASSIGN_OR_RETURN(executable, + executable_cache->PrepareExecutable( + hal::ExecutableCachingMode::kDefault | + hal::ExecutableCachingMode::kAliasProvidedData, + executable_spec), + _.LogError()); + break; + } + if (!executable) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "No executable found for the current driver"; + } + + ASSIGN_OR_RETURN(int workload_x, reader.ReadInt32()); + ASSIGN_OR_RETURN(int workload_y, reader.ReadInt32()); + ASSIGN_OR_RETURN(int workload_z, reader.ReadInt32()); + + std::vector<hal::BufferBinding> bindings; + ASSIGN_OR_RETURN(int input_count, reader.ReadCount()); + for (int i = 0; i < input_count; ++i) { + ASSIGN_OR_RETURN(auto* input_local, reader.ReadLocal()); + bindings.push_back(hal::BufferBinding( + input_local->buffer->allowed_access() & hal::MemoryAccess::kAll, + *input_local)); + } + ASSIGN_OR_RETURN(int output_count, reader.ReadCount()); + for (int i = 0; i < output_count; ++i) { + ASSIGN_OR_RETURN(auto* output_local, reader.ReadLocal()); + bindings.push_back( + hal::BufferBinding(hal::MemoryAccess::kWrite, *output_local)); + } + ASSIGN_OR_RETURN(int result_count, reader.ReadCount()); + CHECK_EQ(0, result_count) << "Results not yet implemented"; + + ASSIGN_OR_RETURN( + auto cmd, + placement.device->CreateCommandBuffer( + hal::CommandBufferMode::kOneShot, + hal::CommandCategory::kTransfer | hal::CommandCategory::kDispatch), + _.LogError()); + RETURN_IF_ERROR(cmd->Begin()); + hal::DispatchRequest dispatch_request; + dispatch_request.executable = executable.get(); + dispatch_request.entry_point = export_ordinal; + dispatch_request.workload[0] = workload_x; + dispatch_request.workload[1] = workload_y; + dispatch_request.workload[2] = workload_z; + dispatch_request.bindings = bindings; + RETURN_IF_ERROR(cmd->Dispatch(dispatch_request)); + RETURN_IF_ERROR(cmd->End()); + auto* cmd_ptr = cmd.get(); + + auto* queue = placement.device->dispatch_queues().front(); + hal::SubmissionBatch batch; + batch.command_buffers = absl::MakeConstSpan(&cmd_ptr, 1); + ASSIGN_OR_RETURN(auto fence, placement.device->CreateFence(0u)); + RETURN_IF_ERROR(queue->Submit(batch, {fence.get(), 1u})); + RETURN_IF_ERROR(placement.device->WaitAllFences({{fence.get(), 1u}}, + absl::InfiniteFuture())); + }); + + DISPATCH_CORE_OPCODE(kAllocStatic, { + return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented alloc_static"; + }); + + DISPATCH_CORE_OPCODE(kAllocStack, { + return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented alloc_stack"; + }); + + DISPATCH_CORE_OPCODE(kAllocStackInit, { + return UnimplementedErrorBuilder(IREE_LOC) + << "Unimplemented alloc_stack_init"; + }); + + DISPATCH_CORE_OPCODE(kAllocHeap, { + ASSIGN_OR_RETURN(auto heap_type, reader.ReadInt32()); + ASSIGN_OR_RETURN(auto type, reader.ReadType()); + size_t element_size = type.element_size(); + + // TODO(benvanik): more efficient reading and storage. + size_t element_count = 0; + ASSIGN_OR_RETURN(auto shape, reader.ReadShapePieces(&element_count)); + size_t allocation_size = element_size * element_count; + + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + dst_local->element_size = element_size; + dst_local->shape = shape; + + // TODO(benvanik): pick an allocator and use that instead. + CHECK_EQ(heap_type, 0); + auto* allocator = placement.device->allocator(); + ASSIGN_OR_RETURN( + dst_local->buffer, + allocator->Allocate( + hal::MemoryType::kHostLocal | hal::MemoryType::kDeviceVisible, + hal::BufferUsage::kAll, allocation_size)); + }); + + DISPATCH_CORE_OPCODE(kDiscard, { + // NOTE: if we were an encoder we would actually discard the buffer. + ASSIGN_OR_RETURN(auto* local, reader.ReadLocal()); + *local = {}; + }); + + DISPATCH_CORE_OPCODE(kComputeRange, { + ASSIGN_OR_RETURN(auto shape_data, reader.ReadSlotElements<int32_t>()); + ASSIGN_OR_RETURN(auto element_size, reader.ReadUint8_t()); + ASSIGN_OR_RETURN(auto indices, reader.ReadSlotElements<int32_t>()); + ASSIGN_OR_RETURN(auto lengths, reader.ReadSlotElements<int32_t>()); + ASSIGN_OR_RETURN(auto* dst_offset_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* dst_length_local, reader.ReadLocal()); + + Shape shape(shape_data); + ASSIGN_OR_RETURN(device_size_t dst_offset, + CalculateOffset(indices, shape, element_size)); + RETURN_IF_ERROR( + dst_offset_local->buffer->WriteData(0, &dst_offset, sizeof(int32_t))); + + // A buffer range can only be computed for contiguous memory. To ensure that + // this only requests such, we validate that the offset in the buffer + // between the start and end indices is the same as the requested size. + device_size_t dst_length = element_size; + for (int i = 0; i < lengths.size(); ++i) { + dst_length *= lengths[i]; + indices[i] += lengths[i] - 1; + } + ASSIGN_OR_RETURN(auto end_offset, + CalculateOffset(indices, shape, element_size)); + auto offset_based_length = end_offset - dst_offset + element_size; + if (dst_length != offset_based_length) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Cannot compute range for non-contiguous region of memory;" + << " shape: " << PrettyPrint(shape.subspan()) + << " indices: " << PrettyPrint(indices) + << " lengths: " << PrettyPrint(lengths); + } + RETURN_IF_ERROR( + dst_length_local->buffer->WriteData(0, &dst_length, sizeof(int32_t))); + }); + + DISPATCH_CORE_OPCODE(kShape, { + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + RETURN_IF_ERROR(dst_local->buffer->WriteData( + 0, src_local->shape.subspan().data(), + src_local->shape.subspan().size() * sizeof(int32_t))); + }); + + DISPATCH_CORE_OPCODE(kLength, { + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + int32_t length = src_local->shape.element_count(); + RETURN_IF_ERROR(dst_local->buffer->WriteData(0, &length, sizeof(int32_t))); + }); + + DISPATCH_CORE_OPCODE(kDynamicSlice, { + // TODO(b/139299169): implement indirect copies to avoid CPU readback. + return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented dynamic_slice"; + }); + + DISPATCH_CORE_OPCODE(kStaticSlice, { + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto offset, reader.ReadInt32()); + ASSIGN_OR_RETURN(auto length, reader.ReadInt32()); + ASSIGN_OR_RETURN(auto type, reader.ReadType()); + ASSIGN_OR_RETURN(auto shape_data, reader.ReadSlotElements<int32_t>()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + Shape new_shape = Shape{shape_data}; + if (new_shape.element_count() * type.element_size() != length) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "New element count " << new_shape.element_count() + << " != length slice " << length; + } + ASSIGN_OR_RETURN(dst_local->buffer, + Buffer::Subspan(src_local->buffer, offset, length)); + dst_local->shape = new_shape; + dst_local->element_size = type.element_size(); + }); + + DISPATCH_CORE_OPCODE(kDynamicCopy, { + // TODO(b/139299169): implement indirect copies to avoid CPU readback. + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto src_offset_span, reader.ReadSlotElements<int32_t>()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto dst_offset_span, reader.ReadSlotElements<int32_t>()); + ASSIGN_OR_RETURN(auto length_span, reader.ReadSlotElements<int32_t>()); + RETURN_IF_ERROR(dst_local->buffer->CopyData( + dst_offset_span.front(), src_local->buffer.get(), + src_offset_span.front(), length_span.front())); + }); + + DISPATCH_CORE_OPCODE(kStaticCopy, { + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto src_offset, reader.ReadInt32()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto dst_offset, reader.ReadInt32()); + ASSIGN_OR_RETURN(auto length, reader.ReadInt32()); + RETURN_IF_ERROR(dst_local->buffer->CopyData( + dst_offset, src_local->buffer.get(), src_offset, length)); + }); + + DISPATCH_CORE_OPCODE(kDynamicFill, { + // TODO(b/139299169): implement indirect fills to avoid CPU readback. + return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented dynamic_fill"; + }); + + DISPATCH_CORE_OPCODE(kStaticFill, { + ASSIGN_OR_RETURN(auto value, reader.ReadInt32()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto dst_offset, reader.ReadInt32()); + ASSIGN_OR_RETURN(auto length, reader.ReadInt32()); + RETURN_IF_ERROR(dst_local->buffer->Fill32(dst_offset, length, value)); + }); + + DISPATCH_CORE_OPCODE(kClone, { + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + dst_local->element_size = src_local->element_size; + dst_local->shape = src_local->shape; + ASSIGN_OR_RETURN(dst_local->buffer, placement.device->allocator()->Allocate( + src_local->buffer->memory_type(), + src_local->buffer->usage(), + src_local->buffer->byte_length())); + RETURN_IF_ERROR(dst_local->buffer->CopyData(0, src_local->buffer.get())); + }); + + DISPATCH_CORE_OPCODE(kAssign, { + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + *dst_local = *src_local; + }); + + DISPATCH_CORE_OPCODE(kCondAssign, { + ASSIGN_OR_RETURN(auto* cond_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* lhs_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* rhs_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + *dst_local = BufferViewIsTrue(*cond_local) ? *lhs_local : *rhs_local; + }); + + DISPATCH_CORE_OPCODE(kReshape, { + // TODO(benvanik): more logic required if strides differ. + ASSIGN_OR_RETURN(auto* src_local, reader.ReadLocal()); + ASSIGN_OR_RETURN(auto shape_data, reader.ReadSlotElements<int32_t>()); + ASSIGN_OR_RETURN(auto* dst_local, reader.ReadLocal()); + Shape new_shape = Shape{shape_data}; + if (src_local->shape.element_count() != new_shape.element_count()) { + return InvalidArgumentErrorBuilder(IREE_LOC) + << "New element count " << new_shape.element_count() + << " != source element count " << src_local->shape.element_count(); + } + dst_local->shape = new_shape; + dst_local->buffer = add_ref(src_local->buffer); + dst_local->element_size = src_local->element_size; + }); + + DISPATCH_CORE_OPCODE(kTrace, { + return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented trace"; + }); + + DISPATCH_CORE_OPCODE(kBreak, { + return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented break"; + }); + + DISPATCH_CORE_OPCODE(kCondBreak, { + return UnimplementedErrorBuilder(IREE_LOC) << "Unimplemented cond_break"; + }); + +_dispatch_unhandled: + // TODO(benvanik): better tracing. + return UnimplementedErrorBuilder(IREE_LOC) << "Unknown dispatch opcode"; +} + +} // namespace vm +} // namespace iree
diff --git a/vm/sequencer_dispatch.h b/vm/sequencer_dispatch.h new file mode 100644 index 0000000..814f1d5 --- /dev/null +++ b/vm/sequencer_dispatch.h
@@ -0,0 +1,35 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_VM_SEQUENCER_DISPATCH_H_ +#define IREE_VM_SEQUENCER_DISPATCH_H_ + +#include "base/status.h" +#include "hal/buffer_view.h" +#include "hal/device_placement.h" +#include "rt/stack.h" +#include "rt/stack_frame.h" + +namespace iree { +namespace vm { + +// TODO(benvanik): API that supports yielding. +Status DispatchSequence(const hal::DevicePlacement& placement, rt::Stack* stack, + rt::StackFrame* entry_stack_frame, + absl::Span<hal::BufferView> entry_results); + +} // namespace vm +} // namespace iree + +#endif // IREE_VM_SEQUENCER_DISPATCH_H_
diff --git a/vm/sequencer_module.cc b/vm/sequencer_module.cc new file mode 100644 index 0000000..52b1ca7 --- /dev/null +++ b/vm/sequencer_module.cc
@@ -0,0 +1,112 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "vm/sequencer_module.h" + +#include "absl/memory/memory.h" +#include "base/status.h" +#include "base/tracing.h" +#include "hal/buffer_view.h" +#include "rt/context.h" +#include "rt/instance.h" +#include "vm/bytecode_tables_sequencer.h" +#include "vm/sequencer_dispatch.h" + +namespace iree { +namespace vm { + +namespace { + +using ::iree::hal::BufferView; +using ::iree::rt::Function; +using ::iree::rt::Module; + +} // namespace + +// static +StatusOr<ref_ptr<rt::Module>> SequencerModule::FromDef( + const ModuleDef& module_def) { + ASSIGN_OR_RETURN(auto module_file, ModuleFile::Create(&module_def, []() {})); + return FromFile(std::move(module_file)); +} + +// static +StatusOr<ref_ptr<rt::Module>> SequencerModule::FromFile( + std::unique_ptr<ModuleFile> module_file) { + if (module_file->root() == nullptr) { + return InvalidArgumentErrorBuilder(IREE_LOC) << "No root ModuleDef present"; + } + const auto& module_def = *module_file->root(); + + // Validates the structure of the module (but not bytecode). + // This ensures we don't have flatbuffer vectors will null entries, etc. + RETURN_IF_ERROR(BytecodeModule::ValidateStructure(module_def)); + + auto module = assign_ref(new SequencerModule(std::move(module_file))); + + // TODO(benvanik): validate internals here? or make explicit? + + return {std::move(module)}; +} + +SequencerModule::SequencerModule(std::unique_ptr<ModuleFile> module_file) + : BytecodeModule(std::move(module_file), sequencer_opcode_table()) {} + +SequencerModule::~SequencerModule() = default; + +Status SequencerModule::Execute( + rt::Stack* stack, const Function function, + absl::InlinedVector<hal::BufferView, 8> arguments, + absl::InlinedVector<hal::BufferView, 8>* results) const { + IREE_TRACE_SCOPE0("SequencerModule::Execute"); + + // Push stack frame for the function we are calling. + ASSIGN_OR_RETURN(auto* callee_stack_frame, stack->PushFrame(function)); + + // TODO(benvanik): rework register storage interface. + ASSIGN_OR_RETURN(const auto* function_def, + GetFunctionDef(function.linkage(), function.ordinal())); + auto* registers = callee_stack_frame->mutable_registers(); + registers->buffer_views.resize(function_def->bytecode()->local_count()); + + // Marshal input arguments. + for (int i = 0; i < arguments.size(); ++i) { + auto arg = arguments[i]; + auto expected_arg_type = function_def->type()->inputs()->Get(i); + RETURN_IF_ERROR(BytecodeModule::ValidateArgType( + arg, *expected_arg_type->type_union_as_MemRefTypeDef())) + << "Function " << function.name() << " argument " << i; + registers->buffer_views[i] = std::move(arg); + } + + // TODO(benvanik): change to: + // get command queue (any command queue) + // make command buffer + // record dispatch + // submit + // wait on fence + ASSIGN_OR_RETURN( + auto placement, + stack->context()->instance()->device_manager()->ResolvePlacement({})); + RETURN_IF_ERROR(DispatchSequence(placement, stack, callee_stack_frame, + absl::MakeSpan(*results))); + + // Pop the callee frame to balance out the stack. + RETURN_IF_ERROR(stack->PopFrame()); + + return OkStatus(); +} + +} // namespace vm +} // namespace iree
diff --git a/vm/sequencer_module.h b/vm/sequencer_module.h new file mode 100644 index 0000000..6be5464 --- /dev/null +++ b/vm/sequencer_module.h
@@ -0,0 +1,46 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_VM_SEQUENCER_MODULE_H_ +#define IREE_VM_SEQUENCER_MODULE_H_ + +#include <memory> + +#include "vm/bytecode_module.h" + +namespace iree { +namespace vm { + +// A module using the sequencer bytecode ops. +class SequencerModule final : public BytecodeModule { + public: + static StatusOr<ref_ptr<rt::Module>> FromDef(const ModuleDef& module_def); + static StatusOr<ref_ptr<rt::Module>> FromFile( + std::unique_ptr<ModuleFile> module_file); + + ~SequencerModule() override; + + Status Execute( + rt::Stack* stack, const rt::Function function, + absl::InlinedVector<hal::BufferView, 8> arguments, + absl::InlinedVector<hal::BufferView, 8>* results) const override; + + private: + explicit SequencerModule(std::unique_ptr<ModuleFile> module_file); +}; + +} // namespace vm +} // namespace iree + +#endif // IREE_VM_SEQUENCER_MODULE_H_
diff --git a/vm/source_map_resolver.cc b/vm/source_map_resolver.cc new file mode 100644 index 0000000..c5da2f7 --- /dev/null +++ b/vm/source_map_resolver.cc
@@ -0,0 +1,194 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "vm/source_map_resolver.h" + +#include "base/flatbuffer_util.h" +#include "base/status.h" +#include "schemas/source_map_def_generated.h" + +namespace iree { +namespace vm { + +namespace { + +Status PrintLocation(const SourceMapResolver& source_map, + const FunctionSourceMapDef& function_source_map, + const LocationDef& location, std::ostream* stream); + +Status PrintFileLocation(const SourceMapResolver& source_map, + const FunctionSourceMapDef& function_source_map, + const FileLocationDef& location, + std::ostream* stream) { + ASSIGN_OR_RETURN(auto filename, + source_map.GetUniqueString(location.filename())); + *stream << filename << ":" << location.line() << ":" << location.column(); + return OkStatus(); +} + +Status PrintNameLocation(const SourceMapResolver& source_map, + const FunctionSourceMapDef& function_source_map, + const NameLocationDef& location, + std::ostream* stream) { + ASSIGN_OR_RETURN(auto name, source_map.GetUniqueString(location.name())); + *stream << "\"" << name << "\""; + return OkStatus(); +} + +Status PrintCallSiteLocation(const SourceMapResolver& source_map, + const FunctionSourceMapDef& function_source_map, + const CallSiteLocationDef& location, + std::ostream* stream) { + *stream << "(callsites todo)"; + return OkStatus(); +} + +Status PrintFusedLocation(const SourceMapResolver& source_map, + const FunctionSourceMapDef& function_source_map, + const FusedLocationDef& location, + std::ostream* stream) { + *stream << "fused["; + if (location.locations()) { + for (int i = 0; i < location.locations()->size(); ++i) { + if (i > 0) *stream << ", "; + int location_ordinal = location.locations()->Get(i); + const auto& child_location = + *function_source_map.location_table()->Get(location_ordinal); + RETURN_IF_ERROR(PrintLocation(source_map, function_source_map, + child_location, stream)); + } + } + *stream << "]"; + return OkStatus(); +} + +Status PrintLocation(const SourceMapResolver& source_map, + const FunctionSourceMapDef& function_source_map, + const LocationDef& location, std::ostream* stream) { + switch (location.location_union_type()) { + case LocationDefUnion::FileLocationDef: + return PrintFileLocation(source_map, function_source_map, + *location.location_union_as_FileLocationDef(), + stream); + case LocationDefUnion::NameLocationDef: + return PrintNameLocation(source_map, function_source_map, + *location.location_union_as_NameLocationDef(), + stream); + case LocationDefUnion::CallSiteLocationDef: + return PrintCallSiteLocation( + source_map, function_source_map, + *location.location_union_as_CallSiteLocationDef(), stream); + case LocationDefUnion::FusedLocationDef: + return PrintFusedLocation(source_map, function_source_map, + *location.location_union_as_FusedLocationDef(), + stream); + default: + return UnimplementedErrorBuilder(IREE_LOC) + << "Unhandled location type " + << static_cast<int>(location.location_union_type()); + } +} + +} // namespace + +// static +SourceMapResolver SourceMapResolver::FromModule(const ModuleDef& module_def) { + if (module_def.source_map()) { + return SourceMapResolver{*module_def.source_map()}; + } + return {}; +} + +StatusOr<absl::string_view> SourceMapResolver::GetUniqueString( + int string_index) const { + if (empty()) { + return NotFoundErrorBuilder(IREE_LOC) << "No source map present"; + } + const auto* string_table = source_map_def_->string_table(); + if (string_table && string_table->size() > string_index) { + return WrapString(string_table->Get(string_index)); + } + return NotFoundErrorBuilder(IREE_LOC) + << "String index " << string_index << " not present in string table"; +} + +StatusOr<const FunctionSourceMapDef*> SourceMapResolver::GetFunctionSourceMap( + int function_ordinal) const { + if (empty()) { + return NotFoundErrorBuilder(IREE_LOC) << "No source map present"; + } + const auto* function_table = source_map_def_->function_table(); + if (function_table && function_table->size() > function_ordinal) { + const auto* function_source_map = function_table->Get(function_ordinal); + if (function_source_map && function_source_map->location_table() && + function_source_map->bytecode_map()) { + return function_source_map; + } + } + return NotFoundErrorBuilder(IREE_LOC) + << "Function ordinal " << function_ordinal + << " source map not present in function table"; +} + +absl::optional<rt::SourceLocation> SourceMapResolver::ResolveFunctionOffset( + const rt::Function& function, rt::SourceOffset offset) { + if (empty()) return absl::nullopt; + auto function_source_map_or = GetFunctionSourceMap(function.ordinal()); + if (!function_source_map_or.ok()) { + return absl::nullopt; + } + const auto* function_source_map = function_source_map_or.ValueOrDie(); + const auto* bytecode_map = function_source_map->bytecode_map(); + if (!bytecode_map) return absl::nullopt; + + // TODO(benvanik): allow fuzzy offset matching/table sparsity. + int location_ordinal = -1; + for (const auto* map_loc : *bytecode_map) { + if (map_loc->offset() == offset) { + location_ordinal = map_loc->location(); + break; + } + } + if (location_ordinal == -1) { + return absl::nullopt; + } + + return rt::SourceLocation(this, + { + reinterpret_cast<uint64_t>(function_source_map), + static_cast<uint64_t>(location_ordinal), + }); +} + +void SourceMapResolver::PrintSourceLocation( + rt::SourceResolverArgs resolver_args, std::ostream* stream) const { + if (empty()) { + *stream << "<unknown>"; + return; + } + + auto* function_source_map = + reinterpret_cast<FunctionSourceMapDef*>(resolver_args[0]); + int location_ordinal = static_cast<int>(resolver_args[1]); + + const auto& location = + *function_source_map->location_table()->Get(location_ordinal); + auto status = PrintLocation(*this, *function_source_map, location, stream); + if (!status.ok()) { + *stream << status; + } +} + +} // namespace vm +} // namespace iree
diff --git a/vm/source_map_resolver.h b/vm/source_map_resolver.h new file mode 100644 index 0000000..63a9f01 --- /dev/null +++ b/vm/source_map_resolver.h
@@ -0,0 +1,57 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_VM_SOURCE_MAP_RESOLVER_H_ +#define IREE_VM_SOURCE_MAP_RESOLVER_H_ + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "base/status.h" +#include "rt/source_resolver.h" +#include "schemas/module_def_generated.h" +#include "schemas/source_map_def_generated.h" + +namespace iree { +namespace vm { + +class SourceMapResolver final : public rt::SourceResolver { + public: + static SourceMapResolver FromModule(const ModuleDef& module_def); + + SourceMapResolver() = default; + explicit SourceMapResolver(const SourceMapDef& source_map_def) + : source_map_def_(&source_map_def) {} + + bool empty() const { return source_map_def_ == nullptr; } + const SourceMapDef* def() const { return source_map_def_; } + + StatusOr<absl::string_view> GetUniqueString(int string_index) const; + + StatusOr<const FunctionSourceMapDef*> GetFunctionSourceMap( + int function_ordinal) const; + + absl::optional<rt::SourceLocation> ResolveFunctionOffset( + const rt::Function& function, rt::SourceOffset offset) override; + + void PrintSourceLocation(rt::SourceResolverArgs resolver_args, + std::ostream* stream) const override; + + private: + const SourceMapDef* source_map_def_ = nullptr; +}; + +} // namespace vm +} // namespace iree + +#endif // IREE_VM_SOURCE_MAP_RESOLVER_H_
diff --git a/vm/type.cc b/vm/type.cc new file mode 100644 index 0000000..8906371 --- /dev/null +++ b/vm/type.cc
@@ -0,0 +1,64 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "vm/type.h" + +#include "base/status.h" + +namespace iree { +namespace vm { + +// static +StatusOr<const Type> Type::FromTypeIndex(uint8_t type_index) { + // Currently we only support the builtin types. + if (type_index == static_cast<uint8_t>(BuiltinType::kOpaque)) { + return Type(type_index); + } else if (type_index < kBuiltinTypeCount) { + return Type(type_index); + } + return InvalidArgumentErrorBuilder(IREE_LOC) + << "Type index " << static_cast<int>(type_index) << " not supported"; +} + +// static +const Type Type::FromBuiltin(BuiltinType type) { + return Type(static_cast<uint8_t>(type)); +} + +std::string Type::DebugString() const { + switch (type_index_) { +#define TYPE_NAME(index, name, str, size) \ + case index: \ + return str; + IREE_TYPE_LIST(TYPE_NAME) +#undef TYPE_NAME + default: + return "<invalid>"; + } +} + +size_t Type::element_size() const { + switch (type_index_) { +#define TYPE_SIZE(index, name, str, size) \ + case index: \ + return size; + IREE_TYPE_LIST(TYPE_SIZE) +#undef TYPE_SIZE + default: + return 0; + } +} + +} // namespace vm +} // namespace iree
diff --git a/vm/type.h b/vm/type.h new file mode 100644 index 0000000..0de5357 --- /dev/null +++ b/vm/type.h
@@ -0,0 +1,65 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_VM_TYPE_H_ +#define IREE_VM_TYPE_H_ + +#include "base/status.h" +#include "schemas/bytecode/bytecode_v0.h" +#include "schemas/type_def_generated.h" + +namespace iree { +namespace vm { + +class Type { + public: + static StatusOr<const Type> FromTypeIndex(uint8_t type_index); + static const Type FromBuiltin(BuiltinType type); + + std::string DebugString() const; + + uint8_t type_index() const { return type_index_; } + + bool is_opaque() const { + return type_index_ == static_cast<uint8_t>(BuiltinType::kOpaque); + } + bool is_builtin() const { return !is_opaque(); } + BuiltinType builtin_type() const { + DCHECK(is_builtin()); + return static_cast<BuiltinType>(type_index_); + } + + size_t element_size() const; + + private: + explicit Type(uint8_t type_index) : type_index_(type_index) {} + + uint8_t type_index_; +}; + +inline bool operator==(const Type& a, const Type& b) { + return a.type_index() == b.type_index(); +} + +inline bool operator!=(const Type& a, const Type& b) { return !(a == b); } + +inline std::ostream& operator<<(std::ostream& stream, const Type& type) { + stream << type.DebugString(); + return stream; +} + +} // namespace vm +} // namespace iree + +#endif // IREE_VM_TYPE_H_