[vulkan] Enable initial executable linking (#15802)

This commit turns on executable linking for Vulkan. Specifically, it
merges `spirv.module` ops in `hal.executable.variant` ops which has the
same target attribute. Each `hal.executable.variant` op is then
serialized into one final FlatBuffer executable. Inside, it now then
contains multiple SPIR-V blobs. Each entry point gets an index pointing
to one of the SPIR-V blobs.

Fixes https://github.com/openxla/iree/issues/7824
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
index 476b7dd..5c09b4b 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
@@ -678,6 +678,17 @@
   });
 }
 
+// NOTE: this runs on the top-level program module containing all hal.executable
+// ops.
+void buildSPIRVLinkingPassPipeline(OpPassManager &passManager) {
+  // Link together executables. This may produce some IR duplication.
+  passManager.addPass(createSPIRVLinkExecutablesPass());
+
+  // Cleanup IR duplication.
+  passManager.addNestedPass<IREE::HAL::ExecutableOp>(
+      mlir::createCanonicalizerPass());
+}
+
 //===---------------------------------------------------------------------===//
 // Register SPIRV Passes
 //===---------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
index 9b8bcbe..cf0938c 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.h
@@ -109,6 +109,9 @@
 /// Links SPIR-V HAL executables within the top-level program module.
 std::unique_ptr<OperationPass<mlir::ModuleOp>> createSPIRVLinkExecutablesPass();
 
+/// Populates passes needed to link HAL executables across SPIRV targets.
+void buildSPIRVLinkingPassPipeline(OpPassManager &passManager);
+
 /// Pass to set the lowering strategy for the target variant.
 std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
 createSPIRVSelectLoweringStrategyPass();
@@ -180,7 +183,7 @@
     ArrayRef<int64_t> workgroupSize);
 
 //----------------------------------------------------------------------------//
-// Register SPIRV Passes
+// Register SPIR-V Passes
 //----------------------------------------------------------------------------//
 
 void registerCodegenSPIRVPasses();
diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp
index 4d1e409..de65e2c 100644
--- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp
+++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVLinkExecutables.cpp
@@ -27,6 +27,19 @@
     if (sourceExecutableOps.size() <= 1)
       return;
 
+    // Retain only non-external source executables. Linking right now happens as
+    // placing spirv.module ops into the same hal.executable.variant ops.
+    // External source executables won't have any spirv.modules inside.
+    int retainSize = 0;
+    for (int i = 0, e = sourceExecutableOps.size(); i < e; ++i) {
+      IREE::HAL::ExecutableOp executable = sourceExecutableOps[i];
+      if (llvm::none_of(executable.getOps<IREE::HAL::ExecutableVariantOp>(),
+                        [](auto op) { return op.getObjects().has_value(); })) {
+        sourceExecutableOps[retainSize++] = executable;
+      }
+    }
+    sourceExecutableOps.resize(retainSize);
+
     // Guess a module name, if needed, to make the output files readable.
     std::string moduleName = guessModuleName(moduleOp, "spirv_module");
 
diff --git a/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
index 09ac284..c823020 100644
--- a/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
+++ b/compiler/src/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
@@ -162,6 +162,10 @@
     buildSPIRVCodegenPassPipeline(passManager, /*enableFastMath=*/false);
   }
 
+  void buildLinkingPassPipeline(OpPassManager &passManager) override {
+    buildSPIRVLinkingPassPipeline(passManager);
+  }
+
   LogicalResult serializeExecutable(const SerializationOptions &options,
                                     IREE::HAL::ExecutableVariantOp variantOp,
                                     OpBuilder &executableBuilder) override {
@@ -174,46 +178,61 @@
 
     ModuleOp innerModuleOp = variantOp.getInnerModule();
     auto spirvModuleOps = innerModuleOp.getOps<spirv::ModuleOp>();
-    if (!llvm::hasSingleElement(spirvModuleOps)) {
-      return variantOp.emitError()
-             << "should only contain exactly one spirv.module op";
-    }
-    auto spvModuleOp = *spirvModuleOps.begin();
-    if (!options.dumpIntermediatesPath.empty()) {
-      std::string assembly;
-      llvm::raw_string_ostream os(assembly);
-      spvModuleOp.print(os, OpPrintingFlags().useLocalScope());
-      dumpDataToPath(options.dumpIntermediatesPath, options.dumpBaseName,
-                     variantOp.getName(), ".mlir", assembly);
+    if (spirvModuleOps.empty()) {
+      return variantOp.emitError() << "should contain some spirv.module ops";
     }
 
     FlatbufferBuilder builder;
     iree_hal_spirv_ExecutableDef_start_as_root(builder);
 
-    // Serialize the spirv::ModuleOp into the binary that we will embed in the
-    // final FlatBuffer.
-    SmallVector<uint32_t, 256> spvBinary;
-    if (failed(spirv::serialize(spvModuleOp, spvBinary)) || spvBinary.empty()) {
-      return variantOp.emitError() << "failed to serialize spirv.module";
-    }
-    if (!options.dumpBinariesPath.empty()) {
-      dumpDataToPath<uint32_t>(options.dumpBinariesPath, options.dumpBaseName,
-                               variantOp.getName(), ".spv", spvBinary);
-    }
-
-    auto spvCodeRef = flatbuffers_uint32_vec_create(builder, spvBinary.data(),
-                                                    spvBinary.size());
-
-    // The runtime uses ordinals instead of names. We provide the list of entry
-    // point names here that are then passed in VkShaderModuleCreateInfo.
     SmallVector<StringRef> entryPointNames;
     SmallVector<uint32_t> subgroupSizes;
+    SmallVector<iree_hal_spirv_ShaderModuleDef_ref_t> shaderModuleRefs;
+    SmallVector<uint32_t> shaderModuleIndices;
     SmallVector<iree_hal_spirv_FileLineLocDef_ref_t> sourceLocationRefs;
     bool hasAnySubgroupSizes = false;
-    spvModuleOp.walk([&](spirv::EntryPointOp exportOp) {
-      entryPointNames.push_back(exportOp.getFn());
 
-      auto fn = spvModuleOp.lookupSymbol<spirv::FuncOp>(exportOp.getFn());
+    // Iterate over all spirv.module ops and encode them into the FlatBuffer
+    // data structure.
+    for (spirv::ModuleOp spvModuleOp : spirvModuleOps) {
+      // Currently the spirv.module op should only have one entry point. Get it.
+      auto spirvEntryPoints = spvModuleOp.getOps<spirv::EntryPointOp>();
+      if (!llvm::hasSingleElement(spirvEntryPoints)) {
+        return spvModuleOp.emitError()
+               << "expected to contain exactly one entry point";
+      }
+      spirv::EntryPointOp spvEntryPoint = *spirvEntryPoints.begin();
+
+      if (!options.dumpIntermediatesPath.empty()) {
+        std::string assembly;
+        llvm::raw_string_ostream os(assembly);
+        spvModuleOp.print(os, OpPrintingFlags().useLocalScope());
+        dumpDataToPath(options.dumpIntermediatesPath, options.dumpBaseName,
+                       spvEntryPoint.getFn(), ".spirv.mlir", assembly);
+      }
+
+      // Serialize the spirv::ModuleOp into the binary blob.
+      SmallVector<uint32_t, 0> spvBinary;
+      if (failed(spirv::serialize(spvModuleOp, spvBinary)) ||
+          spvBinary.empty()) {
+        return spvModuleOp.emitError() << "failed to serialize";
+      }
+      if (!options.dumpBinariesPath.empty()) {
+        dumpDataToPath<uint32_t>(options.dumpBinariesPath, options.dumpBaseName,
+                                 spvEntryPoint.getFn(), ".spv", spvBinary);
+      }
+      auto spvCodeRef = flatbuffers_uint32_vec_create(builder, spvBinary.data(),
+                                                      spvBinary.size());
+      shaderModuleIndices.push_back(shaderModuleRefs.size());
+      shaderModuleRefs.push_back(
+          iree_hal_spirv_ShaderModuleDef_create(builder, spvCodeRef));
+
+      // The IREE runtime uses ordinals instead of names. We need to attach the
+      // entry point name for VkShaderModuleCreateInfo.
+      entryPointNames.push_back(spvEntryPoint.getFn());
+
+      // If there are subgroup size requests, we need to pick up too.
+      auto fn = spvModuleOp.lookupSymbol<spirv::FuncOp>(spvEntryPoint.getFn());
       auto abi = fn->getAttrOfType<spirv::EntryPointABIAttr>(
           spirv::getEntryPointABIAttrName());
       if (abi && abi.getSubgroupSize()) {
@@ -225,29 +244,37 @@
 
       // Optional source location information for debugging/profiling.
       if (options.debugLevel >= 1) {
-        if (auto loc = findFirstFileLoc(exportOp.getLoc())) {
+        if (auto loc = findFirstFileLoc(spvEntryPoint.getLoc())) {
           auto filenameRef = builder.createString(loc->getFilename());
           sourceLocationRefs.push_back(iree_hal_spirv_FileLineLocDef_create(
               builder, filenameRef, loc->getLine()));
         }
-      }
-    });
+      };
+    }
+
+    // Add top-level executable fields following their order of definition.
     auto entryPointsRef = builder.createStringVec(entryPointNames);
     flatbuffers_int32_vec_ref_t subgroupSizesRef =
         hasAnySubgroupSizes ? builder.createInt32Vec(subgroupSizes) : 0;
-
+    flatbuffers_int32_vec_ref_t shaderModuleIndicesRef =
+        builder.createInt32Vec(shaderModuleIndices);
     iree_hal_spirv_ExecutableDef_entry_points_add(builder, entryPointsRef);
     if (subgroupSizesRef) {
       iree_hal_spirv_ExecutableDef_subgroup_sizes_add(builder,
                                                       subgroupSizesRef);
     }
-    iree_hal_spirv_ExecutableDef_code_add(builder, spvCodeRef);
+    iree_hal_spirv_ExecutableDef_shader_module_indices_add(
+        builder, shaderModuleIndicesRef);
+    auto shaderModulesRef =
+        builder.createOffsetVecDestructive(shaderModuleRefs);
+    iree_hal_spirv_ExecutableDef_shader_modules_add(builder, shaderModulesRef);
     if (!sourceLocationRefs.empty()) {
       auto sourceLocationsRef =
           builder.createOffsetVecDestructive(sourceLocationRefs);
       iree_hal_spirv_ExecutableDef_source_locations_add(builder,
                                                         sourceLocationsRef);
     }
+
     iree_hal_spirv_ExecutableDef_end_as_root(builder);
 
     // Add the binary data to the target executable.
@@ -281,6 +308,9 @@
     for (auto exportOp : variantOp.getExportOps()) {
       entryPointNames.emplace_back(exportOp.getSymName());
     }
+    // We only have one object file for now. So all entry points have shader
+    // module index 0.
+    SmallVector<uint32_t, 8> shaderModuleIndices(entryPointNames.size(), 0);
 
     // Load .spv object file.
     auto objectAttr = llvm::cast<IREE::HAL::ExecutableObjectAttr>(
@@ -303,11 +333,20 @@
     auto spvCodeRef = flatbuffers_uint32_vec_create(
         builder, reinterpret_cast<const uint32_t *>(spvBinary.data()),
         spvBinary.size() / sizeof(uint32_t));
+    SmallVector<iree_hal_spirv_ShaderModuleDef_ref_t> shaderModuleRefs;
+    shaderModuleRefs.push_back(
+        iree_hal_spirv_ShaderModuleDef_create(builder, spvCodeRef));
 
+    // Add top-level executable fields following their order of definition.
     auto entryPointsRef = builder.createStringVec(entryPointNames);
-
+    auto shaderModuleIndicesRef = builder.createInt32Vec(shaderModuleIndices);
     iree_hal_spirv_ExecutableDef_entry_points_add(builder, entryPointsRef);
-    iree_hal_spirv_ExecutableDef_code_add(builder, spvCodeRef);
+    iree_hal_spirv_ExecutableDef_shader_module_indices_add(
+        builder, shaderModuleIndicesRef);
+    auto shaderModulesRef =
+        builder.createOffsetVecDestructive(shaderModuleRefs);
+    iree_hal_spirv_ExecutableDef_shader_modules_add(builder, shaderModulesRef);
+
     iree_hal_spirv_ExecutableDef_end_as_root(builder);
 
     // Add the binary data to the target executable.
diff --git a/runtime/src/iree/hal/drivers/vulkan/native_executable.cc b/runtime/src/iree/hal/drivers/vulkan/native_executable.cc
index fb91041..eb9d0d9 100644
--- a/runtime/src/iree/hal/drivers/vulkan/native_executable.cc
+++ b/runtime/src/iree/hal/drivers/vulkan/native_executable.cc
@@ -11,7 +11,6 @@
 #include <cstring>
 
 #include "iree/base/api.h"
-#include "iree/hal/drivers/vulkan/dynamic_symbol_tables.h"
 #include "iree/hal/drivers/vulkan/dynamic_symbols.h"
 #include "iree/hal/drivers/vulkan/handle_util.h"
 #include "iree/hal/drivers/vulkan/native_pipeline_layout.h"
@@ -64,7 +63,7 @@
     VkDeviceHandle* logical_device, VkPipelineCache pipeline_cache,
     const iree_hal_executable_params_t* executable_params,
     iree_hal_spirv_ExecutableDef_table_t executable_def,
-    VkShaderModule shader_module, iree_host_size_t pipeline_count,
+    VkShaderModule* shader_modules, iree_host_size_t pipeline_count,
     iree_hal_vulkan_entry_point_t* out_entry_points) {
   IREE_TRACE_SCOPE();
   uint8_t* scratch_memory = NULL;
@@ -102,6 +101,8 @@
 
   flatbuffers_string_vec_t entry_points_vec =
       iree_hal_spirv_ExecutableDef_entry_points_get(executable_def);
+  flatbuffers_uint32_vec_t shader_module_indices_vec =
+      iree_hal_spirv_ExecutableDef_shader_module_indices_get(executable_def);
   flatbuffers_uint32_vec_t subgroup_sizes_vec =
       iree_hal_spirv_ExecutableDef_subgroup_sizes_get(executable_def);
   for (iree_host_size_t entry_ordinal = 0; entry_ordinal < pipeline_count;
@@ -130,7 +131,10 @@
         VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
     stage_create_info->flags = 0;
     stage_create_info->stage = VK_SHADER_STAGE_COMPUTE_BIT;
-    stage_create_info->module = shader_module;
+    uint32_t shader_module_index =
+        flatbuffers_uint32_vec_at(shader_module_indices_vec, entry_ordinal);
+    // We have verified that shader_module_index is within the range.
+    stage_create_info->module = shader_modules[shader_module_index];
     stage_create_info->pName =
         flatbuffers_string_vec_at(entry_points_vec, entry_ordinal);
     stage_create_info->pSpecializationInfo = &spec_info;
@@ -248,10 +252,45 @@
     }
   }
 
-  if (flatbuffers_uint32_vec_len(
-          iree_hal_spirv_ExecutableDef_code_get(executable_def)) == 0) {
-    return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
-                            "executable SPIR-V code is missing/empty");
+  iree_hal_spirv_ShaderModuleDef_vec_t shader_modules_vec =
+      iree_hal_spirv_ExecutableDef_shader_modules_get(executable_def);
+  size_t shader_module_count = flatbuffers_vec_len(shader_modules_vec);
+  if (shader_module_count == 0) {
+    return iree_make_status(IREE_STATUS_FAILED_PRECONDITION,
+                            "executable provides no shader modules");
+  }
+  for (size_t i = 0; i < shader_module_count; ++i) {
+    iree_hal_spirv_ShaderModuleDef_table_t shader_module =
+        iree_hal_spirv_ShaderModuleDef_vec_at(shader_modules_vec, i);
+    size_t code_size = flatbuffers_uint32_vec_len(
+        iree_hal_spirv_ShaderModuleDef_code_get(shader_module));
+    if (code_size == 0) {
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "executable SPIR-V code in shader module #%zu is missing", i);
+    }
+  }
+
+  flatbuffers_uint32_vec_t shader_module_indices_vec =
+      iree_hal_spirv_ExecutableDef_shader_module_indices_get(executable_def);
+  size_t shader_module_index_count =
+      flatbuffers_vec_len(shader_module_indices_vec);
+  if (shader_module_index_count != expected_entry_point_count) {
+    return iree_make_status(
+        IREE_STATUS_INVALID_ARGUMENT,
+        "executable has %" PRIhsz
+        " entry points but %zu shader module indices are defined",
+        expected_entry_point_count, shader_module_index_count);
+  }
+  for (size_t i = 0; i < shader_module_index_count; ++i) {
+    uint32_t index = flatbuffers_uint32_vec_at(shader_module_indices_vec, i);
+    if (index >= shader_module_count) {
+      return iree_make_status(
+          IREE_STATUS_INVALID_ARGUMENT,
+          "executable entry point shader module index %u out of range; "
+          "executable only has %zu total shader modules",
+          index, shader_module_count);
+    }
   }
 
   return iree_ok_status();
@@ -284,6 +323,7 @@
   IREE_ASSERT_ARGUMENT(executable_params);
   IREE_ASSERT_ARGUMENT(out_executable);
   *out_executable = NULL;
+  iree_allocator_t host_allocator = logical_device->host_allocator();
   IREE_TRACE_ZONE_BEGIN(z0);
 
   // Verify and fetch the executable FlatBuffer wrapper.
@@ -295,17 +335,30 @@
       iree_hal_spirv_ExecutableDef_as_root(
           executable_params->executable_data.data);
 
-  // Create the shader module.
-  flatbuffers_uint32_vec_t code_vec =
-      iree_hal_spirv_ExecutableDef_code_get(executable_def);
-  VkShaderModule shader_module = VK_NULL_HANDLE;
+  // Allocate space for Vulkan shader module handles.
+  iree_hal_spirv_ShaderModuleDef_vec_t shader_modules_vec =
+      iree_hal_spirv_ExecutableDef_shader_modules_get(executable_def);
+  size_t shader_module_count = flatbuffers_vec_len(shader_modules_vec);
+  VkShaderModule* shader_modules = NULL;
   IREE_RETURN_AND_END_ZONE_IF_ERROR(
-      z0, iree_hal_vulkan_create_shader_module(
-              logical_device,
-              iree_make_const_byte_span(
-                  code_vec,
-                  flatbuffers_uint32_vec_len(code_vec) * sizeof(uint32_t)),
-              &shader_module));
+      z0, iree_allocator_malloc(host_allocator,
+                                shader_module_count * sizeof(VkShaderModule),
+                                (void**)&shader_modules));
+
+  // Create all shader modules.
+  // TODO: perform the shader module creation in multiple threaded manner.
+  iree_status_t status = iree_ok_status();
+  for (size_t i = 0; i < shader_module_count; ++i) {
+    iree_hal_spirv_ShaderModuleDef_table_t shader_module =
+        iree_hal_spirv_ShaderModuleDef_vec_at(shader_modules_vec, i);
+    flatbuffers_uint32_vec_t code_vec =
+        iree_hal_spirv_ShaderModuleDef_code_get(shader_module);
+    size_t code_size = flatbuffers_uint32_vec_len(code_vec) * sizeof(uint32_t);
+    status = iree_hal_vulkan_create_shader_module(
+        logical_device, iree_make_const_byte_span(code_vec, code_size),
+        &shader_modules[i]);
+    if (!iree_status_is_ok(status)) break;
+  }
 
   // Create pipelines for each entry point.
   flatbuffers_string_vec_t entry_points_vec =
@@ -314,11 +367,13 @@
       flatbuffers_string_vec_len(entry_points_vec);
 
   iree_hal_vulkan_native_executable_t* executable = NULL;
-  iree_host_size_t total_size =
-      sizeof(*executable) +
-      entry_point_count * sizeof(*executable->entry_points);
-  iree_status_t status = iree_allocator_malloc(logical_device->host_allocator(),
-                                               total_size, (void**)&executable);
+  if (iree_status_is_ok(status)) {
+    status = iree_allocator_malloc(
+        host_allocator,
+        sizeof(*executable) +
+            entry_point_count * sizeof(*executable->entry_points),
+        (void**)&executable);
+  }
   if (iree_status_is_ok(status)) {
     iree_hal_resource_initialize(&iree_hal_vulkan_native_executable_vtable,
                                  &executable->resource);
@@ -330,9 +385,14 @@
   if (iree_status_is_ok(status)) {
     status = iree_hal_vulkan_create_pipelines(
         logical_device, pipeline_cache, executable_params, executable_def,
-        shader_module, executable->entry_point_count, executable->entry_points);
+        shader_modules, executable->entry_point_count,
+        executable->entry_points);
   }
-  iree_hal_vulkan_destroy_shader_module(logical_device, shader_module);
+  // Pipelines are created and we don't need the shader modules anymore.
+  // Note that if error happens before, we also destroy the shader modules here.
+  for (size_t i = 0; i < shader_module_count; ++i) {
+    iree_hal_vulkan_destroy_shader_module(logical_device, shader_modules[i]);
+  }
 
   if (iree_status_is_ok(status)) {
     flatbuffers_string_vec_t entry_points_vec =
diff --git a/runtime/src/iree/schemas/spirv_executable_def.fbs b/runtime/src/iree/schemas/spirv_executable_def.fbs
index 15ebff1..e6ead6b 100644
--- a/runtime/src/iree/schemas/spirv_executable_def.fbs
+++ b/runtime/src/iree/schemas/spirv_executable_def.fbs
@@ -10,6 +10,11 @@
 file_identifier "SPVE";
 file_extension "spve";
 
+table ShaderModuleDef {
+  // SPIR-V code blob.
+  code:[uint32];
+}
+
 // Source code location denoted by a file name and line within that file.
 table FileLineLocDef {
   filename:string;
@@ -23,11 +28,20 @@
   // A map of entry point ordinals to string names as used in the shader module.
   entry_points:[string];
 
-  // Required subgroup size for each entry point. 0 means no requirement.
+  // A list of required subgroup sizes for each entry point. 0 means no
+  // requirement.
+  // This list has the same size as the entry_points list.
   subgroup_sizes:[uint32];
 
-  // SPIR-V code words.
-  code:[uint32];
+  // A map of entry point ordinals to the indices of the containing shader
+  // modules (the following field).
+  // This list has the same size as the entry_points list.
+  shader_module_indices:[uint32];
+
+  // A list of shader modules hosting various entry points. Each shader module
+  // contains at least one entry point.
+  // This list may not have the same size as the entry_points list.
+  shader_modules:[ShaderModuleDef];
 
   // A map of entry point ordinals to source locations.
   // This information is optional and may be used by debuggers and profilers to