[runtime] Add semaphore test where 2 batches wait on a former batch amongst 2 (#17080)
With this this test I added some functions to make the test body
shorter.
Unfortunately, we can't put gtest asserts in functions that have a
non-void return. Because of that in the functions I am using
`IREE_EXPECT*` instead of `IREE_ASSERT*`. This may cause tests other
than the one that is failing to fail also, as the whole executable may
crash if for example a command buffer fails to be created and then we
get a segmentation fault when trying to use it.
I thought this is an acceptable sacrifice to avoid repetition and to
have shorter and more readable tests.
diff --git a/runtime/src/iree/hal/cts/cts_test_base.h b/runtime/src/iree/hal/cts/cts_test_base.h
index ce58832..cb6eb01 100644
--- a/runtime/src/iree/hal/cts/cts_test_base.h
+++ b/runtime/src/iree/hal/cts/cts_test_base.h
@@ -125,6 +125,30 @@
return status;
}
+ iree_hal_command_buffer_t* CreateEmptyCommandBuffer() {
+ iree_hal_command_buffer_t* command_buffer = NULL;
+ IREE_EXPECT_OK(iree_hal_command_buffer_create(
+ device_, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT,
+ IREE_HAL_COMMAND_CATEGORY_DISPATCH, IREE_HAL_QUEUE_AFFINITY_ANY,
+ /*binding_capacity=*/0, &command_buffer));
+ IREE_EXPECT_OK(iree_hal_command_buffer_begin(command_buffer));
+ IREE_EXPECT_OK(iree_hal_command_buffer_end(command_buffer));
+ return command_buffer;
+ }
+
+ iree_hal_semaphore_t* CreateSemaphore() {
+ iree_hal_semaphore_t* semaphore = NULL;
+ IREE_EXPECT_OK(iree_hal_semaphore_create(device_, 0, &semaphore));
+ return semaphore;
+ }
+
+ void CheckSemaphoreValue(iree_hal_semaphore_t* semaphore,
+ uint64_t expected_value) {
+ uint64_t value;
+ IREE_EXPECT_OK(iree_hal_semaphore_query(semaphore, &value));
+ EXPECT_EQ(expected_value, value);
+ }
+
iree_hal_driver_t* driver_ = NULL;
iree_hal_device_t* device_ = NULL;
iree_hal_allocator_t* device_allocator_ = NULL;
diff --git a/runtime/src/iree/hal/cts/semaphore_submission_test.h b/runtime/src/iree/hal/cts/semaphore_submission_test.h
index ed7f9d8..5b42971 100644
--- a/runtime/src/iree/hal/cts/semaphore_submission_test.h
+++ b/runtime/src/iree/hal/cts/semaphore_submission_test.h
@@ -544,9 +544,84 @@
iree_hal_semaphore_release(semaphore2);
}
-// TODO: test device -> device synchronization: submit multiple batches with
+// Test device -> device synchronization: submit multiple batches with
// multiple later batches waiting on the same signaling from a former batch.
-//
+TEST_P(semaphore_submission_test, TwoBatchesWaitingOn1FormerBatchAmongst2) {
+ // The signaling-wait relation is:
+ // command_buffer11 command_buffer12
+ // ↓
+ // semaphore11
+ // ↙ ↘
+ // command_buffer21 command_buffer22
+ // ↓ ↓
+ // semaphore21 semaphore22
+
+ iree_hal_command_buffer_t* command_buffer11 = CreateEmptyCommandBuffer();
+ iree_hal_command_buffer_t* command_buffer12 = CreateEmptyCommandBuffer();
+ iree_hal_command_buffer_t* command_buffer21 = CreateEmptyCommandBuffer();
+ iree_hal_command_buffer_t* command_buffer22 = CreateEmptyCommandBuffer();
+ iree_hal_semaphore_t* semaphore11 = CreateSemaphore();
+ iree_hal_semaphore_t* semaphore21 = CreateSemaphore();
+ iree_hal_semaphore_t* semaphore22 = CreateSemaphore();
+
+ // All semaphores start from value 0 and reach 1.
+ uint64_t semaphore_signal_wait_value = 1;
+ iree_hal_semaphore_list_t semaphore11_list = {/*count=*/1, &semaphore11,
+ &semaphore_signal_wait_value};
+ iree_hal_semaphore_list_t semaphore21_list = {/*count=*/1, &semaphore21,
+ &semaphore_signal_wait_value};
+ iree_hal_semaphore_list_t semaphore22_list = {/*count=*/1, &semaphore22,
+ &semaphore_signal_wait_value};
+ iree_hal_semaphore_list_t empty_semaphore_list{/*count=*/0, nullptr, nullptr};
+
+ // We submit the command buffers in reverse order.
+ IREE_ASSERT_OK(iree_hal_device_queue_execute(
+ device_, IREE_HAL_QUEUE_AFFINITY_ANY,
+ /*wait_semaphore_list=*/semaphore11_list,
+ /*signal_semaphore_list=*/semaphore22_list, 1, &command_buffer22));
+ IREE_ASSERT_OK(iree_hal_device_queue_execute(
+ device_, IREE_HAL_QUEUE_AFFINITY_ANY,
+ /*wait_semaphore_list=*/semaphore11_list,
+ /*signal_semaphore_list=*/semaphore21_list, 1, &command_buffer21));
+ IREE_ASSERT_OK(iree_hal_device_queue_execute(
+ device_, IREE_HAL_QUEUE_AFFINITY_ANY,
+ /*wait_semaphore_list=*/empty_semaphore_list,
+ /*signal_semaphore_list=*/empty_semaphore_list, 1, &command_buffer12));
+
+ // Assert that semaphores have not advance since we have not yet submitted
+ // command_buffer11.
+ CheckSemaphoreValue(semaphore11, 0);
+ CheckSemaphoreValue(semaphore21, 0);
+ CheckSemaphoreValue(semaphore22, 0);
+
+ // Submit command_buffer11.
+ IREE_ASSERT_OK(iree_hal_device_queue_execute(
+ device_, IREE_HAL_QUEUE_AFFINITY_ANY,
+ /*wait_semaphore_list=*/empty_semaphore_list,
+ /*signal_semaphore_list=*/semaphore11_list, 1, &command_buffer11));
+
+ // Wait and check that semaphore values have advanced.
+ IREE_ASSERT_OK(
+ iree_hal_semaphore_wait(semaphore21, semaphore_signal_wait_value,
+ iree_make_deadline(IREE_TIME_INFINITE_FUTURE)));
+ CheckSemaphoreValue(semaphore21, semaphore_signal_wait_value);
+ // semaphore11 must have also advanced because semaphore21 has advanced.
+ CheckSemaphoreValue(semaphore11, semaphore_signal_wait_value);
+
+ IREE_ASSERT_OK(
+ iree_hal_semaphore_wait(semaphore22, semaphore_signal_wait_value,
+ iree_make_deadline(IREE_TIME_INFINITE_FUTURE)));
+ CheckSemaphoreValue(semaphore22, semaphore_signal_wait_value);
+
+ iree_hal_semaphore_release(semaphore11);
+ iree_hal_semaphore_release(semaphore21);
+ iree_hal_semaphore_release(semaphore22);
+ iree_hal_command_buffer_release(command_buffer11);
+ iree_hal_command_buffer_release(command_buffer12);
+ iree_hal_command_buffer_release(command_buffer21);
+ iree_hal_command_buffer_release(command_buffer22);
+}
+
// TODO: test device -> device synchronization: submit multiple batches with
// a former batch signaling a value greater than all other batches' (different)
// wait values.