| // Copyright 2022 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree/compiler/Codegen/LLVMCPU/PassDetail.h" |
| #include "iree/compiler/Codegen/LLVMCPU/Passes.h" |
| #include "iree/compiler/Codegen/Utils/LinkingUtils.h" |
| #include "iree/compiler/Utils/ModuleUtils.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "mlir/Pass/Pass.h" |
| |
| namespace mlir::iree_compiler { |
| |
| namespace { |
| |
| struct LLVMCPULinkExecutablesPass |
| : public LLVMCPULinkExecutablesBase<LLVMCPULinkExecutablesPass> { |
| LLVMCPULinkExecutablesPass() = default; |
| void runOnOperation() override { |
| auto moduleOp = getOperation(); |
| auto moduleBuilder = OpBuilder::atBlockBegin(moduleOp.getBody()); |
| |
| auto sourceExecutableOps = |
| llvm::to_vector<8>(moduleOp.getOps<IREE::HAL::ExecutableOp>()); |
| if (sourceExecutableOps.size() <= 1) |
| return; |
| |
| // Guess a module name, if needed, to make the output files readable. |
| auto moduleName = guessModuleName(moduleOp, "llvm_module"); |
| |
| // Create our new "linked" hal.executable. |
| std::string linkedExecutableName = |
| llvm::formatv("{0}_linked_{1}", moduleName, "llvm_cpu"); |
| auto linkedExecutableOp = moduleBuilder.create<IREE::HAL::ExecutableOp>( |
| moduleOp.getLoc(), linkedExecutableName); |
| linkedExecutableOp.setVisibility( |
| sourceExecutableOps.front().getVisibility()); |
| auto executableBuilder = |
| OpBuilder::atBlockBegin(&linkedExecutableOp.getBlock()); |
| |
| // Gather all unique executable targets - we may have multiple. |
| auto executableTargetAttrs = gatherExecutableTargets(sourceExecutableOps); |
| for (auto [index, attr] : llvm::enumerate(executableTargetAttrs)) { |
| // Add our hal.executable.variant with an empty module. |
| std::string linkedVariantName = |
| executableTargetAttrs.size() == 1 |
| ? attr.getSymbolNameFragment() |
| : llvm::formatv("{0}_{1}", attr.getSymbolNameFragment(), index); |
| auto linkedTargetOp = |
| executableBuilder.create<IREE::HAL::ExecutableVariantOp>( |
| moduleOp.getLoc(), linkedVariantName, attr); |
| auto targetBuilder = OpBuilder::atBlockBegin(&linkedTargetOp.getBlock()); |
| targetBuilder.create<mlir::ModuleOp>(moduleOp.getLoc()); |
| |
| auto mergeModuleFn = [](mlir::ModuleOp sourceInnerModule, |
| mlir::ModuleOp linkedInnerModule, |
| DenseMap<StringRef, Operation *> &symbolMap) { |
| return mergeModuleInto(sourceInnerModule, linkedInnerModule, symbolMap); |
| }; |
| |
| // Try linking together all executables in moduleOp. |
| if (failed(linkExecutablesInto(moduleOp, sourceExecutableOps, |
| linkedExecutableOp, linkedTargetOp, |
| mergeModuleFn))) { |
| return signalPassFailure(); |
| } |
| } |
| } |
| }; |
| |
| } // namespace |
| |
| std::unique_ptr<OperationPass<mlir::ModuleOp>> |
| createLLVMCPULinkExecutablesPass() { |
| return std::make_unique<LLVMCPULinkExecutablesPass>(); |
| } |
| |
| } // namespace mlir::iree_compiler |