blob: 21160d55664104c0d46fa20566c908030bbe1674 [file] [log] [blame]
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/OpDefinition.h"
#define DEBUG_TYPE "linalg-ext-tiling"
namespace mlir::iree_compiler::IREE::LinalgExt {
//===----------------------------------------------------------------------===//
// Utils.
//===----------------------------------------------------------------------===//
/// Returns the size and offset scaled by some scale factor, and clamped to a
/// dimSize for the dimension. `(offset + size) * scale` will be clamped to the
/// `dimSize`.
static std::pair<OpFoldResult, OpFoldResult>
getScaledSizeAndOffset(OpBuilder &builder, Location loc, OpFoldResult size,
OpFoldResult offset, OpFoldResult dimSize,
int64_t offsetScale, int64_t sizeScale) {
AffineExpr dim0, dim1, dim2;
auto ctx = builder.getContext();
bindDims(ctx, dim0, dim1, dim2);
auto imageOffset = affine::makeComposedFoldedAffineApply(
builder, loc, {dim0 * offsetScale}, offset);
auto dimSizeValue = getValueOrCreateConstantIndexOp(builder, loc, dimSize);
AffineMap sizeMap =
AffineMap::get(3, 0, {dim0 - dim1, dim2 * sizeScale}, ctx);
auto imageSize = affine::makeComposedFoldedAffineMin(
builder, loc, sizeMap, {dimSizeValue, imageOffset, size});
return std::make_pair(imageSize, imageOffset);
}
/// If the input has a fully static shape, return the static sizes. Otherwise,
/// attempt to reify the shape of the input from its defining op. Input dims
/// are store into `reifiedInputDims`.
static LogicalResult
getStaticOrReifiedInputDims(OpBuilder &builder, Location loc, Value input,
ReifiedRankedShapedTypeDims &reifiedInputDims) {
if (auto reifyOp = input.getDefiningOp<ReifyRankedShapedTypeOpInterface>()) {
return reifyOp.reifyResultShapes(builder, reifiedInputDims);
}
auto inputType = cast<ShapedType>(input.getType());
if (!inputType.hasStaticShape()) {
return failure();
}
reifiedInputDims.push_back(
getAsIndexOpFoldResult(builder.getContext(), inputType.getShape()));
return success();
}
//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
SmallVector<utils::IteratorType> ScatterOp::getLoopIteratorTypes() {
SmallVector<utils::IteratorType> iteratorTypes(getUpdateType().getRank(),
utils::IteratorType::parallel);
if (!getUniqueIndices()) {
int64_t batchRank = getBatchRank();
for (auto i : llvm::seq<int64_t>(0, batchRank)) {
iteratorTypes[i] = utils::IteratorType::reduction;
}
}
return iteratorTypes;
}
SmallVector<Range> ScatterOp::getIterationDomain(OpBuilder &builder) {
Location loc = getLoc();
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
SmallVector<Range> ranges;
for (auto dim : llvm::seq<int64_t>(0, getUpdateType().getRank())) {
OpFoldResult ub = getDim(builder, loc, getUpdates(), dim);
ranges.emplace_back(Range{zero, ub, one});
}
return ranges;
}
FailureOr<TilingResult>
ScatterOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
assert(offsets.size() >= 1 && sizes.size() >= 1);
Location loc = getLoc();
auto zeroAttr = builder.getI64IntegerAttr(0);
auto oneAttr = builder.getI64IntegerAttr(1);
SmallVector<Operation *> slices;
// Slice of the updates.
auto updateRank = getUpdateType().getRank();
SmallVector<OpFoldResult> updateStrides(updateRank, oneAttr);
Operation *updateSlice =
getSlice(builder, loc, getUpdates(), offsets, sizes, updateStrides);
if (!updateSlice) {
return emitOpError("failed to get updates slice");
}
Value tiledUpdate = updateSlice->getResult(0);
slices.push_back(updateSlice);
// Slice of indices.
auto indicesRank = getIndicesType().getRank();
SmallVector<OpFoldResult> indicesOffsets(offsets.take_front(getBatchRank()));
SmallVector<OpFoldResult> indicesSizes(sizes.take_front(getBatchRank()));
if (getBatchRank() != getIndicesType().getRank()) {
indicesOffsets.push_back(zeroAttr);
indicesSizes.push_back(builder.getIndexAttr(getIndexDepth()));
}
SmallVector<OpFoldResult> indicesStrides(indicesRank, oneAttr);
Operation *indicesSlice = getSlice(builder, loc, getIndices(), indicesOffsets,
indicesSizes, indicesStrides);
if (!indicesSlice) {
return emitOpError("failed to get indices slices");
}
Value tiledIndices = indicesSlice->getResult(0);
slices.push_back(indicesSlice);
// Slice of the original.
SmallVector<OpFoldResult> originalOffsets, originalSizes;
if (failed(getResultTilePosition(builder, 0, offsets, sizes, originalOffsets,
originalSizes))) {
return {};
}
auto originalRank = getOriginalType().getRank();
SmallVector<OpFoldResult> originalStrides(originalRank, oneAttr);
Operation *originalSlice =
getSlice(builder, loc, getOriginal(), originalOffsets, originalSizes,
originalStrides);
if (!originalSlice) {
return emitOpError("failed to get original tensor slice");
}
Value tiledOriginal = originalSlice->getResult(0);
slices.push_back(originalSlice);
SmallVector<Type> resultTypes;
if (getNumResults()) {
resultTypes.push_back(tiledOriginal.getType());
}
Operation *tiledScatterOp =
mlir::clone(builder, getOperation(), resultTypes,
ValueRange{tiledUpdate, tiledIndices, tiledOriginal});
return TilingResult{{tiledScatterOp},
SmallVector<Value>(tiledScatterOp->getResults()),
slices};
}
LogicalResult ScatterOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
auto zeroAttr = builder.getI64IntegerAttr(0);
// Slice of the original.
auto originalRank = getOriginalType().getRank();
resultOffsets.resize(originalRank, zeroAttr);
resultSizes.resize(originalRank);
auto updateRank = getUpdateType().getRank();
Location loc = getLoc();
for (auto dim : llvm::seq<int64_t>(0, originalRank - getUpdateSliceRank())) {
resultSizes[dim] = getDim(builder, loc, getOriginal(), dim);
}
for (auto dim :
llvm::seq<int64_t>(originalRank - getUpdateSliceRank(), originalRank)) {
resultOffsets[dim] = offsets[dim - (originalRank - updateRank)];
resultSizes[dim] = sizes[dim - (originalRank - updateRank)];
}
return success();
}
/// Method to return the position of the result tile computed by the tiled
/// operation.
LogicalResult ScatterOp::getIterationDomainTileFromOperandTile(
OpBuilder &b, unsigned operandNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
SmallVectorImpl<OpFoldResult> &iterDomainSizes) {
// Fusion with producers is not possible in general if `unique_indices` is not
// true as reductions along the scattered indices are not tilable in parallel.
if (!getUniqueIndices()) {
return failure();
}
// TODO: Support fusion along the index operand. For the index operand, the
// offset + size must be the full size for the inner most dim.
if (getInputs().getBeginOperandIndex() != operandNumber) {
return failure();
}
// The iteration domain is defined in terms of the |input|, so simply
// use the given offsets/sizes.
iterDomainOffsets.assign(offsets.begin(), offsets.end());
iterDomainSizes.assign(sizes.begin(), sizes.end());
return success();
}
/// Method to generate the tiled implementation of an operation from the tile
/// of the operand.
FailureOr<TilingResult> ScatterOp::getTiledImplementationFromOperandTile(
OpBuilder &b, unsigned operandNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
if (failed(getIterationDomainTileFromOperandTile(
b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
return failure();
}
return getTiledImplementation(b, mappedOffsets, mappedSizes);
}
LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b,
Location loc,
ValueRange ivs) {
auto indexDepth = getIndexDepth();
Value update = b.create<memref::LoadOp>(loc, getUpdates(), ivs);
SmallVector<Value> starts;
SmallVector<Value> loadIndices;
append_range(loadIndices, ivs.take_front(getBatchRank()));
// Populate with empty values.
auto originalTy = getOriginalType();
starts.resize(originalTy.getRank(), Value());
auto updateIvs = ivs.drop_front(getBatchRank());
int64_t offset = starts.size() - updateIvs.size();
for (auto [idx, iv] : llvm::enumerate(updateIvs)) {
starts[idx + offset] = iv;
}
ArrayRef<int64_t> dimMap = getDimensionMap();
if (getIndicesType().getRank() > getBatchRank()) {
loadIndices.push_back(Value());
}
for (auto i : llvm::seq<unsigned>(0, indexDepth)) {
if (getIndicesType().getRank() > getBatchRank()) {
loadIndices.back() = b.create<arith::ConstantIndexOp>(loc, i);
}
Value idx = b.create<memref::LoadOp>(loc, getIndices(), loadIndices);
Value ret = b.create<arith::IndexCastOp>(loc, b.getIndexType(), idx);
auto dim = dimMap[i];
if (starts[dim])
ret = b.create<arith::AddIOp>(loc, ret, starts[dim]);
starts[dim] = ret;
}
Value init = b.create<memref::LoadOp>(loc, getOriginal(), starts);
IRMapping bvm;
Block &block = getRegion().front();
bvm.map(block.getArgument(0), update);
bvm.map(block.getArgument(1), init);
for (auto &blockOp : block.without_terminator()) {
b.clone(blockOp, bvm);
}
// The last op is linalg_ext.yield op. Store the operand to
// destination.
b.create<memref::StoreOp>(
loc, bvm.lookupOrDefault(block.getTerminator()->getOperand(0)),
getOriginal(), starts);
return success();
}
//===----------------------------------------------------------------------===//
// SortOp
//===----------------------------------------------------------------------===//
SmallVector<utils::IteratorType> SortOp::getLoopIteratorTypes() {
// All loops except the dimension to sort along are parallel.
SmallVector<utils::IteratorType> iteratorTypes(getOperandRank(),
utils::IteratorType::parallel);
iteratorTypes[getDimension()] = utils::IteratorType::reduction;
return iteratorTypes;
}
SmallVector<Range> SortOp::getIterationDomain(OpBuilder &builder) {
int64_t operandRank = getOperandRank();
SmallVector<Range> loopBounds(operandRank);
Location loc = getLoc();
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
Value source = getOperand(0);
for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
loopBounds[dim].offset = zero;
loopBounds[dim].size = getDimValue(builder, loc, source, dim);
loopBounds[dim].stride = one;
}
return loopBounds;
}
FailureOr<TilingResult>
SortOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
int64_t rank = getOperandRank();
assert(offsets.size() == static_cast<size_t>(rank) &&
sizes.size() == static_cast<size_t>(rank));
auto oneAttr = builder.getI64IntegerAttr(1);
SmallVector<OpFoldResult> strides(rank, oneAttr);
SmallVector<Operation *> slices;
SmallVector<Value> tiledOperands(getOutputs().size());
for (auto [idx, output] : llvm::enumerate(getOutputs())) {
Operation *slice =
getSlice(builder, getLoc(), output, offsets, sizes, strides);
if (!slice) {
return emitOpError("failed to get slice of operand ") << idx;
}
tiledOperands[idx] = slice->getResult(0);
slices.push_back(slice);
}
SmallVector<Type, 4> resultTypes;
if (getNumResults()) {
resultTypes = llvm::map_to_vector<4>(tiledOperands,
[&](Value v) { return v.getType(); });
}
Operation *tiledSortOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
return TilingResult{
{tiledSortOp}, SmallVector<Value>{tiledSortOp->getResults()}, slices};
}
LogicalResult SortOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
resultOffsets = llvm::to_vector(offsets);
resultSizes = llvm::to_vector(sizes);
return success();
}
LogicalResult SortOp::generateScalarImplementation(OpBuilder &b, Location loc,
ValueRange ivs) {
auto sortDim = getDimension();
SmallVector<Value> indices, sortBlkArgs;
indices.append(ivs.begin(), ivs.end());
// Bubble sort innermost loop.
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
Value ub;
if (getOperandType(0).isDynamicDim(sortDim)) {
ub = b.create<memref::DimOp>(loc, getOperand(0), sortDim);
} else {
ub = b.create<arith::ConstantIndexOp>(
loc, getOperandType(0).getDimSize(sortDim));
}
ub = b.create<arith::SubIOp>(loc, ub, one);
auto scfFor = b.create<scf::ForOp>(
loc, zero, ub, one, ValueRange{},
[&](OpBuilder &b, Location loc, Value iv, ValueRange iters) {
SmallVector<Value> indices(ivs);
Value ivPlusOne = b.create<arith::AddIOp>(loc, iv, one);
for (auto output : getDpsInits()) {
indices[sortDim] = iv;
sortBlkArgs.push_back(b.create<memref::LoadOp>(loc, output, indices));
indices[sortDim] = ivPlusOne;
sortBlkArgs.push_back(b.create<memref::LoadOp>(loc, output, indices));
}
});
auto &srcBlock = getRegion().front();
Region &region = scfFor.getRegion();
IRMapping bvm;
{
OpBuilder::InsertionGuard guard(b);
auto &block = region.front();
b.setInsertionPointToEnd(&block);
for (auto it : llvm::zip_equal(srcBlock.getArguments(), sortBlkArgs)) {
bvm.map(std::get<0>(it), std::get<1>(it));
}
for (auto &blockOp : srcBlock.without_terminator()) {
b.clone(blockOp, bvm);
}
}
Value cond = bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0));
OpBuilder::InsertionGuard g(b);
b.setInsertionPointToEnd(&region.front());
b.create<scf::IfOp>(
loc, cond,
[&](OpBuilder &b, Location loc) {
// Do not swap the pairs if true.
b.create<scf::YieldOp>(loc);
},
[&](OpBuilder &b, Location loc) {
// Swap the pairs if false.
SmallVector<Value> indices(ivs.begin(), ivs.end());
Value ivPlusOne =
b.create<arith::AddIOp>(loc, scfFor.getInductionVar(), one);
for (int i = 0, e = getNumDpsInits(); i < e; ++i) {
Value v1 = sortBlkArgs[i * 2];
Value v2 = sortBlkArgs[i * 2 + 1];
indices[sortDim] = scfFor.getInductionVar();
b.create<memref::StoreOp>(loc, v2, getDpsInits()[i], indices);
indices[sortDim] = ivPlusOne;
b.create<memref::StoreOp>(loc, v1, getDpsInits()[i], indices);
}
b.create<scf::YieldOp>(loc);
});
b.create<scf::YieldOp>(loc);
return success();
}
//===----------------------------------------------------------------------===//
// FftOp
//===----------------------------------------------------------------------===//
SmallVector<utils::IteratorType> FftOp::getLoopIteratorTypes() {
// There are `rank-1` outer loops. The fft itselfs has one loop for each
// stage, which handles the merge step -- taking two half size tensors and
// merge them into one tensor.
SmallVector<utils::IteratorType> iteratorTypes(getOperandRank(),
utils::IteratorType::parallel);
return iteratorTypes;
}
SmallVector<Range> FftOp::getIterationDomain(OpBuilder &builder) {
SmallVector<Range> res;
Location loc = getLoc();
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
for (auto [idx, val] : llvm::enumerate(getOperandShape().drop_back())) {
Value size;
if (ShapedType::isDynamic(val)) {
size = getDimValue(builder, loc, getReal(), idx);
} else {
size = builder.create<arith::ConstantIndexOp>(loc, val);
}
res.emplace_back(Range{/*offset=*/zero, size, /*stride=*/one});
}
Value size = getDimValue(builder, loc, getReal(), getOperandRank() - 1);
Value stride = builder.create<arith::ShLIOp>(loc, one, getStage());
res.emplace_back(Range{/*offset=*/zero, size, /*stride=*/stride});
return res;
}
void FftOp::generateScalarImplWithoutCoeffBuf(OpBuilder &b, Location loc,
ArrayRef<Value> operands,
Value wholeSize) {
auto rank = getOperandRank();
SmallVector<AffineMap> maps(operands.size(), b.getMultiDimIdentityMap(rank));
auto f32Type = b.getF32Type();
auto indexToF32 = [](OpBuilder &builder, Location loc, Value v) -> Value {
v = builder.create<arith::IndexCastOp>(loc, builder.getI32Type(), v);
return builder.create<arith::SIToFPOp>(loc, builder.getF32Type(), v);
};
// We will need exp(-2 * PI * j / m * I), compute "-2 * PI / m" for imag part
// first.
Value coeff = b.create<arith::ConstantFloatOp>(
loc, llvm::APFloat(static_cast<float>(-2 * acos(-1))), f32Type);
coeff = b.create<arith::DivFOp>(loc, coeff, indexToF32(b, loc, wholeSize));
b.create<linalg::GenericOp>(
loc, TypeRange{}, ValueRange{}, operands, maps, getLoopIteratorTypes(),
[&](OpBuilder &b, Location loc, ValueRange args) {
Value lhsReal = args[0];
Value lhsImag = args[1];
Value rhsReal = args[2];
Value rhsImag = args[3];
// Compute "-2 * PI / m * j"
Value w = b.create<arith::MulFOp>(
loc, coeff,
indexToF32(b, loc, b.create<linalg::IndexOp>(loc, rank - 1)));
Value wReal = b.create<math::CosOp>(loc, w);
Value wImag = b.create<math::SinOp>(loc, w);
// t = w * a[k + j + mh];
// -> (x + yi)(u + vi) = (xu - yv) + (xv + yu)i
Value xu = b.create<arith::MulFOp>(loc, wReal, rhsReal);
Value yv = b.create<arith::MulFOp>(loc, wImag, rhsImag);
Value xv = b.create<arith::MulFOp>(loc, wReal, rhsImag);
Value yu = b.create<arith::MulFOp>(loc, wImag, rhsReal);
Value tReal = b.create<arith::SubFOp>(loc, xu, yv);
Value tImag = b.create<arith::AddFOp>(loc, xv, yu);
// cplx u = a[k + j];
// a[k + j] = u + t;
// a[k + j + mh] = u - t;
Value r1 = b.create<arith::AddFOp>(loc, lhsReal, tReal);
Value r2 = b.create<arith::AddFOp>(loc, lhsImag, tImag);
Value r3 = b.create<arith::SubFOp>(loc, lhsReal, tReal);
Value r4 = b.create<arith::SubFOp>(loc, lhsImag, tImag);
b.create<linalg::YieldOp>(loc, ValueRange{r1, r2, r3, r4});
});
}
void FftOp::generateScalarImplWithCoeffBuf(OpBuilder &b, Location loc,
ArrayRef<Value> operands) {
auto rank = getOperandRank();
SmallVector<AffineMap> maps;
// The size of coefficent buffer is epxected to match `2^(stage-1)`, which
// equals to the last dim of operands.
maps.append(
2, AffineMap::get(rank, 0, b.getAffineDimExpr(rank - 1), b.getContext()));
maps.append(operands.size(), b.getMultiDimIdentityMap(rank));
b.create<linalg::GenericOp>(
loc, TypeRange{}, ValueRange{getRealCoeff(), getImagCoeff()}, operands,
maps, getLoopIteratorTypes(),
[&](OpBuilder &b, Location loc, ValueRange args) {
Value wReal = args[0];
Value wImag = args[1];
Value lhsReal = args[2];
Value lhsImag = args[3];
Value rhsReal = args[4];
Value rhsImag = args[5];
// t = w * a[k + j + mh];
// -> (x + yi)(u + vi) = (xu - yv) + (xv + yu)i
Value xu = b.create<arith::MulFOp>(loc, wReal, rhsReal);
Value yv = b.create<arith::MulFOp>(loc, wImag, rhsImag);
Value xv = b.create<arith::MulFOp>(loc, wReal, rhsImag);
Value yu = b.create<arith::MulFOp>(loc, wImag, rhsReal);
Value tReal = b.create<arith::SubFOp>(loc, xu, yv);
Value tImag = b.create<arith::AddFOp>(loc, xv, yu);
// cplx u = a[k + j];
// a[k + j] = u + t;
// a[k + j + mh] = u - t;
Value r1 = b.create<arith::AddFOp>(loc, lhsReal, tReal);
Value r2 = b.create<arith::AddFOp>(loc, lhsImag, tImag);
Value r3 = b.create<arith::SubFOp>(loc, lhsReal, tReal);
Value r4 = b.create<arith::SubFOp>(loc, lhsImag, tImag);
b.create<linalg::YieldOp>(loc, ValueRange{r1, r2, r3, r4});
});
}
// Generates FFT stage scalar implementation. This follows Cooley–Tukey FFT
// algorithm. The pseudo reference code is:
// let s <- stage of linalg_ext.fft
// int m = 1 << s;
// int mh = m >> 1;
// for (int k = 0; k < n; k += m) {
// for (int j = 0; j < mh; ++j) {
// cplx w = exp(-2 * PI * j / m * I);
// cplx t = w * a[k + j + mh];
// cplx u = a[k + j];
// a[k + j] = u + t;
// a[k + j + mh] = u - t;
// }
// }
LogicalResult FftOp::generateScalarImplementation(OpBuilder &b, Location loc,
ValueRange ivs) {
Value real = getReal();
Value imag = getImag();
Value stage = getStage();
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
Value wholeSize = b.create<arith::ShLIOp>(loc, one, stage);
Value halfSize = b.create<arith::ShRSIOp>(loc, wholeSize, one);
auto rank = getOperandRank();
SmallVector<Value> operands;
SmallVector<OpFoldResult> lhsIvs(ivs.begin(), ivs.end());
SmallVector<OpFoldResult> ones(rank, b.getIndexAttr(1));
SmallVector<OpFoldResult> sizes(rank, b.getIndexAttr(1));
sizes.back() = halfSize;
operands.push_back(
b.create<memref::SubViewOp>(loc, real, lhsIvs, sizes, ones));
operands.push_back(
b.create<memref::SubViewOp>(loc, imag, lhsIvs, sizes, ones));
SmallVector<OpFoldResult> rhsIvs(ivs.begin(), ivs.end());
rhsIvs.back() =
b.create<arith::AddIOp>(loc, ivs.back(), halfSize).getResult();
operands.push_back(
b.create<memref::SubViewOp>(loc, real, rhsIvs, sizes, ones));
operands.push_back(
b.create<memref::SubViewOp>(loc, imag, rhsIvs, sizes, ones));
if (hasCoeff()) {
generateScalarImplWithCoeffBuf(b, loc, operands);
} else {
generateScalarImplWithoutCoeffBuf(b, loc, operands, wholeSize);
}
return success();
}
FailureOr<TilingResult>
FftOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
int64_t rank = getOperandRank();
SmallVector<OpFoldResult> strides(rank, builder.getI64IntegerAttr(1));
SmallVector<Operation *> slices;
SmallVector<Value> tiledOperands(3);
tiledOperands[0] = getStage();
tiledOperands[1] = getRealCoeff();
tiledOperands[2] = getImagCoeff();
SmallVector<Type, 4> resultTypes;
for (auto [index, out] : llvm::enumerate(getOutputs())) {
Operation *slice =
getSlice(builder, getLoc(), out, offsets, sizes, strides);
if (!slice) {
return emitOpError("failed to get slice of output ") << index;
}
tiledOperands.push_back(slice->getResult(0));
slices.push_back(slice);
if (hasPureTensorSemantics()) {
resultTypes.push_back(tiledOperands.back().getType());
}
}
Operation *tiledFftOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
return TilingResult{
{tiledFftOp}, SmallVector<Value>(tiledFftOp->getResults()), slices};
}
LogicalResult FftOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
resultOffsets.assign(offsets.begin(), offsets.end());
resultSizes.assign(sizes.begin(), sizes.end());
return success();
}
//===----------------------------------------------------------------------===//
// ScanOp
//===----------------------------------------------------------------------===//
SmallVector<Range> ScanOp::getIterationDomain(OpBuilder &builder) {
int64_t operandRank = getOperandRank();
SmallVector<Range> loopBounds(operandRank);
Location loc = getLoc();
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
Value source = getInput();
for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
loopBounds[dim].offset = zero;
loopBounds[dim].size = getDimValue(builder, loc, source, dim);
loopBounds[dim].stride = one;
}
return loopBounds;
}
SmallVector<utils::IteratorType> ScanOp::getLoopIteratorTypes() {
SmallVector<utils::IteratorType> iteratorTypes(getOperandRank(),
utils::IteratorType::parallel);
iteratorTypes[getDimension()] = utils::IteratorType::reduction;
return iteratorTypes;
}
// Generates naive scalar implementation of scan for a given operator f.
// For inclusive,
// output[0] = input[0]
// output[i] = f(output[i-1], input[i])
//
// For exclusive,
// output[0] = 0
// output[i] = f(output[i-1], input[i-1])
LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc,
ValueRange ivs) {
SmallVector<Value> indices, scanBlkArgs;
indices.append(ivs.begin(), ivs.end());
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
auto scanDim = getDimension();
auto cond = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
indices[scanDim], zero);
bool isInclusive = getInclusive();
SmallVector<Value> accIndices;
for (int i = 0; i < indices.size(); i++) {
if (i != scanDim) {
accIndices.push_back(indices[i]);
}
}
auto scfIf = b.create<scf::IfOp>(
loc, cond,
[&](OpBuilder &b, Location loc) {
if (isInclusive) {
auto value = b.create<memref::LoadOp>(loc, getInput(), indices);
b.create<memref::StoreOp>(loc, value, getOutput(), indices);
} else {
auto value =
b.create<memref::LoadOp>(loc, getAccumulator(), accIndices);
b.create<memref::StoreOp>(loc, value, getOutput(), indices);
}
b.create<scf::YieldOp>(loc);
},
[&](OpBuilder &b, Location loc) {
SmallVector<Value> indices(ivs.begin(), ivs.end());
Value iv = indices[scanDim];
Value ivMinusOne = b.create<arith::SubIOp>(loc, iv, one);
indices[scanDim] = ivMinusOne;
scanBlkArgs.push_back(
b.create<memref::LoadOp>(loc, getOutput(), indices));
Value i0;
if (!isInclusive)
i0 = b.create<memref::LoadOp>(loc, getInput(), indices);
indices[scanDim] = iv;
if (isInclusive)
i0 = b.create<memref::LoadOp>(loc, getInput(), indices);
scanBlkArgs.push_back(i0);
});
auto &srcBlock = getRegion().front();
Region &region = scfIf.getElseRegion();
IRMapping bvm;
{
OpBuilder::InsertionGuard guard(b);
auto &block = region.front();
b.setInsertionPointToEnd(&block);
for (auto it : llvm::zip_equal(srcBlock.getArguments(), scanBlkArgs)) {
bvm.map(std::get<0>(it), std::get<1>(it));
}
for (auto &blockOp : srcBlock.without_terminator()) {
b.clone(blockOp, bvm);
}
b.create<memref::StoreOp>(
loc, bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)),
getOutput(), indices);
b.create<memref::StoreOp>(
loc, bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)),
getAccumulator(), accIndices);
b.create<scf::YieldOp>(loc);
}
return success();
}
FailureOr<TilingResult>
ScanOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
int64_t rank = getOperandRank();
assert(offsets.size() == static_cast<size_t>(rank) &&
sizes.size() == static_cast<size_t>(rank));
auto oneAttr = builder.getI64IntegerAttr(1);
SmallVector<OpFoldResult> strides(rank, oneAttr);
SmallVector<Value> tiledOperands;
SmallVector<Operation *> slices;
// Input
{
Operation *inputSlice =
getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
if (!inputSlice) {
return emitOpError("failed to get input slice");
}
tiledOperands.emplace_back(inputSlice->getResult(0));
slices.push_back(inputSlice);
}
// Output 0
{
Operation *output0Slice =
getSlice(builder, getLoc(), getOutputs()[0], offsets, sizes, strides);
if (!output0Slice) {
return emitOpError("failed to get slice of output 0");
}
tiledOperands.emplace_back(output0Slice->getResult(0));
slices.push_back(output0Slice);
}
if (rank > 1) {
SmallVector<OpFoldResult> accumOffsets, accumSizes;
if (failed(getResultTilePosition(builder, 1, offsets, sizes, accumOffsets,
accumSizes))) {
return {};
}
SmallVector<OpFoldResult> accumStrides(rank - 1, oneAttr);
Operation *output1Slice = getSlice(builder, getLoc(), getOutputs()[1],
accumOffsets, accumSizes, accumStrides);
if (!output1Slice) {
return emitOpError("failed to get output1 slice");
}
tiledOperands.emplace_back(output1Slice->getResult(0));
slices.push_back(output1Slice);
} else {
tiledOperands.emplace_back(getOutputs()[1]);
}
SmallVector<Type, 4> resultTypes;
if (hasPureTensorSemantics()) {
resultTypes.push_back(tiledOperands[1].getType());
resultTypes.push_back(tiledOperands[2].getType());
}
Operation *tiledScanOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
return TilingResult{
{tiledScanOp}, SmallVector<Value>(tiledScanOp->getResults()), slices};
}
LogicalResult ScanOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
if (resultNumber == 0) {
resultOffsets.assign(offsets.begin(), offsets.end());
resultSizes.assign(sizes.begin(), sizes.end());
return success();
}
if (resultNumber == 1) {
int64_t rank = getOperandRank();
if (rank > 1) {
for (auto i : llvm::seq<int64_t>(0, rank)) {
if (i == getDimension())
continue;
resultOffsets.push_back(offsets[i]);
resultSizes.push_back(sizes[i]);
}
}
return success();
}
return failure();
}
//===----------------------------------------------------------------------===//
// TopkOp
//===----------------------------------------------------------------------===//
SmallVector<Range> TopkOp::getIterationDomain(OpBuilder &builder) {
int64_t operandRank = getInputRank();
SmallVector<Range> loopBounds(operandRank);
Location loc = getLoc();
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
Value source = getValues();
for (auto [idx, val] : llvm::enumerate(getInputType().getShape())) {
loopBounds[idx].offset = zero;
loopBounds[idx].size = getDimValue(builder, loc, source, idx);
loopBounds[idx].stride = one;
}
return loopBounds;
}
SmallVector<utils::IteratorType> TopkOp::getLoopIteratorTypes() {
SmallVector<utils::IteratorType> iteratorTypes(getInputRank(),
utils::IteratorType::parallel);
iteratorTypes[getDimension()] = utils::IteratorType::reduction;
return iteratorTypes;
}
LogicalResult TopkOp::generateScalarImplementation(OpBuilder &b, Location loc,
ValueRange ivs) {
uint64_t kDim = getDimension();
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
Value initialValue = b.create<memref::LoadOp>(loc, getValues(), ivs);
// If the indices tensor is not provided, the value index is derived from the
// loop induction variables.
Value initialIndex;
if (getIndices()) {
initialIndex = b.create<memref::LoadOp>(loc, *getIndices(), ivs);
} else {
Value rawInitialIndex = ivs[kDim];
initialIndex =
b.create<arith::IndexCastOp>(loc, b.getI32Type(), rawInitialIndex);
}
// Compute K (ub) from the selected dim of the output
Value ub = b.create<memref::DimOp>(loc, outputValues(), getDimension());
// Inner K loop functions:
// Load current K value and index
// Compare N/K using inserted block compare
// Check if N == K using strict weak ordering, select which index came first
// Select new K value from N/K comparison
// Select new K index from N/K comparison or which index came first
// Store new k value and index
// Yield loop carry values after K selection
Value kValue, kIndex;
auto scfFor = b.create<scf::ForOp>(
loc, zero, ub, one, ValueRange{initialValue, initialIndex},
[&](OpBuilder &b, Location loc, Value iv, ValueRange loopCarryValues) {
SmallVector<Value> indices(ivs);
indices[kDim] = iv;
kValue = b.create<memref::LoadOp>(loc, outputValues(), indices);
kIndex = b.create<memref::LoadOp>(loc, outputIndices(), indices);
});
SmallVector<Value> indices(ivs);
indices[kDim] = scfFor.getInductionVar();
auto loopCarryValues = scfFor.getRegionIterArgs();
// Retrieve region as black box comparision function f(x,y). Plug into op.
auto &srcBlock = getRegion().front();
IRMapping bvmF; // f(x,y)
IRMapping bvmR; // f(y,x)
{
// Save previous insertion point. Continue within loop body.
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToEnd(&scfFor.getRegion().front());
SmallVector<Value> forwardValues{loopCarryValues[0], kValue};
SmallVector<Value> reverseValues{kValue, loopCarryValues[0]};
for (auto it : llvm::zip_equal(srcBlock.getArguments(), forwardValues)) {
bvmF.map(std::get<0>(it), std::get<1>(it));
}
for (auto it : llvm::zip_equal(srcBlock.getArguments(), reverseValues)) {
bvmR.map(std::get<0>(it), std::get<1>(it));
}
for (auto &blockOp : srcBlock.without_terminator()) {
b.clone(blockOp, bvmF);
b.clone(blockOp, bvmR);
}
Value forwardCmpRes = bvmF.lookup(srcBlock.getTerminator()->getOperand(0));
Value reverseCmpRes = bvmR.lookup(srcBlock.getTerminator()->getOperand(0));
// Check value equality using strictly weak ordering from the region:
// f(x,y) --> forwardCmpRes
// f(y,x) --> reverseCmpRes
// if forwardCmpRes == reverseCmpRes then select which came first
Value cmpValuesEqual = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, forwardCmpRes, reverseCmpRes);
Value cmpFirstIndex = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, loopCarryValues[1], kIndex);
Value combinedCmpEqRes =
b.create<arith::AndIOp>(loc, cmpValuesEqual, cmpFirstIndex);
// True if N > K or N came before K
Value indexCmpRes =
b.create<arith::OrIOp>(loc, forwardCmpRes, combinedCmpEqRes);
// Select results for K based on comparisons
Value resultKValue = b.create<arith::SelectOp>(loc, forwardCmpRes,
loopCarryValues[0], kValue);
Value resultKIndex =
b.create<arith::SelectOp>(loc, indexCmpRes, loopCarryValues[1], kIndex);
b.create<memref::StoreOp>(loc, resultKValue, outputValues(), indices);
b.create<memref::StoreOp>(loc, resultKIndex, outputIndices(), indices);
// Select loop carry, opposite of K results
Value resultCarryValue = b.create<arith::SelectOp>(
loc, forwardCmpRes, kValue, loopCarryValues[0]);
Value resultCarryIndex =
b.create<arith::SelectOp>(loc, indexCmpRes, kIndex, loopCarryValues[1]);
b.create<scf::YieldOp>(loc, ValueRange{resultCarryValue, resultCarryIndex});
}
return success();
}
FailureOr<TilingResult>
TopkOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
int64_t rank = getInputRank();
assert(offsets.size() == static_cast<size_t>(rank) &&
sizes.size() == static_cast<size_t>(rank));
SmallVector<OpFoldResult> strides(rank, builder.getI64IntegerAttr(1));
Location loc = getLoc();
SmallVector<OpFoldResult> outputOffsets, outputSizes;
if (failed(getResultTilePosition(builder, 0, offsets, sizes, outputOffsets,
outputSizes))) {
return {};
}
SmallVector<Value> tiledOperands;
SmallVector<Operation *> slices;
// Values
{
Operation *valuesSlice =
getSlice(builder, loc, getValues(), offsets, sizes, strides);
if (!valuesSlice) {
return emitOpError("failed to get values slice");
}
tiledOperands.emplace_back(valuesSlice->getResult(0));
slices.push_back(valuesSlice);
}
if (getIndices()) {
Operation *indicesSlice =
getSlice(builder, loc, *getIndices(), offsets, sizes, strides);
if (!indicesSlice) {
return emitOpError("failed to get slices of indices");
}
tiledOperands.emplace_back(indicesSlice->getResult(0));
slices.push_back(indicesSlice);
}
// Replace the tile size for the K dimension to use the output size instead of
// the input size.
Value kSize = getDimValue(builder, getLoc(), outputValues(), getDimension());
outputSizes[getDimension()] = getAsOpFoldResult(kSize);
// Output 0
{
Operation *output0Slice =
getSlice(builder, loc, getOutputs()[0], offsets, outputSizes, strides);
if (!output0Slice) {
return emitOpError("failed to get output 0 slice");
}
tiledOperands.emplace_back(output0Slice->getResult(0));
slices.push_back(output0Slice);
}
// Output 1
{
Operation *output1Slice =
getSlice(builder, loc, getOutputs()[1], offsets, outputSizes, strides);
if (!output1Slice) {
return emitOpError("failed to get output1 slice");
}
tiledOperands.emplace_back(output1Slice->getResult(0));
slices.push_back(output1Slice);
}
SmallVector<Type, 2> resultTypes;
if (hasPureTensorSemantics()) {
resultTypes.push_back(tiledOperands[tiledOperands.size() - 2].getType());
resultTypes.push_back(tiledOperands[tiledOperands.size() - 1].getType());
}
Operation *tiledTopkOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
return TilingResult{
{tiledTopkOp}, SmallVector<Value>(tiledTopkOp->getResults()), slices};
}
LogicalResult TopkOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
resultOffsets.assign(offsets.begin(), offsets.end());
resultSizes.assign(sizes.begin(), sizes.end());
Value kSize = getDimValue(builder, getLoc(), getDpsInits()[resultNumber],
getDimension());
resultSizes[getDimension()] = getAsOpFoldResult(kSize);
return success();
}
//===----------------------------------------------------------------------===//
// PackOp and UnPackOp utils
//===----------------------------------------------------------------------===//
/// Utility function to build the iteration domain for `packOp` or `unPackOp`.
template <typename OpTy>
static SmallVector<Range> getIterationDomain(OpTy op, OpBuilder &builder) {
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
"applies to only pack or unpack operations");
OpBuilder::InsertionGuard g(builder);
Location loc = op.getLoc();
int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getInputRank()
: op.getOutputRank();
SmallVector<Range> loopBounds(rank);
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
ReifiedRankedShapedTypeDims resultShape;
(void)op.reifyResultShapes(builder, resultShape);
for (auto dim : llvm::seq<int64_t>(0, rank)) {
loopBounds[dim].offset = zero;
loopBounds[dim].stride = one;
loopBounds[dim].size = resultShape[0][dim];
}
return loopBounds;
}
//===----------------------------------------------------------------------===//
// PackOp
//===----------------------------------------------------------------------===//
SmallVector<Range> PackOp::getIterationDomain(OpBuilder &builder) {
return LinalgExt::getIterationDomain(*this, builder);
}
/// Generate the body of the innermost loop of the scalar implementation
/// of `pack` operation.
static void generatePackOpScalarImplementationBody(PackOp packOp,
OpBuilder &builder,
Location loc,
ValueRange ivs) {
// Note: `ivs` are already in the correct order, possibly interchanged based
// on `dims_pos`. However, connecting the loops with the access patterns is
// difficult - What is the relation between the position of the tile loop and
// the point loop? However, if we interchange `ivs` once more to go to the
// canonical blocking format: ABCabc, this connection becomes trivial: Each
// point loop is pointLoopsOffset + inputRank away from the tiled loop.
ArrayRef<int64_t> dimsToInnerBlock = packOp.getInnerDimsPos();
ArrayRef<int64_t> dimsToOuterBlock = packOp.getOuterDimsPerm();
SmallVector<Value> interchangedIvs = ivs;
SmallVector<int64_t> interchangeVector =
computeInterchangeFromDimPos(dimsToInnerBlock, packOp.getInputRank());
interchangedIvs = interchange<Value>(interchangedIvs, interchangeVector,
/*offset=*/packOp.getInputRank());
if (!dimsToOuterBlock.empty()) {
interchangeVector =
computeInterchangeFromDimPos(dimsToOuterBlock, packOp.getInputRank());
interchangedIvs =
interchange<Value>(interchangedIvs, interchangeVector, /*offset=*/0);
}
SmallVector<OpFoldResult> tiles = packOp.getMixedTiles();
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
packOp.getDimAndTileMapping();
SmallVector<OpFoldResult> sourceIndices;
size_t pointLoopsOffset = 0;
int64_t inputRank = packOp.getInputRank();
for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
if (dimAndTileMapping.count(dim)) {
AffineExpr i, j, tile;
bindDims(builder.getContext(), i, j);
bindSymbols(builder.getContext(), tile);
OpFoldResult sourceIndex = affine::makeComposedFoldedAffineApply(
builder, loc, i * tile + j,
ArrayRef<OpFoldResult>{
interchangedIvs[dim],
interchangedIvs[pointLoopsOffset + packOp.getInputRank()],
dimAndTileMapping[dim]});
sourceIndices.push_back(sourceIndex);
++pointLoopsOffset;
} else {
sourceIndices.push_back(interchangedIvs[dim]);
}
}
auto createLoad = [&]() -> Value {
return builder.create<memref::LoadOp>(
loc, packOp.getInput(),
getValueOrCreateConstantIndexOp(builder, loc, sourceIndices));
};
Value scalar;
if (auto paddingValue = packOp.getPaddingValue()) {
ArithBuilder arithBuilder(builder, loc);
Value isInBounds;
for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
Value idx =
getValueOrCreateConstantIndexOp(builder, loc, sourceIndices[dim]);
Value cond = arithBuilder.slt(
idx, getDimValue(builder, loc, packOp.getInput(), dim));
isInBounds = dim == 0 ? cond : arithBuilder._and(isInBounds, cond);
}
scalar = builder
.create<scf::IfOp>(
loc, isInBounds, /*thenBuilder=*/
[&](OpBuilder &b, Location l) {
b.create<scf::YieldOp>(l, createLoad());
},
/*elseBuilder=*/
[&](OpBuilder &b, Location l) {
b.create<scf::YieldOp>(l, paddingValue);
})
.getResult(0);
} else {
scalar = createLoad();
}
builder.create<memref::StoreOp>(loc, scalar, packOp.getOutput(), ivs);
}
LogicalResult PackOp::generateScalarImplementation(OpBuilder &builder,
Location loc,
ValueRange ivs) {
OpBuilder::InsertionGuard g(builder);
// The `ivs` already represent the position into the output tensor for the
// non data-tile dimensions.
SmallVector<Value> ivVec = llvm::to_vector(ivs);
ReifiedRankedShapedTypeDims outputShape;
if (failed(reifyResultShapes(builder, outputShape))) {
return getOperation()->emitOpError("failed to reify result shape");
}
if (outputShape.size() != 1 || outputShape[0].size() != getOutputRank()) {
return getOperation()->emitOpError(
"expected shape of one result value of rank")
<< getOutputRank();
}
// Generate the loops that iterate over the data tile.
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
// All loops except the innermost are simple loops that just iterate
// over the tile dimensions.
for (auto dataTileDim :
llvm::seq<unsigned>(getInputRank(), getOutputRank() - 1)) {
Value ub = getValueOrCreateConstantIndexOp(builder, loc,
outputShape[0][dataTileDim]);
scf::ForOp loop = builder.create<scf::ForOp>(loc, zero, ub, one);
builder.setInsertionPointToStart(loop.getBody());
ivVec.push_back(loop.getInductionVar());
}
// The body of the innermost loops does the actual data movement.
builder.create<scf::ForOp>(
loc, zero,
getValueOrCreateConstantIndexOp(builder, loc, outputShape[0].back()), one,
ValueRange{},
[&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
ValueRange regionIterArgs) {
ivVec.push_back(iv);
generatePackOpScalarImplementationBody(*this, bodyBuilder, bodyLoc,
ivVec);
bodyBuilder.create<scf::YieldOp>(bodyLoc);
});
return success();
}
//===----------------------------------------------------------------------===//
// UnPackOp
//===----------------------------------------------------------------------===//
LogicalResult UnPackOp::generateScalarImplementation(OpBuilder &builder,
Location loc,
ValueRange ivs) {
assert(ivs.size() == getOutputRank() &&
"number of ivs must match the rank of the output tensor");
OpBuilder::InsertionGuard g(builder);
ReifiedRankedShapedTypeDims outputShape;
if (failed(reifyResultShapes(builder, outputShape))) {
return getOperation()->emitOpError("failed to reify result shape");
}
if (outputShape.size() != 1 || outputShape[0].size() != getOutputRank()) {
return getOperation()->emitOpError(
"expected shape of one result value of rank")
<< getOutputRank();
}
DenseMap<int64_t, OpFoldResult> dimAndTileMapping = getDimAndTileMapping();
// untiled loops and tile loops induction variables.
SmallVector<Value> inputIvs;
// point loops induction variables.
SmallVector<Value> inputIvsPointLoops;
inputIvs.reserve(getOutputRank());
inputIvsPointLoops.reserve(dimAndTileMapping.size());
for (auto dim : llvm::seq<int64_t>(0, getOutputRank())) {
if (dimAndTileMapping.count(dim)) {
affine::DivModValue divMod =
affine::getDivMod(builder, loc, ivs[dim],
getValueOrCreateConstantIndexOp(
builder, loc, dimAndTileMapping[dim]));
inputIvsPointLoops.push_back(divMod.remainder);
inputIvs.push_back(divMod.quotient);
} else {
inputIvs.push_back(ivs[dim]);
}
}
// TODO: (lorenzo) simplify the logic a bit. There is `ivs`,
// `inputIvsPointLoops` and `inputIvs`.
assert(inputIvsPointLoops.size() + inputIvs.size() == getInputRank() &&
"expect same number of iduction variables equals to input rank");
// interchange the point loops induction variables based on `inner_dim_pos`.
ArrayRef<int64_t> innerDims = getInnerDimsPos();
SmallVector<int64_t> interchangeVector =
computeInterchangeFromDimPos(innerDims, getOutputRank());
SmallVector<Value> interchangedInputIvsPointLoops = inputIvsPointLoops;
interchangedInputIvsPointLoops = interchange<Value>(
interchangedInputIvsPointLoops, interchangeVector, /*offset=*/0);
// interchange the tiled loops induction variables based on `outer_dims_perm`.
ArrayRef<int64_t> outerDims = getOuterDimsPerm();
if (!outerDims.empty()) {
inputIvs = interchange<Value>(inputIvs, outerDims, /*offset=*/0);
}
llvm::append_range(inputIvs, interchangedInputIvsPointLoops);
Value scalar = builder.create<memref::LoadOp>(loc, getInput(), inputIvs);
builder.create<memref::StoreOp>(loc, scalar, getOutput(), ivs);
return success();
}
SmallVector<Range> UnPackOp::getIterationDomain(OpBuilder &builder) {
return LinalgExt::getIterationDomain(*this, builder);
}
//===----------------------------------------------------------------------===//
// Im2colOp
//===----------------------------------------------------------------------===//
SmallVector<Range> Im2colOp::getIterationDomain(OpBuilder &builder) {
Location loc = getLoc();
OpFoldResult zero = builder.getIndexAttr(0);
OpFoldResult one = builder.getIndexAttr(1);
Value dest = getOutput();
SmallVector<Range> loopBounds(getOutputRank());
for (int dim = 0; dim < getOutputRank(); ++dim) {
loopBounds[dim].offset = zero;
loopBounds[dim].size = getDimValue(builder, loc, dest, dim);
loopBounds[dim].stride = one;
}
return loopBounds;
}
SmallVector<utils::IteratorType> Im2colOp::getLoopIteratorTypes() {
SmallVector<utils::IteratorType> iteratorTypes(getOutputRank(),
utils::IteratorType::parallel);
return iteratorTypes;
}
FailureOr<TilingResult>
Im2colOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
Location loc = getLoc();
OpFoldResult one = builder.getIndexAttr(1);
OpFoldResult zero = builder.getIndexAttr(0);
ReifiedRankedShapedTypeDims reifiedInputShapes;
SmallVector<OpFoldResult> inputOffsets(getInputRank(), zero);
SmallVector<OpFoldResult> inputSizes = getDims(builder, loc, getInput());
// Set batch offsets and sizes for input
for (auto [outDim, inDim] :
llvm::zip_equal(getBatchOutputDims(), getBatchPos())) {
inputOffsets[inDim] = offsets[outDim];
inputSizes[inDim] = sizes[outDim];
}
SmallVector<OpFoldResult> inputStrides(getInputRank(), one);
// Input
Operation *inputSlice = getSlice(builder, loc, getInput(), inputOffsets,
inputSizes, inputStrides);
if (!inputSlice) {
return emitOpError("failed to get slice of input");
}
SmallVector<OpFoldResult> outputStrides(getOutputRank(), one);
Operation *outputSlice =
getSlice(builder, loc, getOutput(), offsets, sizes, outputStrides);
if (!outputSlice) {
return emitOpError("failed to get outputSlice");
}
SmallVector<Type, 4> resultTypes;
if (hasPureTensorSemantics()) {
resultTypes.append(outputSlice->result_type_begin(),
outputSlice->result_type_end());
}
// Adjust m_offset and k_offset by adding the offsets from tiling.
SmallVector<OpFoldResult> newKOffsets, newMOffsets;
for (auto [outDim, kOffset] :
llvm::zip_equal(getKOutputDims(), getMixedKOffset())) {
OpFoldResult kTileOffset = offsets[outDim];
newKOffsets.push_back(addOfrs(builder, loc, kTileOffset, kOffset));
}
for (auto [outDim, mOffset] :
llvm::zip_equal(getMOutputDims(), getMixedMOffset())) {
OpFoldResult mTileOffset = offsets[outDim];
newMOffsets.push_back(addOfrs(builder, loc, mTileOffset, mOffset));
}
// Create the tiled op.
SmallVector<Value> operands = {inputSlice->getResult(0),
outputSlice->getResult(0)};
// Copy all metadata operands from the untiled operation.
operands.append(getOperation()->getOperands().begin() + 2,
getOperation()->getOperands().end());
Im2colOp tiledOp =
mlir::clone(builder, *this, outputSlice->getResultTypes(), operands);
// Set the new k_offset and m_offset, since they have changed with tiling.
tiledOp.setMixedKOffset(newKOffsets);
tiledOp.setMixedMOffset(newMOffsets);
return TilingResult{{tiledOp},
SmallVector<Value>(tiledOp->getResults()),
{inputSlice, outputSlice}};
}
FailureOr<TilingResult>
Im2colOp::generateResultTileValue(OpBuilder &builder, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
return getTiledImplementation(builder, offsets, sizes);
}
LogicalResult Im2colOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
resultOffsets = SmallVector<OpFoldResult>(offsets);
resultSizes = SmallVector<OpFoldResult>(sizes);
return success();
}
//===----------------------------------------------------------------------===//
// WinogradInputTransformOp
//===----------------------------------------------------------------------===//
SmallVector<Range>
WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
Location loc = getLoc();
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
Value dest = getOutput();
SmallVector<Range> loopBounds(getIterationDomainRank());
int count = 0;
for (auto dim :
llvm::seq<int64_t>(getImageDimensions().size(), getOutputRank())) {
loopBounds[count].offset = zero;
loopBounds[count].size = getDimValue(builder, loc, dest, dim);
loopBounds[count].stride = one;
count++;
}
return loopBounds;
}
SmallVector<utils::IteratorType>
WinogradInputTransformOp::getLoopIteratorTypes() {
SmallVector<utils::IteratorType> iteratorTypes(getIterationDomainRank(),
utils::IteratorType::parallel);
return iteratorTypes;
}
FailureOr<TilingResult>
WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
Location loc = getLoc();
auto one = builder.getIndexAttr(1);
auto zero = builder.getIndexAttr(0);
const int cDim = getChannelDim();
assert(offsets.size() == 4);
SmallVector<OpFoldResult> inputOffsets(getInputRank(), zero);
SmallVector<OpFoldResult> outputOffsets(getOutputRank(), zero);
const auto hDim = getImageDimensions()[0];
const auto wDim = getImageDimensions()[1];
outputOffsets[2] = inputOffsets[0] = offsets[0];
outputOffsets[3] = offsets[1];
outputOffsets[4] = offsets[2];
outputOffsets[5] = inputOffsets[cDim] = offsets[3];
SmallVector<OpFoldResult> inputStrides(getInputRank(), one);
SmallVector<OpFoldResult> outputStrides(getOutputRank(), one);
ReifiedRankedShapedTypeDims reifiedResultShapes, reifiedInputShapes;
if (failed(reifyResultShapes(builder, reifiedResultShapes))) {
return failure();
}
SmallVector<OpFoldResult> outputSizes = reifiedResultShapes[0];
if (failed(getStaticOrReifiedInputDims(builder, loc, getInput(),
reifiedInputShapes))) {
return failure();
}
SmallVector<OpFoldResult> inputSizes = reifiedInputShapes[0];
assert(sizes.size() == 4);
outputSizes[2] = inputSizes[0] = sizes[0];
outputSizes[3] = sizes[1];
outputSizes[4] = sizes[2];
outputSizes[5] = inputSizes[cDim] = sizes[3];
auto hSizeAndOffset = getScaledSizeAndOffset(
builder, loc, sizes[1], offsets[1], inputSizes[hDim], getOutputTileSize(),
getInputTileSize());
auto wSizeAndOffset = getScaledSizeAndOffset(
builder, loc, sizes[2], offsets[2], inputSizes[wDim], getOutputTileSize(),
getInputTileSize());
inputSizes[hDim] = hSizeAndOffset.first;
inputSizes[wDim] = wSizeAndOffset.first;
inputOffsets[hDim] = hSizeAndOffset.second;
inputOffsets[wDim] = wSizeAndOffset.second;
SmallVector<Value> tiledOperands;
SmallVector<Operation *> slices;
// Input
{
Operation *inputSlice = getSlice(builder, loc, getInput(), inputOffsets,
inputSizes, inputStrides);
if (!inputSlice) {
return emitOpError("failed to get input slice");
}
tiledOperands.emplace_back(inputSlice->getResult(0));
slices.push_back(inputSlice);
}
// Output
{
Operation *outputSlice = getSlice(builder, loc, getOutput(), outputOffsets,
outputSizes, outputStrides);
if (!outputSlice) {
return emitOpError("failed to get output slice");
}
tiledOperands.emplace_back(outputSlice->getResult(0));
slices.push_back(outputSlice);
}
SmallVector<Type, 4> resultTypes;
if (hasPureTensorSemantics()) {
resultTypes.push_back(tiledOperands[1].getType());
}
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
return TilingResult{
{tiledOp}, SmallVector<Value>(tiledOp->getResults()), slices};
}
LogicalResult WinogradInputTransformOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
if (resultNumber == 0) {
auto resultShape = cast<ShapedType>(getOutput().getType()).getShape();
resultSizes = getAsOpFoldResult(builder.getIndexArrayAttr(resultShape));
resultOffsets =
SmallVector<OpFoldResult>(getOutputRank(), builder.getIndexAttr(0));
resultOffsets[2] = offsets[0];
resultOffsets[3] = offsets[1];
resultOffsets[4] = offsets[2];
resultOffsets[5] = offsets[3];
resultSizes[2] = sizes[0];
resultSizes[3] = sizes[1];
resultSizes[4] = sizes[2];
resultSizes[5] = sizes[3];
return success();
}
return failure();
}
//===----------------------------------------------------------------------===//
// WinogradFilterTransformOp
//===----------------------------------------------------------------------===//
SmallVector<Range>
WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
Location loc = getLoc();
OpFoldResult zero = builder.getIndexAttr(0);
OpFoldResult one = builder.getIndexAttr(1);
Value source = getOutput();
int64_t numKernelDims = getKernelDimensions().size();
auto outRank = getOutputRank();
SmallVector<Range> loopBounds(outRank - numKernelDims);
for (auto dim : llvm::seq<int64_t>(numKernelDims, outRank)) {
int64_t loopDim = dim - numKernelDims;
loopBounds[loopDim].offset = zero;
loopBounds[loopDim].size = getDimValue(builder, loc, source, dim);
loopBounds[loopDim].stride = one;
}
return loopBounds;
}
SmallVector<utils::IteratorType>
WinogradFilterTransformOp::getLoopIteratorTypes() {
SmallVector<utils::IteratorType> iteratorTypes(getIterationDomainRank(),
utils::IteratorType::parallel);
return iteratorTypes;
}
FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
Location loc = getLoc();
OpFoldResult one = builder.getIndexAttr(1);
OpFoldResult zero = builder.getIndexAttr(0);
const int cDim = getChannelDim();
const int fDim = getFilterDim();
assert(offsets.size() == 2);
SmallVector<OpFoldResult> inputOffsets(getInputRank(), zero);
SmallVector<OpFoldResult> outputOffsets(getOutputRank(), zero);
outputOffsets[2] = inputOffsets[cDim] = offsets[0];
outputOffsets[3] = inputOffsets[fDim] = offsets[1];
SmallVector<OpFoldResult> inputStrides(getInputRank(), one);
SmallVector<OpFoldResult> outputStrides(getOutputRank(), one);
assert(sizes.size() == 2);
ArrayRef<int64_t> inputShape = getInputType().getShape();
ArrayRef<int64_t> outputShape = getOutputType().getShape();
SmallVector<OpFoldResult> inputSizes =
getAsOpFoldResult(builder.getIndexArrayAttr(inputShape));
SmallVector<OpFoldResult> outputSizes =
getAsOpFoldResult(builder.getIndexArrayAttr(outputShape));
outputSizes[2] = inputSizes[cDim] = sizes[0];
outputSizes[3] = inputSizes[fDim] = sizes[1];
SmallVector<Value> tiledOperands;
SmallVector<Operation *> slices;
// Input
{
Operation *inputSlice = getSlice(builder, loc, getInput(), inputOffsets,
inputSizes, inputStrides);
if (!inputSlice) {
return emitOpError("failed to get input slice");
}
tiledOperands.emplace_back(inputSlice->getResult(0));
slices.push_back(inputSlice);
}
// Output
{
Operation *outputSlice = getSlice(builder, loc, getOutput(), outputOffsets,
outputSizes, outputStrides);
if (!outputSlice) {
return emitOpError("failed to get output slice");
}
tiledOperands.emplace_back(outputSlice->getResult(0));
slices.push_back(outputSlice);
}
SmallVector<Type> resultTypes;
if (hasPureTensorSemantics()) {
resultTypes.push_back(tiledOperands[1].getType());
}
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
return TilingResult{
{tiledOp}, SmallVector<Value>(tiledOp->getResults()), slices};
}
LogicalResult WinogradFilterTransformOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
if (resultNumber != 0) {
return failure();
}
ArrayRef<int64_t> resultShape = getOutputType().getShape();
resultSizes = getAsOpFoldResult(builder.getIndexArrayAttr(resultShape));
resultOffsets =
SmallVector<OpFoldResult>(getOutputRank(), builder.getIndexAttr(0));
resultOffsets[2] = offsets[0];
resultOffsets[3] = offsets[1];
resultSizes[2] = sizes[0];
resultSizes[3] = sizes[1];
return success();
}
//===----------------------------------------------------------------------===//
// WinogradOutputTransformOp
//===----------------------------------------------------------------------===//
SmallVector<Range>
WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
Location loc = getLoc();
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
Value source = getInput();
SmallVector<Range> loopBounds(getIterationDomainRank());
int count = 0;
for (auto dim :
llvm::seq<int64_t>(getImageDimensions().size(), getInputRank())) {
loopBounds[count].offset = zero;
loopBounds[count].size = getDimValue(builder, loc, source, dim);
loopBounds[count].stride = one;
count++;
}
return loopBounds;
}
SmallVector<utils::IteratorType>
WinogradOutputTransformOp::getLoopIteratorTypes() {
SmallVector<utils::IteratorType> iteratorTypes(getIterationDomainRank(),
utils::IteratorType::parallel);
return iteratorTypes;
}
FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
Location loc = getLoc();
auto one = builder.getIndexAttr(1);
auto zero = builder.getIndexAttr(0);
const int cDim = getChannelDim();
assert(offsets.size() == 4);
const auto hDim = getImageDimensions()[0];
const auto wDim = getImageDimensions()[1];
SmallVector<OpFoldResult> inputOffsets(getInputRank(), zero);
SmallVector<OpFoldResult> outputOffsets(getOutputRank(), zero);
inputOffsets[2] = outputOffsets[0] = offsets[0];
inputOffsets[3] = offsets[1];
inputOffsets[4] = offsets[2];
inputOffsets[5] = outputOffsets[cDim] = offsets[3];
SmallVector<OpFoldResult> inputStrides(getInputRank(), one);
SmallVector<OpFoldResult> outputStrides(getOutputRank(), one);
ReifiedRankedShapedTypeDims reifiedResultShapes, reifiedInputShapes;
if (failed(reifyResultShapes(builder, reifiedResultShapes))) {
return failure();
}
SmallVector<OpFoldResult> outputSizes = reifiedResultShapes[0];
if (failed(getStaticOrReifiedInputDims(builder, loc, getInput(),
reifiedInputShapes))) {
return failure();
}
SmallVector<OpFoldResult> inputSizes = reifiedInputShapes[0];
inputSizes[2] = outputSizes[0] = sizes[0];
inputSizes[5] = outputSizes[cDim] = sizes[3];
assert(sizes.size() == 4);
inputSizes[2] = outputSizes[0] = sizes[0];
inputSizes[3] = sizes[1];
inputSizes[4] = sizes[2];
inputSizes[5] = outputSizes[cDim] = sizes[3];
auto hSizeAndOffset = getScaledSizeAndOffset(
builder, loc, sizes[1], offsets[1], outputSizes[hDim],
getOutputTileSize(), getOutputTileSize());
auto wSizeAndOffset = getScaledSizeAndOffset(
builder, loc, sizes[2], offsets[2], outputSizes[wDim],
getOutputTileSize(), getOutputTileSize());
outputSizes[hDim] = hSizeAndOffset.first;
outputSizes[wDim] = wSizeAndOffset.first;
outputOffsets[hDim] = hSizeAndOffset.second;
outputOffsets[wDim] = wSizeAndOffset.second;
Operation *outputSlice = getSlice(builder, loc, getOutput(), outputOffsets,
outputSizes, outputStrides);
// The image dims of the winograd.output_transform result will always be a
// multiple of the static output_tile_size, so insert a tensor.cast op to
// maintain more static information in the IR.
auto outSliceType = cast<ShapedType>(outputSlice->getResultTypes().front());
SmallVector<int64_t> staticOutShape(outSliceType.getShape());
auto constSizeH = getConstantIntValue(sizes[1]);
if (constSizeH.has_value()) {
staticOutShape[hDim] = constSizeH.value() * getOutputTileSize();
}
auto constSizeW = getConstantIntValue(sizes[2]);
if (constSizeW.has_value()) {
staticOutShape[wDim] = constSizeW.value() * getOutputTileSize();
}
Value staticOutputSlice = castValue(builder, loc, outputSlice->getResult(0),
outSliceType.clone(staticOutShape));
SmallVector<Value> tiledOperands;
auto inputSlice = getSlice(builder, loc, getInput(), inputOffsets, inputSizes,
inputStrides);
tiledOperands.emplace_back(inputSlice->getResult(0));
tiledOperands.emplace_back(staticOutputSlice);
SmallVector<Type, 4> resultTypes;
if (hasPureTensorSemantics()) {
resultTypes.push_back(tiledOperands[1].getType());
}
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
SmallVector<Value> results(tiledOp->getResults());
if (!results.empty()) {
results.front() = castValue(builder, loc, results.front(), outSliceType);
}
return TilingResult{{tiledOp}, results, {inputSlice, outputSlice}};
}
LogicalResult WinogradOutputTransformOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
if (resultNumber == 0) {
auto resultShape = cast<ShapedType>(getOutput().getType()).getShape();
resultSizes = getAsOpFoldResult(builder.getIndexArrayAttr(resultShape));
resultOffsets =
SmallVector<OpFoldResult>(getOutputRank(), builder.getIndexAttr(0));
const int cDim = getChannelDim();
const auto hDim = getImageDimensions()[0];
const auto wDim = getImageDimensions()[1];
auto loc = getLoc();
resultOffsets[0] = offsets[0];
resultOffsets[cDim] = offsets[3];
resultSizes[0] = sizes[0];
resultSizes[cDim] = sizes[3];
SmallVector<SmallVector<OpFoldResult>> reifiedResultShapes;
if (failed(reifyResultShapes(builder, reifiedResultShapes))) {
return failure();
}
auto hSizeAndOffset = getScaledSizeAndOffset(
builder, loc, sizes[1], offsets[1], reifiedResultShapes[0][hDim],
getOutputTileSize(), getOutputTileSize());
auto wSizeAndOffset = getScaledSizeAndOffset(
builder, loc, sizes[2], offsets[2], reifiedResultShapes[0][wDim],
getOutputTileSize(), getOutputTileSize());
resultSizes[hDim] = hSizeAndOffset.first;
resultSizes[wDim] = wSizeAndOffset.first;
resultOffsets[hDim] = hSizeAndOffset.second;
resultOffsets[wDim] = wSizeAndOffset.second;
return success();
}
return failure();
}
//===----------------------------------------------------------------------===//
// Attention Helpers
//===----------------------------------------------------------------------===//
static SmallVector<Range>
getAttentionIterationDomain(Location loc, OpBuilder &b, int64_t domainRank,
ArrayRef<Value> values,
ArrayRef<AffineMap> indexingMaps) {
SmallVector<Range> loopBounds(domainRank);
OpFoldResult zero = b.getIndexAttr(0);
OpFoldResult one = b.getIndexAttr(1);
for (auto dim : llvm::seq<int64_t>(0, domainRank)) {
loopBounds[dim].offset = zero;
loopBounds[dim].stride = one;
}
SmallVector<bool> dimsFound(domainRank, false);
auto fillSizes = [&](Value val, AffineMap indexingMap) {
for (auto [idx, dimExpr] : llvm::enumerate(indexingMap.getResults())) {
assert(isa<AffineDimExpr>(dimExpr));
AffineDimExpr dim = cast<AffineDimExpr>(dimExpr);
int64_t pos = dim.getPosition();
if (dimsFound[pos]) {
continue;
}
dimsFound[pos] = true;
loopBounds[pos].size = getDim(b, loc, val, idx);
}
};
for (auto [val, indexingMap] : llvm::zip_equal(values, indexingMaps)) {
fillSizes(val, indexingMap);
}
return loopBounds;
}
static SmallVector<utils::IteratorType>
getAttentionIteratorTypes(int64_t domainRank, AffineMap qMap, AffineMap kMap,
AffineMap vMap, AffineMap oMap) {
FailureOr<AttentionOpDetail> maybeOpInfo =
AttentionOpDetail::get(qMap, kMap, vMap, oMap);
assert(succeeded(maybeOpInfo) && "Failed to infer attention op details");
AttentionOpDetail opInfo = maybeOpInfo.value();
// All dimensions other than k1 and k2 are parallel.
SmallVector<utils::IteratorType> iteratorTypes(domainRank,
utils::IteratorType::parallel);
for (auto dim :
llvm::concat<const int64_t>(opInfo.getK1Dims(), opInfo.getK2Dims())) {
iteratorTypes[dim] = utils::IteratorType::reduction;
}
return iteratorTypes;
}
static SmallVector<Range> getPermutedRange(AffineMap permutation,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
auto one = IntegerAttr::get(IndexType::get(permutation.getContext()), 1);
assert(permutation.isProjectedPermutation() &&
"Indexing map should be a projected permutation");
SmallVector<Range> output;
for (AffineExpr dimExpr : permutation.getResults()) {
int dim = cast<AffineDimExpr>(dimExpr).getPosition();
Range dimRange;
dimRange.offset = offsets[dim];
dimRange.size = sizes[dim];
dimRange.stride = one;
output.push_back(dimRange);
}
return output;
}
static Operation *getPermutedSlice(OpBuilder &b, Location loc, Value val,
AffineMap permutation,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
SmallVector<Range> slice = getPermutedRange(permutation, offsets, sizes);
Operation *querySliceOp = getSlice(b, loc, val, slice);
return querySliceOp;
}
//===----------------------------------------------------------------------===//
// AttentionOp
//===----------------------------------------------------------------------===//
SmallVector<Range> AttentionOp::getIterationDomain(OpBuilder &b) {
// Attention shape can be determined from Q, K, V alone.
SmallVector<Value> shapedValues = {getQuery(), getKey(), getValue()};
SmallVector<AffineMap> indexingMaps = {getQueryMap(), getKeyMap(),
getValueMap()};
return getAttentionIterationDomain(getLoc(), b, getIterationDomainRank(),
shapedValues, indexingMaps);
}
SmallVector<utils::IteratorType> AttentionOp::getLoopIteratorTypes() {
return getAttentionIteratorTypes(getIterationDomainRank(), getQueryMap(),
getKeyMap(), getValueMap(), getOutputMap());
}
FailureOr<TilingResult>
AttentionOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
assert(offsets.size() == getIterationDomainRank());
assert(sizes.size() == getIterationDomainRank());
Location loc = getLoc();
SmallVector<Range> querySlice =
getPermutedRange(getQueryMap(), offsets, sizes);
SmallVector<Range> keySlice = getPermutedRange(getKeyMap(), offsets, sizes);
SmallVector<Range> valueSlice =
getPermutedRange(getValueMap(), offsets, sizes);
SmallVector<Range> outputSlice =
getPermutedRange(getOutputMap(), offsets, sizes);
Value scale = getScale();
SmallVector<Value> tiledOperands;
SmallVector<Operation *> slices;
// Query
{
Operation *querySliceOp = getSlice(builder, loc, getQuery(), querySlice);
if (!querySliceOp) {
return emitOpError("failed to get query slice");
}
tiledOperands.emplace_back(querySliceOp->getResult(0));
slices.push_back(querySliceOp);
}
// Key
{
Operation *keySliceOp = getSlice(builder, loc, getKey(), keySlice);
if (!keySliceOp) {
return emitOpError("failed to get key slice");
}
tiledOperands.emplace_back(keySliceOp->getResult(0));
slices.push_back(keySliceOp);
}
// Value
{
Operation *valueSliceOp = getSlice(builder, loc, getValue(), valueSlice);
if (!valueSliceOp) {
return emitOpError("failed to get value slice");
}
tiledOperands.emplace_back(valueSliceOp->getResult(0));
slices.push_back(valueSliceOp);
}
// Scale
tiledOperands.emplace_back(scale);
// Mask
Value attnMask = getMask();
if (attnMask) {
SmallVector<Range> maskSlice =
getPermutedRange(*getMaskMap(), offsets, sizes);
Operation *maskSliceOp = getSlice(builder, loc, attnMask, maskSlice);
tiledOperands.emplace_back(maskSliceOp->getResult(0));
slices.push_back(maskSliceOp);
}
// Output
{
Operation *outputSliceOp = getSlice(builder, loc, getOutput(), outputSlice);
if (!outputSliceOp) {
return emitOpError("failed to get output slice");
}
tiledOperands.emplace_back(outputSliceOp->getResult(0));
slices.push_back(outputSliceOp);
}
SmallVector<Type> resultTypes;
if (hasPureTensorSemantics()) {
int64_t baseIdx = attnMask ? 5 : 4;
resultTypes.push_back(tiledOperands[baseIdx].getType());
}
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
return TilingResult{
{tiledOp}, SmallVector<Value>(tiledOp->getResults()), slices};
}
LogicalResult AttentionOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
resultOffsets.clear();
resultSizes.clear();
AffineMap resultIndexingMap;
switch (resultNumber) {
case 0:
resultIndexingMap = getOutputMap();
break;
default:
return failure();
}
for (AffineExpr dimExpr : resultIndexingMap.getResults()) {
int dim = cast<AffineDimExpr>(dimExpr).getPosition();
resultOffsets.push_back(offsets[dim]);
resultSizes.push_back(sizes[dim]);
}
return success();
}
FailureOr<TilingResult>
AttentionOp::generateResultTileValue(OpBuilder &builder, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
// Input offsets and sizes here are from the POV of the outputMap. We need to
// normalize these offsets and size for it to be useful.
// Initialize normalized offsets with 0s and normalized sizes with original
// size.
SmallVector<Range> iterationDomain(getIterationDomain(builder));
SmallVector<OpFoldResult> normalizedSizes =
llvm::map_to_vector(iterationDomain, [](Range x) { return x.size; });
SmallVector<OpFoldResult> normalizedOffsets(getIterationDomainRank(),
builder.getIndexAttr(0));
ArrayRef<AffineExpr> outputDims = getOutputMap().getResults();
for (int i = 0; i < outputDims.size(); i++) {
int dim = cast<AffineDimExpr>(outputDims[i]).getPosition();
normalizedOffsets[dim] = offsets[i];
normalizedSizes[dim] = sizes[i];
}
return getTiledImplementation(builder, normalizedOffsets, normalizedSizes);
}
//===----------------------------------------------------------------------===//
// OnlineAttentionOp
//===----------------------------------------------------------------------===//
SmallVector<Range> OnlineAttentionOp::getIterationDomain(OpBuilder &b) {
// Attention shape can be determined from Q, K, V alone.
SmallVector<Value> shapedValues = {getQuery(), getKey(), getValue()};
SmallVector<AffineMap> indexingMaps = {getQueryMap(), getKeyMap(),
getValueMap()};
return getAttentionIterationDomain(getLoc(), b, getIterationDomainRank(),
shapedValues, indexingMaps);
}
SmallVector<utils::IteratorType> OnlineAttentionOp::getLoopIteratorTypes() {
return getAttentionIteratorTypes(getIterationDomainRank(), getQueryMap(),
getKeyMap(), getValueMap(), getOutputMap());
}
FailureOr<TilingResult>
OnlineAttentionOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
assert(offsets.size() == getIterationDomainRank());
assert(sizes.size() == getIterationDomainRank());
Location loc = getLoc();
SmallVector<Range> querySlice =
getPermutedRange(getQueryMap(), offsets, sizes);
SmallVector<Range> keySlice = getPermutedRange(getKeyMap(), offsets, sizes);
SmallVector<Range> valueSlice =
getPermutedRange(getValueMap(), offsets, sizes);
std::optional<SmallVector<Range>> maskSlice;
if (auto maskMap = getMaskMap()) {
maskSlice = getPermutedRange(*maskMap, offsets, sizes);
}
SmallVector<Range> outputSlice =
getPermutedRange(getOutputMap(), offsets, sizes);
SmallVector<Range> maxSlice = getPermutedRange(getMaxMap(), offsets, sizes);
SmallVector<Range> sumSlice = getPermutedRange(getSumMap(), offsets, sizes);
Value scale = getScale();
SmallVector<Value> tiledOperands;
SmallVector<Operation *> slices;
/// Query
{
Operation *querySliceOp = getSlice(builder, loc, getQuery(), querySlice);
if (!querySliceOp) {
return emitOpError("failed to get query slice");
}
tiledOperands.emplace_back(querySliceOp->getResult(0));
slices.push_back(querySliceOp);
}
/// Key
{
Operation *keySliceOp = getSlice(builder, loc, getKey(), keySlice);
if (!keySliceOp) {
return emitOpError("failed to get key slice");
}
tiledOperands.emplace_back(keySliceOp->getResult(0));
slices.push_back(keySliceOp);
}
/// Value
{
Operation *valueSliceOp = getSlice(builder, loc, getValue(), valueSlice);
if (!valueSliceOp) {
return emitOpError("failed to get value slice");
}
tiledOperands.emplace_back(valueSliceOp->getResult(0));
slices.push_back(valueSliceOp);
}
tiledOperands.emplace_back(scale);
// Mask
Value attnMask = getMask();
if (attnMask) {
SmallVector<Range> maskSlice =
getPermutedRange(*getMaskMap(), offsets, sizes);
Operation *maskSliceOp = getSlice(builder, loc, attnMask, maskSlice);
tiledOperands.emplace_back(maskSliceOp->getResult(0));
slices.push_back(maskSliceOp);
}
/// Output
{
Operation *outputSliceOp = getSlice(builder, loc, getOutput(), outputSlice);
if (!outputSliceOp) {
return emitOpError("failed to get output slice");
}
tiledOperands.emplace_back(outputSliceOp->getResult(0));
slices.push_back(outputSliceOp);
}
/// Max
{
Operation *maxSliceOp = getSlice(builder, loc, getMax(), maxSlice);
if (!maxSliceOp) {
return emitOpError("failed to get max slice");
}
tiledOperands.emplace_back(maxSliceOp->getResult(0));
slices.push_back(maxSliceOp);
}
/// Sum
{
Operation *sumSliceOp = getSlice(builder, loc, getSum(), sumSlice);
if (!sumSliceOp) {
return emitOpError("failed to get sum slice");
}
tiledOperands.emplace_back(sumSliceOp->getResult(0));
slices.push_back(sumSliceOp);
}
SmallVector<Type> resultTypes;
resultTypes.push_back(tiledOperands[tiledOperands.size() - 3].getType());
resultTypes.push_back(tiledOperands[tiledOperands.size() - 2].getType());
resultTypes.push_back(tiledOperands[tiledOperands.size() - 1].getType());
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
return TilingResult{
{tiledOp}, SmallVector<Value>(tiledOp->getResults()), slices};
}
LogicalResult OnlineAttentionOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
resultOffsets.clear();
resultSizes.clear();
AffineMap resultIndexingMap;
switch (resultNumber) {
case 0:
resultIndexingMap = getOutputMap();
break;
case 1:
resultIndexingMap = getMaxMap();
break;
case 2:
resultIndexingMap = getSumMap();
break;
default:
return failure();
}
for (AffineExpr dimExpr : resultIndexingMap.getResults()) {
int dim = cast<AffineDimExpr>(dimExpr).getPosition();
resultOffsets.push_back(offsets[dim]);
resultSizes.push_back(sizes[dim]);
}
return success();
}
static AffineMap getPartialResultMap(AffineMap map, AttentionOpDetail &opInfo) {
// Append K2 dimensions at end.
for (int dim : opInfo.getK2Dims()) {
map = map.insertResult(getAffineDimExpr(dim, map.getContext()),
map.getNumResults());
}
return map;
}
FailureOr<SmallVector<Value>>
OnlineAttentionOp::generateInitialTensorForPartialReduction(
OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes,
ArrayRef<int> reductionDim) {
FailureOr<AttentionOpDetail> maybeOpInfo = AttentionOpDetail::get(
getQueryMap(), getKeyMap(), getValueMap(), getOutputMap());
if (failed(maybeOpInfo)) {
return emitOpError("failed to verify op's indexing maps");
}
AttentionOpDetail &opInfo = maybeOpInfo.value();
SmallVector<OpFoldResult> shape = llvm::map_to_vector(
getIterationDomain(b), [](Range x) { return x.size; });
SmallVector<OpFoldResult> tiledShape;
for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) {
if (isZeroIndex(tileSize)) {
tiledShape.push_back(dimSize);
} else {
tiledShape.push_back(tileSize);
}
}
SmallVector<OpFoldResult> accSize = applyPermutationMap<OpFoldResult>(
getPartialResultMap(getOutputMap(), opInfo), tiledShape);
SmallVector<OpFoldResult> maxSize = applyPermutationMap<OpFoldResult>(
getPartialResultMap(getMaxMap(), opInfo), tiledShape);
SmallVector<OpFoldResult> sumSize = applyPermutationMap<OpFoldResult>(
getPartialResultMap(getSumMap(), opInfo), tiledShape);
Type accElTy = getElementTypeOrSelf(getOutput().getType());
Type maxElTy = getElementTypeOrSelf(getMax().getType());
Type sumElTy = getElementTypeOrSelf(getSum().getType());
Value partialAcc = b.create<tensor::EmptyOp>(loc, accSize, accElTy);
Value partialMax = b.create<tensor::EmptyOp>(loc, maxSize, maxElTy);
Value partialSum = b.create<tensor::EmptyOp>(loc, sumSize, sumElTy);
Value accInit = arith::getIdentityValue(arith::AtomicRMWKind::addf, accElTy,
b, loc, /*useOnlyFiniteValue=*/true);
Value maxInit =
arith::getIdentityValue(arith::AtomicRMWKind::maximumf, maxElTy, b, loc,
/*useOnlyFiniteValue=*/true);
Value sumInit =
arith::getIdentityValue(arith::AtomicRMWKind::addf, sumElTy, b, loc);
Value accFill = b.create<linalg::FillOp>(loc, ValueRange{accInit}, partialAcc)
.getResult(0);
Value maxFill = b.create<linalg::FillOp>(loc, ValueRange{maxInit}, partialMax)
.getResult(0);
Value sumFill = b.create<linalg::FillOp>(loc, ValueRange{sumInit}, partialSum)
.getResult(0);
return SmallVector<Value>{accFill, maxFill, sumFill};
}
FailureOr<TilingResult> OnlineAttentionOp::tileToPartialReduction(
OpBuilder &b, Location loc, ValueRange init, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, ArrayRef<int> reductionDims) {
FailureOr<AttentionOpDetail> maybeOpInfo = AttentionOpDetail::get(
getQueryMap(), getKeyMap(), getValueMap(), getOutputMap());
if (failed(maybeOpInfo)) {
return emitOpError("failed to verify op's indexing maps");
}
AttentionOpDetail &opInfo = maybeOpInfo.value();
// Extend result maps, keeping everything else the same.
AffineMap partialAccMap = getPartialResultMap(getOutputMap(), opInfo);
AffineMap partialMaxMap = getPartialResultMap(getMaxMap(), opInfo);
AffineMap partialSumMap = getPartialResultMap(getSumMap(), opInfo);
SmallVector<AffineMap> indexingMaps = getIndexingMapsArray();
indexingMaps[getNumDpsInputs()] = partialAccMap;
indexingMaps[getNumDpsInputs() + 1] = partialMaxMap;
indexingMaps[getNumDpsInputs() + 2] = partialSumMap;
SmallVector<Value> tiledOperands;
SmallVector<Operation *> slices;
auto appendSlice = [&](Value val, AffineMap map,
ArrayRef<OpFoldResult> offsets) -> LogicalResult {
Operation *sliceOp = getPermutedSlice(b, loc, val, map, offsets, sizes);
if (!sliceOp) {
return emitOpError("failed to get slice");
}
tiledOperands.emplace_back(sliceOp->getResult(0));
slices.push_back(sliceOp);
return success();
};
if (failed(appendSlice(getQuery(), getQueryMap(), offsets))) {
return failure();
}
if (failed(appendSlice(getKey(), getKeyMap(), offsets))) {
return failure();
}
if (failed(appendSlice(getValue(), getValueMap(), offsets))) {
return failure();
}
tiledOperands.emplace_back(getScale());
if (Value mask = getMask()) {
if (failed(appendSlice(mask, *getMaskMap(), offsets))) {
return failure();
}
}
// For results, we set offset of the iterated reduction dims to 0.
SmallVector<OpFoldResult> initOffsets(offsets);
for (int dim : opInfo.getK2Dims()) {
initOffsets[dim] = b.getIndexAttr(0);
}
if (failed(appendSlice(init[0], partialAccMap, initOffsets))) {
return failure();
}
if (failed(appendSlice(init[1], partialMaxMap, initOffsets))) {
return failure();
}
if (failed(appendSlice(init[2], partialSumMap, initOffsets))) {
return failure();
}
// Get the initial values.
ValueRange slicedInits = ArrayRef(tiledOperands).take_back(3);
auto tiledOp = cast<OnlineAttentionOp>(
mlir::clone(b, getOperation(), slicedInits.getTypes(), tiledOperands));
tiledOp.setIndexingMapsAttr(b.getAffineMapArrayAttr(indexingMaps));
return TilingResult{
{tiledOp}, SmallVector<Value>(tiledOp->getResults()), slices};
}
template <typename CombinerOp>
static linalg::ReduceOp reduceOnK2(OnlineAttentionOp attn, AffineMap partialMap,
AttentionOpDetail &opInfo, OpBuilder &b,
Location loc, Value partialResult,
Value init) {
// linalg.reduce's iteration space is the result's iteration space (and
// not the operations iteration space). To account for this, permute the
// reduction dimensions based on the partial result map.
SmallVector<int64_t> partialReductionDims;
for (auto [resultNum, dimExpr] : llvm::enumerate(partialMap.getResults())) {
unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
if (llvm::find(opInfo.getK2Dims(), dim) != opInfo.getK2Dims().end()) {
partialReductionDims.push_back(resultNum);
}
}
return b.create<linalg::ReduceOp>(
loc, partialResult, init, partialReductionDims,
[&](OpBuilder &b, Location loc, ValueRange inputs) {
Value reduced = b.create<CombinerOp>(loc, inputs[0], inputs[1]);
b.create<linalg::YieldOp>(loc, reduced);
});
};
template <typename T>
static Value elementwiseValueInPlace(OpBuilder &builder, Location loc,
AffineMap inputMap, AffineMap scaleMap,
Value value, Value scale) {
SmallVector<AffineMap> compressedMaps =
compressUnusedDims(SmallVector<AffineMap>{inputMap, scaleMap});
inputMap = compressedMaps[0];
scaleMap = compressedMaps[1];
SmallVector<utils::IteratorType> iteratorTypes(inputMap.getNumDims(),
utils::IteratorType::parallel);
auto genericOp = builder.create<linalg::GenericOp>(
loc, value.getType(), scale, value,
SmallVector<AffineMap>{scaleMap, inputMap}, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
// Convert scale to the same datatype as input.
Value scale = convertScalarToDtype(b, loc, args[0], args[1].getType(),
/*isUnsignedCast=*/false);
Value result = b.create<T>(loc, scale, args[1]);
b.create<linalg::YieldOp>(loc, result);
});
return genericOp.getResult(0);
}
// Compute output = exp2(output - input)
static Value computeSubAndExp2(OpBuilder &builder, Location loc,
AffineMap inputMap, AffineMap outputMap,
Value input, Value output) {
SmallVector<AffineMap> compressedMaps =
compressUnusedDims(SmallVector<AffineMap>{inputMap, outputMap});
inputMap = compressedMaps[0];
outputMap = compressedMaps[1];
SmallVector<utils::IteratorType> iteratorTypes(inputMap.getNumDims(),
utils::IteratorType::parallel);
auto genericOp = builder.create<linalg::GenericOp>(
loc, output.getType(), input, output,
SmallVector<AffineMap>{inputMap, outputMap}, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
// Convert input to the same datatype as output.
Value in = convertScalarToDtype(b, loc, args[0], args[1].getType(),
/*isUnsignedCast=*/false);
Value diff = b.create<arith::SubFOp>(loc, args[1], in);
Value weight = b.create<math::Exp2Op>(loc, diff);
b.create<linalg::YieldOp>(loc, weight);
});
return genericOp.getResult(0);
}
FailureOr<MergeResult>
OnlineAttentionOp::mergeReductions(OpBuilder &b, Location loc,
ValueRange partialReduce,
ArrayRef<int> reductionDim) {
FailureOr<AttentionOpDetail> maybeOpInfo = AttentionOpDetail::get(
getQueryMap(), getKeyMap(), getValueMap(), getOutputMap());
if (failed(maybeOpInfo)) {
return emitOpError("failed to verify op's indexing maps");
}
AttentionOpDetail &opInfo = maybeOpInfo.value();
AffineMap partialAccMap = getPartialResultMap(getOutputMap(), opInfo);
AffineMap partialMaxMap = getPartialResultMap(getMaxMap(), opInfo);
AffineMap partialSumMap = getPartialResultMap(getSumMap(), opInfo);
// newMax = max(maxInit, rowMax(partialMax))
linalg::ReduceOp reducedMax = reduceOnK2<arith::MaximumFOp>(
*this, partialMaxMap, opInfo, b, loc, partialReduce[1], getMax());
// norm = exp2(partialMax - newMax)
Value norm = computeSubAndExp2(b, loc, getMaxMap(), partialMaxMap,
reducedMax.getResult(0), partialReduce[1]);
// normSum = norm * partialSum
Value normSum = elementwiseValueInPlace<arith::MulFOp>(
b, loc, partialSumMap, partialMaxMap, partialReduce[2], norm);
// newSum = sumInit + rowSum(partialSum)
linalg::ReduceOp reducedSum = reduceOnK2<arith::AddFOp>(
*this, partialSumMap, opInfo, b, loc, normSum, getSum());
// normAcc = norm * partialAcc
Value normAcc = elementwiseValueInPlace<arith::MulFOp>(
b, loc, partialAccMap, partialMaxMap, partialReduce[0], norm);
// newAcc = accInit + rowMax(partialAcc)
linalg::ReduceOp reducedAcc = reduceOnK2<arith::AddFOp>(
*this, partialAccMap, opInfo, b, loc, normAcc, getOutput());
return MergeResult{{reducedAcc, reducedMax, reducedSum},
{reducedAcc.getResult(0), reducedMax.getResult(0),
reducedSum.getResult(0)}};
}
LogicalResult OnlineAttentionOp::getPartialResultTilePosition(
OpBuilder &b, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes, ArrayRef<int> reductionDims) {
FailureOr<AttentionOpDetail> maybeOpInfo = AttentionOpDetail::get(
getQueryMap(), getKeyMap(), getValueMap(), getOutputMap());
if (failed(maybeOpInfo)) {
return emitOpError("failed to verify op's indexing maps");
}
AttentionOpDetail &opInfo = maybeOpInfo.value();
resultOffsets.clear();
resultSizes.clear();
AffineMap resultIndexingMap;
switch (resultNumber) {
case 0:
resultIndexingMap = getOutputMap();
break;
case 1:
resultIndexingMap = getMaxMap();
break;
case 2:
resultIndexingMap = getSumMap();
break;
default:
return failure();
}
AffineMap partialMap = getPartialResultMap(resultIndexingMap, opInfo);
for (AffineExpr dimExpr : partialMap.getResults()) {
int dim = cast<AffineDimExpr>(dimExpr).getPosition();
resultSizes.push_back(sizes[dim]);
if (llvm::find(opInfo.getK2Dims(), dim) != opInfo.getK2Dims().end()) {
// Reduction dims are reduced, and are always outputed in the same
// place. So use offset 0 for them.
resultOffsets.push_back(b.getIndexAttr(0));
} else {
resultOffsets.push_back(offsets[dim]);
}
}
return success();
}
//===---------------------------------------------------------------------===//
// CustomOp
//===---------------------------------------------------------------------===//
/// These methods copied/modified from `TilingInterface` implementation of
/// `getIterationDomain` of `LinalgOp`s.
SmallVector<utils::IteratorType> CustomOp::getLoopIteratorTypes() {
return llvm::map_to_vector(getIteratorTypes(), [](Attribute attr) {
return cast<IREE::LinalgExt::IteratorTypeAttr>(attr).getValue();
});
}
/// Method similar to `LinalgOp`s that concatenates shapes of all operands.
static SmallVector<OpFoldResult>
createFlatListOfOperandDims(OpBuilder &builder, Location loc,
CustomOp customOp) {
SmallVector<OpFoldResult> result;
for (Value operand : customOp->getOperands()) {
for (auto dim : llvm::seq<unsigned>(customOp.getRank(operand))) {
result.push_back(getDim(builder, loc, operand, dim));
}
}
return result;
}
SmallVector<Range> CustomOp::getIterationDomainForDimensions(
OpBuilder &builder, ArrayRef<unsigned> dims, ArrayRef<unsigned> symbols) {
CustomOp customOp = *this;
SmallVector<AffineMap> maps = customOp.getIndexingMapsArray();
if (maps.empty()) {
return SmallVector<Range>{};
}
Location loc = getLoc();
// 1. Create a flat list of all the operand shapes (similar to Linalg)
SmallVector<OpFoldResult> allShapesSizes =
createFlatListOfOperandDims(builder, loc, customOp);
// 2a. Next we need to get a map from shapes to loop. Since `custom_op`
// has indexing maps that have symbols, to make this work correctly
// compute new maps that replaces the symbols with "new" dims.
unsigned numDims = getNumLoops();
unsigned numSymbols = getNumNonLoopDimensions();
MLIRContext *context = getContext();
SmallVector<AffineMap> modifiedMaps =
convertDimsToSymbols(context, maps, numDims, numSymbols);
// 2b. Concat the affine maps.
AffineMap concatMap =
inversePermutation(concatAffineMaps(modifiedMaps, context));
// TODO: Ideally we should bail if the map is invalid, i.e. we abort from
// applying the transformation. We could add this to the verifier as well, but
// it is unclear if this makes the op invalid. Revisit after more experience
// of how this operation is used.
assert(concatMap && "failure in inverting indexing maps");
SmallVector<Range> ranges;
OpFoldResult zero = builder.getIndexAttr(0), one = builder.getIndexAttr(1);
auto getRange = [&](AffineExpr expr) {
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(builder, loc, expr,
allShapesSizes);
return Range{zero, ofr, one};
};
ranges = llvm::map_to_vector(
dims, [&](unsigned dim) { return getRange(concatMap.getResult(dim)); });
ranges.append(llvm::map_to_vector(symbols, [&](unsigned symbol) {
return getRange(concatMap.getResult(symbol + numSymbols));
}));
return ranges;
}
SmallVector<Range> CustomOp::getIterationDomain(OpBuilder &builder) {
auto dims = llvm::to_vector(llvm::seq<unsigned>(0, getNumLoops()));
return getIterationDomainForDimensions(builder, dims,
/*symbols=*/ArrayRef<unsigned>{});
}
/// During tiling, the tiling implementation generates offsets and sizes
/// to use for the dimensions of the operation that correspond to loops.
/// The dimensions of the operation corresponding to symbols isnt tiled.
/// This method extends the offsets and sizes for a tile with the
/// offsets (which are zero) and sizes (which are same as the untiled sizes) for
/// the dimension represented by symbols.
static std::pair<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
appendOffsetAndSizeForSymbolDimensions(OpBuilder &builder, CustomOp customOp,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
unsigned numSymbols = customOp.getNumNonLoopDimensions();
if (numSymbols == 0) {
return {llvm::to_vector(offsets), llvm::to_vector(sizes)};
}
auto appendedOffsets = llvm::to_vector(offsets);
appendedOffsets.append(numSymbols, builder.getIndexAttr(0));
auto appendedSizes = llvm::to_vector(sizes);
auto symbols = llvm::to_vector(llvm::seq<unsigned>(0, numSymbols));
auto symbolRanges = customOp.getIterationDomainForDimensions(
builder, /*dims=*/ArrayRef<unsigned>{}, symbols);
auto symbolRangesUb =
llvm::map_range(symbolRanges, [](Range r) { return r.size; });
appendedSizes.append(symbolRangesUb.begin(), symbolRangesUb.end());
return {appendedOffsets, appendedSizes};
}
/// This method is adapted from the method `linalg::computeAllSliceParameters`,
/// adapted to work with map that have symbols, and empty maps.
static SmallVector<std::optional<linalg::SliceParameters>>
computeCustomOpAllSliceParameters(OpBuilder &builder, Location loc,
CustomOp customOp, ValueRange valuesToTile,
ArrayRef<OpFoldResult> ivs,
SmallVector<OpFoldResult> tileSizes) {
assert(ivs.size() == static_cast<size_t>(llvm::count_if(
llvm::make_range(tileSizes.begin(), tileSizes.end()),
[](OpFoldResult v) { return !isZeroIndex(v); })) &&
"expected as many ivs as non-zero sizes");
unsigned numDims = customOp.getNumLoops();
unsigned numSymbols = customOp.getNumNonLoopDimensions();
// Construct (potentially temporary) mins and maxes on which to apply maps
// that define tile subshapes.
SmallVector<OpFoldResult> lbs =
linalg::computeTileOffsets(builder, loc, ivs, tileSizes);
SmallVector<OpFoldResult> sizeBounds = tileSizes;
std::tie(lbs, sizeBounds) = appendOffsetAndSizeForSymbolDimensions(
builder, customOp, lbs, sizeBounds);
tileSizes.append(numSymbols, builder.getIndexAttr(0));
SmallVector<OpFoldResult> subShapeSizes =
linalg::computeTileSizes(builder, loc, tileSizes, sizeBounds);
SmallVector<AffineExpr> symbolReplacements;
if (numSymbols != 0) {
symbolReplacements =
getDimExprsForSymbols(builder.getContext(), numDims, numSymbols);
}
assert(static_cast<int64_t>(valuesToTile.size()) <=
customOp->getNumOperands() &&
"more value to tile than operands.");
SmallVector<std::optional<linalg::SliceParameters>> allSliceParams;
allSliceParams.reserve(valuesToTile.size());
for (auto [opOperand, val] :
llvm::zip(customOp->getOpOperands(), valuesToTile)) {
Value shapedOp = val;
LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
AffineMap map = customOp.getMatchingIndexingMap(&opOperand);
if (numSymbols != 0) {
map = convertDimsToSymbols(map, numDims, numSymbols, symbolReplacements);
}
// If the map is empty, we dont tile this operand.
Type operandType = opOperand.get().getType();
if (map.isEmpty() || !isa<ShapedType>(operandType)) {
allSliceParams.push_back(std::nullopt);
LLVM_DEBUG(llvm::dbgs()
<< ": not tiled: use shape: " << operandType << "\n");
continue;
}
allSliceParams.push_back(linalg::computeSliceParameters(
builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
/*omitPartialTileCheck=*/true));
}
return allSliceParams;
}
/// This method is same as `materializeTiledShape` method defined in
/// `mlir/Dialect/Linalg/Utils/Utils.[h|cpp]`.
static Operation *
materializeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
const linalg::SliceParameters &sliceParams) {
auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
.Case([&](RankedTensorType) {
return builder.create<tensor::ExtractSliceOp>(
loc, valueToTile, sliceParams.offsets,
sliceParams.sizes, sliceParams.strides);
})
.Default([](ShapedType) -> Operation * {
llvm_unreachable("Unexpected shaped type");
});
return sliceOp;
}
/// This method is adapted from the method `linalg::makeTiledShapes`,
/// adapted to work with map that have symbols, and empty maps.
static SmallVector<Value>
makeCustomOpTiledShapes(OpBuilder &builder, Location loc, CustomOp customOp,
ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs,
ArrayRef<OpFoldResult> tileSizes) {
SmallVector<std::optional<linalg::SliceParameters>> allSliceParameter =
computeCustomOpAllSliceParameters(builder, loc, customOp, valuesToTile,
ivs, llvm::to_vector(tileSizes));
SmallVector<Value> tiledShapes;
for (auto [valueToTile, sliceParams] :
llvm::zip_equal(valuesToTile, allSliceParameter)) {
tiledShapes.push_back(
sliceParams.has_value()
? materializeTiledShape(builder, loc, valueToTile, *sliceParams)
->getResult(0)
: valueToTile);
}
return tiledShapes;
}
static void offsetCustomOpIndices(OpBuilder &b, CustomOp customOp,
ArrayRef<OpFoldResult> offsets) {
IRRewriter rewriter(b);
for (auto indexOp : customOp.getBody()->getOps<IREE::LinalgExt::IndexOp>()) {
if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()])
continue;
OpBuilder::InsertionGuard guard(b);
rewriter.setInsertionPointAfter(indexOp);
AffineExpr index, offset;
bindDims(b.getContext(), index, offset);
OpFoldResult applied = affine::makeComposedFoldedAffineApply(
rewriter, indexOp.getLoc(), index + offset,
{getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]});
Value materialized =
getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied);
rewriter.replaceUsesWithIf(indexOp, materialized, [&](OpOperand &use) {
return use.getOwner() != materialized.getDefiningOp();
});
}
}
FailureOr<TilingResult>
CustomOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
CustomOp customOp = *this;
// Leave the `sizeBounds` value empty. That is only needed when the `sizes`
// specified could lead to out of bounds accesses.
Location loc = getLoc();
SmallVector<Value> valuesToTile = getOperation()->getOperands();
SmallVector<Value> tiledOperands = makeCustomOpTiledShapes(
builder, loc, *this, valuesToTile, offsets, sizes);
SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
llvm::make_filter_range(
tiledOperands,
[](Value v) -> bool {
return isa_and_nonnull<tensor::ExtractSliceOp, memref::SubViewOp>(
v.getDefiningOp());
}),
[](Value v) -> Operation * { return v.getDefiningOp(); });
SmallVector<Type> resultTensorTypes =
llvm::map_to_vector(getDpsInitsMutable(), [&](OpOperand &opOperand) {
return tiledOperands[opOperand.getOperandNumber()].getType();
});
Operation *tiledOp =
mlir::clone(builder, customOp, resultTensorTypes, tiledOperands);
offsetCustomOpIndices(builder, cast<CustomOp>(tiledOp), offsets);
return TilingResult{
{tiledOp}, SmallVector<Value>(tiledOp->getResults()), generatedSlices};
}
/// Methods copied/modified from `TilingInterface` implementation of
/// `getTiledImplementation` of `LinalgOp`s.
LogicalResult CustomOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
CustomOp customOp = *this;
Location loc = getLoc();
AffineExpr d0 = builder.getAffineDimExpr(0);
SmallVector<OpFoldResult> appendedOffsets, appendedSizes;
std::tie(appendedOffsets, appendedSizes) =
appendOffsetAndSizeForSymbolDimensions(builder, customOp, offsets, sizes);
SmallVector<OpFoldResult> subShapeSizes =
llvm::map_to_vector(appendedSizes, [&](OpFoldResult ofr) {
return affine::makeComposedFoldedAffineApply(builder, loc, d0 - 1, ofr);
});
OpOperand *outOperand = customOp.getDpsInitOperand(resultNumber);
linalg::SliceParameters sliceParams = linalg::computeSliceParameters(
builder, loc, outOperand->get(), sizes,
customOp.getMatchingIndexingMap(outOperand), appendedOffsets,
/*ubs*/ {}, subShapeSizes, /*omitPartialTileCheck=*/true);
resultOffsets = sliceParams.offsets;
resultSizes = sliceParams.sizes;
return success();
}
} // namespace mlir::iree_compiler::IREE::LinalgExt