blob: 9086303aeeb67a2ffa0a73c2d213eb30f62a34d4 [file] [log] [blame]
// Copyright 2020 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/hal/command_buffer_validation.h"
#include <inttypes.h>
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#include "iree/hal/allocator.h"
#include "iree/hal/buffer.h"
#include "iree/hal/detail.h"
#include "iree/hal/event.h"
#include "iree/hal/executable.h"
#include "iree/hal/pipeline_layout.h"
#include "iree/hal/resource.h"
// Returns success iff the queue supports the given command categories.
static iree_status_t iree_hal_command_buffer_validate_categories(
const iree_hal_command_buffer_t* command_buffer,
const iree_hal_command_buffer_validation_state_t* validation_state,
iree_hal_command_category_t required_categories) {
if (IREE_UNLIKELY(!validation_state->is_recording)) {
return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
"command buffer is not in a recording state");
}
if (!iree_all_bits_set(command_buffer->allowed_categories,
required_categories)) {
#if IREE_STATUS_MODE
iree_bitfield_string_temp_t temp0, temp1;
iree_string_view_t required_categories_str =
iree_hal_command_category_format(required_categories, &temp0);
iree_string_view_t allowed_categories_str =
iree_hal_command_category_format(command_buffer->allowed_categories,
&temp1);
return iree_make_status(
IREE_STATUS_FAILED_PRECONDITION,
"operation requires categories %.*s but command buffer only supports "
"%.*s",
(int)required_categories_str.size, required_categories_str.data,
(int)allowed_categories_str.size, allowed_categories_str.data);
#else
return iree_status_from_code(IREE_STATUS_FAILED_PRECONDITION);
#endif // IREE_STATUS_MODE
}
return iree_ok_status();
}
// Returns success iff the buffer is compatible with the device.
static iree_status_t iree_hal_command_buffer_validate_buffer_compatibility(
const iree_hal_command_buffer_t* command_buffer,
const iree_hal_command_buffer_validation_state_t* validation_state,
iree_hal_buffer_t* buffer,
iree_hal_buffer_compatibility_t required_compatibility,
iree_hal_buffer_usage_t intended_usage) {
iree_hal_buffer_compatibility_t allowed_compatibility =
iree_hal_allocator_query_buffer_compatibility(
validation_state->device_allocator,
(iree_hal_buffer_params_t){
.type = iree_hal_buffer_memory_type(buffer),
.usage = iree_hal_buffer_allowed_usage(buffer) & intended_usage,
},
iree_hal_buffer_allocation_size(buffer), /*out_params=*/NULL,
/*out_allocation_size=*/NULL);
if (!iree_all_bits_set(allowed_compatibility, required_compatibility)) {
#if IREE_STATUS_MODE
// Buffer cannot be used on the queue for the given usage.
iree_bitfield_string_temp_t temp0, temp1;
iree_string_view_t allowed_usage_str = iree_hal_buffer_usage_format(
iree_hal_buffer_allowed_usage(buffer), &temp0);
iree_string_view_t intended_usage_str =
iree_hal_buffer_usage_format(intended_usage, &temp1);
return iree_make_status(
IREE_STATUS_PERMISSION_DENIED,
"requested buffer usage is not supported for the buffer on this queue; "
"buffer allows %.*s, operation requires %.*s (allocator compatibility "
"mismatch)",
(int)allowed_usage_str.size, allowed_usage_str.data,
(int)intended_usage_str.size, intended_usage_str.data);
#else
return iree_status_from_code(IREE_STATUS_PERMISSION_DENIED);
#endif // IREE_STATUS_MODE
}
return iree_ok_status();
}
static iree_status_t iree_hal_command_buffer_validate_binding_requirements(
iree_hal_command_buffer_t* command_buffer,
const iree_hal_command_buffer_validation_state_t* validation_state,
iree_hal_buffer_binding_t binding,
iree_hal_buffer_binding_requirements_t requirements) {
// Check for binding presence.
if (requirements.usage == IREE_HAL_BUFFER_USAGE_NONE) {
// Binding slot is unused and its value in the table is ignored.
return iree_ok_status();
} else if (!binding.buffer) {
// Binding is used and required.
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"binding table slot requires a buffer but none was provided");
}
// Ensure the buffer is compatible with the device.
// NOTE: this check is very slow! We may want to disable this outside of debug
// mode or try to fast path it if the buffer is known-good.
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_buffer_compatibility(
command_buffer, validation_state, binding.buffer,
requirements.required_compatibility, requirements.usage));
// Verify buffer compatibility.
IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_usage(
iree_hal_buffer_allowed_usage(binding.buffer), requirements.usage));
IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_access(
iree_hal_buffer_allowed_access(binding.buffer), requirements.access));
IREE_RETURN_IF_ERROR(iree_hal_buffer_validate_memory_type(
iree_hal_buffer_memory_type(binding.buffer), requirements.type));
// Verify that the binding range is valid and that any commands that reference
// it are in range.
if (requirements.max_byte_offset > 0) {
iree_device_size_t end = binding.offset + requirements.max_byte_offset;
if (IREE_UNLIKELY(end > binding.length)) {
return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
"at least one command attempted to access an "
"address outside of the valid bound buffer "
"range (length=%" PRIdsz ", end(inc)=%" PRIdsz
", binding offset=%" PRIdsz
", binding length=%" PRIdsz ")",
requirements.max_byte_offset, end - 1,
binding.offset, binding.length);
}
}
// Ensure the offset and length have an alignment matching the value length.
if (requirements.min_byte_alignment &&
(binding.offset % requirements.min_byte_alignment) != 0) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"binding offset does not match the required "
"alignment of one or more command (offset=%" PRIdsz
", min_byte_alignment=%" PRIhsz ")",
binding.offset, requirements.min_byte_alignment);
}
return iree_ok_status();
}
static iree_status_t iree_hal_command_buffer_validate_buffer_requirements(
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
iree_hal_buffer_ref_t buffer_ref,
iree_hal_buffer_binding_requirements_t requirements) {
// If the buffer is directly specified we can validate it inline.
if (buffer_ref.buffer) {
iree_hal_buffer_binding_t binding = {
.buffer = buffer_ref.buffer,
.offset = 0,
.length = buffer_ref.offset + buffer_ref.length,
};
return iree_hal_command_buffer_validate_binding_requirements(
command_buffer, validation_state, binding, requirements);
}
// Ensure the buffer binding table slot is within range. Note that the
// binding table provided may have more bindings than required so we only
// verify against the declared command buffer capacity.
if (IREE_UNLIKELY(buffer_ref.buffer_slot >=
command_buffer->binding_capacity)) {
return iree_make_status(
IREE_STATUS_OUT_OF_RANGE,
"indirect buffer reference slot %u is out range of the declared "
"binding capacity of the command buffer %u",
buffer_ref.buffer_slot, command_buffer->binding_capacity);
}
command_buffer->binding_count =
iree_max(command_buffer->binding_count, buffer_ref.buffer_slot + 1);
// Merge the binding requirements into the table.
iree_hal_buffer_binding_requirements_t* table_requirements =
&validation_state->binding_requirements[buffer_ref.buffer_slot];
table_requirements->required_compatibility |=
requirements.required_compatibility;
table_requirements->usage |= requirements.usage;
table_requirements->access |= requirements.access;
table_requirements->type |= requirements.type;
table_requirements->max_byte_offset = iree_max(
table_requirements->max_byte_offset, requirements.max_byte_offset);
table_requirements->min_byte_alignment = iree_device_size_lcm(
table_requirements->min_byte_alignment, requirements.min_byte_alignment);
return iree_ok_status();
}
// Returns success iff the currently bound descriptor sets are valid for the
// given executable entry point.
static iree_status_t iree_hal_command_buffer_validate_dispatch_bindings(
iree_hal_command_buffer_t* command_buffer,
const iree_hal_command_buffer_validation_state_t* validation_state,
iree_hal_executable_t* executable, int32_t entry_point) {
// TODO(benvanik): validate buffers referenced have compatible memory types
// and access rights.
// TODO(benvanik): validate no aliasing between inputs/outputs.
return iree_ok_status();
}
void iree_hal_command_buffer_initialize_validation(
iree_hal_allocator_t* device_allocator,
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* out_validation_state) {
out_validation_state->device_allocator = device_allocator;
out_validation_state->is_recording = false;
out_validation_state->debug_group_depth = 0;
}
iree_status_t iree_hal_command_buffer_begin_validation(
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state) {
if (validation_state->is_recording) {
return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
"command buffer is already in a recording state");
}
validation_state->is_recording = true;
return iree_ok_status();
}
iree_status_t iree_hal_command_buffer_end_validation(
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state) {
if (validation_state->debug_group_depth != 0) {
return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
"unbalanced debug group depth (expected 0, is %d)",
validation_state->debug_group_depth);
} else if (!validation_state->is_recording) {
return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
"command buffer is not in a recording state");
}
validation_state->is_recording = false;
return iree_ok_status();
}
void iree_hal_command_buffer_begin_debug_group_validation(
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
iree_string_view_t label, iree_hal_label_color_t label_color,
const iree_hal_label_location_t* location) {
++validation_state->debug_group_depth;
}
void iree_hal_command_buffer_end_debug_group_validation(
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state) {
--validation_state->debug_group_depth;
}
iree_status_t iree_hal_command_buffer_execution_barrier_validation(
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
iree_hal_execution_stage_t source_stage_mask,
iree_hal_execution_stage_t target_stage_mask,
iree_hal_execution_barrier_flags_t flags,
iree_host_size_t memory_barrier_count,
const iree_hal_memory_barrier_t* memory_barriers,
iree_host_size_t buffer_barrier_count,
const iree_hal_buffer_barrier_t* buffer_barriers) {
// NOTE: all command buffer types can perform this so no need to check.
// TODO(benvanik): additional synchronization validation.
return iree_ok_status();
}
iree_status_t iree_hal_command_buffer_signal_event_validation(
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
iree_hal_event_t* event, iree_hal_execution_stage_t source_stage_mask) {
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories(
command_buffer, validation_state, IREE_HAL_COMMAND_CATEGORY_DISPATCH));
// TODO(benvanik): additional synchronization validation.
return iree_ok_status();
}
iree_status_t iree_hal_command_buffer_reset_event_validation(
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
iree_hal_event_t* event, iree_hal_execution_stage_t source_stage_mask) {
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories(
command_buffer, validation_state, IREE_HAL_COMMAND_CATEGORY_DISPATCH));
// TODO(benvanik): additional synchronization validation.
return iree_ok_status();
}
iree_status_t iree_hal_command_buffer_wait_events_validation(
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
iree_host_size_t event_count, const iree_hal_event_t** events,
iree_hal_execution_stage_t source_stage_mask,
iree_hal_execution_stage_t target_stage_mask,
iree_host_size_t memory_barrier_count,
const iree_hal_memory_barrier_t* memory_barriers,
iree_host_size_t buffer_barrier_count,
const iree_hal_buffer_barrier_t* buffer_barriers) {
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories(
command_buffer, validation_state, IREE_HAL_COMMAND_CATEGORY_DISPATCH));
// TODO(benvanik): additional synchronization validation.
return iree_ok_status();
}
iree_status_t iree_hal_command_buffer_discard_buffer_validation(
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
iree_hal_buffer_ref_t buffer_ref) {
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories(
command_buffer, validation_state, IREE_HAL_COMMAND_CATEGORY_TRANSFER));
const iree_hal_buffer_binding_requirements_t buffer_reqs = {
.type = IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
.max_byte_offset = buffer_ref.offset + buffer_ref.length,
};
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_buffer_requirements(
command_buffer, validation_state, buffer_ref, buffer_reqs));
return iree_ok_status();
}
iree_status_t iree_hal_command_buffer_fill_buffer_validation(
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
iree_hal_buffer_ref_t target_ref, const void* pattern,
iree_host_size_t pattern_length) {
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories(
command_buffer, validation_state, IREE_HAL_COMMAND_CATEGORY_TRANSFER));
// Ensure the value length is supported.
if (pattern_length != 1 && pattern_length != 2 && pattern_length != 4) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"fill value length is not one of the supported "
"values (pattern_length=%" PRIhsz ")",
pattern_length);
}
if ((target_ref.offset % pattern_length) != 0 ||
(target_ref.length % pattern_length) != 0) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"binding offset and/or length do not match the required alignment of "
"one or more command (offset=%" PRIdsz ", length=%" PRIdsz
", pattern_length=%" PRIhsz ")",
target_ref.offset, target_ref.length, pattern_length);
}
const iree_hal_buffer_binding_requirements_t target_reqs = {
.required_compatibility = IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER,
.usage = IREE_HAL_BUFFER_USAGE_TRANSFER_TARGET,
.access = IREE_HAL_MEMORY_ACCESS_WRITE,
.type = IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
.max_byte_offset = target_ref.offset + target_ref.length,
.min_byte_alignment = pattern_length,
};
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_buffer_requirements(
command_buffer, validation_state, target_ref, target_reqs));
return iree_ok_status();
}
iree_status_t iree_hal_command_buffer_update_buffer_validation(
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
const void* source_buffer, iree_host_size_t source_offset,
iree_hal_buffer_ref_t target_ref) {
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories(
command_buffer, validation_state, IREE_HAL_COMMAND_CATEGORY_TRANSFER));
const iree_hal_buffer_binding_requirements_t target_reqs = {
.required_compatibility = IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER,
.usage = IREE_HAL_BUFFER_USAGE_TRANSFER_TARGET,
.access = IREE_HAL_MEMORY_ACCESS_WRITE,
.type = IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
.max_byte_offset = target_ref.offset + target_ref.length,
};
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_buffer_requirements(
command_buffer, validation_state, target_ref, target_reqs));
return iree_ok_status();
}
iree_status_t iree_hal_command_buffer_copy_buffer_validation(
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
iree_hal_buffer_ref_t source_ref, iree_hal_buffer_ref_t target_ref) {
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories(
command_buffer, validation_state, IREE_HAL_COMMAND_CATEGORY_TRANSFER));
if (source_ref.length != target_ref.length) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"copy spans between source and target must match "
"(source_length=%" PRIdsz ", target_length=%" PRIdsz
")",
source_ref.length, target_ref.length);
}
const iree_hal_buffer_binding_requirements_t source_reqs = {
.required_compatibility = IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER,
.usage = IREE_HAL_BUFFER_USAGE_TRANSFER_SOURCE,
.access = IREE_HAL_MEMORY_ACCESS_READ,
.type = IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
.max_byte_offset = source_ref.offset + source_ref.length,
};
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_buffer_requirements(
command_buffer, validation_state, source_ref, source_reqs));
const iree_hal_buffer_binding_requirements_t target_reqs = {
.required_compatibility = IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER,
.usage = IREE_HAL_BUFFER_USAGE_TRANSFER_TARGET,
.access = IREE_HAL_MEMORY_ACCESS_WRITE,
.type = IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
.max_byte_offset = target_ref.offset + target_ref.length,
};
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_buffer_requirements(
command_buffer, validation_state, target_ref, target_reqs));
// Check for overlap - just like memcpy we don't handle that.
// Note that it's only undefined behavior if violated so we are ok if tricky
// situations (subspans of subspans of binding table subranges etc) make it
// through.
if (iree_hal_buffer_test_overlap(source_ref.buffer, source_ref.offset,
source_ref.length, target_ref.buffer,
target_ref.offset, target_ref.length) !=
IREE_HAL_BUFFER_OVERLAP_DISJOINT) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"source and target ranges overlap within the same buffer");
}
return iree_ok_status();
}
iree_status_t iree_hal_command_buffer_collective_validation(
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
iree_hal_channel_t* channel, iree_hal_collective_op_t op, uint32_t param,
iree_hal_buffer_ref_t send_ref, iree_hal_buffer_ref_t recv_ref,
iree_device_size_t element_count) {
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories(
command_buffer, validation_state, IREE_HAL_COMMAND_CATEGORY_DISPATCH));
if (op.kind > IREE_HAL_COLLECTIVE_KIND_MAX_VALUE) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"unknown collective operation");
} else if (op.reduction > IREE_HAL_COLLECTIVE_REDUCTION_MAX_VALUE) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"unknown collective reduction");
} else if (op.element_type > IREE_HAL_COLLECTIVE_ELEMENT_TYPE_MAX_VALUE) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"unknown collective element type");
}
enum iree_hal_collective_info_bits_t {
IREE_HAL_COLLECTIVE_IS_REDUCTION = 1u << 0,
IREE_HAL_COLLECTIVE_REQUIRES_SEND_BINDING = 1u << 1,
IREE_HAL_COLLECTIVE_REQUIRES_RECV_BINDING = 1u << 2,
};
static const uint32_t
info_bits_table[IREE_HAL_COLLECTIVE_KIND_MAX_VALUE + 1] = {
[IREE_HAL_COLLECTIVE_KIND_ALL_GATHER] =
IREE_HAL_COLLECTIVE_REQUIRES_SEND_BINDING |
IREE_HAL_COLLECTIVE_REQUIRES_RECV_BINDING,
[IREE_HAL_COLLECTIVE_KIND_ALL_REDUCE] =
IREE_HAL_COLLECTIVE_IS_REDUCTION |
IREE_HAL_COLLECTIVE_REQUIRES_SEND_BINDING |
IREE_HAL_COLLECTIVE_REQUIRES_RECV_BINDING,
[IREE_HAL_COLLECTIVE_KIND_ALL_TO_ALL] =
IREE_HAL_COLLECTIVE_REQUIRES_SEND_BINDING |
IREE_HAL_COLLECTIVE_REQUIRES_RECV_BINDING,
[IREE_HAL_COLLECTIVE_KIND_BROADCAST] =
IREE_HAL_COLLECTIVE_REQUIRES_SEND_BINDING |
IREE_HAL_COLLECTIVE_REQUIRES_RECV_BINDING,
[IREE_HAL_COLLECTIVE_KIND_REDUCE] =
IREE_HAL_COLLECTIVE_IS_REDUCTION |
IREE_HAL_COLLECTIVE_REQUIRES_SEND_BINDING |
IREE_HAL_COLLECTIVE_REQUIRES_RECV_BINDING,
[IREE_HAL_COLLECTIVE_KIND_REDUCE_SCATTER] =
IREE_HAL_COLLECTIVE_IS_REDUCTION |
IREE_HAL_COLLECTIVE_REQUIRES_SEND_BINDING |
IREE_HAL_COLLECTIVE_REQUIRES_RECV_BINDING,
[IREE_HAL_COLLECTIVE_KIND_SEND] =
IREE_HAL_COLLECTIVE_REQUIRES_SEND_BINDING,
[IREE_HAL_COLLECTIVE_KIND_RECV] =
IREE_HAL_COLLECTIVE_REQUIRES_RECV_BINDING,
[IREE_HAL_COLLECTIVE_KIND_SEND_RECV] =
IREE_HAL_COLLECTIVE_REQUIRES_SEND_BINDING |
IREE_HAL_COLLECTIVE_REQUIRES_RECV_BINDING,
};
const uint32_t info_bits = info_bits_table[op.kind];
if (!(info_bits & IREE_HAL_COLLECTIVE_IS_REDUCTION) && op.reduction != 0) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"reduction operation cannot be specified on a non-reducing collective");
}
// TODO(benvanik): add queue cap/usage for COLLECTIVE source/dest?
if (info_bits & IREE_HAL_COLLECTIVE_REQUIRES_SEND_BINDING) {
const iree_hal_buffer_binding_requirements_t send_reqs = {
.required_compatibility = IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH,
.usage = IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE_READ,
.access = IREE_HAL_MEMORY_ACCESS_READ,
.type = IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
.max_byte_offset = send_ref.offset + send_ref.length,
};
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_buffer_requirements(
command_buffer, validation_state, send_ref, send_reqs));
} else {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"collective operation does not use a send buffer binding");
}
if (info_bits & IREE_HAL_COLLECTIVE_REQUIRES_RECV_BINDING) {
const iree_hal_buffer_binding_requirements_t recv_reqs = {
.required_compatibility = IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH,
.usage = IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE_WRITE,
.access = IREE_HAL_MEMORY_ACCESS_WRITE,
.type = IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
.max_byte_offset = recv_ref.offset + recv_ref.length,
};
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_buffer_requirements(
command_buffer, validation_state, recv_ref, recv_reqs));
} else {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"collective operation does not use a recv buffer binding");
}
return iree_ok_status();
}
iree_status_t iree_hal_command_buffer_push_constants_validation(
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
iree_hal_pipeline_layout_t* pipeline_layout, iree_host_size_t offset,
const void* values, iree_host_size_t values_length) {
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories(
command_buffer, validation_state, IREE_HAL_COMMAND_CATEGORY_DISPATCH));
if (IREE_UNLIKELY((values_length % 4) != 0)) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"invalid alignment %" PRIhsz ", must be 4-byte aligned", values_length);
}
// TODO(benvanik): validate offset and value count with layout.
return iree_ok_status();
}
iree_status_t iree_hal_command_buffer_push_descriptor_set_validation(
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set,
iree_host_size_t binding_count, const iree_hal_buffer_ref_t* bindings) {
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories(
command_buffer, validation_state, IREE_HAL_COMMAND_CATEGORY_DISPATCH));
// TODO(benvanik): validate set index.
// TODO(benvanik): use pipeline layout to derive usage and access bits.
iree_hal_buffer_binding_requirements_t requirements = {
.required_compatibility = IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH,
// .usage = IREE_HAL_BUFFER_USAGE_DISPATCH_...,
// .access = IREE_HAL_MEMORY_ACCESS_...,
.type = IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
};
for (iree_host_size_t i = 0; i < binding_count; ++i) {
// TODO(benvanik): validate binding ordinal against pipeline layout.
requirements.max_byte_offset = bindings[i].offset + bindings[i].length;
IREE_RETURN_IF_ERROR(
iree_hal_command_buffer_validate_buffer_requirements(
command_buffer, validation_state, bindings[i], requirements),
"set[%u] binding[%u] (arg[%" PRIhsz "])", set, bindings[i].ordinal, i);
}
return iree_ok_status();
}
iree_status_t iree_hal_command_buffer_dispatch_validation(
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
iree_hal_executable_t* executable, int32_t entry_point,
uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories(
command_buffer, validation_state, IREE_HAL_COMMAND_CATEGORY_DISPATCH));
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_dispatch_bindings(
command_buffer, validation_state, executable, entry_point));
return iree_ok_status();
}
iree_status_t iree_hal_command_buffer_dispatch_indirect_validation(
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
iree_hal_executable_t* executable, int32_t entry_point,
iree_hal_buffer_ref_t workgroups_ref) {
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories(
command_buffer, validation_state, IREE_HAL_COMMAND_CATEGORY_DISPATCH));
if ((workgroups_ref.offset % sizeof(uint32_t)) != 0) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"workgroup count offset does not match the required natural alignment "
"of uint32_t (offset=%" PRIdsz ", min_byte_alignment=%" PRIhsz ")",
workgroups_ref.offset, sizeof(uint32_t));
} else if (workgroups_ref.length < 3 * sizeof(uint32_t)) {
return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
"workgroup count buffer does not have the capacity "
"to store the required 3 uint32_t values "
"(length=%" PRIdsz ", min_length=%" PRIhsz ")",
workgroups_ref.length, 3 * sizeof(uint32_t));
}
const iree_hal_buffer_binding_requirements_t workgroups_reqs = {
.required_compatibility = IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH,
.usage = IREE_HAL_BUFFER_USAGE_DISPATCH_INDIRECT_PARAMS,
.access = IREE_HAL_MEMORY_ACCESS_READ,
.type = IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
.max_byte_offset = workgroups_ref.offset + workgroups_ref.length,
.min_byte_alignment = sizeof(uint32_t),
};
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_buffer_requirements(
command_buffer, validation_state, workgroups_ref, workgroups_reqs));
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_dispatch_bindings(
command_buffer, validation_state, executable, entry_point));
return iree_ok_status();
}
iree_status_t iree_hal_command_buffer_binding_table_validation(
iree_hal_command_buffer_t* command_buffer,
const iree_hal_command_buffer_validation_state_t* validation_state,
iree_hal_buffer_binding_table_t binding_table) {
IREE_TRACE_ZONE_BEGIN(z0);
IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, command_buffer->binding_count);
// NOTE: we only validate from [0, binding_count) and don't care if there are
// extra bindings present.
for (uint32_t i = 0; i < command_buffer->binding_count; ++i) {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0,
iree_hal_command_buffer_validate_binding_requirements(
command_buffer, validation_state, binding_table.bindings[i],
validation_state->binding_requirements[i]),
"binding table slot %u", i);
}
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
}