[ROCM] Fix redefinition of symbol error for including tensor ukernels (#21780)
Fix for when multiple functions require the same tensor ukernel to be
included. Without this PR, this leads to a symbol redefinition error.
Signed-off-by: Jorn Tuyls <jorn.tuyls@gmail.com>
diff --git a/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/ApplyBuiltinPDLPatterns.cpp b/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/ApplyBuiltinPDLPatterns.cpp
index f562c11..2ff03ee 100644
--- a/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/ApplyBuiltinPDLPatterns.cpp
+++ b/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/ApplyBuiltinPDLPatterns.cpp
@@ -299,6 +299,7 @@
MLIRContext *ctx = moduleOp.getContext();
auto rocmDialect = ctx->getOrLoadDialect<IREE::ROCM::ROCMDialect>();
SmallVector<FunctionOpInterface> ukernelFunctions;
+ llvm::SmallDenseSet<StringRef> ukernelSymbols;
auto res = moduleOp.walk([&](Operation *op) {
auto builtinName =
dyn_cast_or_null<StringAttr>(op->getAttr(kBuiltinName));
@@ -306,8 +307,7 @@
if (!builtinName || !ukernelDesc) {
return WalkResult::advance();
}
- if (moduleOp->hasAttr(ukernelDesc.getUkernelName())) {
- // Avoid parsing and serializing the same ukernel again and again.
+ if (ukernelSymbols.contains(ukernelDesc.getUkernelName())) {
return WalkResult::advance();
}
std::optional<StringRef> maybeBuiltin =
@@ -335,6 +335,7 @@
funcOp->remove();
ukernelFunctions.push_back(funcOp);
op->removeAttr(kBuiltinName);
+ ukernelSymbols.insert(ukernelDesc.getUkernelName());
return WalkResult::advance();
});
if (res.wasInterrupted()) {
diff --git a/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns_driver.mlir b/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns_driver.mlir
index e56b7ec..30a8340 100644
--- a/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns_driver.mlir
+++ b/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns_driver.mlir
@@ -31,6 +31,21 @@
} -> tensor<1x128x1024xf32>
return %2 : tensor<1x128x1024xf32>
}
+ // Check that a second function requiring the same ukernel doesn't lead to a 'redefinition of symbol named ...' error.
+ func.func @matmul_f8_medium_expanded_2(%arg0: tensor<1x128x4096xf8E4M3FNUZ>, %arg1: tensor<1024x4096xf8E4M3FNUZ>) -> tensor<1x128x1024xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<1x128x1024xf32>
+ %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x128x1024xf32>) -> tensor<1x128x1024xf32>
+ %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1x128x4096xf8E4M3FNUZ>, tensor<1024x4096xf8E4M3FNUZ>) outs(%1 : tensor<1x128x1024xf32>) {
+ ^bb0(%in: f8E4M3FNUZ, %in_4: f8E4M3FNUZ, %out: f32):
+ %12 = arith.extf %in : f8E4M3FNUZ to f32
+ %13 = arith.extf %in_4 : f8E4M3FNUZ to f32
+ %14 = arith.mulf %12, %13 : f32
+ %15 = arith.addf %out, %14 : f32
+ linalg.yield %15 : f32
+ } -> tensor<1x128x1024xf32>
+ return %2 : tensor<1x128x1024xf32>
+ }
}
// CHECK-LABEL: util.func private @pingpong_medium_f8_expanded
// CHECK: iree_codegen.inner_tiled