Update LinalgExt to encompass iree-llvm-sandbox additions. (#8554)
Moves LinalgExt related developments in `iree-llvm-sandbox` into `iree-dialects`. With this revision, iree-llvm-sandbox is able to integrate more closely with IREE.
See PSA https://github.com/google/iree-llvm-sandbox/issues/373
Co-authored-by: Mahesh Ravishankar <ravishankarm@google.com>
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/CMakeLists.txt
index 9f57627..126b878 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(IR)
+add_subdirectory(Passes)
add_subdirectory(Transforms)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index b692d30..cabc5c4 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -24,6 +24,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
@@ -1242,6 +1243,388 @@
} // namespace
//===----------------------------------------------------------------------===//
+// TileOp
+//===----------------------------------------------------------------------===//
+
+void TileOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
+ Value tileSize, ValueRange outs, int64_t tiledDim,
+ TileOp::TileOpBodyBuilderFn bodyBuilder) {
+ result.addOperands(tileSize);
+ result.addOperands(outs);
+ result.addAttribute(TileOp::getTiledDimAttrName(),
+ builder.getI64IntegerAttr(tiledDim));
+ result.addTypes(outs.getType());
+
+ Region *bodyRegion = result.addRegion();
+ bodyRegion->push_back(new Block);
+ Block &bodyBlock = bodyRegion->front();
+ // TODO: Pass a better location here.
+ Location loc = tileSize.getLoc();
+ bodyBlock.addArgument(builder.getIndexType(), loc);
+ bodyBlock.addArgument(builder.getIndexType(), loc);
+ // Handle the sliced out types in a conservative fashion: all dimensions
+ // become dynamic and a later canonicalization is expected to recover static
+ // types.
+ // TODO: should we relax this and use something less strict?
+ auto dynamicTypes =
+ llvm::to_vector(llvm::map_range(outs.getTypes(), [](Type t) -> Type {
+ auto rankedTensorType = t.cast<RankedTensorType>();
+ RankedTensorType::Builder rttb(rankedTensorType);
+ SmallVector<int64_t> dynamicShape(rankedTensorType.getRank(),
+ ShapedType::kDynamicSize);
+ return rttb.setShape(dynamicShape);
+ }));
+ SmallVector<Location> locs(dynamicTypes.size(), loc);
+ bodyBlock.addArguments(dynamicTypes, locs);
+
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(&bodyBlock);
+ bodyBuilder(builder, result.location, bodyBlock.getArgument(0),
+ bodyBlock.getArgument(1), bodyBlock.getArguments().drop_front(2));
+}
+
+void TileOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
+ Value tileSize, ValueRange outs,
+ TileOp::TileOpBodyBuilderFn bodyBuilder) {
+ TileOp::build(builder, result, tileSize, outs, 0, bodyBuilder);
+}
+
+// TODO(#81): Impl me.
+LogicalResult TileOp::verify() { return success(); }
+
+void TileOp::print(OpAsmPrinter &p) {
+ p << ' ' << tile_size() << ' ';
+ if (tiled_dim() > 0) p << "tiled_dim = " << tiled_dim() << ' ';
+ if (!outs().empty()) {
+ p << "outs(";
+ llvm::interleaveComma(outs(), p,
+ [&p](Value v) { p << v << ": " << v.getType(); });
+ p << ')';
+ }
+ p << " -> (" << getResultTypes() << ") ";
+ p.printRegion(region(),
+ /*printEntryBlockArgs=*/true,
+ /*printBlockTerminators=*/true);
+ p.printOptionalAttrDict(getOperation()->getAttrs(),
+ /*elidedAttrs=*/{TileOp::getTiledDimAttrName()});
+}
+
+ParseResult TileOp::parse(OpAsmParser &parser, OperationState &result) {
+ auto &builder = parser.getBuilder();
+
+ OpAsmParser::OperandType tileSizes;
+ // TODO: also allow tensor<..xindex> and figure out a good syntax.
+ // Type tensorOfIndexType =
+ // RankedTensorType::get({ShapedType::kDynamicSize}, indexType);
+ Type tileSizesType = builder.getIndexType();
+ SmallVector<Type> outsTypes;
+ SmallVector<OpAsmParser::OperandType, 4> outsOperands;
+
+ llvm::SMLoc outputsOperandsLoc;
+ if (parser.parseOperand(tileSizes) ||
+ parser.resolveOperand(tileSizes, tileSizesType, result.operands))
+ return failure();
+
+ // Parse the `tiled_dim` attribute or set it to 0 implicitly when elided.
+ if (succeeded(parser.parseOptionalKeyword(TileOp::getTiledDimAttrName()))) {
+ outputsOperandsLoc = parser.getCurrentLocation();
+ Attribute valueAttr;
+ parser.parseAttribute(valueAttr, TileOp::getTiledDimAttrName(),
+ result.attributes);
+ } else {
+ result.attributes.append(TileOp::getTiledDimAttrName(),
+ parser.getBuilder().getI64IntegerAttr(0));
+ }
+
+ if (succeeded(parser.parseOptionalKeyword("outs"))) {
+ bool _1;
+ SmallVector<NamedAttrList> _2;
+ SmallVector<Location> _3;
+ outputsOperandsLoc = parser.getCurrentLocation();
+ if (mlir::function_interface_impl::parseFunctionArgumentList(
+ parser,
+ /*allowAttributes=*/false,
+ /*allowVariadic=*/false, outsOperands, outsTypes, /*argAttrs=*/_2,
+ /*argLocations=*/_3,
+ /*isVariadic=*/_1) ||
+ parser.resolveOperands(outsOperands, outsTypes, outputsOperandsLoc,
+ result.operands))
+ return failure();
+ }
+ if (parser.parseArrowTypeList(result.types)) return failure();
+
+ SmallVector<OpAsmParser::OperandType, 8> regionOperands;
+ std::unique_ptr<Region> region = std::make_unique<Region>();
+ SmallVector<Type, 8> operandTypes, regionTypes;
+ if (parser.parseRegion(*region, regionOperands, regionTypes))
+ return failure();
+
+ // Parse the optional attribute list.
+ if (parser.parseOptionalAttrDict(result.attributes)) return failure();
+
+ TileOp::ensureTerminator(*region, builder, result.location);
+ result.addRegion(std::move(region));
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// InParallelOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult InParallelOp::verify() {
+ // Check that the body defines as single block argument for the thread index.
+ auto *body = getBody();
+ if (body->getNumArguments() != 1)
+ return emitOpError("body expects exactly one argument");
+ if (!body->getArgument(0).getType().isIndex())
+ return emitOpError(
+ "expected body first argument to be an index argument for "
+ "the thread index");
+
+ // Verify consistency between the result types and the terminator.
+ auto terminatorTypes = getTerminator().yieldedTypes();
+ auto opResults = getResults();
+ if (opResults.size() != terminatorTypes.size())
+ return emitOpError("produces ")
+ << opResults.size() << " results, but its terminator yields "
+ << terminatorTypes.size() << " values";
+ unsigned i = 0;
+ for (auto e : llvm::zip(terminatorTypes, opResults)) {
+ if (std::get<0>(e) != std::get<1>(e).getType())
+ return emitOpError() << "type mismatch between " << i
+ << "th result of in_parallel (" << std::get<0>(e)
+ << ") and " << i << "th result yielded by its "
+ << "terminator (" << std::get<1>(e).getType() << ")";
+ i++;
+ }
+
+ return success();
+}
+
+void InParallelOp::print(OpAsmPrinter &p) {
+ p << ' ' << num_threads() << ' ';
+ p << " -> (" << getResultTypes() << ") ";
+ p.printRegion(region(),
+ /*printEntryBlockArgs=*/true,
+ /*printBlockTerminators=*/true);
+ p.printOptionalAttrDict(getOperation()->getAttrs());
+}
+
+ParseResult InParallelOp::parse(OpAsmParser &parser, OperationState &result) {
+ auto &builder = parser.getBuilder();
+
+ OpAsmParser::OperandType numThreads;
+ Type indexType = builder.getIndexType();
+
+ if (parser.parseOperand(numThreads) ||
+ parser.resolveOperand(numThreads, indexType, result.operands))
+ return failure();
+ if (parser.parseArrowTypeList(result.types)) return failure();
+
+ SmallVector<OpAsmParser::OperandType, 8> regionOperands;
+ SmallVector<Type, 8> regionTypes;
+ std::unique_ptr<Region> region = std::make_unique<Region>();
+ if (parser.parseRegion(*region, regionOperands, regionTypes))
+ return failure();
+ InParallelOp::ensureTerminator(*region, builder, result.location);
+ result.addRegion(std::move(region));
+
+ // Parse the optional attribute list.
+ if (parser.parseOptionalAttrDict(result.attributes)) return failure();
+ return success();
+}
+
+// Bodyless builder, result types must be specified.
+void InParallelOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
+ TypeRange resultTypes, Value numThreads) {
+ // TODO: Pass better location.
+ Location loc = numThreads.getLoc();
+ result.addOperands(numThreads);
+
+ Region *bodyRegion = result.addRegion();
+ bodyRegion->push_back(new Block);
+ Block &bodyBlock = bodyRegion->front();
+ bodyBlock.addArgument(builder.getIndexType(), loc);
+
+ // Create the default terminator if the builder is not provided and if the
+ // iteration arguments are not provided. Otherwise, leave this to the caller
+ // because we don't know which values to return from the loop.
+ InParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
+ result.addTypes(resultTypes);
+}
+
+// Builder that takes a bodyBuilder lambda, result types are inferred from
+// the terminator.
+void InParallelOp::build(
+ mlir::OpBuilder &builder, mlir::OperationState &result, Value numThreads,
+ function_ref<void(OpBuilder &, Location, Value)> bodyBuilder) {
+ // TODO: Pass better location.
+ Location loc = numThreads.getLoc();
+ result.addOperands(numThreads);
+
+ Region *bodyRegion = result.addRegion();
+ bodyRegion->push_back(new Block);
+ Block &bodyBlock = bodyRegion->front();
+ bodyBlock.addArgument(builder.getIndexType(), loc);
+
+ OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(&bodyBlock);
+ bodyBuilder(builder, result.location, bodyBlock.getArgument(0));
+ auto terminator =
+ llvm::cast<PerformConcurrentlyOp>(bodyBlock.getTerminator());
+ result.addTypes(terminator.yieldedTypes());
+}
+
+// The ensureTerminator method generated by SingleBlockImplicitTerminator is
+// unaware of the fact that our terminator also needs a region to be well
+// formed. We override it here to ensure that we do the right thing.
+void InParallelOp::ensureTerminator(Region ®ion, Builder &builder,
+ Location loc) {
+ OpTrait::SingleBlockImplicitTerminator<PerformConcurrentlyOp>::Impl<
+ InParallelOp>::ensureTerminator(region, builder, loc);
+ auto terminator =
+ llvm::dyn_cast<PerformConcurrentlyOp>(region.front().getTerminator());
+ PerformConcurrentlyOp::ensureTerminator(terminator.getRegion(), builder, loc);
+}
+
+PerformConcurrentlyOp InParallelOp::getTerminator() {
+ return cast<PerformConcurrentlyOp>(getBody()->getTerminator());
+}
+
+//===----------------------------------------------------------------------===//
+// ParallelInsertSliceOp
+//===----------------------------------------------------------------------===//
+
+// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
+void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
+ Value source, Value dest,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ ArrayRef<OpFoldResult> strides,
+ ArrayRef<NamedAttribute> attrs) {
+ SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
+ SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
+ ShapedType::kDynamicStrideOrOffset);
+ dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
+ ShapedType::kDynamicSize);
+ dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
+ ShapedType::kDynamicStrideOrOffset);
+ build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
+ dynamicStrides, b.getI64ArrayAttr(staticOffsets),
+ b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
+ result.addAttributes(attrs);
+}
+
+// Build a ParallelInsertSliceOp with dynamic entries.
+void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
+ Value source, Value dest, ValueRange offsets,
+ ValueRange sizes, ValueRange strides,
+ ArrayRef<NamedAttribute> attrs) {
+ SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
+ llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
+ SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
+ llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
+ SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
+ llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
+ build(b, result, source, dest, offsetValues, sizeValues, strideValues);
+}
+
+namespace {
+/// Pattern to rewrite a parallel_insert_slice op with constant arguments.
+class ParallelInsertSliceOpConstantArgumentFolder final
+ : public OpRewritePattern<ParallelInsertSliceOp> {
+ public:
+ using OpRewritePattern<ParallelInsertSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ParallelInsertSliceOp insertSliceOp,
+ PatternRewriter &rewriter) const override {
+ // No constant operand, just return.
+ if (llvm::none_of(insertSliceOp.getOperands(), [](Value operand) {
+ return matchPattern(operand, matchConstantIndex());
+ }))
+ return failure();
+
+ // At least one of offsets/sizes/strides is a new constant.
+ // Form the new list of operands and constant attributes from the
+ // existing.
+ SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
+ SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
+ SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
+ canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
+ canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
+ canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
+
+ // Create the new op in canonical form.
+ rewriter.replaceOpWithNewOp<ParallelInsertSliceOp>(
+ insertSliceOp, insertSliceOp.source(), insertSliceOp.dest(),
+ mixedOffsets, mixedSizes, mixedStrides);
+ return success();
+ }
+};
+} // namespace
+
+void ParallelInsertSliceOp::getCanonicalizationPatterns(
+ RewritePatternSet &results, MLIRContext *context) {
+ results.add<ParallelInsertSliceOpConstantArgumentFolder>(context);
+}
+
+//===----------------------------------------------------------------------===//
+// PerformConcurrentlyOp
+//===----------------------------------------------------------------------===//
+
+// TODO(ntv,apaszke): Implement this
+LogicalResult PerformConcurrentlyOp::verify() { return success(); }
+
+void PerformConcurrentlyOp::print(OpAsmPrinter &p) {
+ p << " ";
+ p.printRegion(region(),
+ /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/false);
+ p.printOptionalAttrDict(getOperation()->getAttrs());
+}
+
+ParseResult PerformConcurrentlyOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ auto &builder = parser.getBuilder();
+
+ SmallVector<OpAsmParser::OperandType, 8> regionOperands;
+ SmallVector<Type, 8> regionTypes;
+ std::unique_ptr<Region> region = std::make_unique<Region>();
+ if (parser.parseRegion(*region, regionOperands, regionTypes))
+ return failure();
+ PerformConcurrentlyOp::ensureTerminator(*region, builder, result.location);
+ result.addRegion(std::move(region));
+
+ // Parse the optional attribute list.
+ if (parser.parseOptionalAttrDict(result.attributes)) return failure();
+ return success();
+}
+
+SmallVector<Type> PerformConcurrentlyOp::yieldedTypes() {
+ return llvm::to_vector(llvm::map_range(
+ this->yieldingOps(),
+ [](ParallelInsertSliceOp op) { return op.yieldedType(); }));
+}
+
+SmallVector<ParallelInsertSliceOp> PerformConcurrentlyOp::yieldingOps() {
+ SmallVector<ParallelInsertSliceOp> ret;
+ for (Operation &op : *getBody()) {
+ // TODO: interface when this grows up.
+ if (auto sliceOp = llvm::dyn_cast<ParallelInsertSliceOp>(op)) {
+ ret.push_back(sliceOp);
+ continue;
+ }
+ if (auto endPerformOp = llvm::dyn_cast<EndPerformConcurrentlyOp>(op)) {
+ continue;
+ }
+ llvm_unreachable("Unexpected operation in perform_concurrently");
+ }
+ return ret;
+}
+
+//===----------------------------------------------------------------------===//
// LinalgExtDialect
//===----------------------------------------------------------------------===//
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt
new file mode 100644
index 0000000..e26003e
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/CMakeLists.txt
@@ -0,0 +1,25 @@
+add_mlir_library(IREELinalgExtPasses
+ ConvertToLoops.cpp
+ PadContractionToBlockSize.cpp
+ Passes.cpp
+ Tiling.cpp
+
+ DEPENDS
+ IREELinalgExtPassesIncGen
+
+ LINK_LIBS PUBLIC
+ IREEInputDialect
+ IREELinalgExtDialect
+ MLIRAffine
+ MLIRIR
+ MLIRLinalg
+ MLIRLinalgTransforms
+ MLIRMath
+ MLIRMemRef
+ MLIRPass
+ MLIRSCF
+ MLIRFunc
+ MLIRSupport
+ MLIRTensor
+ MLIRTransforms
+)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/ConvertToLoops.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertToLoops.cpp
similarity index 96%
rename from llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/ConvertToLoops.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertToLoops.cpp
index 52c9dcf..da62126 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/ConvertToLoops.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/ConvertToLoops.cpp
@@ -6,8 +6,8 @@
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
-#include "iree-dialects/Dialect/LinalgExt/Transforms/PassDetail.h"
-#include "iree-dialects/Dialect/LinalgExt/Transforms/Passes.h"
+#include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h"
+#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/PadContractionToBlockSize.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/PadContractionToBlockSize.cpp
similarity index 97%
rename from llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/PadContractionToBlockSize.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/PadContractionToBlockSize.cpp
index b050cc7..a2fe9bd 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/PadContractionToBlockSize.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/PadContractionToBlockSize.cpp
@@ -6,8 +6,8 @@
#include "iree-dialects/Dialect/Input/InputDialect.h"
#include "iree-dialects/Dialect/Input/InputOps.h"
-#include "iree-dialects/Dialect/LinalgExt/Transforms/PassDetail.h"
-#include "iree-dialects/Dialect/LinalgExt/Transforms/Passes.h"
+#include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h"
+#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Passes.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/Passes.cpp
similarity index 82%
rename from llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Passes.cpp
rename to llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/Passes.cpp
index c41b9ed..f038541 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Passes.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/Passes.cpp
@@ -4,7 +4,7 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-#include "iree-dialects/Dialect/LinalgExt/Transforms/Passes.h"
+#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
@@ -20,7 +20,7 @@
namespace detail {
#define GEN_PASS_REGISTRATION
-#include "iree-dialects/Dialect/LinalgExt/Transforms/Passes.h.inc" // IWYU pragma: export
+#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h.inc" // IWYU pragma: export
} // namespace detail
} // namespace LinalgExt
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/Tiling.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/Tiling.cpp
new file mode 100644
index 0000000..fd66bff
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/Tiling.cpp
@@ -0,0 +1,360 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree-dialects/Dialect/Input/InputDialect.h"
+#include "iree-dialects/Dialect/Input/InputOps.h"
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree-dialects/Dialect/LinalgExt/Passes/PassDetail.h"
+#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
+#include "iree-dialects/Dialect/LinalgExt/Passes/Transforms.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+namespace IREE = mlir::iree_compiler::IREE;
+using namespace IREE::LinalgExt;
+
+//===----------------------------------------------------------------------===//
+// Utility methods for tiling a linalg_ext operation that implements a
+// TiledOpInterface
+//===----------------------------------------------------------------------===//
+
+/// Returns failure if the options are unsupported.
+static LogicalResult verifySupportedTilingOptions(
+ PatternRewriter &rewriter, Operation *op,
+ const linalg::LinalgTilingOptions &options) {
+ if (!options.interchangeVector.empty()) {
+ return rewriter.notifyMatchFailure(op,
+ "unsupported interchange during tiling");
+ }
+ if (options.loopType != linalg::LinalgTilingLoopType::Loops) {
+ return rewriter.notifyMatchFailure(op,
+ "only tiling with scf.for is supported");
+ }
+ if (options.distribution) {
+ if (llvm::any_of(options.distribution->distributionMethod,
+ [](linalg::DistributionMethod method) {
+ return method != linalg::DistributionMethod::Cyclic;
+ })) {
+ return rewriter.notifyMatchFailure(op,
+ "only cyclic distibution is allowed");
+ }
+ }
+ return success();
+}
+
+/// Converts an `OpFoldResult` to a `Value` by building a constant op if
+/// if the `OpFoldResult` is an `IntegerAttr`.
+static Value getValue(OpBuilder &builder, Location loc,
+ OpFoldResult valueOrAttr) {
+ if (auto attr = valueOrAttr.dyn_cast<Attribute>()) {
+ return builder.create<arith::ConstantIndexOp>(
+ loc, attr.cast<IntegerAttr>().getInt());
+ }
+ return valueOrAttr.get<Value>();
+}
+
+/// Returns true if loop is untiled. Only checks if the value is statically
+/// zero. It is assumed that a `Value` defined by a constant op is already
+/// converted to an `IntegerAttr` of that value. So here just return true if
+/// this is an attribute with a zero value.
+static bool isUntiledLoop(OpFoldResult valueOrAttr) {
+ Optional<int64_t> intVal = getConstantIntValue(valueOrAttr);
+ return intVal && *intVal == 0;
+}
+
+/// Generates the tiled loops and the body by invoking the interface methods of
+/// TiledOpInterface.
+/// - `outputs` are the operands to use for outputs of the tiled operation.
+/// - `tileSizes` are tile sizes specified for all loops of the operation. If a
+/// loop is to be untiled it is set to 0.
+/// - `iteratorType` is the type of the loop iterator returned by the
+/// TiledOpInterface.
+/// - `loopBounds` are the bounds of all the loops of the op returned by the
+/// TiledOpInterface.
+/// - `loopDepth` is the current loop depth being processed.
+/// - `offsets` are the `Value`s that represent the position of the tile being
+/// operated on. The offsets are computed as the tiled loops are being
+/// generated.
+/// - `distributionInfo` is the proc_id and nprocs `Value`s to be used for
+/// distributed loops. It is a stack, and once an entry at the top of the
+/// stack is used for distribution it is popped before processing the inner
+/// loops.
+static FailureOr<TiledOp> tileInterfaceOpImpl(
+ OpBuilder &builder, TiledOpInterface tilableOp, ValueRange outputs,
+ MutableArrayRef<OpFoldResult> tileSizes, ArrayRef<StringRef> iteratorTypes,
+ ArrayRef<Range> loopBounds, unsigned loopDepth,
+ SmallVectorImpl<OpFoldResult> &offsets,
+ ArrayRef<linalg::ProcInfo> distributionInfo) {
+ Location loc = tilableOp.getLoc();
+ // If this is the innermost loop, then generated the tiled implementation of
+ // the op by invoking the TiledOpInterface methods.
+ if (loopDepth == tileSizes.size()) {
+ TiledOp ret;
+ ret.op = tilableOp.getTiledImplementation(builder, outputs, offsets,
+ tileSizes, ret.results);
+ if (!ret.op) {
+ return static_cast<LogicalResult>(
+ tilableOp.emitOpError("failed to get tiled implementation"));
+ }
+ return ret;
+ }
+
+ // If tile size at this depth is empty, do nothing.
+ if (isUntiledLoop(tileSizes[loopDepth])) {
+ auto zeroAttr = builder.getI64IntegerAttr(0);
+ offsets.push_back(zeroAttr);
+ assert(matchPattern(loopBounds[loopDepth].offset, m_Zero()) &&
+ "expected loop bounds to have lower bound of zero");
+ tileSizes[loopDepth] = getAsOpFoldResult(loopBounds[loopDepth].size);
+ return tileInterfaceOpImpl(builder, tilableOp, outputs, tileSizes,
+ iteratorTypes, loopBounds, loopDepth + 1,
+ offsets, distributionInfo);
+ }
+
+ // Generate an scf.for for the current loop depth.
+ Value lb = loopBounds[loopDepth].offset;
+ Value ub = loopBounds[loopDepth].size;
+ // TODO(#7073): Put the check back. This is required by tiling linalg_ext.fft
+ // op. We can put the check back after updating linalg_ext.fft semantics.
+ // if (!matchPattern(loopBounds[loopDepth].stride, m_One())) {
+ // return static_cast<LogicalResult>(
+ // tilableOp.emitOpError("expected stride to be 1"));
+ //}
+ Value step = getValue(builder, loc, tileSizes[loopDepth]);
+
+ // Update lb, ub and step for cyclic distribution.
+ if (!distributionInfo.empty() &&
+ iteratorTypes[loopDepth] == getParallelIteratorTypeName()) {
+ linalg::updateBoundsForCyclicDistribution(
+ builder, loc, distributionInfo.front().procId,
+ distributionInfo.front().nprocs, lb, ub, step);
+ distributionInfo = distributionInfo.drop_front();
+ }
+ FailureOr<TiledOp> innerReturnValue;
+ bool isBufferTiling = tilableOp->getNumResults() == 0;
+ ValueRange initValues(isBufferTiling ? ValueRange{} : outputs);
+ auto forOp = builder.create<scf::ForOp>(
+ loc, lb, ub, step, initValues,
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
+ offsets.push_back(iv);
+ auto affineMaps = AffineMap::inferFromExprList({ArrayRef<AffineExpr>{
+ b.getAffineSymbolExpr(0),
+ b.getAffineSymbolExpr(1) - b.getAffineDimExpr(0)}})[0];
+ // Similar to linalg tiling, the tile size is the min(tileSizes, ub -
+ // iv) to account for cases where tile size does not divide (ub - lb)
+ // exactly.
+ Value inBoundsTileSize = b.create<AffineMinOp>(
+ loc, affineMaps,
+ ValueRange{iv, getValue(builder, loc, tileSizes[loopDepth]), ub});
+ tileSizes[loopDepth] = getAsOpFoldResult(inBoundsTileSize);
+ // Recursively proceed to generate the tiled loop for the next level.
+ innerReturnValue =
+ tileInterfaceOpImpl(b, tilableOp, (isBufferTiling ? outputs : args),
+ tileSizes, iteratorTypes, loopBounds,
+ loopDepth + 1, offsets, distributionInfo);
+ if (failed(innerReturnValue)) return;
+ b.create<scf::YieldOp>(loc, innerReturnValue->results);
+ });
+ if (failed(innerReturnValue)) {
+ return innerReturnValue;
+ }
+ innerReturnValue->loops.insert(innerReturnValue->loops.begin(),
+ forOp.getOperation());
+ innerReturnValue->results = forOp.getResults();
+ return innerReturnValue;
+}
+
+FailureOr<TiledOp> tileInterfaceOp(OpBuilder &b, TiledOpInterface tilableOp,
+ const linalg::LinalgTilingOptions &options) {
+ SmallVector<Value> dest = tilableOp.getDestinationOperands(b);
+ if (dest.empty()) {
+ return static_cast<LogicalResult>(tilableOp.emitOpError(
+ "cannot tile operation without destination operands"));
+ }
+
+ SmallVector<StringRef> iteratorTypes = tilableOp.getLoopIteratorTypes();
+ SmallVector<Value, 4> tileSizesVals =
+ options.tileSizeComputationFunction(b, tilableOp);
+ auto zeroAttr = b.getI64IntegerAttr(0);
+
+ // The actual tile sizes used converts `Value` defined as constant 0, to a
+ // zero integer attributes. Currently if the iterator type is not "parallel",
+ // the tile size is forced to zero as well.
+ auto tileSizes = getAsOpFoldResult(tileSizesVals);
+ tileSizes.resize(iteratorTypes.size(), zeroAttr);
+ for (auto en : llvm::enumerate(iteratorTypes)) {
+ if (en.value() == getParallelIteratorTypeName()) continue;
+ if (!isUntiledLoop(tileSizes[en.index()])) {
+ return static_cast<LogicalResult>(tilableOp.emitOpError(
+ "unimplemented tiling of non-parallel loop iterator type"));
+ }
+ }
+
+ // Trivial early exit case of tile sizes being zero for all parallel loops.
+ if (llvm::all_of(tileSizes, isUntiledLoop)) {
+ return TiledOp{tilableOp, {}, {}};
+ }
+
+ SmallVector<Range> loopBounds = tilableOp.getIterationDomain(b);
+ SmallVector<linalg::ProcInfo> distributionInfo;
+ // If the tiled loops are distributed, get the proc_id and nprocs for the
+ // distributed loops. First collect the parallel loops by iterating over the
+ // tileSizes and getting the loops that are distribute, i.e.,
+ // - parallel, i.e. iteratorTypes is "parallel"
+ // - tiled, i.e. tileSize != 0
+ if (options.distribution) {
+ SmallVector<Range> distributedLoopRange;
+ for (auto i : llvm::seq<unsigned>(0, tileSizes.size())) {
+ if (isUntiledLoop(tileSizes[i])) continue;
+ if (iteratorTypes[i] != getParallelIteratorTypeName()) continue;
+ distributedLoopRange.push_back(loopBounds[i]);
+ }
+ distributionInfo = options.distribution->procInfo(b, tilableOp.getLoc(),
+ distributedLoopRange);
+ }
+
+ SmallVector<OpFoldResult> offsets;
+ return tileInterfaceOpImpl(b, tilableOp, dest, tileSizes, iteratorTypes,
+ loopBounds, 0, offsets, distributionInfo);
+}
+
+LogicalResult TiledOpInterfaceBaseTilingPattern::matchAndRewriteBase(
+ TiledOpInterface tilableOp, PatternRewriter &rewriter,
+ TiledOp &result) const {
+ if (failed(filter.checkAndNotify(rewriter, tilableOp))) {
+ return failure();
+ }
+ if (failed(verifySupportedTilingOptions(rewriter, tilableOp, options))) {
+ return failure();
+ }
+
+ FailureOr<TiledOp> res = tileInterfaceOp(rewriter, tilableOp, options);
+ if (failed(res)) return res;
+ result = *res;
+ if (result.op) {
+ filter.replaceLinalgTransformationFilter(rewriter, result.op);
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Test pass for tiling Linalg Ext ops
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct TiledOpInterfaceTilingPass
+ : public TiledOpInterfaceTilingBase<TiledOpInterfaceTilingPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<
+ AffineDialect, IREE::Input::IREEInputDialect, linalg::LinalgDialect,
+ IREE::LinalgExt::IREELinalgExtDialect, memref::MemRefDialect,
+ func::FuncDialect, mlir::arith::ArithmeticDialect, math::MathDialect,
+ tensor::TensorDialect, scf::SCFDialect>();
+ }
+ void runOnOperation() override;
+};
+} // namespace
+
+template <typename OpTy>
+static Value buildFlowWorkgroupInfoOp(OpBuilder &b, unsigned dim) {
+ return b.template create<OpTy>(b.getInsertionPoint()->getLoc(), dim);
+}
+
+void TiledOpInterfaceTilingPass::runOnOperation() {
+ FuncOp funcOp = getOperation();
+ MLIRContext *context = funcOp.getContext();
+
+ RewritePatternSet patterns(context);
+ patterns.add<TiledOpInterfaceTilingPattern>(
+ context, linalg::LinalgTilingOptions().setTileSizes({10, 20}),
+ linalg::LinalgTransformationFilter(
+ StringAttr::get(context, "tiling_input"),
+ StringAttr::get(context, "tiling_output")));
+ patterns.add<TiledOpInterfaceTilingPattern>(
+ context, linalg::LinalgTilingOptions().setTileSizes(ArrayRef<int64_t>{0}),
+ linalg::LinalgTransformationFilter(
+ StringAttr::get(context, "no_tiling_input"),
+ StringAttr::get(context, "no_tiling_output")));
+
+ patterns.add<TiledOpInterfaceTilingPattern>(
+ context, linalg::LinalgTilingOptions().setTileSizes({0, 20}),
+ linalg::LinalgTransformationFilter(
+ StringAttr::get(context, "outer_reduce_input"),
+ StringAttr::get(context, "outer_reduce_output")));
+ patterns.add<TiledOpInterfaceTilingPattern>(
+ context, linalg::LinalgTilingOptions().setTileSizes({10, 0, 0}),
+ linalg::LinalgTransformationFilter(
+ StringAttr::get(context, "inner_reduce_input"),
+ StringAttr::get(context, "inner_reduce_output")));
+
+ static linalg::LinalgLoopDistributionOptions workgroupDistributionOptions = {
+ [](OpBuilder &builder, Location loc, ArrayRef<Range> parallelLoopRanges) {
+ auto numParallelDims = parallelLoopRanges.size();
+
+ SmallVector<linalg::ProcInfo, 3> procInfo(numParallelDims);
+ for (size_t dim = 0; dim < numParallelDims; ++dim) {
+ procInfo[numParallelDims - dim - 1] = {
+ buildFlowWorkgroupInfoOp<IREE::Input::DispatchWorkgroupIDOp>(
+ builder, dim),
+ buildFlowWorkgroupInfoOp<IREE::Input::DispatchWorkgroupCountOp>(
+ builder, dim)};
+ }
+ return procInfo;
+ },
+ {linalg::DistributionMethod::Cyclic, linalg::DistributionMethod::Cyclic,
+ linalg::DistributionMethod::Cyclic},
+ DenseMap<StringRef,
+ std::function<linalg::ProcInfo(OpBuilder &, Location)>>()};
+
+ patterns.add<TiledOpInterfaceTilingPattern>(
+ context,
+ linalg::LinalgTilingOptions()
+ .setTileSizes(ArrayRef<int64_t>{10, 0, 30})
+ .setDistributionOptions(workgroupDistributionOptions),
+ linalg::LinalgTransformationFilter(
+ StringAttr::get(context, "distribute_input"),
+ StringAttr::get(context, "distribute_output")));
+
+ patterns.add<TiledOpInterfaceTilingPattern>(
+ context,
+ linalg::LinalgTilingOptions().setTileSizes(ArrayRef<int64_t>{32}),
+ linalg::LinalgTransformationFilter(
+ StringAttr::get(context, "tiling_1d_stage5_fft_input"),
+ StringAttr::get(context, "tiling_1d_stage5_fft_output")));
+
+ patterns.add<TiledOpInterfaceTilingPattern>(
+ context,
+ linalg::LinalgTilingOptions().setTileSizes(ArrayRef<int64_t>{10, 32}),
+ linalg::LinalgTransformationFilter(
+ StringAttr::get(context, "tiling_2d_stage5_fft_input"),
+ StringAttr::get(context, "tiling_2d_stage5_fft_output")));
+
+ patterns.add<TiledOpInterfaceTilingPattern>(
+ context, linalg::LinalgTilingOptions().setTileSizes({0, 20}),
+ linalg::LinalgTransformationFilter(
+ StringAttr::get(context, "tiling_repeated_indices_scatter_input"),
+ StringAttr::get(context, "tiling_repeated_indices_scatter_output")));
+
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ return signalPassFailure();
+ }
+}
+
+std::unique_ptr<OperationPass<FuncOp>>
+IREE::LinalgExt::createTiledOpInterfaceTilingPass() {
+ return std::make_unique<TiledOpInterfaceTilingPass>();
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt
index 0cd7fd0..a174ba1 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/CMakeLists.txt
@@ -1,25 +1,44 @@
-add_mlir_library(IREELinalgExtPasses
- ConvertToLoops.cpp
- PadContractionToBlockSize.cpp
- Passes.cpp
+add_mlir_library(IREELinalgExtTransforms
+ InParallelToAsync.cpp
+ InParallelToSequentialFor.cpp
+ TilingExternalModels.cpp
+ TileToSequentialFor.cpp
+ TileToInParallel.cpp
Tiling.cpp
+ TilingToTileOp.cpp
+ Utils.cpp
+ PARTIAL_SOURCES_INTENDED
DEPENDS
- IREELinalgExtTransformsPassesIncGen
+ mlir-headers
+ IREELinalgExtDialect
LINK_LIBS PUBLIC
- IREEInputDialect
IREELinalgExtDialect
- MLIRAffine
+
+ MLIRAffineToStandard
+ MLIRAsync
+ MLIRSCFToControlFlow
+ MLIRLinalgToLLVM
+ MLIRVectorToLLVM
+ MLIRMathToLLVM
+ MLIRMemRefToLLVM
MLIRIR
+ MLIRMath
MLIRLinalg
MLIRLinalgTransforms
- MLIRMath
- MLIRMemRef
MLIRPass
MLIRSCF
- MLIRFunc
- MLIRSupport
- MLIRTensor
MLIRTransforms
)
+
+add_mlir_library(IREELinalgExtOpInterfaceImpl
+ LinalgExtBufferization.cpp
+
+ PARTIAL_SOURCES_INTENDED
+ LINK_LIBS PUBLIC
+ IREELinalgExtDialect
+
+ MLIRBufferization
+ MLIRTensorTransforms
+)
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/InParallelToAsync.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/InParallelToAsync.cpp
new file mode 100644
index 0000000..64514bb
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/InParallelToAsync.cpp
@@ -0,0 +1,91 @@
+//===- InParallelToAsync.cpp - Rewrite InParallel as Async ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include <cstdlib>
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Utils.h"
+#include "llvm/ADT/STLExtras.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Async/IR/Async.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::iree_compiler::IREE::LinalgExt;
+
+FailureOr<Operation *> mlir::iree_compiler::IREE::LinalgExt::
+ InParallelOpToAsyncRewriter::returningMatchAndRewrite(
+ iree_compiler::IREE::LinalgExt::InParallelOp inParallelOp,
+ PatternRewriter &rewriter) const {
+ assert(inParallelOp.getNumResults() == 0 &&
+ "expected bufferized InParallelOp");
+
+ // Only consider the top level InParallelOp op and skip if it already
+ // contains an ExecuteOp.
+ if (inParallelOp
+ ->getParentOfType<iree_compiler::IREE::LinalgExt::InParallelOp>() ||
+ llvm::any_of(inParallelOp.getBody()->getOperations(),
+ [](Operation &op) { return isa<async::ExecuteOp>(&op); }))
+ return failure();
+
+ auto *ctx = inParallelOp.getContext();
+ Location loc = inParallelOp.getLoc();
+ Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ Value numThreads = inParallelOp.num_threads();
+
+ // Wrap the linalg_ext.in_parallel into an async::ExecuteOp.
+ // 1. Create the async::GroupType object on which we synchronize.
+ Value asyncGroup = rewriter.create<async::CreateGroupOp>(
+ loc, async::GroupType::get(ctx), numThreads);
+
+ // 2. Create a bodyless forOp.
+ scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, zero, numThreads, one);
+ rewriter.setInsertionPointToStart(forOp.getBody());
+
+ // 3. Create an empty executeOp, nested within the forOp.
+ auto noopExec = [&](OpBuilder &executeBuilder, Location executeLoc,
+ ValueRange executeArgs) {};
+ auto executeOp =
+ rewriter.create<async::ExecuteOp>(loc, /*resultTypes=*/TypeRange(),
+ /*dependencies=*/ValueRange(),
+ /*operands=*/ValueRange(), noopExec);
+
+ // 3. Steal the iree_compiler::IREE::LinalgExt::InParallel ops, except the
+ // terminator, into the body of the async::ExecuteOp, just before the
+ // terminator.
+ SmallVector<Value> bbArgsTranslated{forOp.getInductionVar()};
+ rewriter.mergeBlocks(&inParallelOp.region().front(), executeOp.getBody(),
+ bbArgsTranslated);
+ // 3.b. Erase the terminator stolen from inParallelOp.
+ rewriter.eraseOp(&executeOp.getBody()->back());
+ // 3.c. Erase inParallelOp.
+ rewriter.eraseOp(inParallelOp);
+ // 3.d. Add ExecuteOp terminator.
+ rewriter.setInsertionPointToEnd(executeOp.getBody());
+ rewriter.create<async::YieldOp>(loc, ValueRange{});
+ // 3.e. Add to group within the loop.
+ rewriter.setInsertionPoint(forOp.getBody()->getTerminator());
+ rewriter.create<async::AddToGroupOp>(loc, rewriter.getIndexType(),
+ executeOp.token(), asyncGroup);
+
+ // 4. After the iree_compiler::IREE::LinalgExt::InParallel, await all async
+ // tasks in `asyncGroup`.
+ rewriter.setInsertionPointAfter(forOp);
+ return rewriter.create<async::AwaitAllOp>(loc, asyncGroup).getOperation();
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/InParallelToSequentialFor.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/InParallelToSequentialFor.cpp
new file mode 100644
index 0000000..683629b
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/InParallelToSequentialFor.cpp
@@ -0,0 +1,111 @@
+//===- InParallelToSequentialFor.cpp.cpp - Rewrite InParallel as ForOp ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Utils.h"
+#include "llvm/ADT/STLExtras.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::iree_compiler::IREE::LinalgExt;
+
+namespace {
+
+SmallVector<Value> getValuesToYield(PerformConcurrentlyOp op) {
+ return llvm::to_vector(llvm::map_range(
+ op.yieldingOps(), [](ParallelInsertSliceOp op) { return op.dest(); }));
+}
+
+} // namespace
+
+FailureOr<scf::ForOp> InParallelOpToScfForRewriter::returningMatchAndRewrite(
+ InParallelOp inParallelOp, PatternRewriter &rewriter) const {
+ // Construct the loop bounds based on the canonical arithmetic progression.
+ Location loc = inParallelOp.getLoc();
+ Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ Value numThreads = inParallelOp.num_threads();
+
+ // Construct the op without a body builder: we need to clone the ops in the
+ // body explicitly after having access to the new bbArgs.
+ // As a consequence, `ensureTerminator` is not called and the `forOp` body
+ // has no terminator.
+ PerformConcurrentlyOp performConcurrentlyOp = inParallelOp.getTerminator();
+ SmallVector<Value> valuesToYield = getValuesToYield(performConcurrentlyOp);
+ scf::ForOp forOp =
+ rewriter.create<scf::ForOp>(loc, zero, numThreads, one, valuesToYield);
+
+ // Move the body while replacing the threadId by the forOp iv.
+ SmallVector<Value> bbArgsTranslated{forOp.getInductionVar()};
+ Block *body = forOp.getBody();
+ bool hasTerminator =
+ !body->empty() && body->back().hasTrait<OpTrait::IsTerminator>();
+ if (hasTerminator) {
+ rewriter.mergeBlockBefore(&inParallelOp.region().front(),
+ body->getTerminator(), bbArgsTranslated);
+ } else {
+ rewriter.mergeBlocks(&inParallelOp.region().front(), body,
+ bbArgsTranslated);
+ }
+
+ rewriter.setInsertionPointToStart(body);
+ BlockAndValueMapping bvm;
+ bvm.map(valuesToYield, forOp.getRegionIterArgs());
+
+ // Create sequential insertSlice ops.
+ SmallVector<Value> toYield;
+ rewriter.setInsertionPoint(performConcurrentlyOp);
+ for (ParallelInsertSliceOp op : performConcurrentlyOp.yieldingOps()) {
+ toYield.push_back(rewriter.createOrFold<tensor::InsertSliceOp>(
+ loc, op.source(), bvm.lookup(op.dest()), op.getMixedOffsets(),
+ op.getMixedSizes(), op.getMixedStrides()));
+ }
+
+ // performConcurrentlyOp.yieldedValues come from above, not from bbArgs.
+ // There is no rewriter method to make mergeBlocks update non-bbArgs.
+ // Need to manually clone + bvm all uses that are now nested under forOp.
+ // Warning: this replacement is currently optimistic and may change the
+ // semantics as explained in the pass description in Passes.td.
+ SmallVector<Operation *> opsToReplace;
+ for (Value toReplace : valuesToYield) {
+ for (OpOperand &u : toReplace.getUses()) {
+ Operation *op = u.getOwner();
+ if (!forOp->isProperAncestor(op)) continue;
+ opsToReplace.push_back(op);
+ }
+ }
+ for (Operation *op : opsToReplace) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(op);
+ Operation *cloned = rewriter.clone(*op, bvm);
+ rewriter.replaceOp(op, cloned->getResults());
+ }
+
+ // Insert terminator.
+ if (!hasTerminator) {
+ rewriter.setInsertionPointToEnd(body);
+ rewriter.create<scf::YieldOp>(loc, toYield);
+ }
+
+ // Cleanup and replace.
+ rewriter.eraseOp(performConcurrentlyOp);
+ rewriter.replaceOp(inParallelOp, forOp.getResults());
+
+ return forOp;
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/LinalgExtBufferization.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/LinalgExtBufferization.cpp
new file mode 100644
index 0000000..6a03048
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/LinalgExtBufferization.cpp
@@ -0,0 +1,347 @@
+//===-- LinalgExtBufferization.cpp - Linalg Extension bufferization -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree-dialects/Dialect/LinalgExt/LinalgExtBufferization.h"
+
+#include <mlir/IR/BuiltinOps.h>
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::iree_compiler::IREE::LinalgExt;
+
+/// Return the destinations that an InParallelOp is inserting into. One per
+/// ParallelInsertSliceOp.
+static SmallVector<OpOperand *> getInsertionDest(InParallelOp inParallelOp) {
+ Operation *terminator = inParallelOp.region().front().getTerminator();
+ auto performConcOp = dyn_cast<PerformConcurrentlyOp>(terminator);
+ assert(performConcOp && "expected PerformConcurrentlyOp as terminator");
+
+ SmallVector<OpOperand *> result;
+ performConcOp.walk([&](ParallelInsertSliceOp insertOp) {
+ result.push_back(&insertOp->getOpOperand(1) /*dest*/);
+ });
+
+ return result;
+}
+
+namespace mlir {
+
+using bufferization::BufferizableOpInterface;
+using bufferization::BufferizationState;
+using bufferization::BufferRelation;
+using bufferization::getMemRefType;
+using bufferization::replaceOpWithBufferizedValues;
+using bufferization::replaceOpWithNewBufferizedOp;
+using tensor::ExtractSliceOp;
+
+namespace iree_compiler {
+namespace IREE {
+namespace LinalgExt {
+
+/// Bufferization of InParallelOp. This also bufferizes the terminator of the
+/// region. There are op interfaces for the terminators (PerformConcurrentlyOp
+/// and ParallelInsertSliceOp), but these are only used during analysis. Not
+/// for bufferization.
+struct InParallelOpInterface
+ : public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
+ InParallelOp> {
+ SmallVector<OpOperand *> getAliasingOpOperand(
+ Operation *op, OpResult opResult, const BufferizationState &state) const {
+ // Get OpOperand (dest) from corresponding ParallelInsertSliceOp.
+ auto inParallelOp = cast<InParallelOp>(op);
+ return {getInsertionDest(inParallelOp)[opResult.getResultNumber()]};
+ }
+
+ bool isMemoryWrite(Operation *op, OpResult opResult,
+ const BufferizationState &state) const {
+ // This op is a memory write. Stop lookup here to avoid finding false
+ // conflicts involving this op and one of the ops in the region. This is
+ // similar to how scf.if ops are analyzed.
+ return true;
+ }
+
+ bool isAllocationHoistingBarrier(Operation *op) const { return true; }
+
+ BufferRelation bufferRelation(Operation *op, OpResult opResult,
+ const BufferizationState &state) const {
+ return BufferRelation::Equivalent;
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &b,
+ const BufferizationState &state) const {
+ OpBuilder::InsertionGuard g(b);
+ auto inParallelOp = cast<InParallelOp>(op);
+ Block *body = &inParallelOp.region().front();
+ Operation *oldTerminator = body->getTerminator();
+ assert(isa<PerformConcurrentlyOp>(oldTerminator) &&
+ "unexpected terminator");
+
+ // Gather new results of the InParallelOp.
+ SmallVector<Value> newResults;
+ for (OpResult opResult : inParallelOp->getOpResults()) {
+ SmallVector<OpOperand *> insertDestOperands =
+ state.getAliasingOpOperand(opResult);
+ assert(insertDestOperands.size() == 1 &&
+ "expected exactly one aliasing OpOperand");
+ // Insert copies right before the PerformConcurrentlyOp terminator. They
+ // should not be inside terminator (which would be the default insertion
+ // point).
+ Value buffer = *state.getBuffer(
+ b, *insertDestOperands.front(), /*forceInPlace=*/false,
+ /*customCopyInsertionPoint=*/oldTerminator);
+ newResults.push_back(buffer);
+ Value destTensor = insertDestOperands.front()->get();
+
+ // Replace all uses of the insert dest tensor inside the InParallelOp
+ // with the result buffer.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPointToStart(body);
+ Value toTensorOp =
+ b.create<bufferization::ToTensorOp>(inParallelOp.getLoc(), buffer);
+ for (OpOperand &use : destTensor.getUses())
+ if (body->findAncestorOpInBlock(*use.getOwner()))
+ // This is a use inside the InParallelOp.
+ use.set(toTensorOp);
+ }
+
+ // Create new InParallelOp without any results.
+ TypeRange newResultTypes;
+ auto newInParallelOp = b.create<InParallelOp>(
+ inParallelOp.getLoc(), newResultTypes, inParallelOp.num_threads());
+
+ // Delete terminator.
+ newInParallelOp.getBody()->getTerminator()->erase();
+
+ // Move over block contents of the old op.
+ b.mergeBlocks(inParallelOp.getBody(), newInParallelOp.getBody(),
+ {newInParallelOp.getBody()->getArgument(0)});
+
+ // Bufferize terminator.
+ auto performConcurrentlyOp =
+ cast<PerformConcurrentlyOp>(newInParallelOp.getBody()->getTerminator());
+ b.setInsertionPoint(performConcurrentlyOp);
+ WalkResult walkResult =
+ performConcurrentlyOp.walk([&](ParallelInsertSliceOp insertOp) {
+ Location loc = insertOp.getLoc();
+ Type srcType = getMemRefType(
+ insertOp.source().getType().cast<RankedTensorType>(),
+ state.getOptions());
+ Type destType =
+ getMemRefType(insertOp.dest().getType().cast<RankedTensorType>(),
+ state.getOptions());
+ // ParallelInsertSliceOp bufferizes to a copy.
+ auto srcMemref = b.create<bufferization::ToMemrefOp>(
+ loc, srcType, insertOp.source());
+ auto destMemref = b.create<bufferization::ToMemrefOp>(
+ loc, destType, insertOp.dest());
+ Value subview = b.create<memref::SubViewOp>(
+ loc, destMemref, insertOp.getMixedOffsets(),
+ insertOp.getMixedSizes(), insertOp.getMixedStrides());
+ // This memcpy will fold away if everything bufferizes in-place.
+ if (failed(createMemCpy(b, insertOp.getLoc(), srcMemref, subview,
+ state.getOptions())))
+ return WalkResult::interrupt();
+ b.eraseOp(insertOp);
+ return WalkResult::advance();
+ });
+ if (walkResult.wasInterrupted()) return failure();
+
+ // Replace the op.
+ replaceOpWithBufferizedValues(b, op, newResults);
+
+ return success();
+ }
+};
+
+/// Nothing to do for PerformConcurrentlyOp.
+struct PerformConcurrentlyOpInterface
+ : public BufferizableOpInterface::ExternalModel<
+ PerformConcurrentlyOpInterface, PerformConcurrentlyOp> {
+ LogicalResult bufferize(Operation *op, RewriterBase &b,
+ const BufferizationState &state) const {
+ llvm_unreachable("op does not have any tensor OpOperands / OpResults");
+ return failure();
+ }
+};
+
+/// Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e.
+/// equivalent operand / result and same offset/sizes/strides specification).
+static bool areEquivalentExtractSliceOps(const BufferizationState &state,
+ ExtractSliceOp st,
+ ParallelInsertSliceOp sti) {
+ if (!st || !sti) return false;
+ if (st != sti &&
+ !state.areEquivalentBufferizedValues(st.source(), sti.dest()))
+ return false;
+ if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
+ return false;
+ return true;
+}
+
+/// Return true if `value` is originating from an ExtractSliceOp that matches
+/// the given InsertSliceOp.
+static bool hasMatchingExtractSliceOp(const BufferizationState &state,
+ Value value,
+ ParallelInsertSliceOp insertOp) {
+ auto condition = [&](Value val) {
+ if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
+ if (areEquivalentExtractSliceOps(state, extractOp, insertOp)) return true;
+ return false;
+ };
+
+ return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
+ condition);
+}
+
+/// Analysis of ParallelInsertSliceOp.
+struct ParallelInsertSliceOpInterface
+ : public BufferizableOpInterface::ExternalModel<
+ ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
+ SmallVector<OpResult> getAliasingOpResult(
+ Operation *op, OpOperand &opOperand,
+ const BufferizationState &state) const {
+ if (&opOperand != &op->getOpOperand(1) /*dest*/) return {};
+
+ // ParallelInsertSliceOp itself has no results. Tensors are returned via
+ // the parent op.
+ auto inParallelOp = op->getParentOfType<InParallelOp>();
+ assert(inParallelOp &&
+ "could not find valid owner of parallel_insert_slice");
+
+ // The i-th ParallelInsertSliceOp result is returned via the i-th OpResult
+ // of the parent InParallelOp.
+ Block *block = op->getBlock();
+ unsigned int opIdx = 0;
+ for (ParallelInsertSliceOp insertOp :
+ block->getOps<ParallelInsertSliceOp>()) {
+ if (insertOp.getOperation() == op) break;
+ ++opIdx;
+ }
+ assert(opIdx < inParallelOp->getNumResults() &&
+ "could not find op inside terminator op");
+
+ return {inParallelOp->getResult(opIdx)};
+ }
+
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const BufferizationState &state) const {
+ return true;
+ }
+
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ const BufferizationState &state) const {
+ return &opOperand == &op->getOpOperand(1) /*dest*/;
+ }
+
+ BufferRelation bufferRelation(Operation *op, OpResult opResult,
+ const BufferizationState &state) const {
+ return BufferRelation::Equivalent;
+ }
+
+ LogicalResult bufferize(Operation *op, RewriterBase &b,
+ const BufferizationState &state) const {
+ // Will be bufferized as part of InParallelOp.
+ return failure();
+ }
+
+ // TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share
+ // the code.
+ bool isNotConflicting(Operation *op, OpOperand *uRead,
+ OpOperand *uConflictingWrite,
+ const BufferizationState &state) const {
+ Operation *readingOp = uRead->getOwner();
+ Operation *conflictingWritingOp = uConflictingWrite->getOwner();
+
+ // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
+ // uRead is an InsertSliceOp...
+ if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) {
+ // As an example, consider the following IR.
+ //
+ // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
+ // %1 = linalg.fill %cst, %0 {inplace= [true] }
+ // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
+ // {inplace= [true] }
+
+ // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
+ if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
+ hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
+ insertSliceOp))
+ // Case 1: The main insight is that InsertSliceOp reads only part of
+ // the destination tensor. The overwritten area is not read. If
+ // uConflictingWrite writes into exactly the memory location that is
+ // being read by uRead, this is not a conflict.
+ //
+ // In the above example:
+ // uRead = OpOperand 1 (%t) of tensor.insert_slice
+ // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
+ //
+ // The read of %t does not conflict with the write of the FillOp
+ // (same aliases!) because the area that the FillOp operates on is
+ // exactly the one that is *not* read via %t.
+ return true;
+
+ if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
+ uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
+ hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
+ // Case 2: The read of the source tensor and the write to the dest
+ // tensor via an InsertSliceOp is not a conflict if the read is
+ // reading exactly that part of an equivalent tensor that the
+ // InsertSliceOp is writing.
+ //
+ // In the above example:
+ // uRead = OpOperand 0 (%1) of tensor.insert_slice
+ // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
+ return true;
+ }
+
+ // If uConflictingWrite is an InsertSliceOp...
+ if (auto insertSliceOp =
+ dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp))
+ // As an example, consider the following IR.
+ //
+ // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
+ // %1 = linalg.fill %cst, %0 {inplace= [true] }
+ // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
+ // {inplace= [true] }
+ // %3 = vector.transfer_read %1, %cst
+ //
+ // In the above example:
+ // uRead = OpOperand 0 (%1) of vector.transfer_read
+ // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
+ // lastWrite = %1
+ //
+ // This is not a conflict because the InsertSliceOp overwrites the
+ // memory segment of %1 with the exact same data. (Effectively, there
+ // is no memory write here.)
+ if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
+ state.areEquivalentBufferizedValues(uRead->get(),
+ insertSliceOp.source()) &&
+ hasMatchingExtractSliceOp(state, insertSliceOp.source(),
+ insertSliceOp))
+ return true;
+
+ return false;
+ }
+};
+} // namespace LinalgExt
+} // namespace IREE
+} // namespace iree_compiler
+} // namespace mlir
+
+void mlir::iree_compiler::IREE::LinalgExt::
+ registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
+ registry.addOpInterface<InParallelOp, InParallelOpInterface>();
+ registry
+ .addOpInterface<PerformConcurrentlyOp, PerformConcurrentlyOpInterface>();
+ registry
+ .addOpInterface<ParallelInsertSliceOp, ParallelInsertSliceOpInterface>();
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TileToInParallel.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TileToInParallel.cpp
new file mode 100644
index 0000000..83ece71
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TileToInParallel.cpp
@@ -0,0 +1,132 @@
+//===- TileToInParallel.cpp.cpp - Rewrite TileOp as InParallel -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Utils.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/raw_ostream.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::iree_compiler::IREE::LinalgExt;
+
+FailureOr<iree_compiler::IREE::LinalgExt::InParallelOp> mlir::iree_compiler::
+ IREE::LinalgExt::TileOpToInParallelRewriter::returningMatchAndRewrite(
+ iree_compiler::IREE::LinalgExt::TileOp tileOp,
+ PatternRewriter &rewriter) const {
+ // TODO: verifier.
+ assert(tileOp.getNumResults() > 0 &&
+ tileOp.outs().size() == tileOp.getNumResults());
+
+ // TODO: when supported, iterate over the tensor of sizes. This will be
+ // iterating through a level of indirection.
+
+ int64_t tiledDim = tileOp.tiled_dim();
+
+ // Construct the loop bounds based on the canonical arithmetic progression.
+ Location loc = tileOp.getLoc();
+ Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value tiledDimValue = rewriter.create<arith::ConstantIndexOp>(loc, tiledDim);
+ Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ Value totalSize =
+ rewriter.create<tensor::DimOp>(loc, tileOp.outs().front(), tiledDimValue);
+ Value step = tileOp.tile_size();
+ assert(step.getType().isa<IndexType>() && "NYI: not an index type");
+
+ using AV = AffineValueExpr;
+ AffineBuilder ab(rewriter, loc);
+ AffineExpr i, j, M;
+ bindDims(rewriter.getContext(), i, j);
+ bindSymbols(rewriter.getContext(), M);
+ Value numThreads = ab.ceil(AV(i).bind(totalSize), AV(M).bind(step));
+
+ // Construct the op without a body builder: we need to clone the ops in the
+ // body explicitly after having access to the new bbArgs.
+ // As a consequence, `ensureTerminator` is not called and the body has no
+ // terminator.
+ iree_compiler::IREE::LinalgExt::InParallelOp inParallelOp =
+ rewriter.create<iree_compiler::IREE::LinalgExt::InParallelOp>(
+ loc, tileOp->getResultTypes(), numThreads);
+
+ // At the beginning of the InParallelOp, compute offset and sizes.
+ rewriter.setInsertionPointToStart(inParallelOp.getBody());
+
+ // Materialize the implicit subtensors as explicit subset_extract.
+ // TODO: generalize to multiple offset/chunk_size bbargs if needed.
+ // TODO: generalize the subset op.
+ SmallVector<Value> leadingOffsets, leadingSizes, leadingStrides;
+ for (int64_t i = 0; i < tiledDim; ++i) {
+ leadingOffsets.push_back(zero);
+ leadingSizes.push_back(
+ rewriter.createOrFold<tensor::DimOp>(loc, tileOp.outs().front(), i));
+ leadingStrides.push_back(one);
+ }
+ // clang-format off
+ Value offset = ab.mul(AV(i).bind(inParallelOp.getThreadIndex()),
+ AV(M).bind(step));
+ Value size = ab.min(
+ ValueRange{ab.sub(AV(i).bind(totalSize), AV(j).bind(offset)),
+ step});
+ // clang-format on
+ leadingOffsets.push_back(offset);
+ leadingSizes.push_back(size);
+ leadingStrides.push_back(one);
+
+ SmallVector<Value> implicitSubtensorExtracts;
+ for (Value tensor : tileOp.outs()) {
+ implicitSubtensorExtracts.push_back(
+ createSubsetExtractOpFromLeadingOffsetsSizesAndStrides(
+ rewriter, loc, tensor, leadingOffsets, leadingSizes,
+ leadingStrides));
+ }
+
+ // Get a reference to the TileOp terminator before the body is merged and it
+ // becomes too hard to get to the terminator.
+ auto tileYieldOp = cast<TileYieldOp>(tileOp.getBody()->getTerminator());
+
+ // Regroup the values that replace the tileOp's bbArg and move the body.
+ SmallVector<Value> bbArgsTranslated{offset, size};
+ llvm::append_range(bbArgsTranslated, implicitSubtensorExtracts);
+ rewriter.mergeBlockBefore(&tileOp.region().front(),
+ inParallelOp.getBody()->getTerminator(),
+ bbArgsTranslated);
+
+ // tileOp's terminator is not the terminator, insert explicit subset_insert
+ // ops and feed them to a new scf.yield terminator that we can now add.
+ PerformConcurrentlyOp performConcurrentlyOp = inParallelOp.getTerminator();
+
+ for (auto it : llvm::zip(tileYieldOp->getOperands(), tileOp.outs())) {
+ SmallVector<Value> offsets, sizes, strides;
+ completeOffsetsSizesAndStrides(rewriter, loc, std::get<0>(it),
+ leadingOffsets, leadingSizes, leadingStrides,
+ offsets, sizes, strides);
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(
+ performConcurrentlyOp.getBody()->getTerminator());
+ createParallelInsertSliceOpFromLeadingOffsetsSizesAndStrides(
+ rewriter, loc, std::get<0>(it), std::get<1>(it), offsets, sizes,
+ strides);
+ }
+
+ // Cleanup and replace.
+ rewriter.eraseOp(tileYieldOp);
+ rewriter.replaceOp(tileOp, inParallelOp.getResults());
+
+ return inParallelOp;
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TileToSequentialFor.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TileToSequentialFor.cpp
new file mode 100644
index 0000000..657eedd
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TileToSequentialFor.cpp
@@ -0,0 +1,106 @@
+//===- LowerToSCF.cpp.cpp - Lower to SCF ---------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Utils.h"
+#include "llvm/ADT/STLExtras.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::iree_compiler::IREE::LinalgExt;
+
+FailureOr<scf::ForOp> mlir::iree_compiler::IREE::LinalgExt::
+ TileOpToSCFRewriter::returningMatchAndRewrite(
+ iree_compiler::IREE::LinalgExt::TileOp tileOp,
+ PatternRewriter &rewriter) const {
+ // TODO: verifier.
+ assert(tileOp.getNumResults() > 0 &&
+ tileOp.outs().size() == tileOp.getNumResults());
+
+ // TODO: when supported, iterate over the tensor of sizes. This will be
+ // iterating through a level of indirection.
+
+ // Construct the loop bounds based on the canonical arithmetic progression.
+ Location loc = tileOp.getLoc();
+ Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ Value totalSize =
+ rewriter.create<tensor::DimOp>(loc, tileOp.outs().front(), zero);
+ Value step = tileOp.tile_size();
+ assert(step.getType().isa<IndexType>() && "NYI: not an index type");
+
+ // Construct the op without a body builder: we need to clone the ops in the
+ // body explicitly after having access to the new bbArgs.
+ // As a consequence, `ensureTerminator` is not called and the body has no
+ // terminator.
+ scf::ForOp forOp =
+ rewriter.create<scf::ForOp>(loc, zero, totalSize, step, tileOp.outs());
+
+ rewriter.setInsertionPointToStart(forOp.getBody());
+
+ // TODO: when supported, also compute from the tensor of sizes.
+ using AV = AffineValueExpr;
+ AffineBuilder ab(rewriter, loc);
+ AffineExpr i, j, M;
+ bindDims(rewriter.getContext(), i, j);
+ bindSymbols(rewriter.getContext(), M);
+
+ // Materialize the implicit subtensors as explicit subset_extract.
+ // TODO: generalize to multiple offset/chunk_size bbargs if needed.
+ // TODO: generalize the subset op.
+ Value offset = forOp.getInductionVar();
+ // clang-format off
+ Value size = ab.min(
+ ValueRange{ab.sub(AV(i).bind(totalSize), AV(j).bind(offset)),
+ step});
+ // clang-format on
+ SmallVector<Value> implicitSubtensorExtracts;
+ for (Value tensor : forOp.getRegionIterArgs()) {
+ implicitSubtensorExtracts.push_back(
+ createSubsetExtractOpFromLeadingOffsetsSizesAndStrides(
+ rewriter, loc, tensor, offset, size, one));
+ }
+
+ // Regroup the values that replace the tileOp's bbArg and move the body.
+ SmallVector<Value> bbArgsTranslated{offset, size};
+ llvm::append_range(bbArgsTranslated, implicitSubtensorExtracts);
+ rewriter.mergeBlocks(&tileOp.region().front(), forOp.getBody(),
+ bbArgsTranslated);
+ // tileOp's terminator is not the terminator, insert explicit subset_insert
+ // ops and feed them to a new scf.yield terminator that we can now add.
+ auto tileYieldOp = cast<TileYieldOp>(&forOp.getBody()->back());
+ SmallVector<Value> implicitSubtensorInserts;
+ for (auto it : llvm::zip(implicitSubtensorExtracts, tileYieldOp.getOperands(),
+ forOp.getRegionIterArgs())) {
+ implicitSubtensorInserts.push_back(createMatchingSubsetInsertOp(
+ rewriter, loc,
+ /*subsetExtractOp=*/
+ std::get<0>(it).getDefiningOp<tensor::ExtractSliceOp>(),
+ /*source=*/std::get<1>(it), /*dest=*/std::get<2>(it)));
+ }
+ // Insert terminator.
+ rewriter.setInsertionPointToEnd(forOp.getBody());
+ rewriter.create<scf::YieldOp>(loc, implicitSubtensorInserts);
+
+ // Cleanup and replace.
+ rewriter.eraseOp(tileYieldOp);
+ rewriter.replaceOp(tileOp, forOp.getResults());
+
+ return forOp;
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Tiling.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Tiling.cpp
index 25df1f8..0e55970 100644
--- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Tiling.cpp
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Tiling.cpp
@@ -1,360 +1,216 @@
-// Copyright 2021 The IREE Authors
+//===- Tiling.cpp - Tiling using TilingInterface --------------------------===//
//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
-#include "iree-dialects/Dialect/Input/InputDialect.h"
-#include "iree-dialects/Dialect/Input/InputOps.h"
-#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
-#include "iree-dialects/Dialect/LinalgExt/Transforms/PassDetail.h"
-#include "iree-dialects/Dialect/LinalgExt/Transforms/Passes.h"
-#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
-#include "llvm/ADT/TypeSwitch.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/SCF/SCF.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Utils.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Utils/StaticValueUtils.h"
-#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
-namespace IREE = mlir::iree_compiler::IREE;
-using namespace IREE::LinalgExt;
+using namespace mlir::iree_compiler::IREE::LinalgExt;
-//===----------------------------------------------------------------------===//
-// Utility methods for tiling a linalg_ext operation that implements a
-// TiledOpInterface
-//===----------------------------------------------------------------------===//
+// TODO: connect these patterns to PDL. Either via the transform dialect or via
+// PDLL.
-/// Returns failure if the options are unsupported.
-static LogicalResult verifySupportedTilingOptions(
- PatternRewriter &rewriter, Operation *op,
- const linalg::LinalgTilingOptions &options) {
- if (!options.interchangeVector.empty()) {
- return rewriter.notifyMatchFailure(op,
- "unsupported interchange during tiling");
+static bool isZero(Value v) {
+ if (auto cst = v.getDefiningOp<arith::ConstantIndexOp>())
+ return cst.value() == 0;
+ return false;
+}
+
+SmallVector<Value> tileToSCF(PatternRewriter &rewriter, TilingInterface op,
+ TilingInterface clonedOp, ValueRange tileSizes) {
+ // Compute lower and upper bounds of the loop nest.
+ SmallVector<Range> ranges = clonedOp.getIterationDomain(rewriter);
+ assert(tileSizes.size() <= ranges.size() &&
+ "expected tile sizes to match the number of loops");
+
+ // Fill the tile sizes with zeros for the untiled dimensions.
+ Location loc = op->getLoc();
+ SmallVector<Value> tileSizesVec(tileSizes.begin(), tileSizes.end());
+ if (ranges.size() != tileSizes.size()) {
+ Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ tileSizesVec.resize(ranges.size(), zero);
}
- if (options.loopType != linalg::LinalgTilingLoopType::Loops) {
- return rewriter.notifyMatchFailure(op,
- "only tiling with scf.for is supported");
- }
- if (options.distribution) {
- if (llvm::any_of(options.distribution->distributionMethod,
- [](linalg::DistributionMethod method) {
- return method != linalg::DistributionMethod::Cyclic;
- })) {
- return rewriter.notifyMatchFailure(op,
- "only cyclic distibution is allowed");
+
+ SmallVector<Value> lbs, dims, allDims, steps;
+ for (auto it : llvm::enumerate(ranges)) {
+ allDims.push_back(it.value().size);
+ if (!isZero(tileSizesVec[it.index()])) {
+ lbs.push_back(it.value().offset);
+ dims.push_back(it.value().size);
+ steps.push_back(tileSizesVec[it.index()]);
}
}
- return success();
-}
-/// Converts an `OpFoldResult` to a `Value` by building a constant op if
-/// if the `OpFoldResult` is an `IntegerAttr`.
-static Value getValue(OpBuilder &builder, Location loc,
- OpFoldResult valueOrAttr) {
- if (auto attr = valueOrAttr.dyn_cast<Attribute>()) {
- return builder.create<arith::ConstantIndexOp>(
- loc, attr.cast<IntegerAttr>().getInt());
- }
- return valueOrAttr.get<Value>();
-}
-
-/// Returns true if loop is untiled. Only checks if the value is statically
-/// zero. It is assumed that a `Value` defined by a constant op is already
-/// converted to an `IntegerAttr` of that value. So here just return true if
-/// this is an attribute with a zero value.
-static bool isUntiledLoop(OpFoldResult valueOrAttr) {
- Optional<int64_t> intVal = getConstantIntValue(valueOrAttr);
- return intVal && *intVal == 0;
-}
-
-/// Generates the tiled loops and the body by invoking the interface methods of
-/// TiledOpInterface.
-/// - `outputs` are the operands to use for outputs of the tiled operation.
-/// - `tileSizes` are tile sizes specified for all loops of the operation. If a
-/// loop is to be untiled it is set to 0.
-/// - `iteratorType` is the type of the loop iterator returned by the
-/// TiledOpInterface.
-/// - `loopBounds` are the bounds of all the loops of the op returned by the
-/// TiledOpInterface.
-/// - `loopDepth` is the current loop depth being processed.
-/// - `offsets` are the `Value`s that represent the position of the tile being
-/// operated on. The offsets are computed as the tiled loops are being
-/// generated.
-/// - `distributionInfo` is the proc_id and nprocs `Value`s to be used for
-/// distributed loops. It is a stack, and once an entry at the top of the
-/// stack is used for distribution it is popped before processing the inner
-/// loops.
-static FailureOr<TiledOp> tileInterfaceOpImpl(
- OpBuilder &builder, TiledOpInterface tilableOp, ValueRange outputs,
- MutableArrayRef<OpFoldResult> tileSizes, ArrayRef<StringRef> iteratorTypes,
- ArrayRef<Range> loopBounds, unsigned loopDepth,
- SmallVectorImpl<OpFoldResult> &offsets,
- ArrayRef<linalg::ProcInfo> distributionInfo) {
- Location loc = tilableOp.getLoc();
- // If this is the innermost loop, then generated the tiled implementation of
- // the op by invoking the TiledOpInterface methods.
- if (loopDepth == tileSizes.size()) {
- TiledOp ret;
- ret.op = tilableOp.getTiledImplementation(builder, outputs, offsets,
- tileSizes, ret.results);
- if (!ret.op) {
- return static_cast<LogicalResult>(
- tilableOp.emitOpError("failed to get tiled implementation"));
- }
- return ret;
- }
-
- // If tile size at this depth is empty, do nothing.
- if (isUntiledLoop(tileSizes[loopDepth])) {
- auto zeroAttr = builder.getI64IntegerAttr(0);
- offsets.push_back(zeroAttr);
- assert(matchPattern(loopBounds[loopDepth].offset, m_Zero()) &&
- "expected loop bounds to have lower bound of zero");
- tileSizes[loopDepth] = getAsOpFoldResult(loopBounds[loopDepth].size);
- return tileInterfaceOpImpl(builder, tilableOp, outputs, tileSizes,
- iteratorTypes, loopBounds, loopDepth + 1,
- offsets, distributionInfo);
- }
-
- // Generate an scf.for for the current loop depth.
- Value lb = loopBounds[loopDepth].offset;
- Value ub = loopBounds[loopDepth].size;
- // TODO(#7073): Put the check back. This is required by tiling linalg_ext.fft
- // op. We can put the check back after updating linalg_ext.fft semantics.
- // if (!matchPattern(loopBounds[loopDepth].stride, m_One())) {
- // return static_cast<LogicalResult>(
- // tilableOp.emitOpError("expected stride to be 1"));
- //}
- Value step = getValue(builder, loc, tileSizes[loopDepth]);
-
- // Update lb, ub and step for cyclic distribution.
- if (!distributionInfo.empty() &&
- iteratorTypes[loopDepth] == getParallelIteratorTypeName()) {
- linalg::updateBoundsForCyclicDistribution(
- builder, loc, distributionInfo.front().procId,
- distributionInfo.front().nprocs, lb, ub, step);
- distributionInfo = distributionInfo.drop_front();
- }
- FailureOr<TiledOp> innerReturnValue;
- bool isBufferTiling = tilableOp->getNumResults() == 0;
- ValueRange initValues(isBufferTiling ? ValueRange{} : outputs);
- auto forOp = builder.create<scf::ForOp>(
- loc, lb, ub, step, initValues,
- [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
- offsets.push_back(iv);
- auto affineMaps = AffineMap::inferFromExprList({ArrayRef<AffineExpr>{
- b.getAffineSymbolExpr(0),
- b.getAffineSymbolExpr(1) - b.getAffineDimExpr(0)}})[0];
- // Similar to linalg tiling, the tile size is the min(tileSizes, ub -
- // iv) to account for cases where tile size does not divide (ub - lb)
- // exactly.
- Value inBoundsTileSize = b.create<AffineMinOp>(
- loc, affineMaps,
- ValueRange{iv, getValue(builder, loc, tileSizes[loopDepth]), ub});
- tileSizes[loopDepth] = getAsOpFoldResult(inBoundsTileSize);
- // Recursively proceed to generate the tiled loop for the next level.
- innerReturnValue =
- tileInterfaceOpImpl(b, tilableOp, (isBufferTiling ? outputs : args),
- tileSizes, iteratorTypes, loopBounds,
- loopDepth + 1, offsets, distributionInfo);
- if (failed(innerReturnValue)) return;
- b.create<scf::YieldOp>(loc, innerReturnValue->results);
+ // Generate loop nest: One loop per dimension.
+ llvm::SmallPtrSet<Operation *, 1> preservedUses;
+ SmallVector<Value> destOperand = clonedOp.getDestinationOperands(rewriter);
+ auto loopNest = mlir::scf::buildLoopNest(
+ rewriter, loc, lbs, /*ubs=*/dims, steps, ValueRange(destOperand),
+ [&](OpBuilder &b, Location loc, ValueRange localIvs,
+ ValueRange iterArgs) -> scf::ValueVector {
+ // Compute offsets and sizes of ExtractSliceOp.
+ SmallVector<Value> offsets =
+ linalg::computeTileOffsets(b, loc, localIvs, tileSizesVec);
+ SmallVector<Value> sizes =
+ linalg::computeTileSizes(b, loc, localIvs, tileSizesVec, allDims);
+ // Create ExtractSliceOp: Extract a tile from the PadOp.
+ // Note: The PadOp is located outside of the loop nest. It is
+ // later moved inside by ExtractSliceOfPadTensorSwapPattern.
+ auto map =
+ AffineMap::getMultiDimIdentityMap(ranges.size(), b.getContext());
+ assert(clonedOp->getNumResults() == 1 && "expected single result op");
+ Value tiledOutput =
+ linalg::makeTiledShape(b, loc, clonedOp->getResult(0), tileSizesVec,
+ map, offsets, allDims, sizes);
+ auto sliceOp = tiledOutput.getDefiningOp<tensor::ExtractSliceOp>();
+ preservedUses.insert(sliceOp);
+ assert(sliceOp && "expected ExtractSliceOp");
+ // Insert the tile into the output tensor.
+ Value yieldValue =
+ createMatchingSubsetInsertOp(b, loc, sliceOp, sliceOp, iterArgs[0]);
+ return scf::ValueVector({yieldValue});
});
- if (failed(innerReturnValue)) {
- return innerReturnValue;
- }
- innerReturnValue->loops.insert(innerReturnValue->loops.begin(),
- forOp.getOperation());
- innerReturnValue->results = forOp.getResults();
- return innerReturnValue;
+ return loopNest.getResults();
}
-FailureOr<TiledOp> tileInterfaceOp(OpBuilder &b, TiledOpInterface tilableOp,
- const linalg::LinalgTilingOptions &options) {
- SmallVector<Value> dest = tilableOp.getDestinationOperands(b);
- if (dest.empty()) {
- return static_cast<LogicalResult>(tilableOp.emitOpError(
- "cannot tile operation without destination operands"));
- }
-
- SmallVector<StringRef> iteratorTypes = tilableOp.getLoopIteratorTypes();
- SmallVector<Value, 4> tileSizesVals =
- options.tileSizeComputationFunction(b, tilableOp);
- auto zeroAttr = b.getI64IntegerAttr(0);
-
- // The actual tile sizes used converts `Value` defined as constant 0, to a
- // zero integer attributes. Currently if the iterator type is not "parallel",
- // the tile size is forced to zero as well.
- auto tileSizes = getAsOpFoldResult(tileSizesVals);
- tileSizes.resize(iteratorTypes.size(), zeroAttr);
- for (auto en : llvm::enumerate(iteratorTypes)) {
- if (en.value() == getParallelIteratorTypeName()) continue;
- if (!isUntiledLoop(tileSizes[en.index()])) {
- return static_cast<LogicalResult>(tilableOp.emitOpError(
- "unimplemented tiling of non-parallel loop iterator type"));
- }
- }
-
- // Trivial early exit case of tile sizes being zero for all parallel loops.
- if (llvm::all_of(tileSizes, isUntiledLoop)) {
- return TiledOp{tilableOp, {}, {}};
- }
-
- SmallVector<Range> loopBounds = tilableOp.getIterationDomain(b);
- SmallVector<linalg::ProcInfo> distributionInfo;
- // If the tiled loops are distributed, get the proc_id and nprocs for the
- // distributed loops. First collect the parallel loops by iterating over the
- // tileSizes and getting the loops that are distribute, i.e.,
- // - parallel, i.e. iteratorTypes is "parallel"
- // - tiled, i.e. tileSize != 0
- if (options.distribution) {
- SmallVector<Range> distributedLoopRange;
- for (auto i : llvm::seq<unsigned>(0, tileSizes.size())) {
- if (isUntiledLoop(tileSizes[i])) continue;
- if (iteratorTypes[i] != getParallelIteratorTypeName()) continue;
- distributedLoopRange.push_back(loopBounds[i]);
- }
- distributionInfo = options.distribution->procInfo(b, tilableOp.getLoc(),
- distributedLoopRange);
- }
-
- SmallVector<OpFoldResult> offsets;
- return tileInterfaceOpImpl(b, tilableOp, dest, tileSizes, iteratorTypes,
- loopBounds, 0, offsets, distributionInfo);
-}
-
-LogicalResult TiledOpInterfaceBaseTilingPattern::matchAndRewriteBase(
- TiledOpInterface tilableOp, PatternRewriter &rewriter,
- TiledOp &result) const {
- if (failed(filter.checkAndNotify(rewriter, tilableOp))) {
- return failure();
- }
- if (failed(verifySupportedTilingOptions(rewriter, tilableOp, options))) {
- return failure();
- }
-
- FailureOr<TiledOp> res = tileInterfaceOp(rewriter, tilableOp, options);
- if (failed(res)) return res;
- result = *res;
- if (result.op) {
- filter.replaceLinalgTransformationFilter(rewriter, result.op);
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Test pass for tiling Linalg Ext ops
-//===----------------------------------------------------------------------===//
-
namespace {
-struct TiledOpInterfaceTilingPass
- : public TiledOpInterfaceTilingBase<TiledOpInterfaceTilingPass> {
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<
- AffineDialect, IREE::Input::IREEInputDialect, linalg::LinalgDialect,
- IREE::LinalgExt::IREELinalgExtDialect, memref::MemRefDialect,
- func::FuncDialect, mlir::arith::ArithmeticDialect, math::MathDialect,
- tensor::TensorDialect, scf::SCFDialect>();
+
+/// The tiling here works by two steps. The first step is to create a loop based
+/// on the loop bounds of the operation obtained from `TilingInterface`.
+///
+/// ```mlir
+/// %1 = <tiling interface op> ins(...) outs(%0 : ...)
+/// ... <use_op> ... %1 ...
+/// ```
+///
+/// is rewritten using a "noop" subtensor extract/insert pair
+///
+/// ```mlir
+/// %1 = <tiling interface op> ins(...) outs(%0 : ...)
+/// %2 = scf.for %iv0 = ... iter_args(%arg0 = %0) {
+/// %3 = scf.for %iv1 = ... iter_args(%arg1 = %arg0) {
+/// ...
+/// %4 = tensor.extract_slice %1[%iv0, %iv1]....
+/// %5 = tensor.insert_slice %4 into %arg1[%iv0, %iv1]...
+/// scf.yield %5
+/// }
+/// scf.yield %3
+/// }
+/// ... <use_op> ... %2 ...
+/// ```
+///
+/// Following this the `TilingInterface` -> `tensor::ExtractSliceOp` pattern is
+/// replaced with
+///
+/// /// ```mlir
+/// %2 = scf.for %iv0 = ... iter_args(%arg0 = %0) {
+/// %3 = scf.for %iv1 = ... iter_args(%arg1 = %arg0) {
+/// ...
+/// %4 = tensor.extract_slice %0[%iv0, %iv1]
+/// %5 = <tiling interface op> ins(...) outs(%4 : ...)
+/// %6 = tensor.insert_slice %5 into %arg1[%iv0, %iv1]...
+/// scf.yield %6
+/// }
+/// scf.yield %3
+/// }
+/// ... <use_op> ... %2 ...
+/// ```
+///
+/// TODO(ravishankarm): The current approach seems to work for only tiling the
+/// parallel loops of the operation. Specifically,
+/// 1) the `%0` in the third snippet needs to be `%arg1`, for cases where the
+/// tiled loop is a reduction.
+/// 2) Current implementation is using the `getIterationDomain` method to get
+/// the
+/// initial loop structure as described in the second snippet. If any of
+/// those loops are reductions, then that IR snippet itself is wrong (replace
+/// this with the case of `linalg.matmul` and the error becomes apparent).
+
+/// First pattern to introduce the loop nests.
+struct OpTilingPattern : public OpInterfaceRewritePattern<TilingInterface> {
+ OpTilingPattern(MLIRContext *context, linalg::LinalgTilingOptions opt,
+ linalg::LinalgTransformationFilter filt)
+ : OpInterfaceRewritePattern<TilingInterface>(context),
+ options(opt),
+ filter(filt) {}
+
+ LogicalResult matchAndRewrite(TilingInterface op,
+ PatternRewriter &rewriter) const override {
+ if (failed(filter.checkAndNotify(rewriter, op))) return failure();
+
+ /// Currently only handle single result operations.
+ if (op->getNumResults() != 1) return failure();
+
+ Location loc = op->getLoc();
+ // Get rank and tile sizes.
+ SmallVector<Value> tileSizes =
+ options.tileSizeComputationFunction(rewriter, op);
+ auto iteratorTypes = op.getLoopIteratorTypes();
+ Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ tileSizes.resize(iteratorTypes.size(), zero);
+
+ /// Currently only handle operations with all parallel iterator types.
+ for (auto iteratorType : enumerate(iteratorTypes)) {
+ if (iteratorType.value() != getParallelIteratorTypeName() &&
+ !isZero(tileSizes[iteratorType.index()])) {
+ return rewriter.notifyMatchFailure(
+ op, "unhandled tiling of non-parallel iterator");
+ }
+ }
+
+ auto clonedOp = cast<TilingInterface>(rewriter.clone(*op.getOperation()));
+ SmallVector<Value> results = tileToSCF(rewriter, op, clonedOp, tileSizes);
+
+ filter.replaceLinalgTransformationFilter(rewriter, clonedOp);
+ rewriter.replaceOp(op, results);
+ return success();
}
- void runOnOperation() override;
+
+ private:
+ linalg::LinalgTilingOptions options;
+ linalg::LinalgTransformationFilter filter;
};
-} // namespace
-template <typename OpTy>
-static Value buildFlowWorkgroupInfoOp(OpBuilder &b, unsigned dim) {
- return b.template create<OpTy>(b.getInsertionPoint()->getLoc(), dim);
-}
+/// Second pattern to implement the switch of `TilingInterface ->
+/// tensor.extract_slice` to `tensor.extract_slice -> `TilingInterface`.
+struct SliceOpTiledOpSwapPattern
+ : public OpRewritePattern<tensor::ExtractSliceOp> {
+ SliceOpTiledOpSwapPattern(MLIRContext *context,
+ linalg::LinalgTilingOptions opt,
+ linalg::LinalgTransformationFilter filt)
+ : OpRewritePattern<tensor::ExtractSliceOp>(context),
+ options(opt),
+ filter(filt) {}
-void TiledOpInterfaceTilingPass::runOnOperation() {
- FuncOp funcOp = getOperation();
- MLIRContext *context = funcOp.getContext();
-
- RewritePatternSet patterns(context);
- patterns.add<TiledOpInterfaceTilingPattern>(
- context, linalg::LinalgTilingOptions().setTileSizes({10, 20}),
- linalg::LinalgTransformationFilter(
- StringAttr::get(context, "tiling_input"),
- StringAttr::get(context, "tiling_output")));
- patterns.add<TiledOpInterfaceTilingPattern>(
- context, linalg::LinalgTilingOptions().setTileSizes(ArrayRef<int64_t>{0}),
- linalg::LinalgTransformationFilter(
- StringAttr::get(context, "no_tiling_input"),
- StringAttr::get(context, "no_tiling_output")));
-
- patterns.add<TiledOpInterfaceTilingPattern>(
- context, linalg::LinalgTilingOptions().setTileSizes({0, 20}),
- linalg::LinalgTransformationFilter(
- StringAttr::get(context, "outer_reduce_input"),
- StringAttr::get(context, "outer_reduce_output")));
- patterns.add<TiledOpInterfaceTilingPattern>(
- context, linalg::LinalgTilingOptions().setTileSizes({10, 0, 0}),
- linalg::LinalgTransformationFilter(
- StringAttr::get(context, "inner_reduce_input"),
- StringAttr::get(context, "inner_reduce_output")));
-
- static linalg::LinalgLoopDistributionOptions workgroupDistributionOptions = {
- [](OpBuilder &builder, Location loc, ArrayRef<Range> parallelLoopRanges) {
- auto numParallelDims = parallelLoopRanges.size();
-
- SmallVector<linalg::ProcInfo, 3> procInfo(numParallelDims);
- for (size_t dim = 0; dim < numParallelDims; ++dim) {
- procInfo[numParallelDims - dim - 1] = {
- buildFlowWorkgroupInfoOp<IREE::Input::DispatchWorkgroupIDOp>(
- builder, dim),
- buildFlowWorkgroupInfoOp<IREE::Input::DispatchWorkgroupCountOp>(
- builder, dim)};
- }
- return procInfo;
- },
- {linalg::DistributionMethod::Cyclic, linalg::DistributionMethod::Cyclic,
- linalg::DistributionMethod::Cyclic},
- DenseMap<StringRef,
- std::function<linalg::ProcInfo(OpBuilder &, Location)>>()};
-
- patterns.add<TiledOpInterfaceTilingPattern>(
- context,
- linalg::LinalgTilingOptions()
- .setTileSizes(ArrayRef<int64_t>{10, 0, 30})
- .setDistributionOptions(workgroupDistributionOptions),
- linalg::LinalgTransformationFilter(
- StringAttr::get(context, "distribute_input"),
- StringAttr::get(context, "distribute_output")));
-
- patterns.add<TiledOpInterfaceTilingPattern>(
- context,
- linalg::LinalgTilingOptions().setTileSizes(ArrayRef<int64_t>{32}),
- linalg::LinalgTransformationFilter(
- StringAttr::get(context, "tiling_1d_stage5_fft_input"),
- StringAttr::get(context, "tiling_1d_stage5_fft_output")));
-
- patterns.add<TiledOpInterfaceTilingPattern>(
- context,
- linalg::LinalgTilingOptions().setTileSizes(ArrayRef<int64_t>{10, 32}),
- linalg::LinalgTransformationFilter(
- StringAttr::get(context, "tiling_2d_stage5_fft_input"),
- StringAttr::get(context, "tiling_2d_stage5_fft_output")));
-
- patterns.add<TiledOpInterfaceTilingPattern>(
- context, linalg::LinalgTilingOptions().setTileSizes({0, 20}),
- linalg::LinalgTransformationFilter(
- StringAttr::get(context, "tiling_repeated_indices_scatter_input"),
- StringAttr::get(context, "tiling_repeated_indices_scatter_output")));
-
- if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
- return signalPassFailure();
+ LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
+ PatternRewriter &rewriter) const override {
+ auto sourceOp = sliceOp.source().getDefiningOp<TilingInterface>();
+ if (!sourceOp || !filter.hasReplacementFilter(sourceOp)) return failure();
+ SmallVector<Operation *> tiledOps = sourceOp.getTiledImplementation(
+ rewriter, sourceOp.getDestinationOperands(rewriter),
+ sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
+ /*tileDestOperands=*/true);
+ assert(tiledOps.size() && "expected single tiled op");
+ Operation *tiledOp = tiledOps.front();
+ rewriter.replaceOp(sliceOp, tiledOp->getResults());
+ return success();
}
-}
-std::unique_ptr<OperationPass<FuncOp>>
-IREE::LinalgExt::createTiledOpInterfaceTilingPass() {
- return std::make_unique<TiledOpInterfaceTilingPass>();
-}
+ private:
+ linalg::LinalgTilingOptions options;
+ linalg::LinalgTransformationFilter filter;
+};
+
+} // namespace
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TilingExternalModels.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TilingExternalModels.cpp
new file mode 100644
index 0000000..7174daa
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TilingExternalModels.cpp
@@ -0,0 +1,178 @@
+//===- TilingExternalModels.cpp - External models for TilingInterface -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Interfaces/TilingInterface.h"
+
+#define DEBUG_TYPE "linalg-ext-tiling"
+
+using namespace mlir;
+using namespace mlir::linalg;
+using namespace mlir::iree_compiler::IREE::LinalgExt;
+
+static Value getAsValue(OpBuilder &b, Location loc, OpFoldResult ofr) {
+ if (auto v = ofr.dyn_cast<Value>()) return v;
+ return b.create<arith::ConstantIndexOp>(
+ loc, ofr.get<Attribute>().cast<IntegerAttr>().getInt());
+}
+static SmallVector<Value> getAsValues(OpBuilder &b, Location loc,
+ ArrayRef<OpFoldResult> ofrs) {
+ SmallVector<Value> vals;
+ vals.reserve(ofrs.size());
+ for (auto ofr : ofrs) vals.push_back(getAsValue(b, loc, ofr));
+ return vals;
+}
+
+static SmallVector<Value, 4> makeTiledInputShapes(OpBuilder &b, Location loc,
+ LinalgOp linalgOp,
+ ArrayRef<Value> valuesToTile,
+ ArrayRef<Value> ivsRef,
+ ArrayRef<Value> tileSizesRef,
+ ArrayRef<Value> sizeBounds) {
+ assert(static_cast<int64_t>(valuesToTile.size()) == linalgOp.getNumInputs() &&
+ "expected one value to tile for every operand");
+
+ Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+ SmallVector<Value> tileSizes{tileSizesRef.begin(), tileSizesRef.end()};
+ tileSizes.append(sizeBounds.size() - tileSizes.size(), zero);
+
+ // Construct (potentially temporary) mins and maxes on which to apply maps
+ // that define tile subshapes.
+ SmallVector<Value> lbs = computeTileOffsets(b, loc, ivsRef, tileSizes);
+ SmallVector<Value> subShapeSizes =
+ computeTileSizes(b, loc, ivsRef, tileSizes, sizeBounds);
+
+ SmallVector<Value, 4> tiledShapes;
+ tiledShapes.reserve(valuesToTile.size());
+ for (OpOperand *opOperand : linalgOp.getInputOperands()) {
+ Value shapedOp = valuesToTile[opOperand->getOperandNumber()];
+ LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
+ AffineMap map = linalgOp.getTiedIndexingMap(opOperand);
+ LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");
+ tiledShapes.push_back(makeTiledShape(b, loc, shapedOp, tileSizes, map, lbs,
+ sizeBounds, subShapeSizes));
+ }
+
+ return tiledShapes;
+}
+
+namespace {
+
+/// External model implementation of TilingInterface for LinalgOps. This is
+/// templated on the actual Linalg named op for now since the registration of
+/// the external model requires the original operation.
+template <typename LinalgOpTy>
+struct LinalgOpTilingInterface
+ : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>,
+ LinalgOpTy> {
+ SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
+ LinalgOp linalgOp = cast<LinalgOp>(op);
+ return linalgOp.getOutputOperands();
+ }
+
+ SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
+ LinalgOp linalgOp = cast<LinalgOp>(op);
+ SmallVector<StringRef> iteratorTypes;
+ iteratorTypes.reserve(linalgOp.iterator_types().size());
+ for (Attribute iteratorAttr : linalgOp.iterator_types()) {
+ iteratorTypes.push_back(iteratorAttr.cast<StringAttr>().getValue());
+ }
+ return iteratorTypes;
+ }
+
+ SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
+ LinalgOp linalgOp = cast<LinalgOp>(op);
+ return linalgOp.createLoopRanges(b, op->getLoc());
+ }
+
+ SmallVector<Operation *> getTiledImplementation(
+ Operation *op, OpBuilder &b, ValueRange tiledDest,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ bool tileDestOperands) const {
+ LinalgOp linalgOp = cast<LinalgOp>(op);
+ if (op->getNumResults() != 1) {
+ // TODO: Need a failure message here, but `notifyMatchFailure` is only a
+ // method on `PatternRewriter`.
+ return {};
+ }
+ Location loc = op->getLoc();
+ AffineMap shapeSizesToLoopsMap = linalgOp.getShapesToLoopsMap();
+ auto allShapeSizes = linalgOp.createFlatListOfOperandDims(b, loc);
+ if (!shapeSizesToLoopsMap) return {};
+
+ OpOperand *outOperand = linalgOp.getOutputOperand(0);
+ AffineMap indexingMap = linalgOp.getTiedIndexingMap(outOperand);
+ if (!indexingMap.isProjectedPermutation()) return {};
+
+ SmallVector<Value> offsetsVals = getAsValues(b, loc, offsets);
+ SmallVector<Value> sizeVals = getAsValues(b, loc, sizes);
+ SmallVector<Value> sizeBounds =
+ applyMapToValues(b, loc, shapeSizesToLoopsMap, allShapeSizes);
+
+ // The offsets and sizes form the slice operation only give you the tile
+ // size of the output. Use that compute the tile sizes and offsets of the
+ // loops. For loops not used to access the output, set the tile sizes to
+ // loop bounds and set the offset to 0.
+ Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+ SmallVector<Value> tileOffsets(sizeBounds.size(), zero);
+ SmallVector<Value> tileSizes = sizeBounds;
+ for (auto result : enumerate(indexingMap.getResults())) {
+ unsigned position = result.value().cast<AffineDimExpr>().getPosition();
+ tileOffsets[position] = offsetsVals[result.index()];
+ tileSizes[position] = sizeVals[result.index()];
+ }
+
+ SmallVector<Value> valuesToTile = linalgOp.getInputOperands();
+ SmallVector<Value> tiledOperands;
+ if (tileDestOperands) {
+ // Append the outputs then tile both the inputs and outputs.
+ valuesToTile.append(tiledDest.begin(), tiledDest.end());
+ tiledOperands = makeTiledShapes(b, loc, linalgOp, valuesToTile,
+ tileOffsets, tileSizes, sizeBounds);
+ } else {
+ // Only tile the inputs, then apped the outputs.
+ int64_t dim = offsets.size();
+ ArrayRef<Value> tileOffsetsRef{tileOffsets.begin(), tileOffsets.end()};
+ ArrayRef<Value> tileSizesRef{tileSizes.begin(), tileSizes.end()};
+ tiledOperands = makeTiledInputShapes(
+ b, loc, linalgOp, valuesToTile, tileOffsetsRef.take_front(dim + 1),
+ tileSizesRef.take_front(dim + 1), sizeBounds);
+ tiledOperands.append(tiledDest.begin(), tiledDest.end());
+ }
+ return {linalgOp.clone(b, loc, tiledDest.getTypes(), tiledOperands)};
+ }
+};
+} // namespace
+
+template <typename OpType>
+void registerOne(DialectRegistry ®istry) {
+ registry.addOpInterface<OpType, LinalgOpTilingInterface<OpType>>();
+}
+
+/// Variadic helper function.
+template <typename... OpTypes>
+void registerAll(DialectRegistry ®istry) {
+ // FIXME: In c++17 this can be simplified by using 'fold expressions'.
+ (void)std::initializer_list<int>{0, (registerOne<OpTypes>(registry), 0)...};
+}
+
+#define GET_OP_LIST
+
+void mlir::iree_compiler::IREE::LinalgExt::
+ registerTilingInterfaceExternalModels(DialectRegistry ®istry) {
+ registerOne<linalg::GenericOp>(registry);
+ registerAll<
+#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
+ >(registry);
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TilingToTileOp.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TilingToTileOp.cpp
new file mode 100644
index 0000000..ba8cc4d
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/TilingToTileOp.cpp
@@ -0,0 +1,106 @@
+//===- TilingToTileOp.cpp - Tiling using to TileOp TilingInterface --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::iree_compiler::IREE::LinalgExt;
+
+struct TilingResult {
+ TileOp tileOp;
+ Operation *tiledOp;
+};
+
+static TilingResult tileToTileOp(PatternRewriter &rewriter, TilingInterface op,
+ int64_t tiledDim, Value tileSize) {
+ Location loc = op->getLoc();
+ OpBuilder::InsertionGuard g(rewriter);
+ // TODO: Handle the case where the `loopRanges` are empty.
+ SmallVector<Range> loopRanges = op.getIterationDomain(rewriter);
+ assert(loopRanges.size() >= 1 &&
+ "expected at least a single loop in operation");
+ auto destOperands = op.getDestinationOperands(rewriter);
+ Operation *tiledOp = nullptr;
+ auto tileOp = rewriter.create<TileOp>(
+ loc, tileSize, destOperands, tiledDim,
+ [&](OpBuilder &b, Location loc, Value offset, Value size,
+ ValueRange outSlices) {
+ // TODO: support `getTiledImplementation` with >1 produced tiled ops.
+ int64_t nLoops = loopRanges.size();
+ SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
+ tiledOffsets.reserve(nLoops);
+ tiledSizes.reserve(nLoops);
+ for (unsigned i = 0; i < nLoops; ++i) {
+ if (i == tiledDim) {
+ tiledOffsets.push_back(offset);
+ tiledSizes.push_back(size);
+ } else {
+ tiledOffsets.push_back(loopRanges[i].offset);
+ tiledSizes.push_back(loopRanges[i].size);
+ }
+ }
+ SmallVector<Operation *> tiledOps = op.getTiledImplementation(
+ b, outSlices, tiledOffsets, tiledSizes, /*tileDestOperands=*/false);
+ assert(tiledOps.size() == 1 && "expected single tiled op");
+ tiledOp = tiledOps.front();
+ b.create<TileYieldOp>(loc, tiledOp->getResults());
+ });
+ return TilingResult{tileOp, tiledOp};
+}
+
+FailureOr<Operation *> mlir::iree_compiler::IREE::LinalgExt::
+ LinalgExtTilingPattern::returningMatchAndRewrite(
+ TilingInterface op, PatternRewriter &rewriter) const {
+ /// Currently only handle single result operations.
+ if (op->getNumResults() != 1)
+ return rewriter.notifyMatchFailure(op, "Not a single result");
+
+ // Get rank and tile sizes.
+ // TODO: consider moving these checks to a common place that the TransformOp
+ // verifier can also use.
+ SmallVector<Value> tileSizes =
+ options.tileSizeComputationFunction(rewriter, op);
+ int64_t dim = -1;
+ for (auto en : llvm::enumerate(tileSizes)) {
+ Optional<int64_t> maybeTileSize = getConstantIntValue(en.value());
+ if (maybeTileSize && *maybeTileSize == 0) continue;
+ if (maybeTileSize && *maybeTileSize < 0)
+ return rewriter.notifyMatchFailure(op, "Negative tile size");
+ if (dim >= 0)
+ return rewriter.notifyMatchFailure(op,
+ "Could not find a single tiling dim");
+ dim = en.index();
+ }
+ if (dim < 0)
+ return rewriter.notifyMatchFailure(op,
+ "Could not find a single tiling dim");
+
+ /// Currently only handle tiling operations on a parallel iterator type.
+ auto loopIteratorTypes = op.getLoopIteratorTypes();
+ // Scalar operation, nothing to do, so just return.
+ if (loopIteratorTypes.empty())
+ return rewriter.notifyMatchFailure(op, "Scalar op, no tiling possible");
+ ArrayRef<StringRef> loopIteratorTypesRef(loopIteratorTypes);
+ if (loopIteratorTypesRef[dim] != getParallelIteratorTypeName())
+ return rewriter.notifyMatchFailure(op, "Trying to tile a non-parallel dim");
+
+ TilingResult tilingResult = tileToTileOp(rewriter, op, dim, tileSizes[dim]);
+ rewriter.replaceOp(op, tilingResult.tileOp->getResults());
+
+ return tilingResult.tiledOp;
+}
diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Utils.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Utils.cpp
new file mode 100644
index 0000000..9b250b8
--- /dev/null
+++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Utils.cpp
@@ -0,0 +1,104 @@
+//===- Utils.cpp - LinalgExt transform utils ------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "iree-dialects/Dialect/LinalgExt/Transforms/Utils.h"
+
+#include "iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::iree_compiler::IREE::LinalgExt;
+
+void mlir::iree_compiler::IREE::LinalgExt::completeOffsetsSizesAndStrides(
+ OpBuilder &b, Location loc, Value tensor, ArrayRef<Value> leadingOffsets,
+ ArrayRef<Value> leadingSizes, ArrayRef<Value> leadingStrides,
+ SmallVectorImpl<Value> &offsets, SmallVectorImpl<Value> &sizes,
+ SmallVectorImpl<Value> &strides) {
+ assert(leadingOffsets.size() == leadingSizes.size() &&
+ "expected matching lengths");
+ assert(leadingSizes.size() == leadingStrides.size() &&
+ "expected matching lengths");
+
+ auto rankedTensorType = tensor.getType().cast<RankedTensorType>();
+ int64_t tensorRank = rankedTensorType.getRank();
+ int64_t leadingRank = leadingOffsets.size();
+ offsets = SmallVector<Value>(leadingOffsets.begin(), leadingOffsets.end());
+ sizes = SmallVector<Value>(leadingSizes.begin(), leadingSizes.end());
+ strides = SmallVector<Value>(leadingStrides.begin(), leadingStrides.end());
+ if (leadingRank >= tensorRank) return;
+ Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = b.create<arith::ConstantIndexOp>(loc, 1);
+ for (int64_t i = leadingRank, e = tensorRank; i < e; ++i) {
+ offsets.push_back(zero);
+ sizes.push_back(b.createOrFold<tensor::DimOp>(loc, tensor, i));
+ strides.push_back(one);
+ }
+}
+
+/// Create a tensor::ExtractSliceOp by auto-completing the missing trailing
+/// dimensions to always be offset = 0, size = dim, stride = 1.
+Value mlir::iree_compiler::IREE::LinalgExt::
+ createSubsetExtractOpFromLeadingOffsetsSizesAndStrides(
+ OpBuilder &b, Location loc, Value tensor,
+ ArrayRef<Value> leadingOffsets, ArrayRef<Value> leadingSizes,
+ ArrayRef<Value> leadingStrides) {
+ SmallVector<Value> offsets, sizes, strides;
+ completeOffsetsSizesAndStrides(b, loc, tensor, leadingOffsets, leadingSizes,
+ leadingStrides, offsets, sizes, strides);
+ return b.createOrFold<tensor::ExtractSliceOp>(loc, tensor, offsets, sizes,
+ strides);
+}
+
+/// Create a tensor::InsertSliceOp by auto-completing the missing trailing
+/// dimensions to always be offset = 0, size = dim, stride = 1.
+Value mlir::iree_compiler::IREE::LinalgExt::
+ createSubsetInsertOpFromLeadingOffsetsSizesAndStrides(
+ OpBuilder &b, Location loc, Value tensor, Value dest,
+ ArrayRef<Value> leadingOffsets, ArrayRef<Value> leadingSizes,
+ ArrayRef<Value> leadingStrides) {
+ SmallVector<Value> offsets, sizes, strides;
+ completeOffsetsSizesAndStrides(b, loc, tensor, leadingOffsets, leadingSizes,
+ leadingStrides, offsets, sizes, strides);
+ return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
+ sizes, strides);
+}
+
+/// Create a iree_compiler::IREE::LinalgExt::ParallelInsertSliceOp by
+/// auto-completing the missing trailing dimensions to always be offset = 0,
+/// size = dim, stride = 1.
+Operation *mlir::iree_compiler::IREE::LinalgExt::
+ createParallelInsertSliceOpFromLeadingOffsetsSizesAndStrides(
+ OpBuilder &b, Location loc, Value tensor, Value dest,
+ ArrayRef<Value> leadingOffsets, ArrayRef<Value> leadingSizes,
+ ArrayRef<Value> leadingStrides) {
+ SmallVector<Value> offsets, sizes, strides;
+ completeOffsetsSizesAndStrides(b, loc, tensor, leadingOffsets, leadingSizes,
+ leadingStrides, offsets, sizes, strides);
+ return b.createOrFold<iree_compiler::IREE::LinalgExt::ParallelInsertSliceOp>(
+ loc, tensor, dest, offsets, sizes, strides);
+}
+
+/// Insert the `source` tensor into the `dest` tensor by creating the relevant
+/// `subset_insert` op. The details of the `subset_insert` op are retrieved
+/// from the `subset_extract` op so that they form a matching extract/insert
+/// pair.
+Value mlir::iree_compiler::IREE::LinalgExt::createMatchingSubsetInsertOp(
+ OpBuilder &b, Location loc, tensor::ExtractSliceOp subsetExtractOp,
+ Value source, Value dest) {
+ return b.create<tensor::InsertSliceOp>(
+ loc, subsetExtractOp.source().getType(), source, dest,
+ subsetExtractOp.offsets(), subsetExtractOp.sizes(),
+ subsetExtractOp.strides(), subsetExtractOp.static_offsets(),
+ subsetExtractOp.static_sizes(), subsetExtractOp.static_strides());
+}