blob: 02bda2358b8c6f7e5c89b8b7b00ac27faf8f4fc0 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree-dialects/Dialect/LinalgExt/Utils/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::iree_compiler::IREE::LinalgExt;
namespace IREE = mlir::iree_compiler::IREE;
//===----------------------------------------------------------------------===//
// Utils.
//===----------------------------------------------------------------------===//
static Type getComplexElementTypeOrSelf(Type ty) {
if (auto complex = dyn_cast_or_null<ComplexType>(ty))
return complex.getElementType();
return ty;
}
static void getEffectsImpl(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects,
ValueRange inputOperands, ValueRange outputOperands) {
for (Value value : inputOperands) {
if (!llvm::isa<MemRefType>(value.getType())) {
continue;
}
effects.emplace_back(MemoryEffects::Read::get(), value,
SideEffects::DefaultResource::get());
}
for (Value value : outputOperands) {
if (!llvm::isa<MemRefType>(value.getType())) {
continue;
}
effects.emplace_back(MemoryEffects::Read::get(), value,
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Write::get(), value,
SideEffects::DefaultResource::get());
}
}
/// Returns a memref.subview or a tensor.extract_slice based on the type of the
/// `source`.
static Value getSlice(OpBuilder &b, Location loc, Value source,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) {
return TypeSwitch<Type, Value>(source.getType())
.Case<RankedTensorType>([&](RankedTensorType t) -> Value {
return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
strides);
})
.Case<MemRefType>([&](MemRefType type) -> Value {
return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
strides);
})
.Default([&](Type t) { return nullptr; });
}
/// Return true if `dimsPos` is invalid. It is invalid when: a) it contains
/// duplicate. b) At least one dimension is out of bound (`dimPos` is >= 0 and <
/// rank). c) the number of elements in `dimsPos` is > than `rank`.
static bool isInvalid(ArrayRef<int64_t> dimsPos, int64_t rank) {
// early exit.
if (dimsPos.size() > rank)
return true;
DenseSet<int64_t> uniqued;
for (int64_t dim : dimsPos)
uniqued.insert(dim);
if (dimsPos.size() != uniqued.size())
return true;
return llvm::any_of(
dimsPos, [rank](int64_t dimPos) { return dimPos < 0 || dimPos >= rank; });
}
/// Returns true if the dimension of `sourceShape` is smaller than the dimension
/// of the `limitShape`.
static bool isSmallerThan(ArrayRef<int64_t> sourceShape,
ArrayRef<int64_t> limitShape) {
assert(
sourceShape.size() == limitShape.size() &&
"expected source shape rank, and limit of the shape to have same rank");
return llvm::all_of(
llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
int64_t sourceExtent = std::get<0>(it);
int64_t limit = std::get<1>(it);
return sourceExtent == ShapedType::kDynamic ||
limit == ShapedType::kDynamic || sourceExtent <= limit;
});
}
//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
LogicalResult ScatterOp::verify() {
Operation *op = getOperation();
if (getInputs().size() != 2) {
return op->emitOpError("expected two input operands");
}
if (getOutputs().size() != 1) {
return op->emitOpError("expected one output operand");
}
auto checkDimensionsMatch = [&](ShapedType t1, ShapedType t2, unsigned dim) {
return t1.getShape()[dim] == t2.getShape()[dim];
};
auto indicesType = getIndicesType();
if (indicesType.getRank() != 2 ||
!indicesType.getElementType().isInteger(32)) {
return op->emitOpError(
"expected indices to be of rank 2 of i32 element type");
}
auto indexDepth = getIndexDepth();
if (indexDepth == ShapedType::kDynamic) {
return op->emitOpError("expected index depth is static");
}
ArrayRef<int64_t> dimMap = getDimensionMap();
if (dimMap.size() != indexDepth) {
return op->emitOpError("invalid number of dimension map entries ");
}
auto originalType = getOriginalType();
if (isInvalid(dimMap, originalType.getRank()))
return op->emitOpError("dimension map is invalid");
// The first dimension of the indices should match the first dimension of the
// output. They indicate to the number of updates.
auto updateType = getUpdateType();
if (updateType.getRank() < 1) {
return op->emitOpError("expected update value to be at least rank 1");
}
if (!checkDimensionsMatch(indicesType, updateType, 0)) {
return op->emitOpError(
"mismatch in shape of indices and update value at dim#0");
}
if (updateType.getRank() - 1 > originalType.getRank()) {
return op->emitOpError(
"update value rank exceeds the rank of the original value");
}
// indexDepth + update dims should cover the original dims. The first dim of
// update is the number of updates.
if (originalType.getRank() > indexDepth + updateType.getRank() - 1) {
return op->emitOpError(
"index depth and update value does not cover rank of original value");
}
// Validate the non-indexed update dims cover the full slice size of the
// original tensor.
int64_t fullSliceDims = originalType.getRank() - indexDepth;
for (auto it :
llvm::zip(llvm::seq<unsigned>(indexDepth, originalType.getRank()),
llvm::seq<unsigned>(updateType.getRank() - fullSliceDims,
updateType.getRank()))) {
int64_t originalDim = std::get<0>(it);
int64_t updateDim = std::get<1>(it);
if (!originalType.isDynamicDim(originalDim) &&
updateType.getDimSize(updateDim) >
originalType.getDimSize(originalDim)) {
return op->emitOpError("shape of update value dim#")
<< updateDim << " exceeds original value at dim#" << originalDim;
}
}
// Check that the remaining update indices do not exceed the update length.
int64_t insertDims = originalType.getRank() - updateType.getRank() + 1;
for (auto it : llvm::zip(
llvm::seq<unsigned>(insertDims, indexDepth),
llvm::seq<unsigned>(1, updateType.getRank() - fullSliceDims))) {
int64_t originalDim = std::get<0>(it);
int64_t updateDim = std::get<1>(it);
if (!originalType.isDynamicDim(originalDim) &&
updateType.getDimSize(updateDim) >
originalType.getDimSize(originalDim)) {
return op->emitOpError("indexed shape of update value dim#")
<< updateDim << " exceeds original value at dim#" << originalDim
<< " " << updateType.getDimSize(updateDim) << " "
<< originalType.getDimSize(originalDim);
}
}
Region &region = this->getRegion();
Block *body = &region.front();
if (body->getNumArguments() != 2) {
return op->emitOpError("expected region to have two arguments");
}
Type arg0Type = body->getArgument(0).getType();
Type arg1Type = body->getArgument(1).getType();
if (!getComplexElementTypeOrSelf(arg0Type).isIntOrFloat() ||
!getComplexElementTypeOrSelf(arg1Type).isIntOrFloat()) {
return op->emitOpError(
"expected region to have scalar argument of integer or float types");
}
if (arg0Type != updateType.getElementType()) {
return op->emitOpError("mismatch in argument 0 of region ")
<< arg0Type << " and element type of update value "
<< updateType.getElementType();
}
if (arg1Type != originalType.getElementType()) {
return op->emitOpError("mismatch in argument 1 of region ")
<< arg1Type << " and element type of original value "
<< originalType.getElementType();
}
if (arg0Type != arg1Type) {
return op->emitOpError("mismatch in region argument types ")
<< arg0Type << " and " << arg1Type;
}
auto yieldOp = cast<IREE::LinalgExt::YieldOp>(body->getTerminator());
if (yieldOp->getNumOperands() != 1) {
return yieldOp.emitOpError("expected region to yield a single value");
}
auto yieldedType = yieldOp->getOperand(0).getType();
if (yieldedType != arg0Type) {
return yieldOp.emitOpError("mismatch in type of yielded value ")
<< yieldedType << " and argument of the region " << arg0Type;
}
return success();
}
SmallVector<utils::IteratorType> ScatterOp::getLoopIteratorTypes() {
SmallVector<utils::IteratorType> iteratorTypes(getUpdateType().getRank(),
utils::IteratorType::parallel);
if (!getUniqueIndices()) {
iteratorTypes[0] = utils::IteratorType::reduction;
}
return iteratorTypes;
}
SmallVector<Range> ScatterOp::getIterationDomain(OpBuilder &builder) {
Location loc = getLoc();
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
SmallVector<Range> ranges;
for (auto dim : llvm::seq<int64_t>(0, getUpdateType().getRank())) {
Value ub = getDimValue(builder, loc, updates(), dim);
ranges.emplace_back(Range{zero, ub, one});
}
return ranges;
}
FailureOr<TilingResult>
ScatterOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
assert(offsets.size() >= 1 && sizes.size() >= 1);
Location loc = getLoc();
auto zeroAttr = builder.getI64IntegerAttr(0);
auto oneAttr = builder.getI64IntegerAttr(1);
// Slice of the updates.
auto updateRank = getUpdateType().getRank();
SmallVector<OpFoldResult> updateStrides(updateRank, oneAttr);
Value tiledUpdate =
getSlice(builder, loc, updates(), offsets, sizes, updateStrides);
assert(tiledUpdate && "failed to get slice of update");
// Slice of indices.
auto indicesRank = getIndicesType().getRank();
SmallVector<OpFoldResult> indicesOffsets(indicesRank, zeroAttr);
SmallVector<OpFoldResult> indicesSizes(indicesRank);
indicesOffsets[0] = offsets[0];
indicesSizes[0] = sizes[0];
for (auto dim : llvm::seq<int64_t>(1, indicesRank)) {
indicesSizes[dim] = getDim(builder, loc, indices(), dim);
}
SmallVector<OpFoldResult> indicesStrides(indicesRank, oneAttr);
Value tiledIndices = getSlice(builder, loc, indices(), indicesOffsets,
indicesSizes, indicesStrides);
assert(tiledIndices && "failed to get slice of indices");
// Slice of the original.
SmallVector<OpFoldResult> originalOffsets, originalSizes;
if (failed(getResultTilePosition(builder, 0, offsets, sizes, originalOffsets,
originalSizes))) {
return {};
}
auto originalRank = getOriginalType().getRank();
SmallVector<OpFoldResult> originalStrides(originalRank, oneAttr);
Value tiledOriginal = getSlice(builder, loc, original(), originalOffsets,
originalSizes, originalStrides);
assert(tiledOriginal && "failed to get slice of original tensor");
SmallVector<Type> resultTypes;
if (getNumResults()) {
resultTypes.push_back(tiledOriginal.getType());
}
Operation *tiledScatterOp =
mlir::clone(builder, getOperation(), resultTypes,
ValueRange{tiledUpdate, tiledIndices, tiledOriginal});
return TilingResult{{tiledScatterOp},
SmallVector<Value>(tiledScatterOp->getResults())};
}
LogicalResult ScatterOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
auto zeroAttr = builder.getI64IntegerAttr(0);
// Slice of the original.
auto originalRank = getOriginalType().getRank();
resultOffsets.resize(originalRank, zeroAttr);
resultSizes.resize(originalRank);
auto updateRank = getUpdateType().getRank();
Location loc = getLoc();
for (auto dim : llvm::seq<int64_t>(0, originalRank - updateRank + 1)) {
resultSizes[dim] = getDim(builder, loc, original(), dim);
}
for (auto dim :
llvm::seq<int64_t>(originalRank - updateRank + 1, originalRank)) {
resultOffsets[dim] = offsets[dim - (originalRank - updateRank)];
resultSizes[dim] = sizes[dim - (originalRank - updateRank)];
}
return success();
}
LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b,
Location loc,
ValueRange ivs) {
auto indexDepth = getIndexDepth();
Value update = b.create<memref::LoadOp>(loc, updates(), ivs);
SmallVector<Value> starts;
SmallVector<Value> loadIndices;
loadIndices.push_back(ivs.front());
loadIndices.push_back(Value());
// Populate with empty values.
auto originalTy = original().getType().cast<ShapedType>();
starts.resize(originalTy.getRank(), Value());
auto updateIvs = ivs.drop_front(1);
int64_t offset = starts.size() - updateIvs.size();
for (auto it : llvm::enumerate(updateIvs)) {
starts[it.index() + offset] = it.value();
}
ArrayRef<int64_t> dimMap = getDimensionMap();
for (auto i : llvm::seq<unsigned>(0, indexDepth)) {
loadIndices.back() = b.create<arith::ConstantIndexOp>(loc, i);
Value idx = b.create<memref::LoadOp>(loc, indices(), loadIndices);
Value ret = b.create<arith::IndexCastOp>(loc, b.getIndexType(), idx);
auto dim = dimMap[i];
if (starts[dim])
ret = b.create<arith::AddIOp>(loc, ret, starts[dim]);
starts[dim] = ret;
}
Value init = b.create<memref::LoadOp>(loc, original(), starts);
IRMapping bvm;
Block &block = getRegion().front();
bvm.map(block.getArgument(0), update);
bvm.map(block.getArgument(1), init);
for (auto &blockOp : block.without_terminator()) {
b.clone(blockOp, bvm);
}
// The last op is linalg_ext.yield op. Store the operand to
// destination.
b.create<memref::StoreOp>(
loc, bvm.lookupOrDefault(block.getTerminator()->getOperand(0)),
original(), starts);
return success();
}
LogicalResult
ScatterOp::reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(b, reifiedReturnShapes);
}
//===----------------------------------------------------------------------===//
// SortOp
//===----------------------------------------------------------------------===//
LogicalResult SortOp::verify() {
Operation *op = getOperation();
if (getNumDpsInputs()) {
return op->emitOpError("does not expect to take any inputs");
}
if (getNumDpsInits() == 0) {
return op->emitOpError("expected at least one `outs` operand");
}
Block &block = getRegion().front();
size_t numOutputs = getNumDpsInits();
if (block.getNumArguments() != 2 * numOutputs) {
return op->emitOpError("region block should have ")
<< 2 * numOutputs << " arguments";
}
int64_t rank = getOperandRank();
int sortDim = getDimension();
if (sortDim < 0 || sortDim >= rank) {
return op->emitOpError("dimension must be within (0, ") << rank << "]";
}
ArrayRef<int64_t> shape = getOperandShape();
for (auto indexedOperand : llvm::enumerate(getOutputs())) {
int index = indexedOperand.index();
auto operandType = getOperandType(index);
if (operandType.getRank() != rank) {
return op->emitOpError("expected operand ")
<< index << " to be rank " << rank << ", same as other operands";
}
if (operandType.getShape() != shape) {
return op->emitOpError("expected operand ")
<< index << " to have same shape as other operands";
}
Type elemType = operandType.getElementType();
for (int i : {2 * index, 2 * index + 1}) {
Type argType = block.getArgument(i).getType();
if (argType != elemType) {
return op->emitOpError("region block argument #")
<< i << " should be of type " << elemType << " but got "
<< argType;
}
}
}
auto yieldOp = cast<YieldOp>(block.getTerminator());
if (yieldOp.getNumOperands() != 1) {
return op->emitOpError("should yield exactly one operand");
}
auto ty = yieldOp.getOperand(0).getType().dyn_cast<IntegerType>();
if (!ty || ty.getWidth() != 1) {
return op->emitOpError("should yield i1 type");
}
return success();
}
SmallVector<utils::IteratorType> SortOp::getLoopIteratorTypes() {
// All loops except the dimension to sort along are parallel.
SmallVector<utils::IteratorType> iteratorTypes(getOperandRank(),
utils::IteratorType::parallel);
iteratorTypes[getDimension()] = utils::IteratorType::reduction;
return iteratorTypes;
}
SmallVector<Range> SortOp::getIterationDomain(OpBuilder &builder) {
int64_t operandRank = getOperandRank();
SmallVector<Range> loopBounds(operandRank);
Location loc = getLoc();
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
Value source = operand(0);
for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
loopBounds[dim].offset = zero;
loopBounds[dim].size = getDimValue(builder, loc, source, dim);
loopBounds[dim].stride = one;
}
return loopBounds;
}
FailureOr<TilingResult>
SortOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
int64_t rank = getOperandRank();
assert(offsets.size() == static_cast<size_t>(rank) &&
sizes.size() == static_cast<size_t>(rank));
auto oneAttr = builder.getI64IntegerAttr(1);
SmallVector<OpFoldResult> strides(rank, oneAttr);
SmallVector<Value> tiledOperands(getOutputs().size());
for (auto en : llvm::enumerate(getOutputs())) {
tiledOperands[en.index()] =
getSlice(builder, getLoc(), en.value(), offsets, sizes, strides);
assert(tiledOperands[en.index()] && "failed to get slice of operand");
}
SmallVector<Type, 4> resultTypes;
if (getNumResults()) {
resultTypes = llvm::map_to_vector<4>(tiledOperands,
[&](Value v) { return v.getType(); });
}
Operation *tiledSortOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
return TilingResult{{tiledSortOp},
SmallVector<Value>{tiledSortOp->getResults()}};
}
LogicalResult SortOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
resultOffsets = llvm::to_vector(offsets);
resultSizes = llvm::to_vector(sizes);
return success();
}
LogicalResult SortOp::generateScalarImplementation(OpBuilder &b, Location loc,
ValueRange ivs) {
auto sortDim = getDimension();
SmallVector<Value> indices, sortBlkArgs;
indices.append(ivs.begin(), ivs.end());
// Bubble sort innermost loop.
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
Value ub;
if (getOperandType(0).isDynamicDim(sortDim)) {
ub = b.create<memref::DimOp>(loc, operand(0), sortDim);
} else {
ub = b.create<arith::ConstantIndexOp>(
loc, getOperandType(0).getDimSize(sortDim));
}
ub = b.create<arith::SubIOp>(loc, ub, one);
auto scfFor = b.create<scf::ForOp>(
loc, zero, ub, one, ValueRange{},
[&](OpBuilder &b, Location loc, Value iv, ValueRange iters) {
SmallVector<Value> indices(ivs);
Value ivPlusOne = b.create<arith::AddIOp>(loc, iv, one);
for (auto output : getDpsInits()) {
indices[sortDim] = iv;
sortBlkArgs.push_back(b.create<memref::LoadOp>(loc, output, indices));
indices[sortDim] = ivPlusOne;
sortBlkArgs.push_back(b.create<memref::LoadOp>(loc, output, indices));
}
});
auto &srcBlock = getRegion().front();
Region &region = scfFor.getRegion();
IRMapping bvm;
{
OpBuilder::InsertionGuard guard(b);
auto &block = region.front();
b.setInsertionPointToEnd(&block);
for (auto it : llvm::zip(srcBlock.getArguments(), sortBlkArgs)) {
bvm.map(std::get<0>(it), std::get<1>(it));
}
for (auto &blockOp : srcBlock.without_terminator()) {
b.clone(blockOp, bvm);
}
}
Value cond = bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0));
OpBuilder::InsertionGuard g(b);
b.setInsertionPointToEnd(&region.front());
b.create<scf::IfOp>(
loc, cond,
[&](OpBuilder &b, Location loc) {
// Do not swap the pairs if true.
b.create<scf::YieldOp>(loc);
},
[&](OpBuilder &b, Location loc) {
// Swap the pairs if false.
SmallVector<Value> indices(ivs.begin(), ivs.end());
Value ivPlusOne =
b.create<arith::AddIOp>(loc, scfFor.getInductionVar(), one);
for (int i = 0, e = getNumDpsInits(); i < e; ++i) {
Value v1 = sortBlkArgs[i * 2];
Value v2 = sortBlkArgs[i * 2 + 1];
indices[sortDim] = scfFor.getInductionVar();
b.create<memref::StoreOp>(loc, v2, getDpsInits()[i], indices);
indices[sortDim] = ivPlusOne;
b.create<memref::StoreOp>(loc, v1, getDpsInits()[i], indices);
}
b.create<scf::YieldOp>(loc);
});
b.create<scf::YieldOp>(loc);
return success();
}
LogicalResult
SortOp::reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(b, reifiedReturnShapes);
}
//===----------------------------------------------------------------------===//
// FftOp
//===----------------------------------------------------------------------===//
LogicalResult FftOp::verify() {
Operation *op = getOperation();
auto length = getFftLength();
// After tiling, it could be dynamic shape. (Because
// subview/subtensor does not inference the type correctly
// on (1 << x)) cases).
if (length == ShapedType::kDynamic)
return success();
if (length & (length - 1)) {
return op->emitOpError("only powers of 2 are handled currently");
}
if (!getNumDpsInputs() || !isScalar(getDpsInputOperand(0))) {
return op->emitOpError("expected to carry `stage` input");
}
if (getNumDpsInputs() != 1) {
if (getNumDpsInputs() != 3 || isScalar(getDpsInputOperand(1)) ||
isScalar(getDpsInputOperand(2))) {
return op->emitOpError("expected to carry real and imag coeff inputs");
}
}
if (getNumDpsInits() != 2) {
return op->emitOpError(
"expected outputs to be real and imag tensor/memref");
}
return success();
}
SmallVector<utils::IteratorType> FftOp::getLoopIteratorTypes() {
// There are `rank-1` outer loops. The fft itselfs has one loop for each
// stage, which handles the merge step -- taking two half size tensors and
// merge them into one tensor.
SmallVector<utils::IteratorType> iteratorTypes(getOperandRank(),
utils::IteratorType::parallel);
return iteratorTypes;
}
SmallVector<Range> FftOp::getIterationDomain(OpBuilder &builder) {
SmallVector<Range> res;
Location loc = getLoc();
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
for (auto en : llvm::enumerate(getOperandShape().drop_back())) {
Value size;
if (en.value() == ShapedType::kDynamic) {
size = getDimValue(builder, loc, getReal(), en.index());
} else {
size = builder.create<arith::ConstantIndexOp>(loc, en.value());
}
res.emplace_back(Range{/*offset=*/zero, size, /*stride=*/one});
}
Value size = getDimValue(builder, loc, getReal(), getOperandRank() - 1);
Value stride = builder.create<arith::ShLIOp>(loc, one, getStage());
res.emplace_back(Range{/*offset=*/zero, size, /*stride=*/stride});
return res;
}
void FftOp::generateScalarImplWithoutCoeffBuf(OpBuilder &b, Location loc,
ArrayRef<Value> operands,
Value wholeSize) {
auto rank = getOperandRank();
SmallVector<AffineMap> maps(operands.size(), b.getMultiDimIdentityMap(rank));
auto f32Type = b.getF32Type();
auto indexToF32 = [](OpBuilder &builder, Location loc, Value v) -> Value {
v = builder.create<arith::IndexCastOp>(loc, builder.getI32Type(), v);
return builder.create<arith::SIToFPOp>(loc, builder.getF32Type(), v);
};
// We will need exp(-2 * PI * j / m * I), compute "-2 * PI / m" for imag part
// first.
Value coeff = b.create<arith::ConstantFloatOp>(
loc, llvm::APFloat(static_cast<float>(-2 * acos(-1))), f32Type);
coeff = b.create<arith::DivFOp>(loc, coeff, indexToF32(b, loc, wholeSize));
b.create<linalg::GenericOp>(
loc, TypeRange{}, ValueRange{}, operands, maps, getLoopIteratorTypes(),
[&](OpBuilder &b, Location loc, ValueRange args) {
Value lhsReal = args[0];
Value lhsImag = args[1];
Value rhsReal = args[2];
Value rhsImag = args[3];
// Compute "-2 * PI / m * j"
Value w = b.create<arith::MulFOp>(
loc, coeff,
indexToF32(b, loc, b.create<linalg::IndexOp>(loc, rank - 1)));
Value wReal = b.create<math::CosOp>(loc, w);
Value wImag = b.create<math::SinOp>(loc, w);
// t = w * a[k + j + mh];
// -> (x + yi)(u + vi) = (xu - yv) + (xv + yu)i
Value xu = b.create<arith::MulFOp>(loc, wReal, rhsReal);
Value yv = b.create<arith::MulFOp>(loc, wImag, rhsImag);
Value xv = b.create<arith::MulFOp>(loc, wReal, rhsImag);
Value yu = b.create<arith::MulFOp>(loc, wImag, rhsReal);
Value tReal = b.create<arith::SubFOp>(loc, xu, yv);
Value tImag = b.create<arith::AddFOp>(loc, xv, yu);
// cplx u = a[k + j];
// a[k + j] = u + t;
// a[k + j + mh] = u - t;
Value r1 = b.create<arith::AddFOp>(loc, lhsReal, tReal);
Value r2 = b.create<arith::AddFOp>(loc, lhsImag, tImag);
Value r3 = b.create<arith::SubFOp>(loc, lhsReal, tReal);
Value r4 = b.create<arith::SubFOp>(loc, lhsImag, tImag);
b.create<linalg::YieldOp>(loc, ValueRange{r1, r2, r3, r4});
});
}
void FftOp::generateScalarImplWithCoeffBuf(OpBuilder &b, Location loc,
ArrayRef<Value> operands) {
auto rank = getOperandRank();
SmallVector<AffineMap> maps;
// The size of coefficent buffer is epxected to match `2^(stage-1)`, which
// equals to the last dim of operands.
maps.append(
2, AffineMap::get(rank, 0, b.getAffineDimExpr(rank - 1), b.getContext()));
maps.append(operands.size(), b.getMultiDimIdentityMap(rank));
b.create<linalg::GenericOp>(
loc, TypeRange{}, ValueRange{getRealCoeff(), getImagCoeff()}, operands,
maps, getLoopIteratorTypes(),
[&](OpBuilder &b, Location loc, ValueRange args) {
Value wReal = args[0];
Value wImag = args[1];
Value lhsReal = args[2];
Value lhsImag = args[3];
Value rhsReal = args[4];
Value rhsImag = args[5];
// t = w * a[k + j + mh];
// -> (x + yi)(u + vi) = (xu - yv) + (xv + yu)i
Value xu = b.create<arith::MulFOp>(loc, wReal, rhsReal);
Value yv = b.create<arith::MulFOp>(loc, wImag, rhsImag);
Value xv = b.create<arith::MulFOp>(loc, wReal, rhsImag);
Value yu = b.create<arith::MulFOp>(loc, wImag, rhsReal);
Value tReal = b.create<arith::SubFOp>(loc, xu, yv);
Value tImag = b.create<arith::AddFOp>(loc, xv, yu);
// cplx u = a[k + j];
// a[k + j] = u + t;
// a[k + j + mh] = u - t;
Value r1 = b.create<arith::AddFOp>(loc, lhsReal, tReal);
Value r2 = b.create<arith::AddFOp>(loc, lhsImag, tImag);
Value r3 = b.create<arith::SubFOp>(loc, lhsReal, tReal);
Value r4 = b.create<arith::SubFOp>(loc, lhsImag, tImag);
b.create<linalg::YieldOp>(loc, ValueRange{r1, r2, r3, r4});
});
}
// Generates FFT stage scalar implementation. This follows Cooley–Tukey FFT
// algorithm. The pseudo reference code is:
// let s <- stage of linalg_ext.fft
// int m = 1 << s;
// int mh = m >> 1;
// for (int k = 0; k < n; k += m) {
// for (int j = 0; j < mh; ++j) {
// cplx w = exp(-2 * PI * j / m * I);
// cplx t = w * a[k + j + mh];
// cplx u = a[k + j];
// a[k + j] = u + t;
// a[k + j + mh] = u - t;
// }
// }
LogicalResult FftOp::generateScalarImplementation(OpBuilder &b, Location loc,
ValueRange ivs) {
Value real = getReal();
Value imag = getImag();
Value stage = getStage();
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
Value wholeSize = b.create<arith::ShLIOp>(loc, one, stage);
Value halfSize = b.create<arith::ShRSIOp>(loc, wholeSize, one);
auto rank = getOperandRank();
SmallVector<Value> operands;
SmallVector<OpFoldResult> lhsIvs(ivs.begin(), ivs.end());
SmallVector<OpFoldResult> ones(rank, b.getIndexAttr(1));
SmallVector<OpFoldResult> sizes(rank, b.getIndexAttr(1));
sizes.back() = halfSize;
operands.push_back(
b.create<memref::SubViewOp>(loc, real, lhsIvs, sizes, ones));
operands.push_back(
b.create<memref::SubViewOp>(loc, imag, lhsIvs, sizes, ones));
SmallVector<OpFoldResult> rhsIvs(ivs.begin(), ivs.end());
rhsIvs.back() =
b.create<arith::AddIOp>(loc, ivs.back(), halfSize).getResult();
operands.push_back(
b.create<memref::SubViewOp>(loc, real, rhsIvs, sizes, ones));
operands.push_back(
b.create<memref::SubViewOp>(loc, imag, rhsIvs, sizes, ones));
if (hasCoeff()) {
generateScalarImplWithCoeffBuf(b, loc, operands);
} else {
generateScalarImplWithoutCoeffBuf(b, loc, operands, wholeSize);
}
return success();
}
FailureOr<TilingResult>
FftOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
int64_t rank = getOperandRank();
SmallVector<OpFoldResult> strides(rank, builder.getI64IntegerAttr(1));
SmallVector<Value> tiledOperands(3);
tiledOperands[0] = getStage();
tiledOperands[1] = getRealCoeff();
tiledOperands[2] = getImagCoeff();
SmallVector<Type, 4> resultTypes;
for (auto out : getOutputs()) {
tiledOperands.push_back(
getSlice(builder, getLoc(), out, offsets, sizes, strides));
if (hasTensorSemantics()) {
resultTypes.push_back(tiledOperands.back().getType());
}
}
Operation *tiledFftOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
return TilingResult{{tiledFftOp},
SmallVector<Value>(tiledFftOp->getResults())};
}
LogicalResult FftOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
resultOffsets.assign(offsets.begin(), offsets.end());
resultSizes.assign(sizes.begin(), sizes.end());
return success();
}
LogicalResult
FftOp::reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(b, reifiedReturnShapes);
}
//===----------------------------------------------------------------------===//
// ScanOp
//===----------------------------------------------------------------------===//
LogicalResult ScanOp::verify() {
Operation *op = getOperation();
if (getNumDpsInputs() != 1) {
return op->emitOpError("expected one input operands");
}
if (getNumDpsInits() != 2) {
return op->emitOpError("expected two output operands");
}
if (!input().getType().isa<ShapedType>()) {
return op->emitOpError("expected first input element type to be shaped");
}
auto accumulatorType = accumulator().getType().cast<ShapedType>();
auto inputType = input().getType().cast<ShapedType>();
auto outputType = output().getType().cast<ShapedType>();
ArrayRef<int64_t> inputShapes = inputType.getShape();
ArrayRef<int64_t> outputShapes = outputType.getShape();
if (accumulatorType.getElementType() != inputType.getElementType()) {
return op->emitOpError(
"expected input/accumulator element types to be identical");
}
ArrayRef<int64_t> accumulatorShape = accumulatorType.getShape();
int64_t accumulatorRank = accumulatorType.getRank();
if (accumulatorRank != inputType.getRank() - 1) {
return op->emitOpError(
"expected accumulator rank to be equal to input rank - 1");
}
SmallVector<int64_t> expectedAccumulatorShape;
for (int i = 0; i < inputType.getRank(); i++) {
if (i != getDimension())
expectedAccumulatorShape.push_back(inputShapes[i]);
}
if (llvm::any_of(llvm::zip(expectedAccumulatorShape, accumulatorShape),
[](std::tuple<int64_t, int64_t> s) {
return std::get<0>(s) != ShapedType::kDynamic &&
std::get<1>(s) != ShapedType::kDynamic &&
std::get<0>(s) != std::get<1>(s);
})) {
return op->emitOpError("incompatible input/accumulator shapes");
}
if (inputType.getElementType() != outputType.getElementType()) {
return op->emitOpError(
"expected input/output element types to be identical");
}
if (inputShapes.size() != outputShapes.size()) {
return op->emitOpError("expected input/output to have identical ranks");
}
if (llvm::any_of(llvm::zip(inputShapes, outputShapes),
[](std::tuple<int64_t, int64_t> s) {
return std::get<0>(s) != ShapedType::kDynamic &&
std::get<1>(s) != ShapedType::kDynamic &&
std::get<0>(s) != std::get<1>(s);
})) {
return op->emitOpError("incompatible input/output shapes");
}
return success();
}
SmallVector<Range> ScanOp::getIterationDomain(OpBuilder &builder) {
int64_t operandRank = getOperandRank();
SmallVector<Range> loopBounds(operandRank);
Location loc = getLoc();
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
Value source = input();
for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
loopBounds[dim].offset = zero;
loopBounds[dim].size = getDimValue(builder, loc, source, dim);
loopBounds[dim].stride = one;
}
return loopBounds;
}
SmallVector<utils::IteratorType> ScanOp::getLoopIteratorTypes() {
SmallVector<utils::IteratorType> iteratorTypes(getOperandRank(),
utils::IteratorType::parallel);
iteratorTypes[getDimension()] = utils::IteratorType::reduction;
return iteratorTypes;
}
// Generates naive scalar implementation of scan for a given operator f.
// For inclusive,
// output[0] = input[0]
// output[i] = f(output[i-1], input[i])
//
// For exclusive,
// output[0] = 0
// output[i] = f(output[i-1], input[i-1])
LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc,
ValueRange ivs) {
SmallVector<Value> indices, scanBlkArgs;
indices.append(ivs.begin(), ivs.end());
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
auto scanDim = getDimension();
auto cond = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
indices[scanDim], zero);
bool isInclusive = getInclusive();
SmallVector<Value> accIndices;
for (int i = 0; i < indices.size(); i++) {
if (i != scanDim)
accIndices.push_back(indices[i]);
}
auto scfIf = b.create<scf::IfOp>(
loc, cond,
[&](OpBuilder &b, Location loc) {
if (isInclusive) {
auto value = b.create<memref::LoadOp>(loc, input(), indices);
b.create<memref::StoreOp>(loc, value, output(), indices);
} else {
auto value = b.create<memref::LoadOp>(loc, accumulator(), accIndices);
b.create<memref::StoreOp>(loc, value, output(), indices);
}
b.create<scf::YieldOp>(loc);
},
[&](OpBuilder &b, Location loc) {
SmallVector<Value> indices(ivs.begin(), ivs.end());
Value iv = indices[scanDim];
Value ivMinusOne = b.create<arith::SubIOp>(loc, iv, one);
indices[scanDim] = ivMinusOne;
scanBlkArgs.push_back(b.create<memref::LoadOp>(loc, output(), indices));
Value i0;
if (!isInclusive)
i0 = b.create<memref::LoadOp>(loc, input(), indices);
indices[scanDim] = iv;
if (isInclusive)
i0 = b.create<memref::LoadOp>(loc, input(), indices);
scanBlkArgs.push_back(i0);
});
auto &srcBlock = getRegion().front();
Region &region = scfIf.getElseRegion();
IRMapping bvm;
{
OpBuilder::InsertionGuard guard(b);
auto &block = region.front();
b.setInsertionPointToEnd(&block);
for (auto it : llvm::zip(srcBlock.getArguments(), scanBlkArgs)) {
bvm.map(std::get<0>(it), std::get<1>(it));
}
for (auto &blockOp : srcBlock.without_terminator()) {
b.clone(blockOp, bvm);
}
b.create<memref::StoreOp>(
loc, bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)),
output(), indices);
b.create<memref::StoreOp>(
loc, bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)),
accumulator(), accIndices);
b.create<scf::YieldOp>(loc);
}
return success();
}
FailureOr<TilingResult>
ScanOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
int64_t rank = getOperandRank();
assert(offsets.size() == static_cast<size_t>(rank) &&
sizes.size() == static_cast<size_t>(rank));
auto oneAttr = builder.getI64IntegerAttr(1);
SmallVector<OpFoldResult> strides(rank, oneAttr);
SmallVector<Value> tiledOperands;
tiledOperands.emplace_back(
getSlice(builder, getLoc(), input(), offsets, sizes, strides));
tiledOperands.emplace_back(
getSlice(builder, getLoc(), getOutputs()[0], offsets, sizes, strides));
if (rank > 1) {
SmallVector<OpFoldResult> accumOffsets, accumSizes;
if (failed(getResultTilePosition(builder, 1, offsets, sizes, accumOffsets,
accumSizes))) {
return {};
}
SmallVector<OpFoldResult> accumStrides(rank - 1, oneAttr);
tiledOperands.emplace_back(getSlice(builder, getLoc(), getOutputs()[1],
accumOffsets, accumSizes,
accumStrides));
} else {
tiledOperands.emplace_back(getOutputs()[1]);
}
SmallVector<Type, 4> resultTypes;
if (hasTensorSemantics()) {
resultTypes.push_back(tiledOperands[1].getType());
resultTypes.push_back(tiledOperands[2].getType());
}
Operation *tiledScanOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
return TilingResult{{tiledScanOp},
SmallVector<Value>(tiledScanOp->getResults())};
}
LogicalResult ScanOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
if (resultNumber == 0) {
resultOffsets.assign(offsets.begin(), offsets.end());
resultSizes.assign(sizes.begin(), sizes.end());
return success();
}
if (resultNumber == 1) {
int64_t rank = getOperandRank();
if (rank > 1) {
for (auto i : llvm::seq<int64_t>(0, rank)) {
if (i == getDimension())
continue;
resultOffsets.push_back(offsets[i]);
resultSizes.push_back(sizes[i]);
}
}
return success();
}
return failure();
}
LogicalResult ScanOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
return memref::foldMemRefCast(*this);
}
LogicalResult
ScanOp::reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(b, reifiedReturnShapes);
}
//===----------------------------------------------------------------------===//
// ReverseOp
//===----------------------------------------------------------------------===//
LogicalResult ReverseOp::verify() {
Operation *op = getOperation();
if (getNumDpsInputs() != 1) {
return op->emitOpError("expected exactly one input");
}
if (getNumDpsInits() != 1) {
return op->emitOpError("expected exactly one output");
}
auto inputType = input().getType().cast<ShapedType>();
auto outputType = output().getType().cast<ShapedType>();
if (inputType.getElementType() != outputType.getElementType()) {
return op->emitOpError(
"expected input/output element types to be identical");
}
ArrayRef<int64_t> inputShapes = inputType.getShape();
ArrayRef<int64_t> outputShapes = outputType.getShape();
if (inputShapes.size() != outputShapes.size()) {
return op->emitOpError("expexted input/output to have identical ranks");
}
if (llvm::any_of(llvm::zip(inputShapes, outputShapes),
[](std::tuple<int64_t, int64_t> s) {
return std::get<0>(s) != ShapedType::kDynamic &&
std::get<1>(s) != ShapedType::kDynamic &&
std::get<0>(s) != std::get<1>(s);
})) {
return op->emitOpError("incompatible input/output shapes");
}
int64_t rank = getOperandRank();
llvm::SmallSetVector<int64_t, 4> s;
for (auto dim : dims()) {
if (dim < 0 || dim >= rank) {
return op->emitOpError("all the dimensions must be within [0, ")
<< rank << ")";
}
if (s.contains(dim)) {
return op->emitOpError("expected dimensions numbers are all unique");
}
s.insert(dim);
}
return success();
}
SmallVector<utils::IteratorType> ReverseOp::getLoopIteratorTypes() {
SmallVector<utils::IteratorType> iteratorTypes(getOperandRank(),
utils::IteratorType::parallel);
return iteratorTypes;
}
SmallVector<Range> ReverseOp::getIterationDomain(OpBuilder &builder) {
Location loc = getLoc();
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
SmallVector<Range> ranges;
for (auto dim : llvm::seq<int64_t>(0, getOperandRank())) {
Value ub = getDimValue(builder, loc, input(), dim);
ranges.emplace_back(Range{zero, ub, one});
}
return ranges;
}
LogicalResult ReverseOp::generateScalarImplementation(OpBuilder &b,
Location loc,
ValueRange ivs) {
SmallVector<Value> mirrorIndices(ivs.begin(), ivs.end());
for (auto dim : dims()) {
auto size = getDimValue(b, loc, input(), dim);
size = b.create<arith::SubIOp>(loc, size,
b.create<arith::ConstantIndexOp>(loc, 1));
mirrorIndices[dim] = b.create<arith::SubIOp>(loc, size, mirrorIndices[dim]);
}
Value val = b.create<memref::LoadOp>(loc, input(), ivs);
b.create<memref::StoreOp>(loc, val, output(), mirrorIndices);
return success();
}
FailureOr<TilingResult>
ReverseOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
int64_t rank = getOperandRank();
SmallVector<OpFoldResult> strides(rank, builder.getI64IntegerAttr(1));
Location loc = getLoc();
SmallVector<OpFoldResult> mirrorOffsets, mirrorSizes;
if (failed(getResultTilePosition(builder, 0, offsets, sizes, mirrorOffsets,
mirrorSizes))) {
return {};
}
SmallVector<Value> tiledOperands;
tiledOperands.emplace_back(
getSlice(builder, loc, input(), offsets, sizes, strides));
SmallVector<Type, 4> resultTypes;
if (hasTensorSemantics()) {
tiledOperands.emplace_back(
getSlice(builder, loc, output(), mirrorOffsets, sizes, strides));
resultTypes.push_back(tiledOperands[1].getType());
} else {
tiledOperands.emplace_back(
getSlice(builder, loc, output(), mirrorOffsets, sizes, strides));
}
Operation *tiledRevOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
return TilingResult{{tiledRevOp},
SmallVector<Value>(tiledRevOp->getResults())};
}
LogicalResult ReverseOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
AffineExpr sym0, sym1, sym2;
bindSymbols(builder.getContext(), sym0, sym1, sym2);
AffineMap map =
AffineMap::get(/*dimCount=*/0, /*symbolCount=*/3, {sym0 - sym1 - sym2});
resultOffsets.assign(offsets.begin(), offsets.end());
Location loc = getLoc();
for (auto dim : dims()) {
Value size = getDimValue(builder, loc, input(), dim);
Value offset =
getValueOrCreateConstantIndexOp(builder, loc, resultOffsets[dim]);
Value tileSize = getValueOrCreateConstantIndexOp(builder, loc, sizes[dim]);
resultOffsets[dim] = builder
.create<affine::AffineApplyOp>(
loc, map, ValueRange{size, offset, tileSize})
.getResult();
}
resultSizes.assign(sizes.begin(), sizes.end());
return success();
}
LogicalResult
ReverseOp::reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(b, reifiedReturnShapes);
}
//===----------------------------------------------------------------------===//
// TopkOp
//===----------------------------------------------------------------------===//
LogicalResult TopkOp::verify() {
Operation *op = getOperation();
if (getNumDpsInputs() != 1 && getNumDpsInputs() != 2) {
return op->emitOpError("expected one or two input operands");
}
if (getNumDpsInits() != 2) {
return op->emitOpError("expected two output operands");
}
if (getDimension() >= getInputRank()) {
return op->emitOpError("dimension exceeds rank");
}
// Ensure input/output element types match
auto inputValuesType = values().getType().cast<ShapedType>();
auto outputValuesType = outputValues().getType().cast<ShapedType>();
if (inputValuesType.getElementType() != outputValuesType.getElementType()) {
return op->emitOpError("expected input/output value types to be identical");
}
// Indices must be int if provided
auto outputIndicesType = outputIndices().getType().cast<ShapedType>();
if (auto inputIndices = indices()) {
auto inputIndicesType = inputIndices->getType().cast<ShapedType>();
if (!inputIndicesType.getElementType().isInteger(32) ||
!outputIndicesType.getElementType().isInteger(32)) {
return op->emitOpError("expected input/output indices types to be int32");
}
}
// Ranks must match
if (inputValuesType.getRank() != outputValuesType.getRank()) {
return op->emitOpError("expected input/output to have the same rank");
}
if (auto inputIndices = indices()) {
auto inputIndicesType = inputIndices->getType().cast<ShapedType>();
if (inputIndicesType.getRank() != outputIndicesType.getRank()) {
return op->emitOpError("expected input/output to have the same rank");
}
}
// Input indicies and values must have the same shape.
if (auto inputIndices = indices()) {
auto inputIndicesType = inputIndices->getType().cast<ShapedType>();
if (failed(verifyCompatibleShape(inputValuesType, inputIndicesType)))
return op->emitOpError("input indices/values shape must match");
}
// Output indicies and values must have the same shape.
if (failed(verifyCompatibleShape(outputValuesType, outputIndicesType)))
return op->emitOpError("output indices/values shape must match");
// Input shape must match the output shape except for the dimension()
uint64_t dim = getDimension();
if (!llvm::all_of(llvm::enumerate(llvm::zip(inputValuesType.getShape(),
outputValuesType.getShape())),
[dim](auto e) {
if (e.index() == dim) {
return true;
}
std::tuple<int64_t, int64_t> s = e.value();
return succeeded(verifyCompatibleShape(std::get<0>(s),
std::get<1>(s)));
})) {
return op->emitOpError("incompatible input/output shapes");
}
// Check region compatibility
Block &block = getRegion().front();
if (block.getNumArguments() != 2) {
return op->emitOpError("region block should have 2 arguments");
}
if (block.getArgument(0).getType() != inputValuesType.getElementType() ||
block.getArgument(1).getType() != inputValuesType.getElementType()) {
return op->emitOpError("region block types must match input");
}
auto terminatorOp = llvm::dyn_cast<YieldOp>(block.getTerminator());
if (!terminatorOp || !terminatorOp.getOperand(0).getType().isInteger(1)) {
return op->emitOpError("region block must end with a linalg_ext.yield i1!");
}
return success();
}
SmallVector<Range> TopkOp::getIterationDomain(OpBuilder &builder) {
int64_t operandRank = getInputRank();
SmallVector<Range> loopBounds(operandRank);
Location loc = getLoc();
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
Value source = values();
for (auto dim : llvm::enumerate(getInputType().getShape())) {
loopBounds[dim.index()].offset = zero;
loopBounds[dim.index()].size =
getDimValue(builder, loc, source, dim.index());
loopBounds[dim.index()].stride = one;
}
return loopBounds;
}
SmallVector<utils::IteratorType> TopkOp::getLoopIteratorTypes() {
SmallVector<utils::IteratorType> iteratorTypes(getInputRank(),
utils::IteratorType::parallel);
iteratorTypes[getDimension()] = utils::IteratorType::reduction;
return iteratorTypes;
}
LogicalResult TopkOp::generateScalarImplementation(OpBuilder &b, Location loc,
ValueRange ivs) {
uint64_t kDim = getDimension();
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
Value initialValue = b.create<memref::LoadOp>(loc, values(), ivs);
// If the indices tensor is not provided, the value index is derived from the
// loop induction variables.
Value initialIndex;
if (indices()) {
initialIndex = b.create<memref::LoadOp>(loc, *indices(), ivs);
} else {
Value rawInitialIndex = ivs[kDim];
initialIndex =
b.create<arith::IndexCastOp>(loc, b.getI32Type(), rawInitialIndex);
}
// Compute K (ub) from the selected dim of the output
Value ub = b.create<memref::DimOp>(loc, outputValues(), getDimension());
// Inner K loop functions:
// Load current K value and index
// Compare N/K using inserted block compare
// Check if N == K using strict weak ordering, select which index came first
// Select new K value from N/K comparison
// Select new K index from N/K comparison or which index came first
// Store new k value and index
// Yield loop carry values after K selection
Value kValue, kIndex;
auto scfFor = b.create<scf::ForOp>(
loc, zero, ub, one, ValueRange{initialValue, initialIndex},
[&](OpBuilder &b, Location loc, Value iv, ValueRange loopCarryValues) {
SmallVector<Value> indices(ivs);
indices[kDim] = iv;
kValue = b.create<memref::LoadOp>(loc, outputValues(), indices);
kIndex = b.create<memref::LoadOp>(loc, outputIndices(), indices);
});
SmallVector<Value> indices(ivs);
indices[kDim] = scfFor.getInductionVar();
auto loopCarryValues = scfFor.getRegionIterArgs();
// Retrieve region as black box comparision function f(x,y). Plug into op.
auto &srcBlock = getRegion().front();
IRMapping bvmF; // f(x,y)
IRMapping bvmR; // f(y,x)
{
// Save previous insertion point. Continue within loop body.
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToEnd(&scfFor.getRegion().front());
SmallVector<Value> forwardValues{loopCarryValues[0], kValue};
SmallVector<Value> reverseValues{kValue, loopCarryValues[0]};
for (auto it : llvm::zip(srcBlock.getArguments(), forwardValues)) {
bvmF.map(std::get<0>(it), std::get<1>(it));
}
for (auto it : llvm::zip(srcBlock.getArguments(), reverseValues)) {
bvmR.map(std::get<0>(it), std::get<1>(it));
}
for (auto &blockOp : srcBlock.without_terminator()) {
b.clone(blockOp, bvmF);
b.clone(blockOp, bvmR);
}
Value forwardCmpRes = bvmF.lookup(srcBlock.getTerminator()->getOperand(0));
Value reverseCmpRes = bvmR.lookup(srcBlock.getTerminator()->getOperand(0));
// Check value equality using strictly weak ordering from the region:
// f(x,y) --> forwardCmpRes
// f(y,x) --> reverseCmpRes
// if forwardCmpRes == reverseCmpRes then select which came first
Value cmpValuesEqual = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, forwardCmpRes, reverseCmpRes);
Value cmpFirstIndex = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, loopCarryValues[1], kIndex);
Value combinedCmpEqRes =
b.create<arith::AndIOp>(loc, cmpValuesEqual, cmpFirstIndex);
// True if N > K or N came before K
Value indexCmpRes =
b.create<arith::OrIOp>(loc, forwardCmpRes, combinedCmpEqRes);
// Select results for K based on comparisons
Value resultKValue = b.create<arith::SelectOp>(loc, forwardCmpRes,
loopCarryValues[0], kValue);
Value resultKIndex =
b.create<arith::SelectOp>(loc, indexCmpRes, loopCarryValues[1], kIndex);
b.create<memref::StoreOp>(loc, resultKValue, outputValues(), indices);
b.create<memref::StoreOp>(loc, resultKIndex, outputIndices(), indices);
// Select loop carry, opposite of K results
Value resultCarryValue = b.create<arith::SelectOp>(
loc, forwardCmpRes, kValue, loopCarryValues[0]);
Value resultCarryIndex =
b.create<arith::SelectOp>(loc, indexCmpRes, kIndex, loopCarryValues[1]);
b.create<scf::YieldOp>(loc, ValueRange{resultCarryValue, resultCarryIndex});
}
return success();
}
FailureOr<TilingResult>
TopkOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
int64_t rank = getInputRank();
assert(offsets.size() == static_cast<size_t>(rank) &&
sizes.size() == static_cast<size_t>(rank));
SmallVector<OpFoldResult> strides(rank, builder.getI64IntegerAttr(1));
Location loc = getLoc();
SmallVector<OpFoldResult> outputOffsets, outputSizes;
if (failed(getResultTilePosition(builder, 0, offsets, sizes, outputOffsets,
outputSizes))) {
return {};
}
SmallVector<Value> tiledOperands;
tiledOperands.emplace_back(
getSlice(builder, loc, values(), offsets, sizes, strides));
if (indices()) {
tiledOperands.emplace_back(
getSlice(builder, loc, *indices(), offsets, sizes, strides));
}
// Replace the tile size for the K dimension to use the output size instead of
// the input size.
Value kSize = getDimValue(builder, getLoc(), outputValues(), getDimension());
outputSizes[getDimension()] = getAsOpFoldResult(kSize);
tiledOperands.emplace_back(
getSlice(builder, loc, getOutputs()[0], offsets, outputSizes, strides));
tiledOperands.emplace_back(
getSlice(builder, loc, getOutputs()[1], offsets, outputSizes, strides));
SmallVector<Type, 2> resultTypes;
if (hasTensorSemantics()) {
resultTypes.push_back(tiledOperands[tiledOperands.size() - 2].getType());
resultTypes.push_back(tiledOperands[tiledOperands.size() - 1].getType());
}
Operation *tiledTopkOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
return TilingResult{{tiledTopkOp},
SmallVector<Value>(tiledTopkOp->getResults())};
}
LogicalResult TopkOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
resultOffsets.assign(offsets.begin(), offsets.end());
resultSizes.assign(sizes.begin(), sizes.end());
Value kSize = getDimValue(builder, getLoc(), getDpsInits()[resultNumber],
getDimension());
resultSizes[getDimension()] = getAsOpFoldResult(kSize);
return success();
}
LogicalResult
TopkOp::reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(b, reifiedReturnShapes);
}
//===----------------------------------------------------------------------===//
// PackOp and UnPackOp utils
//===----------------------------------------------------------------------===//
/// Return true if at least one element in `tiles` is zero.
static bool hasZeros(ArrayRef<OpFoldResult> tiles) {
return llvm::any_of(
tiles, [&](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
}
/// Check if we have enough static information to catch undefined behavior when
/// the tile size does not divide perfectly the dimension of the input tensor.
static bool
areNotFullTiles(ArrayRef<int64_t> inputShape,
DenseMap<int64_t, OpFoldResult> const &dimAndTileMapping) {
int64_t rank = inputShape.size();
for (int64_t dim = 0; dim < rank; dim++) {
if (inputShape[dim] == ShapedType::kDynamic)
continue;
auto it = dimAndTileMapping.find(dim);
if (it != dimAndTileMapping.end()) {
std::optional<int64_t> constantTile = getConstantIntValue(it->second);
if (!constantTile)
continue;
if (inputShape[dim] % (*constantTile) != 0)
return true;
}
}
return false;
}
/// Utility function shared between Pack and UnPack to get the tile sizes as
/// OpFoldResults.
// TODO: interface or base class in .td
template <typename OpTy>
static SmallVector<OpFoldResult> getMixedTiles(OpTy op) {
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
"applies to only pack or unpack operations");
SmallVector<OpFoldResult> mixedInnerTiles;
unsigned dynamicValIndex = 0;
OpBuilder b(op.getContext());
for (int64_t tileSize : op.getStaticInnerTiles()) {
if (!ShapedType::isDynamic(tileSize))
mixedInnerTiles.push_back(b.getIndexAttr(tileSize));
else
mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
}
return mixedInnerTiles;
}
/// Return the tile sizes as `int64_t`. If a tile size is dynamic a sentinel
/// `kDynamic` is introduced at that position in the returned vector.
template <typename OpTy>
static SmallVector<int64_t> getStaticTiles(OpTy op) {
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
"applies to only pack or unpack operations");
SmallVector<Value> dynamicTiles;
SmallVector<int64_t> staticTiles;
dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
return staticTiles;
}
/// Utility function shared between Pack and UnPack to get a map between
/// `dim_pos` and `inner_tiles`.
// TODO: interface or base class in .td
template <typename OpTy>
static DenseMap<int64_t, OpFoldResult> getDimAndTileMapping(OpTy op) {
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
"applies to only pack or unpack operations");
DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
ArrayRef<int64_t> dimsToBlock = op.getInnerDimsPos();
SmallVector<OpFoldResult> tiles = op.getMixedTiles();
assert(tiles.size() == dimsToBlock.size() &&
"tiles must match indices of dimension to block");
// bind the dimension with the tile factor.
for (auto i : llvm::seq<int64_t>(0, dimsToBlock.size()))
dimAndTileMapping[dimsToBlock[i]] = tiles[i];
return dimAndTileMapping;
}
/// Utility function to build the iteration domain for `packOp` or `unPackOp`.
template <typename OpTy>
static SmallVector<Range> getIterationDomain(OpTy op, OpBuilder &builder) {
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
"applies to only pack or unpack operations");
OpBuilder::InsertionGuard g(builder);
Location loc = op.getLoc();
int64_t rank = (std::is_same<OpTy, PackOp>::value) ? op.getInputRank()
: op.getOutputRank();
SmallVector<Range> loopBounds(rank);
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
ReifiedRankedShapedTypeDims resultShape;
(void)op.reifyResultShapes(builder, resultShape);
for (auto dim : llvm::seq<int64_t>(0, rank)) {
loopBounds[dim].offset = zero;
loopBounds[dim].stride = one;
loopBounds[dim].size = resultShape[0][dim];
}
return loopBounds;
}
/// Common verifier for `PackOp` and `UnPackOp`.
template <typename OpTy>
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
"applies to only pack or unpack operations");
Operation *op = packOrUnPack.getOperation();
ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
? packOrUnPack.getInputType()
: packOrUnPack.getOutputType();
int64_t unpackedRank = unpackedType.getRank();
ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
// Verify tiles. Make sure each provided tile is non-zero.
SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
if (hasZeros(mixedTiles))
return op->emitError("invalid tile factor");
if (isInvalid(innerDimsPos, unpackedRank))
return op->emitError("invalid inner_dims_pos vector");
if (isInvalid(outerDimPerm, unpackedRank))
return op->emitError("invalid outer_dims_perm vector");
if (mixedTiles.size() != innerDimsPos.size()) {
return op->emitError(
"blocking factors must equal the number of dimensions to block");
}
// Blocking factors must be less or equal than the input rank, and must
// match the number of `dims_pos`.
if (mixedTiles.size() > unpackedRank) {
return op->emitError(
"blocking factors must be less or equal than the input rank");
}
ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
? packOrUnPack.getOutputType()
: packOrUnPack.getInputType();
int64_t packedRank = packedType.getRank();
// Require output rank to match input rank + number of blocking factors.
if (unpackedRank + mixedTiles.size() != packedRank) {
return op->emitError(
"packed rank must equal unpacked rank + blocking factors");
}
// Verify result shape is greater than the minimum expected
// by the pack operation, and that the output shape
// represents full tiles.
ShapedType expectedPackedType = PackOp::getPackedType(
unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
if (!isSmallerThan(expectedPackedType.getShape(), packedType.getShape())) {
return op->emitError("the shape of output is not large enough to hold the "
"packed data. Expected at least ")
<< expectedPackedType << ", got " << packedType;
}
if (!llvm::all_of(
llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
mixedTiles),
[](std::tuple<int64_t, OpFoldResult> it) {
std::optional<int64_t> constTileSize =
getConstantIntValue(std::get<1>(it));
int64_t shape = std::get<0>(it);
if (!constTileSize) {
// If specified tile size is dynamic, output shape should
// be dynamic too.
return shape == ShapedType::kDynamic;
} else {
if (shape == ShapedType::kDynamic) {
// For the shape being dynamic when tile size is
// specified, return true. In canonical form a constant
// tile size should lead to constant shape of the tiled
// dimension, but not needed for verification.
return true;
}
return shape == constTileSize.value();
}
})) {
return op->emitError("mismatch in inner tile sizes specified and shaped of "
"tiled dimension in the packed type");
}
return success();
}
//===----------------------------------------------------------------------===//
// PackOp
//===----------------------------------------------------------------------===//
/// Custom builder methods for pack ops.
void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
Value output, ArrayRef<int64_t> innerDimsPos,
ArrayRef<OpFoldResult> innerTiles,
std::optional<Value> paddingValue,
ArrayRef<int64_t> outerDimsPerm) {
assert(innerDimsPos.size() == innerTiles.size() &&
"number of tile sizes specified must match the specified number of "
"original dimensions to be tiled");
SmallVector<int64_t> staticTileSizes;
SmallVector<Value> dynamicTileSizes;
dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
SmallVector<Type> resultType;
auto outputType = output.getType();
if (outputType.isa<RankedTensorType>())
resultType.push_back(outputType);
build(builder, state, resultType, source, output,
outerDimsPerm.empty() ? nullptr
: builder.getDenseI64ArrayAttr(outerDimsPerm),
builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
builder.getDenseI64ArrayAttr(staticTileSizes),
(paddingValue ? paddingValue.value() : nullptr));
}
LogicalResult PackOp::verify() {
if (failed(commonVerifierPackAndUnPackOp(*this))) {
return failure();
}
// Bail out if the tile does not divide the dimension fully. In the case of
// dynamic tile factors or dimensions, having a partial tile is undefined
// behavior.
auto dimAndTileMapping = getDimAndTileMapping();
if (!getPaddingValue() &&
areNotFullTiles(getInputShape(), dimAndTileMapping)) {
return emitOpError("invalid tile factor provided. Only full tiles are "
"supported when padding_value is not set");
}
if (auto paddingValue = getPaddingValue()) {
if (paddingValue.getType() != getInputType().getElementType()) {
return emitOpError("expected padding_value has ")
<< getInputType().getElementType()
<< " but got: " << paddingValue.getType();
}
}
return success();
}
SmallVector<OpFoldResult> PackOp::getMixedTiles() {
return ::getMixedTiles(*this);
}
SmallVector<int64_t> PackOp::getStaticTiles() {
return ::getStaticTiles(*this);
}
// Helper for PackOp::{getResultShape,getPackedType}. Returns the shape of the
// packed type. Having a shared helper helps implement these two methods in a
// way that ensures that they agree on which dimensions are dynamic.
static SmallVector<int64_t> getPackOpResultTypeShape(
ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
for (auto tiledDim : llvm::enumerate(innerDimsPos)) {
if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
continue;
if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
resultShape[tiledDim.value()] = ShapedType::kDynamic;
continue;
}
resultShape[tiledDim.value()] = ceilDiv(resultShape[tiledDim.value()],
innerTileSizes[tiledDim.index()]);
}
// Swap tile loops if outer_dims_perm is available.
resultShape = interchange<int64_t>(resultShape, outerDimsPerm, /*offset=*/0);
// Append the inner tile dimensions.
resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
return resultShape;
}
SmallVector<OpFoldResult> PackOp::getResultShape(
OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm) {
SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
AffineExpr s0, s1;
bindSymbols(builder.getContext(), s0, s1);
AffineExpr ceilDivExpr = s0.ceilDiv(s1);
for (auto tiledDim : llvm::enumerate(innerDimsPos)) {
resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
builder, loc, ceilDivExpr,
{resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
}
if (!outerDimsPerm.empty()) {
resultDims =
interchange<OpFoldResult>(resultDims, outerDimsPerm, /*offset=*/0);
}
resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
SmallVector<int64_t> resultTypeShape =
getPackOpResultTypeShape(asShapeWithAnyValueAsDynamic(sourceDims),
asShapeWithAnyValueAsDynamic(innerTileSizes),
innerDimsPos, outerDimsPerm);
// Fix-up `resultDims` to ensure that they are Value's if and only if the
// result type shape says it's a dynamic dim. This is needed as callers may
// use dispatchIndexOpFoldResults on the result, and rely on exact number of
// dynamic dims returned by that.
for (unsigned i = 0; i < resultDims.size(); ++i) {
if (!ShapedType::isDynamic(resultTypeShape[i]))
continue;
resultDims[i] =
getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
}
return resultDims;
}
SmallVector<OpFoldResult> PackOp::getResultShape(OpBuilder &builder) {
return tensor::getMixedSizes(builder, getLoc(), getOutput());
}
ShapedType PackOp::getPackedType(ShapedType sourceType,
ArrayRef<int64_t> innerTileSizes,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm) {
SmallVector<int64_t> resultTypeShape = getPackOpResultTypeShape(
sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
return TypeSwitch<ShapedType, ShapedType>(sourceType)
.Case<RankedTensorType>([&](auto shapedType) {
return RankedTensorType::get(resultTypeShape,
shapedType.getElementType());
})
.Case<MemRefType>([&](auto shapedType) {
return MemRefType::get(resultTypeShape, shapedType.getElementType());
})
.Default([&](Type t) {
assert(false && "unexpected type");
return nullptr;
});
}
DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
return ::getDimAndTileMapping(*this);
}
SmallVector<Range> PackOp::getIterationDomain(OpBuilder &builder) {
return ::getIterationDomain(*this, builder);
}
/// Generate the body of the innermost loop of the scalar implementation
/// of `pack` operation.
static void generatePackOpScalarImplementationBody(PackOp packOp,
OpBuilder &builder,
Location loc,
ValueRange ivs) {
// Note: `ivs` are already in the correct order, possibly interchanged based
// on `dims_pos`. However, connecting the loops with the access patterns is
// difficult - What is the relation between the position of the tile loop and
// the point loop? However, if we interchange `ivs` once more to go to the
// canonical blocking format: ABCabc, this connection becomes trivial: Each
// point loop is pointLoopsOffset + inputRank away from the tiled loop.
ArrayRef<int64_t> dimsToInnerBlock = packOp.getInnerDimsPos();
ArrayRef<int64_t> dimsToOuterBlock = packOp.getOuterDimsPerm();
SmallVector<Value> interchangedIvs = ivs;
SmallVector<int64_t> interchangeVector =
computeInterchangeFromDimPos(dimsToInnerBlock, packOp.getInputRank());
interchangedIvs = interchange<Value>(interchangedIvs, interchangeVector,
/*offset=*/packOp.getInputRank());
if (!dimsToOuterBlock.empty()) {
interchangeVector =
computeInterchangeFromDimPos(dimsToOuterBlock, packOp.getInputRank());
interchangedIvs =
interchange<Value>(interchangedIvs, interchangeVector, /*offset=*/0);
}
SmallVector<OpFoldResult> tiles = packOp.getMixedTiles();
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
packOp.getDimAndTileMapping();
SmallVector<OpFoldResult> sourceIndices;
size_t pointLoopsOffset = 0;
int64_t inputRank = packOp.getInputRank();
for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
if (dimAndTileMapping.count(dim)) {
AffineExpr i, j, tile;
bindDims(builder.getContext(), i, j);
bindSymbols(builder.getContext(), tile);
OpFoldResult sourceIndex = affine::makeComposedFoldedAffineApply(
builder, loc, i * tile + j,
ArrayRef<OpFoldResult>{
interchangedIvs[dim],
interchangedIvs[pointLoopsOffset + packOp.getInputRank()],
dimAndTileMapping[dim]});
sourceIndices.push_back(sourceIndex);
++pointLoopsOffset;
} else {
sourceIndices.push_back(interchangedIvs[dim]);
}
}
auto createLoad = [&]() -> Value {
return builder.create<memref::LoadOp>(
loc, packOp.getInput(),
getValueOrCreateConstantIndexOp(builder, loc, sourceIndices));
};
Value scalar;
if (auto paddingValue = packOp.getPaddingValue()) {
ArithBuilder arithBuilder(builder, loc);
Value isInBounds;
for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
Value idx =
getValueOrCreateConstantIndexOp(builder, loc, sourceIndices[dim]);
Value cond = arithBuilder.slt(
idx, getDimValue(builder, loc, packOp.getInput(), dim));
isInBounds = dim == 0 ? cond : arithBuilder._and(isInBounds, cond);
}
scalar = builder
.create<scf::IfOp>(
loc, isInBounds, /*thenBuilder=*/
[&](OpBuilder &b, Location l) {
b.create<scf::YieldOp>(l, createLoad());
},
/*elseBuilder=*/
[&](OpBuilder &b, Location l) {
b.create<scf::YieldOp>(l, paddingValue);
})
.getResult(0);
} else {
scalar = createLoad();
}
builder.create<memref::StoreOp>(loc, scalar, packOp.getOutput(), ivs);
}
LogicalResult PackOp::generateScalarImplementation(OpBuilder &builder,
Location loc,
ValueRange ivs) {
OpBuilder::InsertionGuard g(builder);
// The `ivs` already represent the position into the output tensor for the
// non data-tile dimensions.
SmallVector<Value> ivVec = llvm::to_vector(ivs);
ReifiedRankedShapedTypeDims outputShape;
if (failed(reifyResultShapes(builder, outputShape)))
return getOperation()->emitOpError("failed to reify result shape");
if (outputShape.size() != 1 || outputShape[0].size() != getOutputRank()) {
return getOperation()->emitOpError(
"expected shape of one result value of rank")
<< getOutputRank();
}
// Generate the loops that iterate over the data tile.
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
// All loops except the innermost are simple loops that just iterate
// over the tile dimensions.
for (auto dataTileDim :
llvm::seq<unsigned>(getInputRank(), getOutputRank() - 1)) {
Value ub = getValueOrCreateConstantIndexOp(builder, loc,
outputShape[0][dataTileDim]);
scf::ForOp loop = builder.create<scf::ForOp>(loc, zero, ub, one);
builder.setInsertionPointToStart(loop.getBody());
ivVec.push_back(loop.getInductionVar());
}
// The body of the innermost loops does the actual data movement.
builder.create<scf::ForOp>(
loc, zero,
getValueOrCreateConstantIndexOp(builder, loc, outputShape[0].back()), one,
ValueRange{},
[&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
ValueRange regionIterArgs) {
ivVec.push_back(iv);
generatePackOpScalarImplementationBody(*this, bodyBuilder, bodyLoc,
ivVec);
bodyBuilder.create<scf::YieldOp>(bodyLoc);
});
return success();
}
LogicalResult
PackOp::reifyResultShapes(OpBuilder &builder,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(builder, reifiedReturnShapes);
}
//===----------------------------------------------------------------------===//
// UnPackOp
//===----------------------------------------------------------------------===//
/// Custom builder methods for unpack ops.
void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
Value output, ArrayRef<int64_t> innerDimsPos,
ArrayRef<OpFoldResult> innerTiles,
ArrayRef<int64_t> outerDimsPerm) {
SmallVector<int64_t> staticTileSizes;
SmallVector<Value> dynamicTileSizes;
dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
SmallVector<Type> resultType;
auto outputType = output.getType();
if (outputType.isa<RankedTensorType>())
resultType.push_back(outputType);
build(builder, state, resultType, source, output,
outerDimsPerm.empty() ? nullptr
: builder.getDenseI64ArrayAttr(outerDimsPerm),
builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
builder.getDenseI64ArrayAttr(staticTileSizes));
}
SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
return ::getMixedTiles(*this);
}
SmallVector<int64_t> UnPackOp::getStaticTiles() {
return ::getStaticTiles(*this);
}
DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
return ::getDimAndTileMapping(*this);
}
LogicalResult UnPackOp::generateScalarImplementation(OpBuilder &builder,
Location loc,
ValueRange ivs) {
assert(ivs.size() == getOutputRank() &&
"number of ivs must match the rank of the output tensor");
OpBuilder::InsertionGuard g(builder);
ReifiedRankedShapedTypeDims outputShape;
if (failed(reifyResultShapes(builder, outputShape)))
return getOperation()->emitOpError("failed to reify result shape");
if (outputShape.size() != 1 || outputShape[0].size() != getOutputRank()) {
return getOperation()->emitOpError(
"expected shape of one result value of rank")
<< getOutputRank();
}
DenseMap<int64_t, OpFoldResult> dimAndTileMapping = getDimAndTileMapping();
// untiled loops and tile loops induction variables.
SmallVector<Value> inputIvs;
// point loops induction variables.
SmallVector<Value> inputIvsPointLoops;
inputIvs.reserve(getOutputRank());
inputIvsPointLoops.reserve(dimAndTileMapping.size());
for (auto dim : llvm::seq<int64_t>(0, getOutputRank())) {
if (dimAndTileMapping.count(dim)) {
affine::DivModValue divMod =
affine::getDivMod(builder, loc, ivs[dim],
getValueOrCreateConstantIndexOp(
builder, loc, dimAndTileMapping[dim]));
inputIvsPointLoops.push_back(divMod.remainder);
inputIvs.push_back(divMod.quotient);
} else {
inputIvs.push_back(ivs[dim]);
}
}
// TODO: (lorenzo) simplify the logic a bit. There is `ivs`,
// `inputIvsPointLoops` and `inputIvs`.
assert(inputIvsPointLoops.size() + inputIvs.size() == getInputRank() &&
"expect same number of iduction variables equals to input rank");
// interchange the point loops induction variables based on `inner_dim_pos`.
ArrayRef<int64_t> innerDims = getInnerDimsPos();
SmallVector<int64_t> interchangeVector =
computeInterchangeFromDimPos(innerDims, getOutputRank());
SmallVector<Value> interchangedInputIvsPointLoops = inputIvsPointLoops;
interchangedInputIvsPointLoops = interchange<Value>(
interchangedInputIvsPointLoops, interchangeVector, /*offset=*/0);
// interchange the tiled loops induction variables based on `outer_dims_perm`.
ArrayRef<int64_t> outerDims = getOuterDimsPerm();
if (!outerDims.empty()) {
inputIvs = interchange<Value>(inputIvs, outerDims, /*offset=*/0);
}
llvm::append_range(inputIvs, interchangedInputIvsPointLoops);
Value scalar = builder.create<memref::LoadOp>(loc, getInput(), inputIvs);
builder.create<memref::StoreOp>(loc, scalar, getOutput(), ivs);
return success();
}
LogicalResult
UnPackOp::reifyResultShapes(OpBuilder &builder,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(builder, reifiedReturnShapes);
}
SmallVector<Range> UnPackOp::getIterationDomain(OpBuilder &builder) {
return ::getIterationDomain(*this, builder);
}
LogicalResult UnPackOp::verify() {
if (failed(commonVerifierPackAndUnPackOp(*this))) {
return failure();
}
return success();
}
//===----------------------------------------------------------------------===//
// WinogradInputTransformOp
//===----------------------------------------------------------------------===//
LogicalResult WinogradInputTransformOp::verify() {
Operation *op = getOperation();
if (getNumDpsInputs() != 1) {
return op->emitOpError("expected one input operand");
}
if (getNumDpsInits() != 1) {
return op->emitOpError("expected one output operand");
}
auto inputType = input().getType().cast<ShapedType>();
auto outputType = output().getType().cast<ShapedType>();
if (outputType.getElementType() != inputType.getElementType()) {
return op->emitOpError(
"expected input/output element types to be identical");
}
unsigned inputRank = inputType.getRank();
unsigned outputRank = outputType.getRank();
if (inputRank != 2 && inputRank != 4) {
return op->emitOpError("expected input operand to have rank either 2 or 4");
}
if (inputRank == 2) {
if (outputRank != 2) {
return op->emitOpError(
"expected output operand to have rank 2 if input is of rank 2");
}
return success();
}
if (getOutputOperandRank() != getInputOperandRank() + 2) {
return op->emitOpError(
"expected output rank to be equal to input rank + 2");
}
const SmallVector<int64_t> imageDims = imageDimensions();
const size_t numImageDims = imageDims.size();
llvm::SmallSetVector<int64_t, 2> imageDimsSet(imageDims.begin(),
imageDims.end());
if (imageDims.size() != 2) {
return op->emitOpError("expected only 2 image dimensions");
}
if (!isNchw() && !isNhwc()) {
return op->emitOpError(
"expect image dimensions to be either [1, 2] or [2, 3]");
}
const int64_t outputTileSize = getOutputTileSize();
const int64_t kernelSize = getKernelSize();
const int64_t inputTileSize = getInputTileSize();
SmallVector<int64_t> expectedOutputShape(getOutputOperandRank(),
inputTileSize);
int outputIndex;
ArrayRef<int64_t> inputShape = inputType.getShape();
for (int i = 0; i < inputShape.size(); i++) {
outputIndex = i + numImageDims;
if (ShapedType::isDynamic(inputShape[i])) {
expectedOutputShape[outputIndex] = inputShape[i];
continue;
}
if (!imageDimsSet.contains(i)) {
expectedOutputShape[outputIndex] = inputShape[i];
} else {
expectedOutputShape[outputIndex] =
std::ceil((float)(inputShape[i] - kernelSize + 1) / outputTileSize);
}
}
if (isNchw()) {
permute<Permutation::TTNCHW_TO_TTNHWC>(expectedOutputShape);
}
ArrayRef<int64_t> outputShape = outputType.getShape();
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
return op->emitOpError("incompatible output shape");
}
return success();
}
SmallVector<Range>
WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
Location loc = getLoc();
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
Value source = input();
SmallVector<int64_t> imageDims = imageDimensions();
llvm::SmallSetVector<int64_t, 2> imageDimsSet(imageDims.begin(),
imageDims.end());
SmallVector<Range> loopBounds(imageDims.size());
int count = 0;
for (auto dim : llvm::seq<int64_t>(0, getInputOperandRank())) {
if (!imageDimsSet.contains(dim)) {
loopBounds[count].offset = zero;
loopBounds[count].size = getDimValue(builder, loc, source, dim);
loopBounds[count].stride = one;
count++;
}
}
return loopBounds;
}
SmallVector<utils::IteratorType>
WinogradInputTransformOp::getLoopIteratorTypes() {
SmallVector<utils::IteratorType> iteratorTypes(getIterationDomainRank(),
utils::IteratorType::parallel);
return iteratorTypes;
}
FailureOr<TilingResult>
WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
Location loc = getLoc();
auto one = builder.getIndexAttr(1);
auto zero = builder.getIndexAttr(0);
const int cDim = channelDim();
assert(offsets.size() == 2);
SmallVector<OpFoldResult> inputOffsets(getInputOperandRank(), zero);
SmallVector<OpFoldResult> outputOffsets(getOutputOperandRank(), zero);
outputOffsets[2] = inputOffsets[0] = offsets[0];
outputOffsets[5] = inputOffsets[cDim] = offsets[1];
SmallVector<OpFoldResult> inputStrides(getInputOperandRank(), one);
SmallVector<OpFoldResult> outputStrides(getOutputOperandRank(), one);
assert(sizes.size() == 2);
auto inputShape = input().getType().cast<ShapedType>().getShape();
auto outputShape = output().getType().cast<ShapedType>().getShape();
SmallVector<OpFoldResult> inputSizes =
getAsOpFoldResult(builder.getIndexArrayAttr(inputShape));
SmallVector<OpFoldResult> outputSizes =
getAsOpFoldResult(builder.getIndexArrayAttr(outputShape));
outputSizes[2] = inputSizes[0] = sizes[0];
outputSizes[5] = inputSizes[cDim] = sizes[1];
SmallVector<Value> tiledOperands;
tiledOperands.emplace_back(
getSlice(builder, loc, input(), inputOffsets, inputSizes, inputStrides));
tiledOperands.emplace_back(getSlice(builder, loc, output(), outputOffsets,
outputSizes, outputStrides));
SmallVector<Type, 4> resultTypes;
if (hasTensorSemantics()) {
resultTypes.push_back(tiledOperands[1].getType());
}
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
}
LogicalResult WinogradInputTransformOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
if (resultNumber == 0) {
auto resultShape = output().getType().cast<ShapedType>().getShape();
resultSizes = getAsOpFoldResult(builder.getIndexArrayAttr(resultShape));
resultOffsets = SmallVector<OpFoldResult>(getOutputOperandRank(),
builder.getIndexAttr(0));
resultOffsets[2] = offsets[0];
resultOffsets[5] = offsets[1];
resultSizes[2] = sizes[0];
resultSizes[5] = sizes[1];
return success();
}
return failure();
}
LogicalResult WinogradInputTransformOp::fold(FoldAdaptor,
SmallVectorImpl<OpFoldResult> &) {
return memref::foldMemRefCast(*this);
}
LogicalResult WinogradInputTransformOp::reifyResultShapes(
OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(b, reifiedReturnShapes);
}
//===----------------------------------------------------------------------===//
// WinogradOutputTransformOp
//===----------------------------------------------------------------------===//
LogicalResult WinogradOutputTransformOp::verify() {
Operation *op = getOperation();
if (getNumDpsInputs() != 1) {
return op->emitOpError("expected one input operand");
}
if (getNumDpsInits() != 1) {
return op->emitOpError("expected one output operand");
}
auto inputType = input().getType().cast<ShapedType>();
auto outputType = output().getType().cast<ShapedType>();
unsigned inputRank = inputType.getRank();
unsigned outputRank = outputType.getRank();
if (inputRank != 2 && inputRank != 6) {
return op->emitOpError("expected input operand to have rank either 2 or 6");
}
if (inputRank == 2) {
if (outputRank != 2) {
return op->emitOpError(
"expected output operand to have rank 2 if input is of rank 2");
}
return success();
}
ArrayRef<int64_t> outputShape = outputType.getShape();
if (outputType.getElementType() != inputType.getElementType()) {
return op->emitOpError(
"expected input/output element types to be identical");
}
if (outputRank != inputRank - 2) {
return op->emitOpError(
"expected output rank to be equal to input rank - 2");
}
const SmallVector<int64_t> imageDims = imageDimensions();
const size_t numImageDims = imageDims.size();
llvm::SmallSetVector<int64_t, 2> imageDimsSet(imageDims.begin(),
imageDims.end());
if (imageDims.size() != 2) {
return op->emitOpError("expected only 2 image dimensions");
}
if (!isNchw() && !isNhwc()) {
return op->emitOpError(
"expect image dimensions to be either [1, 2] or [2, 3]");
}
SmallVector<int64_t> inputShape(inputType.getShape());
if (isNchw()) {
permute<Permutation::TTNHWC_TO_TTNCHW>(inputShape);
}
const int64_t outputTileSize = getOutputTileSize();
SmallVector<int64_t> expectedOutputShape(getOutputOperandRank(), 1);
int outputIndex;
for (int i = numImageDims; i < inputShape.size(); i++) {
outputIndex = i - numImageDims;
if (!imageDimsSet.contains(outputIndex)) {
expectedOutputShape[outputIndex] = inputShape[i];
} else {
expectedOutputShape[outputIndex] = outputTileSize * inputShape[i];
}
}
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
return op->emitOpError("incompatible output shape");
}
return success();
}
SmallVector<Range>
WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
Location loc = getLoc();
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
Value source = output();
SmallVector<int64_t> imageDims = imageDimensions();
llvm::SmallSetVector<int64_t, 2> imageDimsSet(imageDims.begin(),
imageDims.end());
SmallVector<Range> loopBounds(imageDims.size());
int count = 0;
for (auto dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
if (!imageDimsSet.contains(dim)) {
loopBounds[count].offset = zero;
loopBounds[count].size = getDimValue(builder, loc, source, dim);
loopBounds[count].stride = one;
count++;
}
}
return loopBounds;
}
SmallVector<utils::IteratorType>
WinogradOutputTransformOp::getLoopIteratorTypes() {
SmallVector<utils::IteratorType> iteratorTypes(getIterationDomainRank(),
utils::IteratorType::parallel);
return iteratorTypes;
}
FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
Location loc = getLoc();
auto one = builder.getIndexAttr(1);
auto zero = builder.getIndexAttr(0);
const int cDim = channelDim();
assert(offsets.size() == 2);
SmallVector<OpFoldResult> inputOffsets(getInputOperandRank(), zero);
SmallVector<OpFoldResult> outputOffsets(getOutputOperandRank(), zero);
inputOffsets[2] = outputOffsets[0] = offsets[0];
inputOffsets[5] = outputOffsets[cDim] = offsets[1];
SmallVector<OpFoldResult> inputStrides(getInputOperandRank(), one);
SmallVector<OpFoldResult> outputStrides(getOutputOperandRank(), one);
assert(sizes.size() == 2);
auto inputShape = input().getType().cast<ShapedType>().getShape();
auto outputShape = output().getType().cast<ShapedType>().getShape();
SmallVector<OpFoldResult> inputSizes =
getAsOpFoldResult(builder.getIndexArrayAttr(inputShape));
SmallVector<OpFoldResult> outputSizes =
getAsOpFoldResult(builder.getIndexArrayAttr(outputShape));
inputSizes[2] = outputSizes[0] = sizes[0];
inputSizes[5] = outputSizes[cDim] = sizes[1];
SmallVector<Value> tiledOperands;
tiledOperands.emplace_back(
getSlice(builder, loc, input(), inputOffsets, inputSizes, inputStrides));
tiledOperands.emplace_back(getSlice(builder, loc, output(), outputOffsets,
outputSizes, outputStrides));
SmallVector<Type, 4> resultTypes;
if (hasTensorSemantics()) {
resultTypes.push_back(tiledOperands[1].getType());
}
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
}
LogicalResult WinogradOutputTransformOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
if (resultNumber == 0) {
auto resultShape = output().getType().cast<ShapedType>().getShape();
resultSizes = getAsOpFoldResult(builder.getIndexArrayAttr(resultShape));
resultOffsets = SmallVector<OpFoldResult>(getOutputOperandRank(),
builder.getIndexAttr(0));
const int cDim = channelDim();
resultOffsets[0] = offsets[0];
resultOffsets[cDim] = offsets[1];
resultSizes[0] = sizes[0];
resultSizes[cDim] = sizes[1];
return success();
}
return failure();
}
LogicalResult WinogradOutputTransformOp::fold(FoldAdaptor,
SmallVectorImpl<OpFoldResult> &) {
return memref::foldMemRefCast(*this);
}
LogicalResult WinogradOutputTransformOp::reifyResultShapes(
OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(b, reifiedReturnShapes);
}
//===----------------------------------------------------------------------===//
// AttentionOp
//===----------------------------------------------------------------------===//
/// Utility function to check whether a given ShapedType has the expected rank.
static LogicalResult checkShapeRank(Operation *op, StringRef operandName,
ShapedType shapedType,
unsigned rankToCompareWith) {
unsigned opRank = shapedType.getRank();
if (opRank != rankToCompareWith)
return op->emitOpError("expected ")
<< operandName << " to have rank " << rankToCompareWith
<< " but found " << opRank;
return success();
}
LogicalResult AttentionOp::verify() {
Operation *op = getOperation();
unsigned numOperands = getNumOperands();
unsigned rankToCompareWith = 3;
if (numOperands == 6)
rankToCompareWith = 2;
else if (numOperands != 4)
return op->emitOpError("expected operand count 4 or 6, but got")
<< numOperands;
ShapedType queryType = getQueryType();
ShapedType keyType = getKeyType();
ShapedType valueType = getValueType();
ShapedType outputType = getOutputType();
Type queryElementType = queryType.getElementType();
Type keyElementType = keyType.getElementType();
Type valueElementType = valueType.getElementType();
Type outputElementType = outputType.getElementType();
if (failed(checkShapeRank(op, "query", queryType, rankToCompareWith)))
return failure();
if (failed(checkShapeRank(op, "key", keyType, rankToCompareWith)))
return failure();
if (failed(checkShapeRank(op, "value", valueType, rankToCompareWith)))
return failure();
if (failed(checkShapeRank(op, "output", outputType, rankToCompareWith)))
return failure();
ArrayRef<int64_t> queryShape = queryType.getShape();
ArrayRef<int64_t> keyShape = keyType.getShape();
ArrayRef<int64_t> valueShape = valueType.getShape();
ArrayRef<int64_t> outputShape = outputType.getShape();
if (failed(verifyCompatibleShape(keyShape, valueShape)))
return op->emitOpError("incompatible value shape");
if (failed(verifyCompatibleShape(queryShape, outputShape)))
return op->emitOpError("incompatible output shape");
if (queryElementType != keyElementType || keyElementType != valueElementType)
return op->emitOpError(
"element types of (Q)uery, (K)ey and (V)value should be same");
if (numOperands == 4) {
// Vanilla attention.
if (queryElementType != outputElementType)
return op->emitOpError("expected element type for Output ")
<< queryElementType << "but found " << outputElementType
<< " instead";
if (keyShape[2] != queryShape[2])
return op->emitOpError("query and key head dimension mismatch");
}
if (numOperands == 6) {
// Tiled/Flash attention.
ShapedType maxType = *getMaxType();
ShapedType sumType = *getSumType();
if (failed(checkShapeRank(op, "max", maxType, 1)))
return failure();
if (failed(checkShapeRank(op, "sum", sumType, 1)))
return failure();
Type maxElementType = maxType.getElementType();
Type sumElementType = sumType.getElementType();
ArrayRef<int64_t> maxShape = maxType.getShape();
ArrayRef<int64_t> sumShape = sumType.getShape();
if (outputElementType != maxElementType || maxElementType != sumElementType)
return op->emitOpError(
"element types of tiled output, max and sum should be same");
if (failed(verifyCompatibleShape(maxShape, sumShape)))
return op->emitOpError("incompatible sum shape");
if (maxShape[0] != queryShape[0])
return op->emitOpError("Query and max dimension-0 mismatch");
}
return success();
}
SmallVector<Range> AttentionOp::getIterationDomain(OpBuilder &builder) {
int64_t iterationDomainRank = getIterationDomainRank();
SmallVector<Range> loopBounds(iterationDomainRank);
Location loc = getLoc();
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
Value source = getQuery();
for (auto dim : llvm::seq<int64_t>(0, iterationDomainRank)) {
loopBounds[dim].offset = zero;
loopBounds[dim].size = getDimValue(builder, loc, source, dim);
loopBounds[dim].stride = one;
}
return loopBounds;
}
SmallVector<utils::IteratorType> AttentionOp::getLoopIteratorTypes() {
SmallVector<utils::IteratorType> iteratorTypes(getIterationDomainRank(),
utils::IteratorType::parallel);
return iteratorTypes;
}
FailureOr<TilingResult>
AttentionOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
assert(offsets.size() == getIterationDomainRank());
assert(sizes.size() == getIterationDomainRank());
Location loc = getLoc();
auto one = builder.getIndexAttr(1);
auto zero = builder.getIndexAttr(0);
SmallVector<OpFoldResult> queryOutputOffsets(getQueryRank(), zero);
SmallVector<OpFoldResult> queryOutputStrides(getQueryRank(), one);
ArrayRef<int64_t> queryShape = getQueryType().getShape();
SmallVector<OpFoldResult> queryOutputSizes =
getAsOpFoldResult(builder.getIndexArrayAttr(queryShape));
for (auto info : llvm::enumerate(llvm::zip(offsets, sizes))) {
queryOutputOffsets[info.index()] = std::get<0>(info.value());
queryOutputSizes[info.index()] = std::get<1>(info.value());
}
SmallVector<OpFoldResult> keyValueOffsets(getKeyRank(), zero);
SmallVector<OpFoldResult> keyValueStrides(getKeyRank(), one);
ArrayRef<int64_t> keyShape = getKeyType().getShape();
SmallVector<OpFoldResult> keyValueSizes =
getAsOpFoldResult(builder.getIndexArrayAttr(keyShape));
keyValueSizes[0] = sizes[0];
keyValueOffsets[0] = offsets[0];
SmallVector<Value> tiledOperands;
tiledOperands.emplace_back(getSlice(builder, loc, getQuery(),
queryOutputOffsets, queryOutputSizes,
queryOutputStrides));
tiledOperands.emplace_back(getSlice(builder, loc, getKey(), keyValueOffsets,
keyValueSizes, keyValueStrides));
tiledOperands.emplace_back(getSlice(builder, loc, getValue(), keyValueOffsets,
keyValueSizes, keyValueStrides));
tiledOperands.emplace_back(getSlice(builder, loc, getOutput(),
queryOutputOffsets, queryOutputSizes,
queryOutputStrides));
SmallVector<Type> resultTypes;
if (hasTensorSemantics())
resultTypes.push_back(tiledOperands[3].getType());
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
}
LogicalResult AttentionOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
if (resultNumber == 0) {
ArrayRef<int64_t> resultShape = getOutputType().getShape();
resultSizes = getAsOpFoldResult(builder.getIndexArrayAttr(resultShape));
resultOffsets =
SmallVector<OpFoldResult>(getOutputRank(), builder.getIndexAttr(0));
for (auto info : llvm::enumerate(llvm::zip(offsets, sizes))) {
resultOffsets[info.index()] = std::get<0>(info.value());
resultSizes[info.index()] = std::get<1>(info.value());
}
return success();
}
return failure();
}
LogicalResult AttentionOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
return memref::foldMemRefCast(*this);
}
LogicalResult AttentionOp::reifyResultShapes(
OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(b, reifiedReturnShapes);
}
#define DEFINE_OP_GET_EFFECTS(OP_NAME) \
void OP_NAME::getEffects( \
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \
&effects) { \
getEffectsImpl(effects, getDpsInputs(), getDpsInits()); \
}
DEFINE_OP_GET_EFFECTS(ScatterOp)
DEFINE_OP_GET_EFFECTS(SortOp)
DEFINE_OP_GET_EFFECTS(FftOp)
DEFINE_OP_GET_EFFECTS(ReverseOp)
DEFINE_OP_GET_EFFECTS(ScanOp)
DEFINE_OP_GET_EFFECTS(TopkOp)
DEFINE_OP_GET_EFFECTS(PackOp)
DEFINE_OP_GET_EFFECTS(UnPackOp)
DEFINE_OP_GET_EFFECTS(WinogradInputTransformOp)
DEFINE_OP_GET_EFFECTS(WinogradOutputTransformOp)
DEFINE_OP_GET_EFFECTS(AttentionOp)
//===----------------------------------------------------------------------===//
// iree_linalg_ext.set_encoding
//===----------------------------------------------------------------------===//
LogicalResult SetEncodingOp::verify() {
// Source and the result have the same rank.
if (getSourceType().getEncoding()) {
return emitOpError(
"source of set_encoding op cannot have a tensor encoding");
}
if (!getResultType().getEncoding().isa_and_nonnull<EncodingAttr>()) {
return emitOpError(
"result of set_encoding op expected to have a valid tensor encoding");
}
// The source and result must have the same rank.
if (getResultType().getRank() != getSourceType().getRank())
return emitOpError("cannot change the rank of the tensor");
if (failed(verifyCompatibleShape(getResultType(), getSourceType())))
return emitOpError("expected to preserve the logical shape of the tensor");
return success();
}
LogicalResult SetEncodingOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPoint(getOperation());
reifiedReturnShapes.resize(1);
reifiedReturnShapes[0] = getDims(builder, getLoc(), getSource());
return success();
}
//===----------------------------------------------------------------------===//
// iree_linalg_ext.unset_encoding
//===----------------------------------------------------------------------===//
LogicalResult UnsetEncodingOp::verify() {
if (getResultType().getEncoding()) {
return emitOpError(
"result of unset_encoding op cannot have a tensor encoding");
}
if (!getSourceType().getEncoding().isa_and_nonnull<EncodingAttr>()) {
return emitOpError(
"source of unset_encoding op expected to have a valid tensor encoding");
}
// The source and result must have the same rank.
if (getResultType().getRank() != getSourceType().getRank())
return emitOpError("cannot change the rank of the tensor");
if (failed(verifyCompatibleShape(getResultType(), getSourceType())))
return emitOpError("expected to preserve the logical shape of the tensor");
return success();
}
LogicalResult UnsetEncodingOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPoint(getOperation());
reifiedReturnShapes.resize(1);
reifiedReturnShapes[0] = getDims(builder, getLoc(), getSource());
return success();
}
// clang-format off
#define GET_OP_CLASSES
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.cpp.inc" // IWYU pragma: keep
// clang-format: on