blob: 701dd567e8b95ecd35c05566273e1f7583a95674 [file] [log] [blame]
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "third_party/mlir_edge/iree/hal/command_buffer_validation.h"
#include "third_party/mlir_edge/iree/base/logging.h"
#include "third_party/mlir_edge/iree/base/status.h"
namespace iree {
namespace hal {
namespace {
// Command buffer validation shim.
// Wraps an existing command buffer to provide in-depth validation during
// recording. This should be enabled whenever the command buffer is being driven
// by unsafe code or when early and readable diagnostics are needed.
class ValidatingCommandBuffer : public CommandBuffer {
public:
explicit ValidatingCommandBuffer(ref_ptr<CommandBuffer> impl);
~ValidatingCommandBuffer() override;
CommandBuffer* impl() { return impl_.get(); }
bool is_recording() const override;
Status Begin() override;
Status End() override;
Status ExecutionBarrier(
ExecutionStageBitfield source_stage_mask,
ExecutionStageBitfield target_stage_mask,
absl::Span<const MemoryBarrier> memory_barriers,
absl::Span<const BufferBarrier> buffer_barriers) override;
Status SignalEvent(Event* event,
ExecutionStageBitfield source_stage_mask) override;
Status ResetEvent(Event* event,
ExecutionStageBitfield source_stage_mask) override;
Status WaitEvents(absl::Span<Event*> events,
ExecutionStageBitfield source_stage_mask,
ExecutionStageBitfield target_stage_mask,
absl::Span<const MemoryBarrier> memory_barriers,
absl::Span<const BufferBarrier> buffer_barriers) override;
Status FillBuffer(Buffer* target_buffer, device_size_t target_offset,
device_size_t length, const void* pattern,
size_t pattern_length) override;
Status DiscardBuffer(Buffer* buffer) override;
Status UpdateBuffer(const void* source_buffer, device_size_t source_offset,
Buffer* target_buffer, device_size_t target_offset,
device_size_t length) override;
Status CopyBuffer(Buffer* source_buffer, device_size_t source_offset,
Buffer* target_buffer, device_size_t target_offset,
device_size_t length) override;
Status Dispatch(const DispatchRequest& dispatch_request) override;
private:
// Returns a failure if the queue does not support the given caps.
Status ValidateCategories(CommandCategoryBitfield required_categories) const;
// Returns a failure if the memory type the buffer was allocated from is not
// compatible with the given type.
Status ValidateCompatibleMemoryType(Buffer* buffer,
MemoryTypeBitfield memory_type) const;
// Returns a failure if the buffer memory type or usage disallows the given
// access type.
Status ValidateAccess(Buffer* buffer,
MemoryAccessBitfield memory_access) const;
// Returns a failure if the buffer was not allocated for the given usage.
Status ValidateUsage(Buffer* buffer, BufferUsageBitfield usage) const;
// Validates that the range provided is within the given buffer.
Status ValidateRange(Buffer* buffer, device_size_t byte_offset,
device_size_t byte_length) const;
ref_ptr<CommandBuffer> impl_;
};
ValidatingCommandBuffer::ValidatingCommandBuffer(ref_ptr<CommandBuffer> impl)
: CommandBuffer(impl->allocator(), impl->mode(),
impl->command_categories()),
impl_(std::move(impl)) {}
ValidatingCommandBuffer::~ValidatingCommandBuffer() = default;
bool ValidatingCommandBuffer::is_recording() const {
return impl_->is_recording();
}
Status ValidatingCommandBuffer::Begin() {
DVLOG(3) << "CommandBuffer::Begin()";
if (impl_->is_recording()) {
return FailedPreconditionErrorBuilder(ABSL_LOC)
<< "Command buffer is already recording";
}
return impl_->Begin();
}
Status ValidatingCommandBuffer::End() {
DVLOG(3) << "CommandBuffer::End()";
if (!impl_->is_recording()) {
return FailedPreconditionErrorBuilder(ABSL_LOC)
<< "Command buffer is not recording";
}
return impl_->End();
}
Status ValidatingCommandBuffer::ValidateCategories(
CommandCategoryBitfield required_categories) const {
if (!AllBitsSet(command_categories(), required_categories)) {
return FailedPreconditionErrorBuilder(ABSL_LOC)
<< "Operation requires categories "
<< CommandCategoryString(required_categories)
<< " but buffer only supports "
<< CommandCategoryString(command_categories());
}
return OkStatus();
}
Status ValidatingCommandBuffer::ValidateCompatibleMemoryType(
Buffer* buffer, MemoryTypeBitfield memory_type) const {
if ((buffer->memory_type() & memory_type) != memory_type) {
// Missing one or more bits.
return PermissionDeniedErrorBuilder(ABSL_LOC)
<< "Buffer memory type is not compatible with the requested "
"operation; buffer has "
<< MemoryTypeString(buffer->memory_type()) << ", operation requires "
<< MemoryTypeString(memory_type);
}
return OkStatus();
}
Status ValidatingCommandBuffer::ValidateAccess(
Buffer* buffer, MemoryAccessBitfield memory_access) const {
if ((buffer->allowed_access() & memory_access) != memory_access) {
// Bits must match exactly.
return PermissionDeniedErrorBuilder(ABSL_LOC)
<< "The buffer does not support the requested access type; buffer "
"allows "
<< MemoryAccessString(buffer->allowed_access())
<< ", operation requires " << MemoryAccessString(memory_access);
}
return OkStatus();
}
// Returns a failure if the buffer was not allocated for the given usage.
Status ValidatingCommandBuffer::ValidateUsage(Buffer* buffer,
BufferUsageBitfield usage) const {
if (!allocator()->CanUseBuffer(buffer, usage)) {
// Buffer cannot be used on the queue for the given usage.
return PermissionDeniedErrorBuilder(ABSL_LOC)
<< "Requested usage of " << buffer->DebugString()
<< " is not supported for the buffer on this queue; "
"buffer allows "
<< BufferUsageString(buffer->usage()) << ", queue requires "
<< BufferUsageString(usage);
}
if ((buffer->usage() & usage) != usage) {
// Missing one or more bits.
return PermissionDeniedErrorBuilder(ABSL_LOC)
<< "Requested usage was not specified when the buffer was "
"allocated; buffer allows "
<< BufferUsageString(buffer->usage()) << ", operation requires "
<< BufferUsageString(usage);
}
return OkStatus();
}
// Validates that the range provided is within the given buffer.
Status ValidatingCommandBuffer::ValidateRange(Buffer* buffer,
device_size_t byte_offset,
device_size_t byte_length) const {
// Check if the start of the range runs off the end of the buffer.
if (byte_offset > buffer->byte_length()) {
return OutOfRangeErrorBuilder(ABSL_LOC)
<< "Attempted to access an address off the end of the valid buffer "
"range (offset="
<< byte_offset << ", length=" << byte_length
<< ", buffer byte_length=" << buffer->byte_length() << ")";
}
if (byte_length == 0) {
// Fine to have a zero length.
return OkStatus();
}
// Check if the end runs over the allocation.
device_size_t end = byte_offset + byte_length;
if (end > buffer->byte_length()) {
return OutOfRangeErrorBuilder(ABSL_LOC)
<< "Attempted to access an address outside of the valid buffer "
"range (offset="
<< byte_offset << ", length=" << byte_length
<< ", end(inc)=" << (end - 1)
<< ", buffer byte_length=" << buffer->byte_length() << ")";
}
return OkStatus();
}
Status ValidatingCommandBuffer::ExecutionBarrier(
ExecutionStageBitfield source_stage_mask,
ExecutionStageBitfield target_stage_mask,
absl::Span<const MemoryBarrier> memory_barriers,
absl::Span<const BufferBarrier> buffer_barriers) {
DVLOG(3) << "CommandBuffer::ExecutionBarrier(...)";
// TODO(benvanik): additional synchronization validation.
RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer |
CommandCategory::kDispatch));
return impl_->ExecutionBarrier(source_stage_mask, target_stage_mask,
memory_barriers, buffer_barriers);
}
Status ValidatingCommandBuffer::SignalEvent(
Event* event, ExecutionStageBitfield source_stage_mask) {
DVLOG(3) << "CommandBuffer::SignalEvent(...)";
// TODO(benvanik): additional synchronization validation.
RETURN_IF_ERROR(ValidateCategories(CommandCategory::kDispatch));
return impl_->SignalEvent(event, source_stage_mask);
}
Status ValidatingCommandBuffer::ResetEvent(
Event* event, ExecutionStageBitfield source_stage_mask) {
DVLOG(3) << "CommandBuffer::ResetEvent(...)";
// TODO(benvanik): additional synchronization validation.
RETURN_IF_ERROR(ValidateCategories(CommandCategory::kDispatch));
return impl_->ResetEvent(event, source_stage_mask);
}
Status ValidatingCommandBuffer::WaitEvents(
absl::Span<Event*> events, ExecutionStageBitfield source_stage_mask,
ExecutionStageBitfield target_stage_mask,
absl::Span<const MemoryBarrier> memory_barriers,
absl::Span<const BufferBarrier> buffer_barriers) {
DVLOG(3) << "CommandBuffer::WaitEvents(...)";
// TODO(benvanik): additional synchronization validation.
RETURN_IF_ERROR(ValidateCategories(CommandCategory::kDispatch));
return impl_->WaitEvents(events, source_stage_mask, target_stage_mask,
memory_barriers, buffer_barriers);
}
Status ValidatingCommandBuffer::FillBuffer(Buffer* target_buffer,
device_size_t target_offset,
device_size_t length,
const void* pattern,
size_t pattern_length) {
DVLOG(3) << "CommandBuffer::FillBuffer(" << target_buffer->DebugString()
<< ", " << target_offset << ", " << length << ", ??, "
<< pattern_length << ")";
RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer));
RETURN_IF_ERROR(
ValidateCompatibleMemoryType(target_buffer, MemoryType::kDeviceVisible));
RETURN_IF_ERROR(ValidateAccess(target_buffer, MemoryAccess::kWrite));
RETURN_IF_ERROR(ValidateUsage(target_buffer, BufferUsage::kTransfer));
RETURN_IF_ERROR(ValidateRange(target_buffer, target_offset, length));
// Ensure the value length is supported.
if (pattern_length != 1 && pattern_length != 2 && pattern_length != 4) {
return InvalidArgumentErrorBuilder(ABSL_LOC)
<< "Fill value length is not one of the supported values "
"(pattern_length="
<< pattern_length << ")";
}
// Ensure the offset and length have an alignment matching the value length.
if ((target_offset % pattern_length) != 0 || (length % pattern_length) != 0) {
return InvalidArgumentErrorBuilder(ABSL_LOC)
<< "Fill offset and/or length do not match the natural alignment of "
"the fill value (target_offset="
<< target_offset << ", length=" << length
<< ", pattern_length=" << pattern_length << ")";
}
return impl_->FillBuffer(target_buffer, target_offset, length, pattern,
pattern_length);
}
Status ValidatingCommandBuffer::DiscardBuffer(Buffer* buffer) {
DVLOG(3) << "CommandBuffer::DiscardBuffer(" << buffer->DebugString() << ")";
RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer));
RETURN_IF_ERROR(
ValidateCompatibleMemoryType(buffer, MemoryType::kDeviceVisible));
RETURN_IF_ERROR(ValidateUsage(buffer, BufferUsage::kNone));
return impl_->DiscardBuffer(buffer);
}
Status ValidatingCommandBuffer::UpdateBuffer(const void* source_buffer,
device_size_t source_offset,
Buffer* target_buffer,
device_size_t target_offset,
device_size_t length) {
DVLOG(3) << "CommandBuffer::UpdateBuffer(" << source_buffer << ", "
<< source_offset << ", " << target_buffer->DebugString() << ", "
<< target_offset << ", " << length << ")";
RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer));
RETURN_IF_ERROR(
ValidateCompatibleMemoryType(target_buffer, MemoryType::kDeviceVisible));
RETURN_IF_ERROR(ValidateAccess(target_buffer, MemoryAccess::kWrite));
RETURN_IF_ERROR(ValidateUsage(target_buffer, BufferUsage::kTransfer));
RETURN_IF_ERROR(ValidateRange(target_buffer, target_offset, length));
return impl_->UpdateBuffer(source_buffer, source_offset, target_buffer,
target_offset, length);
}
Status ValidatingCommandBuffer::CopyBuffer(Buffer* source_buffer,
device_size_t source_offset,
Buffer* target_buffer,
device_size_t target_offset,
device_size_t length) {
DVLOG(3) << "CommandBuffer::CopyBuffer(" << source_buffer->DebugString()
<< ", " << source_offset << ", " << target_buffer->DebugString()
<< ", " << target_offset << ", " << length << ")";
RETURN_IF_ERROR(ValidateCategories(CommandCategory::kTransfer));
// At least source or destination must be device-visible to enable
// host->device, device->host, and device->device.
// TODO(b/117338171): host->host copies.
if (!AnyBitSet(source_buffer->memory_type() & MemoryType::kDeviceVisible) &&
!AnyBitSet(target_buffer->memory_type() & MemoryType::kDeviceVisible)) {
return PermissionDeniedErrorBuilder(ABSL_LOC)
<< "At least one buffer must be device-visible for a copy; "
"source_buffer="
<< MemoryTypeString(source_buffer->memory_type())
<< ", target_buffer="
<< MemoryTypeString(target_buffer->memory_type());
}
RETURN_IF_ERROR(ValidateAccess(source_buffer, MemoryAccess::kRead));
RETURN_IF_ERROR(ValidateAccess(target_buffer, MemoryAccess::kWrite));
RETURN_IF_ERROR(ValidateUsage(source_buffer, BufferUsage::kTransfer));
RETURN_IF_ERROR(ValidateUsage(target_buffer, BufferUsage::kTransfer));
RETURN_IF_ERROR(ValidateRange(source_buffer, source_offset, length));
RETURN_IF_ERROR(ValidateRange(target_buffer, target_offset, length));
// Check for overlap - just like memcpy we don't handle that.
if (Buffer::TestOverlap(source_buffer, source_offset, length, target_buffer,
target_offset,
length) != Buffer::Overlap::kDisjoint) {
return InvalidArgumentErrorBuilder(ABSL_LOC)
<< "Source and target ranges overlap within the same buffer";
}
return impl_->CopyBuffer(source_buffer, source_offset, target_buffer,
target_offset, length);
}
Status ValidatingCommandBuffer::Dispatch(
const DispatchRequest& dispatch_request) {
DVLOG(3) << "CommandBuffer::Dispatch(?)";
RETURN_IF_ERROR(ValidateCategories(CommandCategory::kDispatch));
// Validate all buffers referenced have compatible memory types, access
// rights, and usage.
for (const auto& binding : dispatch_request.bindings) {
RETURN_IF_ERROR(ValidateCompatibleMemoryType(binding.buffer,
MemoryType::kDeviceVisible))
<< "input buffer: " << MemoryAccessString(binding.access) << " "
<< binding.buffer->DebugStringShort();
RETURN_IF_ERROR(ValidateAccess(binding.buffer, binding.access));
RETURN_IF_ERROR(ValidateUsage(binding.buffer, BufferUsage::kDispatch));
// TODO(benvanik): validate it matches the executable expectations.
// TODO(benvanik): validate buffer contains enough data for shape+size.
}
// TODO(benvanik): validate no aliasing?
return impl_->Dispatch(dispatch_request);
}
} // namespace
ref_ptr<CommandBuffer> WrapCommandBufferWithValidation(
ref_ptr<CommandBuffer> impl) {
return make_ref<ValidatingCommandBuffer>(std::move(impl));
}
} // namespace hal
} // namespace iree