Enabling external HIP HSACO support ala CUDA external PTX support. (#16830)
This required fixing up ROCMTarget a bit as it was ignoring the target
configuration specified in IR and just using flags.
The `iree-rocm-link-bc` flag was removed because there's no cases in
which it was valid to be false.
diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp
index 8759755..0cb90a5 100644
--- a/compiler/plugins/target/ROCM/ROCMTarget.cpp
+++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp
@@ -69,16 +69,14 @@
static llvm::cl::OptionCategory category("ROCM HAL Target");
binder.opt<std::string>("iree-rocm-target-chip", targetChip,
llvm::cl::cat(category),
- llvm::cl::desc("ROCm target Chip"));
- binder.opt<bool>("iree-rocm-link-bc", linkBitcode, llvm::cl::cat(category),
- llvm::cl::desc("Whether to try Linking to AMD Bitcodes"));
+ llvm::cl::desc("ROCm target chip."));
binder.opt<std::string>("iree-rocm-bc-dir", bitcodeDirectory,
llvm::cl::cat(category),
- llvm::cl::desc("Directory of ROCM Bitcode"));
+ llvm::cl::desc("Directory of ROCM Bitcode."));
binder.opt<int>("iree-rocm-waves-per-eu", wavesPerEu,
llvm::cl::cat(category),
llvm::cl::desc("Optimization hint specifying minimum "
- "number of waves per execution unit"));
+ "number of waves per execution unit."));
binder.opt<std::string>(
"iree-rocm-enable-ukernels", enableROCMUkernels,
llvm::cl::cat(category),
@@ -206,11 +204,12 @@
addConfig("target_arch", b.getStringAttr(options.targetChip));
addConfig("ukernels", b.getStringAttr(options.enableROCMUkernels));
+ if (options.wavesPerEu > 0)
+ addConfig("waves_per_eu", b.getI64IntegerAttr(options.wavesPerEu));
ArrayAttr mmaAttrs = getROCMSupportedMmaAttrs(context, options.targetChip);
- if (mmaAttrs) {
+ if (mmaAttrs)
addConfig("mma_intrinsics", mmaAttrs);
- }
return b.getAttr<IREE::HAL::ExecutableTargetAttr>(
b.getStringAttr("rocm"), b.getStringAttr("rocm-hsaco-fb"),
@@ -288,10 +287,11 @@
LogicalResult serializeExecutable(const SerializationOptions &serOptions,
IREE::HAL::ExecutableVariantOp variantOp,
OpBuilder &executableBuilder) override {
- // Perform the translation in a separate context to avoid any
- // multi-threading issues.
- llvm::LLVMContext context;
- const std::string &bitcodeDirectory = options.getBitcodeDirectory();
+ ModuleOp innerModuleOp = variantOp.getInnerModule();
+ auto targetAttr = variantOp.getTargetAttr();
+ StringRef targetArch = options.targetChip;
+ if (auto attr = getConfigStringAttr(targetAttr, "target_arch"))
+ targetArch = attr->getValue();
// We name our files after the executable name so that they are easy to
// track both during compilation (logs/artifacts/etc), as outputs (final
@@ -300,43 +300,25 @@
auto libraryName =
variantOp->getParentOfType<IREE::HAL::ExecutableOp>().getName().str();
- ModuleOp innerModuleOp = variantOp.getInnerModule();
-
- auto llvmModule =
- mlir::translateModuleToLLVMIR(innerModuleOp, context, libraryName);
- if (!llvmModule) {
- return variantOp.emitError() << "failed to translate the MLIR LLVM "
- "dialect to the native llvm::Module";
- }
-
// Collect all the entry point names.
SmallVector<IREE::HAL::ExecutableExportOp> exportOps;
llvm::StringMap<IREE::HAL::ExecutableExportOp> exportOpMap;
- for (auto op : variantOp.getExportOps()) {
- exportOps.push_back(op);
- exportOpMap[op.getSymName()] = op;
- }
std::vector<std::array<int32_t, 3>> workgroupSizes;
SmallVector<uint32_t> workgroupLocalMemories;
int32_t subgroupSize = 64;
- StringRef subTarget = options.targetChip;
- StringRef GFX9("gfx9");
- for (auto func : innerModuleOp.getOps<LLVM::LLVMFuncOp>()) {
- int32_t flatWgSize = 1;
- auto *llvmFunc = llvmModule->getFunction(func.getName());
- if (llvmFunc->isDeclaration())
- continue;
+ for (auto exportOp : variantOp.getExportOps()) {
+ exportOps.push_back(exportOp);
+ exportOpMap[exportOp.getSymName()] = exportOp;
+
std::array<int32_t, 3> workgroupSize;
- auto exportOp = exportOpMap[func.getName()];
if (std::optional<ArrayAttr> workgroupSizeAttr =
exportOp.getWorkgroupSize()) {
- for (auto it : llvm::enumerate(workgroupSizeAttr.value())) {
+ for (auto it : llvm::enumerate(workgroupSizeAttr.value()))
workgroupSize[it.index()] = it.value().cast<IntegerAttr>().getInt();
- flatWgSize *= it.value().cast<IntegerAttr>().getInt();
- }
} else {
workgroupSize = {1, 1, 1};
}
+ workgroupSizes.push_back(workgroupSize);
if (auto setSubgroupSize = getSubgroupSize(exportOp)) {
if (subgroupSize != 32 && subgroupSize != 64) {
@@ -346,132 +328,206 @@
subgroupSize = *setSubgroupSize;
}
- int64_t wavesPerEu = options.wavesPerEu;
- IREE::Codegen::TranslationInfoAttr translationInfo =
- getTranslationInfo(exportOp);
- if (auto translationConfig = translationInfo.getConfiguration()) {
- if (auto attr = dyn_cast_or_null<IntegerAttr>(
- translationConfig.get("amdgpu-waves-per-eu"))) {
- wavesPerEu = attr.getValue().getSExtValue();
- }
- }
-
- workgroupSizes.push_back(workgroupSize);
uint32_t workgroupLocalMemory = 0;
if (auto workgroupLocalMemoryAttr = exportOp.getWorkgroupLocalMemory()) {
workgroupLocalMemory = workgroupLocalMemoryAttr->getSExtValue();
}
workgroupLocalMemories.push_back(workgroupLocalMemory);
- // For GPU kernels,
- // 1. Insert AMDGPU_KERNEL calling convention.
- // 2. Insert amdgpu-flat-workgroup-size(1, 256) attribute.
- // 3. Insert amdgpu-implicitarg-num-bytes=56 (which must be set on OpenCL
- // and HIP kernels per Clang)
- llvmFunc->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
- std::string wgSizeRange = std::string("1, ") + std::to_string(flatWgSize);
- llvmFunc->addFnAttr("amdgpu-flat-work-group-size", wgSizeRange);
- if (wavesPerEu > 0)
- llvmFunc->addFnAttr("amdgpu-waves-per-eu", std::to_string(wavesPerEu));
- if (subTarget.starts_with(GFX9))
- addPreloadKernArgHint(llvmFunc);
}
- std::unique_ptr<llvm::TargetMachine> targetMachine;
- {
- llvm::Triple triple("amdgcn-amd-amdhsa");
- std::string error;
- const llvm::Target *target =
- llvm::TargetRegistry::lookupTarget("", triple, error);
- if (target == nullptr) {
- return variantOp.emitError() << "cannot initialize target triple";
+ std::string targetHSACO;
+ if (variantOp.isExternal()) {
+ if (!variantOp.getObjects().has_value()) {
+ return variantOp.emitOpError()
+ << "no objects defined for external variant";
+ } else if (variantOp.getObjects()->getValue().size() != 1) {
+ // For now we assume there will be exactly one object file.
+ // In the future we will want to perform a linking step here and ideally
+ // support _also_ linking in the codegen results.
+ return variantOp.emitOpError() << "only one object reference is "
+ "supported for external variants";
}
- llvm::TargetOptions opt;
- opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
- opt.UnsafeFPMath = false;
- opt.NoInfsFPMath = false;
- opt.NoNaNsFPMath = true;
- std::string features;
- if (subTarget.starts_with(GFX9)) {
- features = "+sramecc,-xnack";
+
+ // Read the HSACO from the object file.
+ auto objectAttr = llvm::cast<IREE::HAL::ExecutableObjectAttr>(
+ variantOp.getObjects()->getValue().front());
+ if (auto data = objectAttr.loadData()) {
+ targetHSACO = data.value();
} else {
- // GFX 10 or 11.
- if (subgroupSize == 32)
- features = "+wavefrontsize32";
- if (subgroupSize == 64)
- features = "+wavefrontsize64";
+ return variantOp.emitOpError()
+ << "object file could not be loaded: " << objectAttr;
+ }
+ } else {
+ // Perform the translation in a separate context to avoid any
+ // multi-threading issues.
+ llvm::LLVMContext context;
+
+ auto llvmModule =
+ mlir::translateModuleToLLVMIR(innerModuleOp, context, libraryName);
+ if (!llvmModule) {
+ return variantOp.emitError() << "failed to translate the MLIR LLVM "
+ "dialect to the native llvm::Module";
}
- targetMachine.reset(target->createTargetMachine(
- triple.str(), options.targetChip, features, opt,
- llvm::Reloc::Model::PIC_, std::nullopt,
- llvm::CodeGenOptLevel::Aggressive));
+ for (auto func : innerModuleOp.getOps<LLVM::LLVMFuncOp>()) {
+ int32_t flatWgSize = 1;
+ auto *llvmFunc = llvmModule->getFunction(func.getName());
+ if (llvmFunc->isDeclaration())
+ continue;
+ auto exportOp = exportOpMap[func.getName()];
+ if (auto workgroupSizeAttr = exportOp.getWorkgroupSize()) {
+ for (auto it : llvm::enumerate(workgroupSizeAttr.value())) {
+ flatWgSize *= it.value().cast<IntegerAttr>().getInt();
+ }
+ }
- if (targetMachine == nullptr) {
- return variantOp.emitError() << "cannot initialize target machine";
+ // Try to get waves-per-eu from the export-specific translation info in
+ // cases where codegen decides to override the value.
+ // Otherwise, fallback to the default option.
+ int64_t wavesPerEu = 0;
+ IREE::Codegen::TranslationInfoAttr translationInfo =
+ getTranslationInfo(exportOp);
+ if (auto translationConfig = translationInfo.getConfiguration()) {
+ if (auto attr =
+ translationConfig.getAs<IntegerAttr>("waves_per_eu")) {
+ wavesPerEu = attr.getValue().getSExtValue();
+ }
+ }
+ if (wavesPerEu == 0) {
+ if (auto attr = getConfigIntegerAttr(targetAttr, "waves_per_eu"))
+ wavesPerEu = attr->getValue().getSExtValue();
+ }
+
+ // For GPU kernels,
+ // 1. Insert AMDGPU_KERNEL calling convention.
+ // 2. Insert amdgpu-flat-workgroup-size(1, 256) attribute.
+ // 3. Insert amdgpu-implicitarg-num-bytes=56 (which must be set on
+ // OpenCL and HIP kernels per Clang)
+ llvmFunc->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
+ std::string wgSizeRange =
+ std::string("1, ") + std::to_string(flatWgSize);
+ llvmFunc->addFnAttr("amdgpu-flat-work-group-size", wgSizeRange);
+ if (wavesPerEu > 0) {
+ llvmFunc->addFnAttr("amdgpu-waves-per-eu",
+ std::to_string(wavesPerEu));
+ }
+ if (targetArch.starts_with("gfx9"))
+ addPreloadKernArgHint(llvmFunc);
}
- }
- llvmModule->setDataLayout(targetMachine->createDataLayout());
+ std::unique_ptr<llvm::TargetMachine> targetMachine;
+ {
+ llvm::Triple triple("amdgcn-amd-amdhsa");
+ std::string error;
+ const llvm::Target *target =
+ llvm::TargetRegistry::lookupTarget("", triple, error);
+ if (target == nullptr) {
+ return variantOp.emitError() << "cannot initialize target triple";
+ }
+ llvm::TargetOptions opt;
+ opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
+ opt.UnsafeFPMath = false;
+ opt.NoInfsFPMath = false;
+ opt.NoNaNsFPMath = true;
+ std::string features;
+ if (targetArch.starts_with("gfx9")) {
+ features = "+sramecc,-xnack";
+ } else {
+ // GFX 10 or 11.
+ if (subgroupSize == 32)
+ features = "+wavefrontsize32";
+ if (subgroupSize == 64)
+ features = "+wavefrontsize64";
+ }
- for (llvm::Function &f : llvmModule->functions())
- f.addFnAttr(llvm::Attribute::AlwaysInline);
+ targetMachine.reset(target->createTargetMachine(
+ triple.str(), targetArch, features, opt, llvm::Reloc::Model::PIC_,
+ std::nullopt, llvm::CodeGenOptLevel::Aggressive));
- // Link user modules and libdevice (if required).
- // Note that linking order matters:
- llvm::Linker linker(*llvmModule);
- if (failed(linkCmdlineBitcodeFiles(
- variantOp.getLoc(), linker, llvm::Linker::OverrideFromSrc,
- *targetMachine, llvmModule->getContext()))) {
- return failure();
- }
+ if (targetMachine == nullptr) {
+ return variantOp.emitError() << "cannot initialize target machine";
+ }
+ }
- if (!options.enableROCMUkernels.empty() &&
- options.enableROCMUkernels != "none") {
- auto enabledUkernelsStr = StringRef(options.enableROCMUkernels);
- if (failed(linkUkernelBCFiles(
- variantOp.getLoc(), llvmModule.get(), enabledUkernelsStr,
- options.targetChip, bitcodeDirectory,
- llvm::Linker::OverrideFromSrc, *targetMachine)))
+ llvmModule->setDataLayout(targetMachine->createDataLayout());
+
+ for (llvm::Function &f : llvmModule->functions())
+ f.addFnAttr(llvm::Attribute::AlwaysInline);
+
+ // Link user-provided modules.
+ llvm::Linker linker(*llvmModule);
+ if (failed(linkCmdlineBitcodeFiles(
+ variantOp.getLoc(), linker, llvm::Linker::OverrideFromSrc,
+ *targetMachine, llvmModule->getContext()))) {
return failure();
- }
- // Link module to Device Library
- if (options.linkBitcode) {
+ }
+
+ // Link module to any enabled ukernels.
+ const std::string &bitcodeDirectory = options.getBitcodeDirectory();
+ StringRef enabledUkernels;
+ if (auto attr = getConfigStringAttr(targetAttr, "ukernels"))
+ enabledUkernels = attr->getValue();
+ if (!enabledUkernels.empty() && enabledUkernels != "none") {
+ if (failed(linkUkernelBitcodeFiles(
+ variantOp.getLoc(), llvmModule.get(), enabledUkernels,
+ targetArch, bitcodeDirectory, llvm::Linker::OverrideFromSrc,
+ *targetMachine))) {
+ return failure();
+ }
+ }
+
+ // Link module to HIP device library.
if (bitcodeDirectory.empty()) {
return variantOp.emitError()
<< "cannot find ROCM bitcode files. Check your installation "
- "consistency and in the worst case, set --iree-rocm-bc-dir= "
- "to an explicit location on your system.";
+ "consistency and in the worst case, set "
+ "--iree-rocm-bc-dir= to a path on your system.";
}
- if (failed(linkROCDLIfNecessary(variantOp.getLoc(), llvmModule.get(),
- options.targetChip, bitcodeDirectory)))
+ if (failed(linkHIPBitcodeIfNeeded(variantOp.getLoc(), llvmModule.get(),
+ targetArch, bitcodeDirectory))) {
return failure();
- }
- if (!serOptions.dumpIntermediatesPath.empty()) {
- dumpModuleToPath(serOptions.dumpIntermediatesPath,
- serOptions.dumpBaseName, variantOp.getName(),
- ".linked.ll", *llvmModule);
- }
- // Add Optimize module
- optimizeModule(*llvmModule, *targetMachine);
- // Store optimized ll.
- if (!serOptions.dumpIntermediatesPath.empty()) {
- dumpModuleToPath(serOptions.dumpIntermediatesPath,
- serOptions.dumpBaseName, variantOp.getName(),
- ".optimized.ll", *llvmModule);
- }
- // Serialize hsaco kernel into the binary that we will embed in the
- // final FlatBuffer.
- std::unique_ptr<llvm::Module> moduleCopy;
- if (!serOptions.dumpIntermediatesPath.empty()) {
- moduleCopy = llvm::CloneModule(*llvmModule);
- if (!moduleCopy)
- llvm::errs() << "Error: cloning LLVM IR failed\n";
- }
- std::string targetObj = translateModuleToObj(*llvmModule, *targetMachine);
- std::string targetHSACO =
- createHsaco(variantOp.getLoc(), targetObj, libraryName);
- if (targetHSACO.empty()) {
- return failure();
+ }
+
+ // Sets HIP platform globals based on the target architecture.
+ if (failed(setHIPGlobals(variantOp.getLoc(), llvmModule.get(),
+ targetArch))) {
+ return failure();
+ }
+
+ if (!serOptions.dumpIntermediatesPath.empty()) {
+ dumpModuleToPath(serOptions.dumpIntermediatesPath,
+ serOptions.dumpBaseName, variantOp.getName(),
+ ".linked.ll", *llvmModule);
+ }
+
+ // Run LLVM optimization passes.
+ optimizeModule(*llvmModule, *targetMachine);
+ if (!serOptions.dumpIntermediatesPath.empty()) {
+ dumpModuleToPath(serOptions.dumpIntermediatesPath,
+ serOptions.dumpBaseName, variantOp.getName(),
+ ".optimized.ll", *llvmModule);
+ }
+
+ // Dump the assembly output.
+ if (!serOptions.dumpIntermediatesPath.empty()) {
+ auto moduleCopy = llvm::CloneModule(*llvmModule);
+ if (!moduleCopy) {
+ llvm::errs() << "Error: cloning LLVM IR failed\n";
+ return failure();
+ }
+ std::string targetISA =
+ translateModuleToISA(*moduleCopy.get(), *targetMachine);
+ dumpDataToPath(serOptions.dumpIntermediatesPath,
+ serOptions.dumpBaseName, variantOp.getName(), ".rocmasm",
+ targetISA);
+ }
+
+ // Serialize hsaco kernel into the binary that we will embed in the
+ // final FlatBuffer.
+ std::string targetObj = translateModuleToObj(*llvmModule, *targetMachine);
+ targetHSACO = createHsaco(variantOp.getLoc(), targetObj, libraryName);
+ if (targetHSACO.empty())
+ return failure();
}
if (!serOptions.dumpBinariesPath.empty()) {
@@ -598,13 +654,6 @@
variantOp.getTarget().getFormat(),
builder.getBufferAttr(executableBuilder.getContext()));
- if (!serOptions.dumpIntermediatesPath.empty()) {
- std::string targetISA =
- translateModuleToISA(*moduleCopy.get(), *targetMachine);
- dumpDataToPath(serOptions.dumpIntermediatesPath, serOptions.dumpBaseName,
- variantOp.getName(), ".rocmasm", targetISA);
- }
-
return success();
}
diff --git a/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp b/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp
index 12b7861..4663126 100644
--- a/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp
+++ b/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp
@@ -14,6 +14,7 @@
#include "llvm/IRReader/IRReader.h"
#include "llvm/Linker/Linker.h"
#include "llvm/Support/FileUtilities.h"
+#include "llvm/Support/Path.h"
#include "llvm/Support/Process.h"
#include "llvm/Support/Program.h"
#include "llvm/Support/SourceMgr.h"
@@ -24,33 +25,16 @@
namespace mlir::iree_compiler::IREE::HAL {
-//===========Link LLVM Module to ROCDL Start===================/
-// Inspiration of code from this section comes from XLA Kernel Gen Project
-// https://github.com/openxla/xla/blob/main/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
-
-bool couldNeedDeviceBitcode(const llvm::Module &module) {
- for (const llvm::Function &function : module.functions()) {
- // The list of prefixes should be in sync with library functions used in
- // target_util.cc.
- if (!function.isIntrinsic() && function.isDeclaration() &&
- (function.getName().starts_with("__ocml_") ||
- function.getName().starts_with("__ockl_"))) {
- return true;
- }
- }
- return false;
-}
-
-std::unique_ptr<llvm::Module> loadIRModule(Location loc,
- const std::string &filename,
- llvm::LLVMContext *llvm_context) {
+static std::unique_ptr<llvm::Module>
+loadIRModule(Location loc, const std::string &filename,
+ llvm::LLVMContext *llvm_context) {
llvm::SMDiagnostic diagnostic;
std::unique_ptr<llvm::Module> module(
llvm::parseIRFile(llvm::StringRef(filename.data(), filename.size()),
diagnostic, *llvm_context));
if (!module) {
- mlir::emitError(loc) << "error loading ROCDL LLVM module: "
+ mlir::emitError(loc) << "error loading HIP LLVM module: "
<< diagnostic.getFilename().str() << ":"
<< diagnostic.getLineNo() << ":"
<< diagnostic.getColumnNo() << ": "
@@ -61,26 +45,27 @@
return module;
}
-LogicalResult
-linkWithBitcodeVector(Location loc, llvm::Module *module,
- const std::vector<std::string> &bitcode_path_vector) {
+static LogicalResult linkWithBitcodeFiles(Location loc, llvm::Module *module,
+ ArrayRef<std::string> bitcodePaths) {
+ if (bitcodePaths.empty())
+ return success();
llvm::Linker linker(*module);
-
- for (auto &bitcode_path : bitcode_path_vector) {
- if (!(llvm::sys::fs::exists(bitcode_path)))
+ for (auto &bitcodePath : bitcodePaths) {
+ if (!llvm::sys::fs::exists(bitcodePath)) {
return mlir::emitError(loc)
<< "AMD bitcode module is required by this module but was "
"not found at "
- << bitcode_path;
- std::unique_ptr<llvm::Module> bitcode_module =
- loadIRModule(loc, bitcode_path, &module->getContext());
- if (!bitcode_module)
+ << bitcodePath;
+ }
+ std::unique_ptr<llvm::Module> bitcodeModule =
+ loadIRModule(loc, bitcodePath, &module->getContext());
+ if (!bitcodeModule)
return failure();
// Ignore the data layout of the module we're importing. This avoids a
// warning from the linker.
- bitcode_module->setDataLayout(module->getDataLayout());
+ bitcodeModule->setDataLayout(module->getDataLayout());
if (linker.linkInModule(
- std::move(bitcode_module), llvm::Linker::Flags::LinkOnlyNeeded,
+ std::move(bitcodeModule), llvm::Linker::Flags::LinkOnlyNeeded,
[](llvm::Module &M, const llvm::StringSet<> &GVS) {
llvm::internalizeModule(M, [&GVS](const llvm::GlobalValue &GV) {
return !GV.hasName() || (GVS.count(GV.getName()) == 0);
@@ -92,10 +77,10 @@
return success();
}
-LogicalResult linkPathBitcodeFiles(Location loc, llvm::Linker &linker,
- unsigned linkerFlags, StringRef path,
- llvm::TargetMachine &targetMachine,
- llvm::LLVMContext &context) {
+static LogicalResult linkBitcodeFile(Location loc, llvm::Linker &linker,
+ unsigned linkerFlags, StringRef path,
+ llvm::TargetMachine &targetMachine,
+ llvm::LLVMContext &context) {
auto bitcodeBufferRef = llvm::MemoryBuffer::getFile(path);
if (auto ec = bitcodeBufferRef.getError()) {
return mlir::emitError(loc) << "failed reading user bitcode file `" << path
@@ -134,24 +119,9 @@
return success();
}
-static std::vector<std::string> getROCDLPaths(std::string targetChip,
- std::string bitCodeDir) {
- // AMDGPU bitcodes.
- static const std::vector<std::string> rocdlFilenames({"ocml.bc", "ockl.bc"});
-
- // Construct full path to ROCDL bitcode libraries.
- std::vector<std::string> result;
- std::string app = "/";
- for (auto &filename : rocdlFilenames) {
- result.push_back(bitCodeDir + app + filename);
- }
- return result;
-}
-
static std::vector<std::string> getUkernelPaths(StringRef enabledUkernelsStr,
StringRef targetChip,
- StringRef bitCodeDir) {
- // AMD bitcodes.
+ StringRef bitcodePath) {
std::vector<std::string> selectedUkernelNames;
if (enabledUkernelsStr == "all") {
const char *allUkernelNames[] = {"argmax"};
@@ -173,7 +143,7 @@
for (auto &kernelName : selectedUkernelNames) {
std::string filename =
"rocm_" + kernelName + "_ukernel_" + targetChip.str();
- result.push_back(bitCodeDir.str() + app + filename + ".bc");
+ result.push_back(bitcodePath.str() + app + filename + ".bc");
}
return result;
}
@@ -191,13 +161,13 @@
APInt(globalValue->getValueType()->getIntegerBitWidth(), newValue)));
}
-static LogicalResult linkModuleWithGlobal(Location loc, llvm::Module *module,
- std::string &targetChip) {
+LogicalResult setHIPGlobals(Location loc, llvm::Module *module,
+ StringRef targetChip) {
// Link target chip ISA version as global.
const int kLenOfChipPrefix = 3;
- std::string chipId = targetChip.substr(kLenOfChipPrefix);
+ auto chipId = targetChip.substr(kLenOfChipPrefix);
// i.e gfx90a -> 9000 series.
- int chipArch = stoi(chipId.substr(0, chipId.length() - 1)) * 100;
+ int chipArch = stoi(chipId.substr(0, chipId.size() - 1).str()) * 100;
// Oldest GFX arch supported is gfx60x.
if (chipArch < 6000)
return failure();
@@ -207,8 +177,8 @@
// Get chip code from suffix. i.e gfx1103 -> `3`.
// gfx90a -> `a` == `10`.
// gfx90c -> `c` == `12`.
- std::string chipSuffix = chipId.substr(chipId.length() - 1);
- uint32_t chipCode;
+ auto chipSuffix = chipId.substr(chipId.size() - 1);
+ uint32_t chipCode = 0;
if (chipSuffix == "a") {
chipCode = chipArch + 10;
} else if (chipSuffix == "c") {
@@ -218,7 +188,7 @@
return mlir::emitError(loc)
<< "error linking module with globals: unrecognized chip suffix '"
<< chipSuffix << "' for " << targetChip;
- chipCode = chipArch + stoi(chipSuffix);
+ chipCode = chipArch + stoi(chipSuffix.str());
}
auto *int32Type = llvm::Type::getInt32Ty(module->getContext());
overridePlatformGlobal(module, "__oclc_ISA_version", chipCode, int32Type);
@@ -235,30 +205,45 @@
overridePlatformGlobal(module, globalParam.first, globalParam.second,
boolType);
}
+
return success();
}
-// Links ROCm-Device-Libs into the given module if the module needs it.
-LogicalResult linkROCDLIfNecessary(Location loc, llvm::Module *module,
- std::string targetChip,
- std::string bitCodeDir) {
- if (!couldNeedDeviceBitcode(*module))
- return success();
- if (!succeeded(HAL::linkWithBitcodeVector(
- loc, module, getROCDLPaths(targetChip, bitCodeDir))))
- return failure();
- if (!succeeded(HAL::linkModuleWithGlobal(loc, module, targetChip)))
- return failure();
- return success();
+LogicalResult linkHIPBitcodeIfNeeded(Location loc, llvm::Module *module,
+ StringRef targetChip,
+ StringRef bitcodePath) {
+ bool usesOCML = false;
+ bool usesOCKL = false;
+ for (const llvm::Function &function : module->functions()) {
+ if (!function.isIntrinsic() && function.isDeclaration()) {
+ auto functionName = function.getName();
+ if (functionName.starts_with("__ocml_"))
+ usesOCML = true;
+ else if (functionName.starts_with("__ockl_"))
+ usesOCKL = true;
+ }
+ }
+
+ // Link externally-provided bitcode files when used.
+ SmallVector<std::string> bitcodePaths;
+ if (usesOCML) {
+ bitcodePaths.push_back(
+ (bitcodePath + llvm::sys::path::get_separator() + "ocml.bc").str());
+ }
+ if (usesOCKL) {
+ bitcodePaths.push_back(
+ (bitcodePath + llvm::sys::path::get_separator() + "ockl.bc").str());
+ }
+ return linkWithBitcodeFiles(loc, module, bitcodePaths);
}
-// Links optimized Ukernel bitcodes into the given module if the module needs
-// it.
-LogicalResult linkUkernelBCFiles(Location loc, llvm::Module *module,
- StringRef enabledUkernelsStr,
- StringRef targetChip, StringRef bitCodeDir,
- unsigned linkerFlags,
- llvm::TargetMachine &targetMachine) {
+// Links optimized Ukernel bitcode into the given module if the module needs it.
+LogicalResult linkUkernelBitcodeFiles(Location loc, llvm::Module *module,
+ StringRef enabledUkernelsStr,
+ StringRef targetChip,
+ StringRef bitcodePath,
+ unsigned linkerFlags,
+ llvm::TargetMachine &targetMachine) {
// Early exit if Ukernel not supported on target chip.
if (!iree_compiler::hasUkernelSupportedRocmArch(targetChip)) {
return mlir::emitError(loc)
@@ -266,20 +251,17 @@
<< "' not supported on target chip: " << targetChip;
}
std::vector<std::string> ukernelPaths =
- getUkernelPaths(enabledUkernelsStr, targetChip, bitCodeDir);
+ getUkernelPaths(enabledUkernelsStr, targetChip, bitcodePath);
llvm::Linker linker(*module);
for (auto &path : ukernelPaths) {
- if (failed(linkPathBitcodeFiles(loc, linker, linkerFlags, StringRef(path),
- targetMachine, module->getContext())))
+ if (failed(linkBitcodeFile(loc, linker, linkerFlags, StringRef(path),
+ targetMachine, module->getContext())))
return failure();
}
return success();
}
-//===========Link LLVM Module to ROCDL End===================/
-
-//=====================Create HSACO Begin=============//
// Link object file using lld lnker to generate code object
// Inspiration from this section comes from LLVM-PROJECT-MLIR by
// ROCmSoftwarePlatform
@@ -311,7 +293,6 @@
llvm::FileRemover cleanupHsaco(tempHsacoFilename);
// Invoke lld. Expect a true return value from lld.
- // Searching for LLD
const SmallVector<std::string> &toolNames{"iree-lld", "lld"};
std::string lldProgram = findTool(toolNames);
if (lldProgram.empty()) {
@@ -328,7 +309,7 @@
tempHsacoFilename.str(),
};
- // Executing LLD
+ // Execute LLD.
std::string errorMessage;
int lldResult = llvm::sys::ExecuteAndWait(
unescapeCommandLineComponent(lldProgram),
@@ -350,6 +331,5 @@
hsacoFile->getBuffer().end());
return strHSACO;
}
-//==============Create HSACO End=============//
} // namespace mlir::iree_compiler::IREE::HAL
diff --git a/compiler/plugins/target/ROCM/ROCMTargetUtils.h b/compiler/plugins/target/ROCM/ROCMTargetUtils.h
index bfa4146..eb0dd78 100644
--- a/compiler/plugins/target/ROCM/ROCMTargetUtils.h
+++ b/compiler/plugins/target/ROCM/ROCMTargetUtils.h
@@ -13,17 +13,22 @@
namespace mlir::iree_compiler::IREE::HAL {
-// Links LLVM module to ROC Device Library Bit Code
-LogicalResult linkROCDLIfNecessary(Location loc, llvm::Module *module,
- std::string targetChip,
- std::string bitCodeDir);
+// Sets HIP platform globals based on the target architecture.
+LogicalResult setHIPGlobals(Location loc, llvm::Module *module,
+ StringRef targetChip);
+
+// Links HIP device bitcode if the module uses any symbols from it.
+LogicalResult linkHIPBitcodeIfNeeded(Location loc, llvm::Module *module,
+ StringRef targetChip,
+ StringRef bitcodePath);
// Links optimized Ukernel module.
-LogicalResult linkUkernelBCFiles(Location loc, llvm::Module *module,
- StringRef enabledUkernelsStr,
- StringRef targetChip, StringRef bitCodeDir,
- unsigned linkerFlags,
- llvm::TargetMachine &targetMachine);
+LogicalResult linkUkernelBitcodeFiles(Location loc, llvm::Module *module,
+ StringRef enabledUkernelsStr,
+ StringRef targetChip,
+ StringRef bitcodePath,
+ unsigned linkerFlags,
+ llvm::TargetMachine &targetMachine);
// Compiles ISAToHsaco Code
std::string createHsaco(Location loc, const std::string isa, StringRef name);
diff --git a/compiler/plugins/target/ROCM/builtins/ukernel/CMakeLists.txt b/compiler/plugins/target/ROCM/builtins/ukernel/CMakeLists.txt
index 92135e5..7550141 100644
--- a/compiler/plugins/target/ROCM/builtins/ukernel/CMakeLists.txt
+++ b/compiler/plugins/target/ROCM/builtins/ukernel/CMakeLists.txt
@@ -3,22 +3,22 @@
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-if (NOT IREE_TARGET_BACKEND_ROCM)
+if(NOT IREE_TARGET_BACKEND_ROCM)
return()
endif()
# Check if HIP is installed on system.
# HIP is required to compile ukernels.
# TODO: We can do better than this and ensure that headers are always available.
-if (NOT IREE_ROCM_PATH)
+if(NOT IREE_ROCM_PATH)
set(IREE_ROCM_PATH "/opt/rocm")
endif()
-set (IREE_ROCM_VERSION "${IREE_ROCM_PATH}/.info/version")
-if (NOT EXISTS ${IREE_ROCM_VERSION})
+set(IREE_ROCM_VERSION "${IREE_ROCM_PATH}/include/hip/hip_version.h")
+if(NOT EXISTS ${IREE_ROCM_VERSION})
message(STATUS
"hip runtime cannot be found in ${IREE_ROCM_PATH}.
Please try setting IREE_ROCM_PATH to rocm directory.
- Ukernel will not be compiled.")
+ Ukernels will not be compiled.")
return()
endif()
@@ -27,8 +27,8 @@
set(_platform_lib_reldir "iree_platform_libs/rocm")
set(_device_bc_path "${IREE_COMPILER_DYLIB_DIR}/iree_platform_libs/rocm")
-set (_amd_ukernel_libs)
-set (_amd_ukernel_targets)
+set(_amd_ukernel_libs)
+set(_amd_ukernel_targets)
function(iree_rocm_bitcode_library)
cmake_parse_arguments(
_RULE
@@ -45,9 +45,9 @@
endif()
set(_ROCM_ARCH "${_RULE_ROCM_ARCH}")
- set (OPT_FLAG "-O0")
- if (_ROCM_ARCH MATCHES "GFX9")
- set (OPT_FLAG "-O3")
+ set(OPT_FLAG "-O0")
+ if(_ROCM_ARCH MATCHES "GFX9")
+ set(OPT_FLAG "-O3")
endif()
set(_COPTS
"-x" "hip"
diff --git a/experimental/regression_suite/external_test_suite/config_gpu_rocm_rdna3.json b/experimental/regression_suite/external_test_suite/config_gpu_rocm_rdna3.json
index 03b7382..56ebe3c 100644
--- a/experimental/regression_suite/external_test_suite/config_gpu_rocm_rdna3.json
+++ b/experimental/regression_suite/external_test_suite/config_gpu_rocm_rdna3.json
@@ -1,9 +1,8 @@
{
"config_name": "gpu_rocm_rdna3",
- "iree_compile_flags" : [
+ "iree_compile_flags": [
"--iree-hal-target-backends=rocm",
- "--iree-rocm-target-chip=gfx1100",
- "--iree-rocm-link-bc=true"
+ "--iree-rocm-target-chip=gfx1100"
],
"iree_run_module_flags": [
"--device=rocm"
@@ -784,7 +783,6 @@
"test_unsqueeze_unsorted_axes",
"test_upsample_nearest",
"test_wrap_pad",
-
// These pass on CPU but fail on ROCm.
],
"expected_run_failures": [
@@ -904,7 +902,6 @@
"test_shape_start_1_end_negative_1",
"test_shape_start_negative_1",
"test_where_long_example",
-
// These pass on CPU but fail on ROCm.
"test_gather_negative_indices",
"test_reduce_l1_default_axes_keepdims_example_expanded",
diff --git a/experimental/regression_suite/tests/pregenerated/test_llama2.py b/experimental/regression_suite/tests/pregenerated/test_llama2.py
index 0ff2ca9..dcde908 100644
--- a/experimental/regression_suite/tests/pregenerated/test_llama2.py
+++ b/experimental/regression_suite/tests/pregenerated/test_llama2.py
@@ -103,7 +103,6 @@
+ [
"--iree-hal-target-backends=rocm",
"--iree-rocm-target-chip=gfx1100",
- "--iree-rocm-link-bc=true",
],
)
diff --git a/experimental/regression_suite/tests/pregenerated/test_ukernel.py b/experimental/regression_suite/tests/pregenerated/test_ukernel.py
index 22a5aa2..f96225d 100644
--- a/experimental/regression_suite/tests/pregenerated/test_ukernel.py
+++ b/experimental/regression_suite/tests/pregenerated/test_ukernel.py
@@ -43,7 +43,6 @@
+ [
"--iree-hal-target-backends=rocm",
"--iree-rocm-target-chip=gfx90a",
- "--iree-rocm-link-bc=true",
"--iree-rocm-enable-ukernels=argmax",
],
)
@@ -58,7 +57,6 @@
+ [
"--iree-hal-target-backends=rocm",
"--iree-rocm-target-chip=gfx940",
- "--iree-rocm-link-bc=true",
"--iree-rocm-enable-ukernels=argmax",
],
)
diff --git a/integrations/pjrt/python_packages/iree_rocm_plugin/setup.py b/integrations/pjrt/python_packages/iree_rocm_plugin/setup.py
index 795aa38..1f6b76b 100644
--- a/integrations/pjrt/python_packages/iree_rocm_plugin/setup.py
+++ b/integrations/pjrt/python_packages/iree_rocm_plugin/setup.py
@@ -32,7 +32,7 @@
print("*****************************", file=sys.stderr)
self.build_configuration(
os.path.join(THIS_DIR, "build", "cmake"),
- extra_cmake_args=("-DIREE_EXTERNAL_HAL_DRIVERS=ROCM",),
+ extra_cmake_args=("-DIREE_EXTERNAL_HAL_DRIVERS=rocm",),
)
print("Target populated.", file=sys.stderr)
diff --git a/integrations/pjrt/src/CMakeLists.txt b/integrations/pjrt/src/CMakeLists.txt
index 6310d38..d0f4947 100644
--- a/integrations/pjrt/src/CMakeLists.txt
+++ b/integrations/pjrt/src/CMakeLists.txt
@@ -27,7 +27,7 @@
if(IREE_HAL_DRIVER_CUDA)
add_subdirectory(iree_pjrt/cuda)
endif()
-if("ROCM" IN_LIST IREE_EXTERNAL_HAL_DRIVERS)
+if("rocm" IN_LIST IREE_EXTERNAL_HAL_DRIVERS)
add_subdirectory(iree_pjrt/rocm)
endif()
if(IREE_HAL_DRIVER_VULKAN)
diff --git a/samples/custom_dispatch/hip/CMakeLists.txt b/samples/custom_dispatch/hip/CMakeLists.txt
new file mode 100644
index 0000000..ae2678c
--- /dev/null
+++ b/samples/custom_dispatch/hip/CMakeLists.txt
@@ -0,0 +1,7 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+iree_add_all_subdirs()
diff --git a/samples/custom_dispatch/hip/kernels/CMakeLists.txt b/samples/custom_dispatch/hip/kernels/CMakeLists.txt
new file mode 100644
index 0000000..7f68ee4
--- /dev/null
+++ b/samples/custom_dispatch/hip/kernels/CMakeLists.txt
@@ -0,0 +1,67 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+if((NOT IREE_TARGET_BACKEND_ROCM) OR
+ (NOT "rocm" IN_LIST IREE_EXTERNAL_HAL_DRIVERS))
+ return()
+endif()
+
+if(NOT IREE_ROCM_PATH)
+ message(WARNING "IREE_ROCM_PATH not specified; cannot build sample")
+endif()
+
+# NOTE: this is not how one should actually build their HSACO files. Do not use
+# this as an authoritative source for compilation settings or CMake goo. If you
+# choose to go the route of custom CUDA kernels you must bring your own build
+# infrastructure. This sample only demonstrates how to use compiled HSACO blobs
+# inside of the IREE compiler and this is the minimum amount of hacking that
+# could be done to do that.
+
+# Builds a HSACO blob using the clang built by IREE from tip-of-tree LLVM.
+function(hip_kernel_hsaco_clang _ARCH)
+ set(_NAME iree_samples_custom_dispatch_hip_kernels_hsaco_${_ARCH})
+ set(_HSACO_SRC_NAME "kernels.cu")
+ get_filename_component(_HSACO_SRC_BASENAME ${_HSACO_SRC_NAME} NAME_WE CACHE)
+ set(_HSACO_OBJ_NAME "${_HSACO_SRC_BASENAME}_${_ARCH}.co")
+ add_custom_command(
+ OUTPUT
+ ${_HSACO_OBJ_NAME}
+ DEPENDS
+ ${_HSACO_SRC_NAME}
+ ${IREE_CLANG_TARGET}
+ COMMAND ${IREE_CLANG_BINARY}
+ -x hip
+ --offload-device-only
+ --offload-arch=${_ARCH}
+ --rocm-path=${IREE_ROCM_PATH}
+ -fuse-cuid=none
+ -O3
+ ${CMAKE_CURRENT_SOURCE_DIR}/${_HSACO_SRC_NAME}
+ -o ${CMAKE_CURRENT_BINARY_DIR}/${_HSACO_OBJ_NAME}
+ VERBATIM
+ )
+ add_custom_target(${_NAME} DEPENDS
+ ${CMAKE_CURRENT_BINARY_DIR}/${_HSACO_OBJ_NAME}
+ )
+ add_dependencies(iree-sample-deps "${_NAME}")
+endfunction()
+
+# Build the kernels_*.co files for each architecture we target.
+hip_kernel_hsaco_clang(gfx1100)
+
+iree_lit_test_suite(
+ NAME
+ example
+ SRCS
+ "example.mlir"
+ TOOLS
+ FileCheck
+ iree-compile
+ iree-run-module
+ LABELS
+ "driver=rocm"
+ "hostonly"
+)
diff --git a/samples/custom_dispatch/hip/kernels/example.mlir b/samples/custom_dispatch/hip/kernels/example.mlir
new file mode 100644
index 0000000..8819d86
--- /dev/null
+++ b/samples/custom_dispatch/hip/kernels/example.mlir
@@ -0,0 +1,153 @@
+// RUN: iree-compile %s \
+// RUN: --iree-hal-executable-object-search-path=$IREE_BINARY_DIR | \
+// RUN: iree-run-module \
+// RUN: --device=rocm \
+// RUN: --module=- \
+// RUN: --function=mixed_invocation \
+// RUN: --input=8xf32=2 \
+// RUN: --input=8xf32=4 | \
+// RUN: FileCheck %s
+
+// The configurations used for executable compilation.
+// This lets the compiler and runtime know the format and requirements of the
+// executable binaries produced and multiple variants with differing formats
+// and compilation options (architectures, etc) can be embedded for runtime
+// selection.
+#rocm_gfx1100_target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {
+ target_arch = "gfx1100"
+}>
+
+// The target devices that the program will run on.
+// These can come from compiler flags and multiple targets can be supported
+// It's possible, for example, to support targeting multiple devices in the same
+// compiled binary.
+#rocm_target = #hal.device.target<"rocm", [
+ #rocm_gfx1100_target
+]>
+
+module @example attributes {hal.device.targets = [#rocm_target]} {
+
+ // Executable containing hand-authored kernels.
+ // Each executable can contain multiple exported functions and variants for
+ // different architectures or even devices. It's also possible to mix hand-
+ // authored functions with code generated ones even for the same functions
+ // such that code generation is used as a fallback when the hand-authored
+ // kernels aren't supported at runtime.
+ hal.executable.source private @executable attributes {
+ // Object files linked into the executable per-target.
+ // Certain backends (today) support either wholesale definition or linking
+ // of partial objects for imports used by generated code. Each compilation
+ // target can have its own unique set of objects to link in and the target
+ // keys can be generic. This allows for an object file to be linked in based
+ // only on the target triple while allowing for more specialized ones
+ // requiring certain CPU features to be only included when building those.
+ objects = #hal.executable.objects<{
+ #rocm_gfx1100_target = [
+ #hal.executable.object<{
+ // Referencing a file path on disk but could also have the data
+ // embedded in order to make the MLIR file hermetic/portable across
+ // compilation pipelines. In the future we'll likely use MLIR's
+ // external resource functionality for this. By allowing for the
+ // objects to be embedded we can support JIT scenarios where some
+ // layer higher or lower may be emitting the objects to link in as
+ // part of the overall compilation.
+ path = "samples/custom_dispatch/hip/kernels/kernels_gfx1100.co"
+ }>
+ ]
+ }>
+ } {
+
+ // TODO(benvanik): demonstrate hal.executable.constant.block for
+ // specialization via host logic. Maps to a read-only buffer passed into
+ // kernels. ROCM doesn't yet have these wired up.
+
+ // Exported function with the C name `simple_mul`.
+ // The ordinal must be assigned by the user and unique for the executable.
+ // The layout defines the required bindings and push constants and can be
+ // thought of as the function signature.
+ hal.executable.export public @simple_mul ordinal(0)
+ layout(#hal.pipeline.layout<push_constants = 1, sets = [
+ <0, bindings = [
+ <0, storage_buffer, ReadOnly>,
+ <1, storage_buffer, ReadOnly>,
+ <2, storage_buffer>
+ ]>
+ ]>) attributes {
+ // Certain backends (like ROCM) require a workgroup size (aka block
+ // size) to be defined ahead of time.
+ workgroup_size = [64 : index, 1 : index, 1 : index],
+ // Bindings are automatically inferred when possible as part of the ABI
+ // but can be overridden if the user wants to use features such as sparse
+ // bindings or multiple descriptor sets. To do so the
+ // `hal.interface.bindings` attribute can be added to a dispatch op as
+ // follows mapping tensor operands/results to the pipeline layout
+ // sets/bindings:
+ hal.interface.bindings = [
+ #hal.interface.binding<0, 0>,
+ #hal.interface.binding<0, 1>,
+ #hal.interface.binding<0, 2>
+ ]
+ } {
+ ^bb0(%device: !hal.device, %workload: index):
+ // This host function is used to compute the XYZ workgroup count
+ // dispatched at runtime. It can query the %device for capabilities
+ // and limits (shared memory size, etc). The other arguments are the
+ // values passed in the dispatch operation (usually things like root
+ // output op tensor dimensions and other abstract values).
+ %x = affine.apply affine_map<()[s0] -> (s0 ceildiv 64)>()[%workload]
+ %c1 = arith.constant 1 : index
+ hal.return %x, %c1, %c1 : index, index, index
+ }
+
+ // Similar to the above but in-place by using a read/write binding.
+ hal.executable.export public @simple_mul_inplace ordinal(1)
+ layout(#hal.pipeline.layout<push_constants = 1, sets = [
+ <0, bindings = [
+ <0, storage_buffer, ReadOnly>,
+ <1, storage_buffer>
+ ]>
+ ]>) attributes {
+ workgroup_size = [64 : index, 1 : index, 1 : index]
+ } {
+ ^bb0(%device: !hal.device, %workload: index):
+ %x = affine.apply affine_map<()[s0] -> (s0 ceildiv 64)>()[%workload]
+ %c1 = arith.constant 1 : index
+ hal.return %x, %c1, %c1 : index, index, index
+ }
+
+ } // hal.executable.source
+
+ // Function demonstrating a few hand-authored dispatches mixed with codegen.
+ // Invoke with:
+ // --device=rocm
+ // --function=mixed_invocation
+ // --input=8xf32=2
+ // --input=8xf32=4
+ // CHECK-LABEL: EXEC @mixed_invocation
+ func.func @mixed_invocation(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+ // HACK: for hand-authored kernels all primitive values passed in need to
+ // be i32 or a bit-castable type. This is because ABI packing of other types
+ // happens inside of the PackDispatchOperandsPass that is currently not
+ // usable with external functions as it changes the ABI. In the future we
+ // can better define the ABI such that it's possible to match the compiler
+ // expectations around padding/alignment. For now users must do the packing
+ // themselves (splitting i64 into i32+i32, etc).
+ %c0 = arith.constant 0 : index
+ %dim = tensor.dim %arg0, %c0 : tensor<?xf32>
+ %dim_i32 = arith.index_cast %dim : index to i32
+
+ // Dispatch a basic `ret = lhs * rhs` kernel.
+ %0 = flow.dispatch @executable::@simple_mul[%dim](%dim_i32, %arg0, %arg1) : (i32, tensor<?xf32>{%dim}, tensor<?xf32>{%dim}) -> tensor<?xf32>{%dim}
+
+ // Code gen some other ops - these will interleave with the hand-authored
+ // ones but naturally won't be able to fuse with them.
+ %1 = arith.addf %0, %arg1 : tensor<?xf32>
+
+ // Dispatch an in-place `rhs *= lhs` kernel.
+ %2 = flow.dispatch @executable::@simple_mul_inplace[%dim](%dim_i32, %0, %1) : (i32, tensor<?xf32>{%dim}, tensor<?xf32>{%dim}) -> %1{%dim}
+
+ // CHECK: 8xf32=96 96 96 96 96 96 96 96
+ return %2 : tensor<?xf32>
+ }
+
+} // module
diff --git a/samples/custom_dispatch/hip/kernels/kernels.cu b/samples/custom_dispatch/hip/kernels/kernels.cu
new file mode 100644
index 0000000..87e29ee
--- /dev/null
+++ b/samples/custom_dispatch/hip/kernels/kernels.cu
@@ -0,0 +1,74 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <hip/hip_runtime.h>
+
+// This minimal example just has some publicly exported (__global__) kernels.
+// It's possible with more build goo to include .cuh files and pull in any
+// HIP functions that do not involve host behavior (kernel launches/etc).
+//
+// NOTE: kernels must be exported with C naming (no C++ mangling) in order to
+// match the names used in the IR declarations.
+//
+// NOTE: arguments are packed as a dense list of
+// ([ordered bindings...], [push constants...]). If a binding is declared as
+// read-only the kernel must not write to it as it may be shared by other
+// invocations.
+//
+// NOTE: today all constants must be i32. If larger types are required there are
+// packing rules that must line up with compiler expectations - passed i64
+// values must be padded to natural 8-byte alignment, for example.
+//
+// NOTE: IREE ensures that all I/O buffers are legal to have the __restrict__
+// keyword defined (no aliasing is induced that is potentially unsafe). It's
+// still possible for users to do bad things but such is the case with native
+// HIP programming.
+//
+// NOTE: I/O buffer base pointers are likely to be nicely aligned (64B minimum
+// but usually larger) but the pointers passed in may be offset by any value
+// as they represent subranges of the underlying buffers. For example if the
+// user slices out elements 3 and 4 out of a 4xf32 tensor then the base buffer
+// pointer will be at +8B. In general if the input wasn't trying to be tricky
+// (bitcasting/etc) then natural alignment is guaranteed (an f32 tensor will
+// always have buffer pointers aligned to 4B).
+
+// `ret = lhs * rhs`
+//
+// Conforms to ABI:
+// #hal.pipeline.layout<push_constants = 1, sets = [
+// <0, bindings = [
+// <0, storage_buffer, ReadOnly>,
+// <1, storage_buffer, ReadOnly>,
+// <2, storage_buffer>
+// ]>
+// ]>
+// workgroup_size = [64 : index, 1 : index, 1 : index]
+extern "C" __global__ void simple_mul(const float* __restrict__ binding0,
+ const float* __restrict__ binding1,
+ float* __restrict__ binding2, int dim) {
+ int tid = blockDim.x * blockIdx.x + threadIdx.x;
+ if (tid < dim) {
+ binding2[tid] = binding0[tid] * binding1[tid];
+ }
+}
+
+// `rhs *= lhs`
+//
+// Conforms to ABI:
+// #hal.pipeline.layout<push_constants = 1, sets = [
+// <0, bindings = [
+// <0, storage_buffer, ReadOnly>,
+// <1, storage_buffer>
+// ]>
+// ]>
+// workgroup_size = [64 : index, 1 : index, 1 : index]
+extern "C" __global__ void simple_mul_inplace(
+ const float* __restrict__ binding0, float* __restrict__ binding1, int dim) {
+ int tid = blockDim.x * blockIdx.x + threadIdx.x;
+ if (tid < dim) {
+ binding1[tid] *= binding0[tid];
+ }
+}