blob: 540b357da0d278aaadfc2287700fb89cd525d274 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "third_party/mlir_edge/iree/hal/vulkan/direct_command_buffer.h"
#include "third_party/absl/base/attributes.h"
#include "third_party/absl/container/inlined_vector.h"
#include "third_party/absl/synchronization/mutex.h"
#include "third_party/absl/types/source_location.h"
#include "third_party/mlir_edge/iree/base/status.h"
#include "third_party/mlir_edge/iree/base/tracing.h"
#include "third_party/mlir_edge/iree/hal/vulkan/status_util.h"
namespace iree {
namespace hal {
namespace vulkan {
namespace {
VkPipelineStageFlags ConvertPipelineStageFlags(
ExecutionStageBitfield stage_mask) {
VkPipelineStageFlags flags = 0;
flags |= AnyBitSet(stage_mask & ExecutionStage::kCommandIssue)
? VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT
: 0;
flags |= AnyBitSet(stage_mask & ExecutionStage::kCommandProcess)
? VK_PIPELINE_STAGE_DRAW_INDIRECT_BIT
: 0;
flags |= AnyBitSet(stage_mask & ExecutionStage::kDispatch)
? VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT
: 0;
flags |= AnyBitSet(stage_mask & ExecutionStage::kTransfer)
? VK_PIPELINE_STAGE_TRANSFER_BIT
: 0;
flags |= AnyBitSet(stage_mask & ExecutionStage::kCommandRetire)
? VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT
: 0;
flags |= AnyBitSet(stage_mask & ExecutionStage::kHost)
? VK_PIPELINE_STAGE_HOST_BIT
: 0;
return flags;
}
VkAccessFlags ConvertAccessMask(AccessScopeBitfield access_mask) {
VkAccessFlags flags = 0;
flags |= AnyBitSet(access_mask & AccessScope::kIndirectCommandRead)
? VK_ACCESS_INDIRECT_COMMAND_READ_BIT
: 0;
flags |= AnyBitSet(access_mask & AccessScope::kConstantRead)
? VK_ACCESS_UNIFORM_READ_BIT
: 0;
flags |= AnyBitSet(access_mask & AccessScope::kDispatchRead)
? VK_ACCESS_SHADER_READ_BIT
: 0;
flags |= AnyBitSet(access_mask & AccessScope::kDispatchWrite)
? VK_ACCESS_SHADER_WRITE_BIT
: 0;
flags |= AnyBitSet(access_mask & AccessScope::kTransferRead)
? VK_ACCESS_TRANSFER_READ_BIT
: 0;
flags |= AnyBitSet(access_mask & AccessScope::kTransferWrite)
? VK_ACCESS_TRANSFER_WRITE_BIT
: 0;
flags |= AnyBitSet(access_mask & AccessScope::kHostRead)
? VK_ACCESS_HOST_READ_BIT
: 0;
flags |= AnyBitSet(access_mask & AccessScope::kHostWrite)
? VK_ACCESS_HOST_WRITE_BIT
: 0;
flags |= AnyBitSet(access_mask & AccessScope::kMemoryRead)
? VK_ACCESS_MEMORY_READ_BIT
: 0;
flags |= AnyBitSet(access_mask & AccessScope::kMemoryWrite)
? VK_ACCESS_MEMORY_WRITE_BIT
: 0;
return flags;
}
// Splats a pattern value of 1, 2, or 4 bytes out to a 4 byte value.
uint32_t SplatPattern(const void* pattern, size_t pattern_length) {
switch (pattern_length) {
case 1: {
uint32_t pattern_value = *static_cast<const uint8_t*>(pattern);
return (pattern_value << 24) | (pattern_value << 16) |
(pattern_value << 8) | pattern_value;
}
case 2: {
uint32_t pattern_value = *static_cast<const uint16_t*>(pattern);
return (pattern_value << 16) | pattern_value;
}
case 4: {
uint32_t pattern_value = *static_cast<const uint32_t*>(pattern);
return pattern_value;
}
default:
return 0; // Already verified that this should not be possible.
}
}
} // namespace
DirectCommandBuffer::DirectCommandBuffer(
Allocator* allocator, CommandBufferModeBitfield mode,
CommandCategoryBitfield command_categories,
const ref_ptr<VkCommandPoolHandle>& command_pool,
VkCommandBuffer command_buffer)
: CommandBuffer(allocator, mode, command_categories),
command_pool_(add_ref(command_pool)),
command_buffer_(command_buffer) {}
DirectCommandBuffer::~DirectCommandBuffer() {
IREE_TRACE_SCOPE0("DirectCommandBuffer::dtor");
absl::MutexLock lock(command_pool_->mutex());
syms()->vkFreeCommandBuffers(*command_pool_->logical_device(), *command_pool_,
1, &command_buffer_);
}
StatusOr<NativeEvent*> DirectCommandBuffer::CastEvent(Event* event) const {
// TODO(benvanik): assert the event is valid.
return static_cast<NativeEvent*>(event);
}
StatusOr<VmaBuffer*> DirectCommandBuffer::CastBuffer(Buffer* buffer) const {
// TODO(benvanik): assert that the buffer is from the right allocator and
// that it is compatible with our target queue family.
return static_cast<VmaBuffer*>(buffer->allocated_buffer());
}
Status DirectCommandBuffer::Begin() {
IREE_TRACE_SCOPE0("DirectCommandBuffer::Begin");
is_recording_ = true;
VkCommandBufferBeginInfo begin_info;
begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
begin_info.pNext = nullptr;
begin_info.flags = AllBitsSet(mode(), CommandBufferMode::kOneShot)
? VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT
: 0;
begin_info.pInheritanceInfo = nullptr;
VK_RETURN_IF_ERROR(
syms()->vkBeginCommandBuffer(command_buffer_, &begin_info));
return OkStatus();
}
Status DirectCommandBuffer::End() {
IREE_TRACE_SCOPE0("DirectCommandBuffer::End");
VK_RETURN_IF_ERROR(syms()->vkEndCommandBuffer(command_buffer_));
is_recording_ = false;
return OkStatus();
}
Status DirectCommandBuffer::ExecutionBarrier(
ExecutionStageBitfield source_stage_mask,
ExecutionStageBitfield target_stage_mask,
absl::Span<const MemoryBarrier> memory_barriers,
absl::Span<const BufferBarrier> buffer_barriers) {
IREE_TRACE_SCOPE0("DirectCommandBuffer::ExecutionBarrier");
absl::InlinedVector<VkMemoryBarrier, 8> memory_barrier_infos(
memory_barriers.size());
for (int i = 0; i < memory_barriers.size(); ++i) {
const auto& memory_barrier = memory_barriers[i];
auto& info = memory_barrier_infos[i];
info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
info.pNext = nullptr;
info.srcAccessMask = ConvertAccessMask(memory_barrier.source_scope);
info.dstAccessMask = ConvertAccessMask(memory_barrier.target_scope);
}
absl::InlinedVector<VkBufferMemoryBarrier, 8> buffer_barrier_infos(
buffer_barriers.size());
for (int i = 0; i < buffer_barriers.size(); ++i) {
const auto& buffer_barrier = buffer_barriers[i];
auto& info = buffer_barrier_infos[i];
info.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER;
info.pNext = nullptr;
info.srcAccessMask = ConvertAccessMask(buffer_barrier.source_scope);
info.dstAccessMask = ConvertAccessMask(buffer_barrier.target_scope);
info.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
info.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
ASSIGN_OR_RETURN(auto* device_buffer, CastBuffer(buffer_barrier.buffer));
info.buffer = device_buffer->handle();
info.offset = buffer_barrier.offset;
info.size = buffer_barrier.length;
}
syms()->vkCmdPipelineBarrier(
command_buffer_, ConvertPipelineStageFlags(source_stage_mask),
ConvertPipelineStageFlags(target_stage_mask), /*dependencyFlags=*/0,
memory_barrier_infos.size(), memory_barrier_infos.data(),
buffer_barrier_infos.size(), buffer_barrier_infos.data(), 0, nullptr);
return OkStatus();
}
Status DirectCommandBuffer::SignalEvent(
Event* event, ExecutionStageBitfield source_stage_mask) {
IREE_TRACE_SCOPE0("DirectCommandBuffer::SignalEvent");
ASSIGN_OR_RETURN(auto* device_event, CastEvent(event));
syms()->vkCmdSetEvent(command_buffer_, device_event->handle(),
ConvertPipelineStageFlags(source_stage_mask));
return OkStatus();
}
Status DirectCommandBuffer::ResetEvent(
Event* event, ExecutionStageBitfield source_stage_mask) {
IREE_TRACE_SCOPE0("DirectCommandBuffer::ResetEvent");
ASSIGN_OR_RETURN(auto* device_event, CastEvent(event));
syms()->vkCmdResetEvent(command_buffer_, device_event->handle(),
ConvertPipelineStageFlags(source_stage_mask));
return OkStatus();
}
Status DirectCommandBuffer::WaitEvents(
absl::Span<Event*> events, ExecutionStageBitfield source_stage_mask,
ExecutionStageBitfield target_stage_mask,
absl::Span<const MemoryBarrier> memory_barriers,
absl::Span<const BufferBarrier> buffer_barriers) {
IREE_TRACE_SCOPE0("DirectCommandBuffer::WaitEvents");
absl::InlinedVector<VkEvent, 4> event_handles(events.size());
for (int i = 0; i < events.size(); ++i) {
ASSIGN_OR_RETURN(auto* device_event, CastEvent(events[i]));
event_handles[i] = device_event->handle();
}
absl::InlinedVector<VkMemoryBarrier, 8> memory_barrier_infos(
memory_barriers.size());
for (int i = 0; i < memory_barriers.size(); ++i) {
const auto& memory_barrier = memory_barriers[i];
auto& info = memory_barrier_infos[i];
info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER;
info.pNext = nullptr;
info.srcAccessMask = ConvertAccessMask(memory_barrier.source_scope);
info.dstAccessMask = ConvertAccessMask(memory_barrier.target_scope);
}
absl::InlinedVector<VkBufferMemoryBarrier, 8> buffer_barrier_infos(
buffer_barriers.size());
for (int i = 0; i < buffer_barriers.size(); ++i) {
const auto& buffer_barrier = buffer_barriers[i];
auto& info = buffer_barrier_infos[i];
info.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER;
info.pNext = nullptr;
info.srcAccessMask = ConvertAccessMask(buffer_barrier.source_scope);
info.dstAccessMask = ConvertAccessMask(buffer_barrier.target_scope);
info.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
info.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
ASSIGN_OR_RETURN(auto* device_buffer, CastBuffer(buffer_barrier.buffer));
info.buffer = device_buffer->handle();
info.offset = buffer_barrier.offset;
info.size = buffer_barrier.length;
}
syms()->vkCmdWaitEvents(
command_buffer_, event_handles.size(), event_handles.data(),
ConvertPipelineStageFlags(source_stage_mask),
ConvertPipelineStageFlags(target_stage_mask), memory_barrier_infos.size(),
memory_barrier_infos.data(), buffer_barrier_infos.size(),
buffer_barrier_infos.data(), 0, nullptr);
return OkStatus();
}
Status DirectCommandBuffer::FillBuffer(Buffer* target_buffer,
device_size_t target_offset,
device_size_t length,
const void* pattern,
size_t pattern_length) {
IREE_TRACE_SCOPE0("DirectCommandBuffer::FillBuffer");
ASSIGN_OR_RETURN(auto* target_device_buffer, CastBuffer(target_buffer));
// Note that fill only accepts 4-byte aligned values so we need to splat out
// our variable-length pattern.
target_offset += target_buffer->byte_offset();
uint32_t dword_pattern = SplatPattern(pattern, pattern_length);
syms()->vkCmdFillBuffer(command_buffer_, target_device_buffer->handle(),
target_offset, length, dword_pattern);
return OkStatus();
}
Status DirectCommandBuffer::DiscardBuffer(Buffer* buffer) {
IREE_TRACE_SCOPE0("DirectCommandBuffer::DiscardBuffer");
// NOTE: we could use this to prevent queue family transitions.
return OkStatus();
}
Status DirectCommandBuffer::UpdateBuffer(const void* source_buffer,
device_size_t source_offset,
Buffer* target_buffer,
device_size_t target_offset,
device_size_t length) {
IREE_TRACE_SCOPE0("DirectCommandBuffer::UpdateBuffer");
ASSIGN_OR_RETURN(auto* target_device_buffer, CastBuffer(target_buffer));
// Vulkan only allows updates of <= 65536 because you really, really, really
// shouldn't do large updates like this (as it wastes command buffer space and
// may be slower than just using write-through mapped memory). The
// recommendation in the spec for larger updates is to split the single update
// into multiple updates over the entire desired range.
const auto* source_buffer_ptr = static_cast<const uint8_t*>(source_buffer);
target_offset += target_buffer->byte_offset();
while (length > 0) {
device_size_t chunk_length =
std::min(static_cast<device_size_t>(65536u), length);
syms()->vkCmdUpdateBuffer(command_buffer_, target_device_buffer->handle(),
target_offset, chunk_length, source_buffer_ptr);
source_buffer_ptr += chunk_length;
target_offset += chunk_length;
length -= chunk_length;
}
return OkStatus();
}
Status DirectCommandBuffer::CopyBuffer(Buffer* source_buffer,
device_size_t source_offset,
Buffer* target_buffer,
device_size_t target_offset,
device_size_t length) {
IREE_TRACE_SCOPE0("DirectCommandBuffer::CopyBuffer");
ASSIGN_OR_RETURN(auto* source_device_buffer, CastBuffer(source_buffer));
ASSIGN_OR_RETURN(auto* target_device_buffer, CastBuffer(target_buffer));
VkBufferCopy region;
region.srcOffset = source_buffer->byte_offset() + source_offset;
region.dstOffset = target_buffer->byte_offset() + target_offset;
region.size = length;
syms()->vkCmdCopyBuffer(command_buffer_, source_device_buffer->handle(),
target_device_buffer->handle(), 1, &region);
return OkStatus();
}
Status DirectCommandBuffer::UpdateAndBindDescriptorSet(
PipelineExecutable* executable, absl::Span<const BufferBinding> bindings) {
absl::InlinedVector<VkDescriptorBufferInfo, 8> buffer_infos;
buffer_infos.resize(bindings.size());
for (int i = 0; i < bindings.size(); ++i) {
ASSIGN_OR_RETURN(auto buffer, CastBuffer(bindings[i].buffer));
buffer_infos[i].buffer = buffer->handle();
// TODO(benvanik): properly subrange (add to BufferBinding).
buffer_infos[i].offset = bindings[i].buffer->byte_offset();
buffer_infos[i].range = bindings[i].buffer->byte_length();
}
const auto& descriptor_sets = executable->descriptor_sets();
absl::InlinedVector<VkWriteDescriptorSet, 8> write_infos;
write_infos.resize(bindings.size());
for (int i = 0; i < bindings.size(); ++i) {
ASSIGN_OR_RETURN(auto buffer, CastBuffer(bindings[i].buffer));
VkDescriptorBufferInfo buffer_info;
buffer_info.buffer = buffer->handle();
// TODO(benvanik): properly subrange (add to BufferBinding).
buffer_info.offset = bindings[i].buffer->byte_offset();
buffer_info.range = bindings[i].buffer->byte_length();
auto& write_info = write_infos[i];
write_info.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
write_info.pNext = nullptr;
write_info.dstSet = VK_NULL_HANDLE;
write_info.dstBinding = descriptor_sets.buffer_binding_set_map[i];
write_info.dstArrayElement = 0;
write_info.descriptorCount = 1;
write_info.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
write_info.pImageInfo = nullptr;
write_info.pBufferInfo = &buffer_infos[i];
write_info.pTexelBufferView = nullptr;
}
if (command_pool_->logical_device()->enabled_extensions().push_descriptors) {
// Fast path using push descriptors. These are pooled internally by the
// command buffer and prevent the need for our own pooling mechanisms.
syms()->vkCmdPushDescriptorSetKHR(
command_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE,
executable->pipeline_layout(), descriptor_sets.buffer_binding_set,
write_infos.size(), write_infos.data());
} else {
// TODO(benvanik): allocate from pool and update.
return UnimplementedErrorBuilder(ABSL_LOC)
<< "Non-push descriptor set path not yet implemented";
}
return OkStatus();
}
Status DirectCommandBuffer::Dispatch(const DispatchRequest& dispatch_request) {
IREE_TRACE_SCOPE0("DirectCommandBuffer::Dispatch");
// Get the compiled and linked pipeline for the specified entry point and
// bind it to the command buffer.
auto* executable =
static_cast<PipelineExecutable*>(dispatch_request.executable);
ASSIGN_OR_RETURN(VkPipeline pipeline, executable->GetPipelineForEntryPoint(
dispatch_request.entry_point));
syms()->vkCmdBindPipeline(command_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE,
pipeline);
// Either allocate, update, and bind a descriptor set or use push descriptor
// sets to use the command buffer pool when supported.
RETURN_IF_ERROR(
UpdateAndBindDescriptorSet(executable, dispatch_request.bindings));
// TODO(benvanik): not this, /obviously/. Replace with semantic tags or just
// get SPIR-V roundtripping what we need to do this in proper IR. The infra
// for dynamic shapes is another route, with us being able to just pass shapes
// in via dynamically updated uniform buffers.
if (executable->is_matmul()) {
struct ABSL_ATTRIBUTE_PACKED {
int32_t dims[4];
} shapes[3];
for (int i = 0; i < 3; ++i) {
const auto& shape = dispatch_request.bindings[i].shape;
if (shape.size() == 3) {
shapes[i].dims[0] = shape[0];
shapes[i].dims[1] = shape[1];
shapes[i].dims[2] = shape[2];
shapes[i].dims[3] = 1;
} else {
shapes[i].dims[0] = 1;
shapes[i].dims[1] = shape[0];
shapes[i].dims[2] = shape[1];
shapes[i].dims[3] = 1;
}
}
syms()->vkCmdPushConstants(command_buffer_, executable->pipeline_layout(),
VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(shapes),
&shapes);
}
// TODO(benvanik): divide workload by caps and issue multiple dispatches.
// TODO(benvanik): track local workgroup/subgroup size and divide into groups.
if (dispatch_request.workload_buffer) {
return UnimplementedErrorBuilder(ABSL_LOC)
<< "Dynamic dispatches not yet implemented";
}
uint32_t group_count_x = dispatch_request.workload[0];
uint32_t group_count_y = dispatch_request.workload[1];
uint32_t group_count_z = dispatch_request.workload[2];
if (executable->is_matmul()) {
group_count_x = (group_count_x + 16 - 1) / 16;
group_count_y = (group_count_y + 16 - 1) / 16;
group_count_z = 1;
}
syms()->vkCmdDispatch(command_buffer_, group_count_x, group_count_y,
group_count_z);
return OkStatus();
}
} // namespace vulkan
} // namespace hal
} // namespace iree