blob: ba14818965b021f92479b44020554be550278883 [file] [log] [blame]
Nicolas Vasilache11fb8d02022-03-21 20:33:30 +01001// Copyright 2021 The IREE Authors
Nicolas Vasilachef3612672022-03-17 21:58:56 +01002//
Nicolas Vasilache11fb8d02022-03-21 20:33:30 +01003// Licensed under the Apache License v2.0 with LLVM Exceptions.
Nicolas Vasilachef3612672022-03-17 21:58:56 +01004// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Nicolas Vasilachef3612672022-03-17 21:58:56 +01006
7#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
8#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
9#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
10#include "mlir/Dialect/Linalg/Utils/Utils.h"
11#include "mlir/Dialect/Tensor/IR/Tensor.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/BuiltinOps.h"
14#include "mlir/IR/Operation.h"
15#include "mlir/IR/OperationSupport.h"
16#include "mlir/IR/PatternMatch.h"
17#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18
19using namespace mlir;
20using namespace mlir::iree_compiler::IREE::LinalgExt;
21
22struct TilingResult {
23 TileOp tileOp;
24 Operation *tiledOp;
25};
26
27static TilingResult tileToTileOp(PatternRewriter &rewriter, TilingInterface op,
28 int64_t tiledDim, Value tileSize) {
29 Location loc = op->getLoc();
30 OpBuilder::InsertionGuard g(rewriter);
31 // TODO: Handle the case where the `loopRanges` are empty.
32 SmallVector<Range> loopRanges = op.getIterationDomain(rewriter);
33 assert(loopRanges.size() >= 1 &&
34 "expected at least a single loop in operation");
35 auto destOperands = op.getDestinationOperands(rewriter);
36 Operation *tiledOp = nullptr;
37 auto tileOp = rewriter.create<TileOp>(
38 loc, tileSize, destOperands, tiledDim,
39 [&](OpBuilder &b, Location loc, Value offset, Value size,
40 ValueRange outSlices) {
41 // TODO: support `getTiledImplementation` with >1 produced tiled ops.
42 int64_t nLoops = loopRanges.size();
43 SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
44 tiledOffsets.reserve(nLoops);
45 tiledSizes.reserve(nLoops);
46 for (unsigned i = 0; i < nLoops; ++i) {
47 if (i == tiledDim) {
48 tiledOffsets.push_back(offset);
49 tiledSizes.push_back(size);
50 } else {
51 tiledOffsets.push_back(loopRanges[i].offset);
52 tiledSizes.push_back(loopRanges[i].size);
53 }
54 }
55 SmallVector<Operation *> tiledOps = op.getTiledImplementation(
56 b, outSlices, tiledOffsets, tiledSizes, /*tileDestOperands=*/false);
57 assert(tiledOps.size() == 1 && "expected single tiled op");
58 tiledOp = tiledOps.front();
59 b.create<TileYieldOp>(loc, tiledOp->getResults());
60 });
61 return TilingResult{tileOp, tiledOp};
62}
63
Nicolas Vasilache11fb8d02022-03-21 20:33:30 +010064FailureOr<Operation *>
65mlir::iree_compiler::IREE::LinalgExt::LinalgExtTilingPattern::
66 returningMatchAndRewrite(TilingInterface op,
67 PatternRewriter &rewriter) const {
Nicolas Vasilachef3612672022-03-17 21:58:56 +010068 /// Currently only handle single result operations.
69 if (op->getNumResults() != 1)
70 return rewriter.notifyMatchFailure(op, "Not a single result");
71
72 // Get rank and tile sizes.
73 // TODO: consider moving these checks to a common place that the TransformOp
74 // verifier can also use.
75 SmallVector<Value> tileSizes =
76 options.tileSizeComputationFunction(rewriter, op);
77 int64_t dim = -1;
78 for (auto en : llvm::enumerate(tileSizes)) {
79 Optional<int64_t> maybeTileSize = getConstantIntValue(en.value());
Nicolas Vasilache11fb8d02022-03-21 20:33:30 +010080 if (maybeTileSize && *maybeTileSize == 0)
81 continue;
Nicolas Vasilachef3612672022-03-17 21:58:56 +010082 if (maybeTileSize && *maybeTileSize < 0)
83 return rewriter.notifyMatchFailure(op, "Negative tile size");
84 if (dim >= 0)
85 return rewriter.notifyMatchFailure(op,
86 "Could not find a single tiling dim");
87 dim = en.index();
88 }
89 if (dim < 0)
90 return rewriter.notifyMatchFailure(op,
91 "Could not find a single tiling dim");
92
93 /// Currently only handle tiling operations on a parallel iterator type.
94 auto loopIteratorTypes = op.getLoopIteratorTypes();
95 // Scalar operation, nothing to do, so just return.
96 if (loopIteratorTypes.empty())
97 return rewriter.notifyMatchFailure(op, "Scalar op, no tiling possible");
98 ArrayRef<StringRef> loopIteratorTypesRef(loopIteratorTypes);
99 if (loopIteratorTypesRef[dim] != getParallelIteratorTypeName())
100 return rewriter.notifyMatchFailure(op, "Trying to tile a non-parallel dim");
101
102 TilingResult tilingResult = tileToTileOp(rewriter, op, dim, tileSizes[dim]);
103 rewriter.replaceOp(op, tilingResult.tileOp->getResults());
104
105 return tilingResult.tiledOp;
106}