blob: 93c7919a3ab577ff356e51fead7900a536385886 [file] [log] [blame]
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h"
#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/Debug.h"
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace LinalgExt {
namespace {
// Computes a reduction along the rows of a 2d tensor of shape MxN
// to produce a tensor of shape M
template <typename T>
static Value computeRowwiseReduction(Value a, Value output, Location loc,
OpBuilder &builder,
SmallVectorImpl<Operation *> &ops) {
SmallVector<utils::IteratorType> iteratorTypes{
utils::IteratorType::parallel, utils::IteratorType::reduction};
AffineMap id = AffineMap::getMultiDimIdentityMap(2, builder.getContext());
AffineExpr d0, d1;
bindDims(builder.getContext(), d0, d1);
// (d0, d1) -> (d0)
auto rowMap = AffineMap::get(2, 0, {d0}, builder.getContext());
SmallVector<AffineMap> indexingMaps{id, rowMap};
auto genericOp = builder.create<linalg::GenericOp>(
loc, output.getType(), a, output, indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value result = b.create<T>(loc, args[0], args[1]);
b.create<linalg::YieldOp>(loc, result);
});
ops.push_back(genericOp);
return genericOp.getResult(0);
}
static Value computePartialSoftmax(Value qkTranspose, Value currentMax,
Location loc, OpBuilder &builder,
SmallVectorImpl<Operation *> &ops) {
AffineMap identityMap =
AffineMap::getMultiDimIdentityMap(2, builder.getContext());
AffineExpr d0, d1;
bindDims(builder.getContext(), d0, d1);
// (d0, d1) -> (d0)
auto rowMap = AffineMap::get(2, 0, {d0}, builder.getContext());
SmallVector<AffineMap> indexingMaps{rowMap, identityMap};
SmallVector<utils::IteratorType> iteratorTypes(2,
utils::IteratorType::parallel);
auto genericOp = builder.create<linalg::GenericOp>(
loc, qkTranspose.getType(), ValueRange{currentMax}, qkTranspose,
indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value diff = b.create<arith::SubFOp>(loc, args[1], args[0]);
Value result = b.create<math::Exp2Op>(loc, diff);
b.create<linalg::YieldOp>(loc, result);
});
ops.push_back(genericOp);
return genericOp.getResult(0);
}
/// Return the scale factor for the new softmax maximum and add the generic to
/// the provided list of operations.
static Value computeScaleFactor(Value oldMax, Value newMax, Location loc,
OpBuilder &builder,
SmallVectorImpl<Operation *> &ops) {
SmallVector<utils::IteratorType> iteratorTypes(1,
utils::IteratorType::parallel);
auto identityMap = AffineMap::getMultiDimIdentityMap(1, builder.getContext());
SmallVector<AffineMap> indexingMaps(2, identityMap);
auto genericOp = builder.create<linalg::GenericOp>(
loc, oldMax.getType(), newMax, oldMax, indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value diff = b.create<arith::SubFOp>(loc, args[1], args[0]);
Value weight = b.create<math::Exp2Op>(loc, diff);
b.create<linalg::YieldOp>(loc, weight);
});
ops.push_back(genericOp);
return genericOp.getResult(0);
}
static Value updateAndScale(Value scaleFactor, Value oldSum, Location loc,
OpBuilder &builder,
SmallVectorImpl<Operation *> &ops) {
SmallVector<utils::IteratorType> iteratorTypes(1,
utils::IteratorType::parallel);
auto identityMap = AffineMap::getMultiDimIdentityMap(1, builder.getContext());
SmallVector<AffineMap> indexingMaps(2, identityMap);
auto genericOp = builder.create<linalg::GenericOp>(
loc, oldSum.getType(), scaleFactor, oldSum, indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value scaledOldSum = b.create<arith::MulFOp>(loc, args[0], args[1]);
b.create<linalg::YieldOp>(loc, scaledOldSum);
});
ops.push_back(genericOp);
return genericOp.getResult(0);
}
static Value scalePartialSoftmax(Value softmax, Value inverseNewSum,
Location loc, OpBuilder &builder,
SmallVectorImpl<Operation *> &ops) {
AffineMap identityMap =
AffineMap::getMultiDimIdentityMap(2, builder.getContext());
AffineExpr d0, d1;
bindDims(builder.getContext(), d0, d1);
// (d0, d1) -> (d0)
auto rowMap = AffineMap::get(2, 0, {d0}, builder.getContext());
SmallVector<AffineMap> indexingMaps{rowMap, identityMap};
SmallVector<utils::IteratorType> iteratorTypes(2,
utils::IteratorType::parallel);
auto genericOp = builder.create<linalg::GenericOp>(
loc, softmax.getType(), ValueRange{inverseNewSum}, softmax, indexingMaps,
iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
Value result = b.create<arith::MulFOp>(loc, args[1], args[0]);
b.create<linalg::YieldOp>(loc, result);
});
ops.push_back(genericOp);
return genericOp.getResult(0);
}
static Value applyFinalScaling(Value result, Value newSum, Location loc,
OpBuilder &builder,
SmallVectorImpl<Operation *> &ops) {
AffineMap identityMap =
AffineMap::getMultiDimIdentityMap(2, builder.getContext());
AffineExpr d0, d1;
bindDims(builder.getContext(), d0, d1);
// (d0, d1) -> (d0)
auto rowMap = AffineMap::get(2, 0, {d0}, builder.getContext());
SmallVector<AffineMap> indexingMaps = {rowMap, identityMap};
SmallVector<utils::IteratorType> iteratorTypes(2,
utils::IteratorType::parallel);
auto genericOp = builder.create<linalg::GenericOp>(
loc, result.getType(), newSum, result, indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value one = b.create<arith::ConstantOp>(
loc, b.getFloatAttr(args[0].getType(), 1.0));
Value reciprocal = b.create<arith::DivFOp>(loc, one, args[0]);
Value result = b.create<arith::MulFOp>(loc, reciprocal, args[1]);
b.create<linalg::YieldOp>(loc, result);
});
ops.push_back(genericOp);
return genericOp.getResult(0);
}
static Value scaleAccumulator(Value accumulator, Value scaleFactor,
Location loc, OpBuilder &builder,
SmallVectorImpl<Operation *> &ops) {
AffineMap identityMap =
AffineMap::getMultiDimIdentityMap(2, builder.getContext());
AffineExpr d0, d1;
bindDims(builder.getContext(), d0, d1);
// (d0, d1) -> (d0)
auto rowMap = AffineMap::get(2, 0, {d0}, builder.getContext());
SmallVector<AffineMap> indexingMaps{rowMap, identityMap};
SmallVector<utils::IteratorType> iteratorTypes(2,
utils::IteratorType::parallel);
auto genericOp = builder.create<linalg::GenericOp>(
loc, accumulator.getType(), scaleFactor, accumulator, indexingMaps,
iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
Value result = b.create<arith::MulFOp>(loc, args[0], args[1]);
b.create<linalg::YieldOp>(loc, result);
});
ops.push_back(genericOp);
return genericOp.getResult(0);
}
static Value computeQKTranspose(Value query, Value key, Value output,
Value zero, Location loc, OpBuilder &builder,
SmallVectorImpl<Operation *> &ops) {
auto fillOp = builder.create<linalg::FillOp>(loc, ValueRange{zero}, output);
ops.push_back(fillOp);
Value acc = fillOp.result();
auto matmulOp = builder.create<linalg::MatmulTransposeBOp>(
loc, output.getType(), ValueRange{query, key}, acc);
ops.push_back(matmulOp);
return matmulOp.getResult(0);
}
static Value extractSlice(Value key, ArrayRef<int64_t> keyShape,
ArrayRef<Value> ivs, OpFoldResult keyValueTileLength,
OpFoldResult headDimension, Type elementType,
Location loc, OpBuilder &builder) {
auto one = builder.getIndexAttr(1);
auto zero = builder.getIndexAttr(0);
SmallVector<OpFoldResult> strides(keyShape.size(), one);
SmallVector<OpFoldResult> sizes(keyShape.size(), one);
SmallVector<OpFoldResult> offsets(keyShape.size(), zero);
sizes[1] = keyValueTileLength;
sizes[2] = headDimension;
if (!ivs.empty())
offsets[1] = ivs[0];
SmallVector<int64_t> tensorShape{keyShape[1], keyShape[2]};
auto tensorType = RankedTensorType::get(tensorShape, elementType);
Value keySlice = builder.create<tensor::ExtractSliceOp>(
loc, tensorType, key, offsets, sizes, strides);
return keySlice;
}
static scf::LoopNest createLoopNest(SmallVectorImpl<Value> &ivs, Value lb,
Value step, Value ub, ValueRange args,
Location loc, OpBuilder &builder) {
SmallVector<Value> lbs{lb};
SmallVector<Value> steps{step};
SmallVector<Value> ubs{ub};
scf::LoopNest loopNest = scf::buildLoopNest(
builder, loc, lbs, ubs, steps, args,
[&](OpBuilder &nestedBuilder, Location loc, ValueRange outputIvs,
ValueRange iterArgs) -> scf::ValueVector { return iterArgs; });
for (scf::ForOp loop : loopNest.loops)
ivs.push_back(loop.getInductionVar());
return loopNest;
}
static Value truncateToF16(Value input, Value output,
SmallVectorImpl<Operation *> &ops,
OpBuilder &builder, Location loc) {
AffineMap identityMap =
AffineMap::getMultiDimIdentityMap(2, builder.getContext());
SmallVector<AffineMap> indexingMaps{identityMap, identityMap};
SmallVector<utils::IteratorType> iteratorTypes(2,
utils::IteratorType::parallel);
auto genericOp = builder.create<linalg::GenericOp>(
loc, output.getType(), ValueRange{input}, output, indexingMaps,
iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
Value result = b.create<arith::TruncFOp>(loc, b.getF16Type(), args[0]);
b.create<linalg::YieldOp>(loc, result);
});
ops.push_back(genericOp);
return genericOp.getResult(0);
}
static std::tuple<Value, Value, Value>
createAttentionBody(Value keySlice, Value valueSlice, Value querySlice,
Value outputSlice, Value maxSlice, Value sumSlice,
OpFoldResult sequenceTileLength,
OpFoldResult keyValueTileLength, OpFoldResult headDimension,
Type elementType, SmallVectorImpl<Operation *> &ops,
Location loc, OpBuilder &builder) {
Type f32Type = builder.getF32Type();
// Compute matmul(q, transpose(k))
Value zero =
builder.create<arith::ConstantOp>(loc, builder.getZeroAttr(f32Type));
SmallVector<OpFoldResult> resultShape{sequenceTileLength, keyValueTileLength};
Value emptySquare =
builder.create<tensor::EmptyOp>(loc, resultShape, f32Type);
Value qkTranspose = computeQKTranspose(querySlice, keySlice, emptySquare,
zero, loc, builder, ops);
// Compute current statistics
Value newMax = computeRowwiseReduction<arith::MaximumFOp>(
qkTranspose, maxSlice, loc, builder, ops);
Value partialSoftmax =
computePartialSoftmax(qkTranspose, newMax, loc, builder, ops);
Value scaleFactor = computeScaleFactor(maxSlice, newMax, loc, builder, ops);
Value scaledOldSum = updateAndScale(scaleFactor, sumSlice, loc, builder, ops);
Value newSum = computeRowwiseReduction<arith::AddFOp>(
partialSoftmax, scaledOldSum, loc, builder, ops);
if (elementType.isF16()) {
Value empty =
builder.create<tensor::EmptyOp>(loc, resultShape, builder.getF16Type());
partialSoftmax = truncateToF16(partialSoftmax, empty, ops, builder, loc);
}
// Update accumulator
Value scaledAcc =
scaleAccumulator(outputSlice, scaleFactor, loc, builder, ops);
// Compute matmul(softmax, v)
auto matmulOp = builder.create<linalg::MatmulOp>(
loc, scaledAcc.getType(), ValueRange{partialSoftmax, valueSlice},
scaledAcc);
ops.push_back(matmulOp);
Value result = matmulOp.getResult(0);
return std::make_tuple(result, newMax, newSum);
}
static Value extractOrInsertOutputSlice(Value src, Value dst,
ArrayRef<int64_t> queryShape,
OpFoldResult sequenceTileLength,
OpFoldResult headDimension,
Location loc, OpBuilder &builder) {
auto one = builder.getIndexAttr(1);
auto zero = builder.getIndexAttr(0);
SmallVector<OpFoldResult> strides(3, one);
SmallVector<OpFoldResult> sizes = {one, sequenceTileLength, headDimension};
SmallVector<OpFoldResult> offsets(3, zero);
Value slice;
if (!dst) {
SmallVector<int64_t> accShape{queryShape[1], queryShape[2]};
Type elementType = src.getType().cast<ShapedType>().getElementType();
auto tensorType = RankedTensorType::get(accShape, elementType);
slice = builder.create<tensor::ExtractSliceOp>(loc, tensorType, src,
offsets, sizes, strides);
} else {
slice = builder.create<tensor::InsertSliceOp>(loc, src, dst, offsets, sizes,
strides);
}
return slice;
}
static Value extractOutputSlice(Value src, ArrayRef<int64_t> queryShape,
OpFoldResult sequenceTileLength,
OpFoldResult headDimension, Location loc,
OpBuilder &builder) {
return extractOrInsertOutputSlice(src, {}, queryShape, sequenceTileLength,
headDimension, loc, builder);
}
static Value insertOutputSlice(Value src, Value dst,
OpFoldResult sequenceTileLength,
OpFoldResult headDimension, Location loc,
OpBuilder &builder) {
return extractOrInsertOutputSlice(src, dst, {}, sequenceTileLength,
headDimension, loc, builder);
}
} // namespace
/// Tile iree_linalg_ext.attention.
/// TODO: Adopt getTiledImplementation with this.
IREE::LinalgExt::AttentionOp tileAttention(IREE::LinalgExt::AttentionOp attnOp,
SmallVectorImpl<Operation *> &ops,
RewriterBase &rewriter,
std::optional<uint64_t> tileSize) {
Location loc = attnOp.getLoc();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(attnOp);
Value query = attnOp.getQuery();
ShapedType queryType = attnOp.getQueryType();
Type elementType = queryType.getElementType();
ArrayRef<int64_t> queryShape = queryType.getShape();
SmallVector<OpFoldResult> queryDimValues =
tensor::getMixedSizes(rewriter, loc, query);
OpFoldResult headDimension = queryDimValues[2];
OpFoldResult sequenceTileLength = queryDimValues[1];
OpFoldResult keyValueTileLength = sequenceTileLength;
SmallVector<int64_t> keyShape{queryShape};
if (tileSize) {
keyValueTileLength = rewriter.getIndexAttr(tileSize.value());
for (auto it : llvm::enumerate(attnOp.getKeyType().getShape())) {
keyShape[it.index()] = it.index() == 1 ? tileSize.value() : it.value();
}
}
Value key = attnOp.getKey();
Value value = attnOp.getValue();
SmallVector<OpFoldResult> keyDimValues =
tensor::getMixedSizes(rewriter, loc, key);
OpFoldResult sequenceLength = keyDimValues[1];
// Create output accumulator
Value output = attnOp.getOutput();
Type f32Type = rewriter.getF32Type();
SmallVector<OpFoldResult> accShape{queryDimValues[1], queryDimValues[2]};
Value accumulatorF32 =
rewriter.create<tensor::EmptyOp>(loc, accShape, f32Type);
// Create accumulator, max and sum statistics
Value outputSlice = extractOutputSlice(output, queryShape, sequenceTileLength,
headDimension, loc, rewriter);
Value zeroF32 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(f32Type));
auto accumulatorFill =
rewriter.create<linalg::FillOp>(loc, ValueRange{zeroF32}, accumulatorF32);
accumulatorF32 = accumulatorFill.result();
ops.push_back(accumulatorFill);
Value largeNegativeF32 = rewriter.create<arith::ConstantOp>(
loc, rewriter.getFloatAttr(f32Type, -1.0e+30));
SmallVector<OpFoldResult> dims{sequenceTileLength};
Value max = rewriter.create<tensor::EmptyOp>(loc, dims, f32Type);
auto maxFill =
rewriter.create<linalg::FillOp>(loc, ValueRange{largeNegativeF32}, max);
Value negativeMax = maxFill.result();
ops.push_back(maxFill);
Value sum = rewriter.create<tensor::EmptyOp>(loc, dims, f32Type);
auto sumFill = rewriter.create<linalg::FillOp>(loc, ValueRange{zeroF32}, sum);
Value zeroSum = sumFill.result();
ops.push_back(sumFill);
// Construct sequential loop
SmallVector<Value> ivs;
Value zeroValue = rewriter.create<arith::ConstantIndexOp>(loc, 0);
scf::LoopNest loopNest = createLoopNest(
ivs, zeroValue,
getValueOrCreateConstantIndexOp(rewriter, loc, keyValueTileLength),
getValueOrCreateConstantIndexOp(rewriter, loc, sequenceLength),
ValueRange({accumulatorF32, negativeMax, zeroSum}), loc, rewriter);
ops.push_back(loopNest.loops.back());
Value iterArgResult = loopNest.loops.back().getRegionIterArg(0);
Value iterArgMax = loopNest.loops.back().getRegionIterArg(1);
Value iterArgSum = loopNest.loops.back().getRegionIterArg(2);
OpBuilder::InsertionGuard guardSecondLoop(rewriter);
rewriter.setInsertionPointToStart(loopNest.loops.back().getBody());
// Extract slices
Value keySlice = extractSlice(key, keyShape, ivs, keyValueTileLength,
headDimension, elementType, loc, rewriter);
Value valueSlice = extractSlice(value, keyShape, ivs, keyValueTileLength,
headDimension, elementType, loc, rewriter);
Value querySlice = extractSlice(query, queryShape, {}, sequenceTileLength,
headDimension, elementType, loc, rewriter);
auto tiledAttentionOp = rewriter.create<IREE::LinalgExt::AttentionOp>(
attnOp.getLoc(),
SmallVector<Type>{accumulatorF32.getType(), sum.getType(), max.getType()},
SmallVector<Value>{querySlice, keySlice, valueSlice},
SmallVector<Value>{iterArgResult, iterArgMax, iterArgSum});
Value tiledResult = tiledAttentionOp.getResult(0);
Value newMax = tiledAttentionOp.getResult(1);
Value newSum = tiledAttentionOp.getResult(2);
if (scf::YieldOp yieldOp = dyn_cast<scf::YieldOp>(
loopNest.loops.back().getBody()->getTerminator())) {
OpBuilder::InsertionGuard yieldGuard(rewriter);
rewriter.setInsertionPoint(yieldOp);
rewriter.replaceOpWithNewOp<scf::YieldOp>(
yieldOp, ValueRange{tiledResult, newMax, newSum});
}
OpBuilder::InsertionGuard yieldGuard(rewriter);
rewriter.setInsertionPointAfter(loopNest.loops.back());
loopNest.results[0] = applyFinalScaling(
loopNest.results[0], loopNest.results[2], loc, rewriter, ops);
if (elementType.isF16()) {
loopNest.results[0] =
truncateToF16(loopNest.results[0], outputSlice, ops, rewriter, loc);
}
loopNest.results[0] =
insertOutputSlice(loopNest.results[0], output, sequenceTileLength,
headDimension, loc, rewriter);
rewriter.replaceOp(attnOp, loopNest.results[0]);
ops.push_back(tiledAttentionOp);
return tiledAttentionOp;
}
/// Decompose tiled iree_linalg_ext.attention op.
/// TODO: Adopt decomposeOperation with this.
void decomposeTiledAttention(IREE::LinalgExt::AttentionOp tiledAttnOp,
SmallVectorImpl<Operation *> &ops,
RewriterBase &rewriter,
std::optional<uint64_t> tileSize) {
Location loc = tiledAttnOp.getLoc();
Value keySlice = tiledAttnOp.getKey();
Value valueSlice = tiledAttnOp.getValue();
Value querySlice = tiledAttnOp.getQuery();
Value tiledResult = tiledAttnOp.getOutput();
Value max = *tiledAttnOp.getMax();
Value sum = *tiledAttnOp.getSum();
assert(max && "expected max statistic operand to be present");
assert(sum && "expected sum statistic operand to be present");
OpBuilder::InsertionGuard withinScfLoop(rewriter);
rewriter.setInsertionPointAfter(tiledAttnOp);
SmallVector<OpFoldResult> queryDimValues =
tensor::getMixedSizes(rewriter, loc, querySlice);
OpFoldResult headDimension = queryDimValues[1];
OpFoldResult sequenceTileLength = queryDimValues[0];
OpFoldResult keyValueTileLength =
tileSize ? rewriter.getIndexAttr(tileSize.value()) : sequenceTileLength;
Type elementType = tiledAttnOp.getQueryType().getElementType();
auto [result, newMax, newSum] =
createAttentionBody(keySlice, valueSlice, querySlice, tiledResult, max,
sum, sequenceTileLength, keyValueTileLength,
headDimension, elementType, ops, loc, rewriter);
rewriter.replaceOp(tiledAttnOp, ValueRange{result, newMax, newSum});
}
/// Utility function which tiles and then decomposes attention op via
/// FlashAttention algorithm.
void tileAndDecomposeAttention(IREE::LinalgExt::AttentionOp attnOp,
SmallVectorImpl<Operation *> &ops,
RewriterBase &rewriter, bool onlyTile,
std::optional<uint64_t> tileSize) {
IREE::LinalgExt::AttentionOp tiledAttentionOp =
tileAttention(attnOp, ops, rewriter, tileSize);
if (onlyTile)
return;
decomposeTiledAttention(tiledAttentionOp, ops, rewriter, tileSize);
}
namespace {
/// This is an implementation of flash attention which
/// is a tiled and fused implementation of the attention operator.
/// The attention operator computes:
/// matmul(softmax(matmul(Q, transpose(K))), V)
/// where: Q is the query matrix [B x N x d]
/// K is the key matrix [B x S x d]
/// V is the value matrix [B x S x d]
///
/// The core algorithm is as follows:
/// For each element in B,
/// 1. Load a tile from the Q matrix of size T x d -> q
/// 2. Initialize statistics: running_sum, running_max
/// 3. for i = 0 to S with step T
/// a. Load a tile from the K matrix of size T x d -> k
/// b. Load a tile from the V matrix of size T x d -> v
/// c. Compute matmul_transpose_b(q, k) -> qkT
/// d. Compute max(max(qkT) along rows, old_max) -> new_max
/// e. Compute curent estimate of softmax: exp(qKT - current_max) -> s
/// f. Compute product of fixup and old_sum -> fsum
/// g. Compute sum(sum(qkT) along rows, fsum) -> new_sum
/// h. Compute 1.0 / new_sum -> inv_new_sum
/// i. Compute softmax = softmax * inv_new_sum
/// j. Truncate softmax to fp16
/// k. Compute fsum * inv_new_sum * accumulator -> new_accumulator
/// j. Compute matmul(s, v) and add new_accumulator
///
///
LogicalResult reifyAttentionTransform(func::FuncOp funcOp, bool onlyTile,
std::optional<uint64_t> tileSize) {
IRRewriter rewriter(funcOp.getContext());
funcOp.walk([&](IREE::LinalgExt::AttentionOp attnOp) {
SmallVector<Operation *> ops;
tileAndDecomposeAttention(attnOp, ops, rewriter, onlyTile, tileSize);
return WalkResult::advance();
});
return success();
}
} // namespace
namespace {
struct TileAndDecomposeAttentionPass
: public TileAndDecomposeAttentionBase<TileAndDecomposeAttentionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<
affine::AffineDialect, IREE::LinalgExt::IREELinalgExtDialect,
linalg::LinalgDialect, scf::SCFDialect, tensor::TensorDialect>();
}
TileAndDecomposeAttentionPass() = default;
TileAndDecomposeAttentionPass(bool onlyTile, uint64_t tileSize) {
this->onlyTile = onlyTile;
this->tileSize = tileSize;
}
TileAndDecomposeAttentionPass(const TileAndDecomposeAttentionPass &pass) {
onlyTile = pass.onlyTile;
tileSize = pass.tileSize;
}
void runOnOperation() override;
};
} // namespace
void TileAndDecomposeAttentionPass::runOnOperation() {
MLIRContext *context = &getContext();
IRRewriter rewriter(context);
std::optional<uint64_t> optionalTileSize{std::nullopt};
if (tileSize.hasValue())
optionalTileSize = tileSize.getValue();
if (failed(
reifyAttentionTransform(getOperation(), onlyTile, optionalTileSize)))
return signalPassFailure();
}
std::unique_ptr<Pass> createTileAndDecomposeAttentionPass() {
return std::make_unique<TileAndDecomposeAttentionPass>();
}
} // namespace LinalgExt
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir