Update LinalgExt to encompass iree-llvm-sandbox additions. (#8554)
Moves LinalgExt related developments in `iree-llvm-sandbox` into `iree-dialects`. With this revision, iree-llvm-sandbox is able to integrate more closely with IREE.
See PSA https://github.com/google/iree-llvm-sandbox/issues/373
Co-authored-by: Mahesh Ravishankar <ravishankarm@google.com>
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TilingToTileOp.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TilingToTileOp.cpp
new file mode 100644
index 0000000..ba8cc4d
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TilingToTileOp.cpp
@@ -0,0 +1,106 @@
+//===- TilingToTileOp.cpp - Tiling using to TileOp TilingInterface --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::iree_compiler::IREE::LinalgExt;
+
+struct TilingResult {
+ TileOp tileOp;
+ Operation *tiledOp;
+};
+
+static TilingResult tileToTileOp(PatternRewriter &rewriter, TilingInterface op,
+ int64_t tiledDim, Value tileSize) {
+ Location loc = op->getLoc();
+ OpBuilder::InsertionGuard g(rewriter);
+ // TODO: Handle the case where the `loopRanges` are empty.
+ SmallVector<Range> loopRanges = op.getIterationDomain(rewriter);
+ assert(loopRanges.size() >= 1 &&
+ "expected at least a single loop in operation");
+ auto destOperands = op.getDestinationOperands(rewriter);
+ Operation *tiledOp = nullptr;
+ auto tileOp = rewriter.create<TileOp>(
+ loc, tileSize, destOperands, tiledDim,
+ [&](OpBuilder &b, Location loc, Value offset, Value size,
+ ValueRange outSlices) {
+ // TODO: support `getTiledImplementation` with >1 produced tiled ops.
+ int64_t nLoops = loopRanges.size();
+ SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
+ tiledOffsets.reserve(nLoops);
+ tiledSizes.reserve(nLoops);
+ for (unsigned i = 0; i < nLoops; ++i) {
+ if (i == tiledDim) {
+ tiledOffsets.push_back(offset);
+ tiledSizes.push_back(size);
+ } else {
+ tiledOffsets.push_back(loopRanges[i].offset);
+ tiledSizes.push_back(loopRanges[i].size);
+ }
+ }
+ SmallVector<Operation *> tiledOps = op.getTiledImplementation(
+ b, outSlices, tiledOffsets, tiledSizes, /*tileDestOperands=*/false);
+ assert(tiledOps.size() == 1 && "expected single tiled op");
+ tiledOp = tiledOps.front();
+ b.create<TileYieldOp>(loc, tiledOp->getResults());
+ });
+ return TilingResult{tileOp, tiledOp};
+}
+
+FailureOr<Operation *> mlir::iree_compiler::IREE::LinalgExt::
+ LinalgExtTilingPattern::returningMatchAndRewrite(
+ TilingInterface op, PatternRewriter &rewriter) const {
+ /// Currently only handle single result operations.
+ if (op->getNumResults() != 1)
+ return rewriter.notifyMatchFailure(op, "Not a single result");
+
+ // Get rank and tile sizes.
+ // TODO: consider moving these checks to a common place that the TransformOp
+ // verifier can also use.
+ SmallVector<Value> tileSizes =
+ options.tileSizeComputationFunction(rewriter, op);
+ int64_t dim = -1;
+ for (auto en : llvm::enumerate(tileSizes)) {
+ Optional<int64_t> maybeTileSize = getConstantIntValue(en.value());
+ if (maybeTileSize && *maybeTileSize == 0) continue;
+ if (maybeTileSize && *maybeTileSize < 0)
+ return rewriter.notifyMatchFailure(op, "Negative tile size");
+ if (dim >= 0)
+ return rewriter.notifyMatchFailure(op,
+ "Could not find a single tiling dim");
+ dim = en.index();
+ }
+ if (dim < 0)
+ return rewriter.notifyMatchFailure(op,
+ "Could not find a single tiling dim");
+
+ /// Currently only handle tiling operations on a parallel iterator type.
+ auto loopIteratorTypes = op.getLoopIteratorTypes();
+ // Scalar operation, nothing to do, so just return.
+ if (loopIteratorTypes.empty())
+ return rewriter.notifyMatchFailure(op, "Scalar op, no tiling possible");
+ ArrayRef<StringRef> loopIteratorTypesRef(loopIteratorTypes);
+ if (loopIteratorTypesRef[dim] != getParallelIteratorTypeName())
+ return rewriter.notifyMatchFailure(op, "Trying to tile a non-parallel dim");
+
+ TilingResult tilingResult = tileToTileOp(rewriter, op, dim, tileSizes[dim]);
+ rewriter.replaceOp(op, tilingResult.tileOp->getResults());
+
+ return tilingResult.tiledOp;
+}