[ROCM] Expose amdgpu-waves-per-eu opt hint (#16010)
This PR exposes the amdgpu-waves-per-eu flag as
iree-rocm-waves-per-eu. For FA2, setting this to 2 for 16x16384x128xf16
results in a speedup of
25TFlops.
diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp
index 4792fb0..1b57579 100644
--- a/compiler/plugins/target/ROCM/ROCMTarget.cpp
+++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp
@@ -46,6 +46,7 @@
std::string targetChip = "gfx908";
bool linkBitcode = false;
std::string bitcodeDirectory;
+ int wavesPerEu = 0;
void bindOptions(OptionsBinder &binder) {
static llvm::cl::OptionCategory category("ROCM HAL Target");
@@ -57,6 +58,10 @@
binder.opt<std::string>("iree-rocm-bc-dir", bitcodeDirectory,
llvm::cl::cat(category),
llvm::cl::desc("Directory of ROCM Bitcode"));
+ binder.opt<int>("iree-rocm-waves-per-eu", wavesPerEu,
+ llvm::cl::cat(category),
+ llvm::cl::desc("Optimization hint specifying minimum "
+ "number of waves per execution unit"));
}
};
} // namespace
@@ -254,6 +259,9 @@
llvmFunc->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
std::string wgSizeRange = std::string("1, ") + std::to_string(flatWgSize);
llvmFunc->addFnAttr("amdgpu-flat-work-group-size", wgSizeRange);
+ if (options.wavesPerEu > 0)
+ llvmFunc->addFnAttr("amdgpu-waves-per-eu",
+ std::to_string(options.wavesPerEu));
}
std::unique_ptr<llvm::TargetMachine> targetMachine;