[LinalgExt] Retire `LinalgExt::ReverseOp` (#17866)

`LinalgExt::ReverseOp` is only lowered from `stablehlo::ReverseOp`. We
can expand `stablehlo::ReverseOp` to a different pattern and retire
`LinalgExt::ReverseOp`.

Fixes https://github.com/iree-org/iree/issues/16060

---------

Signed-off-by: Alan Li <me@alanli.org>
diff --git a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp
index a20f872..cc7fa2b 100644
--- a/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp
+++ b/compiler/plugins/input/StableHLO/Conversion/StableHLOToLinalgExt.cpp
@@ -7,6 +7,7 @@
 // Implements IREE-specific logic for lowering StableHLO/CHLO dialects to
 // LinalgExt dialect.
 
+#include <algorithm>
 #include <cmath>
 #include <complex>
 #include <memory>
@@ -427,7 +428,6 @@
 struct ReverseOpConversion final
     : OpConversionPattern<mlir::stablehlo::ReverseOp> {
   using OpConversionPattern::OpConversionPattern;
-
   LogicalResult
   matchAndRewrite(mlir::stablehlo::ReverseOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -435,14 +435,45 @@
     if (!ty)
       return failure();
 
+    Value input = op.getOperand();
+    auto inputTy = cast<ShapedType>(input.getType());
+    auto resultTy = cast<ShapedType>(op.getType());
+    ArrayRef<int64_t> dims = op.getDimensions();
     Location loc = op.getLoc();
-    SmallVector<OpFoldResult> mixedSizes =
-        tensor::getMixedSizes(rewriter, loc, adaptor.getOperands()[0]);
-    Value emptyTensor =
-        rewriter.create<tensor::EmptyOp>(loc, mixedSizes, ty.getElementType());
-    rewriter.replaceOpWithNewOp<IREE::LinalgExt::ReverseOp>(
-        op, typeConverter->convertType(op.getType()), adaptor.getOperands(),
-        emptyTensor, rewriter.getI64TensorAttr(op.getDimensions()));
+    int64_t inputTyRank = inputTy.getRank();
+
+    // First fill the output buffer with the init value.
+    SmallVector<OpFoldResult> inputMixedSizes =
+        tensor::getMixedSizes(rewriter, loc, input);
+    auto emptyTensor = rewriter.create<tensor::EmptyOp>(
+        loc, inputMixedSizes, inputTy.getElementType());
+    SmallVector<AffineMap> affineMaps = {
+        rewriter.getMultiDimIdentityMap(resultTy.getRank())};
+
+    rewriter.replaceOpWithNewOp<linalg::GenericOp>(
+        op, resultTy, ArrayRef<Value>({}), ValueRange{emptyTensor}, affineMaps,
+        getNParallelLoopsAttrs(resultTy.getRank()),
+        [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+          llvm::SmallVector<Value> indices;
+          for (unsigned int i = 0; i < inputTyRank; i++) {
+            Value index =
+                rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
+            if (std::find(dims.begin(), dims.end(), i) != dims.end()) {
+              auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
+              Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, i);
+              auto sizeMinusOne =
+                  rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
+              index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
+                                                     index);
+            }
+            indices.push_back(index);
+          }
+
+          auto extract = nestedBuilder.create<tensor::ExtractOp>(
+              nestedLoc, input, indices);
+          nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
+                                                extract.getResult());
+        });
     return success();
   }
 };
diff --git a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_ext.mlir b/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_ext.mlir
index 917f2f8..09b2bd4 100644
--- a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_ext.mlir
+++ b/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_ext.mlir
@@ -495,12 +495,17 @@
   return %0 : tensor<3x5xi32>
 }
 // CHECK:        %[[INIT:.+]] = tensor.empty() : tensor<3x5xi32>
-// CHECK:        %[[REV:.+]] = iree_linalg_ext.reverse
-// CHECK-SAME:     dimensions(dense<1> : tensor<1xi64>)
-// CHECK-SAME:     ins(%[[IN]] : tensor<3x5xi32>)
-// CHECK-SAME:     outs(%[[INIT]] : tensor<3x5xi32>) : tensor<3x5xi32>
-// CHECK:        return %[[REV]]
-
+// CHECK:        %[[GEN:.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor<3x5xi32>) {
+// CHECK:        %[[SAME_DIM:.+]] = linalg.index 0 : index
+// CHECK:        %[[REV_DIM:.+]] = linalg.index 1 : index
+// CHECK:        %[[C1:.+]] = arith.constant 1 : index
+// CHECK:        %[[C1_0:.+]] = arith.constant 1 : index
+// CHECK:        %[[DIM:.+]] = tensor.dim %arg0, %[[C1_0]] : tensor<3x5xi32>
+// CHECK:        %[[DIMSUB1:.+]] = arith.subi %[[DIM]], %[[C1]] : index
+// CHECK:        %[[REV_IDX:.+]] = arith.subi %[[DIMSUB1]], %[[REV_DIM]] : index
+// CHECK:        %[[EXTRACTED:.+]] = tensor.extract %arg0[%[[SAME_DIM]], %[[REV_IDX]]] : tensor<3x5xi32>
+// CHECK:        linalg.yield %[[EXTRACTED]] : i32
+// CHECK:        return %[[GEN]]
 // -----
 
 func.func @reverse_unsigned(%arg0: tensor<3x5xui32>) -> tensor<3x5xui32> {
@@ -512,13 +517,18 @@
 // CHECK-LABEL: func.func @reverse_unsigned
 // CHECK-SAME:   %[[IN:[a-zA-Z0-9]+]]
 // CHECK:        %[[BITCAST:.+]] = builtin.unrealized_conversion_cast %[[IN]] : tensor<3x5xui32> to tensor<3x5xi32>
-// CHECK:        %[[INIT:.+]] = tensor.empty() : tensor<3x5xi32>
-// CHECK:        %[[REV:.+]] = iree_linalg_ext.reverse
-// CHECK-SAME:     dimensions(dense<1> : tensor<1xi64>)
-// CHECK-SAME:     ins(%[[BITCAST]] : tensor<3x5xi32>)
-// CHECK-SAME:     outs(%[[INIT]] : tensor<3x5xi32>) : tensor<3x5xi32>
-// CHECK:        %[[BITCAST:.+]] = builtin.unrealized_conversion_cast %[[REV]] : tensor<3x5xi32> to tensor<3x5xui32>
-// CHECK:        return %[[BITCAST]]
+// CHECK:        %[[INIT:.+]] = tensor.empty() : tensor<3x5xui32>
+// CHECK:        %[[GEN:.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor<3x5xui32>)
+// CHECK:        %[[SAME_DIM:.+]] = linalg.index 0 : index
+// CHECK:        %[[REV_DIM:.+]] = linalg.index 1 : index
+// CHECK:        %[[C1:.+]] = arith.constant 1 : index
+// CHECK:        %[[C1_0:.+]] = arith.constant 1 : index
+// CHECK:        %[[DIM:.+]] = tensor.dim %arg0, %[[C1_0]] : tensor<3x5xui32>
+// CHECK:        %[[DIMSUB1:.+]] = arith.subi %[[DIM]], %[[C1]] : index
+// CHECK:        %[[REV_IDX:.+]] = arith.subi %[[DIMSUB1]], %[[REV_DIM]] : index
+// CHECK:        %[[EXTRACTED:.+]] = tensor.extract %arg0[%[[SAME_DIM]], %[[REV_IDX]]] : tensor<3x5xui32>
+// CHECK:        linalg.yield %[[EXTRACTED]] : ui32
+// CHECK:        return %[[GEN]]
 
 // -----
 
@@ -530,16 +540,32 @@
   } : (tensor<?x?xi32>) -> tensor<?x?xi32>
   return %0 : tensor<?x?xi32>
 }
-// CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG:    %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG:    %[[D0:.+]] = tensor.dim %[[IN]], %[[C0]]
-// CHECK-DAG:    %[[D1:.+]] = tensor.dim %[[IN]], %[[C1]]
-// CHECK:        %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]]) : tensor<?x?xi32>
-// CHECK:        %[[REV:.+]] = iree_linalg_ext.reverse
-// CHECK-SAME:     dimensions(dense<[0, 1]> : tensor<2xi64>)
-// CHECK-SAME:     ins(%[[IN]] : tensor<?x?xi32>)
-// CHECK-SAME:     outs(%[[INIT]] : tensor<?x?xi32>) : tensor<?x?xi32>
-// CHECK:        return %[[REV]]
+// CHECK:    %[[C0:.+]] = arith.constant 0 : index
+// CHECK:    %[[D:.+]] = tensor.dim %[[IN]], %[[C0]] : tensor<?x?xi32>
+// CHECK:    %[[C1:.+]] = arith.constant 1 : index
+// CHECK:    %[[D0:.+]] = tensor.dim %[[IN]], %[[C1]] : tensor<?x?xi32>
+// CHECK:    %[[INIT:.+]] = tensor.empty(%[[D]], %[[D0]]) : tensor<?x?xi32>
+// CHECK:    %[[GEN:.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor<?x?xi32>) {
+
+// First reverse dimension
+// CHECK:       %[[IDX0:.+]] = linalg.index 0 : index
+// CHECK:       %[[C1_1:.+]] = arith.constant 1 : index
+// CHECK:       %[[C0_2:.+]] = arith.constant 0 : index
+// CHECK:       %[[DIM0:.+]] = tensor.dim %arg0, %[[C0_2]] : tensor<?x?xi32>
+// CHECK:       %[[DIM0SUB1:.+]] = arith.subi %[[DIM0]], %[[C1_1]] : index
+// CHECK:       %[[REV_IDX0:.+]] = arith.subi %[[DIM0SUB1]], %[[IDX0]] : index
+
+// Second reverse dimension
+// CHECK:       %[[IDX1:.+]] = linalg.index 1 : index
+// CHECK:       %[[C1_4:.+]] = arith.constant 1 : index
+// CHECK:       %[[C1_5:.+]] = arith.constant 1 : index
+// CHECK:       %[[DIM1:.+]] = tensor.dim %arg0, %[[C1_5]] : tensor<?x?xi32>
+// CHECK:       %[[DIM1SUB1:.+]] = arith.subi %[[DIM1]], %[[C1_4]] : index
+// CHECK:       %[[REV_IDX1:.+]] = arith.subi %[[DIM1SUB1]], %[[IDX1]] : index
+
+// CHECK:        %[[EXTRACTED:.+]] = tensor.extract %arg0[%[[REV_IDX0]], %[[REV_IDX1]]] : tensor<?x?xi32>
+// CHECK:        linalg.yield %[[EXTRACTED]] : i32
+// CHECK:        return %[[GEN]]
 
 // -----
 
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir b/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
index 68c0873..c075e09 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/convert_to_destination_passing_style.mlir
@@ -526,46 +526,6 @@
 
 // -----
 
-func.func @linalg_ext_reverse_dim0() {
-  %c1 = arith.constant 1 : index
-  %c2 = arith.constant 2 : index
-  %c3 = arith.constant 3 : index
-  %c0 = arith.constant 0 : index
-  %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<2x3xf32>>
-  %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x3xf32>>
-  %workgroup_id_x = hal.interface.workgroup.id[0] : index
-  %workgroup_count_x = hal.interface.workgroup.count[0] : index
-  %workgroup_id_y = hal.interface.workgroup.id[1] : index
-  %workgroup_count_y = hal.interface.workgroup.count[1] : index
-  %2 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_y]
-  %3 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_y]
-  scf.for %arg0 = %2 to %c2 step %3 {
-    %4 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
-    %5 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_x]
-    scf.for %arg1 = %4 to %c3 step %5 {
-      %6 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1], sizes = [2, 3], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2x3xf32>> -> tensor<2x3xf32>
-      %7 = tensor.empty() : tensor<2x3xf32>
-      %8 = iree_linalg_ext.reverse dimensions(dense<0> : tensor<1xi64>) ins(%6 : tensor<2x3xf32>) outs(%7 : tensor<2x3xf32>) : tensor<2x3xf32>
-      %9 = affine.apply affine_map<()[s0] -> (-s0)>()[%arg0]
-      flow.dispatch.tensor.store %8, %1, offsets = [%9, %arg1], sizes = [2, 3], strides = [%c1, %c1] : tensor<2x3xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x3xf32>>
-    }
-  }
-  return
-}
-//      CHECK: func.func @linalg_ext_reverse_dim0()
-//  CHECK-DAG:   %[[IN:.+]] = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
-//  CHECK-DAG:   %[[OUT:.+]] = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
-//      CHECK:   scf.for %[[IV0:.+]] =
-//      CHECK:     scf.for %[[IV1:.+]] =
-//  CHECK-DAG:       %[[IN_TILE:.+]] = flow.dispatch.tensor.load %[[IN]]
-//  CHECK-DAG:       %[[OUT_TILE:.+]] = flow.dispatch.tensor.load %[[OUT]]
-//      CHECK:       %[[REV_TILE:.+]] = iree_linalg_ext.reverse
-// CHECK-SAME:           ins(%[[IN_TILE]] : tensor<2x3xf32>)
-// CHECK-SAME:           outs(%[[OUT_TILE]] : tensor<2x3xf32>)
-//      CHECK:       flow.dispatch.tensor.store %[[REV_TILE]], %[[OUT]]
-
-// -----
-
 func.func @sort1D() {
   %c0 = arith.constant 0 : index
   %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readwrite:tensor<4xi32>>
diff --git a/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir b/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
index 911a44d..f360aad 100644
--- a/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
+++ b/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir
@@ -2170,30 +2170,6 @@
 
 // -----
 
-// CHECK-LABEL: func.func @reverse_dim(
-//   CHECK-DAG:   %[[alloc:.*]] = memref.alloc()
-//   CHECK-DAG:   %[[cst:.*]] = bufferization.to_memref
-//       CHECK:   iree_linalg_ext.reverse dimensions(dense<0> : tensor<1xi64>)
-//  CHECK-SAME:       ins(%[[cst]] :
-//  CHECK-SAME:       outs(%[[alloc]] :
-//       CHECK:   %[[load:.*]] = memref.load %[[alloc]]
-//       CHECK:   return %[[load]]
-func.func @reverse_dim(%pos: index) -> f32 {
-  %input = arith.constant dense<[[1.0, 2.0, 3.0],
-                                 [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
-
-  %init = bufferization.alloc_tensor() : tensor<2x3xf32>
-  %0 = iree_linalg_ext.reverse
-         dimensions(dense<0> : tensor<1xi64>)
-         ins(%input : tensor<2x3xf32>)
-         outs(%init : tensor<2x3xf32>) : tensor<2x3xf32>
-
-  %1 = tensor.extract %0[%pos, %pos] : tensor<2x3xf32>
-  return %1 : f32
-}
-
-// -----
-
 // CHECK-LABEL: func.func @fft_tensor(
 //       CHECK:   memref.alloc
 //       CHECK:   memref.alloc
diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp b/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp
index 0422f16..8e52db0 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp
@@ -350,9 +350,9 @@
 
   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                               const AnalysisState &state) const {
-    // TODO: Revisit this for Scatter/ReverseOp. We can then get rid of
+    // TODO: Revisit this for ScatterOp. We can then get rid of
     //       `bufferizesToMemoryRead` completely.
-    return !isa<IREE::LinalgExt::ScatterOp, IREE::LinalgExt::ReverseOp>(op);
+    return !isa<IREE::LinalgExt::ScatterOp>(op);
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -630,8 +630,6 @@
         LinalgExtOpInterface<IREE::LinalgExt::PackOp>>(*ctx);
     IREE::LinalgExt::UnPackOp::attachInterface<
         LinalgExtOpInterface<IREE::LinalgExt::UnPackOp>>(*ctx);
-    IREE::LinalgExt::ReverseOp::attachInterface<
-        LinalgExtOpInterface<IREE::LinalgExt::ReverseOp>>(*ctx);
     IREE::LinalgExt::ScanOp::attachInterface<
         LinalgExtOpInterface<IREE::LinalgExt::ScanOp>>(*ctx);
     IREE::LinalgExt::ScatterOp::attachInterface<
diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.cpp b/compiler/src/iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.cpp
index cd289a5..b695e1e 100644
--- a/compiler/src/iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.cpp
+++ b/compiler/src/iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.cpp
@@ -241,8 +241,6 @@
         OuterParallelAsPartitionableLoops<IREE::LinalgExt::ScatterOp>>(*ctx);
     IREE::LinalgExt::SortOp::attachInterface<
         AllParallelAsPartitionableLoops<IREE::LinalgExt::SortOp>>(*ctx);
-    IREE::LinalgExt::ReverseOp::attachInterface<
-        OuterParallelAsPartitionableLoops<IREE::LinalgExt::ReverseOp>>(*ctx);
     IREE::LinalgExt::TopkOp::attachInterface<
         AllParallelAsPartitionableLoops<IREE::LinalgExt::TopkOp>>(*ctx);
     IREE::LinalgExt::WinogradInputTransformOp::attachInterface<
diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_ext_fusion.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_ext_fusion.mlir
index 6621451..4cca860 100644
--- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_ext_fusion.mlir
+++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/dispatch_linalg_ext_fusion.mlir
@@ -46,41 +46,6 @@
 // CHECK:         %[[GEN2:.+]] = linalg.generic
 // CHECK-SAME:      ins(%[[INPUT:.+]] : tensor<8192x16x8x128xf32>)
 
-
-
-// -----
-
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-util.func public @linalgext_reverse_fusion() -> tensor<10x10xi32> {
-  %0 = tensor.empty() : tensor<10x10xi64>
-  %1 = tensor.empty() : tensor<10x10xi32>
-  %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<10x10xi64>) outs(%1 : tensor<10x10xi32>) {
-  ^bb0(%in: i64, %out: i32):
-    %7 = arith.trunci %in : i64 to i32
-    linalg.yield %7 : i32
-  } -> tensor<10x10xi32>
-  %3 = tensor.empty() : tensor<10x10xi32>
-  %4 = iree_linalg_ext.reverse dimensions(dense<0> : tensor<1xi64>) ins(%2 : tensor<10x10xi32>) outs(%3 : tensor<10x10xi32>) : tensor<10x10xi32>
-
-  // dont fuse with with reverse's consumer
-  %5 = tensor.empty() : tensor<10x10xi32>
-  %6 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%4 : tensor<10x10xi32>) outs(%5 : tensor<10x10xi32>) {
-  ^bb0(%in: i32, %out: i32):
-    %7 = arith.addi %in, %out : i32
-    linalg.yield %7 : i32
-  } -> tensor<10x10xi32>
-  util.return %6 : tensor<10x10xi32>
-}
-
-// CHECK:     util.func public @linalgext_reverse_fusion
-// CHECK:       flow.dispatch.workgroups
-// CHECK:       %[[SHRUNK:.+]] = linalg.generic
-// CHECK:       %[[REVERSED:.+]] = iree_linalg_ext.reverse
-// CHECK:         ins(%[[SHRUNK]] : tensor<10x10xi32>)
-// CHECK:       flow.dispatch.workgroups
-// CHECK:       %[[GEN:.+]] = linalg.generic
-
 // -----
 
 #map = affine_map<(d0, d1) -> (d0, d1)>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index ce40d00..4df93b3 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -440,71 +440,6 @@
 }
 
 //===----------------------------------------------------------------------===//
-// ReverseOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult ReverseOp::verify() {
-  Operation *op = getOperation();
-  if (getNumDpsInputs() != 1) {
-    return op->emitOpError("expected exactly one input");
-  }
-  if (getNumDpsInits() != 1) {
-    return op->emitOpError("expected exactly one output");
-  }
-  auto inputType = cast<ShapedType>(getInput().getType());
-  auto outputType = cast<ShapedType>(getOutput().getType());
-  if (inputType.getElementType() != outputType.getElementType()) {
-    return op->emitOpError(
-        "expected input/output element types to be identical");
-  }
-  ArrayRef<int64_t> inputShapes = inputType.getShape();
-  ArrayRef<int64_t> outputShapes = outputType.getShape();
-  if (inputShapes.size() != outputShapes.size()) {
-    return op->emitOpError("expexted input/output to have identical ranks");
-  }
-  if (llvm::any_of(llvm::zip_equal(inputShapes, outputShapes),
-                   [](std::tuple<int64_t, int64_t> s) {
-                     return !ShapedType::isDynamic(std::get<0>(s)) &&
-                            !ShapedType::isDynamic(std::get<1>(s)) &&
-                            std::get<0>(s) != std::get<1>(s);
-                   })) {
-    return op->emitOpError("incompatible input/output shapes");
-  }
-
-  int64_t rank = getOperandRank();
-  llvm::SmallSetVector<int64_t, 4> s;
-  for (auto dim : getDimensionsArray()) {
-    if (dim < 0 || dim >= rank) {
-      return op->emitOpError("all the dimensions must be within [0, ")
-             << rank << ")";
-    }
-    if (s.contains(dim)) {
-      return op->emitOpError("expected dimensions numbers are all unique");
-    }
-    s.insert(dim);
-  }
-
-  return success();
-}
-
-LogicalResult
-ReverseOp::reifyResultShapes(OpBuilder &b,
-                             ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
-  return cast<LinalgExtOp>(getOperation())
-      .reifyResultShapes(b, reifiedReturnShapes);
-}
-
-SmallVector<AffineMap> ReverseOp::getIndexingMapsForOperands() {
-  Builder builder(getContext());
-  return {builder.getMultiDimIdentityMap(getOperandRank()),
-          /*output=*/AffineMap(nullptr)};
-}
-
-SmallVector<AffineMap> ReverseOp::getIndexingMapsForResults() {
-  return {AffineMap(nullptr)};
-}
-
-//===----------------------------------------------------------------------===//
 // TopkOp
 //===----------------------------------------------------------------------===//
 
@@ -1583,7 +1518,6 @@
 DEFINE_OP_GET_EFFECTS(ScatterOp)
 DEFINE_OP_GET_EFFECTS(SortOp)
 DEFINE_OP_GET_EFFECTS(FftOp)
-DEFINE_OP_GET_EFFECTS(ReverseOp)
 DEFINE_OP_GET_EFFECTS(ScanOp)
 DEFINE_OP_GET_EFFECTS(TopkOp)
 DEFINE_OP_GET_EFFECTS(PackOp)
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index 993ab3b..fe8693a 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -369,67 +369,6 @@
   }];
 }
 
-def IREELinalgExt_ReverseOp : IREELinalgExt_Op<"reverse", [
-  DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
-  DeclareOpInterfaceMethods<LinalgFusionInterface>,
-  DeclareOpInterfaceMethods<
-      TilingInterface,
-      ["generateScalarImplementation",
-       "getIterationDomain",
-       "getLoopIteratorTypes",
-       "getResultTilePosition",
-       "getTiledImplementation"]>,
-  DeclareOpInterfaceMethods<LinalgExtInterface>]> {
-  let summary = "Reverse operator";
-  let description = [{
-    A temporary solution for lowering reverse ops into IREE, allowing IREE to
-    tile and distribute them.
-    }
-  }];
-
-  let arguments = (ins Variadic<AnyShaped>:$inputs,
-                       Variadic<AnyShaped>:$outputs,
-                       I64ElementsAttr:$dimensions
-  );
-  let results = (outs Variadic<AnyRankedTensor>:$results);
-  let assemblyFormat = [{
-    attr-dict `dimensions` `(` $dimensions `)`
-    (`ins` `(` $inputs^ `:` type($inputs) `)`)?
-    (`outs` `(` $outputs^ `:` type($outputs) `)`)?
-    (`:` type($results)^)?
-  }];
-  let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
-    Value getInput() {
-      return getDpsInputOperand(0)->get();
-    }
-    Value getOutput() {
-      return getDpsInitOperand(0)->get();
-    }
-    ShapedType getOperandType() {
-      return cast<ShapedType>(getInput().getType());
-    }
-    int64_t getOperandRank() {
-      return getOperandType().getRank();
-    }
-    ArrayRef<int64_t> getOprerandShape() {
-      return getOperandType().getShape();
-    }
-    SmallVector<int64_t> getDimensionsArray() {
-      SmallVector<int64_t> ret;
-      for (const APInt& elem : getDimensions()) {
-        ret.push_back(elem.getLimitedValue());
-      }
-      return ret;
-    }
-
-    // Method to implement for specifying output range for
-    // DestinationStyleOpInterface
-    MutableOperandRange getDpsInitsMutable() {
-      return getOutputsMutable();
-    }
-  }];
-}
-
 def IREELinalgExt_TopkOp : IREELinalgExt_Op<"topk",[
   DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
   DeclareOpInterfaceMethods<LinalgExtInterface>,
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/canonicalize.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/canonicalize.mlir
index a2f89e8..87ee429 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/canonicalize.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/canonicalize.mlir
@@ -1,27 +1,5 @@
 // RUN: iree-opt --canonicalize --split-input-file %s | FileCheck %s
 
-func.func @tensor_cast(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> {
-  %init = tensor.empty() : tensor<3x5xi32>
-
-  %casted_arg0 = tensor.cast %arg0 : tensor<3x5xi32> to tensor<?x?xi32>
-  %casted_init = tensor.cast %init : tensor<3x5xi32> to tensor<?x?xi32>
-
-  %0 = iree_linalg_ext.reverse
-         dimensions(dense<0> : tensor<1xi64>)
-         ins(%casted_arg0 : tensor<?x?xi32>)
-         outs(%casted_init : tensor<?x?xi32>) : tensor<?x?xi32>
-
-  %1 = tensor.cast %0 : tensor<?x?xi32> to tensor<3x5xi32>
-
-  return %1: tensor<3x5xi32>
-}
-// CHECK-LABEL: func.func @tensor_cast(
-//       CHECK:      iree_linalg_ext.reverse
-//  CHECK-SAME:   ins(%{{[a-zA-Z0-9]*}} : tensor<3x5xi32>)
-//  CHECK-SAME:  outs(%{{[a-zA-Z0-9]*}} : tensor<3x5xi32>)
-
-// -----
-
 func.func @pack_canonicalize(%arg0 : tensor<?x?xi32>,
     %arg1 : tensor<1x2x3x3xi32>) -> tensor<1x?x3x3xi32> {
   %c0_i32 = arith.constant 0 : i32
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
index 3363f1b..1d0280b 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
@@ -338,42 +338,6 @@
 
 // -----
 
-func.func @reverse_diff_element_type(%arg0: tensor<3x5xi32>) -> tensor<3x5xf32> {
-  %init = tensor.empty() : tensor<3x5xf32>
-  // expected-error @+1 {{expected input/output element types to be identical}}
-  %0 = iree_linalg_ext.reverse
-         dimensions(dense<0> : tensor<1xi64>)
-         ins(%arg0 : tensor<3x5xi32>)
-         outs(%init : tensor<3x5xf32>) : tensor<3x5xf32>
-  return %0 : tensor<3x5xf32>
-}
-
-// -----
-
-func.func @reverse_diff_shape(%arg0: tensor<3x5xi32>) -> tensor<3x6xi32> {
-  %init = tensor.empty() : tensor<3x6xi32>
-  // expected-error @+1 {{incompatible input/output shapes}}
-  %0 = iree_linalg_ext.reverse
-         dimensions(dense<0> : tensor<1xi64>)
-         ins(%arg0 : tensor<3x5xi32>)
-         outs(%init : tensor<3x6xi32>) : tensor<3x6xi32>
-  return %0 : tensor<3x6xi32>
-}
-
-// -----
-
-func.func @reverse_dup_dims(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> {
-  %init = tensor.empty() : tensor<3x5xi32>
-  // expected-error @+1 {{expected dimensions numbers are all unique}}
-  %0 = iree_linalg_ext.reverse
-         dimensions(dense<[0, 0]> : tensor<2xi64>)
-         ins(%arg0 : tensor<3x5xi32>)
-         outs(%init : tensor<3x5xi32>) : tensor<3x5xi32>
-  return %0 : tensor<3x5xi32>
-}
-
-// -----
-
 func.func @topk_invalid(%input_values: tensor<2x10xf32>, %input_indices: tensor<2x10xi32>, %out_values : tensor<2x3xf32>, %out_indices: tensor<2x3xi32>) -> (tensor<2x3xf32>, tensor<2x3xi32>) {
   // expected-error@+1 {{expected one or two input operands}}
   %0:2 = iree_linalg_ext.topk
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
index 88a1f2d..eddb2c7 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
@@ -484,111 +484,6 @@
 
 // -----
 
-func.func @reverse_tensor(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> {
-  %init = tensor.empty() : tensor<3x5xi32>
-  %0 = iree_linalg_ext.reverse
-         dimensions(dense<0> : tensor<1xi64>)
-         ins(%arg0 : tensor<3x5xi32>)
-         outs(%init : tensor<3x5xi32>) : tensor<3x5xi32>
-  return %0 : tensor<3x5xi32>
-}
-// CHECK-LABEL: func.func @reverse_tensor
-//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<3x5xi32>
-//       CHECK:   %[[INIT:.+]] = tensor.empty()
-//       CHECK:   %[[RESULT:.+]] = iree_linalg_ext.reverse
-//  CHECK-SAME:      dimensions(dense<0> : tensor<1xi64>)
-//  CHECK-SAME:      ins(%[[ARG0]]
-//  CHECK-SAME:      outs(%[[INIT]]
-
-// -----
-
-func.func @reverse_memref(%arg0: memref<3x5xi32>, %arg1: memref<3x5xi32>) {
-  iree_linalg_ext.reverse
-    dimensions(dense<0> : tensor<1xi64>)
-    ins(%arg0 : memref<3x5xi32>)
-    outs(%arg1 : memref<3x5xi32>)
-  return
-}
-// CHECK-LABEL: func.func @reverse_memref
-//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: memref<3x5xi32>
-//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: memref<3x5xi32>
-//       CHECK:   iree_linalg_ext.reverse
-//  CHECK-SAME:      dimensions(dense<0> : tensor<1xi64>)
-//  CHECK-SAME:      ins(%[[ARG0]]
-//  CHECK-SAME:      outs(%[[ARG1]]
-
-// -----
-
-func.func @reverse_dynamic_tensor(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
-  %d1 = tensor.dim %arg0, %c1 : tensor<?x?xi32>
-  %init = tensor.empty(%d0, %d1) : tensor<?x?xi32>
-  %0 = iree_linalg_ext.reverse
-         dimensions(dense<1> : tensor<1xi64>)
-         ins(%arg0 : tensor<?x?xi32>)
-         outs(%init : tensor<?x?xi32>) : tensor<?x?xi32>
-  return %0 : tensor<?x?xi32>
-}
-// CHECK-LABEL: func.func @reverse_dynamic_tensor
-//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xi32>
-//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
-//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-//   CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
-//   CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
-//       CHECK:   %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]])
-//       CHECK:   %[[RESULT:.+]] = iree_linalg_ext.reverse
-//  CHECK-SAME:      dimensions(dense<1> : tensor<1xi64>)
-//  CHECK-SAME:      ins(%[[ARG0]]
-//  CHECK-SAME:      outs(%[[INIT]]
-
-// -----
-
-func.func @reverse_static_dynamic_tensor(%arg0: tensor<3x5xi32>) -> tensor<?x?xi32> {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %d0 = tensor.dim %arg0, %c0 : tensor<3x5xi32>
-  %d1 = tensor.dim %arg0, %c1 : tensor<3x5xi32>
-  %init = tensor.empty(%d0, %d1) : tensor<?x?xi32>
-  %0 = iree_linalg_ext.reverse
-         dimensions(dense<1> : tensor<1xi64>)
-         ins(%arg0 : tensor<3x5xi32>)
-         outs(%init : tensor<?x?xi32>) : tensor<?x?xi32>
-  return %0 : tensor<?x?xi32>
-}
-// CHECK-LABEL: func.func @reverse_static_dynamic_tensor
-//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<3x5xi32>
-//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
-//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-//   CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
-//   CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
-//       CHECK:   %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]])
-//       CHECK:   %[[RESULT:.+]] = iree_linalg_ext.reverse
-//  CHECK-SAME:      dimensions(dense<1> : tensor<1xi64>)
-//  CHECK-SAME:      ins(%[[ARG0]]
-//  CHECK-SAME:      outs(%[[INIT]]
-
-// -----
-
-func.func @reverse_multi_dims(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> {
-  %init = tensor.empty() : tensor<3x5xi32>
-  %0 = iree_linalg_ext.reverse
-         dimensions(dense<[0, 1]> : tensor<2xi64>)
-         ins(%arg0 : tensor<3x5xi32>)
-         outs(%init : tensor<3x5xi32>) : tensor<3x5xi32>
-  return %0 : tensor<3x5xi32>
-}
-// CHECK-LABEL: func.func @reverse_multi_dims
-//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<3x5xi32>
-//       CHECK:   %[[INIT:.+]] = tensor.empty()
-//       CHECK:   %[[RESULT:.+]] = iree_linalg_ext.reverse
-//  CHECK-SAME:      dimensions(dense<[0, 1]> : tensor<2xi64>)
-//  CHECK-SAME:      ins(%[[ARG0]]
-//  CHECK-SAME:      outs(%[[INIT]]
-
-// -----
-
 func.func @topk_tensor(%input_values: tensor<20x10x8x4xf32>, %input_indices: tensor<20x10x8x4xi32>) -> (tensor<20x10x3x4xf32>, tensor<20x10x3x4xi32>) {
   %out_values = tensor.empty() : tensor<20x10x3x4xf32>
   %out_indices = tensor.empty() : tensor<20x10x3x4xi32>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
index 231f694..2565df8 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/TilingInterfaceImpl.cpp
@@ -730,101 +730,6 @@
 }
 
 //===----------------------------------------------------------------------===//
-// ReverseOp
-//===----------------------------------------------------------------------===//
-
-SmallVector<utils::IteratorType> ReverseOp::getLoopIteratorTypes() {
-  SmallVector<utils::IteratorType> iteratorTypes(getOperandRank(),
-                                                 utils::IteratorType::parallel);
-  return iteratorTypes;
-}
-
-SmallVector<Range> ReverseOp::getIterationDomain(OpBuilder &builder) {
-  Location loc = getLoc();
-  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
-  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
-  SmallVector<Range> ranges;
-  for (auto dim : llvm::seq<int64_t>(0, getOperandRank())) {
-    Value ub = getDimValue(builder, loc, getInput(), dim);
-    ranges.emplace_back(Range{zero, ub, one});
-  }
-  return ranges;
-}
-
-LogicalResult ReverseOp::generateScalarImplementation(OpBuilder &b,
-                                                      Location loc,
-                                                      ValueRange ivs) {
-  SmallVector<Value> mirrorIndices(ivs.begin(), ivs.end());
-  for (auto dim : getDimensionsArray()) {
-    auto size = getDimValue(b, loc, getInput(), dim);
-    size = b.create<arith::SubIOp>(loc, size,
-                                   b.create<arith::ConstantIndexOp>(loc, 1));
-    mirrorIndices[dim] = b.create<arith::SubIOp>(loc, size, mirrorIndices[dim]);
-  }
-  Value val = b.create<memref::LoadOp>(loc, getInput(), ivs);
-  b.create<memref::StoreOp>(loc, val, getOutput(), mirrorIndices);
-  return success();
-}
-
-FailureOr<TilingResult>
-ReverseOp::getTiledImplementation(OpBuilder &builder,
-                                  ArrayRef<OpFoldResult> offsets,
-                                  ArrayRef<OpFoldResult> sizes) {
-  int64_t rank = getOperandRank();
-  SmallVector<OpFoldResult> strides(rank, builder.getI64IntegerAttr(1));
-  Location loc = getLoc();
-  SmallVector<OpFoldResult> mirrorOffsets, mirrorSizes;
-  if (failed(getResultTilePosition(builder, 0, offsets, sizes, mirrorOffsets,
-                                   mirrorSizes))) {
-    return {};
-  }
-
-  SmallVector<Value> tiledOperands;
-  tiledOperands.emplace_back(
-      getSlice(builder, loc, getInput(), offsets, sizes, strides));
-
-  SmallVector<Type, 4> resultTypes;
-  if (hasPureTensorSemantics()) {
-    tiledOperands.emplace_back(
-        getSlice(builder, loc, getOutput(), mirrorOffsets, sizes, strides));
-    resultTypes.push_back(tiledOperands[1].getType());
-  } else {
-    tiledOperands.emplace_back(
-        getSlice(builder, loc, getOutput(), mirrorOffsets, sizes, strides));
-  }
-
-  Operation *tiledRevOp =
-      mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
-
-  return TilingResult{{tiledRevOp},
-                      SmallVector<Value>(tiledRevOp->getResults())};
-}
-
-LogicalResult ReverseOp::getResultTilePosition(
-    OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
-    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
-    SmallVector<OpFoldResult> &resultSizes) {
-  AffineExpr sym0, sym1, sym2;
-  bindSymbols(builder.getContext(), sym0, sym1, sym2);
-  AffineMap map =
-      AffineMap::get(/*dimCount=*/0, /*symbolCount=*/3, {sym0 - sym1 - sym2});
-  resultOffsets.assign(offsets.begin(), offsets.end());
-  Location loc = getLoc();
-  for (auto dim : getDimensionsArray()) {
-    Value size = getDimValue(builder, loc, getInput(), dim);
-    Value offset =
-        getValueOrCreateConstantIndexOp(builder, loc, resultOffsets[dim]);
-    Value tileSize = getValueOrCreateConstantIndexOp(builder, loc, sizes[dim]);
-    resultOffsets[dim] = builder
-                             .create<affine::AffineApplyOp>(
-                                 loc, map, ValueRange{size, offset, tileSize})
-                             .getResult();
-  }
-  resultSizes.assign(sizes.begin(), sizes.end());
-  return success();
-}
-
-//===----------------------------------------------------------------------===//
 // TopkOp
 //===----------------------------------------------------------------------===//
 
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
index 3623426..f136ab5 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir
@@ -507,28 +507,6 @@
 
 // -----
 
-func.func @reverse_dim_0(%arg0: memref<?x?xi32>, %arg1: memref<?x?xi32>) {
-  iree_linalg_ext.reverse
-    dimensions(dense<0> : tensor<1xi64>)
-    ins(%arg0 : memref<?x?xi32>)
-    outs(%arg1 : memref<?x?xi32>)
-  return
-}
-// CHECK-LABEL: func.func @reverse_dim_0
-// CHECK-SAME:    %[[IN:[a-zA-Z0-9]+]]
-// CHECK-SAME:    %[[OUT:[a-zA-Z0-9]+]]
-// CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG:     %[[D0:.+]] = memref.dim %arg0, %c0 : memref<?x?xi32>
-// CHECK-DAG:     %[[D1:.+]] = memref.dim %arg0, %c1 : memref<?x?xi32>
-// CHECK:         scf.for %[[I:.+]] = %[[C0]] to %[[D0]] step %[[C1]]
-// CHECK:           scf.for %[[J:.+]] = %[[C0]] to %[[D1]] step %[[C1]]
-// CHECK:             %[[T0:.+]] = memref.dim %[[IN]], %[[C0]]
-// CHECK:             %[[T1:.+]] = arith.subi %[[T0]], %[[C1]] : index
-// CHECK:             %[[T2:.+]] = arith.subi %[[T1]], %[[I]] : index
-// CHECK:             %[[V0:.+]] = memref.load %[[IN]][%[[I]], %[[J]]]
-// CHECK:             memref.store %[[V0]], %[[OUT]][%[[T2]], %[[J]]] : memref<?x?xi32>
-
 func.func @scan_1d_inclusive(%0: memref<128xi32>, %1: memref<128xi32>) {
   %c0 = memref.alloc() : memref<i32>
   iree_linalg_ext.scan dimension(0) inclusive(true)
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
index 42e6193..67a189b 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir
@@ -494,95 +494,6 @@
 
 // -----
 
-func.func @reverse_memref(%arg0: memref<?xi32>, %arg1: memref<?xi32>) {
-  iree_linalg_ext.reverse
-    dimensions(dense<0> : tensor<1xi64>)
-    ins(%arg0: memref<?xi32>)
-    outs(%arg1: memref<?xi32>)
-  return
-}
-module attributes { transform.with_named_sequence } {
-  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["iree_linalg_ext.reverse"]} in %module_op : (!transform.any_op) -> !transform.any_op
-    %1, %loops = transform.structured.tile_using_for %0 tile_sizes [10] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-    transform.yield
-  }
-}
-// CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 10)
-// CHECK-DAG:  #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> (s0 - s1 - s2)>
-// CHECK:      func.func @reverse_memref(
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]
-// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]
-// CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG:    %[[C10:.+]] = arith.constant 10 : index
-// CHECK-DAG:    %[[D0:.+]] = memref.dim %[[ARG0]], %[[C0]] : memref<?xi32>
-// CHECK:        scf.for %[[I:.+]] = %[[C0]] to %[[D0]] step %[[C10]] {
-// CHECK-DAG:      %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]])[%[[D0]]]
-// CHECK-DAG:      %[[IDX:.+]] = affine.apply #[[MAP2]]()[%[[D0]], %[[I]], %[[SIZE]]]
-// CHECK-DAG:      %[[SUB_IN:.+]] =  memref.subview %[[ARG0]][%[[I]]] [%[[SIZE]]] [1]
-// CHECK-DAG:      %[[SUB_OUT:.+]] = memref.subview %[[ARG1]][%[[IDX]]] [%[[SIZE]]] [1]
-// CHECK:          iree_linalg_ext.reverse
-// CHECK-SAME:       dimensions(dense<0> : tensor<1xi64>)
-// CHECK-SAME:       ins(%[[SUB_IN]]
-// CHECK-SAME:       outs(%[[SUB_OUT]]
-
-// -----
-
-func.func @reverse_tensor_multi_dim(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
-  %c0 = arith.constant 0 : index
-  %c1 = arith.constant 1 : index
-  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
-  %d1 = tensor.dim %arg0, %c1 : tensor<?x?xi32>
-  %init = tensor.empty(%d0, %d1) : tensor<?x?xi32>
-  %0 = iree_linalg_ext.reverse
-         dimensions(dense<[0, 1]> : tensor<2xi64>)
-         ins(%arg0: tensor<?x?xi32>)
-         outs(%init: tensor<?x?xi32>) : tensor<?x?xi32>
-  return %0 : tensor<?x?xi32>
-}
-module attributes { transform.with_named_sequence } {
-  transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["iree_linalg_ext.reverse"]} in %module_op : (!transform.any_op) -> !transform.any_op
-    %1, %loops:2 = transform.structured.tile_using_for %0 tile_sizes [10, 20] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
-    transform.yield
-  }
-}
-// CHECK-DAG:  #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 10)>
-// CHECK-DAG:  #[[MAP1:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 20)>
-// CHECK-DAG:  #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> (s0 - s1 - s2)>
-// CHECK:      func.func @reverse_tensor_multi_dim(
-// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]
-// CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG:    %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG:    %[[C10:.+]] = arith.constant 10 : index
-// CHECK-DAG:    %[[C20:.+]] = arith.constant 20 : index
-// CHECK-DAG:    %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xi32>
-// CHECK-DAG:    %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xi32>
-// CHECK:        %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]]) : tensor<?x?xi32>
-// CHECK:        %[[RES:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[D0]] step %[[C10]]
-// CHECK-SAME:     iter_args(%[[INIT2:.+]] = %[[INIT]]) -> (tensor<?x?xi32>) {
-// CHECK:          %[[RES2:.+]] = scf.for %[[J:.+]] = %[[C0]] to %[[D1]] step %[[C20]]
-// CHECK-SAME:       iter_args(%[[INIT3:.+]] = %[[INIT2]]) -> (tensor<?x?xi32>) {
-// CHECK-DAG:        %[[SIZE_I:.+]] = affine.min #[[MAP0]](%[[I]])[%[[D0]]]
-// CHECK-DAG:        %[[SIZE_J:.+]] = affine.min #[[MAP1]](%[[J]])[%[[D1]]]
-// CHECK-DAG:        %[[IDX0:.+]] = affine.apply #[[MAP2]]()[%[[D0]], %[[I]], %[[SIZE_I]]]
-// CHECK-DAG:        %[[IDX1:.+]] = affine.apply #[[MAP2]]()[%[[D1]], %[[J]], %[[SIZE_J]]]
-// CHECK:            %[[SUB_IN:.+]] = tensor.extract_slice
-// CHECK-SAME:         %[[ARG0]][%[[I]], %[[J]]] [%[[SIZE_I]], %[[SIZE_J]]] [1, 1]
-// CHECK:            %[[SUB_INIT:.+]] = tensor.extract_slice
-// CHECK-SAME:         %[[INIT3]][%[[IDX0]], %[[IDX1]]] [%[[SIZE_I]], %[[SIZE_J]]] [1, 1]
-// CHECK:            %[[REV:.+]] = iree_linalg_ext.reverse
-// CHECK-SAME:          dimensions(dense<[0, 1]> : tensor<2xi64>)
-// CHECK-SAME:          ins(%[[SUB_IN]]
-// CHECK-SAME:          outs(%[[SUB_INIT]]
-// CHECK:            %[[RES3:.+]] = tensor.insert_slice %[[REV]] into
-// CHECK-SAME:         %[[INIT3]][%[[IDX0]], %[[IDX1]]] [%[[SIZE_I]], %[[SIZE_J]]] [1, 1]
-// CHECK:            scf.yield %[[RES3]]
-// CHECK:          scf.yield %[[RES2]]
-// CHECK:        return %[[RES]]
-
-// -----
-
 func.func @scan_1d(%0: tensor<128xi32>) -> tensor<128xi32> {
   %c0 = tensor.empty() : tensor<i32>
   %1 = tensor.empty() : tensor<128xi32>
diff --git a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp
index 0a9a1b5..681b336 100644
--- a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp
+++ b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp
@@ -313,8 +313,6 @@
         LinalgOpTiedOpInterface<IREE::LinalgExt::FftOp>>(*context);
     IREE::LinalgExt::ScanOp::attachInterface<
         LinalgOpTiedOpInterface<IREE::LinalgExt::ScanOp>>(*context);
-    IREE::LinalgExt::ReverseOp::attachInterface<
-        LinalgOpTiedOpInterface<IREE::LinalgExt::ReverseOp>>(*context);
     IREE::LinalgExt::TopkOp::attachInterface<
         LinalgOpTiedOpInterface<IREE::LinalgExt::TopkOp>>(*context);
     IREE::LinalgExt::WinogradInputTransformOp::attachInterface<
diff --git a/tests/e2e/linalg_ext_ops/BUILD.bazel b/tests/e2e/linalg_ext_ops/BUILD.bazel
index 8c6f3dd..f655d9b 100644
--- a/tests/e2e/linalg_ext_ops/BUILD.bazel
+++ b/tests/e2e/linalg_ext_ops/BUILD.bazel
@@ -16,7 +16,6 @@
     # keep sorted
     [
         "attention.mlir",
-        "reverse.mlir",
         "scan.mlir",
         "scatter.mlir",
         "sort.mlir",
@@ -42,7 +41,6 @@
 VMVX_SRCS = enforce_glob(
     # keep sorted
     [
-        "reverse.mlir",
         "scan.mlir",
         "scatter.mlir",
         "sort.mlir",
@@ -66,7 +64,6 @@
 LLVM_GPU_SRCS = enforce_glob(
     # keep sorted
     [
-        "reverse.mlir",
         "scan.mlir",
         "scatter.mlir",
         "sort.mlir",
@@ -107,7 +104,6 @@
     srcs = enforce_glob(
         # keep sorted
         [
-            "reverse.mlir",
             "scan.mlir",
             "scatter.mlir",
             "sort.mlir",
@@ -138,7 +134,6 @@
         include = ["*.mlir"],
         exclude = [
             "attention.mlir",
-            "reverse.mlir",  #TODO(#12415): disabled due to miscompilation on Pixel 6.
             "top-k.mlir",
         ],
     ),
diff --git a/tests/e2e/linalg_ext_ops/CMakeLists.txt b/tests/e2e/linalg_ext_ops/CMakeLists.txt
index 1e84c09..97dc732 100644
--- a/tests/e2e/linalg_ext_ops/CMakeLists.txt
+++ b/tests/e2e/linalg_ext_ops/CMakeLists.txt
@@ -15,7 +15,6 @@
     check_llvm-cpu_local-task
   SRCS
     "attention.mlir"
-    "reverse.mlir"
     "scan.mlir"
     "scatter.mlir"
     "sort.mlir"
@@ -34,7 +33,6 @@
   NAME
     check_vmvx_local-task
   SRCS
-    "reverse.mlir"
     "scan.mlir"
     "scatter.mlir"
     "sort.mlir"
@@ -51,7 +49,6 @@
   NAME
     check_cuda
   SRCS
-    "reverse.mlir"
     "scan.mlir"
     "scatter.mlir"
     "sort.mlir"
@@ -74,7 +71,6 @@
   NAME
     check_rocm_hip
   SRCS
-    "reverse.mlir"
     "scan.mlir"
     "scatter.mlir"
     "sort.mlir"
@@ -91,7 +87,6 @@
   NAME
     check_metal-spirv_vulkan
   SRCS
-    "reverse.mlir"
     "scan.mlir"
     "scatter.mlir"
     "sort.mlir"
diff --git a/tests/e2e/linalg_ext_ops/reverse.mlir b/tests/e2e/linalg_ext_ops/reverse.mlir
deleted file mode 100644
index db1610b..0000000
--- a/tests/e2e/linalg_ext_ops/reverse.mlir
+++ /dev/null
@@ -1,53 +0,0 @@
-func.func @reverse_dim0() {
-  %input = util.unfoldable_constant dense<[[1.0, 2.0, 3.0],
-                                           [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
-
-  %init = tensor.empty() : tensor<2x3xf32>
-  %0 = iree_linalg_ext.reverse
-         dimensions(dense<0> : tensor<1xi64>)
-         ins(%input : tensor<2x3xf32>)
-         outs(%init : tensor<2x3xf32>) : tensor<2x3xf32>
-
-  check.expect_almost_eq_const(
-      %0,
-      dense<[[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]]> : tensor<2x3xf32>
-  ) : tensor<2x3xf32>
-
-  return
-}
-
-func.func @reverse_dim1() {
-  %input = util.unfoldable_constant dense<[[1, 2, 3],
-                                           [4, 5, 6]]> : tensor<2x3xi32>
-
-  %init = tensor.empty() : tensor<2x3xi32>
-  %0 = iree_linalg_ext.reverse
-         dimensions(dense<1> : tensor<1xi64>)
-         ins(%input : tensor<2x3xi32>)
-         outs(%init : tensor<2x3xi32>) : tensor<2x3xi32>
-
-  check.expect_eq_const(
-      %0,
-      dense<[[3, 2, 1], [6, 5, 4]]> : tensor<2x3xi32>
-  ) : tensor<2x3xi32>
-
-  return
-}
-
-func.func @reverse_multi_dims() {
-  %input = util.unfoldable_constant dense<[[1, 2, 3],
-                                           [4, 5, 6]]> : tensor<2x3xi32>
-
-  %init = tensor.empty() : tensor<2x3xi32>
-  %0 = iree_linalg_ext.reverse
-         dimensions(dense<[0, 1]> : tensor<2xi64>)
-         ins(%input : tensor<2x3xi32>)
-         outs(%init : tensor<2x3xi32>) : tensor<2x3xi32>
-
-  check.expect_eq_const(
-      %0,
-      dense<[[6, 5, 4], [3, 2, 1]]> : tensor<2x3xi32>
-  ) : tensor<2x3xi32>
-
-  return
-}