[vulkan] Enable initial executable linking (#15802)
This commit turns on executable linking for Vulkan. Specifically, it
merges `spirv.module` ops in `hal.executable.variant` ops which has the
same target attribute. Each `hal.executable.variant` op is then
serialized into one final FlatBuffer executable. Inside, it now then
contains multiple SPIR-V blobs. Each entry point gets an index pointing
to one of the SPIR-V blobs.
Fixes https://github.com/openxla/iree/issues/7824
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
index 476b7dd..5c09b4b 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -678,6 +678,17 @@
});
}
+// NOTE: this runs on the top-level program module containing all hal.executable
+// ops.
+void buildSPIRVLinkingPassPipeline(OpPassManager &passManager) {
+ // Link together executables. This may produce some IR duplication.
+ passManager.addPass(createSPIRVLinkExecutablesPass());
+
+ // Cleanup IR duplication.
+ passManager.addNestedPass<IREE::HAL::ExecutableOp>(
+ mlir::createCanonicalizerPass());
+}
+
//===---------------------------------------------------------------------===//
// Register SPIRV Passes
//===---------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
index 9b8bcbe..cf0938c 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
@@ -109,6 +109,9 @@
/// Links SPIR-V HAL executables within the top-level program module.
std::unique_ptr<OperationPass<mlir::ModuleOp>> createSPIRVLinkExecutablesPass();
+/// Populates passes needed to link HAL executables across SPIRV targets.
+void buildSPIRVLinkingPassPipeline(OpPassManager &passManager);
+
/// Pass to set the lowering strategy for the target variant.
std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
createSPIRVSelectLoweringStrategyPass();
@@ -180,7 +183,7 @@
ArrayRef<int64_t> workgroupSize);
//----------------------------------------------------------------------------//
-// Register SPIRV Passes
+// Register SPIR-V Passes
//----------------------------------------------------------------------------//
void registerCodegenSPIRVPasses();
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp
index 4d1e409..de65e2c 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp
@@ -27,6 +27,19 @@
if (sourceExecutableOps.size() <= 1)
return;
+ // Retain only non-external source executables. Linking right now happens as
+ // placing spirv.module ops into the same hal.executable.variant ops.
+ // External source executables won't have any spirv.modules inside.
+ int retainSize = 0;
+ for (int i = 0, e = sourceExecutableOps.size(); i < e; ++i) {
+ IREE::HAL::ExecutableOp executable = sourceExecutableOps[i];
+ if (llvm::none_of(executable.getOps<IREE::HAL::ExecutableVariantOp>(),
+ [](auto op) { return op.getObjects().has_value(); })) {
+ sourceExecutableOps[retainSize++] = executable;
+ }
+ }
+ sourceExecutableOps.resize(retainSize);
+
// Guess a module name, if needed, to make the output files readable.
std::string moduleName = guessModuleName(moduleOp, "spirv_module");
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
index 09ac284..c823020 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
@@ -162,6 +162,10 @@
buildSPIRVCodegenPassPipeline(passManager, /*enableFastMath=*/false);
}
+ void buildLinkingPassPipeline(OpPassManager &passManager) override {
+ buildSPIRVLinkingPassPipeline(passManager);
+ }
+
LogicalResult serializeExecutable(const SerializationOptions &options,
IREE::HAL::ExecutableVariantOp variantOp,
OpBuilder &executableBuilder) override {
@@ -174,46 +178,61 @@
ModuleOp innerModuleOp = variantOp.getInnerModule();
auto spirvModuleOps = innerModuleOp.getOps<spirv::ModuleOp>();
- if (!llvm::hasSingleElement(spirvModuleOps)) {
- return variantOp.emitError()
- << "should only contain exactly one spirv.module op";
- }
- auto spvModuleOp = *spirvModuleOps.begin();
- if (!options.dumpIntermediatesPath.empty()) {
- std::string assembly;
- llvm::raw_string_ostream os(assembly);
- spvModuleOp.print(os, OpPrintingFlags().useLocalScope());
- dumpDataToPath(options.dumpIntermediatesPath, options.dumpBaseName,
- variantOp.getName(), ".mlir", assembly);
+ if (spirvModuleOps.empty()) {
+ return variantOp.emitError() << "should contain some spirv.module ops";
}
FlatbufferBuilder builder;
iree_hal_spirv_ExecutableDef_start_as_root(builder);
- // Serialize the spirv::ModuleOp into the binary that we will embed in the
- // final FlatBuffer.
- SmallVector<uint32_t, 256> spvBinary;
- if (failed(spirv::serialize(spvModuleOp, spvBinary)) || spvBinary.empty()) {
- return variantOp.emitError() << "failed to serialize spirv.module";
- }
- if (!options.dumpBinariesPath.empty()) {
- dumpDataToPath<uint32_t>(options.dumpBinariesPath, options.dumpBaseName,
- variantOp.getName(), ".spv", spvBinary);
- }
-
- auto spvCodeRef = flatbuffers_uint32_vec_create(builder, spvBinary.data(),
- spvBinary.size());
-
- // The runtime uses ordinals instead of names. We provide the list of entry
- // point names here that are then passed in VkShaderModuleCreateInfo.
SmallVector<StringRef> entryPointNames;
SmallVector<uint32_t> subgroupSizes;
+ SmallVector<iree_hal_spirv_ShaderModuleDef_ref_t> shaderModuleRefs;
+ SmallVector<uint32_t> shaderModuleIndices;
SmallVector<iree_hal_spirv_FileLineLocDef_ref_t> sourceLocationRefs;
bool hasAnySubgroupSizes = false;
- spvModuleOp.walk([&](spirv::EntryPointOp exportOp) {
- entryPointNames.push_back(exportOp.getFn());
- auto fn = spvModuleOp.lookupSymbol<spirv::FuncOp>(exportOp.getFn());
+ // Iterate over all spirv.module ops and encode them into the FlatBuffer
+ // data structure.
+ for (spirv::ModuleOp spvModuleOp : spirvModuleOps) {
+ // Currently the spirv.module op should only have one entry point. Get it.
+ auto spirvEntryPoints = spvModuleOp.getOps<spirv::EntryPointOp>();
+ if (!llvm::hasSingleElement(spirvEntryPoints)) {
+ return spvModuleOp.emitError()
+ << "expected to contain exactly one entry point";
+ }
+ spirv::EntryPointOp spvEntryPoint = *spirvEntryPoints.begin();
+
+ if (!options.dumpIntermediatesPath.empty()) {
+ std::string assembly;
+ llvm::raw_string_ostream os(assembly);
+ spvModuleOp.print(os, OpPrintingFlags().useLocalScope());
+ dumpDataToPath(options.dumpIntermediatesPath, options.dumpBaseName,
+ spvEntryPoint.getFn(), ".spirv.mlir", assembly);
+ }
+
+ // Serialize the spirv::ModuleOp into the binary blob.
+ SmallVector<uint32_t, 0> spvBinary;
+ if (failed(spirv::serialize(spvModuleOp, spvBinary)) ||
+ spvBinary.empty()) {
+ return spvModuleOp.emitError() << "failed to serialize";
+ }
+ if (!options.dumpBinariesPath.empty()) {
+ dumpDataToPath<uint32_t>(options.dumpBinariesPath, options.dumpBaseName,
+ spvEntryPoint.getFn(), ".spv", spvBinary);
+ }
+ auto spvCodeRef = flatbuffers_uint32_vec_create(builder, spvBinary.data(),
+ spvBinary.size());
+ shaderModuleIndices.push_back(shaderModuleRefs.size());
+ shaderModuleRefs.push_back(
+ iree_hal_spirv_ShaderModuleDef_create(builder, spvCodeRef));
+
+ // The IREE runtime uses ordinals instead of names. We need to attach the
+ // entry point name for VkShaderModuleCreateInfo.
+ entryPointNames.push_back(spvEntryPoint.getFn());
+
+ // If there are subgroup size requests, we need to pick up too.
+ auto fn = spvModuleOp.lookupSymbol<spirv::FuncOp>(spvEntryPoint.getFn());
auto abi = fn->getAttrOfType<spirv::EntryPointABIAttr>(
spirv::getEntryPointABIAttrName());
if (abi && abi.getSubgroupSize()) {
@@ -225,29 +244,37 @@
// Optional source location information for debugging/profiling.
if (options.debugLevel >= 1) {
- if (auto loc = findFirstFileLoc(exportOp.getLoc())) {
+ if (auto loc = findFirstFileLoc(spvEntryPoint.getLoc())) {
auto filenameRef = builder.createString(loc->getFilename());
sourceLocationRefs.push_back(iree_hal_spirv_FileLineLocDef_create(
builder, filenameRef, loc->getLine()));
}
- }
- });
+ };
+ }
+
+ // Add top-level executable fields following their order of definition.
auto entryPointsRef = builder.createStringVec(entryPointNames);
flatbuffers_int32_vec_ref_t subgroupSizesRef =
hasAnySubgroupSizes ? builder.createInt32Vec(subgroupSizes) : 0;
-
+ flatbuffers_int32_vec_ref_t shaderModuleIndicesRef =
+ builder.createInt32Vec(shaderModuleIndices);
iree_hal_spirv_ExecutableDef_entry_points_add(builder, entryPointsRef);
if (subgroupSizesRef) {
iree_hal_spirv_ExecutableDef_subgroup_sizes_add(builder,
subgroupSizesRef);
}
- iree_hal_spirv_ExecutableDef_code_add(builder, spvCodeRef);
+ iree_hal_spirv_ExecutableDef_shader_module_indices_add(
+ builder, shaderModuleIndicesRef);
+ auto shaderModulesRef =
+ builder.createOffsetVecDestructive(shaderModuleRefs);
+ iree_hal_spirv_ExecutableDef_shader_modules_add(builder, shaderModulesRef);
if (!sourceLocationRefs.empty()) {
auto sourceLocationsRef =
builder.createOffsetVecDestructive(sourceLocationRefs);
iree_hal_spirv_ExecutableDef_source_locations_add(builder,
sourceLocationsRef);
}
+
iree_hal_spirv_ExecutableDef_end_as_root(builder);
// Add the binary data to the target executable.
@@ -281,6 +308,9 @@
for (auto exportOp : variantOp.getExportOps()) {
entryPointNames.emplace_back(exportOp.getSymName());
}
+ // We only have one object file for now. So all entry points have shader
+ // module index 0.
+ SmallVector<uint32_t, 8> shaderModuleIndices(entryPointNames.size(), 0);
// Load .spv object file.
auto objectAttr = llvm::cast<IREE::HAL::ExecutableObjectAttr>(
@@ -303,11 +333,20 @@
auto spvCodeRef = flatbuffers_uint32_vec_create(
builder, reinterpret_cast<const uint32_t *>(spvBinary.data()),
spvBinary.size() / sizeof(uint32_t));
+ SmallVector<iree_hal_spirv_ShaderModuleDef_ref_t> shaderModuleRefs;
+ shaderModuleRefs.push_back(
+ iree_hal_spirv_ShaderModuleDef_create(builder, spvCodeRef));
+ // Add top-level executable fields following their order of definition.
auto entryPointsRef = builder.createStringVec(entryPointNames);
-
+ auto shaderModuleIndicesRef = builder.createInt32Vec(shaderModuleIndices);
iree_hal_spirv_ExecutableDef_entry_points_add(builder, entryPointsRef);
- iree_hal_spirv_ExecutableDef_code_add(builder, spvCodeRef);
+ iree_hal_spirv_ExecutableDef_shader_module_indices_add(
+ builder, shaderModuleIndicesRef);
+ auto shaderModulesRef =
+ builder.createOffsetVecDestructive(shaderModuleRefs);
+ iree_hal_spirv_ExecutableDef_shader_modules_add(builder, shaderModulesRef);
+
iree_hal_spirv_ExecutableDef_end_as_root(builder);
// Add the binary data to the target executable.
diff --git a/runtime/src/iree/hal/drivers/vulkan/native_executable.cc b/runtime/src/iree/hal/drivers/vulkan/native_executable.cc
index fb91041..eb9d0d9 100644
--- a/runtime/src/iree/hal/drivers/vulkan/native_executable.cc
+++ b/runtime/src/iree/hal/drivers/vulkan/native_executable.cc
@@ -11,7 +11,6 @@
#include <cstring>
#include "iree/base/api.h"
-#include "iree/hal/drivers/vulkan/dynamic_symbol_tables.h"
#include "iree/hal/drivers/vulkan/dynamic_symbols.h"
#include "iree/hal/drivers/vulkan/handle_util.h"
#include "iree/hal/drivers/vulkan/native_pipeline_layout.h"
@@ -64,7 +63,7 @@
VkDeviceHandle* logical_device, VkPipelineCache pipeline_cache,
const iree_hal_executable_params_t* executable_params,
iree_hal_spirv_ExecutableDef_table_t executable_def,
- VkShaderModule shader_module, iree_host_size_t pipeline_count,
+ VkShaderModule* shader_modules, iree_host_size_t pipeline_count,
iree_hal_vulkan_entry_point_t* out_entry_points) {
IREE_TRACE_SCOPE();
uint8_t* scratch_memory = NULL;
@@ -102,6 +101,8 @@
flatbuffers_string_vec_t entry_points_vec =
iree_hal_spirv_ExecutableDef_entry_points_get(executable_def);
+ flatbuffers_uint32_vec_t shader_module_indices_vec =
+ iree_hal_spirv_ExecutableDef_shader_module_indices_get(executable_def);
flatbuffers_uint32_vec_t subgroup_sizes_vec =
iree_hal_spirv_ExecutableDef_subgroup_sizes_get(executable_def);
for (iree_host_size_t entry_ordinal = 0; entry_ordinal < pipeline_count;
@@ -130,7 +131,10 @@
VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
stage_create_info->flags = 0;
stage_create_info->stage = VK_SHADER_STAGE_COMPUTE_BIT;
- stage_create_info->module = shader_module;
+ uint32_t shader_module_index =
+ flatbuffers_uint32_vec_at(shader_module_indices_vec, entry_ordinal);
+ // We have verified that shader_module_index is within the range.
+ stage_create_info->module = shader_modules[shader_module_index];
stage_create_info->pName =
flatbuffers_string_vec_at(entry_points_vec, entry_ordinal);
stage_create_info->pSpecializationInfo = &spec_info;
@@ -248,10 +252,45 @@
}
}
- if (flatbuffers_uint32_vec_len(
- iree_hal_spirv_ExecutableDef_code_get(executable_def)) == 0) {
- return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
- "executable SPIR-V code is missing/empty");
+ iree_hal_spirv_ShaderModuleDef_vec_t shader_modules_vec =
+ iree_hal_spirv_ExecutableDef_shader_modules_get(executable_def);
+ size_t shader_module_count = flatbuffers_vec_len(shader_modules_vec);
+ if (shader_module_count == 0) {
+ return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+ "executable provides no shader modules");
+ }
+ for (size_t i = 0; i < shader_module_count; ++i) {
+ iree_hal_spirv_ShaderModuleDef_table_t shader_module =
+ iree_hal_spirv_ShaderModuleDef_vec_at(shader_modules_vec, i);
+ size_t code_size = flatbuffers_uint32_vec_len(
+ iree_hal_spirv_ShaderModuleDef_code_get(shader_module));
+ if (code_size == 0) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "executable SPIR-V code in shader module #%zu is missing", i);
+ }
+ }
+
+ flatbuffers_uint32_vec_t shader_module_indices_vec =
+ iree_hal_spirv_ExecutableDef_shader_module_indices_get(executable_def);
+ size_t shader_module_index_count =
+ flatbuffers_vec_len(shader_module_indices_vec);
+ if (shader_module_index_count != expected_entry_point_count) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "executable has %" PRIhsz
+ " entry points but %zu shader module indices are defined",
+ expected_entry_point_count, shader_module_index_count);
+ }
+ for (size_t i = 0; i < shader_module_index_count; ++i) {
+ uint32_t index = flatbuffers_uint32_vec_at(shader_module_indices_vec, i);
+ if (index >= shader_module_count) {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "executable entry point shader module index %u out of range; "
+ "executable only has %zu total shader modules",
+ index, shader_module_count);
+ }
}
return iree_ok_status();
@@ -284,6 +323,7 @@
IREE_ASSERT_ARGUMENT(executable_params);
IREE_ASSERT_ARGUMENT(out_executable);
*out_executable = NULL;
+ iree_allocator_t host_allocator = logical_device->host_allocator();
IREE_TRACE_ZONE_BEGIN(z0);
// Verify and fetch the executable FlatBuffer wrapper.
@@ -295,17 +335,30 @@
iree_hal_spirv_ExecutableDef_as_root(
executable_params->executable_data.data);
- // Create the shader module.
- flatbuffers_uint32_vec_t code_vec =
- iree_hal_spirv_ExecutableDef_code_get(executable_def);
- VkShaderModule shader_module = VK_NULL_HANDLE;
+ // Allocate space for Vulkan shader module handles.
+ iree_hal_spirv_ShaderModuleDef_vec_t shader_modules_vec =
+ iree_hal_spirv_ExecutableDef_shader_modules_get(executable_def);
+ size_t shader_module_count = flatbuffers_vec_len(shader_modules_vec);
+ VkShaderModule* shader_modules = NULL;
IREE_RETURN_AND_END_ZONE_IF_ERROR(
- z0, iree_hal_vulkan_create_shader_module(
- logical_device,
- iree_make_const_byte_span(
- code_vec,
- flatbuffers_uint32_vec_len(code_vec) * sizeof(uint32_t)),
- &shader_module));
+ z0, iree_allocator_malloc(host_allocator,
+ shader_module_count * sizeof(VkShaderModule),
+ (void**)&shader_modules));
+
+ // Create all shader modules.
+ // TODO: perform the shader module creation in multiple threaded manner.
+ iree_status_t status = iree_ok_status();
+ for (size_t i = 0; i < shader_module_count; ++i) {
+ iree_hal_spirv_ShaderModuleDef_table_t shader_module =
+ iree_hal_spirv_ShaderModuleDef_vec_at(shader_modules_vec, i);
+ flatbuffers_uint32_vec_t code_vec =
+ iree_hal_spirv_ShaderModuleDef_code_get(shader_module);
+ size_t code_size = flatbuffers_uint32_vec_len(code_vec) * sizeof(uint32_t);
+ status = iree_hal_vulkan_create_shader_module(
+ logical_device, iree_make_const_byte_span(code_vec, code_size),
+ &shader_modules[i]);
+ if (!iree_status_is_ok(status)) break;
+ }
// Create pipelines for each entry point.
flatbuffers_string_vec_t entry_points_vec =
@@ -314,11 +367,13 @@
flatbuffers_string_vec_len(entry_points_vec);
iree_hal_vulkan_native_executable_t* executable = NULL;
- iree_host_size_t total_size =
- sizeof(*executable) +
- entry_point_count * sizeof(*executable->entry_points);
- iree_status_t status = iree_allocator_malloc(logical_device->host_allocator(),
- total_size, (void**)&executable);
+ if (iree_status_is_ok(status)) {
+ status = iree_allocator_malloc(
+ host_allocator,
+ sizeof(*executable) +
+ entry_point_count * sizeof(*executable->entry_points),
+ (void**)&executable);
+ }
if (iree_status_is_ok(status)) {
iree_hal_resource_initialize(&iree_hal_vulkan_native_executable_vtable,
&executable->resource);
@@ -330,9 +385,14 @@
if (iree_status_is_ok(status)) {
status = iree_hal_vulkan_create_pipelines(
logical_device, pipeline_cache, executable_params, executable_def,
- shader_module, executable->entry_point_count, executable->entry_points);
+ shader_modules, executable->entry_point_count,
+ executable->entry_points);
}
- iree_hal_vulkan_destroy_shader_module(logical_device, shader_module);
+ // Pipelines are created and we don't need the shader modules anymore.
+ // Note that if error happens before, we also destroy the shader modules here.
+ for (size_t i = 0; i < shader_module_count; ++i) {
+ iree_hal_vulkan_destroy_shader_module(logical_device, shader_modules[i]);
+ }
if (iree_status_is_ok(status)) {
flatbuffers_string_vec_t entry_points_vec =
diff --git a/runtime/src/iree/schemas/spirv_executable_def.fbs b/runtime/src/iree/schemas/spirv_executable_def.fbs
index 15ebff1..e6ead6b 100644
--- a/runtime/src/iree/schemas/spirv_executable_def.fbs
+++ b/runtime/src/iree/schemas/spirv_executable_def.fbs
@@ -10,6 +10,11 @@
file_identifier "SPVE";
file_extension "spve";
+table ShaderModuleDef {
+ // SPIR-V code blob.
+ code:[uint32];
+}
+
// Source code location denoted by a file name and line within that file.
table FileLineLocDef {
filename:string;
@@ -23,11 +28,20 @@
// A map of entry point ordinals to string names as used in the shader module.
entry_points:[string];
- // Required subgroup size for each entry point. 0 means no requirement.
+ // A list of required subgroup sizes for each entry point. 0 means no
+ // requirement.
+ // This list has the same size as the entry_points list.
subgroup_sizes:[uint32];
- // SPIR-V code words.
- code:[uint32];
+ // A map of entry point ordinals to the indices of the containing shader
+ // modules (the following field).
+ // This list has the same size as the entry_points list.
+ shader_module_indices:[uint32];
+
+ // A list of shader modules hosting various entry points. Each shader module
+ // contains at least one entry point.
+ // This list may not have the same size as the entry_points list.
+ shader_modules:[ShaderModuleDef];
// A map of entry point ordinals to source locations.
// This information is optional and may be used by debuggers and profilers to