[ROCM] Update Ukernel infra to handle InnerTiledOp/Multi_MMA_MFMA (#21759)

-- This commit updates ROCM to handle ukernel lowering for InnerTiledOp
(multi_mma_mfma).

Signed-off-by: Abhishek Varma
[abhvarma@amd.com](mailto:abhvarma@amd.com)
diff --git a/compiler/plugins/target/ROCM/Dialect/ROCM/IR/BUILD.bazel b/compiler/plugins/target/ROCM/Dialect/ROCM/IR/BUILD.bazel
index da69584..ead5fdb 100644
--- a/compiler/plugins/target/ROCM/Dialect/ROCM/IR/BUILD.bazel
+++ b/compiler/plugins/target/ROCM/Dialect/ROCM/IR/BUILD.bazel
@@ -60,10 +60,22 @@
         "//compiler/plugins/target/ROCM/builtins/specialization:iree_specialization_patterns_amdgpu",
         "//compiler/plugins/target/ROCM/builtins/tuning:iree_default_tuning_specs_amdgpu",
         "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
+        "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils",
+        "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
+        "//compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms:GPUTransforms",
+        "//compiler/src/iree/compiler/Codegen/Utils",
+        "//compiler/src/iree/compiler/Dialect/HAL/IR",
         "//compiler/src/iree/compiler/Dialect/Util/IR",
         "//compiler/src/iree/compiler/Utils",
+        "@llvm-project//llvm:ExecutionEngine",
+        "@llvm-project//llvm:IRReader",
         "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:BufferizationDialect",
         "@llvm-project//mlir:DialectUtils",
+        "@llvm-project//mlir:GPUDialect",
+        "@llvm-project//mlir:GPUTransformOps",
+        "@llvm-project//mlir:GPUTransforms",
+        "@llvm-project//mlir:GPUUtils",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:LinalgDialect",
         "@llvm-project//mlir:Parser",
diff --git a/compiler/plugins/target/ROCM/Dialect/ROCM/IR/CMakeLists.txt b/compiler/plugins/target/ROCM/Dialect/ROCM/IR/CMakeLists.txt
index 1f8fafe..f6f893f 100644
--- a/compiler/plugins/target/ROCM/Dialect/ROCM/IR/CMakeLists.txt
+++ b/compiler/plugins/target/ROCM/Dialect/ROCM/IR/CMakeLists.txt
@@ -28,12 +28,24 @@
   DEPS
     ::ROCMAttrs
     ::ROCMDialectGen
+    LLVMExecutionEngine
+    LLVMIRReader
     LLVMSupport
+    MLIRBufferizationDialect
+    MLIRGPUDialect
+    MLIRGPUTransformOps
+    MLIRGPUTransforms
+    MLIRGPUUtils
     MLIRIR
     MLIRLinalgDialect
     MLIRParser
     MLIRSupport
     iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
+    iree::compiler::Codegen::Dialect::Codegen::Utils
+    iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
+    iree::compiler::Codegen::Dialect::GPU::Transforms::GPUTransforms
+    iree::compiler::Codegen::Utils
+    iree::compiler::Dialect::HAL::IR
     iree::compiler::Dialect::Util::IR
     iree::compiler::Utils
     iree::compiler::plugins::target::ROCM::builtins::mlir_ukernel::iree_mlir_ukernel_patterns_amdgpu
diff --git a/compiler/plugins/target/ROCM/Dialect/ROCM/IR/ROCMAttrs.cpp b/compiler/plugins/target/ROCM/Dialect/ROCM/IR/ROCMAttrs.cpp
index b2817a1..2421461 100644
--- a/compiler/plugins/target/ROCM/Dialect/ROCM/IR/ROCMAttrs.cpp
+++ b/compiler/plugins/target/ROCM/Dialect/ROCM/IR/ROCMAttrs.cpp
@@ -6,9 +6,22 @@
 
 #include "compiler/plugins/target/ROCM/Dialect/ROCM/IR/ROCMAttrs.h"
 #include "compiler/plugins/target/ROCM/Dialect/ROCM/IR/ROCMDialect.h"
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
 #include "iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
+#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/ExecutionEngine/ExecutionEngine.h"
+#include "llvm/ExecutionEngine/GenericValue.h"
+#include "llvm/IRReader/IRReader.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/AsmState.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/OpDefinition.h"
@@ -34,10 +47,11 @@
 //===----------------------------------------------------------------------===//
 
 /// Utility function to help create and replace argmax linalg with a ukernel.
-static LogicalResult handleArgmaxUkernel(
-    RewriterBase &rewriter, StringRef name, DictionaryAttr targetConfiguration,
-    Operation *contextualOp, SmallVectorImpl<Value> &inputs,
-    SmallVectorImpl<Value> &outputs, SmallVectorImpl<Value> &otherOperands) {
+static LogicalResult
+handleArgmaxUkernel(RewriterBase &rewriter, StringRef name,
+                    DictionaryAttr targetConfiguration, Operation *contextualOp,
+                    ArrayRef<Value> inputs, ArrayRef<Value> outputs,
+                    SmallVectorImpl<Value> &otherOperands) {
   auto genericOp = dyn_cast<linalg::GenericOp>(contextualOp);
   if (!genericOp) {
     return rewriter.notifyMatchFailure(
@@ -79,16 +93,372 @@
   return success();
 }
 
+constexpr char executableObjectsAttrName[] = "hal.executable.objects";
+
+// Walks parents ops from `op` to return the nearest hal.executable.objects
+// array attribute. If the parent hal.executable.variant is reached, its objects
+// attribute is returned.
+// Adapted from ExecutableTargetAttr::lookup.
+static ArrayAttr lookUpExecutableObjects(Operation *op) {
+  MLIRContext *context = op->getContext();
+  auto attrId = StringAttr::get(context, executableObjectsAttrName);
+  while (op) {
+    // Take directly from the enclosing variant.
+    if (auto variantOp = dyn_cast<IREE::HAL::ExecutableVariantOp>(op)) {
+      if (std::optional<ArrayAttr> objects = variantOp.getObjects()) {
+        return *objects;
+      }
+    }
+    // Take from op attributes.
+    if (auto attr = op->getAttrOfType<ArrayAttr>(attrId)) {
+      return attr;
+    }
+    // Continue walk.
+    op = op->getParentOp();
+  }
+  return {};
+}
+
+static Value createSharedMemory(RewriterBase &rewriter, Location loc,
+                                int64_t sharedMemoryBytes) {
+  RankedTensorType tensorType =
+      RankedTensorType::get({sharedMemoryBytes}, rewriter.getI8Type());
+  ValueRange dynSizes{};
+  if (!sharedMemoryBytes) {
+    IREE::Codegen::NullPointerType nullPointerType =
+        IREE::Codegen::NullPointerType::get(rewriter.getContext());
+    return rewriter.create<IREE::Codegen::NullPointerOp>(loc, nullPointerType);
+  }
+  auto allocOp =
+      rewriter.create<bufferization::AllocTensorOp>(loc, tensorType, dynSizes);
+  Attribute sharedAddrSpace = gpu::AddressSpaceAttr::get(
+      rewriter.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
+  allocOp.setMemorySpaceAttr(sharedAddrSpace);
+  return allocOp;
+}
+
+// Returns the index of the innermost CrossIntrinsic dimension of the C matrix,
+// if it is static, and std::nullopt if it is dynamic or if there are no
+// CrossIntrinsic dims.
+static std::optional<unsigned>
+getCInnermostStaticCrossIntrinsicDim(IREE::Codegen::InnerTiledOp op) {
+  auto outputType = dyn_cast<ShapedType>(op.getResultTypes()[0]);
+  if (!outputType) {
+    return std::nullopt;
+  }
+  auto mma = cast<IREE::GPU::DataTiledMMAAttr>(op.getKind());
+  IREE::Codegen::TileSwizzle accSwizzle =
+      getSwizzle(mma, IREE::GPU::MMAFragment::Acc);
+  SmallVector<IREE::Codegen::TileSwizzle::Dim> swizzleDims;
+  for (IREE::Codegen::TileSwizzle::ExpandShapeDimVectorType group :
+       accSwizzle.expandShape) {
+    swizzleDims.append(group);
+  }
+  applyPermutationToVector(swizzleDims, accSwizzle.permutation);
+  int rankDiff = outputType.getRank() - swizzleDims.size();
+  auto crossIntrinsic = IREE::Codegen::TileSwizzle::Dim::Kind::CrossIntrinsic;
+  for (size_t e = swizzleDims.size(), swizzleIdx = e - 1; swizzleIdx < e;
+       --swizzleIdx) {
+    if (swizzleDims[swizzleIdx].kind != crossIntrinsic) {
+      continue;
+    }
+    int outputIdx = swizzleIdx + rankDiff;
+    if (outputType.isDynamicDim(outputIdx)) {
+      return std::nullopt;
+    }
+    return outputIdx;
+  }
+  return std::nullopt;
+}
+
+static int64_t getSharedMemoryBytes(IREE::GPU::TargetAttr gpuTarget) {
+  if (!gpuTarget) {
+    return 0;
+  }
+  IREE::GPU::TargetWgpAttr wgp = gpuTarget.getWgp();
+  if (!wgp) {
+    return 0;
+  }
+  return wgp.getMaxWorkgroupMemoryBytes();
+}
+
+// Returns a ExecutableObjectAttr carrying the bitcode for the given ukernel.
+//
+// First tries finding the bitcode in the input `sourceExecutableObjects`, which
+// must be an array of ExecutableObjectAttr's and is typically coming from a
+// hal.executable.objects array attribute in the source IR, which is the
+// mechanism by which source programs may provide their own ukernel bitcode.
+//
+// If no matching bitcode was found in `sourceExecutableObjects`, this function
+// will then search in bitcode files that we have embedded as static data.
+static IREE::HAL::ExecutableObjectAttr
+getUKernelBitcode(MLIRContext *context,
+                  IREE::HAL::ExecutableTargetAttr execTarget,
+                  ArrayAttr sourceExecutableObjects, StringRef filename) {
+  // Early-return if the source executable.objects already contain an object
+  // with the expected file name. This happens with user-provided bitcode in the
+  // source IR.
+  if (sourceExecutableObjects) {
+    for (Attribute a : sourceExecutableObjects) {
+      if (auto object = dyn_cast<IREE::HAL::ExecutableObjectAttr>(a)) {
+        if (object.getPath() == filename) {
+          return object;
+        }
+      }
+    }
+  }
+
+  // No user-provided bitcode, so we search our embedded bitcode files in the
+  // EmbeddedDataDirectory singleton.
+  std::optional<StringRef> bitcode;
+  EmbeddedDataDirectory::withGlobal(
+      [&](EmbeddedDataDirectory &dir) { bitcode = dir.getFile(filename); });
+  if (!bitcode) {
+    return {};
+  }
+  AsmResourceBlob blob = HeapAsmResourceBlob::allocateAndCopyInferAlign(
+      ArrayRef<char>(bitcode->data(), bitcode->size()));
+  auto bitcodeDenseAttr = DenseI8ResourceElementsAttr::get(
+      VectorType::get({static_cast<int64_t>(bitcode->size())},
+                      IntegerType::get(context, 8)),
+      filename, std::move(blob));
+  return IREE::HAL::ExecutableObjectAttr::get(
+      context, StringAttr::get(context, filename),
+      cast<IREE::Util::SerializableAttrInterface>(bitcodeDenseAttr));
+}
+
+static std::string getBitcodeFilename(IREE::GPU::TargetAttr gpuTarget,
+                                      StringRef name) {
+  return llvm::formatv("{}.{}.bc", name, gpuTarget.getArch());
+}
+
+// Helper for getSharedMemoryBytes. Typical latency: 2 ms.
+// Evaluates the shared memory size required by the multi_mma microkernel by
+// interpreting a bitcode function with a specific name.
+// On failure, an op warning is emitted and {} is returned.
+static std::optional<int64_t> expensivelyEvaluateSharedMemoryBytes(
+    MLIRContext *context, IREE::Codegen::InnerTiledOp op, StringRef ukernelName,
+    IREE::GPU::TargetAttr gpuTarget) {
+  auto target = IREE::HAL::ExecutableTargetAttr::lookup(op);
+  std::string filename = getBitcodeFilename(gpuTarget, ukernelName);
+  ArrayAttr sourceExecutableObjects = lookUpExecutableObjects(op);
+  IREE::HAL::ExecutableObjectAttr bitcodeObject =
+      getUKernelBitcode(context, target, sourceExecutableObjects, filename);
+
+  auto mma = dyn_cast<IREE::GPU::DataTiledMMAAttr>(op.getKind());
+
+  IREE::Util::SerializableAttrInterface bitcodeData = bitcodeObject.getData();
+  std::string buffer;
+  buffer.resize(bitcodeData.getStorageSize());
+  if (failed(bitcodeObject.getData().serializeToBuffer(
+          op->getLoc(), llvm::endianness::native,
+          ArrayRef<char>{buffer.data(), buffer.size()}))) {
+    op.emitWarning("Failed to serialize bitcode.");
+    return {};
+  }
+  llvm::LLVMContext llvmContext;
+  llvm::Expected<std::unique_ptr<llvm::Module>> moduleOp =
+      llvm::getLazyBitcodeModule(llvm::MemoryBufferRef{buffer, ukernelName},
+                                 llvmContext,
+                                 /*ShouldLazyLoadMetadata=*/true);
+  if (!moduleOp) {
+    op.emitWarning("Failed to parse bitcode module.");
+    return {};
+  }
+  llvm::EngineBuilder builder(std::move(moduleOp.get()));
+  std::string builderError;
+  builder.setEngineKind(llvm::EngineKind::Interpreter)
+      .setErrorStr(&builderError);
+  std::unique_ptr<llvm::ExecutionEngine> interpreter{builder.create()};
+  if (!interpreter) {
+    op.emitWarning("Failed to create the interpreter.");
+    return {};
+  }
+  std::string queryFuncName =
+      llvm::formatv("{}_query_shared_memory_bytes", ukernelName);
+  llvm::Function *func = interpreter->FindFunctionNamed(queryFuncName);
+  if (!func) {
+    op.emitWarning(llvm::formatv(
+        "Bitcode does not contain a function named {}.", queryFuncName));
+    return {};
+  }
+  auto constI32 = [](int32_t val) {
+    llvm::GenericValue v;
+    v.IntVal = APInt(32, val);
+    return v;
+  };
+  llvm::GenericValue args[] = {
+      constI32(mma.getIntrinsicsM()), constI32(mma.getSubgroupsM()),
+      constI32(mma.getIntrinsicsN()), constI32(mma.getSubgroupsN()),
+      constI32(mma.getIntrinsicsK())};
+  if (func->arg_size() != /*total elements in 'args'=*/5) {
+    op.emitWarning(llvm::formatv(
+        "Bitcode function {} takes {} arguments. Expected {}.", queryFuncName,
+        func->arg_size(), /*total elements in 'args'=*/5));
+    return {};
+  }
+  llvm::GenericValue interpreterResult = interpreter->runFunction(func, args);
+  if (interpreter->hasError()) {
+    op.emitWarning(llvm::formatv("Error while interpreting bitcode: {}.",
+                                 interpreter->getErrorMessage()));
+    return {};
+  }
+  int64_t sharedMemoryBytes = interpreterResult.IntVal.getSExtValue();
+
+  // Reject a ukernel that would consume too much shared memory, which we need
+  // to save for other purposes. This threshold can always be adjusted but we
+  // default to a low threshold to get an early signal.
+  int64_t maxSharedMemoryBytes = getSharedMemoryBytes(gpuTarget) / 4;
+  if (sharedMemoryBytes > maxSharedMemoryBytes) {
+    op.emitWarning(llvm::formatv("The shared memory size {} required by the "
+                                 "ukernel exceeds the maximum allowed size {}.",
+                                 sharedMemoryBytes, maxSharedMemoryBytes));
+    return {};
+  }
+  return sharedMemoryBytes;
+}
+
+// Returns the shared memory size required by the multi_mma ukernel.
+// On failure, an op warning is emitted and {} is returned.
+// Uses a static cache to avoid calling expensivelyEvaluateSharedMemoryBytes
+// more than once per DataTiledMMAAttr value.
+static std::optional<int64_t>
+getSharedMemoryBytes(MLIRContext *context, IREE::Codegen::InnerTiledOp op,
+                     StringRef ukernelName,
+                     DictionaryAttr targetConfiguration) {
+  auto mma = dyn_cast<IREE::GPU::DataTiledMMAAttr>(op.getKind());
+  IREE::GPU::TargetAttr gpuTarget = getGPUTargetAttr(targetConfiguration);
+  if (!gpuTarget) {
+    return {};
+  }
+  // We use the stringification of the attributes, rather than the
+  // attributes themselves, as the key, to ensure it's self-contained and does
+  // not contain pointers to other objects, such as a `MLIRContext*`, which
+  // could go dangling.
+  std::string key = llvm::formatv("mma = {}, gpuTarget = {}", mma, gpuTarget);
+
+  struct CacheEntry {
+    std::optional<int64_t> sharedMemoryBytes;
+    std::mutex mutex;
+    bool evaluated = false;
+  };
+
+  // The cache and the mutex guarding it.
+  // We store the CacheEntry's by pointers, so that we don't need to worry about
+  // entryPtr being invalidated.
+  static llvm::StringMap<std::unique_ptr<CacheEntry>> cache;
+  static std::mutex cacheMutex;
+
+  CacheEntry *entryPtr = nullptr;
+
+  {
+    // Critical section on `cacheMutex`. This is the only place where we
+    // access `cache`. When we will later update a cache entry, that will be
+    // through `entryPtr`, independently of `cache`.
+    std::lock_guard<std::mutex> lock(cacheMutex);
+    auto iter = cache.find(key);
+    if (iter != cache.end()) {
+      // Cache hit. Early return.
+      return iter->second->sharedMemoryBytes;
+    }
+    // Cache miss. Create a new cache entry and acquire its mutex.
+    entryPtr =
+        cache.insert({key, std::make_unique<CacheEntry>()}).first->second.get();
+    entryPtr->mutex.lock();
+  }
+
+  // If the entry still isn't evaluated after we have acquired its mutex,
+  // perform the evaluation now.
+  if (!entryPtr->evaluated) {
+    entryPtr->sharedMemoryBytes = expensivelyEvaluateSharedMemoryBytes(
+        context, op, ukernelName, gpuTarget);
+    entryPtr->evaluated = true;
+  }
+
+  entryPtr->mutex.unlock();
+  return entryPtr->sharedMemoryBytes;
+}
+
+/// Utility function to help create and replace inner_tiled with a ukernel.
+static LogicalResult handleInnerTiledMmaUkernel(
+    RewriterBase &rewriter, StringRef name, DictionaryAttr targetConfiguration,
+    Operation *contextualOp, ArrayRef<Value> inputs, ArrayRef<Value> outputs,
+    SmallVectorImpl<Value> &otherOperands) {
+  auto op = dyn_cast<IREE::Codegen::InnerTiledOp>(contextualOp);
+  if (!op) {
+    return rewriter.notifyMatchFailure(
+        contextualOp, "expected a codegen.inner_tiled op for multi_mma");
+  }
+  auto mma = dyn_cast<IREE::GPU::DataTiledMMAAttr>(op.getKind());
+  if (!mma) {
+    return rewriter.notifyMatchFailure(op, "unhandled MMAInterfaceAttr");
+  }
+  std::optional<int64_t> innerCrossIntrinsicDim =
+      getCInnermostStaticCrossIntrinsicDim(op);
+  if (!innerCrossIntrinsicDim) {
+    return rewriter.notifyMatchFailure(
+        op, "inner cross-intrinsic dim is dynamic or not found");
+  }
+  Location loc = op->getLoc();
+  Type I32Type = rewriter.getI32Type();
+  auto castIndexToI32 = [&](Value val) {
+    return rewriter.create<arith::IndexCastOp>(loc, I32Type, val);
+  };
+  auto constI32 = [&](int val) {
+    return rewriter.create<arith::ConstantIntOp>(loc, I32Type, val);
+  };
+  MLIRContext *context = rewriter.getContext();
+  std::optional<int64_t> maybeSharedMemoryBytes =
+      getSharedMemoryBytes(context, op, name, targetConfiguration);
+  int64_t sharedMemoryBytes =
+      (!maybeSharedMemoryBytes) ? 0 : maybeSharedMemoryBytes.value();
+  Value sharedMemory = createSharedMemory(rewriter, loc, sharedMemoryBytes);
+  Value k = castIndexToI32(
+      rewriter.create<tensor::DimOp>(op.getLoc(), op.getInputs()[0], 1));
+  Value intrinsicsM = constI32(mma.getIntrinsicsM());
+  Value subgroupsM = constI32(mma.getSubgroupsM());
+  Value intrinsicsN = constI32(mma.getIntrinsicsN());
+  Value subgroupsN = constI32(mma.getSubgroupsN());
+  Value intrinsicsK = constI32(mma.getIntrinsicsK());
+  // There are 3 shaped input/output operands (A/B/C matrices).
+  SmallVector<SmallVector<int64_t>> stridedDims(3, {});
+  // Only the C matrix gets strides, and we pass the stride of the innermost
+  // CrossIntrinsic dim, because the ukernel needs to know where to store the
+  // result vector from each unrolled intrinsic. Offsets into all other
+  // dimensions are handled by the compiler, and passed as part of the base
+  // pointer + offset. The A and B matrices don't get strides, because we
+  // expect them to always be passed as global memory pointers, and the
+  // strides can be inferred by the ukernel implementation.
+  stridedDims[2].push_back(innerCrossIntrinsicDim.value());
+  // The only additional shaped operand is the shared memory buffer. Only
+  // create a stride list for it if we have shared memory. Otherwise, the
+  // operand is an iree_codegen.null_pointer op.
+  if (sharedMemoryBytes != 0) {
+    // Shared memory does not need strides.
+    stridedDims.push_back({});
+  }
+  auto fnDefAttrs = DictionaryAttr::get(
+      context, {{"vm.import.module", StringAttr::get(context, "rocm")}});
+  rewriter.replaceOpWithNewOp<IREE::Codegen::UKernelGenericOp>(
+      op, op.getOutputs().getTypes(), name, inputs, outputs,
+      ValueRange{sharedMemory, constI32(sharedMemoryBytes), k, intrinsicsM,
+                 subgroupsM, intrinsicsN, subgroupsN, intrinsicsK},
+      fnDefAttrs, stridedDims);
+  return success();
+}
+
 std::optional<LogicalResult> UKernelProviderAttr::createAndReplaceWithUkernelOp(
     RewriterBase &rewriter, StringRef name, DictionaryAttr targetConfiguration,
-    Operation *contextualOp, SmallVectorImpl<Value> &inputs,
-    SmallVectorImpl<Value> &outputs,
+    Operation *contextualOp, ArrayRef<Value> inputs, ArrayRef<Value> outputs,
     SmallVectorImpl<Value> &otherOperands) const {
   if (name.contains("argmax")) {
     return handleArgmaxUkernel(rewriter, name, targetConfiguration,
                                contextualOp, inputs, outputs, otherOperands);
+  } else if (name.contains("multi_mma_mfma")) {
+    return handleInnerTiledMmaUkernel(rewriter, name, targetConfiguration,
+                                      contextualOp, inputs, outputs,
+                                      otherOperands);
   }
-  // TODO(avarma): Add multi_mfma ukernel support via descriptors.
   return std::nullopt;
 }
 
diff --git a/compiler/plugins/target/ROCM/test/lower_rocm_ukernel_descriptor.mlir b/compiler/plugins/target/ROCM/test/lower_rocm_ukernel_descriptor.mlir
index ba1c76b..aceb7b0 100644
--- a/compiler/plugins/target/ROCM/test/lower_rocm_ukernel_descriptor.mlir
+++ b/compiler/plugins/target/ROCM/test/lower_rocm_ukernel_descriptor.mlir
@@ -67,3 +67,81 @@
     return %0#0, %0#1 : tensor<f32>, tensor<i64>
   }
 }
+
+// -----
+
+// CHECK-LABEL:       @multi_mma_mfma_i32_16x16x32_i8_with_gpu_arch
+// CHECK-SAME:          %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x2x8x4x16x2x8xi8>
+// CHECK-SAME:          %[[ARG1:[a-zA-Z0-9]+]]: tensor<1x2x4x2x4x16x2x8xi8>
+// CHECK-SAME:          %[[ARG2:[a-zA-Z0-9]+]]: tensor<1x1x8x4x2x4x16x4xi32>
+// CHECK:               %[[ALLOC:.*]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<8192xi8>
+// CHECK:               %[[C1_INDEX:.*]] = arith.constant 1 : index
+// CHECK:               %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C1_INDEX]] : tensor<1x2x8x4x16x2x8xi8>
+// CHECK:               %[[DIM_CAST:.*]] = arith.index_cast %[[DIM]] : index to i32
+// CHECK:               %[[C8:.*]] = arith.constant 8 : i32
+// CHECK:               %[[C1:.*]] = arith.constant 1 : i32
+// CHECK:               %[[C2:.*]] = arith.constant 2 : i32
+// CHECK:               %[[C4:.*]] = arith.constant 4 : i32
+// CHECK:               %[[C2_1:.*]] = arith.constant 2 : i32
+// CHECK:               %[[C8192:.*]] = arith.constant 8192 : i32
+// CHECK:               %[[UK_GENERIC:.*]] = iree_codegen.ukernel.generic "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8"
+// CHECK-SAME:            ins(%[[ARG0]], %[[ARG1]] : tensor<1x2x8x4x16x2x8xi8>, tensor<1x2x4x2x4x16x2x8xi8>)
+// CHECK-SAME:            outs(%[[ARG2]] : tensor<1x1x8x4x2x4x16x4xi32>)
+// CHECK-SAME:            (%[[ALLOC]], %[[C8192]], %[[DIM_CAST]], %[[C8]], %[[C1]], %[[C2]], %[[C4]], %[[C2_1]] : tensor<8192xi8>, i32, i32, i32, i32, i32, i32, i32)
+// CHECK-SAME:            fn_def_attrs {vm.import.module = "rocm"}
+// CHECK-SAME{LITERAL}:   strided_dims([[], [], [4], []])
+// CHECK:               return %[[UK_GENERIC]] : tensor<1x1x8x4x2x4x16x4xi32>
+#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree_codegen.default_tuning_spec = #rocm.builtin.tuning_module<"iree_default_tuning_spec_gfx942.mlir">, iree_codegen.target_info = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, iree_codegen.ukernel_provider = #rocm.ukernel_provider, ukernels = "all"}>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+module attributes {hal.executable.target = #executable_target_rocm_hsaco_fb} {
+  func.func @multi_mma_mfma_i32_16x16x32_i8_with_gpu_arch(%arg0: tensor<1x2x8x4x16x2x8xi8>, %arg1: tensor<1x2x4x2x4x16x2x8xi8>, %arg2: tensor<1x1x8x4x2x4x16x4xi32>) -> tensor<1x1x8x4x2x4x16x4xi32> {
+    %0 = iree_codegen.inner_tiled ins(%arg0, %arg1) outs(%arg2) {
+      indexing_maps = [#map, #map1, #map2],
+      iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8", bitcode>,
+      iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+      kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, intrinsics_m = 8, intrinsics_n = 2, subgroups_n = 4, intrinsics_k = 2>
+    } : tensor<1x2x8x4x16x2x8xi8>, tensor<1x2x4x2x4x16x2x8xi8> into tensor<1x1x8x4x2x4x16x4xi32>
+    return %0 : tensor<1x1x8x4x2x4x16x4xi32>
+  }
+}
+
+// -----
+
+// CHECK-LABEL:       @multi_mma_mfma_i32_16x16x32_i8_without_gpu_arch
+// CHECK-SAME:          %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x2x8x4x16x2x8xi8>
+// CHECK-SAME:          %[[ARG1:[a-zA-Z0-9]+]]: tensor<1x2x4x2x4x16x2x8xi8>
+// CHECK-SAME:          %[[ARG2:[a-zA-Z0-9]+]]: tensor<1x1x8x4x2x4x16x4xi32>
+// CHECK:               %[[NULL:.*]] = iree_codegen.null_pointer
+// CHECK:               %[[C1_INDEX:.*]] = arith.constant 1 : index
+// CHECK:               %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C1_INDEX]] : tensor<1x2x8x4x16x2x8xi8>
+// CHECK:               %[[DIM_CAST:.*]] = arith.index_cast %[[DIM]] : index to i32
+// CHECK:               %[[C8:.*]] = arith.constant 8 : i32
+// CHECK:               %[[C1:.*]] = arith.constant 1 : i32
+// CHECK:               %[[C2:.*]] = arith.constant 2 : i32
+// CHECK:               %[[C4:.*]] = arith.constant 4 : i32
+// CHECK:               %[[C2_1:.*]] = arith.constant 2 : i32
+// CHECK:               %[[C0:.*]] = arith.constant 0 : i32
+// CHECK:               %[[UK_GENERIC:.*]] = iree_codegen.ukernel.generic "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8"
+// CHECK-SAME:            ins(%[[ARG0]], %[[ARG1]] : tensor<1x2x8x4x16x2x8xi8>, tensor<1x2x4x2x4x16x2x8xi8>)
+// CHECK-SAME:            outs(%[[ARG2]] : tensor<1x1x8x4x2x4x16x4xi32>)
+// CHECK-SAME:            (%[[NULL]], %[[C0]], %[[DIM_CAST]], %[[C8]], %[[C1]], %[[C2]], %[[C4]], %[[C2_1]] : !iree_codegen.null_pointer, i32, i32, i32, i32, i32, i32, i32)
+// CHECK-SAME:            fn_def_attrs {vm.import.module = "rocm"}
+// CHECK-SAME{LITERAL}:   strided_dims([[], [], [4]])
+// CHECK:               return %[[UK_GENERIC]] : tensor<1x1x8x4x2x4x16x4xi32>
+#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree_codegen.ukernel_provider = #rocm.ukernel_provider, ukernels = "all"}>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+module attributes {hal.executable.target = #executable_target_rocm_hsaco_fb} {
+  func.func @multi_mma_mfma_i32_16x16x32_i8_without_gpu_arch(%arg0: tensor<1x2x8x4x16x2x8xi8>, %arg1: tensor<1x2x4x2x4x16x2x8xi8>, %arg2: tensor<1x1x8x4x2x4x16x4xi32>) -> tensor<1x1x8x4x2x4x16x4xi32> {
+    %0 = iree_codegen.inner_tiled ins(%arg0, %arg1) outs(%arg2) {
+      indexing_maps = [#map, #map1, #map2],
+      iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8", bitcode>,
+      iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+      kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, intrinsics_m = 8, intrinsics_n = 2, subgroups_n = 4, intrinsics_k = 2>
+    } : tensor<1x2x8x4x16x2x8xi8>, tensor<1x2x4x2x4x16x2x8xi8> into tensor<1x1x8x4x2x4x16x4xi32>
+    return %0 : tensor<1x1x8x4x2x4x16x4xi32>
+  }
+}
diff --git a/compiler/src/iree/compiler/Codegen/Common/LowerUKernelDescriptors.cpp b/compiler/src/iree/compiler/Codegen/Common/LowerUKernelDescriptors.cpp
index 0d94832..f35fd48 100644
--- a/compiler/src/iree/compiler/Codegen/Common/LowerUKernelDescriptors.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/LowerUKernelDescriptors.cpp
@@ -13,6 +13,8 @@
 #include "iree/compiler/Codegen/Interfaces/UKernelOpInterface.h"
 #include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
 #include "llvm/Support/FormatVariadic.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -31,6 +33,9 @@
 namespace {
 struct LowerBitcodeUKernelsPass final
     : impl::LowerBitcodeUKernelsPassBase<LowerBitcodeUKernelsPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<bufferization::BufferizationDialect, gpu::GPUDialect>();
+  }
   void runOnOperation() override;
 };
 struct LowerMemrefUKernelsPass final
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td
index d2211d1..c07841b 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td
@@ -567,8 +567,8 @@
                     "::mlir::StringRef":$name,
                     "::mlir::DictionaryAttr":$target_configuration,
                     "::mlir::Operation *":$contextual_op,
-                    "::llvm::SmallVectorImpl<::mlir::Value>&":$inputs,
-                    "::llvm::SmallVectorImpl<::mlir::Value>&":$outputs,
+                    "::llvm::ArrayRef<::mlir::Value>":$inputs,
+                    "::llvm::ArrayRef<::mlir::Value>":$outputs,
                     "::llvm::SmallVectorImpl<::mlir::Value>&":$other_operands),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{