[LLVMGPU][ROCm] Disable kernarg preloading on pre-CDNA3 targets (#18343)
This is not supported on CDNA1/CDNA2 devices which, depending on the
firmware, may fail at runtime when executing code that uses kernel
argument preloading.
I considered adding this to the `TargetWgp` attribute but decided
against it because of how rocm-specific this optimization is.
Tested manually for a bunch of targets and looking at the ISA,
including: `gfx90a`, `gfx940`, `gfx942`, and `gfx1100`.
diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp
index c1dc780..de7d3d7 100644
--- a/compiler/plugins/target/ROCM/ROCMTarget.cpp
+++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp
@@ -35,10 +35,12 @@
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Passes/StandardInstrumentations.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/Attributes.h"
@@ -132,6 +134,17 @@
}
};
+// Extracts the amdgpu chipset version from the chip architecture in the
+// executable target attribute.
+static FailureOr<amdgpu::Chipset>
+getChipsetVersion(ExecutableTargetAttr targetAttr) {
+ IREE::GPU::TargetAttr gpuTarget = getGPUTargetAttr(targetAttr);
+ if (!gpuTarget)
+ return failure();
+
+ return amdgpu::Chipset::parse(gpuTarget.getArch());
+}
+
// Set attributes on `funcOp` in order to use upstream's translation of
// ROCDL dialect attributes to LLVM. Primarily this is `rocdl.kernel`
// (sets the calling convention and workgroup size uniformity) but this will
@@ -168,10 +181,17 @@
rocdlDialect->getWavesPerEuAttrHelper().setAttr(funcOp, *attr);
}
+ // Kernel argument preloading is only supported on gfx940 and newer targets
+ // from the CDNA family. This is enabled using the `inreg` function argument
+ // attribute.
+ FailureOr<amdgpu::Chipset> chipset = getChipsetVersion(targetAttr);
+ if (failed(chipset))
+ return;
+ if (chipset->majorVersion != 9 && chipset->minorVersion < 0x40)
+ return;
+
auto inRegAttrName =
builder.getStringAttr(LLVM::LLVMDialect::getInRegAttrName());
- // Currently, `inreg` only enables argument preloading on gfx9,
- // but it is harmless on other targets.
for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i)
funcOp.setArgAttr(i, inRegAttrName, unitAttr);
}