blob: e2bba37914070af2c40385e212f3e62f80e53e8e [file] [log] [blame]
// Copyright 2020 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/hal/command_buffer.h"
#include <stddef.h>
#include "iree/base/api.h"
#include "iree/hal/command_buffer_validation.h"
#include "iree/hal/detail.h"
#include "iree/hal/device.h"
#include "iree/hal/resource.h"
// Conditionally executes an expression based on whether command buffer
// validation was enabled in the build and the command buffer wants validation.
#if IREE_HAL_COMMAND_BUFFER_VALIDATION_ENABLE
#define IF_VALIDATING(command_buffer, expr) \
if (((command_buffer)->mode & IREE_HAL_COMMAND_BUFFER_MODE_UNVALIDATED) == \
0) { \
expr; \
}
#define VALIDATION_STATE(command_buffer) (&(command_buffer)->validation)
#else
#define IF_VALIDATING(command_buffer, expr)
#define VALIDATION_STATE(command_buffer) \
((iree_hal_command_buffer_validation_state_t*)NULL)
#endif // IREE_HAL_COMMAND_BUFFER_VALIDATION_ENABLE
#define _VTABLE_DISPATCH(command_buffer, method_name) \
IREE_HAL_VTABLE_DISPATCH(command_buffer, iree_hal_command_buffer, method_name)
//===----------------------------------------------------------------------===//
// String utils
//===----------------------------------------------------------------------===//
IREE_API_EXPORT iree_string_view_t iree_hal_collective_op_format(
const iree_hal_collective_op_t* op, iree_bitfield_string_temp_t* out_temp) {
static const iree_string_view_t
kind_names[IREE_HAL_COLLECTIVE_KIND_MAX_VALUE + 1] = {
[IREE_HAL_COLLECTIVE_KIND_ALL_GATHER] = IREE_SVL("all_gather"),
[IREE_HAL_COLLECTIVE_KIND_ALL_REDUCE] = IREE_SVL("all_reduce"),
[IREE_HAL_COLLECTIVE_KIND_ALL_TO_ALL] = IREE_SVL("all_to_all"),
[IREE_HAL_COLLECTIVE_KIND_BROADCAST] = IREE_SVL("broadcast"),
[IREE_HAL_COLLECTIVE_KIND_REDUCE] = IREE_SVL("reduce"),
[IREE_HAL_COLLECTIVE_KIND_REDUCE_SCATTER] =
IREE_SVL("reduce_scatter"),
[IREE_HAL_COLLECTIVE_KIND_SEND] = IREE_SVL("send"),
[IREE_HAL_COLLECTIVE_KIND_RECV] = IREE_SVL("recv"),
[IREE_HAL_COLLECTIVE_KIND_SEND_RECV] = IREE_SVL("send_recv"),
};
static const iree_string_view_t
reduction_names[IREE_HAL_COLLECTIVE_REDUCTION_MAX_VALUE + 1] = {
[IREE_HAL_COLLECTIVE_REDUCTION_SUM] = IREE_SVL("sum"),
[IREE_HAL_COLLECTIVE_REDUCTION_PRODUCT] = IREE_SVL("product"),
[IREE_HAL_COLLECTIVE_REDUCTION_MINIMUM] = IREE_SVL("minimum"),
[IREE_HAL_COLLECTIVE_REDUCTION_MAXIMUM] = IREE_SVL("maximum"),
[IREE_HAL_COLLECTIVE_REDUCTION_AVERAGE] = IREE_SVL("average"),
};
static const iree_string_view_t
element_type_names[IREE_HAL_COLLECTIVE_ELEMENT_TYPE_MAX_VALUE + 1] = {
[IREE_HAL_COLLECTIVE_ELEMENT_TYPE_SINT_8] = IREE_SVL("si8"),
[IREE_HAL_COLLECTIVE_ELEMENT_TYPE_UINT_8] = IREE_SVL("ui8"),
[IREE_HAL_COLLECTIVE_ELEMENT_TYPE_SINT_16] = IREE_SVL("si16"),
[IREE_HAL_COLLECTIVE_ELEMENT_TYPE_UINT_16] = IREE_SVL("ui16"),
[IREE_HAL_COLLECTIVE_ELEMENT_TYPE_SINT_32] = IREE_SVL("si32"),
[IREE_HAL_COLLECTIVE_ELEMENT_TYPE_UINT_32] = IREE_SVL("ui32"),
[IREE_HAL_COLLECTIVE_ELEMENT_TYPE_SINT_64] = IREE_SVL("si64"),
[IREE_HAL_COLLECTIVE_ELEMENT_TYPE_UINT_64] = IREE_SVL("ui64"),
[IREE_HAL_COLLECTIVE_ELEMENT_TYPE_FLOAT_16] = IREE_SVL("f16"),
[IREE_HAL_COLLECTIVE_ELEMENT_TYPE_FLOAT_32] = IREE_SVL("f32"),
[IREE_HAL_COLLECTIVE_ELEMENT_TYPE_FLOAT_64] = IREE_SVL("f64"),
[IREE_HAL_COLLECTIVE_ELEMENT_TYPE_BFLOAT_16] = IREE_SVL("bf16"),
};
IREE_ASSERT_LE((int)op->kind, IREE_HAL_COLLECTIVE_KIND_MAX_VALUE);
IREE_ASSERT_LE((int)op->reduction, IREE_HAL_COLLECTIVE_REDUCTION_MAX_VALUE);
IREE_ASSERT_LE((int)op->element_type,
IREE_HAL_COLLECTIVE_ELEMENT_TYPE_MAX_VALUE);
const iree_string_view_t kind_name = kind_names[(int)op->kind];
const iree_string_view_t element_type_name =
element_type_names[(int)op->element_type];
int length = 0;
switch (op->kind) {
default:
length = snprintf(out_temp->buffer, sizeof(out_temp->buffer),
"iree_hal_collective_%.*s_%.*s", (int)kind_name.size,
kind_name.data, (int)element_type_name.size,
element_type_name.data);
break;
case IREE_HAL_COLLECTIVE_KIND_ALL_REDUCE:
case IREE_HAL_COLLECTIVE_KIND_REDUCE:
case IREE_HAL_COLLECTIVE_KIND_REDUCE_SCATTER: {
const iree_string_view_t reduction_name =
reduction_names[(int)op->reduction];
length = snprintf(out_temp->buffer, sizeof(out_temp->buffer),
"iree_hal_collective_%.*s_%.*s_%.*s",
(int)kind_name.size, kind_name.data,
(int)reduction_name.size, reduction_name.data,
(int)element_type_name.size, element_type_name.data);
break;
}
}
return length > 0 ? iree_make_string_view(out_temp->buffer, length)
: IREE_SV("iree_hal_collective_unknown");
}
IREE_API_EXPORT iree_string_view_t
iree_hal_command_buffer_mode_format(iree_hal_command_buffer_mode_t value,
iree_bitfield_string_temp_t* out_temp) {
static const iree_bitfield_string_mapping_t mappings[] = {
{IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, IREE_SVL("ONE_SHOT")},
{IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION,
IREE_SVL("ALLOW_INLINE_EXECUTION")},
{IREE_HAL_COMMAND_BUFFER_MODE_UNVALIDATED, IREE_SVL("UNVALIDATED")},
};
return iree_bitfield_format_inline(value, IREE_ARRAYSIZE(mappings), mappings,
out_temp);
}
IREE_API_EXPORT iree_string_view_t iree_hal_command_category_format(
iree_hal_command_category_t value, iree_bitfield_string_temp_t* out_temp) {
static const iree_bitfield_string_mapping_t mappings[] = {
// Combined:
{IREE_HAL_COMMAND_CATEGORY_ANY, IREE_SVL("ANY")},
// Separate:
{IREE_HAL_COMMAND_CATEGORY_TRANSFER, IREE_SVL("TRANSFER")},
{IREE_HAL_COMMAND_CATEGORY_DISPATCH, IREE_SVL("DISPATCH")},
};
return iree_bitfield_format_inline(value, IREE_ARRAYSIZE(mappings), mappings,
out_temp);
}
//===----------------------------------------------------------------------===//
// iree_hal_collective_element_t
//===----------------------------------------------------------------------===//
IREE_API_EXPORT iree_device_size_t iree_hal_collective_element_byte_count(
iree_hal_collective_element_type_t element_type) {
switch (element_type) {
case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_SINT_8:
case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_UINT_8:
return 1;
case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_SINT_16:
case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_UINT_16:
case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_FLOAT_16:
case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_BFLOAT_16:
return 2;
case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_SINT_32:
case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_UINT_32:
case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_FLOAT_32:
return 4;
case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_SINT_64:
case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_UINT_64:
case IREE_HAL_COLLECTIVE_ELEMENT_TYPE_FLOAT_64:
return 8;
default:
IREE_ASSERT(false, "unhandled element type for collective op");
return 0;
}
}
//===----------------------------------------------------------------------===//
// iree_hal_command_buffer_t
//===----------------------------------------------------------------------===//
IREE_HAL_API_RETAIN_RELEASE(command_buffer);
IREE_API_EXPORT void iree_hal_command_buffer_initialize(
iree_hal_device_t* device, iree_hal_command_buffer_mode_t mode,
iree_hal_command_category_t command_categories,
iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity,
const iree_hal_command_buffer_vtable_t* vtable,
iree_hal_command_buffer_t* command_buffer) {
iree_hal_resource_initialize(vtable, &command_buffer->resource);
command_buffer->mode = mode;
command_buffer->allowed_categories = command_categories;
command_buffer->queue_affinity = queue_affinity;
command_buffer->binding_capacity = binding_capacity;
// Perform initialization validation after we allocate/initialize the concrete
// implementation.
IF_VALIDATING(command_buffer, {
iree_hal_command_buffer_initialize_validation(
device, command_buffer, VALIDATION_STATE(command_buffer));
});
}
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_create(
iree_hal_device_t* device, iree_hal_command_buffer_mode_t mode,
iree_hal_command_category_t command_categories,
iree_hal_queue_affinity_t queue_affinity, iree_host_size_t binding_capacity,
iree_hal_command_buffer_t** out_command_buffer) {
IREE_ASSERT_ARGUMENT(device);
IREE_ASSERT_ARGUMENT(out_command_buffer);
*out_command_buffer = NULL;
if (iree_all_bits_set(mode,
IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION)) {
if (!iree_all_bits_set(mode, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT)) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"inline command buffers must be one-shot");
}
}
IREE_TRACE_ZONE_BEGIN(z0);
iree_status_t status =
IREE_HAL_VTABLE_DISPATCH(device, iree_hal_device, create_command_buffer)(
device, mode, command_categories, queue_affinity, binding_capacity,
out_command_buffer);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_hal_command_buffer_mode_t
iree_hal_command_buffer_mode(const iree_hal_command_buffer_t* command_buffer) {
IREE_ASSERT_ARGUMENT(command_buffer);
return command_buffer->mode;
}
IREE_API_EXPORT iree_hal_command_category_t
iree_hal_command_buffer_allowed_categories(
const iree_hal_command_buffer_t* command_buffer) {
IREE_ASSERT_ARGUMENT(command_buffer);
return command_buffer->allowed_categories;
}
IREE_API_EXPORT iree_status_t
iree_hal_command_buffer_begin(iree_hal_command_buffer_t* command_buffer) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
IF_VALIDATING(command_buffer, {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_begin_validation(
command_buffer, VALIDATION_STATE(command_buffer)));
});
iree_status_t status =
_VTABLE_DISPATCH(command_buffer, begin)(command_buffer);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t
iree_hal_command_buffer_end(iree_hal_command_buffer_t* command_buffer) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
IF_VALIDATING(command_buffer, {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_end_validation(
command_buffer, VALIDATION_STATE(command_buffer)));
});
iree_status_t status = _VTABLE_DISPATCH(command_buffer, end)(command_buffer);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT void iree_hal_command_buffer_begin_debug_group(
iree_hal_command_buffer_t* command_buffer, iree_string_view_t label,
iree_hal_label_color_t label_color,
const iree_hal_label_location_t* location) {
IREE_ASSERT_ARGUMENT(command_buffer);
IF_VALIDATING(command_buffer,
iree_hal_command_buffer_begin_debug_group_validation(
command_buffer, VALIDATION_STATE(command_buffer), label,
label_color, location));
_VTABLE_DISPATCH(command_buffer, begin_debug_group)
(command_buffer, label, label_color, location);
}
IREE_API_EXPORT void iree_hal_command_buffer_end_debug_group(
iree_hal_command_buffer_t* command_buffer) {
IREE_ASSERT_ARGUMENT(command_buffer);
IF_VALIDATING(command_buffer,
iree_hal_command_buffer_end_debug_group_validation(
command_buffer, VALIDATION_STATE(command_buffer)));
_VTABLE_DISPATCH(command_buffer, end_debug_group)
(command_buffer);
}
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_execution_barrier(
iree_hal_command_buffer_t* command_buffer,
iree_hal_execution_stage_t source_stage_mask,
iree_hal_execution_stage_t target_stage_mask,
iree_hal_execution_barrier_flags_t flags,
iree_host_size_t memory_barrier_count,
const iree_hal_memory_barrier_t* memory_barriers,
iree_host_size_t buffer_barrier_count,
const iree_hal_buffer_barrier_t* buffer_barriers) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
IF_VALIDATING(command_buffer, {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0,
iree_hal_command_buffer_execution_barrier_validation(
command_buffer, VALIDATION_STATE(command_buffer), source_stage_mask,
target_stage_mask, flags, memory_barrier_count, memory_barriers,
buffer_barrier_count, buffer_barriers));
});
iree_status_t status = _VTABLE_DISPATCH(command_buffer, execution_barrier)(
command_buffer, source_stage_mask, target_stage_mask, flags,
memory_barrier_count, memory_barriers, buffer_barrier_count,
buffer_barriers);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_signal_event(
iree_hal_command_buffer_t* command_buffer, iree_hal_event_t* event,
iree_hal_execution_stage_t source_stage_mask) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_ASSERT_ARGUMENT(event);
IREE_TRACE_ZONE_BEGIN(z0);
IF_VALIDATING(command_buffer, {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_signal_event_validation(
command_buffer, VALIDATION_STATE(command_buffer), event,
source_stage_mask));
});
iree_status_t status = _VTABLE_DISPATCH(command_buffer, signal_event)(
command_buffer, event, source_stage_mask);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_reset_event(
iree_hal_command_buffer_t* command_buffer, iree_hal_event_t* event,
iree_hal_execution_stage_t source_stage_mask) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_ASSERT_ARGUMENT(event);
IREE_TRACE_ZONE_BEGIN(z0);
IF_VALIDATING(command_buffer, {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_reset_event_validation(
command_buffer, VALIDATION_STATE(command_buffer), event,
source_stage_mask));
});
iree_status_t status = _VTABLE_DISPATCH(command_buffer, reset_event)(
command_buffer, event, source_stage_mask);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_wait_events(
iree_hal_command_buffer_t* command_buffer, iree_host_size_t event_count,
const iree_hal_event_t** events,
iree_hal_execution_stage_t source_stage_mask,
iree_hal_execution_stage_t target_stage_mask,
iree_host_size_t memory_barrier_count,
const iree_hal_memory_barrier_t* memory_barriers,
iree_host_size_t buffer_barrier_count,
const iree_hal_buffer_barrier_t* buffer_barriers) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_ASSERT_ARGUMENT(!event_count || events);
IREE_ASSERT_ARGUMENT(!memory_barrier_count || memory_barriers);
IREE_ASSERT_ARGUMENT(!buffer_barrier_count || buffer_barriers);
IREE_TRACE_ZONE_BEGIN(z0);
IF_VALIDATING(command_buffer, {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0,
iree_hal_command_buffer_wait_events_validation(
command_buffer, VALIDATION_STATE(command_buffer), event_count,
events, source_stage_mask, target_stage_mask, memory_barrier_count,
memory_barriers, buffer_barrier_count, buffer_barriers));
});
iree_status_t status = _VTABLE_DISPATCH(command_buffer, wait_events)(
command_buffer, event_count, events, source_stage_mask, target_stage_mask,
memory_barrier_count, memory_barriers, buffer_barrier_count,
buffer_barriers);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_discard_buffer(
iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_t* buffer) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_ASSERT_ARGUMENT(buffer);
IREE_TRACE_ZONE_BEGIN(z0);
IF_VALIDATING(command_buffer, {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_discard_buffer_validation(
command_buffer, VALIDATION_STATE(command_buffer), buffer));
});
iree_status_t status =
_VTABLE_DISPATCH(command_buffer, discard_buffer)(command_buffer, buffer);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_fill_buffer(
iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_t* target_buffer,
iree_device_size_t target_offset, iree_device_size_t length,
const void* pattern, iree_host_size_t pattern_length) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_ASSERT_ARGUMENT(target_buffer);
if (length == 0) {
// No-op fill. All other validation is skipped.
return iree_ok_status();
}
IREE_TRACE_ZONE_BEGIN(z0);
IF_VALIDATING(command_buffer, {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_fill_buffer_validation(
command_buffer, VALIDATION_STATE(command_buffer), target_buffer,
target_offset, length, pattern, pattern_length));
});
iree_status_t status = _VTABLE_DISPATCH(command_buffer, fill_buffer)(
command_buffer, target_buffer, target_offset, length, pattern,
pattern_length);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_update_buffer(
iree_hal_command_buffer_t* command_buffer, const void* source_buffer,
iree_host_size_t source_offset, iree_hal_buffer_t* target_buffer,
iree_device_size_t target_offset, iree_device_size_t length) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_ASSERT_ARGUMENT(source_buffer);
IREE_ASSERT_ARGUMENT(target_buffer);
if (length == 0) {
// No-op update. All other validation is skipped.
return iree_ok_status();
}
IREE_TRACE_ZONE_BEGIN(z0);
IF_VALIDATING(command_buffer, {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_update_buffer_validation(
command_buffer, VALIDATION_STATE(command_buffer), source_buffer,
source_offset, target_buffer, target_offset, length));
});
iree_status_t status = _VTABLE_DISPATCH(command_buffer, update_buffer)(
command_buffer, source_buffer, source_offset, target_buffer,
target_offset, length);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_copy_buffer(
iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_t* source_buffer,
iree_device_size_t source_offset, iree_hal_buffer_t* target_buffer,
iree_device_size_t target_offset, iree_device_size_t length) {
IREE_ASSERT_ARGUMENT(command_buffer);
if (length == 0) {
// No-op copy. All other validation is skipped.
return iree_ok_status();
}
IREE_TRACE_ZONE_BEGIN(z0);
IF_VALIDATING(command_buffer, {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_copy_buffer_validation(
command_buffer, VALIDATION_STATE(command_buffer), source_buffer,
source_offset, target_buffer, target_offset, length));
});
iree_status_t status = _VTABLE_DISPATCH(command_buffer, copy_buffer)(
command_buffer, source_buffer, source_offset, target_buffer,
target_offset, length);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_collective(
iree_hal_command_buffer_t* command_buffer, iree_hal_channel_t* channel,
iree_hal_collective_op_t op, uint32_t param,
iree_hal_buffer_binding_t send_binding,
iree_hal_buffer_binding_t recv_binding, iree_device_size_t element_count) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_ASSERT_ARGUMENT(channel);
IREE_TRACE_ZONE_BEGIN(z0);
IF_VALIDATING(command_buffer, {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_collective_validation(
command_buffer, VALIDATION_STATE(command_buffer), channel, op,
param, send_binding, recv_binding, element_count));
});
#if IREE_HAL_VERBOSE_TRACING_ENABLE
IREE_TRACE({
iree_bitfield_string_temp_t string_temp;
iree_string_view_t collective_str =
iree_hal_collective_op_format(&op, &string_temp);
IREE_TRACE_ZONE_APPEND_TEXT(z0, collective_str.data, collective_str.size);
});
#endif // IREE_HAL_VERBOSE_TRACING_ENABLE
iree_status_t status = _VTABLE_DISPATCH(command_buffer, collective)(
command_buffer, channel, op, param, send_binding, recv_binding,
element_count);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_push_constants(
iree_hal_command_buffer_t* command_buffer,
iree_hal_pipeline_layout_t* pipeline_layout, iree_host_size_t offset,
const void* values, iree_host_size_t values_length) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_ASSERT_ARGUMENT(pipeline_layout);
IREE_ASSERT_ARGUMENT(values);
if (IREE_UNLIKELY(values_length == 0)) {
return iree_ok_status();
}
IREE_TRACE_ZONE_BEGIN(z0);
IF_VALIDATING(command_buffer, {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_push_constants_validation(
command_buffer, VALIDATION_STATE(command_buffer),
pipeline_layout, offset, values, values_length));
});
iree_status_t status = _VTABLE_DISPATCH(command_buffer, push_constants)(
command_buffer, pipeline_layout, offset, values, values_length);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_push_descriptor_set(
iree_hal_command_buffer_t* command_buffer,
iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set,
iree_host_size_t binding_count,
const iree_hal_descriptor_set_binding_t* bindings) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_ASSERT_ARGUMENT(pipeline_layout);
IREE_ASSERT_ARGUMENT(!binding_count || bindings);
IREE_TRACE_ZONE_BEGIN(z0);
IF_VALIDATING(command_buffer, {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_push_descriptor_set_validation(
command_buffer, VALIDATION_STATE(command_buffer),
pipeline_layout, set, binding_count, bindings));
});
iree_status_t status = _VTABLE_DISPATCH(command_buffer, push_descriptor_set)(
command_buffer, pipeline_layout, set, binding_count, bindings);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch(
iree_hal_command_buffer_t* command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_ASSERT_ARGUMENT(executable);
if ((workgroup_x | workgroup_y | workgroup_z) == 0) {
// No-op dispatch. All implementations are expected to do this but we ensure
// it happens here to avoid the overhead of going all the way down into the
// device layer for something we know should have no (intentional)
// side-effects. Note that this does mean that validation is skipped and
// the executable/etc could be bogus but that's fine.
return iree_ok_status();
}
IREE_TRACE_ZONE_BEGIN(z0);
IF_VALIDATING(command_buffer, {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_dispatch_validation(
command_buffer, VALIDATION_STATE(command_buffer), executable,
entry_point, workgroup_x, workgroup_y, workgroup_z));
});
#if IREE_HAL_VERBOSE_TRACING_ENABLE
// TODO(benvanik): add a tracing.h helper that does the snprintf directly
// into a tracy_malloc buffer so that we can avoid the memcpy. Today this can
// take 4-5us which adds too much overhead when trying to get accurate timings
// with tracing enabled. Because benchmarks shouldn't be run with asserts
// enabled we only enable these when assertions are enabled. Ideally we'd
// slice off a much larger allocation and then suballocate from that ourselves
// so that we could avoid the tracy_malloc overheads per-dispatch.
IREE_TRACE({
char xyz_string[32];
int xyz_string_length =
snprintf(xyz_string, IREE_ARRAYSIZE(xyz_string), "%ux%ux%u",
workgroup_x, workgroup_y, workgroup_z);
IREE_TRACE_ZONE_APPEND_TEXT_STRING_VIEW(z0, xyz_string, xyz_string_length);
});
#endif // IREE_HAL_VERBOSE_TRACING_ENABLE
iree_status_t status = _VTABLE_DISPATCH(command_buffer, dispatch)(
command_buffer, executable, entry_point, workgroup_x, workgroup_y,
workgroup_z);
IREE_TRACE_ZONE_END(z0);
return status;
}
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch_indirect(
iree_hal_command_buffer_t* command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
iree_hal_buffer_t* workgroups_buffer,
iree_device_size_t workgroups_offset) {
IREE_ASSERT_ARGUMENT(command_buffer);
IREE_ASSERT_ARGUMENT(executable);
IREE_ASSERT_ARGUMENT(workgroups_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
IF_VALIDATING(command_buffer, {
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_dispatch_indirect_validation(
command_buffer, VALIDATION_STATE(command_buffer), executable,
entry_point, workgroups_buffer, workgroups_offset));
});
iree_status_t status = _VTABLE_DISPATCH(command_buffer, dispatch_indirect)(
command_buffer, executable, entry_point, workgroups_buffer,
workgroups_offset);
IREE_TRACE_ZONE_END(z0);
return status;
}
//===----------------------------------------------------------------------===//
// Validation support
//===----------------------------------------------------------------------===//
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_validate_binding_table(
iree_hal_command_buffer_t* command_buffer,
const iree_hal_buffer_binding_table_t* binding_table) {
IREE_ASSERT_ARGUMENT(command_buffer);
IF_VALIDATING(command_buffer, {
// Only check binding tables when one is required and otherwise ignore any
// bindings provided.
if (command_buffer->binding_capacity == 0) {
return iree_ok_status();
} else if (!binding_table ||
binding_table->count < command_buffer->binding_capacity) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"indirect command buffer requires at least %u "
"bindings but only %" PRIhsz " were provided ",
command_buffer->binding_capacity,
binding_table ? binding_table->count : 0);
}
// TODO(benvanik): validate each binding against the requirements of the
// command buffer.
});
return iree_ok_status();
}
//===----------------------------------------------------------------------===//
// Utilities for command buffer creation
//===----------------------------------------------------------------------===//
IREE_API_EXPORT iree_status_t iree_hal_create_transfer_command_buffer(
iree_hal_device_t* device, iree_hal_command_buffer_mode_t mode,
iree_hal_queue_affinity_t queue_affinity, iree_host_size_t transfer_count,
const iree_hal_transfer_command_t* transfer_commands,
iree_hal_command_buffer_t** out_command_buffer) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_command_buffer_t* command_buffer = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_command_buffer_create(
device, mode, IREE_HAL_COMMAND_CATEGORY_TRANSFER, queue_affinity,
/*binding_capacity=*/0, &command_buffer));
iree_status_t status = iree_hal_command_buffer_begin(command_buffer);
if (iree_status_is_ok(status)) {
for (iree_host_size_t i = 0; i < transfer_count; ++i) {
const iree_hal_transfer_command_t* transfer_command =
&transfer_commands[i];
switch (transfer_command->type) {
case IREE_HAL_TRANSFER_COMMAND_TYPE_FILL:
status = iree_hal_command_buffer_fill_buffer(
command_buffer, transfer_command->fill.target_buffer,
transfer_command->fill.target_offset,
transfer_command->fill.length, transfer_command->fill.pattern,
transfer_command->fill.pattern_length);
break;
case IREE_HAL_TRANSFER_COMMAND_TYPE_COPY:
status = iree_hal_command_buffer_copy_buffer(
command_buffer, transfer_command->copy.source_buffer,
transfer_command->copy.source_offset,
transfer_command->copy.target_buffer,
transfer_command->copy.target_offset,
transfer_command->copy.length);
break;
case IREE_HAL_TRANSFER_COMMAND_TYPE_UPDATE:
status = iree_hal_command_buffer_update_buffer(
command_buffer, transfer_command->update.source_buffer,
transfer_command->update.source_offset,
transfer_command->update.target_buffer,
transfer_command->update.target_offset,
transfer_command->update.length);
break;
default:
status =
iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"unknown transfer_commands[%" PRIhsz "] type %d",
i, (int)transfer_command->type);
break;
}
if (!iree_status_is_ok(status)) break;
}
}
status =
iree_status_join(status, iree_hal_command_buffer_end(command_buffer));
if (iree_status_is_ok(status)) {
*out_command_buffer = command_buffer;
} else {
iree_hal_command_buffer_release(command_buffer);
}
IREE_TRACE_ZONE_END(z0);
return status;
}