Lower stablehlo.custom_call @TopK to chlo.top_k (#13937)

TopK can sometimes appear as a custom call after running through the
spmd splitter. Best to pattern match after the conversion to raise to
chlo.top_k where we have a fast path.
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.cpp
index 278287e..c63c8e9 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.cpp
@@ -68,8 +68,7 @@
   passManager.addNestedPass<func::FuncOp>(createTopLevelSCFToCFGPass());
   if (detuple) passManager.addPass(createFlattenTuplesInCFG());
 
-  passManager.addNestedPass<func::FuncOp>(
-      createStableHLOToStableHLOPreprocessing());
+  passManager.addPass(createStableHLOToStableHLOPreprocessing());
   passManager.addNestedPass<func::FuncOp>(createStableHLOCanonicalize());
 
   // Various shape functions may have been materialized in the `shape.shape_of`
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Passes.td b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Passes.td
index 201f8f1..b52cb4b 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Passes.td
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Passes.td
@@ -10,7 +10,7 @@
 include "mlir/Pass/PassBase.td"
 
 def StableHLOToStableHLOPreprocessing :
-    Pass<"iree-stablehlo-to-stablehlo-preprocessing", "func::FuncOp"> {
+    Pass<"iree-stablehlo-to-stablehlo-preprocessing", "ModuleOp"> {
   let summary = "Applies IREE-specific stablehlo to stablehlo preprocessing transformations";
   let options = [
     Option<"orderConvFeatures", "order-conv-features", "bool", /*default=*/"true",
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp
index 845bb9b..6dad3e3 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp
@@ -1433,6 +1433,113 @@
   }
 };
 
+struct CustomCallIsTopK final
+    : OpRewritePattern<mlir::stablehlo::CustomCallOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(mlir::stablehlo::CustomCallOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.getCallTargetName() != "TopK") {
+      return rewriter.notifyMatchFailure(op, "not a TopK custom call");
+    }
+
+    if (op.getNumOperands() != 1 ||
+        !(op.getNumResults() == 1 || op.getNumResults() == 2)) {
+      return rewriter.notifyMatchFailure(
+          op, "incorrect number of operands / results");
+    }
+
+    ArrayAttr computations = op.getCalledComputations();
+    if (computations.size() != 1) {
+      return rewriter.notifyMatchFailure(op,
+                                         "incorrect number of computations");
+    }
+
+    SymbolRefAttr computation = dyn_cast<SymbolRefAttr>(computations[0]);
+    if (!computation) {
+      return rewriter.notifyMatchFailure(op, "not a ref attr");
+    }
+
+    auto operand = op.getOperand(0);
+    auto operandTy = cast<ShapedType>(operand.getType());
+    if (!operandTy.hasRank() || operandTy.getRank() != 2) {
+      return rewriter.notifyMatchFailure(op, "rank-2 input not found");
+    }
+
+    ShapedType topVTy;
+    ShapedType topITy;
+    if (op.getNumResults() == 1) {
+      if (auto tupleTy = dyn_cast<TupleType>(op.getType(0))) {
+        if (tupleTy.size() != 2) {
+          return rewriter.notifyMatchFailure(
+              op, "tuple return does not tuple two values");
+        }
+        topVTy = dyn_cast<ShapedType>(tupleTy.getType(0));
+        topITy = dyn_cast<ShapedType>(tupleTy.getType(1));
+      }
+    }
+
+    if (op.getNumResults() == 2) {
+      topVTy = dyn_cast<ShapedType>(op.getType(0));
+      topITy = dyn_cast<ShapedType>(op.getType(1));
+    }
+
+    if (!topVTy || !topITy) {
+      return rewriter.notifyMatchFailure(op, "unknown return type behavior");
+    }
+
+    int64_t k = topVTy.getDimSize(1);
+    if (k == ShapedType::kDynamic) {
+      return rewriter.notifyMatchFailure(op, "dynamic top-k k value");
+    }
+
+    auto moduleOp = op->getParentOfType<ModuleOp>();
+    auto funcOp = dyn_cast<func::FuncOp>(moduleOp.lookupSymbol(computation));
+
+    Block &block = funcOp.getRegion().front();
+    auto stablehloCompareOp =
+        dyn_cast<mlir::stablehlo::CompareOp>(block.front());
+    if (!stablehloCompareOp) {
+      return rewriter.notifyMatchFailure(op, "not stablehlo compare op");
+    }
+
+    auto returnOp = dyn_cast<func::ReturnOp>(block.getTerminator());
+    if (!returnOp) {
+      return rewriter.notifyMatchFailure(op, "could not find ReturnOp");
+    }
+
+    if (returnOp.getNumOperands() != 1 ||
+        returnOp.getOperand(0) != stablehloCompareOp.getResult()) {
+      return rewriter.notifyMatchFailure(op, "ReturnOp operand not compare op");
+    }
+
+    auto direction = stablehloCompareOp.getComparisonDirection();
+    bool getTop = direction == mlir::stablehlo::ComparisonDirection::GT ||
+                  direction == mlir::stablehlo::ComparisonDirection::GE;
+
+    if (!getTop) {
+      return rewriter.notifyMatchFailure(op,
+                                         "Unsupported comparison direction");
+    }
+
+    auto newTopK = rewriter.create<chlo::TopKOp>(
+        op.getLoc(), TypeRange{topVTy, topITy}, operand, k);
+
+    if (op.getNumResults() == 2) {
+      rewriter.replaceOp(op, newTopK.getResults());
+      return success();
+    }
+
+    if (auto tupleTy = dyn_cast<TupleType>(op.getType(0))) {
+      rewriter.replaceOpWithNewOp<mlir::stablehlo::TupleOp>(
+          op, op.getType(0), newTopK.getResults());
+      return success();
+    }
+
+    return failure();
+  }
+};
+
 struct StableHLOToStableHLOPreprocessing final
     : impl::StableHLOToStableHLOPreprocessingBase<
           StableHLOToStableHLOPreprocessing> {
@@ -1441,7 +1548,7 @@
 
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<shape::ShapeDialect, mlir::stablehlo::StablehloDialect,
-                    tensor::TensorDialect>();
+                    chlo::ChloDialect, tensor::TensorDialect>();
   }
 
   void runOnOperation() override {
@@ -1495,6 +1602,9 @@
         context,
         /*benefit=*/400);
 
+    // Identify known custom calls and convert them to equivalent StableHLO.
+    patterns.insert<CustomCallIsTopK>(context);
+
     // Additional canonicalizers that simplify to computationally
     // less-complex operations.
     patterns.insert<DotToMul>(context);
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/stablehlo_to_stablehlo.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/stablehlo_to_stablehlo.mlir
index f9a8e30..90ffb6c 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/stablehlo_to_stablehlo.mlir
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/stablehlo_to_stablehlo.mlir
@@ -356,3 +356,37 @@
   %2 = "stablehlo.dot_general"(%arg2, %arg1) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0, 2], rhs_batching_dimensions = [0, 2], lhs_contracting_dimensions = [3], rhs_contracting_dimensions = [3]>, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<?x1024x16x64xf32>, tensor<?x1024x16x64xf32>) -> tensor<?x16x1024x1024xf32>
   return %2 : tensor<?x16x1024x1024xf32>
 }
+
+// -----
+
+func.func @custom_call_topk_tuple(%arg0: tensor<4x8000xbf16>) -> (tensor<4x40xbf16>, tensor<4x40xi32>) {
+  %0 = stablehlo.custom_call @TopK(%arg0) {called_computations = [@comparison], xla_shape = "(bf16[4,40]{1,0}, s32[4,40]{1,0})"} : (tensor<4x8000xbf16>) -> tuple<tensor<4x40xbf16>, tensor<4x40xi32>>
+  %1 = stablehlo.get_tuple_element %0[0] : (tuple<tensor<4x40xbf16>, tensor<4x40xi32>>) -> tensor<4x40xbf16>
+  %2 = stablehlo.get_tuple_element %0[1] : (tuple<tensor<4x40xbf16>, tensor<4x40xi32>>) -> tensor<4x40xi32>
+  return %1, %2 : tensor<4x40xbf16>, tensor<4x40xi32>
+}
+func.func private @comparison(%arg0: tensor<bf16>, %arg1: tensor<bf16>, %arg2: tensor<i32>, %arg3: tensor<i32>) -> tensor<i1> {
+  %0 = stablehlo.compare  GT, %arg0, %arg1,  TOTALORDER : (tensor<bf16>, tensor<bf16>) -> tensor<i1>
+  return %0 : tensor<i1>
+}
+
+// CHECK-LABEL: @custom_call_topk_tuple
+// CHECK-SAME: %[[ARG0:[a-z0-9]+]]
+// CHECK: %[[VALUES:.+]], %[[INDICES:.+]] = chlo.top_k(%[[ARG0]], k = 40) : tensor<4x8000xbf16> -> (tensor<4x40xbf16>, tensor<4x40xi32>)
+// CHECK: return %[[VALUES]], %[[INDICES]] : tensor<4x40xbf16>, tensor<4x40xi32>
+
+// -----
+
+func.func @custom_call_topk_returns(%arg0: tensor<4x8000xbf16>) -> (tensor<4x40xbf16>, tensor<4x40xi32>) {
+  %0:2 = stablehlo.custom_call @TopK(%arg0) {called_computations = [@comparison], xla_shape = "(bf16[4,40]{1,0}, s32[4,40]{1,0})"} : (tensor<4x8000xbf16>) -> (tensor<4x40xbf16>, tensor<4x40xi32>)
+  return %0#0, %0#1 : tensor<4x40xbf16>, tensor<4x40xi32>
+}
+func.func private @comparison(%arg0: tensor<bf16>, %arg1: tensor<bf16>, %arg2: tensor<i32>, %arg3: tensor<i32>) -> tensor<i1> {
+  %0 = stablehlo.compare  GT, %arg0, %arg1,  TOTALORDER : (tensor<bf16>, tensor<bf16>) -> tensor<i1>
+  return %0 : tensor<i1>
+}
+
+// CHECK-LABEL: @custom_call_topk_returns
+// CHECK-SAME: %[[ARG0:[a-z0-9]+]]
+// CHECK: %[[VALUES:.+]], %[[INDICES:.+]] = chlo.top_k(%[[ARG0]], k = 40) : tensor<4x8000xbf16> -> (tensor<4x40xbf16>, tensor<4x40xi32>)
+// CHECK: return %[[VALUES]], %[[INDICES]] : tensor<4x40xbf16>, tensor<4x40xi32>