| // Copyright 2019 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #ifndef IREE_HAL_CTS_CTS_TEST_BASE_H_ |
| #define IREE_HAL_CTS_CTS_TEST_BASE_H_ |
| |
| #include <set> |
| #include <string> |
| #include <string_view> |
| |
| #include "iree/base/api.h" |
| #include "iree/base/string_view.h" |
| #include "iree/hal/api.h" |
| #include "iree/testing/gtest.h" |
| #include "iree/testing/status_matchers.h" |
| |
| namespace iree { |
| namespace hal { |
| namespace cts { |
| |
| // Returns the name of the driver under test. |
| // Leaf test binaries must implement this function. |
| const char* get_test_driver_name(); |
| |
| // Registers the driver referenced by get_test_driver_name. |
| // Leaf test binaries must implement this function. |
| iree_status_t register_test_driver(iree_hal_driver_registry_t* registry); |
| |
| // Returns the executable format for the driver under test. |
| // Leaf test binaries must implement this function. |
| const char* get_test_executable_format(); |
| |
| // Returns a file's executable data for the driver under test. |
| // Leaf test binaries must implement this function. |
| iree_const_byte_span_t get_test_executable_data(iree_string_view_t file_name); |
| |
| enum class RecordingType { |
| kDirect = 0, |
| kIndirect, |
| }; |
| |
| struct GenerateTestName { |
| template <typename ParamType> |
| std::string operator()( |
| const ::testing::TestParamInfo<ParamType>& info) const { |
| switch (info.param) { |
| default: |
| return ""; |
| case RecordingType::kDirect: |
| return "direct"; |
| case RecordingType::kIndirect: |
| return "indirect"; |
| } |
| } |
| }; |
| |
| // Gets a HAL driver with the provided name, if available. |
| static iree_status_t TryGetDriver(const std::string& driver_name, |
| iree_hal_driver_t** out_driver) { |
| static std::set<std::string> unavailable_driver_names; |
| |
| // If creation failed before, don't try again. |
| if (unavailable_driver_names.find(driver_name) != |
| unavailable_driver_names.end()) { |
| return iree_make_status(IREE_STATUS_UNAVAILABLE, "driver unavailable"); |
| } |
| |
| // No existing driver, attempt to create. |
| iree_hal_driver_t* driver = NULL; |
| iree_status_t status = iree_hal_driver_registry_try_create( |
| iree_hal_driver_registry_default(), |
| iree_make_string_view(driver_name.data(), driver_name.size()), |
| iree_allocator_system(), &driver); |
| if (iree_status_is_unavailable(status)) { |
| unavailable_driver_names.insert(driver_name); |
| } |
| if (iree_status_is_ok(status)) { |
| *out_driver = driver; |
| } |
| return status; |
| } |
| |
| // Statics available in CTSTestBase without template magic. |
| // Note that this header is intended to be included in a single .cc so we can |
| // define the static member storage here. |
| class CTSTestResources { |
| public: |
| static iree_hal_driver_t* driver_; |
| static iree_hal_device_t* device_; |
| static iree_hal_allocator_t* device_allocator_; |
| }; |
| /*static*/ iree_hal_driver_t* CTSTestResources::driver_ = NULL; |
| /*static*/ iree_hal_device_t* CTSTestResources::device_ = NULL; |
| /*static*/ iree_hal_allocator_t* CTSTestResources::device_allocator_ = NULL; |
| |
| // Common setup for tests parameterized on driver names. |
| template <typename BaseType = ::testing::Test> |
| class CTSTestBase : public BaseType, public CTSTestResources { |
| public: |
| static void SetUpTestSuite() { |
| iree_status_t status = |
| register_test_driver(iree_hal_driver_registry_default()); |
| if (iree_status_is_already_exists(status)) { |
| status = iree_status_ignore(status); |
| } |
| IREE_CHECK_OK(status); |
| |
| // Get driver with the given name and create its default device. |
| // Skip drivers that are (gracefully) unavailable, fail if creation fails. |
| iree_hal_driver_t* driver = NULL; |
| status = TryGetDriver(get_test_driver_name(), &driver); |
| if (iree_status_is_unavailable(status)) { |
| iree_status_ignore(status); |
| return; |
| } |
| IREE_CHECK_OK(status); |
| driver_ = driver; |
| |
| iree_hal_device_t* device = NULL; |
| status = iree_hal_driver_create_default_device( |
| driver_, iree_allocator_system(), &device); |
| if (iree_status_is_unavailable(status)) { |
| iree_status_ignore(status); |
| return; |
| } |
| IREE_CHECK_OK(status); |
| device_ = device; |
| |
| device_allocator_ = iree_hal_device_allocator(device_); |
| iree_hal_allocator_retain(device_allocator_); |
| } |
| |
| static void TearDownTestSuite() { |
| if (device_allocator_) { |
| iree_hal_allocator_release(device_allocator_); |
| device_allocator_ = NULL; |
| } |
| if (device_) { |
| iree_hal_device_release(device_); |
| device_ = NULL; |
| } |
| if (driver_) { |
| iree_hal_driver_release(driver_); |
| driver_ = NULL; |
| } |
| } |
| |
| virtual void SetUp() { |
| if (!driver_) { |
| GTEST_SKIP() << "Skipping test as '" << get_test_driver_name() |
| << "' driver is unavailable"; |
| return; |
| } |
| if (!device_) { |
| GTEST_SKIP() << "Skipping test as default device for '" |
| << get_test_driver_name() << "' driver is unavailable"; |
| return; |
| } |
| } |
| |
| virtual void TearDown() {} |
| |
| void CreateUninitializedDeviceBuffer(iree_device_size_t buffer_size, |
| iree_hal_buffer_t** out_buffer) { |
| iree_hal_buffer_params_t params = {0}; |
| params.type = |
| IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | IREE_HAL_MEMORY_TYPE_HOST_VISIBLE; |
| params.usage = |
| IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE | IREE_HAL_BUFFER_USAGE_TRANSFER; |
| iree_hal_buffer_t* device_buffer = NULL; |
| IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer( |
| iree_hal_device_allocator(device_), params, buffer_size, |
| &device_buffer)); |
| *out_buffer = device_buffer; |
| } |
| |
| void CreateZeroedDeviceBuffer(iree_device_size_t buffer_size, |
| iree_hal_buffer_t** out_buffer) { |
| iree_hal_buffer_params_t params = {0}; |
| params.type = |
| IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | IREE_HAL_MEMORY_TYPE_HOST_VISIBLE; |
| params.usage = IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE | |
| IREE_HAL_BUFFER_USAGE_TRANSFER | |
| IREE_HAL_BUFFER_USAGE_MAPPING; |
| iree_hal_buffer_t* device_buffer = NULL; |
| IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer( |
| iree_hal_device_allocator(device_), params, buffer_size, |
| &device_buffer)); |
| IREE_ASSERT_OK( |
| iree_hal_buffer_map_zero(device_buffer, 0, IREE_HAL_WHOLE_BUFFER)); |
| *out_buffer = device_buffer; |
| } |
| |
| template <typename PatternType> |
| void CreateFilledDeviceBuffer(iree_device_size_t buffer_size, |
| PatternType pattern, |
| iree_hal_buffer_t** out_buffer) { |
| iree_hal_buffer_params_t params = {0}; |
| params.type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE | |
| IREE_HAL_MEMORY_TYPE_HOST_VISIBLE; |
| params.usage = IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE | |
| IREE_HAL_BUFFER_USAGE_TRANSFER | |
| IREE_HAL_BUFFER_USAGE_MAPPING; |
| iree_hal_buffer_t* device_buffer = NULL; |
| IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer( |
| iree_hal_device_allocator(device_), params, buffer_size, |
| &device_buffer)); |
| IREE_ASSERT_OK(iree_hal_buffer_map_fill( |
| device_buffer, 0, IREE_HAL_WHOLE_BUFFER, &pattern, sizeof(pattern))); |
| *out_buffer = device_buffer; |
| } |
| |
| // Submits |command_buffer| to the device and waits for it to complete before |
| // returning. |
| iree_status_t SubmitCommandBufferAndWait( |
| iree_hal_command_buffer_t* command_buffer, |
| iree_hal_buffer_binding_table_t binding_table = |
| iree_hal_buffer_binding_table_empty()) { |
| // No wait semaphores. |
| iree_hal_semaphore_list_t wait_semaphores = iree_hal_semaphore_list_empty(); |
| |
| // One signal semaphore from 0 -> 1. |
| iree_hal_semaphore_t* signal_semaphore = NULL; |
| IREE_RETURN_IF_ERROR(iree_hal_semaphore_create( |
| device_, 0ull, IREE_HAL_SEMAPHORE_FLAG_NONE, &signal_semaphore)); |
| uint64_t target_payload_value = 1ull; |
| iree_hal_semaphore_list_t signal_semaphores = { |
| /*count=*/1, |
| /*semaphores=*/&signal_semaphore, |
| /*payload_values=*/&target_payload_value, |
| }; |
| |
| iree_status_t status = iree_hal_device_queue_execute( |
| device_, IREE_HAL_QUEUE_AFFINITY_ANY, wait_semaphores, |
| signal_semaphores, command_buffer, binding_table); |
| if (iree_status_is_ok(status)) { |
| status = iree_hal_semaphore_wait(signal_semaphore, target_payload_value, |
| iree_infinite_timeout()); |
| } |
| |
| iree_hal_semaphore_release(signal_semaphore); |
| return status; |
| } |
| |
| iree_hal_command_buffer_t* CreateEmptyCommandBuffer( |
| iree_host_size_t binding_capacity = 0) { |
| iree_hal_command_buffer_t* command_buffer = NULL; |
| IREE_EXPECT_OK(iree_hal_command_buffer_create( |
| device_, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT, |
| IREE_HAL_COMMAND_CATEGORY_DISPATCH, IREE_HAL_QUEUE_AFFINITY_ANY, |
| binding_capacity, &command_buffer)); |
| IREE_EXPECT_OK(iree_hal_command_buffer_begin(command_buffer)); |
| IREE_EXPECT_OK(iree_hal_command_buffer_end(command_buffer)); |
| return command_buffer; |
| } |
| |
| iree_hal_semaphore_t* CreateSemaphore() { |
| iree_hal_semaphore_t* semaphore = NULL; |
| IREE_EXPECT_OK(iree_hal_semaphore_create( |
| device_, 0, IREE_HAL_SEMAPHORE_FLAG_NONE, &semaphore)); |
| return semaphore; |
| } |
| |
| void CheckSemaphoreValue(iree_hal_semaphore_t* semaphore, |
| uint64_t expected_value) { |
| uint64_t value; |
| IREE_EXPECT_OK(iree_hal_semaphore_query(semaphore, &value)); |
| EXPECT_EQ(expected_value, value); |
| } |
| |
| // Check that a contains b. |
| // That is the codes of a and b are equal and the message of b is contained |
| // in the message of a. |
| void CheckStatusContains(iree_status_t a, iree_status_t b) { |
| EXPECT_EQ(iree_status_code(a), iree_status_code(b)); |
| iree_allocator_t allocator = iree_allocator_system(); |
| char* a_str = NULL; |
| iree_host_size_t a_str_length = 0; |
| EXPECT_TRUE(iree_status_to_string(a, &allocator, &a_str, &a_str_length)); |
| char* b_str = NULL; |
| iree_host_size_t b_str_length = 0; |
| EXPECT_TRUE(iree_status_to_string(b, &allocator, &b_str, &b_str_length)); |
| EXPECT_TRUE(std::string_view(a_str).find(std::string_view(b_str)) != |
| std::string_view::npos); |
| iree_allocator_free(allocator, a_str); |
| iree_allocator_free(allocator, b_str); |
| } |
| }; |
| |
| } // namespace cts |
| } // namespace hal |
| } // namespace iree |
| |
| #endif // IREE_HAL_CTS_CTS_TEST_BASE_H_ |