| // Copyright 2019 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "iree/modules/hal/hal_module.h" |
| |
| #include "absl/base/macros.h" |
| #include "absl/container/inlined_vector.h" |
| #include "absl/memory/memory.h" |
| #include "absl/types/span.h" |
| #include "iree/base/api.h" |
| #include "iree/base/api_util.h" |
| #include "iree/base/tracing.h" |
| #include "iree/hal/api.h" |
| #include "iree/hal/api_detail.h" |
| #include "iree/hal/command_queue.h" |
| #include "iree/hal/device.h" |
| #include "iree/vm/module_abi_cc.h" |
| |
| namespace iree { |
| namespace hal { |
| namespace { |
| |
| //===----------------------------------------------------------------------===// |
| // Type registration |
| //===----------------------------------------------------------------------===// |
| |
| static iree_vm_ref_type_descriptor_t iree_hal_allocator_descriptor = {0}; |
| static iree_vm_ref_type_descriptor_t iree_hal_buffer_descriptor = {0}; |
| static iree_vm_ref_type_descriptor_t iree_hal_buffer_view_descriptor = {0}; |
| static iree_vm_ref_type_descriptor_t iree_hal_command_buffer_descriptor = {0}; |
| static iree_vm_ref_type_descriptor_t iree_hal_descriptor_set_descriptor = {0}; |
| static iree_vm_ref_type_descriptor_t iree_hal_descriptor_set_layout_descriptor = |
| {0}; |
| static iree_vm_ref_type_descriptor_t iree_hal_device_descriptor = {0}; |
| static iree_vm_ref_type_descriptor_t iree_hal_executable_descriptor = {0}; |
| static iree_vm_ref_type_descriptor_t iree_hal_executable_cache_descriptor = {0}; |
| static iree_vm_ref_type_descriptor_t iree_hal_executable_layout_descriptor = { |
| 0}; |
| |
| IREE_API_EXPORT iree_status_t IREE_API_CALL iree_hal_module_register_types() { |
| static bool has_registered = false; |
| if (has_registered) return IREE_STATUS_OK; |
| |
| IREE_VM_REGISTER_CC_TYPE(Allocator, "hal.allocator", |
| iree_hal_allocator_descriptor); |
| IREE_VM_REGISTER_CC_TYPE(Buffer, "hal.buffer", iree_hal_buffer_descriptor); |
| IREE_VM_REGISTER_CC_TYPE(iree_hal_buffer_view, "hal.buffer_view", |
| iree_hal_buffer_view_descriptor); |
| IREE_VM_REGISTER_CC_TYPE(CommandBuffer, "hal.command_buffer", |
| iree_hal_command_buffer_descriptor); |
| IREE_VM_REGISTER_CC_TYPE(DescriptorSet, "hal.descriptor_set", |
| iree_hal_descriptor_set_descriptor); |
| IREE_VM_REGISTER_CC_TYPE(DescriptorSetLayout, "hal.descriptor_set_layout", |
| iree_hal_descriptor_set_layout_descriptor); |
| IREE_VM_REGISTER_CC_TYPE(Device, "hal.device", iree_hal_device_descriptor); |
| IREE_VM_REGISTER_CC_TYPE(Executable, "hal.executable", |
| iree_hal_executable_descriptor); |
| IREE_VM_REGISTER_CC_TYPE(ExecutableCache, "hal.executable_cache", |
| iree_hal_executable_cache_descriptor); |
| IREE_VM_REGISTER_CC_TYPE(ExecutableLayout, "hal.executable_layout", |
| iree_hal_executable_layout_descriptor); |
| |
| has_registered = true; |
| return IREE_STATUS_OK; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Type wrappers |
| //===----------------------------------------------------------------------===// |
| |
| IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_allocator, iree_hal_allocator_t); |
| IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_buffer, iree_hal_buffer_t); |
| IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_buffer_view, iree_hal_buffer_view_t); |
| IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_command_buffer, |
| iree_hal_command_buffer_t); |
| IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_descriptor_set, |
| iree_hal_descriptor_set_t); |
| IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_descriptor_set_layout, |
| iree_hal_descriptor_set_layout_t); |
| IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_device, iree_hal_device_t); |
| IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_executable, iree_hal_executable_t); |
| IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_executable_cache, |
| iree_hal_executable_cache_t); |
| IREE_VM_DEFINE_TYPE_ADAPTERS(iree_hal_executable_layout, |
| iree_hal_executable_layout_t); |
| |
| //===----------------------------------------------------------------------===// |
| // Module type definitions |
| //===----------------------------------------------------------------------===// |
| |
| class HALModuleState final { |
| public: |
| HALModuleState(iree_allocator_t allocator, ref_ptr<Device> shared_device, |
| ref_ptr<ExecutableCache> executable_cache) |
| : allocator_(allocator), shared_device_(std::move(shared_device)) {} |
| |
| ~HALModuleState() { |
| for (auto& ref : deferred_releases_) { |
| iree_vm_ref_release(&ref); |
| } |
| deferred_releases_.clear(); |
| } |
| |
| //===--------------------------------------------------------------------===// |
| // Experimental APIs |
| //===--------------------------------------------------------------------===// |
| // NOTE: Ex* APIs are experimental and likely to be removed soon. Modules |
| // using these APIs are not forward compatible. |
| |
| StatusOr<vm::ref<iree_hal_device_t>> ExSharedDevice() { |
| return vm::retain_ref( |
| reinterpret_cast<iree_hal_device_t*>(shared_device_.get())); |
| } |
| |
| Status ExDeferRelease(absl::optional<vm::opaque_ref> operand) { |
| if (operand.has_value()) { |
| deferred_releases_.push_back({0}); |
| iree_vm_ref_move(&operand.value(), &deferred_releases_.back()); |
| } |
| return OkStatus(); |
| } |
| |
| Status ExSubmitAndWait(vm::ref<iree_hal_device_t> device, |
| vm::ref<iree_hal_command_buffer_t> command_buffer) { |
| IREE_TRACE_SCOPE0("HALModuleState::ExSubmitAndWait"); |
| |
| auto* device_ptr = reinterpret_cast<Device*>(device.get()); |
| auto* queue = device_ptr->dispatch_queues().front(); |
| ASSIGN_OR_RETURN(auto fence, device_ptr->CreateFence(0u)); |
| SubmissionBatch batch; |
| CommandBuffer* command_buffers[1] = { |
| reinterpret_cast<CommandBuffer*>(command_buffer.get())}; |
| batch.command_buffers = absl::MakeConstSpan(command_buffers); |
| RETURN_IF_ERROR(queue->Submit(batch, {fence.get(), 1u})); |
| RETURN_IF_ERROR(queue->WaitIdle()); |
| |
| for (auto& ref : deferred_releases_) { |
| iree_vm_ref_release(&ref); |
| } |
| deferred_releases_.clear(); |
| |
| return OkStatus(); |
| } |
| |
| //===--------------------------------------------------------------------===// |
| // iree::hal::Allocator |
| //===--------------------------------------------------------------------===// |
| |
| StatusOr<int32_t> AllocatorComputeSize( |
| vm::ref<iree_hal_allocator_t> allocator, absl::Span<const int32_t> shape, |
| iree_hal_element_type_t element_type) { |
| iree_device_size_t allocation_size = 0; |
| RETURN_IF_ERROR(FromApiStatus( |
| iree_hal_allocator_compute_size(allocator.get(), shape.data(), |
| shape.size(), element_type, |
| &allocation_size), |
| IREE_LOC)); |
| return static_cast<int32_t>(allocation_size); |
| } |
| |
| StatusOr<int32_t> AllocatorComputeOffset( |
| vm::ref<iree_hal_allocator_t> allocator, absl::Span<const int32_t> shape, |
| iree_hal_element_type_t element_type, absl::Span<const int32_t> indices) { |
| iree_device_size_t offset = 0; |
| RETURN_IF_ERROR(FromApiStatus( |
| iree_hal_allocator_compute_offset( |
| allocator.get(), shape.data(), shape.size(), element_type, |
| indices.data(), indices.size(), &offset), |
| IREE_LOC)); |
| return static_cast<int32_t>(offset); |
| } |
| |
| StatusOr<std::tuple<int32_t, int32_t>> AllocatorComputeRange( |
| vm::ref<iree_hal_allocator_t> allocator, absl::Span<const int32_t> shape, |
| iree_hal_element_type_t element_type, |
| absl::Span<const int32_t> start_indices, |
| absl::Span<const int32_t> lengths) { |
| iree_device_size_t offset = 0; |
| iree_device_size_t length = 0; |
| RETURN_IF_ERROR(FromApiStatus( |
| iree_hal_allocator_compute_range( |
| allocator.get(), shape.data(), shape.size(), element_type, |
| start_indices.data(), start_indices.size(), lengths.data(), |
| lengths.size(), &offset, &length), |
| IREE_LOC)); |
| return std::make_tuple(static_cast<int32_t>(offset), |
| static_cast<int32_t>(length)); |
| } |
| |
| StatusOr<vm::ref<iree_hal_buffer_t>> AllocatorAllocate( |
| vm::ref<iree_hal_allocator_t> allocator, |
| iree_hal_memory_type_t memory_types, iree_hal_buffer_usage_t buffer_usage, |
| int32_t allocation_size) { |
| IREE_TRACE_SCOPE0("HALModuleState::AllocatorAllocate"); |
| vm::ref<iree_hal_buffer_t> buffer; |
| RETURN_IF_ERROR(FromApiStatus(iree_hal_allocator_allocate_buffer( |
| allocator.get(), memory_types, |
| buffer_usage, allocation_size, &buffer), |
| IREE_LOC)); |
| return std::move(buffer); |
| } |
| |
| StatusOr<vm::ref<iree_hal_buffer_t>> AllocatorAllocateConst( |
| vm::ref<iree_hal_allocator_t> allocator, |
| iree_hal_memory_type_t memory_types, iree_hal_buffer_usage_t buffer_usage, |
| absl::Span<const int32_t> shape, iree_hal_element_type_t element_type, |
| vm::ref<iree_vm_ro_byte_buffer_t> value) { |
| IREE_TRACE_SCOPE0("HALModuleState::AllocatorAllocateConst"); |
| |
| iree_device_size_t allocation_size = 0; |
| RETURN_IF_ERROR(FromApiStatus( |
| iree_hal_allocator_compute_size(allocator.get(), shape.data(), |
| shape.size(), element_type, |
| &allocation_size), |
| IREE_LOC)); |
| if (allocation_size < value->data.data_length) { |
| return InvalidArgumentErrorBuilder(IREE_LOC) |
| << "Constant data is too larger for the minimum allocation size"; |
| } |
| |
| vm::ref<iree_hal_buffer_t> buffer; |
| RETURN_IF_ERROR(FromApiStatus(iree_hal_allocator_allocate_buffer( |
| allocator.get(), memory_types, |
| buffer_usage, allocation_size, &buffer), |
| IREE_LOC)) |
| << "Failed to allocate buffer"; |
| |
| RETURN_IF_ERROR(FromApiStatus( |
| iree_hal_buffer_write_data(buffer.get(), 0, value->data.data, |
| value->data.data_length), |
| IREE_LOC)) |
| << "Writing constant data"; |
| |
| return buffer; |
| } |
| |
| //===--------------------------------------------------------------------===// |
| // iree::hal::Buffer |
| //===--------------------------------------------------------------------===// |
| |
| StatusOr<vm::ref<iree_hal_allocator_t>> BufferAllocator( |
| vm::ref<iree_hal_buffer_t> buffer) { |
| return vm::retain_ref(iree_hal_buffer_allocator(buffer.get())); |
| } |
| |
| StatusOr<vm::ref<iree_hal_buffer_t>> BufferSubspan( |
| vm::ref<iree_hal_buffer_t> source_buffer, int32_t source_offset, |
| int32_t length) { |
| IREE_TRACE_SCOPE0("HALModuleState::BufferSubspan"); |
| return UnimplementedErrorBuilder(IREE_LOC) << "BufferSubspan"; |
| } |
| |
| Status BufferFill(vm::ref<iree_hal_buffer_t> target_buffer, |
| int32_t target_offset, int32_t length, int32_t pattern) { |
| IREE_TRACE_SCOPE0("HALModuleState::BufferFill"); |
| return UnimplementedErrorBuilder(IREE_LOC) << "BufferFill"; |
| } |
| |
| Status BufferReadData(vm::ref<iree_hal_buffer_t> source_buffer, |
| int32_t source_offset, |
| vm::ref<iree_vm_rw_byte_buffer_t> target_buffer, |
| int32_t target_offset, int32_t length) { |
| IREE_TRACE_SCOPE0("HALModuleState::BufferReadData"); |
| return UnimplementedErrorBuilder(IREE_LOC) << "BufferReadData"; |
| } |
| |
| Status BufferWriteData(vm::ref<iree_hal_buffer_t> target_buffer, |
| int32_t target_offset, |
| vm::ref<iree_vm_ro_byte_buffer_t> source_buffer, |
| int32_t source_offset, int32_t length) { |
| IREE_TRACE_SCOPE0("HALModuleState::BufferWriteData"); |
| return UnimplementedErrorBuilder(IREE_LOC) << "BufferWriteData"; |
| } |
| |
| Status BufferCopyData(vm::ref<iree_hal_buffer_t> source_buffer, |
| int32_t source_offset, |
| vm::ref<iree_hal_buffer_t> target_buffer, |
| int32_t target_offset, int32_t length) { |
| IREE_TRACE_SCOPE0("HALModuleState::BufferCopyData"); |
| return UnimplementedErrorBuilder(IREE_LOC) << "BufferCopyData"; |
| } |
| |
| StatusOr<int32_t> BufferLoad(vm::ref<iree_hal_buffer_t> source_buffer, |
| int32_t source_offset, int32_t length) { |
| IREE_TRACE_SCOPE0("HALModuleState::BufferLoad"); |
| |
| uint32_t target_buffer = 0; |
| if (length > sizeof(target_buffer)) { |
| return InvalidArgumentErrorBuilder(IREE_LOC) |
| << "Length " << length << " exceeds max"; |
| } |
| |
| RETURN_IF_ERROR(FromApiStatus( |
| iree_hal_buffer_read_data(source_buffer.get(), source_offset, |
| &target_buffer, length), |
| IREE_LOC)) |
| << "Read failed"; |
| return target_buffer; |
| } |
| |
| Status BufferStore(int32_t value, vm::ref<iree_hal_buffer_t> target_buffer, |
| int32_t target_offset, int32_t length) { |
| IREE_TRACE_SCOPE0("HALModuleState::BufferStore"); |
| |
| if (target_offset + length > |
| iree_hal_buffer_byte_length(target_buffer.get())) { |
| return OutOfRangeErrorBuilder(IREE_LOC) << "Out of bounds store"; |
| } else if (length > sizeof(value)) { |
| return InvalidArgumentErrorBuilder(IREE_LOC) |
| << "Length " << length << " exceeds max"; |
| } |
| |
| RETURN_IF_ERROR( |
| FromApiStatus(iree_hal_buffer_write_data(target_buffer.get(), |
| target_offset, &value, length), |
| IREE_LOC)) |
| << "Write failed"; |
| return OkStatus(); |
| } |
| |
| //===--------------------------------------------------------------------===// |
| // iree::hal::BufferView |
| //===--------------------------------------------------------------------===// |
| |
| StatusOr<vm::ref<iree_hal_buffer_view_t>> BufferViewCreate( |
| vm::ref<iree_hal_buffer_t> buffer, absl::Span<const int32_t> shape, |
| iree_hal_element_type_t element_type) { |
| vm::ref<iree_hal_buffer_view_t> buffer_view; |
| RETURN_IF_ERROR(FromApiStatus( |
| iree_hal_buffer_view_create(buffer.get(), shape.data(), shape.size(), |
| element_type, allocator_, &buffer_view), |
| IREE_LOC)) |
| << "Failed to create buffer view"; |
| return std::move(buffer_view); |
| } |
| |
| StatusOr<vm::ref<iree_hal_buffer_view_t>> BufferViewSubview( |
| vm::ref<iree_hal_buffer_view_t> buffer_view, |
| absl::Span<const int32_t> indices, absl::Span<const int32_t> lengths) { |
| vm::ref<iree_hal_buffer_view_t> new_buffer_view; |
| RETURN_IF_ERROR(FromApiStatus( |
| iree_hal_buffer_view_subview( |
| buffer_view.get(), indices.data(), indices.size(), lengths.data(), |
| lengths.size(), allocator_, &new_buffer_view), |
| IREE_LOC)) |
| << "Failed to create subview"; |
| return std::move(new_buffer_view); |
| } |
| |
| StatusOr<vm::ref<iree_hal_buffer_t>> BufferViewBuffer( |
| vm::ref<iree_hal_buffer_view_t> buffer_view) { |
| return vm::retain_ref(iree_hal_buffer_view_buffer(buffer_view.get())); |
| } |
| |
| StatusOr<int32_t> BufferViewByteLength( |
| vm::ref<iree_hal_buffer_view_t> buffer_view) { |
| return iree_hal_buffer_view_byte_length(buffer_view.get()); |
| } |
| |
| StatusOr<int32_t> BufferViewComputeOffset( |
| vm::ref<iree_hal_buffer_view_t> buffer_view, |
| absl::Span<const int32_t> indices) { |
| iree_device_size_t offset = 0; |
| RETURN_IF_ERROR(FromApiStatus( |
| iree_hal_buffer_view_compute_offset(buffer_view.get(), indices.data(), |
| indices.size(), &offset), |
| IREE_LOC)); |
| return offset; |
| } |
| |
| StatusOr<std::tuple<int32_t, int32_t>> BufferViewComputeRange( |
| vm::ref<iree_hal_buffer_view_t> buffer_view, |
| absl::Span<const int32_t> start_indices, |
| absl::Span<const int32_t> lengths) { |
| iree_device_size_t start_offset = 0; |
| iree_device_size_t subspan_length = 0; |
| RETURN_IF_ERROR(FromApiStatus( |
| iree_hal_buffer_view_compute_range( |
| buffer_view.get(), start_indices.data(), start_indices.size(), |
| lengths.data(), lengths.size(), &start_offset, &subspan_length), |
| IREE_LOC)); |
| return std::make_tuple<int32_t, int32_t>( |
| static_cast<int32_t>(start_offset), |
| static_cast<int32_t>(subspan_length)); |
| } |
| |
| StatusOr<int32_t> BufferViewRank( |
| vm::ref<iree_hal_buffer_view_t> buffer_view) { |
| return static_cast<int32_t>( |
| iree_hal_buffer_view_shape_rank(buffer_view.get())); |
| } |
| |
| StatusOr<int32_t> BufferViewDim(vm::ref<iree_hal_buffer_view_t> buffer_view, |
| int32_t index) { |
| return static_cast<int32_t>( |
| iree_hal_buffer_view_shape_dim(buffer_view.get(), index)); |
| } |
| |
| template <size_t N> |
| StatusOr<std::array<int32_t, N>> BufferViewDimsN( |
| vm::ref<iree_hal_buffer_view_t> buffer_view) { |
| std::array<int32_t, N> value; |
| iree_host_size_t rank = 0; |
| RETURN_IF_ERROR(FromApiStatus( |
| iree_hal_buffer_view_shape(buffer_view.get(), N, value.data(), &rank), |
| IREE_LOC)); |
| return value; |
| } |
| |
| StatusOr<std::array<int32_t, 1>> BufferViewDims1( |
| vm::ref<iree_hal_buffer_view_t> buffer_view) { |
| return BufferViewDimsN<1>(std::move(buffer_view)); |
| } |
| |
| StatusOr<std::array<int32_t, 2>> BufferViewDims2( |
| vm::ref<iree_hal_buffer_view_t> buffer_view) { |
| return BufferViewDimsN<2>(std::move(buffer_view)); |
| } |
| |
| StatusOr<std::array<int32_t, 3>> BufferViewDims3( |
| vm::ref<iree_hal_buffer_view_t> buffer_view) { |
| return BufferViewDimsN<3>(std::move(buffer_view)); |
| } |
| |
| StatusOr<std::array<int32_t, 4>> BufferViewDims4( |
| vm::ref<iree_hal_buffer_view_t> buffer_view) { |
| return BufferViewDimsN<4>(std::move(buffer_view)); |
| } |
| |
| //===--------------------------------------------------------------------===// |
| // iree::hal::CommandBuffer |
| //===--------------------------------------------------------------------===// |
| |
| StatusOr<vm::ref<iree_hal_command_buffer_t>> CommandBufferCreate( |
| vm::ref<iree_hal_device_t> device, iree_hal_command_buffer_mode_t modes, |
| iree_hal_command_category_t command_categories) { |
| IREE_TRACE_SCOPE0("HALModuleState::CommandBufferCreate"); |
| |
| vm::ref<iree_hal_command_buffer_t> command_buffer; |
| RETURN_IF_ERROR(FromApiStatus( |
| iree_hal_command_buffer_create(device.get(), modes, command_categories, |
| IREE_ALLOCATOR_SYSTEM, &command_buffer), |
| IREE_LOC)) |
| << "Failed to create command buffer"; |
| return command_buffer; |
| } |
| |
| Status CommandBufferBegin(vm::ref<iree_hal_command_buffer_t> command_buffer) { |
| IREE_TRACE_SCOPE0("HALModuleState::CommandBufferBegin"); |
| RETURN_IF_ERROR(FromApiStatus( |
| iree_hal_command_buffer_begin(command_buffer.get()), IREE_LOC)) |
| << "Failed to begin command buffer recording"; |
| return OkStatus(); |
| } |
| |
| Status CommandBufferEnd(vm::ref<iree_hal_command_buffer_t> command_buffer) { |
| IREE_TRACE_SCOPE0("HALModuleState::CommandBufferEnd"); |
| RETURN_IF_ERROR(FromApiStatus( |
| iree_hal_command_buffer_end(command_buffer.get()), IREE_LOC)) |
| << "Failed to end command buffer recording"; |
| return OkStatus(); |
| } |
| |
| Status CommandBufferExecutionBarrier( |
| vm::ref<iree_hal_command_buffer_t> command_buffer, |
| iree_hal_execution_stage_t source_stage_mask, |
| iree_hal_execution_stage_t target_stage_mask, |
| absl::Span<const int32_t> memory_barriers, |
| absl::Span<const int32_t> buffer_barriers) { |
| IREE_TRACE_SCOPE0("HALModuleState::CommandBufferExecutionBarrier"); |
| |
| // TODO(benvanik): decode barriers. |
| iree_hal_memory_barrier_t global_barrier; |
| global_barrier.source_scope = IREE_HAL_ACCESS_SCOPE_DISPATCH_WRITE; |
| global_barrier.target_scope = IREE_HAL_ACCESS_SCOPE_DISPATCH_READ; |
| RETURN_IF_ERROR( |
| FromApiStatus(iree_hal_command_buffer_execution_barrier( |
| command_buffer.get(), source_stage_mask, |
| target_stage_mask, 1, &global_barrier, 0, nullptr), |
| IREE_LOC)); |
| return OkStatus(); |
| } |
| |
| Status CommandBufferFillBuffer( |
| vm::ref<iree_hal_command_buffer_t> command_buffer, |
| vm::ref<iree_hal_buffer_t> target_buffer, int32_t target_offset, |
| int32_t length, uint32_t pattern) { |
| IREE_TRACE_SCOPE0("HALModuleState::CommandBufferFillBuffer"); |
| RETURN_IF_ERROR(FromApiStatus( |
| iree_hal_command_buffer_fill_buffer(command_buffer.get(), |
| target_buffer.get(), target_offset, |
| length, &pattern, sizeof(pattern)), |
| IREE_LOC)); |
| return OkStatus(); |
| } |
| |
| Status CommandBufferCopyBuffer( |
| vm::ref<iree_hal_command_buffer_t> command_buffer, |
| vm::ref<iree_hal_buffer_t> source_buffer, int32_t source_offset, |
| vm::ref<iree_hal_buffer_t> target_buffer, int32_t target_offset, |
| int32_t length) { |
| IREE_TRACE_SCOPE0("HALModuleState::CommandBufferCopyBuffer"); |
| RETURN_IF_ERROR(FromApiStatus( |
| iree_hal_command_buffer_copy_buffer( |
| command_buffer.get(), source_buffer.get(), source_offset, |
| target_buffer.get(), target_offset, length), |
| IREE_LOC)); |
| return OkStatus(); |
| } |
| |
| Status CommandBufferPushConstants( |
| vm::ref<iree_hal_command_buffer_t> command_buffer, |
| vm::ref<iree_hal_executable_layout_t> executable_layout, uint32_t offset, |
| absl::Span<const uint32_t> values) { |
| IREE_TRACE_SCOPE0("HALModuleState::CommandBufferPushConstants"); |
| RETURN_IF_ERROR( |
| FromApiStatus(iree_hal_command_buffer_push_constants( |
| command_buffer.get(), executable_layout.get(), offset, |
| values.data(), values.size() * sizeof(uint32_t)), |
| IREE_LOC)); |
| return OkStatus(); |
| } |
| |
| Status CommandBufferPushDescriptorSet( |
| vm::ref<iree_hal_command_buffer_t> command_buffer, |
| vm::ref<iree_hal_executable_layout_t> executable_layout, int32_t set, |
| absl::Span<const int32_t> binding_ordinals, |
| absl::Span<const vm::ref<iree_hal_buffer_t>> binding_buffers, |
| absl::Span<const int32_t> binding_offsets, |
| absl::Span<const int32_t> binding_lengths) { |
| IREE_TRACE_SCOPE0("HALModuleState::CommandBufferPushDescriptorSet"); |
| absl::InlinedVector<iree_hal_descriptor_set_binding_t, 4> binding_structs( |
| binding_ordinals.size()); |
| for (int i = 0; i < binding_ordinals.size(); ++i) { |
| binding_structs[i] = { |
| binding_ordinals[i], binding_buffers[i].get(), |
| static_cast<iree_device_size_t>(binding_offsets[i]), |
| static_cast<iree_device_size_t>(binding_lengths[i])}; |
| deferred_releases_.push_back( |
| iree_hal_buffer_retain_ref(binding_buffers[i].get())); |
| } |
| RETURN_IF_ERROR( |
| FromApiStatus(iree_hal_command_buffer_push_descriptor_set( |
| command_buffer.get(), executable_layout.get(), set, |
| binding_structs.size(), binding_structs.data()), |
| IREE_LOC)); |
| return OkStatus(); |
| } |
| |
| Status CommandBufferBindDescriptorSet( |
| vm::ref<iree_hal_command_buffer_t> command_buffer, |
| vm::ref<iree_hal_executable_layout_t> executable_layout, int32_t set, |
| vm::ref<iree_hal_descriptor_set_t> descriptor_set, |
| absl::Span<const int32_t> dynamic_offsets) { |
| IREE_TRACE_SCOPE0("HALModuleState::CommandBufferBindDescriptorSet"); |
| absl::InlinedVector<iree_device_size_t, 4> dynamic_offset_values( |
| dynamic_offsets.size()); |
| for (int i = 0; i < dynamic_offsets.size(); ++i) { |
| dynamic_offset_values[i] = |
| static_cast<iree_device_size_t>(dynamic_offsets[i]); |
| } |
| RETURN_IF_ERROR( |
| FromApiStatus(iree_hal_command_buffer_bind_descriptor_set( |
| command_buffer.get(), executable_layout.get(), set, |
| descriptor_set.get(), dynamic_offset_values.size(), |
| dynamic_offset_values.data()), |
| IREE_LOC)); |
| return OkStatus(); |
| } |
| |
| Status CommandBufferDispatch( |
| vm::ref<iree_hal_command_buffer_t> command_buffer, |
| vm::ref<iree_hal_executable_t> executable, int32_t entry_point, |
| uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { |
| IREE_TRACE_SCOPE0("HALModuleState::CommandBufferDispatch"); |
| RETURN_IF_ERROR(FromApiStatus( |
| iree_hal_command_buffer_dispatch(command_buffer.get(), executable.get(), |
| entry_point, workgroup_x, workgroup_y, |
| workgroup_z), |
| IREE_LOC)); |
| return OkStatus(); |
| } |
| |
| Status CommandBufferDispatchIndirect( |
| vm::ref<iree_hal_command_buffer_t> command_buffer, |
| vm::ref<iree_hal_executable_t> executable, int32_t entry_point, |
| vm::ref<iree_hal_buffer_t> workgroups_buffer, int32_t workgroups_offset) { |
| IREE_TRACE_SCOPE0("HALModuleState::CommandBufferDispatchIndirect"); |
| RETURN_IF_ERROR( |
| FromApiStatus(iree_hal_command_buffer_dispatch_indirect( |
| command_buffer.get(), executable.get(), entry_point, |
| workgroups_buffer.get(), workgroups_offset), |
| IREE_LOC)); |
| return OkStatus(); |
| } |
| |
| //===--------------------------------------------------------------------===// |
| // iree::hal::DescriptorSet |
| //===--------------------------------------------------------------------===// |
| |
| StatusOr<vm::ref<iree_hal_descriptor_set_t>> DescriptorSetCreate( |
| vm::ref<iree_hal_device_t> device, |
| vm::ref<iree_hal_descriptor_set_layout_t> set_layout, |
| absl::Span<const int32_t> binding_ordinals, |
| absl::Span<const vm::ref<iree_hal_buffer_t>> binding_buffers, |
| absl::Span<const int32_t> binding_offsets, |
| absl::Span<const int32_t> binding_lengths) { |
| IREE_TRACE_SCOPE0("HALModuleState::DescriptorSetCreate"); |
| absl::InlinedVector<iree_hal_descriptor_set_binding_t, 4> binding_structs( |
| binding_ordinals.size()); |
| for (int i = 0; i < binding_ordinals.size(); ++i) { |
| binding_structs[i] = { |
| binding_ordinals[i], // binding |
| binding_buffers[i].get(), // buffer |
| static_cast<iree_device_size_t>(binding_offsets[i]), // offset |
| static_cast<iree_device_size_t>(binding_lengths[i])}; // length |
| } |
| vm::ref<iree_hal_descriptor_set_t> descriptor_set; |
| RETURN_IF_ERROR(FromApiStatus( |
| iree_hal_descriptor_set_create( |
| device.get(), set_layout.get(), binding_structs.size(), |
| binding_structs.data(), allocator_, &descriptor_set), |
| IREE_LOC)); |
| return std::move(descriptor_set); |
| } |
| |
| //===--------------------------------------------------------------------===// |
| // iree::hal::DescriptorSetLayout |
| //===--------------------------------------------------------------------===// |
| |
| StatusOr<vm::ref<iree_hal_descriptor_set_layout_t>> DescriptorSetLayoutCreate( |
| vm::ref<iree_hal_device_t> device, |
| iree_hal_descriptor_set_layout_usage_type_t usage_type, |
| absl::Span<const std::tuple<int32_t, iree_hal_descriptor_type_t, |
| iree_hal_memory_access_t>> |
| bindings) { |
| IREE_TRACE_SCOPE0("HALModuleState::DescriptorSetLayoutCreate"); |
| // TODO(benvanik): custom marshaling for the structs. |
| absl::InlinedVector<iree_hal_descriptor_set_layout_binding_t, 4> |
| binding_structs(bindings.size()); |
| for (int i = 0; i < bindings.size(); ++i) { |
| binding_structs[i] = {std::get<0>(bindings[i]), std::get<1>(bindings[i]), |
| std::get<2>(bindings[i])}; |
| } |
| vm::ref<iree_hal_descriptor_set_layout_t> descriptor_set_layout; |
| RETURN_IF_ERROR(FromApiStatus( |
| iree_hal_descriptor_set_layout_create( |
| device.get(), usage_type, binding_structs.size(), |
| binding_structs.data(), allocator_, &descriptor_set_layout), |
| IREE_LOC)); |
| return std::move(descriptor_set_layout); |
| } |
| |
| //===--------------------------------------------------------------------===// |
| // iree::hal::Device |
| //===--------------------------------------------------------------------===// |
| |
| StatusOr<vm::ref<iree_hal_allocator_t>> DeviceAllocator( |
| vm::ref<iree_hal_device_t> device) { |
| return vm::retain_ref(iree_hal_device_allocator(device.get())); |
| } |
| |
| StatusOr<int32_t> DeviceMatchID(vm::ref<iree_hal_device_t> device, |
| absl::string_view pattern) { |
| iree_string_view_t device_id = iree_hal_device_id(device.get()); |
| return iree_string_view_match_pattern( |
| device_id, iree_string_view_t{pattern.data(), pattern.size()}) |
| ? 1 |
| : 0; |
| } |
| |
| //===--------------------------------------------------------------------===// |
| // iree::hal::ExecutableCache |
| //===--------------------------------------------------------------------===// |
| |
| StatusOr<vm::ref<iree_hal_executable_cache_t>> ExecutableCacheCreate( |
| vm::ref<iree_hal_device_t> device, absl::string_view identifier) { |
| IREE_TRACE_SCOPE0("HALModuleState::ExecutableCacheCreate"); |
| vm::ref<iree_hal_executable_cache_t> executable_cache; |
| RETURN_IF_ERROR(FromApiStatus( |
| iree_hal_executable_cache_create( |
| device.get(), |
| iree_string_view_t{identifier.data(), identifier.size()}, |
| allocator_, &executable_cache), |
| IREE_LOC)); |
| return std::move(executable_cache); |
| } |
| |
| StatusOr<int32_t> ExecutableCacheSelectFormat( |
| vm::ref<iree_hal_executable_cache_t> executable_cache, |
| absl::Span<const iree_hal_executable_format_t> available_formats) { |
| IREE_TRACE_SCOPE0("HALModuleState::ExecutableCacheSelectFormat"); |
| for (int i = 0; i < available_formats.size(); ++i) { |
| if (iree_hal_executable_cache_can_prepare_format(executable_cache.get(), |
| available_formats[i])) { |
| return i; |
| } |
| } |
| return -1; |
| } |
| |
| StatusOr<vm::ref<iree_hal_executable_t>> ExecutableCachePrepare( |
| vm::ref<iree_hal_executable_cache_t> executable_cache, |
| vm::ref<iree_hal_executable_layout_t> executable_layout, |
| iree_hal_executable_caching_mode_t caching_mode, |
| vm::ref<iree_vm_ro_byte_buffer_t> executable_data) { |
| IREE_TRACE_SCOPE0("HALModuleState::ExecutableCachePrepare"); |
| vm::ref<iree_hal_executable_t> executable; |
| RETURN_IF_ERROR(FromApiStatus( |
| iree_hal_executable_cache_prepare_executable( |
| executable_cache.get(), executable_layout.get(), caching_mode, |
| executable_data->data, allocator_, &executable), |
| IREE_LOC)); |
| return std::move(executable); |
| } |
| |
| //===--------------------------------------------------------------------===// |
| // iree::hal::ExecutableLayout |
| //===--------------------------------------------------------------------===// |
| |
| StatusOr<vm::ref<iree_hal_executable_layout_t>> ExecutableLayoutCreate( |
| vm::ref<iree_hal_device_t> device, |
| absl::Span<const vm::ref<iree_hal_descriptor_set_layout_t>> set_layouts, |
| int32_t push_constants) { |
| IREE_TRACE_SCOPE0("HALModuleState::ExecutableLayoutCreate"); |
| vm::ref<iree_hal_executable_layout_t> executable_layout; |
| RETURN_IF_ERROR(FromApiStatus( |
| iree_hal_executable_layout_create( |
| device.get(), set_layouts.size(), |
| reinterpret_cast<iree_hal_descriptor_set_layout_t**>( |
| const_cast<vm::ref<iree_hal_descriptor_set_layout_t>*>( |
| set_layouts.data())), |
| push_constants, allocator_, &executable_layout), |
| IREE_LOC)); |
| return std::move(executable_layout); |
| } |
| |
| private: |
| iree_allocator_t allocator_; |
| ref_ptr<Device> shared_device_; |
| |
| std::vector<iree_vm_ref_t> deferred_releases_; |
| }; |
| |
| //===----------------------------------------------------------------------===// |
| // VM module interface implementation |
| //===----------------------------------------------------------------------===// |
| |
| static const vm::NativeFunction<HALModuleState> kHALModuleFunctions[] = { |
| vm::MakeNativeFunction("ex.shared_device", &HALModuleState::ExSharedDevice), |
| vm::MakeNativeFunction("ex.defer_release", &HALModuleState::ExDeferRelease), |
| vm::MakeNativeFunction("ex.submit_and_wait", |
| &HALModuleState::ExSubmitAndWait), |
| |
| vm::MakeNativeFunction("allocator.compute_size", |
| &HALModuleState::AllocatorComputeSize), |
| vm::MakeNativeFunction("allocator.compute_offset", |
| &HALModuleState::AllocatorComputeOffset), |
| vm::MakeNativeFunction("allocator.compute_range", |
| &HALModuleState::AllocatorComputeRange), |
| vm::MakeNativeFunction("allocator.allocate", |
| &HALModuleState::AllocatorAllocate), |
| vm::MakeNativeFunction("allocator.allocate.const", |
| &HALModuleState::AllocatorAllocateConst), |
| |
| vm::MakeNativeFunction("buffer.allocator", |
| &HALModuleState::BufferAllocator), |
| vm::MakeNativeFunction("buffer.subspan", &HALModuleState::BufferSubspan), |
| vm::MakeNativeFunction("buffer.fill", &HALModuleState::BufferFill), |
| vm::MakeNativeFunction("buffer.read_data", &HALModuleState::BufferReadData), |
| vm::MakeNativeFunction("buffer.write_data", |
| &HALModuleState::BufferWriteData), |
| vm::MakeNativeFunction("buffer.copy_data", &HALModuleState::BufferCopyData), |
| vm::MakeNativeFunction("buffer.load", &HALModuleState::BufferLoad), |
| vm::MakeNativeFunction("buffer.store", &HALModuleState::BufferStore), |
| |
| vm::MakeNativeFunction("buffer_view.create", |
| &HALModuleState::BufferViewCreate), |
| vm::MakeNativeFunction("buffer_view.subview", |
| &HALModuleState::BufferViewSubview), |
| vm::MakeNativeFunction("buffer_view.buffer", |
| &HALModuleState::BufferViewBuffer), |
| vm::MakeNativeFunction("buffer_view.byte_length", |
| &HALModuleState::BufferViewByteLength), |
| vm::MakeNativeFunction("buffer_view.compute_offset", |
| &HALModuleState::BufferViewComputeOffset), |
| vm::MakeNativeFunction("buffer_view.compute_range", |
| &HALModuleState::BufferViewComputeRange), |
| vm::MakeNativeFunction("buffer_view.rank", &HALModuleState::BufferViewRank), |
| vm::MakeNativeFunction("buffer_view.dim", &HALModuleState::BufferViewDim), |
| vm::MakeNativeFunction("buffer_view.dims.1", |
| &HALModuleState::BufferViewDims1), |
| vm::MakeNativeFunction("buffer_view.dims.2", |
| &HALModuleState::BufferViewDims2), |
| vm::MakeNativeFunction("buffer_view.dims.3", |
| &HALModuleState::BufferViewDims3), |
| vm::MakeNativeFunction("buffer_view.dims.4", |
| &HALModuleState::BufferViewDims4), |
| |
| vm::MakeNativeFunction("command_buffer.create", |
| &HALModuleState::CommandBufferCreate), |
| vm::MakeNativeFunction("command_buffer.begin", |
| &HALModuleState::CommandBufferBegin), |
| vm::MakeNativeFunction("command_buffer.end", |
| &HALModuleState::CommandBufferEnd), |
| vm::MakeNativeFunction("command_buffer.execution_barrier", |
| &HALModuleState::CommandBufferExecutionBarrier), |
| vm::MakeNativeFunction("command_buffer.fill_buffer", |
| &HALModuleState::CommandBufferFillBuffer), |
| vm::MakeNativeFunction("command_buffer.copy_buffer", |
| &HALModuleState::CommandBufferCopyBuffer), |
| vm::MakeNativeFunction("command_buffer.push_constants", |
| &HALModuleState::CommandBufferPushConstants), |
| vm::MakeNativeFunction("command_buffer.push_descriptor_set", |
| &HALModuleState::CommandBufferPushDescriptorSet), |
| vm::MakeNativeFunction("command_buffer.bind_descriptor_set", |
| &HALModuleState::CommandBufferBindDescriptorSet), |
| vm::MakeNativeFunction("command_buffer.dispatch", |
| &HALModuleState::CommandBufferDispatch), |
| vm::MakeNativeFunction("command_buffer.dispatch.indirect", |
| &HALModuleState::CommandBufferDispatchIndirect), |
| |
| vm::MakeNativeFunction("descriptor_set.create", |
| &HALModuleState::DescriptorSetCreate), |
| vm::MakeNativeFunction("descriptor_set_layout.create", |
| &HALModuleState::DescriptorSetLayoutCreate), |
| |
| vm::MakeNativeFunction("device.allocator", |
| &HALModuleState::DeviceAllocator), |
| vm::MakeNativeFunction("device.match.id", &HALModuleState::DeviceMatchID), |
| |
| vm::MakeNativeFunction("executable_cache.create", |
| &HALModuleState::ExecutableCacheCreate), |
| vm::MakeNativeFunction("executable_cache.select_format", |
| &HALModuleState::ExecutableCacheSelectFormat), |
| vm::MakeNativeFunction("executable_cache.prepare", |
| &HALModuleState::ExecutableCachePrepare), |
| |
| vm::MakeNativeFunction("executable_layout.create", |
| &HALModuleState::ExecutableLayoutCreate), |
| }; |
| |
| class HALModule final : public vm::NativeModule<HALModuleState> { |
| public: |
| HALModule(iree_allocator_t allocator, ref_ptr<Device> shared_device) |
| : vm::NativeModule<HALModuleState>( |
| "hal", allocator, absl::MakeConstSpan(kHALModuleFunctions)), |
| shared_device_(std::move(shared_device)) {} |
| ~HALModule() = default; |
| |
| Status Initialize() { |
| IREE_TRACE_SCOPE0("HALModule::Initialize"); |
| |
| executable_cache_ = shared_device_->CreateExecutableCache(); |
| |
| return OkStatus(); |
| } |
| |
| StatusOr<std::unique_ptr<HALModuleState>> CreateState( |
| iree_allocator_t allocator) override { |
| IREE_TRACE_SCOPE0("HALModule::CreateState"); |
| auto state = std::make_unique<HALModuleState>( |
| allocator, add_ref(shared_device_), add_ref(executable_cache_)); |
| // TODO(benvanik): allocate context-specific variables (allocator pool, |
| // etc). |
| return state; |
| } |
| |
| private: |
| ref_ptr<Device> shared_device_; |
| ref_ptr<ExecutableCache> executable_cache_; |
| }; |
| |
| IREE_API_EXPORT iree_status_t IREE_API_CALL |
| iree_hal_module_create(iree_hal_device_t* device, iree_allocator_t allocator, |
| iree_vm_module_t** out_module) { |
| if (!out_module) return IREE_STATUS_INVALID_ARGUMENT; |
| *out_module = nullptr; |
| auto module = std::make_unique<HALModule>( |
| allocator, add_ref(reinterpret_cast<Device*>(device))); |
| IREE_API_RETURN_IF_ERROR(module->Initialize()); |
| *out_module = module.release()->interface(); |
| return IREE_STATUS_OK; |
| } |
| |
| } // namespace |
| } // namespace hal |
| } // namespace iree |