Introduce reverse op to LinalgExt dialect. (#7124)
Defines linalg_ext.reverse op and implements basic interface methods of
the op.
The reverse op in Linalg has Linalg unfriendly affine exprs, which
blocks tile and distribute transforms. This is a short/mid term solution
for tiling and distributing reverse ops.
It's a step toward https://github.com/google/iree/issues/5045
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index a06b026..4ae4556 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -759,6 +759,69 @@
return tiledFftOp;
}
+//===----------------------------------------------------------------------===//
+// ReverseOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyReverseOp(ReverseOp op) {
+ if (op.getNumInputs()) {
+ return op.emitOpError("expected no inputs");
+ }
+ if (op.getNumOutputs() != 1) {
+ return op.emitOpError("expected exactly one output");
+ }
+
+ int64_t rank = op.getOperandRank();
+ int dimension = op.dimension();
+ if (dimension < 0 || dimension >= rank) {
+ return op.emitOpError("dimension must be within (0, ") << rank << "]";
+ }
+
+ return success();
+}
+
+bool ReverseOp::payloadUsesValueFromOperand(OpOperand *) { return false; }
+
+SmallVector<StringRef> ReverseOp::getLoopIteratorTypes() {
+ SmallVector<StringRef> iteratorTypes(getOperandRank(),
+ getParallelIteratorTypeName());
+ return iteratorTypes;
+}
+
+SmallVector<Range> ReverseOp::getLoopBounds(OpBuilder &builder) {
+ Location loc = getLoc();
+ Value zero = builder.create<ConstantIndexOp>(loc, 0);
+ Value one = builder.create<ConstantIndexOp>(loc, 1);
+ SmallVector<Range> ranges;
+ for (auto dim : llvm::seq<int64_t>(0, getOperandRank())) {
+ Value ub = getDimValue(builder, loc, operand(), dim);
+ ranges.emplace_back(Range{zero, ub, one});
+ }
+ auto dim = dimension();
+ ranges[dim].size = builder.create<SignedDivIOp>(
+ loc, ranges[dim].size, builder.create<ConstantIndexOp>(loc, 2));
+ return ranges;
+}
+
+LogicalResult ReverseOp::generateScalarImplementation(OpBuilder &b,
+ Location loc,
+ ValueRange ivs) {
+ SmallVector<Value> mirrorIndices(ivs.begin(), ivs.end());
+ auto dim = dimension();
+ auto size = getDimValue(b, loc, operand(), dim);
+ size = b.create<SubIOp>(loc, size, b.create<ConstantIndexOp>(loc, 1));
+ mirrorIndices[dim] = b.create<SubIOp>(loc, size, mirrorIndices[dim]);
+
+ // for (int i = 0; i < n / 2; ++i) {
+ // swap(array[i], array[n - 1 - i]);
+ // }
+ Value v1 = b.create<memref::LoadOp>(loc, operand(), ivs);
+ Value v2 = b.create<memref::LoadOp>(loc, operand(), mirrorIndices);
+ b.create<memref::StoreOp>(loc, v1, operand(), mirrorIndices);
+ b.create<memref::StoreOp>(loc, v2, operand(), ivs);
+ return success();
+}
+
#define DEFINE_OP_GET_EFFECTS(OP_NAME) \
void OP_NAME::getEffects( \
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \
@@ -772,6 +835,7 @@
DEFINE_OP_GET_EFFECTS(ScatterOp)
DEFINE_OP_GET_EFFECTS(SortOp)
DEFINE_OP_GET_EFFECTS(FftOp)
+DEFINE_OP_GET_EFFECTS(ReverseOp)
} // namespace linalg_ext
} // namespace iree_compiler
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index fe847fb..9841da1 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -245,6 +245,49 @@
}];
}
+def LinalgExt_ReverseOp : LinalgExt_Op<"reverse", [
+ DeclareOpInterfaceMethods<TiledOpInterface, ["generateScalarImplementation"]>,
+ DeclareOpInterfaceMethods<LinalgExtInterface,
+ // ReverseOp does not have a region, so we have to
+ // overwrite the method.
+ ["payloadUsesValueFromOperand"]>]> {
+ let summary = "Reverse operator";
+ let description = [{
+ A temporary solution of a reverse op. The loop bound of the reverse
+ dimension is half of the shape because we can simply swap elements. E.g.,
+
+ for (int i = 0; i < n / 2; ++i) {
+ std::swap(a[i], a[n - 1 - i]);
+ }
+ }];
+
+ let arguments = (ins Variadic<AnyShaped>:$inputs,
+ Variadic<AnyShaped>:$outputs,
+ I64Attr:$dimension
+ );
+ let results = (outs Variadic<AnyRankedTensor>:$results);
+ let assemblyFormat = [{
+ `dimension` `(` $dimension `)`
+ attr-dict (`ins` `(` $inputs^ `:` type($inputs) `)`)?
+ `outs` `(` $outputs `:` type($outputs) `)`
+ (`:` type($results)^)?
+ }];
+ let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
+ Value operand() {
+ return getOutputOperand(0)->get();
+}
+ ShapedType getOperandType() {
+ return operand().getType().cast<ShapedType>();
+ }
+ int64_t getOperandRank() {
+ return getOperandType().getRank();
+ }
+ ArrayRef<int64_t> getOperandShape() {
+ return getOperandType().getShape();
+ }
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Pure ops
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir b/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
index b2deb0a..e488bc1 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
+++ b/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
@@ -387,3 +387,42 @@
// CHECK-SAME: outs(%[[REAL]], %[[IMAG]] : tensor<1024xf32>, tensor<1024xf32>)
// CHECK-SAME: : tensor<1024xf32>, tensor<1024xf32>
// CHECK: return %[[RES]]#0, %[[RES]]#1
+
+// -----
+
+func @reverse_tensor(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> {
+ %0 = linalg_ext.reverse
+ dimension(0)
+ outs(%arg0 : tensor<3x5xi32>) : tensor<3x5xi32>
+ return %0 : tensor<3x5xi32>
+}
+// CHECK-LABEL: func @reverse_tensor
+// CHECK-SAME: %[[ARG0:.+]]: tensor<3x5xi32>
+// CHECK: %[[RESULT:.+]] = linalg_ext.reverse dimension(0)
+// CHECK-SAME: outs(%[[ARG0]]
+
+// -----
+
+func @reverse_memref(%arg0: memref<3x5xi32>) {
+ linalg_ext.reverse
+ dimension(0)
+ outs(%arg0 : memref<3x5xi32>)
+ return
+}
+// CHECK-LABEL: func @reverse_memref
+// CHECK-SAME: %[[ARG0:.+]]: memref<3x5xi32>
+// CHECK: linalg_ext.reverse dimension(0)
+// CHECK-SAME: outs(%[[ARG0]]
+
+// -----
+
+func @reverse_dynamic_tensor(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
+ %0 = linalg_ext.reverse
+ dimension(1)
+ outs(%arg0 : tensor<?x?xi32>) : tensor<?x?xi32>
+ return %0 : tensor<?x?xi32>
+}
+// CHECK-LABEL: func @reverse_dynamic_tensor
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>
+// CHECK: %[[RESULT:.+]] = linalg_ext.reverse dimension(1)
+// CHECK-SAME: outs(%[[ARG0]]
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir b/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
index 423faca..84883b9 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
@@ -481,3 +481,29 @@
// CHECK: %[[RES3:.+]] = subf %[[L_REAL]], %[[T_REAL]]
// CHECK: %[[RES4:.+]] = subf %[[L_IMAG]], %[[T_IMAG]]
// CHECK: linalg.yield %[[RES1]], %[[RES2]], %[[RES3]], %[[RES4]]
+
+// -----
+
+func @reverse_dim_0(%arg0: memref<?x?xi32>) {
+ linalg_ext.reverse
+ dimension(0)
+ outs(%arg0 : memref<?x?xi32>)
+ return
+}
+// CHECK-LABEL: func @reverse_dim_0
+// CHECK-SAME: %[[BUF:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C2:.+]] = constant 2 : index
+// CHECK-DAG: %[[D0:.+]] = memref.dim %arg0, %c0 : memref<?x?xi32>
+// CHECK-DAG: %[[D1:.+]] = memref.dim %arg0, %c1 : memref<?x?xi32>
+// CHECK-DAG: %[[REV_UB:.+]] = divi_signed %[[D0]], %[[C2]] : index
+// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[REV_UB]] step %[[C1]]
+// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[D1]] step %[[C1]]
+// CHECK: %[[T0:.+]] = memref.dim %[[BUF]], %[[C0]]
+// CHECK: %[[T1:.+]] = subi %[[T0]], %[[C1]] : index
+// CHECK: %[[T2:.+]] = subi %[[T1]], %[[I]] : index
+// CHECK: %[[V0:.+]] = memref.load %[[BUF]][%[[I]], %[[J]]]
+// CHECK: %[[V1:.+]] = memref.load %[[BUF]][%[[T2]], %[[J]]]
+// CHECK: memref.store %[[V0]], %[[BUF]][%[[T2]], %[[J]]] : memref<?x?xi32>
+// CHECK: memref.store %[[V1]], %[[BUF]][%[[I]], %[[J]]] : memref<?x?xi32>