Remove the patterns that create dispatch regions without tile+fuse+distribute. (#7328)
With TiledOpInterface most of the operations to be executed on the
device can be tiled and distributed. So the pattern that created a
non-tiled and distributed dispatch region can be removed.
Simplify the code that decides which ops are root ops.
To make all tests pass, the last operation that needs to be tiled and
distributed is tensor.extract_slice. Add an external model to
implement the tile+distribute of this op to allow deleting the fall
back path.
This is recommit of PR #7306 which was reverted in #7319 due to failure
on integrate. The changes from #7306 fix the breakages.
Resolve tensor.dim operation before legalization during dispatch
region formation.
Allow tiling of tensor.extract_slice with non-unit strides.
diff --git a/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp b/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
index e839d2a..c10fd83 100644
--- a/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
+++ b/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
@@ -307,6 +307,13 @@
rootOperation = op;
break;
}
+ if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) {
+ // linalg.generic with `reduction` iterator types are roots as well.
+ if (genericOp.getNumLoops() != genericOp.getNumParallelLoops()) {
+ rootOperation = op;
+ break;
+ }
+ }
}
if (!rootOperation) {
diff --git a/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
index 27bccf4..be4fd7d 100644
--- a/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
+++ b/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
@@ -499,6 +499,14 @@
// If unsuccessful, try to tile and distribute.
return setDefaultOpConfig(limits, op);
})
+ .Case<linalg::GenericOp>([limits](auto op) {
+ // If generic op has reduction iterator types, it is a root as
+ // well. Just set the default configuration, which marks it as a root.
+ if (op.getNumLoops() != op.getNumParallelLoops()) {
+ return setDefaultOpConfig(limits, op);
+ }
+ return success();
+ })
.Default([](Operation *) { return success(); });
};
@@ -529,8 +537,7 @@
return funcOp.emitOpError("failed to get compute ops");
}
- int64_t subgroupSize =
- targetEnv.getResourceLimits().subgroup_size().getValue().getSExtValue();
+ int64_t subgroupSize = limits.subgroup_size().getValue().getSExtValue();
// If the dispatch region does not contain tiled and distributed Linalg ops,
// invoke the pipeline to distribute to global invocations.
@@ -571,8 +578,8 @@
// Check if the op configuration was set.
if (!getLoweringConfig(computeOp)) {
return computeOp->emitOpError(
- "without known roots, the last operation in the tiled loop body "
- "is expected to be set as root");
+ "without known roots, the last compute operation in the tiled "
+ "loop body is expected to be set as root");
}
rootOperation = computeOp;
break;
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
index 863b089..4f59a67 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowOpFolders.cpp
@@ -338,7 +338,8 @@
ArrayRef<OpFoldResult> mixedOffsets,
ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides) {
- return DispatchTensorLoadOp::inferResultType(
+ return DispatchTensorLoadOp::inferRankReducedResultType(
+ loadOp.result().getType().cast<RankedTensorType>().getRank(),
loadOp.source().getType().cast<DispatchTensorType>(), mixedSizes);
}
};
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
index cee79f4..89899c5 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
+++ b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp
@@ -14,6 +14,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Utils/Utils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
@@ -82,6 +83,36 @@
}
}
+RankedTensorType DispatchTensorLoadOp::inferRankReducedResultType(
+ unsigned resultRank, IREE::Flow::DispatchTensorType sourceType,
+ ArrayRef<OpFoldResult> mixedSizes) {
+ // This is using logic from
+ // `tensor::ExtractSliceOp::inferRankReducedResultType`. Eventually just use
+ // that.
+ auto shape = llvm::to_vector<4>(
+ llvm::map_range(mixedSizes, [&](OpFoldResult valueOrAttr) -> int64_t {
+ if (auto attr = valueOrAttr.dyn_cast<Attribute>()) {
+ return attr.cast<IntegerAttr>().getInt();
+ }
+ return DispatchTensorType::kDynamicSize;
+ }));
+ auto inferredType = RankedTensorType::get(shape, sourceType.getElementType());
+ int rankDiff = sourceType.getRank() - resultRank;
+ if (rankDiff > 0) {
+ llvm::SmallDenseSet<unsigned> dimsToProject;
+ mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
+ SmallVector<int64_t> projectedShape;
+ for (unsigned pos = 0, e = shape.size(); pos < e; ++pos) {
+ if (!dimsToProject.contains(pos)) {
+ projectedShape.push_back(shape[pos]);
+ }
+ }
+ inferredType =
+ RankedTensorType::get(projectedShape, inferredType.getElementType());
+ }
+ return inferredType;
+}
+
RankedTensorType DispatchTensorLoadOp::inferResultType(
IREE::Flow::DispatchTensorType sourceType,
ArrayRef<OpFoldResult> mixedSizes) {
diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.td b/iree/compiler/Dialect/Flow/IR/FlowOps.td
index 3dbc19f..17d50cc 100644
--- a/iree/compiler/Dialect/Flow/IR/FlowOps.td
+++ b/iree/compiler/Dialect/Flow/IR/FlowOps.td
@@ -350,8 +350,8 @@
/// Return the expected rank of each of the`static_offsets`, `static_sizes`
/// and `static_strides` attributes.
std::array<unsigned, 3> getArrayAttrMaxRanks() {
- unsigned resultRank = getResult().getType().cast<ShapedType>().getRank();
- return {resultRank, resultRank, resultRank};
+ unsigned sourceRank = source().getType().cast<DispatchTensorType>().asTensorType().getRank();
+ return {sourceRank, sourceRank, sourceRank};
}
/// Return the number of leading operands before the `offsets`, `sizes` and
@@ -362,7 +362,12 @@
static RankedTensorType inferResultType
(IREE::Flow::DispatchTensorType sourceType,
ArrayRef<OpFoldResult> mixedSizes);
- }];
+
+ /// Returns the type of the result based on the sizes.
+ static RankedTensorType inferRankReducedResultType
+ (unsigned resultRank, IREE::Flow::DispatchTensorType sourceType,
+ ArrayRef<OpFoldResult> mixedSizes);
+}];
let hasCanonicalizer = 1;
let hasFolder = 1;
diff --git a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
index 8905b71..6ba9c4b 100644
--- a/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
+++ b/iree/compiler/Dialect/Flow/Transforms/DispatchLinalgOnTensors.cpp
@@ -12,14 +12,12 @@
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
-#include "iree/compiler/Dialect/Shape/IR/Builders.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
-#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -143,43 +141,30 @@
// Op property charecterizations
//===----------------------------------------------------------------------===//
-/// The current fusion algorithm has some embedded heuristics that are meant to
-/// be a first simple start, and can be adapted over time. Note hoever that it
-/// is better to have a simple default strategy and use some search-based
-/// techniques for actual heuristics. Current heuristics classify operations in
-/// this heirarchy
-/// - Root Op : These are ops that are computationally intensive and most
-/// probably dominate model execution time. These are in general named ops
-/// like linalg.matmul, linalg.conv, etc. These are tiled and distributed
-/// across workgroups.
-/// - Dispatchable ops : These are ops that are not root operations, but still
-/// perform some "meaningful" computation. Typically, fused element-wise
-/// operations, represented as linalg.generic. These could be fused with root
-/// operations using tile + fuse, or could be in their own dispatch regions.
-/// - Always fused dispatchable ops : These are ops that are chosen to always be
-/// fused into dispatch regions that use their values, since when bufferized
-/// they can be converted into being no-copy/aliasing operations. Examples of
-/// this is linalg.tensor_reshape that can be converted to a linalg.reshape on
-/// bufferization. These are different from dispatchable ops in that they are
-/// never in their own dispatch region unless there is no consumer to fuse
-/// them with. Typically when the result of the operation is the
-/// output.
-/// - Always cloned into dispatch op : These are operations that are operations
-/// that are always cloned into their consuming dispatch regions and never end
-/// up in their own dispatch regions. Typical examples are splat constants and
-/// linalg.init_tensor operations.
-
+/// Operations that are treated as root operations for dispatch region
+/// formation.
static bool isRootOp(Operation *op) {
if (op->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>()) {
return false;
}
- return (isa<linalg::LinalgOp>(op) &&
- !isa<linalg::GenericOp, linalg::FillOp>(op)) ||
- isa<linalg_ext::TiledOpInterface>(op);
+ // Any Linalg named op or generic op with reduction iterator types is a root
+ // op.
+ if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
+ if (isa<linalg::GenericOp>(op)) {
+ return linalgOp.getNumReductionLoops() != 0;
+ }
+ return !isa<linalg::FillOp>(op);
+ }
+ return isa<linalg_ext::TiledOpInterface>(op) &&
+ !isa<tensor::ExtractSliceOp>(op);
}
-static bool isAlwaysClonedIntoDispatchOp(Operation *op) {
- if (isa<arith::IndexCastOp, linalg::InitTensorOp, tensor::ExtractOp>(op)) {
+/// Operations that are cloned into dispatch regions formed with other
+/// operations as roots.
+static bool isClonableIntoDispatchOp(Operation *op) {
+ if (isa<arith::IndexCastOp, linalg::InitTensorOp,
+ linalg::TensorCollapseShapeOp, linalg::TensorExpandShapeOp,
+ tensor::ExtractOp, tensor::ExtractSliceOp>(op)) {
return true;
}
if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
@@ -194,32 +179,6 @@
return false;
}
-static bool isDispatchableOp(Operation *op) {
- // Ignore operations already in dispatch regions.
- if (op->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>()) {
- return false;
- }
- // Linalg ops are marked dispatchable.
- if ((op->getDialect() !=
- op->getContext()->getLoadedDialect<linalg::LinalgDialect>()) &&
- !isa<tensor::ExtractSliceOp, tensor::InsertSliceOp>(op)) {
- return false;
- }
-
- // Mark linalg.fill as non-dispatchable so that for those linalg.fill ops that
- // cannot be fused together with some root op, they are left out of dispatch
- // region formation, and to be picked up by DMA op conversion.
- if (isa<linalg::FillOp>(op)) return false;
-
- return !isAlwaysClonedIntoDispatchOp(op);
-}
-
-static bool isAlwaysFusedIntoDispatchOp(Operation *op) {
- return isDispatchableOp(op) &&
- (isa<linalg::TensorCollapseShapeOp, tensor::ExtractSliceOp>(op) ||
- isa<linalg::TensorExpandShapeOp, tensor::ExtractSliceOp>(op));
-}
-
//===----------------------------------------------------------------------===//
// Methods that help creating the dispatch regions
//===----------------------------------------------------------------------===//
@@ -272,8 +231,8 @@
}
rewriter.create<IREE::Flow::ReturnOp>(loc);
}
- DEBUG_WITH_TYPE(DEBUG_TYPE, llvm::dbgs() << "Created dispatchOp shell "
- << *dispatchOp << "\n");
+ LLVM_DEBUG(llvm::dbgs() << "Created dispatchOp shell \n"
+ << *dispatchOp << "\n");
return {dispatchOp, clonedOp};
}
@@ -290,8 +249,8 @@
PatternRewriter &rewriter, IREE::Flow::DispatchWorkgroupsOp dispatchOp,
linalg::LinalgOp tiledOp, ValueRange untiledOpOperands,
ArrayRef<Operation *> tiledLoops, int64_t groupNum) {
- DEBUG_WITH_TYPE(DEBUG_TYPE, llvm::dbgs() << "pull in producers for tiled op: "
- << tiledOp << "\n");
+ LLVM_DEBUG(llvm::dbgs() << "pull in producers for tiled op: " << tiledOp
+ << "\n");
// Scoped within DispatchWorkgroupOp.
OpBuilder::InsertionGuard g(rewriter);
@@ -308,8 +267,7 @@
linalg::LinalgOp fusedProducer;
if (tiledLoops.empty()) {
- DEBUG_WITH_TYPE(DEBUG_TYPE, llvm::dbgs()
- << "no loops; just copy over the op\n");
+ LLVM_DEBUG(llvm::dbgs() << "no loops; just copy over the op\n");
// The root op wasn't tiled. We are done then.
removeFusionGroupsAttribute(clonedOrigProducer);
fusedProducer = cast<linalg::LinalgOp>(clonedOrigProducer);
@@ -321,12 +279,10 @@
rewriter, clonedOrigProducer->getResult(opResult.getResultNumber()),
tiledOp->getOpOperand(en.index()));
if (!maybeFusionInfo.hasValue()) {
- DEBUG_WITH_TYPE(DEBUG_TYPE, llvm::dbgs()
- << "failed to fuse with tensor\n");
+ LLVM_DEBUG(llvm::dbgs() << "failed to fuse with tensor\n");
rewriter.replaceOp(clonedOrigProducer, producer->getResults());
} else {
- DEBUG_WITH_TYPE(DEBUG_TYPE, llvm::dbgs()
- << "succeeded to fuse with tensor\n");
+ LLVM_DEBUG(llvm::dbgs() << "succeeded to fuse with tensor\n");
removeFusionGroupsAttribute(maybeFusionInfo->fusedProducer);
fusedProducer = maybeFusionInfo->fusedProducer;
}
@@ -355,7 +311,7 @@
/// Reorders the operations in `ops` such that they could be inlined into the
/// dispatch region in that order to satisfy dependencies.
static SmallVector<Operation *> orderOperations(ArrayRef<Operation *> ops) {
- DEBUG_WITH_TYPE(DEBUG_TYPE, {
+ LLVM_DEBUG({
llvm::dbgs() << "Ops to be inlined :\n";
for (auto op : ops) {
llvm::dbgs() << "\t";
@@ -421,7 +377,7 @@
readyOps = ArrayRef<Operation *>(orderedOps).drop_front(startPos);
}
- DEBUG_WITH_TYPE(DEBUG_TYPE, {
+ LLVM_DEBUG({
llvm::dbgs() << "Ops to be inlined (sorted) : \n";
for (auto op : orderedOps) {
llvm::dbgs() << "\t";
@@ -450,8 +406,7 @@
if (visited.count(outsideValue)) continue;
visited.insert(outsideValue);
Operation *definingOp = outsideValue.getDefiningOp();
- if (!definingOp || !(isAlwaysClonedIntoDispatchOp(definingOp) ||
- isAlwaysFusedIntoDispatchOp(definingOp))) {
+ if (!definingOp || !(isClonableIntoDispatchOp(definingOp))) {
valuesDefinedAbove.insert(outsideValue);
continue;
}
@@ -594,19 +549,6 @@
block->eraseArguments(eraseArguments);
}
-static void replaceAllUsesWithinDispatchOp(
- IREE::Flow::DispatchWorkgroupsOp dispatchOp, Value value,
- Value replacement) {
- SmallPtrSet<Operation *, 4> usesOutsideDispatch;
- for (Operation *user : value.getUsers()) {
- if (isa<IREE::Flow::DispatchWorkgroupsOp>(user) ||
- !dispatchOp->isAncestor(user)) {
- usesOutsideDispatch.insert(user);
- }
- }
- value.replaceAllUsesExcept(replacement, usesOutsideDispatch);
-}
-
// After outlining in dispatch region we can rewrite the dispatch ops with
// proper captures.
// A later RematerializeDispatchConstants should be called to avoid passing
@@ -730,40 +672,25 @@
return {};
}
-/// Computes the shape of the output. This is used to get the workload of the
-/// dispatch region if a dispatch region contains a single "Dispatchable op"
-static Optional<SmallVector<SmallVector<Value, 4>, 1>> computeOutputShape(
- OpBuilder &builder, Operation *op) {
- SmallVector<SmallVector<Value, 4>, 1> outputShapes;
- for (auto outputType : op->getResultTypes()) {
- // Add empty shape for scalar values.
- if (outputType.isIntOrFloat()) {
- outputShapes.push_back({});
- continue;
- }
-
- // TODO(ravishankarm): For now only handle static shapes. For dynamic
- // shapes, the shape of the output needs to be resolved using tie shapes,
- // etc.
- if (auto shapedType = outputType.dyn_cast<ShapedType>()) {
- if (!shapedType.hasStaticShape()) return llvm::None;
- outputShapes.push_back(llvm::to_vector<4>(
- llvm::map_range(shapedType.getShape(), [&](int64_t dim) -> Value {
- return builder.create<arith::ConstantIndexOp>(op->getLoc(), dim);
- })));
- continue;
- }
- return llvm::None;
- }
- return outputShapes;
-}
-
static bool hasOnlyDimUses(Operation *op) {
return llvm::all_of(op->getUsers(), [&](Operation *user) {
return isa<tensor::DimOp>(user);
});
}
+/// For a value `v` append to `dynamicDims` `Value`s that represent the shape of
+/// the dynamic dimensions.
+static void appendDynamicDims(OpBuilder &builder, Location loc, Value v,
+ SmallVectorImpl<Value> &dynamicDims) {
+ auto shapedType = v.getType().dyn_cast<RankedTensorType>();
+ if (!shapedType) return;
+ for (auto shape : enumerate(shapedType.getShape())) {
+ if (shape.value() != ShapedType::kDynamicSize) continue;
+ Value dim = builder.create<tensor::DimOp>(loc, v, shape.index());
+ dynamicDims.push_back(dim);
+ }
+}
+
//===----------------------------------------------------------------------===//
// Patterns that create the dispatch region.
//===----------------------------------------------------------------------===//
@@ -807,8 +734,7 @@
// Capture dynamic result dimensions.
SmallVector<Value, 4> resultDynamicDims;
for (auto result : linalgOp.outputs()) {
- resultDynamicDims.append(Shape::buildOrFindDynamicDimsForValue(
- linalgOp.getLoc(), result, rewriter));
+ appendDynamicDims(rewriter, loc, result, resultDynamicDims);
}
// Note: DispatchTensorStoreOp generated by the
@@ -829,27 +755,37 @@
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(clonedLinalgOp);
- linalg::TiledLinalgOp tiledLinalgOp;
- LogicalResult tilingResult =
- Base::matchAndRewriteBase(clonedLinalgOp, rewriter, tiledLinalgOp);
- if (failed(tilingResult)) {
- // GreedyPatternRewriter is not transactional and does not stop on
- // failure. Must explicitly delete on all failure paths.
- rewriter.eraseOp(dispatchOp);
- return failure();
+ auto nLoops = linalgOp.getNumParallelLoops();
+ if (nLoops) {
+ linalg::TiledLinalgOp tiledLinalgOp;
+ LogicalResult tilingResult =
+ Base::matchAndRewriteBase(clonedLinalgOp, rewriter, tiledLinalgOp);
+ if (failed(tilingResult)) {
+ // GreedyPatternRewriter is not transactional and does not stop on
+ // failure. Must explicitly delete on all failure paths.
+ rewriter.eraseOp(dispatchOp);
+ return failure();
+ }
+
+ SmallVector<Value> clonedOpOperands =
+ clonedLinalgOp.getInputAndOutputOperands();
+ pullInProducersInSameGroup(rewriter, dispatchOp, tiledLinalgOp.op,
+ clonedOpOperands, tiledLinalgOp.loops,
+ getRootNumber(op));
+
+ // Keep track of the tiledOpOperands for fusion.
+ rewriter.replaceOp(clonedLinalgOp, tiledLinalgOp.tensorResults);
+
+ removeRootOpAttribute(tiledLinalgOp.op);
+ } else {
+ SmallVector<Value> clonedOpOperands =
+ clonedLinalgOp.getInputAndOutputOperands();
+ pullInProducersInSameGroup(
+ rewriter, dispatchOp, clonedLinalgOp, clonedOpOperands,
+ /*tiledLoops =*/ArrayRef<Operation *>{}, getRootNumber(op));
+ removeRootOpAttribute(clonedLinalgOp);
}
- SmallVector<Value> clonedOpOperands =
- clonedLinalgOp.getInputAndOutputOperands();
- pullInProducersInSameGroup(rewriter, dispatchOp, tiledLinalgOp.op,
- clonedOpOperands, tiledLinalgOp.loops,
- getRootNumber(op));
-
- // Keep track of the tiledOpOperands for fusion.
- rewriter.replaceOp(clonedLinalgOp, tiledLinalgOp.tensorResults);
-
- removeRootOpAttribute(tiledLinalgOp.op);
-
rewriter.replaceOpWithIf(op, dispatchOp.getResults(),
[&](OpOperand &operand) {
return !isa<tensor::DimOp>(operand.getOwner());
@@ -884,9 +820,8 @@
// Capture dynamic result dimensions.
SmallVector<Value, 4> resultDynamicDims;
- for (auto result : tilableOp.getDestinationOperands()) {
- resultDynamicDims.append(
- Shape::buildOrFindDynamicDimsForValue(loc, result, rewriter));
+ for (auto result : tilableOp.getDestinationOperands(rewriter)) {
+ appendDynamicDims(rewriter, loc, result, resultDynamicDims);
}
// Note: DispatchTensorStoreOp generated by the
@@ -931,163 +866,6 @@
return success();
}
};
-
-/// Given a list of shapes, returns whether it is statically provable that all
-/// shapes are the same. For now checks if
-/// 1) Each dimension has the same dynamic value, or,
-/// 2) The defining op for each dimension is a `constant` op with the same
-/// scalar value.
-static bool areAllShapesEqual(ArrayRef<SmallVector<Value>> shapes) {
- assert(!shapes.empty());
- if (shapes.size() == 1) return true;
- auto isSameShape = [&](ArrayRef<Value> lhsShape,
- ArrayRef<Value> rhsShape) -> bool {
- if (lhsShape.size() != rhsShape.size()) return false;
- return llvm::all_of(
- llvm::zip(lhsShape, rhsShape), [&](std::tuple<Value, Value> vals) {
- APInt lhsInt, rhsInt;
- Value lhs = std::get<0>(vals);
- Value rhs = std::get<1>(vals);
- return lhs == rhs || (matchPattern(lhs, m_ConstantInt(&lhsInt)) &&
- matchPattern(rhs, m_ConstantInt(&rhsInt)) &&
- lhsInt == rhsInt);
- });
- };
- return llvm::all_of(
- llvm::make_range(std::next(shapes.begin()), shapes.end()),
- [&](ArrayRef<Value> shape) { return isSameShape(shapes[0], shape); });
-}
-
-/// The workload is computed based on the problem size. For a given operation,
-/// return the shape of all its results.
-static Optional<SmallVector<SmallVector<Value>>> getResultShapes(
- PatternRewriter &rewriter, Operation *op) {
- if (op->getNumResults() == 0) return llvm::None;
- ReifiedRankedShapedTypeDims resultShapes;
- // Check if the op implements the shape interface.
- if (auto shapedOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op)) {
- if (failed(shapedOp.reifyResultShapes(rewriter, resultShapes))) {
- return llvm::None;
- }
- return resultShapes;
- }
-
- // Fallback is to get the shape using `dim` of the outputs. Since the
- // workload depends on the output shape, set the insertion point to after
- // the operation. After dim canonicalization, the original operation should
- // become dead.
- rewriter.setInsertionPointAfter(op);
- Location loc = op->getLoc();
- auto getShapeOfShapedTypeVal = [&](Value v) -> SmallVector<Value> {
- SmallVector<Value> shape;
- for (auto dim :
- llvm::seq<int64_t>(0, v.getType().cast<ShapedType>().getRank())) {
- shape.push_back(rewriter.createOrFold<tensor::DimOp>(loc, v, dim));
- }
- return shape;
- };
- for (OpResult result : op->getResults()) {
- auto resultType = result.getType().dyn_cast<ShapedType>();
- if (!resultType) return llvm::None;
- rewriter.setInsertionPointAfter(op);
- auto resultShape = getShapeOfShapedTypeVal(result);
- resultShapes.emplace_back(std::move(resultShape));
- }
- return resultShapes;
-}
-
-/// Puts ops that are not-tilable or arent tiled into a
-/// `flow.dispatch.workgroups` operation. For example tile and distribute of
-/// element-wise operations is not beneficial. These are handled appropriately
-/// by the backends.
-struct MakeDispatchWorkgroupsOp : public RewritePattern {
- MakeDispatchWorkgroupsOp(MLIRContext *context, PatternBenefit benefit = 1)
- : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {}
-
- LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override {
- if (!isDispatchableOp(op) || hasOnlyDimUses(op)) return failure();
-
- // If this is a dispatchable op that is to be fused into dispatch ops, and
- // all its uses are dispatchable ops, don't do anything.
- if ((hasFusionGroupsAttribute(op) || isAlwaysFusedIntoDispatchOp(op)) &&
- llvm::all_of(op->getUsers(), [](Operation *user) {
- return isDispatchableOp(user) ||
- user->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>() ||
- isa<IREE::Flow::DispatchWorkgroupsOp, tensor::DimOp>(user);
- })) {
- return failure();
- }
-
- // The workgroup count is based on the result shape.
- Optional<SmallVector<SmallVector<Value>>> resultShapesOpt =
- getResultShapes(rewriter, op);
- if (!resultShapesOpt) return failure();
- ArrayRef<SmallVector<Value>> resultShapes = *resultShapesOpt;
- if (resultShapes.size() != op->getNumResults() ||
- !areAllShapesEqual(resultShapes))
- return failure();
-
- // TODO(ravishankarm): For now the Flow -> HAL conversion only handles
- // workload count of 3, though it should be generalized. For now making sure
- // the flow has three elements of workload size (x, y, z) by linearizing the
- // workloads for all higher dimensions greater than or equal to
- // kNumMaxParallelDims.
- Location loc = op->getLoc();
- SmallVector<Value, 4> count(resultShapes[0].begin(), resultShapes[0].end());
- if (count.size() > kNumMaxParallelDims) {
- unsigned numSymbols = 0;
- AffineExpr expr = rewriter.getAffineSymbolExpr(numSymbols++);
- for (int64_t i = 1; i < count.size() - kNumMaxParallelDims + 1; i++) {
- expr = expr * rewriter.getAffineSymbolExpr(numSymbols++);
- }
- count[count.size() - kNumMaxParallelDims] = linalg::applyMapToValues(
- rewriter, loc, AffineMap::get(0, numSymbols, expr),
- ArrayRef<Value>(count).take_front(count.size() - kNumMaxParallelDims +
- 1))[0];
- count = llvm::to_vector<4>(
- ArrayRef<Value>(count).take_back(kNumMaxParallelDims));
- }
- auto workload = convertToWorkload(rewriter, loc, count);
-
- // Capture dynamic result dimensions.
- SmallVector<Value, 4> resultDynamicDims;
- for (auto result : llvm::enumerate(op->getResults())) {
- auto resultType = result.value().getType().cast<ShapedType>();
- for (unsigned i = 0; i < resultType.getRank(); ++i) {
- if (resultType.isDynamicDim(i)) {
- resultDynamicDims.push_back(resultShapes[result.index()][i]);
- }
- }
- }
-
- auto en = buildOperandLessFlowDispatchWorkgroupOp(rewriter, op->getLoc(),
- workload, op);
- IREE::Flow::DispatchWorkgroupsOp dispatchOp = en.first;
- dispatchOp.result_dimsMutable().assign(resultDynamicDims);
-
- // If this is a root op for fusion, try to pull in the ops to be fused
- // together with it.
- if (hasRootOpAttribute(op)) {
- auto clonedLinalgOp = dyn_cast<linalg::LinalgOp>(en.second);
- if (clonedLinalgOp) {
- SmallVector<Value> opOperandsVal =
- clonedLinalgOp.getInputAndOutputOperands();
- pullInProducersInSameGroup(
- rewriter, dispatchOp, clonedLinalgOp, opOperandsVal,
- /*tiledLoops=*/ArrayRef<Operation *>(), getRootNumber(op));
- removeRootOpAttribute(clonedLinalgOp);
- }
- }
-
- rewriter.replaceOpWithIf(op, dispatchOp.getOperation()->getResults(),
- [&](OpOperand &operand) {
- Operation *user = operand.getOwner();
- return !isa<tensor::DimOp>(user);
- });
- return success();
- }
-};
}; // namespace
//===----------------------------------------------------------------------===//
@@ -1102,46 +880,6 @@
/// fuse with multiple root operations (i.e. replicated). For now a very simple
/// heuristic is used below, but the mechanism should be general enough to
/// capture any heuristic.
-
-/// Sets elementwise operations as root operations.
-// TODO(#5045): After the regression issue on CPU side is addressed, this can be
-// folded into the main logic of fusion.
-template <typename GenericOpTy>
-static unsigned makeElementwiseOpsRootOps(mlir::FuncOp funcOp,
- unsigned numRoots) {
- MLIRContext *context = funcOp.getContext();
- OpBuilder builder(context);
- for (Block &block : funcOp) {
- auto linalgOps = block.getOps<linalg::LinalgOp>();
- for (linalg::LinalgOp linalgOp : llvm::reverse(linalgOps)) {
- Operation *op = linalgOp.getOperation();
- if (hasRootOpAttribute(op) || hasFusionGroupsAttribute(op)) {
- continue;
- }
- if (!isa<GenericOpTy>(op) ||
- !llvm::all_of(
- cast<linalg::LinalgOp>(op).getIndexingMaps(),
- [](AffineMap map) { return map.isProjectedPermutation(); })) {
- continue;
- }
- unsigned newGroup = numRoots++;
- setRootAttribute(context, op, newGroup);
-
- for (OpOperand *operand : linalgOp.getOutputTensorOperands()) {
- auto producer = operand->get().getDefiningOp<linalg::LinalgOp>();
- if (!producer) continue;
- if (producer.getNumLoops() != producer.getNumParallelLoops()) continue;
- appendToFusionGroup(producer, newGroup);
- }
- }
- }
- return numRoots;
-}
-
-/// For a given block partition the LinalgOps in the block into fusable
-/// groups. All analysis of what to fuse happens here. For now this is just
-/// hard-wiring from basic heuristic but this could be adapted to have 1) better
-/// heuristics and 2) use a search approach to decide what all should be fused.
static unsigned decideFusableLinalgOps(mlir::FuncOp funcOp) {
unsigned numRootOps = 0;
MLIRContext *context = funcOp.getContext();
@@ -1210,6 +948,7 @@
appendToFusionGroup(op, rootNumber);
}
}
+
return numRootOps;
}
@@ -1232,21 +971,9 @@
};
} // namespace
-void DispatchLinalgOnTensorsPass::runOnOperation() {
- auto funcOp = getOperation();
-
+/// For all ops within `funcOp` tagged as root ops, create dispatch regions.
+LogicalResult createDispatchRegionsFromRootOps(FuncOp funcOp) {
MLIRContext *context = funcOp->getContext();
- context->allowUnregisteredDialects(true);
-
- unsigned numRoots = decideFusableLinalgOps(funcOp);
- numRoots = makeElementwiseOpsRootOps<linalg::GenericOp>(funcOp, numRoots);
-
- DEBUG_WITH_TYPE(DEBUG_TYPE, {
- llvm::dbgs() << "\n--- After annotating linalg op fusion scheme ---\n";
- funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n\n";
- });
-
// Distribution strategy along at most 3 dimensions with WorkgroupIdOp in
// range [0, WorkgroupSizeOp).
static linalg::LinalgLoopDistributionOptions workgroupDistributionOptions = {
@@ -1268,6 +995,12 @@
DenseMap<StringRef,
std::function<linalg::ProcInfo(OpBuilder &, Location)>>()};
+ // Tile size selection function. Sets the tile size now to
+ // flow.dispatch.workgroup.size op, with 0 for the innermost parallel loop
+ // partitioned, 1 for the next outermost loop partitioned and so on. Use the
+ // workgroup size as a proxy for tile size here. At the flow level this
+ // represents the "workload" per processors and is not necessarily tied to the
+ // workgroup size specified by the backend.
auto tileSizeFn = [&](OpBuilder &builder,
Operation *op) -> SmallVector<Value, 4> {
SmallVector<unsigned> partitionedLoops = getPartitionedLoops(op);
@@ -1303,11 +1036,10 @@
return useTileSizes;
};
+ // Create the dispatch region, first without the isolate region from above
+ // property.
{
- // Use the workgroup size as a proxy for tile size here. At the flow level
- // this represents the "workload" per processors and is not necessarily tied
- // to the workgroup size specified by the backend.
- OwningRewritePatternList patterns(&getContext());
+ OwningRewritePatternList patterns(context);
auto linalgTilingOptions =
linalg::LinalgTilingOptions()
.setDistributionOptions(workgroupDistributionOptions)
@@ -1320,64 +1052,33 @@
// TODO(nicolavasilache): use refactored `getWorkgroupMarker()`
linalg::LinalgTransformationFilter(
ArrayRef<Identifier>(), Identifier::get("workgroup", context)));
- (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ return failure();
+ }
- // Run canonicalization patterns.
- OwningRewritePatternList canonicalizationPattterns(&getContext());
+ // Run canonicalization patterns and pattern to resolve tensor.dim of result
+ // values into tensor.dim of its operands..
+ OwningRewritePatternList canonicalizationPatterns(context);
linalg::populateLinalgTilingCanonicalizationPatterns(
- canonicalizationPattterns);
- (void)applyPatternsAndFoldGreedily(funcOp,
- std::move(canonicalizationPattterns));
+ canonicalizationPatterns);
+ if (failed(applyPatternsAndFoldGreedily(
+ funcOp, std::move(canonicalizationPatterns)))) {
+ return failure();
+ }
}
- // If elementwise operations are not tiled and distributed, the wont be marked
- // as root ops previously. Mark them so here to allow fusion of `fill` etc.
- numRoots = makeElementwiseOpsRootOps<linalg::GenericOp>(funcOp, numRoots);
-
- DEBUG_WITH_TYPE(DEBUG_TYPE, {
- llvm::dbgs()
- << "\n--- After annotating linalg op fusion scheme for fallback ---\n";
+ LLVM_DEBUG({
+ llvm::dbgs() << "\n--- After dispatch op formation ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
- // After outlining in dispatch region we can rewrite the dispatch ops with
- // proper captures.
- if (funcOp
- .walk([&](IREE::Flow::DispatchWorkgroupsOp op) -> WalkResult {
- return legalizeDispatchWorkgroupOperands(op);
- })
- .wasInterrupted()) {
- return signalPassFailure();
- }
-
- // Move other operations into their own dispatch regions.
+ // Run necessary canonicalization patterns before rewrite destructive updates.
{
OwningRewritePatternList patterns(context);
- patterns.insert<MakeDispatchWorkgroupsOp>(context);
- (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
- }
-
- DEBUG_WITH_TYPE(DEBUG_TYPE, {
- llvm::dbgs() << "\n--- After dispatch region creation ---\n";
- funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
- llvm::dbgs() << "\n\n";
- });
-
- // After outlining in dispatch region we can rewrite the dispatch ops with
- // proper captures.
- if (funcOp
- .walk([&](IREE::Flow::DispatchWorkgroupsOp op) -> WalkResult {
- numDispatches++;
- return legalizeDispatchWorkgroupOperands(op);
- })
- .wasInterrupted()) {
- return signalPassFailure();
- }
-
- // Run necessary canonicalization patterns before destructive updates.
- {
- OwningRewritePatternList patterns(&getContext());
+ // Resolve `tensor.dim` of result of operations into operations on its
+ // operands using the `ReifyRankedShapedTypeOpInterface`.
+ memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
// This is needed because tiling and distribution may create
// subtensor_insert ops whose source operands come from tensor.cast ops.
// Those tensor.cast ops cast tensors into a more dynamic shape, in order
@@ -1388,12 +1089,106 @@
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
}
- DEBUG_WITH_TYPE(DEBUG_TYPE, {
+ // After outlining in dispatch region we can rewrite the dispatch ops with
+ // proper captures to make it isolated from above.
+ if (funcOp
+ .walk([&](IREE::Flow::DispatchWorkgroupsOp op) -> WalkResult {
+ return legalizeDispatchWorkgroupOperands(op);
+ })
+ .wasInterrupted()) {
+ return failure();
+ }
+
+ LLVM_DEBUG({
llvm::dbgs() << "\n--- After dispatch op legalization ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});
+ return success();
+}
+
+void DispatchLinalgOnTensorsPass::runOnOperation() {
+ auto funcOp = getOperation();
+
+ MLIRContext *context = funcOp->getContext();
+ context->allowUnregisteredDialects(true);
+
+ unsigned numRoots = decideFusableLinalgOps(funcOp);
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "\n--- After annotating linalg op fusion scheme ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+
+ if (failed(createDispatchRegionsFromRootOps(funcOp))) {
+ return signalPassFailure();
+ }
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "\n--- After first step of dispatch region formation ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+
+ /// Iterate over the remaining ops and pick up whatever needs to go into
+ /// dispatch regions and mark them as root ops.
+ for (Block &block : funcOp) {
+ for (Operation &op : block) {
+ // Ignore ops that
+ // - Do not implement the `LinalgOp` interface.
+ // - linalg.fill ops.
+ if (!isa<linalg::LinalgOp>(&op)) continue;
+ if (isa<linalg::FillOp>(&op)) continue;
+ assert(!hasRootOpAttribute(&op) &&
+ "unexpected root operation outside of dispatch region");
+ removeFusionGroupsAttribute(&op);
+ setRootAttribute(context, &op, numRoots++);
+ }
+ }
+
+ LLVM_DEBUG({
+ llvm::dbgs()
+ << "\n--- After annotating remaining linalg ops as roots ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+
+ if (failed(createDispatchRegionsFromRootOps(funcOp))) {
+ return signalPassFailure();
+ }
+
+ LLVM_DEBUG({
+ llvm::dbgs()
+ << "\n--- After second step of dispatch region formation ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+
+ /// Iterate over the remaining ops and pick up whatever needs to go into
+ /// dispatch regions and mark them as root ops.
+ for (Block &block : funcOp) {
+ for (Operation &op : block) {
+ // Ignore ops that do not implement the `TiledOpInterface` interface.
+ if (!isa<linalg_ext::TiledOpInterface>(&op)) continue;
+ assert(!hasRootOpAttribute(&op) &&
+ "unexpected root operation outside of dispatch region");
+ removeFusionGroupsAttribute(&op);
+ setRootAttribute(context, &op, numRoots++);
+ }
+ }
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "\n--- After annotating remaining ops as roots ---\n";
+ funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
+ llvm::dbgs() << "\n\n";
+ });
+
+ if (failed(createDispatchRegionsFromRootOps(funcOp))) {
+ return signalPassFailure();
+ }
+
// Rewrite destructive updates and ensure no remaining store remains to the
// full output.
if (funcOp
@@ -1409,7 +1204,7 @@
signalPassFailure();
}
- DEBUG_WITH_TYPE(DEBUG_TYPE, {
+ LLVM_DEBUG({
llvm::dbgs() << "\n--- After rewriting destructive updates ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
diff --git a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
index a4b0747..7edfc27 100644
--- a/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
+++ b/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_on_tensors.mlir
@@ -205,29 +205,6 @@
// -----
-func @fuse_reshape_op(%arg0: tensor<?x?xf32>) -> tensor<?xf32>
-{
- %0 = linalg.tensor_collapse_shape %arg0 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>
- return %0 : tensor<?xf32>
-}
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
-// CHECK: func @fuse_reshape_op
-// CHECK-SAME: (%[[ARG0:.+]]: tensor<?x?xf32>)
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
-// CHECK: %[[WORKLOAD:.+]] = affine.apply #[[MAP0]]()[%[[D0]], %[[D1]]]
-// CHECK: %[[RESULT:.+]] = flow.dispatch.workgroups
-// CHECK-SAME: [%[[WORKLOAD]], %[[C1]], %[[C1]]](%[[ARG0]])
-// CHECK-NEXT: %[[ARG1:.+]]: !flow.dispatch.tensor<readonly:?x?xf32>
-// CHECK-SAME: %[[ARG2:.+]]: !flow.dispatch.tensor<writeonly:?xf32>
-// CHECK: %[[LOAD:.+]] = flow.dispatch.tensor.load %[[ARG1]], {{.*}}
-// CHECK: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[LOAD]] {{\[}}[0, 1]]
-// CHECK: flow.dispatch.tensor.store %[[RESHAPE]], %[[ARG2]], {{.*}}
-
-// -----
-
func @tile_4d_generic_op_alone
(%A: tensor<?x?x?x?xf32>, %B: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
%c0 = arith.constant 0 : index
@@ -308,10 +285,10 @@
// CHECK-DAG: %[[M:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
// CHECK-DAG: %[[N1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK: %[[RESULT1:.+]] = flow.dispatch.workgroups[%[[N1]], %[[M]], %[[C1]]]
-// CHECK-SAME: (%[[ARG0]], %[[RHS1]], %[[M]], %[[N1]])
+// CHECK-SAME: (%[[RHS1]], %[[ARG0]], %[[M]], %[[N1]])
// CHECK: %[[N2:.+]] = tensor.dim %[[RHS2]], %[[C1]]
// CHECK: %[[RESULT2:.+]] = flow.dispatch.workgroups[%[[N2]], %[[M]], %[[C1]]]
-// CHECK-SAME: (%[[ARG0]], %[[RHS2]], %[[M]], %[[N2]])
+// CHECK-SAME: (%[[RHS2]], %[[ARG0]], %[[M]], %[[N2]])
// CHECK: return %[[RESULT1]], %[[RESULT2]]
// -----
@@ -804,37 +781,6 @@
// -----
-func @multi_result_fallback(%arg0: tensor<?x10xi32>, %arg1: tensor<?x10xi32>)
- -> (tensor<?x10xi32>, tensor<?x10xi32>) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %0 = tensor.dim %arg0, %c0 : tensor<?x10xi32>
- %1 = linalg.init_tensor [%0, 10] : tensor<?x10xi32>
- %2:2 = linalg.generic {
- indexing_maps = [affine_map<(d0, d1) -> (d0, 10-d1)>,
- affine_map<(d0, d1) -> (d0, 10-d1)>,
- affine_map<(d0, d1) -> (d0, d1)>,
- affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]}
- ins(%arg0, %arg1 : tensor<?x10xi32>, tensor<?x10xi32>)
- outs(%1, %1 : tensor<?x10xi32>, tensor<?x10xi32>) {
- ^bb0(%arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32):
- linalg.yield %arg2, %arg3 : i32, i32
- } -> (tensor<?x10xi32>, tensor<?x10xi32>)
- return %2#0, %2#1 : tensor<?x10xi32>, tensor<?x10xi32>
-}
-// CHECK-LABEL: func @multi_result_fallback
-// CHECK: %[[RESULT:.+]]:2 = flow.dispatch.workgroup
-// CHECK-NEXT: %[[ARG5:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<writeonly:?x10xi32>
-// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: !flow.dispatch.tensor<writeonly:?x10xi32>
-// CHECK-NOT: scf.for
-// CHECK: %[[OP_RESULT:.+]]:2 = linalg.generic
-// CHECK-DAG: flow.dispatch.tensor.store %[[OP_RESULT]]#0, %[[ARG5]]
-// CHECK-DAG: flow.dispatch.tensor.store %[[OP_RESULT]]#1, %[[ARG6]]
-// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1
-
-// -----
-
func @dynamic_slice(%arg0: tensor<?x?xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3 : index) -> tensor<1x?xi32> {
%c1_i32 = arith.constant 1 : i32
%c0_i32 = arith.constant 0 : i32
@@ -861,7 +807,7 @@
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[RESULT:.+]] = flow.dispatch.workgroups
// CHECK-SAME: [%[[ARG3]], %[[C1]], %[[C1]]]
-// CHECK-SAME: (%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]])
+// CHECK-SAME: (%[[ARG3]], %[[ARG1]], %[[ARG2]], %[[ARG0]])
// CHECK-DAG: cmpi
// CHECK-DAG: select
// CHECK-DAG: cmpi
@@ -872,7 +818,13 @@
// CHECK-DAG: select
// CHECK-DAG: index_cast
// CHECK-DAG: index_cast
-// CHECK: tensor.extract_slice
+// CHECK-NOT: tensor.extract
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK-NOT: tensor.extract
+// CHECK: flow.dispatch.tensor.load
+// CHECK-NOT: tensor.extract
+// CHECK: flow.dispatch.tensor.store
// CHECK: flow.return
// CHECK: return %[[RESULT]]
diff --git a/iree/compiler/Dialect/LinalgExt/IR/BUILD b/iree/compiler/Dialect/LinalgExt/IR/BUILD
index abf4a49..5fa8109 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/BUILD
+++ b/iree/compiler/Dialect/LinalgExt/IR/BUILD
@@ -163,6 +163,7 @@
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:IR",
+ "@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
diff --git a/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt b/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt
index 59d4920..46c09f1 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt
+++ b/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt
@@ -100,6 +100,7 @@
LLVMSupport
MLIRAffine
MLIRIR
+ MLIRLinalg
MLIRStandard
MLIRSupport
MLIRTensor
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index c98d9de..bd8bc30 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -30,7 +30,7 @@
let printer = [{ return print$cppClass(p, *this); }];
let parser = [{ return parse$cppClass(parser, result); }];
code extraLinalgExtOpClassDeclaration = [{
- SmallVector<Value> getDestinationOperands() {
+ SmallVector<Value> getDestinationOperands(OpBuilder &b) {
SmallVector<Value> dest(outputs().begin(), outputs().end());
return dest;
}
diff --git a/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.cpp b/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.cpp
index f6bc021..b8bd224 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.cpp
+++ b/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.cpp
@@ -8,6 +8,7 @@
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -36,10 +37,110 @@
//===----------------------------------------------------------------------===//
namespace {
+
+/// External model for `tensor.extract_slice`.
+struct ExtractSliceTiledOpInterface
+ : public TiledOpInterface::ExternalModel<ExtractSliceTiledOpInterface,
+ tensor::ExtractSliceOp> {
+ SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
+ // No operand of `tensor.extract_slice` serves as a destination operand. So
+ // create an `init_tensor` op of the same size as the result.
+ auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
+ SmallVector<Value> dest;
+ ReifiedRankedShapedTypeDims returnShape;
+ (void)extractSliceOp.reifyResultShapes(b, returnShape);
+ auto ofrShape = llvm::to_vector<4>(llvm::map_range(
+ returnShape[0], [](Value v) { return getAsOpFoldResult(v); }));
+ Value initTensor = b.create<linalg::InitTensorOp>(
+ op->getLoc(), ofrShape, extractSliceOp.getType().getElementType());
+ return {initTensor};
+ }
+
+ SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
+ auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
+ return SmallVector<StringRef>(extractSliceOp.getType().getRank(),
+ getParallelIteratorTypeName());
+ }
+
+ SmallVector<Range> getLoopBounds(Operation *op, OpBuilder &b) const {
+ auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
+ SmallVector<Value> dest;
+ ReifiedRankedShapedTypeDims returnShape;
+ (void)extractSliceOp.reifyResultShapes(b, returnShape);
+ Location loc = op->getLoc();
+ Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = b.create<arith::ConstantIndexOp>(loc, 1);
+ SmallVector<Range> loopRanges(returnShape[0].size(),
+ Range{zero, nullptr, one});
+ for (auto ub : enumerate(returnShape[0])) {
+ loopRanges[ub.index()].size = ub.value();
+ }
+ return loopRanges;
+ }
+
+ Operation *getTiledImplementation(Operation *op, OpBuilder &b,
+ ValueRange outputs,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<Value> &results) const {
+ auto extractOp = cast<tensor::ExtractSliceOp>(op);
+ // Check that strides are 1. For now abort if they arent
+ Location loc = extractOp.getLoc();
+ auto oneAttr = b.getI64IntegerAttr(1);
+
+ // Compute the offset and sizes for the tiled `tensor.extract_slice`
+ // operation.
+ llvm::SmallDenseSet<unsigned> droppedDims = extractOp.getDroppedDims();
+ unsigned resultDimPos = 0;
+ auto opOffsets = extractOp.getMixedOffsets();
+ auto opSizes = extractOp.getMixedSizes();
+ auto opStrides = extractOp.getMixedStrides();
+ MLIRContext *context = b.getContext();
+ SmallVector<OpFoldResult> newOffset, newSizes, newStrides;
+ for (auto opOffset : enumerate(opOffsets)) {
+ // If the dimension is dropped, use the same offset.
+ if (droppedDims.count(opOffset.index())) {
+ newOffset.push_back(opOffset.value());
+ newSizes.push_back(opSizes[opOffset.index()]);
+ } else {
+ AffineExpr d0, s0, s1;
+ bindDims(context, d0);
+ bindSymbols(context, s0, s1);
+ AffineMap map = AffineMap::get(1, 2, d0 * s0 + s1);
+ SmallVector<Value> operands = {
+ getValue(b, loc, offsets[resultDimPos]),
+ getValue(b, loc, opStrides[opOffset.index()]),
+ getValue(b, loc, opOffset.value())};
+ Value offset = b.create<AffineApplyOp>(loc, map, operands);
+ newOffset.push_back(offset);
+ newSizes.push_back(sizes[resultDimPos]);
+ resultDimPos++;
+ }
+ newStrides.push_back(opStrides[opOffset.index()]);
+ }
+
+ // Generate the tiled `tensor.extract_slice` operation.
+ Type resultType = tensor::ExtractSliceOp::inferRankReducedResultType(
+ extractOp.getType().getRank(), extractOp.getSourceType(), newOffset,
+ newSizes, newStrides);
+ auto tiledExtractOp = b.create<tensor::ExtractSliceOp>(
+ loc, resultType.cast<RankedTensorType>(), extractOp.source(), newOffset,
+ newSizes, newStrides);
+
+ // Insert the tiled extract into the result tensor.
+ SmallVector<OpFoldResult> resultStrides(offsets.size(), oneAttr);
+ auto tiledInsertOp = b.create<tensor::InsertSliceOp>(
+ loc, tiledExtractOp.result(), outputs[0], offsets, sizes,
+ resultStrides);
+ results.push_back(tiledInsertOp.result());
+ return tiledExtractOp;
+ }
+};
+
struct InsertSliceTiledOpInterface
: public TiledOpInterface::ExternalModel<InsertSliceTiledOpInterface,
tensor::InsertSliceOp> {
- SmallVector<Value> getDestinationOperands(Operation *op) const {
+ SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
SmallVector<Value> dest;
dest.push_back(cast<tensor::InsertSliceOp>(op).dest());
return dest;
@@ -82,6 +183,7 @@
op->emitOpError("unable to tile operation with non-unit stride");
return nullptr;
}
+ MLIRContext *context = b.getContext();
Location loc = insertOp.getLoc();
auto oneAttr = b.getI64IntegerAttr(1);
SmallVector<OpFoldResult> strides(offsets.size(), oneAttr);
@@ -116,13 +218,14 @@
resultOffsets[opOffsetIndex] = b.getI64IntegerAttr(
*getConstantIntValue(opOffsetVal) + *getConstantIntValue(offset));
} else {
- AffineMap map = AffineMap::get(
- 1, 1, {b.getAffineDimExpr(0) + b.getAffineSymbolExpr(0)});
+ AffineExpr d0, s0;
+ bindDims(context, d0);
+ bindSymbols(context, s0);
+ AffineMap map = AffineMap::get(1, 1, d0 + s0);
+ SmallVector<Value> operands = {getValue(b, loc, offset),
+ getValue(b, loc, opOffsetVal)};
resultOffsets[opOffsetIndex] =
- b.create<AffineApplyOp>(loc, map,
- ValueRange{getValue(b, loc, offset),
- getValue(b, loc, opOffsetVal)})
- .getResult();
+ b.create<AffineApplyOp>(loc, map, operands).getResult();
}
resultSizes[opOffsetIndex] = sizes[offsetIndex];
offsetIndex++;
@@ -138,9 +241,10 @@
} // namespace
void registerTiledOpInterfaceExternalModels(DialectRegistry ®istry) {
- LLVM_DEBUG({
- llvm::dbgs() << "Adding tiled op interface for tensor.insert_slice\n";
- });
+ LLVM_DEBUG(
+ { llvm::dbgs() << "Adding external models of tiled op interface\n"; });
+ registry
+ .addOpInterface<tensor::ExtractSliceOp, ExtractSliceTiledOpInterface>();
registry.addOpInterface<tensor::InsertSliceOp, InsertSliceTiledOpInterface>();
}
diff --git a/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.td b/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.td
index 337b273..af89913 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.td
+++ b/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.td
@@ -29,7 +29,7 @@
}],
/*retType=*/"SmallVector<Value>",
/*methodName=*/"getDestinationOperands",
- /*args=*/(ins),
+ /*args=*/(ins "OpBuilder &":$b),
/*methodBody=*/"",
/*defaultImplementation=*/"return ValueRange{};"
>,
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp b/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp
index 3224687..31b5140 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp
@@ -183,7 +183,7 @@
FailureOr<TiledOp> tileInterfaceOp(OpBuilder &b, TiledOpInterface tilableOp,
const linalg::LinalgTilingOptions &options) {
- SmallVector<Value> dest = tilableOp.getDestinationOperands();
+ SmallVector<Value> dest = tilableOp.getDestinationOperands(b);
if (dest.empty()) {
return static_cast<LogicalResult>(tilableOp.emitOpError(
"cannot tile operation without destination operands"));
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir b/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
index f269e75..3482609 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
@@ -745,3 +745,425 @@
// CHECK-SAME: [%[[ARG2]], %[[APPLY]], %[[ARG4]]] [1, %[[TILESIZE]], 1]
// CHECK: scf.yield %[[YIELD]]
// CHECK: return %[[RESULT]]
+
+// -----
+
+func @extract_slice(%arg0 : tensor<?x?xf32>, %arg1: index,
+ %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index,
+ %arg6 : index) -> tensor<?x?xf32> {
+ %0 = tensor.extract_slice %arg0[%arg1, %arg2] [%arg3, %arg4] [%arg5, %arg6]
+ {__internal_linalg_transform__ = "tiling_input"} : tensor<?x?xf32> to tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+// CHECK: func @extract_slice
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[ARG3]], %[[ARG4]]]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG3]] step %[[C10]]
+// CHECK-SAME: iter_args(%[[ITER1:.+]] = %[[INIT]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[ARG3]]]
+// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG4]] step %[[C20]]
+// CHECK-SAME: iter_args(%[[ITER2:.+]] = %[[ITER1]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[ARG4]]]
+// CHECK-DAG: %[[OFFSET_Y:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[ARG5]], %[[ARG1]]]
+// CHECK-DAG: %[[OFFSET_X:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[ARG6]], %[[ARG2]]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME: [%[[OFFSET_Y]], %[[OFFSET_X]]] [%[[TILE_Y]], %[[TILE_X]]] [%[[ARG5]], %[[ARG6]]]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE]] into %[[ITER2]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TILE_Y]], %[[TILE_X]]] [1, 1]
+// CHECK: scf.yield %[[INSERT]]
+// CHECK: }
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: }
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @extract_slice_static(%arg0 : tensor<50x60xf32>) -> tensor<20x30xf32> {
+ %0 = tensor.extract_slice %arg0[2, 3] [20, 30] [5, 6]
+ {__internal_linalg_transform__ = "tiling_input"} : tensor<50x60xf32> to tensor<20x30xf32>
+ return %0 : tensor<20x30xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+// CHECK: func @extract_slice_static
+// CHECK-SAME: %[[ARG0:.+]]: tensor<50x60xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index
+// CHECK-DAG: %[[C6:.+]] = arith.constant 6 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index
+// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [20, 30]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:.+]] = %[[C0]]
+// CHECK-SAME: to %[[C20]] step %[[C10]]
+// CHECK-SAME: iter_args(%[[ITER1:.+]] = %[[INIT]]) -> (tensor<20x30xf32>) {
+// CHECK: %[[TILE_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[C20]]]
+// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:.+]] = %[[C0]]
+// CHECK-SAME: to %[[C30]] step %[[C20]]
+// CHECK-SAME: iter_args(%[[ITER2:.+]] = %[[ITER1]]) -> (tensor<20x30xf32>) {
+// CHECK: %[[TILE_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[C30]]]
+// CHECK-DAG: %[[OFFSET_Y:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[C5]], %[[C2]]]
+// CHECK-DAG: %[[OFFSET_X:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[C6]], %[[C3]]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME: [%[[OFFSET_Y]], %[[OFFSET_X]]] [%[[TILE_Y]], %[[TILE_X]]] [5, 6]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE]] into %[[ITER2]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TILE_Y]], %[[TILE_X]]] [1, 1]
+// CHECK: scf.yield %[[INSERT]]
+// CHECK: }
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: }
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @extract_slice_reduced_rank_outer(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
+ %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index,
+ %arg7 : index, %arg8 : index) -> tensor<?x?xf32> {
+ %0 = tensor.extract_slice %arg0[%arg1, %arg2, %arg3] [1, %arg4, %arg5] [%arg6, %arg7, %arg8]
+ {__internal_linalg_transform__ = "tiling_input"} : tensor<?x?x?xf32> to tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+// CHECK: func @extract_slice_reduced_rank_outer
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[ARG4]], %[[ARG5]]]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG4]] step %[[C10]]
+// CHECK-SAME: iter_args(%[[ITER1:.+]] = %[[INIT]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[ARG4]]]
+// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG5]] step %[[C20]]
+// CHECK-SAME: iter_args(%[[ITER2:.+]] = %[[ITER1]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[ARG5]]]
+// CHECK-DAG: %[[OFFSET_Y:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[ARG7]], %[[ARG2]]]
+// CHECK-DAG: %[[OFFSET_X:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[ARG8]], %[[ARG3]]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME: [%[[ARG1]], %[[OFFSET_Y]], %[[OFFSET_X]]]
+// CHECK-SAME: [1, %[[TILE_Y]], %[[TILE_X]]] [%[[ARG6]], %[[ARG7]], %[[ARG8]]]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE]] into %[[ITER2]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TILE_Y]], %[[TILE_X]]] [1, 1]
+// CHECK: scf.yield %[[INSERT]]
+// CHECK: }
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: }
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @extract_slice_reduced_rank_middle(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
+ %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index,
+ %arg7 : index, %arg8 : index) -> tensor<?x?xf32> {
+ %0 = tensor.extract_slice %arg0[%arg1, %arg2, %arg3] [%arg4, 1, %arg5] [%arg6, %arg7, %arg8]
+ {__internal_linalg_transform__ = "tiling_input"} : tensor<?x?x?xf32> to tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+// CHECK: func @extract_slice_reduced_rank_middle
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[ARG4]], %[[ARG5]]]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG4]] step %[[C10]]
+// CHECK-SAME: iter_args(%[[ITER1:.+]] = %[[INIT]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[ARG4]]]
+// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG5]] step %[[C20]]
+// CHECK-SAME: iter_args(%[[ITER2:.+]] = %[[ITER1]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[ARG5]]]
+// CHECK-DAG: %[[OFFSET_Y:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[ARG6]], %[[ARG1]]]
+// CHECK-DAG: %[[OFFSET_X:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[ARG8]], %[[ARG3]]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME: [%[[OFFSET_Y]], %[[ARG2]], %[[OFFSET_X]]]
+// CHECK-SAME: [%[[TILE_Y]], 1, %[[TILE_X]]] [%[[ARG6]], %[[ARG7]], %[[ARG8]]]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE]] into %[[ITER2]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TILE_Y]], %[[TILE_X]]] [1, 1]
+// CHECK: scf.yield %[[INSERT]]
+// CHECK: }
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: }
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @extract_slice_reduced_rank_inner(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
+ %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index,
+ %arg7 : index, %arg8 : index) -> tensor<?x?xf32> {
+ %0 = tensor.extract_slice %arg0[%arg1, %arg2, %arg3] [%arg4, %arg5, 1] [%arg6, %arg7, %arg8]
+ {__internal_linalg_transform__ = "tiling_input"} : tensor<?x?x?xf32> to tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+// CHECK: func @extract_slice_reduced_rank_inner
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[ARG4]], %[[ARG5]]]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG4]] step %[[C10]]
+// CHECK-SAME: iter_args(%[[ITER1:.+]] = %[[INIT]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[ARG4]]]
+// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG5]] step %[[C20]]
+// CHECK-SAME: iter_args(%[[ITER2:.+]] = %[[ITER1]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[ARG5]]]
+// CHECK-DAG: %[[OFFSET_Y:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[ARG6]], %[[ARG1]]]
+// CHECK-DAG: %[[OFFSET_X:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[ARG7]], %[[ARG2]]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME: [%[[OFFSET_Y]], %[[OFFSET_X]], %[[ARG3]]]
+// CHECK-SAME: [%[[TILE_Y]], %[[TILE_X]], 1] [%[[ARG6]], %[[ARG7]], %[[ARG8]]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE]] into %[[ITER2]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TILE_Y]], %[[TILE_X]]] [1, 1]
+// CHECK: scf.yield %[[INSERT]]
+// CHECK: }
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: }
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @extract_slice_reduced_rank_two_dims_1(%arg0 : tensor<?x?x?x?xf32>, %arg1 : index,
+ %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index,
+ %arg7 : index, %arg8 : index, %arg9 : index, %arg10 : index) -> tensor<?x?xf32> {
+ %0 = tensor.extract_slice %arg0[%arg1, %arg2, %arg3, %arg4] [%arg5, 1, %arg6, 1] [%arg7, %arg8, %arg9, %arg10]
+ {__internal_linalg_transform__ = "tiling_input"} : tensor<?x?x?x?xf32> to tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+// CHECK: func @extract_slice_reduced_rank_two_dims_1
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG9:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG10:[a-zA-Z0-9]+]]: index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[ARG5]], %[[ARG6]]]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG5]] step %[[C10]]
+// CHECK-SAME: iter_args(%[[ITER1:.+]] = %[[INIT]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[ARG5]]]
+// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG6]] step %[[C20]]
+// CHECK-SAME: iter_args(%[[ITER2:.+]] = %[[ITER1]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[ARG6]]]
+// CHECK-DAG: %[[OFFSET_Y:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[ARG7]], %[[ARG1]]]
+// CHECK-DAG: %[[OFFSET_X:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[ARG9]], %[[ARG3]]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME: [%[[OFFSET_Y]], %[[ARG2]], %[[OFFSET_X]], %[[ARG4]]]
+// CHECK-SAME: [%[[TILE_Y]], 1, %[[TILE_X]], 1]
+// CHECK-SAME: [%[[ARG7]], %[[ARG8]], %[[ARG9]], %[[ARG10]]]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE]] into %[[ITER2]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TILE_Y]], %[[TILE_X]]] [1, 1]
+// CHECK: scf.yield %[[INSERT]]
+// CHECK: }
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: }
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @extract_slice_reduced_rank_two_dims_2(%arg0 : tensor<?x?x?x?xf32>, %arg1 : index,
+ %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index,
+ %arg7 : index, %arg8 : index, %arg9 : index, %arg10 : index) -> tensor<?x?xf32> {
+ %0 = tensor.extract_slice %arg0[%arg1, %arg2, %arg3, %arg4] [%arg5, 1, 1, %arg6] [%arg7, %arg8, %arg9, %arg10]
+ {__internal_linalg_transform__ = "tiling_input"} : tensor<?x?x?x?xf32> to tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+// CHECK: func @extract_slice_reduced_rank_two_dims_2
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[ARG5]], %[[ARG6]]]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG5]] step %[[C10]]
+// CHECK-SAME: iter_args(%[[ITER1:.+]] = %[[INIT]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[ARG5]]]
+// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG6]] step %[[C20]]
+// CHECK-SAME: iter_args(%[[ITER2:.+]] = %[[ITER1]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[ARG6]]]
+// CHECK-DAG: %[[OFFSET_Y:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[ARG7]], %[[ARG1]]]
+// CHECK-DAG: %[[OFFSET_X:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[ARG10]], %[[ARG4]]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME: [%[[OFFSET_Y]], %[[ARG2]], %[[ARG3]], %[[OFFSET_X]]]
+// CHECK-SAME: [%[[TILE_Y]], 1, 1, %[[TILE_X]]]
+// CHECK-SAME: [%[[ARG7]], %[[ARG8]], %[[ARG9]], %[[ARG10]]]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE]] into %[[ITER2]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TILE_Y]], %[[TILE_X]]] [1, 1]
+// CHECK: scf.yield %[[INSERT]]
+// CHECK: }
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: }
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @extract_slice_reduced_rank_two_dims_3(%arg0 : tensor<?x?x?x?xf32>, %arg1 : index,
+ %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index,
+ %arg7 : index, %arg8 : index, %arg9 : index, %arg10 : index) -> tensor<?x?xf32> {
+ %0 = tensor.extract_slice %arg0[%arg1, %arg2, %arg3, %arg4] [1, %arg5, 1, %arg6] [%arg7, %arg8, %arg9, %arg10]
+ {__internal_linalg_transform__ = "tiling_input"} : tensor<?x?x?x?xf32> to tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+// CHECK: func @extract_slice_reduced_rank_two_dims_3
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG9:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG10:[a-zA-Z0-9]+]]: index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[ARG5]], %[[ARG6]]]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG5]] step %[[C10]]
+// CHECK-SAME: iter_args(%[[ITER1:.+]] = %[[INIT]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[ARG5]]]
+// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG6]] step %[[C20]]
+// CHECK-SAME: iter_args(%[[ITER2:.+]] = %[[ITER1]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[ARG6]]]
+// CHECK-DAG: %[[OFFSET_Y:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[ARG8]], %[[ARG2]]]
+// CHECK-DAG: %[[OFFSET_X:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[ARG10]], %[[ARG4]]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME: [%[[ARG1]], %[[OFFSET_Y]], %[[ARG3]], %[[OFFSET_X]]]
+// CHECK-SAME: [1, %[[TILE_Y]], 1, %[[TILE_X]]]
+// CHECK-SAME: [%[[ARG7]], %[[ARG8]], %[[ARG9]], %[[ARG10]]]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE]] into %[[ITER2]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TILE_Y]], %[[TILE_X]]] [1, 1]
+// CHECK: scf.yield %[[INSERT]]
+// CHECK: }
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: }
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @extract_slice_reduced_rank_two_dims_4(%arg0 : tensor<?x?x?x?xf32>, %arg1 : index,
+ %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index,
+ %arg7 : index, %arg8 : index, %arg9 : index, %arg10 : index) -> tensor<?x?xf32> {
+ %0 = tensor.extract_slice %arg0[%arg1, %arg2, %arg3, %arg4] [%arg5, %arg6, 1, 1] [%arg7, %arg8, %arg9, %arg10]
+ {__internal_linalg_transform__ = "tiling_input"} : tensor<?x?x?x?xf32> to tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)>
+// CHECK: func @extract_slice_reduced_rank_two_dims_4
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG9:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG10:[a-zA-Z0-9]+]]: index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[ARG5]], %[[ARG6]]]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG5]] step %[[C10]]
+// CHECK-SAME: iter_args(%[[ITER1:.+]] = %[[INIT]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[C10]], %[[ARG5]]]
+// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:.+]] = %[[C0]]
+// CHECK-SAME: to %[[ARG6]] step %[[C20]]
+// CHECK-SAME: iter_args(%[[ITER2:.+]] = %[[ITER1]]) -> (tensor<?x?xf32>) {
+// CHECK: %[[TILE_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[C20]], %[[ARG6]]]
+// CHECK-DAG: %[[OFFSET_Y:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[ARG7]], %[[ARG1]]]
+// CHECK-DAG: %[[OFFSET_X:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[ARG8]], %[[ARG2]]]
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME: [%[[OFFSET_Y]], %[[OFFSET_X]], %[[ARG3]], %[[ARG4]]]
+// CHECK-SAME: [%[[TILE_Y]], %[[TILE_X]], 1, 1]
+// CHECK-SAME: [%[[ARG7]], %[[ARG8]], %[[ARG9]], %[[ARG10]]]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE]] into %[[ITER2]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TILE_Y]], %[[TILE_X]]] [1, 1]
+// CHECK: scf.yield %[[INSERT]]
+// CHECK: }
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: }
+// CHECK: return %[[RESULT]]
diff --git a/iree/test/e2e/regression/BUILD b/iree/test/e2e/regression/BUILD
index d57b54d..872d0b4 100644
--- a/iree/test/e2e/regression/BUILD
+++ b/iree/test/e2e/regression/BUILD
@@ -29,7 +29,6 @@
"dynamic_torch_index_select_vector.mlir",
"linalg_ops.mlir",
"linalg_ext_ops.mlir",
- "tensor_insert_slice.mlir",
]
iree_lit_test_suite(
@@ -38,7 +37,6 @@
[
"globals.mlir",
"scalar.mlir",
- "tensor_cast.mlir",
"trace_dispatch_tensors.mlir",
"unused_args.mlir",
],
diff --git a/iree/test/e2e/regression/CMakeLists.txt b/iree/test/e2e/regression/CMakeLists.txt
index c676a6f..cb1ea0b 100644
--- a/iree/test/e2e/regression/CMakeLists.txt
+++ b/iree/test/e2e/regression/CMakeLists.txt
@@ -16,7 +16,6 @@
SRCS
"globals.mlir"
"scalar.mlir"
- "tensor_cast.mlir"
"trace_dispatch_tensors.mlir"
"unused_args.mlir"
DATA
@@ -43,7 +42,6 @@
"linalg_ext_ops.mlir"
"linalg_ops.mlir"
"lowering_config.mlir"
- "tensor_insert_slice.mlir"
TARGET_BACKEND
"dylib-llvm-aot"
DRIVER
@@ -66,7 +64,6 @@
"dynamic_torch_index_select_vector.mlir"
"linalg_ext_ops.mlir"
"linalg_ops.mlir"
- "tensor_insert_slice.mlir"
TARGET_BACKEND
"vmvx"
DRIVER
@@ -89,7 +86,6 @@
"dynamic_torch_index_select_vector.mlir"
"linalg_ext_ops.mlir"
"linalg_ops.mlir"
- "tensor_insert_slice.mlir"
TARGET_BACKEND
"vulkan-spirv"
DRIVER
@@ -112,7 +108,6 @@
"dynamic_torch_index_select_vector.mlir"
"linalg_ext_ops.mlir"
"linalg_ops.mlir"
- "tensor_insert_slice.mlir"
TARGET_BACKEND
"cuda"
DRIVER
diff --git a/iree/test/e2e/tensor_ops/BUILD b/iree/test/e2e/tensor_ops/BUILD
new file mode 100644
index 0000000..4b2a1ce
--- /dev/null
+++ b/iree/test/e2e/tensor_ops/BUILD
@@ -0,0 +1,102 @@
+# Copyright 2021 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Tests of end-to-end IREE support for individual ops in the XLA HLO dialect.
+# Each test file should have a name matching the corresponding XLA HLO op and test only the
+# functionality of that op (though may make use of other ops where necessary). Tests should be
+# written using the IREE Check framework and should always pass on the reference VMLA backend.
+# See https://github.com/google/iree/blob/main/docs/developers/developing_iree/testing_guide.md#iree-core-end-to-end-tests.
+
+load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
+load("//iree:lit_test.bzl", "iree_lit_test_suite")
+load("//build_tools/bazel:iree_check_test.bzl", "iree_check_single_backend_test_suite")
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+iree_lit_test_suite(
+ name = "lit",
+ srcs = [
+ "tensor_cast.mlir",
+ ],
+ data = [
+ "//iree/tools:IreeFileCheck",
+ "//iree/tools:iree-benchmark-module",
+ "//iree/tools:iree-run-mlir",
+ "//iree/tools:iree-translate",
+ ],
+ tags = ["hostonly"],
+)
+
+iree_check_single_backend_test_suite(
+ name = "check_dylib_embedded-llvm-aot_dylib",
+ srcs = enforce_glob(
+ # keep sorted
+ [
+ "extract_slice.mlir",
+ "tensor_insert_slice.mlir",
+ ],
+ include = ["*.mlir"],
+ exclude = [
+ "tensor_cast.mlir",
+ ],
+ ),
+ compiler_flags = [
+ "-iree-llvm-link-embedded=true",
+ ],
+ driver = "dylib",
+ target_backend = "dylib-llvm-aot",
+)
+
+iree_check_single_backend_test_suite(
+ name = "check_cuda",
+ srcs = enforce_glob(
+ # keep sorted
+ [
+ "extract_slice.mlir",
+ "tensor_insert_slice.mlir",
+ ],
+ include = ["*.mlir"],
+ exclude = [
+ "tensor_cast.mlir",
+ ],
+ ),
+ compiler_flags = [
+ "-iree-llvm-link-embedded=true",
+ ],
+ driver = "cuda",
+ tags = [
+ "noasan",
+ "nomsan",
+ "notsan",
+ "noubsan",
+ "requires-gpu-nvidia",
+ ],
+ target_backend = "cuda",
+)
+
+iree_check_single_backend_test_suite(
+ name = "check_vulkan-spirv_vulkan",
+ srcs = enforce_glob(
+ # keep sorted
+ [
+ "extract_slice.mlir",
+ "tensor_insert_slice.mlir",
+ ],
+ include = ["*.mlir"],
+ exclude = [
+ "tensor_cast.mlir",
+ ],
+ ),
+ compiler_flags = [
+ "-iree-llvm-link-embedded=true",
+ ],
+ driver = "vulkan",
+ target_backend = "vulkan-spirv",
+)
diff --git a/iree/test/e2e/tensor_ops/CMakeLists.txt b/iree/test/e2e/tensor_ops/CMakeLists.txt
new file mode 100644
index 0000000..325e859
--- /dev/null
+++ b/iree/test/e2e/tensor_ops/CMakeLists.txt
@@ -0,0 +1,75 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# iree/test/e2e/tensor_ops/BUILD #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_lit_test_suite(
+ NAME
+ lit
+ SRCS
+ "tensor_cast.mlir"
+ DATA
+ iree::tools::IreeFileCheck
+ iree::tools::iree-benchmark-module
+ iree::tools::iree-run-mlir
+ iree::tools::iree-translate
+ LABELS
+ "hostonly"
+)
+
+iree_check_single_backend_test_suite(
+ NAME
+ check_dylib_embedded-llvm-aot_dylib
+ SRCS
+ "extract_slice.mlir"
+ "tensor_insert_slice.mlir"
+ TARGET_BACKEND
+ "dylib-llvm-aot"
+ DRIVER
+ "dylib"
+ COMPILER_FLAGS
+ "-iree-llvm-link-embedded=true"
+)
+
+iree_check_single_backend_test_suite(
+ NAME
+ check_cuda
+ SRCS
+ "extract_slice.mlir"
+ "tensor_insert_slice.mlir"
+ TARGET_BACKEND
+ "cuda"
+ DRIVER
+ "cuda"
+ COMPILER_FLAGS
+ "-iree-llvm-link-embedded=true"
+ LABELS
+ "noasan"
+ "nomsan"
+ "notsan"
+ "noubsan"
+ "requires-gpu-nvidia"
+)
+
+iree_check_single_backend_test_suite(
+ NAME
+ check_vulkan-spirv_vulkan
+ SRCS
+ "extract_slice.mlir"
+ "tensor_insert_slice.mlir"
+ TARGET_BACKEND
+ "vulkan-spirv"
+ DRIVER
+ "vulkan"
+ COMPILER_FLAGS
+ "-iree-llvm-link-embedded=true"
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/test/e2e/tensor_ops/extract_slice.mlir b/iree/test/e2e/tensor_ops/extract_slice.mlir
new file mode 100644
index 0000000..c81c955
--- /dev/null
+++ b/iree/test/e2e/tensor_ops/extract_slice.mlir
@@ -0,0 +1,44 @@
+func @extract_slice_strided() {
+ %0 = linalg.init_tensor [500, 750] : tensor<500x750xi32>
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ outs(%0 : tensor<500x750xi32>) {
+ ^bb0(%arg0 : i32):
+ %2 = linalg.index 0 : index
+ %3 = linalg.index 1 : index
+ %4 = arith.index_cast %2 : index to i32
+ %c750_i32 = arith.constant 750 : i32
+ %5 = arith.muli %4, %c750_i32 : i32
+ %6 = arith.index_cast %3 : index to i32
+ %7 = arith.addi %5, %6 : i32
+ linalg.yield %7 : i32
+ } -> tensor<500x750xi32>
+ %2 = tensor.extract_slice %1[20, 30] [50, 75] [2, 3]
+ : tensor<500x750xi32> to tensor<50x75xi32>
+ %3 = linalg.init_tensor [50, 75] : tensor<50x75xi32>
+ %4 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ outs(%3 : tensor<50x75xi32>) {
+ ^bb0(%arg0 : i32) :
+ %5 = linalg.index 0 : index
+ %6 = linalg.index 1 : index
+ %c20_i32 = arith.constant 20 : i32
+ %c30_i32 = arith.constant 30 : i32
+ %c2_i32 = arith.constant 2 : i32
+ %c3_i32 = arith.constant 3 : i32
+ %7 = arith.index_cast %6 : index to i32
+ %8 = arith.muli %7, %c3_i32 : i32
+ %9 = arith.addi %c30_i32, %8 : i32
+ %10 = arith.index_cast %5 : index to i32
+ %11 = arith.muli %10, %c2_i32 : i32
+ %12 = arith.addi %c20_i32, %11 : i32
+ %c750_i32 = arith.constant 750 : i32
+ %13 = arith.muli %12, %c750_i32 : i32
+ %14 = arith.addi %13, %9 : i32
+ linalg.yield %14 : i32
+ } -> tensor<50x75xi32>
+ check.expect_eq(%2, %4) : tensor<50x75xi32>
+ return
+}
diff --git a/iree/test/e2e/regression/tensor_cast.mlir b/iree/test/e2e/tensor_ops/tensor_cast.mlir
similarity index 100%
rename from iree/test/e2e/regression/tensor_cast.mlir
rename to iree/test/e2e/tensor_ops/tensor_cast.mlir
diff --git a/iree/test/e2e/regression/tensor_insert_slice.mlir b/iree/test/e2e/tensor_ops/tensor_insert_slice.mlir
similarity index 100%
rename from iree/test/e2e/regression/tensor_insert_slice.mlir
rename to iree/test/e2e/tensor_ops/tensor_insert_slice.mlir