blob: eed2a458ee283e0b3d9d1f370ae40a92a67aea81 [file] [log] [blame]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include <cstdint>
#include <optional>
namespace mlir::iree_compiler::IREE::LinalgExt {
//===----------------------------------------------------------------------===//
// Utils.
//===----------------------------------------------------------------------===//
static Type getComplexElementTypeOrSelf(Type ty) {
if (auto complex = dyn_cast_or_null<ComplexType>(ty)) {
return complex.getElementType();
}
return ty;
}
static void getEffectsImpl(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects,
ArrayRef<OpOperand *> inputOperands, MutableOperandRange outputOperands) {
for (OpOperand *operand : inputOperands) {
if (!llvm::isa<MemRefType>(operand->get().getType())) {
continue;
}
effects.emplace_back(MemoryEffects::Read::get(), operand,
SideEffects::DefaultResource::get());
}
for (OpOperand &operand : outputOperands) {
if (!llvm::isa<MemRefType>(operand.get().getType())) {
continue;
}
effects.emplace_back(MemoryEffects::Read::get(), &operand,
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Write::get(), &operand,
SideEffects::DefaultResource::get());
}
}
/// Return true if `dimsPos` is invalid. It is invalid when: a) it contains
/// duplicate. b) At least one dimension is out of bound (`dimPos` is >= 0 and <
/// rank). c) the number of elements in `dimsPos` is > than `rank`.
static bool isInvalid(ArrayRef<int64_t> dimsPos, int64_t rank) {
// early exit.
if (dimsPos.size() > rank) {
return true;
}
DenseSet<int64_t> uniqued;
for (int64_t dim : dimsPos) {
uniqued.insert(dim);
}
if (dimsPos.size() != uniqued.size()) {
return true;
}
return llvm::any_of(
dimsPos, [rank](int64_t dimPos) { return dimPos < 0 || dimPos >= rank; });
}
/// Emit an error and return failure when `seq` is invalid. It is only valid
/// when it is a permutation of the sequence 0...length(seq) - 1.
static LogicalResult
isPermSequence(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> seq) {
BitVector seen(seq.size(), false);
for (auto [idx, dim] : llvm::enumerate(seq)) {
if (dim < 0 || dim >= seq.size()) {
return emitError().attachNote() << "element (" << dim << ") at index#"
<< idx << " is out of bounds";
}
if (seen.test(dim)) {
return emitError().attachNote()
<< "element (" << dim << ") at index#" << idx << " is a duplicate";
}
seen.set(dim);
}
return success();
}
/// Returns true if the dimension of `sourceShape` is smaller than the dimension
/// of the `limitShape`.
static bool isSmallerThan(ArrayRef<int64_t> sourceShape,
ArrayRef<int64_t> limitShape) {
assert(
sourceShape.size() == limitShape.size() &&
"expected source shape rank, and limit of the shape to have same rank");
return llvm::all_of(llvm::zip_equal(sourceShape, limitShape),
[](std::tuple<int64_t, int64_t> it) {
int64_t sourceExtent = std::get<0>(it);
int64_t limit = std::get<1>(it);
return ShapedType::isDynamic(sourceExtent) ||
ShapedType::isDynamic(limit) ||
sourceExtent <= limit;
});
}
//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
LogicalResult ScatterOp::verify() {
Operation *op = getOperation();
if (getInputs().size() != 2) {
return op->emitOpError("expected two input operands");
}
if (getOutputs().size() != 1) {
return op->emitOpError("expected one output operand");
}
auto indicesType = getIndicesType();
if (indicesType.getRank() < 1 ||
!isa<IntegerType>(indicesType.getElementType())) {
return op->emitOpError("expected indices to be of rank 1 or greater and of "
"integer element type");
}
ArrayRef<int64_t> dimMap = getDimensionMap();
if (failed(isPermSequence(
[&]() { return this->emitOpError("dimension map is invalid."); },
dimMap))) {
return failure();
}
if (dimMap.size() == 0) {
return op->emitOpError("dimension map must have at least one element");
}
const size_t indexDepth = getIndexDepth();
auto originalType = getOriginalType();
auto updateType = getUpdateType();
const auto originalSliceRank = originalType.getRank() - indexDepth;
if (originalSliceRank < 0) {
return op->emitOpError(
"expected original rank to be greater or equal to index depth");
}
if (updateType.getRank() < originalSliceRank) {
return op->emitOpError(
"expected update to be at least the rank of non indexed original dims");
}
const size_t batchRank = updateType.getRank() - originalSliceRank;
if (updateType.getRank() - batchRank != originalSliceRank) {
return op->emitOpError("expected rank of update value - batch rank to be "
"equal to rank of original value - index depth");
}
if ((indicesType.getRank() != batchRank || indexDepth != 1) &&
indicesType.getRank() != batchRank + 1) {
return op->emitOpError("expected indices to be equal to batch rank "
"or batch rank + 1");
}
{
// Validate the shape of indices and update value match for the first
// `batchRank` dims.
auto [indicesIt, updateIt] =
llvm::mismatch(indicesType.getShape().take_front(batchRank),
updateType.getShape().take_front(batchRank));
if (indicesIt != indicesType.getShape().take_front(batchRank).end()) {
return op->emitOpError(
"mismatch in shape of indices and update value at dim#")
<< (indicesIt - indicesType.getShape().begin());
}
}
if (batchRank + 1 < indicesType.getShape().size() &&
dimMap.size() != indicesType.getShape().back()) {
return op->emitOpError(
"size of dimension map must match the last dimension of indices");
}
// updateSlice[0..indexDepth] <= original[0..indexDepth]
// updateSlice[indexDepth..] == original[indexDepth..]
{
auto [updateIt, originalIt] = llvm::mismatch(
getUpdateSliceShape(), originalType.getShape().drop_front(indexDepth));
if (updateIt != getUpdateSliceShape().end()) {
return op->emitOpError("shape of update value dim#")
<< (updateIt - updateType.getShape().begin())
<< " must match original value at dim#"
<< (originalIt - originalType.getShape().begin());
}
}
Region &region = this->getRegion();
Block *body = &region.front();
if (body->getNumArguments() != 2) {
return op->emitOpError("expected region to have two arguments");
}
Type arg0Type = body->getArgument(0).getType();
Type arg1Type = body->getArgument(1).getType();
if (!getComplexElementTypeOrSelf(arg0Type).isIntOrFloat() ||
!getComplexElementTypeOrSelf(arg1Type).isIntOrFloat()) {
return op->emitOpError(
"expected region to have scalar argument of integer or float types");
}
if (arg0Type != updateType.getElementType()) {
return op->emitOpError("mismatch in argument 0 of region ")
<< arg0Type << " and element type of update value "
<< updateType.getElementType();
}
if (arg1Type != originalType.getElementType()) {
return op->emitOpError("mismatch in argument 1 of region ")
<< arg1Type << " and element type of original value "
<< originalType.getElementType();
}
if (arg0Type != arg1Type) {
return op->emitOpError("mismatch in region argument types ")
<< arg0Type << " and " << arg1Type;
}
auto yieldOp = cast<IREE::LinalgExt::YieldOp>(body->getTerminator());
if (yieldOp->getNumOperands() != 1) {
return yieldOp.emitOpError("expected region to yield a single value");
}
auto yieldedType = yieldOp->getOperand(0).getType();
if (yieldedType != arg0Type) {
return yieldOp.emitOpError("mismatch in type of yielded value ")
<< yieldedType << " and argument of the region " << arg0Type;
}
return success();
}
LogicalResult
ScatterOp::reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(b, reifiedReturnShapes);
}
FailureOr<SmallVector<int64_t>> ScatterOp::getStaticLoopRanges() {
// Scatter loop ranges are loop ranges for update.
return SmallVector<int64_t>(getUpdateType().getShape());
}
SmallVector<AffineMap> ScatterOp::getIndexingMapsForOperands() {
Builder builder(getContext());
return {builder.getMultiDimIdentityMap(getUpdateType().getRank()),
builder.getMultiDimIdentityMap(getIndicesType().getRank()),
/*output=*/AffineMap(nullptr)};
}
SmallVector<AffineMap> ScatterOp::getIndexingMapsForResults() {
return {AffineMap(nullptr)};
}
//===----------------------------------------------------------------------===//
// SortOp
//===----------------------------------------------------------------------===//
LogicalResult SortOp::verify() {
Operation *op = getOperation();
if (getNumDpsInputs()) {
return op->emitOpError("does not expect to take any inputs");
}
if (getNumDpsInits() == 0) {
return op->emitOpError("expected at least one `outs` operand");
}
Block &block = getRegion().front();
size_t numOutputs = getNumDpsInits();
if (block.getNumArguments() != 2 * numOutputs) {
return op->emitOpError("region block should have ")
<< 2 * numOutputs << " arguments";
}
int64_t rank = getOperandRank();
int sortDim = getDimension();
if (sortDim < 0 || sortDim >= rank) {
return op->emitOpError("dimension must be within (0, ") << rank << "]";
}
ArrayRef<int64_t> shape = getOperandShape();
for (auto [index, operand] : llvm::enumerate(getOutputs())) {
auto operandType = getOperandType(index);
if (operandType.getRank() != rank) {
return op->emitOpError("expected operand ")
<< index << " to be rank " << rank << ", same as other operands";
}
if (operandType.getShape() != shape) {
return op->emitOpError("expected operand ")
<< index << " to have same shape as other operands";
}
Type elemType = operandType.getElementType();
for (int i : {2 * index, 2 * index + 1}) {
Type argType = block.getArgument(i).getType();
if (argType != elemType) {
return op->emitOpError("region block argument #")
<< i << " should be of type " << elemType << " but got "
<< argType;
}
}
}
auto yieldOp = cast<YieldOp>(block.getTerminator());
if (yieldOp.getNumOperands() != 1) {
return op->emitOpError("should yield exactly one operand");
}
auto ty = dyn_cast<IntegerType>(yieldOp.getOperand(0).getType());
if (!ty || ty.getWidth() != 1) {
return op->emitOpError("should yield i1 type");
}
return success();
}
LogicalResult
SortOp::reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(b, reifiedReturnShapes);
}
//===----------------------------------------------------------------------===//
// FftOp
//===----------------------------------------------------------------------===//
LogicalResult FftOp::verify() {
Operation *op = getOperation();
auto length = getFftLength();
// After tiling, it could be dynamic shape. (Because
// subview/subtensor does not inference the type correctly
// on (1 << x)) cases).
if (ShapedType::isDynamic(length))
return success();
if (length & (length - 1)) {
return op->emitOpError("only powers of 2 are handled currently");
}
if (!getNumDpsInputs() || !isScalar(getDpsInputOperand(0))) {
return op->emitOpError("expected to carry `stage` input");
}
if (getNumDpsInputs() != 1) {
if (getNumDpsInputs() != 3 || isScalar(getDpsInputOperand(1)) ||
isScalar(getDpsInputOperand(2))) {
return op->emitOpError("expected to carry real and imag coeff inputs");
}
}
if (getNumDpsInits() != 2) {
return op->emitOpError(
"expected outputs to be real and imag tensor/memref");
}
return success();
}
LogicalResult
FftOp::reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(b, reifiedReturnShapes);
}
//===----------------------------------------------------------------------===//
// ScanOp
//===----------------------------------------------------------------------===//
LogicalResult ScanOp::verify() {
Operation *op = getOperation();
if (getNumDpsInputs() != 1) {
return op->emitOpError("expected one input operands");
}
if (getNumDpsInits() != 2) {
return op->emitOpError("expected two output operands");
}
if (!isa<ShapedType>(getInput().getType())) {
return op->emitOpError("expected first input element type to be shaped");
}
auto accumulatorType = cast<ShapedType>(getAccumulator().getType());
auto inputType = cast<ShapedType>(getInput().getType());
auto outputType = cast<ShapedType>(getOutput().getType());
ArrayRef<int64_t> inputShapes = inputType.getShape();
ArrayRef<int64_t> outputShapes = outputType.getShape();
if (accumulatorType.getElementType() != inputType.getElementType()) {
return op->emitOpError(
"expected input/accumulator element types to be identical");
}
ArrayRef<int64_t> accumulatorShape = accumulatorType.getShape();
int64_t accumulatorRank = accumulatorType.getRank();
if (accumulatorRank != inputType.getRank() - 1) {
return op->emitOpError(
"expected accumulator rank to be equal to input rank - 1");
}
SmallVector<int64_t> expectedAccumulatorShape;
for (int i = 0; i < inputType.getRank(); i++) {
if (i != getDimension()) {
expectedAccumulatorShape.push_back(inputShapes[i]);
}
}
if (llvm::any_of(llvm::zip_equal(expectedAccumulatorShape, accumulatorShape),
[](std::tuple<int64_t, int64_t> s) {
return !ShapedType::isDynamic(std::get<0>(s)) &&
!ShapedType::isDynamic(std::get<1>(s)) &&
std::get<0>(s) != std::get<1>(s);
})) {
return op->emitOpError("incompatible input/accumulator shapes");
}
if (inputType.getElementType() != outputType.getElementType()) {
return op->emitOpError(
"expected input/output element types to be identical");
}
if (inputShapes.size() != outputShapes.size()) {
return op->emitOpError("expected input/output to have identical ranks");
}
if (llvm::any_of(llvm::zip_equal(inputShapes, outputShapes),
[](std::tuple<int64_t, int64_t> s) {
return !ShapedType::isDynamic(std::get<0>(s)) &&
!ShapedType::isDynamic(std::get<1>(s)) &&
std::get<0>(s) != std::get<1>(s);
})) {
return op->emitOpError("incompatible input/output shapes");
}
return success();
}
LogicalResult ScanOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
return memref::foldMemRefCast(*this);
}
LogicalResult
ScanOp::reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(b, reifiedReturnShapes);
}
//===----------------------------------------------------------------------===//
// TopkOp
//===----------------------------------------------------------------------===//
LogicalResult TopkOp::verify() {
Operation *op = getOperation();
if (getNumDpsInputs() != 1 && getNumDpsInputs() != 2) {
return op->emitOpError("expected one or two input operands");
}
if (getNumDpsInits() != 2) {
return op->emitOpError("expected two output operands");
}
if (getDimension() >= getInputRank()) {
return op->emitOpError("dimension exceeds rank");
}
// Ensure input/output element types match
auto inputValuesType = cast<ShapedType>(getValues().getType());
auto outputValuesType = cast<ShapedType>(outputValues().getType());
if (inputValuesType.getElementType() != outputValuesType.getElementType()) {
return op->emitOpError("expected input/output value types to be identical");
}
// Indices must be int if provided
auto outputIndicesType = cast<ShapedType>(outputIndices().getType());
if (auto inputIndices = getIndices()) {
auto inputIndicesType = cast<ShapedType>(inputIndices->getType());
if (!inputIndicesType.getElementType().isInteger(32) ||
!outputIndicesType.getElementType().isInteger(32)) {
return op->emitOpError("expected input/output indices types to be int32");
}
}
// Ranks must match
if (inputValuesType.getRank() != outputValuesType.getRank()) {
return op->emitOpError("expected input/output to have the same rank");
}
if (auto inputIndices = getIndices()) {
auto inputIndicesType = cast<ShapedType>(inputIndices->getType());
if (inputIndicesType.getRank() != outputIndicesType.getRank()) {
return op->emitOpError("expected input/output to have the same rank");
}
}
// Input indicies and values must have the same shape.
if (auto inputIndices = getIndices()) {
auto inputIndicesType = cast<ShapedType>(inputIndices->getType());
if (failed(verifyCompatibleShape(inputValuesType, inputIndicesType))) {
return op->emitOpError("input indices/values shape must match");
}
}
// Output indicies and values must have the same shape.
if (failed(verifyCompatibleShape(outputValuesType, outputIndicesType))) {
return op->emitOpError("output indices/values shape must match");
}
// Input shape must match the output shape except for the dimension()
uint64_t dim = getDimension();
if (!llvm::all_of(
llvm::enumerate(llvm::zip_equal(inputValuesType.getShape(),
outputValuesType.getShape())),
[dim](auto e) {
if (e.index() == dim) {
return true;
}
std::tuple<int64_t, int64_t> s = e.value();
return succeeded(
verifyCompatibleShape(std::get<0>(s), std::get<1>(s)));
})) {
return op->emitOpError("incompatible input/output shapes");
}
// Check region compatibility
Block &block = getRegion().front();
if (block.getNumArguments() != 2) {
return op->emitOpError("region block should have 2 arguments");
}
if (block.getArgument(0).getType() != inputValuesType.getElementType() ||
block.getArgument(1).getType() != inputValuesType.getElementType()) {
return op->emitOpError("region block types must match input");
}
auto terminatorOp = llvm::dyn_cast<YieldOp>(block.getTerminator());
if (!terminatorOp || !terminatorOp.getOperand(0).getType().isInteger(1)) {
return op->emitOpError("region block must end with a linalg_ext.yield i1!");
}
return success();
}
LogicalResult
TopkOp::reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(b, reifiedReturnShapes);
}
//===----------------------------------------------------------------------===//
// PackOp and UnPackOp utils
//===----------------------------------------------------------------------===//
/// Return true if at least one element in `tiles` is zero.
static bool hasZeros(ArrayRef<OpFoldResult> tiles) {
return llvm::any_of(
tiles, [&](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
}
/// Check if we have enough static information to catch undefined behavior when
/// the tile size does not divide perfectly the dimension of the input tensor.
static bool
areNotFullTiles(ArrayRef<int64_t> inputShape,
DenseMap<int64_t, OpFoldResult> const &dimAndTileMapping) {
int64_t rank = inputShape.size();
for (int64_t dim = 0; dim < rank; dim++) {
if (ShapedType::isDynamic(inputShape[dim]))
continue;
auto it = dimAndTileMapping.find(dim);
if (it != dimAndTileMapping.end()) {
std::optional<int64_t> constantTile = getConstantIntValue(it->second);
if (!constantTile)
continue;
if (inputShape[dim] % (*constantTile) != 0)
return true;
}
}
return false;
}
static SmallVector<OpFoldResult> getMixedValues(MLIRContext *context,
ArrayRef<int64_t> staticValues,
OperandRange dynamicValues) {
OpBuilder b(context);
return mlir::getMixedValues(staticValues, dynamicValues, b);
}
static SmallVector<int64_t>
getStaticValues(SmallVector<OpFoldResult> mixedValues) {
SmallVector<Value> dynamicTiles;
SmallVector<int64_t> staticTiles;
dispatchIndexOpFoldResults(mixedValues, dynamicTiles, staticTiles);
return staticTiles;
}
/// Utility function shared between Pack and UnPack to get the tile sizes as
/// OpFoldResults.
// TODO: interface or base class in .td
template <typename OpTy>
static SmallVector<OpFoldResult> getMixedTiles(OpTy op) {
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
"applies to only pack or unpack operations");
return LinalgExt::getMixedValues(op.getContext(), op.getStaticInnerTiles(),
op.getInnerTiles());
}
/// Return the tile sizes as `int64_t`. If a tile size is dynamic a sentinel
/// `kDynamic` is introduced at that position in the returned vector.
template <typename OpTy>
static SmallVector<int64_t> getStaticTiles(OpTy op) {
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
"applies to only pack or unpack operations");
return getStaticValues(op.getMixedTiles());
}
/// Utility function shared between Pack and UnPack to get a map between
/// `dim_pos` and `inner_tiles`.
// TODO: interface or base class in .td
template <typename OpTy>
static DenseMap<int64_t, OpFoldResult> getDimAndTileMapping(OpTy op) {
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
"applies to only pack or unpack operations");
DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
ArrayRef<int64_t> dimsToBlock = op.getInnerDimsPos();
SmallVector<OpFoldResult> tiles = op.getMixedTiles();
assert(tiles.size() == dimsToBlock.size() &&
"tiles must match indices of dimension to block");
// bind the dimension with the tile factor.
for (auto i : llvm::seq<int64_t>(0, dimsToBlock.size())) {
dimAndTileMapping[dimsToBlock[i]] = tiles[i];
}
return dimAndTileMapping;
}
/// Common verifier for `PackOp` and `UnPackOp`.
template <typename OpTy>
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
"applies to only pack or unpack operations");
Operation *op = packOrUnPack.getOperation();
ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
? packOrUnPack.getInputType()
: packOrUnPack.getOutputType();
int64_t unpackedRank = unpackedType.getRank();
ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
// Verify tiles. Make sure each provided tile is non-zero.
SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
if (hasZeros(mixedTiles)) {
return op->emitError("invalid tile factor");
}
if (isInvalid(innerDimsPos, unpackedRank)) {
return op->emitError("invalid inner_dims_pos vector");
}
if (isInvalid(outerDimPerm, unpackedRank)) {
return op->emitError("invalid outer_dims_perm vector");
}
if (mixedTiles.size() != innerDimsPos.size()) {
return op->emitError(
"blocking factors must equal the number of dimensions to block");
}
// Blocking factors must be less or equal than the input rank, and must
// match the number of `dims_pos`.
if (mixedTiles.size() > unpackedRank) {
return op->emitError(
"blocking factors must be less or equal than the input rank");
}
ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
? packOrUnPack.getOutputType()
: packOrUnPack.getInputType();
int64_t packedRank = packedType.getRank();
// Require output rank to match input rank + number of blocking factors.
if (unpackedRank + mixedTiles.size() != packedRank) {
return op->emitError(
"packed rank must equal unpacked rank + blocking factors");
}
// Verify result shape is greater than the minimum expected
// by the pack operation, and that the output shape
// represents full tiles.
ShapedType expectedPackedType = PackOp::getPackedType(
unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
if (!isSmallerThan(expectedPackedType.getShape(), packedType.getShape())) {
return op->emitError("the shape of output is not large enough to hold the "
"packed data. Expected at least ")
<< expectedPackedType << ", got " << packedType;
}
if (!llvm::all_of(
llvm::zip_equal(packedType.getShape().take_back(mixedTiles.size()),
mixedTiles),
[](std::tuple<int64_t, OpFoldResult> it) {
std::optional<int64_t> constTileSize =
getConstantIntValue(std::get<1>(it));
int64_t shape = std::get<0>(it);
if (!constTileSize) {
// If specified tile size is dynamic, output shape should
// be dynamic too.
return ShapedType::isDynamic(shape);
} else {
if (ShapedType::isDynamic(shape)) {
// For the shape being dynamic when tile size is
// specified, return true. In canonical form a constant
// tile size should lead to constant shape of the tiled
// dimension, but not needed for verification.
return true;
}
return shape == constTileSize.value();
}
})) {
return op->emitError("mismatch in inner tile sizes specified and shaped of "
"tiled dimension in the packed type");
}
return success();
}
//===----------------------------------------------------------------------===//
// PackOp
//===----------------------------------------------------------------------===//
/// Custom builder methods for pack ops.
void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
Value output, ArrayRef<int64_t> innerDimsPos,
ArrayRef<OpFoldResult> innerTiles,
std::optional<Value> paddingValue,
ArrayRef<int64_t> outerDimsPerm) {
assert(innerDimsPos.size() == innerTiles.size() &&
"number of tile sizes specified must match the specified number of "
"original dimensions to be tiled");
SmallVector<int64_t> staticTileSizes;
SmallVector<Value> dynamicTileSizes;
dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
SmallVector<Type> resultType;
auto outputType = output.getType();
if (isa<RankedTensorType>(outputType)) {
resultType.push_back(outputType);
}
build(builder, state, resultType, source, output,
outerDimsPerm.empty() ? nullptr
: builder.getDenseI64ArrayAttr(outerDimsPerm),
builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
builder.getDenseI64ArrayAttr(staticTileSizes),
(paddingValue ? paddingValue.value() : nullptr));
}
LogicalResult PackOp::verify() {
if (failed(commonVerifierPackAndUnPackOp(*this))) {
return failure();
}
// Bail out if the tile does not divide the dimension fully. In the case of
// dynamic tile factors or dimensions, having a partial tile is undefined
// behavior.
auto dimAndTileMapping = getDimAndTileMapping();
if (!getPaddingValue() &&
areNotFullTiles(getInputShape(), dimAndTileMapping)) {
return emitOpError("invalid tile factor provided. Only full tiles are "
"supported when padding_value is not set");
}
if (auto paddingValue = getPaddingValue()) {
if (paddingValue.getType() != getInputType().getElementType()) {
return emitOpError("expected padding_value has ")
<< getInputType().getElementType()
<< " but got: " << paddingValue.getType();
}
}
return success();
}
SmallVector<OpFoldResult> PackOp::getMixedTiles() {
return LinalgExt::getMixedTiles(*this);
}
SmallVector<int64_t> PackOp::getStaticTiles() {
return LinalgExt::getStaticTiles(*this);
}
// Helper for PackOp::{getResultShape,getPackedType}. Returns the shape of the
// packed type. Having a shared helper helps implement these two methods in a
// way that ensures that they agree on which dimensions are dynamic.
static SmallVector<int64_t> getPackOpResultTypeShape(
ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
for (auto [idx, tiledDim] : llvm::enumerate(innerDimsPos)) {
if (ShapedType::isDynamic(resultShape[tiledDim])) {
continue;
}
if (ShapedType::isDynamic(innerTileSizes[idx])) {
resultShape[tiledDim] = ShapedType::kDynamic;
continue;
}
resultShape[tiledDim] =
llvm::divideCeil(resultShape[tiledDim], innerTileSizes[idx]);
}
// Swap tile loops if outer_dims_perm is available.
resultShape = interchange<int64_t>(resultShape, outerDimsPerm, /*offset=*/0);
// Append the inner tile dimensions.
resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
return resultShape;
}
SmallVector<OpFoldResult> PackOp::getResultShape(
OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm) {
SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
AffineExpr s0, s1;
bindSymbols(builder.getContext(), s0, s1);
AffineExpr ceilDivExpr = s0.ceilDiv(s1);
for (auto [idx, tiledDim] : llvm::enumerate(innerDimsPos)) {
resultDims[tiledDim] = affine::makeComposedFoldedAffineApply(
builder, loc, ceilDivExpr, {resultDims[tiledDim], innerTileSizes[idx]});
}
if (!outerDimsPerm.empty()) {
resultDims =
interchange<OpFoldResult>(resultDims, outerDimsPerm, /*offset=*/0);
}
resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
SmallVector<int64_t> resultTypeShape =
getPackOpResultTypeShape(asShapeWithAnyValueAsDynamic(sourceDims),
asShapeWithAnyValueAsDynamic(innerTileSizes),
innerDimsPos, outerDimsPerm);
// Fix-up `resultDims` to ensure that they are Value's if and only if the
// result type shape says it's a dynamic dim. This is needed as callers may
// use dispatchIndexOpFoldResults on the result, and rely on exact number of
// dynamic dims returned by that.
for (unsigned i = 0; i < resultDims.size(); ++i) {
if (!ShapedType::isDynamic(resultTypeShape[i])) {
continue;
}
resultDims[i] =
getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
}
return resultDims;
}
SmallVector<OpFoldResult> PackOp::getResultShape(OpBuilder &builder) {
return tensor::getMixedSizes(builder, getLoc(), getOutput());
}
ShapedType PackOp::getPackedType(ShapedType sourceType,
ArrayRef<int64_t> innerTileSizes,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm) {
SmallVector<int64_t> resultTypeShape = getPackOpResultTypeShape(
sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
return TypeSwitch<ShapedType, ShapedType>(sourceType)
.Case<RankedTensorType>([&](auto shapedType) {
return RankedTensorType::get(resultTypeShape,
shapedType.getElementType());
})
.Case<MemRefType>([&](auto shapedType) {
return MemRefType::get(resultTypeShape, shapedType.getElementType());
})
.Default([&](Type t) {
assert(false && "unexpected type");
return nullptr;
});
}
DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
return LinalgExt::getDimAndTileMapping(*this);
}
LogicalResult
PackOp::reifyResultShapes(OpBuilder &builder,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(builder, reifiedReturnShapes);
}
//===----------------------------------------------------------------------===//
// UnPackOp
//===----------------------------------------------------------------------===//
/// Custom builder methods for unpack ops.
void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
Value output, ArrayRef<int64_t> innerDimsPos,
ArrayRef<OpFoldResult> innerTiles,
ArrayRef<int64_t> outerDimsPerm) {
SmallVector<int64_t> staticTileSizes;
SmallVector<Value> dynamicTileSizes;
dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
SmallVector<Type> resultType;
auto outputType = output.getType();
if (isa<RankedTensorType>(outputType)) {
resultType.push_back(outputType);
}
build(builder, state, resultType, source, output,
outerDimsPerm.empty() ? nullptr
: builder.getDenseI64ArrayAttr(outerDimsPerm),
builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
builder.getDenseI64ArrayAttr(staticTileSizes));
}
SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
return LinalgExt::getMixedTiles(*this);
}
SmallVector<int64_t> UnPackOp::getStaticTiles() {
return LinalgExt::getStaticTiles(*this);
}
DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
return LinalgExt::getDimAndTileMapping(*this);
}
LogicalResult
UnPackOp::reifyResultShapes(OpBuilder &builder,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(builder, reifiedReturnShapes);
}
LogicalResult UnPackOp::verify() {
if (failed(commonVerifierPackAndUnPackOp(*this))) {
return failure();
}
return success();
}
//===----------------------------------------------------------------------===//
// WinogradInputTransformOp
//===----------------------------------------------------------------------===//
LogicalResult WinogradInputTransformOp::verify() {
Operation *op = getOperation();
if (getNumDpsInputs() != 1) {
return op->emitOpError("expected one input operand");
}
if (getNumDpsInits() != 1) {
return op->emitOpError("expected one output operand");
}
auto inputType = getInputType();
auto outputType = getOutputType();
if (outputType.getElementType() != inputType.getElementType()) {
return op->emitOpError(
"expected input/output element types to be identical");
}
unsigned inputRank = inputType.getRank();
unsigned outputRank = outputType.getRank();
if (inputRank != 2 && inputRank != 4) {
return op->emitOpError("expected input operand to have rank either 2 or 4");
}
if (inputRank == 2) {
if (outputRank != 2) {
return op->emitOpError(
"expected output operand to have rank 2 if input is of rank 2");
}
if ((!inputType.isDynamicDim(0) &&
inputType.getDimSize(0) > getInputTileSize()) ||
(inputType.isDynamicDim(1) &&
inputType.getDimSize(1) > getInputTileSize())) {
return op->emitOpError("expected input dims not greater than input tile "
"size if input is of rank 2");
}
SmallVector<int64_t> expectedOutputShape(2, getInputTileSize());
if (failed(verifyCompatibleShape(expectedOutputShape,
outputType.getShape()))) {
return op->emitOpError(
"expected output dims equal to inputTileSize if input is of rank 2");
}
return success();
}
if (getOutputRank() != getInputRank() + 2) {
return op->emitOpError(
"expected output rank to be equal to input rank + 2");
}
ArrayRef<int64_t> imageDims = getImageDimensions();
llvm::SmallSetVector<int64_t, 2> imageDimsSet(imageDims.begin(),
imageDims.end());
if (imageDims.size() != 2) {
return op->emitOpError("expected only 2 image dimensions");
}
if (!isNchw() && !isNhwc()) {
return op->emitOpError(
"expect image dimensions to be either [1, 2] or [2, 3]");
}
SmallVector<int64_t> expectedOutputShape(getOutputRank(), getInputTileSize());
int outputIndex;
ArrayRef<int64_t> inputShape = inputType.getShape();
for (int i = 0; i < inputShape.size(); i++) {
outputIndex = i + imageDims.size();
if (ShapedType::isDynamic(inputShape[i])) {
expectedOutputShape[outputIndex] = inputShape[i];
continue;
}
if (!imageDimsSet.contains(i)) {
expectedOutputShape[outputIndex] = inputShape[i];
} else {
expectedOutputShape[outputIndex] =
std::ceil(static_cast<float>(inputShape[i] - getKernelSize() + 1) /
getOutputTileSize());
}
}
if (isNchw()) {
permute<Permutation::TTNCHW_TO_TTNHWC>(expectedOutputShape);
}
ArrayRef<int64_t> outputShape = outputType.getShape();
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
return op->emitOpError("incompatible output shape");
}
return success();
}
LogicalResult WinogradInputTransformOp::fold(FoldAdaptor,
SmallVectorImpl<OpFoldResult> &) {
return memref::foldMemRefCast(*this);
}
LogicalResult WinogradInputTransformOp::reifyResultShapes(
OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(b, reifiedReturnShapes);
}
//===----------------------------------------------------------------------===//
// WinogradFilterTransformOp
//===----------------------------------------------------------------------===//
LogicalResult WinogradFilterTransformOp::verify() {
Operation *op = getOperation();
if (getNumDpsInputs() != 1) {
return op->emitOpError("expected one input operand");
}
if (getNumDpsInits() != 1) {
return op->emitOpError("expected one output operand");
}
auto inputType = getInputType();
auto outputType = getOutputType();
if (outputType.getElementType() != inputType.getElementType()) {
return op->emitOpError(
"expected input/output element types to be identical");
}
unsigned inputRank = inputType.getRank();
unsigned outputRank = outputType.getRank();
if (inputRank != 2 && inputRank != 4) {
return op->emitOpError("expected input operand to have rank either 2 or 4");
}
if (inputRank == 2) {
if (outputRank != 2) {
return op->emitOpError(
"expected output operand to have rank 2 if input is of rank 2");
}
SmallVector<int64_t> expectedInputShape(2, getKernelSize());
if (failed(
verifyCompatibleShape(expectedInputShape, inputType.getShape()))) {
return op->emitOpError("expected input dims to be equal to kernel size "
"if input is of rank 2");
}
SmallVector<int64_t> expectedOutputShape(2, getInputTileSize());
if (failed(verifyCompatibleShape(expectedOutputShape,
outputType.getShape()))) {
return op->emitOpError("expected output dims equal to input tile size if "
"input is of rank 2");
}
return success();
}
if (getOutputRank() != getInputRank()) {
return op->emitOpError("expected output rank to be equal to input rank");
}
const ArrayRef<int64_t> kernelDims = getKernelDimensions();
if (kernelDims.size() != 2) {
return op->emitOpError("expected only 2 kernel dimensions");
}
if (!isHwcf() && !isFchw()) {
return op->emitOpError(
"expect kernel dimensions to be either [0, 1] or [2, 3]");
}
const int64_t kernelSize = getKernelSize();
for (auto kernelDim : kernelDims) {
if (inputType.getDimSize(kernelDim) != kernelSize) {
return op->emitOpError(
"expect all kernel dimensions to have the kernel size");
}
}
const int64_t inputTileSize = getInputTileSize();
SmallVector<int64_t> expectedOutputShape(kernelDims.size(), inputTileSize);
llvm::SmallSetVector<int64_t, 2> kernelDimsSet(kernelDims.begin(),
kernelDims.end());
for (int i = 0; i < inputType.getRank(); i++) {
if (!kernelDimsSet.contains(i)) {
expectedOutputShape.push_back(inputType.getDimSize(i));
}
}
if (isFchw()) {
permute<Permutation::TTFC_TO_TTCF>(expectedOutputShape);
}
ArrayRef<int64_t> outputShape = outputType.getShape();
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
return op->emitOpError("incompatible output shape");
}
return success();
}
LogicalResult WinogradFilterTransformOp::fold(FoldAdaptor,
SmallVectorImpl<OpFoldResult> &) {
return memref::foldMemRefCast(*this);
}
LogicalResult WinogradFilterTransformOp::reifyResultShapes(
OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(b, reifiedReturnShapes);
}
//===----------------------------------------------------------------------===//
// WinogradOutputTransformOp
//===----------------------------------------------------------------------===//
LogicalResult WinogradOutputTransformOp::verify() {
Operation *op = getOperation();
if (getNumDpsInputs() != 1) {
return op->emitOpError("expected one input operand");
}
if (getNumDpsInits() != 1) {
return op->emitOpError("expected one output operand");
}
auto inputType = getInputType();
auto outputType = getOutputType();
unsigned inputRank = inputType.getRank();
unsigned outputRank = outputType.getRank();
if (inputRank != 2 && inputRank != 6) {
return op->emitOpError("expected input operand to have rank either 2 or 6");
}
if (inputRank == 2) {
if (outputRank != 2) {
return op->emitOpError(
"expected output operand to have rank 2 if input is of rank 2");
}
SmallVector<int64_t> expectedInputShape(2, getInputTileSize());
if (failed(
verifyCompatibleShape(expectedInputShape, inputType.getShape()))) {
return op->emitOpError("expected input dims to be equal to input tile "
"size if input is of rank 2");
}
SmallVector<int64_t> expectedOutputShape(2, getOutputTileSize());
if (failed(verifyCompatibleShape(expectedOutputShape,
outputType.getShape()))) {
return op->emitOpError("expected output dims equal to output tile size "
"if input is of rank 2");
}
return success();
}
ArrayRef<int64_t> outputShape = outputType.getShape();
if (outputType.getElementType() != inputType.getElementType()) {
return op->emitOpError(
"expected input/output element types to be identical");
}
if (outputRank != inputRank - 2) {
return op->emitOpError(
"expected output rank to be equal to input rank - 2");
}
ArrayRef<int64_t> imageDims = getImageDimensions();
llvm::SmallSetVector<int64_t, 2> imageDimsSet(imageDims.begin(),
imageDims.end());
if (imageDims.size() != 2) {
return op->emitOpError("expected only 2 image dimensions");
}
if (!isNchw() && !isNhwc()) {
return op->emitOpError(
"expect image dimensions to be either [1, 2] or [2, 3]");
}
SmallVector<int64_t> inputShape(inputType.getShape());
if (isNchw()) {
permute<Permutation::TTNHWC_TO_TTNCHW>(inputShape);
}
SmallVector<int64_t> expectedOutputShape(getOutputRank(), 1);
int outputIndex;
for (int i = imageDims.size(); i < inputShape.size(); i++) {
outputIndex = i - imageDims.size();
if (ShapedType::isDynamic(inputShape[i])) {
expectedOutputShape[outputIndex] = inputShape[i];
continue;
}
if (!imageDimsSet.contains(outputIndex)) {
expectedOutputShape[outputIndex] = inputShape[i];
} else {
expectedOutputShape[outputIndex] = getOutputTileSize() * inputShape[i];
}
}
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
return op->emitOpError("incompatible output shape");
}
return success();
}
LogicalResult WinogradOutputTransformOp::fold(FoldAdaptor,
SmallVectorImpl<OpFoldResult> &) {
return memref::foldMemRefCast(*this);
}
LogicalResult WinogradOutputTransformOp::reifyResultShapes(
OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(b, reifiedReturnShapes);
}
//===----------------------------------------------------------------------===//
// AttentionOp
//===----------------------------------------------------------------------===//
void AttentionOp::build(OpBuilder &odsBuilder, OperationState &odsState,
TypeRange results, Value query, Value key, Value value,
Value scale, Value output, ArrayAttr indexingMaps,
std::optional<Value> mask) {
Value maskIn = mask.value_or(Value());
build(odsBuilder, odsState, results, query, key, value, scale, maskIn, output,
indexingMaps, DictionaryAttr());
}
LogicalResult AttentionOp::verify() {
AttentionOp attnOp = *this;
// Check if indexing maps can represent attention.
SmallVector<AffineMap> indexingMaps = attnOp.getIndexingMapsArray();
if (indexingMaps.size() != getOperation()->getNumOperands()) {
return attnOp->emitOpError("expected an indexing map for each operand");
}
FailureOr<AttentionOpDetail> maybeOpInfo = AttentionOpDetail::get(
getQueryMap(), getKeyMap(), getValueMap(), getOutputMap());
if (failed(maybeOpInfo)) {
return attnOp->emitOpError("failed to verify op's indexing maps");
}
FloatType scaleElementType = dyn_cast<FloatType>(getScale().getType());
if (!scaleElementType) {
return attnOp->emitOpError("expected scale to be of floating point type");
}
// Check shape compatibility based on indexing maps.
SmallVector<int64_t> shape(getIterationDomainRank());
SmallVector<bool> foundDims(getIterationDomainRank(), false);
auto checkShape = [&shape, &foundDims,
&attnOp](StringRef operandName, ArrayRef<int64_t> valShape,
AffineMap indexingMap) -> LogicalResult {
if (indexingMap.getNumResults() != valShape.size()) {
return attnOp->emitError("Rank Mismatch for ")
<< operandName << ". Expected: " << indexingMap.getNumResults()
<< " Got: " << valShape.size();
}
for (auto [i, dimExpr] : llvm::enumerate(indexingMap.getResults())) {
AffineDimExpr dim = cast<AffineDimExpr>(dimExpr);
int64_t pos = dim.getPosition();
if (ShapedType::isDynamic(valShape[i])) {
continue;
}
if (!foundDims[pos]) {
foundDims[pos] = true;
shape[pos] = valShape[i];
}
if (shape[pos] != valShape[i]) {
return attnOp->emitError("Shape Mismatch for ")
<< operandName << " at position " << i
<< ". Expected: " << shape[pos] << " Got: " << valShape[i];
}
}
return success();
};
if (failed(checkShape("Query", getQuery().getType().getShape(),
getQueryMap())) ||
failed(checkShape("Key", getKey().getType().getShape(), getKeyMap())) ||
failed(checkShape("Value", getValue().getType().getShape(),
getValueMap())) ||
failed(checkShape("Output", getOutput().getType().getShape(),
getOutputMap()))) {
return failure();
}
// Additional check case if mask exists
if (auto maskMap = getMaskMap()) {
if (failed(checkShape("Mask", getMask().getType().getShape(), *maskMap)))
return failure();
}
int expectedSymbols = getQueryMap().getNumInputs();
auto checkDomain =
[&attnOp, &expectedSymbols](StringRef operandName,
AffineMap indexingMap) -> LogicalResult {
if (expectedSymbols != indexingMap.getNumInputs()) {
return attnOp->emitError("Mismatched map domain for ")
<< operandName << ". Expected: " << expectedSymbols
<< " Got: " << indexingMap.getNumInputs();
}
return success();
};
if (failed(checkDomain("Query", getQueryMap())) ||
failed(checkDomain("Key", getKeyMap())) ||
failed(checkDomain("Value", getValueMap())) ||
failed(checkDomain("Scale", getScaleMap())) ||
failed(checkDomain("Output", getOutputMap()))) {
return failure();
}
// Additional check case if mask exists
if (auto maskMap = getMaskMap()) {
if (failed(checkDomain("Mask", *maskMap)))
return failure();
}
auto &block = getRegion().front();
auto blockTys = block.getArgumentTypes();
if (!isa<FloatType>(blockTys[0]))
return attnOp->emitOpError("block argument 0 should be float");
auto yieldOp = dyn_cast<IREE::LinalgExt::YieldOp>(block.getTerminator());
if (!yieldOp) {
return attnOp->emitOpError("expected linalg_ext.yield");
}
if (yieldOp->getNumOperands() != 1) {
return emitOpError("expected only one return");
}
return success();
}
MutableOperandRange AttentionOp::getDpsInitsMutable() {
return MutableOperandRange(*this, /*numInputs=*/getMask() ? 5 : 4,
/*numInits=*/1);
}
LogicalResult AttentionOp::reifyResultShapes(
OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(b, reifiedReturnShapes);
}
SmallVector<AffineMap> AttentionOp::getIndexingMapsArray() {
return SmallVector<AffineMap>(
getIndexingMaps().getAsValueRange<AffineMapAttr>());
}
FailureOr<SmallVector<int64_t>> AttentionOp::getStaticLoopRanges() {
SmallVector<int64_t> bounds(getIterationDomainRank());
SmallVector<bool> dimsFound(getIterationDomainRank(), false);
// batch(s), m, k1
ArrayRef<int64_t> queryShape = getQuery().getType().getShape();
ArrayRef<AffineExpr> queryDims = getQueryMap().getResults();
// batch(s), k2, n
ArrayRef<int64_t> valueShape = getValue().getType().getShape();
ArrayRef<AffineExpr> valueDims = getValueMap().getResults();
auto fillSizes = [&](ArrayRef<int64_t> sizes, ArrayRef<AffineExpr> dims) {
for (auto [size, dim] : llvm::zip_equal(sizes, dims)) {
int pos = cast<AffineDimExpr>(dim).getPosition();
if (dimsFound[pos]) {
continue;
}
bounds[pos] = size;
dimsFound[pos] = true;
}
};
fillSizes(queryShape, queryDims);
fillSizes(valueShape, valueDims);
return bounds;
}
SmallVector<AffineMap> AttentionOp::getIndexingMapsForOperands() {
auto maps = getIndexingMapsArray();
maps.resize(getNumDpsInputs());
return maps;
}
SmallVector<AffineMap> AttentionOp::getIndexingMapsForResults() {
auto maps = getIndexingMapsArray();
return SmallVector<AffineMap>(maps.begin() + getNumDpsInputs(), maps.end());
}
//===----------------------------------------------------------------------===//
// OnlineAttentionOp
//===----------------------------------------------------------------------===//
void OnlineAttentionOp::build(OpBuilder &odsBuilder, OperationState &odsState,
TypeRange results, Value query, Value key,
Value value, Value scale, Value output, Value max,
Value sum, ArrayAttr indexingMaps,
std::optional<Value> mask) {
Value maskIn = mask.value_or(Value());
build(odsBuilder, odsState, results, query, key, value, maskIn, scale, output,
max, sum, indexingMaps, DictionaryAttr());
}
LogicalResult OnlineAttentionOp::verify() {
OnlineAttentionOp attnOp = *this;
SmallVector<AffineMap> indexingMaps = attnOp.getIndexingMapsArray();
// Check if indexing maps can represent attention.
FailureOr<AttentionOpDetail> maybeOpInfo = AttentionOpDetail::get(
getQueryMap(), getKeyMap(), getValueMap(), getOutputMap());
// Check shape compatibility based on indexing maps.
SmallVector<int64_t> shape(getIterationDomainRank());
SmallVector<bool> foundDims(getIterationDomainRank(), false);
auto checkShape = [&shape, &foundDims,
&attnOp](StringRef operandName, ArrayRef<int64_t> valShape,
AffineMap indexingMap) -> LogicalResult {
if (indexingMap.getNumResults() != valShape.size()) {
return attnOp->emitError("Rank Mismatch for ")
<< operandName << ". Expected: " << indexingMap.getNumResults()
<< " Got: " << valShape.size();
}
for (auto [i, dimExpr] : llvm::enumerate(indexingMap.getResults())) {
AffineDimExpr dim = cast<AffineDimExpr>(dimExpr);
int64_t pos = dim.getPosition();
if (ShapedType::isDynamic(valShape[i])) {
continue;
}
if (!foundDims[pos]) {
foundDims[pos] = true;
shape[pos] = valShape[i];
}
if (shape[pos] != valShape[i]) {
return attnOp->emitError("Shape Mismatch for ")
<< operandName << ". Expected: " << shape[pos]
<< " Got: " << valShape[i];
}
}
return success();
};
if (failed(checkShape("Query", getQuery().getType().getShape(),
getQueryMap())) ||
failed(checkShape("Key", getKey().getType().getShape(), getKeyMap())) ||
failed(checkShape("Value", getValue().getType().getShape(),
getValueMap())) ||
failed(checkShape("Output", getOutput().getType().getShape(),
getOutputMap())) ||
failed(checkShape("Max", getMax().getType().getShape(), getMaxMap())) ||
failed(checkShape("Sum", getSum().getType().getShape(), getSumMap()))) {
return failure();
}
// Additional check case if mask exists
if (auto maskMap = getMaskMap()) {
if (failed(checkShape("Mask", getMask().getType().getShape(), *maskMap)))
return failure();
}
int expectedSymbols = getQueryMap().getNumInputs();
auto checkDomain =
[&attnOp, &expectedSymbols](StringRef operandName,
AffineMap indexingMap) -> LogicalResult {
if (expectedSymbols != indexingMap.getNumInputs()) {
return attnOp->emitError("Mismatched map domain for ")
<< operandName << ". Expected: " << expectedSymbols
<< " Got: " << indexingMap.getNumInputs();
}
return success();
};
if (failed(checkDomain("Query", getQueryMap())) ||
failed(checkDomain("Key", getKeyMap())) ||
failed(checkDomain("Value", getValueMap())) ||
failed(checkDomain("Scale", getScaleMap())) ||
failed(checkDomain("Output", getOutputMap())) ||
failed(checkDomain("Max", getMaxMap())) ||
failed(checkDomain("Sum", getSumMap()))) {
return failure();
}
// Additional check case if mask exists
if (auto maskMap = getMaskMap()) {
if (failed(checkDomain("Mask", *maskMap)))
return failure();
}
Block &block = attnOp.getRegion().front();
auto blockTys = block.getArgumentTypes();
if (blockTys.size() != 1) {
return attnOp->emitOpError("expects single block argument for score");
}
if (!isa<FloatType>(blockTys[0])) {
return attnOp->emitOpError("block argument 0 should be float");
}
auto yieldOp = dyn_cast<IREE::LinalgExt::YieldOp>(block.getTerminator());
if (!yieldOp) {
return attnOp->emitOpError("expected linalg_ext.yield");
}
if (yieldOp->getNumOperands() != 1) {
return emitOpError("expected only one return");
}
return success();
}
MutableOperandRange OnlineAttentionOp::getDpsInitsMutable() {
return MutableOperandRange(*this, /*numInputs=*/getMask() ? 5 : 4,
/*numInits=*/3);
}
LogicalResult OnlineAttentionOp::reifyResultShapes(
OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(b, reifiedReturnShapes);
}
SmallVector<AffineMap> OnlineAttentionOp::getIndexingMapsArray() {
return SmallVector<AffineMap>(
getIndexingMaps().getAsValueRange<AffineMapAttr>());
}
//===----------------------------------------------------------------------===//
// Im2colOp
//===----------------------------------------------------------------------===//
/// Return all static and dynamic kernel_size as OpFoldResults.
SmallVector<OpFoldResult> Im2colOp::getMixedKernelSize() {
return LinalgExt::getMixedValues(getContext(), getStaticKernelSize(),
getKernelSize());
}
/// Return all static and dynamic k_offset as OpFoldResults.
SmallVector<OpFoldResult> Im2colOp::getMixedKOffset() {
return LinalgExt::getMixedValues(getContext(), getStaticKOffset(),
getKOffset());
}
/// Return all static and dynamic m_offset as OpFoldResults.
SmallVector<OpFoldResult> Im2colOp::getMixedMOffset() {
return LinalgExt::getMixedValues(getContext(), getStaticMOffset(),
getMOffset());
}
/// Return all static and dynamic k_strides as OpFoldResults.
SmallVector<OpFoldResult> Im2colOp::getMixedKStrides() {
return LinalgExt::getMixedValues(getContext(), getStaticKStrides(),
getKStrides());
}
/// Return all static and dynamic m_strides as OpFoldResults.
SmallVector<OpFoldResult> Im2colOp::getMixedMStrides() {
return LinalgExt::getMixedValues(getContext(), getStaticMStrides(),
getMStrides());
}
void Im2colOp::setMixedKOffset(SmallVector<OpFoldResult> kOffset) {
SmallVector<int64_t> staticKOffset;
SmallVector<Value> dynamicKOffset;
dispatchIndexOpFoldResults(kOffset, dynamicKOffset, staticKOffset);
setStaticKOffset(staticKOffset);
getKOffsetMutable().assign(dynamicKOffset);
}
void Im2colOp::setMixedMOffset(SmallVector<OpFoldResult> mOffset) {
SmallVector<int64_t> staticMOffset;
SmallVector<Value> dynamicMOffset;
dispatchIndexOpFoldResults(mOffset, dynamicMOffset, staticMOffset);
setStaticMOffset(staticMOffset);
getMOffsetMutable().assign(dynamicMOffset);
}
void Im2colOp::setMixedKStrides(SmallVector<OpFoldResult> kStrides) {
SmallVector<int64_t> staticKStrides;
SmallVector<Value> dynamicKStrides;
dispatchIndexOpFoldResults(kStrides, dynamicKStrides, staticKStrides);
setStaticKStrides(staticKStrides);
getKStridesMutable().assign(dynamicKStrides);
}
void Im2colOp::setMixedMStrides(SmallVector<OpFoldResult> mStrides) {
SmallVector<int64_t> staticMStrides;
SmallVector<Value> dynamicMStrides;
dispatchIndexOpFoldResults(mStrides, dynamicMStrides, staticMStrides);
setStaticMStrides(staticMStrides);
getMStridesMutable().assign(dynamicMStrides);
}
SmallVector<int64_t> Im2colOp::getBatchOutputDims() {
return llvm::to_vector(llvm::seq<int64_t>(0, getBatchPos().size()));
}
SmallVector<int64_t> Im2colOp::getMOutputDims() {
int64_t begin = getBatchPos().size();
int64_t end = begin + getMixedMOffset().size();
return llvm::to_vector(llvm::seq<int64_t>(begin, end));
}
SmallVector<int64_t> Im2colOp::getKOutputDims() {
int64_t begin = getBatchPos().size() + getMixedMOffset().size();
int64_t end = begin + getMixedKOffset().size();
return llvm::to_vector(llvm::seq<int64_t>(begin, end));
}
/// Custom builder methods for im2col op.
void Im2colOp::build(
OpBuilder &builder, OperationState &state, Value input, Value output,
ArrayRef<int64_t> strides, ArrayRef<int64_t> dilations,
ArrayRef<OpFoldResult> kernelSize, ArrayRef<OpFoldResult> mOffset,
ArrayRef<OpFoldResult> mStrides, ArrayRef<OpFoldResult> kOffset,
ArrayRef<OpFoldResult> kStrides, ArrayRef<int64_t> batchPos,
ArrayRef<int64_t> mPos, ArrayRef<int64_t> kPos) {
assert(strides.size() == kernelSize.size() &&
dilations.size() == kernelSize.size() &&
mPos.size() == kernelSize.size() &&
"strides, dilations, m_pos, and kernel expected to be the same rank");
SmallVector<int64_t> staticKernelSize, staticMOffset, staticKOffset,
staticMStrides, staticKStrides;
SmallVector<Value> dynamicKernelSize, dynamicMOffset, dynamicKOffset,
dynamicMStrides, dynamicKStrides;
dispatchIndexOpFoldResults(kernelSize, dynamicKernelSize, staticKernelSize);
dispatchIndexOpFoldResults(mOffset, dynamicMOffset, staticMOffset);
dispatchIndexOpFoldResults(mStrides, dynamicMStrides, staticMStrides);
dispatchIndexOpFoldResults(kOffset, dynamicKOffset, staticKOffset);
dispatchIndexOpFoldResults(kStrides, dynamicKStrides, staticKStrides);
SmallVector<Type> resultType;
auto outputType = output.getType();
if (isa<RankedTensorType>(outputType)) {
resultType.push_back(outputType);
}
build(builder, state, resultType, input, output,
builder.getDenseI64ArrayAttr(strides),
builder.getDenseI64ArrayAttr(dilations), dynamicKernelSize,
builder.getDenseI64ArrayAttr(staticKernelSize), dynamicMOffset,
builder.getDenseI64ArrayAttr(staticMOffset), dynamicMStrides,
builder.getDenseI64ArrayAttr(staticMStrides), dynamicKOffset,
builder.getDenseI64ArrayAttr(staticKOffset), dynamicKStrides,
builder.getDenseI64ArrayAttr(staticKStrides),
builder.getDenseI64ArrayAttr(batchPos),
builder.getDenseI64ArrayAttr(mPos), builder.getDenseI64ArrayAttr(kPos));
}
LogicalResult Im2colOp::verify() {
Operation *op = getOperation();
if (llvm::count_if(getDpsInputs(), [](Value v) {
return isa<ShapedType>(v.getType());
}) != 1) {
return op->emitOpError("expected only one ShapedType operand");
}
if (getNumDpsInits() != 1) {
return op->emitOpError("expected one output operand");
}
// Verify offsets and strides
SmallVector<OpFoldResult> kOffset = getMixedKOffset();
SmallVector<OpFoldResult> mOffset = getMixedMOffset();
SmallVector<OpFoldResult> kStrides = getMixedKStrides();
SmallVector<OpFoldResult> mStrides = getMixedMStrides();
if (kOffset.size() < 1) {
return op->emitOpError("expected at least one k_offset");
}
if (mOffset.size() < 1) {
return op->emitOpError("expected at least one m_offset");
}
if (kOffset.size() != kStrides.size()) {
return op->emitOpError("expected the same size k_offset and k_strides");
}
if (mOffset.size() != mStrides.size()) {
return op->emitOpError("expected the same size m_offset and m_strides");
}
std::optional<int64_t> constInnerKStrides =
getConstantIntValue(kStrides.back());
if (!constInnerKStrides.has_value() || constInnerKStrides.value() != 1) {
return op->emitOpError("expected inner k_strides to be 1");
}
std::optional<int64_t> constInnerMStrides =
getConstantIntValue(mStrides.back());
if (!constInnerMStrides.has_value() || constInnerMStrides.value() != 1) {
return op->emitOpError("expected inner m_strides to be 1");
}
// Verify operand ranks and dim position sizes.
auto inputType = getInputType();
unsigned inputRank = inputType.getRank();
ArrayRef<int64_t> batchPos = getBatchPos();
ArrayRef<int64_t> mPos = getMPos();
ArrayRef<int64_t> kPos = getKPos();
if (inputRank != batchPos.size() + mPos.size() + kPos.size()) {
return op->emitOpError(
"expected input rank to be the sum of batch, m, and k ranks");
}
auto outputType = getOutputType();
unsigned outputRank = outputType.getRank();
if (outputRank != batchPos.size() + kOffset.size() + mOffset.size()) {
return op->emitOpError("expected output rank to be the sum of "
"batch_pos, k_offset, and m_offset ranks");
}
// Verify convolution metadata.
ArrayRef<int64_t> strides = getStrides();
ArrayRef<int64_t> dilations = getDilations();
SmallVector<OpFoldResult> kernelSize = getMixedKernelSize();
if (kernelSize.size() != mPos.size()) {
return op->emitOpError(
"expected kernel rank to be equal to the m_pos rank");
}
if (strides.size() != kernelSize.size()) {
return op->emitOpError(
"expected strides rank to be equal to the kernel rank");
}
if (dilations.size() != kernelSize.size()) {
return op->emitOpError(
"expected dilations rank to be equal to the kernel rank");
}
// Verify input and output shapes.
ArrayRef<int64_t> inputShape = inputType.getShape();
ArrayRef<int64_t> outputShape = outputType.getShape();
// When the op is tiled, the m and k dimensions of the output are tiled, but
// they are not tiled in the input, so we cannot verify the output size of
// these dimensions. Only verify the shape of the batch dimensions.
SmallVector<int64_t> expectedOutputShape(outputShape);
for (auto [idx, pos] : llvm::enumerate(batchPos)) {
expectedOutputShape[idx] = inputShape[pos];
}
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
return op->emitOpError("incompatible output shape");
}
return success();
}
LogicalResult Im2colOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
return memref::foldMemRefCast(*this);
}
LogicalResult
Im2colOp::reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(b, reifiedReturnShapes);
}
//===---------------------------------------------------------------------===//
// Custom Op
//===---------------------------------------------------------------------===//
unsigned CustomOp::getNumLoops() { return getIteratorTypesAttr().size(); }
int64_t CustomOp::getRank(Value v) {
Type type = v.getType();
if (type.isIntOrIndexOrFloat()) {
return 0;
}
return cast<RankedTensorType>(type).getRank();
}
unsigned CustomOp::getNumNonLoopDimensions() {
for (auto map : getIndexingMaps().getAsValueRange<AffineMapAttr>()) {
if (map.isEmpty()) {
continue;
}
return map.getNumSymbols();
}
return 0;
}
LogicalResult CustomOp::verify() {
// All inputs/outputs must have indexing maps.
if (static_cast<int64_t>(getIndexingMapsAttr().size()) != getNumOperands()) {
return emitOpError("expected number of indexing maps (")
<< getIndexingMapsAttr().size()
<< ") to be same as the "
"number of input/output operands ("
<< getNumOperands() << ")";
}
// Check the form of the indexing maps.
std::optional<unsigned> numSymbolDims;
for (auto [index, indexingMapAttr, operand] :
llvm::enumerate(getIndexingMapsAttr(), getOperands())) {
auto indexingMap = cast<AffineMapAttr>(indexingMapAttr).getValue();
if (indexingMap.isEmpty()) {
continue;
}
// Domain must be consistent.
unsigned numLoops = getNumLoops();
if (indexingMap.getNumDims() != numLoops) {
return emitOpError("expected indexing_map #")
<< index << " to have " << numLoops
<< " dim(s) to match the number of loops or be zero";
}
// Check that number of symbols is consistent.
if (numSymbolDims) {
if (indexingMap.getNumSymbols() != numSymbolDims.value()) {
return emitOpError(
"inconsistent number of symbol dimensions in indexing_map #")
<< index << ", expected " << numSymbolDims.value()
<< " instead of " << indexingMap.getNumSymbols();
}
} else {
numSymbolDims = indexingMap.getNumSymbols();
}
// Range must match the rank of the operands.
int64_t rank = getRank(operand);
if (indexingMap.getNumResults() != rank) {
return emitOpError("expected operand rank(")
<< rank << ") to match the result rank of indexing map #" << index;
}
}
// Check that number of basic block arguments is same as number of operands
Block *body = getBody();
if (body->getNumArguments() != getNumOperands()) {
return emitOpError("expected as many basic block arguments (")
<< body->getNumArguments() << ") as the number of operands ("
<< getNumOperands() << ")";
}
// Check that type of the basic block argument matches the type of the
// operands.
for (auto [index, bbArg, operand] :
llvm::enumerate(body->getArguments(), getOperands())) {
Type operandType = operand.getType();
Type bbArgType = bbArg.getType();
if (operandType.isIntOrIndexOrFloat()) {
if (operandType != bbArgType) {
return emitOpError("for (scalar) operand #")
<< index
<< " expected corresponding basic block argument to be of the "
"same type";
}
continue;
}
auto operandTensorType = cast<RankedTensorType>(operandType);
auto bbArgTensorType = dyn_cast<RankedTensorType>(bbArgType);
if (!bbArgTensorType) {
return emitOpError("for (tensor) operand #")
<< index
<< " expected corresponding basic block argument to be tensor as "
"well";
}
// Check that the basic block arg has same rank/element type, but all shapes
// dynamic.
auto expectedBBArgType = RankedTensorType::get(
SmallVector<int64_t>(operandTensorType.getRank(), ShapedType::kDynamic),
operandTensorType.getElementType(), operandTensorType.getEncoding());
if (bbArgTensorType != expectedBBArgType) {
return emitOpError("expected basic block argument corresponding to "
"(tensor) operand #")
<< index << " to be " << expectedBBArgType << " instead of "
<< bbArgTensorType;
}
}
// Check yield operation operand types.
auto yieldOp = cast<IREE::LinalgExt::YieldOp>(body->getTerminator());
if (yieldOp->getNumOperands() != getOutputs().size()) {
return emitOpError(
"expected as many yields as the numbers of `outs` operand");
}
for (auto [index, yieldVal, bbOperand] :
llvm::enumerate(yieldOp.getOperands(),
body->getArguments().take_back(getOutputs().size()))) {
if (yieldVal.getType() != bbOperand.getType()) {
return emitOpError("expected type of ")
<< index
<< "-th operand of yield to match the corresponding output basic "
"block argument";
}
}
return success();
}
/// Start `LinalgFusionInterface` implementation.
SmallVector<AffineMap> CustomOp::getIndexingMapsForOperands() {
return llvm::map_to_vector(
getIndexingMaps().getValue().take_front(getNumDpsInputs()),
[](Attribute attr) { return cast<AffineMapAttr>(attr).getValue(); });
}
SmallVector<AffineMap> CustomOp::getIndexingMapsForResults() {
return llvm::map_to_vector(
getIndexingMaps().getValue().take_back(getNumDpsInits()),
[](Attribute attr) { return cast<AffineMapAttr>(attr).getValue(); });
}
/// End `LinalgFusionInterface` implementation
/// Start `ReifyRankedShapedTypeOpInterface` implementation
LogicalResult
CustomOp::reifyResultShapes(OpBuilder &builder,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
for (auto init : getOutputs()) {
SmallVector<OpFoldResult> sizes =
tensor::getMixedSizes(builder, getLoc(), init);
reifiedReturnShapes.emplace_back(std::move(sizes));
}
return success();
}
/// End `ReifyRankedShapedTypeOpInterface` implementation
//===---------------------------------------------------------------------===//
// IndexOp
//===---------------------------------------------------------------------===//
LogicalResult IREE::LinalgExt::IndexOp::verify() {
auto customOp = dyn_cast<CustomOp>(getOperation()->getParentOp());
if (!customOp) {
return emitOpError("expected parent op to be `iree_linalg_ext.custom_op`");
}
if (customOp.getNumLoops() <= getDim()) {
return emitOpError("expected dim (")
<< getDim() << ") to be lower than the number of loops ("
<< customOp.getNumLoops() << ") of the enclosing CustomOp";
}
return success();
}
//===---------------------------------------------------------------------===//
// End operation definitions
//===---------------------------------------------------------------------===//
#define DEFINE_OP_GET_EFFECTS(OP_NAME) \
void OP_NAME::getEffects( \
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \
&effects) { \
getEffectsImpl(effects, getDpsInputOperands(), getDpsInitsMutable()); \
}
DEFINE_OP_GET_EFFECTS(ScatterOp)
DEFINE_OP_GET_EFFECTS(SortOp)
DEFINE_OP_GET_EFFECTS(FftOp)
DEFINE_OP_GET_EFFECTS(ScanOp)
DEFINE_OP_GET_EFFECTS(TopkOp)
DEFINE_OP_GET_EFFECTS(PackOp)
DEFINE_OP_GET_EFFECTS(UnPackOp)
DEFINE_OP_GET_EFFECTS(WinogradInputTransformOp)
DEFINE_OP_GET_EFFECTS(WinogradFilterTransformOp)
DEFINE_OP_GET_EFFECTS(WinogradOutputTransformOp)
DEFINE_OP_GET_EFFECTS(AttentionOp)
DEFINE_OP_GET_EFFECTS(OnlineAttentionOp)
DEFINE_OP_GET_EFFECTS(Im2colOp)
DEFINE_OP_GET_EFFECTS(CustomOp)
} // namespace mlir::iree_compiler::IREE::LinalgExt
// clang-format off
#define GET_OP_CLASSES
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp.inc" // IWYU pragma: keep
// clang-format: on