blob: 0cb90a50beae8b011070cbb03537c4a868d57fd3 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "./ROCMTargetFeatures.h"
#include "./ROCMTargetUtils.h"
#include <cstdint>
#include <mutex>
#include "compiler/plugins/target/ROCM/ROCMTargetFeatures.h"
#include "iree-dialects/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
#include "iree/compiler/Dialect/HAL/Target/LLVMLinkerUtils.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/PluginAPI/Client.h"
#include "iree/compiler/Utils/FlatbufferUtils.h"
#include "iree/compiler/Utils/ModuleUtils.h"
#include "iree/compiler/Utils/ToolUtils.h"
#include "iree/schemas/rocm_executable_def_builder.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Bitcode/BitcodeWriter.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Passes/StandardInstrumentations.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/Path.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
namespace mlir::iree_compiler::IREE::HAL {
namespace {
class ROCMOptions {
public:
std::string targetChip = "gfx908";
bool linkBitcode = true;
int wavesPerEu = 0;
std::string enableROCMUkernels = "none";
const std::string &getBitcodeDirectory() const {
if (bitcodeDirectory.empty()) {
return getDefaultBitcodeDirectory();
}
return bitcodeDirectory;
}
void bindOptions(OptionsBinder &binder) {
static llvm::cl::OptionCategory category("ROCM HAL Target");
binder.opt<std::string>("iree-rocm-target-chip", targetChip,
llvm::cl::cat(category),
llvm::cl::desc("ROCm target chip."));
binder.opt<std::string>("iree-rocm-bc-dir", bitcodeDirectory,
llvm::cl::cat(category),
llvm::cl::desc("Directory of ROCM Bitcode."));
binder.opt<int>("iree-rocm-waves-per-eu", wavesPerEu,
llvm::cl::cat(category),
llvm::cl::desc("Optimization hint specifying minimum "
"number of waves per execution unit."));
binder.opt<std::string>(
"iree-rocm-enable-ukernels", enableROCMUkernels,
llvm::cl::cat(category),
llvm::cl::desc(
"Enables microkernels in the llvmcpu backend. May be "
"`default`, `none`, `all`, or a comma-separated list of "
"specific unprefixed microkernels to enable, e.g. `mmt4d`."));
}
private:
static const std::string &getDefaultBitcodeDirectory() {
static std::string defaultValue =
mlir::iree_compiler::findPlatformLibDirectory("rocm");
return defaultValue;
}
std::string bitcodeDirectory;
};
} // namespace
static void dumpModuleToPath(StringRef path, StringRef baseName,
StringRef suffix, StringRef extension,
llvm::Module &module) {
llvm::SmallVector<char, 0> data;
llvm::raw_svector_ostream ostream(data);
module.print(ostream, nullptr);
dumpDataToPath(path, baseName, suffix, extension,
StringRef(data.data(), data.size()));
}
static std::string translateModuleToObj(llvm::Module &module,
llvm::TargetMachine &targetMachine) {
std::string targetObj;
{
llvm::raw_string_ostream stream(targetObj);
llvm::buffer_ostream pstream(stream);
llvm::legacy::PassManager codegenPasses;
targetMachine.addPassesToEmitFile(codegenPasses, pstream, nullptr,
llvm::CodeGenFileType::ObjectFile);
codegenPasses.run(module);
}
return targetObj;
}
static std::string translateModuleToISA(llvm::Module &module,
llvm::TargetMachine &targetMachine) {
std::string targetISA;
{
llvm::raw_string_ostream stream(targetISA);
llvm::buffer_ostream pstream(stream);
llvm::legacy::PassManager codegenPasses;
targetMachine.addPassesToEmitFile(codegenPasses, pstream, nullptr,
llvm::CodeGenFileType::AssemblyFile);
codegenPasses.run(module);
}
return targetISA;
}
// Modified from lib/Target/AMDGPU/AMDGPUAttributor.cpp.
// Adds argument hints to preload kernel arguments to SGPRs.
// TODO: Query max number of user SGPRs from target machine.
static void addPreloadKernArgHint(llvm::Function *F) {
static constexpr size_t maxSGPRs = 16;
for (size_t i = 0, e = std::min(F->arg_size(), maxSGPRs); i != e; ++i) {
llvm::Argument *Arg = F->getArg(i);
// Check for incompatible attributes.
if (Arg->hasByRefAttr() || Arg->hasNestAttr())
break;
Arg->addAttr(llvm::Attribute::InReg);
}
}
class ROCMTargetDevice final : public TargetDevice {
public:
ROCMTargetDevice(const ROCMOptions &options) : options(options) {}
IREE::HAL::DeviceTargetAttr
getDefaultDeviceTarget(MLIRContext *context,
const TargetRegistry &targetRegistry) const override {
Builder b(context);
SmallVector<NamedAttribute> configItems;
auto addConfig = [&](StringRef name, Attribute value) {
configItems.emplace_back(b.getStringAttr(name), value);
};
// Indicates that the runtime HAL driver operates only in the legacy
// synchronous mode.
addConfig("legacy_sync", b.getUnitAttr());
auto configAttr = b.getDictionaryAttr(configItems);
// If we had multiple target environments we would generate one target attr
// per environment, with each setting its own environment attribute.
SmallVector<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs;
targetRegistry.getTargetBackend("rocm")->getDefaultExecutableTargets(
context, "rocm", configAttr, executableTargetAttrs);
return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("rocm"),
configAttr, executableTargetAttrs);
}
private:
const ROCMOptions &options;
};
class ROCMTargetBackend final : public TargetBackend {
public:
ROCMTargetBackend(const ROCMOptions &options) : options(options) {}
std::string getLegacyDefaultDeviceID() const override { return "rocm"; }
void getDefaultExecutableTargets(
MLIRContext *context, StringRef deviceID, DictionaryAttr deviceConfigAttr,
SmallVectorImpl<IREE::HAL::ExecutableTargetAttr> &executableTargetAttrs)
const override {
executableTargetAttrs.push_back(getExecutableTarget(context));
}
IREE::HAL::ExecutableTargetAttr
getExecutableTarget(MLIRContext *context) const {
Builder b(context);
SmallVector<NamedAttribute> configItems;
auto addConfig = [&](StringRef name, Attribute value) {
configItems.emplace_back(b.getStringAttr(name), value);
};
addConfig("target_arch", b.getStringAttr(options.targetChip));
addConfig("ukernels", b.getStringAttr(options.enableROCMUkernels));
if (options.wavesPerEu > 0)
addConfig("waves_per_eu", b.getI64IntegerAttr(options.wavesPerEu));
ArrayAttr mmaAttrs = getROCMSupportedMmaAttrs(context, options.targetChip);
if (mmaAttrs)
addConfig("mma_intrinsics", mmaAttrs);
return b.getAttr<IREE::HAL::ExecutableTargetAttr>(
b.getStringAttr("rocm"), b.getStringAttr("rocm-hsaco-fb"),
b.getDictionaryAttr(configItems));
}
void getDependentDialects(DialectRegistry &registry) const override {
mlir::registerBuiltinDialectTranslation(registry);
mlir::registerLLVMDialectTranslation(registry);
mlir::registerROCDLDialectTranslation(registry);
registry.insert<IREE::Codegen::IREECodegenDialect>();
registry.insert<IREE::VectorExt::IREEVectorExtDialect>();
registry.insert<IREE::GPU::IREEGPUDialect>();
registry.insert<amdgpu::AMDGPUDialect>();
}
void buildConfigurationPassPipeline(IREE::HAL::ExecutableVariantOp variantOp,
OpPassManager &passManager) override {
// For now we disable configuration if the variant has external object
// files.
if (variantOp.isExternal())
return;
buildLLVMGPUCodegenConfigurationPassPipeline(passManager);
}
void buildTranslationPassPipeline(IREE::HAL::ExecutableVariantOp variantOp,
OpPassManager &passManager) override {
// For now we disable translation if the variant has external object files.
// We could instead perform linking with those objects (if they're bitcode
// ala libdevice.bc, etc).
if (variantOp.isExternal())
return;
buildLLVMGPUCodegenPassPipeline(passManager, true);
}
// Performs optimizations on |module| (including LTO-style whole-program
// ones). Inspired by code section in
// https://github.com/openxla/iree/blob/main/compiler/plugins/target/CUDA/CUDATarget.cpp
static void optimizeModule(llvm::Module &module,
llvm::TargetMachine &targetMachine) {
llvm::LoopAnalysisManager lam;
llvm::FunctionAnalysisManager fam;
llvm::CGSCCAnalysisManager cgam;
llvm::ModuleAnalysisManager mam;
fam.registerPass([&] { return targetMachine.getTargetIRAnalysis(); });
llvm::PipelineTuningOptions pto;
pto.SLPVectorization = false;
llvm::PassInstrumentationCallbacks pic;
llvm::StandardInstrumentations si(module.getContext(), false);
si.registerCallbacks(pic, &mam);
llvm::PassBuilder pb(&targetMachine, pto, std::nullopt, &pic);
llvm::ModulePassManager mpm;
pb.registerModuleAnalyses(mam);
pb.registerCGSCCAnalyses(cgam);
pb.registerFunctionAnalyses(fam);
pb.registerLoopAnalyses(lam);
pb.crossRegisterProxies(lam, fam, cgam, mam);
llvm::OptimizationLevel ol = llvm::OptimizationLevel::O2;
mpm.addPass(llvm::VerifierPass());
mpm.addPass(pb.buildPerModuleDefaultPipeline(ol));
mpm.addPass(llvm::VerifierPass());
mpm.run(module, mam);
}
LogicalResult serializeExecutable(const SerializationOptions &serOptions,
IREE::HAL::ExecutableVariantOp variantOp,
OpBuilder &executableBuilder) override {
ModuleOp innerModuleOp = variantOp.getInnerModule();
auto targetAttr = variantOp.getTargetAttr();
StringRef targetArch = options.targetChip;
if (auto attr = getConfigStringAttr(targetAttr, "target_arch"))
targetArch = attr->getValue();
// We name our files after the executable name so that they are easy to
// track both during compilation (logs/artifacts/etc), as outputs (final
// intermediate code/binary files), and at runtime (loaded
// libraries/symbols/etc).
auto libraryName =
variantOp->getParentOfType<IREE::HAL::ExecutableOp>().getName().str();
// Collect all the entry point names.
SmallVector<IREE::HAL::ExecutableExportOp> exportOps;
llvm::StringMap<IREE::HAL::ExecutableExportOp> exportOpMap;
std::vector<std::array<int32_t, 3>> workgroupSizes;
SmallVector<uint32_t> workgroupLocalMemories;
int32_t subgroupSize = 64;
for (auto exportOp : variantOp.getExportOps()) {
exportOps.push_back(exportOp);
exportOpMap[exportOp.getSymName()] = exportOp;
std::array<int32_t, 3> workgroupSize;
if (std::optional<ArrayAttr> workgroupSizeAttr =
exportOp.getWorkgroupSize()) {
for (auto it : llvm::enumerate(workgroupSizeAttr.value()))
workgroupSize[it.index()] = it.value().cast<IntegerAttr>().getInt();
} else {
workgroupSize = {1, 1, 1};
}
workgroupSizes.push_back(workgroupSize);
if (auto setSubgroupSize = getSubgroupSize(exportOp)) {
if (subgroupSize != 32 && subgroupSize != 64) {
return variantOp.emitError()
<< "invalid subgroup size " << subgroupSize;
}
subgroupSize = *setSubgroupSize;
}
uint32_t workgroupLocalMemory = 0;
if (auto workgroupLocalMemoryAttr = exportOp.getWorkgroupLocalMemory()) {
workgroupLocalMemory = workgroupLocalMemoryAttr->getSExtValue();
}
workgroupLocalMemories.push_back(workgroupLocalMemory);
}
std::string targetHSACO;
if (variantOp.isExternal()) {
if (!variantOp.getObjects().has_value()) {
return variantOp.emitOpError()
<< "no objects defined for external variant";
} else if (variantOp.getObjects()->getValue().size() != 1) {
// For now we assume there will be exactly one object file.
// In the future we will want to perform a linking step here and ideally
// support _also_ linking in the codegen results.
return variantOp.emitOpError() << "only one object reference is "
"supported for external variants";
}
// Read the HSACO from the object file.
auto objectAttr = llvm::cast<IREE::HAL::ExecutableObjectAttr>(
variantOp.getObjects()->getValue().front());
if (auto data = objectAttr.loadData()) {
targetHSACO = data.value();
} else {
return variantOp.emitOpError()
<< "object file could not be loaded: " << objectAttr;
}
} else {
// Perform the translation in a separate context to avoid any
// multi-threading issues.
llvm::LLVMContext context;
auto llvmModule =
mlir::translateModuleToLLVMIR(innerModuleOp, context, libraryName);
if (!llvmModule) {
return variantOp.emitError() << "failed to translate the MLIR LLVM "
"dialect to the native llvm::Module";
}
for (auto func : innerModuleOp.getOps<LLVM::LLVMFuncOp>()) {
int32_t flatWgSize = 1;
auto *llvmFunc = llvmModule->getFunction(func.getName());
if (llvmFunc->isDeclaration())
continue;
auto exportOp = exportOpMap[func.getName()];
if (auto workgroupSizeAttr = exportOp.getWorkgroupSize()) {
for (auto it : llvm::enumerate(workgroupSizeAttr.value())) {
flatWgSize *= it.value().cast<IntegerAttr>().getInt();
}
}
// Try to get waves-per-eu from the export-specific translation info in
// cases where codegen decides to override the value.
// Otherwise, fallback to the default option.
int64_t wavesPerEu = 0;
IREE::Codegen::TranslationInfoAttr translationInfo =
getTranslationInfo(exportOp);
if (auto translationConfig = translationInfo.getConfiguration()) {
if (auto attr =
translationConfig.getAs<IntegerAttr>("waves_per_eu")) {
wavesPerEu = attr.getValue().getSExtValue();
}
}
if (wavesPerEu == 0) {
if (auto attr = getConfigIntegerAttr(targetAttr, "waves_per_eu"))
wavesPerEu = attr->getValue().getSExtValue();
}
// For GPU kernels,
// 1. Insert AMDGPU_KERNEL calling convention.
// 2. Insert amdgpu-flat-workgroup-size(1, 256) attribute.
// 3. Insert amdgpu-implicitarg-num-bytes=56 (which must be set on
// OpenCL and HIP kernels per Clang)
llvmFunc->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
std::string wgSizeRange =
std::string("1, ") + std::to_string(flatWgSize);
llvmFunc->addFnAttr("amdgpu-flat-work-group-size", wgSizeRange);
if (wavesPerEu > 0) {
llvmFunc->addFnAttr("amdgpu-waves-per-eu",
std::to_string(wavesPerEu));
}
if (targetArch.starts_with("gfx9"))
addPreloadKernArgHint(llvmFunc);
}
std::unique_ptr<llvm::TargetMachine> targetMachine;
{
llvm::Triple triple("amdgcn-amd-amdhsa");
std::string error;
const llvm::Target *target =
llvm::TargetRegistry::lookupTarget("", triple, error);
if (target == nullptr) {
return variantOp.emitError() << "cannot initialize target triple";
}
llvm::TargetOptions opt;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
std::string features;
if (targetArch.starts_with("gfx9")) {
features = "+sramecc,-xnack";
} else {
// GFX 10 or 11.
if (subgroupSize == 32)
features = "+wavefrontsize32";
if (subgroupSize == 64)
features = "+wavefrontsize64";
}
targetMachine.reset(target->createTargetMachine(
triple.str(), targetArch, features, opt, llvm::Reloc::Model::PIC_,
std::nullopt, llvm::CodeGenOptLevel::Aggressive));
if (targetMachine == nullptr) {
return variantOp.emitError() << "cannot initialize target machine";
}
}
llvmModule->setDataLayout(targetMachine->createDataLayout());
for (llvm::Function &f : llvmModule->functions())
f.addFnAttr(llvm::Attribute::AlwaysInline);
// Link user-provided modules.
llvm::Linker linker(*llvmModule);
if (failed(linkCmdlineBitcodeFiles(
variantOp.getLoc(), linker, llvm::Linker::OverrideFromSrc,
*targetMachine, llvmModule->getContext()))) {
return failure();
}
// Link module to any enabled ukernels.
const std::string &bitcodeDirectory = options.getBitcodeDirectory();
StringRef enabledUkernels;
if (auto attr = getConfigStringAttr(targetAttr, "ukernels"))
enabledUkernels = attr->getValue();
if (!enabledUkernels.empty() && enabledUkernels != "none") {
if (failed(linkUkernelBitcodeFiles(
variantOp.getLoc(), llvmModule.get(), enabledUkernels,
targetArch, bitcodeDirectory, llvm::Linker::OverrideFromSrc,
*targetMachine))) {
return failure();
}
}
// Link module to HIP device library.
if (bitcodeDirectory.empty()) {
return variantOp.emitError()
<< "cannot find ROCM bitcode files. Check your installation "
"consistency and in the worst case, set "
"--iree-rocm-bc-dir= to a path on your system.";
}
if (failed(linkHIPBitcodeIfNeeded(variantOp.getLoc(), llvmModule.get(),
targetArch, bitcodeDirectory))) {
return failure();
}
// Sets HIP platform globals based on the target architecture.
if (failed(setHIPGlobals(variantOp.getLoc(), llvmModule.get(),
targetArch))) {
return failure();
}
if (!serOptions.dumpIntermediatesPath.empty()) {
dumpModuleToPath(serOptions.dumpIntermediatesPath,
serOptions.dumpBaseName, variantOp.getName(),
".linked.ll", *llvmModule);
}
// Run LLVM optimization passes.
optimizeModule(*llvmModule, *targetMachine);
if (!serOptions.dumpIntermediatesPath.empty()) {
dumpModuleToPath(serOptions.dumpIntermediatesPath,
serOptions.dumpBaseName, variantOp.getName(),
".optimized.ll", *llvmModule);
}
// Dump the assembly output.
if (!serOptions.dumpIntermediatesPath.empty()) {
auto moduleCopy = llvm::CloneModule(*llvmModule);
if (!moduleCopy) {
llvm::errs() << "Error: cloning LLVM IR failed\n";
return failure();
}
std::string targetISA =
translateModuleToISA(*moduleCopy.get(), *targetMachine);
dumpDataToPath(serOptions.dumpIntermediatesPath,
serOptions.dumpBaseName, variantOp.getName(), ".rocmasm",
targetISA);
}
// Serialize hsaco kernel into the binary that we will embed in the
// final FlatBuffer.
std::string targetObj = translateModuleToObj(*llvmModule, *targetMachine);
targetHSACO = createHsaco(variantOp.getLoc(), targetObj, libraryName);
if (targetHSACO.empty())
return failure();
}
if (!serOptions.dumpBinariesPath.empty()) {
dumpDataToPath(serOptions.dumpBinariesPath, serOptions.dumpBaseName,
variantOp.getName(), ".hsaco", targetHSACO);
}
iree_compiler::FlatbufferBuilder builder;
iree_hal_rocm_ExecutableDef_start_as_root(builder);
// Attach embedded source file contents.
SmallVector<iree_hal_rocm_SourceFileDef_ref_t> sourceFileRefs;
if (auto sourcesAttr = variantOp.getSourcesAttr()) {
for (auto sourceAttr : llvm::reverse(sourcesAttr.getValue())) {
if (auto resourceAttr = dyn_cast_if_present<DenseResourceElementsAttr>(
sourceAttr.getValue())) {
auto filenameRef = builder.createString(sourceAttr.getName());
auto contentRef = builder.streamUint8Vec([&](llvm::raw_ostream &os) {
auto blobData = resourceAttr.getRawHandle().getBlob()->getData();
os.write(blobData.data(), blobData.size());
return true;
});
sourceFileRefs.push_back(iree_hal_rocm_SourceFileDef_create(
builder, filenameRef, contentRef));
}
}
std::reverse(sourceFileRefs.begin(), sourceFileRefs.end());
}
SmallVector<StringRef> entryPointNames;
SmallVector<iree_hal_rocm_FileLineLocDef_ref_t> sourceLocationRefs;
entryPointNames.resize(exportOps.size());
for (auto exportOp : exportOps) {
auto ordinalAttr = exportOp.getOrdinalAttr();
if (!ordinalAttr) {
return mlir::emitError(exportOp.getLoc())
<< "could not compile rocm binary: export op is missing ordinal";
}
int64_t ordinal = ordinalAttr.getInt();
entryPointNames[ordinal] = exportOp.getName();
// Optional source location information for debugging/profiling.
if (serOptions.debugLevel >= 1) {
// We only ever resize to the maximum -- so all previous data will
// be kept as-is.
sourceLocationRefs.resize(exportOps.size());
if (auto loc = findFirstFileLoc(exportOp.getLoc())) {
auto filenameRef = builder.createString(loc->getFilename());
sourceLocationRefs[ordinal] = iree_hal_rocm_FileLineLocDef_create(
builder, filenameRef, loc->getLine());
}
}
}
// Optional compilation stage source files.
SmallVector<iree_hal_rocm_StageLocationsDef_ref_t> stageLocationsRefs;
if (serOptions.debugLevel >= 3) {
for (auto exportOp : exportOps) {
SmallVector<iree_hal_rocm_StageLocationDef_ref_t> stageLocationRefs;
if (auto locsAttr = exportOp.getSourceLocsAttr()) {
for (auto locAttr : locsAttr.getValue()) {
if (auto loc =
findFirstFileLoc(cast<LocationAttr>(locAttr.getValue()))) {
auto stageNameRef = builder.createString(locAttr.getName());
auto filenameRef = builder.createString(loc->getFilename());
stageLocationRefs.push_back(iree_hal_rocm_StageLocationDef_create(
builder, stageNameRef,
iree_hal_rocm_FileLineLocDef_create(builder, filenameRef,
loc->getLine())));
}
}
}
if (!stageLocationRefs.empty()) {
// We only ever resize to the maximum -- so all previous data will
// be kept as-is.
stageLocationsRefs.resize(exportOps.size());
int64_t ordinal = exportOp.getOrdinalAttr().getInt();
stageLocationsRefs[ordinal] = iree_hal_rocm_StageLocationsDef_create(
builder, builder.createOffsetVecDestructive(stageLocationRefs));
}
}
}
auto hsacoRef = flatbuffers_string_create(builder, targetHSACO.c_str(),
targetHSACO.size());
auto entryPointsRef = builder.createStringVec(entryPointNames);
iree_hal_rocm_BlockSizeDef_vec_start(builder);
auto blockSizes = workgroupSizes.begin();
for (int i = 0, e = entryPointNames.size(); i < e; ++i) {
iree_hal_rocm_BlockSizeDef_vec_push_create(
builder, (*blockSizes)[0], (*blockSizes)[1], (*blockSizes)[2]);
++blockSizes;
}
auto workgroupLocalMemoriesRef =
builder.createInt32Vec(workgroupLocalMemories);
auto blockSizesRef = iree_hal_rocm_BlockSizeDef_vec_end(builder);
iree_hal_rocm_ExecutableDef_entry_points_add(builder, entryPointsRef);
iree_hal_rocm_ExecutableDef_block_sizes_add(builder, blockSizesRef);
iree_hal_rocm_ExecutableDef_shared_memory_sizes_add(
builder, workgroupLocalMemoriesRef);
iree_hal_rocm_ExecutableDef_hsaco_image_add(builder, hsacoRef);
if (!sourceLocationRefs.empty()) {
auto sourceLocationsRef =
builder.createOffsetVecDestructive(sourceLocationRefs);
iree_hal_rocm_ExecutableDef_source_locations_add(builder,
sourceLocationsRef);
}
if (!stageLocationsRefs.empty()) {
auto stageLocationsRef =
builder.createOffsetVecDestructive(stageLocationsRefs);
iree_hal_rocm_ExecutableDef_stage_locations_add(builder,
stageLocationsRef);
}
if (!sourceFileRefs.empty()) {
auto sourceFilesRef = builder.createOffsetVecDestructive(sourceFileRefs);
iree_hal_rocm_ExecutableDef_source_files_add(builder, sourceFilesRef);
}
iree_hal_rocm_ExecutableDef_end_as_root(builder);
// Add the binary data to the target executable.
executableBuilder.create<iree_compiler::IREE::HAL::ExecutableBinaryOp>(
variantOp.getLoc(), variantOp.getSymName(),
variantOp.getTarget().getFormat(),
builder.getBufferAttr(executableBuilder.getContext()));
return success();
}
private:
const ROCMOptions &options;
};
namespace {
struct ROCMSession
: public PluginSession<ROCMSession, ROCMOptions,
PluginActivationPolicy::DefaultActivated> {
void populateHALTargetDevices(IREE::HAL::TargetDeviceList &targets) {
// #hal.device.target<"rocm", ...
targets.add("rocm",
[&]() { return std::make_shared<ROCMTargetDevice>(options); });
}
void populateHALTargetBackends(IREE::HAL::TargetBackendList &targets) {
// #hal.executable.target<"rocm", ...
targets.add("rocm", [&]() {
LLVMInitializeAMDGPUTarget();
LLVMInitializeAMDGPUTargetMC();
LLVMInitializeAMDGPUTargetInfo();
LLVMInitializeAMDGPUAsmParser();
LLVMInitializeAMDGPUAsmPrinter();
return std::make_shared<ROCMTargetBackend>(options);
});
}
};
} // namespace
} // namespace mlir::iree_compiler::IREE::HAL
extern "C" bool iree_register_compiler_plugin_hal_target_rocm(
mlir::iree_compiler::PluginRegistrar *registrar) {
registrar->registerPlugin<mlir::iree_compiler::IREE::HAL::ROCMSession>(
"hal_target_rocm");
return true;
}
IREE_DEFINE_COMPILER_OPTION_FLAGS(mlir::iree_compiler::IREE::HAL::ROCMOptions);