// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "ROCMTargetUtils.h"

#include <cstdint>

#include "compiler/plugins/target/ROCM/Dialect/ROCM/IR/ROCMAttrs.h"
#include "compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/Passes.h"
#include "compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_bitcode.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
#include "iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.h"
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
#include "iree/compiler/Codegen/Utils/CodegenOptions.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
#include "iree/compiler/Dialect/HAL/Utils/ExecutableDebugInfoUtils.h"
#include "iree/compiler/Dialect/HAL/Utils/LLVMLinkerUtils.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/PluginAPI/Client.h"
#include "iree/compiler/Utils/EmbeddedDataDirectory.h"
#include "iree/compiler/Utils/FlatbufferUtils.h"
#include "iree/compiler/Utils/ToolUtils.h"
#include "iree/schemas/amdgpu_executable_def_builder.h"
#include "iree/schemas/hip_executable_def_builder.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Bitcode/BitcodeWriter.h"
#include "llvm/Frontend/Offloading/Utility.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Passes/StandardInstrumentations.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"

namespace mlir::iree_compiler::IREE::HAL {
namespace {

enum class ContainerType {
  // Automatically detect the container type from the target ABI attribute.
  Auto,
  // HIP ExecutableDef flatbuffer.
  HIP,
  // AMDGPU ExecutableDef flatbuffer.
  AMDGPU,
  // Raw HSACO image (ELF).
  HSACO,
};

enum class AMDGPUTargetFeatureMode {
  // Feature mode is not specified by the target ID.
  Any,
  // Feature is explicitly disabled.
  Off,
  // Feature is explicitly enabled.
  On,
};

struct AMDGPUTargetFeatureModes {
  // SRAM ECC target feature mode.
  AMDGPUTargetFeatureMode sramecc = AMDGPUTargetFeatureMode::Any;
  // XNACK target feature mode.
  AMDGPUTargetFeatureMode xnack = AMDGPUTargetFeatureMode::Any;
};

static LogicalResult
setAMDGPUTargetFeatureMode(Location loc, StringRef featureName,
                           AMDGPUTargetFeatureMode featureMode,
                           AMDGPUTargetFeatureMode &targetFeatureMode) {
  if (targetFeatureMode != AMDGPUTargetFeatureMode::Any) {
    return emitError(loc, "duplicate ROCM target feature '")
           << featureName << "'";
  }
  targetFeatureMode = featureMode;
  return success();
}

static FailureOr<AMDGPUTargetFeatureModes>
parseAMDGPUTargetFeatureModes(Location loc, StringRef targetFeatures) {
  AMDGPUTargetFeatureModes modes;
  SmallVector<StringRef> features;
  llvm::SplitString(targetFeatures, features, ",");
  for (StringRef rawFeature : features) {
    AMDGPUTargetFeatureMode featureMode = AMDGPUTargetFeatureMode::Any;
    StringRef feature = rawFeature;
    if (feature.consume_front("+")) {
      featureMode = AMDGPUTargetFeatureMode::On;
    } else if (feature.consume_front("-")) {
      featureMode = AMDGPUTargetFeatureMode::Off;
    } else {
      emitError(loc, "ROCM target feature must be prefixed with '+' or '-'; "
                     "but seen '")
          << rawFeature << "'";
      return failure();
    }
    if (feature == "sramecc") {
      if (failed(setAMDGPUTargetFeatureMode(loc, feature, featureMode,
                                            modes.sramecc))) {
        return failure();
      }
    } else if (feature == "xnack") {
      if (failed(setAMDGPUTargetFeatureMode(loc, feature, featureMode,
                                            modes.xnack))) {
        return failure();
      }
    } else {
      // We only support these two features to be set explicitly. Features like
      // wavefrontsize are controlled and tuned by the compiler.
      emitError(loc,
                "ROCM target feature can only be 'sramecc' or 'xnack'; but "
                "seen '")
          << feature << "'";
      return failure();
    }
  }
  return modes;
}

static void appendAMDGPUTargetFeatureSuffix(std::string &targetID,
                                            StringRef featureName,
                                            AMDGPUTargetFeatureMode mode) {
  switch (mode) {
  case AMDGPUTargetFeatureMode::Any:
    return;
  case AMDGPUTargetFeatureMode::Off:
    targetID += ":";
    targetID += featureName;
    targetID += "-";
    return;
  case AMDGPUTargetFeatureMode::On:
    targetID += ":";
    targetID += featureName;
    targetID += "+";
    return;
  }
}

static FailureOr<std::string> buildAMDGPUTargetID(Location loc,
                                                  StringRef targetArch,
                                                  StringRef targetFeatures) {
  FailureOr<AMDGPUTargetFeatureModes> modes =
      parseAMDGPUTargetFeatureModes(loc, targetFeatures);
  if (failed(modes)) {
    return failure();
  }
  std::string targetID = targetArch.str();
  appendAMDGPUTargetFeatureSuffix(targetID, "sramecc", modes->sramecc);
  appendAMDGPUTargetFeatureSuffix(targetID, "xnack", modes->xnack);
  return targetID;
}

struct ROCMOptions {
  std::string target = "";
  std::string targetFeatures = "";
  ContainerType containerType = ContainerType::Auto;
  std::string bitcodeDirectory = getDefaultBitcodeDirectory();
  int wavesPerEu = 0;
  std::string enableROCMUkernels = "none";
  std::string encodingLayoutResolver =
      GPU::kDataTilingEncodingLayoutResolverName;
  bool slpVectorization = true;
  bool globalISel = false;
  bool specializeDispatches = false;
  bool enableTensorUKernels = false;
  IREE::Codegen::DenormalFpMath denormalFpMathF32 =
      IREE::Codegen::DenormalFpMath::None;
  bool enableRegSpillWarning = false;
  bool debugSymbols = false;

  void bindOptions(OptionsBinder &binder) {
    using namespace llvm;
    static cl::OptionCategory category("ROCM HAL Target");

    binder.opt<std::string>(
        "iree-rocm-target", target, cl::cat(category),
        cl::desc(
            // clang-format off
            "ROCM target as expected by LLVM AMDGPU backend; e.g., "
            "'gfx90a'/'gfx942' for targeting MI250/MI300 GPUs. "
            "Additionally this also supports architecture code names like "
            "'cdna3'/'rdna3' or some product names like 'mi300x'/'rtx7900xtx' "
            "for a better experience. See "
            "https://iree.dev/guides/deployment-configurations/gpu-rocm/ "
            "for more details."
            // clang-format on
            ));
    binder.opt<std::string>(
        "iree-hip-target", target, cl::cat(category),
        cl::desc("Deprecated; use --iree-rocm-target instead."),
        Deprecated("use --iree-rocm-target instead"));

    binder.opt<std::string>(
        "iree-rocm-target-features", targetFeatures, cl::cat(category),
        cl::desc("ROCM target features as expected by LLVM AMDGPU backend; "
                 "e.g., '+sramecc,+xnack'."));
    binder.opt<std::string>(
        "iree-hip-target-features", targetFeatures, cl::cat(category),
        cl::desc("Deprecated; use --iree-rocm-target-features instead."),
        Deprecated("use --iree-rocm-target-features instead"));

    binder.opt<ContainerType>(
        "iree-rocm-container-type", containerType,
        llvm::cl::desc("Serialized executable container type."),
        llvm::cl::cat(category),
        llvm::cl::values(clEnumValN(ContainerType::Auto, "auto",
                                    "Automatically detect the container type "
                                    "from the target ABI attribute."),
                         clEnumValN(ContainerType::HIP, "hip",
                                    "HIP ExecutableDef flatbuffer."),
                         clEnumValN(ContainerType::AMDGPU, "amdgpu",
                                    "AMDGPU ExecutableDef flatbuffer."),
                         clEnumValN(ContainerType::HSACO, "hsaco",
                                    "Raw HSACO image (ELF).")));

    binder.opt<std::string>("iree-rocm-bc-dir", bitcodeDirectory,
                            cl::cat(category),
                            cl::desc("Directory of ROCM Bitcode."));
    binder.opt<std::string>(
        "iree-hip-bc-dir", bitcodeDirectory, cl::cat(category),
        cl::desc("Deprecated; use --iree-rocm-bc-dir instead."),
        Deprecated("use --iree-rocm-bc-dir instead"));

    binder.opt<int>("iree-rocm-waves-per-eu", wavesPerEu, cl::cat(category),
                    cl::desc("Optimization hint specifying minimum "
                             "number of waves per execution unit."));
    binder.opt<int>(
        "iree-hip-waves-per-eu", wavesPerEu, cl::cat(category),
        cl::desc("Deprecated; use --iree-rocm-waves-per-eu instead."),
        Deprecated("use --iree-rocm-waves-per-eu instead"));

    binder.opt<std::string>(
        "iree-rocm-enable-ukernels", enableROCMUkernels, cl::cat(category),
        cl::desc("Enables microkernels in the ROCM compiler backend. May be "
                 "`default`, `none`, `all`, or a comma-separated list of "
                 "specific unprefixed microkernels to enable, e.g. `mmt4d`."));
    binder.opt<std::string>(
        "iree-hip-enable-ukernels", enableROCMUkernels, cl::cat(category),
        cl::desc("Deprecated; use --iree-rocm-enable-ukernels instead."),
        Deprecated("use --iree-rocm-enable-ukernels instead"));

    binder.opt<std::string>(
        "iree-rocm-encoding-layout-resolver", encodingLayoutResolver,
        cl::cat(category),
        cl::desc("Selects the way that encodings will be "
                 "resolved. Options are: `none` (resolve to "
                 "identity layout), `pad` (additional padding "
                 "on allocations to maximize cache bandwidth), "
                 "and `data-tiling` (enable data tiled layouts)"));
    binder.opt<std::string>(
        "iree-hip-encoding-layout-resolver", encodingLayoutResolver,
        cl::cat(category),
        cl::desc(
            "Deprecated; use --iree-rocm-encoding-layout-resolver instead."),
        Deprecated("use --iree-rocm-encoding-layout-resolver instead"));

    binder.opt<bool>("iree-rocm-llvm-slp-vec", slpVectorization,
                     cl::cat(category),
                     cl::desc("Enable slp vectorization in llvm opt."));
    binder.opt<bool>(
        "iree-hip-llvm-slp-vec", slpVectorization, cl::cat(category),
        cl::desc("Deprecated; use --iree-rocm-llvm-slp-vec instead."),
        Deprecated("use --iree-rocm-llvm-slp-vec instead"));

    binder.opt<bool>("iree-rocm-llvm-global-isel", globalISel,
                     cl::cat(category),
                     cl::desc("Enable global instruction selection in llvm."));
    binder.opt<bool>(
        "iree-hip-llvm-global-isel", globalISel, cl::cat(category),
        cl::desc("Deprecated; use --iree-rocm-llvm-global-isel instead."),
        Deprecated("use --iree-rocm-llvm-global-isel instead"));

    binder.opt<bool>(
        "iree-rocm-specialize-dispatches", specializeDispatches,
        cl::cat(category),
        cl::desc(
            "Enable runtime specialization of dynamically shaped dispatches."));
    binder.opt<bool>(
        "iree-hip-specialize-dispatches", specializeDispatches,
        cl::cat(category),
        cl::desc("Deprecated; use --iree-rocm-specialize-dispatches instead."),
        Deprecated("use --iree-rocm-specialize-dispatches instead"));

    binder.opt<bool>("iree-rocm-enable-tensor-ukernels", enableTensorUKernels,
                     cl::cat(category),
                     cl::desc("Enable MLIR-based ukernels."));
    binder.opt<bool>(
        "iree-hip-enable-tensor-ukernels", enableTensorUKernels,
        cl::cat(category),
        cl::desc("Deprecated; use --iree-rocm-enable-tensor-ukernels instead."),
        Deprecated("use --iree-rocm-enable-tensor-ukernels instead"));

    binder.opt<IREE::Codegen::DenormalFpMath>(
        "iree-rocm-denormal-fp-math-f32", denormalFpMathF32, cl::cat(category),
        cl::desc("Denormal floating point math mode for f32"),
        cl::values(
            clEnumValN(IREE::Codegen::DenormalFpMath::PreserveSign,
                       "preserve-sign",
                       "Convert denormals to zero while preserving sign"),
            clEnumValN(IREE::Codegen::DenormalFpMath::PositiveZero,
                       "positive-zero", "Convert denormals to positive zero")));
    binder.opt<IREE::Codegen::DenormalFpMath>(
        "iree-hip-denormal-fp-math-f32", denormalFpMathF32, cl::cat(category),
        cl::desc("Deprecated; use --iree-rocm-denormal-fp-math-f32 instead."),
        Deprecated("use --iree-rocm-denormal-fp-math-f32 instead"),
        cl::values(
            clEnumValN(IREE::Codegen::DenormalFpMath::PreserveSign,
                       "preserve-sign",
                       "Convert denormals to zero while preserving sign"),
            clEnumValN(IREE::Codegen::DenormalFpMath::PositiveZero,
                       "positive-zero", "Convert denormals to positive zero")));

    binder.opt<bool>("iree-rocm-enable-register-spill-warning",
                     enableRegSpillWarning, cl::cat(category),
                     cl::desc("Report register spilling for AMD GPUs"));
    binder.opt<bool>(
        "iree-hip-enable-register-spill-warning", enableRegSpillWarning,
        cl::cat(category),
        cl::desc("Deprecated; use --iree-rocm-enable-register-spill-warning "
                 "instead."),
        Deprecated("use --iree-rocm-enable-register-spill-warning instead"));

    binder.opt<bool>("iree-rocm-emit-debug-info", debugSymbols,
                     cl::cat(category),
                     cl::desc("Generate and embed debug information (DWARF)."));
    binder.opt<bool>(
        "iree-hip-emit-debug-info", debugSymbols, cl::cat(category),
        cl::desc("Deprecated; use --iree-rocm-emit-debug-info instead."),
        Deprecated("use --iree-rocm-emit-debug-info instead"));
  }

  LogicalResult verify(mlir::Builder &builder) const {
    if (target.empty()) {
      return emitError(builder.getUnknownLoc())
             << "ROCM target not set; did you forget to pass "
                "'--iree-rocm-target'?";
    }
    if (GPU::normalizeHIPTarget(target).empty()) {
      return emitError(builder.getUnknownLoc(), "Unknown ROCM target '")
             << target << "'";
    }
    if (failed(parseAMDGPUTargetFeatureModes(builder.getUnknownLoc(),
                                             targetFeatures))) {
      return failure();
    }
    return success();
  }

private:
  static std::string getDefaultBitcodeDirectory() {
    return mlir::iree_compiler::findPlatformLibDirectory("rocm");
  }
};

// Returns the ABI or an empty string if unspecified.
static StringRef getABI(IREE::HAL::ExecutableTargetAttr targetAttr) {
  if (targetAttr) {
    if (auto config = targetAttr.getConfiguration()) {
      auto abiAttr = targetAttr.getConfiguration().getAs<StringAttr>("abi");
      return abiAttr ? abiAttr.getValue() : "";
    }
  }
  return "";
}

static void dumpModuleToPath(StringRef path, StringRef baseName,
                             StringRef suffix, StringRef extension,
                             llvm::Module &module, StringRef header = {}) {
  llvm::SmallVector<char, 0> data;
  llvm::raw_svector_ostream ostream(data);
  ostream << header;
  module.print(ostream, nullptr);
  dumpDataToPath(path, baseName, suffix, extension,
                 StringRef(data.data(), data.size()));
}

static std::string translateModuleToObj(llvm::Module &module,
                                        llvm::TargetMachine &targetMachine) {
  std::string targetObj;
  {
    llvm::raw_string_ostream stream(targetObj);
    llvm::buffer_ostream pstream(stream);
    llvm::legacy::PassManager codegenPasses;
    targetMachine.addPassesToEmitFile(codegenPasses, pstream, nullptr,
                                      llvm::CodeGenFileType::ObjectFile);
    codegenPasses.run(module);
  }
  return targetObj;
}

static std::string translateModuleToISA(llvm::Module &module,
                                        llvm::TargetMachine &targetMachine) {
  std::string targetISA;
  {
    llvm::raw_string_ostream stream(targetISA);
    llvm::buffer_ostream pstream(stream);
    llvm::legacy::PassManager codegenPasses;
    targetMachine.addPassesToEmitFile(codegenPasses, pstream, nullptr,
                                      llvm::CodeGenFileType::AssemblyFile);
    codegenPasses.run(module);
  }
  return targetISA;
}

static void checkRegisterSpilling(IREE::HAL::ExecutableVariantOp &variantOp,
                                  StringRef obj) {
  uint16_t abiVersion;
  llvm::StringMap<llvm::offloading::amdgpu::AMDGPUKernelMetaData> infoMap;

  if (!llvm::offloading::amdgpu::getAMDGPUMetaDataFromImage(
          llvm::MemoryBufferRef(obj, ""), infoMap, abiVersion)) {
    for (const auto &[dispatchName, metaData] : infoMap) {
      if (metaData.SGPRSpillCount > 0 || metaData.VGPRSpillCount > 0) {
        emitWarning(variantOp.getLoc())
            << "Register spill: " << "VGPRSpillCount: "
            << metaData.VGPRSpillCount
            << " / SGPRSpillCount: " << metaData.SGPRSpillCount
            << " / Dispatch: " << dispatchName;
      }
    }
  }
}

class ROCMTargetBackend final : public TargetBackend {
public:
  ROCMTargetBackend(const ROCMOptions &options, GPUCodegenOptions codegenOpts)
      : targetOptions(options), codegenOptions(std::move(codegenOpts)) {}

  std::string getLegacyDefaultDeviceID() const final { return "hip"; }

  void getDefaultExecutableTargets(
      MLIRContext *context, StringRef deviceID, DictionaryAttr deviceConfigAttr,
      SmallVectorImpl<IREE::HAL::ExecutableTargetAttr> &executableTargetAttrs)
      const final {
    if (auto target = getExecutableTarget(deviceID, context)) {
      executableTargetAttrs.push_back(target);
    }
  }

  IREE::HAL::ExecutableTargetAttr
  getExecutableTarget(StringRef deviceID, MLIRContext *context) const {
    Builder b(context);
    SmallVector<NamedAttribute, 4> configItems;
    auto addConfig = [&](StringRef name, Attribute value) {
      configItems.emplace_back(name, value);
    };

    if (failed(targetOptions.verify(b))) {
      return nullptr;
    }

    addConfig("abi", b.getStringAttr(deviceID));
    std::string format;
    if (deviceID == "amdgpu") {
      FailureOr<std::string> targetID =
          buildAMDGPUTargetID(b.getUnknownLoc(), targetOptions.target,
                              targetOptions.targetFeatures);
      if (failed(targetID)) {
        return nullptr;
      }
      format = *targetID;
    } else {
      format = "rocm-hsaco-fb"; // legacy HIP
    }

    if (auto target = GPU::getHIPTargetDetails(
            targetOptions.target, targetOptions.targetFeatures, context)) {
      addConfigGPUTarget(context, target, configItems);
      if (targetOptions.encodingLayoutResolver !=
          GPU::kNoEncodingLayoutResolverName) {
        if (Attribute encoding = GPU::getHIPTargetEncodingLayoutAttr(
                target, targetOptions.encodingLayoutResolver)) {
          addConfig(IREE::Encoding::kEncodingResolverAttrName, encoding);
        }
      }

      // Look for a default tuning spec.
      auto rocmDialect = context->getOrLoadDialect<IREE::ROCM::ROCMDialect>();
      // First check for a spec based on the sku.
      std::optional<std::string> maybeSpecName = std::nullopt;
      if (IREE::GPU::TargetChipAttr chip = target.getChip()) {
        if (StringAttr sku = chip.getSku()) {
          std::string specName =
              llvm::formatv("iree_default_tuning_spec_{}.mlir", sku.strref());
          if (!rocmDialect->hasBuiltin(specName)) {
            maybeSpecName = specName;
          }
        }
      }

      // Then, if none was found, look for one based on the target arch.
      if (!maybeSpecName) {
        std::string specName =
            llvm::formatv("iree_default_tuning_spec_{}.mlir", target.getArch());
        if (rocmDialect->hasBuiltin(specName)) {
          maybeSpecName = specName;
        }
      }

      if (maybeSpecName) {
        addConfig("iree_codegen.default_tuning_spec",
                  IREE::ROCM::BuiltinTuningModuleAttr::get(
                      context, maybeSpecName.value()));
      }
    }

    addConfig("ukernels", b.getStringAttr(targetOptions.enableROCMUkernels));
    if (targetOptions.enableROCMUkernels != "none") {
      addConfig("iree_codegen.ukernel_provider",
                IREE::ROCM::UKernelProviderAttr::get(context));
    }
    if (targetOptions.wavesPerEu > 0) {
      addConfigWavesPerEu(b.getContext(), targetOptions.wavesPerEu,
                          configItems);
    }
    if (targetOptions.denormalFpMathF32 !=
        IREE::Codegen::DenormalFpMath::None) {
      addConfigDenormalFpMathF32(b.getContext(),
                                 targetOptions.denormalFpMathF32, configItems);
    }

    if (targetOptions.enableTensorUKernels) {
      addConfig(kUKernelProviderName,
                IREE::ROCM::TensorUKernelProviderAttr::get(context));
    }

    return b.getAttr<IREE::HAL::ExecutableTargetAttr>(
        b.getStringAttr("rocm"), b.getStringAttr(format),
        b.getDictionaryAttr(configItems));
  }

  void getDependentDialects(DialectRegistry &registry) const final {
    mlir::registerBuiltinDialectTranslation(registry);
    mlir::registerLLVMDialectTranslation(registry);
    mlir::registerROCDLDialectTranslation(registry);
    registry.insert<IREE::Codegen::IREECodegenDialect>();
    registry.insert<IREE::VectorExt::IREEVectorExtDialect>();
    registry.insert<IREE::GPU::IREEGPUDialect>();
    registry.insert<IREE::ROCM::ROCMDialect>();
    registry.insert<IREE::Util::UtilDialect>();
    registry.insert<mlir::gpu::GPUDialect>();
    // Configuration may load and manipulate transform dialect libraries.
    registerTransformDialectTranslationDependentDialects(registry);
  }

  void
  buildConfigurationPassPipeline(IREE::HAL::ExecutableTargetAttr targetAttr,
                                 OpPassManager &passManager) final {
    if (targetOptions.specializeDispatches) {
      if (auto attr = getGPUTargetAttr(targetAttr.getContext(), targetAttr)) {
        ROCM::ApplyBuiltinPDLPatternsPassOptions options;
        options.enableSpecialization = true;
        if (IREE::GPU::TargetChipAttr chip = attr.getChip()) {
          if (StringAttr sku = chip.getSku()) {
            options.targets.push_back(sku.str());
          }
        }
        options.targets.push_back(attr.getArch().str());
        OpPassManager &modulePassManager = passManager.nest<ModuleOp>();
        FunctionLikeNest(modulePassManager).addPass([&]() {
          return ROCM::createApplyBuiltinPDLPatternsPass(options);
        });
      }
    }
    if (targetOptions.enableTensorUKernels) {
      if (auto attr = getGPUTargetAttr(targetAttr.getContext(), targetAttr)) {
        ROCM::ApplyBuiltinPDLPatternsPassOptions options;
        options.enableTensorUKernels = true;
        if (IREE::GPU::TargetChipAttr chip = attr.getChip()) {
          if (StringAttr sku = chip.getSku()) {
            options.targets.push_back(sku.str());
          }
        }
        options.targets.push_back(attr.getArch().str());
        OpPassManager &modulePassManager = passManager.nest<ModuleOp>();
        FunctionLikeNest(modulePassManager).addPass([&]() {
          return ROCM::createApplyBuiltinPDLPatternsPass(options);
        });
      }
    }
    buildCodegenConfigurationPreProcessingPassPipeline(passManager);
    OpPassManager &modulePassManager = passManager.nest<ModuleOp>();
    buildLLVMGPUCodegenCommonConfigurationPassPipeline(modulePassManager);
    if (targetOptions.enableTensorUKernels) {
      modulePassManager.addPass(
          IREE::ROCM::createApplyBuiltinPDLPatternsDriverPass());
    }
    {
      MaterializeTuningSpecsPassOptions passOptions;
      passOptions.tuningSpecPath = codegenOptions.tuningSpecPath;
      modulePassManager.addPass(createMaterializeTuningSpecsPass(passOptions));
    }
    modulePassManager.addPass(createMaterializeUserConfigsPass());
    modulePassManager.addPass(createLLVMGPUSelectLoweringStrategyPass(
        LLVMGPUSelectLoweringStrategyPassOptions{codegenOptions}));
    if (shouldEmitPipelineConstraints()) {
      FunctionLikeNest(modulePassManager)
          .addPass(createInsertSMTConstraintsPass)
          .addPass(createVerifySMTConstraintsPass);
    }
  }

  void buildTranslationPassPipeline(IREE::HAL::ExecutableTargetAttr targetAttr,
                                    OpPassManager &passManager) final {
    buildLLVMGPUCodegenPassPipeline(passManager.nest<ModuleOp>(), true,
                                    targetOptions.debugSymbols);
    buildCodegenTranslationPostProcessingPassPipeline(passManager);
  }

  void buildLinkingPassPipeline(OpPassManager &passManager) final {
    buildLLVMGPULinkingPassPipeline(passManager, "rocm");
  }

  // Performs optimizations on |module| (including LTO-style whole-program
  // ones). Inspired by code section in
  // https://github.com/iree-org/iree/blob/main/compiler/plugins/target/CUDA/CUDATarget.cpp
  static void optimizeModule(llvm::Module &module,
                             llvm::TargetMachine &targetMachine,
                             bool slpVectorization,
                             std::string &outPassesString) {
    llvm::LoopAnalysisManager lam;
    llvm::FunctionAnalysisManager fam;
    llvm::CGSCCAnalysisManager cgam;
    llvm::ModuleAnalysisManager mam;

    fam.registerPass([&] { return targetMachine.getTargetIRAnalysis(); });

    llvm::PipelineTuningOptions pto;
    pto.SLPVectorization = slpVectorization;

    llvm::PassInstrumentationCallbacks pic;

    llvm::StandardInstrumentations si(module.getContext(), false);
    si.registerCallbacks(pic, &mam);

    llvm::PassBuilder pb(&targetMachine, pto, std::nullopt, &pic);
    llvm::ModulePassManager mpm;
    pb.registerModuleAnalyses(mam);
    pb.registerCGSCCAnalyses(cgam);
    pb.registerFunctionAnalyses(fam);
    pb.registerLoopAnalyses(lam);
    pb.crossRegisterProxies(lam, fam, cgam, mam);

    llvm::OptimizationLevel ol = llvm::OptimizationLevel::O2;

    mpm.addPass(llvm::VerifierPass());
    mpm.addPass(pb.buildPerModuleDefaultPipeline(ol));
    mpm.addPass(llvm::VerifierPass());
    llvm::raw_string_ostream os(outPassesString);
    mpm.printPipeline(os, [&pic](StringRef className) {
      auto passName = pic.getPassNameForClassName(className);
      return passName.empty() ? className : passName;
    });
    mpm.run(module, mam);
  }

  LogicalResult
  validateFinalizedModule(IREE::HAL::ExecutableVariantOp variantOp,
                          llvm::Module &module) {
    for (llvm::Function &func : module.functions()) {
      if (func.isDeclaration() && !func.isIntrinsic() && !func.use_empty()) {
        llvm::User *liveUser = *func.user_begin();
        return variantOp.emitError()
               << "found an unresolved external function '" << func.getName()
               << "' in the final bitcode. A remaining live user is\n"
               << llvm::formatv("{}", *liveUser);
      }
    }
    return success();
  }

  LogicalResult
  serializeExecutable(const SerializationOptions &serializationOptions,
                      IREE::HAL::ExecutableVariantOp variantOp,
                      OpBuilder &executableBuilder) final {
    ModuleOp innerModuleOp = variantOp.getInnerModule();
    auto targetAttr = variantOp.getTargetAttr();
    StringRef targetArch = targetOptions.target;
    StringRef targetFeatures = targetOptions.targetFeatures;
    if (auto attr = getGPUTargetAttr(variantOp.getContext(), targetAttr)) {
      targetArch = attr.getArch();
      targetFeatures = attr.getFeatures();
    }

    // We name our files after the executable name so that they are easy to
    // track both during compilation (logs/artifacts/etc), as outputs (final
    // intermediate code/binary files), and at runtime (loaded
    // libraries/symbols/etc).
    const std::string libraryName =
        variantOp->getParentOfType<IREE::HAL::ExecutableOp>().getName().str();

    // Collect all the entry point names.
    auto exportOps = llvm::to_vector_of<IREE::HAL::ExecutableExportOp>(
        variantOp.getExportOps());
    std::optional<uint32_t> subgroupSize;
    for (IREE::HAL::ExecutableExportOp exportOp : exportOps) {
      // TODO: put this either on the variant or propagate as a function
      // attribute instead - today this *must* be consistent across all exports
      // and it shouldn't need to be.
      if (auto setSubgroupSize = exportOp.getSubgroupSizeAsUInt()) {
        if (setSubgroupSize.value() != 32 && setSubgroupSize.value() != 64) {
          return variantOp.emitError()
                 << "invalid subgroup size " << setSubgroupSize.value();
        }
        if (subgroupSize.has_value() &&
            setSubgroupSize.value() != subgroupSize.value()) {
          return variantOp.emitError()
                 << "multiple exports with different subgroup sizes; this is a "
                    "limitation of the IREE compilation process and should be "
                    "fixed";
        }
        subgroupSize = setSubgroupSize.value();
      }
    }

    std::string targetHSACO;
    if (variantOp.isExternal()) {
      if (!variantOp.getObjects().has_value()) {
        return variantOp.emitOpError()
               << "no objects defined for external variant";
      }

      if (variantOp.getObjects()->getValue().size() != 1) {
        // For now we assume there will be exactly one object file.
        // In the future we will want to perform a linking step here and ideally
        // support _also_ linking in the codegen results.
        return variantOp.emitOpError() << "only one object reference is "
                                          "supported for external variants";
      }

      // Read the HSACO from the object file.
      auto objectAttr = cast<IREE::HAL::ExecutableObjectAttr>(
          variantOp.getObjects()->getValue().front());
      if (auto data = objectAttr.loadData()) {
        targetHSACO = data.value();
      } else {
        return variantOp.emitOpError()
               << "object file could not be loaded: " << objectAttr;
      }
    } else {
      auto maybeChipset = amdgpu::Chipset::parse(targetArch);
      if (failed(maybeChipset)) {
        return variantOp.emitOpError()
               << "could not parse AMDGPU chipset name '" << targetArch << "'";
      }
      amdgpu::Chipset chipset = *maybeChipset;
      // Perform the translation in a separate context to avoid any
      // multi-threading issues.
      llvm::LLVMContext context;
      std::unique_ptr<llvm::Module> llvmModule =
          mlir::translateModuleToLLVMIR(innerModuleOp, context, libraryName);
      if (!llvmModule) {
        return variantOp.emitError() << "failed to translate the MLIR LLVM "
                                        "dialect to the native llvm::Module";
      }

      for (auto func : innerModuleOp.getOps<LLVM::LLVMFuncOp>()) {
        llvm::Function *llvmFunc = llvmModule->getFunction(func.getName());
        if (llvmFunc->isDeclaration()) {
          continue;
        }

        // Override flags as given by target func attrs.
        if (auto funcAttrs =
                func->getAttrOfType<DictionaryAttr>("llvm_func_attrs")) {
          for (NamedAttribute funcAttr : funcAttrs) {
            auto value = dyn_cast<StringAttr>(funcAttr.getValue());
            if (!value) {
              return variantOp->emitError()
                     << "llvm_func_attrs attribute must be a dictionary of "
                        "strings. Attribute "
                     << funcAttr.getName() << " is not a StringAttr.";
            }
            llvmFunc->addFnAttr(funcAttr.getName(), value.getValue());
          }
        }
      }

      std::unique_ptr<llvm::TargetMachine> targetMachine;
      bool isWave64 = true;
      {
        llvm::Triple triple("amdgcn-amd-amdhsa");
        std::string error;
        const llvm::Target *target =
            llvm::TargetRegistry::lookupTarget("", triple, error);
        if (!target) {
          return variantOp.emitError() << "cannot initialize target triple";
        }
        llvm::TargetOptions opt;
        opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
        // Be extra cautious while this is less tested, and prevent unknown
        // fallbacks from global isel.
        //
        // When GlobalISelAbort is set, any failure of GlobalISel,
        // whether due to being not yet implemented or incorrect IR will result
        // in an immediate abortion of compilation. This disables the fallback
        // path of AMDGPUPassConfig::addInstSelector and 2 legacy passes which
        // might work around unimplemented cases or errors in GlobalISel
        // resulting in a successful compilation but would make one assume
        // results are with GlobalISel when they are not.
        opt.EnableGlobalISel = targetOptions.globalISel;
        opt.GlobalISelAbort = targetOptions.globalISel
                                  ? llvm::GlobalISelAbortMode::Enable
                                  : llvm::GlobalISelAbortMode::Disable;
        SmallVector<std::string> features;
        if (chipset.majorVersion >= 10 && chipset.majorVersion <= 12) {
          switch (subgroupSize.value_or(64)) {
          case 32:
            isWave64 = false;
            features.emplace_back("+wavefrontsize32");
            break;
          default:
          case 64:
            features.emplace_back("+wavefrontsize64");
            break;
          }
        }

        // Mixed precision fma instructions have complicated semantics on
        // gf9+ GPUs and can lead to numeric issues as seen in
        // https://github.com/iree-org/iree/issues/18746 so we disable this
        // feature.
        if (targetArch.starts_with("gfx9")) {
          features.emplace_back("-fma-mix-insts");
        }

        if (!targetFeatures.empty()) {
          features.emplace_back(targetFeatures.str());
        }

        std::string featureStr = llvm::join(features, ",");

        targetMachine.reset(target->createTargetMachine(
            triple, targetArch, featureStr, opt, llvm::Reloc::Model::PIC_,
            std::nullopt, llvm::CodeGenOptLevel::Aggressive));

        if (!targetMachine) {
          return variantOp.emitError() << "cannot initialize target machine";
        }
      }

      llvmModule->setDataLayout(targetMachine->createDataLayout());

      // Code object version * 100.
      constexpr uint32_t abiVersion = 500;
      // Let the backend know what code object version we're compiling for. This
      // insulates us from changes to the default code object version that our
      // CI or users may not be prepared for.
      llvmModule->addModuleFlag(llvm::Module::Error,
                                "amdhsa_code_object_version", abiVersion);

      for (llvm::Function &f : llvmModule->functions()) {
        f.addFnAttr(llvm::Attribute::AlwaysInline);
      }

      // Link user-provided modules.
      llvm::Linker linker(*llvmModule);
      if (failed(linkCmdlineBitcodeFiles(
              variantOp.getLoc(), linker, llvm::Linker::OverrideFromSrc,
              *targetMachine, llvmModule->getContext()))) {
        return failure();
      }

      // Link bitcode (*.bc) object attrs specified by the input program.
      // Note that this happens after the command-line files so that the command
      // line ones override the symbols coming from the embedded files.
      auto specializationCallback = [&](llvm::Module &userModule) {
        // TODO: inject __nvvm_reflect-style functions/globals for
        // bitcode specialization based on the targetMachine and configuration.
        // These could use any information we have on the IREE side as well as
        // the TargetMachine.
      };
      unsigned linkerFlags =
          llvm::Linker::LinkOnlyNeeded | llvm::Linker::OverrideFromSrc;
      if (failed(linkBitcodeObjects(variantOp.getLoc(), linker, linkerFlags,
                                    *targetMachine, variantOp.getObjectsAttr(),
                                    llvmModule->getContext(),
                                    specializationCallback))) {
        return mlir::emitError(variantOp.getLoc())
               << "failed linking in user objects for target triple '"
               << targetArch.str() << "'";
      }

      // Link module to HIP device library.
      if (targetOptions.bitcodeDirectory.empty()) {
        return variantOp.emitError()
               << "cannot find ROCM bitcode files. Check your installation "
                  "consistency and in the worst case, set "
                  "--iree-rocm-bc-dir= to a path on your system.";
      }
      if (failed(linkHIPBitcodeIfNeeded(variantOp.getLoc(), llvmModule.get(),
                                        targetArch,
                                        targetOptions.bitcodeDirectory))) {
        return failure();
      }

      // Sets HIP platform globals based on the target architecture.
      if (failed(setHIPGlobals(variantOp.getLoc(), llvmModule.get(), chipset,
                               isWave64, abiVersion))) {
        return failure();
      }

      if (!serializationOptions.dumpIntermediatesPath.empty()) {
        dumpModuleToPath(serializationOptions.dumpIntermediatesPath,
                         serializationOptions.dumpBaseName, variantOp.getName(),
                         ".linked.ll", *llvmModule);
      }

      // For example 'gfx942'.
      StringRef targetCPU = targetMachine->getTargetCPU();

      // For example 'amdgcn-amd-amdhsa'.
      std::string targetTriple = targetMachine->getTargetTriple().str();

      // Run LLVM optimization passes.
      std::string passesString;
      optimizeModule(*llvmModule, *targetMachine,
                     targetOptions.slpVectorization, passesString);
      if (!serializationOptions.dumpIntermediatesPath.empty()) {
        // Additional context on '-mcpu' flag in PR comments, see for example:
        // https://github.com/iree-org/iree/pull/20716#issuecomment-2851650421
        std::string header = llvm::formatv(
            R"TXT(; To reproduce the .optimized.ll from the .linked.ll, run:
; opt -S -mtriple={} -mcpu={} --passes='{}' <.linked.ll>
; The flag '-S' is to emit LLVMIR.
; The behavior of some passes depends on '-mtriple' and '-mcpu'.

)TXT",
            targetTriple, targetCPU, passesString);

        dumpModuleToPath(serializationOptions.dumpIntermediatesPath,
                         serializationOptions.dumpBaseName, variantOp.getName(),
                         ".optimized.ll", *llvmModule, header);
      }

      if (failed(validateFinalizedModule(variantOp, *llvmModule))) {
        return failure();
      }

      // Dump the assembly output.
      if (!serializationOptions.dumpIntermediatesPath.empty()) {
        auto moduleCopy = llvm::CloneModule(*llvmModule);
        if (!moduleCopy) {
          llvm::errs() << "Error: cloning LLVM IR failed\n";
          return failure();
        }
        std::string asmHeader = llvm::formatv(
            R"TXT(; To reproduce the .rocmasm from .optimized.ll, run:
; llc -mtriple={} -mcpu={} -mattr='{}' -O3 <.optimized.ll> -o <out.rocmasm>

)TXT",
            targetTriple, targetCPU, targetMachine->getTargetFeatureString());

        std::string targetISA =
            translateModuleToISA(*moduleCopy.get(), *targetMachine);
        dumpDataToPath(serializationOptions.dumpIntermediatesPath,
                       serializationOptions.dumpBaseName, variantOp.getName(),
                       ".rocmasm", asmHeader + targetISA);
      }

      // Serialize hsaco kernel into the binary that we will embed in the
      // final FlatBuffer.
      std::string targetObj = translateModuleToObj(*llvmModule, *targetMachine);
      targetHSACO = createHsaco(variantOp.getLoc(), targetObj, libraryName);
      if (targetHSACO.empty()) {
        return failure();
      }

      if (targetOptions.enableRegSpillWarning) {
        checkRegisterSpilling(variantOp, targetObj);
      }
    }

    if (!serializationOptions.dumpBinariesPath.empty()) {
      dumpDataToPath(serializationOptions.dumpBinariesPath,
                     serializationOptions.dumpBaseName, variantOp.getName(),
                     ".hsaco", targetHSACO);
    }

    // Determine container type from the target ABI attribute.
    ContainerType containerType = targetOptions.containerType;
    if (containerType == ContainerType::Auto) {
      if (getABI(targetAttr) == "amdgpu") {
        containerType = ContainerType::AMDGPU;
      } else {
        containerType = ContainerType::HIP;
      }
    }

    // Wrap the HSACO ELF binary in the requested container type (if any).
    StringAttr executableBinaryFormat = variantOp.getTarget().getFormat();
    FailureOr<DenseIntElementsAttr> binaryContainer;
    switch (containerType) {
    case ContainerType::Auto: {
      // Resolved above; unreachable. Fall-through to error case.
      assert(false && "auto container type must have resolved earlier");
      break;
    }
    case ContainerType::AMDGPU: {
      FailureOr<std::string> targetID =
          buildAMDGPUTargetID(variantOp.getLoc(), targetArch, targetFeatures);
      if (failed(targetID)) {
        return failure();
      }
      executableBinaryFormat = executableBuilder.getStringAttr(*targetID);
      binaryContainer = serializeAMDGPUBinaryContainer(
          serializationOptions, variantOp, exportOps, *targetID, targetHSACO);
      break;
    }
    case ContainerType::HIP: {
      binaryContainer = serializeHIPBinaryContainer(
          serializationOptions, variantOp, exportOps, targetHSACO);
      break;
    }
    case ContainerType::HSACO: {
      SmallVector<uint8_t> image;
      image.resize(targetHSACO.size());
      std::memcpy(image.data(), targetHSACO.data(), image.size());
      binaryContainer = DenseIntElementsAttr::get(
          VectorType::get({static_cast<int64_t>(targetHSACO.size())},
                          executableBuilder.getI8Type()),
          image);
      break;
    }
    }
    if (failed(binaryContainer) || !binaryContainer.value()) {
      return failure();
    }

    // Add the binary data to the target executable.
    auto binaryOp = IREE::HAL::ExecutableBinaryOp::create(
        executableBuilder, variantOp.getLoc(), variantOp.getSymName(),
        executableBinaryFormat, binaryContainer.value());
    binaryOp.setMimeTypeAttr(
        executableBuilder.getStringAttr("application/x-flatbuffers"));

    return success();
  }

protected:
  FailureOr<DenseIntElementsAttr> serializeAMDGPUBinaryContainer(
      const SerializationOptions &serializationOptions,
      IREE::HAL::ExecutableVariantOp variantOp,
      ArrayRef<IREE::HAL::ExecutableExportOp> exportOps, StringRef targetID,
      StringRef hsacoModule) {
    iree_compiler::FlatbufferBuilder builder;
    iree_hal_amdgpu_ExecutableDef_start_as_root(builder);

    // Attach embedded source file contents.
    auto sourceFilesRef = createSourceFilesVec(
        serializationOptions.debugLevel, variantOp.getSourcesAttr(), builder);

    // Only a single module today.
    SmallVector<iree_hal_amdgpu_ModuleDef_ref_t> moduleRefs;
    {
      auto hsacoImageRef = flatbuffers_string_create(
          builder, hsacoModule.data(), hsacoModule.size());
      moduleRefs.push_back(
          iree_hal_amdgpu_ModuleDef_create(builder, hsacoImageRef));
    }
    auto modulesRef = builder.createOffsetVecDestructive(moduleRefs);

    // Generate optional per-export debug information.
    // May be empty if no debug information was requested.
    auto exportDebugInfos =
        createExportDefs(serializationOptions.debugLevel, exportOps, builder);

    SmallVector<iree_hal_amdgpu_ExportDef_ref_t> exportRefs;
    exportRefs.resize(exportOps.size(), 0);
    for (auto exportOp : exportOps) {
      auto ordinalAttr = exportOp.getOrdinalAttr();
      if (!ordinalAttr) {
        return mlir::emitError(exportOp.getLoc())
               << "could not compile rocm binary: export op is missing ordinal";
      }
      int64_t ordinal = ordinalAttr.getInt();

      // Symbol names include a `.kd` suffix as that's what HSA expects.
      auto symbolNameKd = (exportOp.getName() + ".kd").str();
      auto symbolNameRef = builder.createString(symbolNameKd);

      iree_hal_amdgpu_Dims_t workgroupSize = {0};
      if (auto workgroupSizeAttr = exportOp.getWorkgroupSize()) {
        auto workgroupSizeDims = workgroupSizeAttr->getValue();
        workgroupSize.x = cast<IntegerAttr>(workgroupSizeDims[0]).getInt();
        workgroupSize.y = cast<IntegerAttr>(workgroupSizeDims[1]).getInt();
        workgroupSize.z = cast<IntegerAttr>(workgroupSizeDims[2]).getInt();
      }

      auto layoutAttr = exportOp.getLayoutAttr();
      uint32_t constantCount = static_cast<uint32_t>(layoutAttr.getConstants());
      SmallVector<iree_hal_amdgpu_BindingBits_enum_t> bindingFlags;
      for (auto bindingAttr : layoutAttr.getBindings()) {
        iree_hal_amdgpu_BindingBits_enum_t flags = 0;
        if (allEnumBitsSet(bindingAttr.getFlags(),
                           IREE::HAL::DescriptorFlags::ReadOnly)) {
          flags |= iree_hal_amdgpu_BindingBits_READ_ONLY;
        }
        if (allEnumBitsSet(bindingAttr.getFlags(),
                           IREE::HAL::DescriptorFlags::Indirect)) {
          flags |= iree_hal_amdgpu_BindingBits_INDIRECT;
        }
        bindingFlags.push_back(flags);
      }
      auto bindingFlagsRef = iree_hal_amdgpu_BindingBits_vec_create(
          builder, bindingFlags.data(), bindingFlags.size());

      iree_hal_amdgpu_ExportDef_start(builder);
      iree_hal_amdgpu_ExportDef_symbol_name_add(builder, symbolNameRef);
      iree_hal_amdgpu_ExportDef_workgroup_size_add(builder, &workgroupSize);
      iree_hal_amdgpu_ExportDef_constant_count_add(builder, constantCount);
      iree_hal_amdgpu_ExportDef_binding_flags_add(builder, bindingFlagsRef);
      iree_hal_amdgpu_ExportDef_debug_info_add(builder,
                                               exportDebugInfos[ordinal]);
      exportRefs[ordinal] = iree_hal_amdgpu_ExportDef_end(builder);
    }
    auto exportsRef = builder.createOffsetVecDestructive(exportRefs);

    auto isaRef = builder.createString(targetID);
    iree_hal_amdgpu_ExecutableDef_isa_add(builder, isaRef);
    iree_hal_amdgpu_ExecutableDef_exports_add(builder, exportsRef);
    iree_hal_amdgpu_ExecutableDef_modules_add(builder, modulesRef);
    iree_hal_amdgpu_ExecutableDef_source_files_add(builder, sourceFilesRef);
    iree_hal_amdgpu_ExecutableDef_end_as_root(builder);

    return builder.getHeaderPrefixedBufferAttr(
        variantOp.getContext(),
        /*magic=*/iree_hal_amdgpu_ExecutableDef_file_identifier,
        /*version=*/0);
  }

  FailureOr<DenseIntElementsAttr>
  serializeHIPBinaryContainer(const SerializationOptions &serializationOptions,
                              IREE::HAL::ExecutableVariantOp variantOp,
                              ArrayRef<IREE::HAL::ExecutableExportOp> exportOps,
                              StringRef hsacoModule) {
    iree_compiler::FlatbufferBuilder builder;
    iree_hal_hip_ExecutableDef_start_as_root(builder);

    // Attach embedded source file contents.
    auto sourceFilesRef = createSourceFilesVec(
        serializationOptions.debugLevel, variantOp.getSourcesAttr(), builder);

    // Only a single module today.
    SmallVector<iree_hal_hip_ModuleDef_ref_t> moduleRefs;
    {
      auto hsacoImageRef = flatbuffers_string_create(
          builder, hsacoModule.data(), hsacoModule.size());
      moduleRefs.push_back(
          iree_hal_hip_ModuleDef_create(builder, hsacoImageRef));
    }
    auto modulesRef = builder.createOffsetVecDestructive(moduleRefs);

    // Generate optional per-export debug information.
    // May be empty if no debug information was requested.
    auto exportDebugInfos =
        createExportDefs(serializationOptions.debugLevel, exportOps, builder);

    SmallVector<iree_hal_hip_ExportDef_ref_t> exportRefs;
    exportRefs.resize(exportOps.size(), 0);
    for (auto exportOp : exportOps) {
      auto ordinalAttr = exportOp.getOrdinalAttr();
      if (!ordinalAttr) {
        return mlir::emitError(exportOp.getLoc())
               << "could not compile rocm binary: export op is missing ordinal";
      }
      int64_t ordinal = ordinalAttr.getInt();

      auto kernelNameRef = builder.createString(exportOp.getName());

      iree_hal_hip_BlockDims_t blockDims = {0};
      if (auto workgroupSizeAttr = exportOp.getWorkgroupSize()) {
        auto workgroupSize = workgroupSizeAttr->getValue();
        blockDims.x = cast<IntegerAttr>(workgroupSize[0]).getInt();
        blockDims.y = cast<IntegerAttr>(workgroupSize[1]).getInt();
        blockDims.z = cast<IntegerAttr>(workgroupSize[2]).getInt();
      }

      auto layoutAttr = exportOp.getLayoutAttr();
      uint32_t constantCount = static_cast<uint32_t>(layoutAttr.getConstants());
      SmallVector<iree_hal_hip_BindingBits_enum_t> bindingFlags;
      for (auto bindingAttr : layoutAttr.getBindings()) {
        iree_hal_hip_BindingBits_enum_t flags = 0;
        if (allEnumBitsSet(bindingAttr.getFlags(),
                           IREE::HAL::DescriptorFlags::ReadOnly)) {
          flags |= iree_hal_hip_BindingBits_READ_ONLY;
        }
        if (allEnumBitsSet(bindingAttr.getFlags(),
                           IREE::HAL::DescriptorFlags::Indirect)) {
          flags |= iree_hal_hip_BindingBits_INDIRECT;
        }
        bindingFlags.push_back(flags);
      }
      auto bindingFlagsRef = iree_hal_hip_BindingBits_vec_create(
          builder, bindingFlags.data(), bindingFlags.size());

      iree_hal_hip_ExportDef_start(builder);
      iree_hal_hip_ExportDef_module_ordinal_add(builder, 0); // always 0 today
      iree_hal_hip_ExportDef_kernel_name_add(builder, kernelNameRef);
      iree_hal_hip_ExportDef_block_dims_add(builder, &blockDims);
      iree_hal_hip_ExportDef_constant_count_add(builder, constantCount);
      iree_hal_hip_ExportDef_binding_flags_add(builder, bindingFlagsRef);
      iree_hal_hip_ExportDef_debug_info_add(builder, exportDebugInfos[ordinal]);
      exportRefs[ordinal] = iree_hal_hip_ExportDef_end(builder);
    }
    auto exportsRef = builder.createOffsetVecDestructive(exportRefs);

    iree_hal_hip_ExecutableDef_exports_add(builder, exportsRef);
    iree_hal_hip_ExecutableDef_modules_add(builder, modulesRef);
    iree_hal_hip_ExecutableDef_source_files_add(builder, sourceFilesRef);
    iree_hal_hip_ExecutableDef_end_as_root(builder);

    return builder.getHeaderPrefixedBufferAttr(
        variantOp.getContext(),
        /*magic=*/iree_hal_hip_ExecutableDef_file_identifier,
        /*version=*/0);
  }

private:
  const ROCMOptions &targetOptions;
  const GPUCodegenOptions codegenOptions;
};

class AMDGPUTargetDevice final : public TargetDevice {
public:
  AMDGPUTargetDevice(const ROCMOptions & /*options*/) {}

  IREE::HAL::DeviceTargetAttr
  getDefaultDeviceTarget(MLIRContext *context,
                         const TargetRegistry &targetRegistry) const final {
    Builder b(context);
    auto deviceConfigAttr = b.getDictionaryAttr({});
    auto executableConfigAttr = b.getDictionaryAttr({});

    // If we had multiple target environments we would generate one target attr
    // per environment, with each setting its own environment attribute.
    SmallVector<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs;
    targetRegistry.getTargetBackend("rocm")->getDefaultExecutableTargets(
        context, "amdgpu", executableConfigAttr, executableTargetAttrs);

    return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("amdgpu"),
                                            deviceConfigAttr,
                                            executableTargetAttrs);
  }
};

class HIPTargetDevice final : public TargetDevice {
public:
  HIPTargetDevice(const ROCMOptions & /*options*/) {}

  IREE::HAL::DeviceTargetAttr
  getDefaultDeviceTarget(MLIRContext *context,
                         const TargetRegistry &targetRegistry) const final {
    Builder b(context);
    auto deviceConfigAttr = b.getDictionaryAttr({});
    auto executableConfigAttr = b.getDictionaryAttr({});

    // If we had multiple target environments we would generate one target attr
    // per environment, with each setting its own environment attribute.
    SmallVector<IREE::HAL::ExecutableTargetAttr> executableTargetAttrs;
    targetRegistry.getTargetBackend("rocm")->getDefaultExecutableTargets(
        context, "hip", executableConfigAttr, executableTargetAttrs);

    return IREE::HAL::DeviceTargetAttr::get(context, b.getStringAttr("hip"),
                                            deviceConfigAttr,
                                            executableTargetAttrs);
  }
};

struct ROCMSession final
    : PluginSession<ROCMSession, ROCMOptions,
                    PluginActivationPolicy::DefaultActivated> {
  static void registerPasses() { IREE::ROCM::registerROCMTargetPasses(); }
  void onRegisterDialects(DialectRegistry &registry) final {
    registry.insert<IREE::ROCM::ROCMDialect>();
  }
  void populateHALTargetDevices(IREE::HAL::TargetDeviceList &targets) final {
    // #hal.device.target<"amdgpu", ...
    targets.add("amdgpu", [&]() {
      return std::make_shared<AMDGPUTargetDevice>(options);
    });
    // #hal.device.target<"hip", ...
    targets.add("hip",
                [&]() { return std::make_shared<HIPTargetDevice>(options); });
  }
  void populateHALTargetBackends(IREE::HAL::TargetBackendList &targets) final {
    // #hal.executable.target<"rocm", ...
    targets.add("rocm", [&]() {
      LLVMInitializeAMDGPUTarget();
      LLVMInitializeAMDGPUTargetMC();
      LLVMInitializeAMDGPUTargetInfo();
      LLVMInitializeAMDGPUAsmParser();
      LLVMInitializeAMDGPUAsmPrinter();
      return std::make_shared<ROCMTargetBackend>(options, codegenOptions);
    });
  }

  // Override Registration to also bind GPUCodegenOptions to the session.
  struct Registration : PluginSession::Registration {
    using PluginSession::Registration::Registration;
    std::unique_ptr<AbstractPluginSession>
    createUninitializedSession(OptionsBinder &localOptionsBinder) final {
      auto instance = std::make_unique<ROCMSession>();
      // Bootstrap target options from global CLI if available.
      if (globalCLIOptions) {
        instance->options = *(*globalCLIOptions);
      }
      instance->options.bindOptions(localOptionsBinder);

      // Bootstrap codegen options from global CLI if available.
      if (globalCLICodegenOptions) {
        instance->codegenOptions = *(*globalCLICodegenOptions);
      }
      instance->codegenOptions.bindOptions(localOptionsBinder);

      return instance;
    }
    void initializeCLI() final {
      PluginSession::Registration::initializeCLI();
      globalCLICodegenOptions = &GPUCodegenOptions::FromFlags::get();
    }
    std::optional<GPUCodegenOptions *> globalCLICodegenOptions;
  };

  GPUCodegenOptions codegenOptions;
};

// Iterate over ukernel bitcode embedded-data files, and insert them into the
// EmbeddedDataDirectory singleton.
static void addAMDGPUUkernelBitcodeToGlobalEmbeddedDataDirectory() {
  EmbeddedDataDirectory::withGlobal([](EmbeddedDataDirectory &dir) {
    const iree_file_toc_t *toc = iree_uk_amdgpu_bitcode_create();
    for (size_t i = 0; i < iree_uk_amdgpu_bitcode_size(); ++i) {
      dir.addFile(toc[i].name, llvm::StringRef{toc[i].data, toc[i].size});
    }
  });
}

} // namespace
} // namespace mlir::iree_compiler::IREE::HAL

extern "C" bool iree_register_compiler_plugin_hal_target_rocm(
    mlir::iree_compiler::PluginRegistrar *registrar) {
  registrar->registerPlugin<mlir::iree_compiler::IREE::HAL::ROCMSession>(
      "hal_target_rocm");
  mlir::iree_compiler::IREE::HAL::
      addAMDGPUUkernelBitcodeToGlobalEmbeddedDataDirectory();
  return true;
}

IREE_DEFINE_COMPILER_OPTION_FLAGS(mlir::iree_compiler::IREE::HAL::ROCMOptions);
