[metal] Implement descriptor set and pipeline layout APIs
We don't have direct mappings of these concepts in Metal; so just
capture the content as normal C structs for later usage.
diff --git a/experimental/metal/CMakeLists.txt b/experimental/metal/CMakeLists.txt
index ec60332..6dc91af 100644
--- a/experimental/metal/CMakeLists.txt
+++ b/experimental/metal/CMakeLists.txt
@@ -25,6 +25,8 @@
"metal_driver.m"
"metal_shared_event.h"
"metal_shared_event.m"
+ "pipeline_layout.h"
+ "pipeline_layout.m"
DEPS
iree::base
iree::base::core_headers
diff --git a/experimental/metal/cts/CMakeLists.txt b/experimental/metal/cts/CMakeLists.txt
index 6f06e17..1c50975 100644
--- a/experimental/metal/cts/CMakeLists.txt
+++ b/experimental/metal/cts/CMakeLists.txt
@@ -20,7 +20,9 @@
INCLUDED_TESTS
"allocator"
"buffer_mapping"
+ "descriptor_set_layout"
"driver"
+ "pipeline_layout"
"semaphore"
)
diff --git a/experimental/metal/metal_device.m b/experimental/metal/metal_device.m
index 8099dbf..d2db5c1 100644
--- a/experimental/metal/metal_device.m
+++ b/experimental/metal/metal_device.m
@@ -8,6 +8,7 @@
#include "experimental/metal/direct_allocator.h"
#include "experimental/metal/metal_shared_event.h"
+#include "experimental/metal/pipeline_layout.h"
#include "iree/base/api.h"
#include "iree/base/tracing.h"
#include "iree/hal/api.h"
@@ -153,7 +154,9 @@
iree_hal_device_t* base_device, iree_hal_descriptor_set_layout_flags_t flags,
iree_host_size_t binding_count, const iree_hal_descriptor_set_layout_binding_t* bindings,
iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) {
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "unimplmented descriptor set create");
+ iree_hal_metal_device_t* device = iree_hal_metal_device_cast(base_device);
+ return iree_hal_metal_descriptor_set_layout_create(device->host_allocator, flags, binding_count,
+ bindings, out_descriptor_set_layout);
}
static iree_status_t iree_hal_metal_device_create_event(iree_hal_device_t* base_device,
@@ -171,7 +174,9 @@
iree_hal_device_t* base_device, iree_host_size_t push_constants,
iree_host_size_t set_layout_count, iree_hal_descriptor_set_layout_t* const* set_layouts,
iree_hal_pipeline_layout_t** out_pipeline_layout) {
- return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "unimplmented pipeline layout create");
+ iree_hal_metal_device_t* device = iree_hal_metal_device_cast(base_device);
+ return iree_hal_metal_pipeline_layout_create(device->host_allocator, set_layout_count,
+ set_layouts, push_constants, out_pipeline_layout);
}
static iree_status_t iree_hal_metal_device_create_semaphore(iree_hal_device_t* base_device,
diff --git a/experimental/metal/pipeline_layout.h b/experimental/metal/pipeline_layout.h
new file mode 100644
index 0000000..5fc96d0
--- /dev/null
+++ b/experimental/metal/pipeline_layout.h
@@ -0,0 +1,64 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_EXPERIMENTAL_METAL_PIPELINE_LAYOUT_H_
+#define IREE_EXPERIMENTAL_METAL_PIPELINE_LAYOUT_H_
+
+#include "iree/base/api.h"
+#include "iree/hal/api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+//===----------------------------------------------------------------------===//
+// iree_hal_metal_descriptor_set_layout_t
+//===----------------------------------------------------------------------===//
+
+// Creates a descriptor set layout for the given |bindings|.
+//
+// |out_descriptor_set_layout| must be released by the caller (see
+// iree_hal_descriptor_set_layout_release).
+iree_status_t iree_hal_metal_descriptor_set_layout_create(
+ iree_allocator_t host_allocator,
+ iree_hal_descriptor_set_layout_flags_t flags,
+ iree_host_size_t binding_count,
+ const iree_hal_descriptor_set_layout_binding_t* bindings,
+ iree_hal_descriptor_set_layout_t** out_descriptor_set_layout);
+
+// Returns the information about a given |binding| in //
+// |base_descriptor_set_layout|.
+iree_hal_descriptor_set_layout_binding_t*
+iree_hal_metal_descriptor_set_layout_binding(
+ iree_hal_descriptor_set_layout_t* base_descriptor_set_layout,
+ uint32_t binding);
+
+//===----------------------------------------------------------------------===//
+// iree_hal_metal_pipeline_layout_t
+//===----------------------------------------------------------------------===//
+
+// Creates a pipeline layout with the given |set_layouts| and
+// |push_constant_count|.
+//
+// |out_pipeline_layout| must be released by the caller (see
+// iree_hal_pipeline_layout_release).
+iree_status_t iree_hal_metal_pipeline_layout_create(
+ iree_allocator_t host_allocator, iree_host_size_t set_layout_count,
+ iree_hal_descriptor_set_layout_t* const* set_layouts,
+ iree_host_size_t push_constant_count,
+ iree_hal_pipeline_layout_t** out_pipeline_layout);
+
+// Returns the descriptor set layout of the given |set| in
+// |base_pipeline_layout|.
+iree_hal_descriptor_set_layout_t*
+iree_hal_metal_pipeline_layout_descriptor_set_layout(
+ iree_hal_pipeline_layout_t* base_pipeline_layout, uint32_t set);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // IREE_EXPERIMENTAL_METAL_PIPELINE_LAYOUT_H_
diff --git a/experimental/metal/pipeline_layout.m b/experimental/metal/pipeline_layout.m
new file mode 100644
index 0000000..d778182
--- /dev/null
+++ b/experimental/metal/pipeline_layout.m
@@ -0,0 +1,167 @@
+// Copyright 2023 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "experimental/metal/pipeline_layout.h"
+
+#include <stddef.h>
+
+#include "iree/base/api.h"
+#include "iree/base/tracing.h"
+
+//===------------------------------------------------------------------------------------------===//
+// iree_hal_metal_descriptor_set_layout_t
+//===------------------------------------------------------------------------------------------===//
+
+typedef struct iree_hal_metal_descriptor_set_layout_t {
+ // Abstract resource used for injecting reference counting and vtable; must be at offset 0.
+ iree_hal_resource_t resource;
+
+ iree_allocator_t host_allocator;
+
+ iree_host_size_t binding_count;
+ iree_hal_descriptor_set_layout_binding_t bindings[];
+} iree_hal_metal_descriptor_set_layout_t;
+
+static const iree_hal_descriptor_set_layout_vtable_t iree_hal_metal_descriptor_set_layout_vtable;
+
+static iree_hal_metal_descriptor_set_layout_t* iree_hal_metal_descriptor_set_layout_cast(
+ iree_hal_descriptor_set_layout_t* base_value) {
+ IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_metal_descriptor_set_layout_vtable);
+ return (iree_hal_metal_descriptor_set_layout_t*)base_value;
+}
+
+iree_status_t iree_hal_metal_descriptor_set_layout_create(
+ iree_allocator_t host_allocator, iree_hal_descriptor_set_layout_flags_t flags,
+ iree_host_size_t binding_count, const iree_hal_descriptor_set_layout_binding_t* bindings,
+ iree_hal_descriptor_set_layout_t** out_descriptor_set_layout) {
+ IREE_ASSERT_ARGUMENT(!binding_count || bindings);
+ IREE_ASSERT_ARGUMENT(out_descriptor_set_layout);
+ *out_descriptor_set_layout = NULL;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_metal_descriptor_set_layout_t* descriptor_set_layout = NULL;
+ iree_host_size_t bindings_size = binding_count * sizeof(descriptor_set_layout->bindings[0]);
+ iree_status_t status =
+ iree_allocator_malloc(host_allocator, sizeof(*descriptor_set_layout) + bindings_size,
+ (void**)&descriptor_set_layout);
+ if (iree_status_is_ok(status)) {
+ iree_hal_resource_initialize(&iree_hal_metal_descriptor_set_layout_vtable,
+ &descriptor_set_layout->resource);
+ descriptor_set_layout->host_allocator = host_allocator;
+ descriptor_set_layout->binding_count = binding_count;
+ memcpy(descriptor_set_layout->bindings, bindings, bindings_size);
+ *out_descriptor_set_layout = (iree_hal_descriptor_set_layout_t*)descriptor_set_layout;
+ }
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+static void iree_hal_metal_descriptor_set_layout_destroy(
+ iree_hal_descriptor_set_layout_t* base_descriptor_set_layout) {
+ iree_hal_metal_descriptor_set_layout_t* descriptor_set_layout =
+ iree_hal_metal_descriptor_set_layout_cast(base_descriptor_set_layout);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_allocator_free(descriptor_set_layout->host_allocator, descriptor_set_layout);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+iree_hal_descriptor_set_layout_binding_t* iree_hal_metal_descriptor_set_layout_binding(
+ iree_hal_descriptor_set_layout_t* base_descriptor_set_layout, uint32_t binding) {
+ iree_hal_metal_descriptor_set_layout_t* descriptor_set_layout =
+ iree_hal_metal_descriptor_set_layout_cast(base_descriptor_set_layout);
+ for (iree_host_size_t i = 0; i < descriptor_set_layout->binding_count; ++i) {
+ if (descriptor_set_layout->bindings[i].binding == binding) {
+ return &descriptor_set_layout->bindings[i];
+ }
+ }
+ return NULL;
+}
+
+static const iree_hal_descriptor_set_layout_vtable_t iree_hal_metal_descriptor_set_layout_vtable = {
+ .destroy = iree_hal_metal_descriptor_set_layout_destroy,
+};
+
+//===------------------------------------------------------------------------------------------===//
+// iree_hal_metal_pipeline_layout_t
+//===------------------------------------------------------------------------------------------===//
+
+typedef struct iree_hal_metal_pipeline_layout_t {
+ // Abstract resource used for injecting reference counting and vtable; must be at offset 0.
+ iree_hal_resource_t resource;
+
+ iree_allocator_t host_allocator;
+
+ iree_host_size_t push_constant_count;
+
+ iree_host_size_t set_layout_count;
+ iree_hal_descriptor_set_layout_t* set_layouts[];
+} iree_hal_metal_pipeline_layout_t;
+
+static const iree_hal_pipeline_layout_vtable_t iree_hal_metal_pipeline_layout_vtable;
+
+static iree_hal_metal_pipeline_layout_t* iree_hal_metal_pipeline_layout_cast(
+ iree_hal_pipeline_layout_t* base_value) {
+ IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_metal_pipeline_layout_vtable);
+ return (iree_hal_metal_pipeline_layout_t*)base_value;
+}
+
+iree_status_t iree_hal_metal_pipeline_layout_create(
+ iree_allocator_t host_allocator, iree_host_size_t set_layout_count,
+ iree_hal_descriptor_set_layout_t* const* set_layouts, iree_host_size_t push_constant_count,
+ iree_hal_pipeline_layout_t** out_pipeline_layout) {
+ IREE_ASSERT_ARGUMENT(!set_layout_count || set_layouts);
+ IREE_ASSERT_ARGUMENT(out_pipeline_layout);
+ *out_pipeline_layout = NULL;
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ iree_hal_metal_pipeline_layout_t* pipeline_layout = NULL;
+ iree_host_size_t total_size =
+ sizeof(*pipeline_layout) + set_layout_count * sizeof(pipeline_layout->set_layouts[0]);
+ iree_status_t status =
+ iree_allocator_malloc(host_allocator, total_size, (void**)&pipeline_layout);
+ if (iree_status_is_ok(status)) {
+ iree_hal_resource_initialize(&iree_hal_metal_pipeline_layout_vtable,
+ &pipeline_layout->resource);
+ pipeline_layout->host_allocator = host_allocator;
+ pipeline_layout->push_constant_count = push_constant_count;
+ pipeline_layout->set_layout_count = set_layout_count;
+ for (iree_host_size_t i = 0; i < set_layout_count; ++i) {
+ pipeline_layout->set_layouts[i] = set_layouts[i];
+ iree_hal_descriptor_set_layout_retain(set_layouts[i]);
+ }
+ *out_pipeline_layout = (iree_hal_pipeline_layout_t*)pipeline_layout;
+ }
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+static void iree_hal_metal_pipeline_layout_destroy(
+ iree_hal_pipeline_layout_t* base_pipeline_layout) {
+ iree_hal_metal_pipeline_layout_t* pipeline_layout =
+ iree_hal_metal_pipeline_layout_cast(base_pipeline_layout);
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ for (iree_host_size_t i = 0; i < pipeline_layout->set_layout_count; ++i) {
+ iree_hal_descriptor_set_layout_release(pipeline_layout->set_layouts[i]);
+ }
+ iree_allocator_free(pipeline_layout->host_allocator, pipeline_layout);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+iree_hal_descriptor_set_layout_t* iree_hal_metal_pipeline_layout_descriptor_set_layout(
+ iree_hal_pipeline_layout_t* base_pipeline_layout, uint32_t set) {
+ iree_hal_metal_pipeline_layout_t* pipeline_layout =
+ iree_hal_metal_pipeline_layout_cast(base_pipeline_layout);
+ if (set < pipeline_layout->set_layout_count) return pipeline_layout->set_layouts[set];
+ return NULL;
+}
+
+static const iree_hal_pipeline_layout_vtable_t iree_hal_metal_pipeline_layout_vtable = {
+ .destroy = iree_hal_metal_pipeline_layout_destroy,
+};