blob: 08ca389907cd9fd50af8fd85ee0b76a60811cea7 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "third_party/mlir_edge/iree/hal/vulkan/direct_command_queue.h"
#include <cstdint>
#include "third_party/absl/time/clock.h"
#include "third_party/absl/time/time.h"
#include "third_party/absl/types/source_location.h"
#include "third_party/mlir_edge/iree/base/memory.h"
#include "third_party/mlir_edge/iree/base/status.h"
#include "third_party/mlir_edge/iree/base/tracing.h"
#include "third_party/mlir_edge/iree/hal/vulkan/direct_command_buffer.h"
#include "third_party/mlir_edge/iree/hal/vulkan/legacy_fence.h"
#include "third_party/mlir_edge/iree/hal/vulkan/native_binary_semaphore.h"
#include "third_party/mlir_edge/iree/hal/vulkan/status_util.h"
namespace iree {
namespace hal {
namespace vulkan {
DirectCommandQueue::DirectCommandQueue(
std::string name, CommandCategoryBitfield supported_categories,
const ref_ptr<VkDeviceHandle>& logical_device, VkQueue queue)
: CommandQueue(std::move(name), supported_categories),
logical_device_(add_ref(logical_device)),
queue_(queue) {}
DirectCommandQueue::~DirectCommandQueue() {
IREE_TRACE_SCOPE0("DirectCommandQueue::dtor");
absl::MutexLock lock(&queue_mutex_);
syms()->vkQueueWaitIdle(queue_);
}
Status DirectCommandQueue::TranslateBatchInfo(const SubmissionBatch& batch,
VkSubmitInfo* submit_info,
Arena* arena) {
// TODO(benvanik): see if we can go to finer-grained stages.
// For example, if this was just queue ownership transfers then we can use
// the pseudo-stage of VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT.
VkPipelineStageFlags dst_stage_mask =
VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT;
auto wait_semaphore_handles =
arena->AllocateSpan<VkSemaphore>(batch.wait_semaphores.size());
auto wait_dst_stage_masks =
arena->AllocateSpan<VkPipelineStageFlags>(batch.wait_semaphores.size());
for (int i = 0; i < batch.wait_semaphores.size(); ++i) {
const auto& semaphore_value = batch.wait_semaphores[i];
if (semaphore_value.index() == 0) {
const auto& binary_semaphore =
static_cast<NativeBinarySemaphore*>(absl::get<0>(semaphore_value));
wait_semaphore_handles[i] = binary_semaphore->handle();
} else {
// TODO(b/140141417): implement timeline semaphores.
return UnimplementedErrorBuilder(ABSL_LOC)
<< "Timeline semaphores not yet implemented";
}
wait_dst_stage_masks[i] = dst_stage_mask;
}
auto signal_semaphore_handles =
arena->AllocateSpan<VkSemaphore>(batch.signal_semaphores.size());
for (int i = 0; i < batch.signal_semaphores.size(); ++i) {
const auto& semaphore_value = batch.signal_semaphores[i];
if (semaphore_value.index() == 0) {
const auto& binary_semaphore =
static_cast<NativeBinarySemaphore*>(absl::get<0>(semaphore_value));
signal_semaphore_handles[i] = binary_semaphore->handle();
} else {
// TODO(b/140141417): implement timeline semaphores.
return UnimplementedErrorBuilder(ABSL_LOC)
<< "Timeline semaphores not yet implemented";
}
}
auto command_buffer_handles =
arena->AllocateSpan<VkCommandBuffer>(batch.command_buffers.size());
for (int i = 0; i < batch.command_buffers.size(); ++i) {
const auto& command_buffer = batch.command_buffers[i];
auto* direct_command_buffer =
static_cast<DirectCommandBuffer*>(command_buffer->impl());
command_buffer_handles[i] = direct_command_buffer->handle();
}
submit_info->sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
submit_info->pNext = nullptr;
submit_info->waitSemaphoreCount = wait_semaphore_handles.size();
submit_info->pWaitSemaphores = wait_semaphore_handles.data();
submit_info->pWaitDstStageMask = wait_dst_stage_masks.data();
submit_info->commandBufferCount = command_buffer_handles.size();
submit_info->pCommandBuffers = command_buffer_handles.data();
submit_info->signalSemaphoreCount = signal_semaphore_handles.size();
submit_info->pSignalSemaphores = signal_semaphore_handles.data();
return OkStatus();
}
Status DirectCommandQueue::Submit(absl::Span<const SubmissionBatch> batches,
FenceValue fence) {
IREE_TRACE_SCOPE0("DirectCommandQueue::Submit");
// Map the submission batches to VkSubmitInfos.
// Note that we must keep all arrays referenced alive until submission
// completes and since there are a bunch of them we use an arena.
Arena arena(4 * 1024);
auto submit_infos = arena.AllocateSpan<VkSubmitInfo>(batches.size());
for (int i = 0; i < batches.size(); ++i) {
RETURN_IF_ERROR(TranslateBatchInfo(batches[i], &submit_infos[i], &arena));
}
// TODO(b/140141417): implement timeline semaphore fences and switch here.
auto legacy_fence = reinterpret_cast<LegacyFence*>(fence.first);
ASSIGN_OR_RETURN(VkFence fence_handle,
legacy_fence->AcquireSignalFence(fence.second));
{
absl::MutexLock lock(&queue_mutex_);
VK_RETURN_IF_ERROR(syms()->vkQueueSubmit(
queue_, submit_infos.size(), submit_infos.data(), fence_handle));
}
return OkStatus();
}
Status DirectCommandQueue::Flush() {
IREE_TRACE_SCOPE0("DirectCommandQueue::Flush");
// Nothing to do here as submit is not async.
return OkStatus();
}
Status DirectCommandQueue::WaitIdle(absl::Time deadline) {
if (deadline == absl::InfiniteFuture()) {
// Fast path for using vkQueueWaitIdle, which is usually cheaper (as it
// requires fewer calls into the driver).
IREE_TRACE_SCOPE0("DirectCommandQueue::WaitIdle#vkQueueWaitIdle");
absl::MutexLock lock(&queue_mutex_);
VK_RETURN_IF_ERROR(syms()->vkQueueWaitIdle(queue_));
return OkStatus();
}
IREE_TRACE_SCOPE0("DirectCommandQueue::WaitIdle#Fence");
// Create a new fence just for this wait. This keeps us thread-safe as the
// behavior of wait+reset is racey.
VkFenceCreateInfo create_info;
create_info.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO;
create_info.pNext = nullptr;
create_info.flags = 0;
VkFence fence = VK_NULL_HANDLE;
VK_RETURN_IF_ERROR(syms()->vkCreateFence(
*logical_device_, &create_info, logical_device_->allocator(), &fence));
auto fence_cleanup = MakeCleanup([this, fence]() {
syms()->vkDestroyFence(*logical_device_, fence,
logical_device_->allocator());
});
uint64_t timeout;
if (deadline == absl::InfinitePast()) {
// Do not wait.
timeout = 0;
} else if (deadline == absl::InfiniteFuture()) {
// Wait forever.
timeout = UINT64_MAX;
} else {
// Convert to relative time in nanoseconds.
// The implementation may not wait with this granularity (like, by 10000x).
absl::Time now = absl::Now();
if (deadline < now) {
return DeadlineExceededErrorBuilder(ABSL_LOC) << "Deadline in the past";
}
timeout = static_cast<uint64_t>(absl::ToInt64Nanoseconds(deadline - now));
}
{
absl::MutexLock lock(&queue_mutex_);
VK_RETURN_IF_ERROR(syms()->vkQueueSubmit(queue_, 0, nullptr, fence));
}
VkResult result =
syms()->vkWaitForFences(*logical_device_, 1, &fence, VK_TRUE, timeout);
switch (result) {
case VK_SUCCESS:
return OkStatus();
case VK_TIMEOUT:
return DeadlineExceededErrorBuilder(ABSL_LOC)
<< "Deadline exceeded waiting for idle";
default:
return VkResultToStatus(result);
}
}
} // namespace vulkan
} // namespace hal
} // namespace iree