[LinalgExt] Retire `LinalgExt::ReverseOp` (#17866)
`LinalgExt::ReverseOp` is only lowered from `stablehlo::ReverseOp`. We
can expand `stablehlo::ReverseOp` to a different pattern and retire
`LinalgExt::ReverseOp`.
Fixes https://github.com/iree-org/iree/issues/16060
---------
Signed-off-by: Alan Li <me@alanli.org>
diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp
index a20f872..cc7fa2b 100644
--- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp
@@ -7,6 +7,7 @@
// Implements IREE-specific logic for lowering StableHLO/CHLO dialects to
// LinalgExt dialect.
+#include <algorithm>
#include <cmath>
#include <complex>
#include <memory>
@@ -427,7 +428,6 @@
struct ReverseOpConversion final
: OpConversionPattern<mlir::stablehlo::ReverseOp> {
using OpConversionPattern::OpConversionPattern;
-
LogicalResult
matchAndRewrite(mlir::stablehlo::ReverseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -435,14 +435,45 @@
if (!ty)
return failure();
+ Value input = op.getOperand();
+ auto inputTy = cast<ShapedType>(input.getType());
+ auto resultTy = cast<ShapedType>(op.getType());
+ ArrayRef<int64_t> dims = op.getDimensions();
Location loc = op.getLoc();
- SmallVector<OpFoldResult> mixedSizes =
- tensor::getMixedSizes(rewriter, loc, adaptor.getOperands()[0]);
- Value emptyTensor =
- rewriter.create<tensor::EmptyOp>(loc, mixedSizes, ty.getElementType());
- rewriter.replaceOpWithNewOp<IREE::LinalgExt::ReverseOp>(
- op, typeConverter->convertType(op.getType()), adaptor.getOperands(),
- emptyTensor, rewriter.getI64TensorAttr(op.getDimensions()));
+ int64_t inputTyRank = inputTy.getRank();
+
+ // First fill the output buffer with the init value.
+ SmallVector<OpFoldResult> inputMixedSizes =
+ tensor::getMixedSizes(rewriter, loc, input);
+ auto emptyTensor = rewriter.create<tensor::EmptyOp>(
+ loc, inputMixedSizes, inputTy.getElementType());
+ SmallVector<AffineMap> affineMaps = {
+ rewriter.getMultiDimIdentityMap(resultTy.getRank())};
+
+ rewriter.replaceOpWithNewOp<linalg::GenericOp>(
+ op, resultTy, ArrayRef<Value>({}), ValueRange{emptyTensor}, affineMaps,
+ getNParallelLoopsAttrs(resultTy.getRank()),
+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+ llvm::SmallVector<Value> indices;
+ for (unsigned int i = 0; i < inputTyRank; i++) {
+ Value index =
+ rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
+ if (std::find(dims.begin(), dims.end(), i) != dims.end()) {
+ auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
+ Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, i);
+ auto sizeMinusOne =
+ rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
+ index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
+ index);
+ }
+ indices.push_back(index);
+ }
+
+ auto extract = nestedBuilder.create<tensor::ExtractOp>(
+ nestedLoc, input, indices);
+ nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
+ extract.getResult());
+ });
return success();
}
};
diff --git a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_ext.mlir b/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_ext.mlir
index 917f2f8..09b2bd4 100644
--- a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_ext.mlir
+++ b/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_ext.mlir
@@ -495,12 +495,17 @@
return %0 : tensor<3x5xi32>
}
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<3x5xi32>
-// CHECK: %[[REV:.+]] = iree_linalg_ext.reverse
-// CHECK-SAME: dimensions(dense<1> : tensor<1xi64>)
-// CHECK-SAME: ins(%[[IN]] : tensor<3x5xi32>)
-// CHECK-SAME: outs(%[[INIT]] : tensor<3x5xi32>) : tensor<3x5xi32>
-// CHECK: return %[[REV]]
-
+// CHECK: %[[GEN:.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor<3x5xi32>) {
+// CHECK: %[[SAME_DIM:.+]] = linalg.index 0 : index
+// CHECK: %[[REV_DIM:.+]] = linalg.index 1 : index
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[C1_0:.+]] = arith.constant 1 : index
+// CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C1_0]] : tensor<3x5xi32>
+// CHECK: %[[DIMSUB1:.+]] = arith.subi %[[DIM]], %[[C1]] : index
+// CHECK: %[[REV_IDX:.+]] = arith.subi %[[DIMSUB1]], %[[REV_DIM]] : index
+// CHECK: %[[EXTRACTED:.+]] = tensor.extract %arg0[%[[SAME_DIM]], %[[REV_IDX]]] : tensor<3x5xi32>
+// CHECK: linalg.yield %[[EXTRACTED]] : i32
+// CHECK: return %[[GEN]]
// -----
func.func @reverse_unsigned(%arg0: tensor<3x5xui32>) -> tensor<3x5xui32> {
@@ -512,13 +517,18 @@
// CHECK-LABEL: func.func @reverse_unsigned
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
// CHECK: %[[BITCAST:.+]] = builtin.unrealized_conversion_cast %[[IN]] : tensor<3x5xui32> to tensor<3x5xi32>
-// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<3x5xi32>
-// CHECK: %[[REV:.+]] = iree_linalg_ext.reverse
-// CHECK-SAME: dimensions(dense<1> : tensor<1xi64>)
-// CHECK-SAME: ins(%[[BITCAST]] : tensor<3x5xi32>)
-// CHECK-SAME: outs(%[[INIT]] : tensor<3x5xi32>) : tensor<3x5xi32>
-// CHECK: %[[BITCAST:.+]] = builtin.unrealized_conversion_cast %[[REV]] : tensor<3x5xi32> to tensor<3x5xui32>
-// CHECK: return %[[BITCAST]]
+// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<3x5xui32>
+// CHECK: %[[GEN:.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor<3x5xui32>)
+// CHECK: %[[SAME_DIM:.+]] = linalg.index 0 : index
+// CHECK: %[[REV_DIM:.+]] = linalg.index 1 : index
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[C1_0:.+]] = arith.constant 1 : index
+// CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C1_0]] : tensor<3x5xui32>
+// CHECK: %[[DIMSUB1:.+]] = arith.subi %[[DIM]], %[[C1]] : index
+// CHECK: %[[REV_IDX:.+]] = arith.subi %[[DIMSUB1]], %[[REV_DIM]] : index
+// CHECK: %[[EXTRACTED:.+]] = tensor.extract %arg0[%[[SAME_DIM]], %[[REV_IDX]]] : tensor<3x5xui32>
+// CHECK: linalg.yield %[[EXTRACTED]] : ui32
+// CHECK: return %[[GEN]]
// -----
@@ -530,16 +540,32 @@
} : (tensor<?x?xi32>) -> tensor<?x?xi32>
return %0 : tensor<?x?xi32>
}
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[IN]], %[[C0]]
-// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[IN]], %[[C1]]
-// CHECK: %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]]) : tensor<?x?xi32>
-// CHECK: %[[REV:.+]] = iree_linalg_ext.reverse
-// CHECK-SAME: dimensions(dense<[0, 1]> : tensor<2xi64>)
-// CHECK-SAME: ins(%[[IN]] : tensor<?x?xi32>)
-// CHECK-SAME: outs(%[[INIT]] : tensor<?x?xi32>) : tensor<?x?xi32>
-// CHECK: return %[[REV]]
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[D:.+]] = tensor.dim %[[IN]], %[[C0]] : tensor<?x?xi32>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[D0:.+]] = tensor.dim %[[IN]], %[[C1]] : tensor<?x?xi32>
+// CHECK: %[[INIT:.+]] = tensor.empty(%[[D]], %[[D0]]) : tensor<?x?xi32>
+// CHECK: %[[GEN:.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor<?x?xi32>) {
+
+// First reverse dimension
+// CHECK: %[[IDX0:.+]] = linalg.index 0 : index
+// CHECK: %[[C1_1:.+]] = arith.constant 1 : index
+// CHECK: %[[C0_2:.+]] = arith.constant 0 : index
+// CHECK: %[[DIM0:.+]] = tensor.dim %arg0, %[[C0_2]] : tensor<?x?xi32>
+// CHECK: %[[DIM0SUB1:.+]] = arith.subi %[[DIM0]], %[[C1_1]] : index
+// CHECK: %[[REV_IDX0:.+]] = arith.subi %[[DIM0SUB1]], %[[IDX0]] : index
+
+// Second reverse dimension
+// CHECK: %[[IDX1:.+]] = linalg.index 1 : index
+// CHECK: %[[C1_4:.+]] = arith.constant 1 : index
+// CHECK: %[[C1_5:.+]] = arith.constant 1 : index
+// CHECK: %[[DIM1:.+]] = tensor.dim %arg0, %[[C1_5]] : tensor<?x?xi32>
+// CHECK: %[[DIM1SUB1:.+]] = arith.subi %[[DIM1]], %[[C1_4]] : index
+// CHECK: %[[REV_IDX1:.+]] = arith.subi %[[DIM1SUB1]], %[[IDX1]] : index
+
+// CHECK: %[[EXTRACTED:.+]] = tensor.extract %arg0[%[[REV_IDX0]], %[[REV_IDX1]]] : tensor<?x?xi32>
+// CHECK: linalg.yield %[[EXTRACTED]] : i32
+// CHECK: return %[[GEN]]
// -----
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir b/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
index 68c0873..c075e09 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
@@ -526,46 +526,6 @@
// -----
-func.func @linalg_ext_reverse_dim0() {
- %c1 = arith.constant 1 : index
- %c2 = arith.constant 2 : index
- %c3 = arith.constant 3 : index
- %c0 = arith.constant 0 : index
- %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<2x3xf32>>
- %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x3xf32>>
- %workgroup_id_x = hal.interface.workgroup.id[0] : index
- %workgroup_count_x = hal.interface.workgroup.count[0] : index
- %workgroup_id_y = hal.interface.workgroup.id[1] : index
- %workgroup_count_y = hal.interface.workgroup.count[1] : index
- %2 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_y]
- %3 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_y]
- scf.for %arg0 = %2 to %c2 step %3 {
- %4 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
- %5 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_x]
- scf.for %arg1 = %4 to %c3 step %5 {
- %6 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1], sizes = [2, 3], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2x3xf32>> -> tensor<2x3xf32>
- %7 = tensor.empty() : tensor<2x3xf32>
- %8 = iree_linalg_ext.reverse dimensions(dense<0> : tensor<1xi64>) ins(%6 : tensor<2x3xf32>) outs(%7 : tensor<2x3xf32>) : tensor<2x3xf32>
- %9 = affine.apply affine_map<()[s0] -> (-s0)>()[%arg0]
- flow.dispatch.tensor.store %8, %1, offsets = [%9, %arg1], sizes = [2, 3], strides = [%c1, %c1] : tensor<2x3xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x3xf32>>
- }
- }
- return
-}
-// CHECK: func.func @linalg_ext_reverse_dim0()
-// CHECK-DAG: %[[IN:.+]] = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
-// CHECK-DAG: %[[OUT:.+]] = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
-// CHECK: scf.for %[[IV0:.+]] =
-// CHECK: scf.for %[[IV1:.+]] =
-// CHECK-DAG: %[[IN_TILE:.+]] = flow.dispatch.tensor.load %[[IN]]
-// CHECK-DAG: %[[OUT_TILE:.+]] = flow.dispatch.tensor.load %[[OUT]]
-// CHECK: %[[REV_TILE:.+]] = iree_linalg_ext.reverse
-// CHECK-SAME: ins(%[[IN_TILE]] : tensor<2x3xf32>)
-// CHECK-SAME: outs(%[[OUT_TILE]] : tensor<2x3xf32>)
-// CHECK: flow.dispatch.tensor.store %[[REV_TILE]], %[[OUT]]
-
-// -----
-
func.func @sort1D() {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readwrite:tensor<4xi32>>
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir b/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
index 911a44d..f360aad 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
@@ -2170,30 +2170,6 @@
// -----
-// CHECK-LABEL: func.func @reverse_dim(
-// CHECK-DAG: %[[alloc:.*]] = memref.alloc()
-// CHECK-DAG: %[[cst:.*]] = bufferization.to_memref
-// CHECK: iree_linalg_ext.reverse dimensions(dense<0> : tensor<1xi64>)
-// CHECK-SAME: ins(%[[cst]] :
-// CHECK-SAME: outs(%[[alloc]] :
-// CHECK: %[[load:.*]] = memref.load %[[alloc]]
-// CHECK: return %[[load]]
-func.func @reverse_dim(%pos: index) -> f32 {
- %input = arith.constant dense<[[1.0, 2.0, 3.0],
- [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
-
- %init = bufferization.alloc_tensor() : tensor<2x3xf32>
- %0 = iree_linalg_ext.reverse
- dimensions(dense<0> : tensor<1xi64>)
- ins(%input : tensor<2x3xf32>)
- outs(%init : tensor<2x3xf32>) : tensor<2x3xf32>
-
- %1 = tensor.extract %0[%pos, %pos] : tensor<2x3xf32>
- return %1 : f32
-}
-
-// -----
-
// CHECK-LABEL: func.func @fft_tensor(
// CHECK: memref.alloc
// CHECK: memref.alloc
diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp b/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp
index 0422f16..8e52db0 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp
@@ -350,9 +350,9 @@
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
- // TODO: Revisit this for Scatter/ReverseOp. We can then get rid of
+ // TODO: Revisit this for ScatterOp. We can then get rid of
// `bufferizesToMemoryRead` completely.
- return !isa<IREE::LinalgExt::ScatterOp, IREE::LinalgExt::ReverseOp>(op);
+ return !isa<IREE::LinalgExt::ScatterOp>(op);
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -630,8 +630,6 @@
LinalgExtOpInterface<IREE::LinalgExt::PackOp>>(*ctx);
IREE::LinalgExt::UnPackOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::UnPackOp>>(*ctx);
- IREE::LinalgExt::ReverseOp::attachInterface<
- LinalgExtOpInterface<IREE::LinalgExt::ReverseOp>>(*ctx);
IREE::LinalgExt::ScanOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::ScanOp>>(*ctx);
IREE::LinalgExt::ScatterOp::attachInterface<
diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.cpp b/compiler/src/iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.cpp
index cd289a5..b695e1e 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.cpp
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.cpp
@@ -241,8 +241,6 @@
OuterParallelAsPartitionableLoops<IREE::LinalgExt::ScatterOp>>(*ctx);
IREE::LinalgExt::SortOp::attachInterface<
AllParallelAsPartitionableLoops<IREE::LinalgExt::SortOp>>(*ctx);
- IREE::LinalgExt::ReverseOp::attachInterface<
- OuterParallelAsPartitionableLoops<IREE::LinalgExt::ReverseOp>>(*ctx);
IREE::LinalgExt::TopkOp::attachInterface<
AllParallelAsPartitionableLoops<IREE::LinalgExt::TopkOp>>(*ctx);
IREE::LinalgExt::WinogradInputTransformOp::attachInterface<
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_ext_fusion.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_ext_fusion.mlir
index 6621451..4cca860 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_ext_fusion.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_ext_fusion.mlir
@@ -46,41 +46,6 @@
// CHECK: %[[GEN2:.+]] = linalg.generic
// CHECK-SAME: ins(%[[INPUT:.+]] : tensor<8192x16x8x128xf32>)
-
-
-// -----
-
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-util.func public @linalgext_reverse_fusion() -> tensor<10x10xi32> {
- %0 = tensor.empty() : tensor<10x10xi64>
- %1 = tensor.empty() : tensor<10x10xi32>
- %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<10x10xi64>) outs(%1 : tensor<10x10xi32>) {
- ^bb0(%in: i64, %out: i32):
- %7 = arith.trunci %in : i64 to i32
- linalg.yield %7 : i32
- } -> tensor<10x10xi32>
- %3 = tensor.empty() : tensor<10x10xi32>
- %4 = iree_linalg_ext.reverse dimensions(dense<0> : tensor<1xi64>) ins(%2 : tensor<10x10xi32>) outs(%3 : tensor<10x10xi32>) : tensor<10x10xi32>
-
- // dont fuse with with reverse's consumer
- %5 = tensor.empty() : tensor<10x10xi32>
- %6 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%4 : tensor<10x10xi32>) outs(%5 : tensor<10x10xi32>) {
- ^bb0(%in: i32, %out: i32):
- %7 = arith.addi %in, %out : i32
- linalg.yield %7 : i32
- } -> tensor<10x10xi32>
- util.return %6 : tensor<10x10xi32>
-}
-
-// CHECK: util.func public @linalgext_reverse_fusion
-// CHECK: flow.dispatch.workgroups
-// CHECK: %[[SHRUNK:.+]] = linalg.generic
-// CHECK: %[[REVERSED:.+]] = iree_linalg_ext.reverse
-// CHECK: ins(%[[SHRUNK]] : tensor<10x10xi32>)
-// CHECK: flow.dispatch.workgroups
-// CHECK: %[[GEN:.+]] = linalg.generic
-
// -----
#map = affine_map<(d0, d1) -> (d0, d1)>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index ce40d00..4df93b3 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -440,71 +440,6 @@
}
//===----------------------------------------------------------------------===//
-// ReverseOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult ReverseOp::verify() {
- Operation *op = getOperation();
- if (getNumDpsInputs() != 1) {
- return op->emitOpError("expected exactly one input");
- }
- if (getNumDpsInits() != 1) {
- return op->emitOpError("expected exactly one output");
- }
- auto inputType = cast<ShapedType>(getInput().getType());
- auto outputType = cast<ShapedType>(getOutput().getType());
- if (inputType.getElementType() != outputType.getElementType()) {
- return op->emitOpError(
- "expected input/output element types to be identical");
- }
- ArrayRef<int64_t> inputShapes = inputType.getShape();
- ArrayRef<int64_t> outputShapes = outputType.getShape();
- if (inputShapes.size() != outputShapes.size()) {
- return op->emitOpError("expexted input/output to have identical ranks");
- }
- if (llvm::any_of(llvm::zip_equal(inputShapes, outputShapes),
- [](std::tuple<int64_t, int64_t> s) {
- return !ShapedType::isDynamic(std::get<0>(s)) &&
- !ShapedType::isDynamic(std::get<1>(s)) &&
- std::get<0>(s) != std::get<1>(s);
- })) {
- return op->emitOpError("incompatible input/output shapes");
- }
-
- int64_t rank = getOperandRank();
- llvm::SmallSetVector<int64_t, 4> s;
- for (auto dim : getDimensionsArray()) {
- if (dim < 0 || dim >= rank) {
- return op->emitOpError("all the dimensions must be within [0, ")
- << rank << ")";
- }
- if (s.contains(dim)) {
- return op->emitOpError("expected dimensions numbers are all unique");
- }
- s.insert(dim);
- }
-
- return success();
-}
-
-LogicalResult
-ReverseOp::reifyResultShapes(OpBuilder &b,
- ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
- return cast<LinalgExtOp>(getOperation())
- .reifyResultShapes(b, reifiedReturnShapes);
-}
-
-SmallVector<AffineMap> ReverseOp::getIndexingMapsForOperands() {
- Builder builder(getContext());
- return {builder.getMultiDimIdentityMap(getOperandRank()),
- /*output=*/AffineMap(nullptr)};
-}
-
-SmallVector<AffineMap> ReverseOp::getIndexingMapsForResults() {
- return {AffineMap(nullptr)};
-}
-
-//===----------------------------------------------------------------------===//
// TopkOp
//===----------------------------------------------------------------------===//
@@ -1583,7 +1518,6 @@
DEFINE_OP_GET_EFFECTS(ScatterOp)
DEFINE_OP_GET_EFFECTS(SortOp)
DEFINE_OP_GET_EFFECTS(FftOp)
-DEFINE_OP_GET_EFFECTS(ReverseOp)
DEFINE_OP_GET_EFFECTS(ScanOp)
DEFINE_OP_GET_EFFECTS(TopkOp)
DEFINE_OP_GET_EFFECTS(PackOp)
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index 993ab3b..fe8693a 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -369,67 +369,6 @@
}];
}
-def IREELinalgExt_ReverseOp : IREELinalgExt_Op<"reverse", [
- DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
- DeclareOpInterfaceMethods<LinalgFusionInterface>,
- DeclareOpInterfaceMethods<
- TilingInterface,
- ["generateScalarImplementation",
- "getIterationDomain",
- "getLoopIteratorTypes",
- "getResultTilePosition",
- "getTiledImplementation"]>,
- DeclareOpInterfaceMethods<LinalgExtInterface>]> {
- let summary = "Reverse operator";
- let description = [{
- A temporary solution for lowering reverse ops into IREE, allowing IREE to
- tile and distribute them.
- }
- }];
-
- let arguments = (ins Variadic<AnyShaped>:$inputs,
- Variadic<AnyShaped>:$outputs,
- I64ElementsAttr:$dimensions
- );
- let results = (outs Variadic<AnyRankedTensor>:$results);
- let assemblyFormat = [{
- attr-dict `dimensions` `(` $dimensions `)`
- (`ins` `(` $inputs^ `:` type($inputs) `)`)?
- (`outs` `(` $outputs^ `:` type($outputs) `)`)?
- (`:` type($results)^)?
- }];
- let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
- Value getInput() {
- return getDpsInputOperand(0)->get();
- }
- Value getOutput() {
- return getDpsInitOperand(0)->get();
- }
- ShapedType getOperandType() {
- return cast<ShapedType>(getInput().getType());
- }
- int64_t getOperandRank() {
- return getOperandType().getRank();
- }
- ArrayRef<int64_t> getOprerandShape() {
- return getOperandType().getShape();
- }
- SmallVector<int64_t> getDimensionsArray() {
- SmallVector<int64_t> ret;
- for (const APInt& elem : getDimensions()) {
- ret.push_back(elem.getLimitedValue());
- }
- return ret;
- }
-
- // Method to implement for specifying output range for
- // DestinationStyleOpInterface
- MutableOperandRange getDpsInitsMutable() {
- return getOutputsMutable();
- }
- }];
-}
-
def IREELinalgExt_TopkOp : IREELinalgExt_Op<"topk",[
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<LinalgExtInterface>,
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/canonicalize.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/canonicalize.mlir
index a2f89e8..87ee429 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/canonicalize.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/canonicalize.mlir
@@ -1,27 +1,5 @@
// RUN: iree-opt --canonicalize --split-input-file %s | FileCheck %s
-func.func @tensor_cast(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> {
- %init = tensor.empty() : tensor<3x5xi32>
-
- %casted_arg0 = tensor.cast %arg0 : tensor<3x5xi32> to tensor<?x?xi32>
- %casted_init = tensor.cast %init : tensor<3x5xi32> to tensor<?x?xi32>
-
- %0 = iree_linalg_ext.reverse
- dimensions(dense<0> : tensor<1xi64>)
- ins(%casted_arg0 : tensor<?x?xi32>)
- outs(%casted_init : tensor<?x?xi32>) : tensor<?x?xi32>
-
- %1 = tensor.cast %0 : tensor<?x?xi32> to tensor<3x5xi32>
-
- return %1: tensor<3x5xi32>
-}
-// CHECK-LABEL: func.func @tensor_cast(
-// CHECK: iree_linalg_ext.reverse
-// CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<3x5xi32>)
-// CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<3x5xi32>)
-
-// -----
-
func.func @pack_canonicalize(%arg0 : tensor<?x?xi32>,
%arg1 : tensor<1x2x3x3xi32>) -> tensor<1x?x3x3xi32> {
%c0_i32 = arith.constant 0 : i32
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
index 3363f1b..1d0280b 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
@@ -338,42 +338,6 @@
// -----
-func.func @reverse_diff_element_type(%arg0: tensor<3x5xi32>) -> tensor<3x5xf32> {
- %init = tensor.empty() : tensor<3x5xf32>
- // expected-error @+1 {{expected input/output element types to be identical}}
- %0 = iree_linalg_ext.reverse
- dimensions(dense<0> : tensor<1xi64>)
- ins(%arg0 : tensor<3x5xi32>)
- outs(%init : tensor<3x5xf32>) : tensor<3x5xf32>
- return %0 : tensor<3x5xf32>
-}
-
-// -----
-
-func.func @reverse_diff_shape(%arg0: tensor<3x5xi32>) -> tensor<3x6xi32> {
- %init = tensor.empty() : tensor<3x6xi32>
- // expected-error @+1 {{incompatible input/output shapes}}
- %0 = iree_linalg_ext.reverse
- dimensions(dense<0> : tensor<1xi64>)
- ins(%arg0 : tensor<3x5xi32>)
- outs(%init : tensor<3x6xi32>) : tensor<3x6xi32>
- return %0 : tensor<3x6xi32>
-}
-
-// -----
-
-func.func @reverse_dup_dims(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> {
- %init = tensor.empty() : tensor<3x5xi32>
- // expected-error @+1 {{expected dimensions numbers are all unique}}
- %0 = iree_linalg_ext.reverse
- dimensions(dense<[0, 0]> : tensor<2xi64>)
- ins(%arg0 : tensor<3x5xi32>)
- outs(%init : tensor<3x5xi32>) : tensor<3x5xi32>
- return %0 : tensor<3x5xi32>
-}
-
-// -----
-
func.func @topk_invalid(%input_values: tensor<2x10xf32>, %input_indices: tensor<2x10xi32>, %out_values : tensor<2x3xf32>, %out_indices: tensor<2x3xi32>) -> (tensor<2x3xf32>, tensor<2x3xi32>) {
// expected-error@+1 {{expected one or two input operands}}
%0:2 = iree_linalg_ext.topk
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
index 88a1f2d..eddb2c7 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
@@ -484,111 +484,6 @@
// -----
-func.func @reverse_tensor(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> {
- %init = tensor.empty() : tensor<3x5xi32>
- %0 = iree_linalg_ext.reverse
- dimensions(dense<0> : tensor<1xi64>)
- ins(%arg0 : tensor<3x5xi32>)
- outs(%init : tensor<3x5xi32>) : tensor<3x5xi32>
- return %0 : tensor<3x5xi32>
-}
-// CHECK-LABEL: func.func @reverse_tensor
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<3x5xi32>
-// CHECK: %[[INIT:.+]] = tensor.empty()
-// CHECK: %[[RESULT:.+]] = iree_linalg_ext.reverse
-// CHECK-SAME: dimensions(dense<0> : tensor<1xi64>)
-// CHECK-SAME: ins(%[[ARG0]]
-// CHECK-SAME: outs(%[[INIT]]
-
-// -----
-
-func.func @reverse_memref(%arg0: memref<3x5xi32>, %arg1: memref<3x5xi32>) {
- iree_linalg_ext.reverse
- dimensions(dense<0> : tensor<1xi64>)
- ins(%arg0 : memref<3x5xi32>)
- outs(%arg1 : memref<3x5xi32>)
- return
-}
-// CHECK-LABEL: func.func @reverse_memref
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<3x5xi32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<3x5xi32>
-// CHECK: iree_linalg_ext.reverse
-// CHECK-SAME: dimensions(dense<0> : tensor<1xi64>)
-// CHECK-SAME: ins(%[[ARG0]]
-// CHECK-SAME: outs(%[[ARG1]]
-
-// -----
-
-func.func @reverse_dynamic_tensor(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %d0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
- %d1 = tensor.dim %arg0, %c1 : tensor<?x?xi32>
- %init = tensor.empty(%d0, %d1) : tensor<?x?xi32>
- %0 = iree_linalg_ext.reverse
- dimensions(dense<1> : tensor<1xi64>)
- ins(%arg0 : tensor<?x?xi32>)
- outs(%init : tensor<?x?xi32>) : tensor<?x?xi32>
- return %0 : tensor<?x?xi32>
-}
-// CHECK-LABEL: func.func @reverse_dynamic_tensor
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xi32>
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
-// CHECK: %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]])
-// CHECK: %[[RESULT:.+]] = iree_linalg_ext.reverse
-// CHECK-SAME: dimensions(dense<1> : tensor<1xi64>)
-// CHECK-SAME: ins(%[[ARG0]]
-// CHECK-SAME: outs(%[[INIT]]
-
-// -----
-
-func.func @reverse_static_dynamic_tensor(%arg0: tensor<3x5xi32>) -> tensor<?x?xi32> {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %d0 = tensor.dim %arg0, %c0 : tensor<3x5xi32>
- %d1 = tensor.dim %arg0, %c1 : tensor<3x5xi32>
- %init = tensor.empty(%d0, %d1) : tensor<?x?xi32>
- %0 = iree_linalg_ext.reverse
- dimensions(dense<1> : tensor<1xi64>)
- ins(%arg0 : tensor<3x5xi32>)
- outs(%init : tensor<?x?xi32>) : tensor<?x?xi32>
- return %0 : tensor<?x?xi32>
-}
-// CHECK-LABEL: func.func @reverse_static_dynamic_tensor
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<3x5xi32>
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
-// CHECK: %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]])
-// CHECK: %[[RESULT:.+]] = iree_linalg_ext.reverse
-// CHECK-SAME: dimensions(dense<1> : tensor<1xi64>)
-// CHECK-SAME: ins(%[[ARG0]]
-// CHECK-SAME: outs(%[[INIT]]
-
-// -----
-
-func.func @reverse_multi_dims(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> {
- %init = tensor.empty() : tensor<3x5xi32>
- %0 = iree_linalg_ext.reverse
- dimensions(dense<[0, 1]> : tensor<2xi64>)
- ins(%arg0 : tensor<3x5xi32>)
- outs(%init : tensor<3x5xi32>) : tensor<3x5xi32>
- return %0 : tensor<3x5xi32>
-}
-// CHECK-LABEL: func.func @reverse_multi_dims
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<3x5xi32>
-// CHECK: %[[INIT:.+]] = tensor.empty()
-// CHECK: %[[RESULT:.+]] = iree_linalg_ext.reverse
-// CHECK-SAME: dimensions(dense<[0, 1]> : tensor<2xi64>)
-// CHECK-SAME: ins(%[[ARG0]]
-// CHECK-SAME: outs(%[[INIT]]
-
-// -----
-
func.func @topk_tensor(%input_values: tensor<20x10x8x4xf32>, %input_indices: tensor<20x10x8x4xi32>) -> (tensor<20x10x3x4xf32>, tensor<20x10x3x4xi32>) {
%out_values = tensor.empty() : tensor<20x10x3x4xf32>
%out_indices = tensor.empty() : tensor<20x10x3x4xi32>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
index 231f694..2565df8 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
@@ -730,101 +730,6 @@
}
//===----------------------------------------------------------------------===//
-// ReverseOp
-//===----------------------------------------------------------------------===//
-
-SmallVector<utils::IteratorType> ReverseOp::getLoopIteratorTypes() {
- SmallVector<utils::IteratorType> iteratorTypes(getOperandRank(),
- utils::IteratorType::parallel);
- return iteratorTypes;
-}
-
-SmallVector<Range> ReverseOp::getIterationDomain(OpBuilder &builder) {
- Location loc = getLoc();
- Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
- Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
- SmallVector<Range> ranges;
- for (auto dim : llvm::seq<int64_t>(0, getOperandRank())) {
- Value ub = getDimValue(builder, loc, getInput(), dim);
- ranges.emplace_back(Range{zero, ub, one});
- }
- return ranges;
-}
-
-LogicalResult ReverseOp::generateScalarImplementation(OpBuilder &b,
- Location loc,
- ValueRange ivs) {
- SmallVector<Value> mirrorIndices(ivs.begin(), ivs.end());
- for (auto dim : getDimensionsArray()) {
- auto size = getDimValue(b, loc, getInput(), dim);
- size = b.create<arith::SubIOp>(loc, size,
- b.create<arith::ConstantIndexOp>(loc, 1));
- mirrorIndices[dim] = b.create<arith::SubIOp>(loc, size, mirrorIndices[dim]);
- }
- Value val = b.create<memref::LoadOp>(loc, getInput(), ivs);
- b.create<memref::StoreOp>(loc, val, getOutput(), mirrorIndices);
- return success();
-}
-
-FailureOr<TilingResult>
-ReverseOp::getTiledImplementation(OpBuilder &builder,
- ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes) {
- int64_t rank = getOperandRank();
- SmallVector<OpFoldResult> strides(rank, builder.getI64IntegerAttr(1));
- Location loc = getLoc();
- SmallVector<OpFoldResult> mirrorOffsets, mirrorSizes;
- if (failed(getResultTilePosition(builder, 0, offsets, sizes, mirrorOffsets,
- mirrorSizes))) {
- return {};
- }
-
- SmallVector<Value> tiledOperands;
- tiledOperands.emplace_back(
- getSlice(builder, loc, getInput(), offsets, sizes, strides));
-
- SmallVector<Type, 4> resultTypes;
- if (hasPureTensorSemantics()) {
- tiledOperands.emplace_back(
- getSlice(builder, loc, getOutput(), mirrorOffsets, sizes, strides));
- resultTypes.push_back(tiledOperands[1].getType());
- } else {
- tiledOperands.emplace_back(
- getSlice(builder, loc, getOutput(), mirrorOffsets, sizes, strides));
- }
-
- Operation *tiledRevOp =
- mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
-
- return TilingResult{{tiledRevOp},
- SmallVector<Value>(tiledRevOp->getResults())};
-}
-
-LogicalResult ReverseOp::getResultTilePosition(
- OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
- SmallVector<OpFoldResult> &resultSizes) {
- AffineExpr sym0, sym1, sym2;
- bindSymbols(builder.getContext(), sym0, sym1, sym2);
- AffineMap map =
- AffineMap::get(/*dimCount=*/0, /*symbolCount=*/3, {sym0 - sym1 - sym2});
- resultOffsets.assign(offsets.begin(), offsets.end());
- Location loc = getLoc();
- for (auto dim : getDimensionsArray()) {
- Value size = getDimValue(builder, loc, getInput(), dim);
- Value offset =
- getValueOrCreateConstantIndexOp(builder, loc, resultOffsets[dim]);
- Value tileSize = getValueOrCreateConstantIndexOp(builder, loc, sizes[dim]);
- resultOffsets[dim] = builder
- .create<affine::AffineApplyOp>(
- loc, map, ValueRange{size, offset, tileSize})
- .getResult();
- }
- resultSizes.assign(sizes.begin(), sizes.end());
- return success();
-}
-
-//===----------------------------------------------------------------------===//
// TopkOp
//===----------------------------------------------------------------------===//
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
index 3623426..f136ab5 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
@@ -507,28 +507,6 @@
// -----
-func.func @reverse_dim_0(%arg0: memref<?x?xi32>, %arg1: memref<?x?xi32>) {
- iree_linalg_ext.reverse
- dimensions(dense<0> : tensor<1xi64>)
- ins(%arg0 : memref<?x?xi32>)
- outs(%arg1 : memref<?x?xi32>)
- return
-}
-// CHECK-LABEL: func.func @reverse_dim_0
-// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
-// CHECK-SAME: %[[OUT:[a-zA-Z0-9]+]]
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[D0:.+]] = memref.dim %arg0, %c0 : memref<?x?xi32>
-// CHECK-DAG: %[[D1:.+]] = memref.dim %arg0, %c1 : memref<?x?xi32>
-// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[D0]] step %[[C1]]
-// CHECK: scf.for %[[J:.+]] = %[[C0]] to %[[D1]] step %[[C1]]
-// CHECK: %[[T0:.+]] = memref.dim %[[IN]], %[[C0]]
-// CHECK: %[[T1:.+]] = arith.subi %[[T0]], %[[C1]] : index
-// CHECK: %[[T2:.+]] = arith.subi %[[T1]], %[[I]] : index
-// CHECK: %[[V0:.+]] = memref.load %[[IN]][%[[I]], %[[J]]]
-// CHECK: memref.store %[[V0]], %[[OUT]][%[[T2]], %[[J]]] : memref<?x?xi32>
-
func.func @scan_1d_inclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
%c0 = memref.alloc() : memref<i32>
iree_linalg_ext.scan dimension(0) inclusive(true)
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
index 42e6193..67a189b 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
@@ -494,95 +494,6 @@
// -----
-func.func @reverse_memref(%arg0: memref<?xi32>, %arg1: memref<?xi32>) {
- iree_linalg_ext.reverse
- dimensions(dense<0> : tensor<1xi64>)
- ins(%arg0: memref<?xi32>)
- outs(%arg1: memref<?xi32>)
- return
-}
-module attributes { transform.with_named_sequence } {
- transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["iree_linalg_ext.reverse"]} in %module_op : (!transform.any_op) -> !transform.any_op
- %1, %loops = transform.structured.tile_using_for %0 tile_sizes [10] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
- transform.yield
- }
-}
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 10)
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> (s0 - s1 - s2)>
-// CHECK: func.func @reverse_memref(
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
-// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C0]] : memref<?xi32>
-// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[D0]] step %[[C10]] {
-// CHECK-DAG: %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]])[%[[D0]]]
-// CHECK-DAG: %[[IDX:.+]] = affine.apply #[[MAP2]]()[%[[D0]], %[[I]], %[[SIZE]]]
-// CHECK-DAG: %[[SUB_IN:.+]] = memref.subview %[[ARG0]][%[[I]]] [%[[SIZE]]] [1]
-// CHECK-DAG: %[[SUB_OUT:.+]] = memref.subview %[[ARG1]][%[[IDX]]] [%[[SIZE]]] [1]
-// CHECK: iree_linalg_ext.reverse
-// CHECK-SAME: dimensions(dense<0> : tensor<1xi64>)
-// CHECK-SAME: ins(%[[SUB_IN]]
-// CHECK-SAME: outs(%[[SUB_OUT]]
-
-// -----
-
-func.func @reverse_tensor_multi_dim(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %d0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
- %d1 = tensor.dim %arg0, %c1 : tensor<?x?xi32>
- %init = tensor.empty(%d0, %d1) : tensor<?x?xi32>
- %0 = iree_linalg_ext.reverse
- dimensions(dense<[0, 1]> : tensor<2xi64>)
- ins(%arg0: tensor<?x?xi32>)
- outs(%init: tensor<?x?xi32>) : tensor<?x?xi32>
- return %0 : tensor<?x?xi32>
-}
-module attributes { transform.with_named_sequence } {
- transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["iree_linalg_ext.reverse"]} in %module_op : (!transform.any_op) -> !transform.any_op
- %1, %loops:2 = transform.structured.tile_using_for %0 tile_sizes [10, 20] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
- transform.yield
- }
-}
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 10)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 20)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> (s0 - s1 - s2)>
-// CHECK: func.func @reverse_tensor_multi_dim(
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
-// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
-// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xi32>
-// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xi32>
-// CHECK: %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]]) : tensor<?x?xi32>
-// CHECK: %[[RES:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[D0]] step %[[C10]]
-// CHECK-SAME: iter_args(%[[INIT2:.+]] = %[[INIT]]) -> (tensor<?x?xi32>) {
-// CHECK: %[[RES2:.+]] = scf.for %[[J:.+]] = %[[C0]] to %[[D1]] step %[[C20]]
-// CHECK-SAME: iter_args(%[[INIT3:.+]] = %[[INIT2]]) -> (tensor<?x?xi32>) {
-// CHECK-DAG: %[[SIZE_I:.+]] = affine.min #[[MAP0]](%[[I]])[%[[D0]]]
-// CHECK-DAG: %[[SIZE_J:.+]] = affine.min #[[MAP1]](%[[J]])[%[[D1]]]
-// CHECK-DAG: %[[IDX0:.+]] = affine.apply #[[MAP2]]()[%[[D0]], %[[I]], %[[SIZE_I]]]
-// CHECK-DAG: %[[IDX1:.+]] = affine.apply #[[MAP2]]()[%[[D1]], %[[J]], %[[SIZE_J]]]
-// CHECK: %[[SUB_IN:.+]] = tensor.extract_slice
-// CHECK-SAME: %[[ARG0]][%[[I]], %[[J]]] [%[[SIZE_I]], %[[SIZE_J]]] [1, 1]
-// CHECK: %[[SUB_INIT:.+]] = tensor.extract_slice
-// CHECK-SAME: %[[INIT3]][%[[IDX0]], %[[IDX1]]] [%[[SIZE_I]], %[[SIZE_J]]] [1, 1]
-// CHECK: %[[REV:.+]] = iree_linalg_ext.reverse
-// CHECK-SAME: dimensions(dense<[0, 1]> : tensor<2xi64>)
-// CHECK-SAME: ins(%[[SUB_IN]]
-// CHECK-SAME: outs(%[[SUB_INIT]]
-// CHECK: %[[RES3:.+]] = tensor.insert_slice %[[REV]] into
-// CHECK-SAME: %[[INIT3]][%[[IDX0]], %[[IDX1]]] [%[[SIZE_I]], %[[SIZE_J]]] [1, 1]
-// CHECK: scf.yield %[[RES3]]
-// CHECK: scf.yield %[[RES2]]
-// CHECK: return %[[RES]]
-
-// -----
-
func.func @scan_1d(%0: tensor<128xi32>) -> tensor<128xi32> {
%c0 = tensor.empty() : tensor<i32>
%1 = tensor.empty() : tensor<128xi32>
diff --git a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp
index 0a9a1b5..681b336 100644
--- a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp
+++ b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp
@@ -313,8 +313,6 @@
LinalgOpTiedOpInterface<IREE::LinalgExt::FftOp>>(*context);
IREE::LinalgExt::ScanOp::attachInterface<
LinalgOpTiedOpInterface<IREE::LinalgExt::ScanOp>>(*context);
- IREE::LinalgExt::ReverseOp::attachInterface<
- LinalgOpTiedOpInterface<IREE::LinalgExt::ReverseOp>>(*context);
IREE::LinalgExt::TopkOp::attachInterface<
LinalgOpTiedOpInterface<IREE::LinalgExt::TopkOp>>(*context);
IREE::LinalgExt::WinogradInputTransformOp::attachInterface<
diff --git a/tests/e2e/linalg_ext_ops/BUILD.bazel b/tests/e2e/linalg_ext_ops/BUILD.bazel
index 8c6f3dd..f655d9b 100644
--- a/tests/e2e/linalg_ext_ops/BUILD.bazel
+++ b/tests/e2e/linalg_ext_ops/BUILD.bazel
@@ -16,7 +16,6 @@
# keep sorted
[
"attention.mlir",
- "reverse.mlir",
"scan.mlir",
"scatter.mlir",
"sort.mlir",
@@ -42,7 +41,6 @@
VMVX_SRCS = enforce_glob(
# keep sorted
[
- "reverse.mlir",
"scan.mlir",
"scatter.mlir",
"sort.mlir",
@@ -66,7 +64,6 @@
LLVM_GPU_SRCS = enforce_glob(
# keep sorted
[
- "reverse.mlir",
"scan.mlir",
"scatter.mlir",
"sort.mlir",
@@ -107,7 +104,6 @@
srcs = enforce_glob(
# keep sorted
[
- "reverse.mlir",
"scan.mlir",
"scatter.mlir",
"sort.mlir",
@@ -138,7 +134,6 @@
include = ["*.mlir"],
exclude = [
"attention.mlir",
- "reverse.mlir", #TODO(#12415): disabled due to miscompilation on Pixel 6.
"top-k.mlir",
],
),
diff --git a/tests/e2e/linalg_ext_ops/CMakeLists.txt b/tests/e2e/linalg_ext_ops/CMakeLists.txt
index 1e84c09..97dc732 100644
--- a/tests/e2e/linalg_ext_ops/CMakeLists.txt
+++ b/tests/e2e/linalg_ext_ops/CMakeLists.txt
@@ -15,7 +15,6 @@
check_llvm-cpu_local-task
SRCS
"attention.mlir"
- "reverse.mlir"
"scan.mlir"
"scatter.mlir"
"sort.mlir"
@@ -34,7 +33,6 @@
NAME
check_vmvx_local-task
SRCS
- "reverse.mlir"
"scan.mlir"
"scatter.mlir"
"sort.mlir"
@@ -51,7 +49,6 @@
NAME
check_cuda
SRCS
- "reverse.mlir"
"scan.mlir"
"scatter.mlir"
"sort.mlir"
@@ -74,7 +71,6 @@
NAME
check_rocm_hip
SRCS
- "reverse.mlir"
"scan.mlir"
"scatter.mlir"
"sort.mlir"
@@ -91,7 +87,6 @@
NAME
check_metal-spirv_vulkan
SRCS
- "reverse.mlir"
"scan.mlir"
"scatter.mlir"
"sort.mlir"
diff --git a/tests/e2e/linalg_ext_ops/reverse.mlir b/tests/e2e/linalg_ext_ops/reverse.mlir
deleted file mode 100644
index db1610b..0000000
--- a/tests/e2e/linalg_ext_ops/reverse.mlir
+++ /dev/null
@@ -1,53 +0,0 @@
-func.func @reverse_dim0() {
- %input = util.unfoldable_constant dense<[[1.0, 2.0, 3.0],
- [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
-
- %init = tensor.empty() : tensor<2x3xf32>
- %0 = iree_linalg_ext.reverse
- dimensions(dense<0> : tensor<1xi64>)
- ins(%input : tensor<2x3xf32>)
- outs(%init : tensor<2x3xf32>) : tensor<2x3xf32>
-
- check.expect_almost_eq_const(
- %0,
- dense<[[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]]> : tensor<2x3xf32>
- ) : tensor<2x3xf32>
-
- return
-}
-
-func.func @reverse_dim1() {
- %input = util.unfoldable_constant dense<[[1, 2, 3],
- [4, 5, 6]]> : tensor<2x3xi32>
-
- %init = tensor.empty() : tensor<2x3xi32>
- %0 = iree_linalg_ext.reverse
- dimensions(dense<1> : tensor<1xi64>)
- ins(%input : tensor<2x3xi32>)
- outs(%init : tensor<2x3xi32>) : tensor<2x3xi32>
-
- check.expect_eq_const(
- %0,
- dense<[[3, 2, 1], [6, 5, 4]]> : tensor<2x3xi32>
- ) : tensor<2x3xi32>
-
- return
-}
-
-func.func @reverse_multi_dims() {
- %input = util.unfoldable_constant dense<[[1, 2, 3],
- [4, 5, 6]]> : tensor<2x3xi32>
-
- %init = tensor.empty() : tensor<2x3xi32>
- %0 = iree_linalg_ext.reverse
- dimensions(dense<[0, 1]> : tensor<2xi64>)
- ins(%input : tensor<2x3xi32>)
- outs(%init : tensor<2x3xi32>) : tensor<2x3xi32>
-
- check.expect_eq_const(
- %0,
- dense<[[6, 5, 4], [3, 2, 1]]> : tensor<2x3xi32>
- ) : tensor<2x3xi32>
-
- return
-}