[ROCM] Fix redefinition of symbol error for including tensor ukernels (#21780)

Fix for when multiple functions require the same tensor ukernel to be
included. Without this PR, this leads to a symbol redefinition error.

Signed-off-by: Jorn Tuyls <jorn.tuyls@gmail.com>
diff --git a/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/ApplyBuiltinPDLPatterns.cpp b/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/ApplyBuiltinPDLPatterns.cpp
index f562c11..2ff03ee 100644
--- a/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/ApplyBuiltinPDLPatterns.cpp
+++ b/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/ApplyBuiltinPDLPatterns.cpp
@@ -299,6 +299,7 @@
     MLIRContext *ctx = moduleOp.getContext();
     auto rocmDialect = ctx->getOrLoadDialect<IREE::ROCM::ROCMDialect>();
     SmallVector<FunctionOpInterface> ukernelFunctions;
+    llvm::SmallDenseSet<StringRef> ukernelSymbols;
     auto res = moduleOp.walk([&](Operation *op) {
       auto builtinName =
           dyn_cast_or_null<StringAttr>(op->getAttr(kBuiltinName));
@@ -306,8 +307,7 @@
       if (!builtinName || !ukernelDesc) {
         return WalkResult::advance();
       }
-      if (moduleOp->hasAttr(ukernelDesc.getUkernelName())) {
-        // Avoid parsing and serializing the same ukernel again and again.
+      if (ukernelSymbols.contains(ukernelDesc.getUkernelName())) {
         return WalkResult::advance();
       }
       std::optional<StringRef> maybeBuiltin =
@@ -335,6 +335,7 @@
       funcOp->remove();
       ukernelFunctions.push_back(funcOp);
       op->removeAttr(kBuiltinName);
+      ukernelSymbols.insert(ukernelDesc.getUkernelName());
       return WalkResult::advance();
     });
     if (res.wasInterrupted()) {
diff --git a/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns_driver.mlir b/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns_driver.mlir
index e56b7ec..30a8340 100644
--- a/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns_driver.mlir
+++ b/compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns_driver.mlir
@@ -31,6 +31,21 @@
       } -> tensor<1x128x1024xf32>
     return %2 : tensor<1x128x1024xf32>
   }
+  // Check that a second function requiring the same ukernel doesn't lead to a 'redefinition of symbol named ...' error.
+  func.func @matmul_f8_medium_expanded_2(%arg0: tensor<1x128x4096xf8E4M3FNUZ>, %arg1: tensor<1024x4096xf8E4M3FNUZ>) -> tensor<1x128x1024xf32> {
+    %cst = arith.constant 0.000000e+00 : f32
+    %0 = tensor.empty() : tensor<1x128x1024xf32>
+    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x128x1024xf32>) -> tensor<1x128x1024xf32>
+    %2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<1x128x4096xf8E4M3FNUZ>, tensor<1024x4096xf8E4M3FNUZ>) outs(%1 : tensor<1x128x1024xf32>) {
+      ^bb0(%in: f8E4M3FNUZ, %in_4: f8E4M3FNUZ, %out: f32):
+        %12 = arith.extf %in : f8E4M3FNUZ to f32
+        %13 = arith.extf %in_4 : f8E4M3FNUZ to f32
+        %14 = arith.mulf %12, %13 : f32
+        %15 = arith.addf %out, %14 : f32
+        linalg.yield %15 : f32
+      } -> tensor<1x128x1024xf32>
+    return %2 : tensor<1x128x1024xf32>
+  }
 }
 // CHECK-LABEL: util.func private @pingpong_medium_f8_expanded
 // CHECK:         iree_codegen.inner_tiled