[Codegen] Improve ROCm-specific LLVM translations (#17742)

Use upstream's translations for attributes like rocdl.kernel to reduce
redundancy.

Fix the parsing of chipset versions (the last two digits are in base 16)

Signed-off-by: Krzysztof Drewniak <Krzysztof.Drewniak@amd.com>
diff --git a/.github/workflows/pkgci_regression_test.yml b/.github/workflows/pkgci_regression_test.yml
index eb2d801..2d6d2c9 100644
--- a/.github/workflows/pkgci_regression_test.yml
+++ b/.github/workflows/pkgci_regression_test.yml
@@ -319,9 +319,9 @@
             --goldendispatch-rocm-unet 1714 \
             --goldendispatch-rocm-clip 1569 \
             --goldendispatch-rocm-vae 248 \
-            --goldensize-rocm-unet-bytes 2062938 \
-            --goldensize-rocm-clip-bytes 780328 \
-            --goldensize-rocm-vae-bytes 757933 \
+            --goldensize-rocm-unet-bytes 2073609  \
+            --goldensize-rocm-clip-bytes 783720 \
+            --goldensize-rocm-vae-bytes 764909 \
             --gpu-number 6 \
             --rocm-chip gfx90a \
             --log-cli-level=info \
diff --git a/compiler/plugins/target/ROCM/BUILD.bazel b/compiler/plugins/target/ROCM/BUILD.bazel
index e9e4a7f..b2fa883 100644
--- a/compiler/plugins/target/ROCM/BUILD.bazel
+++ b/compiler/plugins/target/ROCM/BUILD.bazel
@@ -32,6 +32,7 @@
         "//compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils:KnownTargets",
         "//compiler/src/iree/compiler/Codegen/LLVMGPU",
         "//compiler/src/iree/compiler/Codegen/Utils",
+        "//compiler/src/iree/compiler/Dialect/HAL/IR",
         "//compiler/src/iree/compiler/Dialect/HAL/Target",
         "//compiler/src/iree/compiler/Dialect/HAL/Utils:LLVMLinkerUtils",
         "//compiler/src/iree/compiler/PluginAPI",
diff --git a/compiler/plugins/target/ROCM/CMakeLists.txt b/compiler/plugins/target/ROCM/CMakeLists.txt
index f801205..7e88b8a 100644
--- a/compiler/plugins/target/ROCM/CMakeLists.txt
+++ b/compiler/plugins/target/ROCM/CMakeLists.txt
@@ -58,6 +58,7 @@
     iree::compiler::Codegen::Dialect::GPU::TargetUtils::KnownTargets
     iree::compiler::Codegen::LLVMGPU
     iree::compiler::Codegen::Utils
+    iree::compiler::Dialect::HAL::IR
     iree::compiler::Dialect::HAL::Target
     iree::compiler::Dialect::HAL::Utils::LLVMLinkerUtils
     iree::compiler::PluginAPI
diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp
index 1dfe816..5f97a97 100644
--- a/compiler/plugins/target/ROCM/ROCMTarget.cpp
+++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp
@@ -16,6 +16,7 @@
 #include "iree/compiler/Codegen/LLVMGPU/Passes.h"
 #include "iree/compiler/Codegen/Utils/GPUUtils.h"
 #include "iree/compiler/Codegen/Utils/Utils.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
 #include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
 #include "iree/compiler/Dialect/HAL/Utils/LLVMLinkerUtils.h"
 #include "iree/compiler/PluginAPI/Client.h"
@@ -39,6 +40,7 @@
 #include "llvm/Transforms/Utils/Cloning.h"
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/DialectResourceBlobManager.h"
@@ -118,6 +120,50 @@
   }
 };
 
+// Set attributes on `funcOp` in order to use upstream's translation of
+// ROCDL dialect attributes to LLVM. Primarily this is `rocdl.kernel`
+// (sets the calling convention and workgroup size uniformity) but this will
+// also set both forms of workgroup size metadata from `exportOp` (if it is set)
+// and will set the waves_per_eq flag where relevant. Finally, it will mark
+// kernel arguments `inreg` to enable argument preloading on supported
+// architectures.
+static void annotateKernelForTranslation(LLVM::LLVMFuncOp funcOp,
+                                         ExecutableExportOp exportOp,
+                                         ExecutableTargetAttr targetAttr,
+                                         OpBuilder &builder) {
+  auto *rocdlDialect =
+      funcOp.getContext()->getLoadedDialect<ROCDL::ROCDLDialect>();
+  UnitAttr unitAttr = builder.getUnitAttr();
+  rocdlDialect->getKernelAttrHelper().setAttr(funcOp, unitAttr);
+  std::optional<ArrayAttr> workgroupSizeAttr = exportOp.getWorkgroupSize();
+  if (workgroupSizeAttr && workgroupSizeAttr->size() <= 3) {
+    std::array<int32_t, 3> wgSizes;
+    int32_t flatWgSize = 1;
+    for (auto [value, attr] : llvm::zip_equal(
+             wgSizes, workgroupSizeAttr->getAsRange<IntegerAttr>())) {
+      value = attr.getInt();
+      flatWgSize *= value;
+    }
+    rocdlDialect->getReqdWorkGroupSizeAttrHelper().setAttr(
+        funcOp, builder.getDenseI32ArrayAttr(wgSizes));
+    rocdlDialect->getFlatWorkGroupSizeAttrHelper().setAttr(
+        funcOp,
+        builder.getStringAttr(Twine(flatWgSize) + "," + Twine(flatWgSize)));
+  }
+
+  if (std::optional<IntegerAttr> attr =
+          getConfigIntegerAttr(targetAttr, "waves_per_eu")) {
+    rocdlDialect->getWavesPerEuAttrHelper().setAttr(funcOp, *attr);
+  }
+
+  auto inRegAttrName =
+      builder.getStringAttr(LLVM::LLVMDialect::getInRegAttrName());
+  // Currently, `inreg` only enables argument preloading on gfx9,
+  // but it is harmless on other targets.
+  for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i)
+    funcOp.setArgAttr(i, inRegAttrName, unitAttr);
+}
+
 static void dumpModuleToPath(StringRef path, StringRef baseName,
                              StringRef suffix, StringRef extension,
                              llvm::Module &module) {
@@ -155,21 +201,6 @@
   }
   return targetISA;
 }
-
-// Modified from lib/Target/AMDGPU/AMDGPUAttributor.cpp.
-// Adds argument hints to preload kernel arguments to SGPRs.
-// TODO: Query max number of user SGPRs from target machine.
-static void addPreloadKernArgHint(llvm::Function *F) {
-  static constexpr size_t maxSGPRs = 16;
-  for (size_t i = 0, e = std::min(F->arg_size(), maxSGPRs); i != e; ++i) {
-    llvm::Argument *Arg = F->getArg(i);
-    // Check for incompatible attributes.
-    if (Arg->hasByRefAttr() || Arg->hasNestAttr())
-      break;
-    Arg->addAttr(llvm::Attribute::InReg);
-  }
-}
-
 } // namespace
 
 class ROCMTargetDevice final : public TargetDevice {
@@ -249,6 +280,7 @@
     registry.insert<IREE::VectorExt::IREEVectorExtDialect>();
     registry.insert<IREE::GPU::IREEGPUDialect>();
     registry.insert<amdgpu::AMDGPUDialect>();
+    registry.insert<ROCDL::ROCDLDialect>();
   }
 
   void
@@ -380,7 +412,17 @@
       // multi-threading issues.
       llvm::LLVMContext context;
 
-      auto llvmModule =
+      // Set up attributes so upstream's conversions work right.
+      for (auto func : innerModuleOp.getOps<LLVM::LLVMFuncOp>()) {
+        // Un-exported functions are library functions or otherwise
+        // not kernels, so don't need these annotations.
+        if (!exportOpMap.contains(func.getName()))
+          continue;
+        annotateKernelForTranslation(func, exportOpMap[func.getName()],
+                                     targetAttr, executableBuilder);
+      }
+
+      std::unique_ptr<llvm::Module> llvmModule =
           mlir::translateModuleToLLVMIR(innerModuleOp, context, libraryName);
       if (!llvmModule) {
         return variantOp.emitError() << "failed to translate the MLIR LLVM "
@@ -388,35 +430,9 @@
       }
 
       for (auto func : innerModuleOp.getOps<LLVM::LLVMFuncOp>()) {
-        int32_t flatWgSize = 1;
         llvm::Function *llvmFunc = llvmModule->getFunction(func.getName());
         if (llvmFunc->isDeclaration())
           continue;
-        auto exportOp = exportOpMap[func.getName()];
-        if (auto workgroupSizeAttr = exportOp.getWorkgroupSize()) {
-          for (Attribute attr : *workgroupSizeAttr) {
-            flatWgSize *= cast<IntegerAttr>(attr).getInt();
-          }
-        }
-
-        // For GPU kernels,
-        // 1. Insert AMDGPU_KERNEL calling convention.
-        // 2. Insert amdgpu-flat-workgroup-size(1, 256) attribute.
-        // 3. Insert amdgpu-implicitarg-num-bytes=56 (which must be set on
-        // OpenCL and HIP kernels per Clang).
-        llvmFunc->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
-        llvmFunc->addFnAttr(
-            "amdgpu-flat-work-group-size",
-            (llvm::Twine("1, ") + llvm::Twine(flatWgSize)).str());
-        if (targetArch.starts_with("gfx9"))
-          addPreloadKernArgHint(llvmFunc);
-
-        // Set the amdgpu-waves-per-eu flag from config if given.
-        if (std::optional<IntegerAttr> attr =
-                getConfigIntegerAttr(targetAttr, "waves_per_eu")) {
-          llvmFunc->addFnAttr("amdgpu-waves-per-eu",
-                              std::to_string(attr->getValue().getSExtValue()));
-        }
 
         // Override flags as given by target func attrs.
         if (auto funcAttrs =
diff --git a/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp b/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp
index 61fff22..75884fb 100644
--- a/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp
+++ b/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp
@@ -166,31 +166,20 @@
                             StringRef targetChip) {
   // Link target chip ISA version as global.
   const int kLenOfChipPrefix = 3;
-  auto chipId = targetChip.substr(kLenOfChipPrefix);
-  // i.e gfx90a -> 9000 series.
-  int chipArch = stoi(chipId.substr(0, chipId.size() - 1).str()) * 100;
+  StringRef chipId = targetChip.substr(kLenOfChipPrefix);
+  int major = 0;
+  int minor = 0;
+  if (chipId.drop_back(2).getAsInteger(10, major))
+    return failure();
+  if (chipId.take_back(2).getAsInteger(16, minor))
+    return failure();
   // Oldest GFX arch supported is gfx60x.
-  if (chipArch < 6000)
+  if (major < 6)
     return failure();
   // Latest GFX arch supported is gfx115x.
-  if (chipArch > 11500)
+  if (major > 11 || (major == 11 && minor > 0x5f))
     return failure();
-  // Get chip code from suffix. i.e gfx1103 -> `3`.
-  // gfx90a -> `a` == `10`.
-  // gfx90c -> `c` == `12`.
-  auto chipSuffix = chipId.substr(chipId.size() - 1);
-  uint32_t chipCode = 0;
-  if (chipSuffix == "a") {
-    chipCode = chipArch + 10;
-  } else if (chipSuffix == "c") {
-    chipCode = chipArch + 12;
-  } else {
-    if (!std::isdigit(chipSuffix[0]))
-      return mlir::emitError(loc)
-             << "error linking module with globals: unrecognized chip suffix '"
-             << chipSuffix << "' for " << targetChip;
-    chipCode = chipArch + stoi(chipSuffix.str());
-  }
+  int chipCode = major * 1000 + minor;
   auto *int32Type = llvm::Type::getInt32Ty(module->getContext());
   overridePlatformGlobal(module, "__oclc_ISA_version", chipCode, int32Type);