[metal] Initial bring up of Metal HAL driver (1/n) (#3045)
This commit starts a Metal HAL driver. It registers the Metal HAL
driver and implements the following functionalities:
* All hal::Driver APIs
* hal::Device APIs for creating queues/buffers/semaphores
and waiting on semaphores
* All hal::CommandQueue APIs
* All hal::Semaphore APIs
* hal::CommandBuffer APIs for lifetime
All other APIs will return unimplemented error for now.
hal/cts/ tests are enhanced to cover hal::Driver APIs and
more semaphore wait APIs. Existing hal/cts/ tests are all passing,
except allocator_test, which is not yet implemented at the moment.
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 76fa35d..a410308 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -87,6 +87,7 @@
set(IREE_ALL_HAL_DRIVERS
DyLib
LLVM
+ Metal
VMLA
Vulkan
)
@@ -98,6 +99,14 @@
# TODO(ataei): Enable dylib/dylib-llvm-aot for android.
list(REMOVE_ITEM IREE_HAL_DRIVERS_TO_BUILD LLVM DyLib)
endif()
+
+ # For Apple platforms we need to use Metal instead of Vulkan.
+ if(APPLE)
+ list(REMOVE_ITEM IREE_HAL_DRIVERS_TO_BUILD Vulkan)
+ else()
+ # And Metal isn't available on non-Apple platforms for sure.
+ list(REMOVE_ITEM IREE_HAL_DRIVERS_TO_BUILD Metal)
+ endif()
endif()
message(STATUS "Building HAL drivers: ${IREE_HAL_DRIVERS_TO_BUILD}")
@@ -121,6 +130,9 @@
if(${IREE_HAL_DRIVER_LLVM})
list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::llvmjit::llvmjit_driver_module)
endif()
+if(${IREE_HAL_DRIVER_METAL})
+ list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::metal::metal_driver_module)
+endif()
if(${IREE_HAL_DRIVER_VMLA})
list(APPEND IREE_HAL_DRIVER_MODULES iree::hal::vmla::vmla_driver_module)
endif()
diff --git a/docs/design_docs/metal_hal_driver.md b/docs/design_docs/metal_hal_driver.md
new file mode 100644
index 0000000..9c2b6c4
--- /dev/null
+++ b/docs/design_docs/metal_hal_driver.md
@@ -0,0 +1,148 @@
+# Metal HAL Driver
+
+This document lists technical details regarding the Metal HAL driver. Note that
+the Metal HAL driver is working in progress; this document is expected to be
+updated along the way.
+
+IREE provides a [Hardware Abstraction Layer (HAL)][iree-hal] as a common
+interface to different compute accelerators. IREE HAL's design draws inspiration
+from modern GPU architecture and APIs; so implementing a HAL driver using modern
+GPU APIs is generally straightforward. This applies to the Metal HAL driver.
+
+## Overall Design Choices
+
+### Metal Versions
+
+The Metal HAL driver expects Metal 2+. Metal 2 introduces useful features like
+argument buffer, performance shaders, and others, that can improve performance
+and make IREE HAL implementation simpler. Metal 2 was released late 2017 and
+are supported since macOS High Sierra and iOS 11. It is already dominant
+([macOS][macos-version-share], [iOS][ios-version-share]) right now.
+
+### Programming Languages and Libraries
+
+The Metal HAL driver lives under the [`iree/hal/metal/`][iree-metal] directory.
+Header (`.h`) and implementation (`.mm`) files are put adjacent to each other.
+
+The Metal framework only exposes Objective-C or Swift programming language APIs.
+Metal HAL driver needs to inherit from common HAL abstraction classes, which are
+C++. So we use [Objective-C++][objcxx] for implementing the Metal HAL driver.
+The headers try to stay with pure C/C++ syntax when possible, except for
+`#import <Metal/Metal.h>` and using Metal `id` types.
+
+### Object Lifetime Management
+
+Objective-C uses refcount for tracking object lifetime and managing memory.
+This is traditionally done manually by sending `retain` and `release` messages
+to Objective-C objects. Modern Objective-C allows developers to opt in to use
+[Automatic Reference Counting][objc-arc] to let the compiler to automatically
+deduce and insert `retain`/`release` where possible to simplify the burdern
+of manual management.
+
+We don't use ARC in the Metal HAL driver given that IREE has its own object
+[refcount][iree-refptr] and lifetime management mechanism. Metal HAL GPU objects
+are tracked with that to be consistent with others. Each Metal HAL GPU object
+`retain`s the underlying Metal `id<MTL*>` object on construction and `release`s
+on destruction.
+
+## GPU Objects
+
+Metal is one of the main modern GPU APIs that provide more explicit control over
+the hardware. The mapping between IREE HAL classes and Metal protocols are
+relatively straightforward:
+
+IREE HAL Class | Metal Protocol
+:-------------:|:-------------:
+[`hal::Driver`][hal-driver] | N/A
+[`hal::Device`][hal-device] | [`MTLDevice`][mtl-device]
+[`hal::CommandQueue`][hal-command-queue] | [`MTLCommandQueue`][mtl-command-queue]
+[`hal::CommandBuffer`][hal-command-buffer] | [`MTLCommandBuffer`][mtl-command-buffer]
+[`hal::Semaphore`][hal-semaphore] | [`MTLSharedEvent`][mtl-shared-event]
+
+In the following subsections, we go over each pair to provide more details.
+
+### Driver
+
+There is no native driver abstraction in Metal. IREE's Metal HAL driver still
+provides a [`hal::metal::MetalDriver`][metal-driver] subclass inheriting from
+common [`hal::Driver`][hal-driver] class. `hal::metal::MetalDriver` just
+`retain`s all available Metal devices in the system during its lifetime to
+provide similar interface as other HAL drivers.
+
+### Device
+
+[`hal::metal::MetalDevice`][metal-device] inherits [`hal::Device`][hal-device]
+to provide the interface to Metal GPU device by wrapping a `id<MTLDevice>`.
+Upon construction, `hal::metal::MetalDevice` creates and retains one queue for
+both dispatch and transfer during its lifetime.
+
+Metal requres command buffers to be created from a `MTLCommandQueue`. In IREE
+HAL, command buffers are directly created from the `hal::Device`.
+`hal::metal::MetalDevice` chooses the proper queue to create the command
+buffer under the hood.
+
+### Command queue
+
+IREE HAL command queue follows Vulkan for modelling submission. Specifically,
+`hal::CommandQueue::Submit()` takes a `SubmissionBatch`, which contains a list
+of waiting `hal::Semaphore`s, a list of command buffers, and a list signaling
+`hal::Semaphore`s. There is no direct mapping in Metal; so
+[`hal::metal::MetalCommandQueue`][metal-command-queue] performs the submission
+in three steps:
+
+1. Create a new `MTLCommandBuffer` to `encodeWaitForEvent:value` for all
+ waiting `hal::Semaphore`s and commit this command buffer.
+1. Commit all command buffers in the `SubmissionBatch`.
+1. Create a new `MTLCommandBuffer` to `encodeSignalEvent:value` for all
+ signaling `hal::Semaphore`s and commit this command buffer.
+
+There is also no direct `WaitIdle()` for
+[`MTLCommandQueue`][mtl-command-queue]s. `hal::metal::MetalCommandQueue`
+implements `WaitIdle()` by committing an empty `MTLCommandBuffer` and
+registering a complete handler for it to signal a semaphore to wake the current
+thread, which is put into sleep by waiting on the semaphore.
+
+### Command buffer
+
+In Metal, commands are recorded into a command buffer with three different kinds
+of [command encoders][mtl-command-encoder]: `MTLRenderCommandEncoder`,
+`MTLComputeCommandEncoder`, `MTLBlitCommandEncoder`, and
+`MTLParallelRenderCommandEncoder`. Each encoder has its own create/end call.
+There is no overall begin/end call for the whold command buffer. So even
+[`hal::metal::MetalCommandBuffer`][metal-command-buffer] implements an overall
+`Begin()`/`End()` call, under the hood it may create a new command encoder
+for a specific API call.
+
+### Timeline semaphore
+
+[`hal::Semaphore`][hal-semaphore] allows host->device, device->host, host->host,
+and device->device synchronization. It maps to Vulkan timeline semaphore.
+In Metal world, the counterpart would be [`MTLSharedEvent`][mtl-shared-event].
+Most of the `hal::Semaphore` APIs are simple to implement in
+[`MetalSharedEvent`][metal-shared-event], with `Wait()` as an exception.
+A listener is registered on the `MTLSharedEvent` with
+`notifyListener:atValue:block:` to singal a semaphore to wake the current
+thread, which is put into sleep by waiting on the semaphore.
+
+[macos-version-share]: https://gs.statcounter.com/macos-version-market-share/desktop/worldwide
+[ios-version-share]: https://developer.apple.com/support/app-store/
+[iree-hal]: https://github.com/google/iree/tree/main/iree/hal
+[iree-metal]: https://github.com/google/iree/tree/main/iree/hal/metal
+[iree-refptr]: https://github.com/google/iree/blob/main/iree/base/ref_ptr.h
+[hal-driver]: https://github.com/google/iree/blob/main/iree/hal/driver.h
+[hal-device]: https://github.com/google/iree/blob/main/iree/hal/device.h
+[hal-command-queue]: https://github.com/google/iree/blob/main/iree/hal/command_queue.h
+[hal-command-buffer]: https://github.com/google/iree/blob/main/iree/hal/command_buffer.h
+[hal-semaphore]: https://github.com/google/iree/blob/main/iree/hal/semaphore.h
+[metal-driver]: https://github.com/google/iree/blob/main/iree/hal/metal/metal_driver.h
+[metal-device]: https://github.com/google/iree/blob/main/iree/hal/metal/metal_device.h
+[metal-command-queue]: https://github.com/google/iree/blob/main/iree/hal/metal/metal_command_queue.h
+[metal-command-buffer]: https://github.com/google/iree/blob/main/iree/hal/metal/metal_command_buffer.h
+[metal-shared-event]: https://github.com/google/iree/blob/main/iree/hal/metal/metal_shared_event.h
+[mtl-device]: https://developer.apple.com/documentation/metal/mtldevice?language=objc
+[mtl-command-queue]: https://developer.apple.com/documentation/metal/mtlcommandqueue?language=objc
+[mtl-command-buffer]: https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc
+[mtl-command-encoder]: https://developer.apple.com/documentation/metal/mtlcommandencoder?language=objc
+[mtl-shared-event]: https://developer.apple.com/documentation/metal/mtlsharedevent?language=objc
+[objc-arc]: https://en.wikipedia.org/wiki/Automatic_Reference_Counting
+[objcxx]: https://en.wikipedia.org/wiki/Objective-C#Objective-C++
diff --git a/docs/get_started/getting_started_linux_vulkan.md b/docs/get_started/getting_started_linux_vulkan.md
index 5a0fd7f..8ccf6c1 100644
--- a/docs/get_started/getting_started_linux_vulkan.md
+++ b/docs/get_started/getting_started_linux_vulkan.md
@@ -58,16 +58,16 @@
HAL, which includes checking for supported layers and extensions.
Run the
-[device creation test](https://github.com/google/iree/blob/main/iree/hal/cts/device_creation_test.cc):
+[driver test](https://github.com/google/iree/blob/main/iree/hal/cts/driver_test.cc):
```shell
# -- CMake --
$ export VK_LOADER_DEBUG=all
-$ cmake --build build/ --target iree_hal_cts_device_creation_test
-$ ./build/iree/hal/cts/iree_hal_cts_device_creation_test
+$ cmake --build build/ --target iree_hal_cts_driver_test
+$ ./build/iree/hal/cts/iree_hal_cts_driver_test
# -- Bazel --
-$ bazel test iree/hal/cts:device_creation_test --test_env=VK_LOADER_DEBUG=all --test_output=all
+$ bazel test iree/hal/cts:driver_test --test_env=VK_LOADER_DEBUG=all --test_output=all
```
If these tests pass, you can skip down to the next section.
diff --git a/docs/get_started/getting_started_macos_bazel.md b/docs/get_started/getting_started_macos_bazel.md
index 3fc3dcb..1dce426 100644
--- a/docs/get_started/getting_started_macos_bazel.md
+++ b/docs/get_started/getting_started_macos_bazel.md
@@ -126,9 +126,11 @@
### Further Reading
* For an introduction to IREE's project structure and developer tools, see
- [Developer Overview](../developing_iree/developer_overview.md) <!-- TODO:
+ [Developer Overview](../developing_iree/developer_overview.md).
+* To understand how IREE implements HAL over Metal, see
+ [Metal HAL Driver](../design_docs/metal_hal_driver.md).
+<!-- TODO:
Link to macOS versions of these guides once they are developed.
-* To target GPUs using Vulkan, see
- [Getting Started on Linux with Vulkan](getting_started_linux_vulkan.md)
* To use IREE's Python bindings, see
- [Getting Started with Python](getting_started_python.md) -->
+ [Getting Started with Python](getting_started_python.md)
+-->
diff --git a/docs/get_started/getting_started_macos_cmake.md b/docs/get_started/getting_started_macos_cmake.md
index 7b916cd..b71284a 100644
--- a/docs/get_started/getting_started_macos_cmake.md
+++ b/docs/get_started/getting_started_macos_cmake.md
@@ -110,9 +110,11 @@
### Further Reading
* For an introduction to IREE's project structure and developer tools, see
- [Developer Overview](../developing_iree/developer_overview.md) <!-- TODO:
+ [Developer Overview](../developing_iree/developer_overview.md).
+* To understand how IREE implements HAL over Metal, see
+ [Metal HAL Driver](../design_docs/metal_hal_driver.md).
+<!-- TODO:
Link to macOS versions of these guides once they are developed.
-* To target GPUs using Vulkan, see
- [Getting Started on Linux with Vulkan](getting_started_linux_vulkan.md)
* To use IREE's Python bindings, see
- [Getting Started with Python](getting_started_python.md) -->
+ [Getting Started with Python](getting_started_python.md)
+-->
diff --git a/docs/get_started/getting_started_windows_vulkan.md b/docs/get_started/getting_started_windows_vulkan.md
index e263052..dcee4a1 100644
--- a/docs/get_started/getting_started_windows_vulkan.md
+++ b/docs/get_started/getting_started_windows_vulkan.md
@@ -58,16 +58,16 @@
HAL, which includes checking for supported layers and extensions.
Run the
-[device creation test](https://github.com/google/iree/blob/main/iree/hal/cts/device_creation_test.cc):
+[driver test](https://github.com/google/iree/blob/main/iree/hal/cts/driver_test.cc):
```powershell
# -- CMake --
> set VK_LOADER_DEBUG=all
-> cmake --build build\ --target iree_hal_cts_device_creation_test
-> .\build\iree\hal\cts\iree_hal_cts_device_creation_test.exe
+> cmake --build build\ --target iree_hal_cts_driver_test
+> .\build\iree\hal\cts\iree_hal_cts_driver_test.exe
# -- Bazel --
-> bazel test iree/hal/cts:device_creation_test --test_env=VK_LOADER_DEBUG=all --test_output=all
+> bazel test iree/hal/cts:driver_test --test_env=VK_LOADER_DEBUG=all --test_output=all
```
If these tests pass, you can skip down to the next section.
diff --git a/iree/hal/cts/BUILD b/iree/hal/cts/BUILD
index 6740480..445e5af 100644
--- a/iree/hal/cts/BUILD
+++ b/iree/hal/cts/BUILD
@@ -80,8 +80,8 @@
)
cc_test(
- name = "device_creation_test",
- srcs = ["device_creation_test.cc"],
+ name = "driver_test",
+ srcs = ["driver_test.cc"],
deps = [
":cts_test_base",
"//iree/hal:driver_registry",
diff --git a/iree/hal/cts/CMakeLists.txt b/iree/hal/cts/CMakeLists.txt
index 604d8ad..3de079d 100644
--- a/iree/hal/cts/CMakeLists.txt
+++ b/iree/hal/cts/CMakeLists.txt
@@ -71,9 +71,9 @@
iree_cc_test(
NAME
- device_creation_test
+ driver_test
SRCS
- "device_creation_test.cc"
+ "driver_test.cc"
DEPS
::cts_test_base
iree::hal::driver_registry
diff --git a/iree/hal/cts/device_creation_test.cc b/iree/hal/cts/driver_test.cc
similarity index 67%
rename from iree/hal/cts/device_creation_test.cc
rename to iree/hal/cts/driver_test.cc
index d8d516b..372657b 100644
--- a/iree/hal/cts/device_creation_test.cc
+++ b/iree/hal/cts/driver_test.cc
@@ -15,18 +15,28 @@
#include "iree/hal/cts/cts_test_base.h"
#include "iree/hal/driver_registry.h"
#include "iree/testing/gtest.h"
+#include "iree/testing/status_matchers.h"
namespace iree {
namespace hal {
namespace cts {
-class DeviceCreationTest : public CtsTestBase {};
+class DriverTest : public CtsTestBase {};
-TEST_P(DeviceCreationTest, CreateDevice) {
+TEST_P(DriverTest, CreateDefaultDevice) {
LOG(INFO) << "Device details:\n" << device_->DebugString();
}
-INSTANTIATE_TEST_SUITE_P(AllDrivers, DeviceCreationTest,
+TEST_P(DriverTest, EnumerateAndCreateAvailableDevices) {
+ IREE_ASSERT_OK_AND_ASSIGN(auto devices, driver_->EnumerateAvailableDevices());
+
+ for (int i = 0; i < devices.size(); ++i) {
+ IREE_ASSERT_OK_AND_ASSIGN(auto device, driver_->CreateDevice(devices[i]));
+ LOG(INFO) << "Device #" << i << " details:\n" << device->DebugString();
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(AllDrivers, DriverTest,
::testing::ValuesIn(DriverRegistry::shared_registry()
->EnumerateAvailableDrivers()),
GenerateTestName());
diff --git a/iree/hal/cts/semaphore_test.cc b/iree/hal/cts/semaphore_test.cc
index 185acb7..c4299ca 100644
--- a/iree/hal/cts/semaphore_test.cc
+++ b/iree/hal/cts/semaphore_test.cc
@@ -88,6 +88,39 @@
// Waiting on a failed semaphore is undefined behavior. Some backends may
// return UnknownError while others may succeed.
+// Waiting all semaphores but not all are signaled.
+TEST_P(SemaphoreTest, WaitAllButNotAllSignaled) {
+ IREE_ASSERT_OK_AND_ASSIGN(auto a, device_->CreateSemaphore(0u));
+ IREE_ASSERT_OK_AND_ASSIGN(auto b, device_->CreateSemaphore(1u));
+ // NOTE: we don't actually block here because otherwise we'd lock up.
+ // Result status is undefined - some backends may return DeadlineExceededError
+ // while others may return success.
+ device_->WaitAllSemaphores({{a.get(), 1u}, {b.get(), 1u}}, InfinitePast())
+ .IgnoreError();
+}
+
+// Waiting all semaphores and all are signaled.
+TEST_P(SemaphoreTest, WaitAllAndAllSignaled) {
+ IREE_ASSERT_OK_AND_ASSIGN(auto a, device_->CreateSemaphore(1u));
+ IREE_ASSERT_OK_AND_ASSIGN(auto b, device_->CreateSemaphore(1u));
+ IREE_ASSERT_OK(device_->WaitAllSemaphores({{a.get(), 1u}, {b.get(), 1u}},
+ InfiniteFuture()));
+}
+
+// Waiting any semaphore to signal.
+TEST_P(SemaphoreTest, WaitAny) {
+ // TODO: fix this.
+ if (driver_->name() == "dylib" || driver_->name() == "llvmjit" ||
+ driver_->name() == "vmla" || driver_->name() == "vulkan") {
+ GTEST_SKIP();
+ }
+
+ IREE_ASSERT_OK_AND_ASSIGN(auto a, device_->CreateSemaphore(0u));
+ IREE_ASSERT_OK_AND_ASSIGN(auto b, device_->CreateSemaphore(1u));
+ IREE_ASSERT_OK(device_->WaitAnySemaphore({{a.get(), 1u}, {b.get(), 1u}},
+ InfiniteFuture()));
+}
+
// Tests threading behavior by ping-ponging between the test main thread and
// a little thread.
TEST_P(SemaphoreTest, PingPong) {
diff --git a/iree/hal/metal/CMakeLists.txt b/iree/hal/metal/CMakeLists.txt
new file mode 100644
index 0000000..a7e643d
--- /dev/null
+++ b/iree/hal/metal/CMakeLists.txt
@@ -0,0 +1,125 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+if(NOT ${IREE_HAL_DRIVER_METAL})
+ return()
+endif()
+
+iree_cc_library(
+ NAME
+ metal_command_buffer
+ HDRS
+ "metal_command_buffer.h"
+ SRCS
+ "metal_command_buffer.mm"
+ DEPS
+ iree::base::status
+ iree::base::tracing
+ iree::hal::command_buffer
+ LINKOPTS
+ "-framework Metal"
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
+ metal_command_queue
+ HDRS
+ "metal_command_queue.h"
+ SRCS
+ "metal_command_queue.mm"
+ DEPS
+ ::metal_command_buffer
+ ::metal_shared_event
+ iree::base::status
+ iree::base::time
+ iree::base::tracing
+ iree::hal::command_queue
+ LINKOPTS
+ "-framework Metal"
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
+ metal_device
+ HDRS
+ "metal_device.h"
+ SRCS
+ "metal_device.mm"
+ DEPS
+ ::metal_command_buffer
+ ::metal_command_queue
+ ::metal_shared_event
+ absl::strings
+ absl::span
+ iree::base::status
+ iree::base::time
+ iree::base::tracing
+ iree::hal::allocator
+ iree::hal::command_queue
+ iree::hal::device
+ iree::hal::driver
+ iree::hal::semaphore
+ LINKOPTS
+ "-framework Metal"
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
+ metal_driver
+ HDRS
+ "metal_driver.h"
+ SRCS
+ "metal_driver.mm"
+ DEPS
+ ::metal_device
+ iree::base::status
+ iree::base::tracing
+ iree::hal::device_info
+ iree::hal::driver
+ LINKOPTS
+ "-framework Metal"
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
+ metal_driver_module
+ SRCS
+ "metal_driver_module.cc"
+ DEPS
+ ::metal_driver
+ iree::base::init
+ iree::base::status
+ iree::hal::driver_registry
+ ALWAYSLINK
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
+ metal_shared_event
+ HDRS
+ "metal_shared_event.h"
+ SRCS
+ "metal_shared_event.mm"
+ DEPS
+ iree::base::tracing
+ iree::hal::semaphore
+ LINKOPTS
+ "-framework Metal"
+ PUBLIC
+)
diff --git a/iree/hal/metal/README.md b/iree/hal/metal/README.md
new file mode 100644
index 0000000..382066b
--- /dev/null
+++ b/iree/hal/metal/README.md
@@ -0,0 +1,5 @@
+# Metal HAL Driver
+
+This directory contains the source code for the Metal HAL driver. See the
+[design doc](https://google.github.io/iree/design-docs/metal-hal-driver)
+for more details.
diff --git a/iree/hal/metal/dispatch_time_util.h b/iree/hal/metal/dispatch_time_util.h
new file mode 100644
index 0000000..6023d38
--- /dev/null
+++ b/iree/hal/metal/dispatch_time_util.h
@@ -0,0 +1,44 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_METAL_APPLE_TIME_UTIL_H_
+#define IREE_HAL_METAL_APPLE_TIME_UTIL_H_
+
+#include <dispatch/dispatch.h>
+
+#include "iree/base/time.h"
+
+namespace iree {
+namespace hal {
+namespace metal {
+
+// Converts a relative iree::Duration against the currrent time to the
+// corresponding dispatch_time_t value.
+static inline dispatch_time_t DurationToDispatchTime(Duration duration_ns) {
+ if (duration_ns == InfiniteDuration()) return DISPATCH_TIME_FOREVER;
+ if (duration_ns == ZeroDuration()) return DISPATCH_TIME_NOW;
+ return dispatch_time(DISPATCH_TIME_NOW, static_cast<uint64_t>(duration_ns));
+}
+
+// Converts an absolute iree::Time time to the corresponding dispatch_time_t
+// value.
+static inline dispatch_time_t DeadlineToDispatchTime(Time deadline_ns) {
+ return DurationToDispatchTime(DeadlineToRelativeTimeoutNanos(deadline_ns));
+}
+
+} // namespace metal
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_METAL_APPLE_TIME_UTIL_H_
diff --git a/iree/hal/metal/metal_command_buffer.h b/iree/hal/metal/metal_command_buffer.h
new file mode 100644
index 0000000..c3d532e
--- /dev/null
+++ b/iree/hal/metal/metal_command_buffer.h
@@ -0,0 +1,102 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_METAL_METAL_COMMAND_BUFFER_H_
+#define IREE_HAL_METAL_METAL_COMMAND_BUFFER_H_
+
+#import <Metal/Metal.h>
+
+#include "iree/hal/command_buffer.h"
+
+namespace iree {
+namespace hal {
+namespace metal {
+
+// A command buffer implementation for Metal that directly wraps a
+// MTLCommandBuffer.
+//
+// Objects of this class are not expected to be accessed by multiple threads.
+class MetalCommandBuffer final : public CommandBuffer {
+ public:
+ static StatusOr<ref_ptr<CommandBuffer>> Create(
+ CommandBufferModeBitfield mode,
+ CommandCategoryBitfield command_categories,
+ id<MTLCommandBuffer> command_buffer);
+ ~MetalCommandBuffer() override;
+
+ id<MTLCommandBuffer> handle() const { return metal_handle_; }
+
+ bool is_recording() const override { return is_recording_; }
+
+ Status Begin() override;
+ Status End() override;
+
+ Status ExecutionBarrier(
+ ExecutionStageBitfield source_stage_mask,
+ ExecutionStageBitfield target_stage_mask,
+ absl::Span<const MemoryBarrier> memory_barriers,
+ absl::Span<const BufferBarrier> buffer_barriers) override;
+
+ Status SignalEvent(Event* event,
+ ExecutionStageBitfield source_stage_mask) override;
+ Status ResetEvent(Event* event,
+ ExecutionStageBitfield source_stage_mask) override;
+ Status WaitEvents(absl::Span<Event*> events,
+ ExecutionStageBitfield source_stage_mask,
+ ExecutionStageBitfield target_stage_mask,
+ absl::Span<const MemoryBarrier> memory_barriers,
+ absl::Span<const BufferBarrier> buffer_barriers) override;
+
+ Status FillBuffer(Buffer* target_buffer, device_size_t target_offset,
+ device_size_t length, const void* pattern,
+ size_t pattern_length) override;
+ Status DiscardBuffer(Buffer* buffer) override;
+ Status UpdateBuffer(const void* source_buffer, device_size_t source_offset,
+ Buffer* target_buffer, device_size_t target_offset,
+ device_size_t length) override;
+ Status CopyBuffer(Buffer* source_buffer, device_size_t source_offset,
+ Buffer* target_buffer, device_size_t target_offset,
+ device_size_t length) override;
+
+ Status PushConstants(ExecutableLayout* executable_layout, size_t offset,
+ absl::Span<const uint32_t> values) override;
+
+ Status PushDescriptorSet(
+ ExecutableLayout* executable_layout, int32_t set,
+ absl::Span<const DescriptorSet::Binding> bindings) override;
+ Status BindDescriptorSet(
+ ExecutableLayout* executable_layout, int32_t set,
+ DescriptorSet* descriptor_set,
+ absl::Span<const device_size_t> dynamic_offsets) override;
+
+ Status Dispatch(Executable* executable, int32_t entry_point,
+ std::array<uint32_t, 3> workgroups) override;
+ Status DispatchIndirect(Executable* executable, int32_t entry_point,
+ Buffer* workgroups_buffer,
+ device_size_t workgroups_offset) override;
+
+ private:
+ MetalCommandBuffer(CommandBufferModeBitfield mode,
+ CommandCategoryBitfield command_categories,
+ id<MTLCommandBuffer> command_buffer);
+
+ bool is_recording_ = false;
+ id<MTLCommandBuffer> metal_handle_;
+};
+
+} // namespace metal
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_METAL_METAL_COMMAND_BUFFER_H_
diff --git a/iree/hal/metal/metal_command_buffer.mm b/iree/hal/metal/metal_command_buffer.mm
new file mode 100644
index 0000000..17dec6c
--- /dev/null
+++ b/iree/hal/metal/metal_command_buffer.mm
@@ -0,0 +1,140 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/hal/metal/metal_command_buffer.h"
+
+#include "iree/base/status.h"
+#include "iree/base/tracing.h"
+
+namespace iree {
+namespace hal {
+namespace metal {
+
+// static
+StatusOr<ref_ptr<CommandBuffer>> MetalCommandBuffer::Create(
+ CommandBufferModeBitfield mode, CommandCategoryBitfield command_categories,
+ id<MTLCommandBuffer> command_buffer) {
+ return assign_ref(new MetalCommandBuffer(mode, command_categories, command_buffer));
+}
+
+MetalCommandBuffer::MetalCommandBuffer(CommandBufferModeBitfield mode,
+ CommandCategoryBitfield command_categories,
+ id<MTLCommandBuffer> command_buffer)
+ : CommandBuffer(mode, command_categories), metal_handle_([command_buffer retain]) {}
+
+MetalCommandBuffer::~MetalCommandBuffer() {
+ IREE_TRACE_SCOPE0("MetalCommandBuffer::dtor");
+ [metal_handle_ release];
+}
+
+Status MetalCommandBuffer::Begin() {
+ IREE_TRACE_SCOPE0("MetalCommandBuffer::Begin");
+ is_recording_ = true;
+ return OkStatus();
+}
+
+Status MetalCommandBuffer::End() {
+ IREE_TRACE_SCOPE0("MetalCommandBuffer::End");
+ is_recording_ = false;
+ return OkStatus();
+}
+
+Status MetalCommandBuffer::ExecutionBarrier(ExecutionStageBitfield source_stage_mask,
+ ExecutionStageBitfield target_stage_mask,
+ absl::Span<const MemoryBarrier> memory_barriers,
+ absl::Span<const BufferBarrier> buffer_barriers) {
+ IREE_TRACE_SCOPE0("MetalCommandBuffer::ExecutionBarrier");
+ return UnimplementedErrorBuilder(IREE_LOC) << "MetalCommandBuffer::ExecutionBarrier";
+}
+
+Status MetalCommandBuffer::SignalEvent(Event* event, ExecutionStageBitfield source_stage_mask) {
+ IREE_TRACE_SCOPE0("MetalCommandBuffer::SignalEvent");
+ return UnimplementedErrorBuilder(IREE_LOC) << "MetalCommandBuffer::SignalEvent";
+}
+
+Status MetalCommandBuffer::ResetEvent(Event* event, ExecutionStageBitfield source_stage_mask) {
+ IREE_TRACE_SCOPE0("MetalCommandBuffer::ResetEvent");
+ return UnimplementedErrorBuilder(IREE_LOC) << "MetalCommandBuffer::ResetEvent";
+}
+
+Status MetalCommandBuffer::WaitEvents(absl::Span<Event*> events,
+ ExecutionStageBitfield source_stage_mask,
+ ExecutionStageBitfield target_stage_mask,
+ absl::Span<const MemoryBarrier> memory_barriers,
+ absl::Span<const BufferBarrier> buffer_barriers) {
+ IREE_TRACE_SCOPE0("MetalCommandBuffer::WaitEvents");
+ return UnimplementedErrorBuilder(IREE_LOC) << "MetalCommandBuffer::WaitEvents";
+}
+
+Status MetalCommandBuffer::FillBuffer(Buffer* target_buffer, device_size_t target_offset,
+ device_size_t length, const void* pattern,
+ size_t pattern_length) {
+ IREE_TRACE_SCOPE0("MetalCommandBuffer::FillBuffer");
+ return UnimplementedErrorBuilder(IREE_LOC) << "MetalCommandBuffer::FillBuffer";
+}
+
+Status MetalCommandBuffer::DiscardBuffer(Buffer* buffer) {
+ IREE_TRACE_SCOPE0("MetalCommandBuffer::DiscardBuffer");
+ return UnimplementedErrorBuilder(IREE_LOC) << "MetalCommandBuffer::DiscardBuffer";
+}
+
+Status MetalCommandBuffer::UpdateBuffer(const void* source_buffer, device_size_t source_offset,
+ Buffer* target_buffer, device_size_t target_offset,
+ device_size_t length) {
+ IREE_TRACE_SCOPE0("MetalCommandBuffer::UpdateBuffer");
+ return UnimplementedErrorBuilder(IREE_LOC) << "MetalCommandBuffer::UpdateBuffer";
+}
+
+Status MetalCommandBuffer::CopyBuffer(Buffer* source_buffer, device_size_t source_offset,
+ Buffer* target_buffer, device_size_t target_offset,
+ device_size_t length) {
+ IREE_TRACE_SCOPE0("MetalCommandBuffer::CopyBuffer");
+ return UnimplementedErrorBuilder(IREE_LOC) << "MetalCommandBuffer::CopyBuffer";
+}
+
+Status MetalCommandBuffer::PushConstants(ExecutableLayout* executable_layout, size_t offset,
+ absl::Span<const uint32_t> values) {
+ IREE_TRACE_SCOPE0("MetalCommandBuffer::PushConstants");
+ return UnimplementedErrorBuilder(IREE_LOC) << "MetalCommandBuffer::PushConstants";
+}
+
+Status MetalCommandBuffer::PushDescriptorSet(ExecutableLayout* executable_layout, int32_t set,
+ absl::Span<const DescriptorSet::Binding> bindings) {
+ IREE_TRACE_SCOPE0("MetalCommandBuffer::PushDescriptorSet");
+ return UnimplementedErrorBuilder(IREE_LOC) << "MetalCommandBuffer::PushDescriptorSet";
+}
+
+Status MetalCommandBuffer::BindDescriptorSet(ExecutableLayout* executable_layout, int32_t set,
+ DescriptorSet* descriptor_set,
+ absl::Span<const device_size_t> dynamic_offsets) {
+ IREE_TRACE_SCOPE0("MetalCommandBuffer::BindDescriptorSet");
+ return UnimplementedErrorBuilder(IREE_LOC) << "MetalCommandBuffer::BindDescriptorSet";
+}
+
+Status MetalCommandBuffer::Dispatch(Executable* executable, int32_t entry_point,
+ std::array<uint32_t, 3> workgroups) {
+ IREE_TRACE_SCOPE0("MetalCommandBuffer::Dispatch");
+ return UnimplementedErrorBuilder(IREE_LOC) << "MetalCommandBuffer::Dispatch";
+}
+
+Status MetalCommandBuffer::DispatchIndirect(Executable* executable, int32_t entry_point,
+ Buffer* workgroups_buffer,
+ device_size_t workgroups_offset) {
+ IREE_TRACE_SCOPE0("MetalCommandBuffer::DispatchIndirect");
+ return UnimplementedErrorBuilder(IREE_LOC) << "MetalCommandBuffer::DispatchIndirect";
+}
+
+} // namespace metal
+} // namespace hal
+} // namespace iree
diff --git a/iree/hal/metal/metal_command_queue.h b/iree/hal/metal/metal_command_queue.h
new file mode 100644
index 0000000..caf43f0
--- /dev/null
+++ b/iree/hal/metal/metal_command_queue.h
@@ -0,0 +1,54 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_METAL_METAL_COMMAND_QUEUE_H_
+#define IREE_HAL_METAL_METAL_COMMAND_QUEUE_H_
+
+#import <Metal/Metal.h>
+
+#include "iree/base/arena.h"
+#include "iree/base/status.h"
+#include "iree/base/time.h"
+#include "iree/hal/command_queue.h"
+
+namespace iree {
+namespace hal {
+namespace metal {
+
+// A command queue implementation for Metal that directly wraps a
+// MTLCommandQueue.
+//
+// Thread-safe.
+class MetalCommandQueue final : public CommandQueue {
+ public:
+ MetalCommandQueue(std::string name,
+ CommandCategoryBitfield supported_categories,
+ id<MTLCommandQueue> queue);
+ ~MetalCommandQueue() override;
+
+ id<MTLCommandQueue> handle() const { return metal_handle_; }
+
+ Status Submit(absl::Span<const SubmissionBatch> batches) override;
+
+ Status WaitIdle(Time deadline_ns) override;
+
+ private:
+ id<MTLCommandQueue> metal_handle_;
+};
+
+} // namespace metal
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_METAL_METAL_COMMAND_QUEUE_H_
diff --git a/iree/hal/metal/metal_command_queue.mm b/iree/hal/metal/metal_command_queue.mm
new file mode 100644
index 0000000..3729fd0
--- /dev/null
+++ b/iree/hal/metal/metal_command_queue.mm
@@ -0,0 +1,93 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/hal/metal/metal_command_queue.h"
+
+#include "iree/base/status.h"
+#include "iree/base/tracing.h"
+#include "iree/hal/metal/dispatch_time_util.h"
+#include "iree/hal/metal/metal_command_buffer.h"
+#include "iree/hal/metal/metal_shared_event.h"
+
+namespace iree {
+namespace hal {
+namespace metal {
+
+MetalCommandQueue::MetalCommandQueue(std::string name, CommandCategoryBitfield supported_categories,
+ id<MTLCommandQueue> queue)
+ : CommandQueue(std::move(name), supported_categories), metal_handle_([queue retain]) {}
+
+MetalCommandQueue::~MetalCommandQueue() { [metal_handle_ release]; }
+
+Status MetalCommandQueue::Submit(absl::Span<const SubmissionBatch> batches) {
+ IREE_TRACE_SCOPE0("MetalCommandQueue::Submit");
+ for (const auto& batch : batches) {
+ @autoreleasepool {
+ // Wait for semaphores blocking this batch.
+ if (!batch.wait_semaphores.empty()) {
+ id<MTLCommandBuffer> wait_buffer = [metal_handle_ commandBufferWithUnretainedReferences];
+ for (const auto& semaphore : batch.wait_semaphores) {
+ auto* event = static_cast<MetalSharedEvent*>(semaphore.semaphore);
+ [wait_buffer encodeWaitForEvent:event->handle() value:semaphore.value];
+ }
+ [wait_buffer commit];
+ }
+
+ // Commit command buffers to the queue.
+ for (const auto* command_buffer : batch.command_buffers) {
+ const auto* cmdbuf = static_cast<const MetalCommandBuffer*>(command_buffer);
+ [cmdbuf->handle() commit];
+ }
+
+ // Signal semaphores advanced by this batch.
+ if (!batch.signal_semaphores.empty()) {
+ id<MTLCommandBuffer> signal_buffer = [metal_handle_ commandBufferWithUnretainedReferences];
+ for (const auto& semaphore : batch.signal_semaphores) {
+ auto* event = static_cast<MetalSharedEvent*>(semaphore.semaphore);
+ [signal_buffer encodeSignalEvent:event->handle() value:semaphore.value];
+ }
+ [signal_buffer commit];
+ }
+ }
+ }
+ return OkStatus();
+}
+
+Status MetalCommandQueue::WaitIdle(Time deadline_ns) {
+ IREE_TRACE_SCOPE0("MetalCommandQueue::WaitIdle");
+
+ dispatch_time_t timeout = DeadlineToDispatchTime(deadline_ns);
+
+ // Submit an empty command buffer and wait for it to complete. That will indicate all previous
+ // work has completed too.
+ @autoreleasepool {
+ id<MTLCommandBuffer> comand_buffer = [metal_handle_ commandBufferWithUnretainedReferences];
+ __block dispatch_semaphore_t work_done = dispatch_semaphore_create(0);
+ [comand_buffer addCompletedHandler:^(id<MTLCommandBuffer>) {
+ dispatch_semaphore_signal(work_done);
+ }];
+ [comand_buffer commit];
+ long timed_out = dispatch_semaphore_wait(work_done, timeout);
+ dispatch_release(work_done);
+ if (timed_out) {
+ return DeadlineExceededErrorBuilder(IREE_LOC)
+ << "Deadline exceeded waiting for dispatch_semaphore_t";
+ }
+ return OkStatus();
+ }
+}
+
+} // namespace metal
+} // namespace hal
+} // namespace iree
diff --git a/iree/hal/metal/metal_device.h b/iree/hal/metal/metal_device.h
new file mode 100644
index 0000000..7802664
--- /dev/null
+++ b/iree/hal/metal/metal_device.h
@@ -0,0 +1,107 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_METAL_METAL_DEVICE_H_
+#define IREE_HAL_METAL_METAL_DEVICE_H_
+
+#import <Metal/Metal.h>
+
+#include "absl/types/span.h"
+#include "iree/base/memory.h"
+#include "iree/hal/allocator.h"
+#include "iree/hal/device.h"
+#include "iree/hal/driver.h"
+#include "iree/hal/semaphore.h"
+
+namespace iree {
+namespace hal {
+namespace metal {
+
+// A device implementation for Metal that directly wraps a MTLDevice.
+class MetalDevice final : public Device {
+ public:
+ // Creates a device that retains the underlying Metal GPU device.
+ // The DriverDeviceID in |device_info| is expected to be an id<MTLDevice>.
+ static StatusOr<ref_ptr<MetalDevice>> Create(ref_ptr<Driver> driver,
+ const DeviceInfo& device_info);
+
+ ~MetalDevice() override;
+
+ std::string DebugString() const override;
+
+ Allocator* allocator() const override { return nullptr; }
+
+ absl::Span<CommandQueue*> dispatch_queues() const override {
+ return absl::MakeSpan(&common_queue_, 1);
+ }
+
+ absl::Span<CommandQueue*> transfer_queues() const override {
+ return absl::MakeSpan(&common_queue_, 1);
+ }
+
+ ref_ptr<ExecutableCache> CreateExecutableCache() override;
+
+ StatusOr<ref_ptr<DescriptorSetLayout>> CreateDescriptorSetLayout(
+ DescriptorSetLayout::UsageType usage_type,
+ absl::Span<const DescriptorSetLayout::Binding> bindings) override;
+
+ StatusOr<ref_ptr<ExecutableLayout>> CreateExecutableLayout(
+ absl::Span<DescriptorSetLayout* const> set_layouts,
+ size_t push_constants) override;
+
+ StatusOr<ref_ptr<DescriptorSet>> CreateDescriptorSet(
+ DescriptorSetLayout* set_layout,
+ absl::Span<const DescriptorSet::Binding> bindings) override;
+
+ StatusOr<ref_ptr<CommandBuffer>> CreateCommandBuffer(
+ CommandBufferModeBitfield mode,
+ CommandCategoryBitfield command_categories) override;
+
+ StatusOr<ref_ptr<Event>> CreateEvent() override;
+
+ StatusOr<ref_ptr<Semaphore>> CreateSemaphore(uint64_t initial_value) override;
+ Status WaitAllSemaphores(absl::Span<const SemaphoreValue> semaphores,
+ Time deadline_ns) override;
+ StatusOr<int> WaitAnySemaphore(absl::Span<const SemaphoreValue> semaphores,
+ Time deadline_ns) override;
+
+ Status WaitIdle(Time deadline_ns) override;
+
+ private:
+ MetalDevice(ref_ptr<Driver> driver, const DeviceInfo& device_info);
+
+ ref_ptr<Driver> driver_;
+ id<MTLDevice> metal_handle_;
+
+ // Metal does not have clear graphics/dispatch/transfer queue distinction like
+ // Vulkan; one just use the same newCommandQueue() API call on MTLDevice to
+ // get command queues. Command encoders differ for different categories of
+ // commands though. We expose one queue here for everything. This can be
+ // changed later if more queues prove to be useful.
+
+ std::unique_ptr<CommandQueue> command_queue_;
+ mutable CommandQueue* common_queue_ = nullptr;
+
+ // A dispatch queue and associated event listener for running Objective-C
+ // blocks. This is typically used to wake up threads waiting on some HAL
+ // semaphore.
+ dispatch_queue_t wait_notifier_;
+ MTLSharedEventListener* event_listener_;
+};
+
+} // namespace metal
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_METAL_METAL_DEVICE_H_
diff --git a/iree/hal/metal/metal_device.mm b/iree/hal/metal/metal_device.mm
new file mode 100644
index 0000000..f293641
--- /dev/null
+++ b/iree/hal/metal/metal_device.mm
@@ -0,0 +1,192 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/hal/metal/metal_device.h"
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
+#include "iree/base/status.h"
+#include "iree/base/time.h"
+#include "iree/base/tracing.h"
+#include "iree/hal/command_buffer_validation.h"
+#include "iree/hal/metal/dispatch_time_util.h"
+#include "iree/hal/metal/metal_command_buffer.h"
+#include "iree/hal/metal/metal_command_queue.h"
+#include "iree/hal/metal/metal_shared_event.h"
+
+namespace iree {
+namespace hal {
+namespace metal {
+
+// static
+StatusOr<ref_ptr<MetalDevice>> MetalDevice::Create(ref_ptr<Driver> driver,
+ const DeviceInfo& device_info) {
+ return assign_ref(new MetalDevice(std::move(driver), device_info));
+}
+
+MetalDevice::MetalDevice(ref_ptr<Driver> driver, const DeviceInfo& device_info)
+ : Device(device_info),
+ driver_(std::move(driver)),
+ metal_handle_([(__bridge id<MTLDevice>)device_info.device_id() retain]) {
+ IREE_TRACE_SCOPE0("MetalDevice::ctor");
+
+ // Grab one queue for dispatch and transfer.
+ std::string name = absl::StrCat(device_info.name(), ":queue");
+ id<MTLCommandQueue> metal_queue = [metal_handle_ newCommandQueue]; // retained
+ command_queue_ = absl::make_unique<MetalCommandQueue>(
+ name, CommandCategory::kDispatch | CommandCategory::kTransfer, metal_queue);
+ common_queue_ = command_queue_.get();
+ // MetalCommandQueue retains by itself. Release here to avoid leaking.
+ [metal_queue release];
+
+ wait_notifier_ = dispatch_queue_create("com.google.iree.semaphore_wait_notifier", NULL);
+ event_listener_ = [[MTLSharedEventListener alloc] initWithDispatchQueue:wait_notifier_];
+}
+
+MetalDevice::~MetalDevice() {
+ IREE_TRACE_SCOPE0("MetalDevice::dtor");
+ [event_listener_ release];
+ dispatch_release(wait_notifier_);
+ [metal_handle_ release];
+}
+
+std::string MetalDevice::DebugString() const {
+ return absl::StrCat(Device::DebugString(), //
+ "\n[MetalDevice]", //
+ "\n - Dispatch Queues: 1", //
+ "\n - Transfer Queues: 1");
+}
+
+ref_ptr<ExecutableCache> MetalDevice::CreateExecutableCache() {
+ IREE_TRACE_SCOPE0("MetalDevice::CreateExecutableCache");
+ return nullptr;
+}
+
+StatusOr<ref_ptr<DescriptorSetLayout>> MetalDevice::CreateDescriptorSetLayout(
+ DescriptorSetLayout::UsageType usage_type,
+ absl::Span<const DescriptorSetLayout::Binding> bindings) {
+ IREE_TRACE_SCOPE0("MetalDevice::CreateDescriptorSetLayout");
+ return UnimplementedErrorBuilder(IREE_LOC) << "MetalDevice::CreateDescriptorSetLayout";
+}
+
+StatusOr<ref_ptr<ExecutableLayout>> MetalDevice::CreateExecutableLayout(
+ absl::Span<DescriptorSetLayout* const> set_layouts, size_t push_constants) {
+ IREE_TRACE_SCOPE0("MetalDevice::CreateExecutableLayout");
+ return UnimplementedErrorBuilder(IREE_LOC) << "MetalDevice::CreateExecutableLayout";
+}
+
+StatusOr<ref_ptr<DescriptorSet>> MetalDevice::CreateDescriptorSet(
+ DescriptorSetLayout* set_layout, absl::Span<const DescriptorSet::Binding> bindings) {
+ IREE_TRACE_SCOPE0("MetalDevice::CreateDescriptorSet");
+ return UnimplementedErrorBuilder(IREE_LOC) << "MetalDevice::CreateDescriptorSet";
+}
+
+StatusOr<ref_ptr<CommandBuffer>> MetalDevice::CreateCommandBuffer(
+ CommandBufferModeBitfield mode, CommandCategoryBitfield command_categories) {
+ IREE_TRACE_SCOPE0("MetalDevice::CreateCommandBuffer");
+ @autoreleasepool {
+ StatusOr<ref_ptr<CommandBuffer>> command_buffer;
+ // We use commandBufferWithUnretainedReferences here to be performant. This is okay becasue
+ // IREE tracks the lifetime of various objects with the help from compilers.
+ id<MTLCommandBuffer> cmdbuf = [static_cast<MetalCommandQueue*>(common_queue_)->handle()
+ commandBufferWithUnretainedReferences];
+ command_buffer = MetalCommandBuffer::Create(mode, command_categories, cmdbuf);
+ // TODO: WrapCommandBufferWithValidation(allocator(), std::move(impl));
+ return command_buffer;
+ }
+}
+
+StatusOr<ref_ptr<Event>> MetalDevice::CreateEvent() {
+ IREE_TRACE_SCOPE0("MetalDevice::CreateEvent");
+ return UnimplementedErrorBuilder(IREE_LOC) << "MetalDevice::CreateEvent";
+}
+
+StatusOr<ref_ptr<Semaphore>> MetalDevice::CreateSemaphore(uint64_t initial_value) {
+ IREE_TRACE_SCOPE0("MetalDevice::CreateSemaphore");
+ return MetalSharedEvent::Create(metal_handle_, event_listener_, initial_value);
+}
+
+Status MetalDevice::WaitAllSemaphores(absl::Span<const SemaphoreValue> semaphores,
+ Time deadline_ns) {
+ IREE_TRACE_SCOPE0("MetalDevice::WaitAllSemaphores");
+ // Go through all MetalSharedEvents and wait on each of them given we need all of them to be
+ // signaled anyway.
+ for (int i = 0; i < semaphores.size(); ++i) {
+ auto* semaphore = static_cast<MetalSharedEvent*>(semaphores[i].semaphore);
+ IREE_RETURN_IF_ERROR(semaphore->Wait(semaphores[i].value, deadline_ns));
+ }
+ return OkStatus();
+}
+
+StatusOr<int> MetalDevice::WaitAnySemaphore(absl::Span<const SemaphoreValue> semaphores,
+ Time deadline_ns) {
+ IREE_TRACE_SCOPE0("MetalDevice::WaitAnySemaphore");
+
+ if (semaphores.empty()) {
+ return InvalidArgumentErrorBuilder(IREE_LOC) << "expected to have at least one semaphore";
+ }
+
+ // If there is just one semaphore, just wait on it.
+ if (semaphores.size() == 1) {
+ auto* semaphore = static_cast<MetalSharedEvent*>(semaphores[0].semaphore);
+ IREE_RETURN_IF_ERROR(semaphore->Wait(semaphores[0].value, deadline_ns));
+ return 0;
+ }
+
+ // Otherwise, we need to go down a more complicated path by registering listeners to all
+ // MTLSharedEvents to notify us when at least one of them has done the work on GPU by signaling a
+ // semaphore. The signaling will happen in a new dispatch queue; the current thread will wait on
+ // the semaphore.
+
+ dispatch_time_t timeout = DeadlineToDispatchTime(deadline_ns);
+
+ // Store the handle as a __block variable to allow blocks accessing the same copy for the
+ // semaphore handle on heap.
+ // Use an initial value of zero so that any semaphore signal will unblock the wait.
+ __block dispatch_semaphore_t work_done = dispatch_semaphore_create(0);
+ // Also create a __block variable to store the index for the signaled semaphore.
+ __block int signaled_index = 0;
+
+ // The dispatch queue created in the above is a serial one. So even if multiple semaphores signal,
+ // the semaphore signaling should be serialized.
+ for (int i = 0; i < semaphores.size(); ++i) {
+ auto* semaphore = static_cast<MetalSharedEvent*>(semaphores[i].semaphore);
+ [semaphore->handle() notifyListener:event_listener_
+ atValue:semaphores[i].value
+ block:^(id<MTLSharedEvent>, uint64_t) {
+ dispatch_semaphore_signal(work_done);
+ // This should capture the *current* index for each semaphore.
+ signaled_index = i;
+ }];
+ }
+
+ long timed_out = dispatch_semaphore_wait(work_done, timeout);
+
+ dispatch_release(work_done);
+
+ if (timed_out) {
+ return DeadlineExceededErrorBuilder(IREE_LOC)
+ << "Deadline exceeded waiting for dispatch_semaphore_t";
+ }
+ return signaled_index;
+}
+
+Status MetalDevice::WaitIdle(Time deadline_ns) {
+ IREE_TRACE_SCOPE0("MetalDevice::WaitIdle");
+ return UnimplementedErrorBuilder(IREE_LOC) << "MetalDevice::WaitIdle";
+}
+
+} // namespace metal
+} // namespace hal
+} // namespace iree
diff --git a/iree/hal/metal/metal_driver.h b/iree/hal/metal/metal_driver.h
new file mode 100644
index 0000000..7f458b6
--- /dev/null
+++ b/iree/hal/metal/metal_driver.h
@@ -0,0 +1,50 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_METAL_METAL_DRIVER_H_
+#define IREE_HAL_METAL_METAL_DRIVER_H_
+
+#include "iree/hal/driver.h"
+
+namespace iree {
+namespace hal {
+namespace metal {
+
+// A pseudo Metal GPU driver which retains all available Metal GPU devices
+// during its lifetime.
+//
+// It uses the DriverDeviceID to store the underlying id<MTLDevice>.
+class MetalDriver final : public Driver {
+ public:
+ static StatusOr<ref_ptr<MetalDriver>> Create();
+
+ ~MetalDriver() override;
+
+ StatusOr<std::vector<DeviceInfo>> EnumerateAvailableDevices() override;
+
+ StatusOr<ref_ptr<Device>> CreateDefaultDevice() override;
+
+ StatusOr<ref_ptr<Device>> CreateDevice(DriverDeviceID device_id) override;
+
+ private:
+ explicit MetalDriver(std::vector<DeviceInfo> devices);
+
+ std::vector<DeviceInfo> devices_;
+};
+
+} // namespace metal
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_METAL_METAL_DRIVER_H_
diff --git a/iree/hal/metal/metal_driver.mm b/iree/hal/metal/metal_driver.mm
new file mode 100644
index 0000000..de3aca1
--- /dev/null
+++ b/iree/hal/metal/metal_driver.mm
@@ -0,0 +1,107 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/hal/metal/metal_driver.h"
+
+#import <Metal/Metal.h>
+
+#include "iree/base/status.h"
+#include "iree/base/tracing.h"
+#include "iree/hal/metal/metal_device.h"
+
+namespace iree {
+namespace hal {
+namespace metal {
+
+namespace {
+
+// Returns an autoreleased array of available Metal GPU devices.
+NSArray<id<MTLDevice>>* GetAvailableMetalDevices() {
+#if defined(IREE_PLATFORM_MACOS)
+ // For macOS, we might have more than one GPU devices.
+ return [MTLCopyAllDevices() autorelease];
+#else
+ // For other Apple platforms, we only have one GPU device.
+ id<MTLDevice> device = [MTLCreateSystemDefaultDevice() autorelease];
+ return [NSArray arrayWithObject:device];
+#endif
+}
+
+} // namespace
+
+// static
+StatusOr<ref_ptr<MetalDriver>> MetalDriver::Create() {
+ IREE_TRACE_SCOPE0("MetalDriver::Create");
+
+ @autoreleasepool {
+ NSArray<id<MTLDevice>>* devices = GetAvailableMetalDevices();
+ if (devices == nil) {
+ return UnavailableErrorBuilder(IREE_LOC) << "no Metal GPU devices available";
+ }
+
+ std::vector<DeviceInfo> device_infos;
+ for (id<MTLDevice> device in devices) {
+ std::string name = std::string([device.name UTF8String]);
+ DeviceFeatureBitfield supported_features = DeviceFeature::kNone;
+ DriverDeviceID device_id = reinterpret_cast<DriverDeviceID>((__bridge void*)device);
+ device_infos.emplace_back("metal", std::move(name), supported_features, device_id);
+ }
+ return assign_ref(new MetalDriver(std::move(device_infos)));
+ }
+}
+
+MetalDriver::MetalDriver(std::vector<DeviceInfo> devices)
+ : Driver("metal"), devices_(std::move(devices)) {
+ // Retain all the retained Metal GPU devices.
+ for (const auto& device : devices_) {
+ [(__bridge id<MTLDevice>)device.device_id() retain];
+ }
+}
+
+MetalDriver::~MetalDriver() {
+ IREE_TRACE_SCOPE0("MetalDriver::dtor");
+
+ // Release all the retained Metal GPU devices.
+ for (const auto& device : devices_) {
+ [(__bridge id<MTLDevice>)device.device_id() release];
+ }
+}
+
+StatusOr<std::vector<DeviceInfo>> MetalDriver::EnumerateAvailableDevices() {
+ IREE_TRACE_SCOPE0("MetalDriver::EnumerateAvailableDevices");
+
+ return devices_;
+}
+
+StatusOr<ref_ptr<Device>> MetalDriver::CreateDefaultDevice() {
+ IREE_TRACE_SCOPE0("MetalDriver::CreateDefaultDevice");
+
+ if (devices_.empty()) {
+ return UnavailableErrorBuilder(IREE_LOC) << "no Metal GPU devices available";
+ }
+ return CreateDevice(devices_.front().device_id());
+}
+
+StatusOr<ref_ptr<Device>> MetalDriver::CreateDevice(DriverDeviceID device_id) {
+ IREE_TRACE_SCOPE0("MetalDriver::CreateDevice");
+
+ for (const DeviceInfo& info : devices_) {
+ if (info.device_id() == device_id) return MetalDevice::Create(add_ref(this), info);
+ }
+ return InvalidArgumentErrorBuilder(IREE_LOC) << "unknown driver device id: " << device_id;
+}
+
+} // namespace metal
+} // namespace hal
+} // namespace iree
diff --git a/iree/hal/metal/metal_driver_module.cc b/iree/hal/metal/metal_driver_module.cc
new file mode 100644
index 0000000..ba2ec58
--- /dev/null
+++ b/iree/hal/metal/metal_driver_module.cc
@@ -0,0 +1,38 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <memory>
+
+#include "iree/base/init.h"
+#include "iree/base/status.h"
+#include "iree/hal/driver_registry.h"
+#include "iree/hal/metal/metal_driver.h"
+
+namespace iree {
+namespace hal {
+namespace metal {
+namespace {
+
+StatusOr<ref_ptr<Driver>> CreateMetalDriver() { return MetalDriver::Create(); }
+
+} // namespace
+} // namespace metal
+} // namespace hal
+} // namespace iree
+
+IREE_REGISTER_MODULE_INITIALIZER(iree_hal_metal_driver, {
+ IREE_QCHECK_OK(::iree::hal::DriverRegistry::shared_registry()->Register(
+ "metal", ::iree::hal::metal::CreateMetalDriver));
+});
+IREE_REGISTER_MODULE_INITIALIZER_SEQUENCE(iree_hal, iree_hal_metal_driver);
diff --git a/iree/hal/metal/metal_shared_event.h b/iree/hal/metal/metal_shared_event.h
new file mode 100644
index 0000000..10a96ce
--- /dev/null
+++ b/iree/hal/metal/metal_shared_event.h
@@ -0,0 +1,68 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_METAL_METAL_SHARED_EVENT_H_
+#define IREE_HAL_METAL_METAL_SHARED_EVENT_H_
+
+#import <Metal/Metal.h>
+
+#include "absl/synchronization/mutex.h"
+#include "iree/hal/semaphore.h"
+
+namespace iree {
+namespace hal {
+namespace metal {
+
+// A semaphore implementation for Metal that directly wraps a MTLSharedEvent.
+class MetalSharedEvent final : public Semaphore {
+ public:
+ // Creates a MetalSharedEvent with the given |initial_value|.
+ static StatusOr<ref_ptr<Semaphore>> Create(
+ id<MTLDevice> device, MTLSharedEventListener* event_listener,
+ uint64_t initial_value);
+
+ ~MetalSharedEvent() override;
+
+ id<MTLSharedEvent> handle() const { return metal_handle_; }
+
+ StatusOr<uint64_t> Query() override;
+
+ Status Signal(uint64_t value) override;
+
+ void Fail(Status status) override;
+
+ Status Wait(uint64_t value, Time deadline_ns) override;
+
+ private:
+ MetalSharedEvent(id<MTLDevice> device, MTLSharedEventListener* event_listener,
+ uint64_t initial_value);
+
+ id<MTLSharedEvent> metal_handle_;
+
+ // An event listener for waiting and signaling. Its lifetime is managed by
+ // the parent device.
+ MTLSharedEventListener* event_listener_;
+
+ // NOTE: the MTLSharedEvent is the source of truth. We only need to access
+ // this status (and thus take the lock) when we want to either signal failure
+ // or query the status in the case of the semaphore being set to UINT64_MAX.
+ mutable absl::Mutex status_mutex_;
+ Status status_ ABSL_GUARDED_BY(status_mutex_);
+};
+
+} // namespace metal
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_METAL_METAL_SHARED_EVENT_H_
diff --git a/iree/hal/metal/metal_shared_event.mm b/iree/hal/metal/metal_shared_event.mm
new file mode 100644
index 0000000..325c30a
--- /dev/null
+++ b/iree/hal/metal/metal_shared_event.mm
@@ -0,0 +1,108 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/hal/metal/metal_shared_event.h"
+
+#include "iree/base/status.h"
+#include "iree/base/tracing.h"
+#include "iree/hal/metal/dispatch_time_util.h"
+
+namespace iree {
+namespace hal {
+namespace metal {
+
+// static
+StatusOr<ref_ptr<Semaphore>> MetalSharedEvent::Create(id<MTLDevice> device,
+ MTLSharedEventListener* event_listener,
+ uint64_t initial_value) {
+ return assign_ref(new MetalSharedEvent(device, event_listener, initial_value));
+}
+
+MetalSharedEvent::MetalSharedEvent(id<MTLDevice> device, MTLSharedEventListener* event_listener,
+ uint64_t initial_value)
+ : metal_handle_([device newSharedEvent]), event_listener_(event_listener) {
+ IREE_TRACE_SCOPE0("MetalSharedEvent::ctor");
+ metal_handle_.signaledValue = initial_value;
+}
+
+MetalSharedEvent::~MetalSharedEvent() {
+ IREE_TRACE_SCOPE0("MetalSharedEvent::dtor");
+ [metal_handle_ release];
+}
+
+StatusOr<uint64_t> MetalSharedEvent::Query() {
+ IREE_TRACE_SCOPE0("MetalSharedEvent::Query");
+ uint64_t value = metal_handle_.signaledValue;
+ if (value == UINT64_MAX) {
+ absl::MutexLock lock(&status_mutex_);
+ return status_;
+ }
+ return value;
+}
+
+Status MetalSharedEvent::Signal(uint64_t value) {
+ IREE_TRACE_SCOPE0("MetalSharedEvent::Signal");
+ metal_handle_.signaledValue = value;
+ return OkStatus();
+}
+
+void MetalSharedEvent::Fail(Status status) {
+ IREE_TRACE_SCOPE0("MetalSharedEvent::Fail");
+ absl::MutexLock lock(&status_mutex_);
+ status_ = std::move(status);
+ metal_handle_.signaledValue = UINT64_MAX;
+}
+
+Status MetalSharedEvent::Wait(uint64_t value, Time deadline_ns) {
+ IREE_TRACE_SCOPE0("MetalSharedEvent::Wait");
+
+ Duration duration_ns = DeadlineToRelativeTimeoutNanos(deadline_ns);
+ dispatch_time_t timeout = DurationToDispatchTime(duration_ns);
+
+ // Quick path for impatient waiting to avoid all the overhead of dispatch queues and semaphores.
+ if (duration_ns == ZeroDuration()) {
+ if (metal_handle_.signaledValue < value) {
+ return DeadlineExceededErrorBuilder(IREE_LOC) << "Deadline exceeded waiting for semaphores";
+ }
+ return OkStatus();
+ }
+
+ // Theoretically we don't really need to mark the semaphore handle as __block given that the
+ // handle itself is not modified and there is only one block and it will copy the handle.
+ // But marking it as __block serves as good documentation purpose.
+ __block dispatch_semaphore_t work_done = dispatch_semaphore_create(0);
+
+ // Use a listener to the MTLSharedEvent to notify us when the work is done on GPU by signaling a
+ // semaphore. The signaling will happen in a new dispatch queue; the current thread will wait on
+ // the semaphore.
+ [metal_handle_ notifyListener:event_listener_
+ atValue:value
+ block:^(id<MTLSharedEvent>, uint64_t) {
+ dispatch_semaphore_signal(work_done);
+ }];
+
+ long timed_out = dispatch_semaphore_wait(work_done, timeout);
+
+ dispatch_release(work_done);
+
+ if (timed_out) {
+ return DeadlineExceededErrorBuilder(IREE_LOC)
+ << "Deadline exceeded waiting for dispatch_semaphore_t";
+ }
+ return OkStatus();
+}
+
+} // namespace metal
+} // namespace hal
+} // namespace iree