| // Copyright 2021 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h" |
| |
| #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h" |
| #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| #include "mlir/Dialect/Arithmetic/Utils/Utils.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" |
| #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| #include "mlir/Dialect/Math/IR/Math.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/SCF/IR/SCF.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/Diagnostics.h" |
| #include "mlir/IR/FunctionImplementation.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/IR/OpImplementation.h" |
| #include "mlir/IR/OperationSupport.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/IR/TypeUtilities.h" |
| #include "mlir/IR/Value.h" |
| #include "mlir/Support/LLVM.h" |
| #include "mlir/Support/LogicalResult.h" |
| #include "mlir/Support/MathExtras.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SmallSet.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| #include "llvm/Support/SMLoc.h" |
| |
| using namespace mlir; |
| using namespace mlir::iree_compiler::IREE::LinalgExt; |
| namespace IREE = mlir::iree_compiler::IREE; |
| |
| //===----------------------------------------------------------------------===// |
| // Utils. |
| //===----------------------------------------------------------------------===// |
| |
| static void getEffectsImpl( |
| SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> |
| &effects, |
| ValueRange results, ValueRange inputBuffers, ValueRange outputBuffers) { |
| for (Value value : results) { |
| effects.emplace_back(MemoryEffects::Allocate::get(), value, |
| SideEffects::DefaultResource::get()); |
| } |
| for (Value value : inputBuffers) { |
| effects.emplace_back(MemoryEffects::Read::get(), value, |
| SideEffects::DefaultResource::get()); |
| } |
| for (Value value : outputBuffers) { |
| effects.emplace_back(MemoryEffects::Read::get(), value, |
| SideEffects::DefaultResource::get()); |
| effects.emplace_back(MemoryEffects::Write::get(), value, |
| SideEffects::DefaultResource::get()); |
| } |
| } |
| |
| /// Returns a memref.subview or a tensor.extract_slice based on the type of the |
| /// `source`. |
| static Value getSlice(OpBuilder &b, Location loc, Value source, |
| ArrayRef<OpFoldResult> offsets, |
| ArrayRef<OpFoldResult> sizes, |
| ArrayRef<OpFoldResult> strides) { |
| return TypeSwitch<Type, Value>(source.getType()) |
| .Case<RankedTensorType>([&](RankedTensorType t) -> Value { |
| return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes, |
| strides); |
| }) |
| .Case<MemRefType>([&](MemRefType type) -> Value { |
| return b.create<memref::SubViewOp>(loc, source, offsets, sizes, |
| strides); |
| }) |
| .Default([&](Type t) { return nullptr; }); |
| } |
| |
| /// Returns true if the dimensions of ShapedType aren't dynamic or aren't equal. |
| static bool isShapedTypeDimEqual(int64_t lhs, int64_t rhs) { |
| return lhs != ShapedType::kDynamicSize && rhs != ShapedType::kDynamicSize && |
| lhs != rhs; |
| } |
| |
| Value IREE::LinalgExt::getDimValue(OpBuilder &builder, Location loc, Value v, |
| int64_t dim) { |
| return TypeSwitch<Type, Value>(v.getType()) |
| .Case<RankedTensorType>([&](RankedTensorType t) -> Value { |
| return builder.create<tensor::DimOp>(loc, v, dim); |
| }) |
| .Case<MemRefType>([&](MemRefType t) -> Value { |
| return builder.create<memref::DimOp>(loc, v, dim); |
| }) |
| .Default([&](Type t) { return Value(); }); |
| } |
| |
| OpFoldResult IREE::LinalgExt::getDim(OpBuilder &builder, Location loc, Value v, |
| int64_t dim) { |
| auto t = v.getType().cast<ShapedType>(); |
| if (t.isDynamicDim(dim)) { |
| return getDimValue(builder, loc, v, dim); |
| } |
| return builder.getI64IntegerAttr(t.getDimSize(dim)); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ScatterOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult ScatterOp::verify() { |
| Operation *op = getOperation(); |
| if (getInputs().size() != 2) { |
| return op->emitOpError("expected two input operands"); |
| } |
| if (getOutputs().size() != 1) { |
| return op->emitOpError("expected one output operand"); |
| } |
| auto checkDimensionsMatch = [&](ShapedType t1, ShapedType t2, unsigned dim) { |
| return t1.getShape()[dim] == t2.getShape()[dim]; |
| }; |
| |
| auto indicesType = getIndicesType(); |
| if (indicesType.getRank() != 2 || |
| !indicesType.getElementType().isInteger(32)) { |
| return op->emitOpError( |
| "expected indices to be of rank 2 of i32 element type"); |
| } |
| auto indexDepth = getIndexDepth(); |
| if (indexDepth == ShapedType::kDynamicSize) { |
| return op->emitOpError("expected index depth is static"); |
| } |
| |
| // The first dimension of the indices should match the first dimension of the |
| // output. They indicate to the number of updates. |
| auto updateType = getUpdateType(); |
| if (updateType.getRank() < 1) { |
| return op->emitOpError("expected update value to be at least rank 1"); |
| } |
| if (!checkDimensionsMatch(indicesType, updateType, 0)) { |
| return op->emitOpError( |
| "mismatch in shape of indices and update value at dim#0"); |
| } |
| auto originalType = getOriginalType(); |
| if (updateType.getRank() - 1 > originalType.getRank()) { |
| return op->emitOpError( |
| "update value rank exceeds the rank of the original value"); |
| } |
| |
| // indexDepth + update dims should cover the original dims. The first dim of |
| // update is the number of updates. |
| if (originalType.getRank() > indexDepth + updateType.getRank() - 1) { |
| return op->emitOpError( |
| "index depth and update value does not cover rank of original value"); |
| } |
| |
| // Validate the non-indexed update dims covier the full slice size of the |
| // original tensor. |
| int64_t fullSliceDims = originalType.getRank() - indexDepth; |
| for (auto it : |
| llvm::zip(llvm::seq<unsigned>(indexDepth, originalType.getRank()), |
| llvm::seq<unsigned>(updateType.getRank() - fullSliceDims, |
| updateType.getRank()))) { |
| int64_t originalDim = std::get<0>(it); |
| int64_t updateDim = std::get<1>(it); |
| if (updateType.getDimSize(updateDim) != |
| originalType.getDimSize(originalDim)) { |
| return op->emitOpError("mismatch in shape of update value dim#") |
| << updateDim << " and original value at dim#" << originalDim; |
| } |
| } |
| |
| // Check that the remaining update indices do not exceed the update length. |
| int64_t insertDims = originalType.getRank() - updateType.getRank() + 1; |
| for (auto it : llvm::zip( |
| llvm::seq<unsigned>(insertDims, indexDepth), |
| llvm::seq<unsigned>(1, updateType.getRank() - fullSliceDims))) { |
| int64_t originalDim = std::get<0>(it); |
| int64_t updateDim = std::get<1>(it); |
| if (updateType.getDimSize(updateDim) > |
| originalType.getDimSize(originalDim)) { |
| return op->emitOpError("indexed shape of update value dim#") |
| << updateDim << " exceeds original value at dim#" << originalDim |
| << " " << updateType.getDimSize(updateDim) << " " |
| << originalType.getDimSize(originalDim); |
| } |
| } |
| |
| Region ®ion = this->getRegion(); |
| Block *body = ®ion.front(); |
| if (body->getNumArguments() != 2) { |
| return op->emitOpError("expected region to have two arguments"); |
| } |
| Type arg0Type = body->getArgument(0).getType(); |
| Type arg1Type = body->getArgument(1).getType(); |
| if (!arg0Type.isIntOrFloat() || !arg1Type.isIntOrFloat()) { |
| return op->emitOpError( |
| "expected region to have scalar argument of integer or float types"); |
| } |
| if (arg0Type != updateType.getElementType()) { |
| return op->emitOpError("mismatch in argument 0 of region ") |
| << arg0Type << " and element type of update value " |
| << updateType.getElementType(); |
| } |
| if (arg1Type != originalType.getElementType()) { |
| return op->emitOpError("mismatch in argument 1 of region ") |
| << arg1Type << " and element type of original value " |
| << originalType.getElementType(); |
| } |
| if (arg0Type != arg1Type) { |
| return op->emitOpError("mismatch in region argument types ") |
| << arg0Type << " and " << arg1Type; |
| } |
| auto yieldOp = cast<IREE::LinalgExt::YieldOp>(body->getTerminator()); |
| if (yieldOp->getNumOperands() != 1) { |
| return yieldOp.emitOpError("expected region to yield a single value"); |
| } |
| auto yieldedType = yieldOp->getOperand(0).getType(); |
| if (yieldedType != arg0Type) { |
| return yieldOp.emitOpError("mismatch in type of yielded value ") |
| << yieldedType << " and argument of the region " << arg0Type; |
| } |
| return success(); |
| } |
| |
| SmallVector<utils::IteratorType> ScatterOp::getLoopIteratorTypes() { |
| SmallVector<utils::IteratorType> iteratorTypes(getUpdateType().getRank(), |
| utils::IteratorType::parallel); |
| if (!getUniqueIndices()) { |
| iteratorTypes[0] = utils::IteratorType::reduction; |
| } |
| return iteratorTypes; |
| } |
| |
| SmallVector<Range> ScatterOp::getIterationDomain(OpBuilder &builder) { |
| Location loc = getLoc(); |
| Value zero = builder.create<arith::ConstantIndexOp>(loc, 0); |
| Value one = builder.create<arith::ConstantIndexOp>(loc, 1); |
| SmallVector<Range> ranges; |
| for (auto dim : llvm::seq<int64_t>(0, getUpdateType().getRank())) { |
| Value ub = getDimValue(builder, loc, updates(), dim); |
| ranges.emplace_back(Range{zero, ub, one}); |
| } |
| return ranges; |
| } |
| |
| SmallVector<Operation *> |
| ScatterOp::getTiledImplementation(OpBuilder &builder, |
| ArrayRef<OpFoldResult> offsets, |
| ArrayRef<OpFoldResult> sizes) { |
| assert(offsets.size() >= 1 && sizes.size() >= 1); |
| Location loc = getLoc(); |
| auto zeroAttr = builder.getI64IntegerAttr(0); |
| auto oneAttr = builder.getI64IntegerAttr(1); |
| |
| // Slice of the updates. |
| auto updateRank = getUpdateType().getRank(); |
| SmallVector<OpFoldResult> updateStrides(updateRank, oneAttr); |
| Value tiledUpdate = |
| getSlice(builder, loc, updates(), offsets, sizes, updateStrides); |
| assert(tiledUpdate && "failed to get slice of update"); |
| |
| // Slice of indices. |
| auto indicesRank = getIndicesType().getRank(); |
| SmallVector<OpFoldResult> indicesOffsets(indicesRank, zeroAttr); |
| SmallVector<OpFoldResult> indicesSizes(indicesRank); |
| indicesOffsets[0] = offsets[0]; |
| indicesSizes[0] = sizes[0]; |
| for (auto dim : llvm::seq<int64_t>(1, indicesRank)) { |
| indicesSizes[dim] = getDim(builder, loc, indices(), dim); |
| } |
| SmallVector<OpFoldResult> indicesStrides(indicesRank, oneAttr); |
| Value tiledIndices = getSlice(builder, loc, indices(), indicesOffsets, |
| indicesSizes, indicesStrides); |
| assert(tiledIndices && "failed to get slice of indices"); |
| |
| // Slice of the original. |
| SmallVector<OpFoldResult> originalOffsets, originalSizes; |
| if (failed(getResultTilePosition(builder, 0, offsets, sizes, originalOffsets, |
| originalSizes))) { |
| return {}; |
| } |
| auto originalRank = getOriginalType().getRank(); |
| SmallVector<OpFoldResult> originalStrides(originalRank, oneAttr); |
| Value tiledOriginal = getSlice(builder, loc, original(), originalOffsets, |
| originalSizes, originalStrides); |
| assert(tiledOriginal && "failed to get slice of original tensor"); |
| |
| SmallVector<Type> resultTypes; |
| if (getNumResults()) { |
| resultTypes.push_back(tiledOriginal.getType()); |
| } |
| Operation *tiledScatterOp = |
| cast<LinalgExtOp>(getOperation()) |
| .clone(builder, loc, resultTypes, |
| ValueRange{tiledUpdate, tiledIndices, tiledOriginal}); |
| return {tiledScatterOp}; |
| } |
| |
| LogicalResult ScatterOp::getResultTilePosition( |
| OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets, |
| ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets, |
| SmallVector<OpFoldResult> &resultSizes) { |
| auto zeroAttr = builder.getI64IntegerAttr(0); |
| // Slice of the original. |
| auto originalRank = getOriginalType().getRank(); |
| resultOffsets.resize(originalRank, zeroAttr); |
| resultSizes.resize(originalRank); |
| |
| auto updateRank = getUpdateType().getRank(); |
| Location loc = getLoc(); |
| for (auto dim : llvm::seq<int64_t>(0, originalRank - updateRank + 1)) { |
| resultSizes[dim] = getDim(builder, loc, original(), dim); |
| } |
| for (auto dim : |
| llvm::seq<int64_t>(originalRank - updateRank + 1, originalRank)) { |
| resultOffsets[dim] = offsets[dim - (originalRank - updateRank)]; |
| resultSizes[dim] = sizes[dim - (originalRank - updateRank)]; |
| } |
| return success(); |
| } |
| |
| LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b, |
| Location loc, |
| ValueRange ivs) { |
| auto indexDepth = getIndexDepth(); |
| Value update = b.create<memref::LoadOp>(loc, updates(), ivs); |
| SmallVector<Value> starts; |
| SmallVector<Value> loadIndices; |
| loadIndices.push_back(ivs.front()); |
| loadIndices.push_back(Value()); |
| |
| // Populate with empty values. |
| auto originalTy = original().getType().cast<ShapedType>(); |
| starts.resize(originalTy.getRank(), Value()); |
| auto updateIvs = ivs.drop_front(1); |
| |
| int64_t offset = starts.size() - updateIvs.size(); |
| for (auto it : llvm::enumerate(updateIvs)) { |
| starts[it.index() + offset] = it.value(); |
| } |
| |
| for (auto i : llvm::seq<unsigned>(0, indexDepth)) { |
| loadIndices.back() = b.create<arith::ConstantIndexOp>(loc, i); |
| Value idx = b.create<memref::LoadOp>(loc, indices(), loadIndices); |
| Value cast = b.create<arith::IndexCastOp>(loc, b.getIndexType(), idx); |
| |
| if (starts[i]) |
| cast = b.create<arith::AddIOp>(loc, cast, starts[i]); |
| starts[i] = cast; |
| } |
| |
| Value init = b.create<memref::LoadOp>(loc, original(), starts); |
| |
| BlockAndValueMapping bvm; |
| Block &block = getRegion().front(); |
| bvm.map(block.getArgument(0), update); |
| bvm.map(block.getArgument(1), init); |
| for (auto &blockOp : block.without_terminator()) { |
| b.clone(blockOp, bvm); |
| } |
| // The last op is linalg_ext.yield op. Store the operand to |
| // destination. |
| b.create<memref::StoreOp>( |
| loc, bvm.lookupOrDefault(block.getTerminator()->getOperand(0)), |
| original(), starts); |
| return success(); |
| } |
| |
| LogicalResult |
| ScatterOp::reifyResultShapes(OpBuilder &b, |
| ReifiedRankedShapedTypeDims &reifiedReturnShapes) { |
| return cast<LinalgExtOp>(getOperation()) |
| .reifyResultShapes(b, reifiedReturnShapes); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // SortOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult SortOp::verify() { |
| Operation *op = getOperation(); |
| if (getNumInputs()) { |
| return op->emitOpError("does not expect to take any inputs"); |
| } |
| if (getNumOutputs() == 0) { |
| return op->emitOpError("expected at least one `outs` operand"); |
| } |
| |
| Block &block = getRegion().front(); |
| size_t numOutputs = getNumOutputs(); |
| if (block.getNumArguments() != 2 * numOutputs) { |
| return op->emitOpError("region block should have ") |
| << 2 * numOutputs << " arguments"; |
| } |
| |
| int64_t rank = getOperandRank(); |
| int sortDim = getDimension(); |
| if (sortDim < 0 || sortDim >= rank) { |
| return op->emitOpError("dimension must be within (0, ") << rank << "]"; |
| } |
| |
| ArrayRef<int64_t> shape = getOperandShape(); |
| for (auto indexedOperand : llvm::enumerate(getOutputs())) { |
| int index = indexedOperand.index(); |
| auto operandType = getOperandType(index); |
| if (operandType.getRank() != rank) { |
| return op->emitOpError("expected operand ") |
| << index << " to be rank " << rank << ", same as other operands"; |
| } |
| if (operandType.getShape() != shape) { |
| return op->emitOpError("expected operand ") |
| << index << " to have same shape as other operands"; |
| } |
| Type elemType = operandType.getElementType(); |
| for (int i : {2 * index, 2 * index + 1}) { |
| Type argType = block.getArgument(i).getType(); |
| if (argType != elemType) { |
| return op->emitOpError("region block argument #") |
| << i << " should be of type " << elemType << " but got " |
| << argType; |
| } |
| } |
| } |
| |
| auto yieldOp = cast<YieldOp>(block.getTerminator()); |
| if (yieldOp.getNumOperands() != 1) { |
| return op->emitOpError("should yield exactly one operand"); |
| } |
| auto ty = yieldOp.getOperand(0).getType().dyn_cast<IntegerType>(); |
| if (!ty || ty.getWidth() != 1) { |
| return op->emitOpError("should yield i1 type"); |
| } |
| |
| return success(); |
| } |
| |
| SmallVector<utils::IteratorType> SortOp::getLoopIteratorTypes() { |
| // All loops except the dimension to sort along are parallel. |
| SmallVector<utils::IteratorType> iteratorTypes(getOperandRank(), |
| utils::IteratorType::parallel); |
| iteratorTypes[getDimension()] = utils::IteratorType::reduction; |
| return iteratorTypes; |
| } |
| |
| SmallVector<Range> SortOp::getIterationDomain(OpBuilder &builder) { |
| int64_t operandRank = getOperandRank(); |
| SmallVector<Range> loopBounds(operandRank); |
| Location loc = getLoc(); |
| Value zero = builder.create<arith::ConstantIndexOp>(loc, 0); |
| Value one = builder.create<arith::ConstantIndexOp>(loc, 1); |
| Value source = operand(0); |
| for (auto dim : llvm::seq<int64_t>(0, operandRank)) { |
| loopBounds[dim].offset = zero; |
| loopBounds[dim].size = getDimValue(builder, loc, source, dim); |
| loopBounds[dim].stride = one; |
| } |
| return loopBounds; |
| } |
| |
| SmallVector<Operation *> |
| SortOp::getTiledImplementation(OpBuilder &builder, |
| ArrayRef<OpFoldResult> offsets, |
| ArrayRef<OpFoldResult> sizes) { |
| int64_t rank = getOperandRank(); |
| assert(offsets.size() == static_cast<size_t>(rank) && |
| sizes.size() == static_cast<size_t>(rank)); |
| auto oneAttr = builder.getI64IntegerAttr(1); |
| SmallVector<OpFoldResult> strides(rank, oneAttr); |
| Location loc = getLoc(); |
| SmallVector<Value> tiledOperands(getOutputs().size()); |
| for (auto en : llvm::enumerate(getOutputs())) { |
| tiledOperands[en.index()] = |
| getSlice(builder, getLoc(), en.value(), offsets, sizes, strides); |
| assert(tiledOperands[en.index()] && "failed to get slice of operand"); |
| } |
| SmallVector<Type, 4> resultTypes; |
| if (getNumResults()) { |
| resultTypes = llvm::to_vector<4>( |
| llvm::map_range(tiledOperands, [&](Value v) { return v.getType(); })); |
| } |
| Operation *tiledSortOp = cast<LinalgExtOp>(getOperation()) |
| .clone(builder, loc, resultTypes, tiledOperands); |
| return {tiledSortOp}; |
| } |
| |
| LogicalResult SortOp::getResultTilePosition( |
| OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets, |
| ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets, |
| SmallVector<OpFoldResult> &resultSizes) { |
| resultOffsets = llvm::to_vector(offsets); |
| resultSizes = llvm::to_vector(sizes); |
| return success(); |
| } |
| |
| LogicalResult SortOp::generateScalarImplementation(OpBuilder &b, Location loc, |
| ValueRange ivs) { |
| auto sortDim = getDimension(); |
| SmallVector<Value> indices, sortBlkArgs; |
| indices.append(ivs.begin(), ivs.end()); |
| // Bubble sort innermost loop. |
| Value zero = b.create<arith::ConstantIndexOp>(loc, 0); |
| Value one = b.create<arith::ConstantIndexOp>(loc, 1); |
| Value ub; |
| if (getOperandType(0).isDynamicDim(sortDim)) { |
| ub = b.create<memref::DimOp>(loc, operand(0), sortDim); |
| } else { |
| ub = b.create<arith::ConstantIndexOp>( |
| loc, getOperandType(0).getDimSize(sortDim)); |
| } |
| ub = b.create<arith::SubIOp>(loc, ub, one); |
| auto scfFor = b.create<scf::ForOp>( |
| loc, zero, ub, one, ValueRange{}, |
| [&](OpBuilder &b, Location loc, Value iv, ValueRange iters) { |
| SmallVector<Value> indices(ivs); |
| Value ivPlusOne = b.create<arith::AddIOp>(loc, iv, one); |
| for (auto output : getOutputOperands()) { |
| indices[sortDim] = iv; |
| sortBlkArgs.push_back( |
| b.create<memref::LoadOp>(loc, output->get(), indices)); |
| indices[sortDim] = ivPlusOne; |
| sortBlkArgs.push_back( |
| b.create<memref::LoadOp>(loc, output->get(), indices)); |
| } |
| }); |
| |
| auto &srcBlock = getRegion().front(); |
| Region ®ion = scfFor.getRegion(); |
| BlockAndValueMapping bvm; |
| { |
| OpBuilder::InsertionGuard guard(b); |
| auto &block = region.front(); |
| b.setInsertionPointToEnd(&block); |
| for (auto it : llvm::zip(srcBlock.getArguments(), sortBlkArgs)) { |
| bvm.map(std::get<0>(it), std::get<1>(it)); |
| } |
| for (auto &blockOp : srcBlock.without_terminator()) { |
| b.clone(blockOp, bvm); |
| } |
| } |
| Value cond = bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)); |
| |
| OpBuilder::InsertionGuard g(b); |
| b.setInsertionPointToEnd(®ion.front()); |
| b.create<scf::IfOp>( |
| loc, TypeRange{}, cond, |
| [&](OpBuilder &b, Location loc) { |
| // Do not swap the pairs if true. |
| b.create<scf::YieldOp>(loc); |
| }, |
| [&](OpBuilder &b, Location loc) { |
| // Swap the pairs if false. |
| SmallVector<Value> indices(ivs.begin(), ivs.end()); |
| Value ivPlusOne = |
| b.create<arith::AddIOp>(loc, scfFor.getInductionVar(), one); |
| for (int i = 0, e = getNumOutputs(); i < e; ++i) { |
| Value v1 = sortBlkArgs[i * 2]; |
| Value v2 = sortBlkArgs[i * 2 + 1]; |
| indices[sortDim] = scfFor.getInductionVar(); |
| b.create<memref::StoreOp>(loc, v2, getOutputOperand(i)->get(), |
| indices); |
| indices[sortDim] = ivPlusOne; |
| b.create<memref::StoreOp>(loc, v1, getOutputOperand(i)->get(), |
| indices); |
| } |
| b.create<scf::YieldOp>(loc); |
| }); |
| b.create<scf::YieldOp>(loc); |
| return success(); |
| } |
| |
| LogicalResult |
| SortOp::reifyResultShapes(OpBuilder &b, |
| ReifiedRankedShapedTypeDims &reifiedReturnShapes) { |
| return cast<LinalgExtOp>(getOperation()) |
| .reifyResultShapes(b, reifiedReturnShapes); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // FftOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult FftOp::verify() { |
| Operation *op = getOperation(); |
| auto length = getFftLength(); |
| // After tiling, it could be dynamic shape. (Because |
| // subview/subtensor does not inference the type correctly |
| // on (1 << x)) cases). |
| if (length == ShapedType::kDynamicSize) |
| return success(); |
| if (length & (length - 1)) { |
| return op->emitOpError("only powers of 2 are handled currently"); |
| } |
| if (!getNumInputs() || !isScalar(getInputOperand(0))) { |
| return op->emitOpError("expected to carry `stage` input"); |
| } |
| if (getNumInputs() != 1) { |
| if (getNumInputs() != 3 || isScalar(getInputOperand(1)) || |
| isScalar(getInputOperand(2))) { |
| return op->emitOpError("expected to carry real and imag coeff inputs"); |
| } |
| } |
| if (getNumOutputs() != 2) { |
| return op->emitOpError( |
| "expected outputs to be real and imag tensor/memref"); |
| } |
| return success(); |
| } |
| |
| SmallVector<utils::IteratorType> FftOp::getLoopIteratorTypes() { |
| // There are `rank-1` outer loops. The fft itselfs has one loop for each |
| // stage, which handles the merge step -- taking two half size tensors and |
| // merge them into one tensor. |
| SmallVector<utils::IteratorType> iteratorTypes(getOperandRank(), |
| utils::IteratorType::parallel); |
| return iteratorTypes; |
| } |
| |
| SmallVector<Range> FftOp::getIterationDomain(OpBuilder &builder) { |
| SmallVector<Range> res; |
| Location loc = getLoc(); |
| Value zero = builder.create<arith::ConstantIndexOp>(loc, 0); |
| Value one = builder.create<arith::ConstantIndexOp>(loc, 1); |
| for (auto en : llvm::enumerate(getOperandShape().drop_back())) { |
| Value size; |
| if (en.value() == ShapedType::kDynamicSize) { |
| size = getDimValue(builder, loc, getReal(), en.index()); |
| } else { |
| size = builder.create<arith::ConstantIndexOp>(loc, en.value()); |
| } |
| res.emplace_back(Range{/*offset=*/zero, size, /*stride=*/one}); |
| } |
| |
| Value size = getDimValue(builder, loc, getReal(), getOperandRank() - 1); |
| Value stride = builder.create<arith::ShLIOp>(loc, one, getStage()); |
| res.emplace_back(Range{/*offset=*/zero, size, /*stride=*/stride}); |
| return res; |
| } |
| |
| void FftOp::generateScalarImplWithoutCoeffBuf(OpBuilder &b, Location loc, |
| ArrayRef<Value> operands, |
| Value wholeSize) { |
| auto rank = getOperandRank(); |
| SmallVector<AffineMap> maps(operands.size(), b.getMultiDimIdentityMap(rank)); |
| |
| auto f32Type = b.getF32Type(); |
| auto indexToF32 = [](OpBuilder &builder, Location loc, Value v) -> Value { |
| v = builder.create<arith::IndexCastOp>(loc, builder.getI32Type(), v); |
| return builder.create<arith::SIToFPOp>(loc, builder.getF32Type(), v); |
| }; |
| |
| // We will need exp(-2 * PI * j / m * I), compute "-2 * PI / m" for imag part |
| // first. |
| Value coeff = b.create<arith::ConstantFloatOp>( |
| loc, llvm::APFloat(static_cast<float>(-2 * acos(-1))), f32Type); |
| coeff = b.create<arith::DivFOp>(loc, coeff, indexToF32(b, loc, wholeSize)); |
| |
| SmallVector<StringRef> iteratorTypes = llvm::to_vector( |
| llvm::map_range(getLoopIteratorTypes(), [](utils::IteratorType it) { |
| return utils::stringifyIteratorType(it); |
| })); |
| b.create<linalg::GenericOp>( |
| loc, TypeRange{}, ValueRange{}, operands, maps, iteratorTypes, |
| [&](OpBuilder &b, Location loc, ValueRange args) { |
| Value lhsReal = args[0]; |
| Value lhsImag = args[1]; |
| Value rhsReal = args[2]; |
| Value rhsImag = args[3]; |
| |
| // Compute "-2 * PI / m * j" |
| Value w = b.create<arith::MulFOp>( |
| loc, coeff, |
| indexToF32(b, loc, b.create<linalg::IndexOp>(loc, rank - 1))); |
| Value wReal = b.create<math::CosOp>(loc, w); |
| Value wImag = b.create<math::SinOp>(loc, w); |
| |
| // t = w * a[k + j + mh]; |
| // -> (x + yi)(u + vi) = (xu - yv) + (xv + yu)i |
| Value xu = b.create<arith::MulFOp>(loc, wReal, rhsReal); |
| Value yv = b.create<arith::MulFOp>(loc, wImag, rhsImag); |
| Value xv = b.create<arith::MulFOp>(loc, wReal, rhsImag); |
| Value yu = b.create<arith::MulFOp>(loc, wImag, rhsReal); |
| Value tReal = b.create<arith::SubFOp>(loc, xu, yv); |
| Value tImag = b.create<arith::AddFOp>(loc, xv, yu); |
| |
| // cplx u = a[k + j]; |
| // a[k + j] = u + t; |
| // a[k + j + mh] = u - t; |
| Value r1 = b.create<arith::AddFOp>(loc, lhsReal, tReal); |
| Value r2 = b.create<arith::AddFOp>(loc, lhsImag, tImag); |
| Value r3 = b.create<arith::SubFOp>(loc, lhsReal, tReal); |
| Value r4 = b.create<arith::SubFOp>(loc, lhsImag, tImag); |
| b.create<linalg::YieldOp>(loc, ValueRange{r1, r2, r3, r4}); |
| }); |
| } |
| |
| void FftOp::generateScalarImplWithCoeffBuf(OpBuilder &b, Location loc, |
| ArrayRef<Value> operands) { |
| auto rank = getOperandRank(); |
| SmallVector<AffineMap> maps; |
| // The size of coefficent buffer is epxected to match `2^(stage-1)`, which |
| // equals to the last dim of operands. |
| maps.append( |
| 2, AffineMap::get(rank, 0, b.getAffineDimExpr(rank - 1), b.getContext())); |
| maps.append(operands.size(), b.getMultiDimIdentityMap(rank)); |
| |
| SmallVector<StringRef> iteratorTypes = llvm::to_vector( |
| llvm::map_range(getLoopIteratorTypes(), [](utils::IteratorType it) { |
| return utils::stringifyIteratorType(it); |
| })); |
| b.create<linalg::GenericOp>( |
| loc, TypeRange{}, ValueRange{getRealCoeff(), getImagCoeff()}, operands, |
| maps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { |
| Value wReal = args[0]; |
| Value wImag = args[1]; |
| Value lhsReal = args[2]; |
| Value lhsImag = args[3]; |
| Value rhsReal = args[4]; |
| Value rhsImag = args[5]; |
| |
| // t = w * a[k + j + mh]; |
| // -> (x + yi)(u + vi) = (xu - yv) + (xv + yu)i |
| Value xu = b.create<arith::MulFOp>(loc, wReal, rhsReal); |
| Value yv = b.create<arith::MulFOp>(loc, wImag, rhsImag); |
| Value xv = b.create<arith::MulFOp>(loc, wReal, rhsImag); |
| Value yu = b.create<arith::MulFOp>(loc, wImag, rhsReal); |
| Value tReal = b.create<arith::SubFOp>(loc, xu, yv); |
| Value tImag = b.create<arith::AddFOp>(loc, xv, yu); |
| |
| // cplx u = a[k + j]; |
| // a[k + j] = u + t; |
| // a[k + j + mh] = u - t; |
| Value r1 = b.create<arith::AddFOp>(loc, lhsReal, tReal); |
| Value r2 = b.create<arith::AddFOp>(loc, lhsImag, tImag); |
| Value r3 = b.create<arith::SubFOp>(loc, lhsReal, tReal); |
| Value r4 = b.create<arith::SubFOp>(loc, lhsImag, tImag); |
| b.create<linalg::YieldOp>(loc, ValueRange{r1, r2, r3, r4}); |
| }); |
| } |
| |
| // Generates FFT stage scalar implementation. This follows Cooley–Tukey FFT |
| // algorithm. The pseudo reference code is: |
| // let s <- stage of linalg_ext.fft |
| // int m = 1 << s; |
| // int mh = m >> 1; |
| // for (int k = 0; k < n; k += m) { |
| // for (int j = 0; j < mh; ++j) { |
| // cplx w = exp(-2 * PI * j / m * I); |
| // cplx t = w * a[k + j + mh]; |
| // cplx u = a[k + j]; |
| // a[k + j] = u + t; |
| // a[k + j + mh] = u - t; |
| // } |
| // } |
| LogicalResult FftOp::generateScalarImplementation(OpBuilder &b, Location loc, |
| ValueRange ivs) { |
| Value real = getReal(); |
| Value imag = getImag(); |
| Value stage = getStage(); |
| Value one = b.create<arith::ConstantIndexOp>(loc, 1); |
| Value wholeSize = b.create<arith::ShLIOp>(loc, one, stage); |
| Value halfSize = b.create<arith::ShRSIOp>(loc, wholeSize, one); |
| |
| auto rank = getOperandRank(); |
| SmallVector<Value> operands; |
| SmallVector<OpFoldResult> lhsIvs(ivs.begin(), ivs.end()); |
| SmallVector<OpFoldResult> ones(rank, b.getIndexAttr(1)); |
| SmallVector<OpFoldResult> sizes(rank, b.getIndexAttr(1)); |
| sizes.back() = halfSize; |
| operands.push_back( |
| b.create<memref::SubViewOp>(loc, real, lhsIvs, sizes, ones)); |
| operands.push_back( |
| b.create<memref::SubViewOp>(loc, imag, lhsIvs, sizes, ones)); |
| |
| SmallVector<OpFoldResult> rhsIvs(ivs.begin(), ivs.end()); |
| rhsIvs.back() = |
| b.create<arith::AddIOp>(loc, ivs.back(), halfSize).getResult(); |
| operands.push_back( |
| b.create<memref::SubViewOp>(loc, real, rhsIvs, sizes, ones)); |
| operands.push_back( |
| b.create<memref::SubViewOp>(loc, imag, rhsIvs, sizes, ones)); |
| |
| if (hasCoeff()) { |
| generateScalarImplWithCoeffBuf(b, loc, operands); |
| } else { |
| generateScalarImplWithoutCoeffBuf(b, loc, operands, wholeSize); |
| } |
| |
| return success(); |
| } |
| |
| SmallVector<Operation *> |
| FftOp::getTiledImplementation(OpBuilder &builder, |
| ArrayRef<OpFoldResult> offsets, |
| ArrayRef<OpFoldResult> sizes) { |
| int64_t rank = getOperandRank(); |
| SmallVector<OpFoldResult> strides(rank, builder.getI64IntegerAttr(1)); |
| Location loc = getLoc(); |
| SmallVector<Value> tiledOperands(3); |
| tiledOperands[0] = getStage(); |
| tiledOperands[1] = getRealCoeff(); |
| tiledOperands[2] = getImagCoeff(); |
| SmallVector<Type, 4> resultTypes; |
| |
| for (auto out : getOutputs()) { |
| tiledOperands.push_back( |
| getSlice(builder, getLoc(), out, offsets, sizes, strides)); |
| if (hasTensorSemantics()) { |
| resultTypes.push_back(tiledOperands.back().getType()); |
| } |
| } |
| Operation *tiledFftOp = cast<LinalgExtOp>(getOperation()) |
| .clone(builder, loc, resultTypes, tiledOperands); |
| return {tiledFftOp}; |
| } |
| |
| LogicalResult FftOp::getResultTilePosition( |
| OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets, |
| ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets, |
| SmallVector<OpFoldResult> &resultSizes) { |
| resultOffsets.assign(offsets.begin(), offsets.end()); |
| resultSizes.assign(sizes.begin(), sizes.end()); |
| return success(); |
| } |
| |
| LogicalResult |
| FftOp::reifyResultShapes(OpBuilder &b, |
| ReifiedRankedShapedTypeDims &reifiedReturnShapes) { |
| return cast<LinalgExtOp>(getOperation()) |
| .reifyResultShapes(b, reifiedReturnShapes); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ScanOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult ScanOp::verify() { |
| Operation *op = getOperation(); |
| if (getNumInputs() != 1) { |
| return op->emitOpError("expected one input operands"); |
| } |
| if (getNumOutputs() != 2) { |
| return op->emitOpError("expected two output operands"); |
| } |
| if (!input().getType().isa<ShapedType>()) { |
| return op->emitOpError("expected first input element type to be shaped"); |
| } |
| auto accumulatorType = accumulator().getType().cast<ShapedType>(); |
| auto inputType = input().getType().cast<ShapedType>(); |
| auto outputType = output().getType().cast<ShapedType>(); |
| ArrayRef<int64_t> inputShapes = inputType.getShape(); |
| ArrayRef<int64_t> outputShapes = outputType.getShape(); |
| if (accumulatorType.getElementType() != inputType.getElementType()) { |
| return op->emitOpError( |
| "expected input/accumulator element types to be identical"); |
| } |
| ArrayRef<int64_t> accumulatorShape = accumulatorType.getShape(); |
| int64_t accumulatorRank = accumulatorType.getRank(); |
| if (accumulatorRank != inputType.getRank() - 1) { |
| return op->emitOpError( |
| "expected accumulator rank to be equal to input rank - 1"); |
| } |
| SmallVector<int64_t> expectedAccumulatorShape; |
| for (int i = 0; i < inputType.getRank(); i++) { |
| if (i != getDimension()) |
| expectedAccumulatorShape.push_back(inputShapes[i]); |
| } |
| if (llvm::any_of(llvm::zip(expectedAccumulatorShape, accumulatorShape), |
| [](std::tuple<int64_t, int64_t> s) { |
| return std::get<0>(s) != ShapedType::kDynamicSize && |
| std::get<1>(s) != ShapedType::kDynamicSize && |
| std::get<0>(s) != std::get<1>(s); |
| })) { |
| return op->emitOpError("incompatible input/accumulator shapes"); |
| } |
| if (inputType.getElementType() != outputType.getElementType()) { |
| return op->emitOpError( |
| "expected input/output element types to be identical"); |
| } |
| if (inputShapes.size() != outputShapes.size()) { |
| return op->emitOpError("expected input/output to have identical ranks"); |
| } |
| if (llvm::any_of(llvm::zip(inputShapes, outputShapes), |
| [](std::tuple<int64_t, int64_t> s) { |
| return std::get<0>(s) != ShapedType::kDynamicSize && |
| std::get<1>(s) != ShapedType::kDynamicSize && |
| std::get<0>(s) != std::get<1>(s); |
| })) { |
| return op->emitOpError("incompatible input/output shapes"); |
| } |
| return success(); |
| } |
| |
| SmallVector<Range> ScanOp::getIterationDomain(OpBuilder &builder) { |
| int64_t operandRank = getOperandRank(); |
| SmallVector<Range> loopBounds(operandRank); |
| Location loc = getLoc(); |
| Value zero = builder.create<arith::ConstantIndexOp>(loc, 0); |
| Value one = builder.create<arith::ConstantIndexOp>(loc, 1); |
| Value source = input(); |
| for (auto dim : llvm::seq<int64_t>(0, operandRank)) { |
| loopBounds[dim].offset = zero; |
| loopBounds[dim].size = getDimValue(builder, loc, source, dim); |
| loopBounds[dim].stride = one; |
| } |
| return loopBounds; |
| } |
| |
| SmallVector<utils::IteratorType> ScanOp::getLoopIteratorTypes() { |
| SmallVector<utils::IteratorType> iteratorTypes(getOperandRank(), |
| utils::IteratorType::parallel); |
| iteratorTypes[getDimension()] = utils::IteratorType::reduction; |
| return iteratorTypes; |
| } |
| |
| // Generates naive scalar implementation of scan for a given operator f. |
| // For inclusive, |
| // output[0] = input[0] |
| // output[i] = f(output[i-1], input[i]) |
| // |
| // For exclusive, |
| // output[0] = 0 |
| // output[i] = f(output[i-1], input[i-1]) |
| |
| LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc, |
| ValueRange ivs) { |
| SmallVector<Value> indices, scanBlkArgs; |
| indices.append(ivs.begin(), ivs.end()); |
| Value zero = b.create<arith::ConstantIndexOp>(loc, 0); |
| Value one = b.create<arith::ConstantIndexOp>(loc, 1); |
| auto scanDim = getDimension(); |
| auto cond = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, |
| indices[scanDim], zero); |
| bool isInclusive = getInclusive(); |
| SmallVector<Value> accIndices; |
| for (int i = 0; i < indices.size(); i++) { |
| if (i != scanDim) |
| accIndices.push_back(indices[i]); |
| } |
| |
| auto scfIf = b.create<scf::IfOp>( |
| loc, TypeRange{}, cond, |
| [&](OpBuilder &b, Location loc) { |
| if (isInclusive) { |
| auto value = b.create<memref::LoadOp>(loc, input(), indices); |
| b.create<memref::StoreOp>(loc, value, output(), indices); |
| } else { |
| auto value = b.create<memref::LoadOp>(loc, accumulator(), accIndices); |
| b.create<memref::StoreOp>(loc, value, output(), indices); |
| } |
| b.create<scf::YieldOp>(loc); |
| }, |
| [&](OpBuilder &b, Location loc) { |
| SmallVector<Value> indices(ivs.begin(), ivs.end()); |
| Value iv = indices[scanDim]; |
| Value ivMinusOne = b.create<arith::SubIOp>(loc, iv, one); |
| indices[scanDim] = ivMinusOne; |
| scanBlkArgs.push_back(b.create<memref::LoadOp>(loc, output(), indices)); |
| Value i0; |
| if (!isInclusive) |
| i0 = b.create<memref::LoadOp>(loc, input(), indices); |
| indices[scanDim] = iv; |
| if (isInclusive) |
| i0 = b.create<memref::LoadOp>(loc, input(), indices); |
| scanBlkArgs.push_back(i0); |
| }); |
| |
| auto &srcBlock = getRegion().front(); |
| Region ®ion = scfIf.getElseRegion(); |
| BlockAndValueMapping bvm; |
| { |
| OpBuilder::InsertionGuard guard(b); |
| auto &block = region.front(); |
| b.setInsertionPointToEnd(&block); |
| for (auto it : llvm::zip(srcBlock.getArguments(), scanBlkArgs)) { |
| bvm.map(std::get<0>(it), std::get<1>(it)); |
| } |
| for (auto &blockOp : srcBlock.without_terminator()) { |
| b.clone(blockOp, bvm); |
| } |
| b.create<memref::StoreOp>( |
| loc, bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)), |
| output(), indices); |
| b.create<memref::StoreOp>( |
| loc, bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)), |
| accumulator(), accIndices); |
| b.create<scf::YieldOp>(loc); |
| } |
| return success(); |
| } |
| |
| SmallVector<Operation *> |
| ScanOp::getTiledImplementation(OpBuilder &builder, |
| ArrayRef<OpFoldResult> offsets, |
| ArrayRef<OpFoldResult> sizes) { |
| int64_t rank = getOperandRank(); |
| assert(offsets.size() == static_cast<size_t>(rank) && |
| sizes.size() == static_cast<size_t>(rank)); |
| auto oneAttr = builder.getI64IntegerAttr(1); |
| SmallVector<OpFoldResult> strides(rank, oneAttr); |
| Location loc = getLoc(); |
| SmallVector<Value> tiledOperands; |
| tiledOperands.emplace_back( |
| getSlice(builder, getLoc(), input(), offsets, sizes, strides)); |
| tiledOperands.emplace_back( |
| getSlice(builder, getLoc(), getOutputs()[0], offsets, sizes, strides)); |
| if (rank > 1) { |
| SmallVector<OpFoldResult> accumOffsets, accumSizes; |
| if (failed(getResultTilePosition(builder, 1, offsets, sizes, accumOffsets, |
| accumSizes))) { |
| return {}; |
| } |
| SmallVector<OpFoldResult> accumStrides(rank - 1, oneAttr); |
| tiledOperands.emplace_back(getSlice(builder, getLoc(), getOutputs()[1], |
| accumOffsets, accumSizes, |
| accumStrides)); |
| } else { |
| tiledOperands.emplace_back(getOutputs()[1]); |
| } |
| |
| SmallVector<Type, 4> resultTypes; |
| if (hasTensorSemantics()) { |
| resultTypes.push_back(tiledOperands[1].getType()); |
| resultTypes.push_back(tiledOperands[2].getType()); |
| } |
| |
| Operation *tiledScanOp = cast<LinalgExtOp>(getOperation()) |
| .clone(builder, loc, resultTypes, tiledOperands); |
| return {tiledScanOp}; |
| } |
| |
| LogicalResult ScanOp::getResultTilePosition( |
| OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets, |
| ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets, |
| SmallVector<OpFoldResult> &resultSizes) { |
| if (resultNumber == 0) { |
| resultOffsets.assign(offsets.begin(), offsets.end()); |
| resultSizes.assign(sizes.begin(), sizes.end()); |
| return success(); |
| } |
| if (resultNumber == 1) { |
| int64_t rank = getOperandRank(); |
| if (rank > 1) { |
| for (auto i : llvm::seq<int64_t>(0, rank)) { |
| if (i == getDimension()) |
| continue; |
| resultOffsets.push_back(offsets[i]); |
| resultSizes.push_back(sizes[i]); |
| } |
| } |
| return success(); |
| } |
| return failure(); |
| } |
| |
| static LogicalResult foldMemRefCast(Operation *op) { |
| bool folded = false; |
| for (OpOperand &operand : op->getOpOperands()) { |
| auto castOp = operand.get().getDefiningOp<memref::CastOp>(); |
| if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { |
| operand.set(castOp.getOperand()); |
| folded = true; |
| } |
| } |
| return success(folded); |
| } |
| |
| LogicalResult ScanOp::fold(ArrayRef<Attribute>, |
| SmallVectorImpl<OpFoldResult> &) { |
| return foldMemRefCast(*this); |
| } |
| |
| LogicalResult |
| ScanOp::reifyResultShapes(OpBuilder &b, |
| ReifiedRankedShapedTypeDims &reifiedReturnShapes) { |
| return cast<LinalgExtOp>(getOperation()) |
| .reifyResultShapes(b, reifiedReturnShapes); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ReverseOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult ReverseOp::verify() { |
| Operation *op = getOperation(); |
| if (getNumInputs() != 1) { |
| return op->emitOpError("expected exactly one input"); |
| } |
| if (getNumOutputs() != 1) { |
| return op->emitOpError("expected exactly one output"); |
| } |
| auto inputType = input().getType().cast<ShapedType>(); |
| auto outputType = output().getType().cast<ShapedType>(); |
| if (inputType.getElementType() != outputType.getElementType()) { |
| return op->emitOpError( |
| "expected input/output element types to be identical"); |
| } |
| ArrayRef<int64_t> inputShapes = inputType.getShape(); |
| ArrayRef<int64_t> outputShapes = outputType.getShape(); |
| if (inputShapes.size() != outputShapes.size()) { |
| return op->emitOpError("expexted input/output to have identical ranks"); |
| } |
| if (llvm::any_of(llvm::zip(inputShapes, outputShapes), |
| [](std::tuple<int64_t, int64_t> s) { |
| return std::get<0>(s) != ShapedType::kDynamicSize && |
| std::get<1>(s) != ShapedType::kDynamicSize && |
| std::get<0>(s) != std::get<1>(s); |
| })) { |
| return op->emitOpError("incompatible input/output shapes"); |
| } |
| |
| int64_t rank = getOperandRank(); |
| llvm::SmallSetVector<int64_t, 4> s; |
| for (auto dim : dims()) { |
| if (dim < 0 || dim >= rank) { |
| return op->emitOpError("all the dimensions must be within [0, ") |
| << rank << ")"; |
| } |
| if (s.contains(dim)) { |
| return op->emitOpError("expected dimensions numbers are all unique"); |
| } |
| s.insert(dim); |
| } |
| |
| return success(); |
| } |
| |
| SmallVector<utils::IteratorType> ReverseOp::getLoopIteratorTypes() { |
| SmallVector<utils::IteratorType> iteratorTypes(getOperandRank(), |
| utils::IteratorType::parallel); |
| return iteratorTypes; |
| } |
| |
| SmallVector<Range> ReverseOp::getIterationDomain(OpBuilder &builder) { |
| Location loc = getLoc(); |
| Value zero = builder.create<arith::ConstantIndexOp>(loc, 0); |
| Value one = builder.create<arith::ConstantIndexOp>(loc, 1); |
| SmallVector<Range> ranges; |
| for (auto dim : llvm::seq<int64_t>(0, getOperandRank())) { |
| Value ub = getDimValue(builder, loc, input(), dim); |
| ranges.emplace_back(Range{zero, ub, one}); |
| } |
| return ranges; |
| } |
| |
| LogicalResult ReverseOp::generateScalarImplementation(OpBuilder &b, |
| Location loc, |
| ValueRange ivs) { |
| SmallVector<Value> mirrorIndices(ivs.begin(), ivs.end()); |
| for (auto dim : dims()) { |
| auto size = getDimValue(b, loc, input(), dim); |
| size = b.create<arith::SubIOp>(loc, size, |
| b.create<arith::ConstantIndexOp>(loc, 1)); |
| mirrorIndices[dim] = b.create<arith::SubIOp>(loc, size, mirrorIndices[dim]); |
| } |
| Value val = b.create<memref::LoadOp>(loc, input(), ivs); |
| b.create<memref::StoreOp>(loc, val, output(), mirrorIndices); |
| return success(); |
| } |
| |
| SmallVector<Operation *> |
| ReverseOp::getTiledImplementation(OpBuilder &builder, |
| ArrayRef<OpFoldResult> offsets, |
| ArrayRef<OpFoldResult> sizes) { |
| int64_t rank = getOperandRank(); |
| SmallVector<OpFoldResult> strides(rank, builder.getI64IntegerAttr(1)); |
| Location loc = getLoc(); |
| SmallVector<OpFoldResult> mirrorOffsets, mirrorSizes; |
| if (failed(getResultTilePosition(builder, 0, offsets, sizes, mirrorOffsets, |
| mirrorSizes))) { |
| return {}; |
| } |
| |
| SmallVector<Value> tiledOperands; |
| tiledOperands.emplace_back( |
| getSlice(builder, loc, input(), offsets, sizes, strides)); |
| |
| SmallVector<Type, 4> resultTypes; |
| if (hasTensorSemantics()) { |
| tiledOperands.emplace_back( |
| getSlice(builder, loc, output(), mirrorOffsets, sizes, strides)); |
| resultTypes.push_back(tiledOperands[1].getType()); |
| } else { |
| tiledOperands.emplace_back( |
| getSlice(builder, loc, output(), mirrorOffsets, sizes, strides)); |
| } |
| |
| Operation *tiledRevOp = cast<LinalgExtOp>(getOperation()) |
| .clone(builder, loc, resultTypes, tiledOperands); |
| |
| return {tiledRevOp}; |
| } |
| |
| LogicalResult ReverseOp::getResultTilePosition( |
| OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets, |
| ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets, |
| SmallVector<OpFoldResult> &resultSizes) { |
| AffineExpr sym0, sym1, sym2; |
| bindSymbols(builder.getContext(), sym0, sym1, sym2); |
| AffineMap map = |
| AffineMap::get(/*dimCount=*/0, /*symbolCount=*/3, {sym0 - sym1 - sym2}); |
| resultOffsets.assign(offsets.begin(), offsets.end()); |
| Location loc = getLoc(); |
| for (auto dim : dims()) { |
| Value size = getDimValue(builder, loc, input(), dim); |
| Value offset = |
| getValueOrCreateConstantIndexOp(builder, loc, resultOffsets[dim]); |
| Value tileSize = getValueOrCreateConstantIndexOp(builder, loc, sizes[dim]); |
| resultOffsets[dim] = |
| builder |
| .create<AffineApplyOp>(loc, map, ValueRange{size, offset, tileSize}) |
| .getResult(); |
| } |
| resultSizes.assign(sizes.begin(), sizes.end()); |
| return success(); |
| } |
| |
| LogicalResult |
| ReverseOp::reifyResultShapes(OpBuilder &b, |
| ReifiedRankedShapedTypeDims &reifiedReturnShapes) { |
| return cast<LinalgExtOp>(getOperation()) |
| .reifyResultShapes(b, reifiedReturnShapes); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TopkOp |
| //===----------------------------------------------------------------------===// |
| |
| LogicalResult TopkOp::verify() { |
| Operation *op = getOperation(); |
| if (getNumInputs() != 1 && getNumInputs() != 2) { |
| return op->emitOpError("expected one or two input operands"); |
| } |
| if (getNumOutputs() != 2) { |
| return op->emitOpError("expected two output operands"); |
| } |
| if (getDimension() >= getInputRank()) { |
| return op->emitOpError("dimension exceeds rank"); |
| } |
| // Ensure input/output element types match |
| auto inputValuesType = values().getType().cast<ShapedType>(); |
| auto outputValuesType = outputValues().getType().cast<ShapedType>(); |
| if (inputValuesType.getElementType() != outputValuesType.getElementType()) { |
| return op->emitOpError("expected input/output value types to be identical"); |
| } |
| // Indices must be int if provided |
| auto outputIndicesType = outputIndices().getType().cast<ShapedType>(); |
| if (auto inputIndices = indices()) { |
| auto inputIndicesType = inputIndices->getType().cast<ShapedType>(); |
| if (!inputIndicesType.getElementType().isInteger(32) || |
| !outputIndicesType.getElementType().isInteger(32)) { |
| return op->emitOpError("expected input/output indices types to be int32"); |
| } |
| } |
| |
| // Ranks must match |
| if (inputValuesType.getRank() != outputValuesType.getRank()) { |
| return op->emitOpError("expected input/output to have the same rank"); |
| } |
| if (auto inputIndices = indices()) { |
| auto inputIndicesType = inputIndices->getType().cast<ShapedType>(); |
| if (inputIndicesType.getRank() != outputIndicesType.getRank()) { |
| return op->emitOpError("expected input/output to have the same rank"); |
| } |
| } |
| // Input indicies and values must have the same shape. |
| if (auto inputIndices = indices()) { |
| auto inputIndicesType = inputIndices->getType().cast<ShapedType>(); |
| if (llvm::any_of( |
| llvm::zip(inputValuesType.getShape(), inputIndicesType.getShape()), |
| [](std::tuple<int64_t, int64_t> s) { |
| return isShapedTypeDimEqual(std::get<0>(s), std::get<1>(s)); |
| })) { |
| return op->emitOpError("input indices/values shape must match"); |
| } |
| } |
| // Output indicies and values must have the same shape. |
| if (llvm::any_of( |
| llvm::zip(outputValuesType.getShape(), outputIndicesType.getShape()), |
| [](std::tuple<int64_t, int64_t> s) { |
| return isShapedTypeDimEqual(std::get<0>(s), std::get<1>(s)); |
| })) { |
| return op->emitOpError("output indices/values shape must match"); |
| } |
| // Input shape must match the output shape except for the dimension() |
| uint64_t dim = getDimension(); |
| if (llvm::any_of(llvm::enumerate(llvm::zip(inputValuesType.getShape(), |
| outputValuesType.getShape())), |
| [dim](auto e) { |
| if (e.index() == dim) { |
| return false; |
| } |
| std::tuple<int64_t, int64_t> s = e.value(); |
| return isShapedTypeDimEqual(std::get<0>(s), |
| std::get<1>(s)); |
| })) { |
| return op->emitOpError("incompatible input/output shapes"); |
| } |
| // Check region compatibility |
| Block &block = getRegion().front(); |
| if (block.getNumArguments() != 2) { |
| return op->emitOpError("region block should have 2 arguments"); |
| } |
| if (block.getArgument(0).getType() != inputValuesType.getElementType() || |
| block.getArgument(1).getType() != inputValuesType.getElementType()) { |
| return op->emitOpError("region block types must match input"); |
| } |
| auto terminatorOp = llvm::dyn_cast<YieldOp>(block.getTerminator()); |
| if (!terminatorOp || !terminatorOp.getOperand(0).getType().isInteger(1)) { |
| return op->emitOpError("region block must end with a linalg_ext.yield i1!"); |
| } |
| return success(); |
| } |
| |
| SmallVector<Range> TopkOp::getIterationDomain(OpBuilder &builder) { |
| int64_t operandRank = getInputRank(); |
| SmallVector<Range> loopBounds(operandRank); |
| Location loc = getLoc(); |
| Value zero = builder.create<arith::ConstantIndexOp>(loc, 0); |
| Value one = builder.create<arith::ConstantIndexOp>(loc, 1); |
| Value source = values(); |
| for (auto dim : llvm::enumerate(getInputType().getShape())) { |
| loopBounds[dim.index()].offset = zero; |
| loopBounds[dim.index()].size = |
| getDimValue(builder, loc, source, dim.index()); |
| loopBounds[dim.index()].stride = one; |
| } |
| return loopBounds; |
| } |
| |
| SmallVector<utils::IteratorType> TopkOp::getLoopIteratorTypes() { |
| SmallVector<utils::IteratorType> iteratorTypes(getInputRank(), |
| utils::IteratorType::parallel); |
| iteratorTypes[getDimension()] = utils::IteratorType::reduction; |
| return iteratorTypes; |
| } |
| |
| LogicalResult TopkOp::generateScalarImplementation(OpBuilder &b, Location loc, |
| ValueRange ivs) { |
| uint64_t kDim = getDimension(); |
| Value zero = b.create<arith::ConstantIndexOp>(loc, 0); |
| Value one = b.create<arith::ConstantIndexOp>(loc, 1); |
| Value initialValue = b.create<memref::LoadOp>(loc, values(), ivs); |
| |
| // If the indices tensor is not provided, the value index is derived from the |
| // loop induction variables. |
| Value initialIndex; |
| if (indices()) { |
| initialIndex = b.create<memref::LoadOp>(loc, *indices(), ivs); |
| } else { |
| Value rawInitialIndex = ivs[kDim]; |
| initialIndex = |
| b.create<arith::IndexCastOp>(loc, b.getI32Type(), rawInitialIndex); |
| } |
| |
| // Compute K (ub) from the selected dim of the output |
| Value ub = b.create<memref::DimOp>(loc, outputValues(), getDimension()); |
| |
| // Inner K loop functions: |
| // Load current K value and index |
| // Compare N/K using inserted block compare |
| // Check if N == K using strict weak ordering, select which index came first |
| // Select new K value from N/K comparison |
| // Select new K index from N/K comparison or which index came first |
| // Store new k value and index |
| // Yield loop carry values after K selection |
| Value kValue, kIndex; |
| auto scfFor = b.create<scf::ForOp>( |
| loc, zero, ub, one, ValueRange{initialValue, initialIndex}, |
| [&](OpBuilder &b, Location loc, Value iv, ValueRange loopCarryValues) { |
| SmallVector<Value> indices(ivs); |
| indices[kDim] = iv; |
| kValue = b.create<memref::LoadOp>(loc, outputValues(), indices); |
| kIndex = b.create<memref::LoadOp>(loc, outputIndices(), indices); |
| }); |
| |
| SmallVector<Value> indices(ivs); |
| indices[kDim] = scfFor.getInductionVar(); |
| auto loopCarryValues = scfFor.getRegionIterArgs(); |
| |
| // Retrieve region as black box comparision function f(x,y). Plug into op. |
| auto &srcBlock = getRegion().front(); |
| BlockAndValueMapping bvmF; // f(x,y) |
| BlockAndValueMapping bvmR; // f(y,x) |
| { |
| // Save previous insertion point. Continue within loop body. |
| OpBuilder::InsertionGuard guard(b); |
| b.setInsertionPointToEnd(&scfFor.getRegion().front()); |
| SmallVector<Value> forwardValues{loopCarryValues[0], kValue}; |
| SmallVector<Value> reverseValues{kValue, loopCarryValues[0]}; |
| for (auto it : llvm::zip(srcBlock.getArguments(), forwardValues)) { |
| bvmF.map(std::get<0>(it), std::get<1>(it)); |
| } |
| for (auto it : llvm::zip(srcBlock.getArguments(), reverseValues)) { |
| bvmR.map(std::get<0>(it), std::get<1>(it)); |
| } |
| for (auto &blockOp : srcBlock.without_terminator()) { |
| b.clone(blockOp, bvmF); |
| b.clone(blockOp, bvmR); |
| } |
| Value forwardCmpRes = bvmF.lookup(srcBlock.getTerminator()->getOperand(0)); |
| Value reverseCmpRes = bvmR.lookup(srcBlock.getTerminator()->getOperand(0)); |
| |
| // Check value equality using strictly weak ordering from the region: |
| // f(x,y) --> forwardCmpRes |
| // f(y,x) --> reverseCmpRes |
| // if forwardCmpRes == reverseCmpRes then select which came first |
| Value cmpValuesEqual = b.create<arith::CmpIOp>( |
| loc, arith::CmpIPredicate::eq, forwardCmpRes, reverseCmpRes); |
| Value cmpFirstIndex = b.create<arith::CmpIOp>( |
| loc, arith::CmpIPredicate::slt, loopCarryValues[1], kIndex); |
| Value combinedCmpEqRes = |
| b.create<arith::AndIOp>(loc, cmpValuesEqual, cmpFirstIndex); |
| // True if N > K or N came before K |
| Value indexCmpRes = |
| b.create<arith::OrIOp>(loc, forwardCmpRes, combinedCmpEqRes); |
| // Select results for K based on comparisons |
| Value resultKValue = b.create<arith::SelectOp>(loc, forwardCmpRes, |
| loopCarryValues[0], kValue); |
| Value resultKIndex = |
| b.create<arith::SelectOp>(loc, indexCmpRes, loopCarryValues[1], kIndex); |
| b.create<memref::StoreOp>(loc, resultKValue, outputValues(), indices); |
| b.create<memref::StoreOp>(loc, resultKIndex, outputIndices(), indices); |
| // Select loop carry, opposite of K results |
| Value resultCarryValue = b.create<arith::SelectOp>( |
| loc, forwardCmpRes, kValue, loopCarryValues[0]); |
| Value resultCarryIndex = |
| b.create<arith::SelectOp>(loc, indexCmpRes, kIndex, loopCarryValues[1]); |
| b.create<scf::YieldOp>(loc, ValueRange{resultCarryValue, resultCarryIndex}); |
| } |
| return success(); |
| } |
| |
| SmallVector<Operation *> |
| TopkOp::getTiledImplementation(OpBuilder &builder, |
| ArrayRef<OpFoldResult> offsets, |
| ArrayRef<OpFoldResult> sizes) { |
| int64_t rank = getInputRank(); |
| assert(offsets.size() == static_cast<size_t>(rank) && |
| sizes.size() == static_cast<size_t>(rank)); |
| SmallVector<OpFoldResult> strides(rank, builder.getI64IntegerAttr(1)); |
| Location loc = getLoc(); |
| |
| SmallVector<OpFoldResult> outputOffsets, outputSizes; |
| if (failed(getResultTilePosition(builder, 0, offsets, sizes, outputOffsets, |
| outputSizes))) { |
| return {}; |
| } |
| |
| SmallVector<Value> tiledOperands; |
| tiledOperands.emplace_back( |
| getSlice(builder, loc, values(), offsets, sizes, strides)); |
| if (indices()) { |
| tiledOperands.emplace_back( |
| getSlice(builder, loc, *indices(), offsets, sizes, strides)); |
| } |
| |
| // Replace the tile size for the K dimension to use the output size instead of |
| // the input size. |
| Value kSize = getDimValue(builder, getLoc(), outputValues(), getDimension()); |
| outputSizes[getDimension()] = getAsOpFoldResult(kSize); |
| |
| tiledOperands.emplace_back( |
| getSlice(builder, loc, getOutputs()[0], offsets, outputSizes, strides)); |
| tiledOperands.emplace_back( |
| getSlice(builder, loc, getOutputs()[1], offsets, outputSizes, strides)); |
| SmallVector<Type, 2> resultTypes; |
| if (hasTensorSemantics()) { |
| resultTypes.push_back(tiledOperands[tiledOperands.size() - 2].getType()); |
| resultTypes.push_back(tiledOperands[tiledOperands.size() - 1].getType()); |
| } |
| |
| Operation *tiledTopkOp = cast<LinalgExtOp>(getOperation()) |
| .clone(builder, loc, resultTypes, tiledOperands); |
| return {tiledTopkOp}; |
| } |
| |
| LogicalResult TopkOp::getResultTilePosition( |
| OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets, |
| ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets, |
| SmallVector<OpFoldResult> &resultSizes) { |
| resultOffsets.assign(offsets.begin(), offsets.end()); |
| resultSizes.assign(sizes.begin(), sizes.end()); |
| Value kSize = getDimValue( |
| builder, getLoc(), getOutputOperand(resultNumber)->get(), getDimension()); |
| resultSizes[getDimension()] = getAsOpFoldResult(kSize); |
| return success(); |
| } |
| |
| LogicalResult |
| TopkOp::reifyResultShapes(OpBuilder &b, |
| ReifiedRankedShapedTypeDims &reifiedReturnShapes) { |
| return cast<LinalgExtOp>(getOperation()) |
| .reifyResultShapes(b, reifiedReturnShapes); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // PackOp |
| //===----------------------------------------------------------------------===// |
| |
| // Return true if each element in `dimsPos` is >= 0 and < rank. |
| static bool isInBound(ArrayRef<int64_t> dimsPos, int64_t rank) { |
| return llvm::all_of( |
| dimsPos, [rank](int64_t dimPos) { return dimPos >= 0 && dimPos < rank; }); |
| } |
| |
| // Interchange `elements` starting at offset `offset` based on the indexes in |
| // `interchangeVector`. |
| template <typename T> |
| static SmallVector<T> interchange(ArrayRef<T> elements, |
| ArrayRef<int64_t> interchangeVector, |
| int64_t offset) { |
| SmallVector<T> rearrangedElements = llvm::to_vector(elements); |
| if (interchangeVector.empty()) |
| return rearrangedElements; |
| assert((rearrangedElements.size() - offset) == interchangeVector.size() && |
| "number of elements must equal number of permutations"); |
| for (int64_t idx = 0, end = interchangeVector.size(); idx < end; idx++) { |
| rearrangedElements[interchangeVector[idx] + offset] = |
| elements[idx + offset]; |
| } |
| return rearrangedElements; |
| } |
| |
| // Infer result/output type given the input and the tile sizes. |
| ShapedType PackOp::inferResultType() { |
| DenseMap<int64_t, OpFoldResult> tileAndPosMapping = getDimAndTileMapping(); |
| SmallVector<int64_t> inferredShape; |
| inferredShape.reserve(getOutputRank()); |
| ShapedType inputType = getInputType(); |
| int64_t rank = getInputRank(); |
| |
| // tile loop. |
| for (auto i : llvm::seq<int64_t>(0, rank)) { |
| if (tileAndPosMapping.count(i)) { |
| Optional<int64_t> tileSize = |
| getConstantIntValue(tileAndPosMapping.lookup(i)); |
| if (inputType.isDynamicDim(i) || !tileSize) { |
| inferredShape.push_back(ShapedType::kDynamicSize); |
| } else { |
| int64_t sizeTiledDim = ceilDiv(inputType.getDimSize(i), *tileSize); |
| inferredShape.push_back(sizeTiledDim); |
| } |
| } else { |
| inferredShape.push_back(inputType.getShape()[i]); |
| } |
| } |
| |
| // point loop. |
| auto staticTiles = getStaticTiles(); |
| inferredShape.append(staticTiles.begin(), staticTiles.end()); |
| |
| return TypeSwitch<Type, ShapedType>(inputType) |
| .Case<RankedTensorType>([&](RankedTensorType t) -> ShapedType { |
| return RankedTensorType::get(inferredShape, inputType.getElementType()); |
| }) |
| .Case<MemRefType>([&](MemRefType t) -> ShapedType { |
| return MemRefType::get(inferredShape, inputType.getElementType()); |
| }) |
| .Default([&](Type t) { |
| llvm_unreachable("unexpected type"); |
| return nullptr; |
| }); |
| } |
| |
| // Return true if at least one element in `tiles` is zero. |
| static bool hasZeros(ArrayRef<OpFoldResult> tiles) { |
| return llvm::any_of( |
| tiles, [&](OpFoldResult tile) { return isConstantIntValue(tile, 0); }); |
| } |
| |
| // Return true if `dimsPos` is invalid. It is invalid when: a) it contains |
| // duplicate. |
| static bool isInvalid(ArrayRef<int64_t> dimsPos) { |
| DenseSet<int64_t> uniqued; |
| for (int64_t dim : dimsPos) |
| uniqued.insert(dim); |
| return dimsPos.size() != uniqued.size(); |
| } |
| |
| // Check if we have enough static information to catch undefined behavior when |
| // the tile size does not divide perfectly the dimension of the input tensor. |
| static bool areNotFullTiles(ArrayRef<int64_t> inputShape, |
| DenseMap<int64_t, OpFoldResult> dimAndTileMapping) { |
| int64_t rank = inputShape.size(); |
| for (int64_t dim = 0; dim < rank; dim++) { |
| if (inputShape[dim] == ShapedType::kDynamicSize) |
| continue; |
| if (dimAndTileMapping.count(dim)) { |
| Optional<int64_t> constantTile = |
| getConstantIntValue(dimAndTileMapping[dim]); |
| if (!constantTile) |
| continue; |
| if (inputShape[dim] % (*constantTile) != 0) |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| // verifier for the pack operation. |
| LogicalResult PackOp::verify() { |
| Operation *op = getOperation(); |
| size_t numberOfBlockingFactors = getMixedTiles().size(); |
| SmallVector<int64_t> dimsPos = extractFromI64ArrayAttr(getDimsPos()); |
| // Blocking factors must be less or equal than the input rank, and must |
| // match the number of `dims_pos`. |
| if (numberOfBlockingFactors > getInputRank()) { |
| return op->emitError( |
| "blocking factors must be less or equal than the input rank"); |
| } |
| if (numberOfBlockingFactors != dimsPos.size()) { |
| return op->emitError( |
| "blocking factors must equal the number of dimensions to block"); |
| } |
| if (isInvalid(dimsPos)) |
| return op->emitError("invalid dims_pos vector"); |
| // Require `dim_pos` to be in-bound. `dim_pos` carries the index of the |
| // dimensions to block. |
| if (!isInBound(dimsPos, getOutputRank())) |
| return op->emitError("out-of-bound position"); |
| |
| // Require output rank to match input rank + number of blocking factors. |
| if ((getInputRank() + numberOfBlockingFactors) != getOutputRank()) { |
| return op->emitError( |
| "output rank must equal input rank + blocking factors"); |
| } |
| |
| // Verify tiles. Make sure each provided tile is non-zero. |
| if (hasZeros(getMixedTiles())) |
| return op->emitError("invalid tile factor"); |
| |
| // Bail out if the tile does not divide the dimension fully. In the case of |
| // dynamic tile factors or dimensions, having a partial tile is undefined |
| // behavior. We will relax this constraint when we introduce padding |
| // semantics. |
| if (areNotFullTiles(getInputShape(), getDimAndTileMapping())) { |
| return op->emitError( |
| "invalid tile factor provided. Only full tiles are supported"); |
| } |
| |
| // Verify result type against inferred type. |
| ShapedType expectedType = inferResultType(); |
| if (expectedType != getOutputType()) { |
| return op->emitError( |
| "inferred type do not match provied output type. Expected ") |
| << expectedType << " but got: " << getOutputType(); |
| } |
| return success(); |
| } |
| |
| // Get the tile sizes as `OpFoldResult`. |
| SmallVector<OpFoldResult> PackOp::getMixedTiles() { |
| SmallVector<OpFoldResult> mixedInnerTiles; |
| mixedInnerTiles.reserve(getInputRank()); |
| unsigned dynamicValIndex = 0; |
| for (Attribute attr : getStaticInnerTiles()) { |
| auto tileAttr = attr.cast<IntegerAttr>(); |
| if (!ShapedType::isDynamic(tileAttr.getInt())) |
| mixedInnerTiles.push_back(tileAttr); |
| else |
| mixedInnerTiles.push_back(getInnerTiles()[dynamicValIndex++]); |
| } |
| return mixedInnerTiles; |
| } |
| |
| // Return the tile sizes as `int64_t`. If a tile size is dynamic a sentinel |
| // `kDynamicSize` is introduced at that position in the returned vector. |
| SmallVector<int64_t> PackOp::getStaticTiles() { |
| SmallVector<Value> dynamicTiles; |
| SmallVector<int64_t> staticTiles; |
| dispatchIndexOpFoldResults(getMixedTiles(), dynamicTiles, staticTiles, |
| ShapedType::kDynamicSize); |
| return staticTiles; |
| } |
| |
| // Implement the tiling interface. The number of loops equals |
| // the rank of the output tensors. All the loops are parallel. |
| SmallVector<utils::IteratorType> PackOp::getLoopIteratorTypes() { |
| SmallVector<utils::IteratorType> iteratorTypes(getOutputRank(), |
| utils::IteratorType::parallel); |
| return iteratorTypes; |
| } |
| |
| // copied from: |
| // https://mlir.llvm.org/doxygen/TensorInferTypeOpInterfaceImpl_8cpp_source.html |
| /// Helper function to convert a vector of `OpFoldResult`s into a vector of |
| /// `Value`s. |
| static SmallVector<Value> getAsValues(OpBuilder &b, Location loc, |
| ArrayRef<OpFoldResult> valueOrAttrVec) { |
| return llvm::to_vector<4>( |
| llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { |
| return getValueOrCreateConstantIndexOp(b, loc, value); |
| })); |
| } |
| |
| // Return a mapping from positions `dims_pos` to their `OpFoldResult` tile |
| // factors. |
| DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() { |
| DenseMap<int64_t, OpFoldResult> dimAndTileMapping; |
| SmallVector<int64_t> dimsToBlock = extractFromI64ArrayAttr(getDimsPos()); |
| SmallVector<OpFoldResult> tiles = getMixedTiles(); |
| assert(tiles.size() == dimsToBlock.size() && |
| "tiles must match indices of dimension to block"); |
| // bind the dimension with the tile factor. |
| for (auto i : llvm::seq<int64_t>(0, dimsToBlock.size())) |
| dimAndTileMapping[dimsToBlock[i]] = tiles[i]; |
| return dimAndTileMapping; |
| } |
| |
| // Implements `getIterationDomain` from the tiling interface. In each |
| // loop the lower bound is zero and the step is one. For upper bound |
| // is inferred from the output tensor. |
| SmallVector<Range> PackOp::getIterationDomain(OpBuilder &builder) { |
| int64_t outputRank = getOutputRank(); |
| SmallVector<Range> loopBounds(outputRank); |
| Location loc = getLoc(); |
| Value zero = builder.create<arith::ConstantIndexOp>(loc, 0); |
| Value one = builder.create<arith::ConstantIndexOp>(loc, 1); |
| ReifiedRankedShapedTypeDims resultShape; |
| (void)reifyResultShapes(builder, resultShape); |
| for (auto dim : llvm::seq<int64_t>(0, outputRank)) { |
| loopBounds[dim].offset = zero; |
| loopBounds[dim].stride = one; |
| loopBounds[dim].size = resultShape[0][dim]; |
| } |
| return loopBounds; |
| } |
| |
| // Return the `interchangeVector` based on `dims_pos`. |
| SmallVector<int64_t> computeInterchangeFromDimPos(ArrayRef<int64_t> dimsPos, |
| int64_t inputRank) { |
| SmallVector<int64_t> interchangeVector; |
| interchangeVector.reserve(dimsPos.size()); |
| // First map dims and their position. For example, dims_pos = [2, 0] will map |
| // to: |
| // [ |
| // [ key: 2, value: 0] |
| // [ key: 0, value: 1] |
| // ] |
| // where key is the idx in dims_pos while value its position in dims_pos. |
| DenseMap<int64_t, int64_t> dimsAndPosMapping; |
| for (int64_t dimsIdx = 0, end = dimsPos.size(); dimsIdx < end; dimsIdx++) |
| dimsAndPosMapping[dimsPos[dimsIdx]] = dimsIdx; |
| |
| // Scan the position in order and insert the value in the map |
| // to compute the interchange vector. |
| for (int64_t dimsIdx = 0; dimsIdx < inputRank; dimsIdx++) { |
| if (dimsAndPosMapping.count(dimsIdx)) |
| interchangeVector.push_back(dimsAndPosMapping[dimsIdx]); |
| } |
| return interchangeVector; |
| } |
| |
| // Implements `getIterationDomain` from the tiling interface. |
| LogicalResult PackOp::generateScalarImplementation(OpBuilder &builder, |
| Location loc, |
| ValueRange ivs) { |
| // Note: `ivs` are already in the correct order, possibly interchanged based |
| // on `dims_pos`. However, connecting the loops with the access patterns is |
| // difficult - What is the relation between the position of the tile loop and |
| // the point loop? However, if we interchange `ivs` once more to go to the |
| // canonical blocking format: ABCabc, this connection becomes trivial: Each |
| // point loop is pointLoopsOffset + inputRank away from the tiled loop. |
| SmallVector<int64_t> dimsToBlock = extractFromI64ArrayAttr(getDimsPos()); |
| SmallVector<int64_t> testInterchangeVector = |
| computeInterchangeFromDimPos(dimsToBlock, getInputRank()); |
| SmallVector<Value> interchangedIvs = ivs; |
| interchangedIvs = interchange<Value>(interchangedIvs, testInterchangeVector, |
| /*offset=*/getInputRank()); |
| |
| SmallVector<OpFoldResult> tiles = getMixedTiles(); |
| DenseMap<int64_t, OpFoldResult> dimAndTileMapping = getDimAndTileMapping(); |
| SmallVector<OpFoldResult> sourceIndices; |
| size_t pointLoopsOffset = 0; |
| for (auto dim : llvm::seq<int64_t>(0, getInputRank())) { |
| if (dimAndTileMapping.count(dim)) { |
| AffineExpr i, j, tile; |
| bindDims(builder.getContext(), i, j); |
| bindSymbols(builder.getContext(), tile); |
| OpFoldResult sourceIndex = makeComposedFoldedAffineApply( |
| builder, loc, i * tile + j, |
| ArrayRef<OpFoldResult>{ |
| interchangedIvs[dim], |
| interchangedIvs[pointLoopsOffset + getInputRank()], |
| dimAndTileMapping[dim]}); |
| sourceIndices.push_back(sourceIndex); |
| ++pointLoopsOffset; |
| } else { |
| sourceIndices.push_back(interchangedIvs[dim]); |
| } |
| } |
| Value scalar = builder.create<memref::LoadOp>( |
| loc, getInput(), getAsValues(builder, loc, sourceIndices)); |
| builder.create<memref::StoreOp>(loc, scalar, getOutput(), ivs); |
| return success(); |
| } |
| |
| LogicalResult |
| PackOp::reifyResultShapes(OpBuilder &builder, |
| ReifiedRankedShapedTypeDims &reifiedReturnShapes) { |
| |
| // Build the output dimension at pos `dimIdx`. |
| auto buildOutputDim = [&](OpBuilder &builder, size_t dimIdx) -> OpFoldResult { |
| ArrayRef<int64_t> outputShape = getOutputShape(); |
| if (!ShapedType::isDynamic(outputShape[dimIdx])) |
| return builder.getI64IntegerAttr(outputShape[dimIdx]); |
| |
| // Handle dynamic. |
| DenseMap<int64_t, OpFoldResult> dimAndTileMapping = getDimAndTileMapping(); |
| AffineExpr dim = builder.getAffineSymbolExpr(0); |
| AffineExpr tile = builder.getAffineSymbolExpr(1); |
| auto apply = [&](AffineExpr expr, |
| ArrayRef<OpFoldResult> values) -> OpFoldResult { |
| return makeComposedFoldedAffineApply(builder, getOperation()->getLoc(), |
| expr, values); |
| }; |
| // If we are dealing with a tiled dimension compose the map otherwise |
| // return the dimension extracted with `memref.dim`. |
| OpFoldResult dimBound = |
| getDim(builder, getOperation()->getLoc(), getOutput(), dimIdx); |
| return (dimAndTileMapping.count(dimIdx)) |
| ? apply(dim.ceilDiv(tile), |
| ArrayRef<OpFoldResult>{dimBound, |
| dimAndTileMapping[dimIdx]}) |
| : dimBound; |
| }; |
| |
| reifiedReturnShapes.resize(1); |
| reifiedReturnShapes[0].reserve(getOutputRank()); |
| for (auto dimIdx : llvm::seq<int64_t>(0, getOutputRank())) { |
| reifiedReturnShapes[0].push_back(getAsValues( |
| builder, getOperation()->getLoc(), |
| ArrayRef<OpFoldResult>{buildOutputDim(builder, dimIdx)})[0]); |
| } |
| return success(); |
| } |
| |
| #define DEFINE_OP_GET_EFFECTS(OP_NAME) \ |
| void OP_NAME::getEffects( \ |
| SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \ |
| &effects) { \ |
| SmallVector<Value> inputBuffers = getInputBufferOperands(); \ |
| SmallVector<Value> outputBuffers = getOutputBufferOperands(); \ |
| getEffectsImpl(effects, getOperation()->getResults(), inputBuffers, \ |
| outputBuffers); \ |
| } |
| |
| DEFINE_OP_GET_EFFECTS(ScatterOp) |
| DEFINE_OP_GET_EFFECTS(SortOp) |
| DEFINE_OP_GET_EFFECTS(FftOp) |
| DEFINE_OP_GET_EFFECTS(ReverseOp) |
| DEFINE_OP_GET_EFFECTS(ScanOp) |
| DEFINE_OP_GET_EFFECTS(TopkOp) |
| DEFINE_OP_GET_EFFECTS(PackOp) |
| namespace { |
| /// This is derived from mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp without any |
| /// changes. |
| struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgExtOp> { |
| using OpInterfaceRewritePattern<LinalgExtOp>::OpInterfaceRewritePattern; |
| |
| LogicalResult matchAndRewrite(LinalgExtOp op, |
| PatternRewriter &rewriter) const override { |
| // If no operand comes from a tensor::CastOp and can be folded then fail. |
| bool hasTensorCastOperand = |
| llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) { |
| if (opOperand->get().isa<BlockArgument>()) |
| return false; |
| auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>(); |
| return castOp && canFoldIntoConsumerOp(castOp); |
| }); |
| if (!hasTensorCastOperand) |
| return failure(); |
| |
| SmallVector<Type, 4> newResultTypes; |
| newResultTypes.reserve(op->getNumResults()); |
| SmallVector<Value, 4> newOperands; |
| newOperands.reserve(op->getNumOperands()); |
| // Inputs may fold. |
| for (OpOperand *opOperand : op.getInputOperands()) { |
| auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>(); |
| newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp) |
| ? tensorCastOp.getSource() |
| : opOperand->get()); |
| } |
| // Init tensors may fold, in which case the resultType must also change. |
| for (OpOperand *opOperand : op.getOutputOperands()) { |
| auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>(); |
| bool fold = canFoldIntoConsumerOp(tensorCastOp); |
| newOperands.push_back(fold ? tensorCastOp.getOperand() |
| : opOperand->get()); |
| newResultTypes.push_back(newOperands.back().getType()); |
| } |
| // Clone op. |
| Operation *newOp = |
| op.clone(rewriter, op->getLoc(), newResultTypes, newOperands); |
| SmallVector<Value, 4> replacements; |
| replacements.reserve(newOp->getNumResults()); |
| for (auto result : llvm::zip(op->getResults(), newOp->getResults())) { |
| Value oldResult = std::get<0>(result); |
| Value newResult = std::get<1>(result); |
| if (newResult.getType() != oldResult.getType()) { |
| replacements.push_back(rewriter.create<tensor::CastOp>( |
| op->getLoc(), oldResult.getType(), newResult)); |
| } else { |
| replacements.push_back(newResult); |
| } |
| } |
| rewriter.replaceOp(op, replacements); |
| |
| return success(); |
| } |
| }; |
| } // namespace |
| |
| //===----------------------------------------------------------------------===// |
| // LinalgExtDialect |
| //===----------------------------------------------------------------------===// |
| |
| void IREELinalgExtDialect::getCanonicalizationPatterns( |
| RewritePatternSet &results) const { |
| results.add<FoldTensorCastOp>(getContext()); |
| } |
| |
| #define GET_OP_CLASSES |
| #include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.cpp.inc" |