[LLVMGPU][ROCm] Add validation on finalized llvm bitcode (#18552)
Check that there are no unresolved external functions that will
otherwise compile fine but be rejected by the driver.
The validation happens on the llvm bitcode after bitcode linking, when
there is no chance for anything to resolve these external functions.
This is to guard against future issues similar to
https://github.com/iree-org/iree/issues/18534 that are close to
undebuggable for end users.
diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp
index 2d02550..9b4705e 100644
--- a/compiler/plugins/target/ROCM/ROCMTarget.cpp
+++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp
@@ -36,6 +36,7 @@
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Passes/StandardInstrumentations.h"
#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
@@ -369,6 +370,21 @@
mpm.run(module, mam);
}
+ LogicalResult
+ validateFinalizedModule(IREE::HAL::ExecutableVariantOp variantOp,
+ llvm::Module &module) {
+ for (llvm::Function &func : module.functions()) {
+ if (func.isDeclaration() && !func.isIntrinsic() && !func.use_empty()) {
+ llvm::User *liveUser = *func.user_begin();
+ return variantOp.emitError()
+ << "found an unresolved external function '" << func.getName()
+ << "' in the final bitcode. A remaining live user is\n"
+ << llvm::formatv("{0}", *liveUser);
+ }
+ }
+ return success();
+ }
+
LogicalResult serializeExecutable(const SerializationOptions &serOptions,
IREE::HAL::ExecutableVariantOp variantOp,
OpBuilder &executableBuilder) override {
@@ -579,6 +595,10 @@
".optimized.ll", *llvmModule);
}
+ if (failed(validateFinalizedModule(variantOp, *llvmModule))) {
+ return failure();
+ }
+
// Dump the assembly output.
if (!serOptions.dumpIntermediatesPath.empty()) {
auto moduleCopy = llvm::CloneModule(*llvmModule);
diff --git a/compiler/plugins/target/ROCM/test/CMakeLists.txt b/compiler/plugins/target/ROCM/test/CMakeLists.txt
index 7c457c5..df185a0 100644
--- a/compiler/plugins/target/ROCM/test/CMakeLists.txt
+++ b/compiler/plugins/target/ROCM/test/CMakeLists.txt
@@ -9,6 +9,7 @@
NAME
lit
SRCS
+ "external_function_validation.mlir"
"smoketest.mlir"
"target_device_features.mlir"
TOOLS
diff --git a/compiler/plugins/target/ROCM/test/external_function_validation.mlir b/compiler/plugins/target/ROCM/test/external_function_validation.mlir
new file mode 100644
index 0000000..dcebcce
--- /dev/null
+++ b/compiler/plugins/target/ROCM/test/external_function_validation.mlir
@@ -0,0 +1,41 @@
+// RUN: iree-opt --pass-pipeline='builtin.module(hal.executable(iree-hal-serialize-target-executables{target=rocm}))' \
+// RUN: --verify-diagnostics %s -o -
+
+// The final bitcode validation should error out on any external functions that
+// remain in the final bitcode (post device bitcode linking).
+
+#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb",
+ {iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "",
+ wgp = <compute = fp16, storage = b16,
+ subgroup = none, dot = none, mma = [],
+ subgroup_size_choices = [64],
+ max_workgroup_sizes = [1024, 1024, 1024],
+ max_thread_count_per_workgroup = 1024,
+ max_workgroup_memory_bytes = 65536,
+ max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>,
+ ukernels = "none"}>
+#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer, Indirect>],
+ flags = Indirect>
+builtin.module {
+ hal.executable public @test {
+ // expected-error @+2 {{found an unresolved external function 'external_func' in the final bitcode}}
+ // expected-error @+1 {{failed to serialize executable for target backend rocm}}
+ hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm_hsaco_fb) {
+ hal.executable.export public @test ordinal(0) layout(#pipeline_layout)
+ attributes {subgroup_size = 64 : index, workgroup_size = [128 : index, 2 : index, 1 : index]} {
+ ^bb0(%arg0: !hal.device):
+ %c128 = arith.constant 128 : index
+ %c2 = arith.constant 2 : index
+ %c1 = arith.constant 1 : index
+ hal.return %c128, %c2, %c1 : index, index, index
+ }
+ builtin.module {
+ llvm.func @external_func() attributes {sym_visibility = "private"}
+ llvm.func @test() {
+ llvm.call @external_func() : () -> ()
+ llvm.return
+ }
+ }
+ }
+ }
+}