blob: fb551e505aa9b8a13cd5efe35ecf1b7861b75a7f [file] [log] [blame]
// Copyright 2024 The IREE Authors
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Codegen/LLVMGPU/ROCDLKernelConfig.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h"
#include "iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
namespace mlir::iree_compiler {
namespace {
using CodeGenPipeline = IREE::Codegen::DispatchLoweringPassPipeline;
// Warp Reduction Configuration
static bool isMatvecLike(linalg::LinalgOp linalgOp) {
if (linalgOp.getNumParallelLoops() != 2)
return false;
if (linalgOp.getNumReductionLoops() != 1)
return false;
// TODO: Allow for matvec with fused dequantization.
FailureOr<linalg::ContractionDimensions> dims =
if (failed(dims))
return false;
// TODO: Support batch matvec.
if (!dims->batch.empty())
return false;
for (ArrayRef indices : {dims->m, dims->n, dims->k}) {
if (!llvm::hasSingleElement(indices))
return false;
// Check if the first parallel dimension has bound 1, indicating we found a
// vector shape.
SmallVector<int64_t, 4> bounds = linalgOp.getStaticLoopRanges();
if (bounds[dims->m.front()] != 1)
return false;
return true;
static LogicalResult
setWarpReductionConfig(IREE::GPU::TargetAttr target,
mlir::FunctionOpInterface entryPoint,
linalg::LinalgOp op) {
if (!target.supportsSubgroupShuffle())
return failure();
SmallVector<unsigned> parallelDims;
SmallVector<unsigned> reductionDims;
SmallVector<int64_t, 4> bounds = op.getStaticLoopRanges();
int64_t numParallelDims = op.getNumParallelLoops();
if (reductionDims.empty())
return failure();
// Make sure reduction dimensions are static and innermost ones.
int64_t numDynamicReductionDims = 0;
for (unsigned dim : reductionDims) {
if (ShapedType::isDynamic(bounds[dim])) {
if (dim < numParallelDims) {
return failure();
// Distribution of multi-dim masked writes currently aren't fully supported.
if (numDynamicReductionDims > 1) {
return failure();
if (op.getRegionOutputArgs().size() != 1)
return failure();
// Only support projected permutation, this could be extended to projected
// permutated with broadcast.
if (llvm::any_of(op.getDpsInputOperands(), [&](OpOperand *input) {
return !op.getMatchingIndexingMap(input).isProjectedPermutation();
return failure();
bool foundSingleReductionOutput = false;
for (auto [index, initOpOperand] : llvm::enumerate(op.getDpsInitsMutable())) {
// Only single combiner operations are supported for now.
SmallVector<Operation *> combinerOps;
if (matchReduction(op.getRegionOutputArgs(), index, combinerOps) &&
combinerOps.size() == 1) {
if (foundSingleReductionOutput)
return failure();
foundSingleReductionOutput = true;
if (!op.getMatchingIndexingMap(&initOpOperand).isIdentity())
return failure();
if (!foundSingleReductionOutput)
return failure();
// Tile all the parallel dimension to 1.
SmallVector<unsigned> partitionedLoops =
size_t numLoops = partitionedLoops.empty() ? 0 : partitionedLoops.back() + 1;
SmallVector<int64_t> workgroupTileSizes(numLoops, 1);
// Without any bounds on dynamic reduction dims, we need specialization to
// get peak performance. For now, just use the warp size.
if (numDynamicReductionDims) {
SmallVector<int64_t> reductionTileSizes(op.getNumLoops(), 0);
int64_t preferredSubgroupSize = target.getPreferredSubgroupSize();
reductionTileSizes[reductionDims[0]] = preferredSubgroupSize;
TileSizesListType tileSizes;
tileSizes.emplace_back(std::move(workgroupTileSizes)); // Workgroup level
tileSizes.emplace_back(std::move(reductionTileSizes)); // Reduction level
std::array<int64_t, 3> workgroupSize = {preferredSubgroupSize, 1, 1};
if (failed(setOpConfigAndEntryPointFnTranslation(
entryPoint, op, tileSizes, CodeGenPipeline::LLVMGPUWarpReduction,
workgroupSize))) {
return failure();
return success();
int64_t reductionSize = 1;
for (int64_t dim : reductionDims)
reductionSize *= bounds[dim];
int64_t subgroupSize = 0;
for (int s : target.getWgp().getSubgroupSizeChoices().asArrayRef()) {
if (reductionSize % s == 0) {
subgroupSize = s;
if (subgroupSize == 0)
return failure();
const Type elementType =
if (!elementType.isIntOrFloat())
return failure();
unsigned bitWidth = elementType.getIntOrFloatBitWidth();
// Reduction distribution only supports 8/16/32 bit types now.
if (bitWidth != 32 && bitWidth != 16 && bitWidth != 8)
return failure();
const unsigned largestLoadSizeInBits = 128;
unsigned vectorSize = largestLoadSizeInBits / bitWidth;
while ((reductionSize / vectorSize) % subgroupSize != 0)
vectorSize /= 2;
// Deduce the workgroup size we should use for reduction. Currently a
// workgroup processes all elements in reduction dimensions. Need to make sure
// the workgroup size we use can divide the total reduction size, and it's
// also within hardware limitations.
const int64_t maxWorkgroupSize = 1024;
int64_t groupSize = reductionSize / vectorSize;
if (groupSize > maxWorkgroupSize) {
groupSize = llvm::APIntOps::GreatestCommonDivisor(
{64, uint64_t(groupSize)}, {64, uint64_t(maxWorkgroupSize)})
// Then we need to strike a balance--
// 1) parallel dimensions are distributed to workgroups. If there are many
// workgroups dispatched, we'd want to have each GPU core hosting multiple
// of them for occupancy.
// 2) we want each thread to read quite a few 128-bit vectors for better
// memory cache behavior.
// Both means we cannot use a too large workgroup size.
std::optional<int64_t> parallelSize = 1;
for (int64_t dim : parallelDims) {
if (ShapedType::isDynamic(bounds[dim])) {
parallelSize = std::nullopt;
*parallelSize *= bounds[dim];
// Total parallel size that can fill the GPU with enough workgorups.
// TODO: query from the target device; roughly 2x hardware compute unit.
const int parallelThreshold = 256;
// How many 128-bit vectors each thread should at least read.
const int targetVectorCount = 8;
while (parallelSize && *parallelSize > parallelThreshold &&
(groupSize / 2) % subgroupSize == 0 &&
reductionSize / (groupSize * vectorSize) < targetVectorCount) {
// Use less subgroups per workgroup..
groupSize /= 2;
// in order to host more workgroups per hardware compute unit.
*parallelSize /= 2;
// Current warp reduction pattern is a two step butterfly warp reduce.
// First, do warp reductions along multiple subgroups.
// Second, reduce results from multiple subgroups using single warp reduce.
// The final warp reduce requires subgroup count <= subgroup size to work.
if ((groupSize / subgroupSize) > subgroupSize)
return failure();
// With just one subgroup per workgroup, make each subgroup do more work and
// process a few reductions (rows) along the last parallel dimension.
if (llvm::none_of(bounds, ShapedType::isDynamic) && isMatvecLike(op)) {
int64_t lastParallelBound = bounds[parallelDims.back()];
int64_t numParallelReductions = 1;
const int64_t maxParallelFactor = groupSize / 4;
for (int64_t parallelFactor = 2;
(parallelFactor < maxParallelFactor) &&
(lastParallelBound % parallelFactor == 0) &&
(lastParallelBound > parallelFactor);
parallelFactor *= 2) {
numParallelReductions = parallelFactor;
workgroupTileSizes.back() = numParallelReductions;
std::array<int64_t, 3> workgroupSize = {groupSize, 1, 1};
SmallVector<int64_t> reductionTileSizes(op.getNumLoops(), 0);
int64_t remainingGroupSize = groupSize;
for (int i = reductionDims.size() - 1; i >= 0; --i) {
int64_t dim = reductionDims[i];
int64_t bound = bounds[dim];
if (i == reductionDims.size() - 1)
bound /= vectorSize;
APInt size = llvm::APIntOps::GreatestCommonDivisor(
{64, uint64_t(remainingGroupSize)}, {64, uint64_t(bound)});
reductionTileSizes[dim] = size.getSExtValue();
if (i == reductionDims.size() - 1)
reductionTileSizes[dim] *= vectorSize;
remainingGroupSize /= size.getSExtValue();
TileSizesListType tileSizes;
tileSizes.emplace_back(std::move(workgroupTileSizes)); // Workgroup level
tileSizes.emplace_back(std::move(reductionTileSizes)); // Reduction level
return setOpConfigAndEntryPointFnTranslation(
entryPoint, op, tileSizes, CodeGenPipeline::LLVMGPUWarpReduction,
workgroupSize, subgroupSize);
return success();
// Root Configuration
static LogicalResult setRootConfig(IREE::GPU::TargetAttr target,
mlir::FunctionOpInterface entryPointFn,
Operation *computeOp) {
if (succeeded(setDataTiledMultiMmaLoweringConfig(target, entryPointFn,
computeOp))) {
return success();
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(computeOp)) {
if (succeeded(IREE::GPU::setMatmulLoweringConfig(target, entryPointFn,
linalgOp))) {
return success();
if (succeeded(setWarpReductionConfig(target, entryPointFn, linalgOp))) {
return success();
// TODO: Add configurations for matmul here too.
if (succeeded(IREE::GPU::setTileAndFuseLoweringConfig(target, entryPointFn,
computeOp))) {
return success();
return failure();
// Propagates the configuration to the other ops.
static void propagateLoweringConfig(Operation *rootOp,
ArrayRef<Operation *> computeOps) {
if (IREE::Codegen::LoweringConfigAttrInterface config =
getLoweringConfig(rootOp)) {
for (auto op : computeOps) {
if (op != rootOp)
setLoweringConfig(op, config);
} // namespace
// Entry Point
LogicalResult initROCDLLaunchConfig(FunctionOpInterface funcOp) {
IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp);
if (!target)
return funcOp.emitError("missing GPU target in");
// First check whether we already have workgroup count set--it's a
// "contract" to indicate that we should bypass all tiling and
// distribution to go down just the most basic lowering flow.
if (auto exportOp = getEntryPoint(funcOp)) {
if (Block *body = exportOp->getWorkgroupCountBody()) {
auto retOp = cast<IREE::HAL::ReturnOp>(body->getTerminator());
// For scalar dispatch cases--using just one thread of one workgroup.
auto isOne = [](Value value) { return matchPattern(value, m_One()); };
if (llvm::all_of(retOp.getOperands(), isOne)) {
std::array<int64_t, 3> workgroupSize = {1, 1, 1};
auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
if (failed(setTranslationInfo(funcOp, translationInfo))) {
return failure();
return success();
SmallVector<Operation *> computeOps = getComputeOps(funcOp);
if (IREE::Codegen::TranslationInfoAttr translationInfo =
getTranslationInfo(funcOp)) {
// Currently ROCDL requires propagation of user lowering configs for
// all pipelines except TileAndFuse.
if (translationInfo.getDispatchLoweringPassPipeline() !=
IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUTileAndFuse) {
for (auto op : computeOps) {
if (getLoweringConfig(op)) {
propagateLoweringConfig(op, computeOps);
Operation *rootOp = nullptr;
// Find the root operation. linalg.generic and linalg.fill are not root
// operations if there are other compute operations present.
for (Operation *op : llvm::reverse(computeOps)) {
if (!isa<linalg::GenericOp, linalg::FillOp>(op)) {
rootOp = op;
if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) {
// linalg.generic with `reduction` iterator types are roots as well.
if (genericOp.getNumLoops() != genericOp.getNumParallelLoops()) {
rootOp = op;
if (!rootOp) {
for (Operation *op : llvm::reverse(computeOps)) {
if (isa<linalg::GenericOp, linalg::FillOp>(op)) {
rootOp = op;
if (!rootOp) {
// No root operation found, set it to none.
auto translationInfo = IREE::Codegen::TranslationInfoAttr::get(
funcOp.getContext(), CodeGenPipeline::None);
if (failed(setTranslationInfo(funcOp, translationInfo))) {
return failure();
return success();
if (failed(setRootConfig(target, funcOp, rootOp)))
return failure();
if (getTranslationInfo(funcOp).getDispatchLoweringPassPipeline() !=
IREE::Codegen::DispatchLoweringPassPipeline::LLVMGPUTileAndFuse) {
propagateLoweringConfig(rootOp, computeOps);
return success();
} // namespace mlir::iree_compiler