Adding iree_hal_channel_t and the iree_hal_command_buffer_collective API.
diff --git a/experimental/rocm/direct_command_buffer.c b/experimental/rocm/direct_command_buffer.c
index e068db0..3c2dcd9 100644
--- a/experimental/rocm/direct_command_buffer.c
+++ b/experimental/rocm/direct_command_buffer.c
@@ -262,6 +262,15 @@
return iree_ok_status();
}
+static iree_status_t iree_hal_rocm_direct_command_buffer_collective(
+ iree_hal_command_buffer_t* base_command_buffer, iree_hal_channel_t* channel,
+ iree_hal_collective_op_t op, uint32_t param,
+ iree_hal_buffer_binding_t send_binding,
+ iree_hal_buffer_binding_t recv_binding, iree_device_size_t element_count) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "need rocm implementation");
+}
+
static iree_status_t iree_hal_rocm_direct_command_buffer_push_constants(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_pipeline_layout_t* pipeline_layout, iree_host_size_t offset,
@@ -398,6 +407,7 @@
.fill_buffer = iree_hal_rocm_direct_command_buffer_fill_buffer,
.update_buffer = iree_hal_rocm_direct_command_buffer_update_buffer,
.copy_buffer = iree_hal_rocm_direct_command_buffer_copy_buffer,
+ .collective = iree_hal_rocm_direct_command_buffer_collective,
.push_constants = iree_hal_rocm_direct_command_buffer_push_constants,
.push_descriptor_set =
iree_hal_rocm_direct_command_buffer_push_descriptor_set,
diff --git a/experimental/rocm/rocm_device.c b/experimental/rocm/rocm_device.c
index 0e3815a..3c1a9f3 100644
--- a/experimental/rocm/rocm_device.c
+++ b/experimental/rocm/rocm_device.c
@@ -180,6 +180,13 @@
return iree_hal_allocator_trim(device->device_allocator);
}
+static iree_status_t iree_hal_rocm_device_create_channel(
+ iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
+ iree_hal_channel_params_t params, iree_hal_channel_t** out_channel) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "collectives not implemented");
+}
+
static iree_status_t iree_hal_rocm_device_create_command_buffer(
iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode,
iree_hal_command_category_t command_categories,
@@ -303,14 +310,14 @@
}
static iree_status_t iree_hal_rocm_device_profiling_begin(
- iree_hal_device_t* device,
+ iree_hal_device_t* base_device,
const iree_hal_device_profiling_options_t* options) {
// Unimplemented (and that's ok).
return iree_ok_status();
}
static iree_status_t iree_hal_rocm_device_profiling_end(
- iree_hal_device_t* device) {
+ iree_hal_device_t* base_device) {
// Unimplemented (and that's ok).
return iree_ok_status();
}
@@ -322,6 +329,7 @@
.device_allocator = iree_hal_rocm_device_allocator,
.trim = iree_hal_rocm_device_trim,
.query_i64 = iree_hal_rocm_device_query_i64,
+ .create_channel = iree_hal_rocm_device_create_channel,
.create_command_buffer = iree_hal_rocm_device_create_command_buffer,
.create_descriptor_set_layout =
iree_hal_rocm_device_create_descriptor_set_layout,
diff --git a/runtime/iree.natvis b/runtime/iree.natvis
index e6e49ff..bde544c 100644
--- a/runtime/iree.natvis
+++ b/runtime/iree.natvis
@@ -588,6 +588,7 @@
<DisplayString Condition="ptr!=0 && strcmp(iree_vm_ref_type_descriptors[type]->type_name.data, "hal.allocator")==0">{(iree_hal_allocator_t*)ptr}</DisplayString>
<DisplayString Condition="ptr!=0 && strcmp(iree_vm_ref_type_descriptors[type]->type_name.data, "hal.buffer")==0">{(iree_hal_buffer_t*)ptr}</DisplayString>
<DisplayString Condition="ptr!=0 && strcmp(iree_vm_ref_type_descriptors[type]->type_name.data, "hal.buffer_view")==0">{(iree_hal_buffer_view_t*)ptr}</DisplayString>
+ <DisplayString Condition="ptr!=0 && strcmp(iree_vm_ref_type_descriptors[type]->type_name.data, "hal.channel")==0">{(iree_hal_channel_t*)ptr}</DisplayString>
<DisplayString Condition="ptr!=0 && strcmp(iree_vm_ref_type_descriptors[type]->type_name.data, "hal.command_buffer")==0">{(iree_hal_command_buffer_t*)ptr}</DisplayString>
<DisplayString Condition="ptr!=0 && strcmp(iree_vm_ref_type_descriptors[type]->type_name.data, "hal.descriptor_set_layout")==0">{(iree_hal_descriptor_set_layout_t*)ptr}</DisplayString>
<DisplayString Condition="ptr!=0 && strcmp(iree_vm_ref_type_descriptors[type]->type_name.data, "hal.device")==0">{(iree_hal_device_t*)ptr}</DisplayString>
@@ -609,6 +610,7 @@
<ExpandedItem Condition="ptr!=0 && strcmp(iree_vm_ref_type_descriptors[type]->type_name.data, "hal.allocator")==0">(iree_hal_allocator_t*)ptr</ExpandedItem>
<ExpandedItem Condition="ptr!=0 && strcmp(iree_vm_ref_type_descriptors[type]->type_name.data, "hal.buffer")==0">(iree_hal_buffer_t*)ptr</ExpandedItem>
<ExpandedItem Condition="ptr!=0 && strcmp(iree_vm_ref_type_descriptors[type]->type_name.data, "hal.buffer_view")==0">(iree_hal_buffer_view_t*)ptr</ExpandedItem>
+ <ExpandedItem Condition="ptr!=0 && strcmp(iree_vm_ref_type_descriptors[type]->type_name.data, "hal.channel")==0">(iree_hal_channel_t*)ptr</ExpandedItem>
<ExpandedItem Condition="ptr!=0 && strcmp(iree_vm_ref_type_descriptors[type]->type_name.data, "hal.command_buffer")==0">(iree_hal_command_buffer_t*)ptr</ExpandedItem>
<ExpandedItem Condition="ptr!=0 && strcmp(iree_vm_ref_type_descriptors[type]->type_name.data, "hal.descriptor_set_layout")==0">(iree_hal_descriptor_set_layout_t*)ptr</ExpandedItem>
<ExpandedItem Condition="ptr!=0 && strcmp(iree_vm_ref_type_descriptors[type]->type_name.data, "hal.device")==0">(iree_hal_device_t*)ptr</ExpandedItem>
diff --git a/runtime/src/iree/hal/BUILD b/runtime/src/iree/hal/BUILD
index 17ab96f..dac9195 100644
--- a/runtime/src/iree/hal/BUILD
+++ b/runtime/src/iree/hal/BUILD
@@ -34,6 +34,8 @@
"buffer_view.h",
"buffer_view_util.c",
"buffer_view_util.h",
+ "channel.c",
+ "channel.h",
"command_buffer.c",
"command_buffer.h",
"command_buffer_validation.c",
diff --git a/runtime/src/iree/hal/CMakeLists.txt b/runtime/src/iree/hal/CMakeLists.txt
index c25ea4b..349c5f0 100644
--- a/runtime/src/iree/hal/CMakeLists.txt
+++ b/runtime/src/iree/hal/CMakeLists.txt
@@ -27,6 +27,8 @@
"buffer_view.h"
"buffer_view_util.c"
"buffer_view_util.h"
+ "channel.c"
+ "channel.h"
"command_buffer.c"
"command_buffer.h"
"command_buffer_validation.c"
diff --git a/runtime/src/iree/hal/api.h b/runtime/src/iree/hal/api.h
index f8ec0f0..6c98563 100644
--- a/runtime/src/iree/hal/api.h
+++ b/runtime/src/iree/hal/api.h
@@ -13,6 +13,7 @@
#include "iree/hal/buffer.h" // IWYU pragma: export
#include "iree/hal/buffer_view.h" // IWYU pragma: export
#include "iree/hal/buffer_view_util.h" // IWYU pragma: export
+#include "iree/hal/channel.h" // IWYU pragma: export
#include "iree/hal/command_buffer.h" // IWYU pragma: export
#include "iree/hal/device.h" // IWYU pragma: export
#include "iree/hal/driver.h" // IWYU pragma: export
diff --git a/runtime/src/iree/hal/channel.c b/runtime/src/iree/hal/channel.c
new file mode 100644
index 0000000..59ed353
--- /dev/null
+++ b/runtime/src/iree/hal/channel.c
@@ -0,0 +1,57 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/channel.h"
+
+#include <stddef.h>
+
+#include "iree/base/tracing.h"
+#include "iree/hal/detail.h"
+#include "iree/hal/device.h"
+#include "iree/hal/resource.h"
+
+#define _VTABLE_DISPATCH(channel, method_name) \
+ IREE_HAL_VTABLE_DISPATCH(channel, iree_hal_channel, method_name)
+
+IREE_HAL_API_RETAIN_RELEASE(channel);
+
+IREE_API_EXPORT iree_status_t iree_hal_channel_create(
+ iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity,
+ iree_hal_channel_params_t params, iree_hal_channel_t** out_channel) {
+ IREE_ASSERT_ARGUMENT(device);
+ IREE_ASSERT_ARGUMENT(out_channel);
+ *out_channel = NULL;
+ IREE_TRACE_ZONE_BEGIN(z0);
+ iree_status_t status =
+ IREE_HAL_VTABLE_DISPATCH(device, iree_hal_device, create_channel)(
+ device, queue_affinity, params, out_channel);
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+IREE_API_EXPORT void iree_hal_channel_query_rank_and_count(
+ const iree_hal_channel_t* channel, int32_t* out_rank, int32_t* out_count) {
+ IREE_ASSERT_ARGUMENT(channel);
+ int32_t rank = 0;
+ int32_t count = 0;
+ _VTABLE_DISPATCH(channel, query_rank_and_count)(channel, &rank, &count);
+ if (out_rank) *out_rank = rank;
+ if (out_count) *out_count = count;
+}
+
+IREE_API_EXPORT int32_t
+iree_hal_channel_rank(const iree_hal_channel_t* channel) {
+ int32_t rank = 0;
+ iree_hal_channel_query_rank_and_count(channel, &rank, NULL);
+ return rank;
+}
+
+IREE_API_EXPORT int32_t
+iree_hal_channel_count(const iree_hal_channel_t* channel) {
+ int32_t count = 0;
+ iree_hal_channel_query_rank_and_count(channel, NULL, &count);
+ return count;
+}
diff --git a/runtime/src/iree/hal/channel.h b/runtime/src/iree/hal/channel.h
new file mode 100644
index 0000000..fa67312
--- /dev/null
+++ b/runtime/src/iree/hal/channel.h
@@ -0,0 +1,112 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_CHANNEL_H_
+#define IREE_HAL_CHANNEL_H_
+
+#include <stdbool.h>
+#include <stdint.h>
+
+#include "iree/base/api.h"
+#include "iree/hal/allocator.h"
+#include "iree/hal/resource.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+typedef struct iree_hal_device_t iree_hal_device_t;
+
+//===----------------------------------------------------------------------===//
+// iree_hal_channel_t
+//===----------------------------------------------------------------------===//
+
+enum iree_hal_channel_flag_bits_t {
+ IREE_HAL_CHANNEL_FLAG_NONE = 0u,
+};
+typedef uint32_t iree_hal_channel_flags_t;
+
+// Specifies that the channel should use environment settings if available.
+#define IREE_HAL_CHANNEL_RANK_DEFAULT ((int32_t)-1)
+#define IREE_HAL_CHANNEL_COUNT_DEFAULT ((int32_t)-1)
+
+// Parameters defining how a channel should be configured.
+typedef struct {
+ // Flags controlling channel behavior.
+ iree_hal_channel_flags_t flags;
+ // Implementation-defined identifier for the channel.
+ // May be empty to indicate that the environment should be used to populate
+ // the identifier.
+ //
+ // Equivalent to:
+ // ncclUniqueId
+ iree_const_byte_span_t id;
+ // Rank of the participant within the collective group.
+ // May be IREE_HAL_CHANNEL_RANK_DEFAULT to indicate that the environment
+ // should be used to populate the rank.
+ int32_t rank;
+ // Total number of participants within the collective group.
+ // May be IREE_HAL_CHANNEL_COUNT_DEFAULT to indicate that the environment
+ // should be used to populate the count.
+ int32_t count;
+} iree_hal_channel_params_t;
+
+// A collective communication channel representing a single rank.
+//
+// Equivalent to:
+// MPI_Comm
+// ncclComm_t
+// ccl::communicator
+typedef struct iree_hal_channel_t iree_hal_channel_t;
+
+// Creates a channel on |device| for use by all queues defined in
+// |queue_affinity|. |params| may specify the channel parameters or leave its
+// fields as default to indicate that the value should be sourced from the
+// environment.
+IREE_API_EXPORT iree_status_t iree_hal_channel_create(
+ iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity,
+ iree_hal_channel_params_t params, iree_hal_channel_t** out_channel);
+
+// Retains the given |channel| for the caller.
+IREE_API_EXPORT void iree_hal_channel_retain(iree_hal_channel_t* channel);
+
+// Releases the given |channel| from the caller.
+IREE_API_EXPORT void iree_hal_channel_release(iree_hal_channel_t* channel);
+
+// Returns the rank the channel represents as a participant in a collective
+// group in `[0, count)` and the total participant count.
+IREE_API_EXPORT void iree_hal_channel_query_rank_and_count(
+ const iree_hal_channel_t* channel, int32_t* out_rank, int32_t* out_count);
+
+// Returns the rank the channel represents as a participant in a collective
+// group in `[0, count)`.
+IREE_API_EXPORT int32_t
+iree_hal_channel_rank(const iree_hal_channel_t* channel);
+
+// Returns the total participant count in a collective group.
+IREE_API_EXPORT int32_t
+iree_hal_channel_count(const iree_hal_channel_t* channel);
+
+//===----------------------------------------------------------------------===//
+// iree_hal_channel_t implementation details
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_hal_channel_vtable_t {
+ void(IREE_API_PTR* destroy)(iree_hal_channel_t* channel);
+
+ void(IREE_API_PTR* query_rank_and_count)(const iree_hal_channel_t* channel,
+ int32_t* out_rank,
+ int32_t* out_count);
+} iree_hal_channel_vtable_t;
+IREE_HAL_ASSERT_VTABLE_LAYOUT(iree_hal_channel_vtable_t);
+
+IREE_API_EXPORT void iree_hal_channel_destroy(iree_hal_channel_t* channel);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_CHANNEL_H_
diff --git a/runtime/src/iree/hal/command_buffer.c b/runtime/src/iree/hal/command_buffer.c
index d53cd2a..5145d2f 100644
--- a/runtime/src/iree/hal/command_buffer.c
+++ b/runtime/src/iree/hal/command_buffer.c
@@ -365,6 +365,27 @@
return status;
}
+IREE_API_EXPORT iree_status_t iree_hal_command_buffer_collective(
+ iree_hal_command_buffer_t* command_buffer, iree_hal_channel_t* channel,
+ iree_hal_collective_op_t op, uint32_t param,
+ iree_hal_buffer_binding_t send_binding,
+ iree_hal_buffer_binding_t recv_binding, iree_device_size_t element_count) {
+ IREE_ASSERT_ARGUMENT(command_buffer);
+ IREE_ASSERT_ARGUMENT(channel);
+ IREE_TRACE_ZONE_BEGIN(z0);
+ IF_VALIDATING(command_buffer, {
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, iree_hal_command_buffer_collective_validation(
+ command_buffer, VALIDATION_STATE(command_buffer), channel, op,
+ param, send_binding, recv_binding, element_count));
+ });
+ iree_status_t status = _VTABLE_DISPATCH(command_buffer, collective)(
+ command_buffer, channel, op, param, send_binding, recv_binding,
+ element_count);
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
IREE_API_EXPORT iree_status_t iree_hal_command_buffer_push_constants(
iree_hal_command_buffer_t* command_buffer,
iree_hal_pipeline_layout_t* pipeline_layout, iree_host_size_t offset,
diff --git a/runtime/src/iree/hal/command_buffer.h b/runtime/src/iree/hal/command_buffer.h
index 718abe4..883959e 100644
--- a/runtime/src/iree/hal/command_buffer.h
+++ b/runtime/src/iree/hal/command_buffer.h
@@ -13,6 +13,7 @@
#include "iree/base/api.h"
#include "iree/hal/allocator.h"
#include "iree/hal/buffer.h"
+#include "iree/hal/channel.h"
#include "iree/hal/event.h"
#include "iree/hal/executable.h"
#include "iree/hal/pipeline_layout.h"
@@ -206,6 +207,146 @@
iree_device_size_t length;
} iree_hal_descriptor_set_binding_t;
+// Specifies the type of collective operation.
+enum iree_hal_collective_kind_e {
+ // Gathers N*|element_count| elements of the specified type in |recv_binding|
+ // by sourcing |element_count| elements from the |send_binding| of each rank
+ // and concatenating them.
+ //
+ // |param|: unused
+ // |send_binding|: local elements to add at offset rank
+ // |recv_binding|: concatenated results from all ranks
+ // In-place: |send_binding| == |recv_binding| + rank * |element_count|
+ // Equivalent to:
+ // ncclAllGather
+ IREE_HAL_COLLECTIVE_KIND_ALL_GATHER = 0u,
+
+ // Reduces |element_count| elements of the specified type in |send_binding|
+ // using the specified reduction operation and places identical copies of the
+ // result in each |recv_binding|.
+ //
+ // |param|: unused
+ // |send_binding|: local elements to reduce
+ // |recv_binding|: copy of the reduction results
+ // In-place: |send_binding| == |recv_binding|
+ // Equivalent to:
+ // ncclAllReduce
+ IREE_HAL_COLLECTIVE_KIND_ALL_REDUCE,
+
+ // Copies |element_count| elements of the specified type from |send_binding|
+ // on the specified rank |param| to all other ranks |recv_binding|s.
+ //
+ // |param|: source rank of the broadcast value
+ // |send_binding|: only used on the source rank
+ // |recv_binding|: only used on non-source ranks
+ // In-place: |send_binding| == |recv_binding|
+ // Equivalent to:
+ // ncclBroadcast
+ IREE_HAL_COLLECTIVE_KIND_BROADCAST,
+
+ // Reduces |element_count| elements of the specified type in |send_binding|
+ // using the specified reduction operation and places the results in the
+ // |recv_binding| of the target rank |param|.
+ //
+ // |param|: target rank of the resulting value
+ // |send_binding|: used on all ranks
+ // |recv_binding|: only used on the target rank
+ // In-place: |send_binding| == |recv_binding|
+ // Equivalent to:
+ // ncclReduce
+ IREE_HAL_COLLECTIVE_KIND_REDUCE,
+
+ // Reduce |element_count| elements of the specified type in |send_binding|
+ // from all ranks using the specified reduction operation and scatters the
+ // reduced results over the ranks such that the |recv_binding| on rank i
+ // will contain the i-th block of the results.
+ //
+ // |param|: unused
+ // |send_binding|: used on all ranks
+ // |recv_binding|: partial results for the hosting rank
+ // In-place: |recv_binding| == |send_binding| + rank * |element_count|
+ // Equivalent to:
+ // ncclReduceScatter
+ IREE_HAL_COLLECTIVE_KIND_REDUCE_SCATTER,
+
+ // Sends |element_count| elements of the specified type in |send_binding| to
+ // the target rank |param|.
+ //
+ // |param|: target performing a IREE_HAL_COLLECTIVE_KIND_RECV
+ // |send_binding|: used on source
+ // |recv_binding|: unused
+ // Equivalent to:
+ // ncclSend
+ IREE_HAL_COLLECTIVE_KIND_SEND,
+
+ // Receives |element_count| elements of the specified type in |recv_binding|
+ // from source rank |param|.
+ //
+ // |param|: source performing a IREE_HAL_COLLECTIVE_KIND_SEND
+ // |send_binding|: unused
+ // |recv_binding|: used on target
+ // Equivalent to:
+ // ncclRecv
+ IREE_HAL_COLLECTIVE_KIND_RECV,
+
+ // Maximum enumeration value for collective operations.
+ IREE_HAL_COLLECTIVE_KIND_MAX_VALUE = IREE_HAL_COLLECTIVE_KIND_RECV,
+};
+typedef uint8_t iree_hal_collective_kind_t;
+
+// Specifies the reduction operator of a collective reduction operation.
+enum iree_hal_collective_reduction_e {
+ // Specifies that the reduction operation computes a sum (addition).
+ IREE_HAL_COLLECTIVE_REDUCTION_SUM = 0,
+ // Specifies that the reduction operation computes a product (multiplication).
+ IREE_HAL_COLLECTIVE_REDUCTION_PRODUCT,
+ // Specifies that the reduction operation computes a minimum (min).
+ IREE_HAL_COLLECTIVE_REDUCTION_MINIMUM,
+ // Specifies that the reduction operation computes a maximum (max).
+ IREE_HAL_COLLECTIVE_REDUCTION_MAXIMUM,
+ // Specifies that the reduction operation computes an average (avg).
+ IREE_HAL_COLLECTIVE_REDUCTION_AVERAGE,
+};
+typedef uint8_t iree_hal_collective_reduction_t;
+
+// Specifies the element type as processed by a collective operation.
+// Note that these types are a much restricted set compared to
+// iree_hal_element_type_t as most collective compute libraries only expose a
+// limited number of primitives as some may be backed by fixed-function
+// hardware.
+enum iree_hal_collective_element_type_e {
+ IREE_HAL_COLLECTIVE_ELEMENT_TYPE_SINT_8 = 0,
+ IREE_HAL_COLLECTIVE_ELEMENT_TYPE_UINT_8,
+ IREE_HAL_COLLECTIVE_ELEMENT_TYPE_SINT_16, // not commonly implemented
+ IREE_HAL_COLLECTIVE_ELEMENT_TYPE_UINT_16, // not commonly implemented
+ IREE_HAL_COLLECTIVE_ELEMENT_TYPE_SINT_32,
+ IREE_HAL_COLLECTIVE_ELEMENT_TYPE_UINT_32,
+ IREE_HAL_COLLECTIVE_ELEMENT_TYPE_SINT_64,
+ IREE_HAL_COLLECTIVE_ELEMENT_TYPE_UINT_64,
+ IREE_HAL_COLLECTIVE_ELEMENT_TYPE_FLOAT_16,
+ IREE_HAL_COLLECTIVE_ELEMENT_TYPE_FLOAT_32,
+ IREE_HAL_COLLECTIVE_ELEMENT_TYPE_FLOAT_64,
+ IREE_HAL_COLLECTIVE_ELEMENT_TYPE_BFLOAT_16,
+};
+typedef uint8_t iree_hal_collective_element_type_t;
+
+// Describes a collective operation.
+typedef union {
+ uint32_t packed; // packed value
+ struct {
+ // Collective operation.
+ iree_hal_collective_kind_t kind;
+ // Reduction type (for reduction ops).
+ iree_hal_collective_reduction_t reduction;
+ // Element type.
+ iree_hal_collective_element_type_t element_type;
+ // Reserved for future use.
+ uint8_t reserved;
+ };
+} iree_hal_collective_op_t;
+static_assert(sizeof(iree_hal_collective_op_t) == sizeof(uint32_t),
+ "must pack");
+
// Describes a subrange of a buffer that can be bound to a binding slot.
typedef struct iree_hal_buffer_binding_t {
// Buffer being bound to the slot, if any.
@@ -478,6 +619,15 @@
iree_device_size_t source_offset, iree_hal_buffer_t* target_buffer,
iree_device_size_t target_offset, iree_device_size_t length);
+// Dispatches a collective operation defined by |op| using the given buffers.
+// |param| must be specified for operations that require a root/peer rank
+// identifier and is otherwise ignored.
+IREE_API_EXPORT iree_status_t iree_hal_command_buffer_collective(
+ iree_hal_command_buffer_t* command_buffer, iree_hal_channel_t* channel,
+ iree_hal_collective_op_t op, uint32_t param,
+ iree_hal_buffer_binding_t send_binding,
+ iree_hal_buffer_binding_t recv_binding, iree_device_size_t element_count);
+
// Pushes an inline set of constants that can be accessed by subsequent
// dispatches using a compatible pipeline layout.
//
@@ -683,6 +833,12 @@
iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
iree_device_size_t length);
+ iree_status_t(IREE_API_PTR* collective)(
+ iree_hal_command_buffer_t* command_buffer, iree_hal_channel_t* channel,
+ iree_hal_collective_op_t op, uint32_t param,
+ iree_hal_buffer_binding_t send_binding,
+ iree_hal_buffer_binding_t recv_binding, iree_device_size_t element_count);
+
iree_status_t(IREE_API_PTR* push_constants)(
iree_hal_command_buffer_t* command_buffer,
iree_hal_pipeline_layout_t* pipeline_layout, iree_host_size_t offset,
diff --git a/runtime/src/iree/hal/command_buffer_validation.c b/runtime/src/iree/hal/command_buffer_validation.c
index 0c483a8..2436371 100644
--- a/runtime/src/iree/hal/command_buffer_validation.c
+++ b/runtime/src/iree/hal/command_buffer_validation.c
@@ -362,6 +362,96 @@
return iree_ok_status();
}
+iree_status_t iree_hal_command_buffer_collective_validation(
+ iree_hal_command_buffer_t* command_buffer,
+ iree_hal_command_buffer_validation_state_t* validation_state,
+ iree_hal_channel_t* channel, iree_hal_collective_op_t op, uint32_t param,
+ iree_hal_buffer_binding_t send_binding,
+ iree_hal_buffer_binding_t recv_binding, iree_device_size_t element_count) {
+ IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories(
+ command_buffer, validation_state, IREE_HAL_COMMAND_CATEGORY_DISPATCH));
+
+ if (op.kind > IREE_HAL_COLLECTIVE_KIND_MAX_VALUE) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "unknown collective operation");
+ }
+ enum iree_hal_collective_info_bits_t {
+ IREE_HAL_COLLECTIVE_IS_REDUCTION = 1u << 0,
+ IREE_HAL_COLLECTIVE_REQUIRES_SEND_BINDING = 1u << 1,
+ IREE_HAL_COLLECTIVE_REQUIRES_RECV_BINDING = 1u << 2,
+ };
+ static const uint32_t
+ info_bits_table[IREE_HAL_COLLECTIVE_KIND_MAX_VALUE + 1] = {
+ [IREE_HAL_COLLECTIVE_KIND_ALL_GATHER] =
+ IREE_HAL_COLLECTIVE_REQUIRES_SEND_BINDING |
+ IREE_HAL_COLLECTIVE_REQUIRES_RECV_BINDING,
+ [IREE_HAL_COLLECTIVE_KIND_ALL_REDUCE] =
+ IREE_HAL_COLLECTIVE_IS_REDUCTION |
+ IREE_HAL_COLLECTIVE_REQUIRES_SEND_BINDING |
+ IREE_HAL_COLLECTIVE_REQUIRES_RECV_BINDING,
+ [IREE_HAL_COLLECTIVE_KIND_BROADCAST] =
+ IREE_HAL_COLLECTIVE_REQUIRES_SEND_BINDING |
+ IREE_HAL_COLLECTIVE_REQUIRES_RECV_BINDING,
+ [IREE_HAL_COLLECTIVE_KIND_REDUCE] =
+ IREE_HAL_COLLECTIVE_IS_REDUCTION |
+ IREE_HAL_COLLECTIVE_REQUIRES_SEND_BINDING |
+ IREE_HAL_COLLECTIVE_REQUIRES_RECV_BINDING,
+ [IREE_HAL_COLLECTIVE_KIND_REDUCE_SCATTER] =
+ IREE_HAL_COLLECTIVE_IS_REDUCTION |
+ IREE_HAL_COLLECTIVE_REQUIRES_SEND_BINDING |
+ IREE_HAL_COLLECTIVE_REQUIRES_RECV_BINDING,
+ [IREE_HAL_COLLECTIVE_KIND_SEND] =
+ IREE_HAL_COLLECTIVE_REQUIRES_SEND_BINDING,
+ [IREE_HAL_COLLECTIVE_KIND_RECV] =
+ IREE_HAL_COLLECTIVE_REQUIRES_RECV_BINDING,
+ };
+ const uint32_t info_bits = info_bits_table[op.kind];
+ if (!(info_bits & IREE_HAL_COLLECTIVE_IS_REDUCTION) && op.reduction != 0) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "reduction operation cannot be specified on a non-reducing collective");
+ }
+
+ // TODO(benvanik): add queue cap/usage for COLLECTIVE source/dest?
+ if (info_bits & IREE_HAL_COLLECTIVE_REQUIRES_SEND_BINDING) {
+ if (!send_binding.buffer) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "collective operation requires a send buffer binding");
+ } else {
+ IREE_RETURN_IF_ERROR(
+ iree_hal_command_buffer_validate_buffer_compatibility(
+ command_buffer, validation_state, send_binding.buffer,
+ IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH,
+ IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE_READ));
+ }
+ } else {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "collective operation does not use a send buffer binding");
+ }
+
+ if (info_bits & IREE_HAL_COLLECTIVE_REQUIRES_RECV_BINDING) {
+ if (!recv_binding.buffer) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "collective operation requires a recv buffer binding");
+ } else {
+ IREE_RETURN_IF_ERROR(
+ iree_hal_command_buffer_validate_buffer_compatibility(
+ command_buffer, validation_state, recv_binding.buffer,
+ IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH,
+ IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE_WRITE));
+ }
+ } else {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "collective operation does not use a recv buffer binding");
+ }
+
+ return iree_ok_status();
+}
+
iree_status_t iree_hal_command_buffer_push_constants_validation(
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
diff --git a/runtime/src/iree/hal/command_buffer_validation.h b/runtime/src/iree/hal/command_buffer_validation.h
index 9a4f2b7..687acd0 100644
--- a/runtime/src/iree/hal/command_buffer_validation.h
+++ b/runtime/src/iree/hal/command_buffer_validation.h
@@ -90,6 +90,13 @@
iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
iree_device_size_t length);
+iree_status_t iree_hal_command_buffer_collective_validation(
+ iree_hal_command_buffer_t* command_buffer,
+ iree_hal_command_buffer_validation_state_t* validation_state,
+ iree_hal_channel_t* channel, iree_hal_collective_op_t op, uint32_t param,
+ iree_hal_buffer_binding_t send_binding,
+ iree_hal_buffer_binding_t recv_binding, iree_device_size_t element_count);
+
iree_status_t iree_hal_command_buffer_push_constants_validation(
iree_hal_command_buffer_t* command_buffer,
iree_hal_command_buffer_validation_state_t* validation_state,
diff --git a/runtime/src/iree/hal/device.h b/runtime/src/iree/hal/device.h
index ea7f914..d8e5c73 100644
--- a/runtime/src/iree/hal/device.h
+++ b/runtime/src/iree/hal/device.h
@@ -13,6 +13,7 @@
#include "iree/base/api.h"
#include "iree/hal/allocator.h"
#include "iree/hal/buffer.h"
+#include "iree/hal/channel.h"
#include "iree/hal/command_buffer.h"
#include "iree/hal/event.h"
#include "iree/hal/executable_cache.h"
@@ -467,6 +468,10 @@
iree_string_view_t key,
int64_t* out_value);
+ iree_status_t(IREE_API_PTR* create_channel)(
+ iree_hal_device_t* device, iree_hal_queue_affinity_t queue_affinity,
+ iree_hal_channel_params_t params, iree_hal_channel_t** out_channel);
+
iree_status_t(IREE_API_PTR* create_command_buffer)(
iree_hal_device_t* device, iree_hal_command_buffer_mode_t mode,
iree_hal_command_category_t command_categories,
diff --git a/runtime/src/iree/hal/drivers/cuda/CMakeLists.txt b/runtime/src/iree/hal/drivers/cuda/CMakeLists.txt
index 7dcaffe..85b2d2b 100644
--- a/runtime/src/iree/hal/drivers/cuda/CMakeLists.txt
+++ b/runtime/src/iree/hal/drivers/cuda/CMakeLists.txt
@@ -33,6 +33,8 @@
"graph_command_buffer.h"
"native_executable.c"
"native_executable.h"
+ "nccl_channel.c"
+ "nccl_channel.h"
"nop_executable_cache.c"
"nop_executable_cache.h"
"pipeline_layout.c"
@@ -52,6 +54,7 @@
iree::base::tracing
iree::hal
iree::hal::utils::buffer_transfer
+ iree::hal::utils::collective_batch
iree::hal::utils::deferred_command_buffer
iree::hal::utils::resource_set
iree::hal::utils::semaphore_base
diff --git a/runtime/src/iree/hal/drivers/cuda/api.h b/runtime/src/iree/hal/drivers/cuda/api.h
index 93d2b9f..55bcb20 100644
--- a/runtime/src/iree/hal/drivers/cuda/api.h
+++ b/runtime/src/iree/hal/drivers/cuda/api.h
@@ -24,6 +24,11 @@
IREE_HAL_CUDA_COMMAND_BUFFER_MODE_STREAM = 1,
} iree_hal_cuda_command_buffer_mode_t;
+// ncclUniqueId exposed without exporting the NCCL headers.
+typedef struct {
+ char data[128];
+} iree_hal_cuda_nccl_id_t;
+
// Parameters configuring an iree_hal_cuda_device_t.
// Must be initialized with iree_hal_cuda_device_params_initialize prior to use.
typedef struct iree_hal_cuda_device_params_t {
@@ -44,6 +49,19 @@
// Only command buffers produced by the compiler that have the
// IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION bit set will use this.
bool allow_inline_execution;
+
+ // Opaque NCCL ID used during channel creation when empty IDs are provided.
+ // Today this is used for all communicators created but in the future this may
+ // just be used as a default when not otherwise specified on channel creation.
+ iree_hal_cuda_nccl_id_t nccl_default_id;
+ // Default base rank to use when creating collective channels.
+ // This will be added to the local rank assigned to communicators when
+ // IREE_HAL_CHANNEL_RANK_DEFAULT is specified on creation calls.
+ int nccl_default_rank;
+ // Default total number of participants to use when creating collective
+ // channels. This will be used IREE_HAL_CHANNEL_COUNT_DEFAULT is specified on
+ // creation calls.
+ int nccl_default_count;
} iree_hal_cuda_device_params_t;
// Initializes |out_params| to default values.
diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.c b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
index 5d4c89c..19a9fc8 100644
--- a/runtime/src/iree/hal/drivers/cuda/cuda_device.c
+++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.c
@@ -11,6 +11,7 @@
#include <string.h>
#include "iree/base/internal/arena.h"
+#include "iree/base/internal/math.h"
#include "iree/base/tracing.h"
#include "iree/hal/drivers/cuda/context_wrapper.h"
#include "iree/hal/drivers/cuda/cuda_allocator.h"
@@ -18,6 +19,7 @@
#include "iree/hal/drivers/cuda/dynamic_symbols.h"
#include "iree/hal/drivers/cuda/event_semaphore.h"
#include "iree/hal/drivers/cuda/graph_command_buffer.h"
+#include "iree/hal/drivers/cuda/nccl_channel.h"
#include "iree/hal/drivers/cuda/nop_executable_cache.h"
#include "iree/hal/drivers/cuda/pipeline_layout.h"
#include "iree/hal/drivers/cuda/status_util.h"
@@ -66,8 +68,9 @@
void iree_hal_cuda_device_params_initialize(
iree_hal_cuda_device_params_t* out_params) {
+ memset(out_params, 0, sizeof(*out_params));
out_params->arena_block_size = 32 * 1024;
- out_params->queue_count = 8;
+ out_params->queue_count = 1;
out_params->command_buffer_mode = IREE_HAL_CUDA_COMMAND_BUFFER_MODE_GRAPH;
out_params->allow_inline_execution = false;
}
@@ -120,7 +123,7 @@
(iree_hal_device_t*)device, &device->context_wrapper,
IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION,
IREE_HAL_COMMAND_CATEGORY_ANY, /*binding_capacity=*/0, device->stream,
- /*block_pool=*/NULL, &device->stream_command_buffer);
+ &device->block_pool, &device->stream_command_buffer);
}
if (iree_status_is_ok(status)) {
@@ -234,6 +237,80 @@
(int)category.size, category.data, (int)key.size, key.data);
}
+// Returns true if |id| is all zeros indicating an empty ID.
+static bool iree_hal_cuda_nccl_id_is_empty(const iree_hal_cuda_nccl_id_t* id) {
+ for (iree_host_size_t i = 0; i < IREE_ARRAYSIZE(id->data); ++i) {
+ if (id->data[i] != 0) return false;
+ }
+ return true;
+}
+
+static iree_status_t iree_hal_cuda_device_create_channel(
+ iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
+ iree_hal_channel_params_t params, iree_hal_channel_t** out_channel) {
+ iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
+
+ // TODO(#9580): check if nccl symbols are available - if not then we fail
+ // here and have the error propagated up to users. If we wanted to delay load
+ // NCCL we'd want to take a lock here, load it, and merge the symbols into the
+ // dynamic symbol table.
+ if (true) {
+ return iree_make_status(
+ IREE_STATUS_UNIMPLEMENTED,
+ "NCCL unavailable and collective operations cannot be performed");
+ }
+
+ // Try to use the ID specified in the parameters and fall back to the default.
+ iree_hal_cuda_nccl_id_t id;
+ if (iree_const_byte_span_is_empty(params.id)) {
+ // User wants the default.
+ id = device->params.nccl_default_id;
+ } else if (params.id.data_length != IREE_ARRAYSIZE(id.data)) {
+ // User provided something but it's not what we expect.
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "NCCL ID must be %d bytes matching the ncclUniqueId struct",
+ (int)IREE_ARRAYSIZE(id.data));
+ } else {
+ // User provided the ID - we treat it as opaque here and let NCCL validate.
+ memcpy(id.data, params.id.data, IREE_ARRAYSIZE(id.data));
+ }
+ if (iree_hal_cuda_nccl_id_is_empty(&id)) {
+ // TODO: maybe this is ok? a localhost alias or something?
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "no default NCCL ID specified (all zeros)");
+ }
+
+ // Today we only allow a single logical device per channel.
+ // We could multiplex channels but it'd be better to surface that to the
+ // compiler so that it can emit the right rank math.
+ int requested_count = iree_math_count_ones_u64(queue_affinity);
+ if (requested_count != 1) {
+ return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
+ "exactly one participant is allowed in a "
+ "channel but %d were specified",
+ requested_count);
+ }
+
+ // Users can either specify a specific rank or allow this device
+ // implementation to decide. This allows us to run the same programs acting as
+ // different ranks by setting flags/environment variables/API options/etc.
+ int rank = params.rank;
+ if (rank == IREE_HAL_CHANNEL_RANK_DEFAULT) {
+ rank = device->params.nccl_default_rank;
+ }
+ int count = params.count;
+ if (count == IREE_HAL_CHANNEL_COUNT_DEFAULT) {
+ count = device->params.nccl_default_count;
+ }
+
+ // TODO: when we support multiple logical devices we'll want to pass in the
+ // context of the device mapped to the queue_affinity. For now since this
+ // implementation only supports one device we pass in the only one we have.
+ return iree_hal_cuda_nccl_channel_create(&device->context_wrapper, &id, rank,
+ count, out_channel);
+}
+
static iree_status_t iree_hal_cuda_device_create_command_buffer(
iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode,
iree_hal_command_category_t command_categories,
@@ -417,6 +494,7 @@
.device_allocator = iree_hal_cuda_device_allocator,
.trim = iree_hal_cuda_device_trim,
.query_i64 = iree_hal_cuda_device_query_i64,
+ .create_channel = iree_hal_cuda_device_create_channel,
.create_command_buffer = iree_hal_cuda_device_create_command_buffer,
.create_descriptor_set_layout =
iree_hal_cuda_device_create_descriptor_set_layout,
diff --git a/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c b/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c
index 2915cc1..a6a1b90 100644
--- a/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c
+++ b/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c
@@ -15,8 +15,10 @@
#include "iree/hal/drivers/cuda/cuda_buffer.h"
#include "iree/hal/drivers/cuda/dynamic_symbols.h"
#include "iree/hal/drivers/cuda/native_executable.h"
+#include "iree/hal/drivers/cuda/nccl_channel.h"
#include "iree/hal/drivers/cuda/pipeline_layout.h"
#include "iree/hal/drivers/cuda/status_util.h"
+#include "iree/hal/utils/collective_batch.h"
#include "iree/hal/utils/resource_set.h"
#define IREE_HAL_CUDA_MAX_BINDING_COUNT 64
@@ -45,7 +47,12 @@
// Keep track of the last node added to the command buffer as we are currently
// serializing all the nodes (each node depends on the previous one).
CUgraphNode last_node;
+
+ // Iteratively constructed batch of collective operations.
+ iree_hal_collective_batch_t collective_batch;
+
int32_t push_constant[IREE_HAL_CUDA_MAX_PUSH_CONSTANT_COUNT];
+
// Keep track of the current set of kernel arguments.
void* current_descriptor[];
} iree_hal_cuda_graph_command_buffer_t;
@@ -105,6 +112,11 @@
status = iree_hal_resource_set_allocate(block_pool,
&command_buffer->resource_set);
}
+ if (iree_status_is_ok(status)) {
+ iree_hal_collective_batch_initialize(&command_buffer->arena,
+ command_buffer->resource_set,
+ &command_buffer->collective_batch);
+ }
if (iree_status_is_ok(status)) {
*out_command_buffer = &command_buffer->base;
@@ -121,6 +133,9 @@
iree_hal_cuda_graph_command_buffer_cast(base_command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
+ // Drop any pending collective batches before we tear things down.
+ iree_hal_collective_batch_reset(&command_buffer->collective_batch);
+
if (command_buffer->graph != NULL) {
CUDA_IGNORE_ERROR(command_buffer->context->syms,
cuGraphDestroy(command_buffer->graph));
@@ -133,6 +148,7 @@
}
command_buffer->last_node = NULL;
+ iree_hal_collective_batch_deinitialize(&command_buffer->collective_batch);
iree_hal_resource_set_free(command_buffer->resource_set);
iree_arena_deinitialize(&command_buffer->arena);
iree_allocator_free(command_buffer->context->host_allocator, command_buffer);
@@ -164,6 +180,47 @@
return NULL;
}
+// Flushes any pending batched collective operations.
+// Must be called before any other non-collective nodes are added to the graph
+// or a barrier is encountered.
+static iree_status_t iree_hal_cuda_graph_command_buffer_flush_collectives(
+ iree_hal_cuda_graph_command_buffer_t* command_buffer) {
+ // NOTE: we could move this out into callers by way of an always-inline shim -
+ // that would make this a single compare against the command buffer state we
+ // are likely to access immediately after anyway and keep overheads minimal.
+ if (IREE_LIKELY(iree_hal_collective_batch_is_empty(
+ &command_buffer->collective_batch))) {
+ return iree_ok_status();
+ }
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // TODO(#9580): use CUDA graph capture so that the NCCL calls end up in the
+ // graph:
+ // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/cudagraph.html
+ //
+ // Something like:
+ // syms->cuStreamBeginCapture(nccl_stream);
+ // iree_hal_cuda_nccl_submit_batch(command_buffer->context,
+ // &command_buffer->collective_batch,
+ // nccl_stream);
+ // syms->cuStreamEndCapture(nccl_stream, &child_graph);
+ // syms->cuGraphAddChildGraphNode(..., child_graph);
+ // syms->cuGraphDestroy(child_graph); // probably, I think it gets cloned
+ //
+ // Note that we'll want to create a scratch stream that we use to perform the
+ // capture - we could memoize that on the command buffer or on the device
+ // (though that introduces potential threading issues). There may be a special
+ // stream mode for these capture-only streams that is lighter weight than a
+ // normal stream.
+ iree_status_t status = iree_make_status(
+ IREE_STATUS_UNIMPLEMENTED,
+ "CUDA graph capture of collective operations not yet implemented");
+
+ iree_hal_collective_batch_reset(&command_buffer->collective_batch);
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
static iree_status_t iree_hal_cuda_graph_command_buffer_begin(
iree_hal_command_buffer_t* base_command_buffer) {
iree_hal_cuda_graph_command_buffer_t* command_buffer =
@@ -188,6 +245,10 @@
iree_hal_cuda_graph_command_buffer_t* command_buffer =
iree_hal_cuda_graph_command_buffer_cast(base_command_buffer);
+ // Flush any pending collective batches.
+ IREE_RETURN_IF_ERROR(
+ iree_hal_cuda_graph_command_buffer_flush_collectives(command_buffer));
+
// Reset state used during recording.
command_buffer->last_node = NULL;
@@ -230,24 +291,42 @@
const iree_hal_memory_barrier_t* memory_barriers,
iree_host_size_t buffer_barrier_count,
const iree_hal_buffer_barrier_t* buffer_barriers) {
+ iree_hal_cuda_graph_command_buffer_t* command_buffer =
+ iree_hal_cuda_graph_command_buffer_cast(base_command_buffer);
+ IREE_RETURN_IF_ERROR(
+ iree_hal_cuda_graph_command_buffer_flush_collectives(command_buffer));
+
// TODO: Implement barrier with Graph edges. Right now all the nodes are
- // serialized.
+ // serialized so this is a no-op.
+
return iree_ok_status();
}
static iree_status_t iree_hal_cuda_graph_command_buffer_signal_event(
iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event,
iree_hal_execution_stage_t source_stage_mask) {
+ iree_hal_cuda_graph_command_buffer_t* command_buffer =
+ iree_hal_cuda_graph_command_buffer_cast(base_command_buffer);
+ IREE_RETURN_IF_ERROR(
+ iree_hal_cuda_graph_command_buffer_flush_collectives(command_buffer));
+
// TODO: Implement barrier with Graph edges. Right now all the nodes are
- // serialized.
+ // serialized so this is a no-op.
+
return iree_ok_status();
}
static iree_status_t iree_hal_cuda_graph_command_buffer_reset_event(
iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event,
iree_hal_execution_stage_t source_stage_mask) {
+ iree_hal_cuda_graph_command_buffer_t* command_buffer =
+ iree_hal_cuda_graph_command_buffer_cast(base_command_buffer);
+ IREE_RETURN_IF_ERROR(
+ iree_hal_cuda_graph_command_buffer_flush_collectives(command_buffer));
+
// TODO: Implement barrier with Graph edges. Right now all the nodes are
- // serialized.
+ // serialized so this is a no-op.
+
return iree_ok_status();
}
@@ -260,15 +339,21 @@
const iree_hal_memory_barrier_t* memory_barriers,
iree_host_size_t buffer_barrier_count,
const iree_hal_buffer_barrier_t* buffer_barriers) {
+ iree_hal_cuda_graph_command_buffer_t* command_buffer =
+ iree_hal_cuda_graph_command_buffer_cast(base_command_buffer);
+ IREE_RETURN_IF_ERROR(
+ iree_hal_cuda_graph_command_buffer_flush_collectives(command_buffer));
+
// TODO: Implement barrier with Graph edges. Right now all the nodes are
- // serialized.
+ // serialized so this is a no-op.
+
return iree_ok_status();
}
static iree_status_t iree_hal_cuda_graph_command_buffer_discard_buffer(
iree_hal_command_buffer_t* base_command_buffer, iree_hal_buffer_t* buffer) {
- // We could mark the memory as invalidated so that if managed CUDA does not
- // try to copy it back to the host.
+ // We could mark the memory as invalidated so that if this is a managed buffer
+ // CUDA does not try to copy it back to the host.
return iree_ok_status();
}
@@ -301,6 +386,8 @@
iree_host_size_t pattern_length) {
iree_hal_cuda_graph_command_buffer_t* command_buffer =
iree_hal_cuda_graph_command_buffer_cast(base_command_buffer);
+ IREE_RETURN_IF_ERROR(
+ iree_hal_cuda_graph_command_buffer_flush_collectives(command_buffer));
IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert(
command_buffer->resource_set, 1, &target_buffer));
@@ -309,6 +396,7 @@
iree_hal_buffer_allocated_buffer(target_buffer));
target_offset += iree_hal_buffer_byte_offset(target_buffer);
uint32_t dword_pattern = iree_hal_cuda_splat_pattern(pattern, pattern_length);
+
CUDA_MEMSET_NODE_PARAMS params = {
.dst = target_device_buffer + target_offset,
.elementSize = pattern_length,
@@ -326,6 +414,7 @@
dep, numNode, ¶ms,
command_buffer->context->cu_context),
"cuGraphAddMemsetNode");
+
return iree_ok_status();
}
@@ -335,6 +424,8 @@
iree_device_size_t target_offset, iree_device_size_t length) {
iree_hal_cuda_graph_command_buffer_t* command_buffer =
iree_hal_cuda_graph_command_buffer_cast(base_command_buffer);
+ IREE_RETURN_IF_ERROR(
+ iree_hal_cuda_graph_command_buffer_flush_collectives(command_buffer));
// Allocate scratch space in the arena for the data and copy it in.
// The update buffer API requires that the command buffer capture the host
@@ -362,15 +453,18 @@
.Height = 1,
.Depth = 1,
};
+
// Serialize all the nodes for now.
CUgraphNode dep[] = {command_buffer->last_node};
size_t numNode = command_buffer->last_node ? 1 : 0;
+
CUDA_RETURN_IF_ERROR(
command_buffer->context->syms,
cuGraphAddMemcpyNode(&command_buffer->last_node, command_buffer->graph,
dep, numNode, ¶ms,
command_buffer->context->cu_context),
"cuGraphAddMemcpyNode");
+
return iree_ok_status();
}
@@ -381,6 +475,8 @@
iree_device_size_t length) {
iree_hal_cuda_graph_command_buffer_t* command_buffer =
iree_hal_cuda_graph_command_buffer_cast(base_command_buffer);
+ IREE_RETURN_IF_ERROR(
+ iree_hal_cuda_graph_command_buffer_flush_collectives(command_buffer));
const iree_hal_buffer_t* buffers[2] = {source_buffer, target_buffer};
IREE_RETURN_IF_ERROR(
@@ -392,6 +488,7 @@
CUdeviceptr source_device_buffer = iree_hal_cuda_buffer_device_pointer(
iree_hal_buffer_allocated_buffer(source_buffer));
source_offset += iree_hal_buffer_byte_offset(source_buffer);
+
CUDA_MEMCPY3D params = {
.srcMemoryType = CU_MEMORYTYPE_DEVICE,
.srcDevice = source_device_buffer,
@@ -403,18 +500,33 @@
.Height = 1,
.Depth = 1,
};
+
// Serialize all the nodes for now.
CUgraphNode dep[] = {command_buffer->last_node};
size_t numNode = command_buffer->last_node ? 1 : 0;
+
CUDA_RETURN_IF_ERROR(
command_buffer->context->syms,
cuGraphAddMemcpyNode(&command_buffer->last_node, command_buffer->graph,
dep, numNode, ¶ms,
command_buffer->context->cu_context),
"cuGraphAddMemcpyNode");
+
return iree_ok_status();
}
+static iree_status_t iree_hal_cuda_graph_command_buffer_collective(
+ iree_hal_command_buffer_t* base_command_buffer, iree_hal_channel_t* channel,
+ iree_hal_collective_op_t op, uint32_t param,
+ iree_hal_buffer_binding_t send_binding,
+ iree_hal_buffer_binding_t recv_binding, iree_device_size_t element_count) {
+ iree_hal_cuda_graph_command_buffer_t* command_buffer =
+ iree_hal_cuda_graph_command_buffer_cast(base_command_buffer);
+ return iree_hal_collective_batch_append(&command_buffer->collective_batch,
+ channel, op, param, send_binding,
+ recv_binding, element_count);
+}
+
static iree_status_t iree_hal_cuda_graph_command_buffer_push_constants(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_pipeline_layout_t* pipeline_layout, iree_host_size_t offset,
@@ -451,8 +563,10 @@
const iree_hal_descriptor_set_binding_t* bindings) {
iree_hal_cuda_graph_command_buffer_t* command_buffer =
iree_hal_cuda_graph_command_buffer_cast(base_command_buffer);
+
iree_host_size_t base_binding =
iree_hal_cuda_base_binding_index(pipeline_layout, set);
+
// Convention with the compiler side. We map bindings to kernel argument.
// We compact the bindings to get a dense set of arguments and keep them order
// based on the binding index.
@@ -463,10 +577,13 @@
iree_hal_cuda_binding_mapping_t buffer = {i, bindings[i].binding};
binding_used[i] = buffer;
}
+ // TODO: remove this sort - it's thankfully small (1-8 on average) but we
+ // should be able to avoid it like we do on the CPU side with a bitmap.
qsort(binding_used, binding_count, sizeof(iree_hal_cuda_binding_mapping_t),
compare_binding_index);
IREE_ASSERT_LT(binding_count, IREE_HAL_CUDA_MAX_BINDING_COUNT,
"binding count larger than the max expected");
+
for (iree_host_size_t i = 0; i < binding_count; i++) {
const iree_hal_descriptor_set_binding_t* binding =
&bindings[binding_used[i].index];
@@ -483,6 +600,7 @@
command_buffer->resource_set, 1, &binding->buffer));
}
}
+
return iree_ok_status();
}
@@ -492,6 +610,9 @@
uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
iree_hal_cuda_graph_command_buffer_t* command_buffer =
iree_hal_cuda_graph_command_buffer_cast(base_command_buffer);
+ IREE_RETURN_IF_ERROR(
+ iree_hal_cuda_graph_command_buffer_flush_collectives(command_buffer));
+
IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert(
command_buffer->resource_set, 1, &executable));
iree_hal_pipeline_layout_t* layout =
@@ -500,17 +621,20 @@
iree_hal_cuda_pipeline_layout_num_constants(layout);
iree_host_size_t constant_base_index =
iree_hal_cuda_push_constant_index(layout);
+
// Patch the push constants in the kernel arguments.
for (iree_host_size_t i = 0; i < num_constants; i++) {
*((uint32_t*)command_buffer->current_descriptor[i + constant_base_index]) =
command_buffer->push_constant[i];
}
+
int32_t block_size_x, block_size_y, block_size_z;
int32_t shared_memory_size;
IREE_RETURN_IF_ERROR(iree_hal_cuda_native_executable_block_size(
executable, entry_point, &block_size_x, &block_size_y, &block_size_z));
IREE_RETURN_IF_ERROR(iree_hal_cuda_native_executable_shared_memory_size(
executable, entry_point, &shared_memory_size));
+
CUDA_KERNEL_NODE_PARAMS params = {
.func = iree_hal_cuda_native_executable_for_entry_point(executable,
entry_point),
@@ -523,14 +647,17 @@
.kernelParams = command_buffer->current_descriptor,
.sharedMemBytes = shared_memory_size,
};
+
// Serialize all the nodes for now.
CUgraphNode dep[] = {command_buffer->last_node};
size_t numNodes = command_buffer->last_node ? 1 : 0;
+
CUDA_RETURN_IF_ERROR(
command_buffer->context->syms,
cuGraphAddKernelNode(&command_buffer->last_node, command_buffer->graph,
dep, numNodes, ¶ms),
"cuGraphAddKernelNode");
+
return iree_ok_status();
}
@@ -574,6 +701,7 @@
.fill_buffer = iree_hal_cuda_graph_command_buffer_fill_buffer,
.update_buffer = iree_hal_cuda_graph_command_buffer_update_buffer,
.copy_buffer = iree_hal_cuda_graph_command_buffer_copy_buffer,
+ .collective = iree_hal_cuda_graph_command_buffer_collective,
.push_constants = iree_hal_cuda_graph_command_buffer_push_constants,
.push_descriptor_set =
iree_hal_cuda_graph_command_buffer_push_descriptor_set,
diff --git a/runtime/src/iree/hal/drivers/cuda/nccl_channel.c b/runtime/src/iree/hal/drivers/cuda/nccl_channel.c
new file mode 100644
index 0000000..186064e
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/cuda/nccl_channel.c
@@ -0,0 +1,175 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/drivers/cuda/nccl_channel.h"
+
+#include <stddef.h>
+
+#include "iree/base/api.h"
+#include "iree/base/tracing.h"
+
+// Returns the same value as NCCL's init.cc hashUniqueId.
+// These magic constants were chosen by their implementation and unlikely to
+// be stable as it's not part of their public API. Only to be used for
+// correlating debug logging/traces. We keep it internal here too so that we
+// aren't tempted to use it either.
+static uint64_t iree_hal_cuda_nccl_hash_id(const iree_hal_cuda_nccl_id_t* id) {
+ uint64_t hash = 0xDEADBEEF;
+ for (iree_host_size_t i = 0; i < sizeof(*id); i++) {
+ hash ^= hash >> 32;
+ hash *= 0x8DB3DB47FA2994ADull;
+ hash += id->data[i];
+ }
+ return hash;
+}
+
+typedef struct iree_hal_cuda_nccl_channel_t {
+ iree_hal_resource_t resource;
+ iree_hal_cuda_context_wrapper_t* context_wrapper;
+
+ // Hash of the unique ID used to create the communicator.
+ // This is consistent with the hashes NCCL itself uses for logging but is not
+ // guaranteed to be unique - only use for informational purposes.
+ uint64_t id_hash;
+
+ // This participant's rank in the communicator.
+ // Equivalent to ncclCommUserRank.
+ int rank;
+ // Total number of participants in the communicator.
+ // Equivalent to ncclCommCount.
+ int count;
+
+ // Communicator handle.
+ ncclComm_t comm;
+} iree_hal_cuda_nccl_channel_t;
+
+static const iree_hal_channel_vtable_t iree_hal_cuda_nccl_channel_vtable;
+
+static iree_hal_cuda_nccl_channel_t* iree_hal_cuda_nccl_channel_cast(
+ iree_hal_channel_t* base_value) {
+ IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_cuda_nccl_channel_vtable);
+ return (iree_hal_cuda_nccl_channel_t*)base_value;
+}
+
+iree_status_t iree_hal_cuda_nccl_channel_create(
+ iree_hal_cuda_context_wrapper_t* context_wrapper,
+ const iree_hal_cuda_nccl_id_t* id, int rank, int count,
+ iree_hal_channel_t** out_channel) {
+ IREE_ASSERT_ARGUMENT(context_wrapper);
+ IREE_ASSERT_ARGUMENT(out_channel);
+ *out_channel = NULL;
+ IREE_TRACE_ZONE_BEGIN(z0);
+ const uint64_t id_hash = iree_hal_cuda_nccl_hash_id(id);
+ IREE_TRACE_ZONE_APPEND_VALUE(z0, id_hash);
+ IREE_TRACE_ZONE_APPEND_VALUE(z0, rank);
+ IREE_TRACE_ZONE_APPEND_VALUE(z0, count);
+
+ // TODO(#9580): actually use nccl to create a communicator.
+ // Something like:
+ // ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
+ // config.blocking = 0;
+ // syms->ncclCommInitRankConfig(&comm, count, *id, rank, &config);
+ // NOTE: CHECK ERRORS! we can safely return here as we haven't allocated the
+ // channel wrapper yet.
+ ncclComm_t comm = NULL;
+ if (!comm) {
+ IREE_TRACE_ZONE_END(z0);
+ return iree_make_status(
+ IREE_STATUS_INTERNAL,
+ "failed to create NCCL communicator for rank=%d count=%d", rank, count);
+ }
+
+ iree_hal_cuda_nccl_channel_t* channel = NULL;
+ iree_status_t status = iree_allocator_malloc(
+ context_wrapper->host_allocator, sizeof(*channel), (void**)&channel);
+ if (iree_status_is_ok(status)) {
+ iree_hal_resource_initialize(&iree_hal_cuda_nccl_channel_vtable,
+ &channel->resource);
+ channel->context_wrapper = context_wrapper;
+ channel->id_hash = id_hash;
+ channel->rank = rank;
+ channel->count = count;
+ channel->comm = comm;
+ *out_channel = (iree_hal_channel_t*)channel;
+ }
+
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+static void iree_hal_cuda_nccl_channel_destroy(
+ iree_hal_channel_t* base_channel) {
+ iree_hal_cuda_nccl_channel_t* channel =
+ iree_hal_cuda_nccl_channel_cast(base_channel);
+ iree_allocator_t host_allocator = channel->context_wrapper->host_allocator;
+ IREE_TRACE_ZONE_BEGIN(z0);
+ IREE_TRACE_ZONE_APPEND_VALUE(z0, channel->id_hash);
+ IREE_TRACE_ZONE_APPEND_VALUE(z0, channel->rank);
+ IREE_TRACE_ZONE_APPEND_VALUE(z0, channel->count);
+
+ // TODO(#9580): tear down nccl - blocking if needed.
+ // We could be smarter about starting finalization of all channels async and
+ // then waiting for them to complete but we aren't currently optimizing for
+ // lifetime performance. To do that we'd probably want to track each open
+ // channel on the device that created them and manage teardown there.
+ //
+ // Recommended:
+ // syms->ncclCommFinalize(channel->comm); // non-blocking!
+ // while (ncclCommGetAsyncError == ncclInProgress) sleep(1);
+ // syms->ncclCommDestroy(channel->comm)
+ // Should work the same (as we are doing a blocking teardown):
+ // syms->ncclCommDestroy(channel->comm)
+
+ iree_allocator_free(host_allocator, channel);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+void iree_hal_cuda_nccl_channel_query_rank_and_count(
+ const iree_hal_channel_t* base_channel, int32_t* out_rank,
+ int32_t* out_count) {
+ IREE_ASSERT_ARGUMENT(base_channel);
+ iree_hal_cuda_nccl_channel_t* channel =
+ iree_hal_cuda_nccl_channel_cast((iree_hal_channel_t*)base_channel);
+ // NOTE: since it's cheap we keep rank/count local - this lets us trace them
+ // out without needing to call into NCCL each time.
+ *out_rank = channel->rank;
+ *out_count = channel->count;
+}
+
+ncclComm_t iree_hal_cuda_nccl_channel_comm(iree_hal_channel_t* base_channel) {
+ IREE_ASSERT_ARGUMENT(base_channel);
+ iree_hal_cuda_nccl_channel_t* channel =
+ iree_hal_cuda_nccl_channel_cast(base_channel);
+ return channel->comm;
+}
+
+iree_status_t iree_hal_cuda_nccl_submit_batch(
+ iree_hal_cuda_context_wrapper_t* context,
+ const iree_hal_collective_batch_t* batch, CUstream stream) {
+ IREE_ASSERT_ARGUMENT(context);
+ IREE_ASSERT_ARGUMENT(batch);
+ IREE_ASSERT_ARGUMENT(stream);
+
+ // TODO(#9580): issue the operations in the batch. Note that the channel may
+ // change between ops and the communicator should be retrieved from each.
+ //
+ // Something like:
+ // make context->cu_context active (for when using multiple devices)
+ // syms->ncclGroupStart();
+ // for each entry in batch:
+ // ncclComm_t comm = iree_hal_cuda_nccl_channel_comm(entry->channel);
+ // syms->nccl*(comm, ...);
+ // syms->ncclGroupEnd();
+
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "NCCL submission not yet implemented");
+}
+
+static const iree_hal_channel_vtable_t iree_hal_cuda_nccl_channel_vtable = {
+ .destroy = iree_hal_cuda_nccl_channel_destroy,
+ .query_rank_and_count = iree_hal_cuda_nccl_channel_query_rank_and_count,
+};
diff --git a/runtime/src/iree/hal/drivers/cuda/nccl_channel.h b/runtime/src/iree/hal/drivers/cuda/nccl_channel.h
new file mode 100644
index 0000000..51ffd54
--- /dev/null
+++ b/runtime/src/iree/hal/drivers/cuda/nccl_channel.h
@@ -0,0 +1,44 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_DRIVERS_CUDA_NCCL_CHANNEL_H_
+#define IREE_HAL_DRIVERS_CUDA_NCCL_CHANNEL_H_
+
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+#include "iree/hal/drivers/cuda/api.h"
+#include "iree/hal/drivers/cuda/context_wrapper.h"
+#include "iree/hal/drivers/cuda/cuda_headers.h"
+#include "iree/hal/utils/collective_batch.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// Creates a new NCCL communicator channel.
+typedef struct ncclComm* ncclComm_t;
+
+iree_status_t iree_hal_cuda_nccl_channel_create(
+ iree_hal_cuda_context_wrapper_t* context_wrapper,
+ const iree_hal_cuda_nccl_id_t* id, int rank, int count,
+ iree_hal_channel_t** out_channel);
+
+// Returns the NCCL communicator for the given |channel|, if available.
+ncclComm_t iree_hal_cuda_nccl_channel_comm(iree_hal_channel_t* channel);
+
+// Performs a non-blocking submission of |batch| to |stream|.
+// The backing storage of |batch| is dropped immediately but all resources
+// referenced will be retained by the parent command buffer for its lifetime.
+// Note that operations in the batch may apply to different channels.
+iree_status_t iree_hal_cuda_nccl_submit_batch(
+ iree_hal_cuda_context_wrapper_t* context,
+ const iree_hal_collective_batch_t* batch, CUstream stream);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_DRIVERS_CUDA_NCCL_CHANNEL_H_
diff --git a/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c b/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c
index 3ddfcc9..380c8ae 100644
--- a/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c
+++ b/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c
@@ -10,26 +10,35 @@
#include "iree/hal/drivers/cuda/cuda_buffer.h"
#include "iree/hal/drivers/cuda/cuda_event.h"
#include "iree/hal/drivers/cuda/native_executable.h"
+#include "iree/hal/drivers/cuda/nccl_channel.h"
#include "iree/hal/drivers/cuda/pipeline_layout.h"
#include "iree/hal/drivers/cuda/status_util.h"
+#include "iree/hal/utils/collective_batch.h"
+#include "iree/hal/utils/resource_set.h"
#define IREE_HAL_CUDA_MAX_BINDING_COUNT 64
// Kernel arguments contains binding and push constants.
#define IREE_HAL_CUDA_MAX_KERNEL_ARG 128
-// This records the commands on the calling thread without additional threading
-// indirection.
typedef struct {
iree_hal_command_buffer_t base;
iree_hal_cuda_context_wrapper_t* context;
CUstream stream;
+ // Maintains a reference to all resources used within the command buffer.
+ // Reset on each begin.
+ iree_hal_resource_set_t* resource_set;
+
// Staging arena used for host->device transfers.
// Used for when we need CUDA to be able to reference memory as it performs
// asynchronous operations.
iree_arena_allocator_t arena;
+ // Iteratively constructed batch of collective operations.
+ iree_hal_collective_batch_t collective_batch;
+
int32_t push_constant[IREE_HAL_CUDA_MAX_PUSH_CONSTANT_COUNT];
+
// Keep track of the current set of kernel arguments.
void* current_descriptor[IREE_HAL_CUDA_MAX_KERNEL_ARG];
CUdeviceptr* device_ptrs[IREE_HAL_CUDA_MAX_KERNEL_ARG];
@@ -80,6 +89,14 @@
for (size_t i = 0; i < IREE_HAL_CUDA_MAX_KERNEL_ARG; i++) {
command_buffer->current_descriptor[i] = &command_buffer->device_ptrs[i];
}
+
+ status = iree_hal_resource_set_allocate(block_pool,
+ &command_buffer->resource_set);
+ }
+ if (iree_status_is_ok(status)) {
+ iree_hal_collective_batch_initialize(&command_buffer->arena,
+ command_buffer->resource_set,
+ &command_buffer->collective_batch);
}
*out_command_buffer = &command_buffer->base;
@@ -93,6 +110,8 @@
iree_hal_cuda_stream_command_buffer_cast(base_command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
+ iree_hal_collective_batch_deinitialize(&command_buffer->collective_batch);
+ iree_hal_resource_set_free(command_buffer->resource_set);
iree_arena_deinitialize(&command_buffer->arena);
iree_allocator_free(command_buffer->context->host_allocator, command_buffer);
@@ -114,6 +133,27 @@
return NULL;
}
+// Flushes any pending batched collective operations.
+// Must be called before any other non-collective nodes are added to the graph
+// or a barrier is encountered.
+static iree_status_t iree_hal_cuda_stream_command_buffer_flush_collectives(
+ iree_hal_cuda_stream_command_buffer_t* command_buffer) {
+ // NOTE: we could move this out into callers by way of an always-inline shim -
+ // that would make this a single compare against the command buffer state we
+ // are likely to access immediately after anyway and keep overheads minimal.
+ if (IREE_LIKELY(iree_hal_collective_batch_is_empty(
+ &command_buffer->collective_batch))) {
+ return iree_ok_status();
+ }
+ IREE_TRACE_ZONE_BEGIN(z0);
+ iree_status_t status = iree_hal_cuda_nccl_submit_batch(
+ command_buffer->context, &command_buffer->collective_batch,
+ command_buffer->stream);
+ iree_hal_collective_batch_reset(&command_buffer->collective_batch);
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
static iree_status_t iree_hal_cuda_stream_command_buffer_begin(
iree_hal_command_buffer_t* base_command_buffer) {
iree_hal_cuda_stream_command_buffer_t* command_buffer =
@@ -124,6 +164,10 @@
static iree_status_t iree_hal_cuda_stream_command_buffer_end(
iree_hal_command_buffer_t* base_command_buffer) {
+ iree_hal_cuda_stream_command_buffer_t* command_buffer =
+ iree_hal_cuda_stream_command_buffer_cast(base_command_buffer);
+ IREE_RETURN_IF_ERROR(
+ iree_hal_cuda_stream_command_buffer_flush_collectives(command_buffer));
return iree_ok_status();
}
@@ -136,6 +180,10 @@
const iree_hal_memory_barrier_t* memory_barriers,
iree_host_size_t buffer_barrier_count,
const iree_hal_buffer_barrier_t* buffer_barriers) {
+ iree_hal_cuda_stream_command_buffer_t* command_buffer =
+ iree_hal_cuda_stream_command_buffer_cast(base_command_buffer);
+ IREE_RETURN_IF_ERROR(
+ iree_hal_cuda_stream_command_buffer_flush_collectives(command_buffer));
// TODO(jinchen62): implement CUDA barrier
return iree_ok_status();
}
@@ -143,6 +191,10 @@
static iree_status_t iree_hal_cuda_stream_command_buffer_signal_event(
iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event,
iree_hal_execution_stage_t source_stage_mask) {
+ iree_hal_cuda_stream_command_buffer_t* command_buffer =
+ iree_hal_cuda_stream_command_buffer_cast(base_command_buffer);
+ IREE_RETURN_IF_ERROR(
+ iree_hal_cuda_stream_command_buffer_flush_collectives(command_buffer));
// TODO(jinchen62): implement CUDA barrier
return iree_ok_status();
}
@@ -150,6 +202,10 @@
static iree_status_t iree_hal_cuda_stream_command_buffer_reset_event(
iree_hal_command_buffer_t* base_command_buffer, iree_hal_event_t* event,
iree_hal_execution_stage_t source_stage_mask) {
+ iree_hal_cuda_stream_command_buffer_t* command_buffer =
+ iree_hal_cuda_stream_command_buffer_cast(base_command_buffer);
+ IREE_RETURN_IF_ERROR(
+ iree_hal_cuda_stream_command_buffer_flush_collectives(command_buffer));
// TODO(jinchen62): implement CUDA barrier
return iree_ok_status();
}
@@ -163,6 +219,10 @@
const iree_hal_memory_barrier_t* memory_barriers,
iree_host_size_t buffer_barrier_count,
const iree_hal_buffer_barrier_t* buffer_barriers) {
+ iree_hal_cuda_stream_command_buffer_t* command_buffer =
+ iree_hal_cuda_stream_command_buffer_cast(base_command_buffer);
+ IREE_RETURN_IF_ERROR(
+ iree_hal_cuda_stream_command_buffer_flush_collectives(command_buffer));
// TODO(jinchen62): implement CUDA barrier
return iree_ok_status();
}
@@ -181,6 +241,8 @@
iree_host_size_t pattern_length) {
iree_hal_cuda_stream_command_buffer_t* command_buffer =
iree_hal_cuda_stream_command_buffer_cast(base_command_buffer);
+ IREE_RETURN_IF_ERROR(
+ iree_hal_cuda_stream_command_buffer_flush_collectives(command_buffer));
CUdeviceptr target_device_buffer = iree_hal_cuda_buffer_device_pointer(
iree_hal_buffer_allocated_buffer(target_buffer));
@@ -216,6 +278,7 @@
return iree_make_status(IREE_STATUS_INTERNAL,
"unsupported fill pattern length");
}
+
return iree_ok_status();
}
@@ -225,6 +288,8 @@
iree_device_size_t target_offset, iree_device_size_t length) {
iree_hal_cuda_stream_command_buffer_t* command_buffer =
iree_hal_cuda_stream_command_buffer_cast(base_command_buffer);
+ IREE_RETURN_IF_ERROR(
+ iree_hal_cuda_stream_command_buffer_flush_collectives(command_buffer));
// Allocate scratch space in the arena for the data and copy it in.
// The update buffer API requires that the command buffer capture the host
@@ -250,6 +315,7 @@
command_buffer->context->syms,
cuMemcpyHtoDAsync_v2(dst, src, length, command_buffer->stream),
"cuMemcpyHtoDAsync_v2");
+
return iree_ok_status();
}
@@ -260,6 +326,8 @@
iree_device_size_t length) {
iree_hal_cuda_stream_command_buffer_t* command_buffer =
iree_hal_cuda_stream_command_buffer_cast(base_command_buffer);
+ IREE_RETURN_IF_ERROR(
+ iree_hal_cuda_stream_command_buffer_flush_collectives(command_buffer));
CUdeviceptr target_device_buffer = iree_hal_cuda_buffer_device_pointer(
iree_hal_buffer_allocated_buffer(target_buffer));
@@ -272,20 +340,35 @@
CUDA_RETURN_IF_ERROR(command_buffer->context->syms,
cuMemcpyAsync(dst, src, length, command_buffer->stream),
"cuMemcpyAsync");
+
return iree_ok_status();
}
+static iree_status_t iree_hal_cuda_stream_command_buffer_collective(
+ iree_hal_command_buffer_t* base_command_buffer, iree_hal_channel_t* channel,
+ iree_hal_collective_op_t op, uint32_t param,
+ iree_hal_buffer_binding_t send_binding,
+ iree_hal_buffer_binding_t recv_binding, iree_device_size_t element_count) {
+ iree_hal_cuda_stream_command_buffer_t* command_buffer =
+ iree_hal_cuda_stream_command_buffer_cast(base_command_buffer);
+ return iree_hal_collective_batch_append(&command_buffer->collective_batch,
+ channel, op, param, send_binding,
+ recv_binding, element_count);
+}
+
static iree_status_t iree_hal_cuda_stream_command_buffer_push_constants(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_pipeline_layout_t* pipeline_layout, iree_host_size_t offset,
const void* values, iree_host_size_t values_length) {
iree_hal_cuda_stream_command_buffer_t* command_buffer =
iree_hal_cuda_stream_command_buffer_cast(base_command_buffer);
+
iree_host_size_t constant_base_index = offset / sizeof(int32_t);
for (iree_host_size_t i = 0; i < values_length / sizeof(int32_t); i++) {
command_buffer->push_constant[i + constant_base_index] =
((uint32_t*)values)[i];
}
+
return iree_ok_status();
}
@@ -311,8 +394,10 @@
const iree_hal_descriptor_set_binding_t* bindings) {
iree_hal_cuda_stream_command_buffer_t* command_buffer =
iree_hal_cuda_stream_command_buffer_cast(base_command_buffer);
+
iree_host_size_t base_binding =
iree_hal_cuda_base_binding_index(pipeline_layout, set);
+
// Convention with the compiler side. We map bindings to kernel argument.
// We compact the bindings to get a dense set of arguments and keep them order
// based on the binding index.
@@ -323,10 +408,13 @@
iree_hal_cuda_binding_mapping_t buffer = {i, bindings[i].binding};
binding_used[i] = buffer;
}
+ // TODO: remove this sort - it's thankfully small (1-8 on average) but we
+ // should be able to avoid it like we do on the CPU side with a bitmap.
qsort(binding_used, binding_count, sizeof(iree_hal_cuda_binding_mapping_t),
compare_binding_index);
assert(binding_count < IREE_HAL_CUDA_MAX_BINDING_COUNT &&
"binding count larger than the max expected.");
+
for (iree_host_size_t i = 0; i < binding_count; i++) {
iree_hal_descriptor_set_binding_t binding = bindings[binding_used[i].index];
CUdeviceptr device_ptr =
@@ -338,6 +426,7 @@
*((CUdeviceptr*)command_buffer->current_descriptor[i + base_binding]) =
device_ptr;
}
+
return iree_ok_status();
}
@@ -347,6 +436,9 @@
uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) {
iree_hal_cuda_stream_command_buffer_t* command_buffer =
iree_hal_cuda_stream_command_buffer_cast(base_command_buffer);
+ IREE_RETURN_IF_ERROR(
+ iree_hal_cuda_stream_command_buffer_flush_collectives(command_buffer));
+
iree_hal_pipeline_layout_t* layout =
iree_hal_cuda_executable_get_layout(executable, entry_point);
iree_host_size_t num_constants =
@@ -374,6 +466,7 @@
command_buffer->stream, command_buffer->current_descriptor,
NULL),
"cuLaunchKernel");
+
return iree_ok_status();
}
@@ -411,6 +504,7 @@
.fill_buffer = iree_hal_cuda_stream_command_buffer_fill_buffer,
.update_buffer = iree_hal_cuda_stream_command_buffer_update_buffer,
.copy_buffer = iree_hal_cuda_stream_command_buffer_copy_buffer,
+ .collective = iree_hal_cuda_stream_command_buffer_collective,
.push_constants = iree_hal_cuda_stream_command_buffer_push_constants,
.push_descriptor_set =
iree_hal_cuda_stream_command_buffer_push_descriptor_set,
diff --git a/runtime/src/iree/hal/drivers/local_sync/sync_device.c b/runtime/src/iree/hal/drivers/local_sync/sync_device.c
index 1bea5b5..92eebdf 100644
--- a/runtime/src/iree/hal/drivers/local_sync/sync_device.c
+++ b/runtime/src/iree/hal/drivers/local_sync/sync_device.c
@@ -189,6 +189,13 @@
(int)category.size, category.data, (int)key.size, key.data);
}
+static iree_status_t iree_hal_sync_device_create_channel(
+ iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
+ iree_hal_channel_params_t params, iree_hal_channel_t** out_channel) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "collectives not implemented");
+}
+
static iree_status_t iree_hal_sync_device_create_command_buffer(
iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode,
iree_hal_command_category_t command_categories,
@@ -400,6 +407,7 @@
.device_allocator = iree_hal_sync_device_allocator,
.trim = iree_hal_sync_device_trim,
.query_i64 = iree_hal_sync_device_query_i64,
+ .create_channel = iree_hal_sync_device_create_channel,
.create_command_buffer = iree_hal_sync_device_create_command_buffer,
.create_descriptor_set_layout =
iree_hal_sync_device_create_descriptor_set_layout,
diff --git a/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c b/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c
index da96b12..a4c9fe6 100644
--- a/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c
+++ b/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c
@@ -705,6 +705,37 @@
}
//===----------------------------------------------------------------------===//
+// iree_hal_command_buffer_collective
+//===----------------------------------------------------------------------===//
+
+static iree_status_t iree_hal_task_command_buffer_collective(
+ iree_hal_command_buffer_t* base_command_buffer, iree_hal_channel_t* channel,
+ iree_hal_collective_op_t op, uint32_t param,
+ iree_hal_buffer_binding_t send_binding,
+ iree_hal_buffer_binding_t recv_binding, iree_device_size_t element_count) {
+ // The channel can be used as a vtable if we want to inject collective APIs -
+ // the device creation function would set up the channel once and we'll
+ // receive it here each time. When interacting with the task system we want to
+ // get wait handles we can model with iree_task_wait_t.
+ //
+ // An example basic flow:
+ // insert iree_task_call_t:
+ // chains with prior commands and makes the collective API call
+ // insert iree_task_wait_t with API wait handle or our event:
+ // chains with call
+ //
+ // What we probably want to do, though, is group the commands based on
+ // execution barriers. When a new collective command comes in we should
+ // reserve an event from the event pool, create the call to issue the
+ // collective operation, and then track the event in the command buffer state.
+ // When another collective call comes in we'll do the same and append the
+ // event. At the next execution barrier (or non-collective command) we'd
+ // flush to a multi-wait on all of the pending events.
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "collectives not yet implemented on the task system");
+}
+
+//===----------------------------------------------------------------------===//
// iree_hal_command_buffer_push_constants
//===----------------------------------------------------------------------===//
// NOTE: command buffer state change only; enqueues no tasks.
@@ -995,6 +1026,10 @@
return iree_ok_status();
}
+//===----------------------------------------------------------------------===//
+// iree_hal_command_buffer_execute_commands
+//===----------------------------------------------------------------------===//
+
static iree_status_t iree_hal_task_command_buffer_execute_commands(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_command_buffer_t* base_commands,
@@ -1031,6 +1066,7 @@
.fill_buffer = iree_hal_task_command_buffer_fill_buffer,
.update_buffer = iree_hal_task_command_buffer_update_buffer,
.copy_buffer = iree_hal_task_command_buffer_copy_buffer,
+ .collective = iree_hal_task_command_buffer_collective,
.push_constants = iree_hal_task_command_buffer_push_constants,
.push_descriptor_set = iree_hal_task_command_buffer_push_descriptor_set,
.dispatch = iree_hal_task_command_buffer_dispatch,
diff --git a/runtime/src/iree/hal/drivers/local_task/task_device.c b/runtime/src/iree/hal/drivers/local_task/task_device.c
index 5f96fc8..7109110 100644
--- a/runtime/src/iree/hal/drivers/local_task/task_device.c
+++ b/runtime/src/iree/hal/drivers/local_task/task_device.c
@@ -246,6 +246,13 @@
return queue_affinity % device->queue_count;
}
+static iree_status_t iree_hal_task_device_create_channel(
+ iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
+ iree_hal_channel_params_t params, iree_hal_channel_t** out_channel) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "collectives not implemented");
+}
+
static iree_status_t iree_hal_task_device_create_command_buffer(
iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode,
iree_hal_command_category_t command_categories,
@@ -419,6 +426,7 @@
.device_allocator = iree_hal_task_device_allocator,
.trim = iree_hal_task_device_trim,
.query_i64 = iree_hal_task_device_query_i64,
+ .create_channel = iree_hal_task_device_create_channel,
.create_command_buffer = iree_hal_task_device_create_command_buffer,
.create_descriptor_set_layout =
iree_hal_task_device_create_descriptor_set_layout,
diff --git a/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc b/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc
index 7a47af2..df7732b 100644
--- a/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc
+++ b/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc
@@ -654,6 +654,15 @@
return iree_ok_status();
}
+static iree_status_t iree_hal_vulkan_direct_command_buffer_collective(
+ iree_hal_command_buffer_t* base_command_buffer, iree_hal_channel_t* channel,
+ iree_hal_collective_op_t op, uint32_t param,
+ iree_hal_buffer_binding_t send_binding,
+ iree_hal_buffer_binding_t recv_binding, iree_device_size_t element_count) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "collectives not yet implemented on Vulkan");
+}
+
static iree_status_t iree_hal_vulkan_direct_command_buffer_push_constants(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_pipeline_layout_t* pipeline_layout, iree_host_size_t offset,
@@ -835,6 +844,8 @@
/*.update_buffer=*/
iree_hal_vulkan_direct_command_buffer_update_buffer,
/*.copy_buffer=*/iree_hal_vulkan_direct_command_buffer_copy_buffer,
+ /*.collective=*/
+ iree_hal_vulkan_direct_command_buffer_collective,
/*.push_constants=*/
iree_hal_vulkan_direct_command_buffer_push_constants,
/*.push_descriptor_set=*/
diff --git a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc
index 81cee5a..737baa3 100644
--- a/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc
+++ b/runtime/src/iree/hal/drivers/vulkan/vulkan_device.cc
@@ -1091,6 +1091,13 @@
return device->dispatch_queues[queue_affinity % device->dispatch_queue_count];
}
+static iree_status_t iree_hal_vulkan_device_create_channel(
+ iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity,
+ iree_hal_channel_params_t params, iree_hal_channel_t** out_channel) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "collectives not implemented");
+}
+
static iree_status_t iree_hal_vulkan_device_create_command_buffer(
iree_hal_device_t* base_device, iree_hal_command_buffer_mode_t mode,
iree_hal_command_category_t command_categories,
@@ -1327,6 +1334,7 @@
/*.device_allocator=*/iree_hal_vulkan_device_allocator,
/*.trim=*/iree_hal_vulkan_device_trim,
/*.query_i64=*/iree_hal_vulkan_device_query_i64,
+ /*.create_channel=*/iree_hal_vulkan_device_create_channel,
/*.create_command_buffer=*/iree_hal_vulkan_device_create_command_buffer,
/*.create_descriptor_set_layout=*/
iree_hal_vulkan_device_create_descriptor_set_layout,
diff --git a/runtime/src/iree/hal/local/inline_command_buffer.c b/runtime/src/iree/hal/local/inline_command_buffer.c
index e9c148a..edede8f 100644
--- a/runtime/src/iree/hal/local/inline_command_buffer.c
+++ b/runtime/src/iree/hal/local/inline_command_buffer.c
@@ -359,6 +359,19 @@
}
//===----------------------------------------------------------------------===//
+// iree_hal_command_buffer_collective
+//===----------------------------------------------------------------------===//
+
+static iree_status_t iree_hal_inline_command_buffer_collective(
+ iree_hal_command_buffer_t* base_command_buffer, iree_hal_channel_t* channel,
+ iree_hal_collective_op_t op, uint32_t param,
+ iree_hal_buffer_binding_t send_binding,
+ iree_hal_buffer_binding_t recv_binding, iree_device_size_t element_count) {
+ return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
+ "collectives not yet implemented on CPU");
+}
+
+//===----------------------------------------------------------------------===//
// iree_hal_command_buffer_push_constants
//===----------------------------------------------------------------------===//
// NOTE: command buffer state change only; enqueues no tasks.
@@ -567,6 +580,10 @@
workgroup_count.y, workgroup_count.z);
}
+//===----------------------------------------------------------------------===//
+// iree_hal_command_buffer_execute_commands
+//===----------------------------------------------------------------------===//
+
static iree_status_t iree_hal_inline_command_buffer_execute_commands(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_command_buffer_t* base_commands,
@@ -599,6 +616,7 @@
.fill_buffer = iree_hal_inline_command_buffer_fill_buffer,
.update_buffer = iree_hal_inline_command_buffer_update_buffer,
.copy_buffer = iree_hal_inline_command_buffer_copy_buffer,
+ .collective = iree_hal_inline_command_buffer_collective,
.push_constants = iree_hal_inline_command_buffer_push_constants,
.push_descriptor_set =
iree_hal_inline_command_buffer_push_descriptor_set,
diff --git a/runtime/src/iree/hal/utils/BUILD b/runtime/src/iree/hal/utils/BUILD
index a6ad31e..2f8a29f 100644
--- a/runtime/src/iree/hal/utils/BUILD
+++ b/runtime/src/iree/hal/utils/BUILD
@@ -26,6 +26,20 @@
)
iree_runtime_cc_library(
+ name = "collective_batch",
+ srcs = ["collective_batch.c"],
+ hdrs = ["collective_batch.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":resource_set",
+ "//runtime/src/iree/base",
+ "//runtime/src/iree/base:tracing",
+ "//runtime/src/iree/base/internal:arena",
+ "//runtime/src/iree/hal",
+ ],
+)
+
+iree_runtime_cc_library(
name = "deferred_command_buffer",
srcs = ["deferred_command_buffer.c"],
hdrs = ["deferred_command_buffer.h"],
diff --git a/runtime/src/iree/hal/utils/CMakeLists.txt b/runtime/src/iree/hal/utils/CMakeLists.txt
index e646d46..c390d78 100644
--- a/runtime/src/iree/hal/utils/CMakeLists.txt
+++ b/runtime/src/iree/hal/utils/CMakeLists.txt
@@ -26,6 +26,22 @@
iree_cc_library(
NAME
+ collective_batch
+ HDRS
+ "collective_batch.h"
+ SRCS
+ "collective_batch.c"
+ DEPS
+ ::resource_set
+ iree::base
+ iree::base::internal::arena
+ iree::base::tracing
+ iree::hal
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
deferred_command_buffer
HDRS
"deferred_command_buffer.h"
diff --git a/runtime/src/iree/hal/utils/collective_batch.c b/runtime/src/iree/hal/utils/collective_batch.c
new file mode 100644
index 0000000..7f73bc2
--- /dev/null
+++ b/runtime/src/iree/hal/utils/collective_batch.c
@@ -0,0 +1,112 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/hal/utils/collective_batch.h"
+
+#include "iree/base/tracing.h"
+
+//===----------------------------------------------------------------------===//
+// Collective batching utility
+//===----------------------------------------------------------------------===//
+
+#define IREE_HAL_COLLECTIVE_BATCH_INITIAL_CAPACITY 16
+
+IREE_API_EXPORT void iree_hal_collective_batch_initialize(
+ iree_arena_allocator_t* arena, iree_hal_resource_set_t* resource_set,
+ iree_hal_collective_batch_t* out_batch) {
+ out_batch->arena = arena;
+ out_batch->resource_set = resource_set;
+ out_batch->capacity = 0;
+ out_batch->count = 0;
+ out_batch->entries = NULL;
+}
+
+IREE_API_EXPORT void iree_hal_collective_batch_deinitialize(
+ iree_hal_collective_batch_t* batch) {
+ // Since we are just allocating from the arena we don't need to do anything
+ // but clear our pointers for debugging clarity.
+ iree_hal_collective_batch_reset(batch);
+}
+
+IREE_API_EXPORT bool iree_hal_collective_batch_is_empty(
+ const iree_hal_collective_batch_t* batch) {
+ return batch->count == 0;
+}
+
+IREE_API_EXPORT void iree_hal_collective_batch_reset(
+ iree_hal_collective_batch_t* batch) {
+ // Reset the count to zero but keep the arena storage for reuse.
+ // We could memset the contents if we wanted to make debugging easier as ASAN
+ // won't be able to help us but it'd probably be better to use ASAN hooks to
+ // mark the memory as invalid instead.
+ batch->count = 0;
+}
+
+// Grows the storage of the |batch| by 2x by slicing off new memory from the
+// arena and copying over the existing contents.
+static iree_status_t iree_hal_collective_batch_grow(
+ iree_hal_collective_batch_t* batch) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ // Calculate new capacity. Note that we start empty.
+ iree_host_size_t new_capacity =
+ batch->capacity == 0 ? IREE_HAL_COLLECTIVE_BATCH_INITIAL_CAPACITY
+ : batch->capacity * 2;
+ IREE_TRACE_ZONE_APPEND_VALUE(z0, new_capacity);
+
+ // Allocate new storage - this may fail if the system (or block pool) is over
+ // capacity.
+ iree_hal_collective_batch_entry_t* new_entries = NULL;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0,
+ iree_arena_allocate(batch->arena, new_capacity * sizeof(*batch->entries),
+ (void**)&new_entries));
+
+ // Copy over existing items. We let the old entry list go as it'll eventually
+ // be cleaned up when the arena is reset.
+ memcpy(new_entries, batch->entries, batch->count * sizeof(*batch->entries));
+ batch->capacity = new_capacity;
+ batch->entries = new_entries;
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+IREE_API_EXPORT iree_status_t iree_hal_collective_batch_append(
+ iree_hal_collective_batch_t* batch, iree_hal_channel_t* channel,
+ iree_hal_collective_op_t op, uint32_t param,
+ iree_hal_buffer_binding_t send_binding,
+ iree_hal_buffer_binding_t recv_binding, iree_device_size_t element_count) {
+ // Grow the entry storage if required.
+ if (batch->count + 1 > batch->capacity) {
+ IREE_RETURN_IF_ERROR(iree_hal_collective_batch_grow(batch));
+ }
+
+ // Insert resources into the resource set to keep them live.
+ iree_host_size_t resource_count = 0;
+ void* resources[3] = {NULL};
+ resources[resource_count++] = channel;
+ if (send_binding.buffer) {
+ resources[resource_count++] = send_binding.buffer;
+ }
+ if (recv_binding.buffer) {
+ resources[resource_count++] = recv_binding.buffer;
+ }
+ IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert(batch->resource_set,
+ resource_count, resources));
+
+ // Append entry to the list.
+ batch->entries[batch->count++] = (iree_hal_collective_batch_entry_t){
+ .channel = channel,
+ .op = op,
+ .param = param,
+ .send_binding = send_binding,
+ .recv_binding = recv_binding,
+ .element_count = element_count,
+ };
+
+ return iree_ok_status();
+}
diff --git a/runtime/src/iree/hal/utils/collective_batch.h b/runtime/src/iree/hal/utils/collective_batch.h
new file mode 100644
index 0000000..7ed9823
--- /dev/null
+++ b/runtime/src/iree/hal/utils/collective_batch.h
@@ -0,0 +1,98 @@
+// Copyright 2022 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_HAL_UTILS_COLLECTIVE_BATCH_H_
+#define IREE_HAL_UTILS_COLLECTIVE_BATCH_H_
+
+#include "iree/base/api.h"
+#include "iree/base/internal/arena.h"
+#include "iree/hal/api.h"
+#include "iree/hal/utils/resource_set.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// Collective batching utility
+//===----------------------------------------------------------------------===//
+
+// Recorded collective operation in a batch.
+// The specified channel and binding buffers will be retained by the resource
+// set for the lifetime of the parent command buffer.
+typedef struct {
+ iree_hal_channel_t* channel;
+ iree_hal_collective_op_t op;
+ uint32_t param;
+ iree_hal_buffer_binding_t send_binding;
+ iree_hal_buffer_binding_t recv_binding;
+ iree_device_size_t element_count;
+} iree_hal_collective_batch_entry_t;
+
+// Builds batches of collective operations for grouped submission.
+// This is to be embedded in command buffer implementations and used to
+// incrementally build batches of collective operations that can be submitted to
+// implementations as atomic operations. The compiler is _supposed_ to emit
+// collectives within a barrier scope though that's not verified by the API
+// today.
+//
+// Referenced resources, such as channels and buffers, are retained on the
+// resource set owned by the command buffer. This allows for async submissions
+// to backing implementations to remain valid even if the code submitting the
+// command buffers may drop their reference while it is in-flight.
+typedef struct {
+ // Arena used for scratch allocations used during batch construction.
+ // This is owned by the parent of the collective batch and the lifetime of the
+ // arena contents is controlled by the parent.
+ iree_arena_allocator_t* arena;
+
+ // Resource set that submitted channels and buffers will be retained in.
+ iree_hal_resource_set_t* resource_set;
+
+ // Growable list of accumulated operations (starts empty).
+ // We could use a linked list into arena storage but we don't need to persist
+ // the contents beyond a single flush. Instead we slice out some storage as
+ // needed and grow by slicing off more and copying over the existing contents.
+ // This should stabilized to the maximum batch size pretty fast with minimal
+ // command buffer overhead. If we notice people doing counts following the
+ // fibonacci sequence we could rework things but in average usage we expect
+ // 1-16 entries on average.
+ iree_host_size_t capacity;
+ iree_host_size_t count;
+ iree_hal_collective_batch_entry_t* entries;
+} iree_hal_collective_batch_t;
+
+// Initializes |out_batch| for use using |arena| for any transient allocations
+// required. All resources used will be inserted into |resource_set|.
+IREE_API_EXPORT void iree_hal_collective_batch_initialize(
+ iree_arena_allocator_t* arena, iree_hal_resource_set_t* resource_set,
+ iree_hal_collective_batch_t* out_batch);
+
+// Deinitializes |batch| and releases any allocated memory.
+IREE_API_EXPORT void iree_hal_collective_batch_deinitialize(
+ iree_hal_collective_batch_t* batch);
+
+// Returns true if the batch is empty.
+IREE_API_EXPORT bool iree_hal_collective_batch_is_empty(
+ const iree_hal_collective_batch_t* batch);
+
+// Resets the collective batch and drops all storage.
+IREE_API_EXPORT void iree_hal_collective_batch_reset(
+ iree_hal_collective_batch_t* batch);
+
+// Appends a collective operation to the batch.
+// Referenced resources will be retained.
+IREE_API_EXPORT iree_status_t iree_hal_collective_batch_append(
+ iree_hal_collective_batch_t* batch, iree_hal_channel_t* channel,
+ iree_hal_collective_op_t op, uint32_t param,
+ iree_hal_buffer_binding_t send_binding,
+ iree_hal_buffer_binding_t recv_binding, iree_device_size_t element_count);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_HAL_UTILS_COLLECTIVE_BATCH_H_
diff --git a/runtime/src/iree/hal/utils/deferred_command_buffer.c b/runtime/src/iree/hal/utils/deferred_command_buffer.c
index 4fd9e59..41c7dcf 100644
--- a/runtime/src/iree/hal/utils/deferred_command_buffer.c
+++ b/runtime/src/iree/hal/utils/deferred_command_buffer.c
@@ -23,6 +23,7 @@
IREE_HAL_CMD_FILL_BUFFER,
IREE_HAL_CMD_UPDATE_BUFFER,
IREE_HAL_CMD_COPY_BUFFER,
+ IREE_HAL_CMD_COLLECTIVE,
IREE_HAL_CMD_PUSH_CONSTANTS,
IREE_HAL_CMD_PUSH_DESCRIPTOR_SET,
IREE_HAL_CMD_DISPATCH,
@@ -601,6 +602,56 @@
}
//===----------------------------------------------------------------------===//
+// IREE_HAL_CMD_COLLECTIVE
+//===----------------------------------------------------------------------===//
+
+typedef struct iree_hal_cmd_collective_t {
+ iree_hal_cmd_header_t header;
+ iree_hal_channel_t* channel;
+ iree_hal_collective_op_t op;
+ uint32_t param;
+ iree_hal_buffer_binding_t send_binding;
+ iree_hal_buffer_binding_t recv_binding;
+ iree_device_size_t element_count;
+} iree_hal_cmd_collective_t;
+
+static iree_status_t iree_hal_deferred_command_buffer_collective(
+ iree_hal_command_buffer_t* base_command_buffer, iree_hal_channel_t* channel,
+ iree_hal_collective_op_t op, uint32_t param,
+ iree_hal_buffer_binding_t send_binding,
+ iree_hal_buffer_binding_t recv_binding, iree_device_size_t element_count) {
+ iree_hal_deferred_command_buffer_t* command_buffer =
+ iree_hal_deferred_command_buffer_cast(base_command_buffer);
+ iree_hal_cmd_list_t* cmd_list = &command_buffer->cmd_list;
+ iree_host_size_t resource_count = 0;
+ const void* resources[3] = {NULL, NULL, NULL};
+ resources[resource_count++] = channel;
+ if (send_binding.buffer) resources[resource_count++] = send_binding.buffer;
+ if (recv_binding.buffer) resources[resource_count++] = recv_binding.buffer;
+ IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert(
+ command_buffer->resource_set, resource_count, resources));
+ iree_hal_cmd_collective_t* cmd = NULL;
+ IREE_RETURN_IF_ERROR(iree_hal_cmd_list_append_command(
+ cmd_list, IREE_HAL_CMD_COLLECTIVE, sizeof(*cmd), (void**)&cmd));
+ cmd->channel = channel;
+ cmd->op = op;
+ cmd->param = param;
+ cmd->send_binding = send_binding;
+ cmd->recv_binding = recv_binding;
+ cmd->element_count = element_count;
+ return iree_ok_status();
+}
+
+static iree_status_t iree_hal_deferred_command_buffer_apply_collective(
+ iree_hal_command_buffer_t* target_command_buffer,
+ iree_hal_buffer_binding_table_t binding_table,
+ const iree_hal_cmd_collective_t* cmd) {
+ return iree_hal_command_buffer_collective(
+ target_command_buffer, cmd->channel, cmd->op, cmd->param,
+ cmd->send_binding, cmd->recv_binding, cmd->element_count);
+}
+
+//===----------------------------------------------------------------------===//
// IREE_HAL_CMD_PUSH_CONSTANTS
//===----------------------------------------------------------------------===//
@@ -845,6 +896,8 @@
iree_hal_deferred_command_buffer_apply_update_buffer,
[IREE_HAL_CMD_COPY_BUFFER] = (iree_hal_cmd_apply_fn_t)
iree_hal_deferred_command_buffer_apply_copy_buffer,
+ [IREE_HAL_CMD_COLLECTIVE] = (iree_hal_cmd_apply_fn_t)
+ iree_hal_deferred_command_buffer_apply_collective,
[IREE_HAL_CMD_PUSH_CONSTANTS] = (iree_hal_cmd_apply_fn_t)
iree_hal_deferred_command_buffer_apply_push_constants,
[IREE_HAL_CMD_PUSH_DESCRIPTOR_SET] = (iree_hal_cmd_apply_fn_t)
@@ -909,6 +962,7 @@
.fill_buffer = iree_hal_deferred_command_buffer_fill_buffer,
.update_buffer = iree_hal_deferred_command_buffer_update_buffer,
.copy_buffer = iree_hal_deferred_command_buffer_copy_buffer,
+ .collective = iree_hal_deferred_command_buffer_collective,
.push_constants = iree_hal_deferred_command_buffer_push_constants,
.push_descriptor_set =
iree_hal_deferred_command_buffer_push_descriptor_set,