blob: 7f2626318a18cb8ba18f6f2423dcd84be9639ca0 [file] [log] [blame]
// Copyright 2020 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/hal/command_buffer.h"
#include <stddef.h>
#include "iree/base/api.h"
#include "iree/hal/command_buffer_validation.h"
#include "iree/hal/detail.h"
#include "iree/hal/device.h"
#include "iree/hal/resource.h"
// Conditionally executes an expression based on whether command buffer
// validation was enabled in the build and the command buffer wants validation.
#if IREE_HAL_COMMAND_BUFFER_VALIDATION_ENABLE
#define IF_VALIDATING(command_buffer, expr) \
if (((command_buffer)->mode & IREE_HAL_COMMAND_BUFFER_MODE_UNVALIDATED) == \
0) { \
expr; \
}
#define VALIDATION_STATE(command_buffer) \
((iree_hal_command_buffer_validation_state_t*)((command_buffer) \
->validation_state))
#else
#define IF_VALIDATING(command_buffer, expr)
#define VALIDATION_STATE(command_buffer) \
((iree_hal_command_buffer_validation_state_t*)NULL)
#endif // IREE_HAL_COMMAND_BUFFER_VALIDATION_ENABLE
#define _VTABLE_DISPATCH(command_buffer, method_name) \
IREE_HAL_VTABLE_DISPATCH(command_buffer, iree_hal_command_buffer, method_name)
//===----------------------------------------------------------------------===//
// String utils
//===----------------------------------------------------------------------===//
IREE_API_EXPORT iree_string_view_t iree_hal_collective_op_format(
const iree_hal_collective_op_t* op, iree_bitfield_string_temp_t* out_temp) {
static const iree_string_view_t
kind_names[IREE_HAL_COLLECTIVE_KIND_MAX_VALUE + 1] = {
[IREE_HAL_COLLECTIVE_KIND_ALL_GATHER] = IREE_SVL("all_gather"),
[IREE_HAL_COLLECTIVE_KIND_ALL_REDUCE] = IREE_SVL("all_reduce"),
[IREE_HAL_COLLECTIVE_KIND_ALL_TO_ALL] = IREE_SVL("all_to_all"),
[IREE_HAL_COLLECTIVE_KIND_BROADCAST] = IREE_SVL("broadcast"),
[IREE_HAL_COLLECTIVE_KIND_REDUCE] = IREE_SVL("reduce"),
[IREE_HAL_COLLECTIVE_KIND_REDUCE_SCATTER] =
IREE_SVL("reduce_scatter"),
[IREE_HAL_COLLECTIVE_KIND_SEND] = IREE_SVL("send"),
[IREE_HAL_COLLECTIVE_KIND_RECV] = IREE_SVL("recv"),
[IREE_HAL_COLLECTIVE_KIND_SEND_RECV] = IREE_SVL("send_recv"),
};
static const iree_string_view_t
reduction_names[IREE_HAL_COLLECTIVE_REDUCTION_MAX_VALUE + 1] = {
[IREE_HAL_COLLECTIVE_REDUCTION_SUM] = IREE_SVL("sum"),
[IREE_HAL_COLLECTIVE_REDUCTION_PRODUCT] = IREE_SVL("product"),
[IREE_HAL_COLLECTIVE_REDUCTION_MINIMUM] = IREE_SVL("minimum"),
[IREE_HAL_COLLECTIVE_REDUCTION_MAXIMUM] = IREE_SVL("maximum"),
[IREE_HAL_COLLECTIVE_REDUCTION_AVERAGE] = IREE_SVL("average"),
};
static const iree_string_view_t
element_type_names[IREE_HAL_COLLECTIVE_ELEMENT_TYPE_MAX_VALUE + 1] = {
[IREE_HAL_COLLECTIVE_ELEMENT_TYPE_SINT_8] = IREE_SVL("si8"),
[IREE_HAL_COLLECTIVE_ELEMENT_TYPE_UINT_8] = IREE_SVL("ui8"),
[IREE_HAL_COLLECTIVE_ELEMENT_TYPE_SINT_16] = IREE_SVL("si16"),
[IREE_HAL_COLLECTIVE_ELEMENT_TYPE_UINT_16] = IREE_SVL("ui16"),
[IREE_HAL_COLLECTIVE_ELEMENT_TYPE_SINT_32] = IREE_SVL("si32"),
[IREE_HAL_COLLECTIVE_ELEMENT_TYPE_UINT_32] = IREE_SVL("ui32"),
[IREE_HAL_COLLECTIVE_ELEMENT_TYPE_SINT_64] = IREE_SVL("si64"),
[IREE_HAL_COLLECTIVE_ELEMENT_TYPE_UINT_64] = IREE_SVL("ui64"),
[IREE_HAL_COLLECTIVE_ELEMENT_TYPE_FLOAT_16] = IREE_SVL("f16"),
[IREE_HAL_COLLECTIVE_ELEMENT_TYPE_FLOAT_32] = IREE_SVL("f32"),
[IREE_HAL_COLLECTIVE_ELEMENT_TYPE_FLOAT_64] = IREE_SVL("f64"),
[IREE_HAL_COLLECTIVE_ELEMENT_TYPE_BFLOAT_16] = IREE_SVL("bf16"),
};
IREE_ASSERT_LE((int)op->kind, IREE_HAL_COLLECTIVE_KIND_MAX_VALUE);
IREE_ASSERT_LE((int)op->reduction, IREE_HAL_COLLECTIVE_REDUCTION_MAX_VALUE);
IREE_ASSERT_LE((int)op->element_type,
IREE_HAL_COLLECTIVE_ELEMENT_TYPE_MAX_VALUE);
const iree_string_view_t kind_name = kind_names[(int)op->kind];
const iree_string_view_t element_type_name =
element_type_names[(int)op->element_type];
int length = 0;
switch (op->kind) {
default:
length = snprintf(out_temp->buffer, sizeof(out_temp->buffer),
"iree_hal_collective_%.*s_%.*s", (int)kind_name.size,
kind_name.data, (int)element_type_name.size,
element_type_name.data);
break;
case IREE_HAL_COLLECTIVE_KIND_ALL_REDUCE:
case IREE_HAL_COLLECTIVE_KIND_REDUCE:
case IREE_HAL_COLLECTIVE_KIND_REDUCE_SCATTER: {
const iree_string_view_t reduction_name =
reduction_names[(int)op->reduction];
length = snprintf(out_temp->buffer, sizeof(out_temp->buffer),
"iree_hal_collective_%.*s_%.*s_%.*s",
(int)kind_name.size, kind_name.data,
(int)reduction_name.size, reduction_name.data,
(int)element_type_name.size, element_type_name.data);
break;
}
}
return length > 0 ? iree_make_string_view(out_temp->buffer, length)
: IREE_SV("iree_hal_collective_unknown");
}
IREE_API_EXPORT iree_string_view_t
iree_hal_command_buffer_mode_format(iree_hal_command_buffer_mode_t value,
iree_bitfield_string_temp_t* out_temp) {
static const iree_bitfield_string_mapping_t mappings[] = {
{IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, IREE_SVL("ONE_SHOT")},
{IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION,
IREE_SVL("ALLOW_INLINE_EXECUTION")},
{IREE_HAL_COMMAND_BUFFER_MODE_UNVALIDATED, IREE_SVL("UNVALIDATED")},
};
return iree_bitfield_format_inline(value, IREE_ARRAYSIZE(mappings), mappings,
out_temp);
}
IREE_API_EXPORT iree_string_view_t iree_hal_command_category_format(
iree_hal_command_category_t value, iree_bitfield_string_temp_t* out_temp) {
static const iree_bitfield_string_mapping_t mappings[] = {
// Combined:
{IREE_HAL_COMMAND_CATEGORY_ANY, IREE_SVL("ANY")},
// Separate:
{IREE_HAL_COMMAND_CATEGORY_TRANSFER, IREE_SVL("TRANSFER")},
{IREE_HAL_COMMAND_CATEGORY_DISPATCH, IREE_SVL("DISPATCH")},
};
return iree_bitfield_format_inline(value, IREE_ARRAYSIZE(mappings), mappings,
out_temp);
}
//===----------------------------------------------------------------------===//
// iree_hal_collective_element_t
//===----------------------------------------------------------------------===//
IREE_API_EXPORT iree_device_size_t iree_hal_collective_element_byte_count(
iree_hal_collective_element_type_t element_type) {
switch (element_type) {
case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_SINT_8:
case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_UINT_8:
return 1;
case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_SINT_16:
case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_UINT_16:
case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_FLOAT_16:
case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_BFLOAT_16:
return 2;
case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_SINT_32:
case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_UINT_32:
case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_FLOAT_32:
return 4;
case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_SINT_64:
case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_UINT_64:
case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_FLOAT_64:
return 8;
default:
IREE_ASSERT(false, "unhandled element type for collective op");
return 0;
}
}
//===----------------------------------------------------------------------===//
// iree_hal_command_buffer_t
//===----------------------------------------------------------------------===//
IREE_HAL_API_RETAIN_RELEASE(command_buffer);
IREE_API_EXPORT iree_host_size_t iree_hal_command_buffer_validation_state_size(
iree_hal_command_buffer_mode_t mode, iree_host_size_t binding_capacity) {
#if IREE_HAL_COMMAND_BUFFER_VALIDATION_ENABLE
return ((mode & IREE_HAL_COMMAND_BUFFER_MODE_UNVALIDATED) == 0)
? sizeof(iree_hal_command_buffer_validation_state_t) +
binding_capacity *
sizeof(iree_hal_buffer_binding_requirements_t)
: 0;
#else
return 0;
#endif // IREE_HAL_COMMAND_BUFFER_VALIDATION_ENABLE
}
IREE_API_EXPORT void iree_hal_command_buffer_initialize(
iree_hal_allocator_t* device_allocator, iree_hal_command_buffer_mode_t mode,
iree_hal_command_category_t command_categories,
iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity,
void* validation_state, const iree_hal_command_buffer_vtable_t* vtable,
iree_hal_command_buffer_t* command_buffer) {
#if IREE_HAL_COMMAND_BUFFER_VALIDATION_ENABLE
// If validation is compiled in and the command buffer requires validation
// then check that state was provided.
IREE_ASSERT(
iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_UNVALIDATED) ||
validation_state);
#else
// If validation is not compiled in then force the disable bit. This helps
// prevent issues with dynamic libraries that may be compiled with a different
// setting, but we don't really support that kind of shady use anyway.
mode &= ~IREE_HAL_COMMAND_BUFFER_MODE_UNVALIDATED;
#endif // !IREE_HAL_COMMAND_BUFFER_VALIDATION_ENABLE
iree_hal_resource_initialize(vtable, &command_buffer->resource);
command_buffer->mode = mode;
command_buffer->allowed_categories = command_categories;
command_buffer->queue_affinity = queue_affinity;
command_buffer->binding_capacity = binding_capacity;
command_buffer->binding_count = 0;
command_buffer->validation_state = validation_state;
// Perform initialization validation after we allocate/initialize the concrete
// implementation.
IF_VALIDATING(command_buffer, {
iree_hal_command_buffer_initialize_validation(
device_allocator, command_buffer, VALIDATION_STATE(command_buffer));
});
}
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_create(
iree_hal_device_t* device, iree_hal_command_buffer_mode_t mode,
iree_hal_command_category_t command_categories,
iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity,
iree_hal_command_buffer_t** out_command_buffer) {
IREE_ASSERT_ARGUMENT(device);
IREE_ASSERT_ARGUMENT(out_command_buffer);
*out_command_buffer = NULL;
if (iree_all_bits_set(mode,
IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION)) {
if (!iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT)) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"inline command buffers must be one-shot");
} else if (binding_capacity > 0) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"inline command buffers cannot have indirect bindings");
}
}
IREE_TRACE_ZONE_BEGIN(z0);
iree_status_t status =
IREE_HAL_VTABLE_DISPATCH(device, iree_hal_device, create_command_buffer)(
device, mode, command_categories, queue_affinity, binding_capacity,
out_command_buffer);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_hal_command_buffer_mode_t
iree_hal_command_buffer_mode(const iree_hal_command_buffer_t* command_buffer) {
IREE_ASSERT_ARGUMENT(command_buffer);
return command_buffer->mode;
}
IREE_API_EXPORT iree_hal_command_category_t
iree_hal_command_buffer_allowed_categories(
const iree_hal_command_buffer_t* command_buffer) {
IREE_ASSERT_ARGUMENT(command_buffer);
return command_buffer->allowed_categories;
}
IREE_API_EXPORT iree_status_t
iree_hal_command_buffer_begin(iree_hal_command_buffer_t* command_buffer) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
IF_VALIDATING(command_buffer, {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_begin_validation(
command_buffer, VALIDATION_STATE(command_buffer)));
});
iree_status_t status =
_VTABLE_DISPATCH(command_buffer, begin)(command_buffer);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t
iree_hal_command_buffer_end(iree_hal_command_buffer_t* command_buffer) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
IF_VALIDATING(command_buffer, {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_end_validation(
command_buffer, VALIDATION_STATE(command_buffer)));
});
iree_status_t status = _VTABLE_DISPATCH(command_buffer, end)(command_buffer);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT void iree_hal_command_buffer_begin_debug_group(
iree_hal_command_buffer_t* command_buffer, iree_string_view_t label,
iree_hal_label_color_t label_color,
const iree_hal_label_location_t* location) {
IREE_ASSERT_ARGUMENT(command_buffer);
IF_VALIDATING(command_buffer,
iree_hal_command_buffer_begin_debug_group_validation(
command_buffer, VALIDATION_STATE(command_buffer), label,
label_color, location));
_VTABLE_DISPATCH(command_buffer, begin_debug_group)
(command_buffer, label, label_color, location);
}
IREE_API_EXPORT void iree_hal_command_buffer_end_debug_group(
iree_hal_command_buffer_t* command_buffer) {
IREE_ASSERT_ARGUMENT(command_buffer);
IF_VALIDATING(command_buffer,
iree_hal_command_buffer_end_debug_group_validation(
command_buffer, VALIDATION_STATE(command_buffer)));
_VTABLE_DISPATCH(command_buffer, end_debug_group)
(command_buffer);
}
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_execution_barrier(
iree_hal_command_buffer_t* command_buffer,
iree_hal_execution_stage_t source_stage_mask,
iree_hal_execution_stage_t target_stage_mask,
iree_hal_execution_barrier_flags_t flags,
iree_host_size_t memory_barrier_count,
const iree_hal_memory_barrier_t* memory_barriers,
iree_host_size_t buffer_barrier_count,
const iree_hal_buffer_barrier_t* buffer_barriers) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
IF_VALIDATING(command_buffer, {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0,
iree_hal_command_buffer_execution_barrier_validation(
command_buffer, VALIDATION_STATE(command_buffer), source_stage_mask,
target_stage_mask, flags, memory_barrier_count, memory_barriers,
buffer_barrier_count, buffer_barriers));
});
iree_status_t status = _VTABLE_DISPATCH(command_buffer, execution_barrier)(
command_buffer, source_stage_mask, target_stage_mask, flags,
memory_barrier_count, memory_barriers, buffer_barrier_count,
buffer_barriers);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_signal_event(
iree_hal_command_buffer_t* command_buffer, iree_hal_event_t* event,
iree_hal_execution_stage_t source_stage_mask) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_ASSERT_ARGUMENT(event);
IREE_TRACE_ZONE_BEGIN(z0);
IF_VALIDATING(command_buffer, {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_signal_event_validation(
command_buffer, VALIDATION_STATE(command_buffer), event,
source_stage_mask));
});
iree_status_t status = _VTABLE_DISPATCH(command_buffer, signal_event)(
command_buffer, event, source_stage_mask);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_reset_event(
iree_hal_command_buffer_t* command_buffer, iree_hal_event_t* event,
iree_hal_execution_stage_t source_stage_mask) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_ASSERT_ARGUMENT(event);
IREE_TRACE_ZONE_BEGIN(z0);
IF_VALIDATING(command_buffer, {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_reset_event_validation(
command_buffer, VALIDATION_STATE(command_buffer), event,
source_stage_mask));
});
iree_status_t status = _VTABLE_DISPATCH(command_buffer, reset_event)(
command_buffer, event, source_stage_mask);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_wait_events(
iree_hal_command_buffer_t* command_buffer, iree_host_size_t event_count,
const iree_hal_event_t** events,
iree_hal_execution_stage_t source_stage_mask,
iree_hal_execution_stage_t target_stage_mask,
iree_host_size_t memory_barrier_count,
const iree_hal_memory_barrier_t* memory_barriers,
iree_host_size_t buffer_barrier_count,
const iree_hal_buffer_barrier_t* buffer_barriers) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_ASSERT_ARGUMENT(!event_count || events);
IREE_ASSERT_ARGUMENT(!memory_barrier_count || memory_barriers);
IREE_ASSERT_ARGUMENT(!buffer_barrier_count || buffer_barriers);
IREE_TRACE_ZONE_BEGIN(z0);
IF_VALIDATING(command_buffer, {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0,
iree_hal_command_buffer_wait_events_validation(
command_buffer, VALIDATION_STATE(command_buffer), event_count,
events, source_stage_mask, target_stage_mask, memory_barrier_count,
memory_barriers, buffer_barrier_count, buffer_barriers));
});
iree_status_t status = _VTABLE_DISPATCH(command_buffer, wait_events)(
command_buffer, event_count, events, source_stage_mask, target_stage_mask,
memory_barrier_count, memory_barriers, buffer_barrier_count,
buffer_barriers);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_advise_buffer(
iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_ref_t buffer_ref,
iree_hal_memory_advise_flags_t flags, uint64_t arg0, uint64_t arg1) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
IF_VALIDATING(command_buffer, {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_advise_buffer_validation(
command_buffer, VALIDATION_STATE(command_buffer), buffer_ref,
flags, arg0, arg1));
});
iree_status_t status = _VTABLE_DISPATCH(command_buffer, advise_buffer)(
command_buffer, buffer_ref, flags, arg0, arg1);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_fill_buffer(
iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_ref_t target_ref,
const void* pattern, iree_host_size_t pattern_length,
iree_hal_fill_flags_t flags) {
IREE_ASSERT_ARGUMENT(command_buffer);
if (target_ref.length == 0) {
// No-op fill. All other validation is skipped.
return iree_ok_status();
}
IREE_TRACE_ZONE_BEGIN(z0);
IF_VALIDATING(command_buffer, {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_fill_buffer_validation(
command_buffer, VALIDATION_STATE(command_buffer), target_ref,
pattern, pattern_length, flags));
});
iree_status_t status = _VTABLE_DISPATCH(command_buffer, fill_buffer)(
command_buffer, target_ref, pattern, pattern_length, flags);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_update_buffer(
iree_hal_command_buffer_t* command_buffer, const void* source_buffer,
iree_host_size_t source_offset, iree_hal_buffer_ref_t target_ref,
iree_hal_update_flags_t flags) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_ASSERT_ARGUMENT(source_buffer);
if (target_ref.length == 0) {
// No-op update. All other validation is skipped.
return iree_ok_status();
}
IREE_TRACE_ZONE_BEGIN(z0);
IF_VALIDATING(command_buffer, {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_update_buffer_validation(
command_buffer, VALIDATION_STATE(command_buffer), source_buffer,
source_offset, target_ref, flags));
});
iree_status_t status = _VTABLE_DISPATCH(command_buffer, update_buffer)(
command_buffer, source_buffer, source_offset, target_ref, flags);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_copy_buffer(
iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_ref_t source_ref,
iree_hal_buffer_ref_t target_ref, iree_hal_copy_flags_t flags) {
IREE_ASSERT_ARGUMENT(command_buffer);
if (target_ref.length == 0) {
// No-op copy. All other validation is skipped.
return iree_ok_status();
}
IREE_TRACE_ZONE_BEGIN(z0);
IF_VALIDATING(command_buffer, {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_copy_buffer_validation(
command_buffer, VALIDATION_STATE(command_buffer), source_ref,
target_ref, flags));
});
iree_status_t status = _VTABLE_DISPATCH(command_buffer, copy_buffer)(
command_buffer, source_ref, target_ref, flags);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_collective(
iree_hal_command_buffer_t* command_buffer, iree_hal_channel_t* channel,
iree_hal_collective_op_t op, uint32_t param, iree_hal_buffer_ref_t send_ref,
iree_hal_buffer_ref_t recv_ref, iree_device_size_t element_count) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_ASSERT_ARGUMENT(channel);
IREE_TRACE_ZONE_BEGIN(z0);
IF_VALIDATING(command_buffer, {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_collective_validation(
command_buffer, VALIDATION_STATE(command_buffer), channel, op,
param, send_ref, recv_ref, element_count));
});
#if IREE_HAL_VERBOSE_TRACING_ENABLE
IREE_TRACE({
iree_bitfield_string_temp_t string_temp;
iree_string_view_t collective_str =
iree_hal_collective_op_format(&op, &string_temp);
IREE_TRACE_ZONE_APPEND_TEXT(z0, collective_str.data, collective_str.size);
});
#endif // IREE_HAL_VERBOSE_TRACING_ENABLE
iree_status_t status = _VTABLE_DISPATCH(command_buffer, collective)(
command_buffer, channel, op, param, send_ref, recv_ref, element_count);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch(
iree_hal_command_buffer_t* command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
const uint32_t workgroup_count[3], iree_const_byte_span_t constants,
iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_ASSERT_ARGUMENT(executable);
if ((workgroup_count[0] | workgroup_count[1] | workgroup_count[2]) == 0) {
// No-op dispatch. All implementations are expected to do this but we ensure
// it happens here to avoid the overhead of going all the way down into the
// device layer for something we know should have no (intentional)
// side-effects. Note that this does mean that validation is skipped and
// the executable/etc could be bogus but that's fine.
return iree_ok_status();
}
IREE_TRACE_ZONE_BEGIN(z0);
#if IREE_HAL_VERBOSE_TRACING_ENABLE
// TODO(benvanik): add a tracing.h helper that does the snprintf directly
// into a tracy_malloc buffer so that we can avoid the memcpy. Today this can
// take 4-5us which adds too much overhead when trying to get accurate timings
// with tracing enabled. Because benchmarks shouldn't be run with asserts
// enabled we only enable these when assertions are enabled. Ideally we'd
// slice off a much larger allocation and then suballocate from that ourselves
// so that we could avoid the tracy_malloc overheads per-dispatch.
IREE_TRACE({
char xyz_string[32];
int xyz_string_length =
snprintf(xyz_string, IREE_ARRAYSIZE(xyz_string), "%ux%ux%u",
workgroup_count[0], workgroup_count[1], workgroup_count[2]);
IREE_TRACE_ZONE_APPEND_TEXT(z0, xyz_string, xyz_string_length);
});
#endif // IREE_HAL_VERBOSE_TRACING_ENABLE
IF_VALIDATING(command_buffer, {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_dispatch_validation(
command_buffer, VALIDATION_STATE(command_buffer), executable,
entry_point, workgroup_count, constants, bindings, flags));
});
iree_status_t status = _VTABLE_DISPATCH(command_buffer, dispatch)(
command_buffer, executable, entry_point, workgroup_count, constants,
bindings, flags);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch_indirect(
iree_hal_command_buffer_t* command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants,
iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_ASSERT_ARGUMENT(executable);
IREE_TRACE_ZONE_BEGIN(z0);
IF_VALIDATING(command_buffer, {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_dispatch_indirect_validation(
command_buffer, VALIDATION_STATE(command_buffer), executable,
entry_point, workgroups_ref, constants, bindings, flags));
});
iree_status_t status = _VTABLE_DISPATCH(command_buffer, dispatch_indirect)(
command_buffer, executable, entry_point, workgroups_ref, constants,
bindings, flags);
IREE_TRACE_ZONE_END(z0);
return status;
}
//===----------------------------------------------------------------------===//
// Validation support
//===----------------------------------------------------------------------===//
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_validate_submission(
iree_hal_command_buffer_t* command_buffer,
iree_hal_buffer_binding_table_t binding_table) {
IREE_ASSERT_ARGUMENT(command_buffer);
// Validate the command buffer has been recorded properly.
IF_VALIDATING(command_buffer, {
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_submission_validation(
command_buffer, VALIDATION_STATE(command_buffer)));
});
// Only check binding tables when one is required and otherwise ignore any
// bindings provided. Require at least as many bindings in the table as there
// are used by the command buffer. This may be less than the total capacity
// the command buffer was allocated with.
if (command_buffer->binding_count == 0) {
return iree_ok_status();
} else if (binding_table.count == 0) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"indirect command buffer requires at least %u "
"bindings but no binding table was provided",
command_buffer->binding_count);
} else if (binding_table.count < command_buffer->binding_count) {
return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
"indirect command buffer requires at least %u "
"bindings but only %" PRIhsz " were provided ",
command_buffer->binding_count, binding_table.count);
}
// Validate the binding table against the commands consuming them.
// This is O(binding_count) so something we only do if validation is
// requested on the command buffer.
IF_VALIDATING(command_buffer, {
IREE_RETURN_IF_ERROR(iree_hal_command_buffer_binding_table_validation(
command_buffer, VALIDATION_STATE(command_buffer), binding_table));
});
return iree_ok_status();
}
//===----------------------------------------------------------------------===//
// Utilities for command buffer creation
//===----------------------------------------------------------------------===//
IREE_API_EXPORT iree_status_t iree_hal_create_transfer_command_buffer(
iree_hal_device_t* device, iree_hal_command_buffer_mode_t mode,
iree_hal_queue_affinity_t queue_affinity, iree_host_size_t transfer_count,
const iree_hal_transfer_command_t* transfer_commands,
iree_hal_command_buffer_t** out_command_buffer) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_create(
device, mode, IREE_HAL_COMMAND_CATEGORY_TRANSFER, queue_affinity,
/*binding_capacity=*/0, &command_buffer));
iree_status_t status = iree_hal_command_buffer_begin(command_buffer);
if (iree_status_is_ok(status)) {
for (iree_host_size_t i = 0; i < transfer_count; ++i) {
const iree_hal_transfer_command_t* transfer_command =
&transfer_commands[i];
switch (transfer_command->type) {
case IREE_HAL_TRANSFER_COMMAND_TYPE_FILL:
status = iree_hal_command_buffer_fill_buffer(
command_buffer,
iree_hal_make_buffer_ref(transfer_command->fill.target_buffer,
transfer_command->fill.target_offset,
transfer_command->fill.length),
transfer_command->fill.pattern,
transfer_command->fill.pattern_length, IREE_HAL_FILL_FLAG_NONE);
break;
case IREE_HAL_TRANSFER_COMMAND_TYPE_UPDATE:
status = iree_hal_command_buffer_update_buffer(
command_buffer, transfer_command->update.source_buffer,
transfer_command->update.source_offset,
iree_hal_make_buffer_ref(transfer_command->update.target_buffer,
transfer_command->update.target_offset,
transfer_command->update.length),
IREE_HAL_UPDATE_FLAG_NONE);
break;
case IREE_HAL_TRANSFER_COMMAND_TYPE_COPY:
status = iree_hal_command_buffer_copy_buffer(
command_buffer,
iree_hal_make_buffer_ref(transfer_command->copy.source_buffer,
transfer_command->copy.source_offset,
transfer_command->copy.length),
iree_hal_make_buffer_ref(transfer_command->copy.target_buffer,
transfer_command->copy.target_offset,
transfer_command->copy.length),
IREE_HAL_COPY_FLAG_NONE);
break;
default:
status =
iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"unknown transfer_commands[%" PRIhsz "] type %d",
i, (int)transfer_command->type);
break;
}
if (!iree_status_is_ok(status)) break;
}
}
status =
iree_status_join(status, iree_hal_command_buffer_end(command_buffer));
if (iree_status_is_ok(status)) {
*out_command_buffer = command_buffer;
} else {
iree_hal_command_buffer_release(command_buffer);
}
IREE_TRACE_ZONE_END(z0);
return status;
}