blob: 156b7d3bb7beeaed6c7ed34dbb15fc4dfc2a77d9 [file]
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "experimental/metal/kernel_library.h"
#include <stddef.h>
#include "iree/base/api.h"
// flatcc schemas:
#include "iree/base/internal/flatcc/parsing.h"
#include "iree/schemas/metal_executable_def_reader.h"
#include "iree/schemas/metal_executable_def_verifier.h"
typedef struct iree_hal_metal_kernel_library_t {
// Abstract resource used for injecting reference counting and vtable; must be at offset 0.
iree_hal_resource_t resource;
iree_allocator_t host_allocator;
iree_host_size_t entry_point_count;
iree_hal_metal_kernel_params_t entry_points[];
} iree_hal_metal_kernel_library_t;
static const iree_hal_executable_vtable_t iree_hal_metal_kernel_library_vtable;
static iree_hal_metal_kernel_library_t* iree_hal_metal_kernel_library_cast(
iree_hal_executable_t* base_value) {
IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_metal_kernel_library_vtable);
return (iree_hal_metal_kernel_library_t*)base_value;
}
static const iree_hal_metal_kernel_library_t* iree_hal_metal_kernel_library_const_cast(
const iree_hal_executable_t* base_value) {
IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_metal_kernel_library_vtable);
return (const iree_hal_metal_kernel_library_t*)base_value;
}
// Verifies the structure of the flatbuffer so that we can avoid doing so during runtime.
//
// There are still some conditions we must be aware of (such as omitted names on functions with
// internal linkage), however we shouldn't need to bounds check anything within the flatbuffer
// after this succeeds.
static iree_status_t iree_hal_metal_kernel_library_flatbuffer_verify(
iree_const_byte_span_t flatbuffer_data) {
if (!flatbuffer_data.data || flatbuffer_data.data_length < 16) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"flatbuffer data is not present or less than 16 bytes (%zu total)",
flatbuffer_data.data_length);
}
// Run flatcc generated verification. This ensures all pointers are in-bounds and that we can
// safely walk the file, but not that the actual contents of the flatbuffer meet our expectations.
int verify_ret = iree_hal_metal_ExecutableDef_verify_as_root(flatbuffer_data.data,
flatbuffer_data.data_length);
if (verify_ret != flatcc_verify_ok) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "flatbuffer verification failed: %s",
flatcc_verify_error_string(verify_ret));
}
iree_hal_metal_ExecutableDef_table_t executable_def =
iree_hal_metal_ExecutableDef_as_root(flatbuffer_data.data);
flatbuffers_string_vec_t entry_points_vec =
iree_hal_metal_ExecutableDef_entry_points_get(executable_def);
size_t entry_point_count = flatbuffers_string_vec_len(entry_points_vec);
for (size_t i = 0; i < entry_point_count; ++i) {
if (!flatbuffers_string_len(flatbuffers_string_vec_at(entry_points_vec, i))) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"executable entry point %zu has no name", i);
}
}
iree_hal_metal_ThreadgroupSize_vec_t threadgroup_sizes_vec =
iree_hal_metal_ExecutableDef_threadgroup_sizes(executable_def);
size_t threadgroup_size_count = iree_hal_metal_ThreadgroupSize_vec_len(threadgroup_sizes_vec);
if (!threadgroup_size_count) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "no threadgroup sizes present");
}
if (entry_point_count != threadgroup_size_count) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"entry points (%zu) and thread group sizes (%zu) count mismatch",
entry_point_count, threadgroup_size_count);
}
flatbuffers_string_vec_t shader_libraries_vec =
iree_hal_metal_ExecutableDef_shader_libraries_get(executable_def);
size_t shader_library_count = flatbuffers_string_vec_len(shader_libraries_vec);
for (size_t i = 0; i < shader_library_count; ++i) {
if (!flatbuffers_string_len(flatbuffers_string_vec_at(shader_libraries_vec, i))) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"executable shader library %zu is empty", i);
}
}
if (shader_library_count != 0 && entry_point_count != shader_library_count) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"entry points (%zu) and source libraries (%zu) count mismatch",
entry_point_count, shader_library_count);
}
flatbuffers_string_vec_t shader_sources_vec =
iree_hal_metal_ExecutableDef_shader_sources_get(executable_def);
size_t shader_source_count = flatbuffers_string_vec_len(shader_sources_vec);
for (size_t i = 0; i < shader_source_count; ++i) {
if (!flatbuffers_string_len(flatbuffers_string_vec_at(shader_sources_vec, i))) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "executable shader source %zu is empty",
i);
}
}
if (shader_source_count != 0 && entry_point_count != shader_source_count) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"entry points (%zu) and source strings (%zu) count mismatch",
entry_point_count, shader_source_count);
}
if (!shader_library_count && !shader_source_count) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"missing shader library or source strings");
}
return iree_ok_status();
}
// Returns an invalid argument status with proper Metal NSError annotations during compute pipeline
// creation.
static iree_status_t iree_hal_metal_get_invalid_kernel_status(const char* iree_error_template,
const char* metal_error_template,
NSError* ns_error,
iree_string_view_t entry_point,
const char* shader_source) {
iree_status_t status = iree_make_status(IREE_STATUS_INVALID_ARGUMENT, iree_error_template);
const char* ns_c_error = [ns_error.localizedDescription
cStringUsingEncoding:[NSString defaultCStringEncoding]]; // autoreleased
status = iree_status_annotate_f(status, metal_error_template, ns_c_error);
if (shader_source) {
return iree_status_annotate_f(status, "for entry point '%.*s' in MSL source:\n%s\n",
(int)entry_point.size, entry_point.data, shader_source);
}
return iree_status_annotate_f(status, "for entry point '%.*s' in MTLLibrary\n",
(int)entry_point.size, entry_point.data);
}
// Compiles the given |entry_point| in the MSL |source_code| into MTLLibrary and writes to
// |out_library|. The caller should release |out_library| after done.
iree_status_t iree_hal_metal_compile_msl(iree_string_view_t source_code,
iree_string_view_t entry_point, id<MTLDevice> device,
MTLCompileOptions* compile_options,
id<MTLLibrary>* out_library) {
@autoreleasepool {
NSError* error = nil;
NSString* shader_source =
[[[NSString alloc] initWithBytes:source_code.data
length:source_code.size
encoding:[NSString defaultCStringEncoding]] autorelease];
*out_library = [device newLibraryWithSource:shader_source
options:compile_options
error:&error]; // +1
if (IREE_UNLIKELY(*out_library == nil)) {
return iree_hal_metal_get_invalid_kernel_status(
"failed to create MTLLibrary from shader source",
"when creating MTLLibrary with NSError: %.*s", error, entry_point, source_code.data);
}
}
return iree_ok_status();
}
// Compiles the given |entry_point| in the MSL library |source_data| into MTLLibrary and writes to
// |out_library|. The caller should release |out_library| after done.
static iree_status_t iree_hal_metal_load_mtllib(iree_const_byte_span_t source_data,
iree_string_view_t entry_point,
id<MTLDevice> device, id<MTLLibrary>* out_library) {
@autoreleasepool {
NSError* error = nil;
dispatch_data_t data = dispatch_data_create(source_data.data, source_data.data_length,
/*queue=*/NULL, DISPATCH_DATA_DESTRUCTOR_DEFAULT);
*out_library = [device newLibraryWithData:data error:&error]; // +1
if (IREE_UNLIKELY(*out_library == nil)) {
return iree_hal_metal_get_invalid_kernel_status(
"failed to create MTLLibrary from shader source",
"when creating MTLLibrary with NSError: %s", error, entry_point, NULL);
}
}
return iree_ok_status();
}
// Creates MTL compute pipeline objects for the given |entry_point| in |library| and writes to
// |out_function| and |out_pso|. The caller should release |out_function| and |out_pso| after done.
static iree_status_t iree_hal_metal_create_pipline_object(
id<MTLLibrary> library, iree_string_view_t entry_point, const char* source_code,
id<MTLDevice> device, id<MTLFunction>* out_function, id<MTLComputePipelineState>* out_pso) {
@autoreleasepool {
NSError* error = nil;
NSString* function_name =
[[[NSString alloc] initWithBytes:entry_point.data
length:entry_point.size
encoding:[NSString defaultCStringEncoding]] autorelease];
*out_function = [library newFunctionWithName:function_name]; // +1
if (IREE_UNLIKELY(*out_function == nil)) {
return iree_hal_metal_get_invalid_kernel_status("cannot find entry point in shader source",
"when creating MTLFunction with NSError: %s",
error, entry_point, source_code);
}
// TODO(#14047): Enable async pipeline creation at runtime.
*out_pso = [device newComputePipelineStateWithFunction:*out_function error:&error]; // +1
if (IREE_UNLIKELY(*out_pso == nil)) {
[*out_function release];
return iree_hal_metal_get_invalid_kernel_status(
"invalid shader source", "when creating MTLComputePipelineState with NSError: %s", error,
entry_point, source_code);
}
}
return iree_ok_status();
}
iree_status_t iree_hal_metal_compile_msl_and_create_pipeline_object(
iree_string_view_t source_code, iree_string_view_t entry_point, id<MTLDevice> device,
MTLCompileOptions* compile_options, id<MTLLibrary>* out_library, id<MTLFunction>* out_function,
id<MTLComputePipelineState>* out_pso) {
IREE_RETURN_IF_ERROR(
iree_hal_metal_compile_msl(source_code, entry_point, device, compile_options, out_library));
return iree_hal_metal_create_pipline_object(*out_library, entry_point, source_code.data, device,
out_function, out_pso);
}
iree_status_t iree_hal_metal_kernel_library_create(
id<MTLDevice> device, const iree_hal_executable_params_t* executable_params,
iree_allocator_t host_allocator, iree_hal_executable_t** out_executable) {
IREE_ASSERT_ARGUMENT(executable_params);
IREE_ASSERT_ARGUMENT(out_executable);
*out_executable = NULL;
IREE_TRACE_ZONE_BEGIN(z0);
iree_hal_metal_kernel_library_t* executable = NULL;
IREE_RETURN_IF_ERROR(
iree_hal_metal_kernel_library_flatbuffer_verify(executable_params->executable_data));
iree_hal_metal_ExecutableDef_table_t executable_def =
iree_hal_metal_ExecutableDef_as_root(executable_params->executable_data.data);
flatbuffers_string_vec_t entry_points_vec =
iree_hal_metal_ExecutableDef_entry_points_get(executable_def);
iree_hal_metal_ThreadgroupSize_vec_t threadgroup_sizes_vec =
iree_hal_metal_ExecutableDef_threadgroup_sizes(executable_def);
flatbuffers_string_vec_t shader_libraries_vec =
iree_hal_metal_ExecutableDef_shader_libraries_get(executable_def);
flatbuffers_string_vec_t shader_sources_vec =
iree_hal_metal_ExecutableDef_shader_sources_get(executable_def);
iree_host_size_t entry_point_count = flatbuffers_string_vec_len(entry_points_vec);
// Calculate the total number of characters across all entry point names. This is only required
// when tracing so that we can store copies of the names as the flatbuffer storing the strings
// may be released while the executable is still live.
iree_host_size_t total_entry_point_name_chars = 0;
IREE_TRACE({
for (iree_host_size_t i = 0; i < entry_point_count; i++) {
const char* entry_name = flatbuffers_string_vec_at(entry_points_vec, i);
total_entry_point_name_chars += flatbuffers_string_len(entry_name);
}
});
// Create the kernel library.
iree_host_size_t total_size = sizeof(*executable) +
entry_point_count * sizeof(executable->entry_points[0]) +
total_entry_point_name_chars;
iree_status_t status = iree_allocator_malloc(host_allocator, total_size, (void**)&executable);
IREE_TRACE(char* string_table_buffer =
(char*)((char*)executable + sizeof(*executable) +
entry_point_count * sizeof(executable->entry_points[0])));
if (iree_status_is_ok(status)) {
iree_hal_resource_initialize(&iree_hal_metal_kernel_library_vtable, &executable->resource);
executable->host_allocator = host_allocator;
executable->entry_point_count = entry_point_count;
size_t shader_library_count = flatbuffers_string_vec_len(shader_libraries_vec);
size_t shader_source_count = flatbuffers_string_vec_len(shader_sources_vec);
// Try to load as Metal library first. Otherwise, compile each MSL source string into a
// MTLLibrary and get the MTLFunction for the entry point to build the pipeline state object.
// TODO(#14047): Enable async MSL compilation at runtime.
MTLCompileOptions* compile_options = [MTLCompileOptions new]; // +1
compile_options.languageVersion = MTLLanguageVersion3_0;
for (size_t i = 0, e = iree_max(shader_library_count, shader_source_count); i < e; ++i) {
id<MTLLibrary> library = nil;
id<MTLFunction> function = nil;
id<MTLComputePipelineState> pso = nil;
flatbuffers_string_t source_code = NULL;
flatbuffers_string_t entry_point = flatbuffers_string_vec_at(entry_points_vec, i);
iree_string_view_t entry_point_view =
iree_make_string_view(entry_point, flatbuffers_string_len(entry_point));
if (shader_library_count != 0) {
flatbuffers_string_t source_library = flatbuffers_string_vec_at(shader_libraries_vec, i);
status = iree_hal_metal_load_mtllib(
iree_make_const_byte_span(source_library, flatbuffers_string_len(source_library)),
entry_point_view, device, &library);
} else {
source_code = flatbuffers_string_vec_at(shader_sources_vec, i);
status = iree_hal_metal_compile_msl(
iree_make_string_view(source_code, flatbuffers_string_len(source_code)),
entry_point_view, device, compile_options, &library);
}
if (!iree_status_is_ok(status)) break;
status = iree_hal_metal_create_pipline_object(library, entry_point_view, source_code, device,
&function, &pso);
if (!iree_status_is_ok(status)) break;
// Package required parameters for kernel launches for each entry point.
iree_hal_metal_kernel_params_t* params = &executable->entry_points[i];
params->library = library;
params->function = function;
params->pso = pso;
params->threadgroup_size[0] = threadgroup_sizes_vec[i].x;
params->threadgroup_size[1] = threadgroup_sizes_vec[i].y;
params->threadgroup_size[2] = threadgroup_sizes_vec[i].z;
params->layout = executable_params->pipeline_layouts[i];
iree_hal_pipeline_layout_retain(params->layout);
// Stash the entry point name in the string table for use when tracing.
IREE_TRACE({
iree_host_size_t entry_name_length = flatbuffers_string_len(entry_point);
memcpy(string_table_buffer, entry_point, entry_name_length);
params->function_name = iree_make_string_view(string_table_buffer, entry_name_length);
string_table_buffer += entry_name_length;
});
}
[compile_options release]; // -1
}
if (iree_status_is_ok(status)) {
*out_executable = (iree_hal_executable_t*)executable;
} else {
iree_hal_executable_destroy((iree_hal_executable_t*)executable);
}
IREE_TRACE_ZONE_END(z0);
return status;
}
static void iree_hal_metal_kernel_library_destroy(iree_hal_executable_t* base_executable) {
iree_hal_metal_kernel_library_t* executable = iree_hal_metal_kernel_library_cast(base_executable);
IREE_TRACE_ZONE_BEGIN(z0);
for (iree_host_size_t i = 0; i < executable->entry_point_count; ++i) {
iree_hal_metal_kernel_params_t* entry_point = &executable->entry_points[i];
[entry_point->pso release]; // -1
[entry_point->function release]; // -1
[entry_point->library release]; // -1
iree_hal_pipeline_layout_release(entry_point->layout);
}
iree_allocator_free(executable->host_allocator, executable);
IREE_TRACE_ZONE_END(z0);
}
iree_status_t iree_hal_metal_kernel_library_entry_point_kernel_params(
const iree_hal_executable_t* base_executable, int32_t entry_point,
iree_hal_metal_kernel_params_t* out_params) {
const iree_hal_metal_kernel_library_t* executable =
iree_hal_metal_kernel_library_const_cast(base_executable);
if (entry_point >= executable->entry_point_count) {
return iree_make_status(IREE_STATUS_OUT_OF_RANGE, "invalid entry point ordinal %d",
entry_point);
}
memcpy(out_params, &executable->entry_points[entry_point], sizeof(*out_params));
return iree_ok_status();
}
static const iree_hal_executable_vtable_t iree_hal_metal_kernel_library_vtable = {
.destroy = iree_hal_metal_kernel_library_destroy,
};