[metal] Initial bring up of Metal HAL driver (5/n) (#3318)
This commit adds Metal implementation for DescriptorSetLayout,
DescriptorSet, and ExecutableLayout. It also implements
CommandBuffer::ExecutionBarrier() and CommandBuffer::Dispatch().
Co-authored-by: Scott Todd <scotttodd@google.com>
diff --git a/docs/design_docs/metal_hal_driver.md b/docs/design_docs/metal_hal_driver.md
index 68242ce..cf21206 100644
--- a/docs/design_docs/metal_hal_driver.md
+++ b/docs/design_docs/metal_hal_driver.md
@@ -62,6 +62,9 @@
[`hal::Buffer`][hal-buffer] | [`MTLBuffer`][mtl-buffer]
[`hal::Executable`][hal-executable] | [`MTLLibrary`][mtl-library]
[`hal::ExecutableCache`][hal-executable-cache] | N/A
+[`hal::DescriptorSetLayout`][hal-descriptor-set-layout] | N/A
+[`hal::DescriptorSet`][hal-descriptor-set] | N/A
+[`hal::ExecutableLayout`][hal-executable-layout] | N/A
In the following subsections, we go over each pair to provide more details.
@@ -158,6 +161,10 @@
HAL driver does not peforming any cache on GPU programs; it simply reads the
program from the FlatBuffer and hands it over to Metal driver.
+### DescriptorSetLayout, DescriptorSet, ExecutableLayout
+
+See [Resource descriptors](#resource-descriptors) for more details.
+
## Compute Pipeline
### Shader/kernel compilation
@@ -198,6 +205,70 @@
them at Metal run-time. In the future this should be changed to allow encoding
the library instead.
+### Resource descriptors
+
+A descriptor is an opaque handle pointing to a resource that is accessed in
+the compute kernel. IREE's HAL is inspired by the Vulkan API; it models several
+concepts related to GPU resource management explicitly:
+
+* [`hal::DescriptorSetLayout`][hal-descriptor-set-layout]: a schema for
+ describing an array of descriptor bindings. Each descriptor binding specifies
+ the resource type, access mode and other information.
+* [`hal::DescriptorSet`][hal-descriptor-set]: a concrete set of resources that
+ gets bound to a compute pipeline in a batch. It must match the
+ `DescriptorSetLayout` describing its layout. `DescriptorSet` can be thought as
+ the "object" from the `DescriptorSetLayout` "class".
+* [`hal::ExecutableLayout`][hal-executable-layout]: a schema for describing all
+ the resources accessed by a compute pipeline. It includes zero or more
+ `DescriptorSetLayout`s and (optional) push constants.
+
+One can create `DescriptorSetLayout`, `DescriptorSet`, and `ExecutableLayout`
+objects beforehand to avoid incurring overhead during tight computing loops
+and also amortize costs by sharing these objects. However, this isn't totally
+matching Metal's paradigm.
+
+In the Metal framework, the closest concept to `DescriptorSet` would be [argument
+buffer][mtl-argument-buffer]. There is no direct correspondence to
+`DescriptorSetLayout` and `ExecutableLayout`. Rather, the layout is implicitly
+encoded in Metal shaders as MSL structs. The APIs for creating argument buffers
+do not encourage early creation without pipelines: one typically creates them
+for each `MTLFunction`. Besides, unlike Vulkan where different descriptor sets
+can have the same binding number, in Metal even if we have multiple argument
+buffers, the indices for resources are in the same namespace and are typically
+assigned sequentially. That means we need to remap `DescriptorSet`s with a set
+number greater than zero by applying an offset to each of its bindings.
+
+All of this means it's better to defer the creation of the argument buffer
+until the point of compute pipeline creation and dispatch. Therefore, although
+the Metal HAL driver provides the implementation for `DescriptorSet`
+(i.e., `hal::metal::MetalArgumentBuffer`), `DescriptorSetLayout` (i.e.,
+`hal::metal::MetalArgumentBufferLayout`), and `ExecutableLayout` (i.e.,
+`hal::metal::MetalPipelineArgumentBufferLayout`), they are just containers
+holding the information up until the [command buffer
+dispatch](#command-buffer-dispatch) time.
+
+With the above said, the overall idea is still to map one descriptor set to one
+argument buffer. It just means we need to condense and remap the bindings.
+
+### Command buffer dispatch
+
+`MetalCommandBuffer::Dispatch()` performs the following steps with the current
+active `MTLComputeCommandEncoder`:
+
+1. Bind the `MTLComputePipelineState` for the current entry function queried
+ from `MetalKernelLibrary`.
+1. For each bound descriptor set at set #`S`:
+ 1. Create a [`MTLArgumentEncoder`][mtl-argument-encoder] for encoding an
+ associated argument `MTLBuffer`.
+ 1. For each bound resource buffer at binding #`B` in this descriptor set,
+ encode it to the argument buffer index #`B` with
+ `setBuffer::offset::atIndex:` and inform the `MTLComputeCommandEncoder`
+ that the dispatch will use this resource with `useResource:usage:`.
+ 1. Set the argument `MTLBuffer` to buffer index #`S`.
+1. Dispatch with `dispatchThreadgroups:threadsPerThreadgroup:`.
+
+(TODO: condense and remap bindings)
+
## Memory Management
### Storage type
@@ -239,6 +310,9 @@
[hal-buffer]: https://github.com/google/iree/blob/main/iree/hal/buffer.h
[hal-command-queue]: https://github.com/google/iree/blob/main/iree/hal/command_queue.h
[hal-command-buffer]: https://github.com/google/iree/blob/main/iree/hal/command_buffer.h
+[hal-descriptor-set]: https://github.com/google/iree/blob/main/iree/hal/descriptor_set.h
+[hal-descriptor-set-layout]: https://github.com/google/iree/blob/main/iree/hal/descriptor_set_layout.h
+[hal-executable-layout]: https://github.com/google/iree/blob/main/iree/hal/executable_layout.h
[hal-device]: https://github.com/google/iree/blob/main/iree/hal/device.h
[hal-driver]: https://github.com/google/iree/blob/main/iree/hal/driver.h
[hal-executable]: https://github.com/google/iree/blob/main/iree/hal/executable.h
@@ -251,6 +325,8 @@
[metal-kernel-library]: https://github.com/google/iree/blob/main/iree/hal/metal/metal_kernel_library.h
[metal-shared-event]: https://github.com/google/iree/blob/main/iree/hal/metal/metal_shared_event.h
[metal-spirv-target]: https://github.com/google/iree/tree/hal-metal/iree/compiler/Dialect/HAL/Target/MetalSPIRV
+[mtl-argument-buffer]: https://developer.apple.com/documentation/metal/buffers/about_argument_buffers?language=objc
+[mtl-argument-encoder]: https://developer.apple.com/documentation/metal/mtlargumentencoder?language=objc
[mtl-buffer]: https://developer.apple.com/documentation/metal/mtlbuffer?language=objc
[mtl-command-buffer]: https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc
[mtl-command-encoder]: https://developer.apple.com/documentation/metal/mtlcommandencoder?language=objc
diff --git a/iree/hal/descriptor_set_layout.h b/iree/hal/descriptor_set_layout.h
index 3f111d8..2e80888 100644
--- a/iree/hal/descriptor_set_layout.h
+++ b/iree/hal/descriptor_set_layout.h
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include "absl/strings/str_cat.h"
#include "iree/hal/buffer.h"
#include "iree/hal/resource.h"
@@ -52,6 +53,11 @@
DescriptorType type = DescriptorType::kStorageBuffer;
// Specifies the memory access performed by the executables.
MemoryAccessBitfield access = MemoryAccess::kRead | MemoryAccess::kWrite;
+
+ std::string DebugStringShort() const {
+ return absl::StrCat("binding=", binding, ", type=", type,
+ ", access=", MemoryAccessString(access));
+ }
};
};
diff --git a/iree/hal/metal/CMakeLists.txt b/iree/hal/metal/CMakeLists.txt
index 29a9e89..29b0f67 100644
--- a/iree/hal/metal/CMakeLists.txt
+++ b/iree/hal/metal/CMakeLists.txt
@@ -24,6 +24,10 @@
SRCS
"metal_command_buffer.mm"
DEPS
+ ::metal_kernel_library
+ ::metal_pipeline_argument_buffer
+ absl::flat_hash_map
+ absl::inlined_vector
iree::base::status
iree::base::tracing
iree::hal::command_buffer
@@ -62,6 +66,7 @@
::metal_command_buffer
::metal_command_queue
::metal_direct_allocator
+ ::metal_pipeline_argument_buffer
::metal_pipeline_cache
::metal_shared_event
absl::strings
@@ -154,6 +159,21 @@
iree_cc_library(
NAME
+ metal_pipeline_argument_buffer
+ HDRS
+ "metal_pipeline_argument_buffer.h"
+ SRCS
+ "metal_pipeline_argument_buffer.cc"
+ DEPS
+ absl::inlined_vector
+ absl::span
+ iree::hal::descriptor_set_layout
+ iree::hal::executable_layout
+ PUBLIC
+)
+
+iree_cc_library(
+ NAME
metal_pipeline_cache
HDRS
"metal_pipeline_cache.h"
diff --git a/iree/hal/metal/metal_command_buffer.h b/iree/hal/metal/metal_command_buffer.h
index c58e5be..06973a2 100644
--- a/iree/hal/metal/metal_command_buffer.h
+++ b/iree/hal/metal/metal_command_buffer.h
@@ -17,6 +17,8 @@
#import <Metal/Metal.h>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/inlined_vector.h"
#include "iree/hal/command_buffer.h"
#include "iree/hal/metal/metal_buffer.h"
@@ -88,6 +90,27 @@
device_size_t workgroups_offset) override;
private:
+ // A struct containing all resources states of the current pipeline.
+ struct PipelineStateObject {
+ struct PushState {
+ absl::InlinedVector<DescriptorSet::Binding, 8> resource_bindings;
+ };
+ // Map from set number to push descriptor states
+ absl::flat_hash_map<int32_t, PushState> push_states;
+
+ struct BindState {
+ DescriptorSet* descriptor_set;
+ };
+ // Map from set number to bind descriptor states
+ absl::flat_hash_map<int32_t, BindState> bind_states;
+
+ struct ConstantState {
+ absl::InlinedVector<uint32_t, 16> values;
+ };
+ // Map from set number to push constant states
+ absl::flat_hash_map<uint32_t, ConstantState> constant_states;
+ };
+
MetalCommandBuffer(CommandBufferModeBitfield mode,
CommandCategoryBitfield command_categories,
id<MTLCommandBuffer> command_buffer);
@@ -104,11 +127,15 @@
id<MTLComputeCommandEncoder> GetOrBeginComputeEncoder();
void EndComputeEncoder();
+ private:
bool is_recording_ = false;
id<MTLCommandBuffer> metal_handle_;
id<MTLComputeCommandEncoder> current_compute_encoder_ = nil;
id<MTLBlitCommandEncoder> current_blit_encoder_ = nil;
+
+ absl::flat_hash_map<ExecutableLayout*, PipelineStateObject>
+ pipeline_state_objects_;
};
} // namespace metal
diff --git a/iree/hal/metal/metal_command_buffer.mm b/iree/hal/metal/metal_command_buffer.mm
index 8555793..5257bec 100644
--- a/iree/hal/metal/metal_command_buffer.mm
+++ b/iree/hal/metal/metal_command_buffer.mm
@@ -17,11 +17,24 @@
#include "iree/base/logging.h"
#include "iree/base/status.h"
#include "iree/base/tracing.h"
+#include "iree/hal/metal/metal_kernel_library.h"
+#include "iree/hal/metal/metal_pipeline_argument_buffer.h"
namespace iree {
namespace hal {
namespace metal {
+namespace {
+
+MTLResourceUsage ConvertResourceUsage(MemoryAccessBitfield memory_access) {
+ MTLResourceUsage usage = 0;
+ if (AllBitsSet(memory_access, MemoryAccess::kRead)) usage |= MTLResourceUsageRead;
+ if (AllBitsSet(memory_access, MemoryAccess::kWrite)) usage |= MTLResourceUsageWrite;
+ return usage;
+}
+
+} // namespace
+
// static
StatusOr<ref_ptr<CommandBuffer>> MetalCommandBuffer::Create(
CommandBufferModeBitfield mode, CommandCategoryBitfield command_categories,
@@ -110,7 +123,28 @@
absl::Span<const MemoryBarrier> memory_barriers,
absl::Span<const BufferBarrier> buffer_barriers) {
IREE_TRACE_SCOPE0("MetalCommandBuffer::ExecutionBarrier");
- return UnimplementedErrorBuilder(IREE_LOC) << "MetalCommandBuffer::ExecutionBarrier";
+
+ if (AllBitsSet(source_stage_mask, ExecutionStage::kHost) ||
+ AllBitsSet(target_stage_mask, ExecutionStage::kHost)) {
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "MetalCommandBuffer::ExecutionBarrier with host bit set";
+ }
+
+ // If there is a memory barrier specified, we have to place a catch-all barrier for all buffers.
+ // Metal does not provide a more fine-grained control here. But we do have the option to specify a
+ // list of buffers to synchronize if only buffer barriers are specified.
+ if (!memory_barriers.empty()) {
+ [GetOrBeginComputeEncoder() memoryBarrierWithScope:MTLBarrierScopeBuffers];
+ } else if (!buffer_barriers.empty()) {
+ std::vector<id<MTLResource>> buffers;
+ buffers.reserve(buffer_barriers.size());
+ for (const auto& barrier : buffer_barriers) {
+ buffers.push_back(static_cast<MetalBuffer*>(barrier.buffer)->handle());
+ }
+ [GetOrBeginComputeEncoder() memoryBarrierWithResources:buffers.data() count:buffers.size()];
+ }
+
+ return OkStatus();
}
Status MetalCommandBuffer::SignalEvent(Event* event, ExecutionStageBitfield source_stage_mask) {
@@ -215,20 +249,125 @@
Status MetalCommandBuffer::PushDescriptorSet(ExecutableLayout* executable_layout, int32_t set,
absl::Span<const DescriptorSet::Binding> bindings) {
IREE_TRACE_SCOPE0("MetalCommandBuffer::PushDescriptorSet");
- return UnimplementedErrorBuilder(IREE_LOC) << "MetalCommandBuffer::PushDescriptorSet";
+ if (set != 0) {
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "MetalCommandBuffer::PushDescriptorSet with set number > 0";
+ }
+ auto& push_state = pipeline_state_objects_[executable_layout].push_states[set];
+ push_state.resource_bindings.assign(bindings.begin(), bindings.end());
+ return OkStatus();
}
Status MetalCommandBuffer::BindDescriptorSet(ExecutableLayout* executable_layout, int32_t set,
DescriptorSet* descriptor_set,
absl::Span<const device_size_t> dynamic_offsets) {
IREE_TRACE_SCOPE0("MetalCommandBuffer::BindDescriptorSet");
- return UnimplementedErrorBuilder(IREE_LOC) << "MetalCommandBuffer::BindDescriptorSet";
+ if (set != 0) {
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "MetalCommandBuffer::BindDescriptorSet with set number > 0";
+ }
+ if (!dynamic_offsets.empty()) {
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "MetalCommandBuffer::BindDescriptorSet with dynamic offsets";
+ }
+ pipeline_state_objects_[executable_layout].bind_states[set].descriptor_set = descriptor_set;
+ return OkStatus();
}
Status MetalCommandBuffer::Dispatch(Executable* executable, int32_t entry_point,
std::array<uint32_t, 3> workgroups) {
IREE_TRACE_SCOPE0("MetalCommandBuffer::Dispatch");
- return UnimplementedErrorBuilder(IREE_LOC) << "MetalCommandBuffer::Dispatch";
+ IREE_DVLOG(2) << "MetalCommandBuffer::Dispatch";
+
+ auto* kernel_library = static_cast<MetalKernelLibrary*>(executable);
+ IREE_ASSIGN_OR_RETURN(auto metal_kernel, kernel_library->GetKernelForEntryPoint(entry_point));
+ IREE_ASSIGN_OR_RETURN(auto metal_pso, kernel_library->GetPipelineStateForEntryPoint(entry_point));
+
+ id<MTLComputeCommandEncoder> compute_encoder = GetOrBeginComputeEncoder();
+ [compute_encoder setComputePipelineState:metal_pso];
+
+ // TODO(antiagainst): only update the PSO for the current executable.
+ for (const auto& pso_kv : pipeline_state_objects_) {
+ const auto* pipeline_layout = static_cast<MetalPipelineArgumentBufferLayout*>(pso_kv.first);
+ IREE_DVLOG(3) << "Current pipeline layout: " << pipeline_layout->DebugString();
+
+ const auto& pso = pso_kv.second;
+ if (pso.push_states.size() > 1) {
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "MetalCommandBuffer::Dispatch with more than one push descriptor sets";
+ }
+ if (!pso.bind_states.empty()) {
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "MetalCommandBuffer::Dispatch with bound descriptor sets";
+ }
+ if (!pso.constant_states.empty()) {
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "MetalCommandBuffer::Dispatch with push constants";
+ }
+
+ IREE_DVLOG(3) << "Encoding push descriptors..";
+ for (const auto& push_kv : pso.push_states) {
+ int32_t set_number = push_kv.first;
+ const PipelineStateObject::PushState& push_state = push_kv.second;
+ IREE_DVLOG(3) << " For set #" << set_number;
+
+ id<MTLArgumentEncoder> argument_encoder =
+ [metal_kernel newArgumentEncoderWithBufferIndex:set_number]; // retained
+ if (!argument_encoder) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Buffer index #" << set_number << " is not an argument buffer";
+ }
+
+ __block id<MTLBuffer> argument_buffer =
+ [metal_handle_.device newBufferWithLength:argument_encoder.encodedLength
+ options:MTLResourceStorageModeShared]; // retained
+ if (!argument_buffer) {
+ return InternalErrorBuilder(IREE_LOC)
+ << "Failed to create argument buffer with length=" << argument_encoder.encodedLength;
+ }
+ [metal_handle_ addCompletedHandler:^(id<MTLCommandBuffer>) {
+ [argument_buffer release];
+ [argument_encoder release];
+ }];
+
+ [argument_encoder setArgumentBuffer:argument_buffer offset:0];
+
+ for (const auto& resource_binding : push_state.resource_bindings) {
+ IREE_DVLOG(3) << " Resource @[" << resource_binding.DebugStringShort() << "]";
+
+ if (resource_binding.length != kWholeBuffer &&
+ resource_binding.length != resource_binding.buffer->allocation_size()) {
+ return UnimplementedErrorBuilder(IREE_LOC)
+ << "MetalCommandBuffer::Dispatch with sub-buffer";
+ }
+
+ IREE_ASSIGN_OR_RETURN(auto buffer, CastBuffer(resource_binding.buffer));
+ [argument_encoder setBuffer:buffer->handle()
+ offset:resource_binding.offset
+ atIndex:resource_binding.binding];
+
+ const auto* set_layout = pipeline_layout->set_layouts()[set_number];
+ const auto* layout_binding = set_layout->GetBindingForIndex(resource_binding.binding);
+ if (!layout_binding) {
+ return InvalidArgumentErrorBuilder(IREE_LOC)
+ << "Cannot find binding #" << resource_binding.binding
+ << " in argument buffer layout";
+ }
+ [compute_encoder useResource:buffer->handle()
+ usage:ConvertResourceUsage(layout_binding->access)];
+ }
+
+ [compute_encoder setBuffer:argument_buffer offset:0 atIndex:set_number];
+ }
+ }
+
+ IREE_DVLOG(2) << "Dispatch workgroup count: (" << workgroups[0] << ", " << workgroups[1] << ", "
+ << workgroups[2] << "), workgroup size: (32, 1, 1)";
+ // TODO(antiagainst): fix workgroup size
+ [compute_encoder dispatchThreadgroups:MTLSizeMake(workgroups[0], workgroups[1], workgroups[2])
+ threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
+
+ return OkStatus();
}
Status MetalCommandBuffer::DispatchIndirect(Executable* executable, int32_t entry_point,
diff --git a/iree/hal/metal/metal_device.mm b/iree/hal/metal/metal_device.mm
index ec608f0..97dcba8 100644
--- a/iree/hal/metal/metal_device.mm
+++ b/iree/hal/metal/metal_device.mm
@@ -25,6 +25,7 @@
#include "iree/hal/metal/metal_command_buffer.h"
#include "iree/hal/metal/metal_command_queue.h"
#include "iree/hal/metal/metal_direct_allocator.h"
+#include "iree/hal/metal/metal_pipeline_argument_buffer.h"
#include "iree/hal/metal/metal_pipeline_cache.h"
#include "iree/hal/metal/metal_shared_event.h"
@@ -83,19 +84,20 @@
DescriptorSetLayout::UsageType usage_type,
absl::Span<const DescriptorSetLayout::Binding> bindings) {
IREE_TRACE_SCOPE0("MetalDevice::CreateDescriptorSetLayout");
- return UnimplementedErrorBuilder(IREE_LOC) << "MetalDevice::CreateDescriptorSetLayout";
+ return make_ref<MetalArgumentBufferLayout>(usage_type, bindings);
}
StatusOr<ref_ptr<ExecutableLayout>> MetalDevice::CreateExecutableLayout(
absl::Span<DescriptorSetLayout* const> set_layouts, size_t push_constants) {
IREE_TRACE_SCOPE0("MetalDevice::CreateExecutableLayout");
- return UnimplementedErrorBuilder(IREE_LOC) << "MetalDevice::CreateExecutableLayout";
+ return make_ref<MetalPipelineArgumentBufferLayout>(set_layouts, push_constants);
}
StatusOr<ref_ptr<DescriptorSet>> MetalDevice::CreateDescriptorSet(
DescriptorSetLayout* set_layout, absl::Span<const DescriptorSet::Binding> bindings) {
IREE_TRACE_SCOPE0("MetalDevice::CreateDescriptorSet");
- return UnimplementedErrorBuilder(IREE_LOC) << "MetalDevice::CreateDescriptorSet";
+ return make_ref<MetalArgumentBuffer>(static_cast<MetalArgumentBufferLayout*>(set_layout),
+ bindings);
}
StatusOr<ref_ptr<CommandBuffer>> MetalDevice::CreateCommandBuffer(
diff --git a/iree/hal/metal/metal_kernel_library.h b/iree/hal/metal/metal_kernel_library.h
index e9537b9..ecf9c23 100644
--- a/iree/hal/metal/metal_kernel_library.h
+++ b/iree/hal/metal/metal_kernel_library.h
@@ -46,6 +46,9 @@
bool supports_debugging() const override { return false; }
+ // Returns the MTLFunction for the entry point with the given |ordinal|.
+ StatusOr<id<MTLFunction>> GetKernelForEntryPoint(int ordinal) const;
+
// Returns the pipeline state object for the entry point with the given
// |ordinal|.
StatusOr<id<MTLComputePipelineState>> GetPipelineStateForEntryPoint(
diff --git a/iree/hal/metal/metal_kernel_library.mm b/iree/hal/metal/metal_kernel_library.mm
index e588dee..962e5e6 100644
--- a/iree/hal/metal/metal_kernel_library.mm
+++ b/iree/hal/metal/metal_kernel_library.mm
@@ -140,6 +140,13 @@
return pipelines_[ordinal];
}
+StatusOr<id<MTLFunction>> MetalKernelLibrary::GetKernelForEntryPoint(int ordinal) const {
+ if (ordinal < 0 || ordinal >= pipelines_.size()) {
+ return OutOfRangeErrorBuilder(IREE_LOC) << "Invalid entry point ordinal: " << ordinal;
+ }
+ return functions_[ordinal];
+}
+
} // namespace metal
} // namespace hal
} // namespace iree
diff --git a/iree/hal/metal/metal_pipeline_argument_buffer.cc b/iree/hal/metal/metal_pipeline_argument_buffer.cc
new file mode 100644
index 0000000..018beab
--- /dev/null
+++ b/iree/hal/metal/metal_pipeline_argument_buffer.cc
@@ -0,0 +1,81 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "iree/hal/metal/metal_pipeline_argument_buffer.h"
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+
+namespace iree {
+namespace hal {
+namespace metal {
+
+MetalArgumentBufferLayout::MetalArgumentBufferLayout(
+ DescriptorSetLayout::UsageType usage_type,
+ absl::Span<const DescriptorSetLayout::Binding> bindings)
+ : usage_type_(usage_type), bindings_(bindings.begin(), bindings.end()) {}
+
+const DescriptorSetLayout::Binding*
+MetalArgumentBufferLayout::GetBindingForIndex(int index) const {
+ for (const auto& binding : bindings_) {
+ if (binding.binding == index) return &binding;
+ }
+ return nullptr;
+}
+
+std::string MetalArgumentBufferLayout::DebugString() const {
+ std::vector<std::string> binding_strings;
+ binding_strings.reserve(bindings_.size());
+ for (const auto& binding : bindings_) {
+ binding_strings.push_back(
+ absl::StrCat("[", binding.DebugStringShort(), "]"));
+ }
+ return absl::StrCat("bindings=[", absl::StrJoin(binding_strings, ", "), "]");
+}
+
+MetalPipelineArgumentBufferLayout::MetalPipelineArgumentBufferLayout(
+ absl::Span<DescriptorSetLayout* const> set_layouts, size_t push_constants)
+ : set_layouts_(set_layouts.size()), push_constants_(push_constants) {
+ for (int i = 0; i < set_layouts.size(); ++i) {
+ set_layouts_[i] = static_cast<MetalArgumentBufferLayout*>(set_layouts[i]);
+ set_layouts_[i]->AddReference();
+ }
+}
+
+MetalPipelineArgumentBufferLayout::~MetalPipelineArgumentBufferLayout() {
+ for (auto* layout : set_layouts_) layout->ReleaseReference();
+}
+
+std::string MetalPipelineArgumentBufferLayout::DebugString() const {
+ std::vector<std::string> set_strings;
+ set_strings.reserve(set_layouts_.size());
+ for (int i = 0; i < set_layouts_.size(); ++i) {
+ set_strings.push_back(
+ absl::StrCat("{set=", i, ", ", set_layouts_[i]->DebugString(), "}"));
+ }
+ return absl::StrCat("sets={", absl::StrJoin(set_strings, "; "), "}");
+}
+
+MetalArgumentBuffer::MetalArgumentBuffer(
+ MetalArgumentBufferLayout* layout,
+ absl::Span<const DescriptorSet::Binding> resources)
+ : layout_(layout), resources_(resources.begin(), resources.end()) {
+ layout_->AddReference();
+}
+
+MetalArgumentBuffer::~MetalArgumentBuffer() { layout_->ReleaseReference(); }
+
+} // namespace metal
+} // namespace hal
+} // namespace iree
diff --git a/iree/hal/metal/metal_pipeline_argument_buffer.h b/iree/hal/metal/metal_pipeline_argument_buffer.h
new file mode 100644
index 0000000..2b349f5
--- /dev/null
+++ b/iree/hal/metal/metal_pipeline_argument_buffer.h
@@ -0,0 +1,84 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef IREE_HAL_METAL_METAL_PIPELINE_ARGUMENT_BUFFER_H_
+#define IREE_HAL_METAL_METAL_PIPELINE_ARGUMENT_BUFFER_H_
+
+#include <string>
+
+#include "absl/container/inlined_vector.h"
+#include "absl/types/span.h"
+#include "iree/hal/descriptor_set.h"
+#include "iree/hal/descriptor_set_layout.h"
+#include "iree/hal/executable_layout.h"
+
+// Metal implementaion classes for resource descriptor related interfaces.
+//
+// See docs/design_docs/metal_hal_driver.md#resource-descriptors for more
+// details.
+
+namespace iree {
+namespace hal {
+namespace metal {
+
+class MetalArgumentBufferLayout final : public DescriptorSetLayout {
+ public:
+ MetalArgumentBufferLayout(UsageType usage_type,
+ absl::Span<const Binding> bindings);
+ ~MetalArgumentBufferLayout() override = default;
+
+ absl::Span<const Binding> bindings() const { return bindings_; }
+ const Binding* GetBindingForIndex(int index) const;
+
+ std::string DebugString() const override;
+
+ private:
+ UsageType usage_type_;
+ absl::InlinedVector<Binding, 8> bindings_;
+};
+
+class MetalPipelineArgumentBufferLayout final : public ExecutableLayout {
+ public:
+ MetalPipelineArgumentBufferLayout(
+ absl::Span<DescriptorSetLayout* const> set_layouts,
+ size_t push_constants);
+ ~MetalPipelineArgumentBufferLayout() override;
+
+ absl::Span<MetalArgumentBufferLayout* const> set_layouts() const {
+ return set_layouts_;
+ }
+
+ std::string DebugString() const override;
+
+ private:
+ absl::InlinedVector<MetalArgumentBufferLayout*, 2> set_layouts_;
+ size_t push_constants_;
+};
+
+class MetalArgumentBuffer final : public DescriptorSet {
+ public:
+ MetalArgumentBuffer(MetalArgumentBufferLayout* layout,
+ absl::Span<const Binding> resources);
+ ~MetalArgumentBuffer() override;
+
+ private:
+ MetalArgumentBufferLayout* layout_;
+ absl::InlinedVector<Binding, 8> resources_;
+};
+
+} // namespace metal
+} // namespace hal
+} // namespace iree
+
+#endif // IREE_HAL_METAL_METAL_PIPELINE_ARGUMENT_BUFFER_H_