blob: fc4b1d0cbfbc5d3ce12e8e4fccd931ed23ceef67 [file] [log] [blame]
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h"
#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/Debug.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace LinalgExt {
namespace {
// Computes a reduction along the rows of a 2d tensor of shape MxN
// to produce a tensor of shape M
template <typename T>
static Value computeRowwiseReduction(Value a, Value output, Location loc,
OpBuilder &builder) {
SmallVector<utils::IteratorType> iteratorTypes{
utils::IteratorType::parallel, utils::IteratorType::reduction};
AffineMap id = AffineMap::getMultiDimIdentityMap(2, builder.getContext());
AffineExpr d0, d1;
bindDims(builder.getContext(), d0, d1);
// (d0, d1) -> (d0)
auto rowMap = AffineMap::get(2, 0, {d0}, builder.getContext());
SmallVector<AffineMap> indexingMaps{id, rowMap};
auto genericOp = builder.create<linalg::GenericOp>(
loc, output.getType(), a, output, indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value result = b.create<T>(loc, args[0], args[1]);
b.create<linalg::YieldOp>(loc, result);
});
return genericOp.getResult(0);
}
static std::tuple<Value, Value> computeNewSum(Value oldMax, Value newMax,
Value oldSum, Value currentSum,
Value output, Location loc,
OpBuilder &builder) {
SmallVector<utils::IteratorType> iteratorTypes{utils::IteratorType::parallel};
auto identityMap = AffineMap::getMultiDimIdentityMap(1, builder.getContext());
SmallVector<AffineMap> indexingMaps(6, identityMap);
SmallVector<Type> resultTypes(2, currentSum.getType());
auto genericOp = builder.create<linalg::GenericOp>(
loc, resultTypes, ValueRange{oldMax, newMax, oldSum, currentSum},
ValueRange{output, output}, indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
Value weight = b.create<math::ExpOp>(loc, diff);
Value scaledOldSum = b.create<arith::MulFOp>(loc, weight, args[2]);
Value result = b.create<arith::AddFOp>(loc, scaledOldSum, args[3]);
b.create<linalg::YieldOp>(loc, ValueRange{result, scaledOldSum});
});
return std::make_tuple(genericOp.getResult(0), genericOp.getResult(1));
}
static Value computePartialSoftmax(Value qkTranspose, Value currentMax,
Value output, Location loc,
OpBuilder &builder) {
AffineMap identityMap =
AffineMap::getMultiDimIdentityMap(2, builder.getContext());
AffineExpr d0, d1;
bindDims(builder.getContext(), d0, d1);
// (d0, d1) -> (d0)
auto rowMap = AffineMap::get(2, 0, {d0}, builder.getContext());
SmallVector<AffineMap> indexingMaps{identityMap, rowMap, identityMap};
SmallVector<utils::IteratorType> iteratorTypes(2,
utils::IteratorType::parallel);
auto genericOp = builder.create<linalg::GenericOp>(
loc, qkTranspose.getType(), ValueRange{qkTranspose, currentMax}, output,
indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
Value result = b.create<math::ExpOp>(loc, diff);
b.create<linalg::YieldOp>(loc, result);
});
return genericOp.getResult(0);
}
static Value scalePartialSoftmax(Value softmax, Value scaledOldSum,
Value output, Location loc,
OpBuilder &builder) {
AffineMap identityMap =
AffineMap::getMultiDimIdentityMap(2, builder.getContext());
AffineExpr d0, d1;
bindDims(builder.getContext(), d0, d1);
// (d0, d1) -> (d0)
auto rowMap = AffineMap::get(2, 0, {d0}, builder.getContext());
SmallVector<AffineMap> indexingMaps{identityMap, rowMap, identityMap};
SmallVector<utils::IteratorType> iteratorTypes(2,
utils::IteratorType::parallel);
auto genericOp = builder.create<linalg::GenericOp>(
loc, softmax.getType(), ValueRange{softmax, scaledOldSum}, output,
indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
b.create<linalg::YieldOp>(loc, result);
});
return genericOp.getResult(0);
}
static Value scaleAccumulator(Value accumulator, Value scaledOldSum,
Value newSum, Value output, Location loc,
OpBuilder &builder) {
AffineMap identityMap =
AffineMap::getMultiDimIdentityMap(2, builder.getContext());
AffineExpr d0, d1;
bindDims(builder.getContext(), d0, d1);
// (d0, d1) -> (d0)
auto rowMap = AffineMap::get(2, 0, {d0}, builder.getContext());
SmallVector<AffineMap> indexingMaps{identityMap, rowMap, rowMap, identityMap};
SmallVector<utils::IteratorType> iteratorTypes(2,
utils::IteratorType::parallel);
auto genericOp = builder.create<linalg::GenericOp>(
loc, accumulator.getType(), ValueRange{accumulator, scaledOldSum, newSum},
output, indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value prod = b.create<arith::DivFOp>(loc, args[1], args[2]);
Value result = b.create<arith::MulFOp>(loc, prod, args[0]);
b.create<linalg::YieldOp>(loc, result);
});
return genericOp.getResult(0);
}
static Value computeQKTranspose(Value query, Value key, Value transposedOutput,
Value output, Value zero,
RankedTensorType tensorType, Location loc,
OpBuilder &builder) {
SmallVector<int64_t> perm{1, 0};
auto transposeOp =
builder.create<linalg::TransposeOp>(loc, key, transposedOutput, perm);
Value acc =
builder.create<linalg::FillOp>(loc, ValueRange{zero}, output).result();
auto matmulOp = builder.create<linalg::MatmulOp>(
loc, tensorType, ValueRange{query, transposeOp.getResult()[0]}, acc);
return matmulOp.getResult(0);
}
static std::tuple<Value, Value, Value, Value>
extractSlices(Value key, Value value, Value query, Value output,
ArrayRef<int64_t> queryShape, ArrayRef<Value> ivs,
Value sequenceTileLength, Type elementType, Location loc,
OpBuilder &builder) {
auto one = builder.getIndexAttr(1);
auto zero = builder.getIndexAttr(0);
auto headDimension = builder.getIndexAttr(queryShape.back());
SmallVector<OpFoldResult> strides(queryShape.size(), one);
SmallVector<OpFoldResult> sizes(queryShape.size(), one);
SmallVector<OpFoldResult> offsets(queryShape.size(), zero);
sizes[1] = sequenceTileLength;
sizes[2] = headDimension;
offsets[0] = ivs[0];
offsets[1] = ivs[1];
SmallVector<int64_t> tensorShape{ShapedType::kDynamic, queryShape.back()};
auto tensorType = RankedTensorType::get(tensorShape, elementType);
Value keySlice = builder.create<tensor::ExtractSliceOp>(
loc, tensorType, key, offsets, sizes, strides);
Value valueSlice = builder.create<tensor::ExtractSliceOp>(
loc, tensorType, value, offsets, sizes, strides);
offsets = SmallVector<OpFoldResult>(queryShape.size(), zero);
offsets[0] = ivs[0];
Value querySlice = builder.create<tensor::ExtractSliceOp>(
loc, tensorType, query, offsets, sizes, strides);
Value outputSlice = builder.create<tensor::ExtractSliceOp>(
loc, tensorType, output, offsets, sizes, strides);
return std::make_tuple(keySlice, valueSlice, querySlice, outputSlice);
}
static std::tuple<Value, Value, Value>
insertSlices(Value newResult, Value result, Value newMax, Value max,
Value newSum, Value sum, ArrayRef<int64_t> queryShape,
ArrayRef<Value> ivs, Value sequenceTileLength, Location loc,
OpBuilder &builder) {
auto one = builder.getIndexAttr(1);
auto zero = builder.getIndexAttr(0);
auto headDimension = builder.getIndexAttr(queryShape.back());
SmallVector<OpFoldResult> strides(queryShape.size(), one);
SmallVector<OpFoldResult> sizes(queryShape.size(), one);
SmallVector<OpFoldResult> offsets(queryShape.size(), zero);
sizes[1] = sequenceTileLength;
sizes[2] = headDimension;
offsets[0] = ivs[0];
Value updatedAcc = builder.create<tensor::InsertSliceOp>(
loc, newResult, result, offsets, sizes, strides);
offsets = SmallVector<OpFoldResult>{zero};
sizes = SmallVector<OpFoldResult>{sequenceTileLength};
strides = SmallVector<OpFoldResult>{one};
Value updatedMax = builder.create<tensor::InsertSliceOp>(
loc, newMax, max, offsets, sizes, strides);
Value updatedSum = builder.create<tensor::InsertSliceOp>(
loc, newSum, sum, offsets, sizes, strides);
return std::make_tuple(updatedAcc, updatedMax, updatedSum);
}
static scf::LoopNest createLoopNest(SmallVectorImpl<Value> &ivs, Value lb,
Value step, Value ub, ValueRange args,
Location loc, OpBuilder &builder) {
SmallVector<Value> lbs{lb};
SmallVector<Value> steps{step};
SmallVector<Value> ubs{ub};
scf::LoopNest loopNest = scf::buildLoopNest(
builder, loc, lbs, ubs, steps, args,
[&](OpBuilder &nestedBuilder, Location loc, ValueRange outputIvs,
ValueRange iterArgs) -> scf::ValueVector { return iterArgs; });
for (scf::ForOp loop : loopNest.loops)
ivs.push_back(loop.getInductionVar());
return loopNest;
}
/// This is an implementation of flash attention which
/// is a tiled and fused implementation of the attention operator.
/// The attention operator computes:
/// matmul(softmax(matmul(Q, transpose(K))), V)
/// where: Q is the query matrix [B x N x d]
/// K is the key matrix [B x N x d]
/// V is the value matrix [B x N x d]
///
/// The core algorithm is as follows:
/// For each element in B,
/// 1. Load a tile from the Q matrix of size T x d -> q
/// 2. Initialize statistics: running_sum, running_max
/// 3. for i = 0 to N with step T
/// a. Load a tile from the K matrix of size T x d -> k
/// a. Load a tile from the V matrix of size T x d -> v
/// b. Transpose(k) -> kT
/// c. Compute matmul(q, kT) -> qkT
/// d. Compute sum(qkT) along rows -> current_sum
/// e. Compute max(qkT) along rows -> current_max
/// f. Compute new max: max(current_max, running_max)
/// g. Compute new sum: alpha * running_sum + beta * current_sum
/// h. Compute curent estimate of softmax: exp(qKT - current_max) -> s
/// i. Scale softmax estimate and current value of result by
/// appropriate factors
/// j. Compute matmul(s, v) and add to accumulator
///
///
LogicalResult reifyAttentionTransform(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
funcOp.walk([&](IREE::LinalgExt::AttentionOp attnOp) {
Location loc = attnOp.getLoc();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(attnOp);
Value query = attnOp.getQuery();
ShapedType queryType = attnOp.getQueryType();
Type elementType = queryType.getElementType();
ArrayRef<int64_t> queryShape = queryType.getShape();
SmallVector<OpFoldResult> queryDimValues =
tensor::createDimValues(rewriter, loc, query);
Value sequenceTileLength =
getValueOrCreateConstantIndexOp(rewriter, loc, queryDimValues[1]);
Value batchTileLength =
getValueOrCreateConstantIndexOp(rewriter, loc, queryDimValues[0]);
Value key = attnOp.getKey();
Value value = attnOp.getValue();
SmallVector<OpFoldResult> keyDimValues =
tensor::createDimValues(rewriter, loc, key);
Value sequenceLength =
getValueOrCreateConstantIndexOp(rewriter, loc, keyDimValues[1]);
// Construct first loop
Value zeroValue = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value oneValue = rewriter.create<arith::ConstantIndexOp>(loc, 1);
SmallVector<Value> ivs;
Value output = attnOp.getOutput();
scf::LoopNest firstLoopNest =
createLoopNest(ivs, zeroValue, oneValue, batchTileLength,
ValueRange({output}), loc, rewriter);
Value iterArg = firstLoopNest.loops.back().getRegionIterArg(0);
OpBuilder::InsertionGuard guardFirstLoop(rewriter);
rewriter.setInsertionPointToStart(firstLoopNest.loops.back().getBody());
// Create max and sum statistics
SmallVector<OpFoldResult> dims{sequenceTileLength};
Value zeroF32 = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(elementType));
Value largeNegativeF32 = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(elementType, -1e30));
Value max = rewriter.create<tensor::EmptyOp>(loc, dims, elementType);
Value negativeMax =
rewriter.create<linalg::FillOp>(loc, ValueRange{largeNegativeF32}, max)
.result();
Value sum = rewriter.create<tensor::EmptyOp>(loc, dims, elementType);
Value zeroSum =
rewriter.create<linalg::FillOp>(loc, ValueRange{zeroF32}, sum).result();
// Construct second loop
scf::LoopNest secondLoopNest = createLoopNest(
ivs, zeroValue, sequenceTileLength, sequenceLength,
ValueRange({iterArg, negativeMax, zeroSum}), loc, rewriter);
Value iterArgResult = secondLoopNest.loops.back().getRegionIterArg(0);
Value iterArgMax = secondLoopNest.loops.back().getRegionIterArg(1);
Value iterArgSum = secondLoopNest.loops.back().getRegionIterArg(2);
OpBuilder::InsertionGuard guardSecondLoop(rewriter);
rewriter.setInsertionPointToStart(secondLoopNest.loops.back().getBody());
auto [keySlice, valueSlice, querySlice, outputSlice] =
extractSlices(key, value, query, iterArgResult, queryShape, ivs,
sequenceTileLength, elementType, loc, rewriter);
// Compute matmul(q, transpose(k))
auto headDimension = rewriter.getIndexAttr(queryShape.back());
SmallVector<OpFoldResult> transposedShape{headDimension,
sequenceTileLength};
Value empty =
rewriter.create<tensor::EmptyOp>(loc, transposedShape, elementType);
SmallVector<OpFoldResult> resultShape{sequenceTileLength,
sequenceTileLength};
Value emptySquare =
rewriter.create<tensor::EmptyOp>(loc, resultShape, elementType);
auto tensorType = RankedTensorType::get(
SmallVector<int64_t>(2, ShapedType::kDynamic), elementType);
Value qkTranspose =
computeQKTranspose(querySlice, keySlice, empty, emptySquare, zeroF32,
tensorType, loc, rewriter);
empty = rewriter.create<tensor::EmptyOp>(
loc, SmallVector<OpFoldResult>{sequenceTileLength}, elementType);
// Compute current statistics
Value newMax = computeRowwiseReduction<arith::MaxFOp>(
qkTranspose, iterArgMax, loc, rewriter);
Value partialSoftmax =
computePartialSoftmax(qkTranspose, newMax, emptySquare, loc, rewriter);
Value currentSum = computeRowwiseReduction<arith::AddFOp>(
partialSoftmax, zeroSum, loc, rewriter);
auto [newSum, scaledOldSum] = computeNewSum(
iterArgMax, newMax, iterArgSum, currentSum, empty, loc, rewriter);
// Scale partial softmax
Value softmax =
scalePartialSoftmax(partialSoftmax, newSum, emptySquare, loc, rewriter);
// Update accumulator
empty = rewriter.create<tensor::EmptyOp>(
loc, SmallVector<OpFoldResult>{sequenceLength, headDimension},
elementType);
Value scaledAcc = scaleAccumulator(outputSlice, scaledOldSum, newSum, empty,
loc, rewriter);
// Compute matmul(softmax, v)
Value result = rewriter
.create<linalg::MatmulOp>(
loc, outputSlice.getType(),
ValueRange{softmax, valueSlice}, scaledAcc)
.getResult(0);
// Insert slices
auto [updatedAcc, updatedMax, updatedSum] = insertSlices(
result, iterArgResult, newMax, iterArgMax, newSum, iterArgSum,
queryShape, ivs, sequenceTileLength, loc, rewriter);
if (scf::YieldOp yieldOp = dyn_cast<scf::YieldOp>(
secondLoopNest.loops.back().getBody()->getTerminator())) {
rewriter.replaceOpWithNewOp<scf::YieldOp>(
yieldOp, ValueRange{updatedAcc, updatedMax, updatedSum});
}
if (scf::YieldOp yieldOp = dyn_cast<scf::YieldOp>(
firstLoopNest.loops.back().getBody()->getTerminator())) {
OpBuilder::InsertionGuard yieldGuard(rewriter);
rewriter.setInsertionPoint(yieldOp);
rewriter.replaceOpWithNewOp<scf::YieldOp>(
yieldOp, ValueRange{secondLoopNest.results[0]});
}
attnOp.getResults()[0].replaceAllUsesWith(firstLoopNest.results[0]);
return WalkResult::advance();
});
return success();
}
} // namespace
namespace {
struct TileAndDecomposeAttentionPass
: public TileAndDecomposeAttentionBase<TileAndDecomposeAttentionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, IREE::LinalgExt::IREELinalgExtDialect,
linalg::LinalgDialect, scf::SCFDialect,
tensor::TensorDialect>();
}
void runOnOperation() override;
};
} // namespace
void TileAndDecomposeAttentionPass::runOnOperation() {
MLIRContext *context = &getContext();
IRRewriter rewriter(context);
if (failed(reifyAttentionTransform(getOperation())))
return signalPassFailure();
}
std::unique_ptr<Pass> createTileAndDecomposeAttentionPass() {
return std::make_unique<TileAndDecomposeAttentionPass>();
}
} // namespace LinalgExt
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir