Lower stablehlo.custom_call @TopK to chlo.top_k (#13937)
TopK can sometimes appear as a custom call after running through the
spmd splitter. Best to pattern match after the conversion to raise to
chlo.top_k where we have a fast path.
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.cpp
index 278287e..c63c8e9 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Passes.cpp
@@ -68,8 +68,7 @@
passManager.addNestedPass<func::FuncOp>(createTopLevelSCFToCFGPass());
if (detuple) passManager.addPass(createFlattenTuplesInCFG());
- passManager.addNestedPass<func::FuncOp>(
- createStableHLOToStableHLOPreprocessing());
+ passManager.addPass(createStableHLOToStableHLOPreprocessing());
passManager.addNestedPass<func::FuncOp>(createStableHLOCanonicalize());
// Various shape functions may have been materialized in the `shape.shape_of`
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Passes.td b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Passes.td
index 201f8f1..b52cb4b 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Passes.td
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Passes.td
@@ -10,7 +10,7 @@
include "mlir/Pass/PassBase.td"
def StableHLOToStableHLOPreprocessing :
- Pass<"iree-stablehlo-to-stablehlo-preprocessing", "func::FuncOp"> {
+ Pass<"iree-stablehlo-to-stablehlo-preprocessing", "ModuleOp"> {
let summary = "Applies IREE-specific stablehlo to stablehlo preprocessing transformations";
let options = [
Option<"orderConvFeatures", "order-conv-features", "bool", /*default=*/"true",
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp
index 845bb9b..6dad3e3 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/StableHLOToStableHLO.cpp
@@ -1433,6 +1433,113 @@
}
};
+struct CustomCallIsTopK final
+ : OpRewritePattern<mlir::stablehlo::CustomCallOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(mlir::stablehlo::CustomCallOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getCallTargetName() != "TopK") {
+ return rewriter.notifyMatchFailure(op, "not a TopK custom call");
+ }
+
+ if (op.getNumOperands() != 1 ||
+ !(op.getNumResults() == 1 || op.getNumResults() == 2)) {
+ return rewriter.notifyMatchFailure(
+ op, "incorrect number of operands / results");
+ }
+
+ ArrayAttr computations = op.getCalledComputations();
+ if (computations.size() != 1) {
+ return rewriter.notifyMatchFailure(op,
+ "incorrect number of computations");
+ }
+
+ SymbolRefAttr computation = dyn_cast<SymbolRefAttr>(computations[0]);
+ if (!computation) {
+ return rewriter.notifyMatchFailure(op, "not a ref attr");
+ }
+
+ auto operand = op.getOperand(0);
+ auto operandTy = cast<ShapedType>(operand.getType());
+ if (!operandTy.hasRank() || operandTy.getRank() != 2) {
+ return rewriter.notifyMatchFailure(op, "rank-2 input not found");
+ }
+
+ ShapedType topVTy;
+ ShapedType topITy;
+ if (op.getNumResults() == 1) {
+ if (auto tupleTy = dyn_cast<TupleType>(op.getType(0))) {
+ if (tupleTy.size() != 2) {
+ return rewriter.notifyMatchFailure(
+ op, "tuple return does not tuple two values");
+ }
+ topVTy = dyn_cast<ShapedType>(tupleTy.getType(0));
+ topITy = dyn_cast<ShapedType>(tupleTy.getType(1));
+ }
+ }
+
+ if (op.getNumResults() == 2) {
+ topVTy = dyn_cast<ShapedType>(op.getType(0));
+ topITy = dyn_cast<ShapedType>(op.getType(1));
+ }
+
+ if (!topVTy || !topITy) {
+ return rewriter.notifyMatchFailure(op, "unknown return type behavior");
+ }
+
+ int64_t k = topVTy.getDimSize(1);
+ if (k == ShapedType::kDynamic) {
+ return rewriter.notifyMatchFailure(op, "dynamic top-k k value");
+ }
+
+ auto moduleOp = op->getParentOfType<ModuleOp>();
+ auto funcOp = dyn_cast<func::FuncOp>(moduleOp.lookupSymbol(computation));
+
+ Block &block = funcOp.getRegion().front();
+ auto stablehloCompareOp =
+ dyn_cast<mlir::stablehlo::CompareOp>(block.front());
+ if (!stablehloCompareOp) {
+ return rewriter.notifyMatchFailure(op, "not stablehlo compare op");
+ }
+
+ auto returnOp = dyn_cast<func::ReturnOp>(block.getTerminator());
+ if (!returnOp) {
+ return rewriter.notifyMatchFailure(op, "could not find ReturnOp");
+ }
+
+ if (returnOp.getNumOperands() != 1 ||
+ returnOp.getOperand(0) != stablehloCompareOp.getResult()) {
+ return rewriter.notifyMatchFailure(op, "ReturnOp operand not compare op");
+ }
+
+ auto direction = stablehloCompareOp.getComparisonDirection();
+ bool getTop = direction == mlir::stablehlo::ComparisonDirection::GT ||
+ direction == mlir::stablehlo::ComparisonDirection::GE;
+
+ if (!getTop) {
+ return rewriter.notifyMatchFailure(op,
+ "Unsupported comparison direction");
+ }
+
+ auto newTopK = rewriter.create<chlo::TopKOp>(
+ op.getLoc(), TypeRange{topVTy, topITy}, operand, k);
+
+ if (op.getNumResults() == 2) {
+ rewriter.replaceOp(op, newTopK.getResults());
+ return success();
+ }
+
+ if (auto tupleTy = dyn_cast<TupleType>(op.getType(0))) {
+ rewriter.replaceOpWithNewOp<mlir::stablehlo::TupleOp>(
+ op, op.getType(0), newTopK.getResults());
+ return success();
+ }
+
+ return failure();
+ }
+};
+
struct StableHLOToStableHLOPreprocessing final
: impl::StableHLOToStableHLOPreprocessingBase<
StableHLOToStableHLOPreprocessing> {
@@ -1441,7 +1548,7 @@
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<shape::ShapeDialect, mlir::stablehlo::StablehloDialect,
- tensor::TensorDialect>();
+ chlo::ChloDialect, tensor::TensorDialect>();
}
void runOnOperation() override {
@@ -1495,6 +1602,9 @@
context,
/*benefit=*/400);
+ // Identify known custom calls and convert them to equivalent StableHLO.
+ patterns.insert<CustomCallIsTopK>(context);
+
// Additional canonicalizers that simplify to computationally
// less-complex operations.
patterns.insert<DotToMul>(context);
diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/stablehlo_to_stablehlo.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/stablehlo_to_stablehlo.mlir
index f9a8e30..90ffb6c 100644
--- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/stablehlo_to_stablehlo.mlir
+++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/stablehlo_to_stablehlo.mlir
@@ -356,3 +356,37 @@
%2 = "stablehlo.dot_general"(%arg2, %arg1) {dot_dimension_numbers = #stablehlo.dot<lhs_batching_dimensions = [0, 2], rhs_batching_dimensions = [0, 2], lhs_contracting_dimensions = [3], rhs_contracting_dimensions = [3]>, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<?x1024x16x64xf32>, tensor<?x1024x16x64xf32>) -> tensor<?x16x1024x1024xf32>
return %2 : tensor<?x16x1024x1024xf32>
}
+
+// -----
+
+func.func @custom_call_topk_tuple(%arg0: tensor<4x8000xbf16>) -> (tensor<4x40xbf16>, tensor<4x40xi32>) {
+ %0 = stablehlo.custom_call @TopK(%arg0) {called_computations = [@comparison], xla_shape = "(bf16[4,40]{1,0}, s32[4,40]{1,0})"} : (tensor<4x8000xbf16>) -> tuple<tensor<4x40xbf16>, tensor<4x40xi32>>
+ %1 = stablehlo.get_tuple_element %0[0] : (tuple<tensor<4x40xbf16>, tensor<4x40xi32>>) -> tensor<4x40xbf16>
+ %2 = stablehlo.get_tuple_element %0[1] : (tuple<tensor<4x40xbf16>, tensor<4x40xi32>>) -> tensor<4x40xi32>
+ return %1, %2 : tensor<4x40xbf16>, tensor<4x40xi32>
+}
+func.func private @comparison(%arg0: tensor<bf16>, %arg1: tensor<bf16>, %arg2: tensor<i32>, %arg3: tensor<i32>) -> tensor<i1> {
+ %0 = stablehlo.compare GT, %arg0, %arg1, TOTALORDER : (tensor<bf16>, tensor<bf16>) -> tensor<i1>
+ return %0 : tensor<i1>
+}
+
+// CHECK-LABEL: @custom_call_topk_tuple
+// CHECK-SAME: %[[ARG0:[a-z0-9]+]]
+// CHECK: %[[VALUES:.+]], %[[INDICES:.+]] = chlo.top_k(%[[ARG0]], k = 40) : tensor<4x8000xbf16> -> (tensor<4x40xbf16>, tensor<4x40xi32>)
+// CHECK: return %[[VALUES]], %[[INDICES]] : tensor<4x40xbf16>, tensor<4x40xi32>
+
+// -----
+
+func.func @custom_call_topk_returns(%arg0: tensor<4x8000xbf16>) -> (tensor<4x40xbf16>, tensor<4x40xi32>) {
+ %0:2 = stablehlo.custom_call @TopK(%arg0) {called_computations = [@comparison], xla_shape = "(bf16[4,40]{1,0}, s32[4,40]{1,0})"} : (tensor<4x8000xbf16>) -> (tensor<4x40xbf16>, tensor<4x40xi32>)
+ return %0#0, %0#1 : tensor<4x40xbf16>, tensor<4x40xi32>
+}
+func.func private @comparison(%arg0: tensor<bf16>, %arg1: tensor<bf16>, %arg2: tensor<i32>, %arg3: tensor<i32>) -> tensor<i1> {
+ %0 = stablehlo.compare GT, %arg0, %arg1, TOTALORDER : (tensor<bf16>, tensor<bf16>) -> tensor<i1>
+ return %0 : tensor<i1>
+}
+
+// CHECK-LABEL: @custom_call_topk_returns
+// CHECK-SAME: %[[ARG0:[a-z0-9]+]]
+// CHECK: %[[VALUES:.+]], %[[INDICES:.+]] = chlo.top_k(%[[ARG0]], k = 40) : tensor<4x8000xbf16> -> (tensor<4x40xbf16>, tensor<4x40xi32>)
+// CHECK: return %[[VALUES]], %[[INDICES]] : tensor<4x40xbf16>, tensor<4x40xi32>