Add pattern to map iota->sort->slice to topK (#13972)

diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp
index 6dad3e3..1569238 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp
@@ -1540,6 +1540,87 @@
   }
 };
 
+struct IotaSortSliceIsTopK final : OpRewritePattern<mlir::stablehlo::SortOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(mlir::stablehlo::SortOp op,
+                                PatternRewriter &rewriter) const override {
+    auto opOperands = op.getOperands();
+    auto opResults = op.getResults();
+    Value topKInput;
+    if (opOperands.size() != 2 || opResults.size() != 2) {
+      return rewriter.notifyMatchFailure(
+          op, "Slice that maps to TopK must have exactly two inputs/outputs");
+    }
+
+    Value inputIota;
+    // Check that one of the inputs is iota, assume that the other one is the
+    // input.
+    for (Value operand : opOperands) {
+      auto iotaOp =
+          dyn_cast_or_null<mlir::stablehlo::IotaOp>(operand.getDefiningOp());
+      if (iotaOp) {
+        inputIota = iotaOp.getResult();
+      } else {
+        topKInput = operand;
+      }
+    }
+
+    if (!inputIota) {
+      return rewriter.notifyMatchFailure(op, "Sort isn't called from Iota.");
+    }
+
+    Block &block = op.getRegion().front();
+    auto stablehloCompareOp =
+        dyn_cast<mlir::stablehlo::CompareOp>(block.front());
+    if (!stablehloCompareOp) {
+      return rewriter.notifyMatchFailure(op, "not stablehlo compare op");
+    }
+
+    auto direction = stablehloCompareOp.getComparisonDirection();
+    bool getTop = direction == mlir::stablehlo::ComparisonDirection::GT ||
+                  direction == mlir::stablehlo::ComparisonDirection::GE;
+
+    if (!getTop) {
+      return rewriter.notifyMatchFailure(op,
+                                         "Unsupported comparison direction");
+    }
+
+    Value topV, topI;
+    int64_t k;
+    // Check that the output of the sort op gets fed into a slice.
+    for (auto [idx, result] : llvm::enumerate(opResults)) {
+      auto sliceOp =
+          dyn_cast<mlir::stablehlo::SliceOp>(*result.getUsers().begin());
+      if (!sliceOp) {
+        return rewriter.notifyMatchFailure(
+            op, "Sort isn't calling into a slice op.");
+      }
+
+      for (auto stride : sliceOp.getStrides().getValues<int64_t>()) {
+        if (stride != 1) {
+          return rewriter.notifyMatchFailure(
+              op, "All slice strides must be 1 in order to match to TopK.");
+        }
+      }
+
+      // Treat the first slice as inputs, the second as indices.
+      if (idx == 0) {
+        topV = sliceOp.getResult();
+        k = sliceOp.getLimitIndices().getValues<int64_t>()[1];
+      } else {
+        topI = sliceOp.getResult();
+      }
+    }
+
+    auto topK = rewriter.create<chlo::TopKOp>(
+        op.getLoc(), TypeRange{topV.getType(), topI.getType()}, topKInput, k);
+    topV.replaceAllUsesWith(topK.getResults()[0]);
+    topI.replaceAllUsesWith(topK.getResults()[1]);
+    return success();
+  }
+};
+
 struct StableHLOToStableHLOPreprocessing final
     : impl::StableHLOToStableHLOPreprocessingBase<
           StableHLOToStableHLOPreprocessing> {
@@ -1605,6 +1686,9 @@
     // Identify known custom calls and convert them to equivalent StableHLO.
     patterns.insert<CustomCallIsTopK>(context);
 
+    // Identify an iota->sort->slice pattern that maps to TopK.
+    patterns.insert<IotaSortSliceIsTopK>(context);
+
     // Additional canonicalizers that simplify to computationally
     // less-complex operations.
     patterns.insert<DotToMul>(context);
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/stablehlo_to_stablehlo.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/stablehlo_to_stablehlo.mlir
index 90ffb6c..8d7c9d5 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/stablehlo_to_stablehlo.mlir
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/stablehlo_to_stablehlo.mlir
@@ -390,3 +390,22 @@
 // CHECK-SAME: %[[ARG0:[a-z0-9]+]]
 // CHECK: %[[VALUES:.+]], %[[INDICES:.+]] = chlo.top_k(%[[ARG0]], k = 40) : tensor<4x8000xbf16> -> (tensor<4x40xbf16>, tensor<4x40xi32>)
 // CHECK: return %[[VALUES]], %[[INDICES]] : tensor<4x40xbf16>, tensor<4x40xi32>
+
+// -----
+
+func.func @iota_sort_slice_is_topk(%in : tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) {
+  %iota = "stablehlo.iota"() { iota_dimension = 1 : i64 } : () -> tensor<16x16xi32>
+  %0:2 = "stablehlo.sort"(%in, %iota) ({
+  ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
+    %7 = "stablehlo.compare"(%arg0, %arg1) {comparison_direction = #stablehlo<comparison_direction GT>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+    "stablehlo.return"(%7) : (tensor<i1>) -> ()
+  }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
+  %1 = "stablehlo.slice"(%0#0) { start_indices = dense<[0, 0]> : tensor<2xi64>, limit_indices = dense<[16, 8]> : tensor<2xi64>, strides = dense<[1, 1]> : tensor<2xi64> } : (tensor<16x16xf32>) -> tensor<16x8xf32> 
+  %2 = "stablehlo.slice"(%0#1) { start_indices = dense<[0, 0]> : tensor<2xi64>, limit_indices = dense<[16, 8]> : tensor<2xi64>, strides = dense<[1, 1]> : tensor<2xi64> } : (tensor<16x16xi32>) -> tensor<16x8xi32>
+  return %1, %2 : tensor<16x8xf32>, tensor<16x8xi32>
+}
+
+// CHECK-LABEL: @iota_sort_slice_is_topk
+// CHECK-SAME: %[[IN:[a-z0-9]+]]
+// CHECK: %[[VALUES:.+]], %[[INDICES:.+]] = chlo.top_k(%[[IN]], k = 8) : tensor<16x16xf32> -> (tensor<16x8xf32>, tensor<16x8xi32>)
+// CHECK: return %[[VALUES]], %[[INDICES]] : tensor<16x8xf32>, tensor<16x8xi32>