Add pattern to map iota->sort->slice to topK (#13972)
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp
index 6dad3e3..1569238 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp
@@ -1540,6 +1540,87 @@
}
};
+struct IotaSortSliceIsTopK final : OpRewritePattern<mlir::stablehlo::SortOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(mlir::stablehlo::SortOp op,
+ PatternRewriter &rewriter) const override {
+ auto opOperands = op.getOperands();
+ auto opResults = op.getResults();
+ Value topKInput;
+ if (opOperands.size() != 2 || opResults.size() != 2) {
+ return rewriter.notifyMatchFailure(
+ op, "Slice that maps to TopK must have exactly two inputs/outputs");
+ }
+
+ Value inputIota;
+ // Check that one of the inputs is iota, assume that the other one is the
+ // input.
+ for (Value operand : opOperands) {
+ auto iotaOp =
+ dyn_cast_or_null<mlir::stablehlo::IotaOp>(operand.getDefiningOp());
+ if (iotaOp) {
+ inputIota = iotaOp.getResult();
+ } else {
+ topKInput = operand;
+ }
+ }
+
+ if (!inputIota) {
+ return rewriter.notifyMatchFailure(op, "Sort isn't called from Iota.");
+ }
+
+ Block &block = op.getRegion().front();
+ auto stablehloCompareOp =
+ dyn_cast<mlir::stablehlo::CompareOp>(block.front());
+ if (!stablehloCompareOp) {
+ return rewriter.notifyMatchFailure(op, "not stablehlo compare op");
+ }
+
+ auto direction = stablehloCompareOp.getComparisonDirection();
+ bool getTop = direction == mlir::stablehlo::ComparisonDirection::GT ||
+ direction == mlir::stablehlo::ComparisonDirection::GE;
+
+ if (!getTop) {
+ return rewriter.notifyMatchFailure(op,
+ "Unsupported comparison direction");
+ }
+
+ Value topV, topI;
+ int64_t k;
+ // Check that the output of the sort op gets fed into a slice.
+ for (auto [idx, result] : llvm::enumerate(opResults)) {
+ auto sliceOp =
+ dyn_cast<mlir::stablehlo::SliceOp>(*result.getUsers().begin());
+ if (!sliceOp) {
+ return rewriter.notifyMatchFailure(
+ op, "Sort isn't calling into a slice op.");
+ }
+
+ for (auto stride : sliceOp.getStrides().getValues<int64_t>()) {
+ if (stride != 1) {
+ return rewriter.notifyMatchFailure(
+ op, "All slice strides must be 1 in order to match to TopK.");
+ }
+ }
+
+ // Treat the first slice as inputs, the second as indices.
+ if (idx == 0) {
+ topV = sliceOp.getResult();
+ k = sliceOp.getLimitIndices().getValues<int64_t>()[1];
+ } else {
+ topI = sliceOp.getResult();
+ }
+ }
+
+ auto topK = rewriter.create<chlo::TopKOp>(
+ op.getLoc(), TypeRange{topV.getType(), topI.getType()}, topKInput, k);
+ topV.replaceAllUsesWith(topK.getResults()[0]);
+ topI.replaceAllUsesWith(topK.getResults()[1]);
+ return success();
+ }
+};
+
struct StableHLOToStableHLOPreprocessing final
: impl::StableHLOToStableHLOPreprocessingBase<
StableHLOToStableHLOPreprocessing> {
@@ -1605,6 +1686,9 @@
// Identify known custom calls and convert them to equivalent StableHLO.
patterns.insert<CustomCallIsTopK>(context);
+ // Identify an iota->sort->slice pattern that maps to TopK.
+ patterns.insert<IotaSortSliceIsTopK>(context);
+
// Additional canonicalizers that simplify to computationally
// less-complex operations.
patterns.insert<DotToMul>(context);
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/stablehlo_to_stablehlo.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/stablehlo_to_stablehlo.mlir
index 90ffb6c..8d7c9d5 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/stablehlo_to_stablehlo.mlir
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/stablehlo_to_stablehlo.mlir
@@ -390,3 +390,22 @@
// CHECK-SAME: %[[ARG0:[a-z0-9]+]]
// CHECK: %[[VALUES:.+]], %[[INDICES:.+]] = chlo.top_k(%[[ARG0]], k = 40) : tensor<4x8000xbf16> -> (tensor<4x40xbf16>, tensor<4x40xi32>)
// CHECK: return %[[VALUES]], %[[INDICES]] : tensor<4x40xbf16>, tensor<4x40xi32>
+
+// -----
+
+func.func @iota_sort_slice_is_topk(%in : tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) {
+ %iota = "stablehlo.iota"() { iota_dimension = 1 : i64 } : () -> tensor<16x16xi32>
+ %0:2 = "stablehlo.sort"(%in, %iota) ({
+ ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
+ %7 = "stablehlo.compare"(%arg0, %arg1) {comparison_direction = #stablehlo<comparison_direction GT>} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+ "stablehlo.return"(%7) : (tensor<i1>) -> ()
+ }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
+ %1 = "stablehlo.slice"(%0#0) { start_indices = dense<[0, 0]> : tensor<2xi64>, limit_indices = dense<[16, 8]> : tensor<2xi64>, strides = dense<[1, 1]> : tensor<2xi64> } : (tensor<16x16xf32>) -> tensor<16x8xf32>
+ %2 = "stablehlo.slice"(%0#1) { start_indices = dense<[0, 0]> : tensor<2xi64>, limit_indices = dense<[16, 8]> : tensor<2xi64>, strides = dense<[1, 1]> : tensor<2xi64> } : (tensor<16x16xi32>) -> tensor<16x8xi32>
+ return %1, %2 : tensor<16x8xf32>, tensor<16x8xi32>
+}
+
+// CHECK-LABEL: @iota_sort_slice_is_topk
+// CHECK-SAME: %[[IN:[a-z0-9]+]]
+// CHECK: %[[VALUES:.+]], %[[INDICES:.+]] = chlo.top_k(%[[IN]], k = 8) : tensor<16x16xf32> -> (tensor<16x8xf32>, tensor<16x8xi32>)
+// CHECK: return %[[VALUES]], %[[INDICES]] : tensor<16x8xf32>, tensor<16x8xi32>