[LinalgExt] Add topk_v2 op with roundtrip and invalid mlir test (#24054)
**Context**: Refer to [Discord
discussion](https://discord.com/channels/689900678990135345/1491812950451814540)
for the design choice. We aim to add a top-K op with bitonic sort
vectorization along the VectorDistribute pipeline. Refer to Issue:
https://github.com/iree-org/iree/issues/24053 for more details.
This PR, as the first step, defines the` iree_linalg_ext.topk_v2`
operation and defines its verifier along with corresponding roundtrip
and invalid tests (`roundtrip.mlir` and `invalid.mlir`).
Support for topk_v2 will be added incrementally across multiple PRs:
- Add the `iree_linalg_ext.topk_v2` op definition and verification
logic, along with roundtrip and invalid IR tests (this PR)
- Add tiling support
- Add convert_to_loops lowering support
- Add e2e support along the VectorDistribute pipeline (vectorization,
layout analysis, and distribution, etc)
---------
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 13a6347..a10cfbc 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -1290,6 +1290,111 @@
}
//===----------------------------------------------------------------------===//
+// TopkV2Op
+//===----------------------------------------------------------------------===//
+
+LogicalResult TopkV2Op::verify() {
+ auto inputValuesType = getInputType();
+ auto outputValuesType = cast<ShapedType>(getOutputValues().getType());
+ uint64_t dim = getDimension();
+
+ if (dim >= static_cast<uint64_t>(getInputRank())) {
+ return emitOpError("dimension exceeds rank");
+ }
+ if (inputValuesType.getElementType() != outputValuesType.getElementType()) {
+ return emitOpError("expected input/output value types to be identical");
+ }
+ if (inputValuesType.getRank() != outputValuesType.getRank()) {
+ return emitOpError("expected input/output to have the same rank");
+ }
+
+ if (Value inputIndices = getInputIndices()) {
+ if (!getOutputIndices()) {
+ return emitOpError(
+ "input indices require output indices to carry provenance");
+ }
+ auto inputIndicesType = cast<ShapedType>(inputIndices.getType());
+ if (!isa<IntegerType>(inputIndicesType.getElementType())) {
+ return emitOpError("expected input indices to be integer type");
+ }
+ if (failed(verifyCompatibleShape(inputValuesType, inputIndicesType))) {
+ return emitOpError("input values/indices shape must match");
+ }
+ }
+
+ if (Value outputIndices = getOutputIndices()) {
+ auto outputIndicesType = cast<ShapedType>(outputIndices.getType());
+ if (!isa<IntegerType>(outputIndicesType.getElementType())) {
+ return emitOpError("expected output indices to be integer type");
+ }
+ if (failed(verifyCompatibleShape(outputValuesType, outputIndicesType))) {
+ return emitOpError("output values/indices shape must match");
+ }
+ }
+
+ // All dimensions except the sort dimension must match.
+ for (auto [idx, inDim, outDim] : llvm::enumerate(
+ inputValuesType.getShape(), outputValuesType.getShape())) {
+ if (idx == dim) {
+ continue;
+ }
+ if (ShapedType::isStatic(inDim) && ShapedType::isStatic(outDim) &&
+ inDim != outDim) {
+ return emitOpError("incompatible input/output shapes at dimension ")
+ << idx;
+ }
+ }
+
+ // Validate that output K does not exceed input along the sort dimension.
+ int64_t inputDimSize = inputValuesType.getDimSize(dim);
+ int64_t outputDimSize = outputValuesType.getDimSize(dim);
+ if (ShapedType::isStatic(inputDimSize) &&
+ ShapedType::isStatic(outputDimSize)) {
+ if (outputDimSize == 0) {
+ return emitOpError("output dimension must be positive");
+ }
+ if (outputDimSize > inputDimSize) {
+ return emitOpError("output dimension must not exceed input, got ")
+ << outputDimSize << " > " << inputDimSize;
+ }
+ }
+
+ return success();
+}
+
+LogicalResult TopkV2Op::verifyRegions() {
+ auto inputValuesType = getInputType();
+ Block &block = getRegion().front();
+ if (block.getNumArguments() != 2) {
+ return emitOpError("region block should have 2 arguments");
+ }
+ if (block.getArgument(0).getType() != inputValuesType.getElementType() ||
+ block.getArgument(1).getType() != inputValuesType.getElementType()) {
+ return emitOpError("region block types must match input value type");
+ }
+ auto terminatorOp = cast<YieldOp>(block.getTerminator());
+ if (terminatorOp.getNumOperands() != 1 ||
+ !terminatorOp.getOperand(0).getType().isInteger(1)) {
+ return emitOpError("region block must end with a linalg_ext.yield i1");
+ }
+ return success();
+}
+
+LogicalResult
+TopkV2Op::reifyResultShapes(OpBuilder &b,
+ ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+ return cast<LinalgExtOp>(getOperation())
+ .reifyResultShapes(b, reifiedReturnShapes);
+}
+
+MutableOperandRange TopkV2Op::getDpsInitsMutable() {
+ // Operands order: values, [input_indices], output_values, [output_indices]
+ unsigned numInputs = 1 + (getInputIndices() ? 1 : 0);
+ unsigned numInits = 1 + (getOutputIndices() ? 1 : 0);
+ return MutableOperandRange(*this, numInputs, numInits);
+}
+
+//===----------------------------------------------------------------------===//
// ArgCompareOp
//===----------------------------------------------------------------------===//
@@ -3130,6 +3235,7 @@
DEFINE_OP_GET_EFFECTS(FftOp)
DEFINE_OP_GET_EFFECTS(ScanOp)
DEFINE_OP_GET_EFFECTS(TopkOp)
+DEFINE_OP_GET_EFFECTS(TopkV2Op)
DEFINE_OP_GET_EFFECTS(ArgCompareOp)
DEFINE_OP_GET_EFFECTS(WinogradInputTransformOp)
DEFINE_OP_GET_EFFECTS(WinogradFilterTransformOp)
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index 1da653c..01f198d 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -680,6 +680,77 @@
}];
}
+def IREELinalgExt_TopkV2Op : IREELinalgExt_Op<"topk_v2",[
+ AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, ["reifyResultShapes"]>,
+ DeclareOpInterfaceMethods<LinalgExtInterface>
+]>{
+ let summary = [{Top-K v2 operator.}];
+ let description = [{
+ Selects the top-K elements along the given `dimension` using the given
+ `comparator`. K is determined by the output shape. The output sort
+ dimension must be less than or equal to the input sort dimension.
+
+ This op replaces the legacy `topk` op with several improvements:
+ output indices are optional instead of mandatory, input indices of any
+ integer type are supported for pre-permuted inputs, and the comparator
+ region uses a pure ordering predicate decoupled from the swap semantics
+ of the legacy op.
+
+ Accepts a single N-D input tensor of values and an optional N-D tensor of
+ input indices (any integer type). If input indices aren't provided, the
+ index mapping defaults to [0, N) along the sort dimension, where N is
+ the input size. Input indices may only be provided when output indices
+ are also present, since the comparator operates solely on values and
+ input indices only carry positional provenance to the output. Produces
+ output values and optionally output indices (integer type) tracking
+ original positions.
+
+ When `is_sorted` is present, the output top-K elements are guaranteed to
+ be in sorted order along the sort dimension. When absent (default), the
+ top-K elements are returned but their relative order is unspecified.
+
+ Comparator region accepts two scalar arguments of the input value element
+ type and yields an i1 indicating whether the first argument should be
+ ordered before the second.
+ }];
+
+ let arguments = (ins I64Attr:$dimension,
+ UnitAttr:$is_sorted,
+ AnyShaped:$values,
+ Optional<AnyShaped>:$input_indices,
+ AnyShaped:$output_values,
+ Optional<AnyShaped>:$output_indices
+ );
+ let results = (outs Variadic<AnyRankedTensor>:$results);
+ let regions = (region SizedRegion<1>:$region);
+ let hasRegionVerifier = 1;
+ let assemblyFormat = [{
+ attr-dict
+ `dimension` `(` $dimension `)`
+ (`is_sorted` $is_sorted^)?
+ `ins` `(` $values (`,` $input_indices^)? `:` type($values) (`,` type($input_indices)^)? `)`
+ `outs` `(` $output_values (`,` $output_indices^)? `:`
+ type($output_values) (`,` type($output_indices)^)? `)`
+ $region (`->` type($results)^)?
+ }];
+ let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
+ ShapedType getInputType() {
+ return cast<ShapedType>(getValues().getType());
+ }
+ int64_t getInputRank() {
+ return getInputType().getRank();
+ }
+ bool hasOutputIndices() {
+ return static_cast<bool>(getOutputIndices());
+ }
+
+ // Method to implement for specifying output range for
+ // DestinationStyleOpInterface
+ MutableOperandRange getDpsInitsMutable();
+ }];
+}
+
def IREELinalgExt_ArgCompareOp : IREELinalgExt_Op<"arg_compare", [
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, ["reifyResultShapes"]>,
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
index 2902eb2..5529e94 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
@@ -981,6 +981,224 @@
// -----
+func.func @topk_v2_dimension_exceeds_rank(%arg0: tensor<128xf32>) -> tensor<128xf32> {
+ %out = tensor.empty() : tensor<128xf32>
+ // expected-error@+1 {{dimension exceeds rank}}
+ %0 = iree_linalg_ext.topk_v2 dimension(1)
+ ins(%arg0 : tensor<128xf32>)
+ outs(%out : tensor<128xf32>) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %cmp = arith.cmpf oge, %lhs, %rhs : f32
+ iree_linalg_ext.yield %cmp : i1
+ } -> tensor<128xf32>
+ return %0 : tensor<128xf32>
+}
+
+// -----
+
+func.func @topk_v2_value_type_mismatch(%arg0: tensor<4x128xi32>) -> tensor<4x128xf32> {
+ %out = tensor.empty() : tensor<4x128xf32>
+ // expected-error@+1 {{expected input/output value types to be identical}}
+ %0 = iree_linalg_ext.topk_v2 dimension(1)
+ ins(%arg0 : tensor<4x128xi32>)
+ outs(%out : tensor<4x128xf32>) {
+ ^bb0(%lhs: i32, %rhs: i32):
+ %cmp = arith.cmpi sgt, %lhs, %rhs : i32
+ iree_linalg_ext.yield %cmp : i1
+ } -> tensor<4x128xf32>
+ return %0 : tensor<4x128xf32>
+}
+
+// -----
+
+func.func @topk_v2_rank_mismatch(%arg0: tensor<4x128xf32>) -> tensor<128xf32> {
+ %out = tensor.empty() : tensor<128xf32>
+ // expected-error@+1 {{expected input/output to have the same rank}}
+ %0 = iree_linalg_ext.topk_v2 dimension(0)
+ ins(%arg0 : tensor<4x128xf32>)
+ outs(%out : tensor<128xf32>) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %cmp = arith.cmpf oge, %lhs, %rhs : f32
+ iree_linalg_ext.yield %cmp : i1
+ } -> tensor<128xf32>
+ return %0 : tensor<128xf32>
+}
+
+// -----
+
+func.func @topk_v2_input_indices_not_integer(
+ %arg0: tensor<4x128xf32>, %arg1: tensor<4x128xf32>)
+ -> (tensor<4x128xf32>, tensor<4x128xi32>) {
+ %out_values = tensor.empty() : tensor<4x128xf32>
+ %out_indices = tensor.empty() : tensor<4x128xi32>
+ // expected-error@+1 {{expected input indices to be integer type}}
+ %0:2 = iree_linalg_ext.topk_v2 dimension(1)
+ ins(%arg0, %arg1 : tensor<4x128xf32>, tensor<4x128xf32>)
+ outs(%out_values, %out_indices : tensor<4x128xf32>, tensor<4x128xi32>) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %cmp = arith.cmpf oge, %lhs, %rhs : f32
+ iree_linalg_ext.yield %cmp : i1
+ } -> tensor<4x128xf32>, tensor<4x128xi32>
+ return %0#0, %0#1 : tensor<4x128xf32>, tensor<4x128xi32>
+}
+
+// -----
+
+func.func @topk_v2_input_indices_shape_mismatch(
+ %arg0: tensor<4x128xf32>, %arg1: tensor<8x128xi32>)
+ -> (tensor<4x128xf32>, tensor<4x128xi32>) {
+ %out_values = tensor.empty() : tensor<4x128xf32>
+ %out_indices = tensor.empty() : tensor<4x128xi32>
+ // expected-error@+1 {{input values/indices shape must match}}
+ %0:2 = iree_linalg_ext.topk_v2 dimension(1)
+ ins(%arg0, %arg1 : tensor<4x128xf32>, tensor<8x128xi32>)
+ outs(%out_values, %out_indices : tensor<4x128xf32>, tensor<4x128xi32>) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %cmp = arith.cmpf oge, %lhs, %rhs : f32
+ iree_linalg_ext.yield %cmp : i1
+ } -> tensor<4x128xf32>, tensor<4x128xi32>
+ return %0#0, %0#1 : tensor<4x128xf32>, tensor<4x128xi32>
+}
+
+// -----
+
+func.func @topk_v2_indices_not_integer(%arg0: tensor<4x128xf32>) -> (tensor<4x128xf32>, tensor<4x128xf32>) {
+ %out_values = tensor.empty() : tensor<4x128xf32>
+ %out_indices = tensor.empty() : tensor<4x128xf32>
+ // expected-error@+1 {{expected output indices to be integer type}}
+ %0:2 = iree_linalg_ext.topk_v2 dimension(1)
+ ins(%arg0 : tensor<4x128xf32>)
+ outs(%out_values, %out_indices : tensor<4x128xf32>, tensor<4x128xf32>) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %cmp = arith.cmpf oge, %lhs, %rhs : f32
+ iree_linalg_ext.yield %cmp : i1
+ } -> tensor<4x128xf32>, tensor<4x128xf32>
+ return %0#0, %0#1 : tensor<4x128xf32>, tensor<4x128xf32>
+}
+
+// -----
+
+func.func @topk_v2_indices_shape_mismatch(%arg0: tensor<4x128xf32>) -> (tensor<4x128xf32>, tensor<8x128xi32>) {
+ %out_values = tensor.empty() : tensor<4x128xf32>
+ %out_indices = tensor.empty() : tensor<8x128xi32>
+ // expected-error@+1 {{output values/indices shape must match}}
+ %0:2 = iree_linalg_ext.topk_v2 dimension(1)
+ ins(%arg0 : tensor<4x128xf32>)
+ outs(%out_values, %out_indices : tensor<4x128xf32>, tensor<8x128xi32>) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %cmp = arith.cmpf oge, %lhs, %rhs : f32
+ iree_linalg_ext.yield %cmp : i1
+ } -> tensor<4x128xf32>, tensor<8x128xi32>
+ return %0#0, %0#1 : tensor<4x128xf32>, tensor<8x128xi32>
+}
+
+// -----
+
+func.func @topk_v2_non_sort_dim_mismatch(%arg0: tensor<4x128xf32>) -> tensor<8x128xf32> {
+ %out = tensor.empty() : tensor<8x128xf32>
+ // expected-error@+1 {{incompatible input/output shapes at dimension 0}}
+ %0 = iree_linalg_ext.topk_v2 dimension(1)
+ ins(%arg0 : tensor<4x128xf32>)
+ outs(%out : tensor<8x128xf32>) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %cmp = arith.cmpf oge, %lhs, %rhs : f32
+ iree_linalg_ext.yield %cmp : i1
+ } -> tensor<8x128xf32>
+ return %0 : tensor<8x128xf32>
+}
+
+// -----
+
+func.func @topk_v2_wrong_comparator_args(%arg0: tensor<4x128xf32>) -> tensor<4x128xf32> {
+ %out = tensor.empty() : tensor<4x128xf32>
+ // expected-error@+1 {{region block should have 2 arguments}}
+ %0 = iree_linalg_ext.topk_v2 dimension(1)
+ ins(%arg0 : tensor<4x128xf32>)
+ outs(%out : tensor<4x128xf32>) {
+ ^bb0(%lhs: f32, %mid: f32, %rhs: f32):
+ %cmp = arith.cmpf oge, %lhs, %rhs : f32
+ iree_linalg_ext.yield %cmp : i1
+ } -> tensor<4x128xf32>
+ return %0 : tensor<4x128xf32>
+}
+
+// -----
+
+func.func @topk_v2_comparator_type_mismatch(%arg0: tensor<4x128xf32>) -> tensor<4x128xf32> {
+ %out = tensor.empty() : tensor<4x128xf32>
+ // expected-error@+1 {{region block types must match input value type}}
+ %0 = iree_linalg_ext.topk_v2 dimension(1)
+ ins(%arg0 : tensor<4x128xf32>)
+ outs(%out : tensor<4x128xf32>) {
+ ^bb0(%lhs: i32, %rhs: i32):
+ %cmp = arith.cmpi sgt, %lhs, %rhs : i32
+ iree_linalg_ext.yield %cmp : i1
+ } -> tensor<4x128xf32>
+ return %0 : tensor<4x128xf32>
+}
+
+// -----
+
+func.func @topk_v2_comparator_not_i1(%arg0: tensor<4x128xf32>) -> tensor<4x128xf32> {
+ %out = tensor.empty() : tensor<4x128xf32>
+ // expected-error@+1 {{region block must end with a linalg_ext.yield i1}}
+ %0 = iree_linalg_ext.topk_v2 dimension(1)
+ ins(%arg0 : tensor<4x128xf32>)
+ outs(%out : tensor<4x128xf32>) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ iree_linalg_ext.yield %lhs : f32
+ } -> tensor<4x128xf32>
+ return %0 : tensor<4x128xf32>
+}
+
+// -----
+
+func.func @topk_v2_zero_k(%arg0: tensor<4x8xf32>) -> tensor<4x0xf32> {
+ %out = tensor.empty() : tensor<4x0xf32>
+ // expected-error@+1 {{output dimension must be positive}}
+ %0 = iree_linalg_ext.topk_v2 dimension(1)
+ ins(%arg0 : tensor<4x8xf32>)
+ outs(%out : tensor<4x0xf32>) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %cmp = arith.cmpf oge, %lhs, %rhs : f32
+ iree_linalg_ext.yield %cmp : i1
+ } -> tensor<4x0xf32>
+ return %0 : tensor<4x0xf32>
+}
+
+// -----
+
+func.func @topk_v2_output_exceeds_input(%arg0: tensor<4x8xf32>) -> tensor<4x16xf32> {
+ %out = tensor.empty() : tensor<4x16xf32>
+ // expected-error@+1 {{output dimension must not exceed input}}
+ %0 = iree_linalg_ext.topk_v2 dimension(1)
+ ins(%arg0 : tensor<4x8xf32>)
+ outs(%out : tensor<4x16xf32>) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %cmp = arith.cmpf oge, %lhs, %rhs : f32
+ iree_linalg_ext.yield %cmp : i1
+ } -> tensor<4x16xf32>
+ return %0 : tensor<4x16xf32>
+}
+
+// -----
+
+func.func @topk_v2_input_indices_without_output_indices(
+ %arg0: tensor<4x128xf32>, %arg1: tensor<4x128xi32>) -> tensor<4x64xf32> {
+ %out = tensor.empty() : tensor<4x64xf32>
+ // expected-error@+1 {{input indices require output indices to carry provenance}}
+ %0 = iree_linalg_ext.topk_v2 dimension(1)
+ ins(%arg0, %arg1 : tensor<4x128xf32>, tensor<4x128xi32>)
+ outs(%out : tensor<4x64xf32>) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %cmp = arith.cmpf oge, %lhs, %rhs : f32
+ iree_linalg_ext.yield %cmp : i1
+ } -> tensor<4x64xf32>
+ return %0 : tensor<4x64xf32>
+}
+
+// -----
+
func.func @exp_reduction_non_zero(%S: tensor<2x3xf32>) -> tensor<2xf32> {
%M = tensor.empty() : tensor<2xf32>
%out = tensor.empty() : tensor<2xf32>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
index ac54caa..691f9df 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
@@ -1272,6 +1272,342 @@
// -----
+func.func @topk_v2_tensor(%arg0: tensor<4x1024xf32>) -> tensor<4x1024xf32> {
+ %out = tensor.empty() : tensor<4x1024xf32>
+ %0 = iree_linalg_ext.topk_v2 dimension(1)
+ ins(%arg0 : tensor<4x1024xf32>)
+ outs(%out : tensor<4x1024xf32>) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %cmp = arith.cmpf oge, %lhs, %rhs : f32
+ iree_linalg_ext.yield %cmp : i1
+ } -> tensor<4x1024xf32>
+ return %0 : tensor<4x1024xf32>
+}
+// CHECK-LABEL: func.func @topk_v2_tensor(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<4x1024xf32>
+// CHECK: %[[OUT:.+]] = tensor.empty()
+// CHECK: %[[RESULT:.+]] = iree_linalg_ext.topk_v2
+// CHECK-SAME: dimension(1)
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[OUT]]
+// CHECK: iree_linalg_ext.yield
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @topk_v2_tensor_with_indices(%arg0: tensor<4x1024xf32>) -> (tensor<4x1024xf32>, tensor<4x1024xi32>) {
+ %out_values = tensor.empty() : tensor<4x1024xf32>
+ %out_indices = tensor.empty() : tensor<4x1024xi32>
+ %0:2 = iree_linalg_ext.topk_v2 dimension(1)
+ ins(%arg0 : tensor<4x1024xf32>)
+ outs(%out_values, %out_indices : tensor<4x1024xf32>, tensor<4x1024xi32>) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %cmp = arith.cmpf oge, %lhs, %rhs : f32
+ iree_linalg_ext.yield %cmp : i1
+ } -> tensor<4x1024xf32>, tensor<4x1024xi32>
+ return %0#0, %0#1 : tensor<4x1024xf32>, tensor<4x1024xi32>
+}
+// CHECK-LABEL: func.func @topk_v2_tensor_with_indices(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<4x1024xf32>
+// CHECK: %[[OUT_VALUES:.+]] = tensor.empty()
+// CHECK: %[[OUT_INDICES:.+]] = tensor.empty()
+// CHECK: %[[RESULT:.+]]:2 = iree_linalg_ext.topk_v2
+// CHECK-SAME: dimension(1)
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[OUT_VALUES]], %[[OUT_INDICES]]
+// CHECK: iree_linalg_ext.yield
+// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
+func.func @topk_v2_memref(%arg0: memref<4x1024xf32>, %arg1: memref<4x1024xf32>) {
+ iree_linalg_ext.topk_v2 dimension(1)
+ ins(%arg0 : memref<4x1024xf32>)
+ outs(%arg1 : memref<4x1024xf32>) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %cmp = arith.cmpf oge, %lhs, %rhs : f32
+ iree_linalg_ext.yield %cmp : i1
+ }
+ return
+}
+// CHECK-LABEL: func.func @topk_v2_memref(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<4x1024xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<4x1024xf32>
+// CHECK: iree_linalg_ext.topk_v2
+// CHECK-SAME: dimension(1)
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[ARG1]]
+// CHECK: iree_linalg_ext.yield
+
+// -----
+
+func.func @topk_v2_topk_tensor(%arg0: tensor<4x1024xf32>) -> (tensor<4x8xf32>, tensor<4x8xi32>) {
+ %out_values = tensor.empty() : tensor<4x8xf32>
+ %out_indices = tensor.empty() : tensor<4x8xi32>
+ %0:2 = iree_linalg_ext.topk_v2 dimension(1)
+ ins(%arg0 : tensor<4x1024xf32>)
+ outs(%out_values, %out_indices : tensor<4x8xf32>, tensor<4x8xi32>) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %cmp = arith.cmpf oge, %lhs, %rhs : f32
+ iree_linalg_ext.yield %cmp : i1
+ } -> tensor<4x8xf32>, tensor<4x8xi32>
+ return %0#0, %0#1 : tensor<4x8xf32>, tensor<4x8xi32>
+}
+// CHECK-LABEL: func.func @topk_v2_topk_tensor(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<4x1024xf32>
+// CHECK: %[[OUT_VALUES:.+]] = tensor.empty()
+// CHECK: %[[OUT_INDICES:.+]] = tensor.empty()
+// CHECK: %[[RESULT:.+]]:2 = iree_linalg_ext.topk_v2
+// CHECK-SAME: dimension(1)
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[OUT_VALUES]], %[[OUT_INDICES]]
+// CHECK: iree_linalg_ext.yield
+// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
+func.func @topk_v2_topk_no_indices(%arg0: tensor<4x1024xf32>) -> tensor<4x8xf32> {
+ %out_values = tensor.empty() : tensor<4x8xf32>
+ %0 = iree_linalg_ext.topk_v2 dimension(1)
+ ins(%arg0 : tensor<4x1024xf32>)
+ outs(%out_values : tensor<4x8xf32>) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %cmp = arith.cmpf oge, %lhs, %rhs : f32
+ iree_linalg_ext.yield %cmp : i1
+ } -> tensor<4x8xf32>
+ return %0 : tensor<4x8xf32>
+}
+// CHECK-LABEL: func.func @topk_v2_topk_no_indices(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<4x1024xf32>
+// CHECK: %[[OUT_VALUES:.+]] = tensor.empty()
+// CHECK: %[[RESULT:.+]] = iree_linalg_ext.topk_v2
+// CHECK-SAME: dimension(1)
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[OUT_VALUES]]
+// CHECK: iree_linalg_ext.yield
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @topk_v2_topk_memref(%arg0: memref<4x1024xf32>, %out_values: memref<4x8xf32>, %out_indices: memref<4x8xi32>) {
+ iree_linalg_ext.topk_v2 dimension(1)
+ ins(%arg0 : memref<4x1024xf32>)
+ outs(%out_values, %out_indices : memref<4x8xf32>, memref<4x8xi32>) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %cmp = arith.cmpf oge, %lhs, %rhs : f32
+ iree_linalg_ext.yield %cmp : i1
+ }
+ return
+}
+// CHECK-LABEL: func.func @topk_v2_topk_memref(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<4x1024xf32>
+// CHECK-SAME: %[[OUT_VALUES:[a-zA-Z0-9]+]]: memref<4x8xf32>
+// CHECK-SAME: %[[OUT_INDICES:[a-zA-Z0-9]+]]: memref<4x8xi32>
+// CHECK: iree_linalg_ext.topk_v2
+// CHECK-SAME: dimension(1)
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[OUT_VALUES]], %[[OUT_INDICES]]
+// CHECK: iree_linalg_ext.yield
+
+// -----
+
+func.func @topk_v2_with_input_indices(%arg0: tensor<4x128xf32>, %arg1: tensor<4x128xi32>) -> (tensor<4x128xf32>, tensor<4x128xi32>) {
+ %out_values = tensor.empty() : tensor<4x128xf32>
+ %out_indices = tensor.empty() : tensor<4x128xi32>
+ %0:2 = iree_linalg_ext.topk_v2 dimension(1)
+ ins(%arg0, %arg1 : tensor<4x128xf32>, tensor<4x128xi32>)
+ outs(%out_values, %out_indices : tensor<4x128xf32>, tensor<4x128xi32>) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %cmp = arith.cmpf oge, %lhs, %rhs : f32
+ iree_linalg_ext.yield %cmp : i1
+ } -> tensor<4x128xf32>, tensor<4x128xi32>
+ return %0#0, %0#1 : tensor<4x128xf32>, tensor<4x128xi32>
+}
+// CHECK-LABEL: func.func @topk_v2_with_input_indices(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<4x128xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<4x128xi32>
+// CHECK: %[[OUT_VALUES:.+]] = tensor.empty()
+// CHECK: %[[OUT_INDICES:.+]] = tensor.empty()
+// CHECK: %[[RESULT:.+]]:2 = iree_linalg_ext.topk_v2
+// CHECK-SAME: dimension(1)
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]]
+// CHECK-SAME: outs(%[[OUT_VALUES]], %[[OUT_INDICES]]
+// CHECK: iree_linalg_ext.yield
+// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
+func.func @topk_v2_topk_with_input_indices(%arg0: tensor<4x1024xf32>, %arg1: tensor<4x1024xi32>) -> (tensor<4x8xf32>, tensor<4x8xi32>) {
+ %out_values = tensor.empty() : tensor<4x8xf32>
+ %out_indices = tensor.empty() : tensor<4x8xi32>
+ %0:2 = iree_linalg_ext.topk_v2 dimension(1)
+ ins(%arg0, %arg1 : tensor<4x1024xf32>, tensor<4x1024xi32>)
+ outs(%out_values, %out_indices : tensor<4x8xf32>, tensor<4x8xi32>) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %cmp = arith.cmpf oge, %lhs, %rhs : f32
+ iree_linalg_ext.yield %cmp : i1
+ } -> tensor<4x8xf32>, tensor<4x8xi32>
+ return %0#0, %0#1 : tensor<4x8xf32>, tensor<4x8xi32>
+}
+// CHECK-LABEL: func.func @topk_v2_topk_with_input_indices(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<4x1024xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<4x1024xi32>
+// CHECK: %[[OUT_VALUES:.+]] = tensor.empty()
+// CHECK: %[[OUT_INDICES:.+]] = tensor.empty()
+// CHECK: %[[RESULT:.+]]:2 = iree_linalg_ext.topk_v2
+// CHECK-SAME: dimension(1)
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]]
+// CHECK-SAME: outs(%[[OUT_VALUES]], %[[OUT_INDICES]]
+// CHECK: iree_linalg_ext.yield
+// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
+func.func @topk_v2_i64_indices(%arg0: tensor<4x128xf32>) -> (tensor<4x128xf32>, tensor<4x128xi64>) {
+ %out_values = tensor.empty() : tensor<4x128xf32>
+ %out_indices = tensor.empty() : tensor<4x128xi64>
+ %0:2 = iree_linalg_ext.topk_v2 dimension(1)
+ ins(%arg0 : tensor<4x128xf32>)
+ outs(%out_values, %out_indices : tensor<4x128xf32>, tensor<4x128xi64>) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %cmp = arith.cmpf oge, %lhs, %rhs : f32
+ iree_linalg_ext.yield %cmp : i1
+ } -> tensor<4x128xf32>, tensor<4x128xi64>
+ return %0#0, %0#1 : tensor<4x128xf32>, tensor<4x128xi64>
+}
+// CHECK-LABEL: func.func @topk_v2_i64_indices(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<4x128xf32>
+// CHECK: %[[OUT_VALUES:.+]] = tensor.empty()
+// CHECK: %[[OUT_INDICES:.+]] = tensor.empty()
+// CHECK: %[[RESULT:.+]]:2 = iree_linalg_ext.topk_v2
+// CHECK-SAME: dimension(1)
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[OUT_VALUES]], %[[OUT_INDICES]]
+// CHECK: iree_linalg_ext.yield
+// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
+func.func @topk_v2_integer_values(%arg0: tensor<4x128xi32>) -> tensor<4x128xi32> {
+ %out = tensor.empty() : tensor<4x128xi32>
+ %0 = iree_linalg_ext.topk_v2 dimension(1)
+ ins(%arg0 : tensor<4x128xi32>)
+ outs(%out : tensor<4x128xi32>) {
+ ^bb0(%lhs: i32, %rhs: i32):
+ %cmp = arith.cmpi sgt, %lhs, %rhs : i32
+ iree_linalg_ext.yield %cmp : i1
+ } -> tensor<4x128xi32>
+ return %0 : tensor<4x128xi32>
+}
+// CHECK-LABEL: func.func @topk_v2_integer_values(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<4x128xi32>
+// CHECK: %[[OUT:.+]] = tensor.empty()
+// CHECK: %[[RESULT:.+]] = iree_linalg_ext.topk_v2
+// CHECK-SAME: dimension(1)
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[OUT]]
+// CHECK: iree_linalg_ext.yield
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @topk_v2_dynamic(%arg0: tensor<?x?xf32>, %out_values: tensor<?x?xf32>, %out_indices: tensor<?x?xi32>) -> (tensor<?x?xf32>, tensor<?x?xi32>) {
+ %0:2 = iree_linalg_ext.topk_v2 dimension(1)
+ ins(%arg0 : tensor<?x?xf32>)
+ outs(%out_values, %out_indices : tensor<?x?xf32>, tensor<?x?xi32>) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %cmp = arith.cmpf oge, %lhs, %rhs : f32
+ iree_linalg_ext.yield %cmp : i1
+ } -> tensor<?x?xf32>, tensor<?x?xi32>
+ return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xi32>
+}
+// CHECK-LABEL: func.func @topk_v2_dynamic(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[OUT_VALUES:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[OUT_INDICES:[a-zA-Z0-9]+]]: tensor<?x?xi32>
+// CHECK: %[[RESULT:.+]]:2 = iree_linalg_ext.topk_v2
+// CHECK-SAME: dimension(1)
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[OUT_VALUES]], %[[OUT_INDICES]]
+// CHECK: iree_linalg_ext.yield
+// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
+func.func @topk_v2_sorted_tensor(%arg0: tensor<4x1024xf32>) -> (tensor<4x8xf32>, tensor<4x8xi32>) {
+ %out_values = tensor.empty() : tensor<4x8xf32>
+ %out_indices = tensor.empty() : tensor<4x8xi32>
+ %0:2 = iree_linalg_ext.topk_v2 dimension(1) is_sorted
+ ins(%arg0 : tensor<4x1024xf32>)
+ outs(%out_values, %out_indices : tensor<4x8xf32>, tensor<4x8xi32>) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %cmp = arith.cmpf oge, %lhs, %rhs : f32
+ iree_linalg_ext.yield %cmp : i1
+ } -> tensor<4x8xf32>, tensor<4x8xi32>
+ return %0#0, %0#1 : tensor<4x8xf32>, tensor<4x8xi32>
+}
+// CHECK-LABEL: func.func @topk_v2_sorted_tensor(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<4x1024xf32>
+// CHECK: %[[OUT_VALUES:.+]] = tensor.empty()
+// CHECK: %[[OUT_INDICES:.+]] = tensor.empty()
+// CHECK: %[[RESULT:.+]]:2 = iree_linalg_ext.topk_v2
+// CHECK-SAME: dimension(1) is_sorted
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[OUT_VALUES]], %[[OUT_INDICES]]
+// CHECK: iree_linalg_ext.yield
+// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
+func.func @topk_v2_sorted_no_indices(%arg0: tensor<4x1024xf32>) -> tensor<4x8xf32> {
+ %out_values = tensor.empty() : tensor<4x8xf32>
+ %0 = iree_linalg_ext.topk_v2 dimension(1) is_sorted
+ ins(%arg0 : tensor<4x1024xf32>)
+ outs(%out_values : tensor<4x8xf32>) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %cmp = arith.cmpf oge, %lhs, %rhs : f32
+ iree_linalg_ext.yield %cmp : i1
+ } -> tensor<4x8xf32>
+ return %0 : tensor<4x8xf32>
+}
+// CHECK-LABEL: func.func @topk_v2_sorted_no_indices(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<4x1024xf32>
+// CHECK: %[[OUT_VALUES:.+]] = tensor.empty()
+// CHECK: %[[RESULT:.+]] = iree_linalg_ext.topk_v2
+// CHECK-SAME: dimension(1) is_sorted
+// CHECK-SAME: ins(%[[ARG0]]
+// CHECK-SAME: outs(%[[OUT_VALUES]]
+// CHECK: iree_linalg_ext.yield
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @topk_v2_sorted_with_input_indices(%arg0: tensor<4x1024xf32>, %arg1: tensor<4x1024xi32>) -> (tensor<4x8xf32>, tensor<4x8xi32>) {
+ %out_values = tensor.empty() : tensor<4x8xf32>
+ %out_indices = tensor.empty() : tensor<4x8xi32>
+ %0:2 = iree_linalg_ext.topk_v2 dimension(1) is_sorted
+ ins(%arg0, %arg1 : tensor<4x1024xf32>, tensor<4x1024xi32>)
+ outs(%out_values, %out_indices : tensor<4x8xf32>, tensor<4x8xi32>) {
+ ^bb0(%lhs: f32, %rhs: f32):
+ %cmp = arith.cmpf oge, %lhs, %rhs : f32
+ iree_linalg_ext.yield %cmp : i1
+ } -> tensor<4x8xf32>, tensor<4x8xi32>
+ return %0#0, %0#1 : tensor<4x8xf32>, tensor<4x8xi32>
+}
+// CHECK-LABEL: func.func @topk_v2_sorted_with_input_indices(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<4x1024xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<4x1024xi32>
+// CHECK: %[[OUT_VALUES:.+]] = tensor.empty()
+// CHECK: %[[OUT_INDICES:.+]] = tensor.empty()
+// CHECK: %[[RESULT:.+]]:2 = iree_linalg_ext.topk_v2
+// CHECK-SAME: dimension(1) is_sorted
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]]
+// CHECK-SAME: outs(%[[OUT_VALUES]], %[[OUT_INDICES]]
+// CHECK: iree_linalg_ext.yield
+// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
func.func @exp_reduction(%S: tensor<2x3xf32>) -> tensor<2xf32> {
%M = tensor.empty() : tensor<2xf32>
%out = tensor.empty() : tensor<2xf32>