Start on a HAL CTS for easier HAL development and testing.
PiperOrigin-RevId: 277134730
diff --git a/iree/hal/cts/BUILD b/iree/hal/cts/BUILD
new file mode 100644
index 0000000..118b2f6
--- /dev/null
+++ b/iree/hal/cts/BUILD
@@ -0,0 +1,54 @@
+# Conformance Test Suite (CTS) for HAL implementations.
+
+load("//iree:build_defs.bzl", "PLATFORM_VULKAN_TEST_DEPS")
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+cc_library(
+ name = "cts_test_base",
+ testonly = True,
+ hdrs = ["cts_test_base.h"],
+ data = [
+ # When building with --config=asan you must specify the following
+ # envvar when using Vulkan + a local Nvidia GPU:
+ # LSAN_OPTIONS=suppressions=third_party/iree/tools/sanitizer_suppressions.txt
+ "//iree/tools:sanitizer_suppressions.txt",
+ ],
+ deps = [
+ "//iree/base:status",
+ "//iree/base:status_matchers",
+ "//iree/hal:driver_registry",
+
+ # HAL driver modules.
+ "//iree/hal/interpreter:interpreter_driver_module", # build-cleaner: keep
+ "//iree/hal/vulkan:vulkan_driver_module", # build-cleaner: keep
+ # "//iree/hal/dawn:dawn_driver_module", # build-cleaner: keep
+ ] + PLATFORM_VULKAN_TEST_DEPS,
+)
+
+cc_test(
+ name = "allocator_test",
+ srcs = ["allocator_test.cc"],
+ deps = [
+ ":cts_test_base",
+ "//iree/base:status",
+ "//iree/base:status_matchers",
+ "//iree/hal:driver_registry",
+ "//iree/testing:gtest",
+ ],
+)
+
+cc_test(
+ name = "command_buffer_test",
+ srcs = ["command_buffer_test.cc"],
+ deps = [
+ ":cts_test_base",
+ "//iree/base:status",
+ "//iree/base:status_matchers",
+ "//iree/hal:driver_registry",
+ "//iree/testing:gtest",
+ ],
+)
diff --git a/iree/hal/cts/allocator_test.cc b/iree/hal/cts/allocator_test.cc
new file mode 100644
index 0000000..48bd2ad
--- /dev/null
+++ b/iree/hal/cts/allocator_test.cc
@@ -0,0 +1,75 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/base/status.h"
+#include "iree/base/status_matchers.h"
+#include "iree/hal/cts/cts_test_base.h"
+#include "iree/hal/driver_registry.h"
+#include "iree/testing/gtest.h"
+
+namespace iree {
+namespace hal {
+namespace cts {
+
+class AllocatorTest : public CtsTestBase {
+ protected:
+ virtual void SetUp() {
+ CtsTestBase::SetUp();
+
+ if (!device_) {
+ return;
+ }
+
+ allocator_ = device_->allocator();
+ }
+
+ Allocator* allocator_ = nullptr;
+};
+
+TEST_P(AllocatorTest, CanAllocate) {
+ EXPECT_TRUE(allocator_->CanAllocate(
+ MemoryType::kHostLocal | MemoryType::kDeviceVisible,
+ BufferUsage::kMapping, 1024));
+ EXPECT_TRUE(allocator_->CanAllocate(
+ MemoryType::kHostVisible | MemoryType::kDeviceLocal,
+ BufferUsage::kMapping, 1024));
+
+ // TODO(scotttodd): Minimum memory types and buffer usages necessary for use
+ // TODO(scotttodd): Test upper limits of memory size for allocations (1GB+)?
+}
+
+TEST_P(AllocatorTest, Allocate) {
+ MemoryType memory_type = MemoryType::kHostLocal | MemoryType::kDeviceVisible;
+ BufferUsage usage = BufferUsage::kMapping;
+ size_t allocation_size = 1024;
+
+ ASSERT_OK_AND_ASSIGN(
+ auto buffer, allocator_->Allocate(memory_type, usage, allocation_size));
+
+ EXPECT_EQ(allocator_, buffer->allocator());
+ // At a mimimum, the requested memory type should be respected.
+ // Additional bits may be optionally set depending on the allocator.
+ EXPECT_TRUE((buffer->memory_type() & memory_type) == memory_type);
+ EXPECT_TRUE((buffer->usage() & usage) == usage);
+ EXPECT_GE(buffer->allocation_size(), allocation_size); // Larger is okay.
+}
+
+INSTANTIATE_TEST_SUITE_P(AllDrivers, AllocatorTest,
+ ::testing::ValuesIn(DriverRegistry::shared_registry()
+ ->EnumerateAvailableDrivers()),
+ GenerateTestName());
+
+} // namespace cts
+} // namespace hal
+} // namespace iree
diff --git a/iree/hal/cts/command_buffer_test.cc b/iree/hal/cts/command_buffer_test.cc
new file mode 100644
index 0000000..018dc3c
--- /dev/null
+++ b/iree/hal/cts/command_buffer_test.cc
@@ -0,0 +1,49 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/base/status.h"
+#include "iree/base/status_matchers.h"
+#include "iree/hal/cts/cts_test_base.h"
+#include "iree/hal/driver_registry.h"
+#include "iree/testing/gtest.h"
+
+namespace iree {
+namespace hal {
+namespace cts {
+
+class CommandBufferTest : public CtsTestBase {};
+
+TEST_P(CommandBufferTest, CreateCommandBuffer) {
+ ASSERT_OK_AND_ASSIGN(auto command_buffer, device_->CreateCommandBuffer(
+ CommandBufferMode::kOneShot,
+ CommandCategory::kDispatch));
+
+ EXPECT_EQ(device_->allocator(), command_buffer->allocator());
+ EXPECT_TRUE((command_buffer->mode() & CommandBufferMode::kOneShot) ==
+ CommandBufferMode::kOneShot);
+ EXPECT_TRUE((command_buffer->command_categories() &
+ CommandCategory::kDispatch) == CommandCategory::kDispatch);
+ EXPECT_FALSE(command_buffer->is_recording());
+}
+
+// TODO(scotttodd): Begin, End, UpdateBuffer, CopyBuffer, Dispatch, Sync, etc.
+
+INSTANTIATE_TEST_SUITE_P(AllDrivers, CommandBufferTest,
+ ::testing::ValuesIn(DriverRegistry::shared_registry()
+ ->EnumerateAvailableDrivers()),
+ GenerateTestName());
+
+} // namespace cts
+} // namespace hal
+} // namespace iree
diff --git a/iree/hal/cts/cts_test_base.h b/iree/hal/cts/cts_test_base.h
new file mode 100644
index 0000000..f199314
--- /dev/null
+++ b/iree/hal/cts/cts_test_base.h
@@ -0,0 +1,65 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_CTS_CTS_TEST_BASE_H_
+#define IREE_HAL_CTS_CTS_TEST_BASE_H_
+
+#include "iree/base/status.h"
+#include "iree/base/status_matchers.h"
+#include "iree/hal/driver_registry.h"
+#include "iree/testing/gtest.h"
+
+namespace iree {
+namespace hal {
+namespace cts {
+
+// Common setup for tests parameterized across all registered drivers.
+class CtsTestBase : public ::testing::TestWithParam<std::string> {
+ protected:
+ virtual void SetUp() {
+ const std::string& driver_name = GetParam();
+
+ // Create driver with the given name and create its default device.
+ // Skip drivers that are (gracefully) unavailable, fail if creation fails.
+ LOG(INFO) << "Creating driver '" << driver_name << "'...";
+ auto driver_or = DriverRegistry::shared_registry()->Create(driver_name);
+ if (IsUnavailable(driver_or.status())) {
+ LOG(WARNING) << "Skipping test as driver is unavailable: "
+ << driver_or.status();
+ GTEST_SKIP();
+ return;
+ }
+ ASSERT_OK_AND_ASSIGN(driver_, driver_or);
+ LOG(INFO) << "Creating default device...";
+ ASSERT_OK_AND_ASSIGN(device_, driver_->CreateDefaultDevice());
+ LOG(INFO) << "Created device '" << device_->info().name() << "'";
+ }
+
+ std::shared_ptr<Driver> driver_;
+ std::shared_ptr<Device> device_;
+};
+
+struct GenerateTestName {
+ template <class ParamType>
+ std::string operator()(
+ const ::testing::TestParamInfo<ParamType>& info) const {
+ return info.param;
+ }
+};
+
+} // namespace cts
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_CTS_CTS_TEST_BASE_H_