[runtime][hip] Propagate errors through semaphores (#18021)
Adds proper handling of errors that occur when executing operations on
the device or when a semaphore fails in the wait list. These errors will
propagate to downstream semaphores that are in the signal list of the
operation.
This change includes some refactor the pending queue actions:
* Make the context hold a sticky status instead of just an status code.
* Remove the worker threads' "error" state. This can be handled by the
context's status.
* Remove "exit committed" thread state in favor of standard thread
joining.
* Make the "exit requested" thread state a separate boolean variable and
guard against submitting more work after an exit is requested. Also wait
on all work to complete before exiting worker threads, not just on the
currently ran actions.
* Make pending work items increment immediately when an action is
enqueued instead of when scheduled on the HIP stream. This is required
to properly count outstanding work.
* Remove and merge some of the redundant state for the worker and
completion threads.
* Remove reference counting from the pending queue actions context. It
has a clear owner, which is the device.
* Rework when the threads exit, which is pretty much only when exit is
requested and there is no more queued or executing actions. Errors don't
cause the threads to exit.
Here is not included moving the destruction and cleanup of actions from
the worker thread to the completion thread. This is an optimization and
code simplification that is now possible since we are not using HIP
stream callbacks, so we could do that right after an action completes.
Technically, actions get destroyed on the completion thread as well when
not on the happy path and actions fail.
diff --git a/runtime/src/iree/base/internal/threading.h b/runtime/src/iree/base/internal/threading.h
index 5a4a4d9..545b51c 100644
--- a/runtime/src/iree/base/internal/threading.h
+++ b/runtime/src/iree/base/internal/threading.h
@@ -205,6 +205,9 @@
// This has no effect if the thread is not suspended.
void iree_thread_resume(iree_thread_t* thread);
+// Blocks the current thread until |thread| has finished its execution.
+void iree_thread_join(iree_thread_t* thread);
+
void iree_thread_yield(void);
#ifdef __cplusplus
diff --git a/runtime/src/iree/base/internal/threading_darwin.c b/runtime/src/iree/base/internal/threading_darwin.c
index 90708a1..8f611b8 100644
--- a/runtime/src/iree/base/internal/threading_darwin.c
+++ b/runtime/src/iree/base/internal/threading_darwin.c
@@ -256,6 +256,12 @@
IREE_TRACE_ZONE_END(z0);
}
+void iree_thread_join(iree_thread_t* thread) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+ pthread_join(thread->handle, NULL);
+ IREE_TRACE_ZONE_END(z0);
+}
+
void iree_thread_yield(void) { sched_yield(); }
#endif // IREE_PLATFORM_APPLE
diff --git a/runtime/src/iree/base/internal/threading_pthreads.c b/runtime/src/iree/base/internal/threading_pthreads.c
index bfb7897..ec0f107 100644
--- a/runtime/src/iree/base/internal/threading_pthreads.c
+++ b/runtime/src/iree/base/internal/threading_pthreads.c
@@ -351,6 +351,12 @@
IREE_TRACE_ZONE_END(z0);
}
+void iree_thread_join(iree_thread_t* thread) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+ pthread_join(thread->handle, NULL);
+ IREE_TRACE_ZONE_END(z0);
+}
+
void iree_thread_yield(void) { sched_yield(); }
#endif // IREE_PLATFORM_*
diff --git a/runtime/src/iree/base/internal/threading_win32.c b/runtime/src/iree/base/internal/threading_win32.c
index 160bee0..944c24a 100644
--- a/runtime/src/iree/base/internal/threading_win32.c
+++ b/runtime/src/iree/base/internal/threading_win32.c
@@ -322,6 +322,12 @@
IREE_TRACE_ZONE_END(z0);
}
+void iree_thread_join(iree_thread_t* thread) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+ WaitForSingleObject(thread->handle, INFINITE);
+ IREE_TRACE_ZONE_END(z0);
+}
+
void iree_thread_yield(void) { YieldProcessor(); }
#endif // IREE_PLATFORM_WINDOWS
diff --git a/runtime/src/iree/hal/cts/cts_test_base.h b/runtime/src/iree/hal/cts/cts_test_base.h
index 919933e..9639bfc 100644
--- a/runtime/src/iree/hal/cts/cts_test_base.h
+++ b/runtime/src/iree/hal/cts/cts_test_base.h
@@ -9,6 +9,7 @@
#include <set>
#include <string>
+#include <string_view>
#include "iree/base/api.h"
#include "iree/base/string_view.h"
@@ -276,6 +277,24 @@
IREE_EXPECT_OK(iree_hal_semaphore_query(semaphore, &value));
EXPECT_EQ(expected_value, value);
}
+
+ // Check that a contains b.
+ // That is the codes of a and b are equal and the message of b is contained
+ // in the message of a.
+ void CheckStatusContains(iree_status_t a, iree_status_t b) {
+ EXPECT_EQ(iree_status_code(a), iree_status_code(b));
+ iree_allocator_t allocator = iree_allocator_system();
+ char* a_str = NULL;
+ iree_host_size_t a_str_length = 0;
+ EXPECT_TRUE(iree_status_to_string(a, &allocator, &a_str, &a_str_length));
+ char* b_str = NULL;
+ iree_host_size_t b_str_length = 0;
+ EXPECT_TRUE(iree_status_to_string(b, &allocator, &b_str, &b_str_length));
+ EXPECT_TRUE(std::string_view(a_str).find(std::string_view(b_str)) !=
+ std::string_view::npos);
+ iree_allocator_free(allocator, a_str);
+ iree_allocator_free(allocator, b_str);
+ }
};
} // namespace cts
diff --git a/runtime/src/iree/hal/cts/semaphore_submission_test.h b/runtime/src/iree/hal/cts/semaphore_submission_test.h
index 40d0a0f..5f0a5ba 100644
--- a/runtime/src/iree/hal/cts/semaphore_submission_test.h
+++ b/runtime/src/iree/hal/cts/semaphore_submission_test.h
@@ -841,6 +841,78 @@
iree_hal_command_buffer_release(command_buffer);
}
+// Submit an batch and check that the wait semaphore fails when the signal
+// semaphore fails.
+TEST_F(SemaphoreSubmissionTest, PropagateFailSignal) {
+ // signal-wait relation:
+ //
+ // semaphore1
+ // ↓
+ // command_buffer
+ // ↓
+ // semaphore2
+
+ iree_hal_command_buffer_t* command_buffer = CreateEmptyCommandBuffer();
+ iree_hal_semaphore_t* semaphore1 = CreateSemaphore();
+ iree_hal_semaphore_t* semaphore2 = CreateSemaphore();
+
+ // Submit the command buffer.
+ uint64_t semaphore1_wait_value = 1;
+ iree_hal_semaphore_list_t command_buffer_wait_list = {
+ /*count=*/1, &semaphore1, &semaphore1_wait_value};
+ uint64_t semaphore2_signal_value = 1;
+ iree_hal_semaphore_list_t command_buffer_signal_list = {
+ /*count=*/1, &semaphore2, &semaphore2_signal_value};
+ IREE_ASSERT_OK(iree_hal_device_queue_execute(
+ device_, IREE_HAL_QUEUE_AFFINITY_ANY,
+ /*wait_semaphore_list=*/command_buffer_wait_list,
+ /*signal_semaphore_list=*/command_buffer_signal_list, 1, &command_buffer,
+ /*binding_tables=*/NULL));
+
+ iree_status_t status =
+ iree_make_status(IREE_STATUS_CANCELLED, "PropagateFailSignal test.");
+ std::thread signal_thread([&]() {
+ iree_hal_semaphore_fail(semaphore1, iree_status_clone(status));
+ });
+
+ iree_status_t wait_status =
+ iree_hal_semaphore_wait(semaphore2, semaphore2_signal_value,
+ iree_make_deadline(IREE_TIME_INFINITE_FUTURE));
+ EXPECT_EQ(iree_status_code(wait_status), IREE_STATUS_ABORTED);
+ uint64_t value = 1234;
+ iree_status_t query_status = iree_hal_semaphore_query(semaphore2, &value);
+ EXPECT_EQ(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE);
+ CheckStatusContains(query_status, status);
+
+ signal_thread.join();
+ iree_hal_semaphore_release(semaphore1);
+ iree_hal_semaphore_release(semaphore2);
+ iree_hal_command_buffer_release(command_buffer);
+ iree_status_ignore(status);
+ iree_status_ignore(wait_status);
+ iree_status_ignore(query_status);
+}
+
+// Submit an invalid dispatch and check that the wait semaphore fails.
+TEST_F(SemaphoreSubmissionTest, PropagateDispatchFailure) {
+ // signal-wait relation:
+ //
+ // semaphore1
+ // ↓
+ // command_buffer
+ // ↓
+ // semaphore2
+
+ // TODO (sogartar):
+ // I tried to add a kernel that stores into a null pointer or
+ // traps(aborts), but with HIP that causes the whole executable to abort,
+ // which is not what we want.
+ // We want a failure of the kernel launch or when waiting on the stream for
+ // the kernel to complete.
+ // This needs to be "soft" failure that result in a returned error from the
+ // underlying API call.
+}
+
} // namespace iree::hal::cts
#endif // IREE_HAL_CTS_SEMAPHORE_SUBMISSION_TEST_H_
diff --git a/runtime/src/iree/hal/cts/semaphore_test.h b/runtime/src/iree/hal/cts/semaphore_test.h
index 2f84b6f..4ec5d52 100644
--- a/runtime/src/iree/hal/cts/semaphore_test.h
+++ b/runtime/src/iree/hal/cts/semaphore_test.h
@@ -377,6 +377,140 @@
iree_hal_semaphore_release(semaphore2);
}
+// Wait on a semaphore that is then failed.
+TEST_F(SemaphoreTest, FailThenWait) {
+ iree_hal_semaphore_t* semaphore = this->CreateSemaphore();
+
+ iree_status_t status =
+ iree_make_status(IREE_STATUS_CANCELLED, "FailThenWait test.");
+ iree_hal_semaphore_fail(semaphore, iree_status_clone(status));
+
+ iree_status_t wait_status = iree_hal_semaphore_wait(
+ semaphore, 1, iree_make_deadline(IREE_TIME_INFINITE_FUTURE));
+ EXPECT_EQ(iree_status_code(wait_status), IREE_STATUS_ABORTED);
+ uint64_t value = 1234;
+ iree_status_t query_status = iree_hal_semaphore_query(semaphore, &value);
+ EXPECT_EQ(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE);
+ CheckStatusContains(query_status, status);
+
+ iree_hal_semaphore_release(semaphore);
+ iree_status_ignore(status);
+ iree_status_ignore(wait_status);
+ iree_status_ignore(query_status);
+}
+
+// Wait on a semaphore that is then failed.
+TEST_F(SemaphoreTest, WaitThenFail) {
+ iree_hal_semaphore_t* semaphore = this->CreateSemaphore();
+
+ // It is possible that the order becomes fail than wait.
+ // We assume that it is less likely since starting the thread takes time.
+ iree_status_t status =
+ iree_make_status(IREE_STATUS_CANCELLED, "WaitThenFail test.");
+ std::thread signal_thread(
+ [&]() { iree_hal_semaphore_fail(semaphore, iree_status_clone(status)); });
+
+ iree_status_t wait_status = iree_hal_semaphore_wait(
+ semaphore, 1, iree_make_deadline(IREE_TIME_INFINITE_FUTURE));
+ EXPECT_EQ(iree_status_code(wait_status), IREE_STATUS_ABORTED);
+ uint64_t value = 1234;
+ iree_status_t query_status = iree_hal_semaphore_query(semaphore, &value);
+ EXPECT_EQ(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE);
+ CheckStatusContains(query_status, status);
+
+ signal_thread.join();
+ iree_hal_semaphore_release(semaphore);
+ iree_status_ignore(status);
+ iree_status_ignore(wait_status);
+ iree_status_ignore(query_status);
+}
+
+// Wait 2 semaphores then fail one of them.
+TEST_F(SemaphoreTest, MultiWaitThenFail) {
+ iree_hal_semaphore_t* semaphore1 = this->CreateSemaphore();
+ iree_hal_semaphore_t* semaphore2 = this->CreateSemaphore();
+
+ // It is possible that the order becomes fail than wait.
+ // We assume that it is less likely since starting the thread takes time.
+ iree_status_t status =
+ iree_make_status(IREE_STATUS_CANCELLED, "MultiWaitThenFail test.");
+ std::thread signal_thread([&]() {
+ iree_hal_semaphore_fail(semaphore1, iree_status_clone(status));
+ });
+
+ iree_hal_semaphore_t* semaphore_array[] = {semaphore1, semaphore2};
+ uint64_t payload_array[] = {1, 1};
+ iree_hal_semaphore_list_t semaphore_list = {
+ IREE_ARRAYSIZE(semaphore_array),
+ semaphore_array,
+ payload_array,
+ };
+ iree_status_t wait_status = iree_hal_semaphore_list_wait(
+ semaphore_list, iree_make_deadline(IREE_TIME_INFINITE_FUTURE));
+ EXPECT_EQ(iree_status_code(wait_status), IREE_STATUS_ABORTED);
+ uint64_t value = 1234;
+ iree_status_t semaphore1_query_status =
+ iree_hal_semaphore_query(semaphore1, &value);
+ EXPECT_EQ(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE);
+ CheckStatusContains(semaphore1_query_status, status);
+
+ // semaphore2 must not have changed.
+ uint64_t semaphore2_value = 1234;
+ IREE_ASSERT_OK(iree_hal_semaphore_query(semaphore2, &semaphore2_value));
+ EXPECT_EQ(semaphore2_value, 0);
+
+ signal_thread.join();
+ iree_hal_semaphore_release(semaphore1);
+ iree_hal_semaphore_release(semaphore2);
+ iree_status_ignore(status);
+ iree_status_ignore(wait_status);
+ iree_status_ignore(semaphore1_query_status);
+}
+
+// Wait 2 semaphores using iree_hal_device_wait_semaphores then fail
+// one of them.
+TEST_F(SemaphoreTest, DeviceMultiWaitThenFail) {
+ iree_hal_semaphore_t* semaphore1 = this->CreateSemaphore();
+ iree_hal_semaphore_t* semaphore2 = this->CreateSemaphore();
+
+ // It is possible that the order becomes fail than wait.
+ // We assume that it is less likely since starting the thread takes time.
+ iree_status_t status =
+ iree_make_status(IREE_STATUS_CANCELLED, "DeviceMultiWaitThenFail test.");
+ std::thread signal_thread([&]() {
+ iree_hal_semaphore_fail(semaphore1, iree_status_clone(status));
+ });
+
+ iree_hal_semaphore_t* semaphore_array[] = {semaphore1, semaphore2};
+ uint64_t payload_array[] = {1, 1};
+ iree_hal_semaphore_list_t semaphore_list = {
+ IREE_ARRAYSIZE(semaphore_array),
+ semaphore_array,
+ payload_array,
+ };
+ iree_status_t wait_status = iree_hal_device_wait_semaphores(
+ device_, IREE_HAL_WAIT_MODE_ANY, semaphore_list,
+ iree_make_deadline(IREE_TIME_INFINITE_FUTURE));
+ EXPECT_EQ(iree_status_code(wait_status), IREE_STATUS_ABORTED);
+ uint64_t value = 1234;
+ iree_status_t semaphore1_query_status =
+ iree_hal_semaphore_query(semaphore1, &value);
+ EXPECT_EQ(value, IREE_HAL_SEMAPHORE_FAILURE_VALUE);
+ CheckStatusContains(semaphore1_query_status, status);
+
+ // semaphore2 must not have changed.
+ uint64_t semaphore2_value = 1234;
+ IREE_ASSERT_OK(iree_hal_semaphore_query(semaphore2, &semaphore2_value));
+ EXPECT_EQ(semaphore2_value, 0);
+
+ signal_thread.join();
+ iree_hal_semaphore_release(semaphore1);
+ iree_hal_semaphore_release(semaphore2);
+ iree_status_ignore(status);
+ iree_status_ignore(wait_status);
+ iree_status_ignore(semaphore1_query_status);
+}
+
} // namespace iree::hal::cts
#endif // IREE_HAL_CTS_SEMAPHORE_TEST_H_
diff --git a/runtime/src/iree/hal/drivers/cuda/cts/CMakeLists.txt b/runtime/src/iree/hal/drivers/cuda/cts/CMakeLists.txt
index ae6843e..99f2bb6 100644
--- a/runtime/src/iree/hal/drivers/cuda/cts/CMakeLists.txt
+++ b/runtime/src/iree/hal/drivers/cuda/cts/CMakeLists.txt
@@ -4,6 +4,16 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+unset(FILTER_TESTS)
+string(APPEND FILTER_TESTS "SemaphoreTest.WaitThenFail:")
+string(APPEND FILTER_TESTS "SemaphoreTest.FailThenWait:")
+string(APPEND FILTER_TESTS "SemaphoreTest.MultiWaitThenFail:")
+string(APPEND FILTER_TESTS "SemaphoreTest.DeviceMultiWaitThenFail:")
+string(APPEND FILTER_TESTS "SemaphoreSubmissionTest.PropagateFailSignal:")
+set(FILTER_TESTS_ARGS
+ "--gtest_filter=-${FILTER_TESTS}"
+)
+
iree_hal_cts_test_suite(
DRIVER_NAME
cuda
@@ -19,6 +29,7 @@
"\"PTXE\""
ARGS
"--cuda_use_streams=false"
+ ${FILTER_TESTS_ARGS}
DEPS
iree::hal::drivers::cuda::registration
EXCLUDED_TESTS
@@ -44,6 +55,7 @@
"\"PTXE\""
ARGS
"--cuda_use_streams=true"
+ ${FILTER_TESTS_ARGS}
DEPS
iree::hal::drivers::cuda::registration
EXCLUDED_TESTS
diff --git a/runtime/src/iree/hal/drivers/hip/event_semaphore.c b/runtime/src/iree/hal/drivers/hip/event_semaphore.c
index 6939bc9..99705a6 100644
--- a/runtime/src/iree/hal/drivers/hip/event_semaphore.c
+++ b/runtime/src/iree/hal/drivers/hip/event_semaphore.c
@@ -8,8 +8,8 @@
#include "iree/base/internal/synchronization.h"
#include "iree/base/internal/wait_handle.h"
+#include "iree/base/status.h"
#include "iree/hal/drivers/hip/dynamic_symbols.h"
-#include "iree/hal/drivers/hip/status_util.h"
#include "iree/hal/drivers/hip/timepoint_pool.h"
#include "iree/hal/utils/semaphore_base.h"
@@ -187,6 +187,13 @@
// Notify timepoints - note that this must happen outside the lock.
iree_hal_semaphore_notify(&semaphore->base, IREE_HAL_SEMAPHORE_FAILURE_VALUE,
status_code);
+
+ // Advance the pending queue actions if possible. This also must happen
+ // outside the lock to avoid nesting.
+ status = iree_hal_hip_pending_queue_actions_issue(
+ semaphore->pending_queue_actions);
+ iree_status_ignore(status);
+
IREE_TRACE_ZONE_END(z0);
}
@@ -316,6 +323,14 @@
return iree_ok_status();
}
+ iree_slim_mutex_lock(&semaphore->mutex);
+ if (semaphore->current_value == IREE_HAL_SEMAPHORE_FAILURE_VALUE) {
+ iree_slim_mutex_unlock(&semaphore->mutex);
+ IREE_TRACE_ZONE_END(z0);
+ return iree_make_status(IREE_STATUS_ABORTED);
+ }
+ iree_slim_mutex_unlock(&semaphore->mutex);
+
// Wait until the timepoint resolves.
// If satisfied the timepoint is automatically cleaned up and we are done. If
// the deadline is reached before satisfied then we have to clean it up.
@@ -326,6 +341,17 @@
iree_hal_semaphore_cancel_timepoint(&semaphore->base, &timepoint->base);
}
iree_hal_hip_timepoint_pool_release(semaphore->timepoint_pool, 1, &timepoint);
+ if (!iree_status_is_ok(status)) {
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+ }
+
+ iree_slim_mutex_lock(&semaphore->mutex);
+ if (semaphore->current_value == IREE_HAL_SEMAPHORE_FAILURE_VALUE) {
+ status = iree_make_status(IREE_STATUS_ABORTED);
+ }
+ iree_slim_mutex_unlock(&semaphore->mutex);
+
IREE_TRACE_ZONE_END(z0);
return status;
}
@@ -406,6 +432,23 @@
iree_wait_set_free(wait_set);
iree_arena_deinitialize(&arena);
+ if (!iree_status_is_ok(status)) {
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+ }
+
+ for (iree_host_size_t i = 0; i < semaphore_list.count; ++i) {
+ iree_hal_hip_semaphore_t* semaphore =
+ iree_hal_hip_semaphore_cast(semaphore_list.semaphores[i]);
+ iree_slim_mutex_lock(&semaphore->mutex);
+ if (semaphore->current_value == IREE_HAL_SEMAPHORE_FAILURE_VALUE) {
+ iree_slim_mutex_unlock(&semaphore->mutex);
+ status = iree_make_status(IREE_STATUS_ABORTED);
+ break;
+ }
+ iree_slim_mutex_unlock(&semaphore->mutex);
+ }
+
IREE_TRACE_ZONE_END(z0);
return status;
}
diff --git a/runtime/src/iree/hal/drivers/hip/pending_queue_actions.c b/runtime/src/iree/hal/drivers/hip/pending_queue_actions.c
index eec5acd..6d72330 100644
--- a/runtime/src/iree/hal/drivers/hip/pending_queue_actions.c
+++ b/runtime/src/iree/hal/drivers/hip/pending_queue_actions.c
@@ -10,6 +10,7 @@
#include <stddef.h>
#include "iree/base/api.h"
+#include "iree/base/assert.h"
#include "iree/base/internal/arena.h"
#include "iree/base/internal/atomic_slist.h"
#include "iree/base/internal/atomics.h"
@@ -45,6 +46,13 @@
IREE_HAL_HIP_QUEUE_ACTION_STATE_ZOMBIE,
} iree_hal_hip_queue_action_state_t;
+// How many work items must complete in order for an action to complete.
+// We keep track of the remaining work for an action so we don't exit worker
+// threads prematurely.
+// +1 for issuing an execution of an action.
+// +1 for cleaning up a zombie action.
+static const iree_host_size_t total_work_items_to_complete_an_action = 2;
+
// A pending queue action.
//
// Note that this struct does not have internal synchronization; it's expected
@@ -101,6 +109,9 @@
bool is_pending;
} iree_hal_hip_queue_action_t;
+static void iree_hal_hip_queue_action_fail_locked(
+ iree_hal_hip_queue_action_t* action, iree_status_t status);
+
static void iree_hal_hip_queue_action_clear_events(
iree_hal_hip_queue_action_t* action) {
for (iree_host_size_t i = 0; i < action->event_count; ++i) {
@@ -248,7 +259,7 @@
typedef struct iree_hal_hip_completion_list_node_t {
// The callback and user data for that callback. To be called
// when the associated event has completed.
- iree_status_t (*callback)(void* user_data);
+ iree_status_t (*callback)(iree_status_t, void* user_data);
void* user_data;
// The event to wait for on the completion thread.
hipEvent_t event;
@@ -316,7 +327,7 @@
iree_slim_mutex_deinitialize(&list->guard_mutex);
}
-static iree_hal_hip_queue_action_t* iree_hal_hip_atomic_slist_entry_pop_front(
+static iree_hal_hip_queue_action_t* iree_hal_hip_entry_list_node_pop_front(
iree_hal_hip_entry_list_node_t* list) {
IREE_ASSERT(list->ready_list_head);
@@ -331,17 +342,27 @@
return action;
}
+static void iree_hal_hip_entry_list_node_push_front(
+ iree_hal_hip_entry_list_node_t* entry,
+ iree_hal_hip_queue_action_t* action) {
+ IREE_ASSERT(!action->next && !action->prev);
+
+ iree_hal_hip_queue_action_t* head = entry->ready_list_head;
+ entry->ready_list_head = action;
+ if (head) {
+ action->next = head;
+ head->prev = action;
+ }
+}
+
// The ready-list processing worker's working/exiting state.
//
// States in the list has increasing priorities--meaning normally ones appearing
// earlier can overwrite ones appearing later without checking; but not the
// reverse order.
typedef enum iree_hal_hip_worker_state_e {
- IREE_HAL_HIP_WORKER_STATE_IDLE_WAITING = 0, // Worker to main thread
- IREE_HAL_HIP_WORKER_STATE_WORKLOAD_PENDING = 1, // Main to worker thread
- IREE_HAL_HIP_WORKER_STATE_EXIT_REQUESTED = -1, // Main to worker thread
- IREE_HAL_HIP_WORKER_STATE_EXIT_COMMITTED = -2, // Worker to main thread
- IREE_HAL_HIP_WORKER_STATE_EXIT_ERROR = -3, // Worker to main thread
+ IREE_HAL_HIP_WORKER_STATE_IDLE_WAITING = 0, // Worker to any thread
+ IREE_HAL_HIP_WORKER_STATE_WORKLOAD_PENDING = 1, // Any to worker thread
} iree_hal_hip_worker_state_t;
// The data structure needed by a ready-list processing worker thread to issue
@@ -359,29 +380,8 @@
typedef struct iree_hal_hip_working_area_t {
// Notification from the parent thread to request worker state changes.
iree_notification_t state_notification;
- // Notification to the parent thread to indicate the worker committed exiting.
- // TODO: maybe remove this. We can just wait on the worker thread to exit.
- iree_notification_t exit_notification;
- iree_hal_hip_entry_list_t ready_worklist;
- iree_atomic_int32_t worker_state; // atomic
- // TODO: use status to provide more context for the error.
- iree_atomic_intptr_t error_code; // atomic
-
- // The number of asynchronous work items that are scheduled and not
- // complete.
- // These are
- // * the number of callbacks that are scheduled on the host stream.
- // * the number of pending action cleanup.
- // We need to wait for them to finish before destroying the context.
- iree_slim_mutex_t pending_work_items_count_mutex;
- iree_notification_t pending_work_items_count_notification;
- int32_t pending_work_items_count
- IREE_GUARDED_BY(pending_work_items_count_mutex);
-
- iree_allocator_t host_allocator; // const
-
- const iree_hal_hip_dynamic_symbols_t* symbols;
- hipDevice_t device;
+ iree_hal_hip_entry_list_t ready_worklist; // atomic
+ iree_atomic_int32_t worker_state; // atomic
} iree_hal_hip_working_area_t;
// This data structure is shared by the parent thread. It is responsible
@@ -390,29 +390,11 @@
// This replaces the use of hipLaunchHostFunc, which causes the stream to block
// and wait for the CPU work to complete. It also picks up completed
// events with significantly less latency than hipLaunchHostFunc.
-
typedef struct iree_hal_hip_completion_area_t {
// Notification from the parent thread to request completion state changes.
iree_notification_t state_notification;
- // Notification to the parent thread to indicate the worker committed exiting.
- iree_notification_t exit_notification;
- iree_hal_hip_completion_list_t completion_list;
- iree_atomic_int32_t worker_state; // atomic
-
- iree_atomic_intptr_t error_code; // atomic
-
- // The number of asynchronous completions items that are scheduled and not
- // yet waited on.
- // We need to wait for them to finish before destroying the context.
- iree_slim_mutex_t pending_completion_count_mutex;
- iree_notification_t pending_completion_count_notification;
- int32_t pending_completion_count
- IREE_GUARDED_BY(pending_completion_count_mutex);
-
- iree_allocator_t host_allocator;
-
- const iree_hal_hip_dynamic_symbols_t* symbols;
- hipDevice_t device;
+ iree_hal_hip_completion_list_t completion_list; // atomic
+ iree_atomic_int32_t worker_state; // atomic
} iree_hal_hip_completion_area_t;
static void iree_hal_hip_working_area_initialize(
@@ -420,31 +402,19 @@
const iree_hal_hip_dynamic_symbols_t* symbols,
iree_hal_hip_working_area_t* working_area) {
iree_notification_initialize(&working_area->state_notification);
- iree_notification_initialize(&working_area->exit_notification);
- iree_hal_hip_ready_action_list_initialize(&working_area->ready_worklist);
+ iree_hal_hip_ready_action_list_deinitialize(&working_area->ready_worklist,
+ host_allocator);
iree_atomic_store_int32(&working_area->worker_state,
IREE_HAL_HIP_WORKER_STATE_IDLE_WAITING,
iree_memory_order_release);
- iree_atomic_store_int32(&working_area->error_code, IREE_STATUS_OK,
- iree_memory_order_release);
- iree_slim_mutex_initialize(&working_area->pending_work_items_count_mutex);
- iree_notification_initialize(
- &working_area->pending_work_items_count_notification);
- working_area->pending_work_items_count = 0;
- working_area->host_allocator = host_allocator;
- working_area->symbols = symbols;
- working_area->device = device;
}
static void iree_hal_hip_working_area_deinitialize(
- iree_hal_hip_working_area_t* working_area) {
+ iree_hal_hip_working_area_t* working_area,
+ iree_allocator_t host_allocator) {
iree_hal_hip_ready_action_list_deinitialize(&working_area->ready_worklist,
- working_area->host_allocator);
- iree_notification_deinitialize(&working_area->exit_notification);
+ host_allocator);
iree_notification_deinitialize(&working_area->state_notification);
- iree_slim_mutex_deinitialize(&working_area->pending_work_items_count_mutex);
- iree_notification_deinitialize(
- &working_area->pending_work_items_count_notification);
}
static void iree_hal_hip_completion_area_initialize(
@@ -452,41 +422,27 @@
const iree_hal_hip_dynamic_symbols_t* symbols,
iree_hal_hip_completion_area_t* completion_area) {
iree_notification_initialize(&completion_area->state_notification);
- iree_notification_initialize(&completion_area->exit_notification);
iree_hal_hip_completion_list_initialize(&completion_area->completion_list);
iree_atomic_store_int32(&completion_area->worker_state,
IREE_HAL_HIP_WORKER_STATE_IDLE_WAITING,
iree_memory_order_release);
- iree_atomic_store_int32(&completion_area->error_code, IREE_STATUS_OK,
- iree_memory_order_release);
- iree_slim_mutex_initialize(&completion_area->pending_completion_count_mutex);
- iree_notification_initialize(
- &completion_area->pending_completion_count_notification);
- completion_area->pending_completion_count = 0;
- completion_area->host_allocator = host_allocator;
- completion_area->symbols = symbols;
- completion_area->device = device;
}
static void iree_hal_hip_completion_area_deinitialize(
- iree_hal_hip_completion_area_t* completion_area) {
+ iree_hal_hip_completion_area_t* completion_area,
+ const iree_hal_hip_dynamic_symbols_t* symbols,
+ iree_allocator_t host_allocator) {
iree_hal_hip_completion_list_deinitialize(&completion_area->completion_list,
- completion_area->symbols,
- completion_area->host_allocator);
- iree_notification_deinitialize(&completion_area->exit_notification);
+ symbols, host_allocator);
iree_notification_deinitialize(&completion_area->state_notification);
- iree_slim_mutex_deinitialize(
- &completion_area->pending_completion_count_mutex);
- iree_notification_deinitialize(
- &completion_area->pending_completion_count_notification);
}
// The main function for the ready-list processing worker thread.
static int iree_hal_hip_worker_execute(
- iree_hal_hip_working_area_t* working_area);
+ iree_hal_hip_pending_queue_actions_t* actions);
static int iree_hal_hip_completion_execute(
- iree_hal_hip_completion_area_t* working_area);
+ iree_hal_hip_pending_queue_actions_t* actions);
//===----------------------------------------------------------------------===//
// Pending queue actions
@@ -505,7 +461,7 @@
// The symbols used to create and destroy hipEvent_t objects.
const iree_hal_hip_dynamic_symbols_t* symbols;
- // Non-recursive mutex guarding access to the action list.
+ // Non-recursive mutex guarding access.
iree_slim_mutex_t action_mutex;
// The double-linked list of pending actions.
@@ -525,12 +481,28 @@
// Completion thread's working area.
iree_hal_hip_completion_area_t completion_area;
+ // Atomic of type iree_status_t. It is a sticky error.
+ // Once set with an error, all subsequent actions that have not completed
+ // will fail with this error.
+ iree_status_t status IREE_GUARDED_BY(action_mutex);
+
// The associated hip device.
hipDevice_t device;
-};
-static const iree_hal_resource_vtable_t
- iree_hal_hip_pending_queue_actions_vtable;
+ // The number of asynchronous work items that are scheduled and not
+ // complete.
+ // These are
+ // * the number of actions issued.
+ // * the number of pending action cleanups.
+ // The work and completion threads can exit only when there are no more
+ // pending work items.
+ iree_host_size_t pending_work_items_count IREE_GUARDED_BY(action_mutex);
+
+ // The owner can request an exit of the worker threads.
+ // Once all pending enqueued work is complete the threads will exit.
+ // No actions can be enqueued after requesting an exit.
+ bool exit_requested IREE_GUARDED_BY(action_mutex);
+};
iree_status_t iree_hal_hip_pending_queue_actions_create(
const iree_hal_hip_dynamic_symbols_t* symbols, hipDevice_t device,
@@ -545,8 +517,6 @@
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_allocator_malloc(host_allocator, sizeof(*actions),
(void**)&actions));
- iree_hal_resource_initialize(&iree_hal_hip_pending_queue_actions_vtable,
- &actions->resource);
actions->host_allocator = host_allocator;
actions->block_pool = block_pool;
actions->symbols = symbols;
@@ -570,15 +540,15 @@
params.name = IREE_SV("deferque_worker");
params.create_suspended = false;
iree_status_t status = iree_thread_create(
- (iree_thread_entry_t)iree_hal_hip_worker_execute, working_area, params,
+ (iree_thread_entry_t)iree_hal_hip_worker_execute, actions, params,
actions->host_allocator, &actions->worker_thread);
params.name = IREE_SV("done_worker");
params.create_suspended = false;
if (iree_status_is_ok(status)) {
status = iree_thread_create(
- (iree_thread_entry_t)iree_hal_hip_completion_execute, completion_area,
- params, actions->host_allocator, &actions->completion_thread);
+ (iree_thread_entry_t)iree_hal_hip_completion_execute, actions, params,
+ actions->host_allocator, &actions->completion_thread);
}
if (iree_status_is_ok(status)) {
@@ -596,58 +566,62 @@
return (iree_hal_hip_pending_queue_actions_t*)base_value;
}
-static bool iree_hal_hip_worker_committed_exiting(
- iree_hal_hip_working_area_t* working_area);
-static bool iree_hal_hip_completion_committed_exiting(
- iree_hal_hip_completion_area_t* working_area);
+static void iree_hal_hip_pending_queue_actions_notify_worker_thread(
+ iree_hal_hip_working_area_t* working_area) {
+ iree_atomic_store_int32(&working_area->worker_state,
+ IREE_HAL_HIP_WORKER_STATE_WORKLOAD_PENDING,
+ iree_memory_order_release);
+ iree_notification_post(&working_area->state_notification, IREE_ALL_WAITERS);
+}
+
+static void iree_hal_hip_pending_queue_actions_notify_completion_thread(
+ iree_hal_hip_completion_area_t* completion_area) {
+ iree_atomic_store_int32(&completion_area->worker_state,
+ IREE_HAL_HIP_WORKER_STATE_WORKLOAD_PENDING,
+ iree_memory_order_release);
+ iree_notification_post(&completion_area->state_notification,
+ IREE_ALL_WAITERS);
+}
+
+// Notifies worker and completion threads that there is work available to
+// process.
+static void iree_hal_hip_pending_queue_actions_notify_threads(
+ iree_hal_hip_pending_queue_actions_t* actions) {
+ iree_hal_hip_pending_queue_actions_notify_worker_thread(
+ &actions->working_area);
+ iree_hal_hip_pending_queue_actions_notify_completion_thread(
+ &actions->completion_area);
+}
+
+static void iree_hal_hip_pending_queue_actions_request_exit(
+ iree_hal_hip_pending_queue_actions_t* actions) {
+ iree_slim_mutex_lock(&actions->action_mutex);
+ actions->exit_requested = true;
+ iree_slim_mutex_unlock(&actions->action_mutex);
+
+ iree_hal_hip_pending_queue_actions_notify_threads(actions);
+}
void iree_hal_hip_pending_queue_actions_destroy(
iree_hal_resource_t* base_actions) {
+ IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_hip_pending_queue_actions_t* actions =
iree_hal_hip_pending_queue_actions_cast(base_actions);
iree_allocator_t host_allocator = actions->host_allocator;
- iree_hal_hip_working_area_t* working_area = &actions->working_area;
- iree_hal_hip_completion_area_t* completion_area = &actions->completion_area;
- IREE_TRACE_ZONE_BEGIN(z0);
- // Request the worker to exit.
- iree_hal_hip_worker_state_t prev_state =
- (iree_hal_hip_worker_state_t)iree_atomic_exchange_int32(
- &working_area->worker_state, IREE_HAL_HIP_WORKER_STATE_EXIT_REQUESTED,
- iree_memory_order_acq_rel);
- iree_notification_post(&working_area->state_notification, IREE_ALL_WAITERS);
+ // Request the workers to exit.
+ iree_hal_hip_pending_queue_actions_request_exit(actions);
- // Check potential exit states from the worker.
- if (prev_state != IREE_HAL_HIP_WORKER_STATE_EXIT_ERROR) {
- // Wait until the worker acknowledged exiting.
- iree_notification_await(
- &working_area->exit_notification,
- (iree_condition_fn_t)iree_hal_hip_worker_committed_exiting,
- working_area, iree_infinite_timeout());
- }
-
- // Now we can delete worker related resources.
+ iree_thread_join(actions->worker_thread);
iree_thread_release(actions->worker_thread);
- iree_hal_hip_working_area_deinitialize(working_area);
- // Request the completion thread to exit.
- prev_state = (iree_hal_hip_worker_state_t)iree_atomic_exchange_int32(
- &completion_area->worker_state, IREE_HAL_HIP_WORKER_STATE_EXIT_REQUESTED,
- iree_memory_order_acq_rel);
- iree_notification_post(&completion_area->state_notification,
- IREE_ALL_WAITERS);
-
- // Check potential exit states from the completion thread.
- if (prev_state != IREE_HAL_HIP_WORKER_STATE_EXIT_ERROR) {
- // Wait until the worker acknowledged exiting.
- iree_notification_await(
- &completion_area->exit_notification,
- (iree_condition_fn_t)iree_hal_hip_completion_committed_exiting,
- completion_area, iree_infinite_timeout());
- }
-
+ iree_thread_join(actions->completion_thread);
iree_thread_release(actions->completion_thread);
- iree_hal_hip_completion_area_deinitialize(completion_area);
+
+ iree_hal_hip_working_area_deinitialize(&actions->working_area,
+ actions->host_allocator);
+ iree_hal_hip_completion_area_deinitialize(
+ &actions->completion_area, actions->symbols, actions->host_allocator);
iree_slim_mutex_deinitialize(&actions->action_mutex);
iree_hal_hip_queue_action_list_destroy(actions->action_list.head);
@@ -656,11 +630,6 @@
IREE_TRACE_ZONE_END(z0);
}
-static const iree_hal_resource_vtable_t
- iree_hal_hip_pending_queue_actions_vtable = {
- .destroy = iree_hal_hip_pending_queue_actions_destroy,
-};
-
static void iree_hal_hip_queue_action_destroy(
iree_hal_hip_queue_action_t* action) {
IREE_TRACE_ZONE_BEGIN(z0);
@@ -685,30 +654,10 @@
}
static void iree_hal_hip_queue_decrement_work_items_count(
- iree_hal_hip_working_area_t* working_area) {
- iree_slim_mutex_lock(&working_area->pending_work_items_count_mutex);
- --working_area->pending_work_items_count;
- if (working_area->pending_work_items_count == 0) {
- // Notify inside the lock to make sure that we are done touching anything
- // since the context may get destroyed in the meantime.
- iree_notification_post(&working_area->pending_work_items_count_notification,
- IREE_ALL_WAITERS);
- }
- iree_slim_mutex_unlock(&working_area->pending_work_items_count_mutex);
-}
-
-static void iree_hal_hip_queue_decrement_completion_count(
- iree_hal_hip_completion_area_t* completion_area) {
- iree_slim_mutex_lock(&completion_area->pending_completion_count_mutex);
- --completion_area->pending_completion_count;
- if (completion_area->pending_completion_count == 0) {
- // Notify inside the lock to make sure that we are done touching anything
- // since the context may get destroyed in the meantime.
- iree_notification_post(
- &completion_area->pending_completion_count_notification,
- IREE_ALL_WAITERS);
- }
- iree_slim_mutex_unlock(&completion_area->pending_completion_count_mutex);
+ iree_hal_hip_pending_queue_actions_t* actions) {
+ iree_slim_mutex_lock(&actions->action_mutex);
+ --actions->pending_work_items_count;
+ iree_slim_mutex_unlock(&actions->action_mutex);
}
iree_status_t iree_hal_hip_pending_queue_actions_enqueue_execution(
@@ -854,12 +803,21 @@
}
if (iree_status_is_ok(status)) {
- // Retain the owning queue to make sure the action outlives it.
- iree_hal_resource_retain(actions);
-
// Now everything is okay and we can enqueue the action.
iree_slim_mutex_lock(&actions->action_mutex);
- iree_hal_hip_queue_action_list_push_back(&actions->action_list, action);
+ if (actions->exit_requested) {
+ status = iree_make_status(
+ IREE_STATUS_ABORTED,
+ "can not issue more executions, exit already requested");
+ iree_hal_hip_queue_action_fail_locked(action, status);
+ } else {
+ iree_hal_hip_queue_action_list_push_back(&actions->action_list, action);
+ // One work item is the callback that makes it across from the
+ // completion thread.
+ // The other is the cleanup of the action.
+ actions->pending_work_items_count +=
+ total_work_items_to_complete_an_action;
+ }
iree_slim_mutex_unlock(&actions->action_mutex);
} else {
iree_hal_resource_set_free(action->resource_set);
@@ -870,52 +828,127 @@
return status;
}
-static void iree_hal_hip_post_error_to_worker_state(
- iree_hal_hip_working_area_t* working_area, iree_status_code_t code) {
- // Write error code, but don't overwrite existing error codes.
- intptr_t prev_error_code = IREE_STATUS_OK;
- iree_atomic_compare_exchange_strong_int32(
- &working_area->error_code, /*expected=*/&prev_error_code,
- /*desired=*/code,
- /*order_succ=*/iree_memory_order_acq_rel,
- /*order_fail=*/iree_memory_order_acquire);
-
- // This state has the highest priority so just overwrite.
- iree_atomic_store_int32(&working_area->worker_state,
- IREE_HAL_HIP_WORKER_STATE_EXIT_ERROR,
- iree_memory_order_release);
- iree_notification_post(&working_area->state_notification, IREE_ALL_WAITERS);
+// Does not consume |status|.
+static void iree_hal_hip_pending_queue_actions_fail_status_locked(
+ iree_hal_hip_pending_queue_actions_t* actions, iree_status_t status) {
+ if (iree_status_is_ok(actions->status) && status != actions->status) {
+ actions->status = iree_status_clone(status);
+ }
}
-static void iree_hal_hip_post_error_to_completion_state(
- iree_hal_hip_completion_area_t* completion_area, iree_status_code_t code) {
- // Write error code, but don't overwrite existing error codes.
- intptr_t prev_error_code = IREE_STATUS_OK;
- iree_atomic_compare_exchange_strong_int32(
- &completion_area->error_code, /*expected=*/&prev_error_code,
- /*desired=*/code,
- /*order_succ=*/iree_memory_order_acq_rel,
- /*order_fail=*/iree_memory_order_acquire);
+// Fails and destroys the action.
+// Does not consume |status|.
+// Decrements pending work items count accordingly based on the unfulfilled
+// number of work items.
+static void iree_hal_hip_queue_action_fail_locked(
+ iree_hal_hip_queue_action_t* action, iree_status_t status) {
+ IREE_ASSERT(!iree_status_is_ok(status));
+ iree_hal_hip_pending_queue_actions_t* actions = action->owning_actions;
- // This state has the highest priority so just overwrite.
- iree_atomic_store_int32(&completion_area->worker_state,
- IREE_HAL_HIP_WORKER_STATE_EXIT_ERROR,
- iree_memory_order_release);
- iree_notification_post(&completion_area->state_notification,
- IREE_ALL_WAITERS);
+ // Unlock since failing the semaphore will use |actions|.
+ iree_slim_mutex_unlock(&actions->action_mutex);
+ iree_hal_semaphore_list_fail(action->signal_semaphore_list,
+ iree_status_clone(status));
+
+ iree_host_size_t work_items_remaining = 0;
+ switch (action->state) {
+ case IREE_HAL_HIP_QUEUE_ACTION_STATE_ALIVE:
+ work_items_remaining = total_work_items_to_complete_an_action;
+ break;
+ case IREE_HAL_HIP_QUEUE_ACTION_STATE_ZOMBIE:
+ work_items_remaining = 1;
+ break;
+ default:
+ // Someone forgot to handle all enum values?
+ iree_abort();
+ }
+ iree_slim_mutex_lock(&actions->action_mutex);
+ action->owning_actions->pending_work_items_count -= work_items_remaining;
+ iree_hal_hip_pending_queue_actions_fail_status_locked(actions, status);
+ iree_hal_hip_queue_action_destroy(action);
+}
+
+// Fails and destroys all actions.
+// Does not consume |status|.
+static void iree_hal_hip_queue_action_fail(iree_hal_hip_queue_action_t* action,
+ iree_status_t status) {
+ iree_hal_hip_pending_queue_actions_t* actions = action->owning_actions;
+ iree_slim_mutex_lock(&actions->action_mutex);
+ iree_hal_hip_queue_action_fail_locked(action, status);
+ iree_slim_mutex_unlock(&actions->action_mutex);
+}
+
+// Fails and destroys all actions.
+// Does not consume |status|.
+static void iree_hal_hip_queue_action_raw_list_fail_locked(
+ iree_hal_hip_queue_action_t* head_action, iree_status_t status) {
+ while (head_action) {
+ iree_hal_hip_queue_action_t* next_action = head_action->next;
+ iree_hal_hip_queue_action_fail_locked(head_action, status);
+ head_action = next_action;
+ }
+}
+
+// Fails and destroys all actions.
+// Does not consume |status|.
+static void iree_hal_hip_ready_action_list_fail_locked(
+ iree_hal_hip_entry_list_t* list, iree_status_t status) {
+ iree_hal_hip_entry_list_node_t* entry = iree_hal_hip_entry_list_pop(list);
+ while (entry) {
+ iree_hal_hip_queue_action_raw_list_fail_locked(entry->ready_list_head,
+ status);
+ entry = iree_hal_hip_entry_list_pop(list);
+ }
+}
+
+// Fails and destroys all actions.
+// Does not consume |status|.
+static void iree_hal_hip_queue_action_list_fail_locked(
+ iree_hal_hip_queue_action_list_t* list, iree_status_t status) {
+ iree_hal_hip_queue_action_t* action;
+ if (iree_hal_hip_queue_action_list_is_empty(list)) {
+ return;
+ }
+ do {
+ action = iree_hal_hip_queue_action_list_pop_front(list);
+ iree_hal_hip_queue_action_fail_locked(action, status);
+ } while (action);
+}
+
+// Fails and destroys all actions and sets status of |actions|.
+// Does not consume |status|.
+// Assumes the caller is holding the action_mutex.
+static void iree_hal_hip_pending_queue_actions_fail_locked(
+ iree_hal_hip_pending_queue_actions_t* actions, iree_status_t status) {
+ iree_hal_hip_pending_queue_actions_fail_status_locked(actions, status);
+ iree_hal_hip_queue_action_list_fail_locked(&actions->action_list, status);
+ iree_hal_hip_ready_action_list_fail_locked(
+ &actions->working_area.ready_worklist, status);
+}
+
+// Does not consume |status|.
+static void iree_hal_hip_pending_queue_actions_fail(
+ iree_hal_hip_pending_queue_actions_t* actions, iree_status_t status) {
+ iree_slim_mutex_lock(&actions->action_mutex);
+ iree_hal_hip_pending_queue_actions_fail_locked(actions, status);
+ iree_slim_mutex_unlock(&actions->action_mutex);
}
// Releases resources after action completion on the GPU and advances timeline
// and pending actions queue.
static iree_status_t iree_hal_hip_execution_device_signal_host_callback(
- void* user_data) {
+ iree_status_t status, void* user_data) {
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_hip_queue_action_t* action = (iree_hal_hip_queue_action_t*)user_data;
IREE_ASSERT_EQ(action->kind, IREE_HAL_HIP_QUEUE_ACTION_TYPE_EXECUTION);
IREE_ASSERT_EQ(action->state, IREE_HAL_HIP_QUEUE_ACTION_STATE_ALIVE);
iree_hal_hip_pending_queue_actions_t* actions = action->owning_actions;
- iree_status_t status;
+ if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
+ iree_hal_hip_queue_action_fail(action, status);
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+ }
// Need to signal the list before zombifying the action, because in the mean
// time someone else may issue the pending queue actions.
@@ -923,9 +956,9 @@
// may run while we are still using the semaphore list, causing a crash.
status = iree_hal_semaphore_list_signal(action->signal_semaphore_list);
if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
- IREE_ASSERT(false && "cannot signal semaphores in host callback");
- iree_hal_hip_post_error_to_worker_state(&actions->working_area,
- iree_status_code(status));
+ iree_hal_hip_queue_action_fail(action, status);
+ IREE_TRACE_ZONE_END(z0);
+ return status;
}
// Flip the action state to zombie and enqueue it again so that we can let
@@ -935,18 +968,12 @@
action->state = IREE_HAL_HIP_QUEUE_ACTION_STATE_ZOMBIE;
iree_slim_mutex_lock(&actions->action_mutex);
iree_hal_hip_queue_action_list_push_back(&actions->action_list, action);
+ // The callback (work item) is complete.
+ --actions->pending_work_items_count;
iree_slim_mutex_unlock(&actions->action_mutex);
// We need to trigger execution of this action again, so it gets cleaned up.
status = iree_hal_hip_pending_queue_actions_issue(actions);
- if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
- IREE_ASSERT(false && "cannot issue action for cleanup in host callback");
- iree_hal_hip_post_error_to_worker_state(&actions->working_area,
- iree_status_code(status));
- }
-
- // The callback (work item) is complete.
- iree_hal_hip_queue_decrement_work_items_count(&actions->working_area);
IREE_TRACE_ZONE_END(z0);
return status;
@@ -957,8 +984,8 @@
iree_hal_hip_queue_action_t* action) {
IREE_ASSERT_EQ(action->kind, IREE_HAL_HIP_QUEUE_ACTION_TYPE_EXECUTION);
IREE_ASSERT_EQ(action->is_pending, false);
- const iree_hal_hip_dynamic_symbols_t* symbols =
- action->owning_actions->symbols;
+ iree_hal_hip_pending_queue_actions_t* actions = action->owning_actions;
+ const iree_hal_hip_dynamic_symbols_t* symbols = actions->symbols;
IREE_TRACE_ZONE_BEGIN(z0);
// No need to lock given that this action is already detched from the pending
@@ -1063,19 +1090,11 @@
z0, symbols,
hipEventRecord(completion_event, action->dispatch_hip_stream),
"hipEventRecord");
- iree_slim_mutex_lock(
- &action->owning_actions->working_area.pending_work_items_count_mutex);
- // One work item is the callback that makes it across from the
- // completion thread.
- // The other is the cleanup of the action.
- action->owning_actions->working_area.pending_work_items_count += 2;
- iree_slim_mutex_unlock(
- &action->owning_actions->working_area.pending_work_items_count_mutex);
iree_hal_hip_completion_list_node_t* entry = NULL;
// TODO: avoid host allocator malloc; use some pool for the allocation.
- iree_status_t status = iree_allocator_malloc(
- action->owning_actions->host_allocator, sizeof(*entry), (void**)&entry);
+ iree_status_t status = iree_allocator_malloc(actions->host_allocator,
+ sizeof(*entry), (void**)&entry);
if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
IREE_TRACE_ZONE_END(z0);
@@ -1088,40 +1107,11 @@
entry->created_event = created_event;
entry->callback = iree_hal_hip_execution_device_signal_host_callback;
entry->user_data = action;
+ iree_hal_hip_completion_list_push(&actions->completion_area.completion_list,
+ entry);
- iree_hal_hip_completion_list_push(
- &action->owning_actions->completion_area.completion_list, entry);
-
- iree_slim_mutex_lock(
- &action->owning_actions->completion_area.pending_completion_count_mutex);
-
- action->owning_actions->completion_area.pending_completion_count += 1;
-
- iree_slim_mutex_unlock(
- &action->owning_actions->completion_area.pending_completion_count_mutex);
-
- // We can only overwrite the worker state if the previous state is idle
- // waiting; we cannot overwrite exit related states. so we need to perform
- // atomic compare and exchange here.
- iree_hal_hip_worker_state_t prev_state =
- IREE_HAL_HIP_WORKER_STATE_IDLE_WAITING;
- iree_atomic_compare_exchange_strong_int32(
- &action->owning_actions->completion_area.worker_state,
- /*expected=*/&prev_state,
- /*desired=*/IREE_HAL_HIP_WORKER_STATE_WORKLOAD_PENDING,
- /*order_succ=*/iree_memory_order_acq_rel,
- /*order_fail=*/iree_memory_order_acquire);
- iree_notification_post(
- &action->owning_actions->completion_area.state_notification,
- IREE_ALL_WAITERS);
-
- // Handle potential error cases from the worker thread.
- if (prev_state == IREE_HAL_HIP_WORKER_STATE_EXIT_ERROR) {
- iree_status_code_t code = iree_atomic_load_int32(
- &action->owning_actions->completion_area.error_code,
- iree_memory_order_acquire);
- IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_status_from_code(code));
- }
+ iree_hal_hip_pending_queue_actions_notify_completion_thread(
+ &actions->completion_area);
IREE_TRACE_ZONE_END(z0);
return iree_ok_status();
@@ -1137,7 +1127,7 @@
// Now we fully executed and cleaned up this action. Decrease the work items
// counter.
- iree_hal_hip_queue_decrement_work_items_count(&actions->working_area);
+ iree_hal_hip_queue_decrement_work_items_count(actions);
IREE_TRACE_ZONE_END(z0);
}
@@ -1157,8 +1147,16 @@
return iree_ok_status();
}
- // Scan through the list and categorize actions into pending and ready lists.
+ if (IREE_UNLIKELY(!iree_status_is_ok(actions->status))) {
+ iree_hal_hip_queue_action_list_fail_locked(&actions->action_list,
+ actions->status);
+ iree_slim_mutex_unlock(&actions->action_mutex);
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+ }
+
iree_status_t status = iree_ok_status();
+ // Scan through the list and categorize actions into pending and ready lists.
while (!iree_hal_hip_queue_action_list_is_empty(&actions->action_list)) {
iree_hal_hip_queue_action_t* action =
iree_hal_hip_queue_action_list_pop_front(&actions->action_list);
@@ -1167,6 +1165,7 @@
uint64_t* values = action->wait_semaphore_list.payload_values;
action->is_pending = false;
+ bool action_failed = false;
// Cleanup actions are immediately ready to release. Otherwise, look at all
// wait semaphores to make sure that they are either already ready or we can
@@ -1176,8 +1175,14 @@
// If this semaphore has already signaled past the desired value, we can
// just ignore it.
uint64_t value = 0;
- status = iree_hal_semaphore_query(semaphores[i], &value);
- if (IREE_UNLIKELY(!iree_status_is_ok(status))) break;
+ iree_status_t semaphore_status =
+ iree_hal_semaphore_query(semaphores[i], &value);
+ if (IREE_UNLIKELY(!iree_status_is_ok(semaphore_status))) {
+ iree_hal_hip_queue_action_fail_locked(action, semaphore_status);
+ iree_status_ignore(semaphore_status);
+ action_failed = true;
+ break;
+ }
if (value >= values[i]) {
// No need to wait on this timepoint as it has already occurred and
// we can remove it from the wait list.
@@ -1206,6 +1211,10 @@
IREE_STATUS_RESOURCE_EXHAUSTED,
"exceeded maximum queue action wait event limit");
iree_hal_hip_event_release(wait_event);
+ if (iree_status_is_ok(actions->status)) {
+ actions->status = status;
+ }
+ iree_hal_hip_queue_action_fail_locked(action, status);
break;
}
action->events[action->event_count++] = wait_event;
@@ -1217,11 +1226,16 @@
}
}
- if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
- // Some error happened during processing the current action.
- // Put it back to the pending list so we don't leak.
- action->is_pending = true;
- iree_hal_hip_queue_action_list_push_back(&pending_list, action);
+ if (IREE_UNLIKELY(!iree_status_is_ok(actions->status))) {
+ if (!action_failed) {
+ iree_hal_hip_queue_action_fail_locked(action, actions->status);
+ }
+ iree_hal_hip_queue_action_list_fail_locked(&actions->action_list,
+ actions->status);
+ break;
+ }
+
+ if (action_failed) {
break;
}
@@ -1251,9 +1265,10 @@
}
if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
- // Release all actions in the ready list to avoid leaking.
- iree_hal_hip_queue_action_list_destroy(ready_list.head);
- iree_allocator_free(actions->host_allocator, entry);
+ iree_slim_mutex_lock(&actions->action_mutex);
+ iree_hal_hip_pending_queue_actions_fail_status_locked(actions, status);
+ iree_hal_hip_queue_action_list_fail_locked(&ready_list, status);
+ iree_slim_mutex_unlock(&actions->action_mutex);
IREE_TRACE_ZONE_END(z0);
return status;
}
@@ -1263,25 +1278,8 @@
entry->ready_list_head = ready_list.head;
iree_hal_hip_entry_list_push(&actions->working_area.ready_worklist, entry);
- // We can only overwrite the worker state if the previous state is idle
- // waiting; we cannot overwrite exit related states. so we need to perform
- // atomic compare and exchange here.
- iree_hal_hip_worker_state_t prev_state =
- IREE_HAL_HIP_WORKER_STATE_IDLE_WAITING;
- iree_atomic_compare_exchange_strong_int32(
- &actions->working_area.worker_state, /*expected=*/&prev_state,
- /*desired=*/IREE_HAL_HIP_WORKER_STATE_WORKLOAD_PENDING,
- /*order_succ=*/iree_memory_order_acq_rel,
- /*order_fail=*/iree_memory_order_acquire);
- iree_notification_post(&actions->working_area.state_notification,
- IREE_ALL_WAITERS);
-
- // Handle potential error cases from the worker thread.
- if (prev_state == IREE_HAL_HIP_WORKER_STATE_EXIT_ERROR) {
- iree_status_code_t code = iree_atomic_load_int32(
- &actions->working_area.error_code, iree_memory_order_acquire);
- status = iree_status_from_code(code);
- }
+ iree_hal_hip_pending_queue_actions_notify_worker_thread(
+ &actions->working_area);
IREE_TRACE_ZONE_END(z0);
return status;
@@ -1291,246 +1289,152 @@
// Worker routines
//===----------------------------------------------------------------------===//
-static bool iree_hal_hip_worker_has_incoming_request_or_error(
+static bool iree_hal_hip_worker_has_incoming_request(
iree_hal_hip_working_area_t* working_area) {
iree_hal_hip_worker_state_t value = iree_atomic_load_int32(
&working_area->worker_state, iree_memory_order_acquire);
- return value == IREE_HAL_HIP_WORKER_STATE_WORKLOAD_PENDING ||
- value == IREE_HAL_HIP_WORKER_STATE_EXIT_REQUESTED ||
- value == IREE_HAL_HIP_WORKER_STATE_EXIT_ERROR;
+ return value == IREE_HAL_HIP_WORKER_STATE_WORKLOAD_PENDING;
}
-static bool iree_hal_hip_completion_has_incoming_request_or_error(
+static bool iree_hal_hip_completion_has_incoming_request(
iree_hal_hip_completion_area_t* completion_area) {
iree_hal_hip_worker_state_t value = iree_atomic_load_int32(
&completion_area->worker_state, iree_memory_order_acquire);
- return value == IREE_HAL_HIP_WORKER_STATE_WORKLOAD_PENDING ||
- value == IREE_HAL_HIP_WORKER_STATE_EXIT_REQUESTED ||
- value == IREE_HAL_HIP_WORKER_STATE_EXIT_ERROR;
-}
-
-static bool iree_hal_hip_worker_committed_exiting(
- iree_hal_hip_working_area_t* working_area) {
- return iree_atomic_load_int32(&working_area->worker_state,
- iree_memory_order_acquire) ==
- IREE_HAL_HIP_WORKER_STATE_EXIT_COMMITTED;
-}
-
-static bool iree_hal_hip_completion_committed_exiting(
- iree_hal_hip_completion_area_t* working_area) {
- return iree_atomic_load_int32(&working_area->worker_state,
- iree_memory_order_acquire) ==
- IREE_HAL_HIP_WORKER_STATE_EXIT_COMMITTED;
+ return value == IREE_HAL_HIP_WORKER_STATE_WORKLOAD_PENDING;
}
// Processes all ready actions in the given |worklist|.
-static iree_status_t iree_hal_hip_worker_process_ready_list(
- iree_allocator_t host_allocator, iree_hal_hip_entry_list_t* worklist) {
+static void iree_hal_hip_worker_process_ready_list(
+ iree_hal_hip_pending_queue_actions_t* actions) {
IREE_TRACE_ZONE_BEGIN(z0);
- iree_status_t status = iree_ok_status();
+ iree_slim_mutex_lock(&actions->action_mutex);
+ iree_status_t status = actions->status;
+ if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
+ iree_hal_hip_ready_action_list_fail_locked(
+ &actions->working_area.ready_worklist, status);
+ iree_slim_mutex_unlock(&actions->action_mutex);
+ iree_status_ignore(status);
+ return;
+ }
+ iree_slim_mutex_unlock(&actions->action_mutex);
+
while (true) {
iree_hal_hip_entry_list_node_t* entry =
- iree_hal_hip_entry_list_pop(worklist);
+ iree_hal_hip_entry_list_pop(&actions->working_area.ready_worklist);
if (!entry) break;
// Process the current batch of ready actions.
while (entry->ready_list_head) {
iree_hal_hip_queue_action_t* action =
- iree_hal_hip_atomic_slist_entry_pop_front(entry);
+ iree_hal_hip_entry_list_node_pop_front(entry);
switch (action->state) {
case IREE_HAL_HIP_QUEUE_ACTION_STATE_ALIVE:
status = iree_hal_hip_pending_queue_actions_issue_execution(action);
- if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
- iree_hal_hip_queue_action_destroy(action);
- }
break;
case IREE_HAL_HIP_QUEUE_ACTION_STATE_ZOMBIE:
iree_hal_hip_pending_queue_actions_issue_cleanup(action);
break;
}
- if (IREE_UNLIKELY(!iree_status_is_ok(status))) break;
+
+ if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
+ iree_hal_hip_entry_list_node_push_front(entry, action);
+ iree_hal_hip_entry_list_push(&actions->working_area.ready_worklist,
+ entry);
+ break;
+ }
}
if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
- // Let common destruction path take care of destroying the worklist.
- // When we know all host stream callbacks are done and not touching
- // anything.
- iree_hal_hip_entry_list_push(worklist, entry);
break;
}
- iree_allocator_free(host_allocator, entry);
+ iree_allocator_free(actions->host_allocator, entry);
+ }
+
+ if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
+ iree_hal_hip_pending_queue_actions_fail(actions, status);
+ iree_status_ignore(status);
}
IREE_TRACE_ZONE_END(z0);
- return status;
}
-static bool iree_hal_hip_worker_has_no_pending_work_items(
- iree_hal_hip_working_area_t* working_area) {
- iree_slim_mutex_lock(&working_area->pending_work_items_count_mutex);
- bool result = (working_area->pending_work_items_count == 0);
- iree_slim_mutex_unlock(&working_area->pending_work_items_count_mutex);
- return result;
-}
+static void iree_hal_hip_worker_process_completion(
+ iree_hal_hip_pending_queue_actions_t* actions) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+ iree_hal_hip_completion_list_t* worklist =
+ &actions->completion_area.completion_list;
+ iree_slim_mutex_lock(&actions->action_mutex);
+ iree_status_t status = iree_status_clone(actions->status);
+ iree_slim_mutex_unlock(&actions->action_mutex);
-static bool iree_hal_hip_completion_has_no_pending_completion_items(
- iree_hal_hip_completion_area_t* completion_area) {
- iree_slim_mutex_lock(&completion_area->pending_completion_count_mutex);
- bool result = (completion_area->pending_completion_count == 0);
- iree_slim_mutex_unlock(&completion_area->pending_completion_count_mutex);
- return result;
-}
-
-// Wait for all work items to finish.
-static void iree_hal_hip_worker_wait_pending_work_items(
- iree_hal_hip_working_area_t* working_area) {
- iree_notification_await(
- &working_area->pending_work_items_count_notification,
- (iree_condition_fn_t)iree_hal_hip_worker_has_no_pending_work_items,
- working_area, iree_infinite_timeout());
- // Lock then unlock to make sure that all callbacks are really done.
- // Not even touching the notification.
- iree_slim_mutex_lock(&working_area->pending_work_items_count_mutex);
- iree_slim_mutex_unlock(&working_area->pending_work_items_count_mutex);
-}
-
-// Wait for all work items to finish.
-static void iree_hal_hip_completion_wait_pending_completion_items(
- iree_hal_hip_completion_area_t* completion_area) {
- iree_notification_await(
- &completion_area->pending_completion_count_notification,
- (iree_condition_fn_t)
- iree_hal_hip_completion_has_no_pending_completion_items,
- completion_area, iree_infinite_timeout());
- // Lock then unlock to make sure that all callbacks are really done.
- // Not even touching the notification.
- iree_slim_mutex_lock(&completion_area->pending_completion_count_mutex);
- iree_slim_mutex_unlock(&completion_area->pending_completion_count_mutex);
-}
-
-static iree_status_t iree_hal_hip_worker_process_completion(
- iree_hal_hip_completion_list_t* worklist,
- iree_hal_hip_completion_area_t* completion_area) {
- iree_status_t status = iree_ok_status();
while (true) {
iree_hal_hip_completion_list_node_t* entry =
iree_hal_hip_completion_list_pop(worklist);
if (!entry) break;
- IREE_TRACE_ZONE_BEGIN_NAMED(z1, "hipEventSynchronize");
- hipError_t result =
- completion_area->symbols->hipEventSynchronize(entry->event);
- IREE_TRACE_ZONE_END(z1);
- if (IREE_UNLIKELY(result != hipSuccess)) {
- // Let common destruction path take care of destroying the worklist.
- // When we know all host stream callbacks are done and not touching
- // anything.
- iree_hal_hip_completion_list_push(worklist, entry);
- status =
- iree_make_status(IREE_STATUS_ABORTED, "could not wait on hip event");
- break;
+ if (IREE_LIKELY(iree_status_is_ok(status))) {
+ IREE_TRACE_ZONE_BEGIN_NAMED(z1, "hipEventSynchronize");
+ status = IREE_HIP_RESULT_TO_STATUS(actions->symbols,
+ hipEventSynchronize(entry->event));
+ IREE_TRACE_ZONE_END(z1);
}
- status = entry->callback(entry->user_data);
- if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
- break;
- }
+
+ status =
+ iree_status_join(status, entry->callback(status, entry->user_data));
if (IREE_UNLIKELY(entry->created_event)) {
- IREE_HIP_IGNORE_ERROR(completion_area->symbols,
- hipEventDestroy(entry->event));
+ status = iree_status_join(
+ status, IREE_HIP_RESULT_TO_STATUS(actions->symbols,
+ hipEventDestroy(entry->event)));
}
- iree_allocator_free(completion_area->host_allocator, entry);
-
- // Now we fully executed and cleaned up this entry. Decrease the work
- // items counter.
- iree_hal_hip_queue_decrement_completion_count(completion_area);
+ iree_allocator_free(actions->host_allocator, entry);
}
- return status;
+
+ if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
+ iree_hal_hip_pending_queue_actions_fail(actions, status);
+ iree_status_ignore(status);
+ }
+
+ IREE_TRACE_ZONE_END(z0);
}
// The main function for the completion worker thread.
static int iree_hal_hip_completion_execute(
- iree_hal_hip_completion_area_t* completion_area) {
- iree_hal_hip_completion_list_t* worklist = &completion_area->completion_list;
+ iree_hal_hip_pending_queue_actions_t* actions) {
+ iree_hal_hip_completion_area_t* completion_area = &actions->completion_area;
iree_status_t status = IREE_HIP_RESULT_TO_STATUS(
- completion_area->symbols, hipSetDevice(completion_area->device),
- "hipSetDevice");
+ actions->symbols, hipSetDevice(actions->device), "hipSetDevice");
if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
- iree_hal_hip_completion_wait_pending_completion_items(completion_area);
- iree_hal_hip_post_error_to_completion_state(completion_area,
- iree_status_code(status));
- return -1;
+ iree_hal_hip_pending_queue_actions_fail(actions, status);
+ iree_status_ignore(status);
}
while (true) {
iree_notification_await(
&completion_area->state_notification,
- (iree_condition_fn_t)
- iree_hal_hip_completion_has_incoming_request_or_error,
+ (iree_condition_fn_t)iree_hal_hip_completion_has_incoming_request,
completion_area, iree_infinite_timeout());
- IREE_TRACE_ZONE_BEGIN(z0);
// Immediately flip the state to idle waiting if and only if the previous
// state is workload pending. We do it before processing ready list to make
// sure that we don't accidentally ignore new workload pushed after done
// ready list processing but before overwriting the state from this worker
- // thread. Also we don't want to overwrite other exit states. So we need to
- // perform atomic compare and exchange here.
- iree_hal_hip_worker_state_t prev_state =
- IREE_HAL_HIP_WORKER_STATE_WORKLOAD_PENDING;
- iree_atomic_compare_exchange_strong_int32(
- &completion_area->worker_state, /*expected=*/&prev_state,
- /*desired=*/IREE_HAL_HIP_WORKER_STATE_IDLE_WAITING,
- /*order_succ=*/iree_memory_order_acq_rel,
- /*order_fail=*/iree_memory_order_acquire);
+ // thread.
+ iree_atomic_store_int32(&completion_area->worker_state,
+ IREE_HAL_HIP_WORKER_STATE_IDLE_WAITING,
+ iree_memory_order_release);
+ iree_hal_hip_worker_process_completion(actions);
- int32_t worker_state = iree_atomic_load_int32(
- &completion_area->worker_state, iree_memory_order_acquire);
- // Exit if HIP callbacks have posted any errors.
- if (IREE_UNLIKELY(worker_state == IREE_HAL_HIP_WORKER_STATE_EXIT_ERROR)) {
- iree_hal_hip_completion_wait_pending_completion_items(completion_area);
- IREE_TRACE_ZONE_END(z0);
- return -1;
- }
- // Check if we received request to stop processing and exit this thread.
- bool should_exit =
- (worker_state == IREE_HAL_HIP_WORKER_STATE_EXIT_REQUESTED);
-
- iree_status_t status =
- iree_hal_hip_worker_process_completion(worklist, completion_area);
-
- if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
- iree_hal_hip_completion_wait_pending_completion_items(completion_area);
- iree_hal_hip_post_error_to_completion_state(completion_area,
- iree_status_code(status));
- IREE_TRACE_ZONE_END(z0);
- return -1;
- }
-
- if (IREE_UNLIKELY(should_exit &&
- iree_hal_hip_completion_has_no_pending_completion_items(
- completion_area))) {
- iree_hal_hip_completion_wait_pending_completion_items(completion_area);
- // Signal that this thread is committed to exit.
- // This state has a priority that is only lower than error exit.
- // A HIP callback may have posted an error, make sure we don't
- // overwrite this error state.
- iree_hal_hip_worker_state_t prev_state =
- IREE_HAL_HIP_WORKER_STATE_EXIT_REQUESTED;
- iree_atomic_compare_exchange_strong_int32(
- &completion_area->worker_state, /*expected=*/&prev_state,
- /*desired=*/IREE_HAL_HIP_WORKER_STATE_EXIT_COMMITTED,
- /*order_succ=*/iree_memory_order_acq_rel,
- /*order_fail=*/iree_memory_order_acquire);
- iree_notification_post(&completion_area->exit_notification,
- IREE_ALL_WAITERS);
- IREE_TRACE_ZONE_END(z0);
+ iree_slim_mutex_lock(&actions->action_mutex);
+ if (IREE_UNLIKELY(actions->exit_requested &&
+ !actions->pending_work_items_count)) {
+ iree_slim_mutex_unlock(&actions->action_mutex);
return 0;
}
- IREE_TRACE_ZONE_END(z0);
+ iree_slim_mutex_unlock(&actions->action_mutex);
}
return 0;
@@ -1538,20 +1442,19 @@
// The main function for the ready-list processing worker thread.
static int iree_hal_hip_worker_execute(
- iree_hal_hip_working_area_t* working_area) {
- iree_hal_hip_entry_list_t* worklist = &working_area->ready_worklist;
+ iree_hal_hip_pending_queue_actions_t* actions) {
+ iree_hal_hip_working_area_t* working_area = &actions->working_area;
// Hip stores thread-local data based on the device. Some hip commands pull
// the device from there, and it defaults to device 0 (e.g. hipEventCreate),
// this will cause failures when using it with other devices (or streams from
// other devices). Force the correct device onto this thread.
iree_status_t status = IREE_HIP_RESULT_TO_STATUS(
- working_area->symbols, hipSetDevice(working_area->device),
- "hipSetDevice");
+ actions->symbols, hipSetDevice(actions->device), "hipSetDevice");
if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
- iree_hal_hip_worker_wait_pending_work_items(working_area);
- iree_hal_hip_post_error_to_worker_state(working_area,
- iree_status_code(status));
+ iree_hal_hip_pending_queue_actions_fail(actions, status);
+ iree_status_ignore(status);
+ // We can safely exit here because there are no actions in flight yet.
return -1;
}
@@ -1565,63 +1468,29 @@
// host stream callbacks.
iree_notification_await(
&working_area->state_notification,
- (iree_condition_fn_t)iree_hal_hip_worker_has_incoming_request_or_error,
+ (iree_condition_fn_t)iree_hal_hip_worker_has_incoming_request,
working_area, iree_infinite_timeout());
// Immediately flip the state to idle waiting if and only if the previous
// state is workload pending. We do it before processing ready list to make
// sure that we don't accidentally ignore new workload pushed after done
// ready list processing but before overwriting the state from this worker
- // thread. Also we don't want to overwrite other exit states. So we need to
- // perform atomic compare and exchange here.
- iree_hal_hip_worker_state_t prev_state =
- IREE_HAL_HIP_WORKER_STATE_WORKLOAD_PENDING;
- iree_atomic_compare_exchange_strong_int32(
- &working_area->worker_state, /*expected=*/&prev_state,
- /*desired=*/IREE_HAL_HIP_WORKER_STATE_IDLE_WAITING,
- /*order_succ=*/iree_memory_order_acq_rel,
- /*order_fail=*/iree_memory_order_acquire);
+ // thread.
+ iree_atomic_store_int32(&working_area->worker_state,
+ IREE_HAL_HIP_WORKER_STATE_IDLE_WAITING,
+ iree_memory_order_release);
- int32_t worker_state = iree_atomic_load_int32(&working_area->worker_state,
- iree_memory_order_acquire);
- // Exit if HIP callbacks have posted any errors.
- if (IREE_UNLIKELY(worker_state == IREE_HAL_HIP_WORKER_STATE_EXIT_ERROR)) {
- iree_hal_hip_worker_wait_pending_work_items(working_area);
- return -1;
- }
- // Check if we received request to stop processing and exit this thread.
- bool should_exit =
- (worker_state == IREE_HAL_HIP_WORKER_STATE_EXIT_REQUESTED);
+ iree_hal_hip_worker_process_ready_list(actions);
- // Process the ready list. We also want this even requested to exit.
- iree_status_t status = iree_hal_hip_worker_process_ready_list(
- working_area->host_allocator, worklist);
- if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
- iree_hal_hip_worker_wait_pending_work_items(working_area);
- iree_hal_hip_post_error_to_worker_state(working_area,
- iree_status_code(status));
- return -1;
- }
-
- if (IREE_UNLIKELY(
- should_exit &&
- iree_hal_hip_worker_has_no_pending_work_items(working_area))) {
- iree_hal_hip_worker_wait_pending_work_items(working_area);
- // Signal that this thread is committed to exit.
- // This state has a priority that is only lower than error exit.
- // A HIP callback may have posted an error, make sure we don't
- // overwrite this error state.
- iree_hal_hip_worker_state_t prev_state =
- IREE_HAL_HIP_WORKER_STATE_EXIT_REQUESTED;
- iree_atomic_compare_exchange_strong_int32(
- &working_area->worker_state, /*expected=*/&prev_state,
- /*desired=*/IREE_HAL_HIP_WORKER_STATE_EXIT_COMMITTED,
- /*order_succ=*/iree_memory_order_acq_rel,
- /*order_fail=*/iree_memory_order_acquire);
- iree_notification_post(&working_area->exit_notification,
- IREE_ALL_WAITERS);
+ iree_slim_mutex_lock(&actions->action_mutex);
+ if (IREE_UNLIKELY(actions->exit_requested &&
+ !actions->pending_work_items_count)) {
+ iree_slim_mutex_unlock(&actions->action_mutex);
+ iree_hal_hip_pending_queue_actions_notify_completion_thread(
+ &actions->completion_area);
return 0;
}
+ iree_slim_mutex_unlock(&actions->action_mutex);
}
return 0;
}
diff --git a/runtime/src/iree/hal/drivers/local_sync/cts/CMakeLists.txt b/runtime/src/iree/hal/drivers/local_sync/cts/CMakeLists.txt
index dcdd5f2..759077a 100644
--- a/runtime/src/iree/hal/drivers/local_sync/cts/CMakeLists.txt
+++ b/runtime/src/iree/hal/drivers/local_sync/cts/CMakeLists.txt
@@ -8,6 +8,16 @@
# set(EXECUTABLE_FORMAT_PREFIX "system")
set(NATIVE_EXECUTABLE_FORMAT "\"${EXECUTABLE_FORMAT_PREFIX}-elf-\" IREE_ARCH")
+unset(FILTER_TESTS)
+string(APPEND FILTER_TESTS "SemaphoreTest.WaitThenFail:")
+string(APPEND FILTER_TESTS "SemaphoreTest.FailThenWait:")
+string(APPEND FILTER_TESTS "SemaphoreTest.MultiWaitThenFail:")
+string(APPEND FILTER_TESTS "SemaphoreTest.DeviceMultiWaitThenFail:")
+string(APPEND FILTER_TESTS "SemaphoreSubmissionTest.PropagateFailSignal:")
+set(FILTER_TESTS_ARGS
+ "--gtest_filter=-${FILTER_TESTS}"
+)
+
if(IREE_HAL_EXECUTABLE_LOADER_EMBEDDED_ELF)
iree_hal_cts_test_suite(
DRIVER_NAME
@@ -22,6 +32,8 @@
"llvm-cpu"
EXECUTABLE_FORMAT
"${NATIVE_EXECUTABLE_FORMAT}"
+ ARGS
+ ${FILTER_TESTS_ARGS}
DEPS
iree::hal::drivers::local_sync::registration
EXCLUDED_TESTS
@@ -45,6 +57,8 @@
"vmvx"
EXECUTABLE_FORMAT
"\"vmvx-bytecode-fb\""
+ ARGS
+ ${FILTER_TESTS_ARGS}
DEPS
iree::hal::drivers::local_sync::registration
EXCLUDED_TESTS
diff --git a/runtime/src/iree/hal/drivers/local_task/cts/CMakeLists.txt b/runtime/src/iree/hal/drivers/local_task/cts/CMakeLists.txt
index d569b2d..5658351 100644
--- a/runtime/src/iree/hal/drivers/local_task/cts/CMakeLists.txt
+++ b/runtime/src/iree/hal/drivers/local_task/cts/CMakeLists.txt
@@ -8,6 +8,16 @@
# set(EXECUTABLE_FORMAT_PREFIX "system")
set(NATIVE_EXECUTABLE_FORMAT "\"${EXECUTABLE_FORMAT_PREFIX}-elf-\" IREE_ARCH")
+unset(FILTER_TESTS)
+string(APPEND FILTER_TESTS "SemaphoreTest.WaitThenFail:")
+string(APPEND FILTER_TESTS "SemaphoreTest.FailThenWait:")
+string(APPEND FILTER_TESTS "SemaphoreTest.MultiWaitThenFail:")
+string(APPEND FILTER_TESTS "SemaphoreTest.DeviceMultiWaitThenFail:")
+string(APPEND FILTER_TESTS "SemaphoreSubmissionTest.PropagateFailSignal:")
+set(FILTER_TESTS_ARGS
+ "--gtest_filter=-${FILTER_TESTS}"
+)
+
if(IREE_HAL_EXECUTABLE_LOADER_EMBEDDED_ELF)
iree_hal_cts_test_suite(
DRIVER_NAME
@@ -22,6 +32,8 @@
"llvm-cpu"
EXECUTABLE_FORMAT
"${NATIVE_EXECUTABLE_FORMAT}"
+ ARGS
+ ${FILTER_TESTS_ARGS}
DEPS
iree::hal::drivers::local_task::registration
LABELS
@@ -43,6 +55,8 @@
"vmvx"
EXECUTABLE_FORMAT
"\"vmvx-bytecode-fb\""
+ ARGS
+ ${FILTER_TESTS_ARGS}
DEPS
iree::hal::drivers::local_task::registration
LABELS
diff --git a/runtime/src/iree/hal/drivers/vulkan/cts/CMakeLists.txt b/runtime/src/iree/hal/drivers/vulkan/cts/CMakeLists.txt
index c2135a1..d6d5001 100644
--- a/runtime/src/iree/hal/drivers/vulkan/cts/CMakeLists.txt
+++ b/runtime/src/iree/hal/drivers/vulkan/cts/CMakeLists.txt
@@ -4,12 +4,22 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-unset(ADDITIONAL_TEST_ARGS)
+# TODO: investigate why these tests fail on Vulkan.
+unset(FILTER_TESTS)
+string(APPEND FILTER_TESTS "SemaphoreTest.WaitThenFail:")
+string(APPEND FILTER_TESTS "SemaphoreTest.FailThenWait:")
+string(APPEND FILTER_TESTS "SemaphoreTest.MultiWaitThenFail:")
+string(APPEND FILTER_TESTS "SemaphoreTest.DeviceMultiWaitThenFail:")
+string(APPEND FILTER_TESTS "SemaphoreSubmissionTest.PropagateFailSignal:")
+
if(ANDROID)
# Disable this test on Android due to flaky failures on Google Pixel 6 Pro
# and Moto Edge X30.
- set(ADDITIONAL_TEST_ARGS "--gtest_filter=-*WaitForFiniteTime*")
+ string(APPEND FILTER_TESTS "SemaphoreTest.WaitForFiniteTime:")
endif()
+set(ADDITIONAL_TEST_ARGS
+ "--gtest_filter=-${FILTER_TESTS}"
+)
iree_hal_cts_test_suite(
DRIVER_NAME
diff --git a/runtime/src/iree/hal/semaphore.h b/runtime/src/iree/hal/semaphore.h
index ec0b9af..0d874cc 100644
--- a/runtime/src/iree/hal/semaphore.h
+++ b/runtime/src/iree/hal/semaphore.h
@@ -164,6 +164,7 @@
// Signals each semaphore in |semaphore_list| to indicate failure with
// |signal_status|.
+// Takes ownership of |signal_status|.
IREE_API_EXPORT void iree_hal_semaphore_list_fail(
iree_hal_semaphore_list_t semaphore_list, iree_status_t signal_status);