blob: e71da9e029fa5c80b63e275e49976d6d52674f8a [file]
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Codegen/SPIRV/Passes.h"
#include "iree/compiler/Codegen/SPIRV/Utils.h"
#include "iree/compiler/Codegen/Utils/LinkingUtils.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Utils/ModuleUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Pass/Pass.h"
#define DEBUG_TYPE "iree-spirv-link-executable"
namespace mlir::iree_compiler {
#define GEN_PASS_DEF_SPIRVLINKEXECUTABLESPASS
#include "iree/compiler/Codegen/SPIRV/Passes.h.inc"
namespace IREE::HAL {
// Compares two ExecutableTargetAttr according to the alphabetical order of used
// SPIR-V features.
//
// Note that this is a very specific ordering per the needs of this pass--we
// guarantee that input ExecutableTargetAttr only differ w.r.t. their used
// SPIR-V features, and we want a deterministic order when mutating the IR.
static bool operator<(const ExecutableTargetAttr &a,
const ExecutableTargetAttr &b) {
auto aFeatures = a.getConfiguration().getAs<ArrayAttr>("iree.spirv.features");
auto bFeatures = b.getConfiguration().getAs<ArrayAttr>("iree.spirv.features");
for (unsigned i = 0; i < std::min(aFeatures.size(), bFeatures.size()); ++i) {
if (aFeatures[i] != bFeatures[i]) {
return cast<StringAttr>(aFeatures[i]).getValue() <
cast<StringAttr>(bFeatures[i]).getValue();
}
}
return aFeatures.size() < bFeatures.size();
}
namespace {
// Returns all executables that have one or more variants that use SPIR-V
// codegen. Executables that contain object references are currently ignored as
// we only support full replacement of the modules and not yet linking.
static SmallVector<IREE::HAL::ExecutableOp>
gatherExecutablesForSPIRVCodegen(mlir::ModuleOp moduleOp) {
SmallVector<IREE::HAL::ExecutableOp> result;
for (auto executableOp : moduleOp.getOps<IREE::HAL::ExecutableOp>()) {
if (llvm::any_of(executableOp.getOps<IREE::HAL::ExecutableVariantOp>(),
[&](IREE::HAL::ExecutableVariantOp variantOp) {
return usesSPIRVCodeGen(variantOp) &&
!variantOp.getObjects().has_value();
})) {
result.push_back(executableOp);
}
}
return result;
}
struct SPIRVLinkExecutablesPass final
: impl::SPIRVLinkExecutablesPassBase<SPIRVLinkExecutablesPass> {
void runOnOperation() override {
mlir::ModuleOp moduleOp = getOperation();
// Collect all source executable ops.
auto sourceExecutableOps = gatherExecutablesForSPIRVCodegen(moduleOp);
if (sourceExecutableOps.size() <= 1) {
return;
}
// Note that at runtime, for a particular executable, only one variant of it
// will be loaded. So, all variants of an executable are expected to provide
// the exact same set of entry points; this way we can guarantee no matter
// which variant is chosen, we have all entry points to call into. The same
// entry point in different variants may have different target requirements
// though.
//
// The input to the linking stage are a collection of executables, each may
// have multiple variants, but only ever provide one entry point. Together
// with the above restriction, we can link two executables if and only if
// their variants have the exact same set of target requirements. Under such
// circumstances, we can make sure for a particular target requirement
// (loaded as one variant during runtime), we can provide all entry points.
// Build a map from all variants' target requirements to their wrapping
// executable ops.
std::map<SmallVector<ExecutableTargetAttr, 0>,
SmallVector<IREE::HAL::ExecutableOp>>
executableBuckets;
SmallVector<ExecutableTargetAttr, 0> currentTargets;
for (IREE::HAL::ExecutableOp executable : sourceExecutableOps) {
// Go through all variants and collect all their target requirements and
// sort as the unique key.
currentTargets.clear();
for (auto variant : executable.getOps<IREE::HAL::ExecutableVariantOp>()) {
if (usesSPIRVCodeGen(variant)) {
currentTargets.push_back(variant.getTarget());
}
}
llvm::sort(currentTargets);
LLVM_DEBUG({
llvm::dbgs() << "executable op @" << executable.getSymName()
<< " targets:\n";
for (ExecutableTargetAttr attr : currentTargets) {
llvm::dbgs() << " " << attr << "\n";
}
});
// Put this executable into its proper bucket.
executableBuckets[std::move(currentTargets)].push_back(executable);
}
// Scan through the buckets and drop those with only one executables, given
// nothing to link for such cases.
for (auto it = executableBuckets.begin(), ie = executableBuckets.end();
it != ie;) {
if (it->second.size() <= 1) {
it = executableBuckets.erase(it);
} else {
++it;
}
}
// Guess a base module name, if needed, to make the output files readable.
std::string baseModuleName =
guessModuleName(moduleOp, "spirv_module") + "_linked_spirv";
// Go reverse order with index, so when we keep inserting at the beginning,
// the final IR has ascending order.
int bucketIndex = executableBuckets.size();
for (auto [key, bucket] : llvm::reverse(executableBuckets)) {
--bucketIndex;
// Build a unique name for this particular executable.
std::string moduleName =
executableBuckets.size() == 1
? baseModuleName
: llvm::formatv("{}_{}", baseModuleName, bucketIndex);
LLVM_DEBUG({
llvm::dbgs() << "executable bucket #" << bucketIndex << " targets:\n";
for (ExecutableTargetAttr attr : key) {
llvm::dbgs() << " " << attr << "\n";
}
llvm::dbgs() << "executable bucket #" << bucketIndex
<< " executables:\n";
for (IREE::HAL::ExecutableOp executable : bucket) {
llvm::dbgs() << " " << executable.getSymName() << "\n";
}
});
if (failed(linkOneExecutableBucket(moduleOp, moduleName, key, bucket))) {
return signalPassFailure();
}
}
}
// Links all executables that are known to be in the same bucket.
LogicalResult linkOneExecutableBucket(
mlir::ModuleOp moduleOp, StringRef linkedExecutableName,
ArrayRef<ExecutableTargetAttr> executableTargetAttrs,
SmallVectorImpl<IREE::HAL::ExecutableOp> &sourceExecutableOps) const {
OpBuilder moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody());
// Create our new "linked" hal.executable.
auto linkedExecutableOp = IREE::HAL::ExecutableOp::create(
moduleBuilder, moduleOp.getLoc(), linkedExecutableName);
linkedExecutableOp.setVisibility(
sourceExecutableOps.front().getVisibility());
OpBuilder executableBuilder =
OpBuilder::atBlockBegin(&linkedExecutableOp.getBlock());
for (auto [index, attr] : llvm::enumerate(executableTargetAttrs)) {
// Add our hal.executable.variant with an empty module.
std::string linkedVariantName =
executableTargetAttrs.size() == 1
? attr.getSymbolNameFragment()
: llvm::formatv("{}_{}", attr.getSymbolNameFragment(), index);
auto linkedTargetOp = IREE::HAL::ExecutableVariantOp::create(
executableBuilder, moduleOp.getLoc(), linkedVariantName, attr);
auto targetBuilder = OpBuilder::atBlockBegin(&linkedTargetOp.getBlock());
mlir::ModuleOp::create(targetBuilder, moduleOp.getLoc());
auto mergeModuleFn = [](mlir::ModuleOp sourceInnerModule,
mlir::ModuleOp linkedInnerModule,
DenseMap<StringRef, Operation *> &symbolMap) {
// spirv.module is isolated from above. It does not define symbols or
// reference outside symbols too. So we can just simply move it to the
// linked inner module.
auto srcModules = sourceInnerModule.getOps<spirv::ModuleOp>();
assert(std::distance(srcModules.begin(), srcModules.end()) == 1);
Operation *srcModule = *srcModules.begin();
Block &targetBlock = *linkedInnerModule->getRegion(0).begin();
if (!targetBlock.empty()) {
srcModule->moveAfter(&targetBlock.back());
} else {
srcModule->moveBefore(&targetBlock, targetBlock.end());
}
return success();
};
// Try linking together all executables in moduleOp.
if (failed(linkExecutablesInto(moduleOp, sourceExecutableOps,
linkedExecutableOp, linkedTargetOp,
mergeModuleFn))) {
return failure();
}
}
return success();
}
};
} // namespace
} // namespace IREE::HAL
} // namespace mlir::iree_compiler