A couple of fixes picked up in the fusilli tests using sanitizers. (#23617)
This fixes some *SAN issues found in the fusilli tests.
---------
Signed-off-by: Andrew Woloszyn <andrew.woloszyn@gmail.com>
diff --git a/runtime/src/iree/base/internal/dynamic_library.h b/runtime/src/iree/base/internal/dynamic_library.h
index f1f2370..cf10fc5 100644
--- a/runtime/src/iree/base/internal/dynamic_library.h
+++ b/runtime/src/iree/base/internal/dynamic_library.h
@@ -18,6 +18,7 @@
// Defines the behavior of the dynamic library loader.
enum iree_dynamic_library_flag_bits_t {
IREE_DYNAMIC_LIBRARY_FLAG_NONE = 0u,
+ IREE_DYNAMIC_LIBRARY_FLAG_NODELETE = 1
};
typedef uint32_t iree_dynamic_library_flags_t;
diff --git a/runtime/src/iree/hal/drivers/hip/hip_device.c b/runtime/src/iree/hal/drivers/hip/hip_device.c
index adc0b87..7cc236a 100644
--- a/runtime/src/iree/hal/drivers/hip/hip_device.c
+++ b/runtime/src/iree/hal/drivers/hip/hip_device.c
@@ -1087,6 +1087,7 @@
typedef struct iree_hal_hip_dispatch_completed_data_t {
iree_hal_resource_t resource;
+ iree_allocator_t host_allocator;
iree_notification_t notification;
iree_slim_mutex_t completed_mutex;
bool completed;
@@ -1098,6 +1099,7 @@
(iree_hal_hip_dispatch_completed_data_t*)resource;
iree_slim_mutex_deinitialize(&data->completed_mutex);
iree_notification_deinitialize(&data->notification);
+ iree_allocator_free(data->host_allocator, data);
}
static const iree_hal_resource_vtable_t
@@ -1112,11 +1114,11 @@
IREE_RETURN_IF_ERROR(
iree_allocator_malloc(host_allocator, sizeof(**out), (void**)out));
-
iree_slim_mutex_initialize(&(*out)->completed_mutex);
iree_notification_initialize(&(*out)->notification);
iree_hal_resource_initialize(&iree_hal_hip_dispatch_completed_data_vtable_t,
&(*out)->resource);
+ (*out)->host_allocator = host_allocator;
return iree_ok_status();
}
@@ -1201,15 +1203,19 @@
wait_semaphore_list.count * sizeof(*wait_semaphore_list.payload_values);
data->wait_semaphore_list.count = wait_semaphore_list.count;
data->wait_semaphore_list.semaphores = (iree_hal_semaphore_t**)callback_ptr;
- memcpy(data->wait_semaphore_list.semaphores, wait_semaphore_list.semaphores,
- wait_semaphore_list.count * sizeof(*wait_semaphore_list.semaphores));
+ if (wait_semaphore_list.count > 0) {
+ memcpy(data->wait_semaphore_list.semaphores, wait_semaphore_list.semaphores,
+ wait_semaphore_list.count * sizeof(*wait_semaphore_list.semaphores));
+ }
data->wait_semaphore_list.payload_values =
(uint64_t*)(callback_ptr + wait_semaphore_list.count *
sizeof(*wait_semaphore_list.semaphores));
- memcpy(
- data->wait_semaphore_list.payload_values,
- wait_semaphore_list.payload_values,
- wait_semaphore_list.count * sizeof(*wait_semaphore_list.payload_values));
+ if (wait_semaphore_list.count > 0) {
+ memcpy(data->wait_semaphore_list.payload_values,
+ wait_semaphore_list.payload_values,
+ wait_semaphore_list.count *
+ sizeof(*wait_semaphore_list.payload_values));
+ }
for (iree_host_size_t i = 0; i < wait_semaphore_list.count; ++i) {
iree_hal_resource_retain(wait_semaphore_list.semaphores[i]);
}
@@ -1218,16 +1224,21 @@
// Copy signal list for later access.
data->signal_semaphore_list.count = signal_semaphore_list.count;
data->signal_semaphore_list.semaphores = (iree_hal_semaphore_t**)callback_ptr;
- memcpy(
- data->signal_semaphore_list.semaphores, signal_semaphore_list.semaphores,
- signal_semaphore_list.count * sizeof(*signal_semaphore_list.semaphores));
+ if (signal_semaphore_list.count > 0) {
+ memcpy(data->signal_semaphore_list.semaphores,
+ signal_semaphore_list.semaphores,
+ signal_semaphore_list.count *
+ sizeof(*signal_semaphore_list.semaphores));
+ }
data->signal_semaphore_list.payload_values =
(uint64_t*)(callback_ptr + signal_semaphore_list.count *
sizeof(*signal_semaphore_list.semaphores));
- memcpy(data->signal_semaphore_list.payload_values,
- signal_semaphore_list.payload_values,
- signal_semaphore_list.count *
- sizeof(*signal_semaphore_list.payload_values));
+ if (signal_semaphore_list.count > 0) {
+ memcpy(data->signal_semaphore_list.payload_values,
+ signal_semaphore_list.payload_values,
+ signal_semaphore_list.count *
+ sizeof(*signal_semaphore_list.payload_values));
+ }
for (iree_host_size_t i = 0; i < signal_semaphore_list.count; ++i) {
iree_hal_resource_retain(signal_semaphore_list.semaphores[i]);
}
@@ -2608,8 +2619,10 @@
sizeof(*callback_data) +
additional_data_for_base);
callback_data->binding_table.bindings = binding_element_ptr;
- memcpy(binding_element_ptr, binding_table.bindings,
- sizeof(*binding_element_ptr) * binding_table.count);
+ if (binding_table.count > 0) {
+ memcpy(binding_element_ptr, binding_table.bindings,
+ sizeof(*binding_element_ptr) * binding_table.count);
+ }
status = iree_hal_resource_set_insert_strided(
callback_data->resource_set, binding_table.count,
diff --git a/runtime/src/iree/hal/drivers/hip/rccl_dynamic_symbols.c b/runtime/src/iree/hal/drivers/hip/rccl_dynamic_symbols.c
index 914bcc2..e83bc3d 100644
--- a/runtime/src/iree/hal/drivers/hip/rccl_dynamic_symbols.c
+++ b/runtime/src/iree/hal/drivers/hip/rccl_dynamic_symbols.c
@@ -101,7 +101,7 @@
memset(out_syms, 0, sizeof(*out_syms));
iree_status_t status = iree_dynamic_library_load_from_files(
IREE_ARRAYSIZE(iree_hal_hip_nccl_dylib_names),
- iree_hal_hip_nccl_dylib_names, IREE_DYNAMIC_LIBRARY_FLAG_NONE,
+ iree_hal_hip_nccl_dylib_names, IREE_DYNAMIC_LIBRARY_FLAG_NODELETE,
host_allocator, &out_syms->dylib);
if (iree_status_is_not_found(status)) {
iree_status_ignore(status);