Modify `pack` op iteration space description to make it amenable to Tiling. (#10569)
The innermost loop dimensions of the `pack` operation deal with the packing the data into a tile. These loops cannot be tiled. The loops that iterate over the different data tiles of the output can be tiled. Change the iteration space description to only specify the loops that iterate over the data tiles of the output. The intra-data tile loops are generated as part of the scalar implementation of the `pack` operation.
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index c6d1cfd..32771a5 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -1681,7 +1681,7 @@
// Implement the tiling interface. The number of loops equals
// the rank of the output tensors. All the loops are parallel.
SmallVector<utils::IteratorType> PackOp::getLoopIteratorTypes() {
- SmallVector<utils::IteratorType> iteratorTypes(getOutputRank(),
+ SmallVector<utils::IteratorType> iteratorTypes(getInputRank(),
utils::IteratorType::parallel);
return iteratorTypes;
}
@@ -1714,16 +1714,17 @@
// Implements `getIterationDomain` from the tiling interface. In each
// loop the lower bound is zero and the step is one. For upper bound
-// is inferred from the output tensor.
+// is inferred from the output tensor for the dimensions that are
+// not part of the data tile created.
SmallVector<Range> PackOp::getIterationDomain(OpBuilder &builder) {
- int64_t outputRank = getOutputRank();
- SmallVector<Range> loopBounds(outputRank);
+ int64_t inputRank = getInputRank();
+ SmallVector<Range> loopBounds(inputRank);
Location loc = getLoc();
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
ReifiedRankedShapedTypeDims resultShape;
(void)reifyResultShapes(builder, resultShape);
- for (auto dim : llvm::seq<int64_t>(0, outputRank)) {
+ for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
loopBounds[dim].offset = zero;
loopBounds[dim].stride = one;
loopBounds[dim].size = resultShape[0][dim];
@@ -1756,8 +1757,10 @@
return interchangeVector;
}
-// Implements `getIterationDomain` from the tiling interface.
-LogicalResult PackOp::generateScalarImplementation(OpBuilder &builder,
+/// Generate the body of the innermost loop of the scalar implementation
+/// of `pack` operation.
+static void generatePackOpScalarImplementationBody(PackOp packOp,
+ OpBuilder &builder,
Location loc,
ValueRange ivs) {
// Note: `ivs` are already in the correct order, possibly interchanged based
@@ -1766,18 +1769,20 @@
// the point loop? However, if we interchange `ivs` once more to go to the
// canonical blocking format: ABCabc, this connection becomes trivial: Each
// point loop is pointLoopsOffset + inputRank away from the tiled loop.
- SmallVector<int64_t> dimsToBlock = extractFromI64ArrayAttr(getDimsPos());
- SmallVector<int64_t> testInterchangeVector =
- computeInterchangeFromDimPos(dimsToBlock, getInputRank());
+ SmallVector<int64_t> dimsToBlock =
+ extractFromI64ArrayAttr(packOp.getDimsPos());
SmallVector<Value> interchangedIvs = ivs;
- interchangedIvs = interchange<Value>(interchangedIvs, testInterchangeVector,
- /*offset=*/getInputRank());
+ SmallVector<int64_t> interchangeVector =
+ computeInterchangeFromDimPos(dimsToBlock, packOp.getInputRank());
+ interchangedIvs = interchange<Value>(interchangedIvs, interchangeVector,
+ /*offset=*/packOp.getInputRank());
- SmallVector<OpFoldResult> tiles = getMixedTiles();
- DenseMap<int64_t, OpFoldResult> dimAndTileMapping = getDimAndTileMapping();
+ SmallVector<OpFoldResult> tiles = packOp.getMixedTiles();
+ DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
+ packOp.getDimAndTileMapping();
SmallVector<OpFoldResult> sourceIndices;
size_t pointLoopsOffset = 0;
- for (auto dim : llvm::seq<int64_t>(0, getInputRank())) {
+ for (auto dim : llvm::seq<int64_t>(0, packOp.getInputRank())) {
if (dimAndTileMapping.count(dim)) {
AffineExpr i, j, tile;
bindDims(builder.getContext(), i, j);
@@ -1786,7 +1791,7 @@
builder, loc, i * tile + j,
ArrayRef<OpFoldResult>{
interchangedIvs[dim],
- interchangedIvs[pointLoopsOffset + getInputRank()],
+ interchangedIvs[pointLoopsOffset + packOp.getInputRank()],
dimAndTileMapping[dim]});
sourceIndices.push_back(sourceIndex);
++pointLoopsOffset;
@@ -1795,15 +1800,58 @@
}
}
Value scalar = builder.create<memref::LoadOp>(
- loc, getInput(), getAsValues(builder, loc, sourceIndices));
- builder.create<memref::StoreOp>(loc, scalar, getOutput(), ivs);
+ loc, packOp.getInput(), getAsValues(builder, loc, sourceIndices));
+ builder.create<memref::StoreOp>(loc, scalar, packOp.getOutput(), ivs);
+}
+
+// Implements `generateScalarImplementation` from the tiling interface.
+LogicalResult PackOp::generateScalarImplementation(OpBuilder &builder,
+ Location loc,
+ ValueRange ivs) {
+ OpBuilder::InsertionGuard g(builder);
+ // The `ivs` already represent the position into the output tensor for the
+ // non data-tile dimensions.
+ SmallVector<Value> ivVec = llvm::to_vector(ivs);
+ ReifiedRankedShapedTypeDims outputShape;
+ if (failed(reifyResultShapes(builder, outputShape)))
+ return getOperation()->emitOpError("failed to reify result shape");
+ if (outputShape.size() != 1 || outputShape[0].size() != getOutputRank()) {
+ return getOperation()->emitOpError(
+ "expected shape of one result value of rank")
+ << getOutputRank();
+ }
+
+ // Generate the loops that iterate over the data tile.
+ Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+
+ // All loops except the innermost are simple loops that just iterate
+ // over the tile dimensions.
+ for (auto dataTileDim :
+ llvm::seq<unsigned>(getInputRank(), getOutputRank() - 1)) {
+ Value ub = outputShape[0][dataTileDim];
+ scf::ForOp loop = builder.create<scf::ForOp>(loc, zero, ub, one);
+ builder.setInsertionPointToStart(loop.getBody());
+ ivVec.push_back(loop.getInductionVar());
+ }
+ // The body of the innermost loops does the actual data movement.
+ builder.create<scf::ForOp>(loc, zero, outputShape[0].back(), one,
+ ValueRange{},
+ [&](OpBuilder &bodyBuilder, Location bodyLoc,
+ Value iv, ValueRange regionIterArgs) {
+ ivVec.push_back(iv);
+ generatePackOpScalarImplementationBody(
+ *this, bodyBuilder, bodyLoc, ivVec);
+ bodyBuilder.create<scf::YieldOp>(bodyLoc);
+ });
return success();
}
LogicalResult
PackOp::reifyResultShapes(OpBuilder &builder,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
-
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPoint(getOperation());
// Build the output dimension at pos `dimIdx`.
auto buildOutputDim = [&](OpBuilder &builder, size_t dimIdx) -> OpFoldResult {
ArrayRef<int64_t> outputShape = getOutputShape();