[rocm] Backport and adjust some HIP allocator and buffer changes (#16627)
We want to replace the rocm HAL driver with the new HIP HAL driver soon.
Though the latter is not exactly there yet with some known missing
features and issues. So this commit backports some changes:
* Properly initialized the block pool in the device
* Added update buffer implementation for command buffer
* Fleshed out logic in the allocator compatibility query
Additionally, we mark device local + host visible as low performance
entirely--it's much slower than the device only memory. We have seen 10x
perf degradation.
diff --git a/experimental/rocm/context_wrapper.h b/experimental/rocm/context_wrapper.h
index 819b67a..66451cb 100644
--- a/experimental/rocm/context_wrapper.h
+++ b/experimental/rocm/context_wrapper.h
@@ -14,8 +14,9 @@
// Structure to wrap all objects constant within a context. This makes it
// simpler to pass it to the different objects and saves memory.
typedef struct iree_hal_rocm_context_wrapper_t {
- hipDevice_t rocm_device;
hipCtx_t rocm_context;
+ hipDevice_t rocm_device;
+ hipStream_t rocm_stream;
iree_allocator_t host_allocator;
iree_hal_rocm_dynamic_symbols_t *syms;
} iree_hal_rocm_context_wrapper_t;
diff --git a/experimental/rocm/direct_command_buffer.c b/experimental/rocm/direct_command_buffer.c
index abd91a6..0697aee 100644
--- a/experimental/rocm/direct_command_buffer.c
+++ b/experimental/rocm/direct_command_buffer.c
@@ -27,6 +27,11 @@
iree_arena_block_pool_t* block_pool;
iree_hal_rocm_tracing_context_t* tracing_context;
+ // Staging arena used for host->device transfers.
+ // Used for when we need HIP to be able to reference memory as it performs
+ // asynchronous operations.
+ iree_arena_allocator_t arena;
+
// Keep track of the current set of kernel arguments.
int32_t push_constant[IREE_HAL_ROCM_MAX_PUSH_CONSTANT_COUNT];
void* current_descriptor[];
@@ -79,6 +84,7 @@
command_buffer->context = context;
command_buffer->tracing_context = tracing_context;
command_buffer->block_pool = block_pool;
+ iree_arena_initialize(block_pool, &command_buffer->arena);
hipDeviceptr_t* device_ptrs =
(hipDeviceptr_t*)(command_buffer->current_descriptor +
IREE_HAL_ROCM_MAX_KERNEL_ARG);
@@ -99,6 +105,7 @@
iree_hal_rocm_direct_command_buffer_cast(base_command_buffer);
IREE_TRACE_ZONE_BEGIN(z0);
+ iree_arena_deinitialize(&command_buffer->arena);
iree_allocator_free(command_buffer->context->host_allocator, command_buffer);
IREE_TRACE_ZONE_END(z0);
@@ -227,8 +234,34 @@
iree_hal_command_buffer_t* base_command_buffer, const void* source_buffer,
iree_host_size_t source_offset, iree_hal_buffer_t* target_buffer,
iree_device_size_t target_offset, iree_device_size_t length) {
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED,
- "need rocm implementation");
+ iree_hal_rocm_direct_command_buffer_t* command_buffer =
+ iree_hal_rocm_direct_command_buffer_cast(base_command_buffer);
+
+ // Allocate scratch space in the arena for the data and copy it in.
+ // The update buffer API requires that the command buffer capture the host
+ // memory at the time the method is called in case the caller wants to reuse
+ // the memory. Because HIP memcpys are async if we didn't copy it's possible
+ // for the reused memory to change before the stream reaches the copy
+ // operation and get the wrong data.
+ const uint8_t* src = (const uint8_t*)source_buffer + source_offset;
+ uint8_t* storage = NULL;
+ IREE_RETURN_IF_ERROR(
+ iree_arena_allocate(&command_buffer->arena, length, (void**)&storage));
+ memcpy(storage, src, length);
+ src = storage;
+
+ // Issue the copy using the scratch memory as the source.
+ hipDeviceptr_t target_device_buffer = iree_hal_rocm_buffer_device_pointer(
+ iree_hal_buffer_allocated_buffer(target_buffer));
+ hipDeviceptr_t dst = (uint8_t*)target_device_buffer +
+ iree_hal_buffer_byte_offset(target_buffer) +
+ target_offset;
+ ROCM_RETURN_IF_ERROR(command_buffer->context->syms,
+ hipMemcpyHtoDAsync(dst, (void*)src, length,
+ command_buffer->context->rocm_stream),
+ "hipMemcpyHtoDAsync");
+
+ return iree_ok_status();
}
static iree_status_t iree_hal_rocm_direct_command_buffer_copy_buffer(
diff --git a/experimental/rocm/dynamic_symbol_tables.h b/experimental/rocm/dynamic_symbol_tables.h
index 785f0ed..2cdab9c 100644
--- a/experimental/rocm/dynamic_symbol_tables.h
+++ b/experimental/rocm/dynamic_symbol_tables.h
@@ -61,3 +61,5 @@
RC_PFN_DECL(hipCtxGetDevice, hipDevice_t *)
RC_PFN_DECL(hipCtxSetCurrent, hipCtx_t)
RC_PFN_DECL(hipDevicePrimaryCtxRelease, hipDevice_t)
+RC_PFN_DECL(hipMemPrefetchAsync, const void *, size_t, int, hipStream_t)
+RC_PFN_DECL(hipMemcpyHtoDAsync, hipDeviceptr_t, void *, size_t, hipStream_t)
diff --git a/experimental/rocm/rocm_allocator.c b/experimental/rocm/rocm_allocator.c
index 3c63c71..e8d166d 100644
--- a/experimental/rocm/rocm_allocator.c
+++ b/experimental/rocm/rocm_allocator.c
@@ -18,6 +18,8 @@
iree_hal_device_t* base_device;
iree_hal_rocm_context_wrapper_t* context;
+ bool supports_concurrent_managed_access;
+
IREE_STATISTICS(iree_hal_allocator_statistics_t statistics;)
} iree_hal_rocm_allocator_t;
@@ -34,6 +36,28 @@
iree_hal_allocator_t** out_allocator) {
IREE_ASSERT_ARGUMENT(context);
IREE_TRACE_ZONE_BEGIN(z0);
+
+ // To support device-local + host-visible memory we need concurrent managed
+ // access indicating that the host and devices can concurrently access the
+ // device memory. If we don't have this feature then we fall back to forcing
+ // all device-local + host-visible memory into host-local + device-visible
+ // page-locked memory. The compiler tries to avoid this for high-traffic
+ // buffers except for readback staging buffers.
+ int supports_concurrent_managed_access = 0;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, ROCM_RESULT_TO_STATUS(
+ context->syms,
+ hipDeviceGetAttribute(&supports_concurrent_managed_access,
+ hipDeviceAttributeConcurrentManagedAccess,
+ context->rocm_device),
+ "hipDeviceGetAttribute"));
+
+ IREE_TRACE_ZONE_APPEND_TEXT(
+ z0, supports_concurrent_managed_access
+ ? "has CONCURRENT_MANAGED_ACCESS"
+ : "no CONCURRENT_MANAGED_ACCESS (expect slow accesses on "
+ "device-local + host-visible memory)");
+
iree_hal_rocm_allocator_t* allocator = NULL;
iree_status_t status = iree_allocator_malloc(
context->host_allocator, sizeof(*allocator), (void**)&allocator);
@@ -41,6 +65,8 @@
iree_hal_resource_initialize(&iree_hal_rocm_allocator_vtable,
&allocator->resource);
allocator->context = context;
+ allocator->supports_concurrent_managed_access =
+ supports_concurrent_managed_access != 0;
*out_allocator = (iree_hal_allocator_t*)allocator;
}
@@ -141,6 +167,9 @@
iree_hal_allocator_t* IREE_RESTRICT base_allocator,
iree_hal_buffer_params_t* IREE_RESTRICT params,
iree_device_size_t* IREE_RESTRICT allocation_size) {
+ iree_hal_rocm_allocator_t* allocator =
+ iree_hal_rocm_allocator_cast(base_allocator);
+
// All buffers can be allocated on the heap.
iree_hal_buffer_compatibility_t compatibility =
IREE_HAL_BUFFER_COMPATIBILITY_ALLOCATABLE;
@@ -151,12 +180,31 @@
// Buffers can only be used on the queue if they are device visible.
if (iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) {
+ if (iree_any_bit_set(params->usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) {
+ compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER;
+ }
if (iree_any_bit_set(params->usage,
IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE)) {
compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH;
}
}
+ if (iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL |
+ IREE_HAL_MEMORY_TYPE_HOST_VISIBLE)) {
+ compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_LOW_PERFORMANCE;
+ // If concurrent managed access is not supported then make device-local +
+ // host-visible allocations fall back to host-local + device-visible
+ // page-locked memory. This will be significantly slower for the device to
+ // access but the compiler only uses this type for readback staging buffers
+ // and it's better to function than function fast.
+ if (!allocator->supports_concurrent_managed_access) {
+ params->type &= ~(IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL |
+ IREE_HAL_MEMORY_TYPE_HOST_VISIBLE);
+ params->type |=
+ IREE_HAL_MEMORY_TYPE_HOST_LOCAL | IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE;
+ }
+ }
+
// We are now optimal.
params->type &= ~IREE_HAL_MEMORY_TYPE_OPTIMAL;
@@ -209,6 +257,15 @@
status = ROCM_RESULT_TO_STATUS(
allocator->context->syms,
hipMallocManaged(&device_ptr, allocation_size, hipMemAttachGlobal));
+ if (iree_status_is_ok(status) &&
+ allocator->supports_concurrent_managed_access) {
+ // Prefetch the buffer on the GPU device.
+ status = ROCM_RESULT_TO_STATUS(
+ allocator->context->syms,
+ hipMemPrefetchAsync(device_ptr, allocation_size,
+ allocator->context->rocm_device,
+ allocator->context->rocm_stream));
+ }
host_ptr = (void*)device_ptr;
} else {
// Device only.
diff --git a/experimental/rocm/rocm_device.c b/experimental/rocm/rocm_device.c
index 86147f5..7200668 100644
--- a/experimental/rocm/rocm_device.c
+++ b/experimental/rocm/rocm_device.c
@@ -99,10 +99,13 @@
uint8_t* buffer_ptr = (uint8_t*)device + sizeof(*device);
buffer_ptr += iree_string_view_append_to_buffer(
identifier, &device->identifier, (char*)buffer_ptr);
+ iree_arena_block_pool_initialize(/*arena_block_size=*/32 * 1024,
+ host_allocator, &device->block_pool);
device->device = rocm_device;
device->stream = stream;
device->context_wrapper.rocm_context = context;
device->context_wrapper.rocm_device = rocm_device;
+ device->context_wrapper.rocm_stream = stream;
device->context_wrapper.host_allocator = host_allocator;
device->context_wrapper.syms = syms;
// Enable tracing for the (currently only) stream - no-op if disabled.