Initial Implementation of tile (+distribute) of LinalgExt ops using `TiledOpInterface` (#6423)
This changes adds tiling transformations for LinalgExt using the
TiledOpInterface. The linalg_ext.scatter and linalg_ext.sort
(only parallel dims) operation is used as the candidate for tiling.
The operations implements the TiledOpInterface
which is used by the tiling transformation to tile + distribute the
op.
The tiling transformation is controlled similar to LinalgOps, using
LinalgTilingOptions, LinalgTransformationFilter and
LinalgLoopDistributionOptions.
diff --git a/iree/compiler/Dialect/LinalgExt/IR/BUILD b/iree/compiler/Dialect/LinalgExt/IR/BUILD
index 2530839..23c4f1e 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/BUILD
+++ b/iree/compiler/Dialect/LinalgExt/IR/BUILD
@@ -23,6 +23,7 @@
"LinalgExtBase.td",
"LinalgExtOps.td",
"LinalgExtInterfaces.td",
+ "TiledOpInterface.td",
],
include = ["*.td"],
),
@@ -47,13 +48,17 @@
deps = [
":LinalgExtInterfacesGen",
":LinalgExtOpsGen",
+ ":TiledOpInterface",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ControlFlowInterfaces",
+ "@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:IR",
+ "@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:SideEffects",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TensorDialect",
],
)
@@ -116,3 +121,41 @@
"@llvm-project//mlir:ControlFlowInterfacesTdFiles",
],
)
+
+gentbl_cc_library(
+ name = "TiledOpInterfaceGen",
+ tbl_outs = [
+ (
+ ["-gen-op-interface-decls"],
+ "TiledOpInterface.h.inc",
+ ),
+ (
+ ["-gen-op-interface-defs"],
+ "TiledOpInterface.cpp.inc",
+ ),
+ ],
+ tblgen = "@llvm-project//mlir:mlir-tblgen",
+ td_file = "TiledOpInterface.td",
+ td_srcs = [
+ "@llvm-project//mlir:OpBaseTdFiles",
+ ],
+)
+
+cc_library(
+ name = "TiledOpInterface",
+ srcs = [
+ "TiledOpInterface.cpp",
+ "TiledOpInterface.cpp.inc",
+ ],
+ hdrs = [
+ "TiledOpInterface.h",
+ "TiledOpInterface.h.inc",
+ ],
+ deps = [
+ ":TiledOpInterfaceGen",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Support",
+ "@llvm-project//mlir:ViewLikeInterface",
+ ],
+)
diff --git a/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt b/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt
index 010ef31..271f30e 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt
+++ b/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt
@@ -28,13 +28,16 @@
DEPS
::LinalgExtInterfacesGen
::LinalgExtOpsGen
+ ::TiledOpInterface
LLVMSupport
MLIRControlFlowInterfaces
MLIRIR
+ MLIRMemRef
MLIRParser
MLIRSideEffectInterfaces
MLIRStandard
MLIRSupport
+ MLIRTensor
PUBLIC
)
@@ -67,4 +70,32 @@
-gen-dialect-doc LinalgExtDialect.md
)
+iree_tablegen_library(
+ NAME
+ TiledOpInterfaceGen
+ TD_FILE
+ "TiledOpInterface.td"
+ OUTS
+ -gen-op-interface-decls TiledOpInterface.h.inc
+ -gen-op-interface-defs TiledOpInterface.cpp.inc
+)
+
+iree_cc_library(
+ NAME
+ TiledOpInterface
+ HDRS
+ "TiledOpInterface.h"
+ "TiledOpInterface.h.inc"
+ SRCS
+ "TiledOpInterface.cpp"
+ "TiledOpInterface.cpp.inc"
+ DEPS
+ ::TiledOpInterfaceGen
+ LLVMSupport
+ MLIRIR
+ MLIRSupport
+ MLIRViewLikeInterface
+ PUBLIC
+)
+
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 832b5c5..eb0c002 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -9,10 +9,16 @@
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/SMLoc.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
@@ -49,6 +55,43 @@
}
}
+/// Returns a memref.subview or a tensor.extract_slice based on the type of the
+/// `source`.
+static Value getSlice(OpBuilder &b, Location loc, Value source,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ ArrayRef<OpFoldResult> strides) {
+ return TypeSwitch<Type, Value>(source.getType())
+ .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
+ return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
+ strides);
+ })
+ .Case<MemRefType>([&](MemRefType type) -> Value {
+ return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
+ strides);
+ })
+ .Default([&](Type t) { return nullptr; });
+}
+
+Value getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim) {
+ return TypeSwitch<Type, Value>(v.getType())
+ .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
+ return builder.create<tensor::DimOp>(loc, v, dim);
+ })
+ .Case<MemRefType>([&](MemRefType t) -> Value {
+ return builder.create<memref::DimOp>(loc, v, dim);
+ })
+ .Default([&](Type t) { return Value(); });
+}
+
+OpFoldResult getDim(OpBuilder &builder, Location loc, Value v, int64_t dim) {
+ auto t = v.getType().cast<ShapedType>();
+ if (t.isDynamicDim(dim)) {
+ return getDimValue(builder, loc, v, dim);
+ }
+ return builder.getI64IntegerAttr(t.getDimSize(dim));
+}
+
//===----------------------------------------------------------------------===//
// Common methods from Linalg dialect.
//===----------------------------------------------------------------------===//
@@ -123,7 +166,7 @@
return t1.getShape()[dim] == t2.getShape()[dim];
};
- auto indicesType = op.inputs()[1].getType().cast<ShapedType>();
+ auto indicesType = op.getIndicesType();
if (indicesType.getRank() != 2 ||
!indicesType.getElementType().isInteger(32)) {
return op.emitOpError(
@@ -136,7 +179,7 @@
// The first dimension of the indices should match the first dimension of the
// output. They indicate to the number of updates.
- auto updateType = op.inputs()[0].getType().cast<ShapedType>();
+ auto updateType = op.getUpdateType();
if (updateType.getRank() < 1) {
return op.emitOpError("expected update value to be at least rank 1");
}
@@ -144,7 +187,7 @@
return op.emitOpError(
"mismatch in shape of indices and update value at dim#0");
}
- auto originalType = op.outputs()[0].getType().cast<ShapedType>();
+ auto originalType = op.getOriginalType();
// indexDepth + update dims should match to original dims. The first dim of
// update is the number of updates.
if (originalType.getRank() != indexDepth + updateType.getRank() - 1) {
@@ -196,6 +239,66 @@
return success();
}
+SmallVector<StringRef> ScatterOp::getLoopIteratorTypes() {
+ return {getParallelIteratorTypeName()};
+}
+
+SmallVector<Range> ScatterOp::getLoopBounds(OpBuilder &builder) {
+ Location loc = getLoc();
+ Value zero = builder.create<ConstantIndexOp>(loc, 0);
+ Value one = builder.create<ConstantIndexOp>(loc, 1);
+ Value ub = getDimValue(builder, loc, updates(), 0);
+ return {Range{zero, ub, one}};
+}
+
+Operation *ScatterOp::getTiledImplementation(
+ OpBuilder &builder, ValueRange outputs, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<SmallVector<OpFoldResult, 4>> &resultOffsets) {
+ assert(outputs.size() == 1 && offsets.size() == 1 && sizes.size() == 1);
+ Location loc = getLoc();
+ auto zeroAttr = builder.getI64IntegerAttr(0);
+ auto oneAttr = builder.getI64IntegerAttr(1);
+
+ // Slice of the updates.
+ auto updateRank = getUpdateType().getRank();
+ SmallVector<OpFoldResult> updateOffsets(updateRank, zeroAttr);
+ SmallVector<OpFoldResult> updateSizes(updateRank, zeroAttr);
+ updateOffsets[0] = offsets[0];
+ updateSizes[0] = sizes[0];
+ for (auto dim : llvm::seq<int64_t>(1, updateRank)) {
+ updateSizes[dim] = getDim(builder, loc, updates(), dim);
+ }
+ SmallVector<OpFoldResult> updateStrides(updateRank, oneAttr);
+ Value tiledUpdate = getSlice(builder, loc, updates(), updateOffsets,
+ updateSizes, updateStrides);
+ assert(tiledUpdate && "failed to get slice of update");
+
+ // Slice of indices.
+ auto indicesRank = getIndicesType().getRank();
+ SmallVector<OpFoldResult> indicesOffsets(indicesRank, zeroAttr);
+ SmallVector<OpFoldResult> indicesSizes(indicesRank, zeroAttr);
+ indicesOffsets[0] = offsets[0];
+ indicesSizes[0] = sizes[0];
+ for (auto dim : llvm::seq<int64_t>(1, indicesRank)) {
+ indicesSizes[dim] = getDim(builder, loc, indices(), dim);
+ }
+ SmallVector<OpFoldResult> indicesStrides(indicesRank, oneAttr);
+ Value tiledIndices = getSlice(builder, loc, indices(), indicesOffsets,
+ indicesSizes, indicesStrides);
+ assert(tiledIndices && "failed to get slice of indices");
+
+ resultOffsets.resize(1);
+ resultOffsets[0].resize(getUpdateType().getRank(), zeroAttr);
+ SmallVector<Type> resultTypes;
+ if (getNumResults()) {
+ resultTypes.push_back(getResultTypes()[0]);
+ }
+ return cast<LinalgExtOp>(getOperation())
+ .clone(builder, loc, resultTypes,
+ ValueRange{tiledUpdate, tiledIndices, outputs[0]});
+}
+
//===----------------------------------------------------------------------===//
// SortOp
//===----------------------------------------------------------------------===//
@@ -213,6 +316,9 @@
if (op.getNumInputs()) {
return op.emitOpError("does not expect to take any inputs");
}
+ if (op.getNumOutputs() == 0) {
+ return op.emitOpError("expected at least one `outs` operand");
+ }
Block &block = op.region().front();
size_t numOutputs = op.getNumOutputs();
@@ -221,7 +327,8 @@
<< 2 * numOutputs << " arguments";
}
- int rank = op.getRank(op.getOutputOperand(0));
+ int64_t rank = op.getOperandRank();
+ ArrayRef<int64_t> shape = op.getOperandShape();
if (rank > 1 && !op.dimensionAttr()) {
return op.emitOpError("dimension must be specified if rank > 1");
}
@@ -233,10 +340,18 @@
return op.emitOpError("dimension must be within (0, ") << rank << "]";
}
- for (auto indexedOperand : llvm::enumerate(op.inputs())) {
+ for (auto indexedOperand : llvm::enumerate(op.outputs())) {
int index = indexedOperand.index();
- Type elemType =
- indexedOperand.value().getType().cast<ShapedType>().getElementType();
+ auto operandType = op.getOperandType(index);
+ if (operandType.getRank() != rank) {
+ return op.emitOpError("expected operand ")
+ << index << " to be rank " << rank << ", same as other operands";
+ }
+ if (operandType.getShape() != shape) {
+ return op.emitOpError("expected operand ")
+ << index << " to have same shape as other operands";
+ }
+ Type elemType = operandType.getElementType();
for (int i : {2 * index, 2 * index + 1}) {
Type argType = block.getArgument(i).getType();
if (argType != elemType) {
@@ -259,6 +374,57 @@
return success();
}
+SmallVector<StringRef> SortOp::getLoopIteratorTypes() {
+ // All loops except the dimension to sort along are parallel.
+ SmallVector<StringRef> iteratorTypes(getOperandRank(),
+ getParallelIteratorTypeName());
+ iteratorTypes[getSortedDimension()] = getReductionIteratorTypeName();
+ return iteratorTypes;
+}
+
+SmallVector<Range> SortOp::getLoopBounds(OpBuilder &builder) {
+ int64_t operandRank = getOperandRank();
+ SmallVector<Range> loopBounds(operandRank);
+ Location loc = getLoc();
+ Value zero = builder.create<ConstantIndexOp>(loc, 0);
+ Value one = builder.create<ConstantIndexOp>(loc, 1);
+ Value source = operand(0);
+ for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
+ loopBounds[dim].offset = zero;
+ loopBounds[dim].size = getDimValue(builder, loc, source, dim);
+ loopBounds[dim].stride = one;
+ }
+ return loopBounds;
+}
+
+Operation *SortOp::getTiledImplementation(
+ OpBuilder &builder, ValueRange outputs, ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<SmallVector<OpFoldResult, 4>> &resultOffsets) {
+ assert(outputs.size() == this->outputs().size());
+ int64_t rank = getOperandRank();
+ assert(offsets.size() == static_cast<size_t>(rank) &&
+ sizes.size() == static_cast<size_t>(rank));
+ auto oneAttr = builder.getI64IntegerAttr(1);
+ SmallVector<OpFoldResult> strides(rank, oneAttr);
+ Location loc = getLoc();
+ SmallVector<Value> tiledOperands(outputs.size());
+ resultOffsets.resize(outputs.size());
+ for (auto en : llvm::enumerate(outputs)) {
+ tiledOperands[en.index()] =
+ getSlice(builder, getLoc(), en.value(), offsets, sizes, strides);
+ assert(tiledOperands[en.index()] && "failed to get slice of operand");
+ resultOffsets[en.index()].assign(offsets.begin(), offsets.end());
+ }
+ SmallVector<Type, 4> resultTypes;
+ if (getNumResults()) {
+ resultTypes = llvm::to_vector<4>(
+ llvm::map_range(tiledOperands, [&](Value v) { return v.getType(); }));
+ }
+ return cast<LinalgExtOp>(getOperation())
+ .clone(builder, loc, resultTypes, tiledOperands);
+}
+
} // namespace linalg_ext
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h
index 1925f54..3100b6f 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h
@@ -7,6 +7,8 @@
#ifndef IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTOPS_H_
#define IREE_COMPILER_DIALECT_LINALGEXT_IR_LINALGEXTOPS_H_
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
@@ -17,14 +19,19 @@
namespace mlir {
namespace iree_compiler {
namespace linalg_ext {
-class LinalgExtOp;
+
+/// Returns a `memref.dim` or `tensor.dim` operation to get the shape of `v` at
+/// `dim`.
+Value getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim);
+
+/// Returns a `memref.dim` or `tensor.dim` operation to get the shape of `v` at
+/// `dim`. If the shape is constant, returns the shape as an `IntegerAttr`.
+OpFoldResult getDim(OpBuilder &builder, Location loc, Value v, int64_t dim);
} // namespace linalg_ext
} // namespace iree_compiler
} // namespace mlir
-#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
-
#define GET_OP_CLASSES
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h.inc" // IWYU pragma: export
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index e76c441..d691fed 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -9,6 +9,7 @@
include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtBase.td"
include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td"
+include "iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
@@ -34,7 +35,9 @@
// Non-structured ops
//===----------------------------------------------------------------------===//
-def LinalgExt_ScatterOp : LinalgExt_Op<"scatter"> {
+def LinalgExt_ScatterOp : LinalgExt_Op<"scatter",
+ [DeclareOpInterfaceMethods<TiledOpInterface,
+ ["getTiledImplementation"]>]> {
let summary = "Scatter operator";
let description = [{
Based on XLA operation semantics, takes two `inputs` (`update` and
@@ -88,14 +91,26 @@
return getInputOperand(0)->get();
}
+ ShapedType getUpdateType() {
+ return updates().getType().cast<ShapedType>();
+ }
+
Value indices() {
return getInputOperand(1)->get();
}
+ ShapedType getIndicesType() {
+ return indices().getType().cast<ShapedType>();
+ }
+
Value original() {
return getOutputOperand(0)->get();
}
+ ShapedType getOriginalType() {
+ return original().getType().cast<ShapedType>();
+ }
+
int64_t getUpdateSliceRank() {
return updates().getType().cast<ShapedType>().getRank() - 1;
}
@@ -106,7 +121,9 @@
}];
}
-def LinalgExt_SortOp : LinalgExt_Op<"sort"> {
+def LinalgExt_SortOp : LinalgExt_Op<"sort",
+ [DeclareOpInterfaceMethods<TiledOpInterface,
+ ["getTiledImplementation"]>]> {
let summary = "Sort operator";
let description = [{
Based on XLA operation semantics, sorts the given `operands` at the given
@@ -116,7 +133,7 @@
}];
// Define arguments and results like linalg.generic op. The attribute has the
- // same definision as mhlo.sort::dimension. If the rank is greater than 1,
+ // same definition as mhlo.sort::dimension. If the rank is greater than 1,
// the attribute must be set. If the rank is exacatly 1, the dimension is
// optional.
let arguments = (ins Variadic<AnyType>:$inputs,
@@ -131,6 +148,27 @@
custom<LinalgExtOutsList>($outputs, type($outputs))
$region (`->` type($results)^)?
}];
+ let extraClassDeclaration = [{
+ Value operand(int index) {
+ return outputs()[index];
+ }
+ ShapedType getOperandType(int index) {
+ return operand(index).getType().cast<ShapedType>();
+ }
+ int64_t getOperandRank() {
+ return getOperandType(0).getRank();
+ }
+ ArrayRef<int64_t> getOperandShape() {
+ return getOperandType(0).getShape();
+ }
+ uint64_t getSortedDimension() {
+ uint64_t sortedDim = 0;
+ if (auto setSortedDim = dimension()) {
+ sortedDim = *setSortedDim;
+ }
+ return sortedDim;
+ }
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.cpp b/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.cpp
new file mode 100644
index 0000000..b65d5e8
--- /dev/null
+++ b/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.cpp
@@ -0,0 +1,17 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace linalg_ext {
+
+#include "iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.cpp.inc"
+
+}
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.h b/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.h
new file mode 100644
index 0000000..164531b
--- /dev/null
+++ b/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.h
@@ -0,0 +1,19 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_LINALGEXT_IR_TILEDOPINTERFACE_H_
+#define IREE_COMPILER_DIALECT_LINALGEXT_IR_TILEDOPINTERFACE_H_
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
+#include "mlir/Support/LLVM.h"
+
+/// Include the ODS generated interface header files.
+#include "iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.h.inc"
+
+#endif // IREE_COMPILER_DIALECT_LINALGEXT_IR_TILEDOPINTERFACE_H_
diff --git a/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.td b/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.td
new file mode 100644
index 0000000..697acae
--- /dev/null
+++ b/iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.td
@@ -0,0 +1,61 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_DIALECT_LINALGEXT_TILEDOPINTERFACE
+#define IREE_DIALECT_LINALGEXT_TILEDOPINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def TiledOpInterface : OpInterface<"TiledOpInterface"> {
+ let description = [{
+ Interface for allowing operations to expose information needed to
+ tile it (similar to LinalgOp, but without having access to
+ indexing maps)
+ }];
+ let cppNamespace = "::mlir::iree_compiler::linalg_ext";
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns a list of `StringRef`s that describe the number of
+ loops and the iterator types of the operation. The list is
+ expected to use
+ `getParallelIteratorTypeName()`/`getReductionIteratorTypeName()`
+ from MLIR Structured Op Utils.
+ }],
+ /*retType=*/"SmallVector<StringRef>",
+ /*methodName=*/"getLoopIteratorTypes"
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns a list of ranges that describe the loop bounds and
+ step for the loops of the operation.
+ }],
+ /*retTy=*/"SmallVector<Range>",
+ /*methodName=*/"getLoopBounds",
+ /*args=*/(ins "OpBuilder &":$b)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Generates a tiled version of the operation given the tile
+ size for the loops.
+ }],
+ /*retType=*/"Operation *",
+ /*methodName=*/"getTiledImplementation",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "ValueRange ":$outputs,
+ "ArrayRef<OpFoldResult> ":$offsets,
+ "ArrayRef<OpFoldResult> ":$sizes,
+ "SmallVectorImpl<SmallVector<OpFoldResult, 4>> &":$resultOffsets),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return nullptr;
+ }]
+ >
+ ];
+}
+
+#endif // IREE_DIALECT_LINALGEXT_TILEDOPINTERFACES
diff --git a/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir b/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
index 13f9acd..86db7f0 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
+++ b/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
@@ -26,6 +26,34 @@
// -----
+func @sort_mismatch_rank(%arg0: tensor<?x?xi32>, %arg1: tensor<?xf32>)
+ -> (tensor<?x?xi32>, tensor<?xf32>) {
+ // expected-error @+1 {{expected operand 1 to be rank 2, same as other operands}}
+ %0:2 = linalg_ext.sort dimension(0)
+ outs(%arg0, %arg1 : tensor<?x?xi32>, tensor<?xf32>) {
+ ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32): // no predecessors
+ %1 = cmpf ogt, %arg4, %arg5 : f32
+ linalg_ext.yield %1 : i1
+ } -> tensor<?x?xi32>, tensor<?xf32>
+ return %0#0, %0#1 : tensor<?x?xi32>, tensor<?xf32>
+}
+
+// -----
+
+func @sort_mismatch_shape(%arg0: tensor<?xi32>, %arg1: tensor<42xf32>)
+ -> (tensor<?xi32>, tensor<42xf32>) {
+ // expected-error @+1 {{expected operand 1 to have same shape as other operands}}
+ %0:2 = linalg_ext.sort dimension(0)
+ outs(%arg0, %arg1 : tensor<?xi32>, tensor<42xf32>) {
+ ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32): // no predecessors
+ %1 = cmpf ogt, %arg4, %arg5 : f32
+ linalg_ext.yield %1 : i1
+ } -> tensor<?xi32>, tensor<42xf32>
+ return %0#0, %0#1 : tensor<?xi32>, tensor<42xf32>
+}
+
+// -----
+
func @scatter_mixed_tensor_memref(
%update : memref<?x?xf32>, %indices : tensor<?x1xi32>,
%original : tensor<?x?xf32>) -> tensor<?x?xf32> {
diff --git a/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir b/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
index 12dc335..6d73fbf 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
+++ b/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
@@ -32,6 +32,44 @@
// -----
+func @sort_multi_result_tensor(
+ %arg0: tensor<?x?xi32>, %arg1: tensor<?x?xf32>)
+ -> (tensor<?x?xi32>, tensor<?x?xf32>) {
+ %0:2 = linalg_ext.sort dimension(0)
+ outs(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xf32>) {
+ ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32): // no predecessors
+ %1 = cmpf ogt, %arg4, %arg5 : f32
+ linalg_ext.yield %1 : i1
+ } -> tensor<?x?xi32>, tensor<?x?xf32>
+ return %0#0, %0#1 : tensor<?x?xi32>, tensor<?x?xf32>
+}
+// CHECK-LABEL: func @sort_multi_result_tensor
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>
+// CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xf32>
+// CHECK: %[[RESULT:.+]]:2 = linalg_ext.sort dimension(0)
+// CHECK-SAME: outs(%[[ARG0]], %[[ARG1]]
+// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
+func @sort_multi_result_memref(
+ %arg0: memref<?x?xi32>, %arg1: memref<?x?xf32>) {
+ linalg_ext.sort dimension(0)
+ outs(%arg0, %arg1 : memref<?x?xi32>, memref<?x?xf32>) {
+ ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32): // no predecessors
+ %1 = cmpf ogt, %arg4, %arg5 : f32
+ linalg_ext.yield %1 : i1
+ }
+ return
+}
+// CHECK-LABEL: func @sort_multi_result_memref
+// CHECK-SAME: %[[ARG0:.+]]: memref<?x?xi32>
+// CHECK-SAME: %[[ARG1:.+]]: memref<?x?xf32>
+// CHECK: linalg_ext.sort dimension(0)
+// CHECK-SAME: outs(%[[ARG0]], %[[ARG1]]
+
+// -----
+
func @scatter_tensor_dynamic(
%original: tensor<?x?xf32>, %indices: tensor<?x1xi32>,
%update: tensor<?x?xf32>) -> tensor<?x?xf32> {
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/BUILD b/iree/compiler/Dialect/LinalgExt/Transforms/BUILD
index 3139752..e48f8ef 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/BUILD
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/BUILD
@@ -31,24 +31,30 @@
name = "Transforms",
srcs = [
"ConvertToLoops.cpp",
- "PassDetail.h",
"Passes.cpp",
+ "Tiling.cpp",
],
hdrs = [
+ "PassDetail.h",
"Passes.h",
"Passes.h.inc",
+ "Transforms.h",
],
deps = [
":PassesIncGen",
+ "//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/LinalgExt/IR",
"@llvm-project//llvm:Support",
+ "@llvm-project//mlir:Affine",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
+ "@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
+ "@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
],
)
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt b/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
index b0ae75c..e0edb4c 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/CMakeLists.txt
@@ -23,23 +23,29 @@
NAME
Transforms
HDRS
+ "PassDetail.h"
"Passes.h"
"Passes.h.inc"
+ "Transforms.h"
SRCS
"ConvertToLoops.cpp"
- "PassDetail.h"
"Passes.cpp"
+ "Tiling.cpp"
DEPS
::PassesIncGen
LLVMSupport
+ MLIRAffine
MLIRIR
MLIRLinalg
+ MLIRLinalgTransforms
MLIRMemRef
MLIRPass
MLIRSCF
MLIRStandard
MLIRSupport
+ MLIRTensor
MLIRTransforms
+ iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::LinalgExt::IR
PUBLIC
)
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h b/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
index df5ed5a..d84b97e 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h
@@ -13,6 +13,8 @@
namespace iree_compiler {
namespace linalg_ext {
+std::unique_ptr<OperationPass<FuncOp>> createLinalgExtTilingPass();
+
std::unique_ptr<OperationPass<FuncOp>> createLinalgExtToLoopsPass();
void registerLinalgExtPasses();
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td b/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
index 41cc182..e221e33 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/Passes.td
@@ -15,4 +15,10 @@
let constructor = "mlir::iree_compiler::linalg_ext::createLinalgExtToLoopsPass()";
}
+def LinalgExtTiling :
+ Pass<"iree-linalg-ext-tile", "FuncOp"> {
+ let summary = "Test pass for tiling LinalgExt ops using TiledOpInterface";
+ let constructor = "mlir::iree_compiler::linalg_ext::createLinalgExtTilingPass()";
+}
+
#endif // IREE_DIALECT_LINALGEXT_PASSES
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp b/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp
new file mode 100644
index 0000000..7b4be2c
--- /dev/null
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/Tiling.cpp
@@ -0,0 +1,433 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
+#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
+#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
+#include "iree/compiler/Dialect/LinalgExt/Transforms/PassDetail.h"
+#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
+#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace linalg_ext {
+
+//===----------------------------------------------------------------------===//
+// Utility methods for tiling a linalg_ext operation that implements a
+// TiledOpInterface
+//===----------------------------------------------------------------------===//
+
+/// Returns failure if the options are unsupported.
+static LogicalResult verifySupportedTilingOptions(
+ PatternRewriter &rewriter, Operation *op,
+ const linalg::LinalgTilingOptions &options) {
+ if (!options.interchangeVector.empty()) {
+ return rewriter.notifyMatchFailure(op,
+ "unsupported interchange during tiling");
+ }
+ if (options.paddingValueComputationFunction) {
+ return rewriter.notifyMatchFailure(op, "unsupported tile + pad option");
+ }
+ if (options.loopType != linalg::LinalgTilingLoopType::Loops) {
+ return rewriter.notifyMatchFailure(op,
+ "only tiling with scf.for is supported");
+ }
+ if (options.distribution) {
+ if (llvm::any_of(options.distribution->distributionMethod,
+ [](linalg::DistributionMethod method) {
+ return method != linalg::DistributionMethod::Cyclic;
+ })) {
+ return rewriter.notifyMatchFailure(op,
+ "only cyclic distibution is allowed");
+ }
+ }
+ return success();
+}
+
+/// Converts a `Value` to an `OpFoldRedult` by extracting the constant value if
+/// the value is defined by a constant op.
+static OpFoldResult getOpFoldResult(Value value) {
+ IntegerAttr::ValueType attr;
+ if (matchPattern(value, m_ConstantInt(&attr))) {
+ return IntegerAttr::get(value.getType(), attr);
+ }
+ return value;
+}
+static SmallVector<OpFoldResult, 4> getOpFoldResult(ArrayRef<Value> values) {
+ return llvm::to_vector<4>(llvm::map_range(
+ values, [](Value value) { return getOpFoldResult(value); }));
+}
+
+/// Converts an `OpFoldResult` to a `Value` by building a constant op if
+/// if the `OpFoldResult` is an `IntegerAttr`.
+static Value getValue(OpBuilder &builder, Location loc,
+ OpFoldResult valueOrAttr) {
+ if (auto attr = valueOrAttr.dyn_cast<Attribute>()) {
+ return builder.create<ConstantIndexOp>(loc,
+ attr.cast<IntegerAttr>().getInt());
+ }
+ return valueOrAttr.get<Value>();
+}
+
+/// Returns true if loop is untiled. Only checks if the value is statically
+/// zero. It is assumed that a `Value` defined by a constant op is already
+/// converted to an `IntegerAttr` of that value. So here just return true if
+/// this is an attribute with a zero value.
+static bool isUntiledLoop(OpFoldResult valueOrAttr) {
+ auto attr = valueOrAttr.dyn_cast<Attribute>();
+ return attr && attr.cast<IntegerAttr>().getValue() == 0;
+}
+
+/// Generates the tiled loops and the body by invoking the interface methods of
+/// TiledOpInterface.
+/// - `outputs` are the operands to use for outputs of the tiled operation.
+/// - `tileSizes` are tile sizes specified for all loops of the operation. If a
+/// loop is to be untiled it is set to 0.
+/// - `iteratorType` is the type of the loop iterator returned by the
+/// TiledOpInterface.
+/// - `loopBounds` are the bounds of all the loops of the op returned by the
+/// TiledOpInterface.
+/// - `loopDepth` is the current loop depth being processed.
+/// - `offsets` are the `Value`s that represent the position of the tile being
+/// operated on. The offsets are computed as the tiled loops are being
+/// generated.
+/// - `distributionInfo` is the proc_id and nprocs `Value`s to be used for
+/// distributed loops. It is a stack, and once an entry at the top of the
+/// stack is used for distribution it is popped before processing the inner
+/// loops.
+static FailureOr<TiledOp> tileLinalgExtOpImpl(
+ OpBuilder &builder, TiledOpInterface op, ValueRange outputs,
+ MutableArrayRef<OpFoldResult> tileSizes, ArrayRef<StringRef> iteratorTypes,
+ ArrayRef<Range> loopBounds, unsigned loopDepth,
+ SmallVectorImpl<OpFoldResult> &offsets,
+ ArrayRef<linalg::ProcInfo> distributionInfo) {
+ Location loc = op.getLoc();
+ // If this is the innermost loop, then generated the tiled implementation of
+ // the op by invoking the TiledOpInterface methods.
+ if (loopDepth == tileSizes.size()) {
+ SmallVector<SmallVector<OpFoldResult, 4>> resultOffsets;
+ Operation *tiledOp = op.getTiledImplementation(builder, outputs, offsets,
+ tileSizes, resultOffsets);
+ if (!tiledOp) {
+ return static_cast<LogicalResult>(
+ op.emitOpError("failed to get tiled implementation"));
+ }
+ assert(tiledOp->getNumResults() == 0 ||
+ (resultOffsets.size() == tiledOp->getNumResults()));
+ TiledOp ret;
+ ret.op = tiledOp;
+
+ // If the operation has results, then the result of the tiled operation is
+ // to be inserted into the `initValues` and returned.
+ if (tiledOp->getNumResults()) {
+ SmallVector<Value> results;
+ results.reserve(tiledOp->getNumResults());
+ for (auto en : llvm::enumerate(tiledOp->getResults())) {
+ Value result = en.value();
+ ArrayRef<OpFoldResult> offsets(resultOffsets[en.index()]);
+ auto resultType = result.getType().cast<ShapedType>();
+ auto oneAttr = builder.getI64IntegerAttr(1);
+ SmallVector<OpFoldResult> strides(resultType.getRank(), oneAttr);
+ auto sizes = llvm::to_vector<4>(llvm::map_range(
+ llvm::seq<int64_t>(0, resultType.getRank()),
+ [&](int64_t dim) { return getDim(builder, loc, result, dim); }));
+ Value insert = builder.create<tensor::InsertSliceOp>(
+ loc, result, outputs[en.index()], offsets, sizes, strides);
+ results.push_back(insert);
+ }
+ std::swap(ret.results, results);
+ }
+ return ret;
+ }
+
+ // If tile size at this depth is empty, do nothing.
+ if (isUntiledLoop(tileSizes[loopDepth])) {
+ auto zeroAttr = builder.getI64IntegerAttr(0);
+ offsets.push_back(zeroAttr);
+ assert(matchPattern(loopBounds[loopDepth].offset, m_Zero()) &&
+ "expected loop bounds to have lower bound of zero");
+ tileSizes[loopDepth] = getOpFoldResult(loopBounds[loopDepth].size);
+ return tileLinalgExtOpImpl(builder, op, outputs, tileSizes, iteratorTypes,
+ loopBounds, loopDepth + 1, offsets,
+ distributionInfo);
+ }
+
+ // Generate an scf.for for the current loop depth.
+ Value lb = loopBounds[loopDepth].offset;
+ Value ub = loopBounds[loopDepth].size;
+ if (!matchPattern(loopBounds[loopDepth].stride, m_One())) {
+ return static_cast<LogicalResult>(
+ op.emitOpError("expected stride to be 1"));
+ }
+ Value step = getValue(builder, loc, tileSizes[loopDepth]);
+
+ // Update lb, ub and step for cyclic distribution.
+ if (!distributionInfo.empty() &&
+ iteratorTypes[loopDepth] == getParallelIteratorTypeName()) {
+ linalg::updateBoundsForCyclicDistribution(
+ builder, loc, distributionInfo.front().procId,
+ distributionInfo.front().nprocs, lb, ub, step);
+ distributionInfo = distributionInfo.drop_front();
+ }
+ FailureOr<TiledOp> innerReturnValue;
+ bool isBufferTiling = op->getNumResults() == 0;
+ ValueRange initValues(isBufferTiling ? ValueRange{} : outputs);
+ auto forOp = builder.create<scf::ForOp>(
+ loc, lb, ub, step, initValues,
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
+ offsets.push_back(iv);
+ auto affineMaps = AffineMap::inferFromExprList({ArrayRef<AffineExpr>{
+ b.getAffineSymbolExpr(0),
+ b.getAffineSymbolExpr(1) - b.getAffineDimExpr(0)}})[0];
+ // Similar to linalg tiling, the tile size is the min(tileSizes, ub -
+ // iv) to account for cases where tile size does not divide (ub - lb)
+ // exactly.
+ Value inBoundsTileSize = b.create<AffineMinOp>(
+ loc, affineMaps,
+ ValueRange{iv, getValue(builder, loc, tileSizes[loopDepth]), ub});
+ tileSizes[loopDepth] = getOpFoldResult(inBoundsTileSize);
+ // Recursively proceed to generate the tiled loop for the next level.
+ innerReturnValue = tileLinalgExtOpImpl(
+ b, op, (isBufferTiling ? outputs : args), tileSizes, iteratorTypes,
+ loopBounds, loopDepth + 1, offsets, distributionInfo);
+ if (failed(innerReturnValue)) return;
+ b.create<scf::YieldOp>(loc, innerReturnValue->results);
+ });
+ if (failed(innerReturnValue)) {
+ return innerReturnValue;
+ }
+ innerReturnValue->loops.insert(innerReturnValue->loops.begin(),
+ forOp.getOperation());
+ innerReturnValue->results = forOp.getResults();
+ return innerReturnValue;
+}
+
+FailureOr<TiledOp> tileLinalgExtOp(OpBuilder &b, LinalgExtOp op,
+ const linalg::LinalgTilingOptions &options) {
+ TiledOpInterface tilableOp = dyn_cast<TiledOpInterface>(op.getOperation());
+ if (!tilableOp) return TiledOp{};
+
+ SmallVector<StringRef> iteratorTypes = tilableOp.getLoopIteratorTypes();
+ SmallVector<Value, 4> tileSizesVals =
+ options.tileSizeComputationFunction(b, tilableOp.getOperation());
+ auto zeroAttr = b.getI64IntegerAttr(0);
+
+ // The actual tile sizes used converts `Value` defined as constant 0, to a
+ // zero integer attributes. Currently if the iterator type is not "parallel",
+ // the tile size is forced to zero as well.
+ auto tileSizes = getOpFoldResult(tileSizesVals);
+ tileSizes.resize(iteratorTypes.size(), zeroAttr);
+ for (auto en : llvm::enumerate(iteratorTypes)) {
+ if (en.value() == getParallelIteratorTypeName()) continue;
+ if (!isUntiledLoop(tileSizes[en.index()])) {
+ return static_cast<LogicalResult>(op.emitOpError(
+ "unimplemented tiling of non-parallel loop iterator type"));
+ }
+ }
+
+ // Trivial early exit case of tile sizes being zero for all parallel loops.
+ if (llvm::all_of(tileSizes, isUntiledLoop)) {
+ return TiledOp{op.getOperation(), {}, {}};
+ }
+
+ SmallVector<Range> loopBounds = tilableOp.getLoopBounds(b);
+ SmallVector<linalg::ProcInfo> distributionInfo;
+ // If the tiled loops are distributed, get the proc_id and nprocs for the
+ // distributed loops. First collect the parallel loops by iterating over the
+ // tileSizes and getting the loops that are distribute, i.e.,
+ // - parallel, i.e. iteratorTypes is "parallel"
+ // - tiled, i.e. tileSize != 0
+ if (options.distribution) {
+ SmallVector<Range> distributedLoopRange;
+ for (auto i : llvm::seq<unsigned>(0, tileSizes.size())) {
+ if (isUntiledLoop(tileSizes[i])) continue;
+ if (iteratorTypes[i] != getParallelIteratorTypeName()) continue;
+ distributedLoopRange.push_back(loopBounds[i]);
+ }
+ distributionInfo =
+ options.distribution->procInfo(b, op.getLoc(), distributedLoopRange);
+ }
+
+ SmallVector<OpFoldResult> offsets;
+ return tileLinalgExtOpImpl(b, tilableOp, op.outputs(), tileSizes,
+ iteratorTypes, loopBounds, 0, offsets,
+ distributionInfo);
+}
+
+//===----------------------------------------------------------------------===//
+// Patterns for tiling LinalgExtOps.
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Base pattern for tiling LinalgExtOps.
+struct LinalgExtBaseTilingPattern : public RewritePattern {
+ LinalgExtBaseTilingPattern(StringRef opName, MLIRContext *context,
+ linalg::LinalgTilingOptions options,
+ linalg::LinalgTransformationFilter filter =
+ linalg::LinalgTransformationFilter(),
+ PatternBenefit benefit = 1)
+ : RewritePattern(opName, benefit, context),
+ filter(filter),
+ options(options) {}
+
+ LogicalResult matchAndRewriteBase(Operation *op, PatternRewriter &rewriter,
+ TiledOp &result) const;
+
+ private:
+ /// LinalgTransformMarker handles special attribute manipulations.
+ linalg::LinalgTransformationFilter filter;
+ /// Options to control tiling;
+ linalg::LinalgTilingOptions options;
+};
+
+template <typename OpTy>
+struct LinalgExtTilingPattern : public LinalgExtBaseTilingPattern {
+ LinalgExtTilingPattern(MLIRContext *context,
+ linalg::LinalgTilingOptions options,
+ linalg::LinalgTransformationFilter filter =
+ linalg::LinalgTransformationFilter(),
+ PatternBenefit benefit = 1)
+ : LinalgExtBaseTilingPattern(OpTy::getOperationName(), context, options,
+ filter, benefit) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ TiledOp tiledOp;
+ // Check for failure.
+ if (failed(LinalgExtBaseTilingPattern::matchAndRewriteBase(op, rewriter,
+ tiledOp))) {
+ return failure();
+ }
+ // Check for do-nothing case.
+ if (!tiledOp.op) return failure();
+ if (tiledOp.op != op) {
+ if (tiledOp.results.empty()) {
+ rewriter.eraseOp(op);
+ } else {
+ rewriter.replaceOp(op, tiledOp.results);
+ }
+ }
+ return success();
+ }
+};
+} // namespace
+
+LogicalResult LinalgExtBaseTilingPattern::matchAndRewriteBase(
+ Operation *op, PatternRewriter &rewriter, TiledOp &result) const {
+ auto linalgExtOp = dyn_cast<LinalgExtOp>(op);
+ if (!linalgExtOp) return failure();
+ if (failed(filter.checkAndNotify(rewriter, op))) return failure();
+ if (failed(verifySupportedTilingOptions(rewriter, op, options))) {
+ return failure();
+ }
+
+ FailureOr<TiledOp> res = tileLinalgExtOp(rewriter, linalgExtOp, options);
+ if (failed(res)) return res;
+ result = *res;
+ if (result.op) {
+ filter.replaceLinalgTransformationFilter(rewriter, result.op);
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Test pass for tiling Linalg Ext ops
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct LinalgExtTilingPass : public LinalgExtTilingBase<LinalgExtTilingPass> {
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry
+ .insert<AffineDialect, IREE::Flow::FlowDialect, linalg::LinalgDialect,
+ memref::MemRefDialect, StandardOpsDialect,
+ tensor::TensorDialect, scf::SCFDialect>();
+ }
+
+ void runOnOperation() override;
+};
+} // namespace
+
+template <typename OpTy>
+static Value buildFlowWorkgroupInfoOp(OpBuilder &b, unsigned dim) {
+ return b.template create<OpTy>(b.getInsertionPoint()->getLoc(), dim);
+}
+
+void LinalgExtTilingPass::runOnOperation() {
+ FuncOp funcOp = getOperation();
+ MLIRContext *context = funcOp.getContext();
+ RewritePatternSet patterns(context);
+ patterns.add<LinalgExtTilingPattern<ScatterOp>>(
+ context, linalg::LinalgTilingOptions().setTileSizes({10, 20}),
+ linalg::LinalgTransformationFilter(
+ Identifier::get("tiling_input", context),
+ Identifier::get("tiling_output", context)));
+ patterns.add<LinalgExtTilingPattern<ScatterOp>>(
+ context, linalg::LinalgTilingOptions().setTileSizes(ArrayRef<int64_t>{0}),
+ linalg::LinalgTransformationFilter(
+ Identifier::get("no_tiling_input", context),
+ Identifier::get("no_tiling_output", context)));
+ patterns.add<LinalgExtTilingPattern<SortOp>>(
+ context, linalg::LinalgTilingOptions().setTileSizes({0, 20}),
+ linalg::LinalgTransformationFilter(
+ Identifier::get("outer_reduce_input", context),
+ Identifier::get("outer_reduce_output", context)));
+ patterns.add<LinalgExtTilingPattern<SortOp>>(
+ context, linalg::LinalgTilingOptions().setTileSizes({10, 0, 0}),
+ linalg::LinalgTransformationFilter(
+ Identifier::get("inner_reduce_input", context),
+ Identifier::get("inner_reduce_output", context)));
+
+ static linalg::LinalgLoopDistributionOptions workgroupDistributionOptions = {
+ [](OpBuilder &builder, Location loc, ArrayRef<Range> parallelLoopRanges) {
+ auto numParallelDims = parallelLoopRanges.size();
+
+ SmallVector<linalg::ProcInfo, 3> procInfo(numParallelDims);
+ for (size_t dim = 0; dim < numParallelDims; ++dim) {
+ procInfo[numParallelDims - dim - 1] = {
+ buildFlowWorkgroupInfoOp<IREE::Flow::DispatchWorkgroupIDOp>(
+ builder, dim),
+ buildFlowWorkgroupInfoOp<IREE::Flow::DispatchWorkgroupCountOp>(
+ builder, dim)};
+ }
+ return procInfo;
+ },
+ {linalg::DistributionMethod::Cyclic, linalg::DistributionMethod::Cyclic,
+ linalg::DistributionMethod::Cyclic},
+ DenseMap<StringRef,
+ std::function<linalg::ProcInfo(OpBuilder &, Location)>>()};
+
+ patterns
+ .add<LinalgExtTilingPattern<ScatterOp>, LinalgExtTilingPattern<SortOp>>(
+ context,
+ linalg::LinalgTilingOptions()
+ .setTileSizes(ArrayRef<int64_t>{10, 0, 30})
+ .setDistributionOptions(workgroupDistributionOptions),
+ linalg::LinalgTransformationFilter(
+ Identifier::get("distribute_input", context),
+ Identifier::get("distribute_output", context)));
+
+ if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
+ return signalPassFailure();
+ }
+}
+
+std::unique_ptr<OperationPass<FuncOp>> createLinalgExtTilingPass() {
+ return std::make_unique<LinalgExtTilingPass>();
+}
+
+} // namespace linalg_ext
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h b/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h
new file mode 100644
index 0000000..08e1a81
--- /dev/null
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h
@@ -0,0 +1,37 @@
+// Copyright 2021 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef IREE_COMPILER_DIALECT_LINALGEXT_TRANSFORMS_TRANSFORMS_H_
+#define IREE_COMPILER_DIALECT_LINALGEXT_TRANSFORMS_TRANSFORMS_H_
+
+#include "iree/compiler/Dialect/LinalgExt/IR/TiledOpInterface.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+
+namespace mlir {
+namespace iree_compiler {
+namespace linalg_ext {
+
+/// Structure to represent the result of tiling operation.
+struct TiledOp {
+ /// Tiled op.
+ Operation *op;
+ /// Loops generated during tiling.
+ SmallVector<Operation *> loops;
+ /// Values that are replacements for the untiled operations.
+ SmallVector<Value> results;
+};
+
+/// Main entry point for tiling LinalgExtOps using TiledOpInterface. If the
+/// `op` does not implement the `TiledOpInterface` returns a `TiledOp{}` value.
+FailureOr<TiledOp> tileLinalgExtOp(OpBuilder &b, LinalgExtOp op,
+ const linalg::LinalgTilingOptions &options);
+
+} // namespace linalg_ext
+} // namespace iree_compiler
+} // namespace mlir
+
+#endif // IREE_COMPILER_DIALECT_LINALGEXT_TRANSFORMS_TRANSFORMS_H_
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD b/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD
index a2aef95..e12869b 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/test/BUILD
@@ -18,6 +18,7 @@
srcs = enforce_glob(
[
"convert_to_loops.mlir",
+ "tiling.mlir",
],
include = ["*.mlir"],
),
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt b/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
index 931e320..0c94851 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/test/CMakeLists.txt
@@ -15,6 +15,7 @@
lit
SRCS
"convert_to_loops.mlir"
+ "tiling.mlir"
DATA
iree::tools::IreeFileCheck
iree::tools::iree-opt
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir b/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
new file mode 100644
index 0000000..b870530
--- /dev/null
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
@@ -0,0 +1,422 @@
+// RUN: iree-opt -iree-linalg-ext-tile -split-input-file %s | IreeFileCheck %s
+
+func @scatter_tiling(
+ %original: tensor<?x?xf32>, %indices: tensor<?x1xi32>,
+ %update : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg_ext.scatter
+ {__internal_linalg_transform__ = "tiling_input"}
+ ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
+ outs(%original : tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = addf %arg1, %arg2 : f32
+ linalg_ext.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK: func @scatter_tiling(
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: tensor<?x1xi32>
+// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-DAG: %[[TILESIZE:.+]] = constant 10 : index
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[UPDATES]], %[[C0]]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[D0]] step %[[TILESIZE]]
+// CHECK-SAME: iter_args(%[[INIT:.+]] = %[[ORIGINAL]])
+// CHECK-DAG: %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D0]]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[UPDATES]], %[[C1]]
+// CHECK: %[[UPDATE_SLICE:.+]] = tensor.extract_slice %[[UPDATES]][%[[IV]], 0]
+// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]]
+// CHECK: %[[INDEX_SLICE:.+]] = tensor.extract_slice %[[INDICES]][%[[IV]], 0]
+// CHECK-SAME: [%[[USED_TILESIZE]], 1]
+// CHECK: %[[SCATTER_TILE:.+]] = linalg_ext.scatter
+// CHECK-SAME: __internal_linalg_transform__ = "tiling_output"
+// CHECK-SAME: ins(%[[UPDATE_SLICE]], %[[INDEX_SLICE]]
+// CHECK-SAME: outs(%[[INIT]]
+// CHECK-DAG: %[[SLICE_D0:.+]] = tensor.dim %[[SCATTER_TILE]], %[[C0]]
+// CHECK-DAG: %[[SLICE_D1:.+]] = tensor.dim %[[SCATTER_TILE]], %[[C1]]
+// CHECK: %[[YIELD:.+]] = tensor.insert_slice %[[SCATTER_TILE]] into %[[INIT]][0, 0]
+// CHECK-SAME: [%[[SLICE_D0]], %[[SLICE_D1]]]
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @scatter_tiling_memref(
+ %original: memref<?x?xf32>, %indices: memref<?x1xi32>,
+ %update : memref<?x?xf32>) {
+ linalg_ext.scatter
+ {__internal_linalg_transform__ = "tiling_input"}
+ ins(%update, %indices : memref<?x?xf32>, memref<?x1xi32>)
+ outs(%original : memref<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = addf %arg1, %arg2 : f32
+ linalg_ext.yield %1 : f32
+ }
+ return
+}
+// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK: func @scatter_tiling_memref(
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: memref<?x1xi32>
+// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-DAG: %[[TILESIZE:.+]] = constant 10 : index
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[D0:.+]] = memref.dim %[[UPDATES]], %[[C0]]
+// CHECK: scf.for %[[IV:.+]] = %[[C0]] to %[[D0]] step %[[TILESIZE]]
+// CHECK-DAG: %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D0]]]
+// CHECK-DAG: %[[D1:.+]] = memref.dim %[[UPDATES]], %[[C1]]
+// CHECK: %[[UPDATE_SLICE:.+]] = memref.subview %[[UPDATES]][%[[IV]], 0]
+// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]]
+// CHECK: %[[INDEX_SLICE:.+]] = memref.subview %[[INDICES]][%[[IV]], 0]
+// CHECK-SAME: [%[[USED_TILESIZE]], 1]
+// CHECK: linalg_ext.scatter
+// CHECK-SAME: __internal_linalg_transform__ = "tiling_output"
+// CHECK-SAME: ins(%[[UPDATE_SLICE]], %[[INDEX_SLICE]]
+// CHECK-SAME: outs(%[[ORIGINAL]]
+
+// -----
+
+func @scatter_tiling_distribution(
+ %original: tensor<?x?xf32>, %indices: tensor<?x1xi32>,
+ %update : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg_ext.scatter
+ {__internal_linalg_transform__ = "distribute_input"}
+ ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
+ outs(%original : tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = addf %arg1, %arg2 : f32
+ linalg_ext.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 10)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK: func @scatter_tiling_distribution(
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: tensor<?x1xi32>
+// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-DAG: %[[TILESIZE:.+]] = constant 10 : index
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[UPDATES]], %[[C0]]
+// CHECK-DAG: %[[ID:.+]] = flow.dispatch.workgroup.id[0]
+// CHECK-DAG: %[[COUNT:.+]] = flow.dispatch.workgroup.count[0]
+// CHECK-DAG: %[[OFFSET:.+]] = affine.apply #[[MAP0]]()[%[[ID]]]
+// CHECK-DAG: %[[STEP:.+]] = affine.apply #[[MAP0]]()[%[[COUNT]]]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV:.+]] = %[[OFFSET]] to %[[D0]] step %[[STEP]]
+// CHECK-SAME: iter_args(%[[INIT:.+]] = %[[ORIGINAL]])
+// CHECK-DAG: %[[USED_TILESIZE:.+]] = affine.min #[[MAP1]](%[[IV]])[%[[TILESIZE]], %[[D0]]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[UPDATES]], %[[C1]]
+// CHECK: %[[UPDATE_SLICE:.+]] = tensor.extract_slice %[[UPDATES]][%[[IV]], 0]
+// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]]
+// CHECK: %[[INDEX_SLICE:.+]] = tensor.extract_slice %[[INDICES]][%[[IV]], 0]
+// CHECK-SAME: [%[[USED_TILESIZE]], 1]
+// CHECK: %[[SCATTER_TILE:.+]] = linalg_ext.scatter
+// CHECK-SAME: __internal_linalg_transform__ = "distribute_output"
+// CHECK-SAME: ins(%[[UPDATE_SLICE]], %[[INDEX_SLICE]]
+// CHECK-SAME: outs(%[[INIT]]
+// CHECK-DAG: %[[SLICE_D0:.+]] = tensor.dim %[[SCATTER_TILE]], %[[C0]]
+// CHECK-DAG: %[[SLICE_D1:.+]] = tensor.dim %[[SCATTER_TILE]], %[[C1]]
+// CHECK: %[[YIELD:.+]] = tensor.insert_slice %[[SCATTER_TILE]] into %[[INIT]][0, 0]
+// CHECK-SAME: [%[[SLICE_D0]], %[[SLICE_D1]]]
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @scatter_no_tiling(
+ %original: tensor<?x?xf32>, %indices: tensor<?x1xi32>,
+ %update : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg_ext.scatter
+ {__internal_linalg_transform__ = "no_tiling_input"}
+ ins(%update, %indices : tensor<?x?xf32>, tensor<?x1xi32>)
+ outs(%original : tensor<?x?xf32>) {
+ ^bb0(%arg1: f32, %arg2: f32):
+ %1 = addf %arg1, %arg2 : f32
+ linalg_ext.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK: func @scatter_no_tiling
+// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: tensor<?x1xi32>
+// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
+// CHECK: %[[RESULT:.+]] = linalg_ext.scatter
+// CHECK-SAME: __internal_linalg_transform__ = "no_tiling_output"
+// CHECK-SAME: ins(%[[UPDATES]], %[[INDICES]]
+// CHECK-SAME: outs(%[[ORIGINAL]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @sort_1d(%arg0: tensor<?xi32>) -> tensor<?xi32> {
+ %0 = linalg_ext.sort
+ {__internal_linalg_transform__ = "outer_reduce_input"}
+ outs(%arg0 : tensor<?xi32>) {
+ ^bb0(%arg2: i32, %arg3: i32): // no predecessors
+ %0 = cmpi sgt, %arg2, %arg3 : i32
+ linalg_ext.yield %0 : i1
+ } -> tensor<?xi32>
+ return %0 : tensor<?xi32>
+}
+// CHECK: func @sort_1d(
+// CHECK-SAME: %[[OPERAND:.+]]: tensor<?xi32>
+// CHECK: %[[RESULT:.+]] = linalg_ext.sort
+// CHECK-SAME: {__internal_linalg_transform__ = "outer_reduce_output"}
+// CHECK-SAME: outs(%[[OPERAND]] :
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @sort_2d(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
+ %0 = linalg_ext.sort dimension(1)
+ {__internal_linalg_transform__ = "inner_reduce_input"}
+ outs(%arg0 : tensor<?x?xi32>) {
+ ^bb0(%arg2: i32, %arg3: i32): // no predecessors
+ %0 = cmpi sgt, %arg2, %arg3 : i32
+ linalg_ext.yield %0 : i1
+ } -> tensor<?x?xi32>
+ return %0 : tensor<?x?xi32>
+}
+// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK: func @sort_2d(
+// CHECK-SAME: %[[OPERAND:.+]]: tensor<?x?xi32>
+// CHECK-DAG: %[[TILESIZE:.+]] = constant 10 : index
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[OPERAND]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[OPERAND]], %[[C1]]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[D0]] step %[[TILESIZE]]
+// CHECK-SAME: iter_args(%[[INIT:.+]] = %[[OPERAND]])
+// CHECK-DAG: %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D0]]]
+// CHECK: %[[OPERAND_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]], 0]
+// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]]
+// CHECK: %[[SORT_TILE:.+]] = linalg_ext.sort
+// CHECK-SAME: __internal_linalg_transform__ = "inner_reduce_output"
+// CHECK-SAME: outs(%[[OPERAND_SLICE]]
+// CHECK-DAG: %[[SLICE_D0:.+]] = tensor.dim %[[SORT_TILE]], %[[C0]]
+// CHECK-DAG: %[[SLICE_D1:.+]] = tensor.dim %[[SORT_TILE]], %[[C1]]
+// CHECK: %[[YIELD:.+]] = tensor.insert_slice %[[SORT_TILE]] into %[[INIT]][%[[IV]], 0]
+// CHECK-SAME: [%[[SLICE_D0]], %[[SLICE_D1]]]
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @sort_2d_inner_parallel(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
+ %0 = linalg_ext.sort dimension(0)
+ {__internal_linalg_transform__ = "outer_reduce_input"}
+ outs(%arg0 : tensor<?x?xi32>) {
+ ^bb0(%arg2: i32, %arg3: i32): // no predecessors
+ %0 = cmpi sgt, %arg2, %arg3 : i32
+ linalg_ext.yield %0 : i1
+ } -> tensor<?x?xi32>
+ return %0 : tensor<?x?xi32>
+}
+// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK: func @sort_2d_inner_parallel(
+// CHECK-SAME: %[[OPERAND:.+]]: tensor<?x?xi32>
+// CHECK-DAG: %[[TILESIZE:.+]] = constant 20 : index
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[OPERAND]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[OPERAND]], %[[C1]]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[D1]] step %[[TILESIZE]]
+// CHECK-SAME: iter_args(%[[INIT:.+]] = %[[OPERAND]])
+// CHECK-DAG: %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D1]]]
+// CHECK: %[[OPERAND_SLICE:.+]] = tensor.extract_slice %[[INIT]][0, %[[IV]]]
+// CHECK-SAME: [%[[D0]], %[[USED_TILESIZE]]]
+// CHECK: %[[SORT_TILE:.+]] = linalg_ext.sort
+// CHECK-SAME: __internal_linalg_transform__ = "outer_reduce_output"
+// CHECK-SAME: outs(%[[OPERAND_SLICE]]
+// CHECK-DAG: %[[SLICE_D0:.+]] = tensor.dim %[[SORT_TILE]], %[[C0]]
+// CHECK-DAG: %[[SLICE_D1:.+]] = tensor.dim %[[SORT_TILE]], %[[C1]]
+// CHECK: %[[YIELD:.+]] = tensor.insert_slice %[[SORT_TILE]] into %[[INIT]][0, %[[IV]]]
+// CHECK-SAME: [%[[SLICE_D0]], %[[SLICE_D1]]]
+// CHECK: scf.yield %[[YIELD]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func @sort_2d_multi_result(
+ %arg0: tensor<?x?xi32>, %arg1: tensor<?x?xf32>)
+ -> (tensor<?x?xi32>, tensor<?x?xf32>) {
+ %0:2 = linalg_ext.sort dimension(1)
+ {__internal_linalg_transform__ = "inner_reduce_input"}
+ outs(%arg0, %arg1 : tensor<?x?xi32>, tensor<?x?xf32>) {
+ ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32): // no predecessors
+ %1 = cmpf ogt, %arg4, %arg5 : f32
+ linalg_ext.yield %1 : i1
+ } -> tensor<?x?xi32>, tensor<?x?xf32>
+ return %0#0, %0#1 : tensor<?x?xi32>, tensor<?x?xf32>
+}
+// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK: func @sort_2d_multi_result(
+// CHECK-SAME: %[[OPERAND1:.+]]: tensor<?x?xi32>
+// CHECK-SAME: %[[OPERAND2:.+]]: tensor<?x?xf32>
+// CHECK-DAG: %[[TILESIZE:.+]] = constant 10 : index
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[OPERAND1]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[OPERAND1]], %[[C1]]
+// CHECK: %[[RESULT:.+]]:2 = scf.for %[[IV:.+]] = %[[C0]] to %[[D0]] step %[[TILESIZE]]
+// CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[OPERAND1]], %[[INIT2:.+]] = %[[OPERAND2]])
+// CHECK-DAG: %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D0]]]
+// CHECK: %[[OPERAND1_SLICE:.+]] = tensor.extract_slice %[[INIT1]][%[[IV]], 0]
+// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]]
+// CHECK: %[[OPERAND2_SLICE:.+]] = tensor.extract_slice %[[INIT2]][%[[IV]], 0]
+// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]]
+// CHECK: %[[SORT_TILE:.+]]:2 = linalg_ext.sort
+// CHECK-SAME: __internal_linalg_transform__ = "inner_reduce_output"
+// CHECK-SAME: outs(%[[OPERAND1_SLICE]], %[[OPERAND2_SLICE]]
+// CHECK-DAG: %[[SLICE_D0:.+]] = tensor.dim %[[SORT_TILE]]#0, %[[C0]]
+// CHECK-DAG: %[[SLICE_D1:.+]] = tensor.dim %[[SORT_TILE]]#0, %[[C1]]
+// CHECK: %[[YIELD1:.+]] = tensor.insert_slice %[[SORT_TILE]]#0 into %[[INIT1]][%[[IV]], 0]
+// CHECK-SAME: [%[[SLICE_D0]], %[[SLICE_D1]]]
+// CHECK-DAG: %[[SLICE_D0_0:.+]] = tensor.dim %[[SORT_TILE]]#1, %[[C0]]
+// CHECK-DAG: %[[SLICE_D1_0:.+]] = tensor.dim %[[SORT_TILE]]#1, %[[C1]]
+// CHECK: %[[YIELD2:.+]] = tensor.insert_slice %[[SORT_TILE]]#1 into %[[INIT2]][%[[IV]], 0]
+// CHECK-SAME: [%[[SLICE_D0_0]], %[[SLICE_D1_0]]]
+// CHECK: scf.yield %[[YIELD1]], %[[YIELD2]]
+// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
+func @sort_2d_multi_result_memref(
+ %arg0: memref<?x?xi32>, %arg1: memref<?x?xf32>) {
+ linalg_ext.sort dimension(0)
+ {__internal_linalg_transform__ = "outer_reduce_input"}
+ outs(%arg0, %arg1 : memref<?x?xi32>, memref<?x?xf32>) {
+ ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32): // no predecessors
+ %0 = cmpf ogt, %arg4, %arg5 : f32
+ linalg_ext.yield %0 : i1
+ }
+ return
+}
+// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK: func @sort_2d_multi_result_memref(
+// CHECK-SAME: %[[OPERAND1:.+]]: memref<?x?xi32>
+// CHECK-SAME: %[[OPERAND2:.+]]: memref<?x?xf32>
+// CHECK-DAG: %[[TILESIZE:.+]] = constant 20 : index
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[D0:.+]] = memref.dim %[[OPERAND1]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = memref.dim %[[OPERAND1]], %[[C1]]
+// CHECK: scf.for %[[IV:.+]] = %[[C0]] to %[[D1]] step %[[TILESIZE]]
+// CHECK-DAG: %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D1]]]
+// CHECK: %[[OPERAND1_SLICE:.+]] = memref.subview %[[OPERAND1]][0, %[[IV]]]
+// CHECK-SAME: [%[[D0]], %[[USED_TILESIZE]]]
+// CHECK: %[[OPERAND2_SLICE:.+]] = memref.subview %[[OPERAND2]][0, %[[IV]]]
+// CHECK-SAME: [%[[D0]], %[[USED_TILESIZE]]]
+// CHECK: linalg_ext.sort
+// CHECK-SAME: __internal_linalg_transform__ = "outer_reduce_output"
+// CHECK-SAME: outs(%[[OPERAND1_SLICE]], %[[OPERAND2_SLICE]]
+
+// -----
+
+func @sort_3d_multi_result_distribute(
+ %arg0: tensor<?x?x?xi32>, %arg1 : tensor<?x?x?xf32>)
+ -> (tensor<?x?x?xi32>, tensor<?x?x?xf32>) {
+ %0, %1 = linalg_ext.sort dimension(1)
+ {__internal_linalg_transform__ = "distribute_input"}
+ outs(%arg0, %arg1 : tensor<?x?x?xi32>, tensor<?x?x?xf32>) {
+ ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32): // no predecessors
+ %2 = cmpf ogt, %arg4, %arg5 : f32
+ linalg_ext.yield %2 : i1
+ } -> tensor<?x?x?xi32>, tensor<?x?x?xf32>
+ return %0, %1 : tensor<?x?x?xi32>, tensor<?x?x?xf32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 10)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 30)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)>
+// CHECK: func @sort_3d_multi_result_distribute(
+// CHECK-SAME: %[[OPERAND1:[a-zA-Z0-9_]+]]: tensor<?x?x?xi32>
+// CHECK-SAME: %[[OPERAND2:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
+// CHECK-DAG: %[[TILESIZE1:.+]] = constant 10 : index
+// CHECK-DAG: %[[TILESIZE2:.+]] = constant 30 : index
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[OPERAND1]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[OPERAND1]], %[[C1]]
+// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[OPERAND1]], %[[C2]]
+// CHECK-DAG: %[[IDX:.+]] = flow.dispatch.workgroup.id[0]
+// CHECK-DAG: %[[COUNTX:.+]] = flow.dispatch.workgroup.count[0]
+// CHECK-DAG: %[[IDY:.+]] = flow.dispatch.workgroup.id[1]
+// CHECK-DAG: %[[COUNTY:.+]] = flow.dispatch.workgroup.count[1]
+// CHECK-DAG: %[[OFFSETY:.+]] = affine.apply #[[MAP0]]()[%[[IDY]]]
+// CHECK-DAG: %[[STEPY:.+]] = affine.apply #[[MAP0]]()[%[[COUNTY]]]
+// CHECK: %[[RESULT:.+]]:2 = scf.for %[[IV0:.+]] = %[[OFFSETY]] to %[[D0]] step %[[STEPY]]
+// CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[OPERAND1]], %[[INIT2:.+]] = %[[OPERAND2]])
+// CHECK-DAG: %[[USED_TILESIZE1:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[TILESIZE1]], %[[D0]]]
+// CHECK-DAG: %[[OFFSETX:.+]] = affine.apply #[[MAP2]]()[%[[IDX]]]
+// CHECK-DAG: %[[STEPX:.+]] = affine.apply #[[MAP2]]()[%[[COUNTX]]]
+// CHECK: %[[RESULT_INNER:.+]]:2 = scf.for %[[IV1:.+]] = %[[OFFSETX]] to %[[D2]] step %[[STEPX]]
+// CHECK-SAME: iter_args(%[[INIT3:.+]] = %[[INIT1]], %[[INIT4:.+]] = %[[INIT2]])
+// CHECK-DAG: %[[USED_TILESIZE2:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[TILESIZE2]], %[[D2]]]
+// CHECK: %[[OPERAND1_SLICE:.+]] = tensor.extract_slice %[[INIT3]][%[[IV0]], 0, %[[IV1]]]
+// CHECK-SAME: [%[[USED_TILESIZE1]], %[[D1]], %[[USED_TILESIZE2]]]
+// CHECK: %[[OPERAND2_SLICE:.+]] = tensor.extract_slice %[[INIT4]][%[[IV0]], 0, %[[IV1]]]
+// CHECK-SAME: [%[[USED_TILESIZE1]], %[[D1]], %[[USED_TILESIZE2]]]
+// CHECK: %[[SORT_SLICE:.+]]:2 = linalg_ext.sort
+// CHECK-SAME: __internal_linalg_transform__ = "distribute_output"
+// CHECK-SAME: outs(%[[OPERAND1_SLICE]], %[[OPERAND2_SLICE]]
+// CHECK: %[[YIELD1:.+]] = tensor.insert_slice %[[SORT_SLICE]]#0
+// CHECK-SAME: into %[[INIT3]][%[[IV0]], 0, %[[IV1]]]
+// CHECK: %[[YIELD2:.+]] = tensor.insert_slice %[[SORT_SLICE]]#1
+// CHECK-SAME: into %[[INIT4]][%[[IV0]], 0, %[[IV1]]]
+// CHECK: scf.yield %[[YIELD1]], %[[YIELD2]]
+// CHECK: scf.yield %[[RESULT_INNER]]#0, %[[RESULT_INNER]]#1
+// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
+func @sort_3d_multi_result_distribute_memref(
+ %arg0: memref<?x?x?xi32>, %arg1 : memref<?x?x?xf32>) {
+ linalg_ext.sort dimension(1)
+ {__internal_linalg_transform__ = "distribute_input"}
+ outs(%arg0, %arg1 : memref<?x?x?xi32>, memref<?x?x?xf32>) {
+ ^bb0(%arg2: i32, %arg3: i32, %arg4 : f32, %arg5 : f32): // no predecessors
+ %0 = cmpf ogt, %arg4, %arg5 : f32
+ linalg_ext.yield %0 : i1
+ }
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 10)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 30)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)>
+// CHECK: func @sort_3d_multi_result_distribute_memref(
+// CHECK-SAME: %[[OPERAND1:[a-zA-Z0-9_]+]]: memref<?x?x?xi32>
+// CHECK-SAME: %[[OPERAND2:[a-zA-Z0-9_]+]]: memref<?x?x?xf32>
+// CHECK-DAG: %[[TILESIZE1:.+]] = constant 10 : index
+// CHECK-DAG: %[[TILESIZE2:.+]] = constant 30 : index
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[D0:.+]] = memref.dim %[[OPERAND1]], %[[C0]]
+// CHECK-DAG: %[[D1:.+]] = memref.dim %[[OPERAND1]], %[[C1]]
+// CHECK-DAG: %[[D2:.+]] = memref.dim %[[OPERAND1]], %[[C2]]
+// CHECK-DAG: %[[IDX:.+]] = flow.dispatch.workgroup.id[0]
+// CHECK-DAG: %[[COUNTX:.+]] = flow.dispatch.workgroup.count[0]
+// CHECK-DAG: %[[IDY:.+]] = flow.dispatch.workgroup.id[1]
+// CHECK-DAG: %[[COUNTY:.+]] = flow.dispatch.workgroup.count[1]
+// CHECK-DAG: %[[OFFSETY:.+]] = affine.apply #[[MAP0]]()[%[[IDY]]]
+// CHECK-DAG: %[[STEPY:.+]] = affine.apply #[[MAP0]]()[%[[COUNTY]]]
+// CHECK: scf.for %[[IV0:.+]] = %[[OFFSETY]] to %[[D0]] step %[[STEPY]]
+// CHECK-DAG: %[[USED_TILESIZE1:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[TILESIZE1]], %[[D0]]]
+// CHECK-DAG: %[[OFFSETX:.+]] = affine.apply #[[MAP2]]()[%[[IDX]]]
+// CHECK-DAG: %[[STEPX:.+]] = affine.apply #[[MAP2]]()[%[[COUNTX]]]
+// CHECK: scf.for %[[IV1:.+]] = %[[OFFSETX]] to %[[D2]] step %[[STEPX]]
+// CHECK-DAG: %[[USED_TILESIZE2:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[TILESIZE2]], %[[D2]]]
+// CHECK: %[[OPERAND1_SLICE:.+]] = memref.subview %[[OPERAND1]][%[[IV0]], 0, %[[IV1]]]
+// CHECK-SAME: [%[[USED_TILESIZE1]], %[[D1]], %[[USED_TILESIZE2]]]
+// CHECK: %[[OPERAND2_SLICE:.+]] = memref.subview %[[OPERAND2]][%[[IV0]], 0, %[[IV1]]]
+// CHECK-SAME: [%[[USED_TILESIZE1]], %[[D1]], %[[USED_TILESIZE2]]]
+// CHECK: linalg_ext.sort
+// CHECK-SAME: __internal_linalg_transform__ = "distribute_output"
+// CHECK-SAME: outs(%[[OPERAND1_SLICE]], %[[OPERAND2_SLICE]]