blob: bcc94ec951c0396208bd259c7da585c3db36eb98 [file] [log] [blame]
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir::iree_compiler::IREE::LinalgExt {
static bool isaTranspose(linalg::LinalgOp linalgOp) {
if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
return false;
if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
return false;
SmallVector<AffineMap> mapRange = linalgOp.getIndexingMapsArray();
if (mapRange.size() != 2 || !mapRange.front().isPermutation() ||
!mapRange.back().isPermutation() || mapRange.front() == mapRange.back()) {
return false;
}
return llvm::hasSingleElement(linalgOp.getBlock()->getOperations());
}
static SmallVector<int64_t> getPermutation(linalg::LinalgOp linalgOp) {
assert(isaTranspose(linalgOp) && "linalgOp must be a transpose");
SmallVector<AffineMap> mapRange = linalgOp.getIndexingMapsArray();
AffineMap outMap = mapRange.back();
AffineMap inMap = mapRange.front();
// To get the permutation, look at each output index and find which
// dimension in the input we're reading from for that index.
return llvm::map_to_vector(outMap.getResults(), [&](AffineExpr expr) {
return static_cast<int64_t>(inMap.getResultPosition(expr).value());
});
}
namespace {
struct FuseTransposeWithAttentionOp final
: public OpRewritePattern<LinalgExt::AttentionOp> {
FuseTransposeWithAttentionOp(MLIRContext *context,
linalg::ControlFusionFn controlFn,
PatternBenefit benefit = 1)
: OpRewritePattern<LinalgExt::AttentionOp>(context, benefit),
controlFn(controlFn) {}
LogicalResult matchAndRewrite(LinalgExt::AttentionOp attentionOp,
PatternRewriter &rewriter) const override {
OpOperand *transposeOperand = nullptr;
linalg::LinalgOp transposeOp;
for (OpOperand *input : attentionOp.getDpsInputOperands()) {
if (controlFn && !controlFn(input)) {
continue;
}
auto maybeTransposeOp = input->get().getDefiningOp<linalg::LinalgOp>();
if (maybeTransposeOp && isaTranspose(maybeTransposeOp) &&
maybeTransposeOp->hasOneUse()) {
transposeOp = maybeTransposeOp;
transposeOperand = input;
break;
}
}
if (!transposeOperand) {
return rewriter.notifyMatchFailure(attentionOp, "no transpose operand");
}
int64_t inputIndex = transposeOperand->getOperandNumber();
SmallVector<int64_t> perm = getPermutation(transposeOp);
auto invPerm = invertPermutationVector(perm);
rewriter.modifyOpInPlace(attentionOp, [&]() {
SmallVector<AffineMap> newIndexingMaps =
attentionOp.getIndexingMapsArray();
AffineMap inputMap = attentionOp.getMatchingIndexingMap(transposeOperand);
SmallVector<AffineExpr> newExprs =
applyPermutation(inputMap.getResults(), invPerm);
AffineMap transposedMap =
AffineMap::get(inputMap.getNumDims(), inputMap.getNumSymbols(),
newExprs, rewriter.getContext());
newIndexingMaps[inputIndex] = transposedMap;
attentionOp.setIndexingMapsAttr(
rewriter.getAffineMapArrayAttr(newIndexingMaps));
attentionOp.setOperand(inputIndex, transposeOp.getDpsInputs()[0]);
});
return success();
}
private:
linalg::ControlFusionFn controlFn;
};
// Bubbles transpose-V out of attention to expose the more performant
// attention-transposeV.
struct BubbleTransposeVFromAttentionOp
: public OpRewritePattern<LinalgExt::AttentionOp> {
BubbleTransposeVFromAttentionOp(MLIRContext *context,
linalg::ControlFusionFn controlFn,
PatternBenefit benefit = 1)
: OpRewritePattern<LinalgExt::AttentionOp>(context, benefit),
controlFn(controlFn) {}
LogicalResult matchAndRewrite(LinalgExt::AttentionOp attentionOp,
PatternRewriter &rewriter) const override {
// Only checking for V because we are only bubbling transpose-V.
OpOperand *valueOpOperand = &attentionOp.getValueMutable();
if (controlFn && !controlFn(valueOpOperand)) {
return rewriter.notifyMatchFailure(
attentionOp, "Expected attentionOp and producer of V to be non-null "
"and outside dispatch.");
}
// Extract Attention indexing information.
AffineMap qMap = attentionOp.getQueryMap();
AffineMap kMap = attentionOp.getKeyMap();
AffineMap vMap = attentionOp.getValueMap();
AffineMap oMap = attentionOp.getOutputMap();
FailureOr<AttentionOpDetail> maybeOpInfo =
AttentionOpDetail::get(qMap, kMap, vMap, oMap);
if (failed(maybeOpInfo)) {
return failure();
}
// Only handle single dim for K2 and N for now.
if (maybeOpInfo->getK2Dims().size() != 1 ||
maybeOpInfo->getNDims().size() != 1) {
return failure();
}
// Check that V has standard map/non transposed V.
AffineExpr k2Dim =
rewriter.getAffineDimExpr(maybeOpInfo->getK2Dims().back());
AffineExpr nDim = rewriter.getAffineDimExpr(maybeOpInfo->getNDims().back());
int64_t vRank = vMap.getNumResults();
// TODO: This check is quite conservative, in the future we should simply
// do vMap.getResultPosition(k2Dim) > vMap.getResultPosition(nDim).
if (vMap.getResult(vRank - 1) != nDim ||
vMap.getResult(vRank - 2) != k2Dim) {
return failure();
}
// Get dimension positions to prepare for transpose.
std::optional<int64_t> maybeK2Pos = vMap.getResultPosition(k2Dim);
std::optional<int64_t> maybeNPos = vMap.getResultPosition(nDim);
assert(maybeK2Pos.has_value() && maybeNPos.has_value() &&
"Expected K2 dim and N dim to be in V-map.");
int64_t k2Pos = maybeK2Pos.value();
int64_t nPos = maybeNPos.value();
SmallVector<int64_t> perm = llvm::to_vector(llvm::seq<int64_t>(0, vRank));
std::swap(perm[k2Pos], perm[nPos]);
// Expose transposeOp for V.
Location loc = attentionOp.getLoc();
Value value = attentionOp.getValue();
auto valueType = dyn_cast<ShapedType>(value.getType());
auto valueElType = valueType.getElementType();
SmallVector<OpFoldResult> transVShape =
tensor::getMixedSizes(rewriter, loc, value);
applyPermutationToVector(transVShape, perm);
Value initTransV =
rewriter.create<tensor::EmptyOp>(loc, transVShape, valueElType)
.getResult();
Value transposeV =
rewriter.create<linalg::TransposeOp>(loc, value, initTransV, perm)
->getResult(0);
// Generate transpose V map.
SmallVector<AffineExpr> newExprs =
applyPermutation(vMap.getResults(), perm);
AffineMap transposedVMap =
AffineMap::get(vMap.getNumDims(), vMap.getNumSymbols(), newExprs,
rewriter.getContext());
// Modify attention to have transposed V inputs and mapping.
int64_t valueIndex = valueOpOperand->getOperandNumber();
rewriter.modifyOpInPlace(attentionOp, [&]() {
SmallVector<AffineMap> newIndexingMaps =
attentionOp.getIndexingMapsArray();
newIndexingMaps[valueIndex] = transposedVMap;
attentionOp.setIndexingMapsAttr(
rewriter.getAffineMapArrayAttr(newIndexingMaps));
attentionOp.setOperand(valueIndex, transposeV);
});
return success();
}
private:
linalg::ControlFusionFn controlFn;
};
} // namespace
void populateFuseLinalgExtOpsWithTransposes(
RewritePatternSet &patterns,
const linalg::ControlFusionFn &controlFusionFn) {
patterns.add<FuseTransposeWithAttentionOp>(patterns.getContext(),
controlFusionFn);
}
void populateBubbleTransposeFromLinalgExtOps(
RewritePatternSet &patterns,
const linalg::ControlFusionFn &controlFusionFn) {
patterns.add<BubbleTransposeVFromAttentionOp>(patterns.getContext(),
controlFusionFn);
}
} // namespace mlir::iree_compiler::IREE::LinalgExt