Propagating errors through fences.
diff --git a/iree/task/executor.c b/iree/task/executor.c
index 13d6b27..267b681 100644
--- a/iree/task/executor.c
+++ b/iree/task/executor.c
@@ -270,6 +270,17 @@
IREE_TRACE_ZONE_BEGIN(z0);
iree_task_t* task = NULL;
while ((task = iree_task_list_pop_front(&pending_submission->ready_list))) {
+ // If the scope has been marked as failing then we abort the task.
+ // This needs to happen as a poll here because one or more of the tasks we
+ // are joining may have failed.
+ if (IREE_UNLIKELY(iree_task_scope_has_failed(task->scope))) {
+ iree_task_list_t discard_worklist;
+ iree_task_list_initialize(&discard_worklist);
+ iree_task_discard(task, &discard_worklist);
+ iree_task_list_discard(&discard_worklist);
+ continue;
+ }
+
switch (task->type) {
case IREE_TASK_TYPE_NOP:
// Doesn't do anything; just retire and continue on to any dependents.
diff --git a/iree/task/task.c b/iree/task/task.c
index 02cb20d..537fc77 100644
--- a/iree/task/task.c
+++ b/iree/task/task.c
@@ -326,33 +326,18 @@
iree_task_submission_t* pending_submission) {
IREE_TRACE_ZONE_BEGIN(z0);
- // If the scope has been marked as failing then we abort the barrier.
- // This needs to happen as a poll here because one or more of the tasks we
- // are joining may have failed.
- const bool has_failed = iree_task_scope_has_failed(task->header.scope);
- if (has_failed) {
- // This was the last pending dependency and we know that we can safely
- // abort the completion task by discarding.
- iree_task_list_t discard_worklist;
- iree_task_list_initialize(&discard_worklist);
- iree_task_barrier_discard(task, &discard_worklist);
- iree_task_list_discard(&discard_worklist);
- } else {
- // NOTE: we walk in reverse so that we enqueue in LIFO order.
- for (iree_host_size_t i = 0; i < task->dependent_task_count; ++i) {
- iree_task_t* dependent_task =
- task->dependent_tasks[task->dependent_task_count - i - 1];
- if (iree_atomic_fetch_sub_int32(&dependent_task->pending_dependency_count,
- 1, iree_memory_order_acq_rel) == 1) {
- // The dependent task has retired and can now be made ready.
- iree_task_submission_enqueue(pending_submission, dependent_task);
- }
+ // NOTE: we walk in reverse so that we enqueue in LIFO order.
+ for (iree_host_size_t i = 0; i < task->dependent_task_count; ++i) {
+ iree_task_t* dependent_task =
+ task->dependent_tasks[task->dependent_task_count - i - 1];
+ if (iree_atomic_fetch_sub_int32(&dependent_task->pending_dependency_count,
+ 1, iree_memory_order_acq_rel) == 1) {
+ // The dependent task has retired and can now be made ready.
+ iree_task_submission_enqueue(pending_submission, dependent_task);
}
}
- iree_task_retire(&task->header, pending_submission,
- has_failed ? iree_status_from_code(IREE_STATUS_ABORTED)
- : iree_ok_status());
+ iree_task_retire(&task->header, pending_submission, iree_ok_status());
IREE_TRACE_ZONE_END(z0);
}
diff --git a/iree/task/task_test_fence.cc b/iree/task/task_test_fence.cc
index 95ddea4..6626378 100644
--- a/iree/task/task_test_fence.cc
+++ b/iree/task/task_test_fence.cc
@@ -11,8 +11,13 @@
namespace {
+using iree::Status;
+using iree::StatusCode;
+using iree::testing::status::StatusIs;
+
class TaskFenceTest : public TaskTest {};
+// Tests a chain of fences A -> B -> C.
TEST_F(TaskFenceTest, IssueChained) {
iree_task_fence_t task_a;
iree_task_fence_initialize(&scope_, &task_a);
@@ -28,4 +33,46 @@
IREE_ASSERT_OK(SubmitTasksAndWaitIdle(&task_a.header, &task_c.header));
}
+// Tests that failures propagate through fences; task B should not be called.
+// A fails -> fence -> B
+TEST_F(TaskFenceTest, IssueChainedFailure) {
+ int did_call_a = 0;
+ iree_task_call_t task_a;
+ iree_task_call_initialize(&scope_,
+ iree_task_make_call_closure(
+ [](void* user_context, iree_task_t* task,
+ iree_task_submission_t* pending_submission) {
+ int* did_call_ptr = (int*)user_context;
+ ++(*did_call_ptr);
+ return iree_make_status(IREE_STATUS_DATA_LOSS,
+ "whoops!");
+ },
+ &did_call_a),
+ &task_a);
+
+ iree_task_fence_t fence_task;
+ iree_task_fence_initialize(&scope_, &fence_task);
+ iree_task_set_completion_task(&task_a.header, &fence_task.header);
+
+ int did_call_b = 0;
+ iree_task_call_t task_b;
+ iree_task_call_initialize(&scope_,
+ iree_task_make_call_closure(
+ [](void* user_context, iree_task_t* task,
+ iree_task_submission_t* pending_submission) {
+ int* did_call_ptr = (int*)user_context;
+ ++(*did_call_ptr);
+ return iree_ok_status();
+ },
+ &did_call_b),
+ &task_b);
+ iree_task_set_completion_task(&fence_task.header, &task_b.header);
+
+ IREE_ASSERT_OK(SubmitTasksAndWaitIdle(&task_a.header, &task_b.header));
+ EXPECT_EQ(1, did_call_a);
+ EXPECT_EQ(0, did_call_b);
+ EXPECT_THAT(Status(iree_task_scope_consume_status(&scope_)),
+ StatusIs(StatusCode::kDataLoss));
+}
+
} // namespace