Fixing CUDA/ROCM transfer queue compatibility check.
diff --git a/experimental/rocm/rocm_allocator.c b/experimental/rocm/rocm_allocator.c
index 741fe41..f0aaf5d 100644
--- a/experimental/rocm/rocm_allocator.c
+++ b/experimental/rocm/rocm_allocator.c
@@ -94,11 +94,13 @@
iree_hal_buffer_compatibility_t compatibility =
IREE_HAL_BUFFER_COMPATIBILITY_ALLOCATABLE;
+ // ROCM supports host <-> device for all copies.
+ if (iree_all_bits_set(params->usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) {
+ compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER;
+ }
+
// Buffers can only be used on the queue if they are device visible.
if (iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) {
- if (iree_all_bits_set(params->usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) {
- compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER;
- }
if (iree_all_bits_set(params->usage, IREE_HAL_BUFFER_USAGE_DISPATCH)) {
compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH;
}
diff --git a/iree/hal/cuda/cuda_allocator.c b/iree/hal/cuda/cuda_allocator.c
index af94cf2..e1cec0a 100644
--- a/iree/hal/cuda/cuda_allocator.c
+++ b/iree/hal/cuda/cuda_allocator.c
@@ -123,11 +123,13 @@
iree_hal_buffer_compatibility_t compatibility =
IREE_HAL_BUFFER_COMPATIBILITY_ALLOCATABLE;
+ // CUDA supports host <-> device for all copies.
+ if (iree_all_bits_set(params->usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) {
+ compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER;
+ }
+
// Buffers can only be used on the queue if they are device visible.
if (iree_all_bits_set(params->type, IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE)) {
- if (iree_all_bits_set(params->usage, IREE_HAL_BUFFER_USAGE_TRANSFER)) {
- compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_TRANSFER;
- }
if (iree_all_bits_set(params->usage, IREE_HAL_BUFFER_USAGE_DISPATCH)) {
compatibility |= IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH;
}
@@ -220,7 +222,7 @@
iree_hal_buffer_t* buffer = NULL;
if (iree_status_is_ok(status)) {
status = iree_hal_cuda_buffer_wrap(
- (iree_hal_allocator_t*)allocator, memory_type, params->access,
+ allocator->base_device, base_allocator, memory_type, params->access,
params->usage, allocation_size,
/*byte_offset=*/0,
/*byte_length=*/allocation_size, device_ptr, host_ptr, &buffer);
diff --git a/iree/hal/cuda/dynamic_symbol_tables.h b/iree/hal/cuda/dynamic_symbol_tables.h
index e185356..aa2ee87 100644
--- a/iree/hal/cuda/dynamic_symbol_tables.h
+++ b/iree/hal/cuda/dynamic_symbol_tables.h
@@ -9,7 +9,7 @@
CU_PFN_DECL(cuDeviceGet, CUdevice*, int)
CU_PFN_DECL(cuDeviceGetCount, int*)
CU_PFN_DECL(cuDeviceGetName, char*, int, CUdevice)
-CU_PFN_DECL(cuDeviceGetAttribute, int *, CUdevice_attribute, CUdevice)
+CU_PFN_DECL(cuDeviceGetAttribute, int*, CUdevice_attribute, CUdevice)
CU_PFN_DECL(cuGetErrorName, CUresult, const char**)
CU_PFN_DECL(cuGetErrorString, CUresult, const char**)
CU_PFN_DECL(cuGraphAddMemcpyNode, CUgraphNode*, CUgraph, const CUgraphNode*,
@@ -48,7 +48,7 @@
CU_PFN_DECL(cuMemsetD8Async, unsigned long long, unsigned char, size_t,
CUstream)
CU_PFN_DECL(cuMemcpyAsync, CUdeviceptr, CUdeviceptr, size_t, CUstream)
-CU_PFN_DECL(cuMemcpyHtoDAsync_v2, CUdeviceptr, const void *, size_t, CUstream)
+CU_PFN_DECL(cuMemcpyHtoDAsync_v2, CUdeviceptr, const void*, size_t, CUstream)
CU_PFN_DECL(cuLaunchKernel, CUfunction, unsigned int, unsigned int,
unsigned int, unsigned int, unsigned int, unsigned int,
- unsigned int, CUstream, void **, void **)
+ unsigned int, CUstream, void**, void**)