Enabling external HIP HSACO support ala CUDA external PTX support. (#16830)

This required fixing up ROCMTarget a bit as it was ignoring the target
configuration specified in IR and just using flags.

The `iree-rocm-link-bc` flag was removed because there's no cases in
which it was valid to be false.
diff --git a/compiler/plugins/target/ROCM/ROCMTarget.cpp b/compiler/plugins/target/ROCM/ROCMTarget.cpp
index 8759755..0cb90a5 100644
--- a/compiler/plugins/target/ROCM/ROCMTarget.cpp
+++ b/compiler/plugins/target/ROCM/ROCMTarget.cpp
@@ -69,16 +69,14 @@
     static llvm::cl::OptionCategory category("ROCM HAL Target");
     binder.opt<std::string>("iree-rocm-target-chip", targetChip,
                             llvm::cl::cat(category),
-                            llvm::cl::desc("ROCm target Chip"));
-    binder.opt<bool>("iree-rocm-link-bc", linkBitcode, llvm::cl::cat(category),
-                     llvm::cl::desc("Whether to try Linking to AMD Bitcodes"));
+                            llvm::cl::desc("ROCm target chip."));
     binder.opt<std::string>("iree-rocm-bc-dir", bitcodeDirectory,
                             llvm::cl::cat(category),
-                            llvm::cl::desc("Directory of ROCM Bitcode"));
+                            llvm::cl::desc("Directory of ROCM Bitcode."));
     binder.opt<int>("iree-rocm-waves-per-eu", wavesPerEu,
                     llvm::cl::cat(category),
                     llvm::cl::desc("Optimization hint specifying minimum "
-                                   "number of waves per execution unit"));
+                                   "number of waves per execution unit."));
     binder.opt<std::string>(
         "iree-rocm-enable-ukernels", enableROCMUkernels,
         llvm::cl::cat(category),
@@ -206,11 +204,12 @@
 
     addConfig("target_arch", b.getStringAttr(options.targetChip));
     addConfig("ukernels", b.getStringAttr(options.enableROCMUkernels));
+    if (options.wavesPerEu > 0)
+      addConfig("waves_per_eu", b.getI64IntegerAttr(options.wavesPerEu));
 
     ArrayAttr mmaAttrs = getROCMSupportedMmaAttrs(context, options.targetChip);
-    if (mmaAttrs) {
+    if (mmaAttrs)
       addConfig("mma_intrinsics", mmaAttrs);
-    }
 
     return b.getAttr<IREE::HAL::ExecutableTargetAttr>(
         b.getStringAttr("rocm"), b.getStringAttr("rocm-hsaco-fb"),
@@ -288,10 +287,11 @@
   LogicalResult serializeExecutable(const SerializationOptions &serOptions,
                                     IREE::HAL::ExecutableVariantOp variantOp,
                                     OpBuilder &executableBuilder) override {
-    // Perform the translation in a separate context to avoid any
-    // multi-threading issues.
-    llvm::LLVMContext context;
-    const std::string &bitcodeDirectory = options.getBitcodeDirectory();
+    ModuleOp innerModuleOp = variantOp.getInnerModule();
+    auto targetAttr = variantOp.getTargetAttr();
+    StringRef targetArch = options.targetChip;
+    if (auto attr = getConfigStringAttr(targetAttr, "target_arch"))
+      targetArch = attr->getValue();
 
     // We name our files after the executable name so that they are easy to
     // track both during compilation (logs/artifacts/etc), as outputs (final
@@ -300,43 +300,25 @@
     auto libraryName =
         variantOp->getParentOfType<IREE::HAL::ExecutableOp>().getName().str();
 
-    ModuleOp innerModuleOp = variantOp.getInnerModule();
-
-    auto llvmModule =
-        mlir::translateModuleToLLVMIR(innerModuleOp, context, libraryName);
-    if (!llvmModule) {
-      return variantOp.emitError() << "failed to translate the MLIR LLVM "
-                                      "dialect to the native llvm::Module";
-    }
-
     // Collect all the entry point names.
     SmallVector<IREE::HAL::ExecutableExportOp> exportOps;
     llvm::StringMap<IREE::HAL::ExecutableExportOp> exportOpMap;
-    for (auto op : variantOp.getExportOps()) {
-      exportOps.push_back(op);
-      exportOpMap[op.getSymName()] = op;
-    }
     std::vector<std::array<int32_t, 3>> workgroupSizes;
     SmallVector<uint32_t> workgroupLocalMemories;
     int32_t subgroupSize = 64;
-    StringRef subTarget = options.targetChip;
-    StringRef GFX9("gfx9");
-    for (auto func : innerModuleOp.getOps<LLVM::LLVMFuncOp>()) {
-      int32_t flatWgSize = 1;
-      auto *llvmFunc = llvmModule->getFunction(func.getName());
-      if (llvmFunc->isDeclaration())
-        continue;
+    for (auto exportOp : variantOp.getExportOps()) {
+      exportOps.push_back(exportOp);
+      exportOpMap[exportOp.getSymName()] = exportOp;
+
       std::array<int32_t, 3> workgroupSize;
-      auto exportOp = exportOpMap[func.getName()];
       if (std::optional<ArrayAttr> workgroupSizeAttr =
               exportOp.getWorkgroupSize()) {
-        for (auto it : llvm::enumerate(workgroupSizeAttr.value())) {
+        for (auto it : llvm::enumerate(workgroupSizeAttr.value()))
           workgroupSize[it.index()] = it.value().cast<IntegerAttr>().getInt();
-          flatWgSize *= it.value().cast<IntegerAttr>().getInt();
-        }
       } else {
         workgroupSize = {1, 1, 1};
       }
+      workgroupSizes.push_back(workgroupSize);
 
       if (auto setSubgroupSize = getSubgroupSize(exportOp)) {
         if (subgroupSize != 32 && subgroupSize != 64) {
@@ -346,132 +328,206 @@
         subgroupSize = *setSubgroupSize;
       }
 
-      int64_t wavesPerEu = options.wavesPerEu;
-      IREE::Codegen::TranslationInfoAttr translationInfo =
-          getTranslationInfo(exportOp);
-      if (auto translationConfig = translationInfo.getConfiguration()) {
-        if (auto attr = dyn_cast_or_null<IntegerAttr>(
-                translationConfig.get("amdgpu-waves-per-eu"))) {
-          wavesPerEu = attr.getValue().getSExtValue();
-        }
-      }
-
-      workgroupSizes.push_back(workgroupSize);
       uint32_t workgroupLocalMemory = 0;
       if (auto workgroupLocalMemoryAttr = exportOp.getWorkgroupLocalMemory()) {
         workgroupLocalMemory = workgroupLocalMemoryAttr->getSExtValue();
       }
       workgroupLocalMemories.push_back(workgroupLocalMemory);
-      // For GPU kernels,
-      // 1. Insert AMDGPU_KERNEL calling convention.
-      // 2. Insert amdgpu-flat-workgroup-size(1, 256) attribute.
-      // 3. Insert amdgpu-implicitarg-num-bytes=56 (which must be set on OpenCL
-      // and HIP kernels per Clang)
-      llvmFunc->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
-      std::string wgSizeRange = std::string("1, ") + std::to_string(flatWgSize);
-      llvmFunc->addFnAttr("amdgpu-flat-work-group-size", wgSizeRange);
-      if (wavesPerEu > 0)
-        llvmFunc->addFnAttr("amdgpu-waves-per-eu", std::to_string(wavesPerEu));
-      if (subTarget.starts_with(GFX9))
-        addPreloadKernArgHint(llvmFunc);
     }
 
-    std::unique_ptr<llvm::TargetMachine> targetMachine;
-    {
-      llvm::Triple triple("amdgcn-amd-amdhsa");
-      std::string error;
-      const llvm::Target *target =
-          llvm::TargetRegistry::lookupTarget("", triple, error);
-      if (target == nullptr) {
-        return variantOp.emitError() << "cannot initialize target triple";
+    std::string targetHSACO;
+    if (variantOp.isExternal()) {
+      if (!variantOp.getObjects().has_value()) {
+        return variantOp.emitOpError()
+               << "no objects defined for external variant";
+      } else if (variantOp.getObjects()->getValue().size() != 1) {
+        // For now we assume there will be exactly one object file.
+        // In the future we will want to perform a linking step here and ideally
+        // support _also_ linking in the codegen results.
+        return variantOp.emitOpError() << "only one object reference is "
+                                          "supported for external variants";
       }
-      llvm::TargetOptions opt;
-      opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
-      opt.UnsafeFPMath = false;
-      opt.NoInfsFPMath = false;
-      opt.NoNaNsFPMath = true;
-      std::string features;
-      if (subTarget.starts_with(GFX9)) {
-        features = "+sramecc,-xnack";
+
+      // Read the HSACO from the object file.
+      auto objectAttr = llvm::cast<IREE::HAL::ExecutableObjectAttr>(
+          variantOp.getObjects()->getValue().front());
+      if (auto data = objectAttr.loadData()) {
+        targetHSACO = data.value();
       } else {
-        // GFX 10 or 11.
-        if (subgroupSize == 32)
-          features = "+wavefrontsize32";
-        if (subgroupSize == 64)
-          features = "+wavefrontsize64";
+        return variantOp.emitOpError()
+               << "object file could not be loaded: " << objectAttr;
+      }
+    } else {
+      // Perform the translation in a separate context to avoid any
+      // multi-threading issues.
+      llvm::LLVMContext context;
+
+      auto llvmModule =
+          mlir::translateModuleToLLVMIR(innerModuleOp, context, libraryName);
+      if (!llvmModule) {
+        return variantOp.emitError() << "failed to translate the MLIR LLVM "
+                                        "dialect to the native llvm::Module";
       }
 
-      targetMachine.reset(target->createTargetMachine(
-          triple.str(), options.targetChip, features, opt,
-          llvm::Reloc::Model::PIC_, std::nullopt,
-          llvm::CodeGenOptLevel::Aggressive));
+      for (auto func : innerModuleOp.getOps<LLVM::LLVMFuncOp>()) {
+        int32_t flatWgSize = 1;
+        auto *llvmFunc = llvmModule->getFunction(func.getName());
+        if (llvmFunc->isDeclaration())
+          continue;
+        auto exportOp = exportOpMap[func.getName()];
+        if (auto workgroupSizeAttr = exportOp.getWorkgroupSize()) {
+          for (auto it : llvm::enumerate(workgroupSizeAttr.value())) {
+            flatWgSize *= it.value().cast<IntegerAttr>().getInt();
+          }
+        }
 
-      if (targetMachine == nullptr) {
-        return variantOp.emitError() << "cannot initialize target machine";
+        // Try to get waves-per-eu from the export-specific translation info in
+        // cases where codegen decides to override the value.
+        // Otherwise, fallback to the default option.
+        int64_t wavesPerEu = 0;
+        IREE::Codegen::TranslationInfoAttr translationInfo =
+            getTranslationInfo(exportOp);
+        if (auto translationConfig = translationInfo.getConfiguration()) {
+          if (auto attr =
+                  translationConfig.getAs<IntegerAttr>("waves_per_eu")) {
+            wavesPerEu = attr.getValue().getSExtValue();
+          }
+        }
+        if (wavesPerEu == 0) {
+          if (auto attr = getConfigIntegerAttr(targetAttr, "waves_per_eu"))
+            wavesPerEu = attr->getValue().getSExtValue();
+        }
+
+        // For GPU kernels,
+        // 1. Insert AMDGPU_KERNEL calling convention.
+        // 2. Insert amdgpu-flat-workgroup-size(1, 256) attribute.
+        // 3. Insert amdgpu-implicitarg-num-bytes=56 (which must be set on
+        // OpenCL and HIP kernels per Clang)
+        llvmFunc->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
+        std::string wgSizeRange =
+            std::string("1, ") + std::to_string(flatWgSize);
+        llvmFunc->addFnAttr("amdgpu-flat-work-group-size", wgSizeRange);
+        if (wavesPerEu > 0) {
+          llvmFunc->addFnAttr("amdgpu-waves-per-eu",
+                              std::to_string(wavesPerEu));
+        }
+        if (targetArch.starts_with("gfx9"))
+          addPreloadKernArgHint(llvmFunc);
       }
-    }
 
-    llvmModule->setDataLayout(targetMachine->createDataLayout());
+      std::unique_ptr<llvm::TargetMachine> targetMachine;
+      {
+        llvm::Triple triple("amdgcn-amd-amdhsa");
+        std::string error;
+        const llvm::Target *target =
+            llvm::TargetRegistry::lookupTarget("", triple, error);
+        if (target == nullptr) {
+          return variantOp.emitError() << "cannot initialize target triple";
+        }
+        llvm::TargetOptions opt;
+        opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
+        opt.UnsafeFPMath = false;
+        opt.NoInfsFPMath = false;
+        opt.NoNaNsFPMath = true;
+        std::string features;
+        if (targetArch.starts_with("gfx9")) {
+          features = "+sramecc,-xnack";
+        } else {
+          // GFX 10 or 11.
+          if (subgroupSize == 32)
+            features = "+wavefrontsize32";
+          if (subgroupSize == 64)
+            features = "+wavefrontsize64";
+        }
 
-    for (llvm::Function &f : llvmModule->functions())
-      f.addFnAttr(llvm::Attribute::AlwaysInline);
+        targetMachine.reset(target->createTargetMachine(
+            triple.str(), targetArch, features, opt, llvm::Reloc::Model::PIC_,
+            std::nullopt, llvm::CodeGenOptLevel::Aggressive));
 
-    // Link user modules and libdevice (if required).
-    // Note that linking order matters:
-    llvm::Linker linker(*llvmModule);
-    if (failed(linkCmdlineBitcodeFiles(
-            variantOp.getLoc(), linker, llvm::Linker::OverrideFromSrc,
-            *targetMachine, llvmModule->getContext()))) {
-      return failure();
-    }
+        if (targetMachine == nullptr) {
+          return variantOp.emitError() << "cannot initialize target machine";
+        }
+      }
 
-    if (!options.enableROCMUkernels.empty() &&
-        options.enableROCMUkernels != "none") {
-      auto enabledUkernelsStr = StringRef(options.enableROCMUkernels);
-      if (failed(linkUkernelBCFiles(
-              variantOp.getLoc(), llvmModule.get(), enabledUkernelsStr,
-              options.targetChip, bitcodeDirectory,
-              llvm::Linker::OverrideFromSrc, *targetMachine)))
+      llvmModule->setDataLayout(targetMachine->createDataLayout());
+
+      for (llvm::Function &f : llvmModule->functions())
+        f.addFnAttr(llvm::Attribute::AlwaysInline);
+
+      // Link user-provided modules.
+      llvm::Linker linker(*llvmModule);
+      if (failed(linkCmdlineBitcodeFiles(
+              variantOp.getLoc(), linker, llvm::Linker::OverrideFromSrc,
+              *targetMachine, llvmModule->getContext()))) {
         return failure();
-    }
-    // Link module to Device Library
-    if (options.linkBitcode) {
+      }
+
+      // Link module to any enabled ukernels.
+      const std::string &bitcodeDirectory = options.getBitcodeDirectory();
+      StringRef enabledUkernels;
+      if (auto attr = getConfigStringAttr(targetAttr, "ukernels"))
+        enabledUkernels = attr->getValue();
+      if (!enabledUkernels.empty() && enabledUkernels != "none") {
+        if (failed(linkUkernelBitcodeFiles(
+                variantOp.getLoc(), llvmModule.get(), enabledUkernels,
+                targetArch, bitcodeDirectory, llvm::Linker::OverrideFromSrc,
+                *targetMachine))) {
+          return failure();
+        }
+      }
+
+      // Link module to HIP device library.
       if (bitcodeDirectory.empty()) {
         return variantOp.emitError()
                << "cannot find ROCM bitcode files. Check your installation "
-                  "consistency and in the worst case, set --iree-rocm-bc-dir= "
-                  "to an explicit location on your system.";
+                  "consistency and in the worst case, set "
+                  "--iree-rocm-bc-dir= to a path on your system.";
       }
-      if (failed(linkROCDLIfNecessary(variantOp.getLoc(), llvmModule.get(),
-                                      options.targetChip, bitcodeDirectory)))
+      if (failed(linkHIPBitcodeIfNeeded(variantOp.getLoc(), llvmModule.get(),
+                                        targetArch, bitcodeDirectory))) {
         return failure();
-    }
-    if (!serOptions.dumpIntermediatesPath.empty()) {
-      dumpModuleToPath(serOptions.dumpIntermediatesPath,
-                       serOptions.dumpBaseName, variantOp.getName(),
-                       ".linked.ll", *llvmModule);
-    }
-    // Add Optimize module
-    optimizeModule(*llvmModule, *targetMachine);
-    // Store optimized ll.
-    if (!serOptions.dumpIntermediatesPath.empty()) {
-      dumpModuleToPath(serOptions.dumpIntermediatesPath,
-                       serOptions.dumpBaseName, variantOp.getName(),
-                       ".optimized.ll", *llvmModule);
-    }
-    // Serialize hsaco kernel into the binary that we will embed in the
-    // final FlatBuffer.
-    std::unique_ptr<llvm::Module> moduleCopy;
-    if (!serOptions.dumpIntermediatesPath.empty()) {
-      moduleCopy = llvm::CloneModule(*llvmModule);
-      if (!moduleCopy)
-        llvm::errs() << "Error: cloning LLVM IR failed\n";
-    }
-    std::string targetObj = translateModuleToObj(*llvmModule, *targetMachine);
-    std::string targetHSACO =
-        createHsaco(variantOp.getLoc(), targetObj, libraryName);
-    if (targetHSACO.empty()) {
-      return failure();
+      }
+
+      // Sets HIP platform globals based on the target architecture.
+      if (failed(setHIPGlobals(variantOp.getLoc(), llvmModule.get(),
+                               targetArch))) {
+        return failure();
+      }
+
+      if (!serOptions.dumpIntermediatesPath.empty()) {
+        dumpModuleToPath(serOptions.dumpIntermediatesPath,
+                         serOptions.dumpBaseName, variantOp.getName(),
+                         ".linked.ll", *llvmModule);
+      }
+
+      // Run LLVM optimization passes.
+      optimizeModule(*llvmModule, *targetMachine);
+      if (!serOptions.dumpIntermediatesPath.empty()) {
+        dumpModuleToPath(serOptions.dumpIntermediatesPath,
+                         serOptions.dumpBaseName, variantOp.getName(),
+                         ".optimized.ll", *llvmModule);
+      }
+
+      // Dump the assembly output.
+      if (!serOptions.dumpIntermediatesPath.empty()) {
+        auto moduleCopy = llvm::CloneModule(*llvmModule);
+        if (!moduleCopy) {
+          llvm::errs() << "Error: cloning LLVM IR failed\n";
+          return failure();
+        }
+        std::string targetISA =
+            translateModuleToISA(*moduleCopy.get(), *targetMachine);
+        dumpDataToPath(serOptions.dumpIntermediatesPath,
+                       serOptions.dumpBaseName, variantOp.getName(), ".rocmasm",
+                       targetISA);
+      }
+
+      // Serialize hsaco kernel into the binary that we will embed in the
+      // final FlatBuffer.
+      std::string targetObj = translateModuleToObj(*llvmModule, *targetMachine);
+      targetHSACO = createHsaco(variantOp.getLoc(), targetObj, libraryName);
+      if (targetHSACO.empty())
+        return failure();
     }
 
     if (!serOptions.dumpBinariesPath.empty()) {
@@ -598,13 +654,6 @@
         variantOp.getTarget().getFormat(),
         builder.getBufferAttr(executableBuilder.getContext()));
 
-    if (!serOptions.dumpIntermediatesPath.empty()) {
-      std::string targetISA =
-          translateModuleToISA(*moduleCopy.get(), *targetMachine);
-      dumpDataToPath(serOptions.dumpIntermediatesPath, serOptions.dumpBaseName,
-                     variantOp.getName(), ".rocmasm", targetISA);
-    }
-
     return success();
   }
 
diff --git a/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp b/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp
index 12b7861..4663126 100644
--- a/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp
+++ b/compiler/plugins/target/ROCM/ROCMTargetUtils.cpp
@@ -14,6 +14,7 @@
 #include "llvm/IRReader/IRReader.h"
 #include "llvm/Linker/Linker.h"
 #include "llvm/Support/FileUtilities.h"
+#include "llvm/Support/Path.h"
 #include "llvm/Support/Process.h"
 #include "llvm/Support/Program.h"
 #include "llvm/Support/SourceMgr.h"
@@ -24,33 +25,16 @@
 
 namespace mlir::iree_compiler::IREE::HAL {
 
-//===========Link LLVM Module to ROCDL Start===================/
-// Inspiration of code from this section comes from XLA Kernel Gen Project
-// https://github.com/openxla/xla/blob/main/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
-
-bool couldNeedDeviceBitcode(const llvm::Module &module) {
-  for (const llvm::Function &function : module.functions()) {
-    // The list of prefixes should be in sync with library functions used in
-    // target_util.cc.
-    if (!function.isIntrinsic() && function.isDeclaration() &&
-        (function.getName().starts_with("__ocml_") ||
-         function.getName().starts_with("__ockl_"))) {
-      return true;
-    }
-  }
-  return false;
-}
-
-std::unique_ptr<llvm::Module> loadIRModule(Location loc,
-                                           const std::string &filename,
-                                           llvm::LLVMContext *llvm_context) {
+static std::unique_ptr<llvm::Module>
+loadIRModule(Location loc, const std::string &filename,
+             llvm::LLVMContext *llvm_context) {
   llvm::SMDiagnostic diagnostic;
   std::unique_ptr<llvm::Module> module(
       llvm::parseIRFile(llvm::StringRef(filename.data(), filename.size()),
                         diagnostic, *llvm_context));
 
   if (!module) {
-    mlir::emitError(loc) << "error loading ROCDL LLVM module: "
+    mlir::emitError(loc) << "error loading HIP LLVM module: "
                          << diagnostic.getFilename().str() << ":"
                          << diagnostic.getLineNo() << ":"
                          << diagnostic.getColumnNo() << ": "
@@ -61,26 +45,27 @@
   return module;
 }
 
-LogicalResult
-linkWithBitcodeVector(Location loc, llvm::Module *module,
-                      const std::vector<std::string> &bitcode_path_vector) {
+static LogicalResult linkWithBitcodeFiles(Location loc, llvm::Module *module,
+                                          ArrayRef<std::string> bitcodePaths) {
+  if (bitcodePaths.empty())
+    return success();
   llvm::Linker linker(*module);
-
-  for (auto &bitcode_path : bitcode_path_vector) {
-    if (!(llvm::sys::fs::exists(bitcode_path)))
+  for (auto &bitcodePath : bitcodePaths) {
+    if (!llvm::sys::fs::exists(bitcodePath)) {
       return mlir::emitError(loc)
              << "AMD bitcode module is required by this module but was "
                 "not found at "
-             << bitcode_path;
-    std::unique_ptr<llvm::Module> bitcode_module =
-        loadIRModule(loc, bitcode_path, &module->getContext());
-    if (!bitcode_module)
+             << bitcodePath;
+    }
+    std::unique_ptr<llvm::Module> bitcodeModule =
+        loadIRModule(loc, bitcodePath, &module->getContext());
+    if (!bitcodeModule)
       return failure();
     // Ignore the data layout of the module we're importing. This avoids a
     // warning from the linker.
-    bitcode_module->setDataLayout(module->getDataLayout());
+    bitcodeModule->setDataLayout(module->getDataLayout());
     if (linker.linkInModule(
-            std::move(bitcode_module), llvm::Linker::Flags::LinkOnlyNeeded,
+            std::move(bitcodeModule), llvm::Linker::Flags::LinkOnlyNeeded,
             [](llvm::Module &M, const llvm::StringSet<> &GVS) {
               llvm::internalizeModule(M, [&GVS](const llvm::GlobalValue &GV) {
                 return !GV.hasName() || (GVS.count(GV.getName()) == 0);
@@ -92,10 +77,10 @@
   return success();
 }
 
-LogicalResult linkPathBitcodeFiles(Location loc, llvm::Linker &linker,
-                                   unsigned linkerFlags, StringRef path,
-                                   llvm::TargetMachine &targetMachine,
-                                   llvm::LLVMContext &context) {
+static LogicalResult linkBitcodeFile(Location loc, llvm::Linker &linker,
+                                     unsigned linkerFlags, StringRef path,
+                                     llvm::TargetMachine &targetMachine,
+                                     llvm::LLVMContext &context) {
   auto bitcodeBufferRef = llvm::MemoryBuffer::getFile(path);
   if (auto ec = bitcodeBufferRef.getError()) {
     return mlir::emitError(loc) << "failed reading user bitcode file `" << path
@@ -134,24 +119,9 @@
   return success();
 }
 
-static std::vector<std::string> getROCDLPaths(std::string targetChip,
-                                              std::string bitCodeDir) {
-  // AMDGPU bitcodes.
-  static const std::vector<std::string> rocdlFilenames({"ocml.bc", "ockl.bc"});
-
-  // Construct full path to ROCDL bitcode libraries.
-  std::vector<std::string> result;
-  std::string app = "/";
-  for (auto &filename : rocdlFilenames) {
-    result.push_back(bitCodeDir + app + filename);
-  }
-  return result;
-}
-
 static std::vector<std::string> getUkernelPaths(StringRef enabledUkernelsStr,
                                                 StringRef targetChip,
-                                                StringRef bitCodeDir) {
-  // AMD bitcodes.
+                                                StringRef bitcodePath) {
   std::vector<std::string> selectedUkernelNames;
   if (enabledUkernelsStr == "all") {
     const char *allUkernelNames[] = {"argmax"};
@@ -173,7 +143,7 @@
   for (auto &kernelName : selectedUkernelNames) {
     std::string filename =
         "rocm_" + kernelName + "_ukernel_" + targetChip.str();
-    result.push_back(bitCodeDir.str() + app + filename + ".bc");
+    result.push_back(bitcodePath.str() + app + filename + ".bc");
   }
   return result;
 }
@@ -191,13 +161,13 @@
       APInt(globalValue->getValueType()->getIntegerBitWidth(), newValue)));
 }
 
-static LogicalResult linkModuleWithGlobal(Location loc, llvm::Module *module,
-                                          std::string &targetChip) {
+LogicalResult setHIPGlobals(Location loc, llvm::Module *module,
+                            StringRef targetChip) {
   // Link target chip ISA version as global.
   const int kLenOfChipPrefix = 3;
-  std::string chipId = targetChip.substr(kLenOfChipPrefix);
+  auto chipId = targetChip.substr(kLenOfChipPrefix);
   // i.e gfx90a -> 9000 series.
-  int chipArch = stoi(chipId.substr(0, chipId.length() - 1)) * 100;
+  int chipArch = stoi(chipId.substr(0, chipId.size() - 1).str()) * 100;
   // Oldest GFX arch supported is gfx60x.
   if (chipArch < 6000)
     return failure();
@@ -207,8 +177,8 @@
   // Get chip code from suffix. i.e gfx1103 -> `3`.
   // gfx90a -> `a` == `10`.
   // gfx90c -> `c` == `12`.
-  std::string chipSuffix = chipId.substr(chipId.length() - 1);
-  uint32_t chipCode;
+  auto chipSuffix = chipId.substr(chipId.size() - 1);
+  uint32_t chipCode = 0;
   if (chipSuffix == "a") {
     chipCode = chipArch + 10;
   } else if (chipSuffix == "c") {
@@ -218,7 +188,7 @@
       return mlir::emitError(loc)
              << "error linking module with globals: unrecognized chip suffix '"
              << chipSuffix << "' for " << targetChip;
-    chipCode = chipArch + stoi(chipSuffix);
+    chipCode = chipArch + stoi(chipSuffix.str());
   }
   auto *int32Type = llvm::Type::getInt32Ty(module->getContext());
   overridePlatformGlobal(module, "__oclc_ISA_version", chipCode, int32Type);
@@ -235,30 +205,45 @@
     overridePlatformGlobal(module, globalParam.first, globalParam.second,
                            boolType);
   }
+
   return success();
 }
 
-// Links ROCm-Device-Libs into the given module if the module needs it.
-LogicalResult linkROCDLIfNecessary(Location loc, llvm::Module *module,
-                                   std::string targetChip,
-                                   std::string bitCodeDir) {
-  if (!couldNeedDeviceBitcode(*module))
-    return success();
-  if (!succeeded(HAL::linkWithBitcodeVector(
-          loc, module, getROCDLPaths(targetChip, bitCodeDir))))
-    return failure();
-  if (!succeeded(HAL::linkModuleWithGlobal(loc, module, targetChip)))
-    return failure();
-  return success();
+LogicalResult linkHIPBitcodeIfNeeded(Location loc, llvm::Module *module,
+                                     StringRef targetChip,
+                                     StringRef bitcodePath) {
+  bool usesOCML = false;
+  bool usesOCKL = false;
+  for (const llvm::Function &function : module->functions()) {
+    if (!function.isIntrinsic() && function.isDeclaration()) {
+      auto functionName = function.getName();
+      if (functionName.starts_with("__ocml_"))
+        usesOCML = true;
+      else if (functionName.starts_with("__ockl_"))
+        usesOCKL = true;
+    }
+  }
+
+  // Link externally-provided bitcode files when used.
+  SmallVector<std::string> bitcodePaths;
+  if (usesOCML) {
+    bitcodePaths.push_back(
+        (bitcodePath + llvm::sys::path::get_separator() + "ocml.bc").str());
+  }
+  if (usesOCKL) {
+    bitcodePaths.push_back(
+        (bitcodePath + llvm::sys::path::get_separator() + "ockl.bc").str());
+  }
+  return linkWithBitcodeFiles(loc, module, bitcodePaths);
 }
 
-// Links optimized Ukernel bitcodes into the given module if the module needs
-// it.
-LogicalResult linkUkernelBCFiles(Location loc, llvm::Module *module,
-                                 StringRef enabledUkernelsStr,
-                                 StringRef targetChip, StringRef bitCodeDir,
-                                 unsigned linkerFlags,
-                                 llvm::TargetMachine &targetMachine) {
+// Links optimized Ukernel bitcode into the given module if the module needs it.
+LogicalResult linkUkernelBitcodeFiles(Location loc, llvm::Module *module,
+                                      StringRef enabledUkernelsStr,
+                                      StringRef targetChip,
+                                      StringRef bitcodePath,
+                                      unsigned linkerFlags,
+                                      llvm::TargetMachine &targetMachine) {
   // Early exit if Ukernel not supported on target chip.
   if (!iree_compiler::hasUkernelSupportedRocmArch(targetChip)) {
     return mlir::emitError(loc)
@@ -266,20 +251,17 @@
            << "' not supported on target chip: " << targetChip;
   }
   std::vector<std::string> ukernelPaths =
-      getUkernelPaths(enabledUkernelsStr, targetChip, bitCodeDir);
+      getUkernelPaths(enabledUkernelsStr, targetChip, bitcodePath);
   llvm::Linker linker(*module);
   for (auto &path : ukernelPaths) {
-    if (failed(linkPathBitcodeFiles(loc, linker, linkerFlags, StringRef(path),
-                                    targetMachine, module->getContext())))
+    if (failed(linkBitcodeFile(loc, linker, linkerFlags, StringRef(path),
+                               targetMachine, module->getContext())))
       return failure();
   }
 
   return success();
 }
 
-//===========Link LLVM Module to ROCDL End===================/
-
-//=====================Create HSACO Begin=============//
 // Link object file using lld lnker to generate code object
 // Inspiration from this section comes from LLVM-PROJECT-MLIR by
 // ROCmSoftwarePlatform
@@ -311,7 +293,6 @@
   llvm::FileRemover cleanupHsaco(tempHsacoFilename);
 
   // Invoke lld. Expect a true return value from lld.
-  // Searching for LLD
   const SmallVector<std::string> &toolNames{"iree-lld", "lld"};
   std::string lldProgram = findTool(toolNames);
   if (lldProgram.empty()) {
@@ -328,7 +309,7 @@
       tempHsacoFilename.str(),
   };
 
-  // Executing LLD
+  // Execute LLD.
   std::string errorMessage;
   int lldResult = llvm::sys::ExecuteAndWait(
       unescapeCommandLineComponent(lldProgram),
@@ -350,6 +331,5 @@
                        hsacoFile->getBuffer().end());
   return strHSACO;
 }
-//==============Create HSACO End=============//
 
 } // namespace mlir::iree_compiler::IREE::HAL
diff --git a/compiler/plugins/target/ROCM/ROCMTargetUtils.h b/compiler/plugins/target/ROCM/ROCMTargetUtils.h
index bfa4146..eb0dd78 100644
--- a/compiler/plugins/target/ROCM/ROCMTargetUtils.h
+++ b/compiler/plugins/target/ROCM/ROCMTargetUtils.h
@@ -13,17 +13,22 @@
 
 namespace mlir::iree_compiler::IREE::HAL {
 
-// Links LLVM module to ROC Device Library Bit Code
-LogicalResult linkROCDLIfNecessary(Location loc, llvm::Module *module,
-                                   std::string targetChip,
-                                   std::string bitCodeDir);
+// Sets HIP platform globals based on the target architecture.
+LogicalResult setHIPGlobals(Location loc, llvm::Module *module,
+                            StringRef targetChip);
+
+// Links HIP device bitcode if the module uses any symbols from it.
+LogicalResult linkHIPBitcodeIfNeeded(Location loc, llvm::Module *module,
+                                     StringRef targetChip,
+                                     StringRef bitcodePath);
 
 // Links optimized Ukernel module.
-LogicalResult linkUkernelBCFiles(Location loc, llvm::Module *module,
-                                 StringRef enabledUkernelsStr,
-                                 StringRef targetChip, StringRef bitCodeDir,
-                                 unsigned linkerFlags,
-                                 llvm::TargetMachine &targetMachine);
+LogicalResult linkUkernelBitcodeFiles(Location loc, llvm::Module *module,
+                                      StringRef enabledUkernelsStr,
+                                      StringRef targetChip,
+                                      StringRef bitcodePath,
+                                      unsigned linkerFlags,
+                                      llvm::TargetMachine &targetMachine);
 // Compiles ISAToHsaco Code
 std::string createHsaco(Location loc, const std::string isa, StringRef name);
 
diff --git a/compiler/plugins/target/ROCM/builtins/ukernel/CMakeLists.txt b/compiler/plugins/target/ROCM/builtins/ukernel/CMakeLists.txt
index 92135e5..7550141 100644
--- a/compiler/plugins/target/ROCM/builtins/ukernel/CMakeLists.txt
+++ b/compiler/plugins/target/ROCM/builtins/ukernel/CMakeLists.txt
@@ -3,22 +3,22 @@
 # Licensed under the Apache License v2.0 with LLVM Exceptions.
 # See https://llvm.org/LICENSE.txt for license information.
 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-if (NOT IREE_TARGET_BACKEND_ROCM)
+if(NOT IREE_TARGET_BACKEND_ROCM)
   return()
 endif()
 
 # Check if HIP is installed on system.
 # HIP is required to compile ukernels.
 # TODO: We can do better than this and ensure that headers are always available.
-if (NOT IREE_ROCM_PATH)
+if(NOT IREE_ROCM_PATH)
   set(IREE_ROCM_PATH "/opt/rocm")
 endif()
-set (IREE_ROCM_VERSION "${IREE_ROCM_PATH}/.info/version")
-if (NOT EXISTS ${IREE_ROCM_VERSION})
+set(IREE_ROCM_VERSION "${IREE_ROCM_PATH}/include/hip/hip_version.h")
+if(NOT EXISTS ${IREE_ROCM_VERSION})
   message(STATUS
           "hip runtime cannot be found in ${IREE_ROCM_PATH}.
           Please try setting IREE_ROCM_PATH to rocm directory.
-          Ukernel will not be compiled.")
+          Ukernels will not be compiled.")
   return()
 endif()
 
@@ -27,8 +27,8 @@
 
 set(_platform_lib_reldir "iree_platform_libs/rocm")
 set(_device_bc_path "${IREE_COMPILER_DYLIB_DIR}/iree_platform_libs/rocm")
-set (_amd_ukernel_libs)
-set (_amd_ukernel_targets)
+set(_amd_ukernel_libs)
+set(_amd_ukernel_targets)
 function(iree_rocm_bitcode_library)
   cmake_parse_arguments(
     _RULE
@@ -45,9 +45,9 @@
   endif()
 
   set(_ROCM_ARCH "${_RULE_ROCM_ARCH}")
-  set (OPT_FLAG "-O0")
-  if (_ROCM_ARCH MATCHES "GFX9")
-    set (OPT_FLAG "-O3")
+  set(OPT_FLAG "-O0")
+  if(_ROCM_ARCH MATCHES "GFX9")
+    set(OPT_FLAG "-O3")
   endif()
   set(_COPTS
     "-x" "hip"
diff --git a/experimental/regression_suite/external_test_suite/config_gpu_rocm_rdna3.json b/experimental/regression_suite/external_test_suite/config_gpu_rocm_rdna3.json
index 03b7382..56ebe3c 100644
--- a/experimental/regression_suite/external_test_suite/config_gpu_rocm_rdna3.json
+++ b/experimental/regression_suite/external_test_suite/config_gpu_rocm_rdna3.json
@@ -1,9 +1,8 @@
 {
   "config_name": "gpu_rocm_rdna3",
-  "iree_compile_flags" : [
+  "iree_compile_flags": [
     "--iree-hal-target-backends=rocm",
-    "--iree-rocm-target-chip=gfx1100",
-    "--iree-rocm-link-bc=true"
+    "--iree-rocm-target-chip=gfx1100"
   ],
   "iree_run_module_flags": [
     "--device=rocm"
@@ -784,7 +783,6 @@
     "test_unsqueeze_unsorted_axes",
     "test_upsample_nearest",
     "test_wrap_pad",
-
     // These pass on CPU but fail on ROCm.
   ],
   "expected_run_failures": [
@@ -904,7 +902,6 @@
     "test_shape_start_1_end_negative_1",
     "test_shape_start_negative_1",
     "test_where_long_example",
-
     // These pass on CPU but fail on ROCm.
     "test_gather_negative_indices",
     "test_reduce_l1_default_axes_keepdims_example_expanded",
diff --git a/experimental/regression_suite/tests/pregenerated/test_llama2.py b/experimental/regression_suite/tests/pregenerated/test_llama2.py
index 0ff2ca9..dcde908 100644
--- a/experimental/regression_suite/tests/pregenerated/test_llama2.py
+++ b/experimental/regression_suite/tests/pregenerated/test_llama2.py
@@ -103,7 +103,6 @@
         + [
             "--iree-hal-target-backends=rocm",
             "--iree-rocm-target-chip=gfx1100",
-            "--iree-rocm-link-bc=true",
         ],
     )
 
diff --git a/experimental/regression_suite/tests/pregenerated/test_ukernel.py b/experimental/regression_suite/tests/pregenerated/test_ukernel.py
index 22a5aa2..f96225d 100644
--- a/experimental/regression_suite/tests/pregenerated/test_ukernel.py
+++ b/experimental/regression_suite/tests/pregenerated/test_ukernel.py
@@ -43,7 +43,6 @@
         + [
             "--iree-hal-target-backends=rocm",
             "--iree-rocm-target-chip=gfx90a",
-            "--iree-rocm-link-bc=true",
             "--iree-rocm-enable-ukernels=argmax",
         ],
     )
@@ -58,7 +57,6 @@
         + [
             "--iree-hal-target-backends=rocm",
             "--iree-rocm-target-chip=gfx940",
-            "--iree-rocm-link-bc=true",
             "--iree-rocm-enable-ukernels=argmax",
         ],
     )
diff --git a/integrations/pjrt/python_packages/iree_rocm_plugin/setup.py b/integrations/pjrt/python_packages/iree_rocm_plugin/setup.py
index 795aa38..1f6b76b 100644
--- a/integrations/pjrt/python_packages/iree_rocm_plugin/setup.py
+++ b/integrations/pjrt/python_packages/iree_rocm_plugin/setup.py
@@ -32,7 +32,7 @@
         print("*****************************", file=sys.stderr)
         self.build_configuration(
             os.path.join(THIS_DIR, "build", "cmake"),
-            extra_cmake_args=("-DIREE_EXTERNAL_HAL_DRIVERS=ROCM",),
+            extra_cmake_args=("-DIREE_EXTERNAL_HAL_DRIVERS=rocm",),
         )
         print("Target populated.", file=sys.stderr)
 
diff --git a/integrations/pjrt/src/CMakeLists.txt b/integrations/pjrt/src/CMakeLists.txt
index 6310d38..d0f4947 100644
--- a/integrations/pjrt/src/CMakeLists.txt
+++ b/integrations/pjrt/src/CMakeLists.txt
@@ -27,7 +27,7 @@
 if(IREE_HAL_DRIVER_CUDA)
   add_subdirectory(iree_pjrt/cuda)
 endif()
-if("ROCM" IN_LIST IREE_EXTERNAL_HAL_DRIVERS)
+if("rocm" IN_LIST IREE_EXTERNAL_HAL_DRIVERS)
   add_subdirectory(iree_pjrt/rocm)
 endif()
 if(IREE_HAL_DRIVER_VULKAN)
diff --git a/samples/custom_dispatch/hip/CMakeLists.txt b/samples/custom_dispatch/hip/CMakeLists.txt
new file mode 100644
index 0000000..ae2678c
--- /dev/null
+++ b/samples/custom_dispatch/hip/CMakeLists.txt
@@ -0,0 +1,7 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+iree_add_all_subdirs()
diff --git a/samples/custom_dispatch/hip/kernels/CMakeLists.txt b/samples/custom_dispatch/hip/kernels/CMakeLists.txt
new file mode 100644
index 0000000..7f68ee4
--- /dev/null
+++ b/samples/custom_dispatch/hip/kernels/CMakeLists.txt
@@ -0,0 +1,67 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+if((NOT IREE_TARGET_BACKEND_ROCM) OR
+   (NOT "rocm" IN_LIST IREE_EXTERNAL_HAL_DRIVERS))
+  return()
+endif()
+
+if(NOT IREE_ROCM_PATH)
+  message(WARNING "IREE_ROCM_PATH not specified; cannot build sample")
+endif()
+
+# NOTE: this is not how one should actually build their HSACO files. Do not use
+# this as an authoritative source for compilation settings or CMake goo. If you
+# choose to go the route of custom CUDA kernels you must bring your own build
+# infrastructure. This sample only demonstrates how to use compiled HSACO blobs
+# inside of the IREE compiler and this is the minimum amount of hacking that
+# could be done to do that.
+
+# Builds a HSACO blob using the clang built by IREE from tip-of-tree LLVM.
+function(hip_kernel_hsaco_clang _ARCH)
+  set(_NAME iree_samples_custom_dispatch_hip_kernels_hsaco_${_ARCH})
+  set(_HSACO_SRC_NAME "kernels.cu")
+  get_filename_component(_HSACO_SRC_BASENAME ${_HSACO_SRC_NAME} NAME_WE CACHE)
+  set(_HSACO_OBJ_NAME "${_HSACO_SRC_BASENAME}_${_ARCH}.co")
+  add_custom_command(
+    OUTPUT
+      ${_HSACO_OBJ_NAME}
+    DEPENDS
+      ${_HSACO_SRC_NAME}
+      ${IREE_CLANG_TARGET}
+    COMMAND ${IREE_CLANG_BINARY}
+      -x hip
+      --offload-device-only
+      --offload-arch=${_ARCH}
+      --rocm-path=${IREE_ROCM_PATH}
+      -fuse-cuid=none
+      -O3
+      ${CMAKE_CURRENT_SOURCE_DIR}/${_HSACO_SRC_NAME}
+      -o ${CMAKE_CURRENT_BINARY_DIR}/${_HSACO_OBJ_NAME}
+    VERBATIM
+  )
+  add_custom_target(${_NAME} DEPENDS
+    ${CMAKE_CURRENT_BINARY_DIR}/${_HSACO_OBJ_NAME}
+  )
+  add_dependencies(iree-sample-deps "${_NAME}")
+endfunction()
+
+# Build the kernels_*.co files for each architecture we target.
+hip_kernel_hsaco_clang(gfx1100)
+
+iree_lit_test_suite(
+  NAME
+    example
+  SRCS
+    "example.mlir"
+  TOOLS
+    FileCheck
+    iree-compile
+    iree-run-module
+  LABELS
+    "driver=rocm"
+    "hostonly"
+)
diff --git a/samples/custom_dispatch/hip/kernels/example.mlir b/samples/custom_dispatch/hip/kernels/example.mlir
new file mode 100644
index 0000000..8819d86
--- /dev/null
+++ b/samples/custom_dispatch/hip/kernels/example.mlir
@@ -0,0 +1,153 @@
+// RUN: iree-compile %s \
+// RUN:     --iree-hal-executable-object-search-path=$IREE_BINARY_DIR | \
+// RUN: iree-run-module \
+// RUN:     --device=rocm \
+// RUN:     --module=- \
+// RUN:     --function=mixed_invocation \
+// RUN:     --input=8xf32=2 \
+// RUN:     --input=8xf32=4 | \
+// RUN: FileCheck %s
+
+// The configurations used for executable compilation.
+// This lets the compiler and runtime know the format and requirements of the
+// executable binaries produced and multiple variants with differing formats
+// and compilation options (architectures, etc) can be embedded for runtime
+// selection.
+#rocm_gfx1100_target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {
+  target_arch = "gfx1100"
+}>
+
+// The target devices that the program will run on.
+// These can come from compiler flags and multiple targets can be supported
+// It's possible, for example, to support targeting multiple devices in the same
+// compiled binary.
+#rocm_target = #hal.device.target<"rocm", [
+  #rocm_gfx1100_target
+]>
+
+module @example attributes {hal.device.targets = [#rocm_target]} {
+
+  // Executable containing hand-authored kernels.
+  // Each executable can contain multiple exported functions and variants for
+  // different architectures or even devices. It's also possible to mix hand-
+  // authored functions with code generated ones even for the same functions
+  // such that code generation is used as a fallback when the hand-authored
+  // kernels aren't supported at runtime.
+  hal.executable.source private @executable attributes {
+    // Object files linked into the executable per-target.
+    // Certain backends (today) support either wholesale definition or linking
+    // of partial objects for imports used by generated code. Each compilation
+    // target can have its own unique set of objects to link in and the target
+    // keys can be generic. This allows for an object file to be linked in based
+    // only on the target triple while allowing for more specialized ones
+    // requiring certain CPU features to be only included when building those.
+    objects = #hal.executable.objects<{
+      #rocm_gfx1100_target = [
+        #hal.executable.object<{
+          // Referencing a file path on disk but could also have the data
+          // embedded in order to make the MLIR file hermetic/portable across
+          // compilation pipelines. In the future we'll likely use MLIR's
+          // external resource functionality for this. By allowing for the
+          // objects to be embedded we can support JIT scenarios where some
+          // layer higher or lower may be emitting the objects to link in as
+          // part of the overall compilation.
+          path = "samples/custom_dispatch/hip/kernels/kernels_gfx1100.co"
+        }>
+      ]
+    }>
+  } {
+
+    // TODO(benvanik): demonstrate hal.executable.constant.block for
+    // specialization via host logic. Maps to a read-only buffer passed into
+    // kernels. ROCM doesn't yet have these wired up.
+
+    // Exported function with the C name `simple_mul`.
+    // The ordinal must be assigned by the user and unique for the executable.
+    // The layout defines the required bindings and push constants and can be
+    // thought of as the function signature.
+    hal.executable.export public @simple_mul ordinal(0)
+        layout(#hal.pipeline.layout<push_constants = 1, sets = [
+          <0, bindings = [
+              <0, storage_buffer, ReadOnly>,
+              <1, storage_buffer, ReadOnly>,
+              <2, storage_buffer>
+          ]>
+        ]>) attributes {
+      // Certain backends (like ROCM) require a workgroup size (aka block
+      // size) to be defined ahead of time.
+      workgroup_size = [64 : index, 1 : index, 1 : index],
+      // Bindings are automatically inferred when possible as part of the ABI
+      // but can be overridden if the user wants to use features such as sparse
+      // bindings or multiple descriptor sets. To do so the
+      // `hal.interface.bindings` attribute can be added to a dispatch op as
+      // follows mapping tensor operands/results to the pipeline layout
+      // sets/bindings:
+      hal.interface.bindings = [
+        #hal.interface.binding<0, 0>,
+        #hal.interface.binding<0, 1>,
+        #hal.interface.binding<0, 2>
+      ]
+    } {
+    ^bb0(%device: !hal.device, %workload: index):
+      // This host function is used to compute the XYZ workgroup count
+      // dispatched at runtime. It can query the %device for capabilities
+      // and limits (shared memory size, etc). The other arguments are the
+      // values passed in the dispatch operation (usually things like root
+      // output op tensor dimensions and other abstract values).
+      %x = affine.apply affine_map<()[s0] -> (s0 ceildiv 64)>()[%workload]
+      %c1 = arith.constant 1 : index
+      hal.return %x, %c1, %c1 : index, index, index
+    }
+
+    // Similar to the above but in-place by using a read/write binding.
+    hal.executable.export public @simple_mul_inplace ordinal(1)
+        layout(#hal.pipeline.layout<push_constants = 1, sets = [
+          <0, bindings = [
+              <0, storage_buffer, ReadOnly>,
+              <1, storage_buffer>
+          ]>
+        ]>) attributes {
+      workgroup_size = [64 : index, 1 : index, 1 : index]
+    } {
+    ^bb0(%device: !hal.device, %workload: index):
+      %x = affine.apply affine_map<()[s0] -> (s0 ceildiv 64)>()[%workload]
+      %c1 = arith.constant 1 : index
+      hal.return %x, %c1, %c1 : index, index, index
+    }
+
+  }  // hal.executable.source
+
+  // Function demonstrating a few hand-authored dispatches mixed with codegen.
+  // Invoke with:
+  //  --device=rocm
+  //  --function=mixed_invocation
+  //  --input=8xf32=2
+  //  --input=8xf32=4
+  // CHECK-LABEL: EXEC @mixed_invocation
+  func.func @mixed_invocation(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+    // HACK: for hand-authored kernels all primitive values passed in need to
+    // be i32 or a bit-castable type. This is because ABI packing of other types
+    // happens inside of the PackDispatchOperandsPass that is currently not
+    // usable with external functions as it changes the ABI. In the future we
+    // can better define the ABI such that it's possible to match the compiler
+    // expectations around padding/alignment. For now users must do the packing
+    // themselves (splitting i64 into i32+i32, etc).
+    %c0 = arith.constant 0 : index
+    %dim = tensor.dim %arg0, %c0 : tensor<?xf32>
+    %dim_i32 = arith.index_cast %dim : index to i32
+
+    // Dispatch a basic `ret = lhs * rhs` kernel.
+    %0 = flow.dispatch @executable::@simple_mul[%dim](%dim_i32, %arg0, %arg1) : (i32, tensor<?xf32>{%dim}, tensor<?xf32>{%dim}) -> tensor<?xf32>{%dim}
+
+    // Code gen some other ops - these will interleave with the hand-authored
+    // ones but naturally won't be able to fuse with them.
+    %1 = arith.addf %0, %arg1 : tensor<?xf32>
+
+    // Dispatch an in-place `rhs *= lhs` kernel.
+    %2 = flow.dispatch @executable::@simple_mul_inplace[%dim](%dim_i32, %0, %1) : (i32, tensor<?xf32>{%dim}, tensor<?xf32>{%dim}) -> %1{%dim}
+
+    // CHECK: 8xf32=96 96 96 96 96 96 96 96
+    return %2 : tensor<?xf32>
+  }
+
+}  // module
diff --git a/samples/custom_dispatch/hip/kernels/kernels.cu b/samples/custom_dispatch/hip/kernels/kernels.cu
new file mode 100644
index 0000000..87e29ee
--- /dev/null
+++ b/samples/custom_dispatch/hip/kernels/kernels.cu
@@ -0,0 +1,74 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <hip/hip_runtime.h>
+
+// This minimal example just has some publicly exported (__global__) kernels.
+// It's possible with more build goo to include .cuh files and pull in any
+// HIP functions that do not involve host behavior (kernel launches/etc).
+//
+// NOTE: kernels must be exported with C naming (no C++ mangling) in order to
+// match the names used in the IR declarations.
+//
+// NOTE: arguments are packed as a dense list of
+// ([ordered bindings...], [push constants...]). If a binding is declared as
+// read-only the kernel must not write to it as it may be shared by other
+// invocations.
+//
+// NOTE: today all constants must be i32. If larger types are required there are
+// packing rules that must line up with compiler expectations - passed i64
+// values must be padded to natural 8-byte alignment, for example.
+//
+// NOTE: IREE ensures that all I/O buffers are legal to have the __restrict__
+// keyword defined (no aliasing is induced that is potentially unsafe). It's
+// still possible for users to do bad things but such is the case with native
+// HIP programming.
+//
+// NOTE: I/O buffer base pointers are likely to be nicely aligned (64B minimum
+// but usually larger) but the pointers passed in may be offset by any value
+// as they represent subranges of the underlying buffers. For example if the
+// user slices out elements 3 and 4 out of a 4xf32 tensor then the base buffer
+// pointer will be at +8B. In general if the input wasn't trying to be tricky
+// (bitcasting/etc) then natural alignment is guaranteed (an f32 tensor will
+// always have buffer pointers aligned to 4B).
+
+// `ret = lhs * rhs`
+//
+// Conforms to ABI:
+// #hal.pipeline.layout<push_constants = 1, sets = [
+//   <0, bindings = [
+//       <0, storage_buffer, ReadOnly>,
+//       <1, storage_buffer, ReadOnly>,
+//       <2, storage_buffer>
+//   ]>
+// ]>
+// workgroup_size = [64 : index, 1 : index, 1 : index]
+extern "C" __global__ void simple_mul(const float* __restrict__ binding0,
+                                      const float* __restrict__ binding1,
+                                      float* __restrict__ binding2, int dim) {
+  int tid = blockDim.x * blockIdx.x + threadIdx.x;
+  if (tid < dim) {
+    binding2[tid] = binding0[tid] * binding1[tid];
+  }
+}
+
+// `rhs *= lhs`
+//
+// Conforms to ABI:
+// #hal.pipeline.layout<push_constants = 1, sets = [
+//   <0, bindings = [
+//       <0, storage_buffer, ReadOnly>,
+//       <1, storage_buffer>
+//   ]>
+// ]>
+// workgroup_size = [64 : index, 1 : index, 1 : index]
+extern "C" __global__ void simple_mul_inplace(
+    const float* __restrict__ binding0, float* __restrict__ binding1, int dim) {
+  int tid = blockDim.x * blockIdx.x + threadIdx.x;
+  if (tid < dim) {
+    binding1[tid] *= binding0[tid];
+  }
+}