Implement tiling method for linalg_ext.reverse ops. (#7159)
To tile a linalg_ext.reverse op, we have to store the tiled result to
mirror offsets. E.g.,
Before:
```
[T_0], [T_1], [T_2], [T_3]
```
After:
```
[Rev_t3], [Rev_t2], [Rev_t1], [Rev_t0]
```
The snippet after DispatchLinalgOnTensors, which shows that the op is
tiled and distributed:
```mlir
func private @_foo(%arg0: tensor<?xi32>, %arg1: !shapex.ranked_shape<[?]>) -> (tensor<?xi32>, !shapex.ranked_shape<[?]>) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = shapex.ranked_dim %arg1[0] : !shapex.ranked_shape<[?]> -> index
%1 = shapex.tie_shape %arg0, %arg1 : tensor<?xi32>, !shapex.ranked_shape<[?]>
%2 = linalg.init_tensor [%0] : tensor<?xi32>
%3 = tensor.dim %1, %c0 : tensor<?xi32>
%4 = tensor.dim %2, %c0 : tensor<?xi32>
%5 = tensor.dim %1, %c0 : tensor<?xi32>
%6 = flow.dispatch.workgroups[%3, %c1, %c1](%1, %0) : (tensor<?xi32>{%5}, index) -> tensor<?xi32>{%4} =
(%arg2: !flow.dispatch.tensor<readonly:?xi32>, %arg3: index, %arg4: !flow.dispatch.tensor<writeonly:?xi32>) {
%c0_0 = constant 0 : index
%8 = flow.dispatch.tensor.load %arg2, offsets = [], sizes = [], strides = [] : !flow.dispatch.tensor<readonly:?xi32> -> tensor<?xi32>
%9 = linalg.init_tensor [%arg3] : tensor<?xi32>
%workgroup_size_0 = flow.dispatch.workgroup.size[0] : index
%10 = tensor.dim %8, %c0_0 : tensor<?xi32>
%workgroup_id_0 = flow.dispatch.workgroup.id[0] : index
%workgroup_count_0 = flow.dispatch.workgroup.count[0] : index
%11 = affine.apply affine_map<(d0)[s0] -> (d0 * s0)>(%workgroup_id_0)[%workgroup_size_0]
%12 = affine.apply affine_map<(d0)[s0] -> (d0 * s0)>(%workgroup_count_0)[%workgroup_size_0]
scf.for %arg5 = %11 to %10 step %12 {
%13 = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%arg5)[%workgroup_size_0, %10]
%14 = flow.dispatch.tensor.load %arg2, offsets = [%arg5], sizes = [%13], strides = [1] : !flow.dispatch.tensor<readonly:?xi32> -> tensor<?xi32>
%15 = tensor.dim %8, %c0_0 : tensor<?xi32>
%16 = subi %15, %arg5 : index
%17 = subi %16, %13 : index
%18 = tensor.extract_slice %9[%arg5] [%13] [1] : tensor<?xi32> to tensor<?xi32>
%19 = linalg_ext.reverse dimensions(dense<0> : tensor<1xi64>) {__internal_linalg_transform__ = "workgroup"} ins(%14 : tensor<?xi32>) outs(%18 : tensor<?xi32>) : tensor<?xi32>
flow.dispatch.tensor.store %19, %arg4, offsets = [%17], sizes = [%13], strides = [1] : tensor<?xi32> -> !flow.dispatch.tensor<writeonly:?xi32>
}
flow.return
}
%7 = shapex.get_ranked_shape %6 : tensor<?xi32> -> !shapex.ranked_shape<[?]>
return %6, %7 : tensor<?xi32>, !shapex.ranked_shape<[?]>
}
```
The `print-ir-after-all` log can be found at:
https://gist.githubusercontent.com/hanhanW/bba1d36149f37d948e8cea5fe95a6287/raw
This is a step toward https://github.com/google/iree/issues/5045
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 7497397..c22444c 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -12,11 +12,13 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/SMLoc.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Attributes.h"
@@ -794,9 +796,8 @@
bool ReverseOp::payloadUsesValueFromOperand(OpOperand *) { return false; }
SmallVector<StringRef> ReverseOp::getLoopIteratorTypes() {
- // TODO(hanchung): Mark them parallel after tiling method is implemented.
SmallVector<StringRef> iteratorTypes(getOperandRank(),
- getReductionIteratorTypeName());
+ getParallelIteratorTypeName());
return iteratorTypes;
}
@@ -826,6 +827,56 @@
return success();
}
+Operation *ReverseOp::getTiledImplementation(OpBuilder &builder,
+ ValueRange outputs,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<Value> &results) {
+ int64_t rank = getOperandRank();
+ SmallVector<OpFoldResult> strides(rank, builder.getI64IntegerAttr(1));
+ Location loc = getLoc();
+ SmallVector<Value> tiledOperands;
+ tiledOperands.emplace_back(
+ getSlice(builder, loc, input(), offsets, sizes, strides));
+
+ AffineExpr sym0, sym1, sym2;
+ bindSymbols(builder.getContext(), sym0, sym1, sym2);
+ AffineMap map =
+ AffineMap::get(/*dimCount=*/0, /*symbolCount=*/3, {sym0 - sym1 - sym2});
+ SmallVector<OpFoldResult> mirrorOffsets(offsets.begin(), offsets.end());
+ for (auto dim : dims()) {
+ Value size = getDimValue(builder, loc, input(), dim);
+ Value offset =
+ getValueOrCreateConstantIndexOp(builder, loc, mirrorOffsets[dim]);
+ Value tileSize = getValueOrCreateConstantIndexOp(builder, loc, sizes[dim]);
+ mirrorOffsets[dim] =
+ builder
+ .create<AffineApplyOp>(loc, map, ValueRange{size, offset, tileSize})
+ .getResult();
+ }
+
+ SmallVector<Type, 4> resultTypes;
+ if (hasTensorSemantics()) {
+ tiledOperands.emplace_back(
+ getSlice(builder, loc, output(), mirrorOffsets, sizes, strides));
+ resultTypes.push_back(tiledOperands[1].getType());
+ } else {
+ tiledOperands.emplace_back(
+ getSlice(builder, loc, output(), mirrorOffsets, sizes, strides));
+ }
+
+ Operation *tiledRevOp = cast<LinalgExtOp>(getOperation())
+ .clone(builder, loc, resultTypes, tiledOperands);
+
+ for (auto result : llvm::enumerate(tiledRevOp->getResults())) {
+ auto insertSliceOp = builder.create<tensor::InsertSliceOp>(
+ loc, result.value(), outputs[result.index()], mirrorOffsets, sizes,
+ strides);
+ results.push_back(insertSliceOp.getResult());
+ }
+ return tiledRevOp;
+}
+
#define DEFINE_OP_GET_EFFECTS(OP_NAME) \
void OP_NAME::getEffects( \
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \
diff --git a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index 9894831..c98d9de 100644
--- a/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -246,7 +246,9 @@
}
def LinalgExt_ReverseOp : LinalgExt_Op<"reverse", [
- DeclareOpInterfaceMethods<TiledOpInterface, ["generateScalarImplementation"]>,
+ DeclareOpInterfaceMethods<
+ TiledOpInterface,
+ ["generateScalarImplementation", "getTiledImplementation"]>,
DeclareOpInterfaceMethods<LinalgExtInterface,
// ReverseOp does not have a region, so we have to
// overwrite the method.
diff --git a/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir b/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
index a848535..21c3596 100644
--- a/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
+++ b/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
@@ -598,3 +598,88 @@
// CHECK-SAME: {__internal_linalg_transform__ = "tiling_1d_stage5_fft_output"}
// CHECK-SAME: ins(%[[C5]], %[[COEF_REAL]], %[[COEF_IMAG]] : index, memref<16xf32>, memref<16xf32>)
// CHECK-SAME: outs(%[[SUB1]], %[[SUB2]] : memref<?xf32, #[[MAP1]]>, memref<?xf32, #[[MAP1]]>)
+
+// -----
+
+func @reverse_memref(%arg0: memref<?xi32>, %arg1: memref<?xi32>) {
+ linalg_ext.reverse
+ dimensions(dense<0> : tensor<1xi64>)
+ {__internal_linalg_transform__ = "tiling_input"}
+ ins(%arg0: memref<?xi32>)
+ outs(%arg1: memref<?xi32>)
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> (s0 - s1 - s2)>
+// CHECK: func @reverse_memref(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C10:.+]] = constant 10 : index
+// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C0]] : memref<?xi32>
+// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[D0]] step %[[C10]] {
+// CHECK: %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]])[%[[C10]], %[[D0]]]
+// CHECK: %[[SUB_IN:.+]] = memref.subview %[[ARG0]][%[[I]]] [%[[SIZE]]] [1]
+// CHECK: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C0]] : memref<?xi32>
+// CHECK: %[[IDX:.+]] = affine.apply #[[MAP2]]()[%[[T0]], %[[I]], %[[SIZE]]]
+// CHECK: %[[SUB_OUT:.+]] = memref.subview %[[ARG1]][%[[IDX]]] [%[[SIZE]]] [1]
+// CHECK: linalg_ext.reverse
+// CHECK-SAME: dimensions(dense<0> : tensor<1xi64>)
+// CHECK-SAME: {__internal_linalg_transform__ = "tiling_output"}
+// CHECK-SAME: ins(%[[SUB_IN]]
+// CHECK-SAME: outs(%[[SUB_OUT]]
+
+// -----
+
+func @reverse_tensor_multi_dim(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %d0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
+ %d1 = tensor.dim %arg0, %c1 : tensor<?x?xi32>
+ %init = linalg.init_tensor [%d0, %d1] : tensor<?x?xi32>
+ %0 = linalg_ext.reverse
+ dimensions(dense<[0, 1]> : tensor<2xi64>)
+ {__internal_linalg_transform__ = "tiling_input"}
+ ins(%arg0: tensor<?x?xi32>)
+ outs(%init: tensor<?x?xi32>) : tensor<?x?xi32>
+ return %0 : tensor<?x?xi32>
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> (s0 - s1 - s2)>
+// CHECK: func @reverse_tensor_multi_dim(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C10:.+]] = constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = constant 20 : index
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xi32>
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xi32>
+// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]] : tensor<?x?xi32>
+// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xi32>
+// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xi32>
+// CHECK: %[[RES:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[D0]] step %[[C10]]
+// CHECK-SAME: iter_args(%[[INIT2:.+]] = %[[INIT]]) -> (tensor<?x?xi32>) {
+// CHECK: %[[SIZE_I:.+]] = affine.min #[[MAP0]](%[[I]])[%[[C10]], %[[D0]]]
+// CHECK: %[[RES2:.+]] = scf.for %[[J:.+]] = %[[C0]] to %[[D1]] step %[[C20]]
+// CHECK-SAME: iter_args(%[[INIT3:.+]] = %[[INIT2]]) -> (tensor<?x?xi32>) {
+// CHECK: %[[SIZE_J:.+]] = affine.min #[[MAP1]](%[[J]])[%[[C20]], %[[D1]]]
+// CHECK: %[[SUB_IN:.+]] = tensor.extract_slice
+// CHECK-SAME: %[[ARG0]][%[[I]], %[[J]]] [%[[SIZE_I]], %[[SIZE_J]]] [1, 1]
+// CHECK: %[[T0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xi32>
+// CHECK: %[[IDX0:.+]] = affine.apply #[[MAP2]]()[%[[T0]], %[[I]], %[[SIZE_I]]]
+// CHECK: %[[T1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xi32>
+// CHECK: %[[IDX1:.+]] = affine.apply #[[MAP2]]()[%[[T1]], %[[J]], %[[SIZE_J]]]
+// CHECK: %[[SUB_INIT:.+]] = tensor.extract_slice
+// CHECK-SAME: %[[INIT]][%[[IDX0]], %[[IDX1]]] [%[[SIZE_I]], %[[SIZE_J]]] [1, 1]
+// CHECK: %[[REV:.+]] = linalg_ext.reverse
+// CHECK-SAME: dimensions(dense<[0, 1]> : tensor<2xi64>)
+// CHECK-SAME: {__internal_linalg_transform__ = "tiling_output"}
+// CHECK-SAME: ins(%[[SUB_IN]]
+// CHECK-SAME: outs(%[[SUB_INIT]]
+// CHECK: %[[RES3:.+]] = tensor.insert_slice %[[REV]] into
+// CHECK-SAME: %[[INIT3]][%[[IDX0]], %[[IDX1]]] [%[[SIZE_I]], %[[SIZE_J]]] [1, 1]
+// CHECK: scf.yield %[[RES3]]
+// CHECK: scf.yield %[[RES2]]
+// CHECK: return %[[RES]]