[ROCM] Use translation info to store waves-per-eu (#16573)
This allows setting per-kernel waves-per-eu values. For now there are no
configs that take advantage of this on default paths, but we do use this
in custom transform dialect scripts by annotating the translation info
config with this attribute.
diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp
index ef66f25..21406b4 100644
--- a/compiler/plugins/target/ROCM/ROCMTarget.cpp
+++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp
@@ -280,6 +280,16 @@
subgroupSize = *setSubgroupSize;
}
+ int64_t wavesPerEu = options.wavesPerEu;
+ IREE::Codegen::TranslationInfoAttr translationInfo =
+ getTranslationInfo(exportOp);
+ if (auto translationConfig = translationInfo.getConfiguration()) {
+ if (auto attr = dyn_cast_or_null<IntegerAttr>(
+ translationConfig.get("amdgpu-waves-per-eu"))) {
+ wavesPerEu = attr.getValue().getSExtValue();
+ }
+ }
+
workgroupSizes.push_back(workgroupSize);
uint32_t workgroupLocalMemory = 0;
if (auto workgroupLocalMemoryAttr = exportOp.getWorkgroupLocalMemory()) {
@@ -294,9 +304,8 @@
llvmFunc->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
std::string wgSizeRange = std::string("1, ") + std::to_string(flatWgSize);
llvmFunc->addFnAttr("amdgpu-flat-work-group-size", wgSizeRange);
- if (options.wavesPerEu > 0)
- llvmFunc->addFnAttr("amdgpu-waves-per-eu",
- std::to_string(options.wavesPerEu));
+ if (wavesPerEu > 0)
+ llvmFunc->addFnAttr("amdgpu-waves-per-eu", std::to_string(wavesPerEu));
if (subTarget.starts_with(GFX9))
addPreloadKernArgHint(llvmFunc);
}