blob: fb91041ba886a2add133b08af73c7306d132a88c [file] [log] [blame]
// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/hal/drivers/vulkan/native_executable.h"
#include <cstddef>
#include <cstdint>
#include <cstring>
#include "iree/base/api.h"
#include "iree/hal/drivers/vulkan/dynamic_symbol_tables.h"
#include "iree/hal/drivers/vulkan/dynamic_symbols.h"
#include "iree/hal/drivers/vulkan/handle_util.h"
#include "iree/hal/drivers/vulkan/native_pipeline_layout.h"
#include "iree/hal/drivers/vulkan/status_util.h"
#include "iree/hal/drivers/vulkan/util/ref_ptr.h"
// flatcc schemas:
#include "iree/base/internal/flatcc/parsing.h"
#include "iree/schemas/spirv_executable_def_reader.h"
#include "iree/schemas/spirv_executable_def_verifier.h"
using namespace iree::hal::vulkan;
typedef struct iree_hal_vulkan_entry_point_t {
VkPipeline pipeline;
iree_string_view_t name;
// Optional debug information.
IREE_TRACE(iree_string_view_t source_filename;)
IREE_TRACE(uint32_t source_line;)
} iree_hal_vulkan_entry_point_t;
static iree_status_t iree_hal_vulkan_create_shader_module(
VkDeviceHandle* logical_device, iree_const_byte_span_t code,
VkShaderModule* out_shader_module) {
IREE_TRACE_SCOPE();
VkShaderModuleCreateInfo create_info;
create_info.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
create_info.pNext = NULL;
create_info.flags = 0;
create_info.codeSize = code.data_length;
create_info.pCode = (const uint32_t*)code.data;
VK_RETURN_IF_ERROR(logical_device->syms()->vkCreateShaderModule(
*logical_device, &create_info,
logical_device->allocator(), out_shader_module),
"vkCreateShaderModule");
return iree_ok_status();
}
static void iree_hal_vulkan_destroy_shader_module(
VkDeviceHandle* logical_device, VkShaderModule handle) {
if (handle == VK_NULL_HANDLE) return;
logical_device->syms()->vkDestroyShaderModule(*logical_device, handle,
logical_device->allocator());
}
static iree_status_t iree_hal_vulkan_create_pipelines(
VkDeviceHandle* logical_device, VkPipelineCache pipeline_cache,
const iree_hal_executable_params_t* executable_params,
iree_hal_spirv_ExecutableDef_table_t executable_def,
VkShaderModule shader_module, iree_host_size_t pipeline_count,
iree_hal_vulkan_entry_point_t* out_entry_points) {
IREE_TRACE_SCOPE();
uint8_t* scratch_memory = NULL;
size_t create_info_size =
pipeline_count * sizeof(VkComputePipelineCreateInfo);
size_t spec_map_size =
executable_params->constant_count * sizeof(VkSpecializationMapEntry);
size_t subgroup_control_size =
pipeline_count *
sizeof(VkPipelineShaderStageRequiredSubgroupSizeCreateInfo);
IREE_RETURN_IF_ERROR(iree_allocator_malloc(
logical_device->host_allocator(),
create_info_size + spec_map_size + subgroup_control_size,
(void**)&scratch_memory));
VkComputePipelineCreateInfo* create_infos =
(VkComputePipelineCreateInfo*)scratch_memory;
VkSpecializationMapEntry* spec_map_entries =
(VkSpecializationMapEntry*)(scratch_memory + create_info_size);
VkPipelineShaderStageRequiredSubgroupSizeCreateInfo* subgroup_control_entries =
(VkPipelineShaderStageRequiredSubgroupSizeCreateInfo*)(scratch_memory +
create_info_size +
spec_map_size);
VkSpecializationInfo spec_info;
memset(&spec_info, 0, sizeof(spec_info));
spec_info.mapEntryCount = executable_params->constant_count;
spec_info.pMapEntries = spec_map_entries;
spec_info.dataSize = executable_params->constant_count * sizeof(uint32_t);
spec_info.pData = executable_params->constants;
for (iree_host_size_t i = 0; i < executable_params->constant_count; ++i) {
spec_map_entries[i].constantID = i;
spec_map_entries[i].offset = i * sizeof(uint32_t);
spec_map_entries[i].size = sizeof(uint32_t);
}
flatbuffers_string_vec_t entry_points_vec =
iree_hal_spirv_ExecutableDef_entry_points_get(executable_def);
flatbuffers_uint32_vec_t subgroup_sizes_vec =
iree_hal_spirv_ExecutableDef_subgroup_sizes_get(executable_def);
for (iree_host_size_t entry_ordinal = 0; entry_ordinal < pipeline_count;
++entry_ordinal) {
VkComputePipelineCreateInfo* create_info = &create_infos[entry_ordinal];
create_info->sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
create_info->pNext = NULL;
create_info->flags = 0;
if (!iree_all_bits_set(
executable_params->caching_mode,
IREE_HAL_EXECUTABLE_CACHING_MODE_ALLOW_OPTIMIZATION)) {
create_info->flags |= VK_PIPELINE_CREATE_DISABLE_OPTIMIZATION_BIT;
}
if (entry_ordinal == 0) {
create_info->flags |= VK_PIPELINE_CREATE_ALLOW_DERIVATIVES_BIT;
} else {
create_info->flags |= VK_PIPELINE_CREATE_DERIVATIVE_BIT;
}
create_info->layout = iree_hal_vulkan_native_pipeline_layout_handle(
executable_params->pipeline_layouts[entry_ordinal]);
create_info->basePipelineHandle = VK_NULL_HANDLE;
create_info->basePipelineIndex = 0;
VkPipelineShaderStageCreateInfo* stage_create_info = &create_info->stage;
stage_create_info->sType =
VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
stage_create_info->flags = 0;
stage_create_info->stage = VK_SHADER_STAGE_COMPUTE_BIT;
stage_create_info->module = shader_module;
stage_create_info->pName =
flatbuffers_string_vec_at(entry_points_vec, entry_ordinal);
stage_create_info->pSpecializationInfo = &spec_info;
// If subgroup size is not 0, request the said subgroup size via
// VK_EXT_subgroup_size_control (promoted to core since v1.3).
stage_create_info->pNext = NULL;
if (subgroup_sizes_vec) {
if (uint32_t subgroup_size =
flatbuffers_uint32_vec_at(subgroup_sizes_vec, entry_ordinal)) {
subgroup_control_entries[entry_ordinal].sType =
VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_REQUIRED_SUBGROUP_SIZE_CREATE_INFO;
subgroup_control_entries[entry_ordinal].pNext = NULL;
subgroup_control_entries[entry_ordinal].requiredSubgroupSize =
subgroup_size;
stage_create_info->pNext = &subgroup_control_entries[entry_ordinal];
}
}
}
VkPipeline* pipelines =
(VkPipeline*)iree_alloca(pipeline_count * sizeof(VkPipeline));
iree_status_t status = VK_RESULT_TO_STATUS(
logical_device->syms()->vkCreateComputePipelines(
*logical_device, pipeline_cache, (uint32_t)pipeline_count,
create_infos, logical_device->allocator(), pipelines),
"vkCreateComputePipelines");
if (iree_status_is_ok(status)) {
for (iree_host_size_t i = 0; i < pipeline_count; ++i) {
out_entry_points[i].pipeline = pipelines[i];
// Set pipeline name for tooling.
if (PFN_vkSetDebugUtilsObjectNameEXT set_name =
logical_device->syms()->vkSetDebugUtilsObjectNameEXT) {
VkDebugUtilsObjectNameInfoEXT name_info = {};
name_info.sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_OBJECT_NAME_INFO_EXT;
name_info.pNext = NULL;
name_info.objectHandle = (uint64_t)pipelines[i];
name_info.objectType = VK_OBJECT_TYPE_PIPELINE;
name_info.pObjectName = flatbuffers_string_vec_at(entry_points_vec, i);
set_name(*logical_device, &name_info);
}
}
}
iree_allocator_free(logical_device->host_allocator(), scratch_memory);
return status;
}
static void iree_hal_vulkan_destroy_pipeline(VkDeviceHandle* logical_device,
VkPipeline handle) {
IREE_TRACE_SCOPE();
if (handle == VK_NULL_HANDLE) return;
logical_device->syms()->vkDestroyPipeline(*logical_device, handle,
logical_device->allocator());
}
// Verifies the structure of the FlatBuffer so that we can avoid doing so during
// runtime. There are still some conditions we must be aware of (such as omitted
// names on functions with internal linkage), however we shouldn't need to
// bounds check anything within the FlatBuffer after this succeeds.
static iree_status_t iree_hal_spirv_executable_flatbuffer_verify(
iree_const_byte_span_t flatbuffer_data,
iree_host_size_t expected_entry_point_count) {
if (!flatbuffer_data.data || flatbuffer_data.data_length < 16) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"FlatBuffer data is not present or less than 16 bytes (%" PRIhsz
" total)",
flatbuffer_data.data_length);
}
// Run flatcc generated verification. This ensures all pointers are in-bounds
// and that we can safely walk the file, but not that the actual contents of
// the FlatBuffer meet our expectations.
int verify_ret = iree_hal_spirv_ExecutableDef_verify_as_root(
flatbuffer_data.data, flatbuffer_data.data_length);
if (verify_ret != flatcc_verify_ok) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"FlatBuffer verification failed: %s",
flatcc_verify_error_string(verify_ret));
}
iree_hal_spirv_ExecutableDef_table_t executable_def =
iree_hal_spirv_ExecutableDef_as_root(flatbuffer_data.data);
flatbuffers_string_vec_t entry_points_vec =
iree_hal_spirv_ExecutableDef_entry_points_get(executable_def);
size_t entry_point_count = flatbuffers_string_vec_len(entry_points_vec);
if (entry_point_count != expected_entry_point_count) {
return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
"executable provides %zu entry points but caller "
"provided %" PRIhsz "; must match",
entry_point_count, expected_entry_point_count);
}
for (size_t i = 0; i < entry_point_count; ++i) {
if (!flatbuffers_string_len(
flatbuffers_string_vec_at(entry_points_vec, i))) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"executable entry point %zu has no name", i);
}
}
flatbuffers_uint32_vec_t subgroup_sizes_vec =
iree_hal_spirv_ExecutableDef_subgroup_sizes_get(executable_def);
if (subgroup_sizes_vec) {
size_t subgroup_sizes_count = flatbuffers_vec_len(subgroup_sizes_vec);
if (subgroup_sizes_count != expected_entry_point_count) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"executable has %" PRIhsz
" entry points but %zu subgroup sizes are defined",
expected_entry_point_count, subgroup_sizes_count);
}
}
if (flatbuffers_uint32_vec_len(
iree_hal_spirv_ExecutableDef_code_get(executable_def)) == 0) {
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"executable SPIR-V code is missing/empty");
}
return iree_ok_status();
}
typedef struct iree_hal_vulkan_native_executable_t {
iree_hal_resource_t resource;
VkDeviceHandle* logical_device;
iree_host_size_t entry_point_count;
iree_hal_vulkan_entry_point_t entry_points[];
} iree_hal_vulkan_native_executable_t;
namespace {
extern const iree_hal_executable_vtable_t
iree_hal_vulkan_native_executable_vtable;
} // namespace
static iree_hal_vulkan_native_executable_t*
iree_hal_vulkan_native_executable_cast(iree_hal_executable_t* base_value) {
IREE_HAL_ASSERT_TYPE(base_value, &iree_hal_vulkan_native_executable_vtable);
return (iree_hal_vulkan_native_executable_t*)base_value;
}
iree_status_t iree_hal_vulkan_native_executable_create(
iree::hal::vulkan::VkDeviceHandle* logical_device,
VkPipelineCache pipeline_cache,
const iree_hal_executable_params_t* executable_params,
iree_hal_executable_t** out_executable) {
IREE_ASSERT_ARGUMENT(logical_device);
IREE_ASSERT_ARGUMENT(executable_params);
IREE_ASSERT_ARGUMENT(out_executable);
*out_executable = NULL;
IREE_TRACE_ZONE_BEGIN(z0);
// Verify and fetch the executable FlatBuffer wrapper.
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_spirv_executable_flatbuffer_verify(
executable_params->executable_data,
executable_params->pipeline_layout_count));
iree_hal_spirv_ExecutableDef_table_t executable_def =
iree_hal_spirv_ExecutableDef_as_root(
executable_params->executable_data.data);
// Create the shader module.
flatbuffers_uint32_vec_t code_vec =
iree_hal_spirv_ExecutableDef_code_get(executable_def);
VkShaderModule shader_module = VK_NULL_HANDLE;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_vulkan_create_shader_module(
logical_device,
iree_make_const_byte_span(
code_vec,
flatbuffers_uint32_vec_len(code_vec) * sizeof(uint32_t)),
&shader_module));
// Create pipelines for each entry point.
flatbuffers_string_vec_t entry_points_vec =
iree_hal_spirv_ExecutableDef_entry_points_get(executable_def);
iree_host_size_t entry_point_count =
flatbuffers_string_vec_len(entry_points_vec);
iree_hal_vulkan_native_executable_t* executable = NULL;
iree_host_size_t total_size =
sizeof(*executable) +
entry_point_count * sizeof(*executable->entry_points);
iree_status_t status = iree_allocator_malloc(logical_device->host_allocator(),
total_size, (void**)&executable);
if (iree_status_is_ok(status)) {
iree_hal_resource_initialize(&iree_hal_vulkan_native_executable_vtable,
&executable->resource);
executable->logical_device = logical_device;
executable->entry_point_count = entry_point_count;
memset(executable->entry_points, 0,
entry_point_count * sizeof(*executable->entry_points));
}
if (iree_status_is_ok(status)) {
status = iree_hal_vulkan_create_pipelines(
logical_device, pipeline_cache, executable_params, executable_def,
shader_module, executable->entry_point_count, executable->entry_points);
}
iree_hal_vulkan_destroy_shader_module(logical_device, shader_module);
if (iree_status_is_ok(status)) {
flatbuffers_string_vec_t entry_points_vec =
iree_hal_spirv_ExecutableDef_entry_points_get(executable_def);
for (iree_host_size_t i = 0; i < entry_point_count; ++i) {
flatbuffers_string_t name =
flatbuffers_string_vec_at(entry_points_vec, i);
executable->entry_points[i].name =
iree_make_string_view(name, flatbuffers_string_len(name));
IREE_TRACE_ZONE_APPEND_TEXT(z0, name);
}
}
IREE_TRACE({
if (iree_status_is_ok(status) &&
iree_hal_spirv_ExecutableDef_source_locations_is_present(
executable_def)) {
iree_hal_spirv_FileLineLocDef_vec_t source_locs_vec =
iree_hal_spirv_ExecutableDef_source_locations_get(executable_def);
for (iree_host_size_t i = 0; i < entry_point_count; ++i) {
iree_hal_spirv_FileLineLocDef_table_t source_loc =
iree_hal_spirv_FileLineLocDef_vec_at(source_locs_vec, i);
flatbuffers_string_t filename =
iree_hal_spirv_FileLineLocDef_filename_get(source_loc);
uint32_t line = iree_hal_spirv_FileLineLocDef_line_get(source_loc);
executable->entry_points[i].source_filename =
iree_make_string_view(filename, flatbuffers_string_len(filename));
executable->entry_points[i].source_line = line;
}
}
});
if (iree_status_is_ok(status)) {
*out_executable = (iree_hal_executable_t*)executable;
} else {
iree_hal_executable_destroy((iree_hal_executable_t*)executable);
}
IREE_TRACE_ZONE_END(z0);
return status;
}
static void iree_hal_vulkan_native_executable_destroy(
iree_hal_executable_t* base_executable) {
iree_hal_vulkan_native_executable_t* executable =
iree_hal_vulkan_native_executable_cast(base_executable);
iree_allocator_t host_allocator =
executable->logical_device->host_allocator();
IREE_TRACE_ZONE_BEGIN(z0);
for (iree_host_size_t i = 0; i < executable->entry_point_count; ++i) {
iree_hal_vulkan_destroy_pipeline(executable->logical_device,
executable->entry_points[i].pipeline);
}
iree_allocator_free(host_allocator, executable);
IREE_TRACE_ZONE_END(z0);
}
void iree_hal_vulkan_native_executable_entry_point_source_location(
iree_hal_executable_t* base_executable, iree_host_size_t entry_ordinal,
iree_hal_vulkan_source_location_t* out_source_location) {
iree_hal_vulkan_native_executable_t* executable =
iree_hal_vulkan_native_executable_cast(base_executable);
memset(out_source_location, 0, sizeof(*out_source_location));
if (entry_ordinal >= executable->entry_point_count) {
return;
}
iree_hal_vulkan_entry_point_t entry_point =
executable->entry_points[entry_ordinal];
out_source_location->func_name = entry_point.name;
out_source_location->file_name = out_source_location->func_name;
out_source_location->line = 0;
IREE_TRACE({
out_source_location->file_name = entry_point.source_filename;
out_source_location->line = entry_point.source_line;
});
}
iree_status_t iree_hal_vulkan_native_executable_pipeline_for_entry_point(
iree_hal_executable_t* base_executable, iree_host_size_t entry_ordinal,
VkPipeline* out_pipeline_handle) {
iree_hal_vulkan_native_executable_t* executable =
iree_hal_vulkan_native_executable_cast(base_executable);
if (entry_ordinal >= executable->entry_point_count) {
return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
"invalid entry point ordinal %" PRIhsz,
entry_ordinal);
}
*out_pipeline_handle = executable->entry_points[entry_ordinal].pipeline;
return iree_ok_status();
}
namespace {
const iree_hal_executable_vtable_t iree_hal_vulkan_native_executable_vtable = {
/*.destroy=*/iree_hal_vulkan_native_executable_destroy,
};
} // namespace