Porting a majority of Vulkan/Metal to C.
diff --git a/iree/base/api.h b/iree/base/api.h index 5930d73..baca1e1 100644 --- a/iree/base/api.h +++ b/iree/base/api.h
@@ -382,7 +382,7 @@ #if !defined(IREE_STATUS_MODE) #ifdef NDEBUG // Release mode: just source location. -#define IREE_STATUS_MODE 1 +#define IREE_STATUS_MODE 2 #else // Debug mode: annotations and stack traces. #define IREE_STATUS_MODE 3
diff --git a/iree/hal/buffer.c b/iree/hal/buffer.c index 1a924cf..9ffe8da 100644 --- a/iree/hal/buffer.c +++ b/iree/hal/buffer.c
@@ -43,6 +43,7 @@ &buffer->resource); buffer->allocator = allocated_buffer->allocator; buffer->allocated_buffer = allocated_buffer; + iree_hal_buffer_retain(buffer->allocated_buffer); buffer->allocation_size = allocated_buffer->allocation_size; buffer->byte_offset = byte_offset; buffer->byte_length = byte_length; @@ -71,21 +72,22 @@ iree_hal_buffer_t* buffer, iree_device_size_t byte_offset, iree_device_size_t byte_length, const void* pattern, iree_host_size_t pattern_length) { - return _VTABLE_DISPATCH(buffer, fill)(buffer->allocated_buffer, byte_offset, - byte_length, pattern, pattern_length); + return _VTABLE_DISPATCH(buffer->allocated_buffer, fill)( + buffer->allocated_buffer, byte_offset, byte_length, pattern, + pattern_length); } static iree_status_t iree_hal_subspan_buffer_read_data( iree_hal_buffer_t* buffer, iree_device_size_t source_offset, void* target_buffer, iree_device_size_t data_length) { - return _VTABLE_DISPATCH(buffer, read_data)( + return _VTABLE_DISPATCH(buffer->allocated_buffer, read_data)( buffer->allocated_buffer, source_offset, target_buffer, data_length); } static iree_status_t iree_hal_subspan_buffer_write_data( iree_hal_buffer_t* buffer, iree_device_size_t target_offset, const void* source_buffer, iree_device_size_t data_length) { - return _VTABLE_DISPATCH(buffer, write_data)( + return _VTABLE_DISPATCH(buffer->allocated_buffer, write_data)( buffer->allocated_buffer, target_offset, source_buffer, data_length); } @@ -93,7 +95,7 @@ iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, iree_device_size_t data_length) { - return _VTABLE_DISPATCH(target_buffer, copy_data)( + return _VTABLE_DISPATCH(target_buffer->allocated_buffer, copy_data)( source_buffer, source_offset, target_buffer->allocated_buffer, target_offset, data_length); } @@ -103,30 +105,30 @@ iree_hal_memory_access_t memory_access, iree_device_size_t local_byte_offset, iree_device_size_t local_byte_length, void** out_data_ptr) { - return _VTABLE_DISPATCH(buffer, map_range)(buffer, mapping_mode, - memory_access, local_byte_offset, - local_byte_length, out_data_ptr); + return _VTABLE_DISPATCH(buffer->allocated_buffer, map_range)( + buffer->allocated_buffer, mapping_mode, memory_access, local_byte_offset, + local_byte_length, out_data_ptr); } static void iree_hal_subspan_buffer_unmap_range( iree_hal_buffer_t* buffer, iree_device_size_t local_byte_offset, iree_device_size_t local_byte_length, void* data_ptr) { - return _VTABLE_DISPATCH(buffer, unmap_range)(buffer, local_byte_offset, - local_byte_length, data_ptr); + _VTABLE_DISPATCH(buffer->allocated_buffer, unmap_range) + (buffer->allocated_buffer, local_byte_offset, local_byte_length, data_ptr); } static iree_status_t iree_hal_subspan_buffer_invalidate_range( iree_hal_buffer_t* buffer, iree_device_size_t local_byte_offset, iree_device_size_t local_byte_length) { - return _VTABLE_DISPATCH(buffer, invalidate_range)(buffer, local_byte_offset, - local_byte_length); + return _VTABLE_DISPATCH(buffer->allocated_buffer, invalidate_range)( + buffer->allocated_buffer, local_byte_offset, local_byte_length); } static iree_status_t iree_hal_subspan_buffer_flush_range( iree_hal_buffer_t* buffer, iree_device_size_t local_byte_offset, iree_device_size_t local_byte_length) { - return _VTABLE_DISPATCH(buffer, flush_range)(buffer, local_byte_offset, - local_byte_length); + return _VTABLE_DISPATCH(buffer->allocated_buffer, flush_range)( + buffer->allocated_buffer, local_byte_offset, local_byte_length); } static const iree_hal_buffer_vtable_t iree_hal_subspan_buffer_vtable = {
diff --git a/iree/hal/buffer_view.c b/iree/hal/buffer_view.c index ef0385f..00875c5 100644 --- a/iree/hal/buffer_view.c +++ b/iree/hal/buffer_view.c
@@ -52,7 +52,7 @@ host_allocator, sizeof(*buffer_view) + sizeof(iree_hal_dim_t) * shape_rank, (void**)&buffer_view); - if (iree_status_is_ok(buffer_view)) { + if (iree_status_is_ok(status)) { iree_atomic_ref_count_init(&buffer_view->ref_count); buffer_view->buffer = buffer; iree_hal_buffer_retain(buffer_view->buffer); @@ -112,11 +112,11 @@ IREE_RETURN_IF_ERROR(iree_hal_buffer_subspan( buffer_view->buffer, start_offset, subview_length, &subview_buffer)); - iree_status_t result = + iree_status_t status = iree_hal_buffer_view_create(subview_buffer, lengths, lengths_count, buffer_view->element_type, out_buffer_view); iree_hal_buffer_release(subview_buffer); - return result; + return status; } IREE_API_EXPORT iree_hal_buffer_t* iree_hal_buffer_view_buffer(
diff --git a/iree/hal/cc/BUILD b/iree/hal/cc/BUILD deleted file mode 100644 index c8390d8..0000000 --- a/iree/hal/cc/BUILD +++ /dev/null
@@ -1,74 +0,0 @@ -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# HAL (Hardware Abstraction Layer). -# Subdirectories contain implementations for different hardware and -# software backends. - -package( - default_visibility = ["//visibility:public"], - features = ["layering_check"], - licenses = ["notice"], # Apache 2.0 -) - -cc_library( - name = "cc", - srcs = [ - "buffer.cc", - ], - hdrs = [ - "allocator.h", - "buffer.h", - "command_buffer.h", - "command_queue.h", - "debug_capture_manager.h", - "descriptor_set.h", - "descriptor_set_layout.h", - "device.h", - "device_info.h", - "driver.h", - "event.h", - "executable.h", - "executable_cache.h", - "executable_format.h", - "executable_layout.h", - "resource.h", - "semaphore.h", - ], - deps = [ - "//iree/base:core_headers", - "//iree/base:logging", - "//iree/base:ref_ptr", - "//iree/base:status", - "//iree/base:time", - "//iree/base:tracing", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) - -cc_test( - name = "buffer_test", - srcs = [ - "buffer_mapping_test.cc", - "buffer_test.cc", - ], - deps = [ - ":cc", - "//iree/base:status", - "//iree/testing:gtest", - "//iree/testing:gtest_main", - "@com_google_absl//absl/types:span", - ], -)
diff --git a/iree/hal/cc/CMakeLists.txt b/iree/hal/cc/CMakeLists.txt deleted file mode 100644 index 09b874e..0000000 --- a/iree/hal/cc/CMakeLists.txt +++ /dev/null
@@ -1,64 +0,0 @@ -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -iree_add_all_subdirs() - -iree_cc_library( - NAME - cc - HDRS - "allocator.h" - "buffer.h" - "command_buffer.h" - "command_queue.h" - "debug_capture_manager.h" - "descriptor_set.h" - "descriptor_set_layout.h" - "device.h" - "device_info.h" - "driver.h" - "event.h" - "executable.h" - "executable_cache.h" - "executable_format.h" - "executable_layout.h" - "resource.h" - "semaphore.h" - SRCS - "buffer.cc" - DEPS - absl::span - absl::strings - iree::base::core_headers - iree::base::logging - iree::base::ref_ptr - iree::base::status - iree::base::time - iree::base::tracing - PUBLIC -) - -iree_cc_test( - NAME - buffer_test - SRCS - "buffer_mapping_test.cc" - "buffer_test.cc" - DEPS - ::cc - absl::span - iree::base::status - iree::testing::gtest - iree::testing::gtest_main -)
diff --git a/iree/hal/cc/allocator.h b/iree/hal/cc/allocator.h deleted file mode 100644 index 9c2afbd..0000000 --- a/iree/hal/cc/allocator.h +++ /dev/null
@@ -1,145 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_CC_ALLOCATOR_H_ -#define IREE_HAL_CC_ALLOCATOR_H_ - -#include <cstddef> -#include <memory> - -#include "absl/types/span.h" -#include "iree/base/ref_ptr.h" -#include "iree/base/status.h" -#include "iree/hal/cc/buffer.h" - -namespace iree { -namespace hal { - -// Allocates buffers for a particular device memory space. -// -// Buffers allocated are only guaranteed to work with the driver that the -// allocator services. Any attempt to use buffers on drivers they were not -// allocated from must first be checked with CanUseBuffer. -// -// Thread-safe. -class Allocator : public RefObject<Allocator> { - public: - virtual ~Allocator() = default; - - // Returns true if the device can use the given buffer for the provided usage. - // For buffers allocated from this allocator it's expected that the result - // will always be true. For buffers that originate from another allocator - // there may be limited support for cross-device usage. - // - // Returning false indicates that the buffer must be transferred externally - // into a buffer compatible with the device this allocator services. - bool CanUseBuffer(Buffer* buffer, - iree_hal_buffer_usage_t intended_usage) const { - return CanUseBufferLike(buffer->allocator(), buffer->memory_type(), - buffer->usage(), intended_usage); - } - virtual bool CanUseBufferLike( - Allocator* source_allocator, iree_hal_memory_type_t memory_type, - iree_hal_buffer_usage_t buffer_usage, - iree_hal_buffer_usage_t intended_usage) const = 0; - - // Returns true if the allocator can allocate a buffer with the given - // attributes. - virtual bool CanAllocate(iree_hal_memory_type_t memory_type, - iree_hal_buffer_usage_t buffer_usage, - size_t allocation_size) const = 0; - - // Adjusts allocation parameters to be compatible with the allocator. - // Certain allocators may require particular memory types to function. By - // adjusting the parameters prior to allocation callers can be sure they are - // able to successfully Allocate a buffer later on with the same parameters. - virtual Status MakeCompatible(iree_hal_memory_type_t* memory_type, - iree_hal_buffer_usage_t* buffer_usage) const { - return OkStatus(); - } - - // Allocates a buffer from the allocator. - // Fails if the memory type requested for the given usage cannot be serviced. - // Callers can use CanAllocate to decide their memory use strategy. - // - // The memory type of the buffer returned may differ from the requested value - // if the device can provide more functionality; for example, if requesting - // IREE_HAL_MEMORY_TYPE_HOST_VISIBLE but the memory is really host cached you - // may get a buffer back with IREE_HAL_MEMORY_TYPE_HOST_VISIBLE | - // IREE_HAL_MEMORY_TYPE_HOST_CACHED. The only requirement is that the buffer - // satisfy the required bits. - virtual StatusOr<ref_ptr<Buffer>> Allocate( - iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t buffer_usage, - size_t allocation_size) = 0; - - // Wraps an existing host heap allocation in a buffer. - // Ownership of the host allocation remains with the caller and the memory - // must remain valid for so long as the Buffer may be in use. - // Will have IREE_HAL_MEMORY_TYPE_HOST_LOCAL in most cases and may not be - // usable by the device. - // - // The inference optimizer makes assumptions about buffer aliasing based on - // Buffer instances and because of this wrapping the same host buffer in - // multiple Buffers will create potential memory aliasing issues that can be - // difficult to track down. There's no checking as to whether a host buffer - // has already been wrapped so it's best for callers to ensure this is never - // possible (the simplest way being to never use Wrap and always just allocate - // new Buffers). - // - // Fails if the allocator cannot access host memory in this way. - StatusOr<ref_ptr<Buffer>> Wrap(iree_hal_memory_type_t memory_type, - iree_hal_buffer_usage_t buffer_usage, - const void* data, size_t data_length) { - return WrapMutable(memory_type, IREE_HAL_MEMORY_ACCESS_READ, buffer_usage, - const_cast<void*>(data), data_length); - } - virtual StatusOr<ref_ptr<Buffer>> WrapMutable( - iree_hal_memory_type_t memory_type, - iree_hal_memory_access_t allowed_access, - iree_hal_buffer_usage_t buffer_usage, void* data, size_t data_length) { - return UnimplementedErrorBuilder(IREE_LOC) - << "Allocator does not support wrapping host memory"; - } - template <typename T> - StatusOr<ref_ptr<Buffer>> Wrap(iree_hal_memory_type_t memory_type, - iree_hal_buffer_usage_t buffer_usage, - absl::Span<const T> data); - template <typename T> - StatusOr<ref_ptr<Buffer>> WrapMutable(iree_hal_memory_type_t memory_type, - iree_hal_memory_access_t allowed_access, - iree_hal_buffer_usage_t buffer_usage, - absl::Span<T> data); -}; - -// Inline functions and template definitions follow: - -template <typename T> -StatusOr<ref_ptr<Buffer>> Allocator::Wrap(iree_hal_memory_type_t memory_type, - iree_hal_buffer_usage_t buffer_usage, - absl::Span<const T> data) { - return Wrap(memory_type, buffer_usage, data.data(), data.size() * sizeof(T)); -} - -template <typename T> -StatusOr<ref_ptr<Buffer>> Allocator::WrapMutable( - iree_hal_memory_type_t memory_type, iree_hal_memory_access_t allowed_access, - iree_hal_buffer_usage_t buffer_usage, absl::Span<T> data) { - return WrapMutable(memory_type, allowed_access, buffer_usage, data.data(), - data.size() * sizeof(T)); -} - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_CC_ALLOCATOR_H_
diff --git a/iree/hal/cc/buffer.cc b/iree/hal/cc/buffer.cc deleted file mode 100644 index 49baa8d..0000000 --- a/iree/hal/cc/buffer.cc +++ /dev/null
@@ -1,565 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/cc/buffer.h" - -#include <algorithm> -#include <atomic> -#include <cstdint> -#include <cstring> -#include <sstream> - -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "iree/base/status.h" - -namespace iree { -namespace hal { - -std::string MemoryTypeString(iree_hal_memory_type_t memory_type) { - return "TODO"; - // return FormatBitfieldValue( - // memory_type, { - // // Combined: - // {IREE_HAL_MEMORY_TYPE_HOST_LOCAL, "kHostLocal"}, - // {IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, "kDeviceLocal"}, - // // Separate: - // {IREE_HAL_MEMORY_TYPE_TRANSIENT, "kTransient"}, - // {IREE_HAL_MEMORY_TYPE_HOST_VISIBLE, "kHostVisible"}, - // {IREE_HAL_MEMORY_TYPE_HOST_COHERENT, "kHostCoherent"}, - // {IREE_HAL_MEMORY_TYPE_HOST_CACHED, "kHostCached"}, - // {IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE, - // "kDeviceVisible"}, - // }); -} - -std::string MemoryAccessString(iree_hal_memory_access_t memory_access) { - return "TODO"; - // return FormatBitfieldValue( - // memory_access, - // { - // // Combined: - // {IREE_HAL_MEMORY_ACCESS_ALL, "kAll"}, - // {IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE, "kDiscardWrite"}, - // // Separate: - // {IREE_HAL_MEMORY_ACCESS_READ, "kRead"}, - // {IREE_HAL_MEMORY_ACCESS_WRITE, "kWrite"}, - // {IREE_HAL_MEMORY_ACCESS_DISCARD, "kDiscard"}, - // {IREE_HAL_MEMORY_ACCESS_MAY_ALIAS, "kMayAlias"}, - // }); -} - -std::string BufferUsageString(iree_hal_buffer_usage_t buffer_usage) { - return "TODO"; - // return FormatBitfieldValue(buffer_usage, - // { - // // Combined: - // {IREE_HAL_BUFFER_USAGE_ALL, "kAll"}, - // // Separate: - // {IREE_HAL_BUFFER_USAGE_CONSTANT, - // "kConstant"}, - // {IREE_HAL_BUFFER_USAGE_TRANSFER, - // "kTransfer"}, - // {IREE_HAL_BUFFER_USAGE_MAPPING, "kMapping"}, - // {IREE_HAL_BUFFER_USAGE_DISPATCH, - // "kDispatch"}, - // }); -} - -// Special router for buffers that just reference other buffers. -// We keep this out of the base Buffer so that it's a bit easier to track -// delegation. -class SubspanBuffer : public Buffer { - public: - SubspanBuffer(ref_ptr<Buffer> parent_buffer, iree_device_size_t byte_offset, - iree_device_size_t byte_length) - : Buffer(parent_buffer->allocator(), parent_buffer->memory_type(), - parent_buffer->allowed_access(), parent_buffer->usage(), - parent_buffer->allocation_size(), byte_offset, byte_length) { - allocated_buffer_ = parent_buffer.get(); - parent_buffer_ = std::move(parent_buffer); - } - - protected: - Status FillImpl(iree_device_size_t byte_offset, - iree_device_size_t byte_length, const void* pattern, - iree_device_size_t pattern_length) override { - return parent_buffer_->FillImpl(byte_offset, byte_length, pattern, - pattern_length); - } - - Status ReadDataImpl(iree_device_size_t source_offset, void* data, - iree_device_size_t data_length) override { - return parent_buffer_->ReadDataImpl(source_offset, data, data_length); - } - - Status WriteDataImpl(iree_device_size_t target_offset, const void* data, - iree_device_size_t data_length) override { - return parent_buffer_->WriteDataImpl(target_offset, data, data_length); - } - - Status CopyDataImpl(iree_device_size_t target_offset, Buffer* source_buffer, - iree_device_size_t source_offset, - iree_device_size_t data_length) override { - return parent_buffer_->CopyDataImpl(target_offset, source_buffer, - source_offset, data_length); - } - - Status MapMemoryImpl(MappingMode mapping_mode, - iree_hal_memory_access_t memory_access, - iree_device_size_t local_byte_offset, - iree_device_size_t local_byte_length, - void** out_data) override { - return parent_buffer_->MapMemoryImpl(mapping_mode, memory_access, - local_byte_offset, local_byte_length, - out_data); - } - - Status UnmapMemoryImpl(iree_device_size_t local_byte_offset, - iree_device_size_t local_byte_length, - void* data) override { - return parent_buffer_->UnmapMemoryImpl(local_byte_offset, local_byte_length, - data); - } - - Status InvalidateMappedMemoryImpl( - iree_device_size_t local_byte_offset, - iree_device_size_t local_byte_length) override { - return parent_buffer_->InvalidateMappedMemoryImpl(local_byte_offset, - local_byte_length); - } - - Status FlushMappedMemoryImpl(iree_device_size_t local_byte_offset, - iree_device_size_t local_byte_length) override { - return parent_buffer_->FlushMappedMemoryImpl(local_byte_offset, - local_byte_length); - } -}; - -// static -StatusOr<ref_ptr<Buffer>> Buffer::Subspan(const ref_ptr<Buffer>& buffer, - iree_device_size_t byte_offset, - iree_device_size_t byte_length) { - IREE_RETURN_IF_ERROR(buffer->CalculateRange(byte_offset, byte_length, - &byte_offset, &byte_length)); - if (byte_offset == 0 && byte_length == buffer->byte_length()) { - // Asking for the same buffer. - return add_ref(buffer); - } - - // To avoid heavy nesting of subspans that just add indirection we go to the - // parent buffer directly. If we wanted better accounting (to track where - // buffers came from) we'd want to avoid this but I'm not sure that's worth - // the super deep indirection that could arise. - if (buffer->allocated_buffer() != buffer.get()) { - IREE_CHECK(buffer->parent_buffer_); - return Buffer::Subspan(buffer->parent_buffer_, byte_offset, byte_length); - } else { - return {make_ref<SubspanBuffer>(add_ref(buffer), byte_offset, byte_length)}; - } -} - -// static -Buffer::Overlap Buffer::TestOverlap(Buffer* lhs_buffer, - iree_device_size_t lhs_offset, - iree_device_size_t lhs_length, - Buffer* rhs_buffer, - iree_device_size_t rhs_offset, - iree_device_size_t rhs_length) { - if (lhs_buffer->allocated_buffer() != rhs_buffer->allocated_buffer()) { - // Not even the same buffers. - return Overlap::kDisjoint; - } - // Resolve offsets into the underlying allocation. - iree_device_size_t lhs_alloc_offset = lhs_buffer->byte_offset() + lhs_offset; - iree_device_size_t rhs_alloc_offset = rhs_buffer->byte_offset() + rhs_offset; - iree_device_size_t lhs_alloc_length = - lhs_length == IREE_WHOLE_BUFFER ? lhs_buffer->byte_length() - lhs_offset - : lhs_length; - iree_device_size_t rhs_alloc_length = - rhs_length == IREE_WHOLE_BUFFER ? rhs_buffer->byte_length() - rhs_offset - : rhs_length; - if (!lhs_alloc_length || !rhs_alloc_length) { - return Overlap::kDisjoint; - } - if (lhs_alloc_offset == rhs_alloc_offset && - lhs_alloc_length == rhs_alloc_length) { - return Overlap::kComplete; - } - return lhs_alloc_offset + lhs_alloc_length > rhs_alloc_offset && - rhs_alloc_offset + rhs_alloc_length > lhs_alloc_offset - ? Overlap::kPartial - : Overlap::kDisjoint; -} - -// static -bool Buffer::DoesOverlap(Buffer* lhs_buffer, iree_device_size_t lhs_offset, - iree_device_size_t lhs_length, Buffer* rhs_buffer, - iree_device_size_t rhs_offset, - iree_device_size_t rhs_length) { - return TestOverlap(lhs_buffer, lhs_offset, lhs_length, rhs_buffer, rhs_offset, - rhs_length) != Overlap::kDisjoint; -} - -Buffer::Buffer(Allocator* allocator, iree_hal_memory_type_t memory_type, - iree_hal_memory_access_t allowed_access, - iree_hal_buffer_usage_t usage, - iree_device_size_t allocation_size, - iree_device_size_t byte_offset, iree_device_size_t byte_length) - : allocated_buffer_(const_cast<Buffer*>(this)), - allocator_(allocator), - memory_type_(memory_type), - allowed_access_(allowed_access), - usage_(usage), - allocation_size_(allocation_size), - byte_offset_(byte_offset), - byte_length_(byte_length) {} - -Buffer* Buffer::allocated_buffer() const noexcept { - Buffer* allocated_buffer = allocated_buffer_; - while (allocated_buffer != this && - allocated_buffer != allocated_buffer->allocated_buffer()) { - allocated_buffer = allocated_buffer->allocated_buffer(); - } - return allocated_buffer; -} - -std::string Buffer::DebugString() const { - std::ostringstream stream; - stream << allocated_buffer()->debug_name() << "[" - << (allocation_size() == IREE_WHOLE_BUFFER - ? "?" - : std::to_string(allocation_size())) - << "]."; - if (iree_any_bit_set(memory_type(), IREE_HAL_MEMORY_TYPE_TRANSIENT)) - stream << "Z"; - if ((memory_type() & IREE_HAL_MEMORY_TYPE_HOST_LOCAL) == - IREE_HAL_MEMORY_TYPE_HOST_LOCAL) { - stream << "h"; - } else { - if (iree_any_bit_set(memory_type(), IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) - stream << "v"; - if (iree_any_bit_set(memory_type(), IREE_HAL_MEMORY_TYPE_HOST_COHERENT)) - stream << "x"; - if (iree_any_bit_set(memory_type(), IREE_HAL_MEMORY_TYPE_HOST_CACHED)) - stream << "c"; - } - if (iree_all_bits_set(memory_type(), IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL)) { - stream << "D"; - } else { - if (iree_any_bit_set(memory_type(), IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) - stream << "V"; - } - stream << "."; - if (iree_any_bit_set(usage(), IREE_HAL_BUFFER_USAGE_CONSTANT)) stream << "c"; - if (iree_any_bit_set(usage(), IREE_HAL_BUFFER_USAGE_TRANSFER)) stream << "t"; - if (iree_any_bit_set(usage(), IREE_HAL_BUFFER_USAGE_MAPPING)) stream << "m"; - if (iree_any_bit_set(usage(), IREE_HAL_BUFFER_USAGE_DISPATCH)) stream << "d"; - if (byte_offset_ || byte_length_ != allocation_size_) { - stream << "(" << byte_offset_ << "-" << (byte_offset_ + byte_length_ - 1) - << ")"; - } - return stream.str(); -} - -std::string Buffer::DebugStringShort() const { - // TODO(benvanik): figure out what's most useful here. Maybe a long variant? - std::ostringstream stream; - stream << allocated_buffer()->debug_name() << "[" - << (allocation_size() == IREE_WHOLE_BUFFER - ? "?" - : std::to_string(allocation_size())) - << "]"; - if (byte_offset_ || byte_length_ != allocation_size_) { - stream << "(" << byte_offset_ << "-" << (byte_offset_ + byte_length_ - 1) - << ")"; - } - return stream.str(); -} - -Status Buffer::ValidateCompatibleMemoryType( - iree_hal_memory_type_t memory_type) const { - if ((memory_type_ & memory_type) != memory_type) { - // Missing one or more bits. - return PermissionDeniedErrorBuilder(IREE_LOC) - << "Buffer memory type is not compatible with the requested " - "operation; buffer has " - << MemoryTypeString(memory_type_) << ", operation requires " - << MemoryTypeString(memory_type); - } - return OkStatus(); -} - -Status Buffer::ValidateAccess(iree_hal_memory_access_t memory_access) const { - if (!iree_any_bit_set(memory_access, (IREE_HAL_MEMORY_ACCESS_READ | - IREE_HAL_MEMORY_ACCESS_WRITE))) { - // No actual access bits defined. - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Memory access must specify one or more of kRead or kWrite"; - } else if ((allowed_access_ & memory_access) != memory_access) { - // Bits must match exactly. - return PermissionDeniedErrorBuilder(IREE_LOC) - << "The buffer does not support the requested access type; buffer " - "allows " - << MemoryAccessString(allowed_access_) << ", operation requires " - << MemoryAccessString(memory_access); - } - return OkStatus(); -} - -Status Buffer::ValidateUsage(iree_hal_buffer_usage_t usage) const { - if ((usage_ & usage) != usage) { - // Missing one or more bits. - return PermissionDeniedErrorBuilder(IREE_LOC) - << "Requested usage was not specified when the buffer was " - "allocated; buffer allows " - << BufferUsageString(usage_) << ", operation requires " - << BufferUsageString(usage); - } - return OkStatus(); -} - -Status Buffer::CalculateRange(iree_device_size_t base_offset, - iree_device_size_t max_length, - iree_device_size_t offset, - iree_device_size_t length, - iree_device_size_t* out_adjusted_offset, - iree_device_size_t* out_adjusted_length) { - // Check if the start of the range runs off the end of the buffer. - if (offset > max_length) { - *out_adjusted_offset = 0; - if (out_adjusted_length) *out_adjusted_length = 0; - return OutOfRangeErrorBuilder(IREE_LOC) - << "Attempted to access an address off the end of the valid buffer " - "range (offset=" - << offset << ", length=" << length - << ", buffer byte_length=" << max_length << ")"; - } - - // Handle length as IREE_WHOLE_BUFFER by adjusting it (if allowed). - if (length == IREE_WHOLE_BUFFER && !out_adjusted_length) { - *out_adjusted_offset = 0; - return InvalidArgumentErrorBuilder(IREE_LOC) - << "IREE_WHOLE_BUFFER may only be used with buffer ranges, not " - "external " - "pointer ranges"; - } - - // Calculate the real ranges adjusted for our region within the allocation. - iree_device_size_t adjusted_offset = base_offset + offset; - iree_device_size_t adjusted_length = - length == IREE_WHOLE_BUFFER ? max_length - offset : length; - if (adjusted_length == 0) { - // Fine to have a zero length. - *out_adjusted_offset = adjusted_offset; - if (out_adjusted_length) *out_adjusted_length = adjusted_length; - return OkStatus(); - } - - // Check if the end runs over the allocation. - iree_device_size_t end = offset + adjusted_length - 1; - if (end >= max_length) { - *out_adjusted_offset = 0; - if (out_adjusted_length) *out_adjusted_length = 0; - return OutOfRangeErrorBuilder(IREE_LOC) - << "Attempted to access an address outside of the valid buffer " - "range (offset=" - << offset << ", adjusted_length=" << adjusted_length - << ", end=" << end << ", buffer byte_length=" << max_length << ")"; - } - - *out_adjusted_offset = adjusted_offset; - if (out_adjusted_length) *out_adjusted_length = adjusted_length; - return OkStatus(); -} - -Status Buffer::CalculateRange(iree_device_size_t offset, - iree_device_size_t length, - iree_device_size_t* out_adjusted_offset, - iree_device_size_t* out_adjusted_length) const { - return CalculateRange(byte_offset_, byte_length_, offset, length, - out_adjusted_offset, out_adjusted_length); -} - -Status Buffer::CalculateLocalRange(iree_device_size_t max_length, - iree_device_size_t offset, - iree_device_size_t length, - iree_device_size_t* out_adjusted_offset, - iree_device_size_t* out_adjusted_length) { - return CalculateRange(0, max_length, offset, length, out_adjusted_offset, - out_adjusted_length); -} - -Status Buffer::Fill(iree_device_size_t byte_offset, - iree_device_size_t byte_length, const void* pattern, - iree_device_size_t pattern_length) { - // If not host visible we'll need to issue command buffers. - IREE_RETURN_IF_ERROR( - ValidateCompatibleMemoryType(IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)); - IREE_RETURN_IF_ERROR(ValidateAccess(IREE_HAL_MEMORY_ACCESS_WRITE)); - IREE_RETURN_IF_ERROR(ValidateUsage(IREE_HAL_BUFFER_USAGE_MAPPING)); - IREE_RETURN_IF_ERROR( - CalculateRange(byte_offset, byte_length, &byte_offset, &byte_length)); - if (pattern_length != 1 && pattern_length != 2 && pattern_length != 4) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Fill patterns must be 1, 2, or 4 bytes"; - } - if ((byte_offset % pattern_length) != 0 || - (byte_length % pattern_length) != 0) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Attempting to fill a range with " << pattern_length - << " byte values that is not " - "aligned (offset=" - << byte_offset << ", length=" << byte_length << ")"; - } - if (byte_length == 0) { - return OkStatus(); // No-op. - } - const uint32_t kZero = 0; - if (std::memcmp(pattern, &kZero, pattern_length) == 0) { - // We can turn all-zero values into single-byte fills as that can be much - // faster on devices (doing a fill8 vs fill32). - pattern_length = 1; - } - return FillImpl(byte_offset, byte_length, pattern, pattern_length); -} - -Status Buffer::ReadData(iree_device_size_t source_offset, void* data, - iree_device_size_t data_length) { - // If not host visible we'll need to issue command buffers. - IREE_RETURN_IF_ERROR( - ValidateCompatibleMemoryType(IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)); - IREE_RETURN_IF_ERROR(ValidateAccess(IREE_HAL_MEMORY_ACCESS_READ)); - IREE_RETURN_IF_ERROR(ValidateUsage(IREE_HAL_BUFFER_USAGE_MAPPING)); - IREE_RETURN_IF_ERROR( - CalculateRange(source_offset, data_length, &source_offset)); - if (data_length == 0) { - return OkStatus(); // No-op. - } - return ReadDataImpl(source_offset, data, data_length); -} - -Status Buffer::WriteData(iree_device_size_t target_offset, const void* data, - iree_device_size_t data_length) { - // If not host visible we'll need to issue command buffers. - IREE_RETURN_IF_ERROR( - ValidateCompatibleMemoryType(IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)); - IREE_RETURN_IF_ERROR(ValidateAccess(IREE_HAL_MEMORY_ACCESS_WRITE)); - IREE_RETURN_IF_ERROR(ValidateUsage(IREE_HAL_BUFFER_USAGE_MAPPING)); - IREE_RETURN_IF_ERROR( - CalculateRange(target_offset, data_length, &target_offset)); - if (data_length == 0) { - return OkStatus(); // No-op. - } - return WriteDataImpl(target_offset, data, data_length); -} - -Status Buffer::CopyData(iree_device_size_t target_offset, Buffer* source_buffer, - iree_device_size_t source_offset, - iree_device_size_t data_length) { - IREE_RETURN_IF_ERROR( - ValidateCompatibleMemoryType(IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)); - IREE_RETURN_IF_ERROR(ValidateAccess(IREE_HAL_MEMORY_ACCESS_WRITE)); - IREE_RETURN_IF_ERROR(ValidateUsage(IREE_HAL_BUFFER_USAGE_MAPPING)); - IREE_RETURN_IF_ERROR(source_buffer->ValidateCompatibleMemoryType( - IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)); - IREE_RETURN_IF_ERROR( - source_buffer->ValidateAccess(IREE_HAL_MEMORY_ACCESS_READ)); - IREE_RETURN_IF_ERROR( - source_buffer->ValidateUsage(IREE_HAL_BUFFER_USAGE_MAPPING)); - - // We need to validate both buffers. - iree_device_size_t source_data_length = data_length; - iree_device_size_t target_data_length = data_length; - iree_device_size_t adjusted_source_offset; - IREE_RETURN_IF_ERROR(source_buffer->CalculateRange( - source_offset, source_data_length, &adjusted_source_offset, - &source_data_length)); - IREE_RETURN_IF_ERROR(CalculateRange(target_offset, target_data_length, - &target_offset, &target_data_length)); - iree_device_size_t adjusted_data_length; - if (data_length == IREE_WHOLE_BUFFER) { - // Whole buffer copy requested - that could mean either, so take the min. - adjusted_data_length = std::min(source_data_length, target_data_length); - } else { - // Specific length requested - validate that we have matching lengths. - IREE_CHECK_EQ(source_data_length, target_data_length); - adjusted_data_length = source_data_length; - } - - // Elide zero length copies. - if (adjusted_data_length == 0) { - return OkStatus(); - } - - // Check for overlap. - if (this == source_buffer && - adjusted_source_offset <= target_offset + adjusted_data_length && - target_offset <= adjusted_source_offset + adjusted_data_length) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Source and target ranges overlap within the same buffer"; - } - - return CopyDataImpl(target_offset, source_buffer, source_offset, - adjusted_data_length); -} - -Status Buffer::MapMemory(MappingMode mapping_mode, - iree_hal_memory_access_t memory_access, - iree_device_size_t* byte_offset, - iree_device_size_t* byte_length, void** out_data) { - IREE_RETURN_IF_ERROR( - ValidateCompatibleMemoryType(IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)); - IREE_RETURN_IF_ERROR(ValidateAccess(memory_access)); - IREE_RETURN_IF_ERROR(ValidateUsage(IREE_HAL_BUFFER_USAGE_MAPPING)); - IREE_RETURN_IF_ERROR( - CalculateRange(*byte_offset, *byte_length, byte_offset, byte_length)); - *out_data = nullptr; - return MapMemoryImpl(mapping_mode, memory_access, *byte_offset, *byte_length, - out_data); -} - -Status Buffer::UnmapMemory(iree_device_size_t local_byte_offset, - iree_device_size_t local_byte_length, void* data) { - IREE_RETURN_IF_ERROR( - ValidateCompatibleMemoryType(IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)); - IREE_RETURN_IF_ERROR(ValidateUsage(IREE_HAL_BUFFER_USAGE_MAPPING)); - // NOTE: local_byte_offset/local_byte_length are already adjusted. - return UnmapMemoryImpl(local_byte_offset, local_byte_length, data); -} - -Status Buffer::InvalidateMappedMemory(iree_device_size_t local_byte_offset, - iree_device_size_t local_byte_length) { - IREE_RETURN_IF_ERROR( - ValidateCompatibleMemoryType(IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)); - if (iree_any_bit_set(memory_type_, IREE_HAL_MEMORY_TYPE_HOST_COHERENT)) { - return PermissionDeniedErrorBuilder(IREE_LOC) - << "Buffer memory type is coherent and invalidation is not required"; - } - IREE_RETURN_IF_ERROR(ValidateUsage(IREE_HAL_BUFFER_USAGE_MAPPING)); - // NOTE: local_byte_offset/local_byte_length are already adjusted. - return InvalidateMappedMemoryImpl(local_byte_offset, local_byte_length); -} - -Status Buffer::FlushMappedMemory(iree_device_size_t local_byte_offset, - iree_device_size_t local_byte_length) { - IREE_RETURN_IF_ERROR(ValidateCompatibleMemoryType( - IREE_HAL_MEMORY_TYPE_HOST_VISIBLE | IREE_HAL_MEMORY_TYPE_HOST_CACHED)); - IREE_RETURN_IF_ERROR(ValidateUsage(IREE_HAL_BUFFER_USAGE_MAPPING)); - // NOTE: local_byte_offset/local_byte_length are already adjusted. - return FlushMappedMemoryImpl(local_byte_offset, local_byte_length); -} - -} // namespace hal -} // namespace iree
diff --git a/iree/hal/cc/buffer.h b/iree/hal/cc/buffer.h deleted file mode 100644 index 25b7636..0000000 --- a/iree/hal/cc/buffer.h +++ /dev/null
@@ -1,772 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Allocated memory buffer wrapper type and utilities. -// -// Buffers are the basic unit of memory used by the inference system. They may -// be allocated such that they are accessible from the host (normal C++ code -// running on the main CPU), a particular device (such as an accelerator) or -// family of devices, or from some mix of all of those. -// -// The type of memory a buffer is allocated within has implications on it's -// performance and lifetime. For example if an application attempts to use a -// host-allocated buffer (IREE_HAL_MEMORY_TYPE_HOST_LOCAL) on an accelerator -// with discrete memory the accelerator may either be unable to access the -// memory or take a non-trivial performance hit when attempting to do so -// (involving setting up kernel mappings, doing DMA transfers, etc). Likewise, -// trying to access a device-allocated buffer -// (IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL) may incur similar overhead or not be -// possible at all. This may be due to restrictions in the memory visibility, -// address spaces, mixed endianness or pointer widths, and other weirdness. -// -// The memory types (defined by a bitfield of MemoryType values) that a -// particular context (host or device) may use vary from device to device and -// must be queried by the application when allocating buffers. It's strongly -// recommended that the most specific memory type be set as possible. For -// example allocating a buffer with IREE_HAL_MEMORY_TYPE_HOST_COHERENT even when -// it will never be used in a way that requires coherency may occupy address -// space reservations or memory mapping that would otherwise not be needed. -// -// As buffers may sometimes not be accessible from the host the base Buffer type -// does not allow for direct void* access and instead buffers must be either -// manipulated using utility functions (such as ReadData or WriteData) or by -// mapping them into a host-accessible address space via MapMemory. Buffer must -// be unmapped before any command may use it. -// -// Buffers may map (roughly) 1:1 with an allocation either from the host heap or -// a device. Buffer::Subspan can be used to reference subspans of buffers like -// absl::Span - though unlike absl::Span the returned Buffer holds a reference -// to the parent buffer. - -#ifndef IREE_HAL_CC_BUFFER_H_ -#define IREE_HAL_CC_BUFFER_H_ - -#include <cstddef> -#include <cstdint> -#include <memory> -#include <string> -#include <utility> - -#include "absl/types/span.h" -#include "iree/base/logging.h" -#include "iree/base/status.h" -#include "iree/hal/api.h" -#include "iree/hal/cc/resource.h" - -namespace iree { -namespace hal { - -class Allocator; -template <typename T> -class MappedMemory; - -std::string MemoryTypeString(iree_hal_memory_type_t memory_type); -std::string MemoryAccessString(iree_hal_memory_access_t memory_access); -std::string BufferUsageString(iree_hal_buffer_usage_t buffer_usage); - -// A memory buffer. -// Buffers have a specific memory_type that is used to describe the capabilities -// and behavior of the backing memory of the buffer. Buffers may be any mix of -// host-accessible, host-coherent, or device-accessible for various usages. -// Depending on these memory types the buffers may be mapped for access on the -// host as memory though certain restrictions may be imposed. -// -// See MemoryType for more information about the types and what operations they -// support. -class Buffer : public Resource { - public: - // Returns a reference to a subspan of the buffer. - // If |byte_length| is IREE_WHOLE_BUFFER the remaining bytes in the buffer - // after |byte_offset| (possibly 0) will be selected. - // - // The parent buffer will remain alive for the lifetime of the subspan - // returned. If the subspan is a small portion this may cause additional - // memory to remain allocated longer than required. - // - // Returns the given |buffer| if the requested span covers the entire range. - static StatusOr<ref_ptr<Buffer>> Subspan(const ref_ptr<Buffer>& buffer, - iree_device_size_t byte_offset, - iree_device_size_t byte_length); - - // Overlap test results. - enum class Overlap { - // No overlap between the two buffers. - kDisjoint, - // Partial overlap between the two buffers. - kPartial, - // Complete overlap between the two buffers (they are the same). - kComplete, - }; - - // Tests whether the given buffers overlap, including support for subspans. - // IREE_WHOLE_BUFFER may be used for |lhs_length| and/or |rhs_length| to use - // the lengths of those buffers, respectively. - static Overlap TestOverlap(Buffer* lhs_buffer, iree_device_size_t lhs_offset, - iree_device_size_t lhs_length, Buffer* rhs_buffer, - iree_device_size_t rhs_offset, - iree_device_size_t rhs_length); - - // Returns true if the two buffer ranges overlap at all. - static bool DoesOverlap(Buffer* lhs_buffer, iree_device_size_t lhs_offset, - iree_device_size_t lhs_length, Buffer* rhs_buffer, - iree_device_size_t rhs_offset, - iree_device_size_t rhs_length); - - // Disallow copies (as copying requires real work). - Buffer(const Buffer&) = delete; - Buffer& operator=(const Buffer&) = delete; - - ~Buffer() override = default; - - absl::string_view debug_name() const { return ""; } - void set_debug_name(std::string debug_name) {} - - // Memory allocator this buffer was allocated from. - // May be nullptr if the buffer has no particular allocator and should be - // assumed to be allocated from the host heap. - constexpr Allocator* allocator() const { - return allocated_buffer_ == this ? allocator_ - : allocated_buffer_->allocator(); - } - - // Memory type this buffer is allocated from. - iree_hal_memory_type_t memory_type() const { return memory_type_; } - - // Memory access operations allowed on the buffer. - iree_hal_memory_access_t allowed_access() const { return allowed_access_; } - - // Bitfield describing how the buffer is to be used. - iree_hal_buffer_usage_t usage() const { return usage_; } - - // Returns the underlying buffer that represents the allocated memory for the - // Buffer. In most cases this is the buffer itself but for buffer subspan - // references it will point to the parent buffer. - Buffer* allocated_buffer() const noexcept; - - // Size of the resource memory allocation in bytes. - // This may be rounded up from the originally requested size or the ideal - // size for the resource based on device restrictions. - constexpr iree_device_size_t allocation_size() const { - return allocated_buffer_ == this ? allocation_size_ - : allocated_buffer_->allocation_size(); - } - - // Range within the underlying allocation this buffer occupies. - // For buffers that map 1:1 with an allocation this should be - // [0, allocation_size()), however may still differ if the allocation needed - // to be aligned. - // - // The offset is most often manipulated by Subspan, however it's important to - // note that the offset may not be what was passed to Subspan as it refers to - // the offset in the original ancestor buffer, not the buffer from which the - // subspan was taken. - constexpr iree_device_size_t byte_offset() const noexcept { - return byte_offset_; - } - constexpr iree_device_size_t byte_length() const noexcept { - return byte_length_; - } - - // TODO(benvanik): add debug_name. - - // Returns a longer debug string describing the buffer and its attributes. - std::string DebugString() const override; - // Returns a short debug string describing the buffer. - std::string DebugStringShort() const override; - - // Sets a range of the buffer to the given value. - // This requires that the resource was allocated with - // IREE_HAL_MEMORY_TYPE_HOST_VISIBLE and IREE_HAL_BUFFER_USAGE_MAPPING. - // If |byte_length| is IREE_WHOLE_BUFFER the remaining bytes in the buffer - // after |byte_offset| (possibly 0) will be filled. - // - // The |byte_offset| and |byte_length| must be aligned to the size of the fill - // value. Multi-byte values will be written in host order for host buffers and - // device order for device buffers. - // - // Only |pattern_length| values with 1, 2, or 4 bytes are supported. - // - // Fails if the write could not be performed; either the bounds are out of - // range or the memory type does not support writing in this way. - Status Fill(iree_device_size_t byte_offset, iree_device_size_t byte_length, - const void* pattern, iree_device_size_t pattern_length); - template <typename T> - Status Fill8(iree_device_size_t byte_offset, iree_device_size_t byte_length, - T value); - template <typename T> - Status Fill16(iree_device_size_t byte_offset, iree_device_size_t byte_length, - T value); - template <typename T> - Status Fill32(iree_device_size_t byte_offset, iree_device_size_t byte_length, - T value); - template <typename T> - Status Fill8(T value); - template <typename T> - Status Fill16(T value); - template <typename T> - Status Fill32(T value); - - // Reads a block of byte data from the resource at the given offset. - // This requires that the resource was allocated with - // IREE_HAL_MEMORY_TYPE_HOST_VISIBLE and IREE_HAL_BUFFER_USAGE_MAPPING. - // - // Fails if the read could not be performed; either the bounds are out of - // range or the memory type does not support reading in this way. - Status ReadData(iree_device_size_t source_offset, void* data, - iree_device_size_t data_length); - - // Writes a block of byte data into the resource at the given offset. - // This requires that the resource was allocated with - // IREE_HAL_MEMORY_TYPE_HOST_VISIBLE and IREE_HAL_BUFFER_USAGE_MAPPING. - // - // Fails if the write could not be performed; either the bounds are out of - // range or the memory type does not support writing in this way. - Status WriteData(iree_device_size_t target_offset, const void* data, - iree_device_size_t data_length); - - // Copies data from the provided source_buffer into the buffer. - // This requires that the resource was allocated with - // IREE_HAL_MEMORY_TYPE_HOST_VISIBLE and IREE_HAL_BUFFER_USAGE_MAPPING. - // The source and destination may be the same buffer but the ranges must not - // overlap (a la memcpy). - // - // Fails if the write could not be performed; either the bounds are out of - // range or the memory type does not support writing in this way. - Status CopyData(iree_device_size_t target_offset, Buffer* source_buffer, - iree_device_size_t source_offset, - iree_device_size_t data_length); - Status CopyData(iree_device_size_t target_offset, Buffer* source_buffer) { - return CopyData(target_offset, source_buffer, 0, IREE_WHOLE_BUFFER); - } - - // Maps the resource memory for direct access from the host. - // This requires that the resource was allocated with - // IREE_HAL_MEMORY_TYPE_HOST_VISIBLE and IREE_HAL_BUFFER_USAGE_MAPPING. - // - // If IREE_HAL_MEMORY_TYPE_HOST_COHERENT was not specified then explicit - // Invalidate and Flush calls must be used to control visibility of the data - // on the device. If IREE_HAL_MEMORY_TYPE_HOST_CACHED is not set callers must - // not attempt to read from the mapped memory as doing so may produce - // undefined results and/or ultra slow reads. - // - // If the IREE_HAL_MEMORY_ACCESS_DISCARD bit is set when mapping for writes - // the caller guarantees that they will be overwriting all data in the mapped - // range. This is used as a hint to the device that the prior contents are no - // longer required and can enable optimizations that save on synchronization - // and readback. Note however that it is strictly a hint and the contents are - // not guaranteed to be zeroed during mapping. - // - // This allows mapping the memory as a C++ type. Care must be taken to ensure - // the data layout in C++ matches the expected data layout in the executables - // that consume this data. For simple primitives like uint8_t or float this is - // usually not a problem however struct packing may have many restrictions. - // - // The returned mapping should be unmapped when it is no longer required. - // Unmapping does not implicitly flush. - // - // Fails if the memory could not be mapped due to mapping exhaustion, invalid - // arguments, or unsupported memory types. - // - // Example: - // IREE_ASSIGN_OR_RETURN(auto mapping, buffer->MapForRead<MyStruct>()); - // mapping[5].foo = 3; - // std::memcpy(mapping.data(), source_data, mapping.size()); - // mapping.reset(); - template <typename T> - StatusOr<MappedMemory<T>> MapMemory( - iree_hal_memory_access_t memory_access, - iree_device_size_t element_offset = 0, - iree_device_size_t element_length = IREE_WHOLE_BUFFER); - - protected: - template <typename T> - friend class MappedMemory; - - // Defines the mode of a MapMemory operation. - enum class MappingMode { - // The call to MapMemory will always be matched with UnmapMemory. - kScoped, - }; - - Buffer(Allocator* allocator, iree_hal_memory_type_t memory_type, - iree_hal_memory_access_t allowed_access, iree_hal_buffer_usage_t usage, - iree_device_size_t allocation_size, iree_device_size_t byte_offset, - iree_device_size_t byte_length); - - // Allows subclasses to override the allowed access bits. - // This should only be done when known safe by the allocation scheme. - void set_allowed_access(iree_hal_memory_access_t allowed_access) { - allowed_access_ = allowed_access; - } - - // Sets a range of the buffer to the given value. - // State and parameters have already been validated. For the >8bit variants - // the offset and length have already been validated to be aligned to the - // natural alignment of the type. - virtual Status FillImpl(iree_device_size_t byte_offset, - iree_device_size_t byte_length, const void* pattern, - iree_device_size_t pattern_length) = 0; - - // Reads a block of byte data from the resource at the given offset. - // State and parameters have already been validated. - virtual Status ReadDataImpl(iree_device_size_t source_offset, void* data, - iree_device_size_t data_length) = 0; - - // Writes a block of byte data into the resource at the given offset. - // State and parameters have already been validated. - virtual Status WriteDataImpl(iree_device_size_t target_offset, - const void* data, - iree_device_size_t data_length) = 0; - - // Copies a block of byte data into the resource at the given offset. - // State and parameters have already been validated. - virtual Status CopyDataImpl(iree_device_size_t target_offset, - Buffer* source_buffer, - iree_device_size_t source_offset, - iree_device_size_t data_length) = 0; - - // Maps memory directly. - // The output data pointer will be properly aligned to the start of the data. - // |local_byte_offset| and |local_byte_length| are the adjusted values that - // should map into the local space of the buffer. - // - // Fails if the memory could not be mapped (invalid access type, invalid - // range, or unsupported memory type). - // State and parameters have already been validated. - virtual Status MapMemoryImpl(MappingMode mapping_mode, - iree_hal_memory_access_t memory_access, - iree_device_size_t local_byte_offset, - iree_device_size_t local_byte_length, - void** out_data) = 0; - - // Unmaps previously mapped memory. - // No-op if the memory is not mapped. As this is often used in destructors - // we can't rely on failures here propagating with anything but - // IREE_CHECK/IREE_DCHECK. State and parameters have already been validated. - virtual Status UnmapMemoryImpl(iree_device_size_t local_byte_offset, - iree_device_size_t local_byte_length, - void* data) = 0; - - // Invalidates ranges of non-coherent memory from the host caches. - // Use this before reading from non-coherent memory. - // This guarantees that device writes to the memory ranges provided are - // visible on the host. - // This is only required for memory types without kHostCoherent set. - // State and parameters have already been validated. - virtual Status InvalidateMappedMemoryImpl( - iree_device_size_t local_byte_offset, - iree_device_size_t local_byte_length) = 0; - - // Flushes ranges of non-coherent memory from the host caches. - // Use this after writing to non-coherent memory. - // This guarantees that host writes to the memory ranges provided are made - // available for device access. - // This is only required for memory types without kHostCoherent set. - // State and parameters have already been validated. - virtual Status FlushMappedMemoryImpl( - iree_device_size_t local_byte_offset, - iree_device_size_t local_byte_length) = 0; - - // Validates the given buffer range and adjusts the offset and length if the - // provided length is IREE_WHOLE_BUFFER or the buffer is offset within its - // allocation. This calculates the range in the given domain without adjusting - // to any particular buffer base offsets. - static Status CalculateLocalRange(iree_device_size_t max_length, - iree_device_size_t offset, - iree_device_size_t length, - iree_device_size_t* out_adjusted_offset, - iree_device_size_t* out_adjusted_length); - - private: - friend class Allocator; - - // This is not great and deserves cleanup. - friend class DeferredBuffer; - friend class SubspanBuffer; - friend class HeapBuffer; - - // Maps memory directly. - // The byte offset and byte length may be adjusted for device alignment. - // The output data pointer will be properly aligned to the start of the data. - // Fails if the memory could not be mapped (invalid access type, invalid - // range, or unsupported memory type). - Status MapMemory(MappingMode mapping_mode, - iree_hal_memory_access_t memory_access, - iree_device_size_t* byte_offset, - iree_device_size_t* byte_length, void** out_data); - - // Unmaps previously mapped memory. - // No-op if the memory is not mapped. As this is often used in destructors - // we can't rely on failures here propagating with anything but - // IREE_CHECK/IREE_DCHECK. - Status UnmapMemory(iree_device_size_t local_byte_offset, - iree_device_size_t local_byte_length, void* data); - - // Invalidates ranges of non-coherent memory from the host caches. - // Use this before reading from non-coherent memory. - // This guarantees that device writes to the memory ranges provided are - // visible on the host. - // This is only required for memory types without kHostCoherent set. - Status InvalidateMappedMemory(iree_device_size_t local_byte_offset, - iree_device_size_t local_byte_length); - - // Flushes ranges of non-coherent memory from the host caches. - // Use this after writing to non-coherent memory. - // This guarantees that host writes to the memory ranges provided are made - // available for device access. - // This is only required for memory types without kHostCoherent set. - Status FlushMappedMemory(iree_device_size_t local_byte_offset, - iree_device_size_t local_byte_length); - - // Returns a failure if the memory type the buffer was allocated from is not - // compatible with the given type. - Status ValidateCompatibleMemoryType(iree_hal_memory_type_t memory_type) const; - // Returns a failure if the buffer memory type or usage disallows the given - // access type. - Status ValidateAccess(iree_hal_memory_access_t memory_access) const; - // Returns a failure if the buffer was not allocated for the given usage. - Status ValidateUsage(iree_hal_buffer_usage_t usage) const; - // Validates the given buffer range and optionally adjusts the offset and - // length if the provided length is IREE_WHOLE_BUFFER or the buffer is offset - // within its allocation. - static Status CalculateRange( - iree_device_size_t base_offset, iree_device_size_t max_length, - iree_device_size_t offset, iree_device_size_t length, - iree_device_size_t* out_adjusted_offset, - iree_device_size_t* out_adjusted_length = nullptr); - Status CalculateRange( - iree_device_size_t offset, iree_device_size_t length, - iree_device_size_t* out_adjusted_offset, - iree_device_size_t* out_adjusted_length = nullptr) const; - - // Points to either this or parent_buffer_.get(). - Buffer* allocated_buffer_ = nullptr; - - Allocator* allocator_ = nullptr; - iree_hal_memory_type_t memory_type_ = IREE_HAL_MEMORY_TYPE_NONE; - iree_hal_memory_access_t allowed_access_ = IREE_HAL_MEMORY_ACCESS_NONE; - iree_hal_buffer_usage_t usage_ = IREE_HAL_BUFFER_USAGE_NONE; - - iree_device_size_t allocation_size_ = 0; - iree_device_size_t byte_offset_ = 0; - iree_device_size_t byte_length_ = 0; - - // Defined when this buffer is a subspan of another buffer. - ref_ptr<Buffer> parent_buffer_; -}; - -// A memory mapping RAII object. -// The mapping will stay active until it is reset and will retain the buffer. -template <typename T> -class MappedMemory { - public: - using unspecified_bool_type = const T* MappedMemory<T>::*; - - MappedMemory() = default; - MappedMemory(iree_hal_memory_access_t access, ref_ptr<Buffer> buffer, - iree_device_size_t byte_offset, iree_device_size_t byte_length, - iree_device_size_t element_size, T* data); - - // Allow moving but disallow copying as the mapping is stateful. - MappedMemory(MappedMemory&& rhs) noexcept; - MappedMemory& operator=(MappedMemory&& rhs) noexcept; - MappedMemory(const MappedMemory&) = delete; - MappedMemory& operator=(const MappedMemory&) = delete; - - ~MappedMemory(); - - // The buffer resource that this mapping references. - const ref_ptr<Buffer>& buffer() const noexcept { return buffer_; } - // Offset, in bytes, into the resource allocation. - // This value is *informative only*, as it may vary from device to device. - iree_device_size_t byte_offset() const noexcept { return byte_offset_; } - // Length, in bytes, of the resource mapping. - // This may be larger than the originally requested length due to alignment. - // This value is *informative only*, as it may vary from device to device. - iree_device_size_t byte_length() const noexcept { return byte_length_; } - - // True if the mapping is empty. - bool empty() const noexcept { return element_size_ == 0; } - // The size of the mapping as requested in elements. - size_t size() const noexcept { return static_cast<size_t>(element_size_); } - - // Returns a read-only pointer to the mapped memory. - // This will be nullptr if the mapping failed or the mapping is not readable. - const T* data() const noexcept; - absl::Span<const T> contents() const noexcept { return {data(), size()}; } - - // Returns a mutable pointer to the mapped memory. - // This will be nullptr if the mapping failed or the mapping is not writable. - // If the mapping was not made with read access it may still be possible to - // read from this memory but behavior is undefined. - T* mutable_data() noexcept; - absl::Span<T> mutable_contents() noexcept { return {mutable_data(), size()}; } - - // Returns a raw pointer to the mapped data without any access checks. - T* unsafe_data() const noexcept { return data_; } - - // Equivalent to absl::Span::subspan(). - // May return a 0-length span. - // Fails if the buffer is not mapped or not mapped for the requested access. - StatusOr<absl::Span<const T>> Subspan( - iree_device_size_t element_offset = 0, - iree_device_size_t element_length = IREE_WHOLE_BUFFER) const noexcept; - StatusOr<absl::Span<T>> MutableSubspan( - iree_device_size_t element_offset = 0, - iree_device_size_t element_length = IREE_WHOLE_BUFFER) noexcept; - - // Accesses an element in the mapped memory. - // Must be called with a valid index in [0, size()). - const T& operator[](iree_device_size_t i) const noexcept { return data_[i]; } - - // Invalidates a range of non-coherent elements from the host caches. - Status Invalidate( - iree_device_size_t element_offset = 0, - iree_device_size_t element_length = IREE_WHOLE_BUFFER) const; - - // Flushes a range of non-coherent elements from the host caches. - Status Flush(iree_device_size_t element_offset = 0, - iree_device_size_t element_length = IREE_WHOLE_BUFFER); - - // Unmaps the mapped memory. - // The memory will not be implicitly flushed when unmapping. - void reset(); - - private: - Status ValidateAccess(iree_hal_memory_access_t memory_access) const; - Status CalculateDataRange( - iree_device_size_t element_offset, iree_device_size_t element_length, - iree_device_size_t* out_adjusted_element_offset, - iree_device_size_t* out_adjusted_element_length) const; - - iree_hal_memory_access_t access_ = IREE_HAL_MEMORY_ACCESS_NONE; - ref_ptr<Buffer> buffer_; - iree_device_size_t byte_offset_ = 0; - iree_device_size_t byte_length_ = 0; - iree_device_size_t element_size_ = 0; - T* data_ = nullptr; -}; - -// Inline functions and template definitions follow: - -template <typename T> -Status Buffer::Fill8(iree_device_size_t byte_offset, - iree_device_size_t byte_length, T value) { - auto sized_value = reinterpret_cast<uint8_t*>(&value); - return Fill(byte_offset, byte_length, sized_value, sizeof(*sized_value)); -} - -template <typename T> -Status Buffer::Fill16(iree_device_size_t byte_offset, - iree_device_size_t byte_length, T value) { - auto sized_value = reinterpret_cast<uint16_t*>(&value); - return Fill(byte_offset, byte_length, sized_value, sizeof(*sized_value)); -} - -template <typename T> -Status Buffer::Fill32(iree_device_size_t byte_offset, - iree_device_size_t byte_length, T value) { - auto sized_value = reinterpret_cast<uint32_t*>(&value); - return Fill(byte_offset, byte_length, sized_value, sizeof(*sized_value)); -} - -template <typename T> -Status Buffer::Fill8(T value) { - return Fill8(0, IREE_WHOLE_BUFFER, value); -} - -template <typename T> -Status Buffer::Fill16(T value) { - return Fill16(0, IREE_WHOLE_BUFFER, value); -} - -template <typename T> -Status Buffer::Fill32(T value) { - return Fill32(0, IREE_WHOLE_BUFFER, value); -} - -template <typename T> -StatusOr<MappedMemory<T>> Buffer::MapMemory( - iree_hal_memory_access_t memory_access, iree_device_size_t element_offset, - iree_device_size_t element_length) { - iree_device_size_t byte_offset = element_offset * sizeof(T); - iree_device_size_t byte_length = element_length == IREE_WHOLE_BUFFER - ? IREE_WHOLE_BUFFER - : element_length * sizeof(T); - void* data = nullptr; - IREE_RETURN_IF_ERROR(MapMemory(MappingMode::kScoped, memory_access, - &byte_offset, &byte_length, &data)); - return MappedMemory<T>{ - memory_access, add_ref(this), byte_offset, - byte_length, byte_length / sizeof(T), static_cast<T*>(data)}; -} - -template <typename T> -MappedMemory<T>::MappedMemory(iree_hal_memory_access_t access, - ref_ptr<Buffer> buffer, - iree_device_size_t byte_offset, - iree_device_size_t byte_length, - iree_device_size_t element_size, T* data) - : access_(access), - buffer_(std::move(buffer)), - byte_offset_(byte_offset), - byte_length_(byte_length), - element_size_(element_size), - data_(data) {} - -template <typename T> -MappedMemory<T>::MappedMemory(MappedMemory<T>&& rhs) noexcept - : access_(rhs.access_), - buffer_(std::move(rhs.buffer_)), - byte_offset_(rhs.byte_offset_), - byte_length_(rhs.byte_length_), - element_size_(rhs.element_size_), - data_(rhs.data_) { - rhs.access_ = IREE_HAL_MEMORY_ACCESS_NONE; - rhs.buffer_.reset(); - rhs.byte_offset_ = 0; - rhs.byte_length_ = 0; - rhs.element_size_ = 0; - rhs.data_ = nullptr; -} - -template <typename T> -MappedMemory<T>& MappedMemory<T>::operator=(MappedMemory<T>&& rhs) noexcept { - if (this != &rhs) { - reset(); - access_ = rhs.access_; - buffer_ = std::move(rhs.buffer_); - byte_offset_ = rhs.byte_offset_; - byte_length_ = rhs.byte_length_; - element_size_ = rhs.element_size_; - data_ = rhs.data_; - - rhs.access_ = IREE_HAL_MEMORY_ACCESS_NONE; - rhs.buffer_.reset(); - rhs.byte_offset_ = 0; - rhs.byte_length_ = 0; - rhs.element_size_ = 0; - rhs.data_ = nullptr; - } - return *this; -} - -template <typename T> -MappedMemory<T>::~MappedMemory() { - // Unmap (if needed) - note that we can't fail gracefully here :( - reset(); -} - -template <typename T> -const T* MappedMemory<T>::data() const noexcept { - if (!data_ || !iree_any_bit_set(access_, IREE_HAL_MEMORY_ACCESS_READ)) { - return nullptr; - } - return data_; -} - -template <typename T> -T* MappedMemory<T>::mutable_data() noexcept { - if (!data_ || !iree_any_bit_set(access_, IREE_HAL_MEMORY_ACCESS_WRITE)) { - return nullptr; - } - return data_; -} - -template <typename T> -Status MappedMemory<T>::ValidateAccess( - iree_hal_memory_access_t memory_access) const { - if (!data_) { - return FailedPreconditionErrorBuilder(IREE_LOC) << "Buffer is not mapped"; - } else if (!iree_any_bit_set(access_, memory_access)) { - return PermissionDeniedErrorBuilder(IREE_LOC) - << "Buffer is not mapped for the desired access"; - } - return OkStatus(); -} - -template <typename T> -Status MappedMemory<T>::CalculateDataRange( - iree_device_size_t element_offset, iree_device_size_t element_length, - iree_device_size_t* out_adjusted_element_offset, - iree_device_size_t* out_adjusted_element_length) const { - IREE_RETURN_IF_ERROR(Buffer::CalculateLocalRange( - element_size_ * sizeof(T), element_offset * sizeof(T), - element_length == IREE_WHOLE_BUFFER ? IREE_WHOLE_BUFFER - : element_length * sizeof(T), - out_adjusted_element_offset, out_adjusted_element_length)); - *out_adjusted_element_offset /= sizeof(T); - *out_adjusted_element_length /= sizeof(T); - return OkStatus(); -} - -template <typename T> -inline StatusOr<absl::Span<const T>> MappedMemory<T>::Subspan( - iree_device_size_t element_offset, - iree_device_size_t element_length) const noexcept { - IREE_RETURN_IF_ERROR(ValidateAccess(IREE_HAL_MEMORY_ACCESS_READ)); - IREE_RETURN_IF_ERROR(CalculateDataRange(element_offset, element_length, - &element_offset, &element_length)); - return absl::Span<const T>(data_ + element_offset, element_length); -} - -template <typename T> -inline StatusOr<absl::Span<T>> MappedMemory<T>::MutableSubspan( - iree_device_size_t element_offset, - iree_device_size_t element_length) noexcept { - IREE_RETURN_IF_ERROR(ValidateAccess(IREE_HAL_MEMORY_ACCESS_WRITE)); - IREE_RETURN_IF_ERROR(CalculateDataRange(element_offset, element_length, - &element_offset, &element_length)); - return absl::Span<T>(data_ + element_offset, element_length); -} - -template <typename T> -Status MappedMemory<T>::Invalidate(iree_device_size_t element_offset, - iree_device_size_t element_length) const { - IREE_RETURN_IF_ERROR(ValidateAccess(IREE_HAL_MEMORY_ACCESS_READ)); - IREE_RETURN_IF_ERROR(CalculateDataRange(element_offset, element_length, - &element_offset, &element_length)); - if (!element_length) return OkStatus(); - return buffer_->InvalidateMappedMemory( - byte_offset_ + element_offset * sizeof(T), element_length * sizeof(T)); -} - -template <typename T> -Status MappedMemory<T>::Flush(iree_device_size_t element_offset, - iree_device_size_t element_length) { - IREE_RETURN_IF_ERROR(ValidateAccess(IREE_HAL_MEMORY_ACCESS_WRITE)); - IREE_RETURN_IF_ERROR(CalculateDataRange(element_offset, element_length, - &element_offset, &element_length)); - if (!element_length) return OkStatus(); - return buffer_->FlushMappedMemory(byte_offset_ + element_offset * sizeof(T), - element_length * sizeof(T)); -} - -template <typename T> -void MappedMemory<T>::reset() { - if (!buffer_) return; - // TODO(benvanik): better handling of errors? may be fine to always warn. - buffer_->UnmapMemory(byte_offset_, byte_length_, data_).IgnoreError(); - buffer_.reset(); - access_ = IREE_HAL_MEMORY_ACCESS_NONE; - byte_offset_ = 0; - byte_length_ = 0; - element_size_ = 0; - data_ = nullptr; -} - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_CC_BUFFER_H_
diff --git a/iree/hal/cc/buffer_mapping_test.cc b/iree/hal/cc/buffer_mapping_test.cc deleted file mode 100644 index 41bc001..0000000 --- a/iree/hal/cc/buffer_mapping_test.cc +++ /dev/null
@@ -1,560 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Tests for the MemoryMapping RAII wrapper. -// This uses a mock buffer implementation such that it is only testing -// MemoryMapping and not any real underlying memory mapping behavior. - -#include <cstdint> -#include <memory> -#include <utility> - -#include "absl/types/span.h" -#include "iree/base/status.h" -#include "iree/hal/cc/buffer.h" -#include "iree/testing/gtest.h" -#include "iree/testing/status_matchers.h" - -namespace iree { -namespace hal { -class Allocator; - -namespace { - -using ::testing::_; -using ::testing::DoAll; -using ::testing::Return; -using ::testing::SetArgPointee; - -static void* const kValidPtr = reinterpret_cast<void*>(0xBEEFCAFEF00D1234ull); - -class MockBuffer : public Buffer { - public: - using MappingMode = Buffer::MappingMode; - - MockBuffer(Allocator* allocator, iree_hal_memory_type_t memory_type, - iree_hal_memory_access_t allowed_access, - iree_hal_buffer_usage_t usage, iree_device_size_t allocation_size) - : Buffer(allocator, memory_type, allowed_access, usage, allocation_size, - 0, allocation_size) {} - - MOCK_METHOD(Status, FillImpl, - (iree_device_size_t byte_offset, iree_device_size_t byte_length, - const void* pattern, iree_device_size_t pattern_length), - (override)); - - MOCK_METHOD(Status, ReadDataImpl, - (iree_device_size_t source_offset, void* data, - iree_device_size_t data_length), - (override)); - MOCK_METHOD(Status, WriteDataImpl, - (iree_device_size_t target_offset, const void* data, - iree_device_size_t data_length), - (override)); - MOCK_METHOD(Status, CopyDataImpl, - (iree_device_size_t target_offset, Buffer* source_buffer, - iree_device_size_t source_offset, - iree_device_size_t data_length), - (override)); - - MOCK_METHOD(Status, MapMemoryImpl, - (MappingMode mapping_mode, iree_hal_memory_access_t memory_access, - iree_device_size_t local_byte_offset, - iree_device_size_t local_byte_length, void** out_data), - (override)); - MOCK_METHOD(Status, UnmapMemoryImpl, - (iree_device_size_t local_byte_offset, - iree_device_size_t local_byte_length, void* data), - (override)); - MOCK_METHOD(Status, InvalidateMappedMemoryImpl, - (iree_device_size_t local_byte_offset, - iree_device_size_t local_byte_length), - (override)); - MOCK_METHOD(Status, FlushMappedMemoryImpl, - (iree_device_size_t local_byte_offset, - iree_device_size_t local_byte_length), - (override)); -}; - -TEST(MemoryMappingTest, MapWholeBuffer) { - auto buffer = make_ref<MockBuffer>(nullptr, IREE_HAL_MEMORY_TYPE_HOST_LOCAL, - IREE_HAL_MEMORY_ACCESS_ALL, - IREE_HAL_BUFFER_USAGE_ALL, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - IREE_HAL_MEMORY_ACCESS_READ, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN( - auto mapping, buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_READ)); - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mapping.reset(); -} - -TEST(MemoryMappingTest, MapPartialBuffer) { - auto buffer = make_ref<MockBuffer>(nullptr, IREE_HAL_MEMORY_TYPE_HOST_LOCAL, - IREE_HAL_MEMORY_ACCESS_ALL, - IREE_HAL_BUFFER_USAGE_ALL, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - IREE_HAL_MEMORY_ACCESS_READ, 4, 12, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN( - auto mapping, - buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_READ, 4, 12)); - EXPECT_CALL(*buffer, UnmapMemoryImpl(4, 12, kValidPtr)) - .WillOnce(Return(OkStatus())); - mapping.reset(); -} - -TEST(MemoryMappingTest, EmptyHandle) { - MappedMemory<uint8_t> mm_a; - MappedMemory<uint8_t> mm_b; - mm_a = std::move(mm_b); - EXPECT_EQ(nullptr, mm_a.buffer()); - EXPECT_EQ(0, mm_a.byte_offset()); - EXPECT_EQ(0, mm_a.byte_length()); - EXPECT_TRUE(mm_a.empty()); - EXPECT_EQ(0, mm_a.size()); - EXPECT_EQ(nullptr, mm_a.data()); - EXPECT_EQ(nullptr, mm_a.mutable_data()); - EXPECT_TRUE(IsFailedPrecondition(mm_a.Subspan().status())); - EXPECT_TRUE(IsFailedPrecondition(mm_a.MutableSubspan().status())); - EXPECT_TRUE(IsFailedPrecondition(mm_a.Invalidate())); - EXPECT_TRUE(IsFailedPrecondition(mm_a.Flush())); - mm_a.reset(); -} - -TEST(MemoryMappingTest, MoveHandle) { - auto buffer = make_ref<MockBuffer>(nullptr, IREE_HAL_MEMORY_TYPE_HOST_LOCAL, - IREE_HAL_MEMORY_ACCESS_ALL, - IREE_HAL_BUFFER_USAGE_ALL, 128); - - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - IREE_HAL_MEMORY_ACCESS_READ, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN( - auto mm_a, buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_READ)); - - // Should be able to move the handle around without having any calls. - auto mm_b = std::move(mm_a); - mm_a = std::move(mm_b); - mm_b = std::move(mm_a); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_b.reset(); -} - -TEST(MemoryMappingTest, ReadOnlyAccess) { - auto buffer = make_ref<MockBuffer>(nullptr, IREE_HAL_MEMORY_TYPE_HOST_LOCAL, - IREE_HAL_MEMORY_ACCESS_READ, - IREE_HAL_BUFFER_USAGE_ALL, 128); - - // Should succeed to map for reading. - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - IREE_HAL_MEMORY_ACCESS_READ, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN( - auto mm_r, buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_READ)); - - // Non-mutable access is fine. - EXPECT_EQ(kValidPtr, mm_r.data()); - IREE_ASSERT_OK_AND_ASSIGN(auto span, mm_r.Subspan()); - (void)span; - - // Read-only mappings should not be able to get mutable access. - EXPECT_EQ(nullptr, mm_r.mutable_data()); - EXPECT_TRUE(IsPermissionDenied(mm_r.MutableSubspan().status())); - - // Read-only mappings should not be able to call Flush. - EXPECT_TRUE(IsPermissionDenied(mm_r.Flush())); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_r.reset(); - - // Should fail to map for writing. - EXPECT_TRUE(IsPermissionDenied( - buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_WRITE).status())); -} - -TEST(MemoryMappingTest, ReadWriteAccess) { - auto buffer = make_ref<MockBuffer>( - nullptr, IREE_HAL_MEMORY_TYPE_HOST_LOCAL, - IREE_HAL_MEMORY_ACCESS_READ | IREE_HAL_MEMORY_ACCESS_WRITE, - IREE_HAL_BUFFER_USAGE_ALL, 128); - - // Should succeed to map for reading and/or writing. - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - IREE_HAL_MEMORY_ACCESS_READ | - IREE_HAL_MEMORY_ACCESS_WRITE, - 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN( - auto mm_rw, buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_READ | - IREE_HAL_MEMORY_ACCESS_WRITE)); - - // Everything valid. - EXPECT_EQ(kValidPtr, mm_rw.data()); - IREE_ASSERT_OK_AND_ASSIGN(auto span, mm_rw.Subspan()); - EXPECT_EQ(kValidPtr, mm_rw.mutable_data()); - IREE_ASSERT_OK_AND_ASSIGN(span, mm_rw.MutableSubspan()); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_rw.reset(); - - // Should fail to map for discard. - EXPECT_TRUE(IsPermissionDenied( - buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE) - .status())); -} - -TEST(MemoryMappingTest, WriteOnlyAccess) { - auto buffer = make_ref<MockBuffer>(nullptr, IREE_HAL_MEMORY_TYPE_HOST_LOCAL, - IREE_HAL_MEMORY_ACCESS_WRITE, - IREE_HAL_BUFFER_USAGE_ALL, 128); - - // Should succeed to map for writing. - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - IREE_HAL_MEMORY_ACCESS_WRITE, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN( - auto mm_w, buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_WRITE)); - - // Mutable access is valid. - EXPECT_EQ(kValidPtr, mm_w.mutable_data()); - IREE_ASSERT_OK_AND_ASSIGN(auto span, mm_w.MutableSubspan()); - (void)span; - - // Write-only mappings should not be able to get non-mutable access. - EXPECT_EQ(nullptr, mm_w.data()); - EXPECT_TRUE(IsPermissionDenied(mm_w.Subspan().status())); - - // Write-only mappings should not be able to call Invalidate. - EXPECT_TRUE(IsPermissionDenied(mm_w.Invalidate())); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_w.reset(); - - // Should fail to map for reading. - EXPECT_TRUE(IsPermissionDenied( - buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_READ).status())); - - // Should fail to map for discard. - EXPECT_TRUE(IsPermissionDenied( - buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE) - .status())); -} - -TEST(MemoryMappingTest, WriteDiscardAccess) { - auto buffer = make_ref<MockBuffer>(nullptr, IREE_HAL_MEMORY_TYPE_HOST_LOCAL, - IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE, - IREE_HAL_BUFFER_USAGE_ALL, 128); - - // Should succeed to map for writing with discard. - EXPECT_CALL(*buffer, - MapMemoryImpl(MockBuffer::MappingMode::kScoped, - IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN( - auto mm_dw, - buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE)); - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_dw.reset(); - - // Should also be ok to map for just writing. - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - IREE_HAL_MEMORY_ACCESS_WRITE, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN( - auto mm_w, buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_WRITE)); - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_w.reset(); - - // Should fail to map for reading. - EXPECT_TRUE(IsPermissionDenied( - buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_READ).status())); -} - -TEST(MemoryMappingTest, Subspan) { - auto buffer = make_ref<MockBuffer>(nullptr, IREE_HAL_MEMORY_TYPE_HOST_LOCAL, - IREE_HAL_MEMORY_ACCESS_ALL, - IREE_HAL_BUFFER_USAGE_ALL, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - IREE_HAL_MEMORY_ACCESS_READ, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN( - auto mm_r, buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_READ)); - - // Request some valid ranges and ensure the byte offsets are correct. - IREE_ASSERT_OK_AND_ASSIGN(auto ss, mm_r.Subspan()); - EXPECT_EQ(kValidPtr, ss.data()); - EXPECT_EQ(128, ss.size()); - IREE_ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(100, 2)); - EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 100, ss.data()); - EXPECT_EQ(2, ss.size()); - IREE_ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(100, IREE_WHOLE_BUFFER)); - EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 100, ss.data()); - EXPECT_EQ(28, ss.size()); - - // Zero length ranges are fine. - IREE_ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(0, 0)); - EXPECT_EQ(kValidPtr, ss.data()); - EXPECT_TRUE(ss.empty()); - IREE_ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(128, 0)); - EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 128, ss.data()); - EXPECT_TRUE(ss.empty()); - IREE_ASSERT_OK_AND_ASSIGN(ss, mm_r.Subspan(128, IREE_WHOLE_BUFFER)); - EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 128, ss.data()); - EXPECT_TRUE(ss.empty()); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_r.reset(); -} - -TEST(MemoryMappingTest, SubspanOutOfRange) { - auto buffer = make_ref<MockBuffer>(nullptr, IREE_HAL_MEMORY_TYPE_HOST_LOCAL, - IREE_HAL_MEMORY_ACCESS_ALL, - IREE_HAL_BUFFER_USAGE_ALL, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - IREE_HAL_MEMORY_ACCESS_READ, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN( - auto mm_r, buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_READ)); - - // Try some invalid ranges that would overrun the span. - EXPECT_TRUE(IsOutOfRange(mm_r.Subspan(1234, 0).status())); - EXPECT_TRUE(IsOutOfRange(mm_r.Subspan(1234, 2).status())); - EXPECT_TRUE(IsOutOfRange(mm_r.Subspan(1234, IREE_WHOLE_BUFFER).status())); - EXPECT_TRUE(IsOutOfRange(mm_r.Subspan(100, 1234).status())); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_r.reset(); -} - -TEST(MemoryMappingTest, MutableSubspan) { - auto buffer = make_ref<MockBuffer>(nullptr, IREE_HAL_MEMORY_TYPE_HOST_LOCAL, - IREE_HAL_MEMORY_ACCESS_ALL, - IREE_HAL_BUFFER_USAGE_ALL, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - IREE_HAL_MEMORY_ACCESS_WRITE, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN( - auto mm_w, buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_WRITE)); - - // Request some valid ranges and ensure the byte offsets are correct. - IREE_ASSERT_OK_AND_ASSIGN(auto ss, mm_w.MutableSubspan()); - EXPECT_EQ(kValidPtr, ss.data()); - EXPECT_EQ(128, ss.size()); - IREE_ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(100, 2)); - EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 100, ss.data()); - EXPECT_EQ(2, ss.size()); - IREE_ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(100, IREE_WHOLE_BUFFER)); - EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 100, ss.data()); - EXPECT_EQ(28, ss.size()); - - // Zero length ranges are fine. - IREE_ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(0, 0)); - EXPECT_EQ(kValidPtr, ss.data()); - EXPECT_TRUE(ss.empty()); - IREE_ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(128, 0)); - EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 128, ss.data()); - EXPECT_TRUE(ss.empty()); - IREE_ASSERT_OK_AND_ASSIGN(ss, mm_w.MutableSubspan(128, IREE_WHOLE_BUFFER)); - EXPECT_EQ(static_cast<const uint8_t*>(kValidPtr) + 128, ss.data()); - EXPECT_TRUE(ss.empty()); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_w.reset(); -} - -TEST(MemoryMappingTest, MutableSubspanOutOfRange) { - auto buffer = make_ref<MockBuffer>(nullptr, IREE_HAL_MEMORY_TYPE_HOST_LOCAL, - IREE_HAL_MEMORY_ACCESS_ALL, - IREE_HAL_BUFFER_USAGE_ALL, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - IREE_HAL_MEMORY_ACCESS_WRITE, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN( - auto mm_w, buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_WRITE)); - - // Try some invalid ranges that would overrun the span. - EXPECT_TRUE(IsOutOfRange(mm_w.MutableSubspan(1234, 0).status())); - EXPECT_TRUE(IsOutOfRange(mm_w.MutableSubspan(1234, 2).status())); - EXPECT_TRUE( - IsOutOfRange(mm_w.MutableSubspan(1234, IREE_WHOLE_BUFFER).status())); - EXPECT_TRUE(IsOutOfRange(mm_w.MutableSubspan(100, 1234).status())); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_w.reset(); -} - -TEST(MemoryMappingTest, ElementOperator) { - auto buffer = make_ref<MockBuffer>(nullptr, IREE_HAL_MEMORY_TYPE_HOST_LOCAL, - IREE_HAL_MEMORY_ACCESS_ALL, - IREE_HAL_BUFFER_USAGE_ALL, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - IREE_HAL_MEMORY_ACCESS_READ, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN( - auto mm_r, buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_READ)); - - // Just verify we are getting the expected pointer back. - EXPECT_EQ(kValidPtr, &mm_r[0]); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_r.reset(); -} - -TEST(MemoryMappingTest, Invalidate) { - auto buffer = make_ref<MockBuffer>(nullptr, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE, - IREE_HAL_MEMORY_ACCESS_ALL, - IREE_HAL_BUFFER_USAGE_ALL, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - IREE_HAL_MEMORY_ACCESS_READ, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN( - auto mm_r, buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_READ)); - - // Invalidate a few ways. - EXPECT_CALL(*buffer, InvalidateMappedMemoryImpl(0, 128)) - .WillOnce(Return(OkStatus())); - IREE_EXPECT_OK(mm_r.Invalidate()); - EXPECT_CALL(*buffer, InvalidateMappedMemoryImpl(100, 2)) - .WillOnce(Return(OkStatus())); - IREE_EXPECT_OK(mm_r.Invalidate(100, 2)); - EXPECT_CALL(*buffer, InvalidateMappedMemoryImpl(100, 28)) - .WillOnce(Return(OkStatus())); - IREE_EXPECT_OK(mm_r.Invalidate(100, IREE_WHOLE_BUFFER)); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_r.reset(); -} - -TEST(MemoryMappingTest, InvalidateOutOfRange) { - auto buffer = make_ref<MockBuffer>(nullptr, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE, - IREE_HAL_MEMORY_ACCESS_ALL, - IREE_HAL_BUFFER_USAGE_ALL, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - IREE_HAL_MEMORY_ACCESS_READ, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN( - auto mm_r, buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_READ)); - - // Try to invalidate invalid ranges. - EXPECT_TRUE(IsOutOfRange(mm_r.Invalidate(1234, 0))); - EXPECT_TRUE(IsOutOfRange(mm_r.Invalidate(1234, 12345))); - EXPECT_TRUE(IsOutOfRange(mm_r.Invalidate(1234, IREE_WHOLE_BUFFER))); - EXPECT_TRUE(IsOutOfRange(mm_r.Invalidate(1, 1234))); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_r.reset(); -} - -TEST(MemoryMappingTest, InvalidateBadMode) { - // Invalidate is not required on coherent memory. - auto coherent_buffer = make_ref<MockBuffer>( - nullptr, IREE_HAL_MEMORY_TYPE_HOST_LOCAL, IREE_HAL_MEMORY_ACCESS_ALL, - IREE_HAL_BUFFER_USAGE_ALL, 128); - EXPECT_CALL(*coherent_buffer, - MapMemoryImpl(MockBuffer::MappingMode::kScoped, - IREE_HAL_MEMORY_ACCESS_READ, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN(auto mm_r, coherent_buffer->MapMemory<uint8_t>( - IREE_HAL_MEMORY_ACCESS_READ)); - EXPECT_TRUE(IsPermissionDenied(mm_r.Invalidate())); - EXPECT_CALL(*coherent_buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_r.reset(); -} - -TEST(MemoryMappingTest, Flush) { - auto buffer = make_ref<MockBuffer>( - nullptr, - IREE_HAL_MEMORY_TYPE_HOST_VISIBLE | IREE_HAL_MEMORY_TYPE_HOST_CACHED, - IREE_HAL_MEMORY_ACCESS_ALL, IREE_HAL_BUFFER_USAGE_ALL, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - IREE_HAL_MEMORY_ACCESS_WRITE, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN( - auto mm_w, buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_WRITE)); - - // Flush a few ways. - EXPECT_CALL(*buffer, FlushMappedMemoryImpl(0, 128)) - .WillOnce(Return(OkStatus())); - IREE_EXPECT_OK(mm_w.Flush()); - EXPECT_CALL(*buffer, FlushMappedMemoryImpl(100, 2)) - .WillOnce(Return(OkStatus())); - IREE_EXPECT_OK(mm_w.Flush(100, 2)); - EXPECT_CALL(*buffer, FlushMappedMemoryImpl(100, 28)) - .WillOnce(Return(OkStatus())); - IREE_EXPECT_OK(mm_w.Flush(100, IREE_WHOLE_BUFFER)); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_w.reset(); -} - -TEST(MemoryMappingTest, FlushOutOfRange) { - auto buffer = make_ref<MockBuffer>( - nullptr, - IREE_HAL_MEMORY_TYPE_HOST_VISIBLE | IREE_HAL_MEMORY_TYPE_HOST_CACHED, - IREE_HAL_MEMORY_ACCESS_ALL, IREE_HAL_BUFFER_USAGE_ALL, 128); - EXPECT_CALL(*buffer, MapMemoryImpl(MockBuffer::MappingMode::kScoped, - IREE_HAL_MEMORY_ACCESS_WRITE, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN( - auto mm_w, buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_WRITE)); - - // Try to flush invalid ranges. - EXPECT_TRUE(IsOutOfRange(mm_w.Flush(1234, 0))); - EXPECT_TRUE(IsOutOfRange(mm_w.Flush(1234, 12345))); - EXPECT_TRUE(IsOutOfRange(mm_w.Flush(1234, IREE_WHOLE_BUFFER))); - EXPECT_TRUE(IsOutOfRange(mm_w.Flush(1, 1234))); - - EXPECT_CALL(*buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_w.reset(); -} - -TEST(MemoryMappingTest, FlushBadMode) { - // Flush is not required on uncached memory. - auto uncached_buffer = make_ref<MockBuffer>( - nullptr, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE, IREE_HAL_MEMORY_ACCESS_ALL, - IREE_HAL_BUFFER_USAGE_ALL, 128); - EXPECT_CALL(*uncached_buffer, - MapMemoryImpl(MockBuffer::MappingMode::kScoped, - IREE_HAL_MEMORY_ACCESS_WRITE, 0, 128, _)) - .WillOnce(DoAll(SetArgPointee<4>(kValidPtr), Return(OkStatus()))); - IREE_ASSERT_OK_AND_ASSIGN(auto mm_w, uncached_buffer->MapMemory<uint8_t>( - IREE_HAL_MEMORY_ACCESS_WRITE)); - EXPECT_TRUE(IsPermissionDenied(mm_w.Flush())); - EXPECT_CALL(*uncached_buffer, UnmapMemoryImpl(0, 128, kValidPtr)) - .WillOnce(Return(OkStatus())); - mm_w.reset(); -} - -} // namespace -} // namespace hal -} // namespace iree
diff --git a/iree/hal/cc/buffer_test.cc b/iree/hal/cc/buffer_test.cc deleted file mode 100644 index 61c2c9a..0000000 --- a/iree/hal/cc/buffer_test.cc +++ /dev/null
@@ -1,1048 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Tests for the shared buffer functionality and host heap buffers. -// This does not test device-specific buffer implementations; see the device -// code for associated tests. - -#include "iree/hal/cc/buffer.h" - -#include <vector> - -#include "iree/testing/gtest.h" -#include "iree/testing/status_matchers.h" - -#if 0 // DISABLED: this will have changes in future commits in this branch. - -namespace iree { -namespace hal { -namespace { - -using ::testing::_; -using ::testing::ElementsAre; -using ::testing::Eq; -using ::testing::Not; - -TEST(BufferTest, Allocate) { - auto buffer = HeapBuffer::Allocate( - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, 14); - EXPECT_NE(nullptr, buffer->allocator()); - EXPECT_EQ(IREE_HAL_MEMORY_ACCESS_ALL, buffer->allowed_access()); - EXPECT_EQ(IREE_HAL_MEMORY_TYPE_HOST_LOCAL, buffer->memory_type()); - EXPECT_EQ(IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, - buffer->usage()); - - // We don't currently do any padding on the host. - // Other implementations may differ. - EXPECT_LE(14, buffer->allocation_size()); - EXPECT_EQ(0, buffer->byte_offset()); - EXPECT_EQ(14, buffer->byte_length()); - - // Data should be zeroed by default. - std::vector<uint8_t> zero_data(buffer->allocation_size()); - std::vector<uint8_t> actual_data(buffer->allocation_size()); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(zero_data)); -} - -TEST(BufferTest, AllocateZeroLength) { - auto buffer = HeapBuffer::Allocate( - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, 0); - EXPECT_NE(nullptr, buffer->allocator()); - EXPECT_EQ(IREE_HAL_MEMORY_TYPE_HOST_LOCAL, buffer->memory_type()); - EXPECT_EQ(IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, - buffer->usage()); - EXPECT_EQ(0, buffer->allocation_size()); -} - -TEST(BufferTest, AllocateCopy) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto buffer = HeapBuffer::AllocateCopy( - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, - src_data.data(), src_data.size()); - EXPECT_NE(nullptr, buffer->allocator()); - EXPECT_LE(src_data.size(), buffer->allocation_size()); - - // Data should have been copied. - std::vector<uint8_t> actual_data(src_data.size()); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(src_data)); - - // Modify the source data and ensure it is not reflected in the buffer. - src_data[0] = 0x88; - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Not(Eq(src_data))); -} - -TEST(BufferTest, AllocateCopyZeroLength) { - std::vector<uint8_t> src_data; - auto buffer = HeapBuffer::AllocateCopy( - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, - src_data.data(), src_data.size()); - EXPECT_NE(nullptr, buffer->allocator()); - EXPECT_EQ(0, buffer->allocation_size()); -} - -TEST(BufferTest, AllocateCopyTyped) { - std::vector<int32_t> src_data = {0, 1, 2, 3}; - auto buffer = HeapBuffer::AllocateCopy( - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, - absl::MakeConstSpan(src_data)); - EXPECT_NE(nullptr, buffer->allocator()); - EXPECT_EQ(IREE_HAL_MEMORY_TYPE_HOST_LOCAL, buffer->memory_type()); - EXPECT_EQ(IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, - buffer->usage()); - EXPECT_LE(src_data.size() * sizeof(int32_t), buffer->allocation_size()); - - // Data should have been copied. - std::vector<int32_t> actual_data(src_data.size()); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), - actual_data.size() * sizeof(int32_t))); - EXPECT_THAT(actual_data, Eq(src_data)); -} - -TEST(BufferTest, WrapConstant) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto buffer = HeapBuffer::Wrap( - IREE_HAL_MEMORY_TYPE_HOST_LOCAL, - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, - absl::MakeConstSpan(src_data)); - EXPECT_EQ(IREE_HAL_MEMORY_TYPE_HOST_LOCAL, buffer->memory_type()); - EXPECT_EQ(IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, - buffer->usage()); - EXPECT_EQ(src_data.size(), buffer->allocation_size()); - - // src_data and buffer should match after the wrapping. - std::vector<uint8_t> actual_data(src_data.size()); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(src_data)); - - // Modify the source data directly. - src_data[0] = 123; - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(src_data)); - - // Attempts to modify the buffer should fail. - std::vector<uint8_t> new_data = {3, 2, 1, 0}; - EXPECT_TRUE(IsPermissionDenied( - buffer->WriteData(0, new_data.data(), new_data.size()))); -} - -TEST(BufferTest, WrapMutable) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto buffer = HeapBuffer::WrapMutable( - IREE_HAL_MEMORY_TYPE_HOST_LOCAL, IREE_HAL_MEMORY_ACCESS_ALL, - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, - absl::MakeSpan(src_data)); - EXPECT_EQ(IREE_HAL_MEMORY_TYPE_HOST_LOCAL, buffer->memory_type()); - EXPECT_EQ(IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, - buffer->usage()); - EXPECT_EQ(src_data.size(), buffer->allocation_size()); - - // src_data and buffer should match after the wrapping. - std::vector<uint8_t> actual_data(src_data.size()); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(src_data)); - - // Modify the source data directly. - src_data[0] = 123; - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(src_data)); - - // Modify the source data via the Buffer and ensure reflected in src_data. - std::vector<uint8_t> new_data = {3, 2, 1, 0}; - IREE_EXPECT_OK(buffer->WriteData(0, new_data.data(), new_data.size())); - EXPECT_THAT(src_data, Eq(new_data)); -} - -TEST(BufferTest, WrapExternal) { - // This is not fully supported yet, but does let us verify that the validation - // of memory types is working. - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto buffer = HeapBuffer::Wrap(IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, - IREE_HAL_BUFFER_USAGE_ALL, - absl::MakeConstSpan(src_data)); - EXPECT_EQ(IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, buffer->memory_type()); - - // Should fail (for now) as the buffer is not host visible. - EXPECT_TRUE(IsPermissionDenied(buffer->Fill8(0, IREE_WHOLE_BUFFER, 0x99u))); -} - -TEST(BufferTest, DoesOverlap) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto parent_buffer = HeapBuffer::AllocateCopy( - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, - src_data.data(), src_data.size()); - - // A buffer should overlap with itself. - EXPECT_FALSE(Buffer::DoesOverlap(parent_buffer.get(), 0, 1, - parent_buffer.get(), 1, 1)); - EXPECT_TRUE(Buffer::DoesOverlap(parent_buffer.get(), 0, 1, - parent_buffer.get(), 0, 1)); - - // Zero length buffers never overlap. - EXPECT_FALSE(Buffer::DoesOverlap(parent_buffer.get(), 1, 1, - parent_buffer.get(), 1, 0)); - - // Subspans should offset within their allocation. - IREE_ASSERT_OK_AND_ASSIGN(auto subspan_buffer_0, - Buffer::Subspan(parent_buffer, 1, 2)); - IREE_ASSERT_OK_AND_ASSIGN(auto subspan_buffer_1, - Buffer::Subspan(parent_buffer, 2, 2)); - EXPECT_FALSE(Buffer::DoesOverlap(subspan_buffer_0.get(), 0, 1, - subspan_buffer_1.get(), 0, 1)); - EXPECT_TRUE(Buffer::DoesOverlap(subspan_buffer_0.get(), 1, 1, - subspan_buffer_1.get(), 0, 1)); - - // Mixing subspans and normal buffers. - EXPECT_FALSE(Buffer::DoesOverlap(parent_buffer.get(), 0, 1, - subspan_buffer_0.get(), 0, 1)); - EXPECT_TRUE(Buffer::DoesOverlap(parent_buffer.get(), 1, 2, - subspan_buffer_0.get(), 1, 1)); - - // Independent buffers should not be able to overlap. - auto other_buffer = HeapBuffer::Allocate(IREE_HAL_BUFFER_USAGE_ALL, 128); - EXPECT_FALSE(Buffer::DoesOverlap(parent_buffer.get(), 0, IREE_WHOLE_BUFFER, - other_buffer.get(), 0, IREE_WHOLE_BUFFER)); -} - -TEST(BufferTest, Subspan) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto parent_buffer = HeapBuffer::AllocateCopy( - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, - src_data.data(), src_data.size()); - ASSERT_TRUE(parent_buffer); - - // Create a subspan of the buffer. - IREE_ASSERT_OK_AND_ASSIGN(auto subspan_buffer, - Buffer::Subspan(parent_buffer, 1, 2)); - ASSERT_TRUE(subspan_buffer); - EXPECT_EQ(1, subspan_buffer->byte_offset()); - EXPECT_EQ(2, subspan_buffer->byte_length()); - - // Modifications to either buffer should appear in the other. - IREE_EXPECT_OK(subspan_buffer->Fill8(1, IREE_WHOLE_BUFFER, 0xFFu)); - std::vector<uint8_t> actual_data(src_data.size()); - IREE_EXPECT_OK( - parent_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 1, 0xFF, 3)); - - // Subspans should be able to create subspans. - // NOTE: offset is from the original buffer. - IREE_ASSERT_OK_AND_ASSIGN(auto subsubspan_buffer, - Buffer::Subspan(subspan_buffer, 1, 1)); - ASSERT_TRUE(subsubspan_buffer); - EXPECT_EQ(2, subsubspan_buffer->byte_offset()); - EXPECT_EQ(1, subsubspan_buffer->byte_length()); - - // Zero length subspans are fine. - IREE_ASSERT_OK_AND_ASSIGN(auto zero_subspan_buffer, - Buffer::Subspan(parent_buffer, 0, 0)); - ASSERT_TRUE(zero_subspan_buffer); - EXPECT_EQ(0, zero_subspan_buffer->byte_offset()); - EXPECT_EQ(0, zero_subspan_buffer->byte_length()); - - // Subspan with IREE_WHOLE_BUFFER should get the remaining size (or zero). - IREE_ASSERT_OK_AND_ASSIGN( - auto whole_subspan_buffer, - Buffer::Subspan(parent_buffer, 1, IREE_WHOLE_BUFFER)); - ASSERT_TRUE(whole_subspan_buffer); - EXPECT_EQ(1, whole_subspan_buffer->byte_offset()); - EXPECT_EQ(3, whole_subspan_buffer->byte_length()); - - // Zero length subspans are fine. - IREE_ASSERT_OK(Buffer::Subspan(subspan_buffer, 2, 0)); - IREE_ASSERT_OK(Buffer::Subspan(subspan_buffer, 2, IREE_WHOLE_BUFFER)); -} - -TEST(BufferTest, SubspanIdentity) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto parent_buffer = HeapBuffer::AllocateCopy( - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, - src_data.data(), src_data.size()); - - // Asking for a subspan of the entire buffer should return the same buffer. - // Mostly an optimization. - EXPECT_EQ(parent_buffer.get(), - Buffer::Subspan(parent_buffer, 0, IREE_WHOLE_BUFFER).value().get()); - EXPECT_EQ(parent_buffer.get(), - Buffer::Subspan(parent_buffer, 0, 4).value().get()); -} - -TEST(BufferTest, SubspanOutOfRange) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto parent_buffer = HeapBuffer::AllocateCopy( - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, - src_data.data(), src_data.size()); - ASSERT_TRUE(parent_buffer); - - // Create a subspan of the buffer. - IREE_ASSERT_OK_AND_ASSIGN(auto subspan_buffer, - Buffer::Subspan(parent_buffer, 1, 2)); - ASSERT_TRUE(subspan_buffer); - EXPECT_EQ(1, subspan_buffer->byte_offset()); - EXPECT_EQ(2, subspan_buffer->byte_length()); - - // Try to make subspans from invalid ranges. - EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(parent_buffer, 5, 0).status())); - EXPECT_TRUE(IsOutOfRange( - Buffer::Subspan(parent_buffer, 5, IREE_WHOLE_BUFFER).status())); - EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(parent_buffer, 4, 1).status())); - EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(parent_buffer, 0, 123).status())); - EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(subspan_buffer, 1, 2).status())); - EXPECT_TRUE(IsOutOfRange(Buffer::Subspan(subspan_buffer, 0, 44).status())); -} - -TEST(BufferTest, Fill8) { - auto buffer = HeapBuffer::Allocate(IREE_HAL_BUFFER_USAGE_MAPPING, 5); - ASSERT_TRUE(buffer); - - // Data should be zeroed by default. - std::vector<uint8_t> actual_data(buffer->allocation_size()); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 0, 0, 0, 0)); - - // Fill with a sentinel. - IREE_EXPECT_OK(buffer->Fill8(0, buffer->allocation_size(), 0x33u)); - - // Verify data. - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x33, 0x33, 0x33, 0x33, 0x33)); - - // Zero fills are fine. - IREE_EXPECT_OK(buffer->Fill8(0, 0, 0x44u)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x33, 0x33, 0x33, 0x33, 0x33)); - - // Fill the remaining parts of the buffer by using IREE_WHOLE_BUFFER. - IREE_EXPECT_OK(buffer->Fill8(2, IREE_WHOLE_BUFFER, 0x55u)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x33, 0x33, 0x55, 0x55, 0x55)); - - // Fill a small region of the buffer. - IREE_EXPECT_OK(buffer->Fill8(1, 1, 0x66u)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x33, 0x66, 0x55, 0x55, 0x55)); - - // Whole buffer helper. - IREE_EXPECT_OK(buffer->Fill8(0x99u)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x99, 0x99, 0x99, 0x99, 0x99)); -} - -TEST(BufferTest, Fill8OutOfRange) { - auto buffer = HeapBuffer::Allocate(IREE_HAL_BUFFER_USAGE_MAPPING, 5); - ASSERT_TRUE(buffer); - - // Fill with a sentinel. - IREE_EXPECT_OK(buffer->Fill8(0, buffer->allocation_size(), 0x33u)); - - // Try to fill with invalid ranges. - EXPECT_TRUE(IsOutOfRange(buffer->Fill8(1, 444, 0x44u))); - EXPECT_TRUE(IsOutOfRange(buffer->Fill8(123, 444, 0x44u))); - EXPECT_TRUE(IsOutOfRange(buffer->Fill8(123, 1, 0x44u))); - EXPECT_TRUE(IsOutOfRange(buffer->Fill8(1, 444, 0x44u))); - - // Ensure nothing happened with the bad ranges. - std::vector<uint8_t> actual_data(buffer->allocation_size()); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x33, 0x33, 0x33, 0x33, 0x33)); -} - -TEST(BufferTest, Fill8BadMode) { - // Fail to fill buffers not supporting mapping. - auto nonmapping_buffer = - HeapBuffer::Allocate(IREE_HAL_BUFFER_USAGE_TRANSFER, 4); - EXPECT_TRUE(IsPermissionDenied( - nonmapping_buffer->Fill8(0, IREE_WHOLE_BUFFER, 0x99u))); - - // Fail to fill constant buffers. - std::vector<uint8_t> const_data = {1, 2, 3}; - auto constant_buffer = HeapBuffer::Wrap(IREE_HAL_MEMORY_TYPE_HOST_LOCAL, - IREE_HAL_BUFFER_USAGE_MAPPING, - absl::MakeConstSpan(const_data)); - EXPECT_TRUE( - IsPermissionDenied(constant_buffer->Fill8(0, IREE_WHOLE_BUFFER, 0x99u))); -} - -TEST(BufferTest, Fill8Subspan) { - auto buffer = HeapBuffer::Allocate(IREE_HAL_BUFFER_USAGE_MAPPING, 5); - ASSERT_TRUE(buffer); - - // Test on subspan. - std::vector<uint8_t> actual_data(buffer->allocation_size()); - IREE_ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 1, 3)); - IREE_EXPECT_OK(subspan_buffer->Fill8(2, IREE_WHOLE_BUFFER, 0xDDu)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 0, 0, 0xDD, 0)); -} - -TEST(BufferTest, Fill16) { - auto buffer = HeapBuffer::Allocate(IREE_HAL_BUFFER_USAGE_MAPPING, 9); - ASSERT_TRUE(buffer); - - // Data should be zeroed by default. - std::vector<uint8_t> actual_data(buffer->allocation_size()); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 0, 0, 0, 0, 0, 0, 0, 0)); - - // Fill with a sentinel. - IREE_EXPECT_OK(buffer->Fill16(0, 4, 0x1122u)); - - // Verify data. - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x22, 0x11, 0x22, 0x11, 0, 0, 0, 0, 0)); - - // Zero fills are fine. - IREE_EXPECT_OK(buffer->Fill16(0, 0, 0x5566u)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0x22, 0x11, 0x22, 0x11, 0, 0, 0, 0, 0)); - - // Fill the remaining parts of the buffer by using IREE_WHOLE_BUFFER. - auto aligned_buffer = HeapBuffer::Allocate(IREE_HAL_BUFFER_USAGE_MAPPING, 8); - IREE_EXPECT_OK(aligned_buffer->Fill16(4, IREE_WHOLE_BUFFER, 0x5566u)); - std::vector<uint8_t> aligned_actual_data(aligned_buffer->allocation_size()); - IREE_EXPECT_OK(aligned_buffer->ReadData(0, aligned_actual_data.data(), - aligned_actual_data.size())); - EXPECT_THAT(aligned_actual_data, - ElementsAre(0, 0, 0, 0, 0x66, 0x55, 0x66, 0x55)); - - // Whole buffer helper. - IREE_EXPECT_OK(aligned_buffer->Fill16(0x5566u)); - IREE_EXPECT_OK(aligned_buffer->ReadData(0, aligned_actual_data.data(), - aligned_actual_data.size())); - EXPECT_THAT(aligned_actual_data, - ElementsAre(0x66, 0x55, 0x66, 0x55, 0x66, 0x55, 0x66, 0x55)); -} - -TEST(BufferTest, Fill16OutOfRange) { - auto buffer = HeapBuffer::Allocate(IREE_HAL_BUFFER_USAGE_MAPPING, 9); - ASSERT_TRUE(buffer); - - // Try to fill with invalid ranges. - EXPECT_TRUE(IsOutOfRange(buffer->Fill16(4, 444, 0x5566u))); - EXPECT_TRUE(IsOutOfRange(buffer->Fill16(128, 444, 0x5566u))); - EXPECT_TRUE(IsOutOfRange(buffer->Fill16(128, 4, 0x5566u))); - EXPECT_TRUE(IsOutOfRange(buffer->Fill16(4, 444, 0x5566u))); -} - -TEST(BufferTest, Fill16Unaligned) { - auto buffer = HeapBuffer::Allocate(IREE_HAL_BUFFER_USAGE_MAPPING, 9); - ASSERT_TRUE(buffer); - - // Try to fill with unaligned ranges. - EXPECT_TRUE(IsInvalidArgument(buffer->Fill16(1, 4, 0x5566u))); - EXPECT_TRUE(IsInvalidArgument(buffer->Fill16(0, 5, 0x5566u))); -} - -TEST(BufferTest, Fill16BadMode) { - // Fail to fill buffers not supporting mapping. - auto nonmapping_buffer = - HeapBuffer::Allocate(IREE_HAL_BUFFER_USAGE_TRANSFER, 4); - EXPECT_TRUE(IsPermissionDenied( - nonmapping_buffer->Fill16(0, IREE_WHOLE_BUFFER, 0x99AAu))); - - // Fail to fill constant buffers. - std::vector<uint8_t> const_data = {1, 2, 3}; - auto constant_buffer = HeapBuffer::Wrap(IREE_HAL_MEMORY_TYPE_HOST_LOCAL, - IREE_HAL_BUFFER_USAGE_MAPPING, - absl::MakeConstSpan(const_data)); - EXPECT_TRUE(IsPermissionDenied( - constant_buffer->Fill16(0, IREE_WHOLE_BUFFER, 0x99AAu))); -} - -TEST(BufferTest, Fill16Subspan) { - auto buffer = HeapBuffer::Allocate(IREE_HAL_BUFFER_USAGE_MAPPING, 9); - ASSERT_TRUE(buffer); - - // Fill with a sentinel. - IREE_EXPECT_OK(buffer->Fill16(0, 4, 0x1122u)); - - // Test on subspan. - std::vector<uint8_t> actual_data(buffer->allocation_size()); - IREE_ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 2, 4)); - IREE_EXPECT_OK(subspan_buffer->Fill16(2, IREE_WHOLE_BUFFER, 0xAABBu)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, - ElementsAre(0x22, 0x11, 0x22, 0x11, 0xBB, 0xAA, 0, 0, 0)); -} - -TEST(BufferTest, Fill32) { - auto buffer = HeapBuffer::Allocate(IREE_HAL_BUFFER_USAGE_MAPPING, 9); - ASSERT_TRUE(buffer); - - // Data should be zeroed by default. - std::vector<uint8_t> actual_data(buffer->allocation_size()); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 0, 0, 0, 0, 0, 0, 0, 0)); - - // Fill with a sentinel. - IREE_EXPECT_OK(buffer->Fill32(0, 8, 0x11223344u)); - - // Verify data. - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, - ElementsAre(0x44, 0x33, 0x22, 0x11, 0x44, 0x33, 0x22, 0x11, 0)); - - // Zero fills are fine. - IREE_EXPECT_OK(buffer->Fill32(0, 0, 0x55667788u)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, - ElementsAre(0x44, 0x33, 0x22, 0x11, 0x44, 0x33, 0x22, 0x11, 0)); - - // Fill the remaining parts of the buffer by using IREE_WHOLE_BUFFER. - auto aligned_buffer = HeapBuffer::Allocate(IREE_HAL_BUFFER_USAGE_MAPPING, 8); - IREE_EXPECT_OK(aligned_buffer->Fill32(4, IREE_WHOLE_BUFFER, 0x55667788u)); - std::vector<uint8_t> aligned_actual_data(aligned_buffer->allocation_size()); - IREE_EXPECT_OK(aligned_buffer->ReadData(0, aligned_actual_data.data(), - aligned_actual_data.size())); - EXPECT_THAT(aligned_actual_data, - ElementsAre(0, 0, 0, 0, 0x88, 0x77, 0x66, 0x55)); - - // Whole buffer helper. - IREE_EXPECT_OK(aligned_buffer->Fill32(0x55667788u)); - IREE_EXPECT_OK(aligned_buffer->ReadData(0, aligned_actual_data.data(), - aligned_actual_data.size())); - EXPECT_THAT(aligned_actual_data, - ElementsAre(0x88, 0x77, 0x66, 0x55, 0x88, 0x77, 0x66, 0x55)); -} - -TEST(BufferTest, Fill32OutOfRange) { - auto buffer = HeapBuffer::Allocate(IREE_HAL_BUFFER_USAGE_MAPPING, 9); - ASSERT_TRUE(buffer); - - // Try to fill with invalid ranges. - EXPECT_TRUE(IsOutOfRange(buffer->Fill32(4, 444, 0x55667788u))); - EXPECT_TRUE(IsOutOfRange(buffer->Fill32(128, 444, 0x55667788u))); - EXPECT_TRUE(IsOutOfRange(buffer->Fill32(128, 4, 0x55667788u))); - EXPECT_TRUE(IsOutOfRange(buffer->Fill32(4, 444, 0x55667788u))); -} - -TEST(BufferTest, Fill32Unaligned) { - auto buffer = HeapBuffer::Allocate(IREE_HAL_BUFFER_USAGE_MAPPING, 9); - ASSERT_TRUE(buffer); - - // Try to fill with unaligned ranges. - EXPECT_TRUE(IsInvalidArgument(buffer->Fill32(1, 4, 0x55667788u))); - EXPECT_TRUE(IsInvalidArgument(buffer->Fill32(0, 5, 0x55667788u))); -} - -TEST(BufferTest, Fill32BadMode) { - // Fail to fill buffers not supporting mapping. - auto nonmapping_buffer = - HeapBuffer::Allocate(IREE_HAL_BUFFER_USAGE_TRANSFER, 4); - EXPECT_TRUE(IsPermissionDenied( - nonmapping_buffer->Fill32(0, IREE_WHOLE_BUFFER, 0x99AABBCCu))); - - // Fail to fill constant buffers. - std::vector<uint8_t> const_data = {1, 2, 3}; - auto constant_buffer = HeapBuffer::Wrap(IREE_HAL_MEMORY_TYPE_HOST_LOCAL, - IREE_HAL_BUFFER_USAGE_MAPPING, - absl::MakeConstSpan(const_data)); - EXPECT_TRUE(IsPermissionDenied( - constant_buffer->Fill32(0, IREE_WHOLE_BUFFER, 0x99AABBCCu))); -} - -TEST(BufferTest, Fill32Subspan) { - auto buffer = HeapBuffer::Allocate(IREE_HAL_BUFFER_USAGE_MAPPING, 9); - ASSERT_TRUE(buffer); - - // Fill with a sentinel. - IREE_EXPECT_OK(buffer->Fill32(0, 8, 0x11223344u)); - - // Test on subspan. - std::vector<uint8_t> actual_data(buffer->allocation_size()); - IREE_ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 4, 4)); - IREE_EXPECT_OK(subspan_buffer->Fill32(0, IREE_WHOLE_BUFFER, 0xAABBCCDDu)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, - ElementsAre(0x44, 0x33, 0x22, 0x11, 0xDD, 0xCC, 0xBB, 0xAA, 0)); -} - -TEST(BufferTest, ReadData) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto buffer = HeapBuffer::AllocateCopy( - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, - src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Read the data back. - std::vector<uint8_t> actual_data(src_data.size()); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(src_data)); - - // Reading zero bytes is valid. - std::vector<uint8_t> zero_data(0); - IREE_EXPECT_OK(buffer->ReadData(1, zero_data.data(), 0)); - - // Read a portion of the data. - std::vector<uint8_t> partial_data(2); - IREE_EXPECT_OK(buffer->ReadData(1, partial_data.data(), 2)); - EXPECT_THAT(partial_data, ElementsAre(1, 2)); -} - -TEST(BufferTest, ReadDataOutOfRange) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto buffer = HeapBuffer::AllocateCopy( - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, - src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Try to read out of range. - std::vector<uint8_t> partial_data(2); - EXPECT_TRUE(IsOutOfRange(buffer->ReadData(0, partial_data.data(), 444))); - EXPECT_TRUE(IsOutOfRange(buffer->ReadData(1230, partial_data.data(), 444))); - EXPECT_TRUE(IsOutOfRange(buffer->ReadData(1230, partial_data.data(), 1))); - EXPECT_TRUE(IsInvalidArgument( - buffer->ReadData(0, partial_data.data(), IREE_WHOLE_BUFFER))); -} - -TEST(BufferTest, ReadDataBadMode) { - // Fail to read buffers not supporting mapping. - std::vector<uint8_t> actual_data(1); - auto nonmapping_buffer = - HeapBuffer::Allocate(IREE_HAL_BUFFER_USAGE_TRANSFER, 4); - EXPECT_TRUE(IsPermissionDenied( - nonmapping_buffer->ReadData(0, actual_data.data(), 1))); -} - -TEST(BufferTest, ReadDataSubspan) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto buffer = HeapBuffer::AllocateCopy( - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, - src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Test on subspan. - std::vector<uint8_t> subspan_data(1); - IREE_ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 1, 2)); - IREE_EXPECT_OK(subspan_buffer->ReadData(1, subspan_data.data(), 1)); - EXPECT_THAT(subspan_data, ElementsAre(2)); -} - -TEST(BufferTest, WriteData) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto buffer = HeapBuffer::AllocateCopy( - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, - src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Read the data back - should still match. - std::vector<uint8_t> actual_data(src_data.size()); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(src_data)); - - // Write over the entire buffer. - std::vector<uint8_t> new_data = {10, 20, 30, 40}; - IREE_EXPECT_OK(buffer->WriteData(0, new_data.data(), new_data.size())); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(new_data)); - - // Writing zero bytes is valid. - std::vector<uint8_t> zero_data; - IREE_EXPECT_OK(buffer->WriteData(0, zero_data.data(), 0)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(new_data)); - - // Write over a portion of the buffer. - std::vector<uint8_t> partial_data = {99}; - IREE_EXPECT_OK( - buffer->WriteData(1, partial_data.data(), partial_data.size())); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(10, 99, 30, 40)); -} - -TEST(BufferTest, WriteDataOutOfRange) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto buffer = HeapBuffer::AllocateCopy( - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, - src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Try to write out of range. - std::vector<uint8_t> partial_data = {99}; - EXPECT_TRUE(IsOutOfRange(buffer->WriteData(0, partial_data.data(), 444))); - EXPECT_TRUE(IsOutOfRange(buffer->WriteData(1230, partial_data.data(), 444))); - EXPECT_TRUE(IsOutOfRange(buffer->WriteData(1230, partial_data.data(), 1))); - EXPECT_TRUE(IsInvalidArgument( - buffer->WriteData(0, partial_data.data(), IREE_WHOLE_BUFFER))); -} - -TEST(BufferTest, WriteDataBadMode) { - std::vector<uint8_t> actual_data(4); - - // Fail to write buffers not supporting mapping. - auto nonmapping_buffer = - HeapBuffer::Allocate(IREE_HAL_BUFFER_USAGE_TRANSFER, 4); - EXPECT_TRUE(IsPermissionDenied( - nonmapping_buffer->WriteData(0, actual_data.data(), 1))); - - // Fail to write to constant buffers. - std::vector<uint8_t> const_data = {1, 2, 3}; - auto constant_buffer = HeapBuffer::Wrap(IREE_HAL_MEMORY_TYPE_HOST_LOCAL, - IREE_HAL_BUFFER_USAGE_TRANSFER, - absl::MakeConstSpan(const_data)); - EXPECT_TRUE( - IsPermissionDenied(constant_buffer->WriteData(0, actual_data.data(), 2))); -} - -TEST(BufferTest, WriteDataSubspan) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto buffer = HeapBuffer::AllocateCopy( - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, - src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Test on subspan. - std::vector<uint8_t> subspan_data = {0xAA}; - IREE_ASSERT_OK_AND_ASSIGN(auto subspan_buffer, Buffer::Subspan(buffer, 1, 2)); - IREE_EXPECT_OK(subspan_buffer->WriteData(1, subspan_data.data(), 1)); - std::vector<uint8_t> actual_data(src_data.size()); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 1, 0xAA, 3)); -} - -TEST(BufferTest, CopyData) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto src_buffer = HeapBuffer::AllocateCopy( - IREE_HAL_BUFFER_USAGE_MAPPING | IREE_HAL_BUFFER_USAGE_TRANSFER, - src_data.data(), src_data.size()); - ASSERT_TRUE(src_buffer); - std::vector<uint8_t> dst_data = {0, 1, 2, 3, 4}; - auto dst_buffer = HeapBuffer::AllocateCopy( - IREE_HAL_BUFFER_USAGE_MAPPING | IREE_HAL_BUFFER_USAGE_TRANSFER, - dst_data.data(), dst_data.size()); - ASSERT_TRUE(dst_buffer); - - // Copy of length 0 should not change the dest buffer. - IREE_EXPECT_OK(dst_buffer->CopyData(0, src_buffer.get(), 0, 0)); - std::vector<uint8_t> actual_data(dst_data.size()); - IREE_EXPECT_OK( - dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, Eq(dst_data)); - - // Copy a subrange of the buffer. - IREE_EXPECT_OK(dst_buffer->CopyData(1, src_buffer.get(), 2, 2)); - IREE_EXPECT_OK( - dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 2, 3, 3, 4)); - - // Copy the entire buffer using IREE_WHOLE_BUFFER. This will adjust sizes - // to ensure that the min buffer is taken. We test both src and dst buffer - // offset/length calculations (note that some may end up as 0 copies). - IREE_EXPECT_OK( - dst_buffer->CopyData(3, src_buffer.get(), 0, IREE_WHOLE_BUFFER)); - IREE_EXPECT_OK( - dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 2, 3, 0, 1)); - IREE_EXPECT_OK( - dst_buffer->CopyData(0, src_buffer.get(), 2, IREE_WHOLE_BUFFER)); - IREE_EXPECT_OK( - dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(2, 3, 3, 0, 1)); - IREE_EXPECT_OK( - dst_buffer->CopyData(0, src_buffer.get(), 3, IREE_WHOLE_BUFFER)); - IREE_EXPECT_OK( - dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(3, 3, 3, 0, 1)); - IREE_EXPECT_OK( - dst_buffer->CopyData(4, src_buffer.get(), 0, IREE_WHOLE_BUFFER)); - IREE_EXPECT_OK( - dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(3, 3, 3, 0, 0)); -} - -TEST(BufferTest, CopyDataOutOfRange) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto src_buffer = HeapBuffer::AllocateCopy( - IREE_HAL_BUFFER_USAGE_MAPPING | IREE_HAL_BUFFER_USAGE_TRANSFER, - src_data.data(), src_data.size()); - ASSERT_TRUE(src_buffer); - std::vector<uint8_t> dst_data = {0, 1, 2, 3, 4}; - auto dst_buffer = HeapBuffer::AllocateCopy( - IREE_HAL_BUFFER_USAGE_MAPPING | IREE_HAL_BUFFER_USAGE_TRANSFER, - dst_data.data(), dst_data.size()); - ASSERT_TRUE(dst_buffer); - - // Try to copy out of range of source and dest. - EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(123, src_buffer.get(), 0, 1))); - EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(4, src_buffer.get(), 0, 4))); - EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(0, src_buffer.get(), 123, 1))); - EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(0, src_buffer.get(), 0, 123))); - EXPECT_TRUE( - IsOutOfRange(dst_buffer->CopyData(123, src_buffer.get(), 123, 123))); - EXPECT_TRUE(IsOutOfRange(dst_buffer->CopyData(0, src_buffer.get(), 123, 0))); -} - -TEST(BufferTest, CopyDataOverlapping) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto src_buffer = HeapBuffer::AllocateCopy( - IREE_HAL_BUFFER_USAGE_MAPPING | IREE_HAL_BUFFER_USAGE_TRANSFER, - src_data.data(), src_data.size()); - ASSERT_TRUE(src_buffer); - std::vector<uint8_t> dst_data = {0, 1, 2, 3, 4}; - auto dst_buffer = HeapBuffer::AllocateCopy( - IREE_HAL_BUFFER_USAGE_MAPPING | IREE_HAL_BUFFER_USAGE_TRANSFER, - dst_data.data(), dst_data.size()); - ASSERT_TRUE(dst_buffer); - - // Test overlap. Non-overlapping regions should be fine, otherwise fail. - std::vector<uint8_t> actual_data(dst_data.size()); - IREE_EXPECT_OK(dst_buffer->CopyData(0, dst_buffer.get(), 4, 1)); - IREE_EXPECT_OK( - dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(4, 1, 2, 3, 4)); - EXPECT_TRUE( - IsInvalidArgument(dst_buffer->CopyData(2, dst_buffer.get(), 0, 3))); - EXPECT_TRUE( - IsInvalidArgument(dst_buffer->CopyData(0, dst_buffer.get(), 0, 3))); -} - -TEST(BufferTest, CopyDataBadMode) { - // Both source and target buffers must support mapping. - auto nonmapping_src_buffer = - HeapBuffer::Allocate(IREE_HAL_BUFFER_USAGE_TRANSFER, 4); - auto nonmapping_dst_buffer = - HeapBuffer::Allocate(IREE_HAL_BUFFER_USAGE_TRANSFER, 4); - EXPECT_TRUE(IsPermissionDenied(nonmapping_dst_buffer->CopyData( - 0, nonmapping_src_buffer.get(), 0, IREE_WHOLE_BUFFER))); - EXPECT_TRUE(IsPermissionDenied(nonmapping_src_buffer->CopyData( - 0, nonmapping_dst_buffer.get(), 0, IREE_WHOLE_BUFFER))); - - // Fail to copy into to constant buffers. - std::vector<uint8_t> const_data = {1, 2, 3}; - auto constant_buffer = HeapBuffer::Wrap(IREE_HAL_MEMORY_TYPE_HOST_LOCAL, - IREE_HAL_BUFFER_USAGE_TRANSFER, - absl::MakeConstSpan(const_data)); - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto src_buffer = HeapBuffer::AllocateCopy( - IREE_HAL_BUFFER_USAGE_MAPPING | IREE_HAL_BUFFER_USAGE_TRANSFER, - src_data.data(), src_data.size()); - EXPECT_TRUE(IsPermissionDenied( - constant_buffer->CopyData(0, src_buffer.get(), 0, IREE_WHOLE_BUFFER))); -} - -TEST(BufferTest, CopyDataSubspan) { - std::vector<uint8_t> src_data = {0, 1, 2, 3}; - auto src_buffer = HeapBuffer::AllocateCopy( - IREE_HAL_BUFFER_USAGE_MAPPING | IREE_HAL_BUFFER_USAGE_TRANSFER, - src_data.data(), src_data.size()); - ASSERT_TRUE(src_buffer); - std::vector<uint8_t> dst_data = {0, 1, 2, 3, 4}; - auto dst_buffer = HeapBuffer::AllocateCopy( - IREE_HAL_BUFFER_USAGE_MAPPING | IREE_HAL_BUFFER_USAGE_TRANSFER, - dst_data.data(), dst_data.size()); - ASSERT_TRUE(dst_buffer); - - // Test on subspan. - std::vector<uint8_t> actual_data(dst_data.size()); - IREE_ASSERT_OK_AND_ASSIGN(auto subspan_src_buffer, - Buffer::Subspan(src_buffer, 1, 3)); - IREE_ASSERT_OK_AND_ASSIGN(auto subspan_dst_buffer, - Buffer::Subspan(dst_buffer, 2, 3)); - IREE_EXPECT_OK( - subspan_dst_buffer->CopyData(1, subspan_src_buffer.get(), 1, 2)); - IREE_EXPECT_OK( - dst_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 1, 2, 2, 3)); -} - -// NOTE: more tests related specifically to MappedMemory are in -// buffer_mapping_test.cc. This tests the MapMemory operation and enough to -// ensure the memory was mapped to the correct range and the HostBuffer and -// SubspanBuffer work as intended for basic usage. -TEST(BufferTest, MapMemory) { - std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6}; - auto buffer = HeapBuffer::AllocateCopy( - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, - IREE_HAL_MEMORY_ACCESS_READ, src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // 0-length mappings are valid. - IREE_ASSERT_OK_AND_ASSIGN( - auto mapping, - buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_READ, 0, 0)); - EXPECT_TRUE(mapping.empty()); - EXPECT_EQ(0, mapping.size()); - EXPECT_EQ(0, mapping.byte_length()); - EXPECT_NE(nullptr, mapping.data()); - IREE_ASSERT_OK_AND_ASSIGN(auto span, mapping.Subspan()); - EXPECT_TRUE(span.empty()); - mapping.reset(); - - // Map the whole buffer for reading. - IREE_ASSERT_OK_AND_ASSIGN( - mapping, buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_READ, 0, - IREE_WHOLE_BUFFER)); - EXPECT_EQ(src_data.size(), mapping.size()); - IREE_ASSERT_OK_AND_ASSIGN(span, mapping.Subspan()); - EXPECT_THAT(span, ElementsAre(0, 1, 2, 3, 4, 5, 6)); - mapping.reset(); - - // Map a portion of the buffer for reading. - IREE_ASSERT_OK_AND_ASSIGN( - mapping, buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_READ, 1, 2)); - EXPECT_EQ(2, mapping.size()); - IREE_ASSERT_OK_AND_ASSIGN(span, mapping.Subspan()); - EXPECT_THAT(span, ElementsAre(1, 2)); - mapping.reset(); -} - -TEST(BufferTest, MapMemoryNonByte) { - std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6}; - auto buffer = HeapBuffer::AllocateCopy( - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, - IREE_HAL_MEMORY_ACCESS_READ, src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Map the buffer as non-byte values. - // Note that we'll round down to the number of valid elements at the - // alignment. - IREE_ASSERT_OK_AND_ASSIGN( - auto mapping16, buffer->MapMemory<uint16_t>(IREE_HAL_MEMORY_ACCESS_READ)); - EXPECT_EQ(3, mapping16.size()); - EXPECT_LE(6, mapping16.byte_length()); - IREE_ASSERT_OK_AND_ASSIGN(auto span16, mapping16.Subspan()); - EXPECT_THAT(span16, ElementsAre(0x0100, 0x0302, 0x0504)); - mapping16.reset(); -} - -TEST(BufferTest, MapMemoryOutOfRange) { - std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6}; - auto buffer = HeapBuffer::AllocateCopy( - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, - IREE_HAL_MEMORY_ACCESS_READ, src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Test invalid mapping ranges. - EXPECT_TRUE(IsOutOfRange( - buffer->MapMemory<uint16_t>(IREE_HAL_MEMORY_ACCESS_READ, 0, 123) - .status())); - EXPECT_TRUE(IsOutOfRange( - buffer->MapMemory<uint16_t>(IREE_HAL_MEMORY_ACCESS_READ, 5, 1231) - .status())); - EXPECT_TRUE( - IsOutOfRange(buffer - ->MapMemory<uint16_t>(IREE_HAL_MEMORY_ACCESS_READ, 6, - IREE_WHOLE_BUFFER) - .status())); - EXPECT_TRUE(IsOutOfRange( - buffer->MapMemory<uint16_t>(IREE_HAL_MEMORY_ACCESS_READ, 1236, 1) - .status())); -} - -TEST(BufferTest, MapMemoryBadMode) { - std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6}; - auto read_buffer = HeapBuffer::AllocateCopy( - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, - IREE_HAL_MEMORY_ACCESS_READ, src_data.data(), src_data.size()); - ASSERT_TRUE(read_buffer); - - // Test mapping the read-only buffer for writing. - EXPECT_TRUE(IsPermissionDenied( - read_buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_WRITE).status())); - EXPECT_TRUE(IsPermissionDenied( - read_buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE) - .status())); - EXPECT_TRUE(IsPermissionDenied( - read_buffer - ->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_READ | - IREE_HAL_MEMORY_ACCESS_DISCARD) - .status())); - EXPECT_TRUE(IsInvalidArgument( - read_buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_NONE).status())); -} - -TEST(BufferTest, MapMemoryWrite) { - std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6}; - auto buffer = HeapBuffer::AllocateCopy( - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, - IREE_HAL_MEMORY_ACCESS_ALL, src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Map and modify the data. We should see it when we read back. - IREE_ASSERT_OK_AND_ASSIGN( - auto mapping, - buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_WRITE, 1, 2)); - auto mutable_data = mapping.mutable_data(); - mutable_data[0] = 0xAA; - mutable_data[1] = 0xBB; - mapping.reset(); - std::vector<uint8_t> actual_data(src_data.size()); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 0xAA, 0xBB, 3, 4, 5, 6)); -} - -TEST(BufferTest, MapMemoryDiscard) { - std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6}; - auto buffer = HeapBuffer::AllocateCopy( - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, - IREE_HAL_MEMORY_ACCESS_ALL, src_data.data(), src_data.size()); - ASSERT_TRUE(buffer); - - // Map for discard. Note that we can't really rely on the value of the data - // so we just trust that it's been discarded. It's a hint, anyway. We can be - // sure that the data we didn't want to discard is the same though. - std::vector<uint8_t> actual_data(src_data.size()); - IREE_ASSERT_OK_AND_ASSIGN( - auto mapping, - buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE, 1, 2)); - IREE_EXPECT_OK(buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, _, _, 3, 4, 5, 6)); - mapping.reset(); -} - -TEST(BufferTest, MapMemorySubspan) { - std::vector<uint8_t> src_data = {0, 1, 2, 3, 4, 5, 6}; - auto parent_buffer = HeapBuffer::AllocateCopy( - IREE_HAL_BUFFER_USAGE_TRANSFER | IREE_HAL_BUFFER_USAGE_MAPPING, - IREE_HAL_MEMORY_ACCESS_ALL, src_data.data(), src_data.size()); - ASSERT_TRUE(parent_buffer); - IREE_ASSERT_OK_AND_ASSIGN(auto subspan_buffer, - Buffer::Subspan(parent_buffer, 1, 3)); - IREE_ASSERT_OK_AND_ASSIGN(auto mapping, - subspan_buffer->MapMemory<uint8_t>( - IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE, 1, 2)); - auto* mutable_data = mapping.mutable_data(); - mutable_data[0] = 0xCC; - mutable_data[1] = 0xDD; - mapping.reset(); - - std::vector<uint8_t> actual_data(src_data.size()); - IREE_EXPECT_OK( - parent_buffer->ReadData(0, actual_data.data(), actual_data.size())); - EXPECT_THAT(actual_data, ElementsAre(0, 1, 0xCC, 0xDD, 4, 5, 6)); - - // Just here to make coverage happy; they are currently no-ops on the host. - // buffer_mapping_test.cc contains tests that ensure they are called - // correctly. - std::vector<uint8_t> external_data = {0, 1, 2, 3, 4}; - auto external_buffer = HeapBuffer::WrapMutable( - IREE_HAL_MEMORY_TYPE_HOST_VISIBLE | IREE_HAL_MEMORY_TYPE_HOST_CACHED, - IREE_HAL_MEMORY_ACCESS_ALL, IREE_HAL_BUFFER_USAGE_ALL, - absl::MakeSpan(external_data)); - IREE_ASSERT_OK_AND_ASSIGN(auto external_subspan_buffer, - Buffer::Subspan(external_buffer, 0, 1)); - IREE_ASSERT_OK_AND_ASSIGN( - mapping, - external_subspan_buffer->MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_ALL)); - IREE_EXPECT_OK(mapping.Invalidate()); - IREE_EXPECT_OK(mapping.Flush()); -} - -} // namespace -} // namespace hal -} // namespace iree - -#endif // 0
diff --git a/iree/hal/cc/command_buffer.h b/iree/hal/cc/command_buffer.h deleted file mode 100644 index a2866e8..0000000 --- a/iree/hal/cc/command_buffer.h +++ /dev/null
@@ -1,264 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_CC_COMMAND_BUFFER_H_ -#define IREE_HAL_CC_COMMAND_BUFFER_H_ - -#include <cstdint> - -#include "iree/base/status.h" -#include "iree/hal/cc/buffer.h" -#include "iree/hal/cc/descriptor_set.h" -#include "iree/hal/cc/event.h" -#include "iree/hal/cc/executable.h" -#include "iree/hal/cc/executable_layout.h" -#include "iree/hal/cc/resource.h" - -namespace iree { -namespace hal { - -std::string CommandBufferModeString(iree_hal_command_buffer_mode_t mode); -inline std::string CommandCategoryString( - iree_hal_command_category_t categories) { - return "TODO"; - // return FormatBitfieldValue( - // categories, { - // {IREE_HAL_COMMAND_CATEGORY_TRANSFER, "kTransfer"}, - // {IREE_HAL_COMMAND_CATEGORY_DISPATCH, "kDispatch"}, - // }); -} - -// Asynchronous command buffer recording interface. -// Commands are recorded by the implementation for later submission to command -// queues. -// -// Buffers and synchronization objects referenced must remain valid and not be -// modified or read while there are commands in-flight. The usual flow is to -// populate input buffers, Dispatch using those buffers, wait on a Semaphore -// until the buffers are guaranteed to no longer be in use, and then reuse or -// release the buffers. -// -// Errors that can be recognized when operations are enqueued will be returned -// immediately, such as invalid argument errors. Errors that can only be -// determined at execution time will be returned on semaphores. Once a failure -// occurs the device queue will enter an error state that invalidates all -// operations on the device queue (as ordering is not strict and any may still -// be in-flight). In this case the user of the device queue should treat all -// in-flight operations as cancelled and fully reset themselves. Other device -// queues that may be waiting on events from the device queue will also enter -// error states. Only once a user has acknowledged and cleared the error state -// with a Reset the queue will become usable, and otherwise all operations will -// return errors. -// -// Command buffers are thread-compatible. Use multiple command buffers if trying -// to record commands from multiple threads. Command buffers must not be mutated -// between when they have are submitted for execution on a queue and when the -// semaphore fires indicating the completion of their execution. -class CommandBuffer : public Resource { - public: - virtual CommandBuffer* impl() { return this; } - - // Command buffer operation mode. - iree_hal_command_buffer_mode_t mode() const { return mode_; } - - // Command categories that may be recorded into the buffer. - iree_hal_command_category_t command_categories() const { - return command_categories_; - } - - // True if the command buffer is between a Begin/End recording block. - virtual bool is_recording() const = 0; - - // Resets and begins recording into the command buffer, clearing all - // previously recorded contents. - // The command buffer must not be in-flight. - virtual Status Begin() = 0; - - // Ends recording into the command buffer. - // This must be called prior to submitting the command buffer for execution. - virtual Status End() = 0; - - // TODO(benvanik): annotations for debugging and tracing: - // enter/exit - // stack frame manipulation - // explicit timers? or profiling buffer? - - // TODO(b/138719910): cross-queue and external acquire/release. - // virtual Status AcquireBuffer() = 0; - // virtual Status ReleaseBuffer() = 0; - - // Defines a memory dependency between commands recorded before and after the - // barrier. One or more memory or buffer barriers can be specified to indicate - // between which stages or buffers the dependencies exist. - virtual Status ExecutionBarrier( - iree_hal_execution_stage_t source_stage_mask, - iree_hal_execution_stage_t target_stage_mask, - absl::Span<const iree_hal_memory_barrier_t> memory_barriers, - absl::Span<const iree_hal_buffer_barrier_t> buffer_barriers) = 0; - - // Sets an event to the signaled state. - // |source_stage_mask| specifies when the event is signaled. - // - // Events are only valid within a single command buffer. Events can only be - // used on non-transfer queues. - virtual Status SignalEvent(Event* event, - iree_hal_execution_stage_t source_stage_mask) = 0; - - // Resets an event to the non-signaled state. - // |source_stage_mask| specifies when the event is unsignaled. - // - // Events are only valid within a single command buffer. Events can only be - // used on non-transfer queues. - virtual Status ResetEvent(Event* event, - iree_hal_execution_stage_t source_stage_mask) = 0; - - // Waits for one or more events to be signaled and defines a memory dependency - // between the synchronization scope of the signal operations and the commands - // following the wait. - // - // |source_stage_mask| must include IREE_HAL_EXECUTION_STAGE_HOST for - // Event::Signal to be visibile. - // - // Events are only valid within a single command buffer. Events remain - // signaled even after waiting and must be reset to be reused. Events can only - // be used on non-transfer queues. - virtual Status WaitEvents( - absl::Span<Event*> events, iree_hal_execution_stage_t source_stage_mask, - iree_hal_execution_stage_t target_stage_mask, - absl::Span<const iree_hal_memory_barrier_t> memory_barriers, - absl::Span<const iree_hal_buffer_barrier_t> buffer_barriers) = 0; - - // Fills the target buffer with the given repeating value. - // Expects that value_length is one of 1, 2, or 4 and that the offset and - // length are aligned to the natural alignment of the value. - // The target buffer must be compatible with the devices owned by this - // device queue and be allocated with IREE_HAL_BUFFER_USAGE_TRANSFER. - virtual Status FillBuffer(Buffer* target_buffer, - iree_device_size_t target_offset, - iree_device_size_t length, const void* pattern, - size_t pattern_length) = 0; - - // Hints to the device queue that the given buffer will not be used again. - // After encoding a discard the buffer contents will be considered undefined. - // This is because the discard may be used to elide write backs to host memory - // or aggressively reuse the allocation for other purposes. - // - // For buffers allocated with IREE_HAL_MEMORY_TYPE_TRANSIENT this may allow - // the device queue to reclaim the memory used by the buffer earlier than - // otherwise possible. - virtual Status DiscardBuffer(Buffer* buffer) = 0; - - // Updates a range of the given target buffer from the source host memory. - // The source host memory is copied immediately into the command buffer and - // occupies command buffer space. It is strongly recommended that large buffer - // updates are performed via CopyBuffer where there is the possibility of a - // zero-copy path. - // The |source_buffer| may be releaed by the caller immediately after this - // call returns. - // The |target_buffer| must be compatible with the devices owned by this - // device queue and be allocated with IREE_HAL_BUFFER_USAGE_TRANSFER. - virtual Status UpdateBuffer(const void* source_buffer, - iree_device_size_t source_offset, - Buffer* target_buffer, - iree_device_size_t target_offset, - iree_device_size_t length) = 0; - - // Copies a range of one buffer to another. - // Both buffers must be compatible with the devices owned by this device - // queue and be allocated with IREE_HAL_BUFFER_USAGE_TRANSFER. Though the - // source and target buffer may be the same the ranges must not overlap (as - // with memcpy). - // - // This can be used to perform device->host, host->device, and device->device - // copies. - virtual Status CopyBuffer(Buffer* source_buffer, - iree_device_size_t source_offset, - Buffer* target_buffer, - iree_device_size_t target_offset, - iree_device_size_t length) = 0; - - // Pushes an inline set of constants that can be accessed by subsequent - // dispatches using a compatible executable layout. - // - // Push constants are always 4-byte values and treated as opaque, meaning that - // they may be bit-casted floats, bit-packed booleans, etc. - virtual Status PushConstants(ExecutableLayout* executable_layout, - size_t offset, - absl::Span<const uint32_t> values) = 0; - - // Pushes a descriptor set and associates it with |set|. - // This uses an internal ringbuffer inside of the command buffer to avoid the - // need for creating and binding descriptor sets and managing their lifetime. - // - // The descriptor set will remain bound and valid so long as the executable - // layouts used by dispatches are compatible (same descriptor layouts and push - // constant sizes). - virtual Status PushDescriptorSet( - ExecutableLayout* executable_layout, int32_t set, - absl::Span<const iree_hal_descriptor_set_binding_t> bindings) = 0; - - // Binds a descriptor set to the given |set| matching that used in the - // executable layout interface. - // - // The descriptor set will remain bound and valid so long as the executable - // layouts used by dispatches are compatible (same descriptor layouts and push - // constant sizes). - // - // If any dynamic descriptor types are defined in the descriptor set layout - // then the dynamic offsets must be provided. These offsets will be added to - // the base offset of the descriptor layout binding. - virtual Status BindDescriptorSet( - ExecutableLayout* executable_layout, int32_t set, - DescriptorSet* descriptor_set, - absl::Span<const iree_device_size_t> dynamic_offsets) = 0; - - // Dispatches an execution request. - // The request may execute overlapped with any other transfer operation or - // dispatch made within the same barrier-defined sequence. - // - // The executable specified must be registered for use with the device driver - // owning this queue. It must not be unregistered until all requests that use - // it have completed. - // - // Fails if the queue does not support dispatch operations (as indicated by - // can_dispatch). - virtual Status Dispatch(Executable* executable, int32_t entry_point, - std::array<uint32_t, 3> workgroups) = 0; - - // Dispatches an execution request with deferred workgroup counts. - // This is the same as Dispatch but the workgroup counts are read from the - // given |workgroups_buffer| at offset |workgroups_offset| as 3 uint32_t XYZ - // values before performing the dispatch. This allows prior dispatches within - // the command sequence to populate the workgroup counts. - // - // The buffer must have been allocated with IREE_HAL_BUFFER_USAGE_DISPATCH and - // be of IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE. - virtual Status DispatchIndirect(Executable* executable, int32_t entry_point, - Buffer* workgroups_buffer, - iree_device_size_t workgroups_offset) = 0; - - protected: - CommandBuffer(iree_hal_command_buffer_mode_t mode, - iree_hal_command_category_t command_categories) - : mode_(mode), command_categories_(command_categories) {} - - private: - const iree_hal_command_buffer_mode_t mode_; - const iree_hal_command_category_t command_categories_; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_CC_COMMAND_BUFFER_H_
diff --git a/iree/hal/cc/command_queue.h b/iree/hal/cc/command_queue.h deleted file mode 100644 index 0e6c058..0000000 --- a/iree/hal/cc/command_queue.h +++ /dev/null
@@ -1,111 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_CC_COMMAND_QUEUE_H_ -#define IREE_HAL_CC_COMMAND_QUEUE_H_ - -#include <cstdint> -#include <string> - -#include "absl/types/span.h" -#include "iree/base/status.h" -#include "iree/base/time.h" -#include "iree/hal/cc/command_buffer.h" -#include "iree/hal/cc/semaphore.h" - -namespace iree { -namespace hal { - -// A batch of command buffers with synchronization information for submission. -struct SubmissionBatch { - // A set of semaphores that must have their payload values meet or exceed the - // specified values prior to any command buffer within this batch executing. - absl::Span<const SemaphoreValue> wait_semaphores; - - // Command buffers that will execute in this batch. - // The command buffers will begin execution in order but may complete out of - // order. - absl::Span<CommandBuffer* const> command_buffers; - - // Semaphores to signal after execution of all command buffers complete. - // Semaphore playloads will be set to the maximum of the specified payload or - // their current payload. - absl::Span<const SemaphoreValue> signal_semaphores; -}; - -// Asynchronous command execution queue. -// -// CommandQueues may capture device status at Semaphore barriers, including -// information about device state such as thermal throttling. This information -// is a snapshot of the state at the time the semaphore was signaled and not -// necessarily live at the time of the application query. -// -// Command queues are thread-safe and submissions may occur from multiple -// threads. -class CommandQueue { - public: - virtual ~CommandQueue() = default; - - // Name of the queue used for logging purposes. - // Try to keep at 4 characters total for prettier logging. - const std::string& name() const { return name_; } - - // Capabilities of the command queue. - iree_hal_command_category_t supported_categories() const { - return supported_categories_; - } - - // Whether this queue may be used for transfer commands. - bool can_transfer() const { - return iree_all_bits_set(supported_categories_, - IREE_HAL_COMMAND_CATEGORY_TRANSFER); - } - - // Whether this queue may be used for dispatch commands. - bool can_dispatch() const { - return iree_all_bits_set(supported_categories_, - IREE_HAL_COMMAND_CATEGORY_DISPATCH); - } - - // Submits one or more command batches for execution on the queue. - virtual Status Submit(absl::Span<const SubmissionBatch> batches) = 0; - inline Status Submit(const SubmissionBatch& batch) { - return Submit(absl::MakeConstSpan(&batch, 1)); - } - - // Blocks until all outstanding requests have been completed. - // This is equivalent to having waited on all outstanding semaphores. - // Implicitly calls Flush to ensure delayed requests are scheduled. - // - // If the command queue has encountered an error during submission at any - // point it will be returned here (repeatedly). - virtual Status WaitIdle(Time deadline_ns) = 0; - inline Status WaitIdle(Duration timeout_ns) { - return WaitIdle(RelativeTimeoutToDeadlineNanos(timeout_ns)); - } - inline Status WaitIdle() { return WaitIdle(InfiniteFuture()); } - - protected: - CommandQueue(std::string name, - iree_hal_command_category_t supported_categories) - : name_(std::move(name)), supported_categories_(supported_categories) {} - - const std::string name_; - const iree_hal_command_category_t supported_categories_; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_CC_COMMAND_QUEUE_H_
diff --git a/iree/hal/cc/debug_capture_manager.h b/iree/hal/cc/debug_capture_manager.h deleted file mode 100644 index 0c3a588..0000000 --- a/iree/hal/cc/debug_capture_manager.h +++ /dev/null
@@ -1,62 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_CC_DEBUG_CAPTURE_MANAGER_H_ -#define IREE_HAL_CC_DEBUG_CAPTURE_MANAGER_H_ - -#include "iree/base/status.h" - -namespace iree { -namespace hal { - -// Interface for interacting with command recorders / debuggers. -// -// Subclasses connect to tools like RenderDoc or MTLCaptureManager and use them -// to record commands sent to underlying APIs like Vulkan or Metal, for future -// debugging and analysis. -class DebugCaptureManager { - public: - DebugCaptureManager() {} - virtual ~DebugCaptureManager() = default; - - // Attempts to connect to a command recorder, if not already connected. - // - // This should be called *before* the underlying system and its devices (such - // as a VkInstance and its VkDevices) are initialized, so the command recorder - // can inject any necessary hooks. - virtual Status Connect() = 0; - - // Disconnects from a connected command recorder, if connected. - // This implicitly stops capture if currently capturing. - virtual void Disconnect() = 0; - - // Returns true if connected to a command recorder. - virtual bool is_connected() const = 0; - - // Starts capturing commands. - // Must already be connected and must not already be capturing. - virtual void StartCapture() = 0; - - // Stops capturing commands and saves the capture. - // Must already be connected and capturing. - virtual void StopCapture() = 0; - - // Returns true if currently capturing commands. - virtual bool is_capturing() const = 0; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_CC_DEBUG_CAPTURE_MANAGER_H_
diff --git a/iree/hal/cc/descriptor_set.h b/iree/hal/cc/descriptor_set.h deleted file mode 100644 index 1244196..0000000 --- a/iree/hal/cc/descriptor_set.h +++ /dev/null
@@ -1,36 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "absl/strings/str_cat.h" -#include "iree/hal/cc/buffer.h" -#include "iree/hal/cc/resource.h" - -#ifndef IREE_HAL_CC_DESCRIPTOR_SET_H_ -#define IREE_HAL_CC_DESCRIPTOR_SET_H_ - -namespace iree { -namespace hal { - -// Opaque handle to a descriptor set object. -// -// Maps to VkDescriptorSet: -// https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkDescriptorSet.html -class DescriptorSet : public Resource { - public: -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_CC_DESCRIPTOR_SET_H_
diff --git a/iree/hal/cc/descriptor_set_layout.h b/iree/hal/cc/descriptor_set_layout.h deleted file mode 100644 index 049840e..0000000 --- a/iree/hal/cc/descriptor_set_layout.h +++ /dev/null
@@ -1,36 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "absl/strings/str_cat.h" -#include "iree/hal/cc/buffer.h" -#include "iree/hal/cc/resource.h" - -#ifndef IREE_HAL_CC_DESCRIPTOR_SET_LAYOUT_H_ -#define IREE_HAL_CC_DESCRIPTOR_SET_LAYOUT_H_ - -namespace iree { -namespace hal { - -// Opaque handle to a descriptor set layout object. -// -// Maps to VkDescriptorSetLayout: -// https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkDescriptorSetLayout.html -class DescriptorSetLayout : public Resource { - public: -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_CC_DESCRIPTOR_SET_LAYOUT_H_
diff --git a/iree/hal/cc/device.h b/iree/hal/cc/device.h deleted file mode 100644 index 6d0e45d..0000000 --- a/iree/hal/cc/device.h +++ /dev/null
@@ -1,186 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_CC_DEVICE_H_ -#define IREE_HAL_CC_DEVICE_H_ - -#include <memory> - -#include "iree/base/ref_ptr.h" -#include "iree/base/status.h" -#include "iree/base/target_platform.h" -#include "iree/base/time.h" -#include "iree/hal/cc/allocator.h" -#include "iree/hal/cc/buffer.h" -#include "iree/hal/cc/command_queue.h" -#include "iree/hal/cc/descriptor_set.h" -#include "iree/hal/cc/descriptor_set_layout.h" -#include "iree/hal/cc/device_info.h" -#include "iree/hal/cc/event.h" -#include "iree/hal/cc/executable_cache.h" -#include "iree/hal/cc/executable_layout.h" -#include "iree/hal/cc/semaphore.h" - -#if defined(IREE_PLATFORM_WINDOWS) -// Win32 macro name conflicts: -#undef CreateEvent -#undef CreateSemaphore -#endif // IREE_PLATFORM_WINDOWS - -namespace iree { -namespace hal { - -class Device : public RefObject<Device> { - public: - virtual ~Device() = default; - - // Information about device capabilities. - const DeviceInfo& info() const { return device_info_; } - - // Returns a debug string describing the device. - virtual std::string DebugString() const { return device_info_.DebugString(); } - - // TODO(benvanik): status (thermal, power mode, etc). - - // TODO(benvanik): throttling adjustment/power profile. - - // TODO(benvanik): control (suspend/resume, delay, etc). - - // An allocator providing buffers usable by the device. - // This allocator may be shared with other devices in the same family. - virtual Allocator* allocator() const = 0; - - // Returns a list of all general-purpose dispatch queues provided by the - // device. In general these map 1:1 with independent execution contexts, - // though some devices may hide that and expose only a single queue that is - // scheduled internally. - virtual absl::Span<CommandQueue*> dispatch_queues() const = 0; - - // Returns a list of transfer queues provided by the device. These queues may - // perform transfer operations asynchronously with respect to execution on the - // dispatch queues. For large sequences of transfer operations always prefer - // using one of these queues. - // Note that if the device does not support a dedicated transfer queue this - // list may be the same as (or a subset of) dispatch_queues. - virtual absl::Span<CommandQueue*> transfer_queues() const = 0; - - // TODO(b/137153339): accept initial cache data. - // Creates a device-specific cache for executables prepared for dispatch. - // The cache manages executable compilation, caching (on disk or in memory), - // and lifetime. Users can decide to use one or more caches to allow differing - // lifetimes (such as unloading modules), persistent on disk caching of only - // specific hot executables, etc. - // - // Returns a thread-safe cache that must remain alive until all executables - // using the cache are no longer in-flight. - virtual ref_ptr<ExecutableCache> CreateExecutableCache() = 0; - - // Creates a descriptor set layout with the given bindings. - virtual StatusOr<ref_ptr<DescriptorSetLayout>> CreateDescriptorSetLayout( - iree_hal_descriptor_set_layout_usage_type_t usage_type, - absl::Span<const iree_hal_descriptor_set_layout_binding_t> bindings) = 0; - - // Creates an executable layout composed of the given descriptor set layouts. - // The returned executable layout can be used by multiple executables with the - // same compatible resource binding layouts. - virtual StatusOr<ref_ptr<ExecutableLayout>> CreateExecutableLayout( - absl::Span<DescriptorSetLayout* const> set_layouts, - size_t push_constants) = 0; - - // Creates a descriptor set of the given layout and bindings. - // Descriptor sets are immutable and retain their bindings. - virtual StatusOr<ref_ptr<DescriptorSet>> CreateDescriptorSet( - DescriptorSetLayout* set_layout, - absl::Span<const iree_hal_descriptor_set_binding_t> bindings) = 0; - - // Creates a command buffer for recording commands to submit to queues owned - // by this device. The command buffer may come from a pool but will be reset - // prior to being returned to the caller. - virtual StatusOr<ref_ptr<CommandBuffer>> CreateCommandBuffer( - iree_hal_command_buffer_mode_t mode, - iree_hal_command_category_t command_categories, - iree_hal_command_buffer_t** out_command_buffer) = 0; - - // Creates an event for recording into command buffers. - // The returned event object is only usable with this device and events must - // only be used to synchronize within the same queue. - virtual StatusOr<ref_ptr<Event>> CreateEvent() = 0; - - // Creates a semaphore that can be used with command queues owned by this - // device. To use the semaphores with other devices or instances they must - // first be exported. - virtual StatusOr<ref_ptr<Semaphore>> CreateSemaphore( - uint64_t initial_value) = 0; - - // TODO(benvanik): import/export semaphore utilities. - // TODO(benvanik): semaphores to wait handles. - - // Blocks the caller until all passed |semaphores| reach or exceed the - // specified payload values or the |deadline| elapses. All |semaphores| must - // be created from this device (or be imported into it). - // - // Returns success if the wait is successful and all semaphores have been - // signaled. - // - // Returns DEADLINE_EXCEEDED if the |deadline| elapses without all semaphores - // having been signaled. Note that a subset of the |semaphores| may have been - // signaled and each can be queried to see which ones. - virtual Status WaitAllSemaphores(absl::Span<const SemaphoreValue> semaphores, - Time deadline_ns) = 0; - inline Status WaitAllSemaphores(absl::Span<const SemaphoreValue> semaphores, - Duration timeout_ns) { - return WaitAllSemaphores(semaphores, - RelativeTimeoutToDeadlineNanos(timeout_ns)); - } - - // Blocks the caller until at least one of the |semaphores| reaches or exceeds - // the specified payload value or the |deadline| elapses. All |semaphores| - // must be created from this device (or be imported into it). - // - // Returns an arbitrary index into |semaphores| of a semaphore that was - // signaled. Note that more than one semaphore may have been signaled and all - // of the other |semaphores| should be queried or waited on again until waits - // for them succeed. - // - // Returns DEADLINE_EXCEEDED if the |deadline| elapses without any semaphores - // having been signaled. - virtual StatusOr<int> WaitAnySemaphore( - absl::Span<const SemaphoreValue> semaphores, Time deadline_ns) = 0; - inline StatusOr<int> WaitAnySemaphore( - absl::Span<const SemaphoreValue> semaphores, Duration timeout_ns) { - return WaitAnySemaphore(semaphores, - RelativeTimeoutToDeadlineNanos(timeout_ns)); - } - - // Blocks until all outstanding requests on all queues have been - // completed. This is equivalent to having waited on all outstanding - // semaphores. - virtual Status WaitIdle(Time deadline_ns) = 0; - inline Status WaitIdle(Duration timeout_ns) { - return WaitIdle(RelativeTimeoutToDeadlineNanos(timeout_ns)); - } - inline Status WaitIdle() { return WaitIdle(InfiniteFuture()); } - - protected: - explicit Device(DeviceInfo device_info) - : device_info_(std::move(device_info)) {} - - private: - const DeviceInfo device_info_; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_CC_DEVICE_H_
diff --git a/iree/hal/cc/device_info.h b/iree/hal/cc/device_info.h deleted file mode 100644 index 585ecc6..0000000 --- a/iree/hal/cc/device_info.h +++ /dev/null
@@ -1,85 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_CC_DEVICE_INFO_H_ -#define IREE_HAL_CC_DEVICE_INFO_H_ - -#include <cstdint> -#include <string> -#include <utility> - -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "iree/hal/api.h" - -namespace iree { -namespace hal { - -// TODO(benvanik): device info (caps, physical mappings, etc). -class DeviceInfo { - public: - DeviceInfo(std::string id, std::string name, - iree_hal_device_feature_t supported_features, - iree_hal_device_id_t device_id = 0) - : id_(std::move(id)), - name_(std::move(name)), - supported_features_(supported_features), - device_id_(device_id) {} - - // Machine-friendly device identifier used to match the device against - // compiler-generated patterns. This should be consistent with the device IDs - // emitted by the compiler. For example: `vulkan-v1.1-spec`. - const std::string& id() const { return id_; } - - // Human-friendly device name. - const std::string& name() const { return name_; } - - // Features supported by the device. - iree_hal_device_feature_t supported_features() const { - return supported_features_; - } - - // Opaque handle used by drivers to correlate this device with their internal - // listing. This handle will not be valid across driver instances or outside - // of the current process. - iree_hal_device_id_t device_id() const { return device_id_; } - - // Returns a debug string describing the device information. - std::string DebugString() const { - std::string features = "TODO"; - // FormatBitfieldValue( - // supported_features_, - // { - // {IREE_HAL_DEVICE_FEATURE_SUPPORTS_DEBUGGING, "kDebugging"}, - // {IREE_HAL_DEVICE_FEATURE_SUPPORTS_COVERAGE, "kCoverage"}, - // {IREE_HAL_DEVICE_FEATURE_SUPPORTS_PROFILING, "kProfiling"}, - // }); - - return absl::StrCat("[DeviceInfo]", // - "\n Name: ", name_, // - "\n Supported features: [", features, "]", // - "\n Device ID: ", device_id_); - } - - private: - const std::string id_; - const std::string name_; - const iree_hal_device_feature_t supported_features_; - iree_hal_device_id_t device_id_; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_CC_DEVICE_INFO_H_
diff --git a/iree/hal/cc/driver.h b/iree/hal/cc/driver.h deleted file mode 100644 index 64a50c6..0000000 --- a/iree/hal/cc/driver.h +++ /dev/null
@@ -1,69 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_CC_DRIVER_H_ -#define IREE_HAL_CC_DRIVER_H_ - -#include <memory> -#include <string> -#include <vector> - -#include "iree/base/ref_ptr.h" -#include "iree/base/status.h" -#include "iree/hal/cc/debug_capture_manager.h" -#include "iree/hal/cc/device.h" -#include "iree/hal/cc/device_info.h" - -namespace iree { -namespace hal { - -class Driver : public RefObject<Driver> { - public: - virtual ~Driver() = default; - - // Driver name used during registration. - const std::string& name() const { return name_; } - - // TODO(benvanik): info/query (version number, etc). - - // Enumerates devices available for creation from the driver. - // This may fail if the driver is in an invalid state but otherwise will - // return an empty list if no devices are available. - virtual StatusOr<std::vector<DeviceInfo>> EnumerateAvailableDevices() = 0; - - // Creates the driver-defined 'default' device. - // This may simply be the first device enumerated. - virtual StatusOr<ref_ptr<Device>> CreateDefaultDevice() = 0; - - // Creates a device as queried with the given |driver_handle|. - virtual StatusOr<ref_ptr<Device>> CreateDevice( - iree_hal_device_id_t device_id) = 0; - StatusOr<ref_ptr<Device>> CreateDevice(const DeviceInfo& device_info) { - return CreateDevice(device_info.device_id()); - } - - // Gets the capture manager for this driver, if one exists. - virtual DebugCaptureManager* debug_capture_manager() { return nullptr; } - - protected: - explicit Driver(std::string name) : name_(std::move(name)) {} - - private: - const std::string name_; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_CC_DRIVER_H_
diff --git a/iree/hal/cc/event.h b/iree/hal/cc/event.h deleted file mode 100644 index 030b141..0000000 --- a/iree/hal/cc/event.h +++ /dev/null
@@ -1,35 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_CC_EVENT_H_ -#define IREE_HAL_CC_EVENT_H_ - -#include "iree/hal/cc/resource.h" - -namespace iree { -namespace hal { - -// Events are used for defining synchronization scopes within CommandBuffers. -// An event only exists within a single CommandBuffer and must not be used -// across CommandBuffers from the same device or others. -// -// See CommandBuffer::SignalEvent and CommandBuffer::WaitEvents for more info. -class Event : public Resource { - public: -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_CC_EVENT_H_
diff --git a/iree/hal/cc/executable.h b/iree/hal/cc/executable.h deleted file mode 100644 index 3965e17..0000000 --- a/iree/hal/cc/executable.h +++ /dev/null
@@ -1,30 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_CC_EXECUTABLE_H_ -#define IREE_HAL_CC_EXECUTABLE_H_ - -#include "iree/hal/cc/resource.h" - -namespace iree { -namespace hal { - -class Executable : public Resource { - public: -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_CC_EXECUTABLE_H_
diff --git a/iree/hal/cc/executable_cache.h b/iree/hal/cc/executable_cache.h deleted file mode 100644 index f6b6a75..0000000 --- a/iree/hal/cc/executable_cache.h +++ /dev/null
@@ -1,79 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_CC_EXECUTABLE_CACHE_H_ -#define IREE_HAL_CC_EXECUTABLE_CACHE_H_ - -#include "iree/base/api.h" -#include "iree/base/ref_ptr.h" -#include "iree/base/status.h" -#include "iree/hal/api.h" -#include "iree/hal/cc/executable.h" -#include "iree/hal/cc/executable_format.h" -#include "iree/hal/cc/executable_layout.h" - -namespace iree { -namespace hal { - -// A cache of prepared executables for a particular device. -// Caches may be shared across multiple devices from the same driver or specific -// to individual devices. Caches may persist prepared executables across process -// launches or re-prepare them each run. Callers should assume that the cache is -// a no-op and the returned Executables only live for as long as the cache does. -// -// The term 'cache' here is rather optimistic - it's perfectly acceptable for -// implementations to not cache at all and return new Executables for each -// PrepareExecutable called (even for the same executable). Callers should -// expect such behavior and try to retain the results of the PrepareExecutable -// calls to reduce overhead in re-preparing executables. -// -// Thread-safe - multiple threads may prepare executables (including the *same* -// executable) simultaneously. -class ExecutableCache : public RefObject<ExecutableCache> { - public: - virtual ~ExecutableCache() = default; - - // TODO(benvanik): status/queries (size, etc). - - // TODO(b/137153339): serialization/deserialization. - - // Returns true if the executable cache can prepare the given executable input - // format. Preparation may still fail if the particular version or features - // required by the executable are not supported. - virtual bool CanPrepareFormat(ExecutableFormat format) const = 0; - - // Prepares an executable for use. - // The provided |spec| and |executable_data| will be used to either lookup a - // previously prepared executable in the cache or prepare a new one. - // - // Depending on the driver preparation may take a non-trivial amount of time - // (such as when JITing/etc). As the cache is internally synchronized callers - // can issue preparation requests from multiple threads - even for the same - // executables - and calls will block until preparation completes. - // - // When preparing a large number of executables it's recommended to use the - // PrepareExecutables method to batch and wait on the results. - virtual StatusOr<ref_ptr<Executable>> PrepareExecutable( - ExecutableLayout* executable_layout, - iree_hal_executable_caching_mode_t mode, - iree_const_byte_span_t executable_data) = 0; - - protected: - ExecutableCache() = default; -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_CC_EXECUTABLE_CACHE_H_
diff --git a/iree/hal/cc/executable_format.h b/iree/hal/cc/executable_format.h deleted file mode 100644 index bdd0773..0000000 --- a/iree/hal/cc/executable_format.h +++ /dev/null
@@ -1,75 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Defines the ExecutableFormat 4cc type and a few well-known formats. -// Not all formats need to be defined here, however any format expected to be -// supported by debuggers/tooling will likely want to be here to ensure easier -// referencing. - -#ifndef IREE_HAL_CC_EXECUTABLE_FORMAT_H_ -#define IREE_HAL_CC_EXECUTABLE_FORMAT_H_ - -#include <cstdint> - -namespace iree { -namespace hal { - -// Executable format 4cc identifier. -using ExecutableFormat = uint32_t; - -// Constructs an ExecutableFormat 4cc at compile-time. -constexpr ExecutableFormat MakeExecutableFormatID(char const four_cc[5]) { - return (four_cc[0] << 24) | (four_cc[1] << 16) | (four_cc[2] << 8) | - four_cc[3]; -} - -// Keep these in sync with iree/compiler/Dialect/HAL/IR/HALBase.td - -// Undefined (or unknown). The format may be derived from the executable -// contents (such as file magic bytes). -constexpr ExecutableFormat kExecutableFormatUnspecified = - MakeExecutableFormatID(" "); - -// MLIR text form. -constexpr ExecutableFormat kExecutableFormatMlir = - MakeExecutableFormatID("MLIR"); - -// IREE v0 bytecode. -constexpr ExecutableFormat kExecutableFormatIreeBytecode = - MakeExecutableFormatID("IREE"); - -// IREE VMLA executable in FlatBuffer format using the -// iree/schemas/vmla_executable_def.fbs schema. -constexpr ExecutableFormat kExecutableFormatVMLA = - MakeExecutableFormatID("VMLA"); - -// SPIR-V executable in FlatBuffer format using the -// iree/schemas/spirv_executable_def.fbs schema. -constexpr ExecutableFormat kExecutableFormatSpirV = - MakeExecutableFormatID("SPVE"); - -// Metal executable in FlatBuffer format using the -// iree/schemas/metal_executable_def.fbs schema. -constexpr ExecutableFormat kExecutableFormatMetal = - MakeExecutableFormatID("MTLE"); - -// Dynamic Library (dylib) executable in FlatBuffer format using the -// iree/schemas/dylib_executable_def.fbs schema -constexpr ExecutableFormat kExecutableFormatDyLib = - MakeExecutableFormatID("DLIB"); - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_CC_EXECUTABLE_FORMAT_H_
diff --git a/iree/hal/cc/executable_layout.h b/iree/hal/cc/executable_layout.h deleted file mode 100644 index a8a4a25..0000000 --- a/iree/hal/cc/executable_layout.h +++ /dev/null
@@ -1,39 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/cc/resource.h" - -#ifndef IREE_HAL_CC_EXECUTABLE_LAYOUT_H_ -#define IREE_HAL_CC_EXECUTABLE_LAYOUT_H_ - -namespace iree { -namespace hal { - -// Defines the resource binding layout used by an executable. -// -// Executables can share the same layout even if they do not use all of the -// resources referenced by descriptor sets referenced by the layout. Doing so -// allows for more efficient binding as bound descriptor sets can be reused when -// command buffer executable bindings change. -// -// Maps to VkPipelineLayout: -// https://www.khronos.org/registry/vulkan/specs/1.2-extensions/man/html/VkPipelineLayout.html -class ExecutableLayout : public Resource { - public: -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_CC_EXECUTABLE_LAYOUT_H_
diff --git a/iree/hal/cc/resource.h b/iree/hal/cc/resource.h deleted file mode 100644 index ac2e1bb..0000000 --- a/iree/hal/cc/resource.h +++ /dev/null
@@ -1,52 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_CC_RESOURCE_H_ -#define IREE_HAL_CC_RESOURCE_H_ - -#include <ostream> -#include <string> - -#include "iree/base/ref_ptr.h" - -namespace iree { -namespace hal { - -// Abstract resource type whose lifetime is managed by a ResourceSet. -// Used mostly just to get a virtual dtor, though we could add nicer logging -// by allowing resources to capture debug names, stack traces of creation, etc. -class Resource : public RefObject<Resource> { - public: - virtual ~Resource() = default; - - // Returns a longer debug string describing the resource and its attributes. - virtual std::string DebugString() const { return DebugStringShort(); } - // Returns a short debug string describing the resource. - virtual std::string DebugStringShort() const { - // TODO(benvanik): remove this when all resource types have custom logic. - return std::string("resource_") + std::to_string(static_cast<uint64_t>( - reinterpret_cast<uintptr_t>(this))); - } -}; - -} // namespace hal -} // namespace iree - -inline std::ostream& operator<<(std::ostream& stream, - const iree::hal::Resource& resource) { - stream << resource.DebugStringShort(); - return stream; -} - -#endif // IREE_HAL_CC_RESOURCE_H_
diff --git a/iree/hal/cc/semaphore.h b/iree/hal/cc/semaphore.h deleted file mode 100644 index 74988f9..0000000 --- a/iree/hal/cc/semaphore.h +++ /dev/null
@@ -1,102 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_CC_SEMAPHORE_H_ -#define IREE_HAL_CC_SEMAPHORE_H_ - -#include <cstdint> - -#include "iree/base/status.h" -#include "iree/base/time.h" -#include "iree/hal/cc/resource.h" - -namespace iree { -namespace hal { - -class Semaphore; - -// A reference to a semaphore and associated payload value. -struct SemaphoreValue { - Semaphore* semaphore = nullptr; - uint64_t value = 0; -}; - -// Synchronization mechanism for host->device, device->host, host->host, -// and device->device notification. Semaphores behave like Vulkan timeline -// semaphores (or D3D12 fences) and contain a monotonically increasing -// uint64_t payload. They may be waited on any number of times even if they -// have already been signaled for a particular value. They may also be waited -// on for a particular value prior to the signal for that value. -// -// A semaphore is updated to its new value after all prior commands have -// completed but the delay between completion and the host being woken varies. -// Some implementations may coalesce semaphores to avoid spurious waking while -// others will immediately synchronize with the host. -// -// One use of semaphores is for resource lifetime management: all resources used -// by a set of submission batches must be considered live until the semaphore -// attached to the submission has signaled. -// -// Another use of semaphores is device->device synchronization for setting up -// the DAG of command buffers across queue submissions. This allows devices to -// perform non-trivial scheduling behavior without the need to wake the host. -// -// Semaphores may be set to a permanently failed state by implementations when -// errors occur during asynchronous execution. Users are expected to propagate -// the failures and possibly reset the entire device that produced the error. -// -// For more information on semaphores see the following docs describing how -// timelines are generally used (specifically in the device->host case): -// https://www.youtube.com/watch?v=SpE--Rf516Y -// https://www.khronos.org/assets/uploads/developers/library/2018-xdc/Vulkan-Timeline-Semaphores-Part-1_Sep18.pdf -// https://docs.microsoft.com/en-us/windows/win32/direct3d12/user-mode-heap-synchronization -class Semaphore : public Resource { - public: - // Queries the current payload of the semaphore. As the payload is - // monotonically increasing it is guaranteed that the value is at least equal - // to the previous result of a Query call and coherent with any waits for - // a specified value via Device::WaitAllSemaphores. - // - // Returns the status/payload at the time the method is called without - // blocking and as such is only valid after a semaphore has been signaled. The - // same failure status will be returned regardless of when in the timeline the - // error occurred. - virtual StatusOr<uint64_t> Query() = 0; - - // Signals the semaphore to the given payload value. - // The call is ignored if the current payload value exceeds |value|. - virtual Status Signal(uint64_t value) = 0; - - // Signals the semaphore with a failure. The |status| will be returned from - // Query and Signal for the lifetime of the semaphore. - virtual void Fail(Status status) = 0; - - // Blocks the caller until the semaphore reaches or exceedes the specified - // payload value or the |deadline_ns| elapses. - // - // Returns success if the wait is successful and the semaphore has met or - // exceeded the required payload value. - // - // Returns DEADLINE_EXCEEDED if the |deadline_ns| elapses without the - // semaphore reaching the required value. - virtual Status Wait(uint64_t value, Time deadline_ns) = 0; - inline Status Wait(uint64_t value, Duration timeout_ns) { - return Wait(value, RelativeTimeoutToDeadlineNanos(timeout_ns)); - } -}; - -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_CC_SEMAPHORE_H_
diff --git a/iree/hal/metal/README.md b/iree/hal/metal/README.md index 301d026..d0bbc85 100644 --- a/iree/hal/metal/README.md +++ b/iree/hal/metal/README.md
@@ -1,5 +1,7 @@ # Metal HAL Driver +**TODO(antiagainst)**: move the docs here - having them separate is suboptimal. + This directory contains the source code for the Metal HAL driver. See the [design doc](https://google.github.io/iree/design-docs/metal-hal-driver) for more details.
diff --git a/iree/hal/metal/metal_buffer.h b/iree/hal/metal/metal_buffer.h index 730c3b7..4238d51 100644 --- a/iree/hal/metal/metal_buffer.h +++ b/iree/hal/metal/metal_buffer.h
@@ -17,7 +17,9 @@ #import <Metal/Metal.h> -#include "iree/hal/cc/buffer.h" +#include "iree/hal/api.h" + +id<MTLBuffer> iree_hal_metal_buffer_handle(iree_hal_buffer_t* base_buffer); namespace iree { namespace hal { @@ -36,14 +38,6 @@ iree_device_size_t byte_length, id<MTLBuffer> buffer, id<MTLCommandQueue> transfer_queue); - // Creates a MetalBuffer instance without retaining the given id<MTLBuffer>. - static StatusOr<ref_ptr<MetalBuffer>> CreateUnretained( - MetalDirectAllocator* allocator, iree_hal_memory_type_t memory_type, - iree_hal_memory_access_t allowed_access, iree_hal_buffer_usage_t usage, - iree_device_size_t allocation_size, iree_device_size_t byte_offset, - iree_device_size_t byte_length, id<MTLBuffer> buffer, - id<MTLCommandQueue> transfer_queue); - ~MetalBuffer() override; id<MTLBuffer> handle() const { return metal_handle_; }
diff --git a/iree/hal/metal/metal_buffer.mm b/iree/hal/metal/metal_buffer.mm index cf52ec0..a6cb965 100644 --- a/iree/hal/metal/metal_buffer.mm +++ b/iree/hal/metal/metal_buffer.mm
@@ -18,6 +18,8 @@ #include "iree/base/tracing.h" #include "iree/hal/metal/metal_direct_allocator.h" +id<MTLBuffer> iree_hal_metal_buffer_handle(iree_hal_buffer_t* base_buffer); + namespace iree { namespace hal { namespace metal { @@ -30,17 +32,6 @@ id<MTLCommandQueue> transfer_queue) { IREE_TRACE_SCOPE0("MetalBuffer::Create"); return assign_ref(new MetalBuffer(allocator, memory_type, allowed_access, usage, allocation_size, - byte_offset, byte_length, [buffer retain], transfer_queue)); -} - -// static -StatusOr<ref_ptr<MetalBuffer>> MetalBuffer::CreateUnretained( - MetalDirectAllocator* allocator, iree_hal_memory_type_t memory_type, - iree_hal_memory_access_t allowed_access, iree_hal_buffer_usage_t usage, iree_device_size_t allocation_size, - iree_device_size_t byte_offset, iree_device_size_t byte_length, id<MTLBuffer> buffer, - id<MTLCommandQueue> transfer_queue) { - IREE_TRACE_SCOPE0("MetalBuffer::Create"); - return assign_ref(new MetalBuffer(allocator, memory_type, allowed_access, usage, allocation_size, byte_offset, byte_length, buffer, transfer_queue)); }
diff --git a/iree/hal/metal/metal_capture_manager.h b/iree/hal/metal/metal_capture_manager.h index bdfefb9..22d17dd 100644 --- a/iree/hal/metal/metal_capture_manager.h +++ b/iree/hal/metal/metal_capture_manager.h
@@ -20,7 +20,6 @@ #import <Metal/Metal.h> #include "iree/base/status.h" -#include "iree/hal/cc/debug_capture_manager.h" namespace iree { namespace hal {
diff --git a/iree/hal/metal/metal_command_buffer.h b/iree/hal/metal/metal_command_buffer.h index 1b346e7..d9b4321 100644 --- a/iree/hal/metal/metal_command_buffer.h +++ b/iree/hal/metal/metal_command_buffer.h
@@ -19,7 +19,6 @@ #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" -#include "iree/hal/cc/command_buffer.h" #include "iree/hal/metal/metal_buffer.h" namespace iree { @@ -40,8 +39,6 @@ id<MTLCommandBuffer> handle() const { return metal_handle_; } - bool is_recording() const override { return is_recording_; } - Status Begin() override; Status End() override; @@ -51,43 +48,49 @@ absl::Span<const iree_hal_memory_barrier_t> memory_barriers, absl::Span<const iree_hal_buffer_barrier_t> buffer_barriers) override; - Status SignalEvent(Event* event, + Status SignalEvent(iree_hal_event_t* event, iree_hal_execution_stage_t source_stage_mask) override; - Status ResetEvent(Event* event, + Status ResetEvent(iree_hal_event_t* event, iree_hal_execution_stage_t source_stage_mask) override; Status WaitEvents( - absl::Span<Event*> events, iree_hal_execution_stage_t source_stage_mask, + absl::Span<iree_hal_event_t*> events, + iree_hal_execution_stage_t source_stage_mask, iree_hal_execution_stage_t target_stage_mask, absl::Span<const iree_hal_memory_barrier_t> memory_barriers, absl::Span<const iree_hal_buffer_barrier_t> buffer_barriers) override; - Status FillBuffer(Buffer* target_buffer, iree_device_size_t target_offset, - iree_device_size_t length, const void* pattern, - size_t pattern_length) override; - Status DiscardBuffer(Buffer* buffer) override; + Status FillBuffer(iree_hal_buffer_t* target_buffer, + iree_device_size_t target_offset, iree_device_size_t length, + const void* pattern, size_t pattern_length) override; + Status DiscardBuffer(iree_hal_buffer_t* buffer) override; Status UpdateBuffer(const void* source_buffer, - iree_device_size_t source_offset, Buffer* target_buffer, + iree_device_size_t source_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, iree_device_size_t length) override; - Status CopyBuffer(Buffer* source_buffer, iree_device_size_t source_offset, - Buffer* target_buffer, iree_device_size_t target_offset, + Status CopyBuffer(iree_hal_buffer_t* source_buffer, + iree_device_size_t source_offset, + iree_hal_buffer_t* target_buffer, + iree_device_size_t target_offset, iree_device_size_t length) override; - Status PushConstants(ExecutableLayout* executable_layout, size_t offset, + Status PushConstants(iree_hal_executable_layout_t* executable_layout, + size_t offset, absl::Span<const uint32_t> values) override; Status PushDescriptorSet( - ExecutableLayout* executable_layout, int32_t set, + iree_hal_executable_layout_t* executable_layout, uint32_t set, absl::Span<const iree_hal_descriptor_set_binding_t> bindings) override; Status BindDescriptorSet( - ExecutableLayout* executable_layout, int32_t set, - DescriptorSet* descriptor_set, + iree_hal_executable_layout_t* executable_layout, uint32_t set, + iree_hal_descriptor_set_t* descriptor_set, absl::Span<const iree_device_size_t> dynamic_offsets) override; - Status Dispatch(Executable* executable, int32_t entry_point, + Status Dispatch(iree_hal_executable_t* executable, int32_t entry_point, std::array<uint32_t, 3> workgroups) override; - Status DispatchIndirect(Executable* executable, int32_t entry_point, - Buffer* workgroups_buffer, + Status DispatchIndirect(iree_hal_executable_t* executable, + int32_t entry_point, + iree_hal_buffer_t* workgroups_buffer, iree_device_size_t workgroups_offset) override; private: @@ -117,9 +120,6 @@ iree_hal_command_category_t command_categories, id<MTLCommandBuffer> command_buffer); - StatusOr<MetalBuffer*> CastBuffer(Buffer* buffer) const; - StatusOr<MetalBuffer*> CastBuffer(iree_hal_buffer_t* buffer) const; - // Gets or begins an active MTLBlitCommandEncoder. This also ends all previous // encoded compute commands if any. id<MTLBlitCommandEncoder> GetOrBeginBlitEncoder(); @@ -137,7 +137,7 @@ id<MTLComputeCommandEncoder> current_compute_encoder_ = nil; id<MTLBlitCommandEncoder> current_blit_encoder_ = nil; - absl::flat_hash_map<ExecutableLayout*, PipelineStateObject> + absl::flat_hash_map<iree_hal_executable_layout_t*, PipelineStateObject> pipeline_state_objects_; };
diff --git a/iree/hal/metal/metal_command_buffer.mm b/iree/hal/metal/metal_command_buffer.mm index 61b932f..46607f6 100644 --- a/iree/hal/metal/metal_command_buffer.mm +++ b/iree/hal/metal/metal_command_buffer.mm
@@ -54,18 +54,6 @@ [metal_handle_ release]; } -StatusOr<MetalBuffer*> MetalCommandBuffer::CastBuffer(Buffer* buffer) const { - // TODO(benvanik): assert that the buffer is from the right allocator and - // that it is compatible with our target queue family. - return static_cast<MetalBuffer*>(buffer->allocated_buffer()); -} - -StatusOr<MetalBuffer*> MetalCommandBuffer::CastBuffer(iree_hal_buffer_t* buffer) const { - // TODO(benvanik): assert that the buffer is from the right allocator and - // that it is compatible with our target queue family. - return reinterpret_cast<MetalBuffer*>(iree_hal_buffer_allocated_buffer(buffer)); -} - id<MTLBlitCommandEncoder> MetalCommandBuffer::GetOrBeginBlitEncoder() { IREE_TRACE_SCOPE0("MetalCommandBuffer::GetOrBeginBlitEncoder"); @@ -155,17 +143,17 @@ return OkStatus(); } -Status MetalCommandBuffer::SignalEvent(Event* event, iree_hal_execution_stage_t source_stage_mask) { +Status MetalCommandBuffer::SignalEvent(iree_hal_event_t* event, iree_hal_execution_stage_t source_stage_mask) { IREE_TRACE_SCOPE0("MetalCommandBuffer::SignalEvent"); return UnimplementedErrorBuilder(IREE_LOC) << "MetalCommandBuffer::SignalEvent"; } -Status MetalCommandBuffer::ResetEvent(Event* event, iree_hal_execution_stage_t source_stage_mask) { +Status MetalCommandBuffer::ResetEvent(iree_hal_event_t* event, iree_hal_execution_stage_t source_stage_mask) { IREE_TRACE_SCOPE0("MetalCommandBuffer::ResetEvent"); return UnimplementedErrorBuilder(IREE_LOC) << "MetalCommandBuffer::ResetEvent"; } -Status MetalCommandBuffer::WaitEvents(absl::Span<Event*> events, +Status MetalCommandBuffer::WaitEvents(absl::Span<iree_hal_event_t*> events, iree_hal_execution_stage_t source_stage_mask, iree_hal_execution_stage_t target_stage_mask, absl::Span<const iree_hal_memory_barrier_t> memory_barriers, @@ -174,13 +162,13 @@ return UnimplementedErrorBuilder(IREE_LOC) << "MetalCommandBuffer::WaitEvents"; } -Status MetalCommandBuffer::FillBuffer(Buffer* target_buffer, iree_device_size_t target_offset, +Status MetalCommandBuffer::FillBuffer(iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, iree_device_size_t length, const void* pattern, size_t pattern_length) { IREE_TRACE_SCOPE0("MetalCommandBuffer::FillBuffer"); - IREE_ASSIGN_OR_RETURN(auto* target_device_buffer, CastBuffer(target_buffer)); + id<MTLBuffer> target_device_buffer = iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(target_buffer)); - target_offset += target_buffer->byte_offset(); + target_offset += iree_hal_buffer_byte_offset(target_buffer); // Per the spec for fillBuffer:range:value: "The alignment and length of the range must both be a // multiple of 4 bytes in macOS, and 1 byte in iOS and tvOS." Although iOS/tvOS is more relaxed on @@ -207,29 +195,29 @@ return OkStatus(); } -Status MetalCommandBuffer::DiscardBuffer(Buffer* buffer) { +Status MetalCommandBuffer::DiscardBuffer(iree_hal_buffer_t* buffer) { IREE_TRACE_SCOPE0("MetalCommandBuffer::DiscardBuffer"); // This is a hint. Nothing to do for Metal. return OkStatus(); } Status MetalCommandBuffer::UpdateBuffer(const void* source_buffer, iree_device_size_t source_offset, - Buffer* target_buffer, iree_device_size_t target_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, iree_device_size_t length) { IREE_TRACE_SCOPE0("MetalCommandBuffer::UpdateBuffer"); return UnimplementedErrorBuilder(IREE_LOC) << "MetalCommandBuffer::UpdateBuffer"; } -Status MetalCommandBuffer::CopyBuffer(Buffer* source_buffer, iree_device_size_t source_offset, - Buffer* target_buffer, iree_device_size_t target_offset, +Status MetalCommandBuffer::CopyBuffer(iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, iree_device_size_t length) { IREE_TRACE_SCOPE0("MetalCommandBuffer::CopyBuffer"); - IREE_ASSIGN_OR_RETURN(auto* source_device_buffer, CastBuffer(source_buffer)); - IREE_ASSIGN_OR_RETURN(auto* target_device_buffer, CastBuffer(target_buffer)); + id<MTLBuffer> source_device_buffer = iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(source_buffer)); + id<MTLBuffer> target_device_buffer = iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(target_buffer)); - source_offset += source_buffer->byte_offset(); - target_offset += target_buffer->byte_offset(); + source_offset += iree_hal_buffer_byte_offset(source_buffer); + target_offset += iree_hal_buffer_byte_offset(target_buffer); // Per the spec for copyFromBuffer:sourceOffset:toBuffer:destinationOffset:size, the source/target // offset must be a multiple of 4 bytes in macOS, and 1 byte in iOS and tvOS. Although iOS/tvOS @@ -248,13 +236,13 @@ return OkStatus(); } -Status MetalCommandBuffer::PushConstants(ExecutableLayout* executable_layout, size_t offset, +Status MetalCommandBuffer::PushConstants(iree_hal_executable_layout_t* executable_layout, size_t offset, absl::Span<const uint32_t> values) { IREE_TRACE_SCOPE0("MetalCommandBuffer::PushConstants"); return UnimplementedErrorBuilder(IREE_LOC) << "MetalCommandBuffer::PushConstants"; } -Status MetalCommandBuffer::PushDescriptorSet(ExecutableLayout* executable_layout, int32_t set, +Status MetalCommandBuffer::PushDescriptorSet(iree_hal_executable_layout_t* executable_layout, int32_t set, absl::Span<const iree_hal_descriptor_set_binding_t> bindings) { IREE_TRACE_SCOPE0("MetalCommandBuffer::PushDescriptorSet"); if (set != 0) { @@ -266,8 +254,8 @@ return OkStatus(); } -Status MetalCommandBuffer::BindDescriptorSet(ExecutableLayout* executable_layout, int32_t set, - DescriptorSet* descriptor_set, +Status MetalCommandBuffer::BindDescriptorSet(iree_hal_executable_layout_t* executable_layout, int32_t set, + iree_hal_descriptor_set_t* descriptor_set, absl::Span<const iree_device_size_t> dynamic_offsets) { IREE_TRACE_SCOPE0("MetalCommandBuffer::BindDescriptorSet"); if (set != 0) { @@ -282,7 +270,7 @@ return OkStatus(); } -Status MetalCommandBuffer::Dispatch(Executable* executable, int32_t entry_point, +Status MetalCommandBuffer::Dispatch(iree_hal_executable_t* executable, int32_t entry_point, std::array<uint32_t, 3> workgroups) { IREE_TRACE_SCOPE0("MetalCommandBuffer::Dispatch"); IREE_DVLOG(2) << "MetalCommandBuffer::Dispatch"; @@ -299,7 +287,6 @@ // TODO(antiagainst): only update the PSO for the current executable. for (const auto& pso_kv : pipeline_state_objects_) { const auto* pipeline_layout = static_cast<MetalPipelineArgumentBufferLayout*>(pso_kv.first); - IREE_DVLOG(3) << "Current pipeline layout: " << pipeline_layout->DebugString(); const auto& pso = pso_kv.second; if (pso.push_states.size() > 1) { @@ -317,7 +304,7 @@ IREE_DVLOG(3) << "Encoding push descriptors.."; for (const auto& push_kv : pso.push_states) { - int32_t set_number = push_kv.first; + uint32_t set_number = push_kv.first; const PipelineStateObject::PushState& push_state = push_kv.second; IREE_DVLOG(3) << " For set #" << set_number; @@ -345,7 +332,6 @@ [argument_encoder setArgumentBuffer:argument_buffer offset:0]; for (const auto& resource_binding : push_state.resource_bindings) { - IREE_DVLOG(3) << " Resource @[" << resource_binding.DebugStringShort() << "]"; if (resource_binding.length != IREE_WHOLE_BUFFER && resource_binding.length != resource_binding.buffer->allocation_size()) { @@ -353,8 +339,7 @@ << "MetalCommandBuffer::Dispatch with sub-buffer"; } - IREE_ASSIGN_OR_RETURN(auto buffer, CastBuffer(resource_binding.buffer)); - [argument_encoder setBuffer:buffer->handle() + [argument_encoder setBuffer:iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(resource_binding.buffer)) offset:resource_binding.offset atIndex:resource_binding.binding]; @@ -383,8 +368,8 @@ return OkStatus(); } -Status MetalCommandBuffer::DispatchIndirect(Executable* executable, int32_t entry_point, - Buffer* workgroups_buffer, +Status MetalCommandBuffer::DispatchIndirect(iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_t* workgroups_buffer, iree_device_size_t workgroups_offset) { IREE_TRACE_SCOPE0("MetalCommandBuffer::DispatchIndirect"); return UnimplementedErrorBuilder(IREE_LOC) << "MetalCommandBuffer::DispatchIndirect";
diff --git a/iree/hal/metal/metal_command_queue.h b/iree/hal/metal/metal_command_queue.h index 49d7581..ef4d64c 100644 --- a/iree/hal/metal/metal_command_queue.h +++ b/iree/hal/metal/metal_command_queue.h
@@ -20,7 +20,7 @@ #include "iree/base/arena.h" #include "iree/base/status.h" #include "iree/base/time.h" -#include "iree/hal/cc/command_queue.h" +#include "iree/hal/api.h" namespace iree { namespace hal {
diff --git a/iree/hal/metal/metal_device.h b/iree/hal/metal/metal_device.h index 99fdd4d..aea32ce 100644 --- a/iree/hal/metal/metal_device.h +++ b/iree/hal/metal/metal_device.h
@@ -21,11 +21,8 @@ #include "absl/types/span.h" #include "iree/base/memory.h" -#include "iree/hal/cc/allocator.h" -#include "iree/hal/cc/debug_capture_manager.h" #include "iree/hal/cc/device.h" #include "iree/hal/cc/driver.h" -#include "iree/hal/cc/semaphore.h" namespace iree { namespace hal { @@ -37,56 +34,51 @@ // Creates a device that retains the underlying Metal GPU device. // The iree_hal_device_id_t in |device_info| is expected to be an // id<MTLDevice>. - static StatusOr<ref_ptr<MetalDevice>> Create( - ref_ptr<Driver> driver, const DeviceInfo& device_info, - DebugCaptureManager* debug_capture_manager); + static StatusOr<ref_ptr<MetalDevice>> Create(ref_ptr<Driver> driver, + const DeviceInfo& device_info); ~MetalDevice() override; - std::string DebugString() const override; - Allocator* allocator() const override { return allocator_.get(); } - absl::Span<CommandQueue*> dispatch_queues() const override { - return absl::MakeSpan(&common_queue_, 1); - } + Status CreateExecutableCache( + iree_string_view_t identifier, + iree_hal_executable_cache_t** out_executable_cache) override; - absl::Span<CommandQueue*> transfer_queues() const override { - return absl::MakeSpan(&common_queue_, 1); - } - - ref_ptr<ExecutableCache> CreateExecutableCache() override; - - StatusOr<ref_ptr<DescriptorSetLayout>> CreateDescriptorSetLayout( + Status CreateDescriptorSetLayout( iree_hal_descriptor_set_layout_usage_type_t usage_type, - absl::Span<const iree_hal_descriptor_set_layout_binding_t> bindings) - override; + absl::Span<const iree_hal_descriptor_set_layout_binding_t> bindings, + iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) override; - StatusOr<ref_ptr<ExecutableLayout>> CreateExecutableLayout( - absl::Span<DescriptorSetLayout* const> set_layouts, - size_t push_constants) override; + Status CreateExecutableLayout( + absl::Span<iree_hal_descriptor_set_layout_t*> set_layouts, + size_t push_constants, + iree_hal_executable_layout_t** out_executable_layout) override; - StatusOr<ref_ptr<DescriptorSet>> CreateDescriptorSet( - DescriptorSetLayout* set_layout, - absl::Span<const iree_hal_descriptor_set_binding_t> bindings) override; + Status CreateDescriptorSet( + iree_hal_descriptor_set_layout_t* set_layout, + absl::Span<const iree_hal_descriptor_set_binding_t> bindings, + iree_hal_descriptor_set_t** out_descriptor_set) override; - StatusOr<ref_ptr<CommandBuffer>> CreateCommandBuffer( + Status CreateCommandBuffer( iree_hal_command_buffer_mode_t mode, - iree_hal_command_category_t command_categories) override; + iree_hal_command_category_t command_categories, + iree_hal_command_buffer_t** out_command_buffer) override; - StatusOr<ref_ptr<Event>> CreateEvent() override; + Status CreateEvent(iree_hal_event_t** out_event) override; - StatusOr<ref_ptr<Semaphore>> CreateSemaphore(uint64_t initial_value) override; - Status WaitAllSemaphores(absl::Span<const SemaphoreValue> semaphores, - Time deadline_ns) override; - StatusOr<int> WaitAnySemaphore(absl::Span<const SemaphoreValue> semaphores, - Time deadline_ns) override; + Status CreateSemaphore(uint64_t initial_value, + iree_hal_semaphore_t** out_semaphore) override; + Status WaitAllSemaphores(const iree_hal_semaphore_list_t* semaphore_list, + iree_time_t deadline_ns) override; + StatusOr<int> WaitAnySemaphore( + const iree_hal_semaphore_list_t* semaphore_list, + iree_time_t deadline_ns) override; - Status WaitIdle(Time deadline_ns) override; + Status WaitIdle(iree_time_t deadline_ns) override; private: - MetalDevice(ref_ptr<Driver> driver, const DeviceInfo& device_info, - DebugCaptureManager* debug_capture_manager); + MetalDevice(ref_ptr<Driver> driver, const DeviceInfo& device_info); ref_ptr<Driver> driver_; id<MTLDevice> metal_handle_; @@ -107,8 +99,6 @@ // semaphore. dispatch_queue_t wait_notifier_; MTLSharedEventListener* event_listener_; - - DebugCaptureManager* debug_capture_manager_ = nullptr; }; } // namespace metal
diff --git a/iree/hal/metal/metal_device.mm b/iree/hal/metal/metal_device.mm index 90075d7..58e0000 100644 --- a/iree/hal/metal/metal_device.mm +++ b/iree/hal/metal/metal_device.mm
@@ -19,7 +19,6 @@ #include "iree/base/status.h" #include "iree/base/time.h" #include "iree/base/tracing.h" -#include "iree/hal/cc/allocator.h" #include "iree/hal/metal/dispatch_time_util.h" #include "iree/hal/metal/metal_capture_manager.h" #include "iree/hal/metal/metal_command_buffer.h" @@ -83,40 +82,37 @@ [metal_handle_ release]; } -std::string MetalDevice::DebugString() const { - return absl::StrCat(Device::DebugString(), // - "\n[MetalDevice]", // - "\n - Dispatch Queues: 1", // - "\n - Transfer Queues: 1"); -} - -ref_ptr<ExecutableCache> MetalDevice::CreateExecutableCache() { +Status MetalDevice::CreateExecutableCache(iree_string_view_t identifier, + iree_hal_executable_cache_t** out_executable_cache) { IREE_TRACE_SCOPE0("MetalDevice::CreateExecutableCache"); return make_ref<MetalPipelineCache>(metal_handle_); } -StatusOr<ref_ptr<DescriptorSetLayout>> MetalDevice::CreateDescriptorSetLayout( +Status MetalDevice::CreateDescriptorSetLayout( iree_hal_descriptor_set_layout_usage_type_t usage_type, - absl::Span<const iree_hal_descriptor_set_layout_binding_t> bindings) { + absl::Span<const iree_hal_descriptor_set_layout_binding_t> bindings, + iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) { IREE_TRACE_SCOPE0("MetalDevice::CreateDescriptorSetLayout"); return make_ref<MetalArgumentBufferLayout>(usage_type, bindings); } -StatusOr<ref_ptr<ExecutableLayout>> MetalDevice::CreateExecutableLayout( - absl::Span<DescriptorSetLayout* const> set_layouts, size_t push_constants) { +Status MetalDevice::CreateExecutableLayout( + absl::Span<iree_hal_descriptor_set_layout_t*> set_layouts, size_t push_constants, iree_hal_executable_layout_t** out_executable_layout) { IREE_TRACE_SCOPE0("MetalDevice::CreateExecutableLayout"); return make_ref<MetalPipelineArgumentBufferLayout>(set_layouts, push_constants); } -StatusOr<ref_ptr<DescriptorSet>> MetalDevice::CreateDescriptorSet( - DescriptorSetLayout* set_layout, absl::Span<const iree_hal_descriptor_set_binding_t> bindings) { +Status MetalDevice::CreateDescriptorSet( + iree_hal_descriptor_set_layout_t* set_layout, absl::Span<const iree_hal_descriptor_set_binding_t> bindings, + iree_hal_descriptor_set_t** out_descriptor_set) { IREE_TRACE_SCOPE0("MetalDevice::CreateDescriptorSet"); return make_ref<MetalArgumentBuffer>(static_cast<MetalArgumentBufferLayout*>(set_layout), bindings); } -StatusOr<ref_ptr<CommandBuffer>> MetalDevice::CreateCommandBuffer( - iree_hal_command_buffer_mode_t mode, iree_hal_command_category_t command_categories) { +Status MetalDevice::CreateCommandBuffer( + iree_hal_command_buffer_mode_t mode, iree_hal_command_category_t command_categories, + iree_hal_command_buffer_t** out_command_buffer) { IREE_TRACE_SCOPE0("MetalDevice::CreateCommandBuffer"); @autoreleasepool { StatusOr<ref_ptr<CommandBuffer>> command_buffer; @@ -129,18 +125,18 @@ } } -StatusOr<ref_ptr<Event>> MetalDevice::CreateEvent() { +Status MetalDevice::CreateEvent(iree_hal_event_t** out_event) { IREE_TRACE_SCOPE0("MetalDevice::CreateEvent"); return UnimplementedErrorBuilder(IREE_LOC) << "MetalDevice::CreateEvent"; } -StatusOr<ref_ptr<Semaphore>> MetalDevice::CreateSemaphore(uint64_t initial_value) { +Status MetalDevice::CreateSemaphore(uint64_t initial_value, iree_hal_semaphore_t** out_semaphore) { IREE_TRACE_SCOPE0("MetalDevice::CreateSemaphore"); return MetalSharedEvent::Create(metal_handle_, event_listener_, initial_value); } -Status MetalDevice::WaitAllSemaphores(absl::Span<const SemaphoreValue> semaphores, - Time deadline_ns) { +Status MetalDevice::WaitAllSemaphores(const iree_hal_semaphore_list_t* semaphore_list, + iree_time_t deadline_ns) { IREE_TRACE_SCOPE0("MetalDevice::WaitAllSemaphores"); // Go through all MetalSharedEvents and wait on each of them given we need all of them to be // signaled anyway. @@ -151,8 +147,8 @@ return OkStatus(); } -StatusOr<int> MetalDevice::WaitAnySemaphore(absl::Span<const SemaphoreValue> semaphores, - Time deadline_ns) { +StatusOr<int> MetalDevice::WaitAnySemaphore(const iree_hal_semaphore_list_t* semaphore_list, + iree_time_t deadline_ns) { IREE_TRACE_SCOPE0("MetalDevice::WaitAnySemaphore"); if (semaphores.empty()) { @@ -204,7 +200,7 @@ return signaled_index; } -Status MetalDevice::WaitIdle(Time deadline_ns) { +Status MetalDevice::WaitIdle(iree_time_t deadline_ns) { IREE_TRACE_SCOPE0("MetalDevice::WaitIdle"); return UnimplementedErrorBuilder(IREE_LOC) << "MetalDevice::WaitIdle"; }
diff --git a/iree/hal/metal/metal_direct_allocator.h b/iree/hal/metal/metal_direct_allocator.h index db85f14..117d358 100644 --- a/iree/hal/metal/metal_direct_allocator.h +++ b/iree/hal/metal/metal_direct_allocator.h
@@ -20,7 +20,7 @@ #include <memory> #include "iree/base/status.h" -#include "iree/hal/cc/allocator.h" +#include "iree/hal/api.h" namespace iree { namespace hal { @@ -43,23 +43,10 @@ iree_hal_buffer_usage_t buffer_usage, iree_hal_buffer_usage_t intended_usage) const override; - bool CanAllocate(iree_hal_memory_type_t memory_type, - iree_hal_buffer_usage_t buffer_usage, - size_t allocation_size) const override; - - Status MakeCompatible(iree_hal_memory_type_t* memory_type, - iree_hal_buffer_usage_t* buffer_usage) const override; - StatusOr<ref_ptr<Buffer>> Allocate(iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t buffer_usage, size_t allocation_size) override; - StatusOr<ref_ptr<Buffer>> WrapMutable(iree_hal_memory_type_t memory_type, - iree_hal_memory_access_t allowed_access, - iree_hal_buffer_usage_t buffer_usage, - void* data, - size_t data_length) override; - private: explicit MetalDirectAllocator(id<MTLDevice> device, id<MTLCommandQueue> transfer_queue);
diff --git a/iree/hal/metal/metal_direct_allocator.mm b/iree/hal/metal/metal_direct_allocator.mm index 5b879c7..833006d 100644 --- a/iree/hal/metal/metal_direct_allocator.mm +++ b/iree/hal/metal/metal_direct_allocator.mm
@@ -93,19 +93,6 @@ return source_allocator == this; } -bool MetalDirectAllocator::CanAllocate(iree_hal_memory_type_t memory_type, - iree_hal_buffer_usage_t buffer_usage, - size_t allocation_size) const { - // TODO(benvanik): ensure there is a memory type that can satisfy the request. - return true; -} - -Status MetalDirectAllocator::MakeCompatible(iree_hal_memory_type_t* memory_type, - iree_hal_buffer_usage_t* buffer_usage) const { - // TODO(benvanik): mutate to match supported memory types. - return OkStatus(); -} - StatusOr<ref_ptr<MetalBuffer>> MetalDirectAllocator::AllocateInternal( iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t buffer_usage, iree_hal_memory_access_t allowed_access, size_t allocation_size) { @@ -120,7 +107,7 @@ id<MTLBuffer> metal_buffer = [metal_device_ newBufferWithLength:allocation_size options:resource_options]; // retained - return MetalBuffer::CreateUnretained( + return MetalBuffer::Create( this, memory_type, allowed_access, buffer_usage, allocation_size, /*byte_offset=*/0, /*byte_length=*/allocation_size, metal_buffer, metal_transfer_queue_); } @@ -132,14 +119,6 @@ return AllocateInternal(memory_type, buffer_usage, IREE_HAL_MEMORY_ACCESS_ALL, allocation_size); } -StatusOr<ref_ptr<Buffer>> MetalDirectAllocator::WrapMutable(iree_hal_memory_type_t memory_type, - iree_hal_memory_access_t allowed_access, - iree_hal_buffer_usage_t buffer_usage, - void* data, size_t data_length) { - IREE_TRACE_SCOPE0("MetalDirectAllocator::WrapMutable"); - return UnimplementedErrorBuilder(IREE_LOC) << "MetalDirectAllocator::WrapMutable"; -} - } // namespace metal } // namespace hal } // namespace iree
diff --git a/iree/hal/metal/metal_driver.h b/iree/hal/metal/metal_driver.h index b8c0d48..bb2603e 100644 --- a/iree/hal/metal/metal_driver.h +++ b/iree/hal/metal/metal_driver.h
@@ -18,7 +18,6 @@ #include <memory> #include <string> -#include "iree/hal/cc/debug_capture_manager.h" #include "iree/hal/cc/driver.h" namespace iree { @@ -51,12 +50,9 @@ iree_hal_device_id_t device_id) override; private: - MetalDriver(std::vector<DeviceInfo> devices, - std::unique_ptr<DebugCaptureManager> debug_capture_manager); + MetalDriver(std::vector<DeviceInfo> devices); std::vector<DeviceInfo> devices_; - - std::unique_ptr<DebugCaptureManager> debug_capture_manager_; }; } // namespace metal
diff --git a/iree/hal/metal/metal_driver.mm b/iree/hal/metal/metal_driver.mm index 4c5fa40..b561a31 100644 --- a/iree/hal/metal/metal_driver.mm +++ b/iree/hal/metal/metal_driver.mm
@@ -69,11 +69,9 @@ } } -MetalDriver::MetalDriver(std::vector<DeviceInfo> devices, - std::unique_ptr<DebugCaptureManager> debug_capture_manager) +MetalDriver::MetalDriver(std::vector<DeviceInfo> devices) : Driver("metal"), - devices_(std::move(devices)), - debug_capture_manager_(std::move(debug_capture_manager)) { + devices_(std::move(devices)) { // Retain all the retained Metal GPU devices. for (const auto& device : devices_) { [(__bridge id<MTLDevice>)device.device_id() retain]; @@ -109,7 +107,7 @@ for (const DeviceInfo& info : devices_) { if (info.device_id() == device_id) { - return MetalDevice::Create(add_ref(this), info, debug_capture_manager_.get()); + return MetalDevice::Create(add_ref(this), info); } } return InvalidArgumentErrorBuilder(IREE_LOC) << "unknown driver device id: " << device_id;
diff --git a/iree/hal/metal/metal_kernel_library.h b/iree/hal/metal/metal_kernel_library.h index cb5378a..6916ddc 100644 --- a/iree/hal/metal/metal_kernel_library.h +++ b/iree/hal/metal/metal_kernel_library.h
@@ -21,7 +21,6 @@ #include "absl/container/inlined_vector.h" #include "iree/base/status.h" -#include "iree/hal/cc/executable.h" #include "iree/hal/cc/executable_cache.h" // flatcc schemas:
diff --git a/iree/hal/metal/metal_pipeline_argument_buffer.cc b/iree/hal/metal/metal_pipeline_argument_buffer.cc index 7683c93..1879a81 100644 --- a/iree/hal/metal/metal_pipeline_argument_buffer.cc +++ b/iree/hal/metal/metal_pipeline_argument_buffer.cc
@@ -34,16 +34,6 @@ return nullptr; } -std::string MetalArgumentBufferLayout::DebugString() const { - std::vector<std::string> binding_strings; - binding_strings.reserve(bindings_.size()); - for (const auto& binding : bindings_) { - binding_strings.push_back( - absl::StrCat("[", binding.DebugStringShort(), "]")); - } - return absl::StrCat("bindings=[", absl::StrJoin(binding_strings, ", "), "]"); -} - MetalPipelineArgumentBufferLayout::MetalPipelineArgumentBufferLayout( absl::Span<DescriptorSetLayout* const> set_layouts, size_t push_constants) : set_layouts_(set_layouts.size()), push_constants_(push_constants) { @@ -57,16 +47,6 @@ for (auto* layout : set_layouts_) layout->ReleaseReference(); } -std::string MetalPipelineArgumentBufferLayout::DebugString() const { - std::vector<std::string> set_strings; - set_strings.reserve(set_layouts_.size()); - for (int i = 0; i < set_layouts_.size(); ++i) { - set_strings.push_back( - absl::StrCat("{set=", i, ", ", set_layouts_[i]->DebugString(), "}")); - } - return absl::StrCat("sets={", absl::StrJoin(set_strings, "; "), "}"); -} - MetalArgumentBuffer::MetalArgumentBuffer( MetalArgumentBufferLayout* layout, absl::Span<const iree_hal_descriptor_set_binding_t> resources)
diff --git a/iree/hal/metal/metal_pipeline_argument_buffer.h b/iree/hal/metal/metal_pipeline_argument_buffer.h index a83a3e4..0d5456a 100644 --- a/iree/hal/metal/metal_pipeline_argument_buffer.h +++ b/iree/hal/metal/metal_pipeline_argument_buffer.h
@@ -41,8 +41,6 @@ absl::Span<const Binding> bindings() const { return bindings_; } const Binding* GetBindingForIndex(int index) const; - std::string DebugString() const override; - private: UsageType usage_type_; absl::InlinedVector<Binding, 8> bindings_; @@ -59,8 +57,6 @@ return set_layouts_; } - std::string DebugString() const override; - private: absl::InlinedVector<MetalArgumentBufferLayout*, 2> set_layouts_; size_t push_constants_;
diff --git a/iree/hal/metal/metal_pipeline_cache.h b/iree/hal/metal/metal_pipeline_cache.h index 62c0190..7b74909 100644 --- a/iree/hal/metal/metal_pipeline_cache.h +++ b/iree/hal/metal/metal_pipeline_cache.h
@@ -17,7 +17,6 @@ #import <Metal/Metal.h> -#include "iree/hal/cc/executable.h" #include "iree/hal/cc/executable_cache.h" namespace iree { @@ -30,7 +29,7 @@ explicit MetalPipelineCache(id<MTLDevice> device); ~MetalPipelineCache() override; - bool CanPrepareFormat(ExecutableFormat format) const override; + bool CanPrepareFormat(iree_hal_executable_format_t format) const override; StatusOr<ref_ptr<Executable>> PrepareExecutable( ExecutableLayout* executable_layout,
diff --git a/iree/hal/metal/metal_pipeline_cache.mm b/iree/hal/metal/metal_pipeline_cache.mm index ad60aaf..ec328e4 100644 --- a/iree/hal/metal/metal_pipeline_cache.mm +++ b/iree/hal/metal/metal_pipeline_cache.mm
@@ -16,18 +16,21 @@ #include "iree/base/status.h" #include "iree/base/tracing.h" -#include "iree/hal/cc/executable_format.h" +#include "iree/hal/api.h" #include "iree/hal/metal/metal_kernel_library.h" namespace iree { namespace hal { namespace metal { +static const iree_hal_executable_format_t kExecutableFormatMetal = + iree_hal_make_executable_format("MTLE"); + MetalPipelineCache::MetalPipelineCache(id<MTLDevice> device) : metal_device_([device retain]) {} MetalPipelineCache::~MetalPipelineCache() { [metal_device_ release]; } -bool MetalPipelineCache::CanPrepareFormat(ExecutableFormat format) const { +bool MetalPipelineCache::CanPrepareFormat(iree_hal_executable_format_t format) const { return format == kExecutableFormatMetal; }
diff --git a/iree/hal/resource.h b/iree/hal/resource.h index e3579a6..d906a9d 100644 --- a/iree/hal/resource.h +++ b/iree/hal/resource.h
@@ -20,6 +20,7 @@ #include "iree/base/api.h" #include "iree/base/atomics.h" +#include "iree/base/debugging.h" #ifdef __cplusplus extern "C" { @@ -55,6 +56,24 @@ out_resource->vtable = vtable; } +// Returns true if the |resource| has the given |vtable| type. +// This is *not* a way to ensure that an instance is of a specific type but +// instead that it has a compatible vtable. This is because LTO may very rarely +// dedupe identical vtables and cause the pointer comparison to succeed even if +// the spellings of the types differs. +static inline bool iree_hal_resource_is(const void* resource, + const void* vtable) { + return resource ? ((const iree_hal_resource_t*)resource)->vtable == vtable + : false; +} + +// Asserts (**DEBUG ONLY**) that the |resource| has the given |vtable| type. +// This is only useful to check for programmer error and may have false +// positives - do not rely on it for handling untrusted user input. +#define IREE_HAL_ASSERT_TYPE(resource, vtable) \ + IREE_ASSERT_TRUE(iree_hal_resource_is(resource, vtable), \ + "type does not match expected " #vtable) + #ifdef __cplusplus } // extern "C" #endif // __cplusplus
diff --git a/iree/hal/vulkan/BUILD b/iree/hal/vulkan/BUILD index ac51e89..901ca88 100644 --- a/iree/hal/vulkan/BUILD +++ b/iree/hal/vulkan/BUILD
@@ -31,57 +31,109 @@ ) cc_library( - name = "api", - srcs = ["api.cc"], - hdrs = ["api.h"], - visibility = ["//visibility:public"], - deps = [ - ":utils", - ":vulkan", - "//iree/base:api", - "//iree/base:tracing", - "//iree/hal:api", - ], -) - -cc_library( - name = "utils", + name = "vulkan", srcs = [ + "api.cc", + "command_queue.h", "debug_reporter.cc", - "dynamic_symbols.cc", - "extensibility_util.cc", - "renderdoc_capture_manager.cc", - "status_util.cc", - "timepoint_util.cc", - ], - hdrs = [ "debug_reporter.h", - "dynamic_symbol_tables.h", - "dynamic_symbols.h", + "descriptor_pool_cache.cc", + "descriptor_pool_cache.h", + "descriptor_set_arena.cc", + "descriptor_set_arena.h", + "direct_command_buffer.cc", + "direct_command_buffer.h", + "direct_command_queue.cc", + "direct_command_queue.h", + "emulated_semaphore.cc", + "emulated_semaphore.h", + "extensibility_util.cc", "extensibility_util.h", "handle_util.h", - "renderdoc_capture_manager.h", + "internal_vk_mem_alloc.cc", + "internal_vk_mem_alloc.h", + "native_descriptor_set.cc", + "native_descriptor_set.h", + "native_descriptor_set_layout.cc", + "native_descriptor_set_layout.h", + "native_event.cc", + "native_event.h", + "native_executable.cc", + "native_executable.h", + "native_executable_layout.cc", + "native_executable_layout.h", + "native_semaphore.cc", + "native_semaphore.h", + "nop_executable_cache.cc", + "nop_executable_cache.h", + "serializing_command_queue.cc", + "serializing_command_queue.h", + "status_util.c", "status_util.h", + "timepoint_util.cc", "timepoint_util.h", + "vma_allocator.cc", + "vma_allocator.h", + "vma_buffer.cc", + "vma_buffer.h", + "vulkan_device.cc", + "vulkan_driver.cc", "vulkan_headers.h", ], + hdrs = [ + # TODO(benvanik): hide all but api.h. + "api.h", + "vulkan_device.h", + "vulkan_driver.h", + ], + visibility = ["//visibility:public"], deps = [ + ":dynamic_symbols", + "//iree/base:api", + "//iree/base:arena", "//iree/base:core_headers", - "//iree/base:dynamic_library", + "//iree/base:flatcc", "//iree/base:intrusive_list", "//iree/base:logging", "//iree/base:ref_ptr", "//iree/base:status", - "//iree/base:time", + "//iree/base:synchronization", "//iree/base:tracing", - "//iree/hal/cc", + "//iree/hal:api", + "//iree/schemas:spirv_executable_def_c_fbs", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@iree_vulkan_headers//:vulkan_headers", - "@renderdoc_api//:renderdoc_app", + "@vulkan_memory_allocator//:impl_header_only", + ], +) + +cc_library( + name = "dynamic_symbols", + srcs = [ + "dynamic_symbols.cc", + "vulkan_headers.h", + ], + hdrs = [ + "dynamic_symbol_tables.h", + "dynamic_symbols.h", + ], + deps = [ + "//iree/base:core_headers", + "//iree/base:dynamic_library", + "//iree/base:ref_ptr", + "//iree/base:status", + "//iree/base:tracing", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@iree_vulkan_headers//:vulkan_headers", ], ) @@ -90,91 +142,8 @@ srcs = ["dynamic_symbols_test.cc"], tags = ["driver=vulkan"], deps = [ - ":utils", + ":dynamic_symbols", "//iree/testing:gtest", "//iree/testing:gtest_main", ], ) - -cc_library( - name = "vma_allocator", - srcs = [ - "internal_vk_mem_alloc.cc", - "internal_vk_mem_alloc.h", - "vma_allocator.cc", - "vma_buffer.cc", - ], - hdrs = [ - "vma_allocator.h", - "vma_buffer.h", - ], - deps = [ - ":utils", - "//iree/base:core_headers", - "//iree/base:logging", - "//iree/base:status", - "//iree/base:tracing", - "//iree/hal/cc", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/synchronization", - "@vulkan_memory_allocator//:impl_header_only", - ], -) - -cc_library( - name = "vulkan", - srcs = [ - "descriptor_pool_cache.cc", - "descriptor_set_arena.cc", - "direct_command_buffer.cc", - "direct_command_queue.cc", - "emulated_timeline_semaphore.cc", - "native_descriptor_set.cc", - "native_event.cc", - "native_timeline_semaphore.cc", - "pipeline_cache.cc", - "pipeline_executable.cc", - "pipeline_executable_layout.cc", - "serializing_command_queue.cc", - "vulkan_device.cc", - "vulkan_driver.cc", - ], - hdrs = [ - "descriptor_pool_cache.h", - "descriptor_set_arena.h", - "direct_command_buffer.h", - "direct_command_queue.h", - "emulated_timeline_semaphore.h", - "native_descriptor_set.h", - "native_event.h", - "native_timeline_semaphore.h", - "pipeline_cache.h", - "pipeline_executable.h", - "pipeline_executable_layout.h", - "serializing_command_queue.h", - "vulkan_device.h", - "vulkan_driver.h", - ], - deps = [ - ":utils", - ":vma_allocator", - "//iree/base:api", - "//iree/base:arena", - "//iree/base:core_headers", - "//iree/base:flatcc", - "//iree/base:intrusive_list", - "//iree/base:ref_ptr", - "//iree/base:status", - "//iree/base:time", - "//iree/base:tracing", - "//iree/hal/cc", - "//iree/schemas:spirv_executable_def_c_fbs", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", - ], -)
diff --git a/iree/hal/vulkan/CMakeLists.txt b/iree/hal/vulkan/CMakeLists.txt index 17f3d50..016e114 100644 --- a/iree/hal/vulkan/CMakeLists.txt +++ b/iree/hal/vulkan/CMakeLists.txt
@@ -20,135 +20,63 @@ iree_cc_library( NAME - api - HDRS - "api.h" - SRCS - "api.cc" - DEPS - ::utils - ::vulkan - iree::base::api - iree::base::tracing - iree::hal::api - PUBLIC -) - -iree_cc_library( - NAME - utils - HDRS - "debug_reporter.h" - "dynamic_symbol_tables.h" - "dynamic_symbols.h" - "extensibility_util.h" - "handle_util.h" - "renderdoc_capture_manager.h" - "status_util.h" - "timepoint_util.h" - "vulkan_headers.h" - SRCS - "debug_reporter.cc" - "dynamic_symbols.cc" - "extensibility_util.cc" - "renderdoc_capture_manager.cc" - "status_util.cc" - "timepoint_util.cc" - DEPS - Vulkan::Headers - absl::core_headers - absl::memory - absl::span - absl::strings - absl::synchronization - iree::base::core_headers - iree::base::dynamic_library - iree::base::intrusive_list - iree::base::logging - iree::base::ref_ptr - iree::base::status - iree::base::time - iree::base::tracing - iree::hal::cc - renderdoc_api::renderdoc_app - PUBLIC -) - -iree_cc_test( - NAME - dynamic_symbols_test - SRCS - "dynamic_symbols_test.cc" - DEPS - ::utils - iree::testing::gtest - iree::testing::gtest_main - LABELS - "driver=vulkan" -) - -iree_cc_library( - NAME - vma_allocator - HDRS - "vma_allocator.h" - "vma_buffer.h" - SRCS - "internal_vk_mem_alloc.cc" - "internal_vk_mem_alloc.h" - "vma_allocator.cc" - "vma_buffer.cc" - DEPS - ::utils - absl::flat_hash_map - absl::memory - absl::synchronization - iree::base::core_headers - iree::base::logging - iree::base::status - iree::base::tracing - iree::hal::cc - vulkan_memory_allocator - PUBLIC -) - -iree_cc_library( - NAME vulkan HDRS - "descriptor_pool_cache.h" - "descriptor_set_arena.h" - "direct_command_buffer.h" - "direct_command_queue.h" - "emulated_timeline_semaphore.h" - "native_descriptor_set.h" - "native_event.h" - "native_timeline_semaphore.h" - "pipeline_cache.h" - "pipeline_executable.h" - "pipeline_executable_layout.h" - "serializing_command_queue.h" + "api.h" "vulkan_device.h" "vulkan_driver.h" SRCS + "api.cc" + "command_queue.h" + "debug_reporter.cc" + "debug_reporter.h" "descriptor_pool_cache.cc" + "descriptor_pool_cache.h" "descriptor_set_arena.cc" + "descriptor_set_arena.h" "direct_command_buffer.cc" + "direct_command_buffer.h" "direct_command_queue.cc" - "emulated_timeline_semaphore.cc" + "direct_command_queue.h" + "emulated_semaphore.cc" + "emulated_semaphore.h" + "extensibility_util.cc" + "extensibility_util.h" + "handle_util.h" + "internal_vk_mem_alloc.cc" + "internal_vk_mem_alloc.h" "native_descriptor_set.cc" + "native_descriptor_set.h" + "native_descriptor_set_layout.cc" + "native_descriptor_set_layout.h" "native_event.cc" - "native_timeline_semaphore.cc" - "pipeline_cache.cc" - "pipeline_executable.cc" - "pipeline_executable_layout.cc" + "native_event.h" + "native_executable.cc" + "native_executable.h" + "native_executable_layout.cc" + "native_executable_layout.h" + "native_semaphore.cc" + "native_semaphore.h" + "nop_executable_cache.cc" + "nop_executable_cache.h" "serializing_command_queue.cc" + "serializing_command_queue.h" + "status_util.c" + "status_util.h" + "timepoint_util.cc" + "timepoint_util.h" + "vma_allocator.cc" + "vma_allocator.h" + "vma_buffer.cc" + "vma_buffer.h" "vulkan_device.cc" "vulkan_driver.cc" + "vulkan_headers.h" DEPS - ::utils - ::vma_allocator + ::dynamic_symbols + Vulkan::Headers absl::core_headers + absl::flat_hash_map absl::inlined_vector absl::memory absl::span @@ -159,11 +87,49 @@ iree::base::core_headers iree::base::flatcc iree::base::intrusive_list + iree::base::logging iree::base::ref_ptr iree::base::status - iree::base::time + iree::base::synchronization iree::base::tracing - iree::hal::cc + iree::hal::api iree::schemas::spirv_executable_def_c_fbs + vulkan_memory_allocator PUBLIC ) + +iree_cc_library( + NAME + dynamic_symbols + HDRS + "dynamic_symbol_tables.h" + "dynamic_symbols.h" + SRCS + "dynamic_symbols.cc" + "vulkan_headers.h" + DEPS + Vulkan::Headers + absl::core_headers + absl::memory + absl::span + absl::strings + iree::base::core_headers + iree::base::dynamic_library + iree::base::ref_ptr + iree::base::status + iree::base::tracing + PUBLIC +) + +iree_cc_test( + NAME + dynamic_symbols_test + SRCS + "dynamic_symbols_test.cc" + DEPS + ::dynamic_symbols + iree::testing::gtest + iree::testing::gtest_main + LABELS + "driver=vulkan" +)
diff --git a/iree/hal/vulkan/api.cc b/iree/hal/vulkan/api.cc index d664577..5653ef9 100644 --- a/iree/hal/vulkan/api.cc +++ b/iree/hal/vulkan/api.cc
@@ -21,16 +21,17 @@ #include "iree/hal/vulkan/vulkan_device.h" #include "iree/hal/vulkan/vulkan_driver.h" -namespace iree { -namespace hal { -namespace vulkan { +using namespace iree::hal::vulkan; + +// TODO(benvanik): move these into the appropriate files and delete this .cc. //===----------------------------------------------------------------------===// // iree::hal::vulkan::DynamicSymbols //===----------------------------------------------------------------------===// IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_syms_create( - void* vkGetInstanceProcAddr_fn, iree_hal_vulkan_syms_t** out_syms) { + void* vkGetInstanceProcAddr_fn, iree_allocator_t host_allocator, + iree_hal_vulkan_syms_t** out_syms) { IREE_TRACE_SCOPE0("iree_hal_vulkan_syms_create"); IREE_ASSERT_ARGUMENT(out_syms); *out_syms = nullptr; @@ -53,7 +54,7 @@ IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_syms_create_from_system_loader( - iree_hal_vulkan_syms_t** out_syms) { + iree_allocator_t host_allocator, iree_hal_vulkan_syms_t** out_syms) { IREE_TRACE_SCOPE0("iree_hal_vulkan_syms_create_from_system_loader"); IREE_ASSERT_ARGUMENT(out_syms); *out_syms = nullptr; @@ -63,248 +64,20 @@ return iree_ok_status(); } -IREE_API_EXPORT iree_status_t -iree_hal_vulkan_syms_release(iree_hal_vulkan_syms_t* syms) { - IREE_TRACE_SCOPE0("iree_hal_vulkan_syms_release"); +IREE_API_EXPORT void IREE_API_CALL +iree_hal_vulkan_syms_retain(iree_hal_vulkan_syms_t* syms) { IREE_ASSERT_ARGUMENT(syms); auto* handle = reinterpret_cast<DynamicSymbols*>(syms); - handle->ReleaseReference(); - return iree_ok_status(); + if (handle) { + handle->AddReference(); + } } -//===----------------------------------------------------------------------===// -// iree::hal::vulkan Extensibility Util -//===----------------------------------------------------------------------===// - -namespace { - -ExtensibilitySpec GetInstanceExtensibilitySpec( - const iree_hal_vulkan_features_t& features) { - ExtensibilitySpec spec; - - // Multiple extensions depend on VK_KHR_get_physical_device_properties2. - // This extension was deprecated in Vulkan 1.1 as its functionality was - // promoted to core, so we list it as optional even though we require it. - spec.optional_extensions.push_back( - VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME); - - if (features & IREE_HAL_VULKAN_ENABLE_VALIDATION_LAYERS) { - spec.optional_layers.push_back("VK_LAYER_KHRONOS_standard_validation"); - } - - if (features & IREE_HAL_VULKAN_ENABLE_DEBUG_UTILS) { - spec.optional_extensions.push_back(VK_EXT_DEBUG_UTILS_EXTENSION_NAME); - } - - // Polyfill layer - enable if present. - spec.optional_layers.push_back("VK_LAYER_KHRONOS_timeline_semaphore"); - - return spec; -} - -ExtensibilitySpec GetDeviceExtensibilitySpec( - const iree_hal_vulkan_features_t& features) { - ExtensibilitySpec spec; - - // REQUIRED: these are required extensions that must be present for IREE to - // work (such as those relied upon by SPIR-V kernels, etc). - spec.required_extensions.push_back( - VK_KHR_STORAGE_BUFFER_STORAGE_CLASS_EXTENSION_NAME); - // Timeline semaphore support is required. - spec.required_extensions.push_back(VK_KHR_TIMELINE_SEMAPHORE_EXTENSION_NAME); - - if (features & IREE_HAL_VULKAN_ENABLE_PUSH_DESCRIPTORS) { - spec.optional_extensions.push_back(VK_KHR_PUSH_DESCRIPTOR_EXTENSION_NAME); - } - - return spec; -} - -} // namespace - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_get_extensions( - iree_hal_vulkan_extensibility_set_t extensibility_set, - iree_hal_vulkan_features_t features, iree_host_size_t extensions_capacity, - const char** out_extensions, iree_host_size_t* out_extensions_count) { - IREE_ASSERT_ARGUMENT(out_extensions_count); - *out_extensions_count = 0; - - bool is_instance = extensibility_set & IREE_HAL_VULKAN_INSTANCE_BIT; - bool is_required = extensibility_set & IREE_HAL_VULKAN_REQUIRED_BIT; - - ExtensibilitySpec spec = is_instance ? GetInstanceExtensibilitySpec(features) - : GetDeviceExtensibilitySpec(features); - *out_extensions_count = is_required ? spec.required_extensions.size() - : spec.optional_extensions.size(); - - // Return early if only querying number of extensions in this configuration. - if (!out_extensions) { - return iree_ok_status(); - } - - if (extensions_capacity < *out_extensions_count) { - // Not an error; just a size query. - return iree_status_from_code(IREE_STATUS_OUT_OF_RANGE); - } - - const std::vector<const char*>& extensions = - is_required ? spec.required_extensions : spec.optional_extensions; - for (int i = 0; i < extensions.size(); ++i) { - out_extensions[i] = extensions[i]; - } - - return iree_ok_status(); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_get_layers( - iree_hal_vulkan_extensibility_set_t extensibility_set, - iree_hal_vulkan_features_t features, iree_host_size_t layers_capacity, - const char** out_layers, iree_host_size_t* out_layers_count) { - IREE_ASSERT_ARGUMENT(out_layers_count); - *out_layers_count = 0; - - // Device layers are deprecated and unsupported here. - if (!(extensibility_set & IREE_HAL_VULKAN_INSTANCE_BIT)) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "device layers are deprecated in Vulkan"); - } - - bool is_required = extensibility_set & IREE_HAL_VULKAN_REQUIRED_BIT; - - ExtensibilitySpec spec = GetInstanceExtensibilitySpec(features); - *out_layers_count = - is_required ? spec.required_layers.size() : spec.optional_layers.size(); - - // Return early if only querying number of layers in this configuration. - if (!out_layers) { - return iree_ok_status(); - } - - if (layers_capacity < *out_layers_count) { - // Not an error; just a size query. - return iree_status_from_code(IREE_STATUS_OUT_OF_RANGE); - } - - const std::vector<const char*>& layers = - is_required ? spec.required_layers : spec.optional_layers; - for (int i = 0; i < layers.size(); ++i) { - out_layers[i] = layers[i]; - } - - return iree_ok_status(); -} - -//===----------------------------------------------------------------------===// -// iree::hal::vulkan::VulkanDriver -//===----------------------------------------------------------------------===// - -namespace { - -VulkanDriver::Options ConvertDriverOptions( - iree_hal_vulkan_driver_options_t options) { - VulkanDriver::Options driver_options; - driver_options.api_version = options.api_version; - driver_options.instance_extensibility = - GetInstanceExtensibilitySpec(options.features); - driver_options.device_options.extensibility_spec = - GetDeviceExtensibilitySpec(options.features); - return driver_options; -} - -} // namespace - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_driver_create( - iree_hal_vulkan_driver_options_t options, iree_hal_vulkan_syms_t* syms, - iree_hal_driver_t** out_driver) { - IREE_TRACE_SCOPE0("iree_hal_vulkan_driver_create"); +IREE_API_EXPORT void IREE_API_CALL +iree_hal_vulkan_syms_release(iree_hal_vulkan_syms_t* syms) { IREE_ASSERT_ARGUMENT(syms); - IREE_ASSERT_ARGUMENT(out_driver); - *out_driver = nullptr; - - IREE_ASSIGN_OR_RETURN( - auto driver, - VulkanDriver::Create(ConvertDriverOptions(options), - add_ref(reinterpret_cast<DynamicSymbols*>(syms)))); - *out_driver = reinterpret_cast<iree_hal_driver_t*>(driver.release()); - return iree_ok_status(); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_vulkan_driver_create_using_instance( - iree_hal_vulkan_driver_options_t options, iree_hal_vulkan_syms_t* syms, - VkInstance instance, iree_hal_driver_t** out_driver) { - IREE_TRACE_SCOPE0("iree_hal_vulkan_driver_create_using_instance"); - IREE_ASSERT_ARGUMENT(syms); - IREE_ASSERT_ARGUMENT(instance); - IREE_ASSERT_ARGUMENT(out_driver); - *out_driver = nullptr; - - IREE_ASSIGN_OR_RETURN( - auto driver, - VulkanDriver::CreateUsingInstance( - ConvertDriverOptions(options), - add_ref(reinterpret_cast<DynamicSymbols*>(syms)), instance)); - *out_driver = reinterpret_cast<iree_hal_driver_t*>(driver.release()); - return iree_ok_status(); -} - -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_vulkan_driver_create_default_device(iree_hal_driver_t* driver, - iree_hal_device_t** out_device) { - IREE_TRACE_SCOPE0("iree_hal_vulkan_driver_create_default_device"); - IREE_ASSERT_ARGUMENT(driver); - IREE_ASSERT_ARGUMENT(out_device); - *out_device = nullptr; - - auto* handle = reinterpret_cast<VulkanDriver*>(driver); - - IREE_LOG(INFO) << "Enumerating available Vulkan devices..."; - IREE_ASSIGN_OR_RETURN(auto available_devices, - handle->EnumerateAvailableDevices()); - for (const auto& device_info : available_devices) { - IREE_LOG(INFO) << " Device: " << device_info.name(); + auto* handle = reinterpret_cast<DynamicSymbols*>(syms); + if (handle) { + handle->ReleaseReference(); } - IREE_LOG(INFO) << "Creating default device..."; - IREE_ASSIGN_OR_RETURN(auto device, handle->CreateDefaultDevice()); - IREE_LOG(INFO) << "Successfully created device '" << device->info().name() - << "'"; - - *out_device = reinterpret_cast<iree_hal_device_t*>(device.release()); - return iree_ok_status(); } - -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_driver_wrap_device( - iree_hal_driver_t* driver, VkPhysicalDevice physical_device, - VkDevice logical_device, iree_hal_vulkan_queue_set_t compute_queue_set, - iree_hal_vulkan_queue_set_t transfer_queue_set, - iree_hal_device_t** out_device) { - IREE_TRACE_SCOPE0("iree_hal_vulkan_driver_create_device"); - IREE_ASSERT_ARGUMENT(driver); - IREE_ASSERT_ARGUMENT(physical_device); - IREE_ASSERT_ARGUMENT(logical_device); - IREE_ASSERT_ARGUMENT(out_device); - *out_device = nullptr; - - auto* handle = reinterpret_cast<VulkanDriver*>(driver); - - IREE_LOG(INFO) << "Creating VulkanDevice..."; - QueueSet compute_qs; - compute_qs.queue_family_index = compute_queue_set.queue_family_index; - compute_qs.queue_indices = compute_queue_set.queue_indices; - QueueSet transfer_qs; - transfer_qs.queue_family_index = transfer_queue_set.queue_family_index; - transfer_qs.queue_indices = transfer_queue_set.queue_indices; - IREE_ASSIGN_OR_RETURN(auto device, - handle->WrapDevice(physical_device, logical_device, - compute_qs, transfer_qs)); - IREE_LOG(INFO) << "Successfully created device '" << device->info().name() - << "'"; - - *out_device = reinterpret_cast<iree_hal_device_t*>(device.release()); - - return iree_ok_status(); -} - -} // namespace vulkan -} // namespace hal -} // namespace iree
diff --git a/iree/hal/vulkan/api.h b/iree/hal/vulkan/api.h index 56d1346..3f5912a 100644 --- a/iree/hal/vulkan/api.h +++ b/iree/hal/vulkan/api.h
@@ -29,46 +29,115 @@ #endif // __cplusplus //===----------------------------------------------------------------------===// -// Types and Enums +// iree_hal_vulkan_device_t extensibility util //===----------------------------------------------------------------------===// -// Describes the type of a set of Vulkan extensions. -typedef enum { - IREE_HAL_VULKAN_REQUIRED_BIT = 1 << 0, - IREE_HAL_VULKAN_INSTANCE_BIT = 1 << 1, - - // A set of required instance extension names. - IREE_HAL_VULKAN_INSTANCE_REQUIRED = - IREE_HAL_VULKAN_INSTANCE_BIT | IREE_HAL_VULKAN_REQUIRED_BIT, - // A set of optional instance extension names. - IREE_HAL_VULKAN_INSTANCE_OPTIONAL = IREE_HAL_VULKAN_INSTANCE_BIT, - // A set of required device extension names. - IREE_HAL_VULKAN_DEVICE_REQUIRED = IREE_HAL_VULKAN_REQUIRED_BIT, - // A set of optional device extension names. - IREE_HAL_VULKAN_DEVICE_OPTIONAL = 0, -} iree_hal_vulkan_extensibility_set_t; - +// TODO(benvanik): replace with feature list (easier to version). // Bitfield that defines sets of Vulkan features. -typedef enum { - // Use VK_LAYER_KHRONOS_standard_validation. - IREE_HAL_VULKAN_ENABLE_VALIDATION_LAYERS = 1 << 0, +enum iree_hal_vulkan_feature_e { + // Use VK_LAYER_KHRONOS_standard_validation to validate Vulkan API usage. + // Has a significant performance penalty and is *not* a security mechanism. + IREE_HAL_VULKAN_FEATURE_ENABLE_VALIDATION_LAYERS = 1 << 0, // Use VK_EXT_debug_utils, record markers, and log errors. - IREE_HAL_VULKAN_ENABLE_DEBUG_UTILS = 1 << 1, + IREE_HAL_VULKAN_FEATURE_ENABLE_DEBUG_UTILS = 1 << 1, +}; +typedef uint64_t iree_hal_vulkan_features_t; - // Use vkCmdPushDescriptorSetKHR. - IREE_HAL_VULKAN_ENABLE_PUSH_DESCRIPTORS = 1 << 2, -} iree_hal_vulkan_features_t; +// Describes the type of a set of Vulkan extensions. +enum iree_hal_vulkan_extensibility_set_e { + // A set of required instance layer names. These must all be enabled on + // the VkInstance for IREE to function. + IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_LAYERS_REQUIRED = 0, -// Vulkan driver creation options. -typedef struct { - // Vulkan version that will be requested, e.g. `VK_API_VERSION_1_0`. - // Driver creation will fail if the required version is not available. - uint32_t api_version; + // A set of optional instance layer names. If omitted fallbacks may be + // used or debugging features may not be available. + IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_LAYERS_OPTIONAL = 1, - // Vulkan features to request. - iree_hal_vulkan_features_t features; -} iree_hal_vulkan_driver_options_t; + // A set of required instance extension names. These must all be enabled on + // the VkInstance for IREE to function. + IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_EXTENSIONS_REQUIRED = 2, + + // A set of optional instance extension names. If omitted fallbacks may be + // used or debugging features may not be available. + IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_EXTENSIONS_OPTIONAL = 3, + + // A set of required device extension names. These must all be enabled on + // the VkDevice for IREE to function. + IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_REQUIRED = 4, + + // A set of optional device extension names. If omitted fallbacks may be + // used or debugging features may not be available. + IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_OPTIONAL = 5, + + IREE_HAL_VULKAN_EXTENSIBILITY_SET_COUNT, +}; +typedef uint32_t iree_hal_vulkan_extensibility_set_t; + +// Queries the names of the Vulkan layers and extensions used for a given set of +// IREE |requested_features|. All devices used by IREE must have the required +// layers and extensions as defined by these sets. Optional layers and +// extensions will be used when needed and otherwise have fallbacks for when +// they are not available. +// +// Instance extensions should be enabled on VkInstances passed to +// |iree_hal_vulkan_driver_create_using_instance| and device extensions should +// be enabled on VkDevices passed to |iree_hal_vulkan_driver_wrap_device|. +// +// |string_capacity| defines the number of elements available in +// |out_string_values| and |out_string_count| will be set with the actual number +// of strings returned. If |string_capacity| is too small then +// IREE_STATUS_OUT_OF_RANGE will be returned with the required capacity in +// |out_string_count|. To only query the required capacity then +// |out_string_values| may be passed as NULL. +// +// The returned strings originate from the _EXTENSION_NAME Vulkan macros +// (such as 'VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME') and have a +// lifetime matching whatever module they are defined in. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_vulkan_query_extensibility_set( + iree_hal_vulkan_features_t requested_features, + iree_hal_vulkan_extensibility_set_t set, iree_host_size_t string_capacity, + const char** out_string_values, iree_host_size_t* out_string_count); + +//===----------------------------------------------------------------------===// +// iree_hal_vulkan_syms_t +//===----------------------------------------------------------------------===// + +typedef struct iree_hal_vulkan_syms_s iree_hal_vulkan_syms_t; + +// Loads Vulkan functions by invoking |vkGetInstanceProcAddr|. +// +// |vkGetInstanceProcAddr| can be obtained in whatever way suites the calling +// application, such as via `dlsym` or `GetProcAddress` when dynamically +// loading Vulkan, or `reinterpret_cast<void*>(&vkGetInstanceProcAddr)` when +// statically linking Vulkan. +// +// |out_syms| must be released by the caller. +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_syms_create( + void* vkGetInstanceProcAddr_fn, iree_allocator_t host_allocator, + iree_hal_vulkan_syms_t** out_syms); + +// Loads Vulkan functions from the Vulkan loader. +// This will look for a Vulkan loader on the system (like libvulkan.so) and +// dlsym the functions from that. +// +// |out_syms| must be released by the caller with iree_hal_vulkan_syms_release. +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_vulkan_syms_create_from_system_loader( + iree_allocator_t host_allocator, iree_hal_vulkan_syms_t** out_syms); + +// Retains the given |syms| for the caller. +IREE_API_EXPORT void IREE_API_CALL +iree_hal_vulkan_syms_retain(iree_hal_vulkan_syms_t* syms); + +// Releases the given |syms| from the caller. +IREE_API_EXPORT void IREE_API_CALL +iree_hal_vulkan_syms_release(iree_hal_vulkan_syms_t* syms); + +//===----------------------------------------------------------------------===// +// iree_hal_vulkan_device_t +//===----------------------------------------------------------------------===// // A set of queues within a specific queue family on a VkDevice. typedef struct { @@ -80,114 +149,24 @@ uint64_t queue_indices; } iree_hal_vulkan_queue_set_t; -typedef struct iree_hal_vulkan_syms iree_hal_vulkan_syms_t; +// TODO(benvanik): replace with flag list (easier to version). +enum iree_hal_vulkan_device_flag_e { + // Uses timeline semaphore emulation even if native support exists. + // May be removed in future versions when timeline semaphores can be assumed + // present on all platforms (looking at you, Android ಠ_ಠ). + IREE_HAL_VULKAN_DEVICE_FORCE_TIMELINE_SEMAPHORE_EMULATION = 1 << 0, +}; +typedef uint64_t iree_hal_vulkan_device_flags_t; -//===----------------------------------------------------------------------===// -// iree::hal::vulkan::DynamicSymbols -//===----------------------------------------------------------------------===// +typedef struct { + // Flags controlling device behavior. + iree_hal_vulkan_device_flags_t flags; +} iree_hal_vulkan_device_options_t; -// Loads Vulkan functions by invoking |vkGetInstanceProcAddr|. -// -// |vkGetInstanceProcAddr| can be obtained in whatever way suites the calling -// application, such as via `dlsym` or `GetProcAddress` when dynamically -// loading Vulkan, or `reinterpret_cast<void*>(&vkGetInstanceProcAddr)` when -// statically linking Vulkan. -// -// |out_syms| must be released by the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_syms_create( - void* vkGetInstanceProcAddr_fn, iree_hal_vulkan_syms_t** out_syms); +IREE_API_EXPORT void IREE_API_CALL iree_hal_vulkan_device_options_initialize( + iree_hal_vulkan_device_options_t* out_options); -// Loads Vulkan functions from the Vulkan loader. -// This will look for a Vulkan loader on the system (like libvulkan.so) and -// dlsym the functions from that. -// -// |out_syms| must be released by the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_vulkan_syms_create_from_system_loader( - iree_hal_vulkan_syms_t** out_syms); - -// Releases the given |syms| from the caller. -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_vulkan_syms_release(iree_hal_vulkan_syms_t* syms); - -//===----------------------------------------------------------------------===// -// iree::hal::vulkan Extensibility Util -//===----------------------------------------------------------------------===// - -// Gets the names of the Vulkan extensions used for a given set of |features|. -// -// Instance extensions should be enabled on VkInstances passed to -// |iree_hal_vulkan_driver_create_using_instance| and device extensions should -// be enabled on VkDevices passed to |iree_hal_vulkan_driver_wrap_device|. -// -// |extensions_capacity| defines the number of elements available in -// |out_extensions| and |out_extensions_count| will be set with the actual -// number of extensions returned. If |extensions_capacity| is too small -// IREE_STATUS_OUT_OF_RANGE will be returned with the required capacity in -// |out_extensions_count|. To only query the required capacity |out_extensions| -// may be passed as nullptr. -// -// Extension string lifetime is tied to the loader shared object or instance, -// depending on where they came from. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_get_extensions( - iree_hal_vulkan_extensibility_set_t extensibility_set, - iree_hal_vulkan_features_t features, iree_host_size_t extensions_capacity, - const char** out_extensions, iree_host_size_t* out_extensions_count); - -// Gets the names of the Vulkan layers used for a given set of |features|. -// -// Instance layers should be enabled on VkInstances passed to -// |iree_hal_vulkan_driver_create_using_instance|. Device layers are deprecated -// and unsupported here. -// -// |layers_capacity| defines the number of elements available in |out_layers| -// and |out_layers_count| will be set with the actual number of layers returned. -// If |layers_capacity| is too small IREE_STATUS_OUT_OF_RANGE will be returned -// with the required capacity in |out_layers_count|. To only query the required -// capacity |out_layers| may be passed as nullptr. -// -// Layer string lifetime is tied to the loader shared object or instance, -// depending on where they came from. -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_get_layers( - iree_hal_vulkan_extensibility_set_t extensibility_set, - iree_hal_vulkan_features_t features, iree_host_size_t layers_capacity, - const char** out_layers, iree_host_size_t* out_layers_count); - -//===----------------------------------------------------------------------===// -// iree::hal::vulkan::VulkanDriver -//===----------------------------------------------------------------------===// - -// TODO(scotttodd): Allow applications to provide their own allocators here - -// Creates a Vulkan HAL driver that manages its own VkInstance. -// -// |out_driver| must be released by the caller (see |iree_hal_driver_release|). -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_driver_create( - iree_hal_vulkan_driver_options_t options, iree_hal_vulkan_syms_t* syms, - iree_hal_driver_t** out_driver); - -// Creates a Vulkan HAL driver that shares an existing VkInstance. -// -// |instance| is expected to have been created with all extensions returned by -// |iree_hal_vulkan_get_extensions| and IREE_HAL_VULKAN_INSTANCE_REQUIRED using -// |options| enabled. -// -// |instance| must remain valid for the life of |out_driver| and |out_driver| -// itself must be released by the caller (see |iree_hal_driver_release|). -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_vulkan_driver_create_using_instance( - iree_hal_vulkan_driver_options_t options, iree_hal_vulkan_syms_t* syms, - VkInstance instance, iree_hal_driver_t** out_driver); - -// Creates the default Vulkan HAL device using |driver| that manages its own -// VkPhysicalDevice/VkDevice. -// -// |out_device| must be released by the caller (see |iree_hal_device_release|). -IREE_API_EXPORT iree_status_t IREE_API_CALL -iree_hal_vulkan_driver_create_default_device(iree_hal_driver_t* driver, - iree_hal_device_t** out_device); - -// Creates a Vulkan HAL device using |driver| that wraps an existing VkDevice. +// Creates a Vulkan HAL device that wraps an existing VkDevice. // // HAL devices created in this way may share Vulkan resources and synchronize // within the same physical VkPhysicalDevice and logical VkDevice directly. @@ -197,6 +176,9 @@ // IREE_HAL_VULKAN_DEVICE_REQUIRED using the features provided during driver // creation. // +// |instance_syms| must have at least the instance-specific functions resolved +// and device symbols will be queried from |logical_device| as needed. +// // The device will schedule commands against the queues in // |compute_queue_set| and (if set) |transfer_queue_set|. // @@ -210,14 +192,74 @@ // |compute_queue_set|, if they are available. // Similarly, dedicated transfer queues (no compute or graphics) are preferred // within |transfer_queue_set|. -// The queues may be the same. +// The queue sets can be the same. // // |out_device| must be released by the caller (see |iree_hal_device_release|). -IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_driver_wrap_device( - iree_hal_driver_t* driver, VkPhysicalDevice physical_device, - VkDevice logical_device, iree_hal_vulkan_queue_set_t compute_queue_set, - iree_hal_vulkan_queue_set_t transfer_queue_set, - iree_hal_device_t** out_device); +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_wrap_device( + iree_string_view_t identifier, + const iree_hal_vulkan_device_options_t* options, + const iree_hal_vulkan_syms_t* instance_syms, VkInstance instance, + VkPhysicalDevice physical_device, VkDevice logical_device, + const iree_hal_vulkan_queue_set_t* compute_queue_set, + const iree_hal_vulkan_queue_set_t* transfer_queue_set, + iree_allocator_t host_allocator, iree_hal_device_t** out_device); + +//===----------------------------------------------------------------------===// +// iree_hal_vulkan_driver_t +//===----------------------------------------------------------------------===// + +// Vulkan driver creation options. +typedef struct { + // Vulkan version that will be requested, e.g. `VK_API_VERSION_1_0`. + // Driver creation will fail if the required version is not available. + uint32_t api_version; + + // IREE features used to configure the VkInstance and VkDevices created using + // it. These are used to populate the active Vulkan layers and extensions when + // the instance and its devices are created. + iree_hal_vulkan_features_t requested_features; + + // TODO(benvanik): remove this single setting - it would be nice instead to + // pass a list to force device enumeration/matrix expansion or omit entirely + // to have auto-discovered options based on capabilities. Right now this + // forces all devices - even if from different vendors - to have the same + // options. + // Options to use for all devices created by the driver. + iree_hal_vulkan_device_options_t device_options; + + // TODO(benvanik): change to something more canonically vulkan (like + // VkPhysicalDeviceProperties::deviceID). + // Index of the default Vulkan device to use within the list of available + // devices. Devices are discovered via vkEnumeratePhysicalDevices then + // considered "available" if compatible with the |requested_features|. + int default_device_index; +} iree_hal_vulkan_driver_options_t; + +IREE_API_EXPORT void IREE_API_CALL iree_hal_vulkan_driver_options_initialize( + iree_hal_vulkan_driver_options_t* out_options); + +// Creates a Vulkan HAL driver that manages its own VkInstance. +// +// |out_driver| must be released by the caller (see |iree_hal_driver_release|). +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_driver_create( + iree_string_view_t identifier, + const iree_hal_vulkan_driver_options_t* options, + iree_hal_vulkan_syms_t* syms, iree_allocator_t host_allocator, + iree_hal_driver_t** out_driver); + +// Creates a Vulkan HAL driver that shares an existing VkInstance. +// +// |instance| is expected to have been created with all extensions returned by +// the instance-specific |iree_hal_vulkan_query_extensibility_set| queries. +// +// |instance| must remain valid for the life of |out_driver| and |out_driver| +// itself must be released by the caller (see |iree_hal_driver_release|). +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_vulkan_driver_create_using_instance( + iree_string_view_t identifier, + const iree_hal_vulkan_driver_options_t* options, + iree_hal_vulkan_syms_t* instance_syms, VkInstance instance, + iree_allocator_t host_allocator, iree_hal_driver_t** out_driver); #ifdef __cplusplus } // extern "C"
diff --git a/iree/hal/vulkan/command_queue.h b/iree/hal/vulkan/command_queue.h new file mode 100644 index 0000000..47b3212 --- /dev/null +++ b/iree/hal/vulkan/command_queue.h
@@ -0,0 +1,77 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_VULKAN_COMMAND_QUEUE_H_ +#define IREE_HAL_VULKAN_COMMAND_QUEUE_H_ + +#include <string> + +#include "iree/base/arena.h" +#include "iree/base/status.h" +#include "iree/base/synchronization.h" +#include "iree/hal/api.h" +#include "iree/hal/vulkan/dynamic_symbols.h" +#include "iree/hal/vulkan/handle_util.h" + +namespace iree { +namespace hal { +namespace vulkan { + +class CommandQueue { + public: + virtual ~CommandQueue() { + IREE_TRACE_SCOPE0("CommandQueue::dtor"); + iree_slim_mutex_lock(&queue_mutex_); + syms()->vkQueueWaitIdle(queue_); + iree_slim_mutex_unlock(&queue_mutex_); + iree_slim_mutex_deinitialize(&queue_mutex_); + } + + const ref_ptr<DynamicSymbols>& syms() const { + return logical_device_->syms(); + } + + bool can_dispatch() const { + return iree_all_bits_set(supported_categories_, + IREE_HAL_COMMAND_CATEGORY_DISPATCH); + } + virtual iree_status_t Submit(iree_host_size_t batch_count, + const iree_hal_submission_batch_t* batches) = 0; + + virtual iree_status_t WaitIdle(iree_time_t deadline_ns) = 0; + + protected: + CommandQueue(VkDeviceHandle* logical_device, std::string name, + iree_hal_command_category_t supported_categories, VkQueue queue) + : logical_device_(logical_device), + name_(std::move(name)), + supported_categories_(supported_categories), + queue_(queue) { + iree_slim_mutex_initialize(&queue_mutex_); + } + + VkDeviceHandle* logical_device_; + const std::string name_; + const iree_hal_command_category_t supported_categories_; + + // VkQueue needs to be externally synchronized. + iree_slim_mutex_t queue_mutex_; + VkQueue queue_ IREE_GUARDED_BY(queue_mutex_); +}; + +} // namespace vulkan +} // namespace hal +} // namespace iree + +#endif // IREE_HAL_VULKAN_COMMAND_QUEUE_H_
diff --git a/iree/hal/vulkan/debug_reporter.cc b/iree/hal/vulkan/debug_reporter.cc index 62f0a16..c600030 100644 --- a/iree/hal/vulkan/debug_reporter.cc +++ b/iree/hal/vulkan/debug_reporter.cc
@@ -17,21 +17,23 @@ #include "iree/base/tracing.h" #include "iree/hal/vulkan/status_util.h" -namespace iree { -namespace hal { -namespace vulkan { - -namespace { +struct iree_hal_vulkan_debug_reporter_s { + iree_allocator_t host_allocator; + VkInstance instance; + iree::hal::vulkan::DynamicSymbols* syms; + const VkAllocationCallbacks* allocation_callbacks; + VkDebugUtilsMessengerEXT messenger; +}; // NOTE: |user_data| may be nullptr if we are being called during instance // creation. Otherwise it is a pointer to the DebugReporter instance. - +// // NOTE: this callback must be thread safe and must be careful not to reach too // far outside of the call - it is called in-context from arbitrary threads with // some amount of Vulkan state on the stack. Assume that creating or deleting // Vulkan objects, issuing most Vulkan commands, etc are off-limits. - -VKAPI_ATTR VkBool32 VKAPI_CALL DebugUtilsMessageCallback( +static VKAPI_ATTR VkBool32 VKAPI_CALL +iree_hal_vulkan_debug_utils_message_callback( VkDebugUtilsMessageSeverityFlagBitsEXT message_severity, VkDebugUtilsMessageTypeFlagsEXT message_type, const VkDebugUtilsMessengerCallbackDataEXT* callback_data, @@ -41,122 +43,89 @@ } else { IREE_VLOG(1) << callback_data->pMessage; } - return VK_FALSE; // VK_TRUE is reserved for future use. } -VKAPI_ATTR VkBool32 VKAPI_CALL DebugReportCallback( - VkDebugReportFlagsEXT flags, VkDebugReportObjectTypeEXT object_type, - uint64_t object, size_t location, int32_t message_code, - const char* layer_prefix, const char* message, void* user_data) { - IREE_VLOG(1) << message; - - return VK_FALSE; // VK_TRUE is reserved for future use. -} - -} // namespace - -// static -void DebugReporter::PopulateStaticCreateInfo( - VkDebugUtilsMessengerCreateInfoEXT* create_info) { - create_info->sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_MESSENGER_CREATE_INFO_EXT; - create_info->pNext = nullptr; - create_info->flags = 0; +// Populates |create_info| with an instance-agnostic callback. +// This can be used during instance creation by chaining the |create_info| to +// VkInstanceCreateInfo::pNext. +// +// Only use if VK_EXT_debug_utils is present. +static void iree_hal_vulkan_debug_reporter_populate_create_info( + VkDebugUtilsMessengerCreateInfoEXT* out_create_info) { + out_create_info->sType = + VK_STRUCTURE_TYPE_DEBUG_UTILS_MESSENGER_CREATE_INFO_EXT; + out_create_info->pNext = nullptr; + out_create_info->flags = 0; // TODO(benvanik): only enable the severities that logging has enabled. - create_info->messageSeverity = + out_create_info->messageSeverity = VK_DEBUG_UTILS_MESSAGE_SEVERITY_VERBOSE_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_INFO_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT; // TODO(benvanik): allow filtering by category as a flag. - create_info->messageType = VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT | - VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT | - VK_DEBUG_UTILS_MESSAGE_TYPE_PERFORMANCE_BIT_EXT; + out_create_info->messageType = + VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT | + VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT | + VK_DEBUG_UTILS_MESSAGE_TYPE_PERFORMANCE_BIT_EXT; - create_info->pfnUserCallback = DebugUtilsMessageCallback; - create_info->pUserData = nullptr; + out_create_info->pfnUserCallback = + iree_hal_vulkan_debug_utils_message_callback; + out_create_info->pUserData = nullptr; } -// static -void DebugReporter::PopulateStaticCreateInfo( - VkDebugReportCallbackCreateInfoEXT* create_info) { - create_info->sType = VK_STRUCTURE_TYPE_DEBUG_REPORT_CALLBACK_CREATE_INFO_EXT; - create_info->pNext = nullptr; - create_info->flags = 0; +iree_status_t iree_hal_vulkan_debug_reporter_allocate( + VkInstance instance, iree::hal::vulkan::DynamicSymbols* syms, + const VkAllocationCallbacks* allocation_callbacks, + iree_allocator_t host_allocator, + iree_hal_vulkan_debug_reporter_t** out_reporter) { + IREE_ASSERT_ARGUMENT(instance); + IREE_ASSERT_ARGUMENT(syms); + IREE_ASSERT_ARGUMENT(out_reporter); + IREE_TRACE_ZONE_BEGIN(z0); - // TODO(benvanik): only enable the severities that logging has enabled. - create_info->flags |= - VK_DEBUG_REPORT_INFORMATION_BIT_EXT | VK_DEBUG_REPORT_WARNING_BIT_EXT | - VK_DEBUG_REPORT_PERFORMANCE_WARNING_BIT_EXT | - VK_DEBUG_REPORT_ERROR_BIT_EXT | VK_DEBUG_REPORT_DEBUG_BIT_EXT; - - create_info->pfnCallback = DebugReportCallback; - create_info->pUserData = nullptr; -} - -// static -StatusOr<std::unique_ptr<DebugReporter>> -DebugReporter::CreateDebugUtilsMessenger( - VkInstance instance, const ref_ptr<DynamicSymbols>& syms, - const VkAllocationCallbacks* allocation_callbacks) { - IREE_TRACE_SCOPE0("DebugReporter::CreateDebugUtilsMessenger"); - - auto debug_reporter = std::unique_ptr<DebugReporter>( - new DebugReporter(instance, syms, allocation_callbacks)); + // Allocate our struct first as we need to pass the pointer to the userdata + // of the messager instance when we create it. + iree_hal_vulkan_debug_reporter_t* reporter = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_allocator_malloc(host_allocator, sizeof(*reporter), + (void**)&reporter)); + reporter->host_allocator = host_allocator; + reporter->instance = instance; + reporter->syms = syms; + reporter->allocation_callbacks = allocation_callbacks; VkDebugUtilsMessengerCreateInfoEXT create_info; - PopulateStaticCreateInfo(&create_info); - create_info.pUserData = debug_reporter.get(); + iree_hal_vulkan_debug_reporter_populate_create_info(&create_info); + create_info.pUserData = reporter; + iree_status_t status = VK_RESULT_TO_STATUS( + syms->vkCreateDebugUtilsMessengerEXT( + instance, &create_info, allocation_callbacks, &reporter->messenger), + "vkCreateDebugUtilsMessengerEXT"); - VK_RETURN_IF_ERROR(syms->vkCreateDebugUtilsMessengerEXT( - instance, &create_info, allocation_callbacks, - &debug_reporter->messenger_)); - - return debug_reporter; -} - -// static -StatusOr<std::unique_ptr<DebugReporter>> -DebugReporter::CreateDebugReportCallback( - VkInstance instance, const ref_ptr<DynamicSymbols>& syms, - const VkAllocationCallbacks* allocation_callbacks) { - IREE_TRACE_SCOPE0("DebugReporter::CreateDebugReportCallback"); - - auto debug_reporter = std::unique_ptr<DebugReporter>( - new DebugReporter(instance, syms, allocation_callbacks)); - - VkDebugReportCallbackCreateInfoEXT create_info; - PopulateStaticCreateInfo(&create_info); - create_info.pUserData = debug_reporter.get(); - - VK_RETURN_IF_ERROR(syms->vkCreateDebugReportCallbackEXT( - instance, &create_info, allocation_callbacks, - &debug_reporter->callback_)); - - return debug_reporter; -} - -DebugReporter::DebugReporter(VkInstance instance, - const ref_ptr<DynamicSymbols>& syms, - const VkAllocationCallbacks* allocation_callbacks) - : instance_(instance), - syms_(add_ref(syms)), - allocation_callbacks_(allocation_callbacks) {} - -DebugReporter::~DebugReporter() { - IREE_TRACE_SCOPE0("DebugReporter::dtor"); - if (messenger_ != VK_NULL_HANDLE) { - syms_->vkDestroyDebugUtilsMessengerEXT(instance_, messenger_, - allocation_callbacks_); + if (iree_status_is_ok(status)) { + *out_reporter = reporter; + } else { + iree_hal_vulkan_debug_reporter_free(reporter); } - if (callback_ != VK_NULL_HANDLE) { - syms_->vkDestroyDebugReportCallbackEXT(instance_, callback_, - allocation_callbacks_); - } + IREE_TRACE_ZONE_END(z0); + return status; } -} // namespace vulkan -} // namespace hal -} // namespace iree +void iree_hal_vulkan_debug_reporter_free( + iree_hal_vulkan_debug_reporter_t* reporter) { + if (!reporter) return; + iree_allocator_t host_allocator = reporter->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + if (reporter->messenger != VK_NULL_HANDLE) { + reporter->syms->vkDestroyDebugUtilsMessengerEXT( + reporter->instance, reporter->messenger, + reporter->allocation_callbacks); + } + iree_allocator_free(host_allocator, reporter); + + IREE_TRACE_ZONE_END(z0); +}
diff --git a/iree/hal/vulkan/debug_reporter.h b/iree/hal/vulkan/debug_reporter.h index 82dad6e..3c92d82 100644 --- a/iree/hal/vulkan/debug_reporter.h +++ b/iree/hal/vulkan/debug_reporter.h
@@ -15,22 +15,13 @@ #ifndef IREE_HAL_VULKAN_DEBUG_REPORTER_H_ #define IREE_HAL_VULKAN_DEBUG_REPORTER_H_ -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on - -#include "iree/base/status.h" +#include "iree/base/api.h" +#include "iree/hal/api.h" #include "iree/hal/vulkan/dynamic_symbols.h" -namespace iree { -namespace hal { -namespace vulkan { - // A debug reporter that works with the VK_EXT_debug_utils extension. // One reporter should be created per VkInstance to receive callbacks from the -// API and route them to our logging systems. In general VK_EXT_debug_utils -// should be preferred if available as it provides a much cleaner interface and -// more plug-points than VK_EXT_debug_report. +// API and route them to our logging systems. // // Since creating a reporter requires a VkInstance it's not possible to report // on messages during instance creation. To work around this it's possible to @@ -38,52 +29,16 @@ // VkInstanceCreateInfo::pNext chain. The callback will only be used this way // during the creation call after which users can create the real // instance-specific reporter. -class DebugReporter final { - public: - // Populates |create_info| with an instance-agnostic callback. - // This can be used during instance creation by chaining the |create_info| to - // VkInstanceCreateInfo::pNext. - // - // Only use if VK_EXT_debug_utils is present. - static void PopulateStaticCreateInfo( - VkDebugUtilsMessengerCreateInfoEXT* create_info); +typedef struct iree_hal_vulkan_debug_reporter_s + iree_hal_vulkan_debug_reporter_t; - // Populates |create_info| with an instance-agnostic callback. - // This can be used during instance creation by chaining the |create_info| to - // VkInstanceCreateInfo::pNext. - // - // Only use if VK_EXT_debug_report is present. - static void PopulateStaticCreateInfo( - VkDebugReportCallbackCreateInfoEXT* create_info); +iree_status_t iree_hal_vulkan_debug_reporter_allocate( + VkInstance instance, iree::hal::vulkan::DynamicSymbols* syms, + const VkAllocationCallbacks* allocation_callbacks, + iree_allocator_t host_allocator, + iree_hal_vulkan_debug_reporter_t** out_reporter); - // Creates a debug messenger for the given Vulkan |instance| with - // VK_EXT_debug_utils enabled. - static StatusOr<std::unique_ptr<DebugReporter>> CreateDebugUtilsMessenger( - VkInstance instance, const ref_ptr<DynamicSymbols>& syms, - const VkAllocationCallbacks* allocation_callbacks); - - // Creates a debug report callback for the given Vulkan |instance| with - // VK_EXT_debug_report enabled. - static StatusOr<std::unique_ptr<DebugReporter>> CreateDebugReportCallback( - VkInstance instance, const ref_ptr<DynamicSymbols>& syms, - const VkAllocationCallbacks* allocation_callbacks); - - ~DebugReporter(); - - private: - DebugReporter(VkInstance instance, const ref_ptr<DynamicSymbols>& syms, - const VkAllocationCallbacks* allocation_callbacks); - - VkInstance instance_ = VK_NULL_HANDLE; - ref_ptr<DynamicSymbols> syms_; - const VkAllocationCallbacks* allocation_callbacks_ = nullptr; - - VkDebugUtilsMessengerEXT messenger_ = VK_NULL_HANDLE; - VkDebugReportCallbackEXT callback_ = VK_NULL_HANDLE; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree +void iree_hal_vulkan_debug_reporter_free( + iree_hal_vulkan_debug_reporter_t* reporter); #endif // IREE_HAL_VULKAN_DEBUG_REPORTER_H_
diff --git a/iree/hal/vulkan/descriptor_pool_cache.cc b/iree/hal/vulkan/descriptor_pool_cache.cc index 6feea16..3853796 100644 --- a/iree/hal/vulkan/descriptor_pool_cache.cc +++ b/iree/hal/vulkan/descriptor_pool_cache.cc
@@ -48,8 +48,8 @@ return OkStatus(); } -DescriptorPoolCache::DescriptorPoolCache(ref_ptr<VkDeviceHandle> logical_device) - : logical_device_(std::move(logical_device)) {} +DescriptorPoolCache::DescriptorPoolCache(VkDeviceHandle* logical_device) + : logical_device_(logical_device) {} StatusOr<DescriptorPool> DescriptorPoolCache::AcquireDescriptorPool( VkDescriptorType descriptor_type, int max_descriptor_count) { @@ -74,8 +74,9 @@ descriptor_pool.handle = VK_NULL_HANDLE; VK_RETURN_IF_ERROR(syms().vkCreateDescriptorPool( - *logical_device_, &create_info, logical_device_->allocator(), - &descriptor_pool.handle)); + *logical_device_, &create_info, + logical_device_->allocator(), &descriptor_pool.handle), + "vkCreateDescriptorPool"); return descriptor_pool; } @@ -89,7 +90,8 @@ // this leads to better errors when using the validation layers as we'll // throw if there are in-flight command buffers using the sets in the pool. VK_RETURN_IF_ERROR(syms().vkResetDescriptorPool(*logical_device_, - descriptor_pool.handle, 0)); + descriptor_pool.handle, 0), + "vkResetDescriptorPool"); // TODO(benvanik): release to cache. syms().vkDestroyDescriptorPool(*logical_device_, descriptor_pool.handle,
diff --git a/iree/hal/vulkan/descriptor_pool_cache.h b/iree/hal/vulkan/descriptor_pool_cache.h index bb4fa33..0e001c0 100644 --- a/iree/hal/vulkan/descriptor_pool_cache.h +++ b/iree/hal/vulkan/descriptor_pool_cache.h
@@ -16,7 +16,6 @@ #define IREE_HAL_VULKAN_DESCRIPTOR_POOL_CACHE_H_ #include "absl/container/inlined_vector.h" -#include "iree/base/ref_ptr.h" #include "iree/hal/vulkan/dynamic_symbols.h" #include "iree/hal/vulkan/handle_util.h" @@ -43,9 +42,9 @@ class DescriptorSetGroup final { public: DescriptorSetGroup() = default; - DescriptorSetGroup(ref_ptr<DescriptorPoolCache> descriptor_pool_cache, + DescriptorSetGroup(DescriptorPoolCache* descriptor_pool_cache, absl::InlinedVector<DescriptorPool, 8> descriptor_pools) - : descriptor_pool_cache_(std::move(descriptor_pool_cache)), + : descriptor_pool_cache_(descriptor_pool_cache), descriptor_pools_(std::move(descriptor_pools)) {} DescriptorSetGroup(const DescriptorSetGroup&) = delete; DescriptorSetGroup& operator=(const DescriptorSetGroup&) = delete; @@ -62,7 +61,7 @@ Status Reset(); private: - ref_ptr<DescriptorPoolCache> descriptor_pool_cache_; + DescriptorPoolCache* descriptor_pool_cache_; absl::InlinedVector<DescriptorPool, 8> descriptor_pools_; }; @@ -72,13 +71,11 @@ // resources. After the descriptors in the pool are no longer used (all // command buffers using descriptor sets allocated from the pool have retired) // the pool is returned here to be reused in the future. -class DescriptorPoolCache final : public RefObject<DescriptorPoolCache> { +class DescriptorPoolCache final { public: - explicit DescriptorPoolCache(ref_ptr<VkDeviceHandle> logical_device); + explicit DescriptorPoolCache(VkDeviceHandle* logical_device); - const ref_ptr<VkDeviceHandle>& logical_device() const { - return logical_device_; - } + VkDeviceHandle* logical_device() const { return logical_device_; } const DynamicSymbols& syms() const { return *logical_device_->syms(); } // Acquires a new descriptor pool for use by the caller. @@ -93,7 +90,7 @@ Status ReleaseDescriptorPools(absl::Span<DescriptorPool> descriptor_pools); private: - ref_ptr<VkDeviceHandle> logical_device_; + VkDeviceHandle* logical_device_; }; } // namespace vulkan
diff --git a/iree/hal/vulkan/descriptor_set_arena.cc b/iree/hal/vulkan/descriptor_set_arena.cc index f7b3744..7c86d8c 100644 --- a/iree/hal/vulkan/descriptor_set_arena.cc +++ b/iree/hal/vulkan/descriptor_set_arena.cc
@@ -17,6 +17,8 @@ #include "iree/base/alignment.h" #include "iree/base/math.h" #include "iree/base/tracing.h" +#include "iree/hal/vulkan/native_descriptor_set_layout.h" +#include "iree/hal/vulkan/native_executable_layout.h" #include "iree/hal/vulkan/status_util.h" #include "iree/hal/vulkan/vma_buffer.h" @@ -26,26 +28,22 @@ namespace { -StatusOr<VmaBuffer*> CastBuffer(iree_hal_buffer_t* buffer) { - // TODO(benvanik): assert that the buffer is from the right allocator and - // that it is compatible with our target queue family. - return reinterpret_cast<VmaBuffer*>(iree_hal_buffer_allocated_buffer(buffer)); -} - -StatusOr<absl::Span<VkWriteDescriptorSet>> PopulateDescriptorSetWriteInfos( - absl::Span<const iree_hal_descriptor_set_binding_t> bindings, - VkDescriptorSet dst_set, Arena* arena) { +static StatusOr<absl::Span<VkWriteDescriptorSet>> +PopulateDescriptorSetWriteInfos( + iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings, VkDescriptorSet dst_set, + Arena* arena) { arena->Reset(); auto buffer_infos = - arena->AllocateSpan<VkDescriptorBufferInfo>(bindings.size()); - auto write_infos = arena->AllocateSpan<VkWriteDescriptorSet>(bindings.size()); + arena->AllocateSpan<VkDescriptorBufferInfo>(binding_count); + auto write_infos = arena->AllocateSpan<VkWriteDescriptorSet>(binding_count); - for (int i = 0; i < bindings.size(); ++i) { + for (int i = 0; i < binding_count; ++i) { const auto& binding = bindings[i]; auto& buffer_info = buffer_infos[i]; - IREE_ASSIGN_OR_RETURN(auto buffer, CastBuffer(binding.buffer)); - buffer_info.buffer = buffer->handle(); + buffer_info.buffer = iree_hal_vulkan_vma_buffer_handle( + iree_hal_buffer_allocated_buffer(binding.buffer)); buffer_info.offset = iree_hal_buffer_byte_offset(binding.buffer) + binding.offset; // Round up to a multiple of 32-bit. 32-bit is the most native bitwidth on @@ -86,15 +84,16 @@ return write_infos; } -VkDescriptorSetAllocateInfo PopulateDescriptorSetsAllocateInfo( +static VkDescriptorSetAllocateInfo PopulateDescriptorSetsAllocateInfo( const DescriptorPool& descriptor_pool, - NativeDescriptorSetLayout* set_layout) { + iree_hal_descriptor_set_layout_t* set_layout) { VkDescriptorSetAllocateInfo allocate_info; allocate_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO; allocate_info.pNext = nullptr; allocate_info.descriptorPool = descriptor_pool.handle; - VkDescriptorSetLayout set_layout_handle = set_layout->handle(); + VkDescriptorSetLayout set_layout_handle = + iree_hal_vulkan_native_descriptor_set_layout_handle(set_layout); allocate_info.descriptorSetCount = 1; allocate_info.pSetLayouts = &set_layout_handle; @@ -104,9 +103,9 @@ } // namespace DescriptorSetArena::DescriptorSetArena( - ref_ptr<DescriptorPoolCache> descriptor_pool_cache) - : logical_device_(add_ref(descriptor_pool_cache->logical_device())), - descriptor_pool_cache_(std::move(descriptor_pool_cache)) {} + DescriptorPoolCache* descriptor_pool_cache) + : logical_device_(descriptor_pool_cache->logical_device()), + descriptor_pool_cache_(descriptor_pool_cache) {} DescriptorSetArena::~DescriptorSetArena() { if (!used_descriptor_pools_.empty()) { @@ -118,21 +117,25 @@ } Status DescriptorSetArena::BindDescriptorSet( - VkCommandBuffer command_buffer, PipelineExecutableLayout* executable_layout, - int32_t set, absl::Span<const iree_hal_descriptor_set_binding_t> bindings) { + VkCommandBuffer command_buffer, + iree_hal_executable_layout_t* executable_layout, uint32_t set, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings) { // Always prefer using push descriptors when available as we can avoid the // additional API overhead of updating/resetting pools. if (logical_device_->enabled_extensions().push_descriptors) { - return PushDescriptorSet(command_buffer, executable_layout, set, bindings); + return PushDescriptorSet(command_buffer, executable_layout, set, + binding_count, bindings); } IREE_TRACE_SCOPE0("DescriptorSetArena::BindDescriptorSet"); - auto* set_layout = executable_layout->set_layouts()[set].get(); + auto* set_layout = + iree_hal_vulkan_native_executable_layout_set(executable_layout, set); // Pick a bucket based on the number of descriptors required. // NOTE: right now we are 1:1 with bindings. - uint32_t required_descriptor_count = static_cast<int>(bindings.size() * 1); + uint32_t required_descriptor_count = static_cast<int>(binding_count * 1); uint32_t max_descriptor_count = std::max(8u, iree_math_round_up_to_pow2_u32(required_descriptor_count)); uint32_t bucket = @@ -156,7 +159,8 @@ allocate_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO; allocate_info.pNext = nullptr; allocate_info.descriptorPool = descriptor_pool.handle; - VkDescriptorSetLayout set_layout_handle = set_layout->handle(); + VkDescriptorSetLayout set_layout_handle = + iree_hal_vulkan_native_descriptor_set_layout_handle(set_layout); allocate_info.descriptorSetCount = 1; allocate_info.pSetLayouts = &set_layout_handle; @@ -178,18 +182,18 @@ allocate_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO; allocate_info.pNext = nullptr; allocate_info.descriptorPool = descriptor_pool_buckets_[bucket].handle; - VkDescriptorSetLayout set_layout_handle = set_layout->handle(); allocate_info.descriptorSetCount = 1; allocate_info.pSetLayouts = &set_layout_handle; descriptor_set = VK_NULL_HANDLE; VK_RETURN_IF_ERROR(syms().vkAllocateDescriptorSets( - *logical_device_, &allocate_info, &descriptor_set)); + *logical_device_, &allocate_info, &descriptor_set), + "vkAllocateDescriptorSets"); } // Get a list of VkWriteDescriptorSet structs with all bound buffers. - IREE_ASSIGN_OR_RETURN(auto write_infos, - PopulateDescriptorSetWriteInfos( - bindings, descriptor_set, &scratch_arena_)); + IREE_ASSIGN_OR_RETURN(auto write_infos, PopulateDescriptorSetWriteInfos( + binding_count, bindings, + descriptor_set, &scratch_arena_)); // This is the reason why push descriptor sets are good. // We can't batch these effectively as we don't know prior to recording what @@ -201,29 +205,33 @@ write_infos.data(), 0, nullptr); // Bind the descriptor set. - syms().vkCmdBindDescriptorSets(command_buffer, VK_PIPELINE_BIND_POINT_COMPUTE, - executable_layout->handle(), set, 1, - &descriptor_set, 0, nullptr); + syms().vkCmdBindDescriptorSets( + command_buffer, VK_PIPELINE_BIND_POINT_COMPUTE, + iree_hal_vulkan_native_executable_layout_handle(executable_layout), set, + 1, &descriptor_set, 0, nullptr); return OkStatus(); } Status DescriptorSetArena::PushDescriptorSet( - VkCommandBuffer command_buffer, PipelineExecutableLayout* executable_layout, - int32_t set, absl::Span<const iree_hal_descriptor_set_binding_t> bindings) { + VkCommandBuffer command_buffer, + iree_hal_executable_layout_t* executable_layout, uint32_t set, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings) { IREE_TRACE_SCOPE0("DescriptorSetArena::PushDescriptorSet"); + VkPipelineLayout device_executable_layout = + iree_hal_vulkan_native_executable_layout_handle(executable_layout); // Get a list of VkWriteDescriptorSet structs with all bound buffers. - IREE_ASSIGN_OR_RETURN(auto write_infos, - PopulateDescriptorSetWriteInfos( - bindings, VK_NULL_HANDLE, &scratch_arena_)); + IREE_ASSIGN_OR_RETURN(auto write_infos, PopulateDescriptorSetWriteInfos( + binding_count, bindings, + VK_NULL_HANDLE, &scratch_arena_)); // Fast path using push descriptors. These are pooled internally by the // command buffer and prevent the need for our own pooling mechanisms. syms().vkCmdPushDescriptorSetKHR( - command_buffer, VK_PIPELINE_BIND_POINT_COMPUTE, - executable_layout->handle(), set, - static_cast<uint32_t>(write_infos.size()), write_infos.data()); + command_buffer, VK_PIPELINE_BIND_POINT_COMPUTE, device_executable_layout, + set, static_cast<uint32_t>(write_infos.size()), write_infos.data()); return OkStatus(); } @@ -239,7 +247,7 @@ for (auto& bucket : descriptor_pool_buckets_) { bucket = {}; } - return DescriptorSetGroup(add_ref(descriptor_pool_cache_), + return DescriptorSetGroup(descriptor_pool_cache_, std::move(used_descriptor_pools_)); }
diff --git a/iree/hal/vulkan/descriptor_set_arena.h b/iree/hal/vulkan/descriptor_set_arena.h index aa3a111..e34c35b 100644 --- a/iree/hal/vulkan/descriptor_set_arena.h +++ b/iree/hal/vulkan/descriptor_set_arena.h
@@ -20,9 +20,8 @@ #include "iree/base/arena.h" #include "iree/base/status.h" -#include "iree/hal/cc/command_buffer.h" #include "iree/hal/vulkan/descriptor_pool_cache.h" -#include "iree/hal/vulkan/pipeline_executable.h" +#include "iree/hal/vulkan/native_executable.h" namespace iree { namespace hal { @@ -31,17 +30,16 @@ // A reusable arena for allocating descriptor sets and batching updates. class DescriptorSetArena final { public: - explicit DescriptorSetArena( - ref_ptr<DescriptorPoolCache> descriptor_pool_cache); + explicit DescriptorSetArena(DescriptorPoolCache* descriptor_pool_cache); ~DescriptorSetArena(); // Allocates and binds a descriptor set from the arena. // The command buffer will have the descriptor set containing |bindings| bound // to it. - Status BindDescriptorSet( - VkCommandBuffer command_buffer, - PipelineExecutableLayout* executable_layout, int32_t set, - absl::Span<const iree_hal_descriptor_set_binding_t> bindings); + Status BindDescriptorSet(VkCommandBuffer command_buffer, + iree_hal_executable_layout_t* executable_layout, + uint32_t set, iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings); // Flushes all pending writes to descriptor sets allocated from the arena and // returns a group that - when dropped - will release the descriptor sets @@ -52,13 +50,13 @@ const DynamicSymbols& syms() const { return *logical_device_->syms(); } // Pushes the descriptor set to the command buffer, if supported. - Status PushDescriptorSet( - VkCommandBuffer command_buffer, - PipelineExecutableLayout* executable_layout, int32_t set, - absl::Span<const iree_hal_descriptor_set_binding_t> bindings); + Status PushDescriptorSet(VkCommandBuffer command_buffer, + iree_hal_executable_layout_t* executable_layout, + uint32_t set, iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings); - ref_ptr<VkDeviceHandle> logical_device_; - ref_ptr<DescriptorPoolCache> descriptor_pool_cache_; + VkDeviceHandle* logical_device_; + DescriptorPoolCache* descriptor_pool_cache_; // Arena used for temporary binding information used during allocation. Arena scratch_arena_;
diff --git a/iree/hal/vulkan/direct_command_buffer.cc b/iree/hal/vulkan/direct_command_buffer.cc index b947a2f..f95914a 100644 --- a/iree/hal/vulkan/direct_command_buffer.cc +++ b/iree/hal/vulkan/direct_command_buffer.cc
@@ -14,21 +14,183 @@ #include "iree/hal/vulkan/direct_command_buffer.h" -#include "absl/base/attributes.h" #include "absl/container/inlined_vector.h" -#include "absl/synchronization/mutex.h" #include "iree/base/math.h" -#include "iree/base/status.h" #include "iree/base/tracing.h" +#include "iree/hal/vulkan/descriptor_set_arena.h" +#include "iree/hal/vulkan/dynamic_symbols.h" +#include "iree/hal/vulkan/native_descriptor_set.h" +#include "iree/hal/vulkan/native_event.h" +#include "iree/hal/vulkan/native_executable_layout.h" #include "iree/hal/vulkan/status_util.h" +#include "iree/hal/vulkan/vma_buffer.h" -namespace iree { -namespace hal { -namespace vulkan { +using namespace iree::hal::vulkan; -namespace { +// Command buffer implementation that directly maps to VkCommandBuffer. +// This records the commands on the calling thread without additional threading +// indirection. +typedef struct { + iree_hal_resource_t resource; + VkDeviceHandle* logical_device; + iree_hal_command_buffer_mode_t mode; + iree_hal_command_category_t allowed_categories; -VkPipelineStageFlags ConvertPipelineStageFlags( + VkCommandPoolHandle* command_pool; + VkCommandBuffer handle; + + DynamicSymbols* syms; + + // TODO(benvanik): may grow large - should try to reclaim or reuse. + DescriptorSetArena descriptor_set_arena; + + // The current descriptor set group in use by the command buffer, if any. + // This must remain valid until all in-flight submissions of the command + // buffer complete. + DescriptorSetGroup descriptor_set_group; +} iree_hal_vulkan_direct_command_buffer_t; + +extern const iree_hal_command_buffer_vtable_t + iree_hal_vulkan_direct_command_buffer_vtable; + +static iree_hal_vulkan_direct_command_buffer_t* +iree_hal_vulkan_direct_command_buffer_cast( + iree_hal_command_buffer_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, + &iree_hal_vulkan_direct_command_buffer_vtable); + return (iree_hal_vulkan_direct_command_buffer_t*)base_value; +} + +iree_status_t iree_hal_vulkan_direct_command_buffer_allocate( + iree::hal::vulkan::VkDeviceHandle* logical_device, + iree::hal::vulkan::VkCommandPoolHandle* command_pool, + iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree::hal::vulkan::DescriptorPoolCache* descriptor_pool_cache, + iree_hal_command_buffer_t** out_command_buffer) { + IREE_ASSERT_ARGUMENT(logical_device); + IREE_ASSERT_ARGUMENT(command_pool); + IREE_ASSERT_ARGUMENT(descriptor_pool_cache); + IREE_ASSERT_ARGUMENT(out_command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + + VkCommandBufferAllocateInfo allocate_info; + allocate_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; + allocate_info.pNext = NULL; + allocate_info.commandPool = *command_pool; + allocate_info.commandBufferCount = 1; + allocate_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; + + VkCommandBuffer handle = VK_NULL_HANDLE; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, command_pool->Allocate(&allocate_info, &handle)); + + iree_hal_vulkan_direct_command_buffer_t* command_buffer = NULL; + iree_status_t status = + iree_allocator_malloc(logical_device->host_allocator(), + sizeof(*command_buffer), (void**)&command_buffer); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_vulkan_direct_command_buffer_vtable, + &command_buffer->resource); + command_buffer->logical_device = logical_device; + command_buffer->mode = mode; + command_buffer->allowed_categories = command_categories; + command_buffer->command_pool = command_pool; + command_buffer->handle = handle; + command_buffer->syms = logical_device->syms().get(); + + new (&command_buffer->descriptor_set_arena) + DescriptorSetArena(descriptor_pool_cache); + new (&command_buffer->descriptor_set_group) DescriptorSetGroup(); + + *out_command_buffer = (iree_hal_command_buffer_t*)command_buffer; + } else { + command_pool->Free(handle); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_vulkan_direct_command_buffer_reset( + iree_hal_vulkan_direct_command_buffer_t* command_buffer) { + // NOTE: we require that command buffers not be recorded while they are + // in-flight so this is safe. + IREE_IGNORE_ERROR(command_buffer->descriptor_set_group.Reset()); +} + +static void iree_hal_vulkan_direct_command_buffer_destroy( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); + iree_allocator_t host_allocator = + command_buffer->logical_device->host_allocator(); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_vulkan_direct_command_buffer_reset(command_buffer); + command_buffer->command_pool->Free(command_buffer->handle); + + command_buffer->descriptor_set_group.~DescriptorSetGroup(); + command_buffer->descriptor_set_arena.~DescriptorSetArena(); + + iree_allocator_free(host_allocator, command_buffer); + + IREE_TRACE_ZONE_END(z0); +} + +VkCommandBuffer iree_hal_vulkan_direct_command_buffer_handle( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); + return command_buffer->handle; +} + +static iree_hal_command_category_t +iree_hal_vulkan_direct_command_buffer_allowed_categories( + const iree_hal_command_buffer_t* base_command_buffer) { + return ((const iree_hal_vulkan_direct_command_buffer_t*)base_command_buffer) + ->allowed_categories; +} + +static iree_status_t iree_hal_vulkan_direct_command_buffer_begin( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); + + iree_hal_vulkan_direct_command_buffer_reset(command_buffer); + + VkCommandBufferBeginInfo begin_info; + begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; + begin_info.pNext = NULL; + begin_info.flags = iree_all_bits_set(command_buffer->mode, + IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT) + ? VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT + : 0; + begin_info.pInheritanceInfo = NULL; + VK_RETURN_IF_ERROR(command_buffer->syms->vkBeginCommandBuffer( + command_buffer->handle, &begin_info), + "vkBeginCommandBuffer"); + + return iree_ok_status(); +} + +static iree_status_t iree_hal_vulkan_direct_command_buffer_end( + iree_hal_command_buffer_t* base_command_buffer) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); + + VK_RETURN_IF_ERROR( + command_buffer->syms->vkEndCommandBuffer(command_buffer->handle), + "vkEndCommandBuffer"); + + // Flush all pending descriptor set writes (if any). + IREE_ASSIGN_OR_RETURN(command_buffer->descriptor_set_group, + command_buffer->descriptor_set_arena.Flush()); + + return iree_ok_status(); +} + +static VkPipelineStageFlags iree_hal_vulkan_convert_pipeline_stage_flags( iree_hal_execution_stage_t stage_mask) { VkPipelineStageFlags flags = 0; flags |= iree_any_bit_set(stage_mask, IREE_HAL_EXECUTION_STAGE_COMMAND_ISSUE) @@ -53,7 +215,8 @@ return flags; } -VkAccessFlags ConvertAccessMask(iree_hal_access_scope_t access_mask) { +static VkAccessFlags iree_hal_vulkan_convert_access_mask( + iree_hal_access_scope_t access_mask) { VkAccessFlags flags = 0; flags |= iree_any_bit_set(access_mask, IREE_HAL_ACCESS_SCOPE_INDIRECT_COMMAND_READ) @@ -89,8 +252,155 @@ return flags; } +static iree_status_t iree_hal_vulkan_direct_command_buffer_execution_barrier( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); + + absl::InlinedVector<VkMemoryBarrier, 8> memory_barrier_infos( + memory_barrier_count); + for (int i = 0; i < memory_barrier_count; ++i) { + const auto& memory_barrier = memory_barriers[i]; + auto& info = memory_barrier_infos[i]; + info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; + info.pNext = NULL; + info.srcAccessMask = + iree_hal_vulkan_convert_access_mask(memory_barrier.source_scope); + info.dstAccessMask = + iree_hal_vulkan_convert_access_mask(memory_barrier.target_scope); + } + + absl::InlinedVector<VkBufferMemoryBarrier, 8> buffer_barrier_infos( + buffer_barrier_count); + for (int i = 0; i < buffer_barrier_count; ++i) { + const auto& buffer_barrier = buffer_barriers[i]; + auto& info = buffer_barrier_infos[i]; + info.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER; + info.pNext = NULL; + info.srcAccessMask = + iree_hal_vulkan_convert_access_mask(buffer_barrier.source_scope); + info.dstAccessMask = + iree_hal_vulkan_convert_access_mask(buffer_barrier.target_scope); + info.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; + info.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; + info.buffer = iree_hal_vulkan_vma_buffer_handle( + iree_hal_buffer_allocated_buffer(buffer_barrier.buffer)); + info.offset = buffer_barrier.offset; + info.size = buffer_barrier.length; + } + + command_buffer->syms->vkCmdPipelineBarrier( + command_buffer->handle, + iree_hal_vulkan_convert_pipeline_stage_flags(source_stage_mask), + iree_hal_vulkan_convert_pipeline_stage_flags(target_stage_mask), + /*dependencyFlags=*/0, static_cast<uint32_t>(memory_barrier_infos.size()), + memory_barrier_infos.data(), + static_cast<uint32_t>(buffer_barrier_infos.size()), + buffer_barrier_infos.data(), 0, NULL); + + return iree_ok_status(); +} + +static iree_status_t iree_hal_vulkan_direct_command_buffer_signal_event( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); + + command_buffer->syms->vkCmdSetEvent( + command_buffer->handle, iree_hal_vulkan_native_event_handle(event), + iree_hal_vulkan_convert_pipeline_stage_flags(source_stage_mask)); + + return iree_ok_status(); +} + +static iree_status_t iree_hal_vulkan_direct_command_buffer_reset_event( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event, + iree_hal_execution_stage_t source_stage_mask) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); + + command_buffer->syms->vkCmdResetEvent( + command_buffer->handle, iree_hal_vulkan_native_event_handle(event), + iree_hal_vulkan_convert_pipeline_stage_flags(source_stage_mask)); + + return iree_ok_status(); +} + +static iree_status_t iree_hal_vulkan_direct_command_buffer_wait_events( + iree_hal_command_buffer_t* base_command_buffer, + iree_host_size_t event_count, const iree_hal_event_t** events, + iree_hal_execution_stage_t source_stage_mask, + iree_hal_execution_stage_t target_stage_mask, + iree_host_size_t memory_barrier_count, + const iree_hal_memory_barrier_t* memory_barriers, + iree_host_size_t buffer_barrier_count, + const iree_hal_buffer_barrier_t* buffer_barriers) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); + + absl::InlinedVector<VkEvent, 4> event_handles(event_count); + for (int i = 0; i < event_count; ++i) { + event_handles[i] = iree_hal_vulkan_native_event_handle(events[i]); + } + + absl::InlinedVector<VkMemoryBarrier, 8> memory_barrier_infos( + memory_barrier_count); + for (int i = 0; i < memory_barrier_count; ++i) { + const auto& memory_barrier = memory_barriers[i]; + auto& info = memory_barrier_infos[i]; + info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; + info.pNext = NULL; + info.srcAccessMask = + iree_hal_vulkan_convert_access_mask(memory_barrier.source_scope); + info.dstAccessMask = + iree_hal_vulkan_convert_access_mask(memory_barrier.target_scope); + } + + absl::InlinedVector<VkBufferMemoryBarrier, 8> buffer_barrier_infos( + buffer_barrier_count); + for (int i = 0; i < buffer_barrier_count; ++i) { + const auto& buffer_barrier = buffer_barriers[i]; + auto& info = buffer_barrier_infos[i]; + info.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER; + info.pNext = NULL; + info.srcAccessMask = + iree_hal_vulkan_convert_access_mask(buffer_barrier.source_scope); + info.dstAccessMask = + iree_hal_vulkan_convert_access_mask(buffer_barrier.target_scope); + info.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; + info.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; + info.buffer = iree_hal_vulkan_vma_buffer_handle( + iree_hal_buffer_allocated_buffer(buffer_barrier.buffer)); + info.offset = buffer_barrier.offset; + info.size = buffer_barrier.length; + } + + command_buffer->syms->vkCmdWaitEvents( + command_buffer->handle, (uint32_t)event_count, event_handles.data(), + iree_hal_vulkan_convert_pipeline_stage_flags(source_stage_mask), + iree_hal_vulkan_convert_pipeline_stage_flags(target_stage_mask), + (uint32_t)memory_barrier_count, memory_barrier_infos.data(), + (uint32_t)buffer_barrier_count, buffer_barrier_infos.data(), 0, NULL); + + return iree_ok_status(); +} + +static iree_status_t iree_hal_vulkan_direct_command_buffer_discard_buffer( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* buffer) { + // NOTE: we could use this to prevent queue family transitions. + return iree_ok_status(); +} + // Splats a pattern value of 1, 2, or 4 bytes out to a 4 byte value. -uint32_t SplatPattern(const void* pattern, size_t pattern_length) { +static uint32_t iree_hal_vulkan_splat_pattern(const void* pattern, + size_t pattern_length) { switch (pattern_length) { case 1: { uint32_t pattern_value = *static_cast<const uint8_t*>(pattern); @@ -110,247 +420,36 @@ } } -} // namespace - -DirectCommandBuffer::DirectCommandBuffer( - iree_hal_command_buffer_mode_t mode, - iree_hal_command_category_t command_categories, - ref_ptr<DescriptorPoolCache> descriptor_pool_cache, - ref_ptr<VkCommandPoolHandle> command_pool, VkCommandBuffer command_buffer) - : CommandBuffer(mode, command_categories), - command_pool_(std::move(command_pool)), - command_buffer_(command_buffer), - descriptor_set_arena_(std::move(descriptor_pool_cache)) {} - -DirectCommandBuffer::~DirectCommandBuffer() { - IREE_TRACE_SCOPE0("DirectCommandBuffer::dtor"); - descriptor_set_group_.Reset().IgnoreError(); - absl::MutexLock lock(command_pool_->mutex()); - syms()->vkFreeCommandBuffers(*command_pool_->logical_device(), *command_pool_, - 1, &command_buffer_); -} - -StatusOr<NativeEvent*> DirectCommandBuffer::CastEvent(Event* event) const { - // TODO(benvanik): assert the event is valid. - return static_cast<NativeEvent*>(event); -} - -StatusOr<VmaBuffer*> DirectCommandBuffer::CastBuffer(Buffer* buffer) const { - // TODO(benvanik): assert that the buffer is from the right allocator and - // that it is compatible with our target queue family. - return static_cast<VmaBuffer*>(buffer->allocated_buffer()); -} - -StatusOr<VmaBuffer*> DirectCommandBuffer::CastBuffer( - iree_hal_buffer_t* buffer) const { - // TODO(benvanik): assert that the buffer is from the right allocator and - // that it is compatible with our target queue family. - return reinterpret_cast<VmaBuffer*>(iree_hal_buffer_allocated_buffer(buffer)); -} - -StatusOr<NativeDescriptorSet*> DirectCommandBuffer::CastDescriptorSet( - DescriptorSet* descriptor_set) const { - // TODO(benvanik): assert the descriptor_set is valid. - return static_cast<NativeDescriptorSet*>(descriptor_set); -} - -StatusOr<PipelineExecutableLayout*> DirectCommandBuffer::CastExecutableLayout( - ExecutableLayout* executable_layout) const { - // TODO(benvanik): assert the executable_layout is valid. - return static_cast<PipelineExecutableLayout*>(executable_layout); -} - -StatusOr<PipelineExecutable*> DirectCommandBuffer::CastExecutable( - Executable* executable) const { - // TODO(benvanik): assert the executable is valid. - return static_cast<PipelineExecutable*>(executable); -} - -Status DirectCommandBuffer::Begin() { - IREE_TRACE_SCOPE0("DirectCommandBuffer::Begin"); - - is_recording_ = true; - - // NOTE: we require that command buffers not be recorded while they are - // in-flight so this is safe. - IREE_RETURN_IF_ERROR(descriptor_set_group_.Reset()); - - VkCommandBufferBeginInfo begin_info; - begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; - begin_info.pNext = nullptr; - begin_info.flags = - iree_all_bits_set(mode(), IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT) - ? VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT - : 0; - begin_info.pInheritanceInfo = nullptr; - VK_RETURN_IF_ERROR( - syms()->vkBeginCommandBuffer(command_buffer_, &begin_info)); - - return OkStatus(); -} - -Status DirectCommandBuffer::End() { - IREE_TRACE_SCOPE0("DirectCommandBuffer::End"); - - VK_RETURN_IF_ERROR(syms()->vkEndCommandBuffer(command_buffer_)); - - // Flush all pending descriptor set writes (if any). - IREE_ASSIGN_OR_RETURN(descriptor_set_group_, descriptor_set_arena_.Flush()); - - is_recording_ = false; - - return OkStatus(); -} - -Status DirectCommandBuffer::ExecutionBarrier( - iree_hal_execution_stage_t source_stage_mask, - iree_hal_execution_stage_t target_stage_mask, - absl::Span<const iree_hal_memory_barrier_t> memory_barriers, - absl::Span<const iree_hal_buffer_barrier_t> buffer_barriers) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::ExecutionBarrier"); - - absl::InlinedVector<VkMemoryBarrier, 8> memory_barrier_infos( - memory_barriers.size()); - for (int i = 0; i < memory_barriers.size(); ++i) { - const auto& memory_barrier = memory_barriers[i]; - auto& info = memory_barrier_infos[i]; - info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; - info.pNext = nullptr; - info.srcAccessMask = ConvertAccessMask(memory_barrier.source_scope); - info.dstAccessMask = ConvertAccessMask(memory_barrier.target_scope); - } - - absl::InlinedVector<VkBufferMemoryBarrier, 8> buffer_barrier_infos( - buffer_barriers.size()); - for (int i = 0; i < buffer_barriers.size(); ++i) { - const auto& buffer_barrier = buffer_barriers[i]; - auto& info = buffer_barrier_infos[i]; - info.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER; - info.pNext = nullptr; - info.srcAccessMask = ConvertAccessMask(buffer_barrier.source_scope); - info.dstAccessMask = ConvertAccessMask(buffer_barrier.target_scope); - info.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; - info.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; - IREE_ASSIGN_OR_RETURN(auto* device_buffer, - CastBuffer(buffer_barrier.buffer)); - info.buffer = device_buffer->handle(); - info.offset = buffer_barrier.offset; - info.size = buffer_barrier.length; - } - - syms()->vkCmdPipelineBarrier( - command_buffer_, ConvertPipelineStageFlags(source_stage_mask), - ConvertPipelineStageFlags(target_stage_mask), /*dependencyFlags=*/0, - static_cast<uint32_t>(memory_barrier_infos.size()), - memory_barrier_infos.data(), - static_cast<uint32_t>(buffer_barrier_infos.size()), - buffer_barrier_infos.data(), 0, nullptr); - - return OkStatus(); -} - -Status DirectCommandBuffer::SignalEvent( - Event* event, iree_hal_execution_stage_t source_stage_mask) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::SignalEvent"); - IREE_ASSIGN_OR_RETURN(auto* device_event, CastEvent(event)); - syms()->vkCmdSetEvent(command_buffer_, device_event->handle(), - ConvertPipelineStageFlags(source_stage_mask)); - return OkStatus(); -} - -Status DirectCommandBuffer::ResetEvent( - Event* event, iree_hal_execution_stage_t source_stage_mask) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::ResetEvent"); - IREE_ASSIGN_OR_RETURN(auto* device_event, CastEvent(event)); - syms()->vkCmdResetEvent(command_buffer_, device_event->handle(), - ConvertPipelineStageFlags(source_stage_mask)); - return OkStatus(); -} - -Status DirectCommandBuffer::WaitEvents( - absl::Span<Event*> events, iree_hal_execution_stage_t source_stage_mask, - iree_hal_execution_stage_t target_stage_mask, - absl::Span<const iree_hal_memory_barrier_t> memory_barriers, - absl::Span<const iree_hal_buffer_barrier_t> buffer_barriers) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::WaitEvents"); - - absl::InlinedVector<VkEvent, 4> event_handles(events.size()); - for (int i = 0; i < events.size(); ++i) { - IREE_ASSIGN_OR_RETURN(auto* device_event, CastEvent(events[i])); - event_handles[i] = device_event->handle(); - } - - absl::InlinedVector<VkMemoryBarrier, 8> memory_barrier_infos( - memory_barriers.size()); - for (int i = 0; i < memory_barriers.size(); ++i) { - const auto& memory_barrier = memory_barriers[i]; - auto& info = memory_barrier_infos[i]; - info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; - info.pNext = nullptr; - info.srcAccessMask = ConvertAccessMask(memory_barrier.source_scope); - info.dstAccessMask = ConvertAccessMask(memory_barrier.target_scope); - } - - absl::InlinedVector<VkBufferMemoryBarrier, 8> buffer_barrier_infos( - buffer_barriers.size()); - for (int i = 0; i < buffer_barriers.size(); ++i) { - const auto& buffer_barrier = buffer_barriers[i]; - auto& info = buffer_barrier_infos[i]; - info.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER; - info.pNext = nullptr; - info.srcAccessMask = ConvertAccessMask(buffer_barrier.source_scope); - info.dstAccessMask = ConvertAccessMask(buffer_barrier.target_scope); - info.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; - info.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; - IREE_ASSIGN_OR_RETURN(auto* device_buffer, - CastBuffer(buffer_barrier.buffer)); - info.buffer = device_buffer->handle(); - info.offset = buffer_barrier.offset; - info.size = buffer_barrier.length; - } - - syms()->vkCmdWaitEvents(command_buffer_, event_handles.size(), - event_handles.data(), - ConvertPipelineStageFlags(source_stage_mask), - ConvertPipelineStageFlags(target_stage_mask), - static_cast<uint32_t>(memory_barrier_infos.size()), - memory_barrier_infos.data(), - static_cast<uint32_t>(buffer_barrier_infos.size()), - buffer_barrier_infos.data(), 0, nullptr); - return OkStatus(); -} - -Status DirectCommandBuffer::FillBuffer(Buffer* target_buffer, - iree_device_size_t target_offset, - iree_device_size_t length, - const void* pattern, - size_t pattern_length) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::FillBuffer"); - IREE_ASSIGN_OR_RETURN(auto* target_device_buffer, CastBuffer(target_buffer)); +static iree_status_t iree_hal_vulkan_direct_command_buffer_fill_buffer( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length, const void* pattern, + iree_host_size_t pattern_length) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); + VkBuffer target_device_buffer = iree_hal_vulkan_vma_buffer_handle( + iree_hal_buffer_allocated_buffer(target_buffer)); // Note that fill only accepts 4-byte aligned values so we need to splat out // our variable-length pattern. - target_offset += target_buffer->byte_offset(); - uint32_t dword_pattern = SplatPattern(pattern, pattern_length); - syms()->vkCmdFillBuffer(command_buffer_, target_device_buffer->handle(), - target_offset, length, dword_pattern); + target_offset += iree_hal_buffer_byte_offset(target_buffer); + uint32_t dword_pattern = + iree_hal_vulkan_splat_pattern(pattern, pattern_length); + command_buffer->syms->vkCmdFillBuffer(command_buffer->handle, + target_device_buffer, target_offset, + length, dword_pattern); - return OkStatus(); + return iree_ok_status(); } -Status DirectCommandBuffer::DiscardBuffer(Buffer* buffer) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::DiscardBuffer"); - // NOTE: we could use this to prevent queue family transitions. - return OkStatus(); -} - -Status DirectCommandBuffer::UpdateBuffer(const void* source_buffer, - iree_device_size_t source_offset, - Buffer* target_buffer, - iree_device_size_t target_offset, - iree_device_size_t length) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::UpdateBuffer"); - IREE_ASSIGN_OR_RETURN(auto* target_device_buffer, CastBuffer(target_buffer)); +static iree_status_t iree_hal_vulkan_direct_command_buffer_update_buffer( + iree_hal_command_buffer_t* base_command_buffer, const void* source_buffer, + iree_host_size_t source_offset, iree_hal_buffer_t* target_buffer, + iree_device_size_t target_offset, iree_device_size_t length) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); + VkBuffer target_device_buffer = iree_hal_vulkan_vma_buffer_handle( + iree_hal_buffer_allocated_buffer(target_buffer)); // Vulkan only allows updates of <= 65536 because you really, really, really // shouldn't do large updates like this (as it wastes command buffer space and @@ -358,136 +457,176 @@ // recommendation in the spec for larger updates is to split the single update // into multiple updates over the entire desired range. const auto* source_buffer_ptr = static_cast<const uint8_t*>(source_buffer); - target_offset += target_buffer->byte_offset(); + target_offset += iree_hal_buffer_byte_offset(target_buffer); while (length > 0) { iree_device_size_t chunk_length = - std::min(static_cast<iree_device_size_t>(65536u), length); - syms()->vkCmdUpdateBuffer(command_buffer_, target_device_buffer->handle(), - target_offset, chunk_length, source_buffer_ptr); + iree_min((iree_device_size_t)65536u, length); + command_buffer->syms->vkCmdUpdateBuffer(command_buffer->handle, + target_device_buffer, target_offset, + chunk_length, source_buffer_ptr); source_buffer_ptr += chunk_length; target_offset += chunk_length; length -= chunk_length; } - return OkStatus(); + return iree_ok_status(); } -Status DirectCommandBuffer::CopyBuffer(Buffer* source_buffer, - iree_device_size_t source_offset, - Buffer* target_buffer, - iree_device_size_t target_offset, - iree_device_size_t length) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::CopyBuffer"); - IREE_ASSIGN_OR_RETURN(auto* source_device_buffer, CastBuffer(source_buffer)); - IREE_ASSIGN_OR_RETURN(auto* target_device_buffer, CastBuffer(target_buffer)); +static iree_status_t iree_hal_vulkan_direct_command_buffer_copy_buffer( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset, + iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset, + iree_device_size_t length) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); + VkBuffer source_device_buffer = iree_hal_vulkan_vma_buffer_handle( + iree_hal_buffer_allocated_buffer(source_buffer)); + VkBuffer target_device_buffer = iree_hal_vulkan_vma_buffer_handle( + iree_hal_buffer_allocated_buffer(target_buffer)); VkBufferCopy region; - region.srcOffset = source_buffer->byte_offset() + source_offset; - region.dstOffset = target_buffer->byte_offset() + target_offset; + region.srcOffset = iree_hal_buffer_byte_offset(source_buffer) + source_offset; + region.dstOffset = iree_hal_buffer_byte_offset(target_buffer) + target_offset; region.size = length; - syms()->vkCmdCopyBuffer(command_buffer_, source_device_buffer->handle(), - target_device_buffer->handle(), 1, ®ion); + command_buffer->syms->vkCmdCopyBuffer(command_buffer->handle, + source_device_buffer, + target_device_buffer, 1, ®ion); - return OkStatus(); + return iree_ok_status(); } -Status DirectCommandBuffer::PushConstants(ExecutableLayout* executable_layout, - size_t offset, - absl::Span<const uint32_t> values) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::PushConstants"); - IREE_ASSIGN_OR_RETURN(auto* device_executable_layout, - CastExecutableLayout(executable_layout)); +static iree_status_t iree_hal_vulkan_direct_command_buffer_push_constants( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_layout_t* executable_layout, iree_host_size_t offset, + const void* values, iree_host_size_t values_length) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); - syms()->vkCmdPushConstants( - command_buffer_, device_executable_layout->handle(), - VK_SHADER_STAGE_COMPUTE_BIT, - static_cast<uint32_t>(offset * sizeof(uint32_t)), - static_cast<uint32_t>(values.size() * sizeof(uint32_t)), values.data()); + command_buffer->syms->vkCmdPushConstants( + command_buffer->handle, + iree_hal_vulkan_native_executable_layout_handle(executable_layout), + VK_SHADER_STAGE_COMPUTE_BIT, (uint32_t)(offset * sizeof(uint32_t)), + (uint32_t)(values_length * sizeof(uint32_t)), values); - return OkStatus(); + return iree_ok_status(); } -Status DirectCommandBuffer::PushDescriptorSet( - ExecutableLayout* executable_layout, int32_t set, - absl::Span<const iree_hal_descriptor_set_binding_t> bindings) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::PushDescriptorSet"); - IREE_ASSIGN_OR_RETURN(auto* device_executable_layout, - CastExecutableLayout(executable_layout)); +static iree_status_t iree_hal_vulkan_direct_command_buffer_push_descriptor_set( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_layout_t* executable_layout, uint32_t set, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); // Either allocate, update, and bind a descriptor set or use push descriptor // sets to use the command buffer pool when supported. - return descriptor_set_arena_.BindDescriptorSet( - command_buffer_, device_executable_layout, set, bindings); + return command_buffer->descriptor_set_arena.BindDescriptorSet( + command_buffer->handle, executable_layout, set, binding_count, bindings); } -Status DirectCommandBuffer::BindDescriptorSet( - ExecutableLayout* executable_layout, int32_t set, - DescriptorSet* descriptor_set, - absl::Span<const iree_device_size_t> dynamic_offsets) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::BindDescriptorSet"); - IREE_ASSIGN_OR_RETURN(auto* device_executable_layout, - CastExecutableLayout(executable_layout)); - IREE_ASSIGN_OR_RETURN(auto* device_descriptor_set, - CastDescriptorSet(descriptor_set)); +static iree_status_t iree_hal_vulkan_direct_command_buffer_bind_descriptor_set( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_layout_t* executable_layout, uint32_t set, + iree_hal_descriptor_set_t* descriptor_set, + iree_host_size_t dynamic_offset_count, + const iree_device_size_t* dynamic_offsets) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); // Vulkan takes uint32_t as the size here, unlike everywhere else. - absl::InlinedVector<uint32_t, 4> dynamic_offsets_i32(dynamic_offsets.size()); - for (int i = 0; i < dynamic_offsets.size(); ++i) { + absl::InlinedVector<uint32_t, 4> dynamic_offsets_i32(dynamic_offset_count); + for (int i = 0; i < dynamic_offset_count; ++i) { dynamic_offsets_i32[i] = static_cast<uint32_t>(dynamic_offsets[i]); } - std::array<VkDescriptorSet, 1> descriptor_sets = { - device_descriptor_set->handle()}; - syms()->vkCmdBindDescriptorSets( - command_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, - device_executable_layout->handle(), set, - static_cast<uint32_t>(descriptor_sets.size()), descriptor_sets.data(), + VkDescriptorSet descriptor_sets[1] = { + iree_hal_vulkan_native_descriptor_set_handle(descriptor_set), + }; + command_buffer->syms->vkCmdBindDescriptorSets( + command_buffer->handle, VK_PIPELINE_BIND_POINT_COMPUTE, + iree_hal_vulkan_native_executable_layout_handle(executable_layout), set, + (uint32_t)IREE_ARRAYSIZE(descriptor_sets), descriptor_sets, static_cast<uint32_t>(dynamic_offsets_i32.size()), dynamic_offsets_i32.data()); - return OkStatus(); + return iree_ok_status(); } -Status DirectCommandBuffer::Dispatch(Executable* executable, - int32_t entry_point, - std::array<uint32_t, 3> workgroups) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::Dispatch"); +static iree_status_t iree_hal_vulkan_direct_command_buffer_dispatch( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); // Get the compiled and linked pipeline for the specified entry point and // bind it to the command buffer. - IREE_ASSIGN_OR_RETURN(auto* device_executable, CastExecutable(executable)); - IREE_ASSIGN_OR_RETURN( - VkPipeline pipeline, - device_executable->GetPipelineForEntryPoint(entry_point)); - syms()->vkCmdBindPipeline(command_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, - pipeline); + VkPipeline pipeline_handle = VK_NULL_HANDLE; + IREE_RETURN_IF_ERROR( + iree_hal_vulkan_native_executable_pipeline_for_entry_point( + executable, entry_point, &pipeline_handle)); + command_buffer->syms->vkCmdBindPipeline( + command_buffer->handle, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline_handle); - syms()->vkCmdDispatch(command_buffer_, workgroups[0], workgroups[1], - workgroups[2]); - return OkStatus(); + command_buffer->syms->vkCmdDispatch(command_buffer->handle, workgroup_x, + workgroup_y, workgroup_z); + + return iree_ok_status(); } -Status DirectCommandBuffer::DispatchIndirect( - Executable* executable, int32_t entry_point, Buffer* workgroups_buffer, +static iree_status_t iree_hal_vulkan_direct_command_buffer_dispatch_indirect( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_t* workgroups_buffer, iree_device_size_t workgroups_offset) { - IREE_TRACE_SCOPE0("DirectCommandBuffer::DispatchIndirect"); + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); // Get the compiled and linked pipeline for the specified entry point and // bind it to the command buffer. - IREE_ASSIGN_OR_RETURN(auto* device_executable, CastExecutable(executable)); - IREE_ASSIGN_OR_RETURN( - VkPipeline pipeline, - device_executable->GetPipelineForEntryPoint(entry_point)); - syms()->vkCmdBindPipeline(command_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, - pipeline); + VkPipeline pipeline_handle = VK_NULL_HANDLE; + IREE_RETURN_IF_ERROR( + iree_hal_vulkan_native_executable_pipeline_for_entry_point( + executable, entry_point, &pipeline_handle)); + command_buffer->syms->vkCmdBindPipeline( + command_buffer->handle, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline_handle); - IREE_ASSIGN_OR_RETURN(auto* workgroups_device_buffer, - CastBuffer(workgroups_buffer)); - syms()->vkCmdDispatchIndirect( - command_buffer_, workgroups_device_buffer->handle(), workgroups_offset); - return OkStatus(); + VkBuffer workgroups_device_buffer = iree_hal_vulkan_vma_buffer_handle( + iree_hal_buffer_allocated_buffer(workgroups_buffer)); + workgroups_offset += iree_hal_buffer_byte_offset(workgroups_buffer); + command_buffer->syms->vkCmdDispatchIndirect( + command_buffer->handle, workgroups_device_buffer, workgroups_offset); + + return iree_ok_status(); } -} // namespace vulkan -} // namespace hal -} // namespace iree +const iree_hal_command_buffer_vtable_t + iree_hal_vulkan_direct_command_buffer_vtable = { + /*.destroy=*/iree_hal_vulkan_direct_command_buffer_destroy, + /*.allowed_categories=*/ + iree_hal_vulkan_direct_command_buffer_allowed_categories, + /*.begin=*/iree_hal_vulkan_direct_command_buffer_begin, + /*.end=*/iree_hal_vulkan_direct_command_buffer_end, + /*.execution_barrier=*/ + iree_hal_vulkan_direct_command_buffer_execution_barrier, + /*.signal_event=*/ + iree_hal_vulkan_direct_command_buffer_signal_event, + /*.reset_event=*/iree_hal_vulkan_direct_command_buffer_reset_event, + /*.wait_events=*/iree_hal_vulkan_direct_command_buffer_wait_events, + /*.discard_buffer=*/ + iree_hal_vulkan_direct_command_buffer_discard_buffer, + /*.fill_buffer=*/iree_hal_vulkan_direct_command_buffer_fill_buffer, + /*.update_buffer=*/ + iree_hal_vulkan_direct_command_buffer_update_buffer, + /*.copy_buffer=*/iree_hal_vulkan_direct_command_buffer_copy_buffer, + /*.push_constants=*/ + iree_hal_vulkan_direct_command_buffer_push_constants, + /*.push_descriptor_set=*/ + iree_hal_vulkan_direct_command_buffer_push_descriptor_set, + /*.bind_descriptor_set=*/ + iree_hal_vulkan_direct_command_buffer_bind_descriptor_set, + /*.dispatch=*/iree_hal_vulkan_direct_command_buffer_dispatch, + /*.dispatch_indirect=*/ + iree_hal_vulkan_direct_command_buffer_dispatch_indirect, +};
diff --git a/iree/hal/vulkan/direct_command_buffer.h b/iree/hal/vulkan/direct_command_buffer.h index 6735330..7046093 100644 --- a/iree/hal/vulkan/direct_command_buffer.h +++ b/iree/hal/vulkan/direct_command_buffer.h
@@ -15,115 +15,29 @@ #ifndef IREE_HAL_VULKAN_DIRECT_COMMAND_BUFFER_H_ #define IREE_HAL_VULKAN_DIRECT_COMMAND_BUFFER_H_ -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on - -#include "iree/hal/cc/command_buffer.h" +#include "iree/hal/api.h" #include "iree/hal/vulkan/descriptor_pool_cache.h" -#include "iree/hal/vulkan/descriptor_set_arena.h" -#include "iree/hal/vulkan/dynamic_symbols.h" #include "iree/hal/vulkan/handle_util.h" -#include "iree/hal/vulkan/native_descriptor_set.h" -#include "iree/hal/vulkan/native_event.h" -#include "iree/hal/vulkan/pipeline_executable.h" -#include "iree/hal/vulkan/pipeline_executable_layout.h" -#include "iree/hal/vulkan/vma_buffer.h" -namespace iree { -namespace hal { -namespace vulkan { +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus -// Command buffer implementation that directly maps to VkCommandBuffer. -// This records the commands on the calling thread without additional threading -// indirection. -class DirectCommandBuffer final : public CommandBuffer { - public: - DirectCommandBuffer(iree_hal_command_buffer_mode_t mode, - iree_hal_command_category_t command_categories, - ref_ptr<DescriptorPoolCache> descriptor_pool_cache, - ref_ptr<VkCommandPoolHandle> command_pool, - VkCommandBuffer command_buffer); - ~DirectCommandBuffer() override; +// Creates a command buffer that directly records into a VkCommandBuffer. +iree_status_t iree_hal_vulkan_direct_command_buffer_allocate( + iree::hal::vulkan::VkDeviceHandle* logical_device, + iree::hal::vulkan::VkCommandPoolHandle* command_pool, + iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree::hal::vulkan::DescriptorPoolCache* descriptor_pool_cache, + iree_hal_command_buffer_t** out_command_buffer); - VkCommandBuffer handle() const { return command_buffer_; } +// Returns the native Vulkan VkCommandBuffer handle. +VkCommandBuffer iree_hal_vulkan_direct_command_buffer_handle( + iree_hal_command_buffer_t* command_buffer); - bool is_recording() const override { return is_recording_; } - - Status Begin() override; - Status End() override; - - Status ExecutionBarrier( - iree_hal_execution_stage_t source_stage_mask, - iree_hal_execution_stage_t target_stage_mask, - absl::Span<const iree_hal_memory_barrier_t> memory_barriers, - absl::Span<const iree_hal_buffer_barrier_t> buffer_barriers) override; - Status SignalEvent(Event* event, - iree_hal_execution_stage_t source_stage_mask) override; - Status ResetEvent(Event* event, - iree_hal_execution_stage_t source_stage_mask) override; - Status WaitEvents( - absl::Span<Event*> events, iree_hal_execution_stage_t source_stage_mask, - iree_hal_execution_stage_t target_stage_mask, - absl::Span<const iree_hal_memory_barrier_t> memory_barriers, - absl::Span<const iree_hal_buffer_barrier_t> buffer_barriers) override; - - Status FillBuffer(Buffer* target_buffer, iree_device_size_t target_offset, - iree_device_size_t length, const void* pattern, - size_t pattern_length) override; - Status DiscardBuffer(Buffer* buffer) override; - Status UpdateBuffer(const void* source_buffer, - iree_device_size_t source_offset, Buffer* target_buffer, - iree_device_size_t target_offset, - iree_device_size_t length) override; - Status CopyBuffer(Buffer* source_buffer, iree_device_size_t source_offset, - Buffer* target_buffer, iree_device_size_t target_offset, - iree_device_size_t length) override; - - Status PushConstants(ExecutableLayout* executable_layout, size_t offset, - absl::Span<const uint32_t> values) override; - - Status PushDescriptorSet( - ExecutableLayout* executable_layout, int32_t set, - absl::Span<const iree_hal_descriptor_set_binding_t> bindings) override; - Status BindDescriptorSet( - ExecutableLayout* executable_layout, int32_t set, - DescriptorSet* descriptor_set, - absl::Span<const iree_device_size_t> dynamic_offsets) override; - - Status Dispatch(Executable* executable, int32_t entry_point, - std::array<uint32_t, 3> workgroups) override; - Status DispatchIndirect(Executable* executable, int32_t entry_point, - Buffer* workgroups_buffer, - iree_device_size_t workgroups_offset) override; - - private: - const ref_ptr<DynamicSymbols>& syms() const { return command_pool_->syms(); } - - StatusOr<NativeEvent*> CastEvent(Event* event) const; - StatusOr<VmaBuffer*> CastBuffer(Buffer* buffer) const; - StatusOr<VmaBuffer*> CastBuffer(iree_hal_buffer_t* buffer) const; - StatusOr<NativeDescriptorSet*> CastDescriptorSet( - DescriptorSet* descriptor_set) const; - StatusOr<PipelineExecutableLayout*> CastExecutableLayout( - ExecutableLayout* executable_layout) const; - StatusOr<PipelineExecutable*> CastExecutable(Executable* executable) const; - - bool is_recording_ = false; - ref_ptr<VkCommandPoolHandle> command_pool_; - VkCommandBuffer command_buffer_; - - // TODO(b/140026716): may grow large - should try to reclaim or reuse. - DescriptorSetArena descriptor_set_arena_; - - // The current descriptor set group in use by the command buffer, if any. - // This must remain valid until all in-flight submissions of the command - // buffer complete. - DescriptorSetGroup descriptor_set_group_; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #endif // IREE_HAL_VULKAN_DIRECT_COMMAND_BUFFER_H_
diff --git a/iree/hal/vulkan/direct_command_queue.cc b/iree/hal/vulkan/direct_command_queue.cc index 71ec1ac..461ce9a 100644 --- a/iree/hal/vulkan/direct_command_queue.cc +++ b/iree/hal/vulkan/direct_command_queue.cc
@@ -16,11 +16,9 @@ #include <cstdint> -#include "iree/base/memory.h" -#include "iree/base/status.h" #include "iree/base/tracing.h" #include "iree/hal/vulkan/direct_command_buffer.h" -#include "iree/hal/vulkan/native_timeline_semaphore.h" +#include "iree/hal/vulkan/native_semaphore.h" #include "iree/hal/vulkan/status_util.h" namespace iree { @@ -28,20 +26,15 @@ namespace vulkan { DirectCommandQueue::DirectCommandQueue( - std::string name, iree_hal_command_category_t supported_categories, - const ref_ptr<VkDeviceHandle>& logical_device, VkQueue queue) - : CommandQueue(std::move(name), supported_categories), - logical_device_(add_ref(logical_device)), - queue_(queue) {} + VkDeviceHandle* logical_device, std::string name, + iree_hal_command_category_t supported_categories, VkQueue queue) + : CommandQueue(logical_device, std::move(name), supported_categories, + queue) {} -DirectCommandQueue::~DirectCommandQueue() { - IREE_TRACE_SCOPE0("DirectCommandQueue::dtor"); - absl::MutexLock lock(&queue_mutex_); - syms()->vkQueueWaitIdle(queue_); -} +DirectCommandQueue::~DirectCommandQueue() = default; -Status DirectCommandQueue::TranslateBatchInfo( - const SubmissionBatch& batch, VkSubmitInfo* submit_info, +iree_status_t DirectCommandQueue::TranslateBatchInfo( + const iree_hal_submission_batch_t* batch, VkSubmitInfo* submit_info, VkTimelineSemaphoreSubmitInfo* timeline_submit_info, Arena* arena) { // TODO(benvanik): see if we can go to finer-grained stages. // For example, if this was just queue ownership transfers then we can use @@ -50,39 +43,33 @@ VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT; auto wait_semaphore_handles = - arena->AllocateSpan<VkSemaphore>(batch.wait_semaphores.size()); + arena->AllocateSpan<VkSemaphore>(batch->wait_semaphores.count); auto wait_semaphore_values = - arena->AllocateSpan<uint64_t>(batch.wait_semaphores.size()); + arena->AllocateSpan<uint64_t>(batch->wait_semaphores.count); auto wait_dst_stage_masks = - arena->AllocateSpan<VkPipelineStageFlags>(batch.wait_semaphores.size()); - for (int i = 0; i < batch.wait_semaphores.size(); ++i) { - const auto& wait_point = batch.wait_semaphores[i]; - const auto* semaphore = - static_cast<NativeTimelineSemaphore*>(wait_point.semaphore); - wait_semaphore_handles[i] = semaphore->handle(); - wait_semaphore_values[i] = wait_point.value; + arena->AllocateSpan<VkPipelineStageFlags>(batch->wait_semaphores.count); + for (iree_host_size_t i = 0; i < batch->wait_semaphores.count; ++i) { + wait_semaphore_handles[i] = iree_hal_vulkan_native_semaphore_handle( + batch->wait_semaphores.semaphores[i]); + wait_semaphore_values[i] = batch->wait_semaphores.payload_values[i]; wait_dst_stage_masks[i] = dst_stage_mask; } auto signal_semaphore_handles = - arena->AllocateSpan<VkSemaphore>(batch.signal_semaphores.size()); + arena->AllocateSpan<VkSemaphore>(batch->signal_semaphores.count); auto signal_semaphore_values = - arena->AllocateSpan<uint64_t>(batch.signal_semaphores.size()); - for (int i = 0; i < batch.signal_semaphores.size(); ++i) { - const auto& signal_point = batch.signal_semaphores[i]; - const auto* semaphore = - static_cast<NativeTimelineSemaphore*>(signal_point.semaphore); - signal_semaphore_handles[i] = semaphore->handle(); - signal_semaphore_values[i] = signal_point.value; + arena->AllocateSpan<uint64_t>(batch->signal_semaphores.count); + for (iree_host_size_t i = 0; i < batch->signal_semaphores.count; ++i) { + signal_semaphore_handles[i] = iree_hal_vulkan_native_semaphore_handle( + batch->signal_semaphores.semaphores[i]); + signal_semaphore_values[i] = batch->signal_semaphores.payload_values[i]; } auto command_buffer_handles = - arena->AllocateSpan<VkCommandBuffer>(batch.command_buffers.size()); - for (int i = 0; i < batch.command_buffers.size(); ++i) { - const auto& command_buffer = batch.command_buffers[i]; - auto* direct_command_buffer = - static_cast<DirectCommandBuffer*>(command_buffer->impl()); - command_buffer_handles[i] = direct_command_buffer->handle(); + arena->AllocateSpan<VkCommandBuffer>(batch->command_buffer_count); + for (iree_host_size_t i = 0; i < batch->command_buffer_count; ++i) { + command_buffer_handles[i] = + iree_hal_vulkan_direct_command_buffer_handle(batch->command_buffers[i]); } submit_info->sType = VK_STRUCTURE_TYPE_SUBMIT_INFO; @@ -111,39 +98,43 @@ return OkStatus(); } -Status DirectCommandQueue::Submit(absl::Span<const SubmissionBatch> batches) { +iree_status_t DirectCommandQueue::Submit( + iree_host_size_t batch_count, const iree_hal_submission_batch_t* batches) { IREE_TRACE_SCOPE0("DirectCommandQueue::Submit"); // Map the submission batches to VkSubmitInfos. // Note that we must keep all arrays referenced alive until submission // completes and since there are a bunch of them we use an arena. Arena arena(4 * 1024); - auto submit_infos = arena.AllocateSpan<VkSubmitInfo>(batches.size()); + auto submit_infos = arena.AllocateSpan<VkSubmitInfo>(batch_count); auto timeline_submit_infos = - arena.AllocateSpan<VkTimelineSemaphoreSubmitInfo>(batches.size()); - for (int i = 0; i < batches.size(); ++i) { - IREE_RETURN_IF_ERROR(TranslateBatchInfo(batches[i], &submit_infos[i], + arena.AllocateSpan<VkTimelineSemaphoreSubmitInfo>(batch_count); + for (int i = 0; i < batch_count; ++i) { + IREE_RETURN_IF_ERROR(TranslateBatchInfo(&batches[i], &submit_infos[i], &timeline_submit_infos[i], &arena)); } - { - absl::MutexLock lock(&queue_mutex_); - VK_RETURN_IF_ERROR(syms()->vkQueueSubmit( - queue_, static_cast<uint32_t>(submit_infos.size()), submit_infos.data(), - VK_NULL_HANDLE)); - } + iree_slim_mutex_lock(&queue_mutex_); + iree_status_t status = VK_RESULT_TO_STATUS( + syms()->vkQueueSubmit(queue_, static_cast<uint32_t>(submit_infos.size()), + submit_infos.data(), VK_NULL_HANDLE), + "vkQueueSubmit"); + iree_slim_mutex_unlock(&queue_mutex_); + IREE_RETURN_IF_ERROR(status); - return OkStatus(); + return iree_ok_status(); } -Status DirectCommandQueue::WaitIdle(Time deadline_ns) { - if (deadline_ns == InfiniteFuture()) { +iree_status_t DirectCommandQueue::WaitIdle(iree_time_t deadline_ns) { + if (deadline_ns == IREE_TIME_INFINITE_FUTURE) { // Fast path for using vkQueueWaitIdle, which is usually cheaper (as it // requires fewer calls into the driver). IREE_TRACE_SCOPE0("DirectCommandQueue::WaitIdle#vkQueueWaitIdle"); - absl::MutexLock lock(&queue_mutex_); - VK_RETURN_IF_ERROR(syms()->vkQueueWaitIdle(queue_)); - return OkStatus(); + iree_slim_mutex_lock(&queue_mutex_); + iree_status_t status = + VK_RESULT_TO_STATUS(syms()->vkQueueWaitIdle(queue_), "vkQueueWaitIdle"); + iree_slim_mutex_unlock(&queue_mutex_); + return status; } IREE_TRACE_SCOPE0("DirectCommandQueue::WaitIdle#Fence"); @@ -155,46 +146,52 @@ create_info.pNext = nullptr; create_info.flags = 0; VkFence fence = VK_NULL_HANDLE; - VK_RETURN_IF_ERROR(syms()->vkCreateFence( - *logical_device_, &create_info, logical_device_->allocator(), &fence)); - auto fence_cleanup = MakeCleanup([this, fence]() { - syms()->vkDestroyFence(*logical_device_, fence, - logical_device_->allocator()); - }); + VK_RETURN_IF_ERROR( + syms()->vkCreateFence(*logical_device_, &create_info, + logical_device_->allocator(), &fence), + "vkCreateFence"); uint64_t timeout_ns; - if (deadline_ns == InfinitePast()) { + if (deadline_ns == IREE_TIME_INFINITE_PAST) { // Do not wait. timeout_ns = 0; - } else if (deadline_ns == InfiniteFuture()) { + } else if (deadline_ns == IREE_TIME_INFINITE_FUTURE) { // Wait forever. timeout_ns = UINT64_MAX; } else { // Convert to relative time in nanoseconds. - // The implementation may not wait with this granularity (like, by 10000x). - Time now_ns = Now(); + // The implementation may not wait with this granularity (like by 10000x). + iree_time_t now_ns = iree_time_now(); if (deadline_ns < now_ns) { - return DeadlineExceededErrorBuilder(IREE_LOC) << "Deadline in the past"; + return iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED); } - timeout_ns = static_cast<uint64_t>(deadline_ns - now_ns); + timeout_ns = (uint64_t)(deadline_ns - now_ns); } - { - absl::MutexLock lock(&queue_mutex_); - VK_RETURN_IF_ERROR(syms()->vkQueueSubmit(queue_, 0, nullptr, fence)); + iree_slim_mutex_lock(&queue_mutex_); + iree_status_t status = VK_RESULT_TO_STATUS( + syms()->vkQueueSubmit(queue_, 0, nullptr, fence), "vkQueueSubmit"); + iree_slim_mutex_unlock(&queue_mutex_); + + if (iree_status_is_ok(status)) { + VkResult result = syms()->vkWaitForFences(*logical_device_, 1, &fence, + VK_TRUE, timeout_ns); + switch (result) { + case VK_SUCCESS: + status = iree_ok_status(); + break; + case VK_TIMEOUT: + status = iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED); + break; + default: + status = VK_RESULT_TO_STATUS(result, "vkWaitForFences"); + break; + } } - VkResult result = - syms()->vkWaitForFences(*logical_device_, 1, &fence, VK_TRUE, timeout_ns); - switch (result) { - case VK_SUCCESS: - return OkStatus(); - case VK_TIMEOUT: - return DeadlineExceededErrorBuilder(IREE_LOC) - << "Deadline exceeded waiting for idle"; - default: - return VkResultToStatus(result, IREE_LOC); - } + syms()->vkDestroyFence(*logical_device_, fence, logical_device_->allocator()); + + return status; } } // namespace vulkan
diff --git a/iree/hal/vulkan/direct_command_queue.h b/iree/hal/vulkan/direct_command_queue.h index 059aa0c..9055233 100644 --- a/iree/hal/vulkan/direct_command_queue.h +++ b/iree/hal/vulkan/direct_command_queue.h
@@ -15,21 +15,8 @@ #ifndef IREE_HAL_VULKAN_DIRECT_COMMAND_QUEUE_H_ #define IREE_HAL_VULKAN_DIRECT_COMMAND_QUEUE_H_ -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on - -#include <cstdint> -#include <string> - -#include "absl/base/thread_annotations.h" -#include "absl/synchronization/mutex.h" #include "iree/base/arena.h" -#include "iree/base/status.h" -#include "iree/base/time.h" -#include "iree/hal/cc/command_queue.h" -#include "iree/hal/vulkan/dynamic_symbols.h" -#include "iree/hal/vulkan/handle_util.h" +#include "iree/hal/vulkan/command_queue.h" namespace iree { namespace hal { @@ -38,31 +25,20 @@ // Command queue implementation directly maps to VkQueue. class DirectCommandQueue final : public CommandQueue { public: - DirectCommandQueue(std::string name, + DirectCommandQueue(VkDeviceHandle* logical_device, std::string name, iree_hal_command_category_t supported_categories, - const ref_ptr<VkDeviceHandle>& logical_device, VkQueue queue); ~DirectCommandQueue() override; - const ref_ptr<DynamicSymbols>& syms() const { - return logical_device_->syms(); - } + iree_status_t Submit(iree_host_size_t batch_count, + const iree_hal_submission_batch_t* batches) override; - Status Submit(absl::Span<const SubmissionBatch> batches) override; - - Status WaitIdle(Time deadline_ns) override; + iree_status_t WaitIdle(iree_time_t deadline_ns) override; private: - Status TranslateBatchInfo(const SubmissionBatch& batch, - VkSubmitInfo* submit_info, - VkTimelineSemaphoreSubmitInfo* timeline_submit_info, - Arena* arena); - - ref_ptr<VkDeviceHandle> logical_device_; - - // VkQueue needs to be externally synchronized. - mutable absl::Mutex queue_mutex_; - VkQueue queue_ ABSL_GUARDED_BY(queue_mutex_); + iree_status_t TranslateBatchInfo( + const iree_hal_submission_batch_t* batch, VkSubmitInfo* submit_info, + VkTimelineSemaphoreSubmitInfo* timeline_submit_info, Arena* arena); }; } // namespace vulkan
diff --git a/iree/hal/vulkan/dynamic_symbol_tables.h b/iree/hal/vulkan/dynamic_symbol_tables.h index b709e57..05dcd59 100644 --- a/iree/hal/vulkan/dynamic_symbol_tables.h +++ b/iree/hal/vulkan/dynamic_symbol_tables.h
@@ -300,12 +300,12 @@ DEV_PFN(OPTIONAL, vkSignalSemaphore) \ DEV_PFN(OPTIONAL, vkSignalSemaphoreKHR) \ \ - INS_PFN(OPTIONAL, vkCreateDebugReportCallbackEXT) \ + INS_PFN(EXCLUDED, vkCreateDebugReportCallbackEXT) \ INS_PFN(OPTIONAL, vkCreateDebugUtilsMessengerEXT) \ INS_PFN(EXCLUDED, vkCreateDisplayPlaneSurfaceKHR) \ INS_PFN(EXCLUDED, vkCreateHeadlessSurfaceEXT) \ INS_PFN(EXCLUDED, vkDebugReportMessageEXT) \ - INS_PFN(OPTIONAL, vkDestroyDebugReportCallbackEXT) \ + INS_PFN(EXCLUDED, vkDestroyDebugReportCallbackEXT) \ INS_PFN(OPTIONAL, vkDestroyDebugUtilsMessengerEXT) \ INS_PFN(REQUIRED, vkDestroyInstance) \ INS_PFN(EXCLUDED, vkDestroySurfaceKHR) \
diff --git a/iree/hal/vulkan/dynamic_symbols_test.cc b/iree/hal/vulkan/dynamic_symbols_test.cc index c06e6a6..594673b 100644 --- a/iree/hal/vulkan/dynamic_symbols_test.cc +++ b/iree/hal/vulkan/dynamic_symbols_test.cc
@@ -14,7 +14,6 @@ #include "iree/hal/vulkan/dynamic_symbols.h" -#include "iree/hal/vulkan/status_util.h" #include "iree/testing/gtest.h" #include "iree/testing/status_matchers.h" @@ -58,8 +57,8 @@ VkApplicationInfo app_info = GetApplicationInfo(); VkInstanceCreateInfo create_info = GetInstanceCreateInfo(&app_info); VkInstance instance = VK_NULL_HANDLE; - VK_CHECK_OK( - syms->vkCreateInstance(&create_info, /*pAllocator=*/nullptr, &instance)); + ASSERT_EQ(VK_SUCCESS, syms->vkCreateInstance( + &create_info, /*pAllocator=*/nullptr, &instance)); IREE_ASSERT_OK(syms->LoadFromInstance(instance));
diff --git a/iree/hal/vulkan/emulated_semaphore.cc b/iree/hal/vulkan/emulated_semaphore.cc new file mode 100644 index 0000000..b287d5e --- /dev/null +++ b/iree/hal/vulkan/emulated_semaphore.cc
@@ -0,0 +1,634 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/vulkan/emulated_semaphore.h" + +#include <inttypes.h> +#include <stdint.h> + +#include "absl/base/thread_annotations.h" +#include "absl/container/inlined_vector.h" +#include "absl/synchronization/mutex.h" +#include "iree/base/intrusive_list.h" +#include "iree/base/ref_ptr.h" +#include "iree/base/status.h" +#include "iree/base/tracing.h" +#include "iree/hal/vulkan/dynamic_symbols.h" +#include "iree/hal/vulkan/serializing_command_queue.h" +#include "iree/hal/vulkan/status_util.h" + +namespace iree { +namespace hal { +namespace vulkan { + +class EmulatedTimelineSemaphore final { + public: + EmulatedTimelineSemaphore(VkDeviceHandle* logical_device, + TimePointSemaphorePool* semaphore_pool, + iree_host_size_t command_queue_count, + iree::hal::vulkan::CommandQueue** command_queues, + uint64_t initial_value); + + ~EmulatedTimelineSemaphore(); + + iree_status_t Query(uint64_t* out_value); + + iree_status_t Signal(uint64_t value); + + iree_status_t Wait(uint64_t value, iree_time_t deadline_ns); + + void Fail(iree_status_t status); + + // Gets a binary semaphore for waiting on the timeline to advance to the given + // |value|. The semaphore returned won't be waited by anyone else. Returns + // VK_NULL_HANDLE if no available semaphores for the given |value|. + // |wait_fence| is the fence associated with the queue submission that waiting + // on this semaphore. + VkSemaphore GetWaitSemaphore(uint64_t value, + const ref_ptr<TimePointFence>& wait_fence); + + // Cancels the waiting attempt on the given binary |semaphore|. This allows + // the |semaphore| to be waited by others. + iree_status_t CancelWaitSemaphore(VkSemaphore semaphore); + + // Gets a binary semaphore for signaling the timeline to the given |value|. + // |value| must be smaller than the current timeline value. |signal_fence| is + // the fence associated with the queue submission that signals this semaphore. + iree_status_t GetSignalSemaphore(uint64_t value, + const ref_ptr<TimePointFence>& signal_fence, + VkSemaphore* out_handle); + + private: + // Tries to advance the timeline to the given |to_upper_value| without + // blocking and returns whether the |to_upper_value| is reached. + iree_status_t TryToAdvanceTimeline(uint64_t to_upper_value, + bool* out_reached_upper_value) + ABSL_LOCKS_EXCLUDED(mutex_); + // Similar to the above, but also returns the fences that are known to have + // already signaled via |signaled_fences|. + iree_status_t TryToAdvanceTimeline( + uint64_t to_upper_value, bool* out_reached_upper_value, + absl::InlinedVector<VkFence, 4>* out_signaled_fences) + ABSL_LOCKS_EXCLUDED(mutex_); + + std::atomic<uint64_t> signaled_value_; + + VkDeviceHandle* logical_device_; + TimePointSemaphorePool* semaphore_pool_; + + iree_host_size_t command_queue_count_; + CommandQueue** command_queues_; + + mutable absl::Mutex mutex_; + + // A list of outstanding semaphores used to emulate time points. + // + // The life time of each semaphore is in one of the following state: + // + // * Unused state: value = UINT64_MAX, signal/wait fence = nullptr. This is + // the state of the semaphore when it's initially acquired from the pool and + // not put in the queue for emulating a time point yet. + // * Pending state: signaled value < value < UINT64_MAX, signal fence = + // <some-fence>, wait fence == nullptr. This is the state of the semaphore + // when it's put into the GPU queue for emulating a time point. + // * Pending and waiting state: signaled value < value < UINT64_MAX, signal + // fence = <some-fence>, wait fence == <some-fence>. This is the state of + // the semaphore when it's put into the GPU queue for emulating a time + // point and there is another queue submission waiting on it in GPU. + // * Signaled and not ever waited state: value <= signaled value, singal/wait + // fence = nullptr. This is the state of the semaphore when we know it's + // already signaled on GPU and there is no waiters for it. + // * Signaled and waiting state: value <= signaled value, signal fence = + // nullptr, wait fence = <some-fence>. This is the state of the semaphore + // when we know it's already signaled on GPU and there is still one queue + // submission on GPU is waiting for it. + IntrusiveList<TimePointSemaphore> outstanding_semaphores_ + ABSL_GUARDED_BY(mutex_); + + // NOTE: We only need to access this status (and thus take the lock) when we + // want to either signal failure or query the status in the case of the + // semaphore being set to UINT64_MAX. + iree_status_t status_ ABSL_GUARDED_BY(mutex_) = iree_ok_status(); +}; + +EmulatedTimelineSemaphore::EmulatedTimelineSemaphore( + VkDeviceHandle* logical_device, TimePointSemaphorePool* semaphore_pool, + iree_host_size_t command_queue_count, CommandQueue** command_queues, + uint64_t initial_value) + : signaled_value_(initial_value), + logical_device_(logical_device), + semaphore_pool_(semaphore_pool), + command_queue_count_(command_queue_count), + command_queues_(command_queues) {} + +EmulatedTimelineSemaphore::~EmulatedTimelineSemaphore() { + IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::dtor"); + IREE_CHECK_OK( + TryToAdvanceTimeline(UINT64_MAX, /*out_reached_upper_value=*/NULL)); + absl::MutexLock lock(&mutex_); + IREE_CHECK(outstanding_semaphores_.empty()) + << "Destroying an emulated timeline semaphore without first waiting on " + "outstanding signals"; + iree_status_free(status_); +} + +iree_status_t EmulatedTimelineSemaphore::Query(uint64_t* out_value) { + IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::Query"); + IREE_DVLOG(2) << "EmulatedTimelineSemaphore::Query"; + IREE_RETURN_IF_ERROR( + TryToAdvanceTimeline(UINT64_MAX, /*out_reached_upper_value=*/NULL)); + uint64_t value = signaled_value_.load(); + IREE_DVLOG(2) << "Current timeline value: " << value; + if (value == UINT64_MAX) { + absl::MutexLock lock(&mutex_); + return iree_status_clone(status_); + } + *out_value = value; + return iree_ok_status(); +} + +iree_status_t EmulatedTimelineSemaphore::Signal(uint64_t value) { + IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::Signal"); + IREE_DVLOG(2) << "EmulatedTimelineSemaphore::Signal"; + auto signaled_value = signaled_value_.exchange(value); + IREE_DVLOG(2) << "Previous value: " << signaled_value + << "; new value: " << value; + // Make sure the previous signaled value is smaller than the new value. + IREE_CHECK(signaled_value < value) + << "Attempting to signal a timeline value out of order; trying " << value + << " but " << signaled_value << " already signaled"; + + // Inform the device to make progress given we have a new value signaled now. + for (iree_host_size_t i = 0; i < command_queue_count_; ++i) { + IREE_RETURN_IF_ERROR(((SerializingCommandQueue*)command_queues_[i]) + ->AdvanceQueueSubmission()); + } + + return iree_ok_status(); +} + +iree_status_t EmulatedTimelineSemaphore::Wait(uint64_t value, + iree_time_t deadline_ns) { + IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::Wait"); + IREE_DVLOG(2) << "EmulatedTimelineSemaphore::Wait"; + + VkFence fence = VK_NULL_HANDLE; + do { + IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::Wait#loop"); + // First try to advance the timeline without blocking to see whether we've + // already reached the desired value. + bool reached_desired_value = false; + IREE_RETURN_IF_ERROR(TryToAdvanceTimeline(value, &reached_desired_value)); + if (reached_desired_value) return iree_ok_status(); + + // We must wait now. Find the first emulated time point that has a value >= + // the desired value so we can wait on its associated signal fence to make + // sure the timeline is advanced to the desired value. + absl::MutexLock lock(&mutex_); + auto semaphore = outstanding_semaphores_.begin(); + for (; semaphore != outstanding_semaphores_.end(); ++semaphore) { + if ((*semaphore)->value >= value) break; + } + if (semaphore != outstanding_semaphores_.end()) { + if (!(*semaphore)->signal_fence) { + return iree_make_status(IREE_STATUS_INTERNAL, + "timeline should have a signal fence for the " + "first time point beyond the signaled value"); + } + IREE_DVLOG(2) << "Found timepoint semaphore " << *semaphore + << " (value: " << (*semaphore)->value + << ") to wait for desired timeline value: " << value; + fence = (*semaphore)->signal_fence->value(); + // Found; we can break the loop and proceed to waiting now. + break; + } + // TODO(antiagainst): figure out a better way instead of the busy loop here. + } while (iree_time_now() < deadline_ns); + + if (fence == VK_NULL_HANDLE) { + // NOTE: not an error; it may be expected that the semaphore is not ready. + return iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED); + } + + uint64_t timeout_ns = + static_cast<uint64_t>(iree_absolute_deadline_to_timeout_ns(deadline_ns)); + VK_RETURN_IF_ERROR(logical_device_->syms()->vkWaitForFences( + *logical_device_, /*fenceCount=*/1, &fence, + /*waitAll=*/true, timeout_ns), + "vkWaitForFences"); + + return TryToAdvanceTimeline(value, /*out_reached_upper_value=*/NULL); +} + +void EmulatedTimelineSemaphore::Fail(iree_status_t status) { + IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::Fail"); + absl::MutexLock lock(&mutex_); + if (status_) return; + status_ = status; + signaled_value_.store(UINT64_MAX); +} + +VkSemaphore EmulatedTimelineSemaphore::GetWaitSemaphore( + uint64_t value, const ref_ptr<TimePointFence>& wait_fence) { + IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::GetWaitSemaphore"); + IREE_DVLOG(2) << "EmulatedTimelineSemaphore::GetWaitSemaphore"; + + absl::MutexLock lock(&mutex_); + + VkSemaphore semaphore = VK_NULL_HANDLE; + for (TimePointSemaphore* point : outstanding_semaphores_) { + if (point->value > value && point->wait_fence) { + point->wait_fence = add_ref(wait_fence); + semaphore = point->semaphore; + break; + } + } + + IREE_DVLOG(2) << "Binary VkSemaphore to wait on for timeline value (" << value + << ") and wait fence (" << wait_fence.get() + << "): " << semaphore; + + return semaphore; +} + +iree_status_t EmulatedTimelineSemaphore::CancelWaitSemaphore( + VkSemaphore semaphore) { + IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::CancelWaitSemaphore"); + IREE_DVLOG(2) << "EmulatedTimelineSemaphore::CancelWaitSemaphore"; + + absl::MutexLock lock(&mutex_); + for (TimePointSemaphore* point : outstanding_semaphores_) { + if (point->semaphore != semaphore) continue; + + if (!point->wait_fence) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "time point wasn't waited before"); + } + point->wait_fence = nullptr; + IREE_DVLOG(2) << "Cancelled waiting on binary VkSemaphore: " << semaphore; + return iree_ok_status(); + } + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "no time point for the given semaphore"); +} + +iree_status_t EmulatedTimelineSemaphore::GetSignalSemaphore( + uint64_t value, const ref_ptr<TimePointFence>& signal_fence, + VkSemaphore* out_handle) { + IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::GetSignalSemaphore"); + IREE_DVLOG(2) << "EmulatedTimelineSemaphore::GetSignalSemaphore"; + + if (signaled_value_.load() >= value) { + return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, + "timeline semaphore already signaled past %" PRIu64, + value); + } + + absl::MutexLock lock(&mutex_); + + auto insertion_point = outstanding_semaphores_.begin(); + while (insertion_point != outstanding_semaphores_.end()) { + if ((*insertion_point)->value > value) break; + } + + IREE_ASSIGN_OR_RETURN(TimePointSemaphore * semaphore, + semaphore_pool_->Acquire()); + semaphore->value = value; + semaphore->signal_fence = add_ref(signal_fence); + if (semaphore->wait_fence) { + return iree_make_status( + IREE_STATUS_INTERNAL, + "newly acquired time point semaphore should not have waiters"); + } + outstanding_semaphores_.insert(insertion_point, semaphore); + IREE_DVLOG(2) << "Timepoint semaphore to signal for timeline value (" << value + << ") and wait fence (" << signal_fence.get() + << "): " << semaphore + << " (binary VkSemaphore: " << semaphore->semaphore << ")"; + + *out_handle = semaphore->semaphore; + return iree_ok_status(); +} + +iree_status_t EmulatedTimelineSemaphore::TryToAdvanceTimeline( + uint64_t to_upper_value, bool* out_reached_upper_value) { + absl::InlinedVector<VkFence, 4> signaled_fences; + iree_status_t status = TryToAdvanceTimeline( + to_upper_value, out_reached_upper_value, &signaled_fences); + // Inform the queue that some fences are known to have signaled. This should + // happen here instead of inside the other TryToAdvanceTimeline to avoid + // potential mutex deadlock, given here we are not holding a mutex anymore. + if (!signaled_fences.empty()) { + for (iree_host_size_t i = 0; i < command_queue_count_; ++i) { + ((SerializingCommandQueue*)command_queues_[i]) + ->SignalFences(absl::MakeSpan(signaled_fences)); + } + } + return status; +} + +iree_status_t EmulatedTimelineSemaphore::TryToAdvanceTimeline( + uint64_t to_upper_value, bool* out_reached_upper_value, + absl::InlinedVector<VkFence, 4>* out_signaled_fences) { + IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::TryToAdvanceTimeline"); + IREE_DVLOG(3) << "EmulatedTimelineSemaphore::TryToAdvanceTimeline"; + if (out_reached_upper_value) *out_reached_upper_value = false; + + uint64_t past_value = signaled_value_.load(); + IREE_DVLOG(3) << "Current timeline value: " << past_value + << "; desired timeline value: " << to_upper_value; + + // Fast path for when already signaled past the desired value. + if (past_value >= to_upper_value) { + if (out_reached_upper_value) *out_reached_upper_value = true; + return iree_ok_status(); + } + + // We hold the lock during the entire resolve process so that we can resolve + // to the furthest possible value. + absl::MutexLock lock(&mutex_); + + IREE_DVLOG(3) << "# outstanding semaphores: " + << outstanding_semaphores_.size(); + + // The timeline has not signaled past the desired value and there is no + // binary semaphore pending on GPU yet: certainly the timeline cannot + // advance to the desired value. + if (outstanding_semaphores_.empty()) return iree_ok_status(); + + IntrusiveList<TimePointSemaphore> resolved_semaphores; + + auto clear_signal_fence = + [&out_signaled_fences](ref_ptr<TimePointFence>& fence) { + if (fence) { + if (out_signaled_fences) + out_signaled_fences->push_back(fence->value()); + fence.reset(); + } + }; + + bool keep_resolving = true; + bool reached_desired_value = false; + while (keep_resolving && !outstanding_semaphores_.empty()) { + auto* semaphore = outstanding_semaphores_.front(); + IREE_DVLOG(3) << "Looking at timepoint semaphore " << semaphore << ".."; + IREE_DVLOG(3) << " value: " << semaphore->value; + IREE_DVLOG(3) << " VkSemaphore: " << semaphore->semaphore; + IREE_DVLOG(3) << " signal fence: " << semaphore->signal_fence.get(); + IREE_DVLOG(3) << " wait fence: " << semaphore->wait_fence.get(); + + // If the current semaphore is for a value beyond our upper limit, then + // early exit so that we don't spend time dealing with signals we don't yet + // care about. This can prevent live lock where one thread is signaling + // fences as fast/faster than another thread can consume them. + if (semaphore->value > to_upper_value) { + keep_resolving = false; + reached_desired_value = true; + break; + } + + // If the current semaphore is for a value not greater than the past + // signaled value, then we know it was signaled previously. But there might + // be a waiter on it on GPU. + if (semaphore->value <= past_value) { + if (semaphore->signal_fence) { + return iree_make_status(IREE_STATUS_INTERNAL, + "timeline should already signaled past this " + "time point and cleared the signal fence"); + } + + // If ther is no waiters, we can recycle this semaphore now. If there + // exists one waiter, then query its status and recycle on success. We + // only handle success status here. Others will be handled when the fence + // is checked for other semaphores' signaling status for the same queue + // submission. + if (!semaphore->wait_fence || + semaphore->wait_fence->GetStatus() == VK_SUCCESS) { + clear_signal_fence(semaphore->signal_fence); + semaphore->wait_fence = nullptr; + outstanding_semaphores_.erase(semaphore); + resolved_semaphores.push_back(semaphore); + IREE_DVLOG(3) << "Resolved and recycling semaphore " << semaphore; + } + + continue; + } + + // This semaphore represents a value gerater than the known previously + // signaled value. We don't know its status so we need to really query now. + + if (!semaphore->signal_fence) { + return iree_make_status(IREE_STATUS_INTERNAL, + "status of this time point in the timeline " + "should still be pending with a singal fence"); + } + VkResult signal_status = semaphore->signal_fence->GetStatus(); + + switch (signal_status) { + case VK_SUCCESS: + IREE_DVLOG(3) << "..semaphore signaled"; + signaled_value_.store(semaphore->value); + clear_signal_fence(semaphore->signal_fence); + // If no waiters, we can recycle this semaphore now. + if (!semaphore->wait_fence) { + semaphore->wait_fence = nullptr; + outstanding_semaphores_.erase(semaphore); + resolved_semaphores.push_back(semaphore); + IREE_DVLOG(3) << "Resolved and recycling semaphore " << semaphore; + } + break; + case VK_NOT_READY: + // The fence has not been signaled yet so this is the furthest time + // point we can go in this timeline. + keep_resolving = false; + IREE_DVLOG(3) << "..semaphore not yet signaled"; + break; + default: + // Fence indicates an error (device lost, out of memory, etc). + // Propagate this back to our status (and thus any waiters). + // Since we only take the first error we find we skip all remaining + // fences. + keep_resolving = false; + clear_signal_fence(semaphore->signal_fence); + status_ = VK_RESULT_TO_STATUS(signal_status, "signal status"); + signaled_value_.store(UINT64_MAX); + break; + } + } + + IREE_DVLOG(3) << "Releasing " << resolved_semaphores.size() + << " resolved semaphores; " << outstanding_semaphores_.size() + << " still outstanding"; + semaphore_pool_->ReleaseResolved(&resolved_semaphores); + if (!iree_status_is_ok(status_)) { + for (iree_host_size_t i = 0; i < command_queue_count_; ++i) { + ((SerializingCommandQueue*)command_queues_[i])->AbortQueueSubmission(); + } + semaphore_pool_->ReleaseUnresolved(&outstanding_semaphores_); + return status_; + } + + if (out_reached_upper_value) *out_reached_upper_value = reached_desired_value; + return iree_ok_status(); +} + +} // namespace vulkan +} // namespace hal +} // namespace iree + +using namespace iree::hal::vulkan; + +// Wrap the C++ type above so that we have a somewhat normal C interface. +// Porting the above to C is ideal but since this is just a fallback layer I'm +// not sure it's worth it (given that we may require Vulkan 1.2 with timeline +// semaphores built in at some point soon). +typedef struct { + iree_hal_resource_t resource; + iree_allocator_t host_allocator; + EmulatedTimelineSemaphore* handle; +} iree_hal_vulkan_emulated_semaphore_t; + +extern const iree_hal_semaphore_vtable_t + iree_hal_vulkan_emulated_semaphore_vtable; + +static EmulatedTimelineSemaphore* iree_hal_vulkan_emulated_semaphore_cast( + iree_hal_semaphore_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_vulkan_emulated_semaphore_vtable); + return ((iree_hal_vulkan_emulated_semaphore_t*)base_value)->handle; +} + +iree_status_t iree_hal_vulkan_emulated_semaphore_create( + iree::hal::vulkan::VkDeviceHandle* logical_device, + iree::hal::vulkan::TimePointSemaphorePool* semaphore_pool, + iree_host_size_t command_queue_count, + iree::hal::vulkan::CommandQueue** command_queues, uint64_t initial_value, + iree_hal_semaphore_t** out_semaphore) { + iree_hal_vulkan_emulated_semaphore_t* semaphore = NULL; + IREE_RETURN_IF_ERROR(iree_allocator_malloc(logical_device->host_allocator(), + sizeof(*semaphore), + (void**)&semaphore)); + iree_hal_resource_initialize(&iree_hal_vulkan_emulated_semaphore_vtable, + &semaphore->resource); + semaphore->host_allocator = logical_device->host_allocator(); + semaphore->handle = new EmulatedTimelineSemaphore( + logical_device, semaphore_pool, command_queue_count, command_queues, + initial_value); + + *out_semaphore = (iree_hal_semaphore_t*)semaphore; + return iree_ok_status(); +} + +static void iree_hal_vulkan_emulated_semaphore_destroy( + iree_hal_semaphore_t* base_semaphore) { + iree_hal_vulkan_emulated_semaphore_t* semaphore = + (iree_hal_vulkan_emulated_semaphore_t*)base_semaphore; + iree_allocator_t host_allocator = semaphore->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + delete semaphore->handle; + iree_allocator_free(host_allocator, semaphore); + + IREE_TRACE_ZONE_END(z0); +} + +iree_status_t iree_hal_vulkan_emulated_semaphore_acquire_wait_handle( + iree_hal_semaphore_t* base_semaphore, uint64_t value, + const iree::ref_ptr<iree::hal::vulkan::TimePointFence>& wait_fence, + VkSemaphore* out_handle) { + EmulatedTimelineSemaphore* semaphore = + iree_hal_vulkan_emulated_semaphore_cast(base_semaphore); + *out_handle = semaphore->GetWaitSemaphore(value, wait_fence); + return iree_ok_status(); +} + +iree_status_t iree_hal_vulkan_emulated_semaphore_cancel_wait_handle( + iree_hal_semaphore_t* base_semaphore, VkSemaphore handle) { + EmulatedTimelineSemaphore* semaphore = + iree_hal_vulkan_emulated_semaphore_cast(base_semaphore); + return semaphore->CancelWaitSemaphore(handle); +} + +iree_status_t iree_hal_vulkan_emulated_semaphore_acquire_signal_handle( + iree_hal_semaphore_t* base_semaphore, uint64_t value, + const iree::ref_ptr<iree::hal::vulkan::TimePointFence>& signal_fence, + VkSemaphore* out_handle) { + EmulatedTimelineSemaphore* semaphore = + iree_hal_vulkan_emulated_semaphore_cast(base_semaphore); + return semaphore->GetSignalSemaphore(value, signal_fence, out_handle); +} + +static iree_status_t iree_hal_vulkan_emulated_semaphore_query( + iree_hal_semaphore_t* base_semaphore, uint64_t* out_value) { + EmulatedTimelineSemaphore* semaphore = + iree_hal_vulkan_emulated_semaphore_cast(base_semaphore); + return semaphore->Query(out_value); +} + +static iree_status_t iree_hal_vulkan_emulated_semaphore_signal( + iree_hal_semaphore_t* base_semaphore, uint64_t new_value) { + EmulatedTimelineSemaphore* semaphore = + iree_hal_vulkan_emulated_semaphore_cast(base_semaphore); + return semaphore->Signal(new_value); +} + +static void iree_hal_vulkan_emulated_semaphore_fail( + iree_hal_semaphore_t* base_semaphore, iree_status_t status) { + EmulatedTimelineSemaphore* semaphore = + iree_hal_vulkan_emulated_semaphore_cast(base_semaphore); + semaphore->Fail(status); +} + +static iree_status_t iree_hal_vulkan_emulated_semaphore_wait_with_deadline( + iree_hal_semaphore_t* base_semaphore, uint64_t value, + iree_time_t deadline_ns) { + EmulatedTimelineSemaphore* semaphore = + iree_hal_vulkan_emulated_semaphore_cast(base_semaphore); + return semaphore->Wait(value, deadline_ns); +} + +static iree_status_t iree_hal_vulkan_emulated_semaphore_wait_with_timeout( + iree_hal_semaphore_t* base_semaphore, uint64_t value, + iree_duration_t timeout_ns) { + return iree_hal_vulkan_emulated_semaphore_wait_with_deadline( + base_semaphore, value, iree_relative_timeout_to_deadline_ns(timeout_ns)); +} + +iree_status_t iree_hal_vulkan_emulated_semaphore_multi_wait( + iree::hal::vulkan::VkDeviceHandle* logical_device, + const iree_hal_semaphore_list_t* semaphore_list, iree_time_t deadline_ns, + VkSemaphoreWaitFlags wait_flags) { + // TODO(antiagainst): We actually should get the fences associated with the + // emulated timeline semaphores so that we can wait them in a bunch. This + // implementation is problematic if we wait to wait any and we have the + // first semaphore taking extra long time but the following ones signal + // quickly. + for (iree_host_size_t i = 0; i < semaphore_list->count; ++i) { + IREE_RETURN_IF_ERROR(iree_hal_vulkan_emulated_semaphore_wait_with_deadline( + semaphore_list->semaphores[i], semaphore_list->payload_values[i], + deadline_ns)); + if (wait_flags & VK_SEMAPHORE_WAIT_ANY_BIT) return iree_ok_status(); + } + return iree_ok_status(); +} + +const iree_hal_semaphore_vtable_t iree_hal_vulkan_emulated_semaphore_vtable = { + /*.destroy=*/iree_hal_vulkan_emulated_semaphore_destroy, + /*.query=*/iree_hal_vulkan_emulated_semaphore_query, + /*.signal=*/iree_hal_vulkan_emulated_semaphore_signal, + /*.fail=*/iree_hal_vulkan_emulated_semaphore_fail, + /*.wait_with_deadline=*/ + iree_hal_vulkan_emulated_semaphore_wait_with_deadline, + /*.wait_with_timeout=*/ + iree_hal_vulkan_emulated_semaphore_wait_with_timeout, +};
diff --git a/iree/hal/vulkan/emulated_semaphore.h b/iree/hal/vulkan/emulated_semaphore.h new file mode 100644 index 0000000..28af0c6 --- /dev/null +++ b/iree/hal/vulkan/emulated_semaphore.h
@@ -0,0 +1,165 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_VULKAN_ENUMLATED_SEMAPHORE_H_ +#define IREE_HAL_VULKAN_ENUMLATED_SEMAPHORE_H_ + +#include "iree/hal/api.h" +#include "iree/hal/vulkan/command_queue.h" +#include "iree/hal/vulkan/handle_util.h" +#include "iree/hal/vulkan/timepoint_util.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates a timeline semaphore emulated via `VkFence`s and binary +// `VkSemaphore`s. +// +// Vulkan provides several explicit synchronization primitives: fences, +// (binary/timeline) semaphores, events, pipeline barriers, and render passes. +// See "6. Synchronization and Cache Control" of the Vulkan specification +// for the details. +// +// Render passes are for graphics pipelines so IREE does not care about them. +// Pipeline barriers synchronize control within a command buffer at a single +// point. Fences, (binary/timeline) semaphores, and events are synchronization +// primitives that have separate signal and wait operations. Events are more +// fine-grained compared to fences and semaphores given that they can be +// signaled or waited within a command buffer while fences and semaphores are +// at queue submissions. Each of them have its usage requirements: +// +// * Fences must be signaled on GPU and waited on CPU. Fences must be reset +// before reuse. +// * Binary semaphores must be signaled on GPU and waited on GPU. They do not +// support wait-before-signal submission order. More importantly, binary +// semaphore wait also unsignals the semaphore. So binary semaphore signals +// and waits should occur in discrete 1:1 pairs. +// * Timeline semaphores can be signaled on CPU or GPU and waited on CPU or GPU. +// They support wait-before-signal submission order. Timeline semaphores do +// not need to be reset. +// +// It's clear that timeline semaphore is more flexible than fences and binary +// semaphores: it unifies GPU and CPU synchronization with a single primitive. +// But it's not always available: it requires the VK_KHR_timeline_semaphore +// or Vulkan 1.2. When it's not available, it can be emulated via `VkFence`s +// and binary `VkSemaphore`s. The emulation need to provide the functionality of +// timeline semaphores and also not violate the usage requirements of `VkFence`s +// and binary `VkSemaphore`s. +// +// The basic idea is to create a timeline object with time points to emulate the +// timeline semaphore, which consists of a monotonically increasing 64-bit +// integer value. Each time point represents a specific signaled/waited integer +// value of the timeline semaphore; each time point can associate with binary +// `VkSemaphore`s and/or `VkFence`s for emulating the synchronization. +// +// Concretely, for each of the possible signal -> wait scenarios timeline +// semaphore supports: +// +// ### GPU -> GPU (via `vkQueueSubmit`) +// +// Each `vkQueueSubmit` can attach a `VkTimelineSemaphoreSubmitInfo` to describe +// the timeline semaphore values signaled and waited. Each of the signaled value +// will be a time point and emulated by a binary `VkSemaphore`. We submit the +// binary `VkSemahpore`s to the GPU under the hood. For the waited values, the +// situation is more complicated because of the differences between binary and +// timeline semaphores: +// +// * Binary semaphore signal-wait relationship is strictly 1:1, unlike timeline +// semaphore where we can have 1:N cases. This means for a specific binary +// `VkSemaphore` used to emulate a signaled time point, we can have at most +// one subsequent `vkQueueSubmit` waits on it. We need other mechanisms for +// additional waits. A simple way is to involve the CPU and don't sumbit +// the additional work to queue until the desired value is already signaled +// past. This requires `VkFence`s for letting the CPU know the status of +// GPU progress, but `VkFence` is needed anyway because of GPU -> CPU +// synchronization. +// * Binary semaphores does not support wait-before-signal submission order. +// This means we need to put the submission into a self-managed queue if the +// binary semaphores used to emulate the time points waited by the submission +// are not submitted to GPU yet. +// +// ### GPU -> CPU (via `vkWaitSemaphores`) +// +// Without timeline semaphore, we need to use fences to let CPU wait on GPU +// progress. So this direction can be emulated by `vkWaitFences`. It means we +// need to associate a `VkFence` with the given waited timeline semaphores. +// Because we don't know whether a particular `vkQueueSubmit` with timeline +// semaphores will be later waited on by CPU beforehand, we need to bundle each +// of them with a `VkFence` just in case they will be waited on later. +// +// ### CPU -> GPU (via `vkSignalSemaphore`) +// +// This direction can be handled by bumping the signaled timeline value and +// scan the self-managed queue to submit more work to GPU if possible. +// +// ### CPU -> CPU (via `vkWaitSemaphores`) +// +// This is similar to CPU -> GPU direction; we just need to enable other threads +// on CPU side and let them progress. +// +// The implementation is inspired by the Vulkan-ExtensionLayer project: +// https://github.com/KhronosGroup/Vulkan-ExtensionLayer. We don't handle all +// the aspects of the full spec though given that IREE only uses a subset of +// synchronization primitives. So this should not be treated as a full +// emulation of the Vulkan spec and thus does not substitute +// Vulkan-ExtensionLayer. +iree_status_t iree_hal_vulkan_emulated_semaphore_create( + iree::hal::vulkan::VkDeviceHandle* logical_device, + iree::hal::vulkan::TimePointSemaphorePool* semaphore_pool, + iree_host_size_t command_queue_count, + iree::hal::vulkan::CommandQueue** command_queues, uint64_t initial_value, + iree_hal_semaphore_t** out_semaphore); + +// Acquires a binary semaphore for waiting on the timeline to advance to the +// given |value|. The semaphore returned won't be waited by anyone else. +// |wait_fence| is the fence associated with the queue submission that waiting +// on this semaphore. +// +// Returns VK_NULL_HANDLE if there are no available semaphores for the given +// |value|. +iree_status_t iree_hal_vulkan_emulated_semaphore_acquire_wait_handle( + iree_hal_semaphore_t* semaphore, uint64_t value, + const iree::ref_ptr<iree::hal::vulkan::TimePointFence>& wait_fence, + VkSemaphore* out_handle); + +// Cancels the waiting attempt on the given binary |semaphore|. This allows +// the |semaphore| to be waited by others. +iree_status_t iree_hal_vulkan_emulated_semaphore_cancel_wait_handle( + iree_hal_semaphore_t* semaphore, VkSemaphore handle); + +// Acquires a binary semaphore for signaling the timeline to the given |value|. +// |value| must be smaller than the current timeline value. |signal_fence| is +// the fence associated with the queue submission that signals this semaphore. +iree_status_t iree_hal_vulkan_emulated_semaphore_acquire_signal_handle( + iree_hal_semaphore_t* semaphore, uint64_t value, + const iree::ref_ptr<iree::hal::vulkan::TimePointFence>& signal_fence, + VkSemaphore* out_handle); + +// Performs a multi-wait on one or more semaphores. +// By default this is an all-wait but |wait_flags| may contain +// VK_SEMAPHORE_WAIT_ANY_BIT to change to an any-wait. +// +// Returns IREE_STATUS_DEADLINE_EXCEEDED if the wait does not complete before +// |deadline_ns| elapses. +iree_status_t iree_hal_vulkan_emulated_semaphore_multi_wait( + iree::hal::vulkan::VkDeviceHandle* logical_device, + const iree_hal_semaphore_list_t* semaphore_list, iree_time_t deadline_ns, + VkSemaphoreWaitFlags wait_flags); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_VULKAN_ENUMLATED_SEMAPHORE_H_
diff --git a/iree/hal/vulkan/emulated_timeline_semaphore.cc b/iree/hal/vulkan/emulated_timeline_semaphore.cc deleted file mode 100644 index 475aa33..0000000 --- a/iree/hal/vulkan/emulated_timeline_semaphore.cc +++ /dev/null
@@ -1,380 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/emulated_timeline_semaphore.h" - -#include "absl/container/inlined_vector.h" -#include "absl/synchronization/mutex.h" -#include "iree/base/time.h" -#include "iree/base/tracing.h" -#include "iree/hal/vulkan/dynamic_symbols.h" -#include "iree/hal/vulkan/status_util.h" - -namespace iree { -namespace hal { -namespace vulkan { - -// static -StatusOr<ref_ptr<Semaphore>> EmulatedTimelineSemaphore::Create( - ref_ptr<VkDeviceHandle> logical_device, - std::function<Status(Semaphore*)> on_semaphore_signal, - std::function<void(Semaphore*)> on_semaphore_failure, - std::function<void(absl::Span<VkFence>)> on_fence_signal, - ref_ptr<TimePointSemaphorePool> semaphore_pool, uint64_t initial_value) { - IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::Create"); - return make_ref<EmulatedTimelineSemaphore>( - std::move(logical_device), std::move(on_semaphore_signal), - std::move(on_semaphore_failure), std::move(on_fence_signal), - std::move(semaphore_pool), initial_value); -} - -EmulatedTimelineSemaphore::EmulatedTimelineSemaphore( - ref_ptr<VkDeviceHandle> logical_device, - std::function<Status(Semaphore*)> on_semaphore_signal, - std::function<void(Semaphore*)> on_semaphore_failure, - std::function<void(absl::Span<VkFence>)> on_fence_signal, - ref_ptr<TimePointSemaphorePool> semaphore_pool, uint64_t initial_value) - : signaled_value_(initial_value), - logical_device_(std::move(logical_device)), - on_semaphore_signal_(std::move(on_semaphore_signal)), - on_semaphore_failure_(std::move(on_semaphore_failure)), - on_fence_signal_(std::move(on_fence_signal)), - semaphore_pool_(std::move(semaphore_pool)) {} - -EmulatedTimelineSemaphore::~EmulatedTimelineSemaphore() { - IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::dtor"); - IREE_CHECK_OK(TryToAdvanceTimeline(UINT64_MAX).status()); - absl::MutexLock lock(&mutex_); - IREE_CHECK(outstanding_semaphores_.empty()) - << "Destroying an emulated timeline semaphore without first waiting on " - "outstanding signals"; -} - -StatusOr<uint64_t> EmulatedTimelineSemaphore::Query() { - IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::Query"); - IREE_DVLOG(2) << "EmulatedTimelineSemaphore::Query"; - IREE_ASSIGN_OR_RETURN(bool signaled, TryToAdvanceTimeline(UINT64_MAX)); - (void)signaled; - uint64_t value = signaled_value_.load(); - IREE_DVLOG(2) << "Current timeline value: " << value; - if (value == UINT64_MAX) { - absl::MutexLock lock(&mutex_); - return status_; - } - return value; -} - -Status EmulatedTimelineSemaphore::Signal(uint64_t value) { - IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::Signal"); - IREE_DVLOG(2) << "EmulatedTimelineSemaphore::Signal"; - auto signaled_value = signaled_value_.exchange(value); - IREE_DVLOG(2) << "Previous value: " << signaled_value - << "; new value: " << value; - // Make sure the previous signaled value is smaller than the new value. - IREE_CHECK(signaled_value < value) - << "Attempting to signal a timeline value out of order; trying " << value - << " but " << signaled_value << " already signaled"; - - // Inform the device to make progress given we have a new value signaled now. - IREE_RETURN_IF_ERROR(on_semaphore_signal_(this)); - - return OkStatus(); -} - -Status EmulatedTimelineSemaphore::Wait(uint64_t value, Time deadline_ns) { - IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::Wait"); - IREE_DVLOG(2) << "EmulatedTimelineSemaphore::Wait"; - - VkFence fence = VK_NULL_HANDLE; - do { - IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::Wait#loop"); - // First try to advance the timeline without blocking to see whether we've - // already reached the desired value. - IREE_ASSIGN_OR_RETURN(bool reached_desired_value, - TryToAdvanceTimeline(value)); - if (reached_desired_value) return OkStatus(); - - // We must wait now. Find the first emulated time point that has a value >= - // the desired value so we can wait on its associated signal fence to make - // sure the timeline is advanced to the desired value. - absl::MutexLock lock(&mutex_); - auto semaphore = outstanding_semaphores_.begin(); - for (; semaphore != outstanding_semaphores_.end(); ++semaphore) { - if ((*semaphore)->value >= value) break; - } - if (semaphore != outstanding_semaphores_.end()) { - if (!(*semaphore)->signal_fence) { - return InternalErrorBuilder(IREE_LOC) - << "Timeline should have a signal fence for the first time " - "point beyond the signaled value"; - } - IREE_DVLOG(2) << "Found timepoint semaphore " << *semaphore - << " (value: " << (*semaphore)->value - << ") to wait for desired timeline value: " << value; - fence = (*semaphore)->signal_fence->value(); - // Found; we can break the loop and proceed to waiting now. - break; - } - // TODO(antiagainst): figure out a better way instead of the busy loop here. - } while (Now() < deadline_ns); - - if (fence == VK_NULL_HANDLE) { - return DeadlineExceededErrorBuilder(IREE_LOC) - << "Deadline reached when waiting timeline semaphore"; - } - - uint64_t timeout_ns = - static_cast<uint64_t>(DeadlineToRelativeTimeoutNanos(deadline_ns)); - VK_RETURN_IF_ERROR(logical_device_->syms()->vkWaitForFences( - *logical_device_, /*fenceCount=*/1, &fence, /*waitAll=*/true, - timeout_ns)); - - return TryToAdvanceTimeline(value).status(); -} - -void EmulatedTimelineSemaphore::Fail(Status status) { - IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::Fail"); - absl::MutexLock lock(&mutex_); - status_ = std::move(status); - signaled_value_.store(UINT64_MAX); -} - -VkSemaphore EmulatedTimelineSemaphore::GetWaitSemaphore( - uint64_t value, const ref_ptr<TimePointFence>& wait_fence) { - IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::GetWaitSemaphore"); - IREE_DVLOG(2) << "EmulatedTimelineSemaphore::GetWaitSemaphore"; - - absl::MutexLock lock(&mutex_); - - VkSemaphore semaphore = VK_NULL_HANDLE; - for (TimePointSemaphore* point : outstanding_semaphores_) { - if (point->value > value && point->wait_fence) { - point->wait_fence = add_ref(wait_fence); - semaphore = point->semaphore; - break; - } - } - - IREE_DVLOG(2) << "Binary VkSemaphore to wait on for timeline value (" << value - << ") and wait fence (" << wait_fence.get() - << "): " << semaphore; - - return semaphore; -} - -Status EmulatedTimelineSemaphore::CancelWaitSemaphore(VkSemaphore semaphore) { - IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::CancelWaitSemaphore"); - IREE_DVLOG(2) << "EmulatedTimelineSemaphore::CancelWaitSemaphore"; - - absl::MutexLock lock(&mutex_); - for (TimePointSemaphore* point : outstanding_semaphores_) { - if (point->semaphore != semaphore) continue; - - if (!point->wait_fence) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Time point wasn't waited before"; - } - point->wait_fence = nullptr; - IREE_DVLOG(2) << "Cancelled waiting on binary VkSemaphore: " << semaphore; - return OkStatus(); - } - return InvalidArgumentErrorBuilder(IREE_LOC) - << "No time point for the given semaphore"; -} - -StatusOr<VkSemaphore> EmulatedTimelineSemaphore::GetSignalSemaphore( - uint64_t value, const ref_ptr<TimePointFence>& signal_fence) { - IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::GetSignalSemaphore"); - IREE_DVLOG(2) << "EmulatedTimelineSemaphore::GetSignalSemaphore"; - - if (signaled_value_.load() >= value) { - return FailedPreconditionErrorBuilder(IREE_LOC) - << "Timeline semaphore already signaled past " << value; - } - - absl::MutexLock lock(&mutex_); - - auto insertion_point = outstanding_semaphores_.begin(); - while (insertion_point != outstanding_semaphores_.end()) { - if ((*insertion_point)->value > value) break; - } - - IREE_ASSIGN_OR_RETURN(TimePointSemaphore * semaphore, - semaphore_pool_->Acquire()); - semaphore->value = value; - semaphore->signal_fence = add_ref(signal_fence); - if (semaphore->wait_fence) { - return InternalErrorBuilder(IREE_LOC) - << "Newly acquired time point semaphore should not have waiters"; - } - outstanding_semaphores_.insert(insertion_point, semaphore); - IREE_DVLOG(2) << "Timepoint semaphore to signal for timeline value (" << value - << ") and wait fence (" << signal_fence.get() - << "): " << semaphore - << " (binary VkSemaphore: " << semaphore->semaphore << ")"; - - return semaphore->semaphore; -} - -StatusOr<bool> EmulatedTimelineSemaphore::TryToAdvanceTimeline( - uint64_t to_upper_value) { - absl::InlinedVector<VkFence, 4> signaled_fences; - auto status = TryToAdvanceTimeline(to_upper_value, &signaled_fences); - // Inform the queue that some fences are known to have signaled. This should - // happen here instead of inside the other TryToAdvanceTimeline to avoid - // potential mutex deadlock, given here we are not holding a mutex anymore. - if (!signaled_fences.empty()) { - on_fence_signal_(absl::MakeSpan(signaled_fences)); - } - return status; -} - -StatusOr<bool> EmulatedTimelineSemaphore::TryToAdvanceTimeline( - uint64_t to_upper_value, absl::InlinedVector<VkFence, 4>* signaled_fences) { - IREE_TRACE_SCOPE0("EmulatedTimelineSemaphore::TryToAdvanceTimeline"); - IREE_DVLOG(3) << "EmulatedTimelineSemaphore::TryToAdvanceTimeline"; - - uint64_t past_value = signaled_value_.load(); - IREE_DVLOG(3) << "Current timeline value: " << past_value - << "; desired timeline value: " << to_upper_value; - - // Fast path for when already signaled past the desired value. - if (past_value >= to_upper_value) return true; - - // We hold the lock during the entire resolve process so that we can resolve - // to the furthest possible value. - absl::MutexLock lock(&mutex_); - - IREE_DVLOG(3) << "# outstanding semaphores: " - << outstanding_semaphores_.size(); - - // The timeline has not signaled past the desired value and there is no - // binary semaphore pending on GPU yet: certainly the timeline cannot - // advance to the desired value. - if (outstanding_semaphores_.empty()) return false; - - IntrusiveList<TimePointSemaphore> resolved_semaphores; - - auto clear_signal_fence = [&signaled_fences](ref_ptr<TimePointFence>& fence) { - if (fence) { - if (signaled_fences) signaled_fences->push_back(fence->value()); - fence = nullptr; - } - }; - - bool keep_resolving = true; - bool reached_desired_value = false; - while (keep_resolving && !outstanding_semaphores_.empty()) { - auto* semaphore = outstanding_semaphores_.front(); - IREE_DVLOG(3) << "Looking at timepoint semaphore " << semaphore << ".."; - IREE_DVLOG(3) << " value: " << semaphore->value; - IREE_DVLOG(3) << " VkSemaphore: " << semaphore->semaphore; - IREE_DVLOG(3) << " signal fence: " << semaphore->signal_fence.get(); - IREE_DVLOG(3) << " wait fence: " << semaphore->wait_fence.get(); - - // If the current semaphore is for a value beyond our upper limit, then - // early exit so that we don't spend time dealing with signals we don't yet - // care about. This can prevent live lock where one thread is signaling - // fences as fast/faster than another thread can consume them. - if (semaphore->value > to_upper_value) { - keep_resolving = false; - reached_desired_value = true; - break; - } - - // If the current semaphore is for a value not greater than the past - // signaled value, then we know it was signaled previously. But there might - // be a waiter on it on GPU. - if (semaphore->value <= past_value) { - if (semaphore->signal_fence) { - return InternalErrorBuilder(IREE_LOC) - << "Timeline should already signaled past this time point and " - "cleared the signal fence"; - } - - // If ther is no waiters, we can recycle this semaphore now. If there - // exists one waiter, then query its status and recycle on success. We - // only handle success status here. Others will be handled when the fence - // is checked for other semaphores' signaling status for the same queue - // submission. - if (!semaphore->wait_fence || - semaphore->wait_fence->GetStatus() == VK_SUCCESS) { - clear_signal_fence(semaphore->signal_fence); - semaphore->wait_fence = nullptr; - outstanding_semaphores_.erase(semaphore); - resolved_semaphores.push_back(semaphore); - IREE_DVLOG(3) << "Resolved and recycling semaphore " << semaphore; - } - - continue; - } - - // This semaphore represents a value gerater than the known previously - // signaled value. We don't know its status so we need to really query now. - - if (!semaphore->signal_fence) { - return InternalErrorBuilder(IREE_LOC) - << "The status of this time point in the timeline should still be " - "pending with a singal fence"; - } - VkResult signal_status = semaphore->signal_fence->GetStatus(); - - switch (signal_status) { - case VK_SUCCESS: - IREE_DVLOG(3) << "..semaphore signaled"; - signaled_value_.store(semaphore->value); - clear_signal_fence(semaphore->signal_fence); - // If no waiters, we can recycle this semaphore now. - if (!semaphore->wait_fence) { - semaphore->wait_fence = nullptr; - outstanding_semaphores_.erase(semaphore); - resolved_semaphores.push_back(semaphore); - IREE_DVLOG(3) << "Resolved and recycling semaphore " << semaphore; - } - break; - case VK_NOT_READY: - // The fence has not been signaled yet so this is the furthest time - // point we can go in this timeline. - keep_resolving = false; - IREE_DVLOG(3) << "..semaphore not yet signaled"; - break; - default: - // Fence indicates an error (device lost, out of memory, etc). - // Propagate this back to our status (and thus any waiters). - // Since we only take the first error we find we skip all remaining - // fences. - keep_resolving = false; - clear_signal_fence(semaphore->signal_fence); - status_ = VkResultToStatus(signal_status, IREE_LOC); - signaled_value_.store(UINT64_MAX); - break; - } - } - - IREE_DVLOG(3) << "Releasing " << resolved_semaphores.size() - << " resolved semaphores; " << outstanding_semaphores_.size() - << " still outstanding"; - semaphore_pool_->ReleaseResolved(&resolved_semaphores); - if (!status_.ok()) { - on_semaphore_failure_(this); - semaphore_pool_->ReleaseUnresolved(&outstanding_semaphores_); - return status_; - } - - return reached_desired_value; -} - -} // namespace vulkan -} // namespace hal -} // namespace iree
diff --git a/iree/hal/vulkan/emulated_timeline_semaphore.h b/iree/hal/vulkan/emulated_timeline_semaphore.h deleted file mode 100644 index 39cd569..0000000 --- a/iree/hal/vulkan/emulated_timeline_semaphore.h +++ /dev/null
@@ -1,236 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VULKAN_ENUMLATED_TIMELINE_SEMAPHORE_H_ -#define IREE_HAL_VULKAN_ENUMLATED_TIMELINE_SEMAPHORE_H_ - -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on - -#include <atomic> -#include <vector> - -#include "absl/base/thread_annotations.h" -#include "absl/container/inlined_vector.h" -#include "absl/synchronization/mutex.h" -#include "iree/base/intrusive_list.h" -#include "iree/base/ref_ptr.h" -#include "iree/base/status.h" -#include "iree/hal/cc/semaphore.h" -#include "iree/hal/vulkan/handle_util.h" -#include "iree/hal/vulkan/timepoint_util.h" - -namespace iree { -namespace hal { -namespace vulkan { - -// A timeline semaphore emulated via `VkFence`s and binary `VkSemaphore`s. -// -// Vulkan provides several explicit synchronization primitives: fences, -// (binary/timeline) semaphores, events, pipeline barriers, and render passes. -// See "6. Synchronization and Cache Control" of the Vulkan specification -// for the details. -// -// Render passes are for graphics pipelines so IREE does not care about them. -// Pipeline barriers synchronize control within a command buffer at a single -// point. Fences, (binary/timeline) semaphores, and events are synchronization -// primitives that have separate signal and wait operations. Events are more -// fine-grained compared to fences and semaphores given that they can be -// signaled or waited within a command buffer while fences and semaphores are -// at queue submissions. Each of them have its usage requirements: -// -// * Fences must be signaled on GPU and waited on CPU. Fences must be reset -// before reuse. -// * Binary semaphores must be signaled on GPU and waited on GPU. They do not -// support wait-before-signal submission order. More importantly, binary -// semaphore wait also unsignals the semaphore. So binary semaphore signals -// and waits should occur in discrete 1:1 pairs. -// * Timeline semaphores can be signaled on CPU or GPU and waited on CPU or GPU. -// They support wait-before-signal submission order. Timeline semaphores do -// not need to be reset. -// -// It's clear that timeline semaphore is more flexible than fences and binary -// semaphores: it unifies GPU and CPU synchronization with a single primitive. -// But it's not always available: it requires the VK_KHR_timeline_semaphore -// or Vulkan 1.2. When it's not available, it can be emulated via `VkFence`s -// and binary `VkSemaphore`s. The emulation need to provide the functionality of -// timeline semaphores and also not violate the usage requirements of `VkFence`s -// and binary `VkSemaphore`s. -// -// The basic idea is to create a timeline object with time points to emulate the -// timeline semaphore, which consists of a monotonically increasing 64-bit -// integer value. Each time point represents a specific signaled/waited integer -// value of the timeline semaphore; each time point can associate with binary -// `VkSemaphore`s and/or `VkFence`s for emulating the synchronization. -// -// Concretely, for each of the possible signal -> wait scenarios timeline -// semaphore supports: -// -// ### GPU -> GPU (via `vkQueueSubmit`) -// -// Each `vkQueueSubmit` can attach a `VkTimelineSemaphoreSubmitInfo` to describe -// the timeline semaphore values signaled and waited. Each of the signaled value -// will be a time point and emulated by a binary `VkSemaphore`. We submit the -// binary `VkSemahpore`s to the GPU under the hood. For the waited values, the -// situation is more complicated because of the differences between binary and -// timeline semaphores: -// -// * Binary semaphore signal-wait relationship is strictly 1:1, unlike timeline -// semaphore where we can have 1:N cases. This means for a specific binary -// `VkSemaphore` used to emulate a signaled time point, we can have at most -// one subsequent `vkQueueSubmit` waits on it. We need other mechanisms for -// additional waits. A simple way is to involve the CPU and don't sumbit -// the additional work to queue until the desired value is already signaled -// past. This requires `VkFence`s for letting the CPU know the status of -// GPU progress, but `VkFence` is needed anyway because of GPU -> CPU -// synchronization. -// * Binary semaphores does not support wait-before-signal submission order. -// This means we need to put the submission into a self-managed queue if the -// binary semaphores used to emulate the time points waited by the submission -// are not submitted to GPU yet. -// -// ### GPU -> CPU (via `vkWaitSemaphores`) -// -// Without timeline semaphore, we need to use fences to let CPU wait on GPU -// progress. So this direction can be emulated by `vkWaitFences`. It means we -// need to associate a `VkFence` with the given waited timeline semaphores. -// Because we don't know whether a particular `vkQueueSubmit` with timeline -// semaphores will be later waited on by CPU beforehand, we need to bundle each -// of them with a `VkFence` just in case they will be waited on later. -// -// ### CPU -> GPU (via `vkSignalSemaphore`) -// -// This direction can be handled by bumping the signaled timeline value and -// scan the self-managed queue to submit more work to GPU if possible. -// -// ### CPU -> CPU (via `vkWaitSemaphores`) -// -// This is similar to CPU -> GPU direction; we just need to enable other threads -// on CPU side and let them progress. -// -// The implementation is inspired by the Vulkan-ExtensionLayer project: -// https://github.com/KhronosGroup/Vulkan-ExtensionLayer. We don't handle all -// the aspects of the full spec though given that IREE only uses a subset of -// synchronization primitives. So this should not be treated as a full -// emulation of the Vulkan spec and thus does not substitute -// Vulkan-ExtensionLayer. -class EmulatedTimelineSemaphore final : public Semaphore { - public: - // Creates a timeline semaphore with the given |initial_value|. - static StatusOr<ref_ptr<Semaphore>> Create( - ref_ptr<VkDeviceHandle> logical_device, - std::function<Status(Semaphore*)> on_semaphore_signal, - std::function<void(Semaphore*)> on_semaphore_failure, - std::function<void(absl::Span<VkFence>)> on_fence_signal, - ref_ptr<TimePointSemaphorePool> semaphore_pool, uint64_t initial_value); - - EmulatedTimelineSemaphore( - ref_ptr<VkDeviceHandle> logical_device, - std::function<Status(Semaphore*)> on_semaphore_signal, - std::function<void(Semaphore*)> on_semaphore_failure, - std::function<void(absl::Span<VkFence>)> on_fence_signal, - ref_ptr<TimePointSemaphorePool> semaphore_pool, uint64_t initial_value); - - ~EmulatedTimelineSemaphore() override; - - StatusOr<uint64_t> Query() override; - - Status Signal(uint64_t value) override; - - Status Wait(uint64_t value, Time deadline_ns) override; - - void Fail(Status status) override; - - // Gets a binary semaphore for waiting on the timeline to advance to the given - // |value|. The semaphore returned won't be waited by anyone else. Returns - // VK_NULL_HANDLE if no available semaphores for the given |value|. - // |wait_fence| is the fence associated with the queue submission that waiting - // on this semaphore. - VkSemaphore GetWaitSemaphore(uint64_t value, - const ref_ptr<TimePointFence>& wait_fence); - - // Cancels the waiting attempt on the given binary |semaphore|. This allows - // the |semaphore| to be waited by others. - Status CancelWaitSemaphore(VkSemaphore semaphore); - - // Gets a binary semaphore for signaling the timeline to the given |value|. - // |value| must be smaller than the current timeline value. |signal_fence| is - // the fence associated with the queue submission that signals this semaphore. - StatusOr<VkSemaphore> GetSignalSemaphore( - uint64_t value, const ref_ptr<TimePointFence>& signal_fence); - - private: - // Tries to advance the timeline to the given |to_upper_value| without - // blocking and returns whether the |to_upper_value| is reached. - StatusOr<bool> TryToAdvanceTimeline(uint64_t to_upper_value) - ABSL_LOCKS_EXCLUDED(mutex_); - // Similar to the above, but also returns the fences that are known to have - // already signaled via |signaled_fences|. - StatusOr<bool> TryToAdvanceTimeline( - uint64_t to_upper_value, absl::InlinedVector<VkFence, 4>* signaled_fences) - ABSL_LOCKS_EXCLUDED(mutex_); - - std::atomic<uint64_t> signaled_value_; - - ref_ptr<VkDeviceHandle> logical_device_; - - // Callback to inform that this timeline semaphore has signaled a new value. - std::function<Status(Semaphore*)> on_semaphore_signal_; - - // Callback to inform that this timeline semaphore has encountered a failure. - std::function<void(Semaphore*)> on_semaphore_failure_; - - // Callback to inform that the given fences have signaled. - std::function<void(absl::Span<VkFence>)> on_fence_signal_; - - ref_ptr<TimePointSemaphorePool> semaphore_pool_; - - mutable absl::Mutex mutex_; - - // A list of outstanding semaphores used to emulate time points. - // - // The life time of each semaphore is in one of the following state: - // - // * Unused state: value = UINT64_MAX, signal/wait fence = nullptr. This is - // the state of the semaphore when it's initially acquired from the pool and - // not put in the queue for emulating a time point yet. - // * Pending state: signaled value < value < UINT64_MAX, signal fence = - // <some-fence>, wait fence == nullptr. This is the state of the semaphore - // when it's put into the GPU queue for emulating a time point. - // * Pending and waiting state: signaled value < value < UINT64_MAX, signal - // fence = <some-fence>, wait fence == <some-fence>. This is the state of - // the semaphore when it's put into the GPU queue for emulating a time - // point and there is another queue submission waiting on it in GPU. - // * Signaled and not ever waited state: value <= signaled value, singal/wait - // fence = nullptr. This is the state of the semaphore when we know it's - // already signaled on GPU and there is no waiters for it. - // * Signaled and waiting state: value <= signaled value, signal fence = - // nullptr, wait fence = <some-fence>. This is the state of the semaphore - // when we know it's already signaled on GPU and there is still one queue - // submission on GPU is waiting for it. - IntrusiveList<TimePointSemaphore> outstanding_semaphores_ - ABSL_GUARDED_BY(mutex_); - - // NOTE: We only need to access this status (and thus take the lock) when we - // want to either signal failure or query the status in the case of the - // semaphore being set to UINT64_MAX. - Status status_ ABSL_GUARDED_BY(mutex_); -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_ENUMLATED_TIMELINE_SEMAPHORE_H_
diff --git a/iree/hal/vulkan/extensibility_util.cc b/iree/hal/vulkan/extensibility_util.cc index d78892b..7320cd3 100644 --- a/iree/hal/vulkan/extensibility_util.cc +++ b/iree/hal/vulkan/extensibility_util.cc
@@ -14,195 +14,213 @@ #include "iree/hal/vulkan/extensibility_util.h" -#include "iree/base/memory.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" #include "iree/hal/vulkan/status_util.h" -namespace iree { -namespace hal { -namespace vulkan { - -namespace { - -StatusOr<std::vector<const char*>> MatchAvailableLayers( - absl::Span<const char* const> required_layers, - absl::Span<const char* const> optional_layers, - absl::Span<const VkLayerProperties> properties) { - IREE_TRACE_SCOPE0("MatchAvailableLayers"); - - std::vector<const char*> enabled_layers; - enabled_layers.reserve(required_layers.size() + optional_layers.size()); - - for (const char* layer_name : required_layers) { - bool found = false; - for (const auto& layer_properties : properties) { - if (std::strcmp(layer_name, layer_properties.layerName) == 0) { - IREE_VLOG(1) << "Enabling required layer: " << layer_name; - found = true; - enabled_layers.push_back(layer_name); - break; - } - } - if (!found) { - return UnavailableErrorBuilder(IREE_LOC) - << "Required layer " << layer_name << " not available"; +// Returns true if |layers| contains a layer matching |layer_name|. +static bool iree_hal_vulkan_layer_list_contains(uint32_t layer_count, + const VkLayerProperties* layers, + const char* layer_name) { + for (uint32_t i = 0; i < layer_count; ++i) { + if (strcmp(layer_name, layers[i].layerName) == 0) { + return true; } } - - for (const char* layer_name : optional_layers) { - bool found = false; - for (const auto& layer_properties : properties) { - if (std::strcmp(layer_name, layer_properties.layerName) == 0) { - IREE_VLOG(1) << "Enabling optional layer: " << layer_name; - found = true; - enabled_layers.push_back(layer_name); - break; - } - } - if (!found) { - IREE_VLOG(1) << "Optional layer " << layer_name << " not available"; - } - } - - return enabled_layers; + return false; } -StatusOr<std::vector<const char*>> MatchAvailableExtensions( - absl::Span<const char* const> required_extensions, - absl::Span<const char* const> optional_extensions, - absl::Span<const VkExtensionProperties> properties) { - IREE_TRACE_SCOPE0("MatchAvailableExtensions"); +static iree_status_t iree_hal_vulkan_match_available_layers( + iree_host_size_t available_layers_count, + const VkLayerProperties* available_layers, + const iree_hal_vulkan_string_list_t* required_layers, + const iree_hal_vulkan_string_list_t* optional_layers, + iree_hal_vulkan_string_list_t* out_enabled_layers) { + memset(out_enabled_layers->values, 0, + (required_layers->count + optional_layers->count) * + sizeof(out_enabled_layers->values[0])); - std::vector<const char*> enabled_extensions; - enabled_extensions.reserve(required_extensions.size() + - optional_extensions.size()); - - for (const char* extension_name : required_extensions) { - bool found = false; - for (const auto& extension_properties : properties) { - if (std::strcmp(extension_name, extension_properties.extensionName) == - 0) { - IREE_VLOG(1) << "Enabling required extension: " << extension_name; - found = true; - enabled_extensions.push_back(extension_name); - break; - } + for (iree_host_size_t i = 0; i < required_layers->count; ++i) { + const char* layer_name = required_layers->values[i]; + if (!iree_hal_vulkan_layer_list_contains(available_layers_count, + available_layers, layer_name)) { + return iree_make_status(IREE_STATUS_UNAVAILABLE, + "required layer %s not available", layer_name); } - if (!found) { - return UnavailableErrorBuilder(IREE_LOC) - << "Required extension " << extension_name << " not available"; + out_enabled_layers->values[out_enabled_layers->count++] = layer_name; + } + + for (iree_host_size_t i = 0; i < optional_layers->count; ++i) { + const char* layer_name = optional_layers->values[i]; + if (iree_hal_vulkan_layer_list_contains(available_layers_count, + available_layers, layer_name)) { + out_enabled_layers->values[out_enabled_layers->count++] = layer_name; } } - for (const char* extension_name : optional_extensions) { - bool found = false; - for (const auto& extension_properties : properties) { - if (std::strcmp(extension_name, extension_properties.extensionName) == - 0) { - IREE_VLOG(1) << "Enabling optional extension: " << extension_name; - found = true; - enabled_extensions.push_back(extension_name); - break; - } - } - if (!found) { - IREE_VLOG(1) << "Optional extension " << extension_name - << " not available"; - } - } - - return enabled_extensions; + return iree_ok_status(); } -} // namespace - -StatusOr<std::vector<const char*>> MatchAvailableInstanceLayers( - const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms) { +iree_status_t iree_hal_vulkan_match_available_instance_layers( + const iree::hal::vulkan::DynamicSymbols* syms, + const iree_hal_vulkan_string_list_t* required_layers, + const iree_hal_vulkan_string_list_t* optional_layers, iree::Arena* arena, + iree_hal_vulkan_string_list_t* out_enabled_layers) { uint32_t layer_property_count = 0; VK_RETURN_IF_ERROR( - syms.vkEnumerateInstanceLayerProperties(&layer_property_count, nullptr)); - std::vector<VkLayerProperties> layer_properties(layer_property_count); - VK_RETURN_IF_ERROR(syms.vkEnumerateInstanceLayerProperties( - &layer_property_count, layer_properties.data())); - IREE_ASSIGN_OR_RETURN(auto enabled_layers, - MatchAvailableLayers(extensibility_spec.required_layers, - extensibility_spec.optional_layers, - layer_properties), - _ << "Unable to find all required instance layers"); - return enabled_layers; + syms->vkEnumerateInstanceLayerProperties(&layer_property_count, NULL), + "vkEnumerateInstanceLayerProperties"); + VkLayerProperties* layer_properties = + (VkLayerProperties*)arena->AllocateBytes(layer_property_count * + sizeof(VkLayerProperties)); + VK_RETURN_IF_ERROR(syms->vkEnumerateInstanceLayerProperties( + &layer_property_count, layer_properties), + "vkEnumerateInstanceLayerProperties"); + out_enabled_layers->count = 0; + out_enabled_layers->values = (const char**)arena->AllocateBytes( + (required_layers->count + optional_layers->count) * + sizeof(out_enabled_layers->values[0])); + return iree_hal_vulkan_match_available_layers( + layer_property_count, layer_properties, required_layers, optional_layers, + out_enabled_layers); } -StatusOr<std::vector<const char*>> MatchAvailableInstanceExtensions( - const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms) { +// Returns true if |extensions| contains a layer matching |extension_name|. +static bool iree_hal_vulkan_extension_list_contains( + uint32_t extension_count, const VkExtensionProperties* extensions, + const char* extension_name) { + for (uint32_t i = 0; i < extension_count; ++i) { + if (strcmp(extension_name, extensions[i].extensionName) == 0) { + return true; + } + } + return false; +} + +static iree_status_t iree_hal_vulkan_match_available_extensions( + iree_host_size_t available_extension_count, + const VkExtensionProperties* available_extensions, + const iree_hal_vulkan_string_list_t* required_extensions, + const iree_hal_vulkan_string_list_t* optional_extensions, + iree_hal_vulkan_string_list_t* out_enabled_extensions) { + memset(out_enabled_extensions->values, 0, + (required_extensions->count + optional_extensions->count) * + sizeof(out_enabled_extensions->values[0])); + + for (iree_host_size_t i = 0; i < required_extensions->count; ++i) { + const char* extension_name = required_extensions->values[i]; + if (!iree_hal_vulkan_extension_list_contains( + available_extension_count, available_extensions, extension_name)) { + return iree_make_status(IREE_STATUS_UNAVAILABLE, + "required extension %s not available", + extension_name); + } + out_enabled_extensions->values[out_enabled_extensions->count++] = + extension_name; + } + + for (iree_host_size_t i = 0; i < optional_extensions->count; ++i) { + const char* extension_name = optional_extensions->values[i]; + if (iree_hal_vulkan_extension_list_contains( + available_extension_count, available_extensions, extension_name)) { + out_enabled_extensions->values[out_enabled_extensions->count++] = + extension_name; + } + } + + return iree_ok_status(); +} + +iree_status_t iree_hal_vulkan_match_available_instance_extensions( + const iree::hal::vulkan::DynamicSymbols* syms, + const iree_hal_vulkan_string_list_t* required_extensions, + const iree_hal_vulkan_string_list_t* optional_extensions, + iree::Arena* arena, iree_hal_vulkan_string_list_t* out_enabled_extensions) { uint32_t extension_property_count = 0; - // Warning: leak checks remain disabled if an error is returned. - IREE_DISABLE_LEAK_CHECKS(); - VK_RETURN_IF_ERROR(syms.vkEnumerateInstanceExtensionProperties( - nullptr, &extension_property_count, nullptr)); - std::vector<VkExtensionProperties> extension_properties( - extension_property_count); - VK_RETURN_IF_ERROR(syms.vkEnumerateInstanceExtensionProperties( - nullptr, &extension_property_count, extension_properties.data())); - IREE_ASSIGN_OR_RETURN( - auto enabled_extensions, - MatchAvailableExtensions(extensibility_spec.required_extensions, - extensibility_spec.optional_extensions, - extension_properties), - _ << "Unable to find all required instance extensions"); - IREE_ENABLE_LEAK_CHECKS(); - return enabled_extensions; + VK_RETURN_IF_ERROR(syms->vkEnumerateInstanceExtensionProperties( + NULL, &extension_property_count, NULL), + "vkEnumerateInstanceExtensionProperties"); + VkExtensionProperties* extension_properties = + (VkExtensionProperties*)arena->AllocateBytes( + extension_property_count * sizeof(VkExtensionProperties)); + VK_RETURN_IF_ERROR(syms->vkEnumerateInstanceExtensionProperties( + NULL, &extension_property_count, extension_properties), + "vkEnumerateInstanceExtensionProperties"); + out_enabled_extensions->count = 0; + out_enabled_extensions->values = (const char**)arena->AllocateBytes( + (required_extensions->count + optional_extensions->count) * + sizeof(out_enabled_extensions->values[0])); + return iree_hal_vulkan_match_available_extensions( + extension_property_count, extension_properties, required_extensions, + optional_extensions, out_enabled_extensions); } -StatusOr<std::vector<const char*>> MatchAvailableDeviceExtensions( +iree_status_t iree_hal_vulkan_match_available_device_extensions( + const iree::hal::vulkan::DynamicSymbols* syms, VkPhysicalDevice physical_device, - const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms) { + const iree_hal_vulkan_string_list_t* required_extensions, + const iree_hal_vulkan_string_list_t* optional_extensions, + iree::Arena* arena, iree_hal_vulkan_string_list_t* out_enabled_extensions) { uint32_t extension_property_count = 0; - VK_RETURN_IF_ERROR(syms.vkEnumerateDeviceExtensionProperties( - physical_device, nullptr, &extension_property_count, nullptr)); - std::vector<VkExtensionProperties> extension_properties( - extension_property_count); - VK_RETURN_IF_ERROR(syms.vkEnumerateDeviceExtensionProperties( - physical_device, nullptr, &extension_property_count, - extension_properties.data())); - IREE_ASSIGN_OR_RETURN( - auto enabled_extensions, - MatchAvailableExtensions(extensibility_spec.required_extensions, - extensibility_spec.optional_extensions, - extension_properties), - _ << "Unable to find all required device extensions"); - return enabled_extensions; + VK_RETURN_IF_ERROR( + syms->vkEnumerateDeviceExtensionProperties( + physical_device, NULL, &extension_property_count, NULL), + "vkEnumerateDeviceExtensionProperties"); + VkExtensionProperties* extension_properties = + (VkExtensionProperties*)arena->AllocateBytes( + extension_property_count * sizeof(VkExtensionProperties)); + VK_RETURN_IF_ERROR(syms->vkEnumerateDeviceExtensionProperties( + physical_device, NULL, &extension_property_count, + extension_properties), + "vkEnumerateDeviceExtensionProperties"); + out_enabled_extensions->count = 0; + out_enabled_extensions->values = (const char**)arena->AllocateBytes( + (required_extensions->count + optional_extensions->count) * + sizeof(out_enabled_extensions->values[0])); + return iree_hal_vulkan_match_available_extensions( + extension_property_count, extension_properties, required_extensions, + optional_extensions, out_enabled_extensions); } -InstanceExtensions PopulateEnabledInstanceExtensions( - absl::Span<const char* const> extension_names) { - InstanceExtensions extensions = {0}; - for (const char* extension_name : extension_names) { - if (std::strcmp(extension_name, VK_EXT_DEBUG_REPORT_EXTENSION_NAME) == 0) { - extensions.debug_report = true; - } else if (std::strcmp(extension_name, VK_EXT_DEBUG_UTILS_EXTENSION_NAME) == - 0) { +iree_hal_vulkan_instance_extensions_t +iree_hal_vulkan_populate_enabled_instance_extensions( + const iree_hal_vulkan_string_list_t* enabled_extensions) { + iree_hal_vulkan_instance_extensions_t extensions; + memset(&extensions, 0, sizeof(extensions)); + for (iree_host_size_t i = 0; i < enabled_extensions->count; ++i) { + const char* extension_name = enabled_extensions->values[i]; + if (strcmp(extension_name, VK_EXT_DEBUG_UTILS_EXTENSION_NAME) == 0) { extensions.debug_utils = true; } } return extensions; } -DeviceExtensions PopulateEnabledDeviceExtensions( - absl::Span<const char* const> extension_names) { - DeviceExtensions extensions = {0}; - for (const char* extension_name : extension_names) { - if (std::strcmp(extension_name, VK_KHR_PUSH_DESCRIPTOR_EXTENSION_NAME) == - 0) { +iree_hal_vulkan_device_extensions_t +iree_hal_vulkan_populate_enabled_device_extensions( + const iree_hal_vulkan_string_list_t* enabled_extensions) { + iree_hal_vulkan_device_extensions_t extensions; + memset(&extensions, 0, sizeof(extensions)); + for (iree_host_size_t i = 0; i < enabled_extensions->count; ++i) { + const char* extension_name = enabled_extensions->values[i]; + if (strcmp(extension_name, VK_KHR_PUSH_DESCRIPTOR_EXTENSION_NAME) == 0) { extensions.push_descriptors = true; - } else if (std::strcmp(extension_name, - VK_KHR_TIMELINE_SEMAPHORE_EXTENSION_NAME) == 0) { + } else if (strcmp(extension_name, + VK_KHR_TIMELINE_SEMAPHORE_EXTENSION_NAME) == 0) { extensions.timeline_semaphore = true; } } return extensions; } -} // namespace vulkan -} // namespace hal -} // namespace iree +iree_hal_vulkan_device_extensions_t +iree_hal_vulkan_infer_enabled_device_extensions( + const iree::hal::vulkan::DynamicSymbols* device_syms) { + iree_hal_vulkan_device_extensions_t extensions; + memset(&extensions, 0, sizeof(extensions)); + if (device_syms->vkCmdPushDescriptorSetKHR) { + extensions.push_descriptors = true; + } + if (device_syms->vkSignalSemaphore || device_syms->vkSignalSemaphoreKHR) { + extensions.timeline_semaphore = true; + } + return extensions; +}
diff --git a/iree/hal/vulkan/extensibility_util.h b/iree/hal/vulkan/extensibility_util.h index 3d9435b..c0c8ff8 100644 --- a/iree/hal/vulkan/extensibility_util.h +++ b/iree/hal/vulkan/extensibility_util.h
@@ -12,89 +12,89 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Utilities for working with layers and extensions. - #ifndef IREE_HAL_VULKAN_EXTENSIBILITY_UTIL_H_ #define IREE_HAL_VULKAN_EXTENSIBILITY_UTIL_H_ -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on - -#include <vector> - -#include "absl/types/span.h" -#include "iree/base/status.h" +#include "iree/base/arena.h" +#include "iree/hal/vulkan/api.h" #include "iree/hal/vulkan/dynamic_symbols.h" -namespace iree { -namespace hal { -namespace vulkan { +// A list of NUL-terminated strings (so they can be passed directly to Vulkan). +typedef struct { + iree_host_size_t count; + const char** values; +} iree_hal_vulkan_string_list_t; -// Describes required and optional extensibility points. -struct ExtensibilitySpec { - // A list of required and optional layers. - std::vector<const char*> required_layers; - std::vector<const char*> optional_layers; +// Populates |out_enabled_layers| with all layers that are both available in the +// implementation and |required_layers| and |optional_layers| lists. +// |out_enabled_layers| must have capacity at least the sum of +// |required_layers|.count and |optional_layer|.count. +// Returns failure if any |required_layers| are unavailable. +iree_status_t iree_hal_vulkan_match_available_instance_layers( + const iree::hal::vulkan::DynamicSymbols* syms, + const iree_hal_vulkan_string_list_t* required_layers, + const iree_hal_vulkan_string_list_t* optional_layers, iree::Arena* arena, + iree_hal_vulkan_string_list_t* out_enabled_layers); - // A list of required and optional extensions. - // Prefer using the _EXTENSION_NAME macros to make tracking easier (such as - // 'VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME'). - std::vector<const char*> required_extensions; - std::vector<const char*> optional_extensions; -}; +// Populates |out_enabled_extensions| with all extensions that are both +// available in the implementation and |required_extensions| and +// |optional_extensions| lists. |out_enabled_extensions| must have capacity at +// least the sum of |required_extensions|.count and |optional_extensions|.count. +// Returns failure if any |required_extensions| are unavailable. +iree_status_t iree_hal_vulkan_match_available_instance_extensions( + const iree::hal::vulkan::DynamicSymbols* syms, + const iree_hal_vulkan_string_list_t* required_extensions, + const iree_hal_vulkan_string_list_t* optional_extensions, + iree::Arena* arena, iree_hal_vulkan_string_list_t* out_enabled_extensions); -// Returns a list of layer names available for instances. -// Fails if any required_layers are unavailable. -StatusOr<std::vector<const char*>> MatchAvailableInstanceLayers( - const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms); - -// Returns a list of extension names available for instances. -// Fails if any required_extensions are unavailable. -StatusOr<std::vector<const char*>> MatchAvailableInstanceExtensions( - const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms); - -// Returns a list of extension names available for the given |physical_device|. -// Fails if any required_extensions are unavailable. -StatusOr<std::vector<const char*>> MatchAvailableDeviceExtensions( +// Populates |out_enabled_extensions| with all extensions that are both +// available in the implementation and |required_extensions| and +// |optional_extensions| lists. |out_enabled_extensions| must have capacity at +// least the sum of |required_extensions|.count and |optional_extensions|.count. +// Returns failure if any |required_extensions| are unavailable. +iree_status_t iree_hal_vulkan_match_available_device_extensions( + const iree::hal::vulkan::DynamicSymbols* syms, VkPhysicalDevice physical_device, - const ExtensibilitySpec& extensibility_spec, const DynamicSymbols& syms); + const iree_hal_vulkan_string_list_t* required_extensions, + const iree_hal_vulkan_string_list_t* optional_extensions, + iree::Arena* arena, iree_hal_vulkan_string_list_t* out_enabled_extensions); // Bits for enabled instance extensions. // We must use this to query support instead of just detecting symbol names as // ICDs will resolve the functions sometimes even if they don't support the // extension (or we didn't ask for it to be enabled). -struct InstanceExtensions { - // VK_EXT_debug_report is enabled and a callback is regsitered. - // https://www.khronos.org/registry/vulkan/specs/1.1-extensions/html/chap44.html#VK_EXT_debug_report - bool debug_report : 1; - +typedef struct { // VK_EXT_debug_utils is enabled and a debug messenger is registered. // https://www.khronos.org/registry/vulkan/specs/1.1-extensions/html/chap44.html#VK_EXT_debug_utils bool debug_utils : 1; -}; +} iree_hal_vulkan_instance_extensions_t; // Returns a bitfield with all of the provided extension names. -InstanceExtensions PopulateEnabledInstanceExtensions( - absl::Span<const char* const> extension_names); +iree_hal_vulkan_instance_extensions_t +iree_hal_vulkan_populate_enabled_instance_extensions( + const iree_hal_vulkan_string_list_t* enabled_extension); // Bits for enabled device extensions. // We must use this to query support instead of just detecting symbol names as // ICDs will resolve the functions sometimes even if they don't support the // extension (or we didn't ask for it to be enabled). -struct DeviceExtensions { +typedef struct { // VK_KHR_push_descriptor is enabled and vkCmdPushDescriptorSetKHR is valid. bool push_descriptors : 1; // VK_KHR_timeline_semaphore is enabled. bool timeline_semaphore : 1; -}; +} iree_hal_vulkan_device_extensions_t; // Returns a bitfield with all of the provided extension names. -DeviceExtensions PopulateEnabledDeviceExtensions( - absl::Span<const char* const> extension_names); +iree_hal_vulkan_device_extensions_t +iree_hal_vulkan_populate_enabled_device_extensions( + const iree_hal_vulkan_string_list_t* enabled_extension); -} // namespace vulkan -} // namespace hal -} // namespace iree +// Returns a bitfield with the extensions that are (likely) available on the +// device symbols. This is less reliable than setting the bits directly when +// the known set of extensions is available. +iree_hal_vulkan_device_extensions_t +iree_hal_vulkan_infer_enabled_device_extensions( + const iree::hal::vulkan::DynamicSymbols* device_syms); #endif // IREE_HAL_VULKAN_EXTENSIBILITY_UTIL_H_
diff --git a/iree/hal/vulkan/handle_util.h b/iree/hal/vulkan/handle_util.h index 2cd3f64..7df7402 100644 --- a/iree/hal/vulkan/handle_util.h +++ b/iree/hal/vulkan/handle_util.h
@@ -28,11 +28,12 @@ #include "iree/hal/vulkan/vulkan_headers.h" // clang-format on -#include "absl/synchronization/mutex.h" #include "iree/base/ref_ptr.h" #include "iree/base/status.h" +#include "iree/base/synchronization.h" #include "iree/hal/vulkan/dynamic_symbols.h" #include "iree/hal/vulkan/extensibility_util.h" +#include "iree/hal/vulkan/status_util.h" namespace iree { namespace hal { @@ -40,13 +41,15 @@ class VkDeviceHandle : public RefObject<VkDeviceHandle> { public: - VkDeviceHandle(const ref_ptr<DynamicSymbols>& syms, - DeviceExtensions enabled_extensions, bool owns_device, + VkDeviceHandle(DynamicSymbols* syms, + iree_hal_vulkan_device_extensions_t enabled_extensions, + bool owns_device, iree_allocator_t host_allocator, const VkAllocationCallbacks* allocator = nullptr) : syms_(add_ref(syms)), enabled_extensions_(enabled_extensions), owns_device_(owns_device), - allocator_(allocator) {} + allocator_(allocator), + host_allocator_(host_allocator) {} ~VkDeviceHandle() { reset(); } VkDeviceHandle(const VkDeviceHandle&) = delete; @@ -57,7 +60,8 @@ syms_(std::move(other.syms_)), enabled_extensions_(other.enabled_extensions_), owns_device_(other.owns_device_), - allocator_(other.allocator_) {} + allocator_(other.allocator_), + host_allocator_(other.host_allocator_) {} void reset() { if (value_ == VK_NULL_HANDLE) return; @@ -73,24 +77,31 @@ const ref_ptr<DynamicSymbols>& syms() const noexcept { return syms_; } const VkAllocationCallbacks* allocator() const noexcept { return allocator_; } + iree_allocator_t host_allocator() const noexcept { return host_allocator_; } - const DeviceExtensions& enabled_extensions() const { + const iree_hal_vulkan_device_extensions_t& enabled_extensions() const { return enabled_extensions_; } private: VkDevice value_ = VK_NULL_HANDLE; ref_ptr<DynamicSymbols> syms_; - DeviceExtensions enabled_extensions_; + iree_hal_vulkan_device_extensions_t enabled_extensions_; bool owns_device_; const VkAllocationCallbacks* allocator_ = nullptr; + iree_allocator_t host_allocator_; }; -class VkCommandPoolHandle : public RefObject<VkCommandPoolHandle> { +class VkCommandPoolHandle { public: - explicit VkCommandPoolHandle(const ref_ptr<VkDeviceHandle>& logical_device) - : logical_device_(add_ref(logical_device)) {} - ~VkCommandPoolHandle() { reset(); } + explicit VkCommandPoolHandle(VkDeviceHandle* logical_device) + : logical_device_(logical_device) { + iree_slim_mutex_initialize(&mutex_); + } + ~VkCommandPoolHandle() { + reset(); + iree_slim_mutex_deinitialize(&mutex_); + } VkCommandPoolHandle(const VkCommandPoolHandle&) = delete; VkCommandPoolHandle& operator=(const VkCommandPoolHandle&) = delete; @@ -114,7 +125,7 @@ VkCommandPool* mutable_value() noexcept { return &value_; } operator VkCommandPool() const noexcept { return value_; } - const ref_ptr<VkDeviceHandle>& logical_device() const noexcept { + const VkDeviceHandle* logical_device() const noexcept { return logical_device_; } const ref_ptr<DynamicSymbols>& syms() const noexcept { @@ -124,16 +135,31 @@ return logical_device_->allocator(); } - absl::Mutex* mutex() const { return &mutex_; } + iree_status_t Allocate(const VkCommandBufferAllocateInfo* allocate_info, + VkCommandBuffer* out_handle) { + iree_slim_mutex_lock(&mutex_); + iree_status_t status = + VK_RESULT_TO_STATUS(syms()->vkAllocateCommandBuffers( + *logical_device_, allocate_info, out_handle), + "vkAllocateCommandBuffers"); + iree_slim_mutex_unlock(&mutex_); + return status; + } + + void Free(VkCommandBuffer handle) { + iree_slim_mutex_lock(&mutex_); + syms()->vkFreeCommandBuffers(*logical_device_, value_, 1, &handle); + iree_slim_mutex_unlock(&mutex_); + } private: - ref_ptr<VkDeviceHandle> logical_device_; + VkDeviceHandle* logical_device_; VkCommandPool value_ = VK_NULL_HANDLE; // Vulkan command pools are not thread safe and require external // synchronization. Since we allow arbitrary threads to allocate and // deallocate the HAL command buffers we need to externally synchronize. - mutable absl::Mutex mutex_; + iree_slim_mutex_t mutex_; }; } // namespace vulkan
diff --git a/iree/hal/vulkan/native_descriptor_set.cc b/iree/hal/vulkan/native_descriptor_set.cc index f000c1c..a047d31 100644 --- a/iree/hal/vulkan/native_descriptor_set.cc +++ b/iree/hal/vulkan/native_descriptor_set.cc
@@ -14,23 +14,80 @@ #include "iree/hal/vulkan/native_descriptor_set.h" -namespace iree { -namespace hal { -namespace vulkan { +#include "iree/base/tracing.h" -NativeDescriptorSet::NativeDescriptorSet(ref_ptr<VkDeviceHandle> logical_device, - VkDescriptorSet handle) - : logical_device_(std::move(logical_device)), handle_(handle) {} +using namespace iree::hal::vulkan; -NativeDescriptorSet::~NativeDescriptorSet() { +typedef struct { + iree_hal_resource_t resource; + VkDeviceHandle* logical_device; + VkDescriptorSet handle; +} iree_hal_vulkan_native_descriptor_set_t; + +extern const iree_hal_descriptor_set_vtable_t + iree_hal_vulkan_native_descriptor_set_vtable; + +static iree_hal_vulkan_native_descriptor_set_t* +iree_hal_vulkan_native_descriptor_set_cast( + iree_hal_descriptor_set_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, + &iree_hal_vulkan_native_descriptor_set_vtable); + return (iree_hal_vulkan_native_descriptor_set_t*)base_value; +} + +iree_status_t iree_hal_vulkan_native_descriptor_set_create( + iree::hal::vulkan::VkDeviceHandle* logical_device, VkDescriptorSet handle, + iree_hal_descriptor_set_t** out_descriptor_set) { + IREE_ASSERT_ARGUMENT(logical_device); + IREE_ASSERT_ARGUMENT(handle); + IREE_ASSERT_ARGUMENT(out_descriptor_set); + *out_descriptor_set = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_vulkan_native_descriptor_set_t* descriptor_set = NULL; + iree_status_t status = + iree_allocator_malloc(logical_device->host_allocator(), + sizeof(*descriptor_set), (void**)&descriptor_set); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_vulkan_native_descriptor_set_vtable, + &descriptor_set->resource); + descriptor_set->logical_device = logical_device; + descriptor_set->handle = handle; + *out_descriptor_set = (iree_hal_descriptor_set_t*)descriptor_set; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_vulkan_native_descriptor_set_destroy( + iree_hal_descriptor_set_t* base_descriptor_set) { + iree_hal_vulkan_native_descriptor_set_t* descriptor_set = + iree_hal_vulkan_native_descriptor_set_cast(base_descriptor_set); + iree_allocator_t host_allocator = + descriptor_set->logical_device->host_allocator(); + IREE_TRACE_ZONE_BEGIN(z0); + // TODO(benvanik): return to pool. For now we rely on the descriptor cache to // reset entire pools at once via via vkResetDescriptorPool so we don't need // to do anything here (the VkDescriptorSet handle will just be invalidated). // In the future if we want to have generational collection/defragmentation // of the descriptor cache we'll want to allow both pooled and unpooled // descriptors and clean them up here appropriately. + + iree_allocator_free(host_allocator, descriptor_set); + + IREE_TRACE_ZONE_END(z0); } -} // namespace vulkan -} // namespace hal -} // namespace iree +VkDescriptorSet iree_hal_vulkan_native_descriptor_set_handle( + iree_hal_descriptor_set_t* base_descriptor_set) { + iree_hal_vulkan_native_descriptor_set_t* descriptor_set = + iree_hal_vulkan_native_descriptor_set_cast(base_descriptor_set); + return descriptor_set->handle; +} + +const iree_hal_descriptor_set_vtable_t + iree_hal_vulkan_native_descriptor_set_vtable = { + /*.destroy=*/iree_hal_vulkan_native_descriptor_set_destroy, +};
diff --git a/iree/hal/vulkan/native_descriptor_set.h b/iree/hal/vulkan/native_descriptor_set.h index f83662f..cf9379e 100644 --- a/iree/hal/vulkan/native_descriptor_set.h +++ b/iree/hal/vulkan/native_descriptor_set.h
@@ -15,33 +15,24 @@ #ifndef IREE_HAL_VULKAN_NATIVE_DESCRIPTOR_SET_H_ #define IREE_HAL_VULKAN_NATIVE_DESCRIPTOR_SET_H_ -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on - -#include "iree/hal/cc/descriptor_set.h" +#include "iree/hal/api.h" #include "iree/hal/vulkan/handle_util.h" -namespace iree { -namespace hal { -namespace vulkan { +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus -// A DescriptorSet implemented with the native VkDescriptorSet type. -class NativeDescriptorSet final : public DescriptorSet { - public: - NativeDescriptorSet(ref_ptr<VkDeviceHandle> logical_device, - VkDescriptorSet handle); - ~NativeDescriptorSet() override; +// Creates a native Vulkan VkDescriptorSet object. +iree_status_t iree_hal_vulkan_native_descriptor_set_create( + iree::hal::vulkan::VkDeviceHandle* logical_device, VkDescriptorSet handle, + iree_hal_descriptor_set_t** out_descriptor_set); - VkDescriptorSet handle() const { return handle_; } +// Returns the native Vulkan VkDescriptorSet handle. +VkDescriptorSet iree_hal_vulkan_native_descriptor_set_handle( + iree_hal_descriptor_set_t* base_descriptor_set); - private: - ref_ptr<VkDeviceHandle> logical_device_; - VkDescriptorSet handle_; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #endif // IREE_HAL_VULKAN_NATIVE_DESCRIPTOR_SET_H_
diff --git a/iree/hal/vulkan/native_descriptor_set_layout.cc b/iree/hal/vulkan/native_descriptor_set_layout.cc new file mode 100644 index 0000000..d744b41 --- /dev/null +++ b/iree/hal/vulkan/native_descriptor_set_layout.cc
@@ -0,0 +1,156 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/vulkan/native_descriptor_set_layout.h" + +#include "iree/base/tracing.h" +#include "iree/hal/vulkan/status_util.h" + +using namespace iree::hal::vulkan; + +typedef struct { + iree_hal_resource_t resource; + VkDeviceHandle* logical_device; + VkDescriptorSetLayout handle; +} iree_hal_vulkan_native_descriptor_set_layout_t; + +extern const iree_hal_descriptor_set_layout_vtable_t + iree_hal_vulkan_native_descriptor_set_layout_vtable; + +static iree_hal_vulkan_native_descriptor_set_layout_t* +iree_hal_vulkan_native_descriptor_set_layout_cast( + iree_hal_descriptor_set_layout_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, + &iree_hal_vulkan_native_descriptor_set_layout_vtable); + return (iree_hal_vulkan_native_descriptor_set_layout_t*)base_value; +} + +static iree_status_t iree_hal_vulkan_create_descriptor_set_layout( + VkDeviceHandle* logical_device, + iree_hal_descriptor_set_layout_usage_type_t usage_type, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_layout_binding_t* bindings, + VkDescriptorSetLayout* out_handle) { + VkDescriptorSetLayoutCreateInfo create_info; + create_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO; + create_info.pNext = NULL; + create_info.flags = 0; + if (usage_type == IREE_HAL_DESCRIPTOR_SET_LAYOUT_USAGE_TYPE_PUSH_ONLY && + logical_device->enabled_extensions().push_descriptors) { + // Note that we can *only* use push descriptor sets if we set this create + // flag. If push descriptors aren't supported we emulate them with normal + // descriptors so it's fine to have kPushOnly without support. + create_info.flags |= + VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT_KHR; + } + + VkDescriptorSetLayoutBinding* native_bindings = NULL; + IREE_RETURN_IF_ERROR(iree_allocator_malloc( + logical_device->host_allocator(), + binding_count * sizeof(VkDescriptorSetLayoutBinding), + (void**)&native_bindings)); + for (iree_host_size_t i = 0; i < binding_count; ++i) { + VkDescriptorSetLayoutBinding* native_binding = &native_bindings[i]; + native_binding->binding = bindings[i].binding; + native_binding->descriptorType = + static_cast<VkDescriptorType>(bindings[i].type); + native_binding->descriptorCount = 1; + native_binding->stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; + native_binding->pImmutableSamplers = NULL; + } + create_info.bindingCount = (uint32_t)binding_count; + create_info.pBindings = native_bindings; + + iree_status_t status = + VK_RESULT_TO_STATUS(logical_device->syms()->vkCreateDescriptorSetLayout( + *logical_device, &create_info, + logical_device->allocator(), out_handle), + "vkCreateDescriptorSetLayout"); + + iree_allocator_free(logical_device->host_allocator(), native_bindings); + return status; +} + +static void iree_hal_vulkan_destroy_descriptor_set_layout( + VkDeviceHandle* logical_device, VkDescriptorSetLayout handle) { + if (handle == VK_NULL_HANDLE) return; + logical_device->syms()->vkDestroyDescriptorSetLayout( + *logical_device, handle, logical_device->allocator()); +} + +iree_status_t iree_hal_vulkan_native_descriptor_set_layout_create( + VkDeviceHandle* logical_device, + iree_hal_descriptor_set_layout_usage_type_t usage_type, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_layout_binding_t* bindings, + iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) { + IREE_ASSERT_ARGUMENT(logical_device); + IREE_ASSERT_ARGUMENT(!binding_count || bindings); + IREE_ASSERT_ARGUMENT(out_descriptor_set_layout); + *out_descriptor_set_layout = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + VkDescriptorSetLayout handle = VK_NULL_HANDLE; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_vulkan_create_descriptor_set_layout( + logical_device, usage_type, binding_count, bindings, &handle)); + + iree_hal_vulkan_native_descriptor_set_layout_t* descriptor_set_layout = NULL; + iree_status_t status = iree_allocator_malloc(logical_device->host_allocator(), + sizeof(*descriptor_set_layout), + (void**)&descriptor_set_layout); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize( + &iree_hal_vulkan_native_descriptor_set_layout_vtable, + &descriptor_set_layout->resource); + descriptor_set_layout->logical_device = logical_device; + descriptor_set_layout->handle = handle; + *out_descriptor_set_layout = + (iree_hal_descriptor_set_layout_t*)descriptor_set_layout; + } else { + iree_hal_vulkan_destroy_descriptor_set_layout(logical_device, handle); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_vulkan_native_descriptor_set_layout_destroy( + iree_hal_descriptor_set_layout_t* base_descriptor_set_layout) { + iree_hal_vulkan_native_descriptor_set_layout_t* descriptor_set_layout = + iree_hal_vulkan_native_descriptor_set_layout_cast( + base_descriptor_set_layout); + iree_allocator_t host_allocator = + descriptor_set_layout->logical_device->host_allocator(); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_vulkan_destroy_descriptor_set_layout( + descriptor_set_layout->logical_device, descriptor_set_layout->handle); + iree_allocator_free(host_allocator, descriptor_set_layout); + + IREE_TRACE_ZONE_END(z0); +} + +VkDescriptorSetLayout iree_hal_vulkan_native_descriptor_set_layout_handle( + iree_hal_descriptor_set_layout_t* base_descriptor_set_layout) { + iree_hal_vulkan_native_descriptor_set_layout_t* descriptor_set_layout = + iree_hal_vulkan_native_descriptor_set_layout_cast( + base_descriptor_set_layout); + return descriptor_set_layout->handle; +} + +const iree_hal_descriptor_set_layout_vtable_t + iree_hal_vulkan_native_descriptor_set_layout_vtable = { + /*.destroy=*/iree_hal_vulkan_native_descriptor_set_layout_destroy, +};
diff --git a/iree/hal/vulkan/native_descriptor_set_layout.h b/iree/hal/vulkan/native_descriptor_set_layout.h new file mode 100644 index 0000000..d7fc86b --- /dev/null +++ b/iree/hal/vulkan/native_descriptor_set_layout.h
@@ -0,0 +1,41 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_VULKAN_NATIVE_DESCRIPTOR_SET_LAYOUT_H_ +#define IREE_HAL_VULKAN_NATIVE_DESCRIPTOR_SET_LAYOUT_H_ + +#include "iree/hal/api.h" +#include "iree/hal/vulkan/handle_util.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates a native Vulkan VkDescriptorSetLayout object. +iree_status_t iree_hal_vulkan_native_descriptor_set_layout_create( + iree::hal::vulkan::VkDeviceHandle* logical_device, + iree_hal_descriptor_set_layout_usage_type_t usage_type, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_layout_binding_t* bindings, + iree_hal_descriptor_set_layout_t** out_descriptor_set_layout); + +// Returns the native Vulkan VkDescriptorSetLayout handle. +VkDescriptorSetLayout iree_hal_vulkan_native_descriptor_set_layout_handle( + iree_hal_descriptor_set_layout_t* base_descriptor_set_layout); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_VULKAN_NATIVE_DESCRIPTOR_SET_LAYOUT_H_
diff --git a/iree/hal/vulkan/native_event.cc b/iree/hal/vulkan/native_event.cc index 28dbc56..c9a7b13 100644 --- a/iree/hal/vulkan/native_event.cc +++ b/iree/hal/vulkan/native_event.cc
@@ -14,18 +14,89 @@ #include "iree/hal/vulkan/native_event.h" -namespace iree { -namespace hal { -namespace vulkan { +#include "iree/base/tracing.h" +#include "iree/hal/vulkan/status_util.h" -NativeEvent::NativeEvent(ref_ptr<VkDeviceHandle> logical_device, VkEvent handle) - : logical_device_(std::move(logical_device)), handle_(handle) {} +using namespace iree::hal::vulkan; -NativeEvent::~NativeEvent() { - logical_device_->syms()->vkDestroyEvent(*logical_device_, handle_, - logical_device_->allocator()); +typedef struct { + iree_hal_resource_t resource; + VkDeviceHandle* logical_device; + VkEvent handle; +} iree_hal_vulkan_native_event_t; + +extern const iree_hal_event_vtable_t iree_hal_vulkan_native_event_vtable; + +static iree_hal_vulkan_native_event_t* iree_hal_vulkan_native_event_cast( + iree_hal_event_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_vulkan_native_event_vtable); + return (iree_hal_vulkan_native_event_t*)base_value; } -} // namespace vulkan -} // namespace hal -} // namespace iree +static iree_status_t iree_hal_vulkan_create_event( + VkDeviceHandle* logical_device, VkEvent* out_handle) { + VkEventCreateInfo create_info; + create_info.sType = VK_STRUCTURE_TYPE_EVENT_CREATE_INFO; + create_info.pNext = NULL; + create_info.flags = 0; + return VK_RESULT_TO_STATUS(logical_device->syms()->vkCreateEvent( + *logical_device, &create_info, + logical_device->allocator(), out_handle), + "vkCreateEvent"); +} + +static void iree_hal_vulkan_destroy_event(VkDeviceHandle* logical_device, + VkEvent handle) { + if (handle == VK_NULL_HANDLE) return; + logical_device->syms()->vkDestroyEvent(*logical_device, handle, + logical_device->allocator()); +} + +iree_status_t iree_hal_vulkan_native_event_create( + VkDeviceHandle* logical_device, iree_hal_event_t** out_event) { + IREE_ASSERT_ARGUMENT(logical_device); + IREE_ASSERT_ARGUMENT(out_event); + *out_event = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + VkEvent handle = VK_NULL_HANDLE; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_vulkan_create_event(logical_device, &handle)); + + iree_hal_vulkan_native_event_t* event = NULL; + iree_status_t status = iree_allocator_malloc(logical_device->host_allocator(), + sizeof(*event), (void**)&event); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_vulkan_native_event_vtable, + &event->resource); + event->logical_device = logical_device; + event->handle = handle; + *out_event = (iree_hal_event_t*)event; + } else { + iree_hal_vulkan_destroy_event(logical_device, handle); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_vulkan_native_event_destroy(iree_hal_event_t* base_event) { + iree_hal_vulkan_native_event_t* event = + iree_hal_vulkan_native_event_cast(base_event); + iree_allocator_t host_allocator = event->logical_device->host_allocator(); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_vulkan_destroy_event(event->logical_device, event->handle); + iree_allocator_free(host_allocator, event); + + IREE_TRACE_ZONE_END(z0); +} + +VkEvent iree_hal_vulkan_native_event_handle( + const iree_hal_event_t* base_event) { + return ((const iree_hal_vulkan_native_event_t*)base_event)->handle; +} + +const iree_hal_event_vtable_t iree_hal_vulkan_native_event_vtable = { + /*.destroy=*/iree_hal_vulkan_native_event_destroy, +};
diff --git a/iree/hal/vulkan/native_event.h b/iree/hal/vulkan/native_event.h index 28923cd..83c7919 100644 --- a/iree/hal/vulkan/native_event.h +++ b/iree/hal/vulkan/native_event.h
@@ -15,32 +15,23 @@ #ifndef IREE_HAL_VULKAN_NATIVE_EVENT_H_ #define IREE_HAL_VULKAN_NATIVE_EVENT_H_ -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on - -#include "iree/hal/cc/event.h" +#include "iree/hal/api.h" #include "iree/hal/vulkan/handle_util.h" -namespace iree { -namespace hal { -namespace vulkan { +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus -// An event implemented with the native VkEvent type. -class NativeEvent final : public Event { - public: - NativeEvent(ref_ptr<VkDeviceHandle> logical_device, VkEvent handle); - ~NativeEvent() override; +// Creates a native Vulkan VkEvent object. +iree_status_t iree_hal_vulkan_native_event_create( + iree::hal::vulkan::VkDeviceHandle* logical_device, + iree_hal_event_t** out_event); - VkEvent handle() const { return handle_; } +// Returns Vulkan event handle. +VkEvent iree_hal_vulkan_native_event_handle(const iree_hal_event_t* event); - private: - ref_ptr<VkDeviceHandle> logical_device_; - VkEvent handle_; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #endif // IREE_HAL_VULKAN_NATIVE_EVENT_H_
diff --git a/iree/hal/vulkan/native_executable.cc b/iree/hal/vulkan/native_executable.cc new file mode 100644 index 0000000..ea80c96 --- /dev/null +++ b/iree/hal/vulkan/native_executable.cc
@@ -0,0 +1,286 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/vulkan/native_executable.h" + +#include "iree/base/memory.h" +#include "iree/base/tracing.h" +#include "iree/hal/vulkan/handle_util.h" +#include "iree/hal/vulkan/native_executable_layout.h" +#include "iree/hal/vulkan/status_util.h" + +// flatcc schemas: +#include "iree/base/flatcc.h" +#include "iree/schemas/spirv_executable_def_reader.h" +#include "iree/schemas/spirv_executable_def_verifier.h" + +using namespace iree::hal::vulkan; + +static iree_status_t iree_hal_vulkan_create_shader_module( + VkDeviceHandle* logical_device, iree_const_byte_span_t code, + VkShaderModule* out_shader_module) { + VkShaderModuleCreateInfo create_info; + create_info.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; + create_info.pNext = NULL; + create_info.flags = 0; + create_info.codeSize = code.data_length; + create_info.pCode = (const uint32_t*)code.data; + VK_RETURN_IF_ERROR(logical_device->syms()->vkCreateShaderModule( + *logical_device, &create_info, + logical_device->allocator(), out_shader_module), + "vkCreateShaderModule"); + return iree_ok_status(); +} + +static void iree_hal_vulkan_destroy_shader_module( + VkDeviceHandle* logical_device, VkShaderModule handle) { + if (handle == VK_NULL_HANDLE) return; + logical_device->syms()->vkDestroyShaderModule(*logical_device, handle, + logical_device->allocator()); +} + +static iree_status_t iree_hal_vulkan_create_pipelines( + VkDeviceHandle* logical_device, VkPipelineCache pipeline_cache, + iree_hal_executable_layout_t* executable_layout, + iree_hal_executable_caching_mode_t caching_mode, + iree_SpirVExecutableDef_table_t executable_def, + VkShaderModule shader_module, iree_host_size_t pipeline_count, + VkPipeline* out_pipelines) { + VkComputePipelineCreateInfo* create_infos = NULL; + IREE_RETURN_IF_ERROR(iree_allocator_malloc( + logical_device->host_allocator(), sizeof(VkComputePipelineCreateInfo), + (void**)&create_infos)); + + flatbuffers_string_vec_t entry_points_vec = + iree_SpirVExecutableDef_entry_points_get(executable_def); + for (iree_host_size_t entry_ordinal = 0; entry_ordinal < pipeline_count; + ++entry_ordinal) { + VkComputePipelineCreateInfo* create_info = &create_infos[entry_ordinal]; + create_info->sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; + create_info->pNext = NULL; + create_info->flags = 0; + if (!iree_all_bits_set( + caching_mode, + IREE_HAL_EXECUTABLE_CACHING_MODE_ALLOW_OPTIMIZATION)) { + create_info->flags |= VK_PIPELINE_CREATE_DISABLE_OPTIMIZATION_BIT; + } + if (entry_ordinal == 0) { + create_info->flags |= VK_PIPELINE_CREATE_ALLOW_DERIVATIVES_BIT; + } else { + create_info->flags |= VK_PIPELINE_CREATE_DERIVATIVE_BIT; + } + create_info->layout = + iree_hal_vulkan_native_executable_layout_handle(executable_layout); + create_info->basePipelineHandle = VK_NULL_HANDLE; + create_info->basePipelineIndex = 0; + VkPipelineShaderStageCreateInfo* stage_create_info = &create_info->stage; + stage_create_info->sType = + VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; + stage_create_info->pNext = NULL; + stage_create_info->flags = 0; + stage_create_info->stage = VK_SHADER_STAGE_COMPUTE_BIT; + stage_create_info->module = shader_module; + stage_create_info->pName = + flatbuffers_string_vec_at(entry_points_vec, entry_ordinal); + stage_create_info->pSpecializationInfo = NULL; + } + + iree_status_t status = VK_RESULT_TO_STATUS( + logical_device->syms()->vkCreateComputePipelines( + *logical_device, pipeline_cache, (uint32_t)pipeline_count, + create_infos, logical_device->allocator(), out_pipelines), + "vkCreateComputePipelines"); + + iree_allocator_free(logical_device->host_allocator(), create_infos); + return status; +} + +static void iree_hal_vulkan_destroy_pipeline(VkDeviceHandle* logical_device, + VkPipeline handle) { + if (handle == VK_NULL_HANDLE) return; + logical_device->syms()->vkDestroyPipeline(*logical_device, handle, + logical_device->allocator()); +} + +// Verifies the structure of the flatbuffer so that we can avoid doing so during +// runtime. There are still some conditions we must be aware of (such as omitted +// names on functions with internal linkage), however we shouldn't need to +// bounds check anything within the flatbuffer after this succeeds. +static iree_status_t iree_hal_spirv_executable_flatbuffer_verify( + iree_const_byte_span_t flatbuffer_data) { + if (!flatbuffer_data.data || flatbuffer_data.data_length < 16) { + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "flatbuffer data is not present or less than 16 bytes (%zu total)", + flatbuffer_data.data_length); + } + + // Run flatcc generated verification. This ensures all pointers are in-bounds + // and that we can safely walk the file, but not that the actual contents of + // the flatbuffer meet our expectations. + int verify_ret = iree_SpirVExecutableDef_verify_as_root( + flatbuffer_data.data, flatbuffer_data.data_length); + if (verify_ret != flatcc_verify_ok) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "flatbuffer verification failed: %s", + flatcc_verify_error_string(verify_ret)); + } + + iree_SpirVExecutableDef_table_t executable_def = + iree_SpirVExecutableDef_as_root(flatbuffer_data.data); + + flatbuffers_string_vec_t entry_points_vec = + iree_SpirVExecutableDef_entry_points_get(executable_def); + size_t entry_point_count = flatbuffers_string_vec_len(entry_points_vec); + for (size_t i = 0; i < entry_point_count; ++i) { + if (!flatbuffers_string_len( + flatbuffers_string_vec_at(entry_points_vec, i))) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "executable entry point %zu has no name", i); + } + } + + if (flatbuffers_uint32_vec_len( + iree_SpirVExecutableDef_code_get(executable_def)) < 0) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "executable SPIR-V code is missing/empty"); + } + + // TODO(benvanik): pull PopulateSpecializationInfo from history and update. + // For now the compiler isn't generating them, and we don't use them. + if (iree_SpirVExecutableDef_specialization_info_is_present(executable_def)) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "executable uses SPIR-V specialization constants; " + "they need to be revived"); + } + + return iree_ok_status(); +} + +typedef struct { + iree_hal_resource_t resource; + VkDeviceHandle* logical_device; + iree_host_size_t pipeline_count; + VkPipeline pipelines[]; +} iree_hal_vulkan_native_executable_t; + +extern const iree_hal_executable_vtable_t + iree_hal_vulkan_native_executable_vtable; + +static iree_hal_vulkan_native_executable_t* +iree_hal_vulkan_native_executable_cast(iree_hal_executable_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_vulkan_native_executable_vtable); + return (iree_hal_vulkan_native_executable_t*)base_value; +} + +iree_status_t iree_hal_vulkan_native_executable_create( + iree::hal::vulkan::VkDeviceHandle* logical_device, + VkPipelineCache pipeline_cache, + iree_hal_executable_layout_t* executable_layout, + iree_hal_executable_caching_mode_t caching_mode, + iree_const_byte_span_t executable_data, + iree_hal_executable_t** out_executable) { + IREE_ASSERT_ARGUMENT(logical_device); + IREE_ASSERT_ARGUMENT(out_executable); + *out_executable = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + // Verify and fetch the executable flatbuffer wrapper. + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_spirv_executable_flatbuffer_verify(executable_data)); + iree_SpirVExecutableDef_table_t executable_def = + iree_SpirVExecutableDef_as_root(executable_data.data); + + // Create the shader module. + flatbuffers_uint32_vec_t code_vec = + iree_SpirVExecutableDef_code_get(executable_def); + VkShaderModule shader_module = VK_NULL_HANDLE; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_vulkan_create_shader_module( + logical_device, + iree_make_const_byte_span( + code_vec, + flatbuffers_uint32_vec_len(code_vec) * sizeof(uint32_t)), + &shader_module)); + + // Create pipelines for each entry point. + flatbuffers_string_vec_t entry_points_vec = + iree_SpirVExecutableDef_entry_points_get(executable_def); + iree_host_size_t pipeline_count = + flatbuffers_string_vec_len(entry_points_vec); + + iree_hal_vulkan_native_executable_t* executable = NULL; + iree_host_size_t total_size = + sizeof(*executable) + pipeline_count * sizeof(*executable->pipelines); + iree_status_t status = iree_allocator_malloc(logical_device->host_allocator(), + total_size, (void**)&executable); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_vulkan_native_executable_vtable, + &executable->resource); + executable->logical_device = logical_device; + executable->pipeline_count = pipeline_count; + memset(executable->pipelines, 0, + pipeline_count * sizeof(*executable->pipelines)); + } + if (iree_status_is_ok(status)) { + status = iree_hal_vulkan_create_pipelines( + logical_device, pipeline_cache, executable_layout, caching_mode, + executable_def, shader_module, executable->pipeline_count, + executable->pipelines); + } + iree_hal_vulkan_destroy_shader_module(logical_device, shader_module); + + if (iree_status_is_ok(status)) { + *out_executable = (iree_hal_executable_t*)executable; + } else { + iree_hal_executable_destroy((iree_hal_executable_t*)executable); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_vulkan_native_executable_destroy( + iree_hal_executable_t* base_executable) { + iree_hal_vulkan_native_executable_t* executable = + iree_hal_vulkan_native_executable_cast(base_executable); + iree_allocator_t host_allocator = + executable->logical_device->host_allocator(); + IREE_TRACE_ZONE_BEGIN(z0); + + for (iree_host_size_t i = 0; i < executable->pipeline_count; ++i) { + iree_hal_vulkan_destroy_pipeline(executable->logical_device, + executable->pipelines[i]); + } + iree_allocator_free(host_allocator, executable); + + IREE_TRACE_ZONE_END(z0); +} + +iree_status_t iree_hal_vulkan_native_executable_pipeline_for_entry_point( + iree_hal_executable_t* base_executable, iree_host_size_t entry_ordinal, + VkPipeline* out_pipeline_handle) { + iree_hal_vulkan_native_executable_t* executable = + iree_hal_vulkan_native_executable_cast(base_executable); + if (entry_ordinal >= executable->pipeline_count) { + return iree_make_status(IREE_STATUS_OUT_OF_RANGE, + "invalid entry point ordinal %zu", entry_ordinal); + } + *out_pipeline_handle = executable->pipelines[entry_ordinal]; + return iree_ok_status(); +} + +const iree_hal_executable_vtable_t iree_hal_vulkan_native_executable_vtable = { + /*.destroy=*/iree_hal_vulkan_native_executable_destroy, +};
diff --git a/iree/hal/vulkan/native_executable.h b/iree/hal/vulkan/native_executable.h new file mode 100644 index 0000000..d8372c7 --- /dev/null +++ b/iree/hal/vulkan/native_executable.h
@@ -0,0 +1,49 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_VULKAN_NATIVE_EXECUTABLE_H_ +#define IREE_HAL_VULKAN_NATIVE_EXECUTABLE_H_ + +// clang-format off: Must be included before all other headers: +#include "iree/hal/vulkan/vulkan_headers.h" +// clang-format on + +#include "iree/hal/api.h" +#include "iree/hal/vulkan/handle_util.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates a wrapper for one or more VkPipelines that are sourced from the same +// IREE executable. Each of the pipelines will share the same shader module +// and just differs by the entry point into the shader module they reference. +iree_status_t iree_hal_vulkan_native_executable_create( + iree::hal::vulkan::VkDeviceHandle* logical_device, + VkPipelineCache pipeline_cache, + iree_hal_executable_layout_t* executable_layout, + iree_hal_executable_caching_mode_t caching_mode, + iree_const_byte_span_t executable_data, + iree_hal_executable_t** out_executable); + +// Returns the cached VkPipeline for the given executable |entry_ordinal|. +iree_status_t iree_hal_vulkan_native_executable_pipeline_for_entry_point( + iree_hal_executable_t* executable, iree_host_size_t entry_ordinal, + VkPipeline* out_pipeline_handle); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_VULKAN_NATIVE_EXECUTABLE_H_
diff --git a/iree/hal/vulkan/native_executable_layout.cc b/iree/hal/vulkan/native_executable_layout.cc new file mode 100644 index 0000000..6a25069 --- /dev/null +++ b/iree/hal/vulkan/native_executable_layout.cc
@@ -0,0 +1,173 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/vulkan/native_executable_layout.h" + +#include "iree/base/tracing.h" +#include "iree/hal/vulkan/native_descriptor_set_layout.h" +#include "iree/hal/vulkan/status_util.h" + +using namespace iree::hal::vulkan; + +typedef struct { + iree_hal_resource_t resource; + VkDeviceHandle* logical_device; + VkPipelineLayout handle; + iree_host_size_t set_layout_count; + iree_hal_descriptor_set_layout_t* set_layouts[]; +} iree_hal_vulkan_native_executable_layout_t; + +extern const iree_hal_executable_layout_vtable_t + iree_hal_vulkan_native_executable_layout_vtable; + +static iree_hal_vulkan_native_executable_layout_t* +iree_hal_vulkan_native_executable_layout_cast( + iree_hal_executable_layout_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, + &iree_hal_vulkan_native_executable_layout_vtable); + return (iree_hal_vulkan_native_executable_layout_t*)base_value; +} + +static iree_status_t iree_hal_vulkan_create_pipeline_layout( + iree::hal::vulkan::VkDeviceHandle* logical_device, + iree_host_size_t set_layout_count, + iree_hal_descriptor_set_layout_t** set_layouts, + iree_host_size_t push_constant_count, VkPipelineLayout* out_handle) { + VkDescriptorSetLayout* set_layout_handles = + (VkDescriptorSetLayout*)iree_alloca(set_layout_count * + sizeof(VkDescriptorSetLayout)); + for (iree_host_size_t i = 0; i < set_layout_count; ++i) { + set_layout_handles[i] = + iree_hal_vulkan_native_descriptor_set_layout_handle(set_layouts[i]); + } + + VkPushConstantRange push_constant_ranges[1]; + push_constant_ranges[0].stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; + push_constant_ranges[0].offset = 0; + push_constant_ranges[0].size = + (uint32_t)(push_constant_count * sizeof(uint32_t)); + + VkPipelineLayoutCreateInfo create_info; + create_info.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO; + create_info.pNext = nullptr; + create_info.flags = 0; + create_info.setLayoutCount = (uint32_t)set_layout_count; + create_info.pSetLayouts = set_layout_handles; + create_info.pushConstantRangeCount = push_constant_count > 0 ? 1 : 0; + create_info.pPushConstantRanges = push_constant_ranges; + + return VK_RESULT_TO_STATUS(logical_device->syms()->vkCreatePipelineLayout( + *logical_device, &create_info, + logical_device->allocator(), out_handle), + "vkCreatePipelineLayout"); +} + +static void iree_hal_vulkan_destroy_pipeline_layout( + VkDeviceHandle* logical_device, VkPipelineLayout handle) { + if (handle == VK_NULL_HANDLE) return; + logical_device->syms()->vkDestroyPipelineLayout(*logical_device, handle, + logical_device->allocator()); +} + +iree_status_t iree_hal_vulkan_native_executable_layout_create( + iree::hal::vulkan::VkDeviceHandle* logical_device, + iree_host_size_t set_layout_count, + iree_hal_descriptor_set_layout_t** set_layouts, + iree_host_size_t push_constant_count, + iree_hal_executable_layout_t** out_executable_layout) { + IREE_ASSERT_ARGUMENT(logical_device); + IREE_ASSERT_ARGUMENT(!set_layout_count || set_layouts); + IREE_ASSERT_ARGUMENT(out_executable_layout); + *out_executable_layout = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + VkPipelineLayout handle = VK_NULL_HANDLE; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_vulkan_create_pipeline_layout(logical_device, + set_layout_count, set_layouts, + push_constant_count, &handle)); + + iree_hal_vulkan_native_executable_layout_t* executable_layout = NULL; + iree_host_size_t total_size = + sizeof(*executable_layout) + + set_layout_count * sizeof(*executable_layout->set_layouts); + iree_status_t status = iree_allocator_malloc( + logical_device->host_allocator(), total_size, (void**)&executable_layout); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize( + &iree_hal_vulkan_native_executable_layout_vtable, + &executable_layout->resource); + executable_layout->logical_device = logical_device; + executable_layout->handle = handle; + executable_layout->set_layout_count = set_layout_count; + for (iree_host_size_t i = 0; i < set_layout_count; ++i) { + executable_layout->set_layouts[i] = set_layouts[i]; + iree_hal_descriptor_set_layout_retain(set_layouts[i]); + } + *out_executable_layout = (iree_hal_executable_layout_t*)executable_layout; + } else { + iree_hal_vulkan_destroy_pipeline_layout(logical_device, handle); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_vulkan_native_executable_layout_destroy( + iree_hal_executable_layout_t* base_executable_layout) { + iree_hal_vulkan_native_executable_layout_t* executable_layout = + iree_hal_vulkan_native_executable_layout_cast(base_executable_layout); + iree_allocator_t host_allocator = + executable_layout->logical_device->host_allocator(); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_vulkan_destroy_pipeline_layout(executable_layout->logical_device, + executable_layout->handle); + for (iree_host_size_t i = 0; i < executable_layout->set_layout_count; ++i) { + iree_hal_descriptor_set_layout_release(executable_layout->set_layouts[i]); + } + iree_allocator_free(host_allocator, executable_layout); + + IREE_TRACE_ZONE_END(z0); +} + +VkPipelineLayout iree_hal_vulkan_native_executable_layout_handle( + iree_hal_executable_layout_t* base_executable_layout) { + iree_hal_vulkan_native_executable_layout_t* executable_layout = + iree_hal_vulkan_native_executable_layout_cast(base_executable_layout); + return executable_layout->handle; +} + +iree_host_size_t iree_hal_vulkan_native_executable_layout_set_count( + iree_hal_executable_layout_t* base_executable_layout) { + iree_hal_vulkan_native_executable_layout_t* executable_layout = + iree_hal_vulkan_native_executable_layout_cast(base_executable_layout); + return executable_layout->set_layout_count; +} + +iree_hal_descriptor_set_layout_t* iree_hal_vulkan_native_executable_layout_set( + iree_hal_executable_layout_t* base_executable_layout, + iree_host_size_t set_index) { + iree_hal_vulkan_native_executable_layout_t* executable_layout = + iree_hal_vulkan_native_executable_layout_cast(base_executable_layout); + if (IREE_UNLIKELY(set_index >= executable_layout->set_layout_count)) { + return NULL; + } + return executable_layout->set_layouts[set_index]; +} + +const iree_hal_executable_layout_vtable_t + iree_hal_vulkan_native_executable_layout_vtable = { + /*.destroy=*/iree_hal_vulkan_native_executable_layout_destroy, +};
diff --git a/iree/hal/vulkan/native_executable_layout.h b/iree/hal/vulkan/native_executable_layout.h new file mode 100644 index 0000000..58500fa --- /dev/null +++ b/iree/hal/vulkan/native_executable_layout.h
@@ -0,0 +1,55 @@ +// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_VULKAN_NATIVE_EXECUTABLE_LAYOUT_H_ +#define IREE_HAL_VULKAN_NATIVE_EXECUTABLE_LAYOUT_H_ + +// clang-format off: Must be included before all other headers: +#include "iree/hal/vulkan/vulkan_headers.h" +// clang-format on + +#include "iree/hal/api.h" +#include "iree/hal/vulkan/handle_util.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates a VkPipelineLayout-based executable layout composed of one or more +// descriptor set layouts. +iree_status_t iree_hal_vulkan_native_executable_layout_create( + iree::hal::vulkan::VkDeviceHandle* logical_device, + iree_host_size_t set_layout_count, + iree_hal_descriptor_set_layout_t** set_layouts, + iree_host_size_t push_constant_count, + iree_hal_executable_layout_t** out_executable_layout); + +// Returns the native VkPipelineLayout handle for the executable layout. +VkPipelineLayout iree_hal_vulkan_native_executable_layout_handle( + iree_hal_executable_layout_t* executable_layout); + +// Returns the total number of descriptor sets within the layout. +iree_host_size_t iree_hal_vulkan_native_executable_layout_set_count( + iree_hal_executable_layout_t* executable_layout); + +// Returns the descriptor set layout with the given |set_index|. +iree_hal_descriptor_set_layout_t* iree_hal_vulkan_native_executable_layout_set( + iree_hal_executable_layout_t* executable_layout, + iree_host_size_t set_index); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_VULKAN_NATIVE_EXECUTABLE_LAYOUT_H_
diff --git a/iree/hal/vulkan/native_semaphore.cc b/iree/hal/vulkan/native_semaphore.cc new file mode 100644 index 0000000..d988d6f --- /dev/null +++ b/iree/hal/vulkan/native_semaphore.cc
@@ -0,0 +1,284 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/vulkan/native_semaphore.h" + +#include "iree/base/tracing.h" +#include "iree/hal/vulkan/status_util.h" + +// The maximum valid payload value of an iree_hal_semaphore_t. +// Payload values larger than this indicate that the semaphore has failed. +// +// This originates from Vulkan having a lower-bound of INT_MAX for +// maxTimelineSemaphoreValueDifference and many Android devices only supporting +// that lower-bound. At ~100 signals per second it'll take 1.5+ years to +// saturate. We may increase this value at some point but so long as there are +// some devices in the wild that may have this limitation we can ensure better +// consistency across the backends by observing this. +// +// The major mitigation here is that in proper usage of IREE there are no +// semaphores that are implicitly referenced by multiple VMs (each creates their +// own internally) and in a multitenant system each session should have its own +// semaphores - so even if the process lives for years it's highly unlikely any +// particular session does. Whatever, 640K is enough for anyone. +// +// See: +// https://vulkan.gpuinfo.org/displayextensionproperty.php?name=maxTimelineSemaphoreValueDifference +#define IREE_HAL_VULKAN_SEMAPHORE_MAX_VALUE (2147483647ull - 1) + +using namespace iree::hal::vulkan; + +typedef struct { + iree_hal_resource_t resource; + VkDeviceHandle* logical_device; + VkSemaphore handle; + iree_atomic_ptr_t failure_status; +} iree_hal_vulkan_native_semaphore_t; + +extern const iree_hal_semaphore_vtable_t + iree_hal_vulkan_native_semaphore_vtable; + +static iree_hal_vulkan_native_semaphore_t* +iree_hal_vulkan_native_semaphore_cast(iree_hal_semaphore_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_vulkan_native_semaphore_vtable); + return (iree_hal_vulkan_native_semaphore_t*)base_value; +} + +iree_status_t iree_hal_vulkan_native_semaphore_create( + iree::hal::vulkan::VkDeviceHandle* logical_device, uint64_t initial_value, + iree_hal_semaphore_t** out_semaphore) { + IREE_ASSERT_ARGUMENT(logical_device); + IREE_ASSERT_ARGUMENT(out_semaphore); + *out_semaphore = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + VkSemaphoreTypeCreateInfo timeline_create_info; + timeline_create_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_TYPE_CREATE_INFO; + timeline_create_info.pNext = NULL; + timeline_create_info.semaphoreType = VK_SEMAPHORE_TYPE_TIMELINE; + timeline_create_info.initialValue = initial_value; + + VkSemaphoreCreateInfo create_info; + create_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_CREATE_INFO; + create_info.pNext = &timeline_create_info; + create_info.flags = 0; + VkSemaphore handle = VK_NULL_HANDLE; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, VK_RESULT_TO_STATUS(logical_device->syms()->vkCreateSemaphore( + *logical_device, &create_info, + logical_device->allocator(), &handle), + "vkCreateSemaphore")); + + iree_hal_vulkan_native_semaphore_t* semaphore = NULL; + iree_status_t status = iree_allocator_malloc( + logical_device->host_allocator(), sizeof(*semaphore), (void**)&semaphore); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_vulkan_native_semaphore_vtable, + &semaphore->resource); + semaphore->logical_device = logical_device; + semaphore->handle = handle; + iree_atomic_store_ptr(&semaphore->failure_status, 0, + iree_memory_order_release); + *out_semaphore = (iree_hal_semaphore_t*)semaphore; + } else { + logical_device->syms()->vkDestroySemaphore(*logical_device, handle, + logical_device->allocator()); + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_vulkan_native_semaphore_destroy( + iree_hal_semaphore_t* base_semaphore) { + iree_hal_vulkan_native_semaphore_t* semaphore = + iree_hal_vulkan_native_semaphore_cast(base_semaphore); + iree_allocator_t host_allocator = semaphore->logical_device->host_allocator(); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_status_free((iree_status_t)iree_atomic_load_ptr( + &semaphore->failure_status, iree_memory_order_acquire)); + semaphore->logical_device->syms()->vkDestroySemaphore( + *semaphore->logical_device, semaphore->handle, + semaphore->logical_device->allocator()); + iree_allocator_free(host_allocator, semaphore); + + IREE_TRACE_ZONE_END(z0); +} + +VkSemaphore iree_hal_vulkan_native_semaphore_handle( + iree_hal_semaphore_t* base_semaphore) { + iree_hal_vulkan_native_semaphore_t* semaphore = + iree_hal_vulkan_native_semaphore_cast(base_semaphore); + return semaphore->handle; +} + +static iree_status_t iree_hal_vulkan_native_semaphore_query( + iree_hal_semaphore_t* base_semaphore, uint64_t* out_value) { + iree_hal_vulkan_native_semaphore_t* semaphore = + iree_hal_vulkan_native_semaphore_cast(base_semaphore); + *out_value = 0; + + uint64_t value = 0; + IREE_RETURN_IF_ERROR(VK_RESULT_TO_STATUS( + semaphore->logical_device->syms()->vkGetSemaphoreCounterValue( + *semaphore->logical_device, semaphore->handle, &value), + "vkGetSemaphoreCounterValue")); + + if (value > IREE_HAL_VULKAN_SEMAPHORE_MAX_VALUE) { + iree_status_t failure_status = (iree_status_t)iree_atomic_load_ptr( + &semaphore->failure_status, iree_memory_order_acquire); + if (iree_status_is_ok(failure_status)) { + return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, + "overflowed timeline semaphore max value"); + } + return iree_status_clone(failure_status); + } + + *out_value = value; + return iree_ok_status(); +} + +static iree_status_t iree_hal_vulkan_native_semaphore_signal( + iree_hal_semaphore_t* base_semaphore, uint64_t new_value) { + iree_hal_vulkan_native_semaphore_t* semaphore = + iree_hal_vulkan_native_semaphore_cast(base_semaphore); + + VkSemaphoreSignalInfo signal_info; + signal_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_SIGNAL_INFO; + signal_info.pNext = NULL; + signal_info.semaphore = semaphore->handle; + signal_info.value = new_value; + return VK_RESULT_TO_STATUS( + semaphore->logical_device->syms()->vkSignalSemaphore( + *semaphore->logical_device, &signal_info), + "vkSignalSemaphore"); +} + +static void iree_hal_vulkan_native_semaphore_fail( + iree_hal_semaphore_t* base_semaphore, iree_status_t status) { + iree_hal_vulkan_native_semaphore_t* semaphore = + iree_hal_vulkan_native_semaphore_cast(base_semaphore); + + // Try to set our local status - we only preserve the first failure so only + // do this if we are going from a valid semaphore to a failed one. + iree_status_t old_status = iree_ok_status(); + if (!iree_atomic_compare_exchange_strong_ptr( + &semaphore->failure_status, (uintptr_t*)&old_status, + (uintptr_t)status, iree_memory_order_seq_cst, + iree_memory_order_seq_cst)) { + // Previous status was not OK; drop our new status. + IREE_IGNORE_ERROR(status); + return; + } + + VkSemaphoreSignalInfo signal_info; + signal_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_SIGNAL_INFO; + signal_info.pNext = NULL; + signal_info.semaphore = semaphore->handle; + signal_info.value = IREE_HAL_VULKAN_SEMAPHORE_MAX_VALUE + 1; + // NOTE: we don't care about the result in case of failures as we are + // failing and the caller will likely be tearing everything down anyway. + semaphore->logical_device->syms()->vkSignalSemaphore( + *semaphore->logical_device, &signal_info); +} + +iree_status_t iree_hal_vulkan_native_semaphore_multi_wait( + iree::hal::vulkan::VkDeviceHandle* logical_device, + const iree_hal_semaphore_list_t* semaphore_list, iree_time_t deadline_ns, + VkSemaphoreWaitFlags wait_flags) { + if (semaphore_list->count == 0) return iree_ok_status(); + + uint64_t timeout_ns; + if (deadline_ns == IREE_TIME_INFINITE_FUTURE) { + timeout_ns = UINT64_MAX; + } else if (deadline_ns == IREE_TIME_INFINITE_PAST) { + timeout_ns = 0; + } else { + iree_time_t now_ns = iree_time_now(); + if (deadline_ns < now_ns) { + return iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED); + } + timeout_ns = (uint64_t)(deadline_ns - now_ns); + } + + IREE_TRACE_ZONE_BEGIN(z0); + + VkSemaphore* semaphore_handles = + (VkSemaphore*)iree_alloca(semaphore_list->count * sizeof(VkSemaphore)); + for (iree_host_size_t i = 0; i < semaphore_list->count; ++i) { + semaphore_handles[i] = + iree_hal_vulkan_native_semaphore_handle(semaphore_list->semaphores[i]); + } + + VkSemaphoreWaitInfo wait_info; + wait_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_WAIT_INFO; + wait_info.pNext = nullptr; + wait_info.flags = wait_flags; + wait_info.semaphoreCount = semaphore_list->count; + wait_info.pSemaphores = semaphore_handles; + wait_info.pValues = semaphore_list->payload_values; + static_assert( + sizeof(wait_info.pValues[0]) == sizeof(semaphore_list->payload_values[0]), + "payload value type must match vulkan expected size"); + + // NOTE: this may fail with a timeout (VK_TIMEOUT) or in the case of a + // device loss event may return either VK_SUCCESS *or* VK_ERROR_DEVICE_LOST. + // We may want to explicitly query for device loss after a successful wait + // to ensure we consistently return errors. + VkResult result = logical_device->syms()->vkWaitSemaphores( + *logical_device, &wait_info, timeout_ns); + + IREE_TRACE_ZONE_END(z0); + + if (result == VK_SUCCESS) { + return iree_ok_status(); + } else if (result == VK_ERROR_DEVICE_LOST) { + // Nothing we do now matters. + return VK_RESULT_TO_STATUS(result, "vkWaitSemaphores"); + } else if (result == VK_TIMEOUT) { + return iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED); + } + return VK_RESULT_TO_STATUS(result, "vkWaitSemaphores"); +} + +static iree_status_t iree_hal_vulkan_native_semaphore_wait_with_deadline( + iree_hal_semaphore_t* base_semaphore, uint64_t value, + iree_time_t deadline_ns) { + iree_hal_vulkan_native_semaphore_t* semaphore = + iree_hal_vulkan_native_semaphore_cast(base_semaphore); + iree_hal_semaphore_list_t semaphore_list = { + /*.count=*/1, + /*.semaphores=*/&base_semaphore, + /*.payload_values=*/&value, + }; + return iree_hal_vulkan_native_semaphore_multi_wait( + semaphore->logical_device, &semaphore_list, deadline_ns, 0); +} + +static iree_status_t iree_hal_vulkan_native_semaphore_wait_with_timeout( + iree_hal_semaphore_t* base_semaphore, uint64_t value, + iree_duration_t timeout_ns) { + return iree_hal_vulkan_native_semaphore_wait_with_deadline( + base_semaphore, value, iree_relative_timeout_to_deadline_ns(timeout_ns)); +} + +const iree_hal_semaphore_vtable_t iree_hal_vulkan_native_semaphore_vtable = { + /*.destroy=*/iree_hal_vulkan_native_semaphore_destroy, + /*.query=*/iree_hal_vulkan_native_semaphore_query, + /*.signal=*/iree_hal_vulkan_native_semaphore_signal, + /*.fail=*/iree_hal_vulkan_native_semaphore_fail, + /*.wait_with_deadline=*/iree_hal_vulkan_native_semaphore_wait_with_deadline, + /*.wait_with_timeout=*/iree_hal_vulkan_native_semaphore_wait_with_timeout, +};
diff --git a/iree/hal/vulkan/native_semaphore.h b/iree/hal/vulkan/native_semaphore.h new file mode 100644 index 0000000..31f2611 --- /dev/null +++ b/iree/hal/vulkan/native_semaphore.h
@@ -0,0 +1,51 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_VULKAN_NATIVE_SEMAPHORE_H_ +#define IREE_HAL_VULKAN_NATIVE_SEMAPHORE_H_ + +#include "iree/hal/api.h" +#include "iree/hal/vulkan/handle_util.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates a timeline semaphore implemented using the native VkSemaphore type. +// This may require emulation pre-Vulkan 1.2 when timeline semaphores were only +// an extension. +iree_status_t iree_hal_vulkan_native_semaphore_create( + iree::hal::vulkan::VkDeviceHandle* logical_device, uint64_t initial_value, + iree_hal_semaphore_t** out_semaphore); + +// Returns the Vulkan timeline semaphore handle. +VkSemaphore iree_hal_vulkan_native_semaphore_handle( + iree_hal_semaphore_t* semaphore); + +// Performs a multi-wait on one or more semaphores. +// By default this is an all-wait but |wait_flags| may contain +// VK_SEMAPHORE_WAIT_ANY_BIT to change to an any-wait. +// +// Returns IREE_STATUS_DEADLINE_EXCEEDED if the wait does not complete before +// |deadline_ns| elapses. +iree_status_t iree_hal_vulkan_native_semaphore_multi_wait( + iree::hal::vulkan::VkDeviceHandle* logical_device, + const iree_hal_semaphore_list_t* semaphore_list, iree_time_t deadline_ns, + VkSemaphoreWaitFlags wait_flags); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_VULKAN_NATIVE_SEMAPHORE_H_
diff --git a/iree/hal/vulkan/native_timeline_semaphore.cc b/iree/hal/vulkan/native_timeline_semaphore.cc deleted file mode 100644 index ecc340b..0000000 --- a/iree/hal/vulkan/native_timeline_semaphore.cc +++ /dev/null
@@ -1,145 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/native_timeline_semaphore.h" - -#include "iree/base/tracing.h" -#include "iree/hal/vulkan/status_util.h" - -namespace iree { -namespace hal { -namespace vulkan { - -// static -StatusOr<ref_ptr<Semaphore>> NativeTimelineSemaphore::Create( - ref_ptr<VkDeviceHandle> logical_device, uint64_t initial_value) { - IREE_TRACE_SCOPE0("NativeTimelineSemaphore::Create"); - - VkSemaphoreTypeCreateInfo timeline_create_info; - timeline_create_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_TYPE_CREATE_INFO; - timeline_create_info.pNext = nullptr; - timeline_create_info.semaphoreType = VK_SEMAPHORE_TYPE_TIMELINE; - timeline_create_info.initialValue = initial_value; - - VkSemaphoreCreateInfo create_info; - create_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_CREATE_INFO; - create_info.pNext = &timeline_create_info; - create_info.flags = 0; - VkSemaphore semaphore_handle = VK_NULL_HANDLE; - VK_RETURN_IF_ERROR(logical_device->syms()->vkCreateSemaphore( - *logical_device, &create_info, logical_device->allocator(), - &semaphore_handle)); - - return make_ref<NativeTimelineSemaphore>(std::move(logical_device), - semaphore_handle, initial_value); -} - -NativeTimelineSemaphore::NativeTimelineSemaphore( - ref_ptr<VkDeviceHandle> logical_device, VkSemaphore handle, - uint64_t initial_value) - : logical_device_(std::move(logical_device)), handle_(handle) {} - -NativeTimelineSemaphore::~NativeTimelineSemaphore() { - IREE_TRACE_SCOPE0("NativeTimelineSemaphore::dtor"); - logical_device_->syms()->vkDestroySemaphore(*logical_device_, handle_, - logical_device_->allocator()); -} - -StatusOr<uint64_t> NativeTimelineSemaphore::Query() { - uint64_t value = 0; - VK_RETURN_IF_ERROR(logical_device_->syms()->vkGetSemaphoreCounterValue( - *logical_device_, handle_, &value)); - if (value == UINT64_MAX) { - absl::MutexLock lock(&status_mutex_); - return status_; - } - return value; -} - -Status NativeTimelineSemaphore::Signal(uint64_t value) { - IREE_TRACE_SCOPE0("NativeTimelineSemaphore::Signal"); - - VkSemaphoreSignalInfo signal_info; - signal_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_SIGNAL_INFO; - signal_info.pNext = nullptr; - signal_info.semaphore = handle_; - signal_info.value = value; - return VkResultToStatus(logical_device_->syms()->vkSignalSemaphore( - *logical_device_, &signal_info), - IREE_LOC); -} - -void NativeTimelineSemaphore::Fail(Status status) { - IREE_TRACE_SCOPE0("NativeTimelineSemaphore::Fail"); - - // NOTE: we hold the lock here as the vkSignalSemaphore may wake a waiter and - // we want to be able to immediately give them the status. - absl::MutexLock lock(&status_mutex_); - status_ = std::move(status); - - VkSemaphoreSignalInfo signal_info; - signal_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_SIGNAL_INFO; - signal_info.pNext = nullptr; - signal_info.semaphore = handle_; - signal_info.value = UINT64_MAX; - // NOTE: we don't care about the result in case of failures as we are - // failing and the caller will likely be tearing everything down anyway. - logical_device_->syms()->vkSignalSemaphore(*logical_device_, &signal_info); -} - -Status NativeTimelineSemaphore::Wait(uint64_t value, Time deadline_ns) { - IREE_TRACE_SCOPE0("NativeTimelineSemaphore::Wait"); - - VkSemaphoreWaitInfo wait_info; - wait_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_WAIT_INFO; - wait_info.pNext = nullptr; - wait_info.flags = 0; - wait_info.semaphoreCount = 1; - wait_info.pSemaphores = &handle_; - wait_info.pValues = &value; - - uint64_t timeout_ns; - if (deadline_ns == InfiniteFuture()) { - timeout_ns = UINT64_MAX; - } else if (deadline_ns == InfinitePast()) { - timeout_ns = 0; - } else { - Duration relative_ns = deadline_ns - Now(); - timeout_ns = static_cast<int64_t>( - relative_ns < ZeroDuration() ? ZeroDuration() : relative_ns); - } - - // NOTE: this may fail with a timeout (VK_TIMEOUT) or in the case of a - // device loss event may return either VK_SUCCESS *or* VK_ERROR_DEVICE_LOST. - // We may want to explicitly query for device loss after a successful wait - // to ensure we consistently return errors. - if (!logical_device_->syms()->vkWaitSemaphores) { - return UnknownErrorBuilder(IREE_LOC) << "vkWaitSemaphores not defined"; - } - VkResult result = logical_device_->syms()->vkWaitSemaphores( - *logical_device_, &wait_info, timeout_ns); - if (result == VK_ERROR_DEVICE_LOST) { - // Nothing we do now matters. - return VkResultToStatus(result, IREE_LOC); - } else if (result == VK_TIMEOUT) { - return DeadlineExceededErrorBuilder(IREE_LOC) - << "Deadline exceeded waiting for semaphore"; - } - - return VkResultToStatus(result, IREE_LOC); -} - -} // namespace vulkan -} // namespace hal -} // namespace iree
diff --git a/iree/hal/vulkan/native_timeline_semaphore.h b/iree/hal/vulkan/native_timeline_semaphore.h deleted file mode 100644 index d7eb0b1..0000000 --- a/iree/hal/vulkan/native_timeline_semaphore.h +++ /dev/null
@@ -1,66 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VULKAN_NATIVE_TIMELINE_SEMAPHORE_H_ -#define IREE_HAL_VULKAN_NATIVE_TIMELINE_SEMAPHORE_H_ - -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on - -#include "absl/synchronization/mutex.h" -#include "iree/hal/cc/semaphore.h" -#include "iree/hal/vulkan/handle_util.h" - -namespace iree { -namespace hal { -namespace vulkan { - -// A timeline semaphore implemented using the native VkSemaphore type. -// This may require emulation pre-Vulkan 1.2 when timeline semaphores were only -// an extension. -class NativeTimelineSemaphore final : public Semaphore { - public: - // Creates a timeline semaphore with the given |initial_value|. - static StatusOr<ref_ptr<Semaphore>> Create( - ref_ptr<VkDeviceHandle> logical_device, uint64_t initial_value); - - NativeTimelineSemaphore(ref_ptr<VkDeviceHandle> logical_device, - VkSemaphore handle, uint64_t initial_value); - ~NativeTimelineSemaphore() override; - - VkSemaphore handle() const { return handle_; } - - StatusOr<uint64_t> Query() override; - - Status Signal(uint64_t value) override; - void Fail(Status status) override; - Status Wait(uint64_t value, Time deadline_ns) override; - - private: - ref_ptr<VkDeviceHandle> logical_device_; - VkSemaphore handle_; - - // NOTE: the Vulkan semaphore is the source of truth. We only need to access - // this status (and thus take the lock) when we want to either signal failure - // or query the status in the case of the semaphore being set to UINT64_MAX. - mutable absl::Mutex status_mutex_; - Status status_ ABSL_GUARDED_BY(status_mutex_); -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_NATIVE_TIMELINE_SEMAPHORE_H_
diff --git a/iree/hal/vulkan/nop_executable_cache.cc b/iree/hal/vulkan/nop_executable_cache.cc new file mode 100644 index 0000000..5aa6bc0 --- /dev/null +++ b/iree/hal/vulkan/nop_executable_cache.cc
@@ -0,0 +1,105 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/vulkan/nop_executable_cache.h" + +#include "iree/base/tracing.h" +#include "iree/hal/vulkan/native_executable.h" + +using namespace iree::hal::vulkan; + +static const iree_hal_executable_format_t kExecutableFormatSpirV = + iree_hal_make_executable_format("SPVE"); + +typedef struct { + iree_hal_resource_t resource; + VkDeviceHandle* logical_device; +} iree_hal_vulkan_nop_executable_cache_t; + +extern const iree_hal_executable_cache_vtable_t + iree_hal_vulkan_nop_executable_cache_vtable; + +static iree_hal_vulkan_nop_executable_cache_t* +iree_hal_vulkan_nop_executable_cache_cast( + iree_hal_executable_cache_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, + &iree_hal_vulkan_nop_executable_cache_vtable); + return (iree_hal_vulkan_nop_executable_cache_t*)base_value; +} + +iree_status_t iree_hal_vulkan_nop_executable_cache_create( + iree::hal::vulkan::VkDeviceHandle* logical_device, + iree_string_view_t identifier, + iree_hal_executable_cache_t** out_executable_cache) { + IREE_ASSERT_ARGUMENT(out_executable_cache); + *out_executable_cache = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_vulkan_nop_executable_cache_t* executable_cache = NULL; + iree_status_t status = iree_allocator_malloc(logical_device->host_allocator(), + sizeof(*executable_cache), + (void**)&executable_cache); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_vulkan_nop_executable_cache_vtable, + &executable_cache->resource); + executable_cache->logical_device = logical_device; + + *out_executable_cache = (iree_hal_executable_cache_t*)executable_cache; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_vulkan_nop_executable_cache_destroy( + iree_hal_executable_cache_t* base_executable_cache) { + iree_hal_vulkan_nop_executable_cache_t* executable_cache = + iree_hal_vulkan_nop_executable_cache_cast(base_executable_cache); + iree_allocator_t host_allocator = + executable_cache->logical_device->host_allocator(); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_allocator_free(host_allocator, executable_cache); + + IREE_TRACE_ZONE_END(z0); +} + +static bool iree_hal_vulkan_nop_executable_cache_can_prepare_format( + iree_hal_executable_cache_t* base_executable_cache, + iree_hal_executable_format_t format) { + return format == kExecutableFormatSpirV; +} + +static iree_status_t iree_hal_vulkan_nop_executable_cache_prepare_executable( + iree_hal_executable_cache_t* base_executable_cache, + iree_hal_executable_layout_t* executable_layout, + iree_hal_executable_caching_mode_t caching_mode, + iree_const_byte_span_t executable_data, + iree_hal_executable_t** out_executable) { + iree_hal_vulkan_nop_executable_cache_t* executable_cache = + iree_hal_vulkan_nop_executable_cache_cast(base_executable_cache); + return iree_hal_vulkan_native_executable_create( + executable_cache->logical_device, + /*pipeline_cache=*/VK_NULL_HANDLE, executable_layout, caching_mode, + executable_data, out_executable); +} + +const iree_hal_executable_cache_vtable_t + iree_hal_vulkan_nop_executable_cache_vtable = { + /*.destroy=*/iree_hal_vulkan_nop_executable_cache_destroy, + /*.can_prepare_format=*/ + iree_hal_vulkan_nop_executable_cache_can_prepare_format, + /*.prepare_executable=*/ + iree_hal_vulkan_nop_executable_cache_prepare_executable, +};
diff --git a/iree/hal/vulkan/nop_executable_cache.h b/iree/hal/vulkan/nop_executable_cache.h new file mode 100644 index 0000000..b6ed2b6 --- /dev/null +++ b/iree/hal/vulkan/nop_executable_cache.h
@@ -0,0 +1,37 @@ +// Copyright 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_VULKAN_NOP_EXECUTABLE_CACHE_H_ +#define IREE_HAL_VULKAN_NOP_EXECUTABLE_CACHE_H_ + +#include "iree/hal/api.h" +#include "iree/hal/vulkan/handle_util.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates a no-op executable cache that does not cache at all. +// This is useful to isolate pipeline caching behavior and verify compilation +// behavior. +iree_status_t iree_hal_vulkan_nop_executable_cache_create( + iree::hal::vulkan::VkDeviceHandle* logical_device, + iree_string_view_t identifier, + iree_hal_executable_cache_t** out_executable_cache); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_VULKAN_NOP_EXECUTABLE_CACHE_H_
diff --git a/iree/hal/vulkan/pipeline_cache.cc b/iree/hal/vulkan/pipeline_cache.cc deleted file mode 100644 index 92f114c..0000000 --- a/iree/hal/vulkan/pipeline_cache.cc +++ /dev/null
@@ -1,54 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/pipeline_cache.h" - -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/cc/executable_format.h" -#include "iree/hal/vulkan/status_util.h" - -namespace iree { -namespace hal { -namespace vulkan { - -PipelineCache::PipelineCache(ref_ptr<VkDeviceHandle> logical_device) - : logical_device_(std::move(logical_device)) {} - -PipelineCache::~PipelineCache() = default; - -bool PipelineCache::CanPrepareFormat(ExecutableFormat format) const { - return format == kExecutableFormatSpirV; -} - -StatusOr<ref_ptr<Executable>> PipelineCache::PrepareExecutable( - ExecutableLayout* executable_layout, - iree_hal_executable_caching_mode_t mode, - iree_const_byte_span_t executable_data) { - IREE_TRACE_SCOPE0("PipelineCache::PrepareExecutable"); - - // Create the executable (which may itself own many pipelines). - IREE_ASSIGN_OR_RETURN( - auto executable, - PipelineExecutable::Create( - add_ref(logical_device_), - /*pipeline_cache=*/VK_NULL_HANDLE, - static_cast<PipelineExecutableLayout*>(executable_layout), mode, - executable_data)); - return executable; -} - -} // namespace vulkan -} // namespace hal -} // namespace iree
diff --git a/iree/hal/vulkan/pipeline_cache.h b/iree/hal/vulkan/pipeline_cache.h deleted file mode 100644 index 186f682..0000000 --- a/iree/hal/vulkan/pipeline_cache.h +++ /dev/null
@@ -1,56 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VULKAN_PIPELINE_CACHE_H_ -#define IREE_HAL_VULKAN_PIPELINE_CACHE_H_ - -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on - -#include "absl/container/inlined_vector.h" -#include "iree/hal/cc/executable.h" -#include "iree/hal/cc/executable_cache.h" -#include "iree/hal/vulkan/handle_util.h" -#include "iree/hal/vulkan/pipeline_executable.h" - -namespace iree { -namespace hal { -namespace vulkan { - -class PipelineCache final : public ExecutableCache { - public: - explicit PipelineCache(ref_ptr<VkDeviceHandle> logical_device); - ~PipelineCache() override; - - const ref_ptr<DynamicSymbols>& syms() const { - return logical_device_->syms(); - } - - bool CanPrepareFormat(ExecutableFormat format) const override; - - StatusOr<ref_ptr<Executable>> PrepareExecutable( - ExecutableLayout* executable_layout, - iree_hal_executable_caching_mode_t mode, - iree_const_byte_span_t executable_data) override; - - private: - ref_ptr<VkDeviceHandle> logical_device_; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_PIPELINE_CACHE_H_
diff --git a/iree/hal/vulkan/pipeline_executable.cc b/iree/hal/vulkan/pipeline_executable.cc deleted file mode 100644 index 81c4889..0000000 --- a/iree/hal/vulkan/pipeline_executable.cc +++ /dev/null
@@ -1,234 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/pipeline_executable.h" - -#include "absl/container/inlined_vector.h" -#include "iree/base/memory.h" -#include "iree/base/status.h" -#include "iree/base/tracing.h" -#include "iree/hal/vulkan/status_util.h" - -// flatcc schemas: -#include "iree/base/flatcc.h" -#include "iree/schemas/spirv_executable_def_reader.h" -#include "iree/schemas/spirv_executable_def_verifier.h" - -// NOTE: starting to port this to C. - -// Verifies the structure of the flatbuffer so that we can avoid doing so during -// runtime. There are still some conditions we must be aware of (such as omitted -// names on functions with internal linkage), however we shouldn't need to -// bounds check anything within the flatbuffer after this succeeds. -static iree_status_t iree_hal_spirv_executable_flatbuffer_verify( - iree_const_byte_span_t flatbuffer_data) { - if (!flatbuffer_data.data || flatbuffer_data.data_length < 16) { - return iree_make_status( - IREE_STATUS_INVALID_ARGUMENT, - "flatbuffer data is not present or less than 16 bytes (%zu total)", - flatbuffer_data.data_length); - } - - // Run flatcc generated verification. This ensures all pointers are in-bounds - // and that we can safely walk the file, but not that the actual contents of - // the flatbuffer meet our expectations. - int verify_ret = iree_SpirVExecutableDef_verify_as_root( - flatbuffer_data.data, flatbuffer_data.data_length); - if (verify_ret != flatcc_verify_ok) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "flatbuffer verification failed: %s", - flatcc_verify_error_string(verify_ret)); - } - - iree_SpirVExecutableDef_table_t executable_def = - iree_SpirVExecutableDef_as_root(flatbuffer_data.data); - - flatbuffers_string_vec_t entry_points_vec = - iree_SpirVExecutableDef_entry_points_get(executable_def); - size_t entry_point_count = flatbuffers_string_vec_len(entry_points_vec); - for (size_t i = 0; i < entry_point_count; ++i) { - if (!flatbuffers_string_len( - flatbuffers_string_vec_at(entry_points_vec, i))) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "executable entry point %zu has no name", i); - } - } - - if (flatbuffers_uint32_vec_len( - iree_SpirVExecutableDef_code_get(executable_def)) < 0) { - return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, - "executable SPIR-V code is missing/empty"); - } - - // TODO(benvanik): pull PopulateSpecializationInfo from history and update. - // For now the compiler isn't generating them, and we don't use them. - if (iree_SpirVExecutableDef_specialization_info_is_present(executable_def)) { - return iree_make_status(IREE_STATUS_UNIMPLEMENTED, - "executable uses SPIR-V specialization constants; " - "they need to be revived"); - } - - return iree_ok_status(); -} - -namespace iree { -namespace hal { -namespace vulkan { - -namespace { - -class VkShaderModuleHandle : public RefObject<VkShaderModuleHandle> { - public: - explicit VkShaderModuleHandle(const ref_ptr<VkDeviceHandle>& logical_device) - : logical_device_(add_ref(logical_device)) {} - ~VkShaderModuleHandle() { reset(); } - - VkShaderModuleHandle(const VkShaderModuleHandle&) = delete; - VkShaderModuleHandle& operator=(const VkShaderModuleHandle&) = delete; - VkShaderModuleHandle(VkShaderModuleHandle&& other) noexcept - : logical_device_(std::move(other.logical_device_)), - value_(absl::exchange(other.value_, - static_cast<VkShaderModule>(VK_NULL_HANDLE))) {} - VkShaderModuleHandle& operator=(VkShaderModuleHandle&& other) { - std::swap(logical_device_, other.logical_device_); - std::swap(value_, other.value_); - return *this; - } - - void reset() { - if (value_ == VK_NULL_HANDLE) return; - logical_device_->syms()->vkDestroyShaderModule( - *logical_device_, value_, logical_device_->allocator()); - value_ = VK_NULL_HANDLE; - } - - VkShaderModule value() const noexcept { return value_; } - VkShaderModule* mutable_value() noexcept { return &value_; } - operator VkShaderModule() const noexcept { return value_; } - - private: - ref_ptr<VkDeviceHandle> logical_device_; - VkShaderModule value_ = VK_NULL_HANDLE; -}; - -} // namespace - -// static -StatusOr<ref_ptr<PipelineExecutable>> PipelineExecutable::Create( - ref_ptr<VkDeviceHandle> logical_device, VkPipelineCache pipeline_cache, - PipelineExecutableLayout* executable_layout, - iree_hal_executable_caching_mode_t mode, - iree_const_byte_span_t executable_data) { - IREE_TRACE_SCOPE0("PipelineExecutable::Create"); - const auto& syms = logical_device->syms(); - - // Verify and fetch the executable flatbuffer wrapper. - IREE_RETURN_IF_ERROR( - iree_hal_spirv_executable_flatbuffer_verify(executable_data)); - iree_SpirVExecutableDef_table_t executable_def = - iree_SpirVExecutableDef_as_root(executable_data.data); - - // Create the shader module. - VkShaderModuleCreateInfo shader_module_create_info; - shader_module_create_info.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; - shader_module_create_info.pNext = nullptr; - shader_module_create_info.flags = 0; - flatbuffers_uint32_vec_t code_vec = - iree_SpirVExecutableDef_code_get(executable_def); - shader_module_create_info.codeSize = - flatbuffers_uint32_vec_len(code_vec) * sizeof(uint32_t); - shader_module_create_info.pCode = code_vec; - VkShaderModuleHandle shader_module(add_ref(logical_device)); - VK_RETURN_IF_ERROR(syms->vkCreateShaderModule( - *logical_device, &shader_module_create_info, logical_device->allocator(), - shader_module.mutable_value())); - - // Create pipelines for each entry point. - flatbuffers_string_vec_t entry_points_vec = - iree_SpirVExecutableDef_entry_points_get(executable_def); - absl::InlinedVector<VkComputePipelineCreateInfo, 1> pipeline_create_infos; - pipeline_create_infos.resize(flatbuffers_string_vec_len(entry_points_vec)); - for (size_t entry_ordinal = 0; - entry_ordinal < flatbuffers_string_vec_len(entry_points_vec); - ++entry_ordinal) { - auto& pipeline_create_info = pipeline_create_infos[entry_ordinal]; - pipeline_create_info.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; - pipeline_create_info.pNext = nullptr; - pipeline_create_info.flags = 0; - if (!iree_all_bits_set( - mode, IREE_HAL_EXECUTABLE_CACHING_MODE_ALLOW_OPTIMIZATION)) { - pipeline_create_info.flags |= VK_PIPELINE_CREATE_DISABLE_OPTIMIZATION_BIT; - } - if (entry_ordinal == 0) { - pipeline_create_info.flags |= VK_PIPELINE_CREATE_ALLOW_DERIVATIVES_BIT; - } else { - pipeline_create_info.flags |= VK_PIPELINE_CREATE_DERIVATIVE_BIT; - } - pipeline_create_info.layout = executable_layout->handle(); - pipeline_create_info.basePipelineHandle = VK_NULL_HANDLE; - pipeline_create_info.basePipelineIndex = 0; - auto& stage_create_info = pipeline_create_info.stage; - stage_create_info.sType = - VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; - stage_create_info.pNext = nullptr; - stage_create_info.flags = 0; - stage_create_info.stage = VK_SHADER_STAGE_COMPUTE_BIT; - stage_create_info.module = shader_module; - stage_create_info.pName = - flatbuffers_string_vec_at(entry_points_vec, entry_ordinal); - stage_create_info.pSpecializationInfo = NULL; - } - absl::InlinedVector<VkPipeline, 1> pipelines; - pipelines.resize(flatbuffers_string_vec_len(entry_points_vec)); - - // Some ICDs appear to leak in here, out of our control. - // Warning: leak checks remain disabled if an error is returned. - IREE_DISABLE_LEAK_CHECKS(); - VK_RETURN_IF_ERROR(syms->vkCreateComputePipelines( - *logical_device, pipeline_cache, - static_cast<uint32_t>(pipeline_create_infos.size()), - pipeline_create_infos.data(), logical_device->allocator(), - pipelines.data())); - IREE_ENABLE_LEAK_CHECKS(); - - return make_ref<PipelineExecutable>(std::move(logical_device), - std::move(pipelines)); -} - -PipelineExecutable::PipelineExecutable( - ref_ptr<VkDeviceHandle> logical_device, - absl::InlinedVector<VkPipeline, 1> pipelines) - : logical_device_(std::move(logical_device)), - pipelines_(std::move(pipelines)) {} - -PipelineExecutable::~PipelineExecutable() { - IREE_TRACE_SCOPE0("PipelineExecutable::dtor"); - for (auto pipeline : pipelines_) { - syms()->vkDestroyPipeline(*logical_device_, pipeline, - logical_device_->allocator()); - } - pipelines_.clear(); -} - -StatusOr<VkPipeline> PipelineExecutable::GetPipelineForEntryPoint( - int entry_ordinal) const { - if (entry_ordinal < 0 || entry_ordinal >= pipelines_.size()) { - return OutOfRangeErrorBuilder(IREE_LOC) << "Invalid entry point ordinal"; - } - return pipelines_[entry_ordinal]; -} - -} // namespace vulkan -} // namespace hal -} // namespace iree
diff --git a/iree/hal/vulkan/pipeline_executable.h b/iree/hal/vulkan/pipeline_executable.h deleted file mode 100644 index c9811c5..0000000 --- a/iree/hal/vulkan/pipeline_executable.h +++ /dev/null
@@ -1,66 +0,0 @@ -// Copyright 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_H_ -#define IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_H_ - -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on - -#include <vector> - -#include "absl/container/inlined_vector.h" -#include "iree/base/status.h" -#include "iree/hal/cc/executable.h" -#include "iree/hal/cc/executable_cache.h" -#include "iree/hal/cc/executable_layout.h" -#include "iree/hal/vulkan/handle_util.h" -#include "iree/hal/vulkan/native_descriptor_set.h" -#include "iree/hal/vulkan/pipeline_executable_layout.h" - -namespace iree { -namespace hal { -namespace vulkan { - -class PipelineExecutable final : public Executable { - public: - static StatusOr<ref_ptr<PipelineExecutable>> Create( - ref_ptr<VkDeviceHandle> logical_device, VkPipelineCache pipeline_cache, - PipelineExecutableLayout* executable_layout, - iree_hal_executable_caching_mode_t mode, - iree_const_byte_span_t executable_data); - - PipelineExecutable(ref_ptr<VkDeviceHandle> logical_device, - absl::InlinedVector<VkPipeline, 1> pipelines); - ~PipelineExecutable() override; - - const ref_ptr<DynamicSymbols>& syms() const { - return logical_device_->syms(); - } - - StatusOr<VkPipeline> GetPipelineForEntryPoint(int entry_ordinal) const; - - private: - ref_ptr<VkDeviceHandle> logical_device_; - - // One pipeline per entry point. - absl::InlinedVector<VkPipeline, 1> pipelines_; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_H_
diff --git a/iree/hal/vulkan/pipeline_executable_layout.cc b/iree/hal/vulkan/pipeline_executable_layout.cc deleted file mode 100644 index 3628b64..0000000 --- a/iree/hal/vulkan/pipeline_executable_layout.cc +++ /dev/null
@@ -1,44 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/pipeline_executable_layout.h" - -namespace iree { -namespace hal { -namespace vulkan { - -NativeDescriptorSetLayout::NativeDescriptorSetLayout( - ref_ptr<VkDeviceHandle> logical_device, VkDescriptorSetLayout handle) - : logical_device_(std::move(logical_device)), handle_(handle) {} - -NativeDescriptorSetLayout::~NativeDescriptorSetLayout() { - logical_device_->syms()->vkDestroyDescriptorSetLayout( - *logical_device_, handle_, logical_device_->allocator()); -} - -PipelineExecutableLayout::PipelineExecutableLayout( - ref_ptr<VkDeviceHandle> logical_device, VkPipelineLayout handle, - absl::InlinedVector<ref_ptr<NativeDescriptorSetLayout>, 2> set_layouts) - : logical_device_(std::move(logical_device)), - handle_(handle), - set_layouts_(std::move(set_layouts)) {} - -PipelineExecutableLayout::~PipelineExecutableLayout() { - logical_device_->syms()->vkDestroyPipelineLayout( - *logical_device_, handle_, logical_device_->allocator()); -} - -} // namespace vulkan -} // namespace hal -} // namespace iree
diff --git a/iree/hal/vulkan/pipeline_executable_layout.h b/iree/hal/vulkan/pipeline_executable_layout.h deleted file mode 100644 index b72bda8..0000000 --- a/iree/hal/vulkan/pipeline_executable_layout.h +++ /dev/null
@@ -1,69 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_LAYOUT_H_ -#define IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_LAYOUT_H_ - -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on - -#include "absl/container/inlined_vector.h" -#include "absl/types/span.h" -#include "iree/hal/cc/descriptor_set_layout.h" -#include "iree/hal/cc/executable_layout.h" -#include "iree/hal/vulkan/handle_util.h" - -namespace iree { -namespace hal { -namespace vulkan { - -// A DescriptorSetLayout implemented with the native VkDescriptorSetLayout type. -class NativeDescriptorSetLayout final : public DescriptorSetLayout { - public: - NativeDescriptorSetLayout(ref_ptr<VkDeviceHandle> logical_device, - VkDescriptorSetLayout handle); - ~NativeDescriptorSetLayout() override; - - VkDescriptorSetLayout handle() const { return handle_; } - - private: - ref_ptr<VkDeviceHandle> logical_device_; - VkDescriptorSetLayout handle_; -}; - -class PipelineExecutableLayout final : public ExecutableLayout { - public: - PipelineExecutableLayout( - ref_ptr<VkDeviceHandle> logical_device, VkPipelineLayout handle, - absl::InlinedVector<ref_ptr<NativeDescriptorSetLayout>, 2> set_layouts); - ~PipelineExecutableLayout() override; - - VkPipelineLayout handle() const { return handle_; } - - absl::Span<const ref_ptr<NativeDescriptorSetLayout>> set_layouts() const { - return set_layouts_; - } - - private: - ref_ptr<VkDeviceHandle> logical_device_; - VkPipelineLayout handle_; - absl::InlinedVector<ref_ptr<NativeDescriptorSetLayout>, 2> set_layouts_; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_PIPELINE_EXECUTABLE_LAYOUT_H_
diff --git a/iree/hal/vulkan/registration/BUILD b/iree/hal/vulkan/registration/BUILD index e94df49..56f3026 100644 --- a/iree/hal/vulkan/registration/BUILD +++ b/iree/hal/vulkan/registration/BUILD
@@ -35,12 +35,12 @@ "IREE_HAL_HAVE_VULKAN_DRIVER_MODULE=1", ], deps = [ + "//iree/base:core_headers", "//iree/base:flags", "//iree/base:status", "//iree/base:tracing", "//iree/hal:api", "//iree/hal/vulkan", - "//iree/hal/vulkan:utils", "@com_google_absl//absl/flags:flag", ], )
diff --git a/iree/hal/vulkan/registration/CMakeLists.txt b/iree/hal/vulkan/registration/CMakeLists.txt index 858710f..7ae1dbe 100644 --- a/iree/hal/vulkan/registration/CMakeLists.txt +++ b/iree/hal/vulkan/registration/CMakeLists.txt
@@ -25,12 +25,12 @@ "driver_module.cc" DEPS absl::flags + iree::base::core_headers iree::base::flags iree::base::status iree::base::tracing iree::hal::api iree::hal::vulkan - iree::hal::vulkan::utils DEFINES "IREE_HAL_HAVE_VULKAN_DRIVER_MODULE=1" PUBLIC
diff --git a/iree/hal/vulkan/registration/driver_module.cc b/iree/hal/vulkan/registration/driver_module.cc index 34565e6..cf396f2 100644 --- a/iree/hal/vulkan/registration/driver_module.cc +++ b/iree/hal/vulkan/registration/driver_module.cc
@@ -14,124 +14,77 @@ #include "iree/hal/vulkan/registration/driver_module.h" +#include <inttypes.h> + #include "absl/flags/flag.h" #include "iree/base/flags.h" #include "iree/base/status.h" +#include "iree/base/target_platform.h" #include "iree/base/tracing.h" -#include "iree/hal/vulkan/dynamic_symbols.h" -#include "iree/hal/vulkan/vulkan_driver.h" +#include "iree/hal/vulkan/api.h" + +#define IREE_HAL_VULKAN_1_X_DRIVER_ID 0x564C4B31u // VLK1 ABSL_FLAG(bool, vulkan_validation_layers, true, "Enables standard Vulkan validation layers."); ABSL_FLAG(bool, vulkan_debug_utils, true, "Enables VK_EXT_debug_utils, records markers, and logs errors."); -ABSL_FLAG(bool, vulkan_debug_report, false, - "Enables VK_EXT_debug_report and logs errors."); -ABSL_FLAG(bool, vulkan_push_descriptors, true, - "Enables use of vkCmdPushDescriptorSetKHR, if available."); + ABSL_FLAG(int, vulkan_default_index, 0, "Index of the default Vulkan device."); -ABSL_FLAG(bool, vulkan_renderdoc, false, "Enables RenderDoc API integration."); + ABSL_FLAG(bool, vulkan_force_timeline_semaphore_emulation, false, "Uses timeline semaphore emulation even if native support exists."); -// Vulkan Memory Allocator (VMA) flags -#if VMA_RECORDING_ENABLED -ABSL_FLAG(std::string, vma_recording_file, "", - "File path to write a CSV containing the VMA recording."); -ABSL_FLAG(bool, vma_recording_flush_after_call, false, - "Flush the VMA recording file after every call (useful if " - "crashing/not exiting cleanly)."); -#endif // VMA_RECORDING_ENABLED - -namespace iree { -namespace hal { -namespace vulkan { -namespace { - -StatusOr<ref_ptr<Driver>> CreateVulkanDriver() { - IREE_TRACE_SCOPE0("CreateVulkanDriver"); - - // Load the Vulkan library. This will fail if the library cannot be found or - // does not have the expected functions. - IREE_ASSIGN_OR_RETURN(auto syms, DynamicSymbols::CreateFromSystemLoader()); +static iree_status_t iree_hal_vulkan_create_driver_with_flags( + iree_string_view_t identifier, iree_allocator_t allocator, + iree_hal_driver_t** out_driver) { + IREE_TRACE_SCOPE(); // Setup driver options from flags. We do this here as we want to enable other // consumers that may not be using modules/command line flags to be able to // set their options however they want. - VulkanDriver::Options options; + iree_hal_vulkan_driver_options_t driver_options; + iree_hal_vulkan_driver_options_initialize(&driver_options); - // TODO: validation layers have bugs when using VK_EXT_debug_report, so if the - // user requested that we force them off with a warning. Prefer using - // VK_EXT_debug_utils when available. - if (absl::GetFlag(FLAGS_vulkan_debug_report) && - absl::GetFlag(FLAGS_vulkan_validation_layers)) { - IREE_LOG(WARNING) - << "VK_EXT_debug_report has issues with modern validation " - "layers; disabling validation"; - absl::SetFlag(&FLAGS_vulkan_validation_layers, false); - } - - // REQUIRED: these are required extensions that must be present for IREE to - // work (such as those relied upon by SPIR-V kernels, etc). - options.device_options.extensibility_spec.required_extensions.push_back( - VK_KHR_STORAGE_BUFFER_STORAGE_CLASS_EXTENSION_NAME); - // Multiple extensions depend on VK_KHR_get_physical_device_properties2. - // This extension was deprecated in Vulkan 1.1 as its functionality was - // promoted to core, so we list it as optional even though we require it. - options.instance_extensibility.optional_extensions.push_back( - VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME); - - // Timeline semaphore support is optional and will be emulated if necessary. - options.device_options.extensibility_spec.optional_extensions.push_back( - VK_KHR_TIMELINE_SEMAPHORE_EXTENSION_NAME); - // Polyfill layer - enable if present (instead of our custom emulation). - options.instance_extensibility.optional_layers.push_back( - "VK_LAYER_KHRONOS_timeline_semaphore"); +// TODO(benvanik): make this a flag - it's useful for testing the same binary +// against multiple versions of Vulkan. +#if defined(IREE_PLATFORM_ANDROID) + // TODO(#4494): let's see when we can always enable timeline semaphores. + driver_options.api_version = VK_API_VERSION_1_1; +#else + driver_options.api_version = VK_API_VERSION_1_2; +#endif // IREE_PLATFORM_ANDROID if (absl::GetFlag(FLAGS_vulkan_validation_layers)) { - options.instance_extensibility.optional_layers.push_back( - "VK_LAYER_KHRONOS_validation"); - } - - if (absl::GetFlag(FLAGS_vulkan_debug_report)) { - options.instance_extensibility.optional_extensions.push_back( - VK_EXT_DEBUG_REPORT_EXTENSION_NAME); + driver_options.requested_features |= + IREE_HAL_VULKAN_FEATURE_ENABLE_VALIDATION_LAYERS; } if (absl::GetFlag(FLAGS_vulkan_debug_utils)) { - options.instance_extensibility.optional_extensions.push_back( - VK_EXT_DEBUG_UTILS_EXTENSION_NAME); + driver_options.requested_features |= + IREE_HAL_VULKAN_FEATURE_ENABLE_DEBUG_UTILS; } - if (absl::GetFlag(FLAGS_vulkan_push_descriptors)) { - options.device_options.extensibility_spec.optional_extensions.push_back( - VK_KHR_PUSH_DESCRIPTOR_EXTENSION_NAME); + driver_options.default_device_index = + absl::GetFlag(FLAGS_vulkan_default_index); + + if (absl::GetFlag(FLAGS_vulkan_force_timeline_semaphore_emulation)) { + driver_options.device_options.flags |= + IREE_HAL_VULKAN_DEVICE_FORCE_TIMELINE_SEMAPHORE_EMULATION; } - options.default_device_index = absl::GetFlag(FLAGS_vulkan_default_index); - options.enable_renderdoc = absl::GetFlag(FLAGS_vulkan_renderdoc); - options.device_options.force_timeline_semaphore_emulation = - absl::GetFlag(FLAGS_vulkan_force_timeline_semaphore_emulation); + // Load the Vulkan library. This will fail if the library cannot be found or + // does not have the expected functions. + iree_hal_vulkan_syms_t* syms = NULL; + IREE_RETURN_IF_ERROR( + iree_hal_vulkan_syms_create_from_system_loader(allocator, &syms)); -#if VMA_RECORDING_ENABLED - options.device_options.vma_options.recording_file = - absl::GetFlag(FLAGS_vma_recording_file); - options.device_options.vma_options.recording_flush_after_call = - absl::GetFlag(FLAGS_vma_recording_flush_after_call); -#endif // VMA_RECORDING_ENABLED + iree_status_t status = iree_hal_vulkan_driver_create( + identifier, &driver_options, syms, allocator, out_driver); - // Create the driver and VkInstance. - return VulkanDriver::Create(options, std::move(syms)); + iree_hal_vulkan_syms_release(syms); + return status; } -} // namespace -} // namespace vulkan -} // namespace hal -} // namespace iree - -#include <inttypes.h> - -#define IREE_HAL_VULKAN_1_X_DRIVER_ID 0x564C4B31u // VLK1 - static iree_status_t iree_hal_vulkan_driver_factory_enumerate( void* self, const iree_hal_driver_info_t** out_driver_infos, iree_host_size_t* out_driver_info_count) { @@ -155,9 +108,13 @@ " is provided by this factory", driver_id); } - IREE_ASSIGN_OR_RETURN(auto driver, iree::hal::vulkan::CreateVulkanDriver()); - *out_driver = reinterpret_cast<iree_hal_driver_t*>(driver.release()); - return iree_ok_status(); + + // When we expose more than one driver (different vulkan versions, etc) we + // can name them here: + iree_string_view_t identifier = iree_make_cstring_view("vulkan"); + + return iree_hal_vulkan_create_driver_with_flags(identifier, allocator, + out_driver); } IREE_API_EXPORT iree_status_t IREE_API_CALL
diff --git a/iree/hal/vulkan/renderdoc_capture_manager.cc b/iree/hal/vulkan/renderdoc_capture_manager.cc deleted file mode 100644 index b8b06ce..0000000 --- a/iree/hal/vulkan/renderdoc_capture_manager.cc +++ /dev/null
@@ -1,121 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "iree/hal/vulkan/renderdoc_capture_manager.h" - -#include "absl/types/span.h" -#include "iree/base/logging.h" -#include "iree/base/target_platform.h" -#include "iree/base/tracing.h" - -#if !defined(IREE_PLATFORM_WINDOWS) -#include <dlfcn.h> -#endif // IREE_PLATFORM_WINDOWS - -namespace iree { -namespace hal { -namespace vulkan { - -namespace { - -static const char* kRenderDocSearchNames[] = { -#if defined(IREE_PLATFORM_WINDOWS) - "renderdoc.dll", - "C:/Program Files/RenderDoc/renderdoc.dll", -#else - "librenderdoc.so", -#endif // IREE_PLATFORM_WINDOWS -}; - -} // namespace - -RenderDocCaptureManager::RenderDocCaptureManager() {} - -RenderDocCaptureManager::~RenderDocCaptureManager() { - IREE_TRACE_SCOPE0("RenderDocCaptureManager::dtor"); - Disconnect(); -} - -Status RenderDocCaptureManager::Connect() { - IREE_TRACE_SCOPE0("RenderDocCaptureManager::Connect"); - - if (renderdoc_library_ != nullptr) { - return OkStatus(); - } - - IREE_ASSIGN_OR_RETURN( - renderdoc_library_, - DynamicLibrary::Load(absl::MakeSpan(kRenderDocSearchNames))); - - auto renderdoc_get_api_fn = - renderdoc_library_->GetSymbol<pRENDERDOC_GetAPI>("RENDERDOC_GetAPI"); - int ret = renderdoc_get_api_fn(eRENDERDOC_API_Version_1_4_0, - (void**)&renderdoc_api_); - if (ret != 1) { - renderdoc_api_ = nullptr; - return InternalErrorBuilder(IREE_LOC) - << "Failed to get RenderDoc API object"; - } - - IREE_LOG(INFO) << "Connected to RenderDoc's API; writing captures to " - << renderdoc_api_->GetCaptureFilePathTemplate(); - - return OkStatus(); -} - -void RenderDocCaptureManager::Disconnect() { - IREE_TRACE_SCOPE0("RenderDocCaptureManager::Disconnect"); - - if (renderdoc_library_ == nullptr) { - return; - } - - if (is_capturing()) { - StopCapture(); - } - - renderdoc_api_ = nullptr; - renderdoc_library_.reset(); -} - -void RenderDocCaptureManager::StartCapture() { - IREE_TRACE_SCOPE0("RenderDocCaptureManager::StartCapture"); - - IREE_CHECK(is_connected()) << "Can't start capture when not connected"; - IREE_CHECK(!is_capturing()) << "Capture is already started"; - - IREE_LOG(INFO) << "Starting RenderDoc capture"; - renderdoc_api_->StartFrameCapture(NULL, NULL); -} - -void RenderDocCaptureManager::StopCapture() { - IREE_TRACE_SCOPE0("RenderDocCaptureManager::StopCapture"); - - IREE_CHECK(is_capturing()) << "Can't stop capture when not capturing"; - - IREE_LOG(INFO) << "Ending RenderDoc capture"; - renderdoc_api_->EndFrameCapture(NULL, NULL); -} - -bool RenderDocCaptureManager::is_capturing() const { - if (!is_connected()) { - return false; - } - - return renderdoc_api_->IsFrameCapturing() == 1; -} - -} // namespace vulkan -} // namespace hal -} // namespace iree
diff --git a/iree/hal/vulkan/renderdoc_capture_manager.h b/iree/hal/vulkan/renderdoc_capture_manager.h deleted file mode 100644 index 1aac73a..0000000 --- a/iree/hal/vulkan/renderdoc_capture_manager.h +++ /dev/null
@@ -1,57 +0,0 @@ -// Copyright 2020 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef IREE_HAL_VULKAN_RENDERDOC_CAPTURE_MANAGER_H_ -#define IREE_HAL_VULKAN_RENDERDOC_CAPTURE_MANAGER_H_ - -#include "iree/base/dynamic_library.h" -#include "iree/base/status.h" -#include "iree/hal/cc/debug_capture_manager.h" -#include "third_party/renderdoc_api/app/renderdoc_app.h" - -namespace iree { -namespace hal { -namespace vulkan { - -// Capture manager using RenderDoc to record Vulkan commands. -// See https://renderdoc.org/ and https://github.com/baldurk/renderdoc. -class RenderDocCaptureManager final : public DebugCaptureManager { - public: - RenderDocCaptureManager(); - ~RenderDocCaptureManager() override; - - // Note: Connect() must be called *before* creating a VkInstance. - Status Connect() override; - - void Disconnect() override; - - bool is_connected() const override { return renderdoc_api_ != nullptr; } - - // Note: StartCapture() must be called *after* creating a VkDevice. - void StartCapture() override; - - void StopCapture() override; - - bool is_capturing() const override; - - private: - std::unique_ptr<DynamicLibrary> renderdoc_library_; - RENDERDOC_API_1_4_0* renderdoc_api_ = nullptr; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree - -#endif // IREE_HAL_VULKAN_RENDERDOC_CAPTURE_MANAGER_H_
diff --git a/iree/hal/vulkan/serializing_command_queue.cc b/iree/hal/vulkan/serializing_command_queue.cc index 8fb8bf1..b524f7b 100644 --- a/iree/hal/vulkan/serializing_command_queue.cc +++ b/iree/hal/vulkan/serializing_command_queue.cc
@@ -20,11 +20,9 @@ #include "iree/base/api.h" #include "iree/base/memory.h" #include "iree/base/tracing.h" -#include "iree/hal/cc/command_buffer.h" -#include "iree/hal/cc/command_queue.h" -#include "iree/hal/cc/semaphore.h" +#include "iree/hal/api.h" #include "iree/hal/vulkan/direct_command_buffer.h" -#include "iree/hal/vulkan/emulated_timeline_semaphore.h" +#include "iree/hal/vulkan/emulated_semaphore.h" #include "iree/hal/vulkan/status_util.h" namespace iree { @@ -39,199 +37,195 @@ // batch is ready to be submitted to GPU. // |wait_semaphores| and |signal_semaphores| will be filled with the binary // `VkSemaphores` on success. -StatusOr<bool> TryToPrepareSemaphores( +iree_status_t TryToPrepareSemaphores( const absl::InlinedVector<SemaphoreValue, 4>& batch_wait_semaphores, const absl::InlinedVector<SemaphoreValue, 4>& batch_signal_semaphores, const ref_ptr<TimePointFence>& batch_fence, absl::InlinedVector<VkSemaphore, 4>* wait_semaphores, - absl::InlinedVector<VkSemaphore, 4>* signal_semaphores) { + absl::InlinedVector<VkSemaphore, 4>* signal_semaphores, + bool* out_ready_to_submit) { IREE_TRACE_SCOPE0("TryToPrepareSemaphores"); - IREE_DVLOG(3) << "TryToPrepareSemaphores"; + *out_ready_to_submit = false; wait_semaphores->clear(); for (const auto& timeline_semaphore : batch_wait_semaphores) { - IREE_DVLOG(3) << "Preparing binary VkSemaphore for timeline semaphore " - << timeline_semaphore.semaphore << ".."; // Query first to progress this timeline semaphore to the furthest. - IREE_ASSIGN_OR_RETURN(auto signaled_value, - timeline_semaphore.semaphore->Query()); + uint64_t signaled_value = 0; + IREE_RETURN_IF_ERROR( + iree_hal_semaphore_query(timeline_semaphore.first, &signaled_value)); // If it's already signaled to a value greater than we require here, // we can just ignore this semaphore now. - if (signaled_value >= timeline_semaphore.value) { - IREE_DVLOG(3) << "..already signaled past; ignoring"; + if (signaled_value >= timeline_semaphore.second) { continue; } - // SerializingCommandQueue only works with EmulatedTimelineSemaphore. - auto* emulated_semaphore = - static_cast<EmulatedTimelineSemaphore*>(timeline_semaphore.semaphore); - // Otherwise try to get a binary semaphore for this time point so that // we can wait on. - VkSemaphore binary_semaphore = emulated_semaphore->GetWaitSemaphore( - timeline_semaphore.value, batch_fence); + // TODO(antiagainst): if this fails we need to cancel. + VkSemaphore wait_semaphore = VK_NULL_HANDLE; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_emulated_semaphore_acquire_wait_handle( + timeline_semaphore.first, timeline_semaphore.second, batch_fence, + &wait_semaphore)); + wait_semaphores->push_back(wait_semaphore); - if (binary_semaphore == VK_NULL_HANDLE) { + if (wait_semaphore == VK_NULL_HANDLE) { // We cannot wait on this time point yet: there are no previous semaphores - // submitted to the GPU that can signal a value greater than what's + // submitted to the GPU that can signal a value greater than what's // desired here. // Cancel the wait so others may make progress. - for (VkSemaphore semaphore : *wait_semaphores) { + // TODO(antiagainst): if any of these fail we need to cancel. + for (iree_host_size_t i = 0; i < batch_wait_semaphores.size(); ++i) { + if (!wait_semaphores->at(i)) break; IREE_RETURN_IF_ERROR( - emulated_semaphore->CancelWaitSemaphore(semaphore)); + iree_hal_vulkan_emulated_semaphore_cancel_wait_handle( + batch_wait_semaphores[i].first, wait_semaphores->at(i))); } // This batch cannot be submitted to GPU yet. - return false; + return iree_ok_status(); } - IREE_DVLOG(3) << "..acqiured binary VkSemaphore " << binary_semaphore; - - wait_semaphores->push_back(binary_semaphore); } // We've collected all necessary binary semaphores for each timeline we need // to wait on. Now prepare binary semaphores for signaling. signal_semaphores->clear(); for (const auto& timeline_semaphore : batch_signal_semaphores) { - IREE_DVLOG(3) << "Preparing binary VkSemaphore for timeline semaphore " - << timeline_semaphore.semaphore << ".."; // SerializingCommandQueue only works with EmulatedTimelineSemaphore. - auto* emulated_semaphore = - static_cast<EmulatedTimelineSemaphore*>(timeline_semaphore.semaphore); - - IREE_ASSIGN_OR_RETURN(auto binary_semaphore, - emulated_semaphore->GetSignalSemaphore( - timeline_semaphore.value, batch_fence)); - signal_semaphores->push_back(binary_semaphore); - IREE_DVLOG(3) << "..acqiured binary VkSemaphore " << binary_semaphore; + VkSemaphore signal_semaphore = VK_NULL_HANDLE; + IREE_RETURN_IF_ERROR( + iree_hal_vulkan_emulated_semaphore_acquire_signal_handle( + timeline_semaphore.first, timeline_semaphore.second, batch_fence, + &signal_semaphore)); + signal_semaphores->push_back(signal_semaphore); } // Good to submit! - IREE_DVLOG(3) << "Succeeded in preparing binary VkSemaphores for submission"; - return true; + *out_ready_to_submit = true; + return iree_ok_status(); } // Prepares `VkSubmitInfo` to submit the given list of |command_buffers| that // waiting on |wait_semaphores| and signalling |signal_semaphores|. Necessary // structures are allocated from |arena| and the result `VkSubmitInfo` is // written to |submit_info|. -void PrepareSubmitInfo( - const absl::InlinedVector<VkSemaphore, 4>& wait_semaphores, - absl::Span<CommandBuffer* const> command_buffers, - const absl::InlinedVector<VkSemaphore, 4>& signal_semaphores, - VkSubmitInfo* submit_info, Arena* arena) { - IREE_TRACE_SCOPE0("PrepareSubmitInfo"); - +void PrepareSubmitInfo(absl::Span<const VkSemaphore> wait_semaphore_handles, + absl::Span<const VkCommandBuffer> command_buffer_handles, + absl::Span<const VkSemaphore> signal_semaphore_handles, + VkSubmitInfo* submit_info, Arena* arena) { // TODO(benvanik): see if we can go to finer-grained stages. // For example, if this was just queue ownership transfers then we can use // the pseudo-stage of VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT. - VkPipelineStageFlags dst_stage_mask = - VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT; - - auto wait_semaphore_handles = - arena->AllocateSpan<VkSemaphore>(wait_semaphores.size()); auto wait_dst_stage_masks = - arena->AllocateSpan<VkPipelineStageFlags>(wait_semaphores.size()); - for (size_t i = 0, e = wait_semaphores.size(); i < e; ++i) { - wait_semaphore_handles[i] = wait_semaphores[i]; - wait_dst_stage_masks[i] = dst_stage_mask; + arena->AllocateSpan<VkPipelineStageFlags>(wait_semaphore_handles.size()); + for (size_t i = 0, e = wait_semaphore_handles.size(); i < e; ++i) { + wait_dst_stage_masks[i] = + VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT; } - auto signal_semaphore_handles = - arena->AllocateSpan<VkSemaphore>(signal_semaphores.size()); - for (size_t i = 0, e = signal_semaphores.size(); i < e; ++i) { - signal_semaphore_handles[i] = signal_semaphores[i]; + // NOTE: this code does some very weird things - the handles we take in as + // args are mutated in-place after this function is called so we can't + // reference them here. If we were going to preserve this code post-Vulkan 1.2 + // then we'd really want to rework all of this to properly use the arena from + // the start instead of all this InlinedVector tomfoolery. + auto wait_semaphores = + arena->AllocateSpan<VkSemaphore>(wait_semaphore_handles.size()); + for (size_t i = 0, e = wait_semaphore_handles.size(); i < e; ++i) { + wait_semaphores[i] = wait_semaphore_handles[i]; } - - auto command_buffer_handles = - arena->AllocateSpan<VkCommandBuffer>(command_buffers.size()); - for (size_t i = 0, e = command_buffers.size(); i < e; ++i) { - const auto& command_buffer = command_buffers[i]; - auto* direct_command_buffer = - static_cast<DirectCommandBuffer*>(command_buffer->impl()); - command_buffer_handles[i] = direct_command_buffer->handle(); + auto command_buffers = + arena->AllocateSpan<VkCommandBuffer>(command_buffer_handles.size()); + for (size_t i = 0, e = command_buffer_handles.size(); i < e; ++i) { + command_buffers[i] = command_buffer_handles[i]; + } + auto signal_semaphores = + arena->AllocateSpan<VkSemaphore>(signal_semaphore_handles.size()); + for (size_t i = 0, e = signal_semaphore_handles.size(); i < e; ++i) { + signal_semaphores[i] = signal_semaphore_handles[i]; } submit_info->sType = VK_STRUCTURE_TYPE_SUBMIT_INFO; submit_info->pNext = nullptr; submit_info->waitSemaphoreCount = - static_cast<uint32_t>(wait_semaphore_handles.size()); - submit_info->pWaitSemaphores = wait_semaphore_handles.data(); + static_cast<uint32_t>(wait_semaphores.size()); + submit_info->pWaitSemaphores = wait_semaphores.data(); submit_info->pWaitDstStageMask = wait_dst_stage_masks.data(); submit_info->commandBufferCount = - static_cast<uint32_t>(command_buffer_handles.size()); - submit_info->pCommandBuffers = command_buffer_handles.data(); + static_cast<uint32_t>(command_buffers.size()); + submit_info->pCommandBuffers = command_buffers.data(); submit_info->signalSemaphoreCount = - static_cast<uint32_t>(signal_semaphore_handles.size()); - submit_info->pSignalSemaphores = signal_semaphore_handles.data(); + static_cast<uint32_t>(signal_semaphores.size()); + submit_info->pSignalSemaphores = signal_semaphores.data(); } } // namespace SerializingCommandQueue::SerializingCommandQueue( - std::string name, iree_hal_command_category_t supported_categories, - const ref_ptr<VkDeviceHandle>& logical_device, - const ref_ptr<TimePointFencePool>& fence_pool, VkQueue queue) - : CommandQueue(std::move(name), supported_categories), - logical_device_(add_ref(logical_device)), - fence_pool_(add_ref(fence_pool)), - queue_(queue) {} + VkDeviceHandle* logical_device, std::string name, + iree_hal_command_category_t supported_categories, VkQueue queue, + TimePointFencePool* fence_pool) + : CommandQueue(logical_device, std::move(name), supported_categories, + queue), + fence_pool_(fence_pool) {} -SerializingCommandQueue::~SerializingCommandQueue() { - IREE_TRACE_SCOPE0("SerializingCommandQueue::dtor"); - absl::MutexLock lock(&mutex_); - syms()->vkQueueWaitIdle(queue_); -} +SerializingCommandQueue::~SerializingCommandQueue() = default; -Status SerializingCommandQueue::Submit( - absl::Span<const SubmissionBatch> batches) { +iree_status_t SerializingCommandQueue::Submit( + iree_host_size_t batch_count, const iree_hal_submission_batch_t* batches) { IREE_TRACE_SCOPE0("SerializingCommandQueue::Submit"); - IREE_DVLOG(2) << "SerializingCommandQueue::Submit"; - absl::MutexLock lock(&mutex_); - for (size_t i = 0; i < batches.size(); ++i) { + IntrusiveList<std::unique_ptr<FencedSubmission>> new_submissions; + for (iree_host_size_t i = 0; i < batch_count; ++i) { + const iree_hal_submission_batch_t* batch = &batches[i]; + // Grab a fence for this submission first. This will be used to check the // progress of emulated timeline semaphores later. - IREE_ASSIGN_OR_RETURN(auto fence, fence_pool_->Acquire()); auto submission = std::make_unique<FencedSubmission>(); - submission->batch = PendingBatch{ - {batches[i].wait_semaphores.begin(), batches[i].wait_semaphores.end()}, - {batches[i].command_buffers.begin(), batches[i].command_buffers.end()}, - {batches[i].signal_semaphores.begin(), - batches[i].signal_semaphores.end()}}; - submission->fence = std::move(fence); - deferred_submissions_.push_back(std::move(submission)); + IREE_ASSIGN_OR_RETURN(submission->fence, fence_pool_->Acquire()); + + submission->wait_semaphores.resize(batch->wait_semaphores.count); + for (iree_host_size_t j = 0; j < batch->wait_semaphores.count; ++j) { + submission->wait_semaphores[j] = { + batch->wait_semaphores.semaphores[j], + batch->wait_semaphores.payload_values[j]}; + } + + submission->command_buffers.resize(batch->command_buffer_count); + for (iree_host_size_t j = 0; j < batch->command_buffer_count; ++j) { + submission->command_buffers[j] = + iree_hal_vulkan_direct_command_buffer_handle( + batch->command_buffers[j]); + } + + submission->signal_semaphores.resize(batch->signal_semaphores.count); + for (iree_host_size_t j = 0; j < batch->signal_semaphores.count; ++j) { + submission->signal_semaphores[j] = { + batch->signal_semaphores.semaphores[j], + batch->signal_semaphores.payload_values[j]}; + } + + new_submissions.push_back(std::move(submission)); } - return ProcessDeferredSubmissions().status(); + iree_slim_mutex_lock(&queue_mutex_); + deferred_submissions_.merge_from(&new_submissions); + iree_status_t status = ProcessDeferredSubmissions(); + iree_slim_mutex_unlock(&queue_mutex_); + return status; } -StatusOr<bool> SerializingCommandQueue::ProcessDeferredSubmissions() { +iree_status_t SerializingCommandQueue::ProcessDeferredSubmissions( + bool* out_work_submitted) { IREE_TRACE_SCOPE0("SerializingCommandQueue::ProcessDeferredSubmissions"); - IREE_DVLOG(2) << "SerializingCommandQueue::ProcessDeferredSubmissions"; - - // Prepare `VkSubmitInfo`s for all submissions we are able to submit. - - // Note that we must keep all arrays referenced alive until submission - // completes and since there are a bunch of them we use an arena. - Arena arena(4 * 1024); - - absl::InlinedVector<VkSubmitInfo, 4> submit_infos; - absl::InlinedVector<VkFence, 4> submit_fences; - - absl::InlinedVector<VkSemaphore, 4> wait_semaphores; - absl::InlinedVector<VkSemaphore, 4> signal_semaphores; - - // A list of submissions that still needs to be deferred. - IntrusiveList<std::unique_ptr<FencedSubmission>> remaining_submissions; + if (out_work_submitted) *out_work_submitted = false; // We need to return all remaining submissions back to the queue to avoid // dropping work. + IntrusiveList<std::unique_ptr<FencedSubmission>> remaining_submissions; auto submission_cleanup = MakeCleanup([this, &remaining_submissions]() { // Disable thread-safety-analysis as it doesn't understand this lambda. -// - This entire function is ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) +// - This entire function is ABSL_EXCLUSIVE_LOCKS_REQUIRED(queue_mutex_) // - This Cleanup object is destroyed when it drops out of scope // - The mutex is always held when executing this function #ifdef __clang__ @@ -242,122 +236,125 @@ deferred_submissions_.push_back( remaining_submissions.take(remaining_submissions.front())); } - - IREE_DVLOG(2) << deferred_submissions_.size() - << " deferred submissions still remaining"; #ifdef __clang__ #pragma clang diagnostic pop #endif }); + Arena arena(4 * 1024); + absl::InlinedVector<VkSubmitInfo, 4> submit_infos; + absl::InlinedVector<VkFence, 4> submit_fences; while (!deferred_submissions_.empty()) { - IREE_DVLOG(2) << "Looking at deferred submission with timepoint fence " - << deferred_submissions_.front()->fence.get() << ".."; - - wait_semaphores.clear(); - signal_semaphores.clear(); - FencedSubmission* submission = deferred_submissions_.front(); - const PendingBatch& batch = submission->batch; ref_ptr<TimePointFence>& fence = submission->fence; - IREE_ASSIGN_OR_RETURN( - bool ready_to_submit, - TryToPrepareSemaphores(batch.wait_semaphores, batch.signal_semaphores, - fence, &wait_semaphores, &signal_semaphores)); - + absl::InlinedVector<VkSemaphore, 4> wait_semaphores; + absl::InlinedVector<VkSemaphore, 4> signal_semaphores; + bool ready_to_submit = false; + IREE_RETURN_IF_ERROR(TryToPrepareSemaphores( + submission->wait_semaphores, submission->signal_semaphores, fence, + &wait_semaphores, &signal_semaphores, &ready_to_submit)); if (ready_to_submit) { submit_infos.emplace_back(); - PrepareSubmitInfo(wait_semaphores, batch.command_buffers, + PrepareSubmitInfo(wait_semaphores, submission->command_buffers, signal_semaphores, &submit_infos.back(), &arena); submit_fences.push_back(fence->value()); pending_fences_.emplace_back(std::move(fence)); deferred_submissions_.pop_front(); - IREE_DVLOG(2) << "..ready to submit"; } else { // We need to defer the submission until later. remaining_submissions.push_back(deferred_submissions_.take(submission)); - IREE_DVLOG(2) << "..not ready to submit"; } } - - if (submit_infos.empty()) return false; - - auto infos = arena.AllocateSpan<VkSubmitInfo>(submit_infos.size()); - for (size_t i = 0, e = submit_infos.size(); i < e; ++i) { - infos[i] = submit_infos[i]; + if (submit_infos.empty()) { + if (out_work_submitted) *out_work_submitted = false; + return iree_ok_status(); } // Note: We might be able to batch the submission but it involves non-trivial // fence handling. We can handle that if really needed. for (size_t i = 0, e = submit_infos.size(); i < e; ++i) { - VK_RETURN_IF_ERROR(syms()->vkQueueSubmit( - queue_, /*submitCount=*/1, &submit_infos[i], submit_fences[i])); + VK_RETURN_IF_ERROR( + syms()->vkQueueSubmit(queue_, /*submitCount=*/1, &submit_infos[i], + submit_fences[i]), + "vkQueueSubmit"); } - IREE_DVLOG(2) << "Released " << submit_infos.size() - << " deferred submissions"; - - return true; + if (out_work_submitted) *out_work_submitted = true; + return iree_ok_status(); } -Status SerializingCommandQueue::WaitIdle(Time deadline_ns) { - absl::MutexLock lock(&mutex_); - IREE_DVLOG(2) << "SerializingCommandQueue::WaitIdle"; +iree_status_t SerializingCommandQueue::WaitIdle(iree_time_t deadline_ns) { + iree_status_t status = iree_ok_status(); - if (deadline_ns == InfiniteFuture()) { + if (deadline_ns == IREE_TIME_INFINITE_FUTURE) { IREE_TRACE_SCOPE0("SerializingCommandQueue::WaitIdle#vkQueueWaitIdle"); // Fast path for using vkQueueWaitIdle, which is usually cheaper (as it // requires fewer calls into the driver). + iree_slim_mutex_lock(&queue_mutex_); + // Complete all pending work on the queue. - VK_RETURN_IF_ERROR(syms()->vkQueueWaitIdle(queue_)); + status = + VK_RESULT_TO_STATUS(syms()->vkQueueWaitIdle(queue_), "vkQueueWaitIdle"); + if (!iree_status_is_ok(status)) { + iree_slim_mutex_unlock(&queue_mutex_); + return status; + } pending_fences_.clear(); // Submit and complete all deferred work. while (!deferred_submissions_.empty()) { - IREE_ASSIGN_OR_RETURN(bool work_submitted, ProcessDeferredSubmissions()); + bool work_submitted = false; + status = ProcessDeferredSubmissions(&work_submitted); + if (!iree_status_is_ok(status)) break; if (work_submitted) { - VK_RETURN_IF_ERROR(syms()->vkQueueWaitIdle(queue_)); + status = VK_RESULT_TO_STATUS(syms()->vkQueueWaitIdle(queue_), + "vkQueueWaitIdle"); + if (!iree_status_is_ok(status)) break; pending_fences_.clear(); } } - return OkStatus(); + iree_slim_mutex_unlock(&queue_mutex_); + return status; } IREE_TRACE_SCOPE0("SerializingCommandQueue::WaitIdle#Fence"); // Keep trying to submit more workload to the GPU until reaching the deadline. + iree_slim_mutex_lock(&queue_mutex_); do { - IREE_RETURN_IF_ERROR(ProcessDeferredSubmissions().status()); + status = ProcessDeferredSubmissions(); + bool has_deferred_submissions = !deferred_submissions_.empty(); + absl::InlinedVector<VkFence, 8> fence_handles(pending_fences_.size()); + for (size_t i = 0; i < pending_fences_.size(); ++i) { + fence_handles[i] = pending_fences_[i]->value(); + } + if (!iree_status_is_ok(status)) { + break; // unable to process submissions + } else if (!has_deferred_submissions && fence_handles.empty()) { + break; // no more work - idle achieved + } uint64_t timeout_ns; - if (deadline_ns == InfiniteFuture()) { + if (deadline_ns == IREE_TIME_INFINITE_FUTURE) { timeout_ns = UINT64_MAX; - } else if (deadline_ns == InfinitePast()) { + } else if (deadline_ns == IREE_TIME_INFINITE_PAST) { timeout_ns = 0; } else { // Convert to relative time in nanoseconds. - // The implementation may not wait with this granularity (like, by - // 10000x). - Duration relative_ns = deadline_ns - Now(); - if (relative_ns < ZeroDuration()) { - return DeadlineExceededErrorBuilder(IREE_LOC) - << "Deadline exceeded waiting for idle"; + // The implementation may not wait with this granularity (like by 10000x). + iree_time_t now_ns = iree_time_now(); + if (deadline_ns < now_ns) { + return iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED); } - timeout_ns = static_cast<uint64_t>(relative_ns); + timeout_ns = (uint64_t)(deadline_ns - now_ns); } - - if (pending_fences_.empty()) continue; - - std::vector<VkFence> fences; - fences.reserve(pending_fences_.size()); - for (const auto& fence : pending_fences_) fences.push_back(fence->value()); - VkResult result = syms()->vkWaitForFences( - *logical_device_, static_cast<uint32_t>(fences.size()), fences.data(), + *logical_device_, static_cast<uint32_t>(fence_handles.size()), + fence_handles.data(), /*waitAll=*/VK_TRUE, timeout_ns); switch (result) { @@ -365,56 +362,62 @@ pending_fences_.clear(); break; case VK_TIMEOUT: - return DeadlineExceededErrorBuilder(IREE_LOC) - << "Deadline exceeded waiting for idle"; + status = iree_status_from_code(IREE_STATUS_DEADLINE_EXCEEDED); + break; default: - return VkResultToStatus(result, IREE_LOC); + status = VK_RESULT_TO_STATUS(result, "vkWaitForFences"); + break; } // As long as there is submitted or deferred work still pending. - } while (!pending_fences_.empty() || !deferred_submissions_.empty()); - - return OkStatus(); + } while (iree_status_is_ok(status)); + iree_slim_mutex_unlock(&queue_mutex_); + return status; } -Status SerializingCommandQueue::AdvanceQueueSubmission() { - absl::MutexLock lock(&mutex_); +iree_status_t SerializingCommandQueue::AdvanceQueueSubmission() { // The returned value just indicates whether there were newly ready // submissions gotten submitted to the GPU. Other callers might be // interested in that information but for this API we just want to advance // queue submisison if possible. So we ignore it here. - IREE_ASSIGN_OR_RETURN(std::ignore, ProcessDeferredSubmissions()); - return OkStatus(); + iree_slim_mutex_lock(&queue_mutex_); + iree_status_t status = ProcessDeferredSubmissions(); + iree_slim_mutex_unlock(&queue_mutex_); + return status; } void SerializingCommandQueue::AbortQueueSubmission() { - absl::MutexLock lock(&mutex_); + iree_slim_mutex_lock(&queue_mutex_); // We have fences in deferred_submissions_ but they are not submitted to GPU // yet so we don't need to reset. deferred_submissions_.clear(); - std::vector<VkFence> fences; - fences.reserve(pending_fences_.size()); - for (const auto& fence : pending_fences_) fences.push_back(fence->value()); + absl::InlinedVector<VkFence, 8> fence_handles(pending_fences_.size()); + for (size_t i = 0; i < pending_fences_.size(); ++i) { + fence_handles[i] = pending_fences_[i]->value(); + } syms()->vkWaitForFences(*logical_device_, - static_cast<uint32_t>(fences.size()), fences.data(), + static_cast<uint32_t>(fence_handles.size()), + fence_handles.data(), /*waitAll=*/VK_TRUE, /*timeout=*/UINT64_MAX); + // Clear the list. Fences will be automatically returned back to the queue // after refcount reaches 0. pending_fences_.clear(); + + iree_slim_mutex_unlock(&queue_mutex_); } void SerializingCommandQueue::SignalFences(absl::Span<VkFence> fences) { - auto span_contains = [&fences](VkFence fence) { + const auto span_contains = [fences](VkFence fence) { for (VkFence f : fences) { if (f == fence) return true; } return false; }; - absl::MutexLock lock(&mutex_); - + iree_slim_mutex_lock(&queue_mutex_); auto it = pending_fences_.begin(); while (it != pending_fences_.end()) { if (span_contains((*it)->value())) { @@ -423,6 +426,7 @@ ++it; } } + iree_slim_mutex_unlock(&queue_mutex_); } } // namespace vulkan
diff --git a/iree/hal/vulkan/serializing_command_queue.h b/iree/hal/vulkan/serializing_command_queue.h index 0eaf663..6b413bd 100644 --- a/iree/hal/vulkan/serializing_command_queue.h +++ b/iree/hal/vulkan/serializing_command_queue.h
@@ -24,13 +24,11 @@ #include "absl/base/thread_annotations.h" #include "absl/container/inlined_vector.h" -#include "absl/synchronization/mutex.h" #include "iree/base/intrusive_list.h" #include "iree/base/ref_ptr.h" #include "iree/base/status.h" -#include "iree/base/time.h" -#include "iree/hal/cc/command_buffer.h" -#include "iree/hal/cc/command_queue.h" +#include "iree/hal/api.h" +#include "iree/hal/vulkan/command_queue.h" #include "iree/hal/vulkan/dynamic_symbols.h" #include "iree/hal/vulkan/handle_util.h" #include "iree/hal/vulkan/timepoint_util.h" @@ -39,6 +37,8 @@ namespace hal { namespace vulkan { +using SemaphoreValue = std::pair<iree_hal_semaphore_t*, uint64_t>; + // A command queue that potentially defers and serializes command buffer // submission to the GPU. // @@ -52,23 +52,22 @@ // the GPU. class SerializingCommandQueue final : public CommandQueue { public: - SerializingCommandQueue(std::string name, + SerializingCommandQueue(VkDeviceHandle* logical_device, std::string name, iree_hal_command_category_t supported_categories, - const ref_ptr<VkDeviceHandle>& logical_device, - const ref_ptr<TimePointFencePool>& fence_pool, - VkQueue queue); + VkQueue queue, TimePointFencePool* fence_pool); ~SerializingCommandQueue() override; const ref_ptr<DynamicSymbols>& syms() const { return logical_device_->syms(); } - Status Submit(absl::Span<const SubmissionBatch> batches) override; + iree_status_t Submit(iree_host_size_t batch_count, + const iree_hal_submission_batch_t* batches) override; - Status WaitIdle(Time deadline_ns) override; + iree_status_t WaitIdle(iree_time_t deadline_ns) override; // Releases all deferred submissions ready to submit to the GPU. - Status AdvanceQueueSubmission(); + iree_status_t AdvanceQueueSubmission(); // Aborts all deferred submissions and waits for submitted work to complete. void AbortQueueSubmission(); @@ -77,37 +76,26 @@ void SignalFences(absl::Span<VkFence> fences); private: - struct PendingBatch { - absl::InlinedVector<SemaphoreValue, 4> wait_semaphores; - absl::InlinedVector<CommandBuffer*, 4> command_buffers; - absl::InlinedVector<SemaphoreValue, 4> signal_semaphores; - }; // A submission batch together with the fence to singal its status. struct FencedSubmission : public IntrusiveLinkBase<void> { - PendingBatch batch; + absl::InlinedVector<SemaphoreValue, 4> wait_semaphores; + absl::InlinedVector<VkCommandBuffer, 4> command_buffers; + absl::InlinedVector<SemaphoreValue, 4> signal_semaphores; ref_ptr<TimePointFence> fence; }; // Processes deferred submissions in this queue and returns whether there are // new workload submitted to the GPU if no errors happen. - StatusOr<bool> ProcessDeferredSubmissions() - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + iree_status_t ProcessDeferredSubmissions(bool* out_work_submitted = NULL); - ref_ptr<VkDeviceHandle> logical_device_; - - ref_ptr<TimePointFencePool> fence_pool_; - - mutable absl::Mutex mutex_; + TimePointFencePool* fence_pool_; // A list of fences that are submitted to GPU. absl::InlinedVector<ref_ptr<TimePointFence>, 4> pending_fences_ - ABSL_GUARDED_BY(mutex_); + IREE_GUARDED_BY(mutex_); // A list of deferred submissions that haven't been submitted to GPU. IntrusiveList<std::unique_ptr<FencedSubmission>> deferred_submissions_ - ABSL_GUARDED_BY(mutex_); - - // VkQueue needs to be externally synchronized. - VkQueue queue_ ABSL_GUARDED_BY(mutex_); + IREE_GUARDED_BY(mutex_); }; } // namespace vulkan
diff --git a/iree/hal/vulkan/status_util.cc b/iree/hal/vulkan/status_util.c similarity index 72% rename from iree/hal/vulkan/status_util.cc rename to iree/hal/vulkan/status_util.c index 8231db1..6117caa 100644 --- a/iree/hal/vulkan/status_util.cc +++ b/iree/hal/vulkan/status_util.c
@@ -14,49 +14,48 @@ #include "iree/hal/vulkan/status_util.h" -namespace iree { -namespace hal { -namespace vulkan { - -Status VkResultToStatus(VkResult result, SourceLocation loc) { +iree_status_t iree_hal_vulkan_result_to_status(VkResult result, + const char* file, + uint32_t line) { switch (result) { // Success codes. case VK_SUCCESS: // Command successfully completed. - return OkStatus(); + return iree_ok_status(); case VK_NOT_READY: // A fence or query has not yet completed. - return OkStatus(); + return iree_ok_status(); case VK_TIMEOUT: // A wait operation has not completed in the specified time. - return OkStatus(); + return iree_ok_status(); case VK_EVENT_SET: // An event is signaled. - return OkStatus(); + return iree_ok_status(); case VK_EVENT_RESET: // An event is unsignaled. - return OkStatus(); + return iree_ok_status(); case VK_INCOMPLETE: // A return array was too small for the result. - return OkStatus(); + return iree_ok_status(); case VK_SUBOPTIMAL_KHR: // A swapchain no longer matches the surface properties exactly, but can // still be used to present to the surface successfully. - return OkStatus(); + return iree_ok_status(); // Error codes. case VK_ERROR_OUT_OF_HOST_MEMORY: // A host memory allocation has failed. - return ResourceExhaustedErrorBuilder(loc) - << "VK_ERROR_OUT_OF_HOST_MEMORY"; + return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, + "VK_ERROR_OUT_OF_HOST_MEMORY"); case VK_ERROR_OUT_OF_DEVICE_MEMORY: // A device memory allocation has failed. - return ResourceExhaustedErrorBuilder(loc) - << "VK_ERROR_OUT_OF_DEVICE_MEMORY"; + return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, + "VK_ERROR_OUT_OF_DEVICE_MEMORY"); case VK_ERROR_INITIALIZATION_FAILED: // Initialization of an object could not be completed for // implementation-specific reasons. - return InternalErrorBuilder(loc) << "VK_ERROR_INITIALIZATION_FAILED"; + return iree_make_status(IREE_STATUS_UNAVAILABLE, + "VK_ERROR_INITIALIZATION_FAILED"); case VK_ERROR_DEVICE_LOST: // The logical or physical device has been lost. // @@ -125,77 +124,87 @@ // command buffer is in the pending state, or whether resources are // considered in-use by the device, a return value of // VK_ERROR_DEVICE_LOST is equivalent to VK_SUCCESS. - return InternalErrorBuilder(loc) << "VK_ERROR_DEVICE_LOST"; + return iree_make_status(IREE_STATUS_INTERNAL, "VK_ERROR_DEVICE_LOST"); case VK_ERROR_MEMORY_MAP_FAILED: // Mapping of a memory object has failed. - return InternalErrorBuilder(loc) << "VK_ERROR_MEMORY_MAP_FAILED"; + return iree_make_status(IREE_STATUS_INTERNAL, + "VK_ERROR_MEMORY_MAP_FAILED"); case VK_ERROR_LAYER_NOT_PRESENT: // A requested layer is not present or could not be loaded. - return UnimplementedErrorBuilder(loc) << "VK_ERROR_LAYER_NOT_PRESENT"; + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "VK_ERROR_LAYER_NOT_PRESENT"); case VK_ERROR_EXTENSION_NOT_PRESENT: // A requested extension is not supported. - return UnimplementedErrorBuilder(loc) << "VK_ERROR_EXTENSION_NOT_PRESENT"; + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "VK_ERROR_EXTENSION_NOT_PRESENT"); case VK_ERROR_FEATURE_NOT_PRESENT: // A requested feature is not supported. - return UnimplementedErrorBuilder(loc) << "VK_ERROR_FEATURE_NOT_PRESENT"; + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "VK_ERROR_FEATURE_NOT_PRESENT"); case VK_ERROR_INCOMPATIBLE_DRIVER: // The requested version of Vulkan is not supported by the driver or is // otherwise incompatible for implementation-specific reasons. - return FailedPreconditionErrorBuilder(loc) - << "VK_ERROR_INCOMPATIBLE_DRIVER"; + return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, + "VK_ERROR_INCOMPATIBLE_DRIVER"); case VK_ERROR_TOO_MANY_OBJECTS: // Too many objects of the type have already been created. - return ResourceExhaustedErrorBuilder(loc) << "VK_ERROR_TOO_MANY_OBJECTS"; + return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, + "VK_ERROR_TOO_MANY_OBJECTS"); case VK_ERROR_FORMAT_NOT_SUPPORTED: // A requested format is not supported on this device. - return UnimplementedErrorBuilder(loc) << "VK_ERROR_FORMAT_NOT_SUPPORTED"; + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "VK_ERROR_FORMAT_NOT_SUPPORTED"); case VK_ERROR_FRAGMENTED_POOL: - // A pool allocation has failed due to fragmentation of the pool’s memory. - // This must only be returned if no attempt to allocate host or device - // memory was made to accommodate the new allocation. - return ResourceExhaustedErrorBuilder(loc) << "VK_ERROR_FRAGMENTED_POOL"; + // A pool allocation has failed due to fragmentation of the pool’s + // memory. This must only be returned if no attempt to allocate host + // or device memory was made to accommodate the new allocation. + return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, + "VK_ERROR_FRAGMENTED_POOL"); case VK_ERROR_OUT_OF_POOL_MEMORY: // A pool memory allocation has failed. This must only be returned if no // attempt to allocate host or device memory was made to accommodate the // new allocation. If the failure was definitely due to fragmentation of // the pool, VK_ERROR_FRAGMENTED_POOL should be returned instead. - return ResourceExhaustedErrorBuilder(loc) - << "VK_ERROR_OUT_OF_POOL_MEMORY"; + return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, + "VK_ERROR_OUT_OF_POOL_MEMORY"); case VK_ERROR_INVALID_EXTERNAL_HANDLE: // An external handle is not a valid handle of the specified type. - return InvalidArgumentErrorBuilder(loc) - << "VK_ERROR_INVALID_EXTERNAL_HANDLE"; + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "VK_ERROR_INVALID_EXTERNAL_HANDLE"); case VK_ERROR_SURFACE_LOST_KHR: // A surface is no longer available. - return UnavailableErrorBuilder(loc) << "VK_ERROR_SURFACE_LOST_KHR"; + return iree_make_status(IREE_STATUS_UNAVAILABLE, + "VK_ERROR_SURFACE_LOST_KHR"); case VK_ERROR_NATIVE_WINDOW_IN_USE_KHR: // The requested window is already in use by Vulkan or another API in a // manner which prevents it from being used again. - return InvalidArgumentErrorBuilder(loc) - << "VK_ERROR_NATIVE_WINDOW_IN_USE_KHR"; + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "VK_ERROR_NATIVE_WINDOW_IN_USE_KHR"); case VK_ERROR_OUT_OF_DATE_KHR: // A surface has changed in such a way that it is no longer compatible // with the swapchain, and further presentation requests using the // swapchain will fail. Applications must query the new surface properties // and recreate their swapchain if they wish to continue presenting to the // surface. - return FailedPreconditionErrorBuilder(loc) << "VK_ERROR_OUT_OF_DATE_KHR"; + return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, + "VK_ERROR_OUT_OF_DATE_KHR"); case VK_ERROR_INCOMPATIBLE_DISPLAY_KHR: // The display used by a swapchain does not use the same presentable image // layout, or is incompatible in a way that prevents sharing an image. - return InvalidArgumentErrorBuilder(loc) - << "VK_ERROR_INCOMPATIBLE_DISPLAY_KHR"; + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "VK_ERROR_INCOMPATIBLE_DISPLAY_KHR"); case VK_ERROR_VALIDATION_FAILED_EXT: // Validation layer testing failed. It is not expected that an // application would see this this error code during normal use of the // validation layers. - return InvalidArgumentErrorBuilder(loc) - << "VK_ERROR_VALIDATION_FAILED_EXT"; + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "VK_ERROR_VALIDATION_FAILED_EXT"); case VK_ERROR_INVALID_SHADER_NV: // One or more shaders failed to compile or link. More details are // reported back to the application when the validation layer is enabled // using the extension VK_EXT_debug_report. - return InvalidArgumentErrorBuilder(loc) << "VK_ERROR_INVALID_SHADER_NV"; + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "VK_ERROR_INVALID_SHADER_NV"); case VK_ERROR_INVALID_DRM_FORMAT_MODIFIER_PLANE_LAYOUT_EXT: // When creating an image with // VkImageDrmFormatModifierExplicitCreateInfoEXT, it is the application’s @@ -207,33 +216,33 @@ // outside the scope of Vulkan, and therefore not described by Valid Usage // requirements). If this validation fails, then vkCreateImage returns // VK_ERROR_INVALID_DRM_FORMAT_MODIFIER_PLANE_LAYOUT_EXT. - return InvalidArgumentErrorBuilder(loc) - << "VK_ERROR_INVALID_DRM_FORMAT_MODIFIER_PLANE_LAYOUT_EXT"; + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "VK_ERROR_INVALID_DRM_FORMAT_MODIFIER_PLANE_LAYOUT_EXT"); case VK_ERROR_FRAGMENTATION_EXT: // A descriptor pool creation has failed due to fragmentation. - return ResourceExhaustedErrorBuilder(loc) << "VK_ERROR_FRAGMENTATION_EXT"; + return iree_make_status(IREE_STATUS_RESOURCE_EXHAUSTED, + "VK_ERROR_FRAGMENTATION_EXT"); case VK_ERROR_NOT_PERMITTED_EXT: // When creating a queue, the caller does not have sufficient privileges // to request to acquire a priority above the default priority // (VK_QUEUE_GLOBAL_PRIORITY_MEDIUM_EXT). - return PermissionDeniedErrorBuilder(loc) << "VK_ERROR_NOT_PERMITTED_EXT"; + return iree_make_status(IREE_STATUS_PERMISSION_DENIED, + "VK_ERROR_NOT_PERMITTED_EXT"); case VK_ERROR_INVALID_DEVICE_ADDRESS_EXT: // A buffer creation failed because the requested address is not // available. - return OutOfRangeErrorBuilder(loc) - << "VK_ERROR_INVALID_DEVICE_ADDRESS_EXT"; + return iree_make_status(IREE_STATUS_OUT_OF_RANGE, + "VK_ERROR_INVALID_DEVICE_ADDRESS_EXT"); case VK_ERROR_FULL_SCREEN_EXCLUSIVE_MODE_LOST_EXT: // An operation on a swapchain created with // VK_FULL_SCREEN_EXCLUSIVE_APPLICATION_CONTROLLED_EXT failed as it did // not have exlusive full-screen access. This may occur due to // implementation-dependent reasons, outside of the application’s control. - return UnavailableErrorBuilder(loc) - << "VK_ERROR_FULL_SCREEN_EXCLUSIVE_MODE_LOST_EXT"; + return iree_make_status(IREE_STATUS_UNAVAILABLE, + "VK_ERROR_FULL_SCREEN_EXCLUSIVE_MODE_LOST_EXT"); default: - return UnknownErrorBuilder(loc) << result; + return iree_make_status(IREE_STATUS_UNKNOWN, "VkResult=%u", + (uint32_t)result); } } - -} // namespace vulkan -} // namespace hal -} // namespace iree
diff --git a/iree/hal/vulkan/status_util.h b/iree/hal/vulkan/status_util.h index d1497fd..f225ec5 100644 --- a/iree/hal/vulkan/status_util.h +++ b/iree/hal/vulkan/status_util.h
@@ -19,19 +19,27 @@ #include "iree/hal/vulkan/vulkan_headers.h" // clang-format on -#include "iree/base/status.h" +#include "iree/base/api.h" -namespace iree { -namespace hal { -namespace vulkan { +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Converts a VkResult to an iree_status_t. +// +// Usage: +// iree_status_t status = VK_RESULT_TO_STATUS(vkDoThing(...)); +#define VK_RESULT_TO_STATUS(expr, ...) \ + iree_hal_vulkan_result_to_status((expr), __FILE__, __LINE__) // IREE_RETURN_IF_ERROR but implicitly converts the VkResult return value to // a Status. // // Usage: -// VK_RETURN_IF_ERROR(vkDoThing(...)); -#define VK_RETURN_IF_ERROR(expr) \ - IREE_RETURN_IF_ERROR(::iree::hal::vulkan::VkResultToStatus(expr, IREE_LOC)) +// VK_RETURN_IF_ERROR(vkDoThing(...), "message"); +#define VK_RETURN_IF_ERROR(expr, ...) \ + IREE_RETURN_IF_ERROR( \ + iree_hal_vulkan_result_to_status(expr, __FILE__, __LINE__), __VA_ARGS__) // IREE_CHECK_OK but implicitly converts the VkResults return value to a // Status and checks that it is OkStatus. @@ -39,7 +47,7 @@ // Usage: // VK_CHECK_OK(vkDoThing(...)); #define VK_CHECK_OK(expr) \ - IREE_CHECK_OK(::iree::hal::vulkan::VkResultToStatus(expr, IREE_LOC)) + IREE_CHECK_OK(iree_hal_vulkan_result_to_status(expr, __FILE__, __LINE__)) // Converts a VkResult to a Status object. // @@ -81,10 +89,11 @@ // - VK_ERROR_NOT_PERMITTED_EXT -> PermissionDeniedError("VK...") // - VK_ERROR_INVALID_DEVICE_ADDRESS_EXT -> OutOfRangeError("VK...") // - VK_ERROR_FULL_SCREEN_EXCLUSIVE_MODE_LOST_EXT -> InternalError("VK...") -Status VkResultToStatus(VkResult result, SourceLocation loc); +iree_status_t iree_hal_vulkan_result_to_status(VkResult result, + const char* file, uint32_t line); -} // namespace vulkan -} // namespace hal -} // namespace iree +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #endif // IREE_HAL_VULKAN_STATUS_UTIL_H_
diff --git a/iree/hal/vulkan/timepoint_util.cc b/iree/hal/vulkan/timepoint_util.cc index c14ea26..a8080d0 100644 --- a/iree/hal/vulkan/timepoint_util.cc +++ b/iree/hal/vulkan/timepoint_util.cc
@@ -17,7 +17,6 @@ #include <memory> #include "absl/synchronization/mutex.h" -#include "iree/base/time.h" #include "iree/base/tracing.h" #include "iree/hal/vulkan/dynamic_symbols.h" #include "iree/hal/vulkan/status_util.h" @@ -47,13 +46,13 @@ } // static -StatusOr<ref_ptr<TimePointFencePool>> TimePointFencePool::Create( - ref_ptr<VkDeviceHandle> logical_device) { +iree_status_t TimePointFencePool::Create(VkDeviceHandle* logical_device, + TimePointFencePool** out_pool) { IREE_TRACE_SCOPE0("TimePointFencePool::Create"); - ref_ptr<TimePointFencePool> pool( - new TimePointFencePool(std::move(logical_device))); + ref_ptr<TimePointFencePool> pool(new TimePointFencePool(logical_device)); IREE_RETURN_IF_ERROR(pool->PreallocateFences()); - return pool; + *out_pool = pool.release(); + return iree_ok_status(); } TimePointFencePool::~TimePointFencePool() { @@ -100,8 +99,8 @@ free_fences_.push_back(std::unique_ptr<TimePointFence>(fence)); } -TimePointFencePool::TimePointFencePool(ref_ptr<VkDeviceHandle> logical_device) - : logical_device_(std::move(logical_device)) {} +TimePointFencePool::TimePointFencePool(VkDeviceHandle* logical_device) + : logical_device_(logical_device) {} const ref_ptr<DynamicSymbols>& TimePointFencePool::syms() const { return logical_device_->syms(); @@ -120,9 +119,10 @@ absl::MutexLock lock(&mutex_); for (int i = 0; i < fences.size(); ++i) { VkFence fence = VK_NULL_HANDLE; - VK_RETURN_IF_ERROR(syms()->vkCreateFence(*logical_device_, &create_info, - logical_device_->allocator(), - &fence)); + VK_RETURN_IF_ERROR( + syms()->vkCreateFence(*logical_device_, &create_info, + logical_device_->allocator(), &fence), + "vkCreateFence"); fences[i] = std::make_unique<TimePointFence>(this, fence); } } @@ -142,13 +142,14 @@ } // static -StatusOr<ref_ptr<TimePointSemaphorePool>> TimePointSemaphorePool::Create( - ref_ptr<VkDeviceHandle> logical_device) { +iree_status_t TimePointSemaphorePool::Create( + VkDeviceHandle* logical_device, TimePointSemaphorePool** out_pool) { IREE_TRACE_SCOPE0("TimePointSemaphorePool::Create"); ref_ptr<TimePointSemaphorePool> pool( - new TimePointSemaphorePool(std::move(logical_device))); + new TimePointSemaphorePool(logical_device)); IREE_RETURN_IF_ERROR(pool->PreallocateSemaphores()); - return pool; + *out_pool = pool.release(); + return iree_ok_status(); } TimePointSemaphorePool::~TimePointSemaphorePool() { @@ -206,9 +207,8 @@ free_semaphores_.merge_from(semaphores); } -TimePointSemaphorePool::TimePointSemaphorePool( - ref_ptr<VkDeviceHandle> logical_device) - : logical_device_(std::move(logical_device)) {} +TimePointSemaphorePool::TimePointSemaphorePool(VkDeviceHandle* logical_device) + : logical_device_(logical_device) {} const ref_ptr<DynamicSymbols>& TimePointSemaphorePool::syms() const { return logical_device_->syms(); @@ -227,7 +227,8 @@ auto* semaphore = &storage_[i]; VK_RETURN_IF_ERROR(syms()->vkCreateSemaphore(*logical_device_, &create_info, logical_device_->allocator(), - &semaphore->semaphore)); + &semaphore->semaphore), + "vkCreateSemaphore"); free_semaphores_.push_back(semaphore); }
diff --git a/iree/hal/vulkan/timepoint_util.h b/iree/hal/vulkan/timepoint_util.h index 7d10129..1f127e2 100644 --- a/iree/hal/vulkan/timepoint_util.h +++ b/iree/hal/vulkan/timepoint_util.h
@@ -125,8 +125,8 @@ static constexpr int kMaxInFlightFenceCount = 64; // Creates a new pool and pre-allocates `kMaxInFlightFenceCount` fences. - static StatusOr<ref_ptr<TimePointFencePool>> Create( - ref_ptr<VkDeviceHandle> logical_device); + static iree_status_t Create(VkDeviceHandle* logical_device, + TimePointFencePool** out_pool); ~TimePointFencePool(); @@ -143,18 +143,16 @@ // not be in flight on GPU. void ReleaseResolved(TimePointFence* fence); - const ref_ptr<VkDeviceHandle>& logical_device() const { - return logical_device_; - } + VkDeviceHandle* logical_device() const { return logical_device_; } private: - explicit TimePointFencePool(ref_ptr<VkDeviceHandle> logical_device); + explicit TimePointFencePool(VkDeviceHandle* logical_device); const ref_ptr<DynamicSymbols>& syms() const; Status PreallocateFences() ABSL_LOCKS_EXCLUDED(mutex_); - ref_ptr<VkDeviceHandle> logical_device_; + VkDeviceHandle* logical_device_; absl::Mutex mutex_; @@ -171,8 +169,8 @@ // Creates a new pool and pre-allocates `kMaxInFlightSemaphoreCount` binary // semaphores. - static StatusOr<ref_ptr<TimePointSemaphorePool>> Create( - ref_ptr<VkDeviceHandle> logical_device); + static iree_status_t Create(VkDeviceHandle* logical_device, + TimePointSemaphorePool** out_pool); ~TimePointSemaphorePool(); @@ -195,13 +193,13 @@ void ReleaseUnresolved(IntrusiveList<TimePointSemaphore>* semaphores); private: - explicit TimePointSemaphorePool(ref_ptr<VkDeviceHandle> logical_device); + explicit TimePointSemaphorePool(VkDeviceHandle* logical_device); const ref_ptr<DynamicSymbols>& syms() const; Status PreallocateSemaphores() ABSL_LOCKS_EXCLUDED(mutex_); - ref_ptr<VkDeviceHandle> logical_device_; + VkDeviceHandle* logical_device_; absl::Mutex mutex_;
diff --git a/iree/hal/vulkan/vma_allocator.cc b/iree/hal/vulkan/vma_allocator.cc index 5b1ed56..635a94d 100644 --- a/iree/hal/vulkan/vma_allocator.cc +++ b/iree/hal/vulkan/vma_allocator.cc
@@ -14,26 +14,39 @@ #include "iree/hal/vulkan/vma_allocator.h" -#include "absl/memory/memory.h" -#include "iree/base/status.h" #include "iree/base/tracing.h" -#include "iree/hal/cc/buffer.h" #include "iree/hal/vulkan/status_util.h" #include "iree/hal/vulkan/vma_buffer.h" -namespace iree { -namespace hal { -namespace vulkan { +using namespace iree::hal::vulkan; -// static -StatusOr<std::unique_ptr<VmaAllocator>> VmaAllocator::Create( - VkPhysicalDevice physical_device, - const ref_ptr<VkDeviceHandle>& logical_device, VkInstance instance, - Options options) { - IREE_TRACE_SCOPE0("VmaAllocator::Create"); +typedef struct iree_hal_vulkan_vma_allocator_s { + iree_hal_resource_t resource; + iree_allocator_t host_allocator; + VmaAllocator vma; +} iree_hal_vulkan_vma_allocator_t; + +extern const iree_hal_allocator_vtable_t iree_hal_vulkan_vma_allocator_vtable; + +static iree_hal_vulkan_vma_allocator_t* iree_hal_vulkan_vma_allocator_cast( + iree_hal_allocator_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_vulkan_vma_allocator_vtable); + return (iree_hal_vulkan_vma_allocator_t*)base_value; +} + +iree_status_t iree_hal_vulkan_vma_allocator_create( + VkInstance instance, VkPhysicalDevice physical_device, + VkDeviceHandle* logical_device, VmaRecordSettings record_settings, + iree_hal_allocator_t** out_allocator) { + IREE_ASSERT_ARGUMENT(instance); + IREE_ASSERT_ARGUMENT(physical_device); + IREE_ASSERT_ARGUMENT(logical_device); + IREE_ASSERT_ARGUMENT(out_allocator); + IREE_TRACE_ZONE_BEGIN(z0); const auto& syms = logical_device->syms(); VmaVulkanFunctions vulkan_fns; + memset(&vulkan_fns, 0, sizeof(vulkan_fns)); vulkan_fns.vkGetPhysicalDeviceProperties = syms->vkGetPhysicalDeviceProperties; vulkan_fns.vkGetPhysicalDeviceMemoryProperties = @@ -56,77 +69,110 @@ vulkan_fns.vkDestroyImage = syms->vkDestroyImage; vulkan_fns.vkCmdCopyBuffer = syms->vkCmdCopyBuffer; - VmaRecordSettings record_settings; -#if VMA_RECORDING_ENABLED - record_settings.flags = - options.recording_flush_after_call ? VMA_RECORD_FLUSH_AFTER_CALL_BIT : 0; - record_settings.pFilePath = options.recording_file.c_str(); -#else - record_settings.flags = 0; - record_settings.pFilePath = nullptr; -#endif // VMA_RECORDING_ENABLED - - VmaAllocatorCreateInfo create_info{}; + VmaAllocatorCreateInfo create_info; + memset(&create_info, 0, sizeof(create_info)); create_info.flags = 0; create_info.physicalDevice = physical_device; create_info.device = *logical_device; create_info.instance = instance; create_info.preferredLargeHeapBlockSize = 64 * 1024 * 1024; create_info.pAllocationCallbacks = logical_device->allocator(); - create_info.pDeviceMemoryCallbacks = nullptr; + create_info.pDeviceMemoryCallbacks = NULL; create_info.frameInUseCount = 0; - create_info.pHeapSizeLimit = nullptr; + create_info.pHeapSizeLimit = NULL; create_info.pVulkanFunctions = &vulkan_fns; create_info.pRecordSettings = &record_settings; - ::VmaAllocator vma = VK_NULL_HANDLE; - VK_RETURN_IF_ERROR(vmaCreateAllocator(&create_info, &vma)); + VmaAllocator vma = VK_NULL_HANDLE; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, VK_RESULT_TO_STATUS(vmaCreateAllocator(&create_info, &vma), + "vmaCreateAllocator")); - auto allocator = - absl::WrapUnique(new VmaAllocator(physical_device, logical_device, vma)); - // TODO(benvanik): query memory properties/types. - return allocator; + iree_allocator_t host_allocator = logical_device->host_allocator(); + iree_hal_vulkan_vma_allocator_t* allocator = NULL; + iree_status_t status = iree_allocator_malloc( + host_allocator, sizeof(*allocator), (void**)&allocator); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_vulkan_vma_allocator_vtable, + &allocator->resource); + allocator->host_allocator = host_allocator; + allocator->vma = vma; + *out_allocator = (iree_hal_allocator_t*)allocator; + } else { + vmaDestroyAllocator(vma); + } + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); } -VmaAllocator::VmaAllocator(VkPhysicalDevice physical_device, - const ref_ptr<VkDeviceHandle>& logical_device, - ::VmaAllocator vma) - : physical_device_(physical_device), - logical_device_(add_ref(logical_device)), - vma_(vma) {} +static void iree_hal_vulkan_vma_allocator_destroy( + iree_hal_allocator_t* base_allocator) { + iree_hal_vulkan_vma_allocator_t* allocator = + iree_hal_vulkan_vma_allocator_cast(base_allocator); + iree_allocator_t host_allocator = allocator->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); -VmaAllocator::~VmaAllocator() { - IREE_TRACE_SCOPE0("VmaAllocator::dtor"); - vmaDestroyAllocator(vma_); + vmaDestroyAllocator(allocator->vma); + iree_allocator_free(host_allocator, allocator); + + IREE_TRACE_ZONE_END(z0); } -bool VmaAllocator::CanUseBufferLike( - Allocator* source_allocator, iree_hal_memory_type_t memory_type, - iree_hal_buffer_usage_t buffer_usage, - iree_hal_buffer_usage_t intended_usage) const { - // TODO(benvanik): ensure there is a memory type that can satisfy the request. - return source_allocator == this; +static iree_allocator_t iree_hal_vulkan_vma_allocator_host_allocator( + const iree_hal_allocator_t* base_allocator) { + iree_hal_vulkan_vma_allocator_t* allocator = + (iree_hal_vulkan_vma_allocator_t*)base_allocator; + return allocator->host_allocator; } -bool VmaAllocator::CanAllocate(iree_hal_memory_type_t memory_type, - iree_hal_buffer_usage_t buffer_usage, - size_t allocation_size) const { - // TODO(benvnik): ensure there is a memory type that can satisfy the request. - return true; +static iree_hal_buffer_compatibility_t +iree_hal_vulkan_vma_allocator_query_buffer_compatibility( + iree_hal_allocator_t* base_allocator, iree_hal_memory_type_t memory_type, + iree_hal_buffer_usage_t allowed_usage, + iree_hal_buffer_usage_t intended_usage, + iree_device_size_t allocation_size) { + // TODO(benvanik): check to ensure the allocator can serve the memory type. + + // Disallow usage not permitted by the buffer itself. Since we then use this + // to determine compatibility below we'll naturally set the right compat flags + // based on what's both allowed and intended. + intended_usage &= allowed_usage; + + // All buffers can be allocated on the heap. + iree_hal_buffer_compatibility_t compatibility = + IREE_HAL_BUFFER_COMPATIBILITY_ALLOCATABLE; + + // Buffers can only be used on the queue if they are device visible. + if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) { + if (iree_all_bits_set(intended_usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) { + compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER; + } + if (iree_all_bits_set(intended_usage, IREE_HAL_BUFFER_USAGE_DISPATCH)) { + compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH; + } + } + + return compatibility; } -Status VmaAllocator::MakeCompatible( +static iree_status_t iree_hal_vulkan_vma_allocator_make_compatible( iree_hal_memory_type_t* memory_type, - iree_hal_buffer_usage_t* buffer_usage) const { - // TODO(benvanik): mutate to match supported memory types. - return OkStatus(); + iree_hal_memory_access_t* allowed_access, + iree_hal_buffer_usage_t* allowed_usage) { + // TODO(benvanik): remove this entirely! + // Host currently uses mapping to copy buffers, which is done a lot. + // We could probably remove this mutation by preventing copies in those cases + // or issuing small copy command buffers. + *allowed_usage |= + IREE_HAL_MEMORY_TYPE_HOST_VISIBLE | IREE_HAL_BUFFER_USAGE_MAPPING; + return iree_ok_status(); } -StatusOr<ref_ptr<VmaBuffer>> VmaAllocator::AllocateInternal( - iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t buffer_usage, +static iree_status_t iree_hal_vulkan_vma_allocator_allocate_internal( + iree_hal_vulkan_vma_allocator_t* allocator, + iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t allowed_usage, iree_hal_memory_access_t allowed_access, size_t allocation_size, - VmaAllocationCreateFlags flags) { - IREE_TRACE_SCOPE0("VmaAllocator::AllocateInternal"); - + VmaAllocationCreateFlags flags, iree_hal_buffer_t** out_buffer) { // Guard against the corner case where the requested buffer size is 0. The // application is unlikely to do anything when requesting a 0-byte buffer; but // it can happen in real world use cases. So we should at least not crash. @@ -134,22 +180,22 @@ VkBufferCreateInfo buffer_create_info; buffer_create_info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; - buffer_create_info.pNext = nullptr; + buffer_create_info.pNext = NULL; buffer_create_info.flags = 0; buffer_create_info.size = allocation_size; buffer_create_info.usage = 0; - if (iree_all_bits_set(buffer_usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) { + if (iree_all_bits_set(allowed_usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) { buffer_create_info.usage |= VK_BUFFER_USAGE_TRANSFER_SRC_BIT; buffer_create_info.usage |= VK_BUFFER_USAGE_TRANSFER_DST_BIT; } - if (iree_all_bits_set(buffer_usage, IREE_HAL_BUFFER_USAGE_DISPATCH)) { + if (iree_all_bits_set(allowed_usage, IREE_HAL_BUFFER_USAGE_DISPATCH)) { buffer_create_info.usage |= VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT; buffer_create_info.usage |= VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; buffer_create_info.usage |= VK_BUFFER_USAGE_INDIRECT_BUFFER_BIT; } buffer_create_info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; buffer_create_info.queueFamilyIndexCount = 0; - buffer_create_info.pQueueFamilyIndices = nullptr; + buffer_create_info.pQueueFamilyIndices = NULL; VmaAllocationCreateInfo allocation_create_info; allocation_create_info.flags = flags; @@ -158,7 +204,7 @@ allocation_create_info.preferredFlags = 0; allocation_create_info.memoryTypeBits = 0; // Automatic selection. allocation_create_info.pool = VK_NULL_HANDLE; - allocation_create_info.pUserData = nullptr; + allocation_create_info.pUserData = NULL; if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL)) { if (iree_all_bits_set(memory_type, IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) { // Device-local, host-visible. @@ -191,39 +237,57 @@ allocation_create_info.preferredFlags |= VK_MEMORY_PROPERTY_LAZILY_ALLOCATED_BIT; } - if (iree_all_bits_set(buffer_usage, IREE_HAL_BUFFER_USAGE_MAPPING)) { + if (iree_all_bits_set(allowed_usage, IREE_HAL_BUFFER_USAGE_MAPPING)) { allocation_create_info.requiredFlags |= VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; } - VkBuffer buffer = VK_NULL_HANDLE; + VkBuffer handle = VK_NULL_HANDLE; VmaAllocation allocation = VK_NULL_HANDLE; VmaAllocationInfo allocation_info; - VK_RETURN_IF_ERROR(vmaCreateBuffer(vma_, &buffer_create_info, - &allocation_create_info, &buffer, - &allocation, &allocation_info)); + VK_RETURN_IF_ERROR(vmaCreateBuffer(allocator->vma, &buffer_create_info, + &allocation_create_info, &handle, + &allocation, &allocation_info), + "vmaCreateBuffer"); - return make_ref<VmaBuffer>(this, memory_type, allowed_access, buffer_usage, - allocation_size, 0, allocation_size, buffer, - allocation, allocation_info); + return iree_hal_vulkan_vma_buffer_wrap( + (iree_hal_allocator_t*)allocator, memory_type, allowed_access, + allowed_usage, allocation_size, + /*byte_offset=*/0, + /*byte_length=*/allocation_size, allocator->vma, handle, allocation, + allocation_info, out_buffer); } -StatusOr<ref_ptr<Buffer>> VmaAllocator::Allocate( - iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t buffer_usage, - size_t allocation_size) { - IREE_TRACE_SCOPE0("VmaAllocator::Allocate"); - return AllocateInternal(memory_type, buffer_usage, IREE_HAL_MEMORY_ACCESS_ALL, - allocation_size, /*flags=*/0); +static iree_status_t iree_hal_vulkan_vma_allocator_allocate_buffer( + iree_hal_allocator_t* base_allocator, iree_hal_memory_type_t memory_type, + iree_hal_buffer_usage_t allowed_usage, iree_host_size_t allocation_size, + iree_hal_buffer_t** out_buffer) { + iree_hal_vulkan_vma_allocator_t* allocator = + iree_hal_vulkan_vma_allocator_cast(base_allocator); + + // Coerce options into those required for use by VMA. + iree_hal_memory_access_t allowed_access = IREE_HAL_MEMORY_ACCESS_ALL; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_vma_allocator_make_compatible( + &memory_type, &allowed_access, &allowed_usage)); + + return iree_hal_vulkan_vma_allocator_allocate_internal( + allocator, memory_type, allowed_usage, allowed_access, allocation_size, + /*flags=*/0, out_buffer); } -StatusOr<ref_ptr<Buffer>> VmaAllocator::WrapMutable( - iree_hal_memory_type_t memory_type, iree_hal_memory_access_t allowed_access, - iree_hal_buffer_usage_t buffer_usage, void* data, size_t data_length) { - IREE_TRACE_SCOPE0("VmaAllocator::WrapMutable"); - // TODO(benvanik): import memory. - return UnimplementedErrorBuilder(IREE_LOC) - << "Wrapping host memory is not yet implemented"; +static iree_status_t iree_hal_vulkan_vma_allocator_wrap_buffer( + iree_hal_allocator_t* base_allocator, iree_hal_memory_type_t memory_type, + iree_hal_memory_access_t allowed_access, + iree_hal_buffer_usage_t allowed_usage, iree_byte_span_t data, + iree_allocator_t data_allocator, iree_hal_buffer_t** out_buffer) { + return iree_make_status(IREE_STATUS_UNAVAILABLE, + "wrapping of external buffers not supported"); } -} // namespace vulkan -} // namespace hal -} // namespace iree +const iree_hal_allocator_vtable_t iree_hal_vulkan_vma_allocator_vtable = { + /*.destroy=*/iree_hal_vulkan_vma_allocator_destroy, + /*.host_allocator=*/iree_hal_vulkan_vma_allocator_host_allocator, + /*.query_buffer_compatibility = */ + iree_hal_vulkan_vma_allocator_query_buffer_compatibility, + /*.allocate_buffer=*/iree_hal_vulkan_vma_allocator_allocate_buffer, + /*.wrap_buffer=*/iree_hal_vulkan_vma_allocator_wrap_buffer, +};
diff --git a/iree/hal/vulkan/vma_allocator.h b/iree/hal/vulkan/vma_allocator.h index 883f8dd..bfa1f55 100644 --- a/iree/hal/vulkan/vma_allocator.h +++ b/iree/hal/vulkan/vma_allocator.h
@@ -15,25 +15,18 @@ #ifndef IREE_HAL_VULKAN_VMA_ALLOCATOR_H_ #define IREE_HAL_VULKAN_VMA_ALLOCATOR_H_ -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on - -#include <memory> - -#include "iree/base/status.h" -#include "iree/hal/cc/allocator.h" -#include "iree/hal/vulkan/dynamic_symbols.h" +#include "iree/hal/api.h" #include "iree/hal/vulkan/handle_util.h" #include "iree/hal/vulkan/internal_vk_mem_alloc.h" -namespace iree { -namespace hal { -namespace vulkan { +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus -class VmaBuffer; - -// A HAL allocator using the Vulkan Memory Allocator (VMA) to manage memory. +// Creates a VMA-based allocator that performs internal suballocation and a +// bunch of other fancy things. +// +// This uses the Vulkan Memory Allocator (VMA) to manage memory. // VMA (//third_party/vulkan_memory_allocator) provides dlmalloc-like behavior // with suballocations made with various policies (best fit, first fit, etc). // This reduces the number of allocations we need from the Vulkan implementation @@ -47,75 +40,13 @@ // More information: // https://github.com/GPUOpen-LibrariesAndSDKs/VulkanMemoryAllocator // https://gpuopen-librariesandsdks.github.io/VulkanMemoryAllocator/html/ -class VmaAllocator final : public Allocator { - public: - struct Options { -#if VMA_RECORDING_ENABLED - // File path to write a CSV containing the VMA recording. - std::string recording_file = ""; +iree_status_t iree_hal_vulkan_vma_allocator_create( + VkInstance instance, VkPhysicalDevice physical_device, + iree::hal::vulkan::VkDeviceHandle* logical_device, + VmaRecordSettings record_settings, iree_hal_allocator_t** out_allocator); - // Flush the VMA recording file after every call (useful if crashing or - // not exiting cleanly). - bool recording_flush_after_call = false; -#endif // VMA_RECORDING_ENABLED - }; - - static StatusOr<std::unique_ptr<VmaAllocator>> Create( - VkPhysicalDevice physical_device, - const ref_ptr<VkDeviceHandle>& logical_device, VkInstance instance, - Options options); - - ~VmaAllocator() override; - - const ref_ptr<DynamicSymbols>& syms() const { - return logical_device_->syms(); - } - - ::VmaAllocator vma() const { return vma_; } - - bool CanUseBufferLike(Allocator* source_allocator, - iree_hal_memory_type_t memory_type, - iree_hal_buffer_usage_t buffer_usage, - iree_hal_buffer_usage_t intended_usage) const override; - - bool CanAllocate(iree_hal_memory_type_t memory_type, - iree_hal_buffer_usage_t buffer_usage, - size_t allocation_size) const override; - - Status MakeCompatible(iree_hal_memory_type_t* memory_type, - iree_hal_buffer_usage_t* buffer_usage) const override; - - StatusOr<ref_ptr<Buffer>> Allocate(iree_hal_memory_type_t memory_type, - iree_hal_buffer_usage_t buffer_usage, - size_t allocation_size) override; - - StatusOr<ref_ptr<Buffer>> WrapMutable(iree_hal_memory_type_t memory_type, - iree_hal_memory_access_t allowed_access, - iree_hal_buffer_usage_t buffer_usage, - void* data, - size_t data_length) override; - - private: - VmaAllocator(VkPhysicalDevice physical_device, - const ref_ptr<VkDeviceHandle>& logical_device, - ::VmaAllocator vma); - - StatusOr<ref_ptr<VmaBuffer>> AllocateInternal( - iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t buffer_usage, - iree_hal_memory_access_t allowed_access, size_t allocation_size, - VmaAllocationCreateFlags flags); - - VkPhysicalDevice physical_device_; - ref_ptr<VkDeviceHandle> logical_device_; - - // Internally synchronized. We could externally synchronize if we thought it - // was worth it, however I'm not sure we'd be able to do much better with the - // current Allocator API. - ::VmaAllocator vma_; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #endif // IREE_HAL_VULKAN_VMA_ALLOCATOR_H_
diff --git a/iree/hal/vulkan/vma_buffer.cc b/iree/hal/vulkan/vma_buffer.cc index 5e23bcc..0bcab31 100644 --- a/iree/hal/vulkan/vma_buffer.cc +++ b/iree/hal/vulkan/vma_buffer.cc
@@ -14,117 +14,196 @@ #include "iree/hal/vulkan/vma_buffer.h" -#include "iree/base/status.h" #include "iree/base/tracing.h" #include "iree/hal/vulkan/status_util.h" -#include "iree/hal/vulkan/vma_allocator.h" -namespace iree { -namespace hal { -namespace vulkan { +typedef struct iree_hal_vulkan_vma_buffer_s { + iree_hal_buffer_t base; -VmaBuffer::VmaBuffer( - VmaAllocator* allocator, iree_hal_memory_type_t memory_type, - iree_hal_memory_access_t allowed_access, iree_hal_buffer_usage_t usage, - iree_device_size_t allocation_size, iree_device_size_t byte_offset, - iree_device_size_t byte_length, VkBuffer buffer, VmaAllocation allocation, - VmaAllocationInfo allocation_info) - : Buffer(allocator, memory_type, allowed_access, usage, allocation_size, - byte_offset, byte_length), - vma_(allocator->vma()), - buffer_(buffer), - allocation_(allocation), - allocation_info_(allocation_info) { - // TODO(benvanik): set debug name instead and use the - // VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT flag. - vmaSetAllocationUserData(vma_, allocation_, this); + VmaAllocator vma; + VkBuffer handle; + VmaAllocation allocation; + VmaAllocationInfo allocation_info; +} iree_hal_vulkan_vma_buffer_t; + +extern const iree_hal_buffer_vtable_t iree_hal_vulkan_vma_buffer_vtable; + +static iree_hal_vulkan_vma_buffer_t* iree_hal_vulkan_vma_buffer_cast( + iree_hal_buffer_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_vulkan_vma_buffer_vtable); + return (iree_hal_vulkan_vma_buffer_t*)base_value; } -VmaBuffer::~VmaBuffer() { - IREE_TRACE_SCOPE0("VmaBuffer::dtor"); - vmaDestroyBuffer(vma_, buffer_, allocation_); +iree_status_t iree_hal_vulkan_vma_buffer_wrap( + iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type, + iree_hal_memory_access_t allowed_access, + iree_hal_buffer_usage_t allowed_usage, iree_device_size_t allocation_size, + iree_device_size_t byte_offset, iree_device_size_t byte_length, + VmaAllocator vma, VkBuffer handle, VmaAllocation allocation, + VmaAllocationInfo allocation_info, iree_hal_buffer_t** out_buffer) { + IREE_ASSERT_ARGUMENT(allocator); + IREE_ASSERT_ARGUMENT(vma); + IREE_ASSERT_ARGUMENT(handle); + IREE_ASSERT_ARGUMENT(allocation); + IREE_ASSERT_ARGUMENT(out_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_vulkan_vma_buffer_t* buffer = NULL; + iree_status_t status = + iree_allocator_malloc(iree_hal_allocator_host_allocator(allocator), + sizeof(*buffer), (void**)&buffer); + if (iree_status_is_ok(status)) { + iree_hal_resource_initialize(&iree_hal_vulkan_vma_buffer_vtable, + &buffer->base.resource); + buffer->base.allocator = allocator; + buffer->base.allocated_buffer = &buffer->base; + buffer->base.allocation_size = allocation_size; + buffer->base.byte_offset = byte_offset; + buffer->base.byte_length = byte_length; + buffer->base.memory_type = memory_type; + buffer->base.allowed_access = allowed_access; + buffer->base.allowed_usage = allowed_usage; + buffer->vma = vma; + buffer->handle = handle; + buffer->allocation = allocation; + buffer->allocation_info = allocation_info; + + // TODO(benvanik): set debug name instead and use the + // VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT flag. + vmaSetAllocationUserData(buffer->vma, buffer->allocation, buffer); + + *out_buffer = &buffer->base; + } else { + vmaDestroyBuffer(vma, handle, allocation); + } + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); } -Status VmaBuffer::FillImpl(iree_device_size_t byte_offset, - iree_device_size_t byte_length, const void* pattern, - iree_device_size_t pattern_length) { - IREE_ASSIGN_OR_RETURN(auto mapping, - MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE, - byte_offset, byte_length)); - void* data_ptr = static_cast<void*>(mapping.mutable_data()); +static void iree_hal_vulkan_vma_buffer_destroy(iree_hal_buffer_t* base_buffer) { + iree_hal_vulkan_vma_buffer_t* buffer = + iree_hal_vulkan_vma_buffer_cast(base_buffer); + iree_allocator_t host_allocator = + iree_hal_allocator_host_allocator(iree_hal_buffer_allocator(base_buffer)); + IREE_TRACE_ZONE_BEGIN(z0); + + vmaDestroyBuffer(buffer->vma, buffer->handle, buffer->allocation); + iree_allocator_free(host_allocator, buffer); + + IREE_TRACE_ZONE_END(z0); +} + +VkBuffer iree_hal_vulkan_vma_buffer_handle(iree_hal_buffer_t* base_buffer) { + iree_hal_vulkan_vma_buffer_t* buffer = + iree_hal_vulkan_vma_buffer_cast(base_buffer); + return buffer->handle; +} + +static iree_status_t iree_hal_vulkan_vma_buffer_fill( + iree_hal_buffer_t* base_buffer, iree_device_size_t byte_offset, + iree_device_size_t byte_length, const void* pattern, + iree_host_size_t pattern_length) { + iree_hal_buffer_mapping_t target_mapping; + IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( + base_buffer, IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE, byte_offset, + byte_length, &target_mapping)); + iree_status_t status = iree_ok_status(); + void* data_ptr = target_mapping.contents.data; switch (pattern_length) { case 1: { - uint8_t* data = static_cast<uint8_t*>(data_ptr); - uint8_t value_bits = *static_cast<const uint8_t*>(pattern); - std::fill_n(data, byte_length, value_bits); + uint8_t* data = (uint8_t*)data_ptr; + uint8_t value_bits = *(const uint8_t*)(pattern); + memset(data, value_bits, byte_length); break; } case 2: { - uint16_t* data = static_cast<uint16_t*>(data_ptr); - uint16_t value_bits = *static_cast<const uint16_t*>(pattern); - std::fill_n(data, byte_length / sizeof(uint16_t), value_bits); + uint16_t* data = (uint16_t*)data_ptr; + uint16_t value_bits = *(const uint16_t*)(pattern); + for (iree_device_size_t i = 0; i < byte_length / sizeof(uint16_t); ++i) { + data[i] = value_bits; + } break; } case 4: { - uint32_t* data = static_cast<uint32_t*>(data_ptr); - uint32_t value_bits = *static_cast<const uint32_t*>(pattern); - std::fill_n(data, byte_length / sizeof(uint32_t), value_bits); + uint32_t* data = (uint32_t*)data_ptr; + uint32_t value_bits = *(const uint32_t*)(pattern); + for (iree_device_size_t i = 0; i < byte_length / sizeof(uint32_t); ++i) { + data[i] = value_bits; + } break; } default: - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Unsupported scalar data size: " << pattern_length; + status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "unsupported fill pattern length: %zu", + pattern_length); + break; } - return OkStatus(); + iree_hal_buffer_flush_range(&target_mapping, byte_offset, byte_length); + iree_status_ignore(iree_hal_buffer_unmap_range(&target_mapping)); + return status; } -Status VmaBuffer::ReadDataImpl(iree_device_size_t source_offset, void* data, - iree_device_size_t data_length) { - IREE_ASSIGN_OR_RETURN(auto mapping, - MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_READ, - source_offset, data_length)); - std::memcpy(data, mapping.data(), mapping.byte_length()); - return OkStatus(); +static iree_status_t iree_hal_vulkan_vma_buffer_read_data( + iree_hal_buffer_t* base_buffer, iree_device_size_t source_offset, + void* target_buffer, iree_device_size_t data_length) { + iree_hal_buffer_mapping_t source_mapping; + IREE_RETURN_IF_ERROR( + iree_hal_buffer_map_range(base_buffer, IREE_HAL_MEMORY_ACCESS_READ, + source_offset, data_length, &source_mapping)); + memcpy(target_buffer, source_mapping.contents.data, data_length); + return iree_hal_buffer_unmap_range(&source_mapping); } -Status VmaBuffer::WriteDataImpl(iree_device_size_t target_offset, - const void* data, - iree_device_size_t data_length) { - IREE_ASSIGN_OR_RETURN(auto mapping, - MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE, - target_offset, data_length)); - std::memcpy(mapping.mutable_data(), data, mapping.byte_length()); - return OkStatus(); +static iree_status_t iree_hal_vulkan_vma_buffer_write_data( + iree_hal_buffer_t* base_buffer, iree_device_size_t target_offset, + const void* source_buffer, iree_device_size_t data_length) { + iree_hal_buffer_mapping_t target_mapping; + IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( + base_buffer, IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE, target_offset, + data_length, &target_mapping)); + memcpy(target_mapping.contents.data, source_buffer, data_length); + return iree_hal_buffer_unmap_range(&target_mapping); } -Status VmaBuffer::CopyDataImpl(iree_device_size_t target_offset, - Buffer* source_buffer, - iree_device_size_t source_offset, - iree_device_size_t data_length) { - // This is pretty terrible. Let's not do this. - // TODO(benvanik): a way for allocators to indicate transfer compat. - IREE_ASSIGN_OR_RETURN(auto source_mapping, source_buffer->MapMemory<uint8_t>( - IREE_HAL_MEMORY_ACCESS_READ, - source_offset, data_length)); - IREE_CHECK_EQ(data_length, source_mapping.size()); - IREE_ASSIGN_OR_RETURN(auto target_mapping, - MapMemory<uint8_t>(IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE, - target_offset, data_length)); - IREE_CHECK_EQ(data_length, target_mapping.size()); - std::memcpy(target_mapping.mutable_data(), source_mapping.data(), - data_length); - return OkStatus(); +static iree_status_t iree_hal_vulkan_vma_buffer_copy_data( + iree_hal_buffer_t* base_source_buffer, iree_device_size_t source_offset, + iree_hal_buffer_t* base_target_buffer, iree_device_size_t target_offset, + iree_device_size_t data_length) { + iree_hal_buffer_mapping_t source_mapping; + IREE_RETURN_IF_ERROR( + iree_hal_buffer_map_range(base_source_buffer, IREE_HAL_MEMORY_ACCESS_READ, + source_offset, data_length, &source_mapping)); + iree_hal_buffer_mapping_t target_mapping; + iree_status_t status = iree_hal_buffer_map_range( + base_target_buffer, IREE_HAL_MEMORY_ACCESS_DISCARD_WRITE, target_offset, + data_length, &target_mapping); + if (IREE_UNLIKELY(!iree_status_is_ok(status))) { + IREE_IGNORE_ERROR(iree_hal_buffer_unmap_range(&source_mapping)); + return status; + } + + memcpy(target_mapping.contents.data, source_mapping.contents.data, + data_length); + + IREE_IGNORE_ERROR(iree_hal_buffer_unmap_range(&source_mapping)); + IREE_IGNORE_ERROR(iree_hal_buffer_unmap_range(&target_mapping)); + return iree_ok_status(); } -Status VmaBuffer::MapMemoryImpl(MappingMode mapping_mode, - iree_hal_memory_access_t memory_access, - iree_device_size_t local_byte_offset, - iree_device_size_t local_byte_length, - void** out_data) { +static iree_status_t iree_hal_vulkan_vma_buffer_map_range( + iree_hal_buffer_t* base_buffer, iree_hal_mapping_mode_t mapping_mode, + iree_hal_memory_access_t memory_access, + iree_device_size_t local_byte_offset, iree_device_size_t local_byte_length, + void** out_data_ptr) { + iree_hal_vulkan_vma_buffer_t* buffer = + iree_hal_vulkan_vma_buffer_cast(base_buffer); + uint8_t* data_ptr = nullptr; VK_RETURN_IF_ERROR( - vmaMapMemory(vma_, allocation_, reinterpret_cast<void**>(&data_ptr))); - *out_data = data_ptr + local_byte_offset; + vmaMapMemory(buffer->vma, buffer->allocation, (void**)&data_ptr), + "vmaMapMemory"); + *out_data_ptr = data_ptr + local_byte_offset; // If we mapped for discard scribble over the bytes. This is not a mandated // behavior but it will make debugging issues easier. Alternatively for @@ -132,34 +211,52 @@ // would only work if the entire buffer was discarded. #ifndef NDEBUG if (iree_any_bit_set(memory_access, IREE_HAL_MEMORY_ACCESS_DISCARD)) { - std::memset(data_ptr + local_byte_offset, 0xCD, local_byte_length); + memset(data_ptr + local_byte_offset, 0xCD, local_byte_length); } #endif // !NDEBUG - return OkStatus(); + return iree_ok_status(); } -Status VmaBuffer::UnmapMemoryImpl(iree_device_size_t local_byte_offset, - iree_device_size_t local_byte_length, - void* data) { - vmaUnmapMemory(vma_, allocation_); - return OkStatus(); +static void iree_hal_vulkan_vma_buffer_unmap_range( + iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset, + iree_device_size_t local_byte_length, void* data_ptr) { + iree_hal_vulkan_vma_buffer_t* buffer = + iree_hal_vulkan_vma_buffer_cast(base_buffer); + vmaUnmapMemory(buffer->vma, buffer->allocation); } -Status VmaBuffer::InvalidateMappedMemoryImpl( - iree_device_size_t local_byte_offset, +static iree_status_t iree_hal_vulkan_vma_buffer_invalidate_range( + iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset, iree_device_size_t local_byte_length) { - vmaInvalidateAllocation(vma_, allocation_, local_byte_offset, - local_byte_length); - return OkStatus(); + iree_hal_vulkan_vma_buffer_t* buffer = + iree_hal_vulkan_vma_buffer_cast(base_buffer); + VK_RETURN_IF_ERROR( + vmaInvalidateAllocation(buffer->vma, buffer->allocation, + local_byte_offset, local_byte_length), + "vmaInvalidateAllocation"); + return iree_ok_status(); } -Status VmaBuffer::FlushMappedMemoryImpl(iree_device_size_t local_byte_offset, - iree_device_size_t local_byte_length) { - vmaFlushAllocation(vma_, allocation_, local_byte_offset, local_byte_length); - return OkStatus(); +static iree_status_t iree_hal_vulkan_vma_buffer_flush_range( + iree_hal_buffer_t* base_buffer, iree_device_size_t local_byte_offset, + iree_device_size_t local_byte_length) { + iree_hal_vulkan_vma_buffer_t* buffer = + iree_hal_vulkan_vma_buffer_cast(base_buffer); + VK_RETURN_IF_ERROR(vmaFlushAllocation(buffer->vma, buffer->allocation, + local_byte_offset, local_byte_length), + "vmaFlushAllocation"); + return iree_ok_status(); } -} // namespace vulkan -} // namespace hal -} // namespace iree +const iree_hal_buffer_vtable_t iree_hal_vulkan_vma_buffer_vtable = { + /*.destroy=*/iree_hal_vulkan_vma_buffer_destroy, + /*.fill=*/iree_hal_vulkan_vma_buffer_fill, + /*.read_data=*/iree_hal_vulkan_vma_buffer_read_data, + /*.write_data=*/iree_hal_vulkan_vma_buffer_write_data, + /*.copy_data=*/iree_hal_vulkan_vma_buffer_copy_data, + /*.map_range=*/iree_hal_vulkan_vma_buffer_map_range, + /*.unmap_range=*/iree_hal_vulkan_vma_buffer_unmap_range, + /*.invalidate_range=*/iree_hal_vulkan_vma_buffer_invalidate_range, + /*.flush_range=*/iree_hal_vulkan_vma_buffer_flush_range, +};
diff --git a/iree/hal/vulkan/vma_buffer.h b/iree/hal/vulkan/vma_buffer.h index eb6cebb..0771e65 100644 --- a/iree/hal/vulkan/vma_buffer.h +++ b/iree/hal/vulkan/vma_buffer.h
@@ -15,71 +15,30 @@ #ifndef IREE_HAL_VULKAN_VMA_BUFFER_H_ #define IREE_HAL_VULKAN_VMA_BUFFER_H_ -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on - -#include "iree/hal/cc/buffer.h" +#include "iree/hal/api.h" #include "iree/hal/vulkan/internal_vk_mem_alloc.h" -namespace iree { -namespace hal { -namespace vulkan { +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus -class VmaAllocator; +// Wraps a VMA allocation in an iree_hal_buffer_t. +// The allocation will be released back to VMA when the buffer is released. +iree_status_t iree_hal_vulkan_vma_buffer_wrap( + iree_hal_allocator_t* allocator, iree_hal_memory_type_t memory_type, + iree_hal_memory_access_t allowed_access, + iree_hal_buffer_usage_t allowed_usage, iree_device_size_t allocation_size, + iree_device_size_t byte_offset, iree_device_size_t byte_length, + VmaAllocator vma, VkBuffer handle, VmaAllocation allocation, + VmaAllocationInfo allocation_info, iree_hal_buffer_t** out_buffer); -// A buffer implementation representing an allocation made from within a pool of -// a Vulkan Memory Allocator instance. See VmaAllocator for more information. -class VmaBuffer final : public Buffer { - public: - VmaBuffer(VmaAllocator* allocator, iree_hal_memory_type_t memory_type, - iree_hal_memory_access_t allowed_access, - iree_hal_buffer_usage_t usage, iree_device_size_t allocation_size, - iree_device_size_t byte_offset, iree_device_size_t byte_length, - VkBuffer buffer, VmaAllocation allocation, - VmaAllocationInfo allocation_info); - ~VmaBuffer() override; +// Returns the Vulkan handle backing the given |buffer|. +// This is the entire allocated_buffer and must be offset by the buffer +// byte_offset and byte_length when used. +VkBuffer iree_hal_vulkan_vma_buffer_handle(iree_hal_buffer_t* buffer); - VkBuffer handle() const { return buffer_; } - VmaAllocation allocation() const { return allocation_; } - const VmaAllocationInfo& allocation_info() const { return allocation_info_; } - - // Exposed so that VmaAllocator can reset access after initial mapping. - using Buffer::set_allowed_access; - - private: - Status FillImpl(iree_device_size_t byte_offset, - iree_device_size_t byte_length, const void* pattern, - iree_device_size_t pattern_length) override; - Status ReadDataImpl(iree_device_size_t source_offset, void* data, - iree_device_size_t data_length) override; - Status WriteDataImpl(iree_device_size_t target_offset, const void* data, - iree_device_size_t data_length) override; - Status CopyDataImpl(iree_device_size_t target_offset, Buffer* source_buffer, - iree_device_size_t source_offset, - iree_device_size_t data_length) override; - Status MapMemoryImpl(MappingMode mapping_mode, - iree_hal_memory_access_t memory_access, - iree_device_size_t local_byte_offset, - iree_device_size_t local_byte_length, - void** out_data) override; - Status UnmapMemoryImpl(iree_device_size_t local_byte_offset, - iree_device_size_t local_byte_length, - void* data) override; - Status InvalidateMappedMemoryImpl( - iree_device_size_t local_byte_offset, - iree_device_size_t local_byte_length) override; - Status FlushMappedMemoryImpl(iree_device_size_t local_byte_offset, - iree_device_size_t local_byte_length) override; - - ::VmaAllocator vma_; - VkBuffer buffer_; - VmaAllocation allocation_; - VmaAllocationInfo allocation_info_; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #endif // IREE_HAL_VULKAN_VMA_BUFFER_H_
diff --git a/iree/hal/vulkan/vulkan_device.cc b/iree/hal/vulkan/vulkan_device.cc index f6cef18..8f3cb0e 100644 --- a/iree/hal/vulkan/vulkan_device.cc +++ b/iree/hal/vulkan/vulkan_device.cc
@@ -20,93 +20,205 @@ #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" -#include "absl/synchronization/mutex.h" #include "iree/base/math.h" +#include "iree/base/memory.h" #include "iree/base/status.h" -#include "iree/base/time.h" #include "iree/base/tracing.h" -#include "iree/hal/cc/command_queue.h" -#include "iree/hal/cc/semaphore.h" +#include "iree/hal/vulkan/api.h" +#include "iree/hal/vulkan/command_queue.h" +#include "iree/hal/vulkan/descriptor_pool_cache.h" #include "iree/hal/vulkan/direct_command_buffer.h" #include "iree/hal/vulkan/direct_command_queue.h" #include "iree/hal/vulkan/dynamic_symbols.h" -#include "iree/hal/vulkan/emulated_timeline_semaphore.h" +#include "iree/hal/vulkan/emulated_semaphore.h" #include "iree/hal/vulkan/extensibility_util.h" +#include "iree/hal/vulkan/handle_util.h" #include "iree/hal/vulkan/native_descriptor_set.h" +#include "iree/hal/vulkan/native_descriptor_set_layout.h" #include "iree/hal/vulkan/native_event.h" -#include "iree/hal/vulkan/native_timeline_semaphore.h" -#include "iree/hal/vulkan/pipeline_cache.h" -#include "iree/hal/vulkan/pipeline_executable_layout.h" +#include "iree/hal/vulkan/native_executable_layout.h" +#include "iree/hal/vulkan/native_semaphore.h" +#include "iree/hal/vulkan/nop_executable_cache.h" #include "iree/hal/vulkan/serializing_command_queue.h" #include "iree/hal/vulkan/status_util.h" #include "iree/hal/vulkan/vma_allocator.h" -namespace iree { -namespace hal { -namespace vulkan { +using namespace iree::hal::vulkan; -namespace { +//===----------------------------------------------------------------------===// +// iree_hal_vulkan_device_t extensibility util +//===----------------------------------------------------------------------===// -constexpr uint32_t kInvalidQueueFamilyIndex = -1; +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_vulkan_query_extensibility_set( + iree_hal_vulkan_features_t requested_features, + iree_hal_vulkan_extensibility_set_t set, iree_host_size_t string_capacity, + const char** out_string_values, iree_host_size_t* out_string_count) { + *out_string_count = 0; -struct QueueFamilyInfo { - uint32_t dispatch_index = kInvalidQueueFamilyIndex; - uint32_t dispatch_queue_count = 0; - uint32_t transfer_index = kInvalidQueueFamilyIndex; - uint32_t transfer_queue_count = 0; -}; + iree_status_t status = iree_ok_status(); + iree_host_size_t string_count = 0; +#define ADD_EXT(target_set, name_literal) \ + if (iree_status_is_ok(status) && set == (target_set)) { \ + if (string_count >= string_capacity && out_string_values) { \ + status = iree_status_from_code(IREE_STATUS_OUT_OF_RANGE); \ + } else if (out_string_values) { \ + out_string_values[string_count] = (name_literal); \ + } \ + ++string_count; \ + } + + //===--------------------------------------------------------------------===// + // Baseline IREE requirements + //===--------------------------------------------------------------------===// + // Using IREE at all requires these extensions unconditionally. Adding things + // here changes our minimum requirements and should be done carefully. + // Optional extensions here are feature detected by the runtime. + + // VK_KHR_storage_buffer_storage_class: + // Our generated SPIR-V kernels use storage buffers for all their data access. + ADD_EXT(IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_REQUIRED, + VK_KHR_STORAGE_BUFFER_STORAGE_CLASS_EXTENSION_NAME); + + // VK_KHR_get_physical_device_properties2: + // Multiple extensions depend on VK_KHR_get_physical_device_properties2. + // This extension was deprecated in Vulkan 1.1 as its functionality was + // promoted to core so we list it as optional even though we require it. + ADD_EXT(IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_EXTENSIONS_OPTIONAL, + VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME); + + // VK_KHR_push_descriptor: + // We can avoid a lot of additional Vulkan descriptor set manipulation + // overhead when this extension is present. Android is a holdout, though, and + // we have a fallback for when it's not available. + ADD_EXT(IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_OPTIONAL, + VK_KHR_PUSH_DESCRIPTOR_EXTENSION_NAME); + + //===--------------------------------------------------------------------===// + // Vulkan forward-compatibility shims + //===--------------------------------------------------------------------===// + // These are shims or extensions that are made core later in the spec and can + // be removed once we require the core version that contains them. + + // VK_KHR_timeline_semaphore: + // timeline semaphore support is optional and will be emulated if necessary. + ADD_EXT(IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_OPTIONAL, + VK_KHR_TIMELINE_SEMAPHORE_EXTENSION_NAME); + + // VK_LAYER_KHRONOS_timeline_semaphore: + // polyfill layer - enable if present instead of our custom emulation. Ignored + // if timeline semaphores are supported natively (Vulkan 1.2+). + ADD_EXT(IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_LAYERS_OPTIONAL, + "VK_LAYER_KHRONOS_timeline_semaphore"); + + //===--------------------------------------------------------------------===// + // Optional debugging features + //===--------------------------------------------------------------------===// + // Used only when explicitly requested as they drastically change the + // performance behavior of Vulkan. + + // VK_LAYER_KHRONOS_validation: + // only enabled if validation is desired. Since validation in Vulkan is just a + // API correctness check it can't be used as a security mechanism and is fine + // to ignore. + if (iree_all_bits_set(requested_features, + IREE_HAL_VULKAN_FEATURE_ENABLE_VALIDATION_LAYERS)) { + ADD_EXT(IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_LAYERS_OPTIONAL, + "VK_LAYER_KHRONOS_validation"); + } + + // VK_EXT_debug_utils: + // only enabled if debugging is desired to route Vulkan debug messages through + // our logging sinks. Note that this adds a non-trivial runtime overhead and + // we may want to disable it even in debug builds. + if (iree_all_bits_set(requested_features, + IREE_HAL_VULKAN_FEATURE_ENABLE_DEBUG_UTILS)) { + ADD_EXT(IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_EXTENSIONS_OPTIONAL, + VK_EXT_DEBUG_UTILS_EXTENSION_NAME); + } + + *out_string_count = string_count; + return status; +} + +//===----------------------------------------------------------------------===// +// Queue selection +//===----------------------------------------------------------------------===// + +#define IREE_HAL_VULKAN_INVALID_QUEUE_FAMILY_INDEX (-1) + +typedef struct { + uint32_t dispatch_index; + iree_host_size_t dispatch_queue_count; + uint32_t transfer_index; + iree_host_size_t transfer_queue_count; +} iree_hal_vulkan_queue_family_info_t; // Finds the first queue in the listing (which is usually the // driver-preferred) that has all of the |required_queue_flags| and none of -// the |excluded_queue_flags|. Returns kInvalidQueueFamilyIndex if no matching -// queue is found. -uint32_t FindFirstQueueFamilyWithFlags( - absl::Span<const VkQueueFamilyProperties> queue_family_properties, - uint32_t required_queue_flags, uint32_t excluded_queue_flags) { - for (int queue_family_index = 0; - queue_family_index < queue_family_properties.size(); +// the |excluded_queue_flags|. +// Returns IREE_HAL_VULKAN_INVALID_QUEUE_FAMILY_INDEX if no matching queue is +// found. +static uint32_t iree_hal_vulkan_find_first_queue_family_with_flags( + uint32_t queue_family_count, + const VkQueueFamilyProperties* queue_family_properties, + VkQueueFlags required_queue_flags, VkQueueFlags excluded_queue_flags) { + for (uint32_t queue_family_index = 0; queue_family_index < queue_family_count; ++queue_family_index) { - const auto& properties = queue_family_properties[queue_family_index]; - if ((properties.queueFlags & required_queue_flags) == - required_queue_flags && - (properties.queueFlags & excluded_queue_flags) == 0) { + const VkQueueFamilyProperties* properties = + &queue_family_properties[queue_family_index]; + if (iree_all_bits_set(properties->queueFlags, required_queue_flags) && + !iree_any_bit_set(properties->queueFlags, excluded_queue_flags)) { return queue_family_index; } } - return kInvalidQueueFamilyIndex; + return IREE_HAL_VULKAN_INVALID_QUEUE_FAMILY_INDEX; } // Selects queue family indices for compute and transfer queues. // Note that both queue families may be the same if there is only one family // available. -StatusOr<QueueFamilyInfo> SelectQueueFamilies( - VkPhysicalDevice physical_device, const ref_ptr<DynamicSymbols>& syms) { +static iree_status_t iree_hal_vulkan_select_queue_families( + VkPhysicalDevice physical_device, iree::hal::vulkan::DynamicSymbols* syms, + iree_hal_vulkan_queue_family_info_t* out_family_info) { // Enumerate queue families available on the device. uint32_t queue_family_count = 0; syms->vkGetPhysicalDeviceQueueFamilyProperties(physical_device, - &queue_family_count, nullptr); - absl::InlinedVector<VkQueueFamilyProperties, 4> queue_family_properties( - queue_family_count); + &queue_family_count, NULL); + VkQueueFamilyProperties* queue_family_properties = + (VkQueueFamilyProperties*)iree_alloca(queue_family_count * + sizeof(VkQueueFamilyProperties)); syms->vkGetPhysicalDeviceQueueFamilyProperties( - physical_device, &queue_family_count, queue_family_properties.data()); + physical_device, &queue_family_count, queue_family_properties); - QueueFamilyInfo queue_family_info; + memset(out_family_info, 0, sizeof(*out_family_info)); + out_family_info->dispatch_index = IREE_HAL_VULKAN_INVALID_QUEUE_FAMILY_INDEX; + out_family_info->dispatch_queue_count = 0; + out_family_info->transfer_index = IREE_HAL_VULKAN_INVALID_QUEUE_FAMILY_INDEX; + out_family_info->transfer_queue_count = 0; // Try to find a dedicated compute queue (no graphics caps). // Some may support both transfer and compute. If that fails then fallback // to any queue that supports compute. - queue_family_info.dispatch_index = FindFirstQueueFamilyWithFlags( - queue_family_properties, VK_QUEUE_COMPUTE_BIT, VK_QUEUE_GRAPHICS_BIT); - if (queue_family_info.dispatch_index == kInvalidQueueFamilyIndex) { - queue_family_info.dispatch_index = FindFirstQueueFamilyWithFlags( - queue_family_properties, VK_QUEUE_COMPUTE_BIT, 0); + out_family_info->dispatch_index = + iree_hal_vulkan_find_first_queue_family_with_flags( + queue_family_count, queue_family_properties, VK_QUEUE_COMPUTE_BIT, + VK_QUEUE_GRAPHICS_BIT); + if (out_family_info->dispatch_index == + IREE_HAL_VULKAN_INVALID_QUEUE_FAMILY_INDEX) { + out_family_info->dispatch_index = + iree_hal_vulkan_find_first_queue_family_with_flags( + queue_family_count, queue_family_properties, VK_QUEUE_COMPUTE_BIT, + 0); } - if (queue_family_info.dispatch_index == kInvalidQueueFamilyIndex) { - return NotFoundErrorBuilder(IREE_LOC) - << "Unable to find any queue family support compute operations"; + if (out_family_info->dispatch_index == + IREE_HAL_VULKAN_INVALID_QUEUE_FAMILY_INDEX) { + return iree_make_status( + IREE_STATUS_NOT_FOUND, + "unable to find any queue family support compute operations"); } - queue_family_info.dispatch_queue_count = - queue_family_properties[queue_family_info.dispatch_index].queueCount; + out_family_info->dispatch_queue_count = + queue_family_properties[out_family_info->dispatch_index].queueCount; // Try to find a dedicated transfer queue (no compute or graphics caps). // Not all devices have one, and some have only a queue family for @@ -114,145 +226,430 @@ // fails then fallback to any queue that supports transfer. Finally, if // /that/ fails then we just won't create a transfer queue and instead use // the compute queue for all operations. - queue_family_info.transfer_index = FindFirstQueueFamilyWithFlags( - queue_family_properties, VK_QUEUE_TRANSFER_BIT, - VK_QUEUE_COMPUTE_BIT | VK_QUEUE_GRAPHICS_BIT); - if (queue_family_info.transfer_index == kInvalidQueueFamilyIndex) { - queue_family_info.transfer_index = FindFirstQueueFamilyWithFlags( - queue_family_properties, VK_QUEUE_TRANSFER_BIT, VK_QUEUE_GRAPHICS_BIT); + out_family_info->transfer_index = + iree_hal_vulkan_find_first_queue_family_with_flags( + queue_family_count, queue_family_properties, VK_QUEUE_TRANSFER_BIT, + VK_QUEUE_COMPUTE_BIT | VK_QUEUE_GRAPHICS_BIT); + if (out_family_info->transfer_index == + IREE_HAL_VULKAN_INVALID_QUEUE_FAMILY_INDEX) { + out_family_info->transfer_index = + iree_hal_vulkan_find_first_queue_family_with_flags( + queue_family_count, queue_family_properties, VK_QUEUE_TRANSFER_BIT, + VK_QUEUE_GRAPHICS_BIT); } - if (queue_family_info.transfer_index == kInvalidQueueFamilyIndex) { - queue_family_info.transfer_index = FindFirstQueueFamilyWithFlags( - queue_family_properties, VK_QUEUE_TRANSFER_BIT, 0); + if (out_family_info->transfer_index == + IREE_HAL_VULKAN_INVALID_QUEUE_FAMILY_INDEX) { + out_family_info->transfer_index = + iree_hal_vulkan_find_first_queue_family_with_flags( + queue_family_count, queue_family_properties, VK_QUEUE_TRANSFER_BIT, + 0); } - if (queue_family_info.transfer_index != kInvalidQueueFamilyIndex) { - queue_family_info.transfer_queue_count = - queue_family_properties[queue_family_info.transfer_index].queueCount; + if (out_family_info->transfer_index != + IREE_HAL_VULKAN_INVALID_QUEUE_FAMILY_INDEX) { + out_family_info->transfer_queue_count = + queue_family_properties[out_family_info->transfer_index].queueCount; } // Ensure that we don't share the dispatch queues with transfer queues if // that would put us over the queue count. - if (queue_family_info.dispatch_index == queue_family_info.transfer_index) { - queue_family_info.transfer_queue_count = std::min( - queue_family_properties[queue_family_info.dispatch_index].queueCount - - queue_family_info.dispatch_queue_count, - queue_family_info.transfer_queue_count); + if (out_family_info->dispatch_index == out_family_info->transfer_index) { + out_family_info->transfer_queue_count = iree_min( + queue_family_properties[out_family_info->dispatch_index].queueCount - + out_family_info->dispatch_queue_count, + out_family_info->transfer_queue_count); } - return queue_family_info; + // Limit the number of queues we create (for now). + // We may want to allow this to grow, but each queue adds overhead and we + // need to measure to make sure we can effectively use them all. + out_family_info->dispatch_queue_count = + iree_min(2u, out_family_info->dispatch_queue_count); + out_family_info->transfer_queue_count = + iree_min(1u, out_family_info->transfer_queue_count); + + return iree_ok_status(); +} + +// Builds a set of compute and transfer queues based on the queues available on +// the device and some magic heuristical goo. +static iree_status_t iree_hal_vulkan_build_queue_sets( + VkPhysicalDevice physical_device, iree::hal::vulkan::DynamicSymbols* syms, + iree_hal_vulkan_queue_set_t* out_compute_queue_set, + iree_hal_vulkan_queue_set_t* out_transfer_queue_set) { + // Select which queues to use (and fail the implementation can't handle them). + iree_hal_vulkan_queue_family_info_t queue_family_info; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_select_queue_families( + physical_device, syms, &queue_family_info)); + + // Build queue indices for the selected queue families. + memset(out_compute_queue_set, 0, sizeof(*out_compute_queue_set)); + out_compute_queue_set->queue_family_index = queue_family_info.dispatch_index; + for (iree_host_size_t i = 0; i < queue_family_info.dispatch_queue_count; + ++i) { + out_compute_queue_set->queue_indices |= 1ull << i; + } + + memset(out_transfer_queue_set, 0, sizeof(*out_transfer_queue_set)); + out_transfer_queue_set->queue_family_index = queue_family_info.transfer_index; + uint32_t base_queue_index = 0; + if (queue_family_info.dispatch_index == queue_family_info.transfer_index) { + // Sharing a family, so transfer queues follow compute queues. + base_queue_index = queue_family_info.dispatch_index; + } + for (iree_host_size_t i = 0; i < queue_family_info.transfer_queue_count; + ++i) { + out_transfer_queue_set->queue_indices |= 1ull << (i + base_queue_index); + } + + return iree_ok_status(); +} + +//===----------------------------------------------------------------------===// +// iree_hal_vulkan_device_t +//===----------------------------------------------------------------------===// + +typedef struct { + iree_hal_resource_t resource; + iree_string_view_t identifier; + + // Optional driver that owns the instance. We retain it for our lifetime to + // ensure the instance remains valid. + iree_hal_driver_t* driver; + + // Flags overriding default device behavior. + iree_hal_vulkan_device_flags_t flags; + // Which optional extensions are active and available on the device. + iree_hal_vulkan_device_extensions_t device_extensions; + + VkInstance instance; + VkPhysicalDevice physical_device; + VkDeviceHandle* logical_device; + + iree_allocator_t host_allocator; + iree_hal_allocator_t* device_allocator; + + // All queues available on the device; the device owns these. + iree_host_size_t queue_count; + CommandQueue** queues; + // The subset of queues that support dispatch operations. May overlap with + // transfer_queues. + iree_host_size_t dispatch_queue_count; + CommandQueue** dispatch_queues; + // The subset of queues that support transfer operations. May overlap with + // dispatch_queues. + iree_host_size_t transfer_queue_count; + CommandQueue** transfer_queues; + + DescriptorPoolCache* descriptor_pool_cache; + + VkCommandPoolHandle* dispatch_command_pool; + VkCommandPoolHandle* transfer_command_pool; + + // Used only for emulated timeline semaphores. + TimePointSemaphorePool* semaphore_pool; + TimePointFencePool* fence_pool; +} iree_hal_vulkan_device_t; + +extern const iree_hal_device_vtable_t iree_hal_vulkan_device_vtable; + +static iree_hal_vulkan_device_t* iree_hal_vulkan_device_cast( + iree_hal_device_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_vulkan_device_vtable); + return (iree_hal_vulkan_device_t*)base_value; +} + +IREE_API_EXPORT void IREE_API_CALL iree_hal_vulkan_device_options_initialize( + iree_hal_vulkan_device_options_t* out_options) { + memset(out_options, 0, sizeof(*out_options)); + out_options->flags = 0; } // Creates a transient command pool for the given queue family. // Command buffers allocated from the pool must only be issued on queues // belonging to the specified family. -StatusOr<ref_ptr<VkCommandPoolHandle>> CreateTransientCommandPool( - const ref_ptr<VkDeviceHandle>& logical_device, - uint32_t queue_family_index) { +static iree_status_t iree_hal_vulkan_create_transient_command_pool( + VkDeviceHandle* logical_device, uint32_t queue_family_index, + VkCommandPoolHandle** out_handle) { VkCommandPoolCreateInfo create_info; create_info.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; - create_info.pNext = nullptr; + create_info.pNext = NULL; create_info.flags = VK_COMMAND_POOL_CREATE_TRANSIENT_BIT; create_info.queueFamilyIndex = queue_family_index; - - auto command_pool = make_ref<VkCommandPoolHandle>(logical_device); - VK_RETURN_IF_ERROR(logical_device->syms()->vkCreateCommandPool( - *logical_device, &create_info, logical_device->allocator(), - command_pool->mutable_value())); - return command_pool; + VkCommandPoolHandle* command_pool = new VkCommandPoolHandle(logical_device); + iree_status_t status = VK_RESULT_TO_STATUS( + logical_device->syms()->vkCreateCommandPool( + *logical_device, &create_info, logical_device->allocator(), + command_pool->mutable_value()), + "vkCreateCommandPool"); + if (iree_status_is_ok(status)) { + *out_handle = command_pool; + } else { + delete command_pool; + } + return status; } -// Creates command queues for the given sets of queues. -absl::InlinedVector<std::unique_ptr<CommandQueue>, 4> CreateCommandQueues( - const DeviceInfo& device_info, - const ref_ptr<VkDeviceHandle>& logical_device, - const QueueSet& compute_queue_set, const QueueSet& transfer_queue_set, - const ref_ptr<TimePointFencePool>& fence_pool, - const ref_ptr<DynamicSymbols>& syms) { - absl::InlinedVector<std::unique_ptr<CommandQueue>, 4> command_queues; +// Creates a command queue of the given queue family. +static CommandQueue* iree_hal_vulkan_device_create_queue( + VkDeviceHandle* logical_device, + iree_hal_command_category_t command_category, uint32_t queue_family_index, + uint32_t queue_index, TimePointFencePool* fence_pool) { + VkQueue queue = VK_NULL_HANDLE; + logical_device->syms()->vkGetDeviceQueue(*logical_device, queue_family_index, + queue_index, &queue); + std::string queue_name; + if (!iree_all_bits_set(command_category, + IREE_HAL_COMMAND_CATEGORY_DISPATCH)) { + queue_name = "q(t):"; + } else { + queue_name = "q(d):"; + } + queue_name += std::to_string(queue_index); + + // When emulating timeline semaphores we use a special queue that allows us to + // sequence the semaphores correctly. + if (fence_pool != NULL) { + return new SerializingCommandQueue(logical_device, std::move(queue_name), + command_category, queue, fence_pool); + } + + return new DirectCommandQueue(logical_device, std::move(queue_name), + command_category, queue); +} + +// Creates command queues for the given sets of queues and populates the +// device queue lists. +static void iree_hal_vulkan_device_initialize_command_queues( + iree_hal_vulkan_device_t* device, iree_string_view_t queue_prefix, + const iree_hal_vulkan_queue_set_t* compute_queue_set, + const iree_hal_vulkan_queue_set_t* transfer_queue_set) { + device->queue_count = 0; + device->dispatch_queue_count = 0; + device->transfer_queue_count = 0; uint64_t compute_queue_count = - iree_math_count_ones_u64(compute_queue_set.queue_indices); - for (uint32_t i = 0; i < compute_queue_count; ++i) { - if (!(compute_queue_set.queue_indices & (1ull << i))) continue; - - VkQueue queue = VK_NULL_HANDLE; - syms->vkGetDeviceQueue(*logical_device, - compute_queue_set.queue_family_index, i, &queue); - std::string queue_name = absl::StrCat(device_info.name(), ":d", i); - - if (fence_pool != nullptr) { - command_queues.push_back(absl::make_unique<SerializingCommandQueue>( - std::move(queue_name), IREE_HAL_COMMAND_CATEGORY_ANY, logical_device, - fence_pool, queue)); - } else { - command_queues.push_back(absl::make_unique<DirectCommandQueue>( - std::move(queue_name), IREE_HAL_COMMAND_CATEGORY_ANY, logical_device, - queue)); - } - } - + iree_math_count_ones_u64(compute_queue_set->queue_indices); uint64_t transfer_queue_count = - iree_math_count_ones_u64(transfer_queue_set.queue_indices); - for (uint32_t i = 0; i < transfer_queue_count; ++i) { - if (!(transfer_queue_set.queue_indices & (1ull << i))) continue; - - VkQueue queue = VK_NULL_HANDLE; - syms->vkGetDeviceQueue(*logical_device, - transfer_queue_set.queue_family_index, i, &queue); - std::string queue_name = absl::StrCat(device_info.name(), ":t", i); - if (fence_pool != nullptr) { - command_queues.push_back(absl::make_unique<SerializingCommandQueue>( - std::move(queue_name), IREE_HAL_COMMAND_CATEGORY_TRANSFER, - logical_device, fence_pool, queue)); - } else { - command_queues.push_back(absl::make_unique<DirectCommandQueue>( - std::move(queue_name), IREE_HAL_COMMAND_CATEGORY_TRANSFER, - logical_device, queue)); + iree_math_count_ones_u64(transfer_queue_set->queue_indices); + for (iree_host_size_t i = 0; i < compute_queue_count; ++i) { + if (!(compute_queue_set->queue_indices & (1ull << i))) continue; + CommandQueue* queue = iree_hal_vulkan_device_create_queue( + device->logical_device, IREE_HAL_COMMAND_CATEGORY_ANY, + compute_queue_set->queue_family_index, i, device->fence_pool); + device->queues[device->queue_count++] = queue; + device->dispatch_queues[device->dispatch_queue_count++] = queue; + if (!transfer_queue_count) { + // If we don't have any dedicated transfer queues then use all dispatch + // queues as transfer queues. + device->transfer_queues[device->transfer_queue_count++] = queue; } } - - return command_queues; + for (iree_host_size_t i = 0; i < transfer_queue_count; ++i) { + if (!(transfer_queue_set->queue_indices & (1ull << i))) continue; + CommandQueue* queue = iree_hal_vulkan_device_create_queue( + device->logical_device, IREE_HAL_COMMAND_CATEGORY_TRANSFER, + compute_queue_set->queue_family_index, i, device->fence_pool); + device->queues[device->queue_count++] = queue; + device->transfer_queues[device->transfer_queue_count++] = queue; + } } -} // namespace +static iree_status_t iree_hal_vulkan_device_create_internal( + iree_hal_driver_t* driver, iree_string_view_t identifier, + const iree_hal_vulkan_device_options_t* options, VkInstance instance, + VkPhysicalDevice physical_device, VkDeviceHandle* logical_device, + const iree_hal_vulkan_device_extensions_t* device_extensions, + const iree_hal_vulkan_queue_set_t* compute_queue_set, + const iree_hal_vulkan_queue_set_t* transfer_queue_set, + iree_allocator_t host_allocator, iree_hal_device_t** out_device) { + auto& device_syms = logical_device->syms(); -// static -StatusOr<ref_ptr<VulkanDevice>> VulkanDevice::Create( - ref_ptr<Driver> driver, VkInstance instance, const DeviceInfo& device_info, - VkPhysicalDevice physical_device, Options options, - const ref_ptr<DynamicSymbols>& syms, - DebugCaptureManager* debug_capture_manager) { - IREE_TRACE_SCOPE0("VulkanDevice::Create"); + iree_host_size_t compute_queue_count = + iree_math_count_ones_u64(compute_queue_set->queue_indices); + iree_host_size_t transfer_queue_count = + iree_math_count_ones_u64(transfer_queue_set->queue_indices); + iree_host_size_t total_queue_count = + compute_queue_count + transfer_queue_count; - if (!options.extensibility_spec.optional_layers.empty() || - !options.extensibility_spec.required_layers.empty()) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "Device layers are deprecated and unsupported by IREE"; + iree_hal_vulkan_device_t* device = NULL; + iree_host_size_t total_size = + sizeof(*device) + identifier.size + + total_queue_count * sizeof(device->queues[0]) + + total_queue_count * sizeof(device->dispatch_queues[0]) + + total_queue_count * sizeof(device->transfer_queues[0]); + IREE_RETURN_IF_ERROR( + iree_allocator_malloc(host_allocator, total_size, (void**)&device)); + memset(device, 0, total_size); + iree_hal_resource_initialize(&iree_hal_vulkan_device_vtable, + &device->resource); + device->host_allocator = host_allocator; + device->driver = driver; + iree_hal_driver_retain(device->driver); + uint8_t* buffer_ptr = (uint8_t*)device + sizeof(*device); + buffer_ptr += iree_string_view_append_to_buffer( + identifier, &device->identifier, (char*)buffer_ptr); + device->flags = options->flags; + + device->device_extensions = *device_extensions; + device->instance = instance; + device->physical_device = physical_device; + device->logical_device = logical_device; + device->logical_device->AddReference(); + + // Point the queue storage into the new device allocation. The queues + // themselves are populated + device->queues = (CommandQueue**)buffer_ptr; + buffer_ptr += total_queue_count * sizeof(device->queues[0]); + device->dispatch_queues = (CommandQueue**)buffer_ptr; + buffer_ptr += total_queue_count * sizeof(device->dispatch_queues[0]); + device->transfer_queues = (CommandQueue**)buffer_ptr; + buffer_ptr += total_queue_count * sizeof(device->transfer_queues[0]); + + device->descriptor_pool_cache = + new DescriptorPoolCache(device->logical_device); + + // Create the device memory allocator that will service all buffer + // allocation requests. + VmaRecordSettings vma_record_settings; + memset(&vma_record_settings, 0, sizeof(vma_record_settings)); + iree_status_t status = iree_hal_vulkan_vma_allocator_create( + instance, physical_device, logical_device, vma_record_settings, + &device->device_allocator); + + // Create command pools for each queue family. If we don't have a transfer + // queue then we'll ignore that one and just use the dispatch pool. + // If we wanted to expose the pools through the HAL to allow the VM to more + // effectively manage them (pool per fiber, etc) we could, however I doubt + // the overhead of locking the pool will be even a blip. + if (iree_status_is_ok(status)) { + status = iree_hal_vulkan_create_transient_command_pool( + device->logical_device, compute_queue_set->queue_family_index, + &device->dispatch_command_pool); } + if (transfer_queue_set->queue_indices != 0 && iree_status_is_ok(status)) { + status = iree_hal_vulkan_create_transient_command_pool( + device->logical_device, transfer_queue_set->queue_family_index, + &device->transfer_command_pool); + } + + // Emulate timeline semaphores when the extension is not available and we are + // ony Vulkan versions prior to 1.2 when they were made core. + bool emulate_timeline_semaphores = + device_syms->vkGetSemaphoreCounterValue == NULL || + iree_all_bits_set( + options->flags, + IREE_HAL_VULKAN_DEVICE_FORCE_TIMELINE_SEMAPHORE_EMULATION); + if (emulate_timeline_semaphores && iree_status_is_ok(status)) { + status = TimePointSemaphorePool::Create(device->logical_device, + &device->semaphore_pool); + } + if (emulate_timeline_semaphores && iree_status_is_ok(status)) { + status = + TimePointFencePool::Create(device->logical_device, &device->fence_pool); + } + + // Initialize queues now that we've completed the rest of the device + // initialization; this happens last as the queues require the pools allocated + // above. + if (iree_status_is_ok(status)) { + iree_hal_vulkan_device_initialize_command_queues( + device, identifier, compute_queue_set, transfer_queue_set); + } + + if (iree_status_is_ok(status)) { + *out_device = (iree_hal_device_t*)device; + } else { + iree_hal_device_destroy((iree_hal_device_t*)device); + } + return status; +} + +static void iree_hal_vulkan_device_destroy(iree_hal_device_t* base_device) { + iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); + iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device); + IREE_TRACE_ZONE_BEGIN(z0); + + // Drop all command queues. These may wait until idle in their destructor. + for (iree_host_size_t i = 0; i < device->queue_count; ++i) { + delete device->queues[i]; + } + + // Drop command pools now that we know there are no more outstanding command + // buffers. + delete device->dispatch_command_pool; + delete device->transfer_command_pool; + + // Now that no commands are outstanding we can release all resources that may + // have been in use. + delete device->descriptor_pool_cache; + delete device->semaphore_pool; + delete device->fence_pool; + + // There should be no more buffers live that use the allocator. + iree_hal_allocator_release(device->device_allocator); + + // Finally, destroy the device. + device->logical_device->ReleaseReference(); + iree_hal_driver_release(device->driver); + + iree_allocator_free(host_allocator, device); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_status_t iree_hal_vulkan_device_query_extensibility_set( + iree_hal_vulkan_features_t requested_features, + iree_hal_vulkan_extensibility_set_t set, iree::Arena* arena, + iree_hal_vulkan_string_list_t* out_string_list) { + IREE_RETURN_IF_ERROR(iree_hal_vulkan_query_extensibility_set( + requested_features, set, 0, NULL, &out_string_list->count)); + out_string_list->values = (const char**)arena->AllocateBytes( + out_string_list->count * sizeof(out_string_list->values[0])); + IREE_RETURN_IF_ERROR(iree_hal_vulkan_query_extensibility_set( + requested_features, set, out_string_list->count, out_string_list->values, + &out_string_list->count)); + return iree_ok_status(); +} + +iree_status_t iree_hal_vulkan_device_create( + iree_hal_driver_t* driver, iree_string_view_t identifier, + iree_hal_vulkan_features_t enabled_features, + const iree_hal_vulkan_device_options_t* options, + iree_hal_vulkan_syms_t* opaque_syms, VkInstance instance, + VkPhysicalDevice physical_device, iree_allocator_t host_allocator, + iree_hal_device_t** out_device) { + DynamicSymbols* instance_syms = (DynamicSymbols*)opaque_syms; // Find the extensions we need (or want) that are also available // on the device. This will fail when required ones are not present. - IREE_ASSIGN_OR_RETURN( - auto enabled_extension_names, - MatchAvailableDeviceExtensions(physical_device, - options.extensibility_spec, *syms)); - auto enabled_device_extensions = - PopulateEnabledDeviceExtensions(enabled_extension_names); + // TODO(benvanik): replace with a real arena. + iree::Arena arena(128 * 1024); + iree_hal_vulkan_string_list_t required_extensions; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_device_query_extensibility_set( + enabled_features, + IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_REQUIRED, &arena, + &required_extensions)); + iree_hal_vulkan_string_list_t optional_extensions; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_device_query_extensibility_set( + enabled_features, + IREE_HAL_VULKAN_EXTENSIBILITY_DEVICE_EXTENSIONS_OPTIONAL, &arena, + &optional_extensions)); + iree_hal_vulkan_string_list_t enabled_extensions; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_match_available_device_extensions( + instance_syms, physical_device, &required_extensions, + &optional_extensions, &arena, &enabled_extensions)); + iree_hal_vulkan_device_extensions_t enabled_device_extensions = + iree_hal_vulkan_populate_enabled_device_extensions(&enabled_extensions); // Find queue families we will expose as HAL queues. - IREE_ASSIGN_OR_RETURN(auto queue_family_info, - SelectQueueFamilies(physical_device, syms)); + iree_hal_vulkan_queue_family_info_t queue_family_info; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_select_queue_families( + physical_device, instance_syms, &queue_family_info)); - // Limit the number of queues we create (for now). - // We may want to allow this to grow, but each queue adds overhead and we - // need to measure to make sure we can effectively use them all. - queue_family_info.dispatch_queue_count = - std::min(2u, queue_family_info.dispatch_queue_count); - queue_family_info.transfer_queue_count = - std::min(1u, queue_family_info.transfer_queue_count); bool has_dedicated_transfer_queues = queue_family_info.transfer_queue_count > 0; + // TODO(benvanik): convert to using the arena. // Setup the queue info we'll be using. // Each queue here (created from within a family) will map to a HAL queue. // @@ -260,34 +657,24 @@ // are of the same queue family as the dispatch queues: Vulkan requires that // all queues created from the same family are done in the same // VkDeviceQueueCreateInfo struct. - IREE_DVLOG(1) << "Creating " << queue_family_info.dispatch_queue_count - << " dispatch queue(s) in queue family " - << queue_family_info.dispatch_index; absl::InlinedVector<VkDeviceQueueCreateInfo, 2> queue_create_info; absl::InlinedVector<float, 4> dispatch_queue_priorities; absl::InlinedVector<float, 4> transfer_queue_priorities; queue_create_info.push_back({}); auto& dispatch_queue_info = queue_create_info.back(); dispatch_queue_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; - dispatch_queue_info.pNext = nullptr; + dispatch_queue_info.pNext = NULL; dispatch_queue_info.flags = 0; dispatch_queue_info.queueFamilyIndex = queue_family_info.dispatch_index; dispatch_queue_info.queueCount = queue_family_info.dispatch_queue_count; if (has_dedicated_transfer_queues) { if (queue_family_info.dispatch_index == queue_family_info.transfer_index) { - IREE_DVLOG(1) << "Creating " << queue_family_info.transfer_queue_count - << " dedicated transfer queue(s) in shared queue family " - << queue_family_info.transfer_index; dispatch_queue_info.queueCount += queue_family_info.transfer_queue_count; } else { - IREE_DVLOG(1) - << "Creating " << queue_family_info.transfer_queue_count - << " dedicated transfer queue(s) in independent queue family " - << queue_family_info.transfer_index; queue_create_info.push_back({}); auto& transfer_queue_info = queue_create_info.back(); transfer_queue_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; - transfer_queue_info.pNext = nullptr; + transfer_queue_info.pNext = NULL; transfer_queue_info.queueFamilyIndex = queue_family_info.transfer_index; transfer_queue_info.queueCount = queue_family_info.transfer_queue_count; transfer_queue_info.flags = 0; @@ -299,547 +686,316 @@ dispatch_queue_info.pQueuePriorities = dispatch_queue_priorities.data(); // Create device and its queues. - VkDeviceCreateInfo device_create_info = {}; + VkDeviceCreateInfo device_create_info; + memset(&device_create_info, 0, sizeof(device_create_info)); device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO; device_create_info.enabledLayerCount = 0; - device_create_info.ppEnabledLayerNames = nullptr; - device_create_info.enabledExtensionCount = enabled_extension_names.size(); - device_create_info.ppEnabledExtensionNames = enabled_extension_names.data(); + device_create_info.ppEnabledLayerNames = NULL; + device_create_info.enabledExtensionCount = enabled_extensions.count; + device_create_info.ppEnabledExtensionNames = enabled_extensions.values; device_create_info.queueCreateInfoCount = queue_create_info.size(); device_create_info.pQueueCreateInfos = queue_create_info.data(); - device_create_info.pEnabledFeatures = nullptr; + device_create_info.pEnabledFeatures = NULL; + + VkPhysicalDeviceFeatures2 features2; + memset(&features2, 0, sizeof(features2)); + features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; + device_create_info.pNext = &features2; VkPhysicalDeviceTimelineSemaphoreFeatures semaphore_features; - std::memset(&semaphore_features, 0, sizeof(semaphore_features)); - semaphore_features.sType = - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_TIMELINE_SEMAPHORE_FEATURES; - semaphore_features.timelineSemaphore = VK_TRUE; - VkPhysicalDeviceFeatures2 features2; - std::memset(&features2, 0, sizeof(features2)); - features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; - features2.pNext = &semaphore_features; - - if (!enabled_device_extensions.timeline_semaphore || - options.force_timeline_semaphore_emulation) { - device_create_info.pNext = nullptr; - } else { - device_create_info.pNext = &features2; + bool emulate_timeline_semaphores = + !enabled_device_extensions.timeline_semaphore || + iree_all_bits_set( + options->flags, + IREE_HAL_VULKAN_DEVICE_FORCE_TIMELINE_SEMAPHORE_EMULATION); + if (!emulate_timeline_semaphores) { + memset(&semaphore_features, 0, sizeof(semaphore_features)); + semaphore_features.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_TIMELINE_SEMAPHORE_FEATURES; + semaphore_features.pNext = features2.pNext; + features2.pNext = &semaphore_features; + semaphore_features.timelineSemaphore = VK_TRUE; } - auto logical_device = - make_ref<VkDeviceHandle>(syms, enabled_device_extensions, - /*owns_device=*/true, /*allocator=*/nullptr); - // The Vulkan loader can leak here, depending on which features are enabled. - // This is out of our control, so disable leak checks. - IREE_DISABLE_LEAK_CHECKS(); - VK_RETURN_IF_ERROR(syms->vkCreateDevice(physical_device, &device_create_info, - logical_device->allocator(), - logical_device->mutable_value())); - IREE_RETURN_IF_ERROR(logical_device->syms()->LoadFromDevice( - instance, logical_device->value())); - IREE_ENABLE_LEAK_CHECKS(); + auto logical_device = new VkDeviceHandle( + instance_syms, enabled_device_extensions, + /*owns_device=*/true, host_allocator, /*allocator=*/NULL); - // Create the device memory allocator. - // TODO(benvanik): allow other types to be plugged in. - IREE_ASSIGN_OR_RETURN( - auto allocator, - VmaAllocator::Create(physical_device, logical_device, instance, - std::move(options.vma_options))); - - // Create command pools for each queue family. If we don't have a transfer - // queue then we'll ignore that one and just use the dispatch pool. - // If we wanted to expose the pools through the HAL to allow the VM to more - // effectively manage them (pool per fiber, etc) we could, however I doubt - // the overhead of locking the pool will be even a blip. - IREE_ASSIGN_OR_RETURN(auto dispatch_command_pool, - CreateTransientCommandPool( - logical_device, queue_family_info.dispatch_index)); - ref_ptr<VkCommandPoolHandle> transfer_command_pool; - if (has_dedicated_transfer_queues) { - IREE_ASSIGN_OR_RETURN( - transfer_command_pool, - CreateTransientCommandPool(logical_device, - queue_family_info.transfer_index)); + iree_status_t status = VK_RESULT_TO_STATUS( + instance_syms->vkCreateDevice(physical_device, &device_create_info, + logical_device->allocator(), + logical_device->mutable_value()), + "vkCreateDevice"); + if (iree_status_is_ok(status)) { + status = logical_device->syms()->LoadFromDevice(instance, + logical_device->value()); } // Select queue indices and create command queues with them. - QueueSet compute_queue_set = {}; - compute_queue_set.queue_family_index = queue_family_info.dispatch_index; - for (uint32_t i = 0; i < queue_family_info.dispatch_queue_count; ++i) { - compute_queue_set.queue_indices |= 1ull << i; - } - QueueSet transfer_queue_set = {}; - transfer_queue_set.queue_family_index = queue_family_info.transfer_index; - uint32_t base_queue_index = 0; - if (queue_family_info.dispatch_index == queue_family_info.transfer_index) { - // Sharing a family, so transfer queues follow compute queues. - base_queue_index = queue_family_info.dispatch_index; - } - for (uint32_t i = 0; i < queue_family_info.transfer_queue_count; ++i) { - transfer_queue_set.queue_indices |= 1ull << (i + base_queue_index); + iree_hal_vulkan_queue_set_t compute_queue_set; + iree_hal_vulkan_queue_set_t transfer_queue_set; + if (iree_status_is_ok(status)) { + status = iree_hal_vulkan_build_queue_sets( + physical_device, logical_device->syms().get(), &compute_queue_set, + &transfer_queue_set); } - // Emulate timeline semaphores if associated functions are not defined. - ref_ptr<TimePointSemaphorePool> semaphore_pool = nullptr; - ref_ptr<TimePointFencePool> fence_pool = nullptr; - if (syms->vkGetSemaphoreCounterValue == nullptr || - options.force_timeline_semaphore_emulation) { - IREE_ASSIGN_OR_RETURN(semaphore_pool, TimePointSemaphorePool::Create( - add_ref(logical_device))); - IREE_ASSIGN_OR_RETURN(fence_pool, - TimePointFencePool::Create(add_ref(logical_device))); + // Allocate and initialize the device. + if (iree_status_is_ok(status)) { + status = iree_hal_vulkan_device_create_internal( + driver, identifier, options, instance, physical_device, logical_device, + &enabled_device_extensions, &compute_queue_set, &transfer_queue_set, + host_allocator, out_device); } - auto command_queues = - CreateCommandQueues(device_info, logical_device, compute_queue_set, - transfer_queue_set, fence_pool, syms); - - return assign_ref(new VulkanDevice( - std::move(driver), device_info, physical_device, - std::move(logical_device), std::move(allocator), - std::move(command_queues), std::move(dispatch_command_pool), - std::move(transfer_command_pool), std::move(semaphore_pool), - std::move(fence_pool), debug_capture_manager)); + logical_device->ReleaseReference(); + return status; } -// static -StatusOr<ref_ptr<VulkanDevice>> VulkanDevice::Wrap( - ref_ptr<Driver> driver, VkInstance instance, const DeviceInfo& device_info, - VkPhysicalDevice physical_device, VkDevice logical_device, Options options, - const QueueSet& compute_queue_set, const QueueSet& transfer_queue_set, - const ref_ptr<DynamicSymbols>& syms) { - IREE_TRACE_SCOPE0("VulkanDevice::Wrap"); +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_wrap_device( + iree_string_view_t identifier, + const iree_hal_vulkan_device_options_t* options, + const iree_hal_vulkan_syms_t* instance_syms, VkInstance instance, + VkPhysicalDevice physical_device, VkDevice logical_device, + const iree_hal_vulkan_queue_set_t* compute_queue_set, + const iree_hal_vulkan_queue_set_t* transfer_queue_set, + iree_allocator_t host_allocator, iree_hal_device_t** out_device) { + IREE_ASSERT_ARGUMENT(instance_syms); + IREE_ASSERT_ARGUMENT(instance); + IREE_ASSERT_ARGUMENT(physical_device); + IREE_ASSERT_ARGUMENT(logical_device); + IREE_ASSERT_ARGUMENT(out_device); - uint64_t compute_queue_count = - iree_math_count_ones_u64(compute_queue_set.queue_indices); - uint64_t transfer_queue_count = - iree_math_count_ones_u64(transfer_queue_set.queue_indices); - - if (compute_queue_count == 0) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "At least one compute queue is required"; + if (iree_math_count_ones_u64(compute_queue_set->queue_indices) == 0) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "at least one compute queue is required"); } - // Find the extensions we need (or want) that are also available on the - // device. This will fail when required ones are not present. - // + // Grab symbols from the device. + auto device_syms = iree::make_ref<DynamicSymbols>(); + device_syms->vkGetInstanceProcAddr = + ((const DynamicSymbols*)instance_syms)->vkGetInstanceProcAddr; + IREE_RETURN_IF_ERROR(device_syms->LoadFromDevice(instance, logical_device)); + // Since the device is already created, we can't actually enable any // extensions or query if they are really enabled - we just have to trust - // that the caller already enabled them for us (or we may fail later). - IREE_ASSIGN_OR_RETURN( - auto enabled_extension_names, - MatchAvailableDeviceExtensions(physical_device, - options.extensibility_spec, *syms)); - auto enabled_device_extensions = - PopulateEnabledDeviceExtensions(enabled_extension_names); + // that the caller already enabled them for us or we may fail later. For the + // optional extensions we check for the symbols but this is not always + // guaranteed to work. + iree_hal_vulkan_device_extensions_t enabled_device_extensions = + iree_hal_vulkan_infer_enabled_device_extensions(device_syms.get()); // Wrap the provided VkDevice with a VkDeviceHandle for use within the HAL. - auto device_handle = - make_ref<VkDeviceHandle>(syms, enabled_device_extensions, - /*owns_device=*/false, /*allocator=*/nullptr); - *device_handle->mutable_value() = logical_device; + auto logical_device_handle = new VkDeviceHandle( + device_syms.get(), enabled_device_extensions, + /*owns_device=*/false, host_allocator, /*allocator=*/NULL); + *logical_device_handle->mutable_value() = logical_device; - // Create the device memory allocator. - // TODO(benvanik): allow other types to be plugged in. - IREE_ASSIGN_OR_RETURN( - auto allocator, - VmaAllocator::Create(physical_device, device_handle, instance, - std::move(options.vma_options))); + // Allocate and initialize the device. + iree_status_t status = iree_hal_vulkan_device_create_internal( + /*driver=*/NULL, identifier, options, instance, physical_device, + logical_device_handle, &enabled_device_extensions, compute_queue_set, + transfer_queue_set, host_allocator, out_device); - bool has_dedicated_transfer_queues = transfer_queue_count > 0; - - // Create command pools for each queue family. If we don't have a transfer - // queue then we'll ignore that one and just use the dispatch pool. - // If we wanted to expose the pools through the HAL to allow the VM to more - // effectively manage them (pool per fiber, etc) we could, however I doubt - // the overhead of locking the pool will be even a blip. - IREE_ASSIGN_OR_RETURN( - auto dispatch_command_pool, - CreateTransientCommandPool(device_handle, - compute_queue_set.queue_family_index)); - ref_ptr<VkCommandPoolHandle> transfer_command_pool; - if (has_dedicated_transfer_queues) { - IREE_ASSIGN_OR_RETURN( - transfer_command_pool, - CreateTransientCommandPool(device_handle, - transfer_queue_set.queue_family_index)); - } - - // Emulate timeline semaphores if associated functions are not defined. - ref_ptr<TimePointSemaphorePool> semaphore_pool = nullptr; - ref_ptr<TimePointFencePool> fence_pool = nullptr; - if (syms->vkGetSemaphoreCounterValue == nullptr || - options.force_timeline_semaphore_emulation) { - IREE_ASSIGN_OR_RETURN( - semaphore_pool, TimePointSemaphorePool::Create(add_ref(device_handle))); - IREE_ASSIGN_OR_RETURN(fence_pool, - TimePointFencePool::Create(add_ref(device_handle))); - } - - auto command_queues = - CreateCommandQueues(device_info, device_handle, compute_queue_set, - transfer_queue_set, fence_pool, syms); - - return assign_ref(new VulkanDevice( - std::move(driver), device_info, physical_device, std::move(device_handle), - std::move(allocator), std::move(command_queues), - std::move(dispatch_command_pool), std::move(transfer_command_pool), - std::move(semaphore_pool), std::move(fence_pool), - /*debug_capture_manager=*/nullptr)); + logical_device_handle->ReleaseReference(); + return status; } -VulkanDevice::VulkanDevice( - ref_ptr<Driver> driver, const DeviceInfo& device_info, - VkPhysicalDevice physical_device, ref_ptr<VkDeviceHandle> logical_device, - std::unique_ptr<Allocator> allocator, - absl::InlinedVector<std::unique_ptr<CommandQueue>, 4> command_queues, - ref_ptr<VkCommandPoolHandle> dispatch_command_pool, - ref_ptr<VkCommandPoolHandle> transfer_command_pool, - ref_ptr<TimePointSemaphorePool> semaphore_pool, - ref_ptr<TimePointFencePool> fence_pool, - DebugCaptureManager* debug_capture_manager) - : Device(device_info), - driver_(std::move(driver)), - physical_device_(physical_device), - logical_device_(std::move(logical_device)), - allocator_(std::move(allocator)), - command_queues_(std::move(command_queues)), - descriptor_pool_cache_( - make_ref<DescriptorPoolCache>(add_ref(logical_device_))), - dispatch_command_pool_(std::move(dispatch_command_pool)), - transfer_command_pool_(std::move(transfer_command_pool)), - semaphore_pool_(std::move(semaphore_pool)), - fence_pool_(std::move(fence_pool)), - debug_capture_manager_(debug_capture_manager) { - // Populate the queue lists based on queue capabilities. - for (auto& command_queue : command_queues_) { - if (command_queue->can_dispatch()) { - dispatch_queues_.push_back(command_queue.get()); - if (transfer_command_pool_ == VK_NULL_HANDLE) { - transfer_queues_.push_back(command_queue.get()); - } - } else { - transfer_queues_.push_back(command_queue.get()); - } - } - - if (debug_capture_manager_ && debug_capture_manager_->is_connected()) { - // Record a capture covering the duration of this VkDevice's lifetime. - debug_capture_manager_->StartCapture(); - } +static iree_string_view_t iree_hal_vulkan_device_id( + iree_hal_device_t* base_device) { + iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); + return device->identifier; } -VulkanDevice::~VulkanDevice() { - IREE_TRACE_SCOPE0("VulkanDevice::dtor"); - if (debug_capture_manager_ && debug_capture_manager_->is_capturing()) { - debug_capture_manager_->StopCapture(); - } - - // Drop all command queues. These may wait until idle. - command_queues_.clear(); - dispatch_queues_.clear(); - transfer_queues_.clear(); - - // Drop command pools now that we know there are no more outstanding command - // buffers. - dispatch_command_pool_.reset(); - transfer_command_pool_.reset(); - - // Now that no commands are outstanding we can release all descriptor sets. - descriptor_pool_cache_.reset(); - - // Finally, destroy the device. - logical_device_.reset(); +static iree_allocator_t iree_hal_vulkan_device_host_allocator( + iree_hal_device_t* base_device) { + iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); + return device->host_allocator; } -std::string VulkanDevice::DebugString() const { - return absl::StrCat(Device::DebugString(), // - "\n[VulkanDevice]", // - "\n Command Queues: ", command_queues_.size(), // - "\n - Dispatch Queues: ", dispatch_queues_.size(), // - "\n - Transfer Queues: ", transfer_queues_.size()); +static iree_hal_allocator_t* iree_hal_vulkan_device_allocator( + iree_hal_device_t* base_device) { + iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); + return device->device_allocator; } -ref_ptr<ExecutableCache> VulkanDevice::CreateExecutableCache() { - IREE_TRACE_SCOPE0("VulkanDevice::CreateExecutableCache"); - return make_ref<PipelineCache>(add_ref(logical_device_)); -} - -StatusOr<ref_ptr<DescriptorSetLayout>> VulkanDevice::CreateDescriptorSetLayout( - iree_hal_descriptor_set_layout_usage_type_t usage_type, - absl::Span<const iree_hal_descriptor_set_layout_binding_t> bindings) { - IREE_TRACE_SCOPE0("VulkanDevice::CreateDescriptorSetLayout"); - - absl::InlinedVector<VkDescriptorSetLayoutBinding, 4> native_bindings( - bindings.size()); - for (int i = 0; i < bindings.size(); ++i) { - auto& native_binding = native_bindings[i]; - native_binding.binding = bindings[i].binding; - native_binding.descriptorType = - static_cast<VkDescriptorType>(bindings[i].type); - native_binding.descriptorCount = 1; - native_binding.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; - native_binding.pImmutableSamplers = nullptr; - } - - VkDescriptorSetLayoutCreateInfo create_info; - create_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO; - create_info.pNext = nullptr; - create_info.flags = 0; - if (usage_type == IREE_HAL_DESCRIPTOR_SET_LAYOUT_USAGE_TYPE_PUSH_ONLY && - logical_device_->enabled_extensions().push_descriptors) { - // Note that we can *only* use push descriptor sets if we set this create - // flag. If push descriptors aren't supported we emulate them with normal - // descriptors so it's fine to have kPushOnly without support. - create_info.flags |= - VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT_KHR; - } - create_info.bindingCount = native_bindings.size(); - create_info.pBindings = native_bindings.data(); - - // Create and insert into the cache. - VkDescriptorSetLayout descriptor_set_layout = VK_NULL_HANDLE; - VK_RETURN_IF_ERROR(syms()->vkCreateDescriptorSetLayout( - *logical_device_, &create_info, logical_device_->allocator(), - &descriptor_set_layout)); - - return make_ref<NativeDescriptorSetLayout>(add_ref(logical_device_), - descriptor_set_layout); -} - -StatusOr<ref_ptr<ExecutableLayout>> VulkanDevice::CreateExecutableLayout( - absl::Span<DescriptorSetLayout* const> set_layouts, size_t push_constants) { - IREE_TRACE_SCOPE0("VulkanDevice::CreateExecutableLayout"); - - absl::InlinedVector<ref_ptr<NativeDescriptorSetLayout>, 2> typed_set_layouts( - set_layouts.size()); - absl::InlinedVector<VkDescriptorSetLayout, 2> set_layout_handles( - set_layouts.size()); - for (int i = 0; i < set_layouts.size(); ++i) { - typed_set_layouts[i] = - add_ref(static_cast<NativeDescriptorSetLayout*>(set_layouts[i])); - set_layout_handles[i] = typed_set_layouts[i]->handle(); - } - - absl::InlinedVector<VkPushConstantRange, 1> push_constant_ranges; - if (push_constants > 0) { - push_constant_ranges.push_back(VkPushConstantRange{ - VK_SHADER_STAGE_COMPUTE_BIT, 0, - static_cast<uint32_t>(sizeof(uint32_t) * push_constants)}); - } - - VkPipelineLayoutCreateInfo create_info; - create_info.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO; - create_info.pNext = nullptr; - create_info.flags = 0; - create_info.setLayoutCount = set_layout_handles.size(); - create_info.pSetLayouts = set_layout_handles.data(); - create_info.pushConstantRangeCount = push_constant_ranges.size(); - create_info.pPushConstantRanges = push_constant_ranges.data(); - - // Create and insert into the cache. - VkPipelineLayout pipeline_layout = VK_NULL_HANDLE; - VK_RETURN_IF_ERROR(syms()->vkCreatePipelineLayout( - *logical_device_, &create_info, logical_device_->allocator(), - &pipeline_layout)); - - return make_ref<PipelineExecutableLayout>( - add_ref(logical_device_), pipeline_layout, std::move(typed_set_layouts)); -} - -StatusOr<ref_ptr<DescriptorSet>> VulkanDevice::CreateDescriptorSet( - DescriptorSetLayout* set_layout, - absl::Span<const iree_hal_descriptor_set_binding_t> bindings) { - IREE_TRACE_SCOPE0("VulkanDevice::CreateDescriptorSet"); - return UnimplementedErrorBuilder(IREE_LOC) - << "CreateDescriptorSet not yet implemented (needs timeline)"; -} - -StatusOr<ref_ptr<CommandBuffer>> VulkanDevice::CreateCommandBuffer( - iree_hal_command_buffer_mode_t mode, - iree_hal_command_category_t command_categories) { - IREE_TRACE_SCOPE0("VulkanDevice::CreateCommandBuffer"); +static iree_status_t iree_hal_vulkan_device_create_command_buffer( + iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode, + iree_hal_command_category_t command_categories, + iree_hal_command_buffer_t** out_command_buffer) { + iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); // Select the command pool to used based on the types of commands used. // Note that we may not have a dedicated transfer command pool if there are // no dedicated transfer queues. - ref_ptr<VkCommandPoolHandle> command_pool; - if (transfer_command_pool_ && + VkCommandPoolHandle* command_pool = NULL; + if (device->transfer_command_pool && !iree_all_bits_set(command_categories, IREE_HAL_COMMAND_CATEGORY_DISPATCH)) { - command_pool = add_ref(transfer_command_pool_); + command_pool = device->transfer_command_pool; } else { - command_pool = add_ref(dispatch_command_pool_); + command_pool = device->dispatch_command_pool; } - VkCommandBufferAllocateInfo allocate_info; - allocate_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; - allocate_info.pNext = nullptr; - allocate_info.commandPool = *command_pool; - allocate_info.commandBufferCount = 1; - allocate_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; + return iree_hal_vulkan_direct_command_buffer_allocate( + device->logical_device, command_pool, mode, command_categories, + device->descriptor_pool_cache, out_command_buffer); +} - VkCommandBuffer command_buffer = VK_NULL_HANDLE; - { - absl::MutexLock lock(command_pool->mutex()); - VK_RETURN_IF_ERROR(syms()->vkAllocateCommandBuffers( - *logical_device_, &allocate_info, &command_buffer)); +static iree_status_t iree_hal_vulkan_device_create_descriptor_set( + iree_hal_device_t* base_device, + iree_hal_descriptor_set_layout_t* set_layout, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_binding_t* bindings, + iree_hal_descriptor_set_t** out_descriptor_set) { + // TODO(benvanik): rework the create fn to take the bindings. + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "non-push descriptor sets still need work"); +} + +static iree_status_t iree_hal_vulkan_device_create_descriptor_set_layout( + iree_hal_device_t* base_device, + iree_hal_descriptor_set_layout_usage_type_t usage_type, + iree_host_size_t binding_count, + const iree_hal_descriptor_set_layout_binding_t* bindings, + iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) { + iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); + return iree_hal_vulkan_native_descriptor_set_layout_create( + device->logical_device, usage_type, binding_count, bindings, + out_descriptor_set_layout); +} + +static iree_status_t iree_hal_vulkan_device_create_event( + iree_hal_device_t* base_device, iree_hal_event_t** out_event) { + iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); + return iree_hal_vulkan_native_event_create(device->logical_device, out_event); +} + +static iree_status_t iree_hal_vulkan_device_create_executable_cache( + iree_hal_device_t* base_device, iree_string_view_t identifier, + iree_hal_executable_cache_t** out_executable_cache) { + iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); + return iree_hal_vulkan_nop_executable_cache_create( + device->logical_device, identifier, out_executable_cache); +} + +static iree_status_t iree_hal_vulkan_device_create_executable_layout( + iree_hal_device_t* base_device, iree_host_size_t set_layout_count, + iree_hal_descriptor_set_layout_t** set_layouts, + iree_host_size_t push_constants, + iree_hal_executable_layout_t** out_executable_layout) { + iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); + return iree_hal_vulkan_native_executable_layout_create( + device->logical_device, set_layout_count, set_layouts, push_constants, + out_executable_layout); +} + +static iree_status_t iree_hal_vulkan_device_create_semaphore( + iree_hal_device_t* base_device, uint64_t initial_value, + iree_hal_semaphore_t** out_semaphore) { + iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); + if (device->semaphore_pool != NULL) { + return iree_hal_vulkan_emulated_semaphore_create( + device->logical_device, device->semaphore_pool, device->queue_count, + device->queues, initial_value, out_semaphore); } - - return make_ref<DirectCommandBuffer>(mode, command_categories, - add_ref(descriptor_pool_cache_), - add_ref(command_pool), command_buffer); + return iree_hal_vulkan_native_semaphore_create(device->logical_device, + initial_value, out_semaphore); } -StatusOr<ref_ptr<Event>> VulkanDevice::CreateEvent() { - IREE_TRACE_SCOPE0("VulkanDevice::CreateEvent"); - - // TODO(b/138729892): pool events. - VkEventCreateInfo create_info; - create_info.sType = VK_STRUCTURE_TYPE_EVENT_CREATE_INFO; - create_info.pNext = nullptr; - create_info.flags = 0; - VkEvent event_handle = VK_NULL_HANDLE; - VK_RETURN_IF_ERROR(syms()->vkCreateEvent(*logical_device_, &create_info, - logical_device_->allocator(), - &event_handle)); - - return make_ref<NativeEvent>(add_ref(logical_device_), event_handle); -} - -StatusOr<ref_ptr<Semaphore>> VulkanDevice::CreateSemaphore( - uint64_t initial_value) { - IREE_TRACE_SCOPE0("VulkanDevice::CreateSemaphore"); - - if (emulating_timeline_semaphores()) { - return EmulatedTimelineSemaphore::Create( - add_ref(logical_device_), - // Triggers necessary processing on all queues due to new values gotten - // signaled for the given timeline |semaphore|. - // Different clang-format versions disagree about asterisk placement. - // clang-format off - [this](Semaphore* /*semaphore*/) -> Status { - // clang-format on - IREE_TRACE_SCOPE0("<lambda>::OnSemaphoreSignal"); - for (const auto& queue : command_queues_) { - IREE_RETURN_IF_ERROR( - static_cast<SerializingCommandQueue*>(queue.get()) - ->AdvanceQueueSubmission()); - } - return OkStatus(); - }, - // Triggers necessary processing on all queues due to failures for the - // given timeline |semaphore|. - [this](Semaphore* /*semaphore*/) { - IREE_TRACE_SCOPE0("<lambda>::OnSemaphoreFailure"); - for (const auto& queue : command_queues_) { - static_cast<SerializingCommandQueue*>(queue.get()) - ->AbortQueueSubmission(); - } - }, - // Triggers necessary processing on all queues due to the given |fence| - // being signaled. This allows the queue to drop the fence ref it holds - // even when we are not waiting on the queue directly. - [this](absl::Span<VkFence> fences) { - IREE_TRACE_SCOPE0("<lambda>::OnFenceSignal"); - for (const auto& queue : command_queues_) { - static_cast<SerializingCommandQueue*>(queue.get()) - ->SignalFences(fences); - } - }, - add_ref(semaphore_pool_), initial_value); +// Returns the queue to submit work to based on the |queue_affinity|. +static CommandQueue* iree_hal_vulkan_device_select_queue( + iree_hal_vulkan_device_t* device, + iree_hal_command_category_t command_categories, uint64_t queue_affinity) { + // TODO(benvanik): meaningful heuristics for affinity. We don't generate + // anything from the compiler that uses multiple queues and until we do it's + // best not to do anything too clever here. + if (command_categories == IREE_HAL_COMMAND_CATEGORY_TRANSFER) { + return device + ->transfer_queues[queue_affinity % device->transfer_queue_count]; } - - return NativeTimelineSemaphore::Create(add_ref(logical_device_), - initial_value); + return device->dispatch_queues[queue_affinity % device->dispatch_queue_count]; } -Status VulkanDevice::WaitAllSemaphores( - absl::Span<const SemaphoreValue> semaphores, Time deadline_ns) { - IREE_TRACE_SCOPE0("VulkanDevice::WaitAllSemaphores"); - return WaitSemaphores(semaphores, deadline_ns, /*wait_flags=*/0); +static iree_status_t iree_hal_vulkan_device_queue_submit( + iree_hal_device_t* base_device, + iree_hal_command_category_t command_categories, uint64_t queue_affinity, + iree_host_size_t batch_count, const iree_hal_submission_batch_t* batches) { + iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); + CommandQueue* queue = iree_hal_vulkan_device_select_queue( + device, command_categories, queue_affinity); + return queue->Submit(batch_count, batches); } -StatusOr<int> VulkanDevice::WaitAnySemaphore( - absl::Span<const SemaphoreValue> semaphores, Time deadline_ns) { - IREE_TRACE_SCOPE0("VulkanDevice::WaitAnySemaphore"); - return WaitSemaphores(semaphores, deadline_ns, - /*wait_flags=*/VK_SEMAPHORE_WAIT_ANY_BIT); -} - -Status VulkanDevice::WaitSemaphores(absl::Span<const SemaphoreValue> semaphores, - Time deadline_ns, - VkSemaphoreWaitFlags wait_flags) { - IREE_TRACE_SCOPE0("VulkanDevice::WaitSemaphores"); - - if (emulating_timeline_semaphores()) { - // TODO(antiagainst): We actually should get the fences associated with the - // emulated timeline semaphores so that we can wait them in a bunch. This - // implementation is problematic if we wait to wait any and we have the - // first semaphore taking extra long time but the following ones signal - // quickly. - for (int i = 0; i < semaphores.size(); ++i) { - auto* semaphore = - static_cast<EmulatedTimelineSemaphore*>(semaphores[i].semaphore); - IREE_RETURN_IF_ERROR(semaphore->Wait(semaphores[i].value, deadline_ns)); - if (wait_flags & VK_SEMAPHORE_WAIT_ANY_BIT) return OkStatus(); - } - - return OkStatus(); +static iree_status_t iree_hal_vulkan_device_wait_semaphores_with_deadline( + iree_hal_device_t* base_device, iree_hal_wait_mode_t wait_mode, + const iree_hal_semaphore_list_t* semaphore_list, iree_time_t deadline_ns) { + iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); + VkSemaphoreWaitFlags wait_flags = 0; + if (wait_mode == IREE_HAL_WAIT_MODE_ANY) { + wait_flags |= VK_SEMAPHORE_WAIT_ANY_BIT; } - - absl::InlinedVector<VkSemaphore, 4> semaphore_handles(semaphores.size()); - absl::InlinedVector<uint64_t, 4> semaphore_values(semaphores.size()); - for (int i = 0; i < semaphores.size(); ++i) { - semaphore_handles[i] = - static_cast<NativeTimelineSemaphore*>(semaphores[i].semaphore) - ->handle(); - semaphore_values[i] = semaphores[i].value; + if (device->semaphore_pool != NULL) { + return iree_hal_vulkan_emulated_semaphore_multi_wait( + device->logical_device, semaphore_list, deadline_ns, wait_flags); } - - VkSemaphoreWaitInfo wait_info; - wait_info.sType = VK_STRUCTURE_TYPE_SEMAPHORE_WAIT_INFO; - wait_info.pNext = nullptr; - wait_info.flags = wait_flags; - wait_info.semaphoreCount = semaphore_handles.size(); - wait_info.pSemaphores = semaphore_handles.data(); - wait_info.pValues = semaphore_values.data(); - - // NOTE: this may fail with a timeout (VK_TIMEOUT) or in the case of a - // device loss event may return either VK_SUCCESS *or* VK_ERROR_DEVICE_LOST. - // We may want to explicitly query for device loss after a successful wait - // to ensure we consistently return errors. - uint64_t timeout_ns = - static_cast<uint64_t>(DeadlineToRelativeTimeoutNanos(deadline_ns)); - VkResult result = - syms()->vkWaitSemaphores(*logical_device_, &wait_info, timeout_ns); - if (result == VK_ERROR_DEVICE_LOST) { - // Nothing we do now matters. - return VkResultToStatus(result, IREE_LOC); - } - - // TODO(benvanik): notify the resource timeline that it should check for the - // semaphores we waited on (including those already expired above). - - return OkStatus(); + return iree_hal_vulkan_native_semaphore_multi_wait( + device->logical_device, semaphore_list, deadline_ns, wait_flags); } -Status VulkanDevice::WaitIdle(Time deadline_ns) { - if (deadline_ns == InfiniteFuture()) { +static iree_status_t iree_hal_vulkan_device_wait_semaphores_with_timeout( + iree_hal_device_t* base_device, iree_hal_wait_mode_t wait_mode, + const iree_hal_semaphore_list_t* semaphore_list, + iree_duration_t timeout_ns) { + return iree_hal_vulkan_device_wait_semaphores_with_deadline( + base_device, wait_mode, semaphore_list, + iree_relative_timeout_to_deadline_ns(timeout_ns)); +} + +static iree_status_t iree_hal_vulkan_device_wait_idle_with_deadline( + iree_hal_device_t* base_device, iree_time_t deadline_ns) { + iree_hal_vulkan_device_t* device = iree_hal_vulkan_device_cast(base_device); + if (deadline_ns == IREE_TIME_INFINITE_FUTURE) { // Fast path for using vkDeviceWaitIdle, which is usually cheaper (as it // requires fewer calls into the driver). - IREE_TRACE_SCOPE0("VulkanDevice::WaitIdle#vkDeviceWaitIdle"); - VK_RETURN_IF_ERROR(syms()->vkDeviceWaitIdle(*logical_device_)); - return OkStatus(); + return VK_RESULT_TO_STATUS(device->logical_device->syms()->vkDeviceWaitIdle( + *device->logical_device), + "vkDeviceWaitIdle"); } - - IREE_TRACE_SCOPE0("VulkanDevice::WaitIdle#Semaphores"); - for (auto& command_queue : command_queues_) { - IREE_RETURN_IF_ERROR(command_queue->WaitIdle(deadline_ns)); + for (iree_host_size_t i = 0; i < device->queue_count; ++i) { + IREE_RETURN_IF_ERROR(device->queues[i]->WaitIdle(deadline_ns)); } - return OkStatus(); + return iree_ok_status(); } -} // namespace vulkan -} // namespace hal -} // namespace iree +static iree_status_t iree_hal_vulkan_device_wait_idle_with_timeout( + iree_hal_device_t* base_device, iree_duration_t timeout_ns) { + return iree_hal_vulkan_device_wait_idle_with_deadline( + base_device, iree_relative_timeout_to_deadline_ns(timeout_ns)); +} + +const iree_hal_device_vtable_t iree_hal_vulkan_device_vtable = { + /*.destroy=*/iree_hal_vulkan_device_destroy, + /*.id=*/iree_hal_vulkan_device_id, + /*.host_allocator=*/iree_hal_vulkan_device_host_allocator, + /*.device_allocator=*/iree_hal_vulkan_device_allocator, + /*.create_command_buffer=*/iree_hal_vulkan_device_create_command_buffer, + /*.create_descriptor_set=*/iree_hal_vulkan_device_create_descriptor_set, + /*.create_descriptor_set_layout=*/ + iree_hal_vulkan_device_create_descriptor_set_layout, + /*.create_event=*/iree_hal_vulkan_device_create_event, + /*.create_executable_cache=*/ + iree_hal_vulkan_device_create_executable_cache, + /*.create_executable_layout=*/ + iree_hal_vulkan_device_create_executable_layout, + /*.create_semaphore=*/iree_hal_vulkan_device_create_semaphore, + /*.queue_submit=*/iree_hal_vulkan_device_queue_submit, + /*.wait_semaphores_with_deadline=*/ + iree_hal_vulkan_device_wait_semaphores_with_deadline, + /*.wait_semaphores_with_timeout=*/ + iree_hal_vulkan_device_wait_semaphores_with_timeout, + /*.wait_idle_with_deadline=*/ + iree_hal_vulkan_device_wait_idle_with_deadline, + /*.wait_idle_with_timeout=*/ + iree_hal_vulkan_device_wait_idle_with_timeout, +};
diff --git a/iree/hal/vulkan/vulkan_device.h b/iree/hal/vulkan/vulkan_device.h index cb897f2..c34ad1c 100644 --- a/iree/hal/vulkan/vulkan_device.h +++ b/iree/hal/vulkan/vulkan_device.h
@@ -15,160 +15,31 @@ #ifndef IREE_HAL_VULKAN_VULKAN_DEVICE_H_ #define IREE_HAL_VULKAN_VULKAN_DEVICE_H_ -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on - -#include <functional> -#include <memory> - -#include "absl/container/inlined_vector.h" -#include "absl/types/span.h" -#include "iree/base/memory.h" -#include "iree/hal/cc/allocator.h" -#include "iree/hal/cc/debug_capture_manager.h" -#include "iree/hal/cc/device.h" -#include "iree/hal/cc/driver.h" -#include "iree/hal/cc/semaphore.h" -#include "iree/hal/vulkan/descriptor_pool_cache.h" +#include "iree/hal/api.h" +#include "iree/hal/vulkan/api.h" #include "iree/hal/vulkan/dynamic_symbols.h" -#include "iree/hal/vulkan/emulated_timeline_semaphore.h" #include "iree/hal/vulkan/extensibility_util.h" -#include "iree/hal/vulkan/handle_util.h" -#include "iree/hal/vulkan/vma_allocator.h" -namespace iree { -namespace hal { -namespace vulkan { +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus -// A set of queues within a specific queue family on a VkDevice. -struct QueueSet { - // The index of a particular queue family on a VkPhysicalDevice, as described - // by vkGetPhysicalDeviceQueueFamilyProperties. - uint32_t queue_family_index; +// Creates a device that owns and manages its own VkDevice. +// +// The |driver| will be retained for as long as the device is live such that if +// the driver owns the |instance| provided it is ensured to be valid. |driver| +// may be NULL if there is no parent driver to retain (such as when wrapping +// existing VkInstances provided by the application). +iree_status_t iree_hal_vulkan_device_create( + iree_hal_driver_t* driver, iree_string_view_t identifier, + iree_hal_vulkan_features_t enabled_features, + const iree_hal_vulkan_device_options_t* options, + iree_hal_vulkan_syms_t* instance_syms, VkInstance instance, + VkPhysicalDevice physical_device, iree_allocator_t host_allocator, + iree_hal_device_t** out_device); - // Bitfield of queue indices within the queue family at |queue_family_index|. - uint64_t queue_indices; -}; - -class VulkanDevice final : public Device { - public: - struct Options { - // Extensibility descriptions for the device. - ExtensibilitySpec extensibility_spec; - - // Options for Vulkan Memory Allocator (VMA). - VmaAllocator::Options vma_options; - - // Uses timeline semaphore emulation even if native support exists. - bool force_timeline_semaphore_emulation = false; - }; - - // Creates a device that manages its own VkDevice. - static StatusOr<ref_ptr<VulkanDevice>> Create( - ref_ptr<Driver> driver, VkInstance instance, - const DeviceInfo& device_info, VkPhysicalDevice physical_device, - Options options, const ref_ptr<DynamicSymbols>& syms, - DebugCaptureManager* debug_capture_manager); - - // Creates a device that wraps an externally managed VkDevice. - static StatusOr<ref_ptr<VulkanDevice>> Wrap( - ref_ptr<Driver> driver, VkInstance instance, - const DeviceInfo& device_info, VkPhysicalDevice physical_device, - VkDevice logical_device, Options options, - const QueueSet& compute_queue_set, const QueueSet& transfer_queue_set, - const ref_ptr<DynamicSymbols>& syms); - - ~VulkanDevice() override; - - std::string DebugString() const override; - - const ref_ptr<DynamicSymbols>& syms() const { - return logical_device_->syms(); - } - - Allocator* allocator() const override { return allocator_.get(); } - - absl::Span<CommandQueue*> dispatch_queues() const override { - return absl::MakeSpan(dispatch_queues_); - } - - absl::Span<CommandQueue*> transfer_queues() const override { - return absl::MakeSpan(transfer_queues_); - } - - ref_ptr<ExecutableCache> CreateExecutableCache() override; - - StatusOr<ref_ptr<DescriptorSetLayout>> CreateDescriptorSetLayout( - iree_hal_descriptor_set_layout_usage_type_t usage_type, - absl::Span<const iree_hal_descriptor_set_layout_binding_t> bindings) - override; - - StatusOr<ref_ptr<ExecutableLayout>> CreateExecutableLayout( - absl::Span<DescriptorSetLayout* const> set_layouts, - size_t push_constants) override; - - StatusOr<ref_ptr<DescriptorSet>> CreateDescriptorSet( - DescriptorSetLayout* set_layout, - absl::Span<const iree_hal_descriptor_set_binding_t> bindings) override; - - StatusOr<ref_ptr<CommandBuffer>> CreateCommandBuffer( - iree_hal_command_buffer_mode_t mode, - iree_hal_command_category_t command_categories) override; - - StatusOr<ref_ptr<Event>> CreateEvent() override; - - StatusOr<ref_ptr<Semaphore>> CreateSemaphore(uint64_t initial_value) override; - Status WaitAllSemaphores(absl::Span<const SemaphoreValue> semaphores, - Time deadline_ns) override; - StatusOr<int> WaitAnySemaphore(absl::Span<const SemaphoreValue> semaphores, - Time deadline_ns) override; - - Status WaitIdle(Time deadline_ns) override; - - private: - VulkanDevice( - ref_ptr<Driver> driver, const DeviceInfo& device_info, - VkPhysicalDevice physical_device, ref_ptr<VkDeviceHandle> logical_device, - std::unique_ptr<Allocator> allocator, - absl::InlinedVector<std::unique_ptr<CommandQueue>, 4> command_queues, - ref_ptr<VkCommandPoolHandle> dispatch_command_pool, - ref_ptr<VkCommandPoolHandle> transfer_command_pool, - ref_ptr<TimePointSemaphorePool> semaphore_pool, - ref_ptr<TimePointFencePool> fence_pool, - DebugCaptureManager* debug_capture_manager); - - Status WaitSemaphores(absl::Span<const SemaphoreValue> semaphores, - Time deadline_ns, VkSemaphoreWaitFlags wait_flags); - - bool emulating_timeline_semaphores() const { - return semaphore_pool_ != nullptr; - } - - ref_ptr<Driver> driver_; - VkPhysicalDevice physical_device_; - ref_ptr<VkDeviceHandle> logical_device_; - - std::unique_ptr<Allocator> allocator_; - - mutable absl::InlinedVector<std::unique_ptr<CommandQueue>, 4> command_queues_; - mutable absl::InlinedVector<CommandQueue*, 4> dispatch_queues_; - mutable absl::InlinedVector<CommandQueue*, 4> transfer_queues_; - - ref_ptr<DescriptorPoolCache> descriptor_pool_cache_; - - ref_ptr<VkCommandPoolHandle> dispatch_command_pool_; - ref_ptr<VkCommandPoolHandle> transfer_command_pool_; - - // Fields used for emulated timeline semaphores. - ref_ptr<TimePointSemaphorePool> semaphore_pool_; - ref_ptr<TimePointFencePool> fence_pool_; - - DebugCaptureManager* debug_capture_manager_ = nullptr; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #endif // IREE_HAL_VULKAN_VULKAN_DEVICE_H_
diff --git a/iree/hal/vulkan/vulkan_driver.cc b/iree/hal/vulkan/vulkan_driver.cc index b1e44ce..5466f17 100644 --- a/iree/hal/vulkan/vulkan_driver.cc +++ b/iree/hal/vulkan/vulkan_driver.cc
@@ -16,43 +16,338 @@ #include <memory> -#include "absl/container/inlined_vector.h" #include "iree/base/memory.h" -#include "iree/base/status.h" -#include "iree/base/target_platform.h" #include "iree/base/tracing.h" -#include "iree/hal/cc/device_info.h" +#include "iree/hal/vulkan/api.h" +#include "iree/hal/vulkan/debug_reporter.h" +#include "iree/hal/vulkan/dynamic_symbols.h" #include "iree/hal/vulkan/extensibility_util.h" #include "iree/hal/vulkan/status_util.h" +#include "iree/hal/vulkan/vulkan_device.h" -namespace iree { -namespace hal { -namespace vulkan { +using namespace iree::hal::vulkan; -namespace { +typedef struct { + iree_hal_resource_t resource; + iree_allocator_t host_allocator; + + // Identifier used for the driver in the IREE driver registry. + // We allow overriding so that multiple Vulkan versions can be exposed in the + // same process. + iree_string_view_t identifier; + + iree_hal_vulkan_device_options_t device_options; + int default_device_index; + + iree_hal_vulkan_features_t enabled_features; + + // Which optional extensions are active and available on the instance. + iree_hal_vulkan_instance_extensions_t instance_extensions; + + // (Partial) loaded Vulkan symbols. Devices created within the driver may have + // different function pointers for device-specific functions that change + // behavior with enabled layers/extensions. + iree::ref_ptr<DynamicSymbols> syms; + + // The Vulkan instance that all devices created from the driver will share. + VkInstance instance; + bool owns_instance; + + // Optional debug reporter: may be disabled or unavailable (no debug layers). + iree_hal_vulkan_debug_reporter_t* debug_reporter; +} iree_hal_vulkan_driver_t; + +extern const iree_hal_driver_vtable_t iree_hal_vulkan_driver_vtable; + +static iree_hal_vulkan_driver_t* iree_hal_vulkan_driver_cast( + iree_hal_driver_t* base_value) { + IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_vulkan_driver_vtable); + return (iree_hal_vulkan_driver_t*)base_value; +} + +IREE_API_EXPORT void IREE_API_CALL iree_hal_vulkan_driver_options_initialize( + iree_hal_vulkan_driver_options_t* out_options) { + memset(out_options, 0, sizeof(*out_options)); + out_options->api_version = VK_API_VERSION_1_2; + out_options->requested_features = 0; + iree_hal_vulkan_device_options_initialize(&out_options->device_options); + out_options->default_device_index = 0; +} // Returns a VkApplicationInfo struct populated with the default app info. // We may allow hosting applications to override this via weak-linkage if it's // useful, otherwise this is enough to create the application. -VkApplicationInfo GetDefaultApplicationInfo() { - VkApplicationInfo info; - info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; - info.pNext = nullptr; - info.pApplicationName = "IREE-ML"; - info.applicationVersion = 0; - info.pEngineName = "IREE"; - info.engineVersion = 0; -#ifdef IREE_PLATFORM_ANDROID - info.apiVersion = VK_API_VERSION_1_1; -#else - info.apiVersion = VK_API_VERSION_1_2; -#endif - return info; +static void iree_hal_vulkan_driver_populate_default_app_info( + const iree_hal_vulkan_driver_options_t* options, + VkApplicationInfo* out_app_info) { + memset(out_app_info, 0, sizeof(*out_app_info)); + out_app_info->sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; + out_app_info->pNext = NULL; + out_app_info->pApplicationName = "IREE-ML"; + out_app_info->applicationVersion = 0; + out_app_info->pEngineName = "IREE"; + out_app_info->engineVersion = 0; + out_app_info->apiVersion = options->api_version; +} + +// NOTE: takes ownership of |instance|. +static iree_status_t iree_hal_vulkan_driver_create_internal( + iree_string_view_t identifier, + const iree_hal_vulkan_driver_options_t* options, + const iree_hal_vulkan_string_list_t* enabled_extensions, + iree_hal_vulkan_syms_t* opaque_syms, VkInstance instance, + bool owns_instance, iree_allocator_t host_allocator, + iree_hal_driver_t** out_driver) { + auto* instance_syms = (DynamicSymbols*)opaque_syms; + + iree_hal_vulkan_instance_extensions_t instance_extensions = + iree_hal_vulkan_populate_enabled_instance_extensions(enabled_extensions); + + // The real debug messenger (not just the static one used above) can now be + // created as we've loaded all the required symbols. + // TODO(benvanik): strip in min-size release builds. + iree_hal_vulkan_debug_reporter_t* debug_reporter = NULL; + if (instance_extensions.debug_utils) { + IREE_RETURN_IF_ERROR(iree_hal_vulkan_debug_reporter_allocate( + instance, instance_syms, /*allocation_callbacks=*/NULL, host_allocator, + &debug_reporter)); + } + + iree_hal_vulkan_driver_t* driver = NULL; + iree_host_size_t total_size = sizeof(*driver) + identifier.size; + iree_status_t status = + iree_allocator_malloc(host_allocator, total_size, (void**)&driver); + if (!iree_status_is_ok(status)) { + // Need to clean up if we fail (as we own these). + iree_hal_vulkan_debug_reporter_free(debug_reporter); + return status; + } + iree_hal_resource_initialize(&iree_hal_vulkan_driver_vtable, + &driver->resource); + driver->host_allocator = host_allocator; + iree_string_view_append_to_buffer( + identifier, &driver->identifier, + (char*)driver + total_size - identifier.size); + memcpy(&driver->device_options, &options->device_options, + sizeof(driver->device_options)); + driver->default_device_index = options->default_device_index; + driver->enabled_features = options->requested_features; + driver->syms = iree::add_ref(instance_syms); + driver->instance = instance; + driver->owns_instance = owns_instance; + driver->debug_reporter = debug_reporter; + *out_driver = (iree_hal_driver_t*)driver; + return status; +} + +static void iree_hal_vulkan_driver_destroy(iree_hal_driver_t* base_driver) { + iree_hal_vulkan_driver_t* driver = iree_hal_vulkan_driver_cast(base_driver); + iree_allocator_t host_allocator = driver->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_vulkan_debug_reporter_free(driver->debug_reporter); + if (driver->owns_instance) { + driver->syms->vkDestroyInstance(driver->instance, /*pAllocator=*/NULL); + } + driver->syms.reset(); + iree_allocator_free(host_allocator, driver); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_status_t iree_hal_vulkan_driver_query_extensibility_set( + iree_hal_vulkan_features_t requested_features, + iree_hal_vulkan_extensibility_set_t set, iree::Arena* arena, + iree_hal_vulkan_string_list_t* out_string_list) { + IREE_RETURN_IF_ERROR(iree_hal_vulkan_query_extensibility_set( + requested_features, set, 0, NULL, &out_string_list->count)); + out_string_list->values = (const char**)arena->AllocateBytes( + out_string_list->count * sizeof(out_string_list->values[0])); + IREE_RETURN_IF_ERROR(iree_hal_vulkan_query_extensibility_set( + requested_features, set, out_string_list->count, out_string_list->values, + &out_string_list->count)); + return iree_ok_status(); +} + +static iree_status_t iree_hal_vulkan_driver_compute_enabled_extensibility_sets( + iree::hal::vulkan::DynamicSymbols* syms, + iree_hal_vulkan_features_t requested_features, iree::Arena* arena, + iree_hal_vulkan_string_list_t* out_enabled_layers, + iree_hal_vulkan_string_list_t* out_enabled_extensions) { + // Query our required and optional layers and extensions based on the IREE + // features the user requested. + iree_hal_vulkan_string_list_t required_layers; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_driver_query_extensibility_set( + requested_features, + IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_LAYERS_REQUIRED, arena, + &required_layers)); + iree_hal_vulkan_string_list_t optional_layers; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_driver_query_extensibility_set( + requested_features, + IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_LAYERS_OPTIONAL, arena, + &optional_layers)); + iree_hal_vulkan_string_list_t required_extensions; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_driver_query_extensibility_set( + requested_features, + IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_EXTENSIONS_REQUIRED, arena, + &required_extensions)); + iree_hal_vulkan_string_list_t optional_extensions; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_driver_query_extensibility_set( + requested_features, + IREE_HAL_VULKAN_EXTENSIBILITY_INSTANCE_EXTENSIONS_OPTIONAL, arena, + &optional_extensions)); + + // Find the layers and extensions we need (or want) that are also available + // on the instance. This will fail when required ones are not present. + IREE_RETURN_IF_ERROR(iree_hal_vulkan_match_available_instance_layers( + syms, &required_layers, &optional_layers, arena, out_enabled_layers)); + IREE_RETURN_IF_ERROR(iree_hal_vulkan_match_available_instance_extensions( + syms, &required_extensions, &optional_extensions, arena, + out_enabled_extensions)); + + return iree_ok_status(); +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_vulkan_driver_create( + iree_string_view_t identifier, + const iree_hal_vulkan_driver_options_t* options, + iree_hal_vulkan_syms_t* opaque_syms, iree_allocator_t host_allocator, + iree_hal_driver_t** out_driver) { + IREE_ASSERT_ARGUMENT(options); + IREE_ASSERT_ARGUMENT(opaque_syms); + IREE_ASSERT_ARGUMENT(out_driver); + IREE_TRACE_SCOPE(); + + auto* instance_syms = (DynamicSymbols*)opaque_syms; + + // Query required and optional instance layers/extensions for the requested + // features. + iree::Arena arena; + iree_hal_vulkan_string_list_t enabled_layers; + iree_hal_vulkan_string_list_t enabled_extensions; + IREE_RETURN_IF_ERROR( + iree_hal_vulkan_driver_compute_enabled_extensibility_sets( + instance_syms, options->requested_features, &arena, &enabled_layers, + &enabled_extensions)); + + // Create the instance this driver will use for all requests. + VkApplicationInfo app_info; + iree_hal_vulkan_driver_populate_default_app_info(options, &app_info); + VkInstanceCreateInfo create_info; + create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; + create_info.pNext = NULL; + create_info.flags = 0; + create_info.pApplicationInfo = &app_info; + create_info.enabledLayerCount = enabled_layers.count; + create_info.ppEnabledLayerNames = enabled_layers.values; + create_info.enabledExtensionCount = enabled_extensions.count; + create_info.ppEnabledExtensionNames = enabled_extensions.values; + + VkInstance instance = VK_NULL_HANDLE; + VK_RETURN_IF_ERROR(instance_syms->vkCreateInstance( + &create_info, /*pAllocator=*/NULL, &instance), + "vkCreateInstance: invalid instance configuration"); + + // Now that the instance has been created we can fetch all of the instance + // symbols. + iree_status_t status = instance_syms->LoadFromInstance(instance); + + if (iree_status_is_ok(status)) { + status = iree_hal_vulkan_driver_create_internal( + identifier, options, &enabled_extensions, opaque_syms, instance, + /*owns_instance=*/true, host_allocator, out_driver); + } + + if (!iree_status_is_ok(status)) { + instance_syms->vkDestroyInstance(instance, /*pAllocator=*/NULL); + } + return status; +} + +IREE_API_EXPORT iree_status_t IREE_API_CALL +iree_hal_vulkan_driver_create_using_instance( + iree_string_view_t identifier, + const iree_hal_vulkan_driver_options_t* options, + iree_hal_vulkan_syms_t* opaque_syms, VkInstance instance, + iree_allocator_t host_allocator, iree_hal_driver_t** out_driver) { + IREE_ASSERT_ARGUMENT(options); + IREE_ASSERT_ARGUMENT(opaque_syms); + IREE_ASSERT_ARGUMENT(out_driver); + if (instance == VK_NULL_HANDLE) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "a non-NULL VkInstance must be provided"); + } + IREE_TRACE_ZONE_BEGIN(z0); + + // May be a no-op but don't rely on that so we can be sure we have the right + // function pointers. + auto* instance_syms = (DynamicSymbols*)opaque_syms; + IREE_RETURN_IF_ERROR(instance_syms->LoadFromInstance(instance)); + + // Since the instance is already created we can't actually enable any + // extensions or even query if they are really enabled - we just have to trust + // that the caller already enabled them for us (or we may fail later). + iree::Arena arena; + iree_hal_vulkan_string_list_t enabled_layers; + iree_hal_vulkan_string_list_t enabled_extensions; + IREE_RETURN_IF_ERROR( + iree_hal_vulkan_driver_compute_enabled_extensibility_sets( + instance_syms, options->requested_features, &arena, &enabled_layers, + &enabled_extensions)); + + iree_status_t status = iree_hal_vulkan_driver_create_internal( + identifier, options, &enabled_extensions, opaque_syms, instance, + /*owns_instance=*/true, host_allocator, out_driver); + IREE_TRACE_ZONE_END(z0); + return status; +} + +// Enumerates all physical devices on |instance| and returns them as an +// allocated list in |out_physical_devices|, which must be freed by the caller. +static iree_status_t iree_hal_vulkan_driver_enumerate_physical_devices( + iree::hal::vulkan::DynamicSymbols* instance_syms, VkInstance instance, + iree_allocator_t host_allocator, uint32_t* out_physical_device_count, + VkPhysicalDevice** out_physical_devices) { + uint32_t physical_device_count = 0; + VK_RETURN_IF_ERROR(instance_syms->vkEnumeratePhysicalDevices( + instance, &physical_device_count, NULL), + "vkEnumeratePhysicalDevices"); + VkPhysicalDevice* physical_devices = NULL; + IREE_RETURN_IF_ERROR(iree_allocator_malloc( + host_allocator, physical_device_count * sizeof(physical_devices), + (void**)&physical_devices)); + iree_status_t status = VK_RESULT_TO_STATUS( + instance_syms->vkEnumeratePhysicalDevices( + instance, &physical_device_count, physical_devices), + "vkEnumeratePhysicalDevices"); + if (iree_status_is_ok(status)) { + *out_physical_device_count = physical_device_count; + *out_physical_devices = physical_devices; + } else { + iree_allocator_free(host_allocator, physical_devices); + } + return status; +} + +// Returns the size, in bytes, of the iree_hal_device_info_t storage required +// for holding the given |physical_device|. +static iree_host_size_t iree_hal_vulkan_calculate_device_info_size( + VkPhysicalDevice physical_device, iree::hal::vulkan::DynamicSymbols* syms) { + VkPhysicalDeviceProperties physical_device_properties; + syms->vkGetPhysicalDeviceProperties(physical_device, + &physical_device_properties); + return strlen(physical_device_properties.deviceName); } // Populates device information from the given Vulkan physical device handle. -StatusOr<DeviceInfo> PopulateDeviceInfo(VkPhysicalDevice physical_device, - const ref_ptr<DynamicSymbols>& syms) { +// |out_device_info| must point to valid memory and additional data will be +// appended to |buffer_ptr| and the new pointer is returned. +static uint8_t* iree_hal_vulkan_populate_device_info( + VkPhysicalDevice physical_device, DynamicSymbols* syms, uint8_t* buffer_ptr, + iree_hal_device_info_t* out_device_info) { + memset(out_device_info, 0, sizeof(*out_device_info)); + out_device_info->device_id = (iree_hal_device_id_t)physical_device; + VkPhysicalDeviceFeatures physical_device_features; syms->vkGetPhysicalDeviceFeatures(physical_device, &physical_device_features); // TODO(benvanik): check and optionally require these features: @@ -67,257 +362,121 @@ // TODO(benvanik): check and optionally require reasonable limits. // TODO(benvanik): more clever/sanitized device naming. - std::string name = std::string(physical_device_properties.deviceName); + iree_string_view_t device_name = + iree_make_string_view(physical_device_properties.deviceName, + strlen(physical_device_properties.deviceName)); + buffer_ptr += iree_string_view_append_to_buffer( + device_name, &out_device_info->name, (char*)buffer_ptr); - iree_hal_device_feature_t supported_features = IREE_HAL_DEVICE_FEATURE_NONE; - return DeviceInfo("vulkan", std::move(name), supported_features, - reinterpret_cast<iree_hal_device_id_t>(physical_device)); + return buffer_ptr; } -} // namespace +static iree_status_t iree_hal_vulkan_driver_query_available_devices( + iree_hal_driver_t* base_driver, iree_allocator_t host_allocator, + iree_hal_device_info_t** out_device_infos, + iree_host_size_t* out_device_info_count) { + iree_hal_vulkan_driver_t* driver = iree_hal_vulkan_driver_cast(base_driver); -// static -StatusOr<ref_ptr<VulkanDriver>> VulkanDriver::Create( - Options options, ref_ptr<DynamicSymbols> syms) { - IREE_TRACE_SCOPE0("VulkanDriver::Create"); - - // Load and connect to RenderDoc before instance creation. - // Note: RenderDoc assumes that only a single VkDevice is used: - // https://renderdoc.org/docs/behind_scenes/vulkan_support.html#current-support - std::unique_ptr<RenderDocCaptureManager> renderdoc_capture_manager; - if (options.enable_renderdoc) { - renderdoc_capture_manager = std::make_unique<RenderDocCaptureManager>(); - IREE_RETURN_IF_ERROR(renderdoc_capture_manager->Connect()); - } - - // Find the layers and extensions we need (or want) that are also available - // on the instance. This will fail when required ones are not present. - IREE_ASSIGN_OR_RETURN( - auto enabled_layer_names, - MatchAvailableInstanceLayers(options.instance_extensibility, *syms)); - IREE_ASSIGN_OR_RETURN( - auto enabled_extension_names, - MatchAvailableInstanceExtensions(options.instance_extensibility, *syms)); - auto instance_extensions = - PopulateEnabledInstanceExtensions(enabled_extension_names); - - // Create the instance this driver will use for all requests. - VkApplicationInfo app_info = GetDefaultApplicationInfo(); - app_info.apiVersion = options.api_version; - VkInstanceCreateInfo create_info; - create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; - create_info.pNext = nullptr; - create_info.flags = 0; - create_info.pApplicationInfo = &app_info; - create_info.enabledLayerCount = - static_cast<uint32_t>(enabled_layer_names.size()); - create_info.ppEnabledLayerNames = enabled_layer_names.data(); - create_info.enabledExtensionCount = - static_cast<uint32_t>(enabled_extension_names.size()); - create_info.ppEnabledExtensionNames = enabled_extension_names.data(); - - // If we have the debug_utils extension then we can chain a one-shot messenger - // callback that we can use to log out the instance creation errors. Once we - // have the real instance we can then register a real messenger. - union { - VkDebugUtilsMessengerCreateInfoEXT debug_utils_create_info; - VkDebugReportCallbackCreateInfoEXT debug_report_create_info; - }; - if (instance_extensions.debug_utils) { - create_info.pNext = &debug_utils_create_info; - DebugReporter::PopulateStaticCreateInfo(&debug_utils_create_info); - } else if (instance_extensions.debug_report) { - create_info.pNext = &debug_report_create_info; - DebugReporter::PopulateStaticCreateInfo(&debug_report_create_info); - } - - // Some ICDs appear to leak in here, out of our control. - // Warning: leak checks remain disabled if an error is returned. - IREE_DISABLE_LEAK_CHECKS(); - VkInstance instance = VK_NULL_HANDLE; - VK_RETURN_IF_ERROR( - syms->vkCreateInstance(&create_info, /*pAllocator=*/nullptr, &instance)) - << "Unable to create Vulkan instance"; - IREE_ENABLE_LEAK_CHECKS(); - - // TODO(benvanik): enable validation layers if needed. - - // Now that the instance has been created we can fetch all of the instance - // symbols. - IREE_RETURN_IF_ERROR(syms->LoadFromInstance(instance)); - - // The real debug messenger (not just the static one used above) can now be - // created as we've loaded all the required symbols. - // TODO(benvanik): strip in release builds. - std::unique_ptr<DebugReporter> debug_reporter; - if (instance_extensions.debug_utils) { - IREE_ASSIGN_OR_RETURN(debug_reporter, - DebugReporter::CreateDebugUtilsMessenger( - instance, syms, - /*allocation_callbacks=*/nullptr)); - } else if (instance_extensions.debug_report) { - IREE_ASSIGN_OR_RETURN( - debug_reporter, DebugReporter::CreateDebugReportCallback( - instance, syms, /*allocation_callbacks=*/nullptr)); - } - - return assign_ref(new VulkanDriver( - std::move(syms), instance, - /*owns_instance=*/true, std::move(options.device_options), - options.default_device_index, std::move(debug_reporter), - std::move(renderdoc_capture_manager))); -} - -// static -StatusOr<ref_ptr<VulkanDriver>> VulkanDriver::CreateUsingInstance( - Options options, ref_ptr<DynamicSymbols> syms, VkInstance instance) { - IREE_TRACE_SCOPE0("VulkanDriver::CreateUsingInstance"); - - if (instance == VK_NULL_HANDLE) { - return InvalidArgumentErrorBuilder(IREE_LOC) - << "VkInstance must not be VK_NULL_HANDLE"; - } - - // Find the extensions we need (or want) that are also available on the - // instance. This will fail when required ones are not present. - // - // Since the instance is already created, we can't actually enable any - // extensions or query if they are really enabled - we just have to trust - // that the caller already enabled them for us (or we may fail later). - IREE_ASSIGN_OR_RETURN( - auto enabled_extension_names, - MatchAvailableInstanceExtensions(options.instance_extensibility, *syms)); - auto instance_extensions = - PopulateEnabledInstanceExtensions(enabled_extension_names); - - IREE_RETURN_IF_ERROR(syms->LoadFromInstance(instance)); - - // TODO(benvanik): strip in release builds. - std::unique_ptr<DebugReporter> debug_reporter; - if (instance_extensions.debug_utils) { - IREE_ASSIGN_OR_RETURN(debug_reporter, - DebugReporter::CreateDebugUtilsMessenger( - instance, syms, - /*allocation_callbacks=*/nullptr)); - } else if (instance_extensions.debug_report) { - IREE_ASSIGN_OR_RETURN( - debug_reporter, DebugReporter::CreateDebugReportCallback( - instance, syms, /*allocation_callbacks=*/nullptr)); - } - - // Note: no RenderDocCaptureManager here since the VkInstance is already - // created externally. Applications using this function must provide their - // own RenderDoc / debugger integration as desired. - - return assign_ref( - new VulkanDriver(std::move(syms), instance, /*owns_instance=*/false, - std::move(options.device_options), - options.default_device_index, std::move(debug_reporter), - /*debug_capture_manager=*/nullptr)); -} - -VulkanDriver::VulkanDriver( - ref_ptr<DynamicSymbols> syms, VkInstance instance, bool owns_instance, - VulkanDevice::Options device_options, int default_device_index, - std::unique_ptr<DebugReporter> debug_reporter, - std::unique_ptr<RenderDocCaptureManager> renderdoc_capture_manager) - : Driver("vulkan"), - syms_(std::move(syms)), - instance_(instance), - owns_instance_(owns_instance), - device_options_(std::move(device_options)), - default_device_index_(default_device_index), - debug_reporter_(std::move(debug_reporter)), - renderdoc_capture_manager_(std::move(renderdoc_capture_manager)) {} - -VulkanDriver::~VulkanDriver() { - IREE_TRACE_SCOPE0("VulkanDriver::dtor"); - debug_reporter_.reset(); - if (owns_instance_) { - syms()->vkDestroyInstance(instance_, /*pAllocator=*/nullptr); - } -} - -StatusOr<std::vector<DeviceInfo>> VulkanDriver::EnumerateAvailableDevices() { - IREE_TRACE_SCOPE0("VulkanDriver::EnumerateAvailableDevices"); - - // Query all available devices (at this moment, note that this may change!). + // Query all devices from the Vulkan instance. uint32_t physical_device_count = 0; - VK_RETURN_IF_ERROR(syms()->vkEnumeratePhysicalDevices( - instance_, &physical_device_count, nullptr)); - absl::InlinedVector<VkPhysicalDevice, 2> physical_devices( - physical_device_count); - VK_RETURN_IF_ERROR(syms()->vkEnumeratePhysicalDevices( - instance_, &physical_device_count, physical_devices.data())); + VkPhysicalDevice* physical_devices = NULL; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_driver_enumerate_physical_devices( + driver->syms.get(), driver->instance, host_allocator, + &physical_device_count, &physical_devices)); - // Convert to our HAL structure. - std::vector<DeviceInfo> device_infos; - device_infos.reserve(physical_device_count); - for (auto physical_device : physical_devices) { - // TODO(benvanik): if we fail should we just ignore the device in the list? - IREE_ASSIGN_OR_RETURN(auto device_info, - PopulateDeviceInfo(physical_device, syms())); - device_infos.push_back(std::move(device_info)); + // Allocate the return infos and populate with the devices. + iree_hal_device_info_t* device_infos = NULL; + iree_host_size_t total_size = + physical_device_count * sizeof(iree_hal_device_info_t); + for (uint32_t i = 0; i < physical_device_count; ++i) { + total_size += iree_hal_vulkan_calculate_device_info_size( + physical_devices[i], driver->syms.get()); } - return device_infos; -} - -StatusOr<ref_ptr<Device>> VulkanDriver::CreateDefaultDevice() { - IREE_TRACE_SCOPE0("VulkanDriver::CreateDefaultDevice"); - - // Query available devices. - IREE_ASSIGN_OR_RETURN(auto available_devices, EnumerateAvailableDevices()); - if (default_device_index_ < 0 || - default_device_index_ >= available_devices.size()) { - return NotFoundErrorBuilder(IREE_LOC) - << "Device index " << default_device_index_ << " not found " - << "(of " << available_devices.size() << ")"; + iree_status_t status = + iree_allocator_malloc(host_allocator, total_size, (void**)&device_infos); + if (iree_status_is_ok(status)) { + uint8_t* buffer_ptr = + (uint8_t*)device_infos + + physical_device_count * sizeof(iree_hal_device_info_t); + for (uint32_t i = 0; i < physical_device_count; ++i) { + buffer_ptr = iree_hal_vulkan_populate_device_info( + physical_devices[i], driver->syms.get(), buffer_ptr, + &device_infos[i]); + } + *out_device_info_count = physical_device_count; + *out_device_infos = device_infos; } - // Just create the first one we find. - return CreateDevice(available_devices[default_device_index_].device_id()); + iree_allocator_free(host_allocator, physical_devices); + return status; } -StatusOr<ref_ptr<Device>> VulkanDriver::CreateDevice( - iree_hal_device_id_t device_id) { - IREE_TRACE_SCOPE0("VulkanDriver::CreateDevice"); +static iree_status_t iree_hal_vulkan_driver_select_default_device( + iree::hal::vulkan::DynamicSymbols* instance_syms, VkInstance instance, + int default_device_index, iree_allocator_t host_allocator, + VkPhysicalDevice* out_physical_device) { + uint32_t physical_device_count = 0; + VkPhysicalDevice* physical_devices = NULL; + IREE_RETURN_IF_ERROR(iree_hal_vulkan_driver_enumerate_physical_devices( + instance_syms, instance, host_allocator, &physical_device_count, + &physical_devices)); + iree_status_t status = iree_ok_status(); + if (physical_device_count == 0 || + default_device_index >= physical_device_count) { + status = iree_make_status(IREE_STATUS_NOT_FOUND, + "default device %d not found (of %d enumerated)", + default_device_index, physical_device_count); + } else { + *out_physical_device = physical_devices[default_device_index]; + } + iree_allocator_free(host_allocator, physical_devices); + return status; +} - auto physical_device = reinterpret_cast<VkPhysicalDevice>(device_id); - IREE_ASSIGN_OR_RETURN(auto device_info, - PopulateDeviceInfo(physical_device, syms())); +static iree_status_t iree_hal_vulkan_driver_create_device( + iree_hal_driver_t* base_driver, iree_hal_device_id_t device_id, + iree_allocator_t host_allocator, iree_hal_device_t** out_device) { + iree_hal_vulkan_driver_t* driver = iree_hal_vulkan_driver_cast(base_driver); + IREE_TRACE_ZONE_BEGIN(z0); + + // Use either the specified device (enumerated earlier) or whatever default + // one was specified when the driver was created. + VkPhysicalDevice physical_device = (VkPhysicalDevice)device_id; + if (physical_device == VK_NULL_HANDLE) { + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + iree_hal_vulkan_driver_select_default_device( + driver->syms.get(), driver->instance, driver->default_device_index, + host_allocator, &physical_device)); + } + + // TODO(benvanik): remove HAL module dependence on the identifier for matching + // devices. Today it *must* be vulkan* to work, whereas really that should be + // a device type (vs the identifier, which is arbitrary). + // Query the device name to use as an identifier. + // VkPhysicalDeviceProperties physical_device_properties; + // driver->syms->vkGetPhysicalDeviceProperties(physical_device, + // &physical_device_properties); + // iree_string_view_t device_name = + // iree_make_string_view(physical_device_properties.deviceName, + // strlen(physical_device_properties.deviceName)); + iree_string_view_t device_name = iree_make_cstring_view("vulkan"); // Attempt to create the device. // This may fail if the device was enumerated but is in exclusive use, // disabled by the system, or permission is denied. - IREE_ASSIGN_OR_RETURN( - auto device, - VulkanDevice::Create(add_ref(this), instance(), device_info, - physical_device, device_options_, syms(), - renderdoc_capture_manager_.get())); + iree_status_t status = iree_hal_vulkan_device_create( + base_driver, device_name, driver->enabled_features, + &driver->device_options, (iree_hal_vulkan_syms_t*)driver->syms.get(), + driver->instance, physical_device, host_allocator, out_device); - IREE_LOG(INFO) << "Created Vulkan Device: " << device->info().name(); - - return device; + IREE_TRACE_ZONE_END(z0); + return status; } -StatusOr<ref_ptr<Device>> VulkanDriver::WrapDevice( - VkPhysicalDevice physical_device, VkDevice logical_device, - const QueueSet& compute_queue_set, const QueueSet& transfer_queue_set) { - IREE_TRACE_SCOPE0("VulkanDriver::WrapDevice"); - - IREE_ASSIGN_OR_RETURN(auto device_info, - PopulateDeviceInfo(physical_device, syms())); - - // Attempt to create the device. - // This may fail if the VkDevice does not support all necessary features. - IREE_ASSIGN_OR_RETURN( - auto device, - VulkanDevice::Wrap(add_ref(this), instance(), device_info, - physical_device, logical_device, device_options_, - compute_queue_set, transfer_queue_set, syms())); - return device; -} - -} // namespace vulkan -} // namespace hal -} // namespace iree +const iree_hal_driver_vtable_t iree_hal_vulkan_driver_vtable = { + /*.destroy=*/iree_hal_vulkan_driver_destroy, + /*.query_available_devices=*/ + iree_hal_vulkan_driver_query_available_devices, + /*.create_device=*/iree_hal_vulkan_driver_create_device, +};
diff --git a/iree/hal/vulkan/vulkan_driver.h b/iree/hal/vulkan/vulkan_driver.h index b5500ea..8f53786 100644 --- a/iree/hal/vulkan/vulkan_driver.h +++ b/iree/hal/vulkan/vulkan_driver.h
@@ -15,106 +15,11 @@ #ifndef IREE_HAL_VULKAN_VULKAN_DRIVER_H_ #define IREE_HAL_VULKAN_VULKAN_DRIVER_H_ -// clang-format off: Must be included before all other headers: -#include "iree/hal/vulkan/vulkan_headers.h" -// clang-format on +#include "iree/hal/api.h" +#include "iree/hal/vulkan/api.h" -#include <memory> -#include <vector> - -#include "iree/hal/cc/driver.h" -#include "iree/hal/vulkan/debug_reporter.h" -#include "iree/hal/vulkan/dynamic_symbols.h" -#include "iree/hal/vulkan/extensibility_util.h" -#include "iree/hal/vulkan/renderdoc_capture_manager.h" -#include "iree/hal/vulkan/vulkan_device.h" - -namespace iree { -namespace hal { -namespace vulkan { - -class VulkanDriver final : public Driver { - public: - struct Options { - // Vulkan version that will be requested. - // Driver creation will fail if the required version is not available. - uint32_t api_version = VK_API_VERSION_1_0; - - // Extensibility descriptions for instances. - // See VulkanDevice::Options for device extensibility descriptions. - ExtensibilitySpec instance_extensibility; - - // Options to use for all devices created by the driver. - VulkanDevice::Options device_options; - - // Index of the default Vulkan device to use within the list of available - // devices. Devices are discovered via vkEnumeratePhysicalDevices then - // considered "available" if compatible with the driver options. - int default_device_index = 0; - - // Enables RenderDoc integration, connecting via RenderDoc's API and - // recording Vulkan calls for offline inspection and debugging. - bool enable_renderdoc = false; - }; - - // Creates a VulkanDriver that manages its own VkInstance. - static StatusOr<ref_ptr<VulkanDriver>> Create(Options options, - ref_ptr<DynamicSymbols> syms); - - // Creates a VulkanDriver that shares an externally managed VkInstance. - // - // |options| are checked for compatibility. - // - // |syms| must at least have |vkGetInstanceProcAddr| set. Other symbols will - // be loaded as needed from |instance|. - // - // |instance| must remain valid for the life of the returned VulkanDriver. - static StatusOr<ref_ptr<VulkanDriver>> CreateUsingInstance( - Options options, ref_ptr<DynamicSymbols> syms, VkInstance instance); - - ~VulkanDriver() override; - - const ref_ptr<DynamicSymbols>& syms() const { return syms_; } - - VkInstance instance() const { return instance_; } - - StatusOr<std::vector<DeviceInfo>> EnumerateAvailableDevices() override; - - StatusOr<ref_ptr<Device>> CreateDefaultDevice() override; - - StatusOr<ref_ptr<Device>> CreateDevice( - iree_hal_device_id_t device_id) override; - - // Creates a device that wraps an externally managed VkDevice. - // - // The device will schedule commands against the provided queues. - StatusOr<ref_ptr<Device>> WrapDevice(VkPhysicalDevice physical_device, - VkDevice logical_device, - const QueueSet& compute_queue_set, - const QueueSet& transfer_queue_set); - - DebugCaptureManager* debug_capture_manager() override { - return renderdoc_capture_manager_.get(); - } - - private: - VulkanDriver( - ref_ptr<DynamicSymbols> syms, VkInstance instance, bool owns_instance, - VulkanDevice::Options device_options, int default_device_index, - std::unique_ptr<DebugReporter> debug_reporter, - std::unique_ptr<RenderDocCaptureManager> renderdoc_capture_manager); - - ref_ptr<DynamicSymbols> syms_; - VkInstance instance_; - bool owns_instance_; - VulkanDevice::Options device_options_; - int default_device_index_; - std::unique_ptr<DebugReporter> debug_reporter_; - std::unique_ptr<RenderDocCaptureManager> renderdoc_capture_manager_; -}; - -} // namespace vulkan -} // namespace hal -} // namespace iree +// NOTE: the driver API calls are defined in api.h. +// TODO(benvanik): clean that up? api.h is nice because then we only need to +// deploy a single header file for the backend, but it is a bit tricky. #endif // IREE_HAL_VULKAN_VULKAN_DRIVER_H_
diff --git a/iree/samples/vulkan/vulkan_inference_gui.cc b/iree/samples/vulkan/vulkan_inference_gui.cc index 0a85be2..779a43f 100644 --- a/iree/samples/vulkan/vulkan_inference_gui.cc +++ b/iree/samples/vulkan/vulkan_inference_gui.cc
@@ -88,9 +88,8 @@ // Setup Vulkan iree_hal_vulkan_features_t iree_vulkan_features = static_cast<iree_hal_vulkan_features_t>( - IREE_HAL_VULKAN_ENABLE_VALIDATION_LAYERS | - IREE_HAL_VULKAN_ENABLE_DEBUG_UTILS | - IREE_HAL_VULKAN_ENABLE_PUSH_DESCRIPTORS); + IREE_HAL_VULKAN_FEATURE_ENABLE_VALIDATION_LAYERS | + IREE_HAL_VULKAN_FEATURE_ENABLE_DEBUG_UTILS); std::vector<const char*> layers = GetInstanceLayers(iree_vulkan_features); std::vector<const char*> extensions = GetInstanceExtensions(window, iree_vulkan_features); @@ -208,8 +207,7 @@ iree_hal_vulkan_driver_options_t options; options.api_version = VK_API_VERSION_1_0; options.features = static_cast<iree_hal_vulkan_features_t>( - IREE_HAL_VULKAN_ENABLE_DEBUG_UTILS | - IREE_HAL_VULKAN_ENABLE_PUSH_DESCRIPTORS); + IREE_HAL_VULKAN_FEATURE_ENABLE_DEBUG_UTILS); IREE_CHECK_OK(iree_hal_vulkan_driver_create_using_instance( options, iree_vk_syms, g_Instance, &iree_vk_driver)); // Create a device sharing our VkDevice and queue.
diff --git a/iree/testing/vulkan/iree-run-module-vulkan-gui-main.cc b/iree/testing/vulkan/iree-run-module-vulkan-gui-main.cc index 28bccd0..6ab6197 100644 --- a/iree/testing/vulkan/iree-run-module-vulkan-gui-main.cc +++ b/iree/testing/vulkan/iree-run-module-vulkan-gui-main.cc
@@ -168,9 +168,8 @@ // Setup Vulkan iree_hal_vulkan_features_t iree_vulkan_features = static_cast<iree_hal_vulkan_features_t>( - IREE_HAL_VULKAN_ENABLE_VALIDATION_LAYERS | - IREE_HAL_VULKAN_ENABLE_DEBUG_UTILS | - IREE_HAL_VULKAN_ENABLE_PUSH_DESCRIPTORS); + IREE_HAL_VULKAN_FEATURE_ENABLE_VALIDATION_LAYERS | + IREE_HAL_VULKAN_FEATURE_ENABLE_DEBUG_UTILS); std::vector<const char*> layers = GetInstanceLayers(iree_vulkan_features); std::vector<const char*> extensions = GetInstanceExtensions(window, iree_vulkan_features); @@ -281,8 +280,7 @@ iree_hal_vulkan_driver_options_t options; options.api_version = VK_API_VERSION_1_0; options.features = static_cast<iree_hal_vulkan_features_t>( - IREE_HAL_VULKAN_ENABLE_DEBUG_UTILS | - IREE_HAL_VULKAN_ENABLE_PUSH_DESCRIPTORS); + IREE_HAL_VULKAN_FEATURE_ENABLE_DEBUG_UTILS); IREE_CHECK_OK(iree_hal_vulkan_driver_create_using_instance( options, iree_vk_syms, g_Instance, &iree_vk_driver)); // Create a device sharing our VkDevice and queue. This makes capturing with