[ROCM] Update Ukernel infra to handle InnerTiledOp/Multi_MMA_MFMA (#21759)
-- This commit updates ROCM to handle ukernel lowering for InnerTiledOp
(multi_mma_mfma).
Signed-off-by: Abhishek Varma
[abhvarma@amd.com](mailto:abhvarma@amd.com)
diff --git a/compiler/plugins/target/ROCM/Dialect/ROCM/IR/BUILD.bazel b/compiler/plugins/target/ROCM/Dialect/ROCM/IR/BUILD.bazel
index da69584..ead5fdb 100644
--- a/compiler/plugins/target/ROCM/Dialect/ROCM/IR/BUILD.bazel
+++ b/compiler/plugins/target/ROCM/Dialect/ROCM/IR/BUILD.bazel
@@ -60,10 +60,22 @@
"//compiler/plugins/target/ROCM/builtins/specialization:iree_specialization_patterns_amdgpu",
"//compiler/plugins/target/ROCM/builtins/tuning:iree_default_tuning_specs_amdgpu",
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
+ "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils",
+ "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
+ "//compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms:GPUTransforms",
+ "//compiler/src/iree/compiler/Codegen/Utils",
+ "//compiler/src/iree/compiler/Dialect/HAL/IR",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"//compiler/src/iree/compiler/Utils",
+ "@llvm-project//llvm:ExecutionEngine",
+ "@llvm-project//llvm:IRReader",
"@llvm-project//llvm:Support",
+ "@llvm-project//mlir:BufferizationDialect",
"@llvm-project//mlir:DialectUtils",
+ "@llvm-project//mlir:GPUDialect",
+ "@llvm-project//mlir:GPUTransformOps",
+ "@llvm-project//mlir:GPUTransforms",
+ "@llvm-project//mlir:GPUUtils",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:Parser",
diff --git a/compiler/plugins/target/ROCM/Dialect/ROCM/IR/CMakeLists.txt b/compiler/plugins/target/ROCM/Dialect/ROCM/IR/CMakeLists.txt
index 1f8fafe..f6f893f 100644
--- a/compiler/plugins/target/ROCM/Dialect/ROCM/IR/CMakeLists.txt
+++ b/compiler/plugins/target/ROCM/Dialect/ROCM/IR/CMakeLists.txt
@@ -28,12 +28,24 @@
DEPS
::ROCMAttrs
::ROCMDialectGen
+ LLVMExecutionEngine
+ LLVMIRReader
LLVMSupport
+ MLIRBufferizationDialect
+ MLIRGPUDialect
+ MLIRGPUTransformOps
+ MLIRGPUTransforms
+ MLIRGPUUtils
MLIRIR
MLIRLinalgDialect
MLIRParser
MLIRSupport
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
+ iree::compiler::Codegen::Dialect::Codegen::Utils
+ iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
+ iree::compiler::Codegen::Dialect::GPU::Transforms::GPUTransforms
+ iree::compiler::Codegen::Utils
+ iree::compiler::Dialect::HAL::IR
iree::compiler::Dialect::Util::IR
iree::compiler::Utils
iree::compiler::plugins::target::ROCM::builtins::mlir_ukernel::iree_mlir_ukernel_patterns_amdgpu
diff --git a/compiler/plugins/target/ROCM/Dialect/ROCM/IR/ROCMAttrs.cpp b/compiler/plugins/target/ROCM/Dialect/ROCM/IR/ROCMAttrs.cpp
index b2817a1..2421461 100644
--- a/compiler/plugins/target/ROCM/Dialect/ROCM/IR/ROCMAttrs.cpp
+++ b/compiler/plugins/target/ROCM/Dialect/ROCM/IR/ROCMAttrs.cpp
@@ -6,9 +6,22 @@
#include "compiler/plugins/target/ROCM/Dialect/ROCM/IR/ROCMAttrs.h"
#include "compiler/plugins/target/ROCM/Dialect/ROCM/IR/ROCMDialect.h"
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h"
+#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
+#include "iree/compiler/Codegen/Utils/GPUUtils.h"
+#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/ExecutionEngine/ExecutionEngine.h"
+#include "llvm/ExecutionEngine/GenericValue.h"
+#include "llvm/IRReader/IRReader.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/AsmState.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpDefinition.h"
@@ -34,10 +47,11 @@
//===----------------------------------------------------------------------===//
/// Utility function to help create and replace argmax linalg with a ukernel.
-static LogicalResult handleArgmaxUkernel(
- RewriterBase &rewriter, StringRef name, DictionaryAttr targetConfiguration,
- Operation *contextualOp, SmallVectorImpl<Value> &inputs,
- SmallVectorImpl<Value> &outputs, SmallVectorImpl<Value> &otherOperands) {
+static LogicalResult
+handleArgmaxUkernel(RewriterBase &rewriter, StringRef name,
+ DictionaryAttr targetConfiguration, Operation *contextualOp,
+ ArrayRef<Value> inputs, ArrayRef<Value> outputs,
+ SmallVectorImpl<Value> &otherOperands) {
auto genericOp = dyn_cast<linalg::GenericOp>(contextualOp);
if (!genericOp) {
return rewriter.notifyMatchFailure(
@@ -79,16 +93,372 @@
return success();
}
+constexpr char executableObjectsAttrName[] = "hal.executable.objects";
+
+// Walks parents ops from `op` to return the nearest hal.executable.objects
+// array attribute. If the parent hal.executable.variant is reached, its objects
+// attribute is returned.
+// Adapted from ExecutableTargetAttr::lookup.
+static ArrayAttr lookUpExecutableObjects(Operation *op) {
+ MLIRContext *context = op->getContext();
+ auto attrId = StringAttr::get(context, executableObjectsAttrName);
+ while (op) {
+ // Take directly from the enclosing variant.
+ if (auto variantOp = dyn_cast<IREE::HAL::ExecutableVariantOp>(op)) {
+ if (std::optional<ArrayAttr> objects = variantOp.getObjects()) {
+ return *objects;
+ }
+ }
+ // Take from op attributes.
+ if (auto attr = op->getAttrOfType<ArrayAttr>(attrId)) {
+ return attr;
+ }
+ // Continue walk.
+ op = op->getParentOp();
+ }
+ return {};
+}
+
+static Value createSharedMemory(RewriterBase &rewriter, Location loc,
+ int64_t sharedMemoryBytes) {
+ RankedTensorType tensorType =
+ RankedTensorType::get({sharedMemoryBytes}, rewriter.getI8Type());
+ ValueRange dynSizes{};
+ if (!sharedMemoryBytes) {
+ IREE::Codegen::NullPointerType nullPointerType =
+ IREE::Codegen::NullPointerType::get(rewriter.getContext());
+ return rewriter.create<IREE::Codegen::NullPointerOp>(loc, nullPointerType);
+ }
+ auto allocOp =
+ rewriter.create<bufferization::AllocTensorOp>(loc, tensorType, dynSizes);
+ Attribute sharedAddrSpace = gpu::AddressSpaceAttr::get(
+ rewriter.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
+ allocOp.setMemorySpaceAttr(sharedAddrSpace);
+ return allocOp;
+}
+
+// Returns the index of the innermost CrossIntrinsic dimension of the C matrix,
+// if it is static, and std::nullopt if it is dynamic or if there are no
+// CrossIntrinsic dims.
+static std::optional<unsigned>
+getCInnermostStaticCrossIntrinsicDim(IREE::Codegen::InnerTiledOp op) {
+ auto outputType = dyn_cast<ShapedType>(op.getResultTypes()[0]);
+ if (!outputType) {
+ return std::nullopt;
+ }
+ auto mma = cast<IREE::GPU::DataTiledMMAAttr>(op.getKind());
+ IREE::Codegen::TileSwizzle accSwizzle =
+ getSwizzle(mma, IREE::GPU::MMAFragment::Acc);
+ SmallVector<IREE::Codegen::TileSwizzle::Dim> swizzleDims;
+ for (IREE::Codegen::TileSwizzle::ExpandShapeDimVectorType group :
+ accSwizzle.expandShape) {
+ swizzleDims.append(group);
+ }
+ applyPermutationToVector(swizzleDims, accSwizzle.permutation);
+ int rankDiff = outputType.getRank() - swizzleDims.size();
+ auto crossIntrinsic = IREE::Codegen::TileSwizzle::Dim::Kind::CrossIntrinsic;
+ for (size_t e = swizzleDims.size(), swizzleIdx = e - 1; swizzleIdx < e;
+ --swizzleIdx) {
+ if (swizzleDims[swizzleIdx].kind != crossIntrinsic) {
+ continue;
+ }
+ int outputIdx = swizzleIdx + rankDiff;
+ if (outputType.isDynamicDim(outputIdx)) {
+ return std::nullopt;
+ }
+ return outputIdx;
+ }
+ return std::nullopt;
+}
+
+static int64_t getSharedMemoryBytes(IREE::GPU::TargetAttr gpuTarget) {
+ if (!gpuTarget) {
+ return 0;
+ }
+ IREE::GPU::TargetWgpAttr wgp = gpuTarget.getWgp();
+ if (!wgp) {
+ return 0;
+ }
+ return wgp.getMaxWorkgroupMemoryBytes();
+}
+
+// Returns a ExecutableObjectAttr carrying the bitcode for the given ukernel.
+//
+// First tries finding the bitcode in the input `sourceExecutableObjects`, which
+// must be an array of ExecutableObjectAttr's and is typically coming from a
+// hal.executable.objects array attribute in the source IR, which is the
+// mechanism by which source programs may provide their own ukernel bitcode.
+//
+// If no matching bitcode was found in `sourceExecutableObjects`, this function
+// will then search in bitcode files that we have embedded as static data.
+static IREE::HAL::ExecutableObjectAttr
+getUKernelBitcode(MLIRContext *context,
+ IREE::HAL::ExecutableTargetAttr execTarget,
+ ArrayAttr sourceExecutableObjects, StringRef filename) {
+ // Early-return if the source executable.objects already contain an object
+ // with the expected file name. This happens with user-provided bitcode in the
+ // source IR.
+ if (sourceExecutableObjects) {
+ for (Attribute a : sourceExecutableObjects) {
+ if (auto object = dyn_cast<IREE::HAL::ExecutableObjectAttr>(a)) {
+ if (object.getPath() == filename) {
+ return object;
+ }
+ }
+ }
+ }
+
+ // No user-provided bitcode, so we search our embedded bitcode files in the
+ // EmbeddedDataDirectory singleton.
+ std::optional<StringRef> bitcode;
+ EmbeddedDataDirectory::withGlobal(
+ [&](EmbeddedDataDirectory &dir) { bitcode = dir.getFile(filename); });
+ if (!bitcode) {
+ return {};
+ }
+ AsmResourceBlob blob = HeapAsmResourceBlob::allocateAndCopyInferAlign(
+ ArrayRef<char>(bitcode->data(), bitcode->size()));
+ auto bitcodeDenseAttr = DenseI8ResourceElementsAttr::get(
+ VectorType::get({static_cast<int64_t>(bitcode->size())},
+ IntegerType::get(context, 8)),
+ filename, std::move(blob));
+ return IREE::HAL::ExecutableObjectAttr::get(
+ context, StringAttr::get(context, filename),
+ cast<IREE::Util::SerializableAttrInterface>(bitcodeDenseAttr));
+}
+
+static std::string getBitcodeFilename(IREE::GPU::TargetAttr gpuTarget,
+ StringRef name) {
+ return llvm::formatv("{}.{}.bc", name, gpuTarget.getArch());
+}
+
+// Helper for getSharedMemoryBytes. Typical latency: 2 ms.
+// Evaluates the shared memory size required by the multi_mma microkernel by
+// interpreting a bitcode function with a specific name.
+// On failure, an op warning is emitted and {} is returned.
+static std::optional<int64_t> expensivelyEvaluateSharedMemoryBytes(
+ MLIRContext *context, IREE::Codegen::InnerTiledOp op, StringRef ukernelName,
+ IREE::GPU::TargetAttr gpuTarget) {
+ auto target = IREE::HAL::ExecutableTargetAttr::lookup(op);
+ std::string filename = getBitcodeFilename(gpuTarget, ukernelName);
+ ArrayAttr sourceExecutableObjects = lookUpExecutableObjects(op);
+ IREE::HAL::ExecutableObjectAttr bitcodeObject =
+ getUKernelBitcode(context, target, sourceExecutableObjects, filename);
+
+ auto mma = dyn_cast<IREE::GPU::DataTiledMMAAttr>(op.getKind());
+
+ IREE::Util::SerializableAttrInterface bitcodeData = bitcodeObject.getData();
+ std::string buffer;
+ buffer.resize(bitcodeData.getStorageSize());
+ if (failed(bitcodeObject.getData().serializeToBuffer(
+ op->getLoc(), llvm::endianness::native,
+ ArrayRef<char>{buffer.data(), buffer.size()}))) {
+ op.emitWarning("Failed to serialize bitcode.");
+ return {};
+ }
+ llvm::LLVMContext llvmContext;
+ llvm::Expected<std::unique_ptr<llvm::Module>> moduleOp =
+ llvm::getLazyBitcodeModule(llvm::MemoryBufferRef{buffer, ukernelName},
+ llvmContext,
+ /*ShouldLazyLoadMetadata=*/true);
+ if (!moduleOp) {
+ op.emitWarning("Failed to parse bitcode module.");
+ return {};
+ }
+ llvm::EngineBuilder builder(std::move(moduleOp.get()));
+ std::string builderError;
+ builder.setEngineKind(llvm::EngineKind::Interpreter)
+ .setErrorStr(&builderError);
+ std::unique_ptr<llvm::ExecutionEngine> interpreter{builder.create()};
+ if (!interpreter) {
+ op.emitWarning("Failed to create the interpreter.");
+ return {};
+ }
+ std::string queryFuncName =
+ llvm::formatv("{}_query_shared_memory_bytes", ukernelName);
+ llvm::Function *func = interpreter->FindFunctionNamed(queryFuncName);
+ if (!func) {
+ op.emitWarning(llvm::formatv(
+ "Bitcode does not contain a function named {}.", queryFuncName));
+ return {};
+ }
+ auto constI32 = [](int32_t val) {
+ llvm::GenericValue v;
+ v.IntVal = APInt(32, val);
+ return v;
+ };
+ llvm::GenericValue args[] = {
+ constI32(mma.getIntrinsicsM()), constI32(mma.getSubgroupsM()),
+ constI32(mma.getIntrinsicsN()), constI32(mma.getSubgroupsN()),
+ constI32(mma.getIntrinsicsK())};
+ if (func->arg_size() != /*total elements in 'args'=*/5) {
+ op.emitWarning(llvm::formatv(
+ "Bitcode function {} takes {} arguments. Expected {}.", queryFuncName,
+ func->arg_size(), /*total elements in 'args'=*/5));
+ return {};
+ }
+ llvm::GenericValue interpreterResult = interpreter->runFunction(func, args);
+ if (interpreter->hasError()) {
+ op.emitWarning(llvm::formatv("Error while interpreting bitcode: {}.",
+ interpreter->getErrorMessage()));
+ return {};
+ }
+ int64_t sharedMemoryBytes = interpreterResult.IntVal.getSExtValue();
+
+ // Reject a ukernel that would consume too much shared memory, which we need
+ // to save for other purposes. This threshold can always be adjusted but we
+ // default to a low threshold to get an early signal.
+ int64_t maxSharedMemoryBytes = getSharedMemoryBytes(gpuTarget) / 4;
+ if (sharedMemoryBytes > maxSharedMemoryBytes) {
+ op.emitWarning(llvm::formatv("The shared memory size {} required by the "
+ "ukernel exceeds the maximum allowed size {}.",
+ sharedMemoryBytes, maxSharedMemoryBytes));
+ return {};
+ }
+ return sharedMemoryBytes;
+}
+
+// Returns the shared memory size required by the multi_mma ukernel.
+// On failure, an op warning is emitted and {} is returned.
+// Uses a static cache to avoid calling expensivelyEvaluateSharedMemoryBytes
+// more than once per DataTiledMMAAttr value.
+static std::optional<int64_t>
+getSharedMemoryBytes(MLIRContext *context, IREE::Codegen::InnerTiledOp op,
+ StringRef ukernelName,
+ DictionaryAttr targetConfiguration) {
+ auto mma = dyn_cast<IREE::GPU::DataTiledMMAAttr>(op.getKind());
+ IREE::GPU::TargetAttr gpuTarget = getGPUTargetAttr(targetConfiguration);
+ if (!gpuTarget) {
+ return {};
+ }
+ // We use the stringification of the attributes, rather than the
+ // attributes themselves, as the key, to ensure it's self-contained and does
+ // not contain pointers to other objects, such as a `MLIRContext*`, which
+ // could go dangling.
+ std::string key = llvm::formatv("mma = {}, gpuTarget = {}", mma, gpuTarget);
+
+ struct CacheEntry {
+ std::optional<int64_t> sharedMemoryBytes;
+ std::mutex mutex;
+ bool evaluated = false;
+ };
+
+ // The cache and the mutex guarding it.
+ // We store the CacheEntry's by pointers, so that we don't need to worry about
+ // entryPtr being invalidated.
+ static llvm::StringMap<std::unique_ptr<CacheEntry>> cache;
+ static std::mutex cacheMutex;
+
+ CacheEntry *entryPtr = nullptr;
+
+ {
+ // Critical section on `cacheMutex`. This is the only place where we
+ // access `cache`. When we will later update a cache entry, that will be
+ // through `entryPtr`, independently of `cache`.
+ std::lock_guard<std::mutex> lock(cacheMutex);
+ auto iter = cache.find(key);
+ if (iter != cache.end()) {
+ // Cache hit. Early return.
+ return iter->second->sharedMemoryBytes;
+ }
+ // Cache miss. Create a new cache entry and acquire its mutex.
+ entryPtr =
+ cache.insert({key, std::make_unique<CacheEntry>()}).first->second.get();
+ entryPtr->mutex.lock();
+ }
+
+ // If the entry still isn't evaluated after we have acquired its mutex,
+ // perform the evaluation now.
+ if (!entryPtr->evaluated) {
+ entryPtr->sharedMemoryBytes = expensivelyEvaluateSharedMemoryBytes(
+ context, op, ukernelName, gpuTarget);
+ entryPtr->evaluated = true;
+ }
+
+ entryPtr->mutex.unlock();
+ return entryPtr->sharedMemoryBytes;
+}
+
+/// Utility function to help create and replace inner_tiled with a ukernel.
+static LogicalResult handleInnerTiledMmaUkernel(
+ RewriterBase &rewriter, StringRef name, DictionaryAttr targetConfiguration,
+ Operation *contextualOp, ArrayRef<Value> inputs, ArrayRef<Value> outputs,
+ SmallVectorImpl<Value> &otherOperands) {
+ auto op = dyn_cast<IREE::Codegen::InnerTiledOp>(contextualOp);
+ if (!op) {
+ return rewriter.notifyMatchFailure(
+ contextualOp, "expected a codegen.inner_tiled op for multi_mma");
+ }
+ auto mma = dyn_cast<IREE::GPU::DataTiledMMAAttr>(op.getKind());
+ if (!mma) {
+ return rewriter.notifyMatchFailure(op, "unhandled MMAInterfaceAttr");
+ }
+ std::optional<int64_t> innerCrossIntrinsicDim =
+ getCInnermostStaticCrossIntrinsicDim(op);
+ if (!innerCrossIntrinsicDim) {
+ return rewriter.notifyMatchFailure(
+ op, "inner cross-intrinsic dim is dynamic or not found");
+ }
+ Location loc = op->getLoc();
+ Type I32Type = rewriter.getI32Type();
+ auto castIndexToI32 = [&](Value val) {
+ return rewriter.create<arith::IndexCastOp>(loc, I32Type, val);
+ };
+ auto constI32 = [&](int val) {
+ return rewriter.create<arith::ConstantIntOp>(loc, I32Type, val);
+ };
+ MLIRContext *context = rewriter.getContext();
+ std::optional<int64_t> maybeSharedMemoryBytes =
+ getSharedMemoryBytes(context, op, name, targetConfiguration);
+ int64_t sharedMemoryBytes =
+ (!maybeSharedMemoryBytes) ? 0 : maybeSharedMemoryBytes.value();
+ Value sharedMemory = createSharedMemory(rewriter, loc, sharedMemoryBytes);
+ Value k = castIndexToI32(
+ rewriter.create<tensor::DimOp>(op.getLoc(), op.getInputs()[0], 1));
+ Value intrinsicsM = constI32(mma.getIntrinsicsM());
+ Value subgroupsM = constI32(mma.getSubgroupsM());
+ Value intrinsicsN = constI32(mma.getIntrinsicsN());
+ Value subgroupsN = constI32(mma.getSubgroupsN());
+ Value intrinsicsK = constI32(mma.getIntrinsicsK());
+ // There are 3 shaped input/output operands (A/B/C matrices).
+ SmallVector<SmallVector<int64_t>> stridedDims(3, {});
+ // Only the C matrix gets strides, and we pass the stride of the innermost
+ // CrossIntrinsic dim, because the ukernel needs to know where to store the
+ // result vector from each unrolled intrinsic. Offsets into all other
+ // dimensions are handled by the compiler, and passed as part of the base
+ // pointer + offset. The A and B matrices don't get strides, because we
+ // expect them to always be passed as global memory pointers, and the
+ // strides can be inferred by the ukernel implementation.
+ stridedDims[2].push_back(innerCrossIntrinsicDim.value());
+ // The only additional shaped operand is the shared memory buffer. Only
+ // create a stride list for it if we have shared memory. Otherwise, the
+ // operand is an iree_codegen.null_pointer op.
+ if (sharedMemoryBytes != 0) {
+ // Shared memory does not need strides.
+ stridedDims.push_back({});
+ }
+ auto fnDefAttrs = DictionaryAttr::get(
+ context, {{"vm.import.module", StringAttr::get(context, "rocm")}});
+ rewriter.replaceOpWithNewOp<IREE::Codegen::UKernelGenericOp>(
+ op, op.getOutputs().getTypes(), name, inputs, outputs,
+ ValueRange{sharedMemory, constI32(sharedMemoryBytes), k, intrinsicsM,
+ subgroupsM, intrinsicsN, subgroupsN, intrinsicsK},
+ fnDefAttrs, stridedDims);
+ return success();
+}
+
std::optional<LogicalResult> UKernelProviderAttr::createAndReplaceWithUkernelOp(
RewriterBase &rewriter, StringRef name, DictionaryAttr targetConfiguration,
- Operation *contextualOp, SmallVectorImpl<Value> &inputs,
- SmallVectorImpl<Value> &outputs,
+ Operation *contextualOp, ArrayRef<Value> inputs, ArrayRef<Value> outputs,
SmallVectorImpl<Value> &otherOperands) const {
if (name.contains("argmax")) {
return handleArgmaxUkernel(rewriter, name, targetConfiguration,
contextualOp, inputs, outputs, otherOperands);
+ } else if (name.contains("multi_mma_mfma")) {
+ return handleInnerTiledMmaUkernel(rewriter, name, targetConfiguration,
+ contextualOp, inputs, outputs,
+ otherOperands);
}
- // TODO(avarma): Add multi_mfma ukernel support via descriptors.
return std::nullopt;
}
diff --git a/compiler/plugins/target/ROCM/test/lower_rocm_ukernel_descriptor.mlir b/compiler/plugins/target/ROCM/test/lower_rocm_ukernel_descriptor.mlir
index ba1c76b..aceb7b0 100644
--- a/compiler/plugins/target/ROCM/test/lower_rocm_ukernel_descriptor.mlir
+++ b/compiler/plugins/target/ROCM/test/lower_rocm_ukernel_descriptor.mlir
@@ -67,3 +67,81 @@
return %0#0, %0#1 : tensor<f32>, tensor<i64>
}
}
+
+// -----
+
+// CHECK-LABEL: @multi_mma_mfma_i32_16x16x32_i8_with_gpu_arch
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x2x8x4x16x2x8xi8>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<1x2x4x2x4x16x2x8xi8>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<1x1x8x4x2x4x16x4xi32>
+// CHECK: %[[ALLOC:.*]] = bufferization.alloc_tensor() {memory_space = #gpu.address_space<workgroup>} : tensor<8192xi8>
+// CHECK: %[[C1_INDEX:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C1_INDEX]] : tensor<1x2x8x4x16x2x8xi8>
+// CHECK: %[[DIM_CAST:.*]] = arith.index_cast %[[DIM]] : index to i32
+// CHECK: %[[C8:.*]] = arith.constant 8 : i32
+// CHECK: %[[C1:.*]] = arith.constant 1 : i32
+// CHECK: %[[C2:.*]] = arith.constant 2 : i32
+// CHECK: %[[C4:.*]] = arith.constant 4 : i32
+// CHECK: %[[C2_1:.*]] = arith.constant 2 : i32
+// CHECK: %[[C8192:.*]] = arith.constant 8192 : i32
+// CHECK: %[[UK_GENERIC:.*]] = iree_codegen.ukernel.generic "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8"
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<1x2x8x4x16x2x8xi8>, tensor<1x2x4x2x4x16x2x8xi8>)
+// CHECK-SAME: outs(%[[ARG2]] : tensor<1x1x8x4x2x4x16x4xi32>)
+// CHECK-SAME: (%[[ALLOC]], %[[C8192]], %[[DIM_CAST]], %[[C8]], %[[C1]], %[[C2]], %[[C4]], %[[C2_1]] : tensor<8192xi8>, i32, i32, i32, i32, i32, i32, i32)
+// CHECK-SAME: fn_def_attrs {vm.import.module = "rocm"}
+// CHECK-SAME{LITERAL}: strided_dims([[], [], [4], []])
+// CHECK: return %[[UK_GENERIC]] : tensor<1x1x8x4x2x4x16x4xi32>
+#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree_codegen.default_tuning_spec = #rocm.builtin.tuning_module<"iree_default_tuning_spec_gfx942.mlir">, iree_codegen.target_info = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, iree_codegen.ukernel_provider = #rocm.ukernel_provider, ukernels = "all"}>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+module attributes {hal.executable.target = #executable_target_rocm_hsaco_fb} {
+ func.func @multi_mma_mfma_i32_16x16x32_i8_with_gpu_arch(%arg0: tensor<1x2x8x4x16x2x8xi8>, %arg1: tensor<1x2x4x2x4x16x2x8xi8>, %arg2: tensor<1x1x8x4x2x4x16x4xi32>) -> tensor<1x1x8x4x2x4x16x4xi32> {
+ %0 = iree_codegen.inner_tiled ins(%arg0, %arg1) outs(%arg2) {
+ indexing_maps = [#map, #map1, #map2],
+ iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8", bitcode>,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, intrinsics_m = 8, intrinsics_n = 2, subgroups_n = 4, intrinsics_k = 2>
+ } : tensor<1x2x8x4x16x2x8xi8>, tensor<1x2x4x2x4x16x2x8xi8> into tensor<1x1x8x4x2x4x16x4xi32>
+ return %0 : tensor<1x1x8x4x2x4x16x4xi32>
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @multi_mma_mfma_i32_16x16x32_i8_without_gpu_arch
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x2x8x4x16x2x8xi8>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<1x2x4x2x4x16x2x8xi8>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<1x1x8x4x2x4x16x4xi32>
+// CHECK: %[[NULL:.*]] = iree_codegen.null_pointer
+// CHECK: %[[C1_INDEX:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C1_INDEX]] : tensor<1x2x8x4x16x2x8xi8>
+// CHECK: %[[DIM_CAST:.*]] = arith.index_cast %[[DIM]] : index to i32
+// CHECK: %[[C8:.*]] = arith.constant 8 : i32
+// CHECK: %[[C1:.*]] = arith.constant 1 : i32
+// CHECK: %[[C2:.*]] = arith.constant 2 : i32
+// CHECK: %[[C4:.*]] = arith.constant 4 : i32
+// CHECK: %[[C2_1:.*]] = arith.constant 2 : i32
+// CHECK: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK: %[[UK_GENERIC:.*]] = iree_codegen.ukernel.generic "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8"
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<1x2x8x4x16x2x8xi8>, tensor<1x2x4x2x4x16x2x8xi8>)
+// CHECK-SAME: outs(%[[ARG2]] : tensor<1x1x8x4x2x4x16x4xi32>)
+// CHECK-SAME: (%[[NULL]], %[[C0]], %[[DIM_CAST]], %[[C8]], %[[C1]], %[[C2]], %[[C4]], %[[C2_1]] : !iree_codegen.null_pointer, i32, i32, i32, i32, i32, i32, i32)
+// CHECK-SAME: fn_def_attrs {vm.import.module = "rocm"}
+// CHECK-SAME{LITERAL}: strided_dims([[], [], [4]])
+// CHECK: return %[[UK_GENERIC]] : tensor<1x1x8x4x2x4x16x4xi32>
+#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree_codegen.ukernel_provider = #rocm.ukernel_provider, ukernels = "all"}>
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+module attributes {hal.executable.target = #executable_target_rocm_hsaco_fb} {
+ func.func @multi_mma_mfma_i32_16x16x32_i8_without_gpu_arch(%arg0: tensor<1x2x8x4x16x2x8xi8>, %arg1: tensor<1x2x4x2x4x16x2x8xi8>, %arg2: tensor<1x1x8x4x2x4x16x4xi32>) -> tensor<1x1x8x4x2x4x16x4xi32> {
+ %0 = iree_codegen.inner_tiled ins(%arg0, %arg1) outs(%arg2) {
+ indexing_maps = [#map, #map1, #map2],
+ iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8", bitcode>,
+ iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
+ kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, intrinsics_m = 8, intrinsics_n = 2, subgroups_n = 4, intrinsics_k = 2>
+ } : tensor<1x2x8x4x16x2x8xi8>, tensor<1x2x4x2x4x16x2x8xi8> into tensor<1x1x8x4x2x4x16x4xi32>
+ return %0 : tensor<1x1x8x4x2x4x16x4xi32>
+ }
+}
diff --git a/compiler/src/iree/compiler/Codegen/Common/LowerUKernelDescriptors.cpp b/compiler/src/iree/compiler/Codegen/Common/LowerUKernelDescriptors.cpp
index 0d94832..f35fd48 100644
--- a/compiler/src/iree/compiler/Codegen/Common/LowerUKernelDescriptors.cpp
+++ b/compiler/src/iree/compiler/Codegen/Common/LowerUKernelDescriptors.cpp
@@ -13,6 +13,8 @@
#include "iree/compiler/Codegen/Interfaces/UKernelOpInterface.h"
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "llvm/Support/FormatVariadic.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -31,6 +33,9 @@
namespace {
struct LowerBitcodeUKernelsPass final
: impl::LowerBitcodeUKernelsPassBase<LowerBitcodeUKernelsPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<bufferization::BufferizationDialect, gpu::GPUDialect>();
+ }
void runOnOperation() override;
};
struct LowerMemrefUKernelsPass final
diff --git a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td
index d2211d1..c07841b 100644
--- a/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td
+++ b/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td
@@ -567,8 +567,8 @@
"::mlir::StringRef":$name,
"::mlir::DictionaryAttr":$target_configuration,
"::mlir::Operation *":$contextual_op,
- "::llvm::SmallVectorImpl<::mlir::Value>&":$inputs,
- "::llvm::SmallVectorImpl<::mlir::Value>&":$outputs,
+ "::llvm::ArrayRef<::mlir::Value>":$inputs,
+ "::llvm::ArrayRef<::mlir::Value>":$outputs,
"::llvm::SmallVectorImpl<::mlir::Value>&":$other_operands),
/*methodBody=*/"",
/*defaultImplementation=*/[{