[LinalgExt] Add topk_v2 op with roundtrip and invalid mlir test (#24054)

**Context**: Refer to [Discord
discussion](https://discord.com/channels/689900678990135345/1491812950451814540)
for the design choice. We aim to add a top-K op with bitonic sort
vectorization along the VectorDistribute pipeline. Refer to Issue:
https://github.com/iree-org/iree/issues/24053 for more details.

This PR, as the first step, defines the` iree_linalg_ext.topk_v2`
operation and defines its verifier along with corresponding roundtrip
and invalid tests (`roundtrip.mlir` and `invalid.mlir`).

Support for topk_v2 will be added incrementally across multiple PRs: 

- Add the `iree_linalg_ext.topk_v2` op definition and verification
logic, along with roundtrip and invalid IR tests (this PR)
- Add tiling support
- Add convert_to_loops lowering support
- Add e2e support along the VectorDistribute pipeline (vectorization,
layout analysis, and distribution, etc)

---------

Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
index 13a6347..a10cfbc 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
@@ -1290,6 +1290,111 @@
 }
 
 //===----------------------------------------------------------------------===//
+// TopkV2Op
+//===----------------------------------------------------------------------===//
+
+LogicalResult TopkV2Op::verify() {
+  auto inputValuesType = getInputType();
+  auto outputValuesType = cast<ShapedType>(getOutputValues().getType());
+  uint64_t dim = getDimension();
+
+  if (dim >= static_cast<uint64_t>(getInputRank())) {
+    return emitOpError("dimension exceeds rank");
+  }
+  if (inputValuesType.getElementType() != outputValuesType.getElementType()) {
+    return emitOpError("expected input/output value types to be identical");
+  }
+  if (inputValuesType.getRank() != outputValuesType.getRank()) {
+    return emitOpError("expected input/output to have the same rank");
+  }
+
+  if (Value inputIndices = getInputIndices()) {
+    if (!getOutputIndices()) {
+      return emitOpError(
+          "input indices require output indices to carry provenance");
+    }
+    auto inputIndicesType = cast<ShapedType>(inputIndices.getType());
+    if (!isa<IntegerType>(inputIndicesType.getElementType())) {
+      return emitOpError("expected input indices to be integer type");
+    }
+    if (failed(verifyCompatibleShape(inputValuesType, inputIndicesType))) {
+      return emitOpError("input values/indices shape must match");
+    }
+  }
+
+  if (Value outputIndices = getOutputIndices()) {
+    auto outputIndicesType = cast<ShapedType>(outputIndices.getType());
+    if (!isa<IntegerType>(outputIndicesType.getElementType())) {
+      return emitOpError("expected output indices to be integer type");
+    }
+    if (failed(verifyCompatibleShape(outputValuesType, outputIndicesType))) {
+      return emitOpError("output values/indices shape must match");
+    }
+  }
+
+  // All dimensions except the sort dimension must match.
+  for (auto [idx, inDim, outDim] : llvm::enumerate(
+           inputValuesType.getShape(), outputValuesType.getShape())) {
+    if (idx == dim) {
+      continue;
+    }
+    if (ShapedType::isStatic(inDim) && ShapedType::isStatic(outDim) &&
+        inDim != outDim) {
+      return emitOpError("incompatible input/output shapes at dimension ")
+             << idx;
+    }
+  }
+
+  // Validate that output K does not exceed input along the sort dimension.
+  int64_t inputDimSize = inputValuesType.getDimSize(dim);
+  int64_t outputDimSize = outputValuesType.getDimSize(dim);
+  if (ShapedType::isStatic(inputDimSize) &&
+      ShapedType::isStatic(outputDimSize)) {
+    if (outputDimSize == 0) {
+      return emitOpError("output dimension must be positive");
+    }
+    if (outputDimSize > inputDimSize) {
+      return emitOpError("output dimension must not exceed input, got ")
+             << outputDimSize << " > " << inputDimSize;
+    }
+  }
+
+  return success();
+}
+
+LogicalResult TopkV2Op::verifyRegions() {
+  auto inputValuesType = getInputType();
+  Block &block = getRegion().front();
+  if (block.getNumArguments() != 2) {
+    return emitOpError("region block should have 2 arguments");
+  }
+  if (block.getArgument(0).getType() != inputValuesType.getElementType() ||
+      block.getArgument(1).getType() != inputValuesType.getElementType()) {
+    return emitOpError("region block types must match input value type");
+  }
+  auto terminatorOp = cast<YieldOp>(block.getTerminator());
+  if (terminatorOp.getNumOperands() != 1 ||
+      !terminatorOp.getOperand(0).getType().isInteger(1)) {
+    return emitOpError("region block must end with a linalg_ext.yield i1");
+  }
+  return success();
+}
+
+LogicalResult
+TopkV2Op::reifyResultShapes(OpBuilder &b,
+                            ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+  return cast<LinalgExtOp>(getOperation())
+      .reifyResultShapes(b, reifiedReturnShapes);
+}
+
+MutableOperandRange TopkV2Op::getDpsInitsMutable() {
+  // Operands order: values, [input_indices], output_values, [output_indices]
+  unsigned numInputs = 1 + (getInputIndices() ? 1 : 0);
+  unsigned numInits = 1 + (getOutputIndices() ? 1 : 0);
+  return MutableOperandRange(*this, numInputs, numInits);
+}
+
+//===----------------------------------------------------------------------===//
 // ArgCompareOp
 //===----------------------------------------------------------------------===//
 
@@ -3130,6 +3235,7 @@
 DEFINE_OP_GET_EFFECTS(FftOp)
 DEFINE_OP_GET_EFFECTS(ScanOp)
 DEFINE_OP_GET_EFFECTS(TopkOp)
+DEFINE_OP_GET_EFFECTS(TopkV2Op)
 DEFINE_OP_GET_EFFECTS(ArgCompareOp)
 DEFINE_OP_GET_EFFECTS(WinogradInputTransformOp)
 DEFINE_OP_GET_EFFECTS(WinogradFilterTransformOp)
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
index 1da653c..01f198d 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
@@ -680,6 +680,77 @@
   }];
 }
 
+def IREELinalgExt_TopkV2Op : IREELinalgExt_Op<"topk_v2",[
+  AttrSizedOperandSegments,
+  DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, ["reifyResultShapes"]>,
+  DeclareOpInterfaceMethods<LinalgExtInterface>
+]>{
+  let summary = [{Top-K v2 operator.}];
+  let description = [{
+    Selects the top-K elements along the given `dimension` using the given
+    `comparator`. K is determined by the output shape. The output sort
+    dimension must be less than or equal to the input sort dimension.
+
+    This op replaces the legacy `topk` op with several improvements:
+    output indices are optional instead of mandatory, input indices of any
+    integer type are supported for pre-permuted inputs, and the comparator
+    region uses a pure ordering predicate decoupled from the swap semantics
+    of the legacy op.
+
+    Accepts a single N-D input tensor of values and an optional N-D tensor of
+    input indices (any integer type). If input indices aren't provided, the
+    index mapping defaults to [0, N) along the sort dimension, where N is
+    the input size. Input indices may only be provided when output indices
+    are also present, since the comparator operates solely on values and
+    input indices only carry positional provenance to the output. Produces
+    output values and optionally output indices (integer type) tracking
+    original positions.
+
+    When `is_sorted` is present, the output top-K elements are guaranteed to
+    be in sorted order along the sort dimension. When absent (default), the
+    top-K elements are returned but their relative order is unspecified.
+
+    Comparator region accepts two scalar arguments of the input value element
+    type and yields an i1 indicating whether the first argument should be
+    ordered before the second.
+  }];
+
+  let arguments = (ins I64Attr:$dimension,
+                       UnitAttr:$is_sorted,
+                       AnyShaped:$values,
+                       Optional<AnyShaped>:$input_indices,
+                       AnyShaped:$output_values,
+                       Optional<AnyShaped>:$output_indices
+  );
+  let results = (outs Variadic<AnyRankedTensor>:$results);
+  let regions = (region SizedRegion<1>:$region);
+  let hasRegionVerifier = 1;
+  let assemblyFormat = [{
+    attr-dict
+    `dimension` `(` $dimension `)`
+    (`is_sorted` $is_sorted^)?
+    `ins` `(` $values (`,` $input_indices^)? `:` type($values) (`,` type($input_indices)^)? `)`
+    `outs` `(` $output_values (`,` $output_indices^)? `:`
+               type($output_values) (`,` type($output_indices)^)? `)`
+    $region (`->` type($results)^)?
+  }];
+  let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
+    ShapedType getInputType() {
+      return cast<ShapedType>(getValues().getType());
+    }
+    int64_t getInputRank() {
+      return getInputType().getRank();
+    }
+    bool hasOutputIndices() {
+      return static_cast<bool>(getOutputIndices());
+    }
+
+    // Method to implement for specifying output range for
+    // DestinationStyleOpInterface
+    MutableOperandRange getDpsInitsMutable();
+  }];
+}
+
 def IREELinalgExt_ArgCompareOp : IREELinalgExt_Op<"arg_compare", [
   AttrSizedOperandSegments,
   DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface, ["reifyResultShapes"]>,
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
index 2902eb2..5529e94 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir
@@ -981,6 +981,224 @@
 
 // -----
 
+func.func @topk_v2_dimension_exceeds_rank(%arg0: tensor<128xf32>) -> tensor<128xf32> {
+  %out = tensor.empty() : tensor<128xf32>
+  // expected-error@+1 {{dimension exceeds rank}}
+  %0 = iree_linalg_ext.topk_v2 dimension(1)
+      ins(%arg0 : tensor<128xf32>)
+      outs(%out : tensor<128xf32>) {
+      ^bb0(%lhs: f32, %rhs: f32):
+        %cmp = arith.cmpf oge, %lhs, %rhs : f32
+        iree_linalg_ext.yield %cmp : i1
+      } -> tensor<128xf32>
+  return %0 : tensor<128xf32>
+}
+
+// -----
+
+func.func @topk_v2_value_type_mismatch(%arg0: tensor<4x128xi32>) -> tensor<4x128xf32> {
+  %out = tensor.empty() : tensor<4x128xf32>
+  // expected-error@+1 {{expected input/output value types to be identical}}
+  %0 = iree_linalg_ext.topk_v2 dimension(1)
+      ins(%arg0 : tensor<4x128xi32>)
+      outs(%out : tensor<4x128xf32>) {
+      ^bb0(%lhs: i32, %rhs: i32):
+        %cmp = arith.cmpi sgt, %lhs, %rhs : i32
+        iree_linalg_ext.yield %cmp : i1
+      } -> tensor<4x128xf32>
+  return %0 : tensor<4x128xf32>
+}
+
+// -----
+
+func.func @topk_v2_rank_mismatch(%arg0: tensor<4x128xf32>) -> tensor<128xf32> {
+  %out = tensor.empty() : tensor<128xf32>
+  // expected-error@+1 {{expected input/output to have the same rank}}
+  %0 = iree_linalg_ext.topk_v2 dimension(0)
+      ins(%arg0 : tensor<4x128xf32>)
+      outs(%out : tensor<128xf32>) {
+      ^bb0(%lhs: f32, %rhs: f32):
+        %cmp = arith.cmpf oge, %lhs, %rhs : f32
+        iree_linalg_ext.yield %cmp : i1
+      } -> tensor<128xf32>
+  return %0 : tensor<128xf32>
+}
+
+// -----
+
+func.func @topk_v2_input_indices_not_integer(
+    %arg0: tensor<4x128xf32>, %arg1: tensor<4x128xf32>)
+    -> (tensor<4x128xf32>, tensor<4x128xi32>) {
+  %out_values = tensor.empty() : tensor<4x128xf32>
+  %out_indices = tensor.empty() : tensor<4x128xi32>
+  // expected-error@+1 {{expected input indices to be integer type}}
+  %0:2 = iree_linalg_ext.topk_v2 dimension(1)
+      ins(%arg0, %arg1 : tensor<4x128xf32>, tensor<4x128xf32>)
+      outs(%out_values, %out_indices : tensor<4x128xf32>, tensor<4x128xi32>) {
+      ^bb0(%lhs: f32, %rhs: f32):
+        %cmp = arith.cmpf oge, %lhs, %rhs : f32
+        iree_linalg_ext.yield %cmp : i1
+      } -> tensor<4x128xf32>, tensor<4x128xi32>
+  return %0#0, %0#1 : tensor<4x128xf32>, tensor<4x128xi32>
+}
+
+// -----
+
+func.func @topk_v2_input_indices_shape_mismatch(
+    %arg0: tensor<4x128xf32>, %arg1: tensor<8x128xi32>)
+    -> (tensor<4x128xf32>, tensor<4x128xi32>) {
+  %out_values = tensor.empty() : tensor<4x128xf32>
+  %out_indices = tensor.empty() : tensor<4x128xi32>
+  // expected-error@+1 {{input values/indices shape must match}}
+  %0:2 = iree_linalg_ext.topk_v2 dimension(1)
+      ins(%arg0, %arg1 : tensor<4x128xf32>, tensor<8x128xi32>)
+      outs(%out_values, %out_indices : tensor<4x128xf32>, tensor<4x128xi32>) {
+      ^bb0(%lhs: f32, %rhs: f32):
+        %cmp = arith.cmpf oge, %lhs, %rhs : f32
+        iree_linalg_ext.yield %cmp : i1
+      } -> tensor<4x128xf32>, tensor<4x128xi32>
+  return %0#0, %0#1 : tensor<4x128xf32>, tensor<4x128xi32>
+}
+
+// -----
+
+func.func @topk_v2_indices_not_integer(%arg0: tensor<4x128xf32>) -> (tensor<4x128xf32>, tensor<4x128xf32>) {
+  %out_values = tensor.empty() : tensor<4x128xf32>
+  %out_indices = tensor.empty() : tensor<4x128xf32>
+  // expected-error@+1 {{expected output indices to be integer type}}
+  %0:2 = iree_linalg_ext.topk_v2 dimension(1)
+      ins(%arg0 : tensor<4x128xf32>)
+      outs(%out_values, %out_indices : tensor<4x128xf32>, tensor<4x128xf32>) {
+      ^bb0(%lhs: f32, %rhs: f32):
+        %cmp = arith.cmpf oge, %lhs, %rhs : f32
+        iree_linalg_ext.yield %cmp : i1
+      } -> tensor<4x128xf32>, tensor<4x128xf32>
+  return %0#0, %0#1 : tensor<4x128xf32>, tensor<4x128xf32>
+}
+
+// -----
+
+func.func @topk_v2_indices_shape_mismatch(%arg0: tensor<4x128xf32>) -> (tensor<4x128xf32>, tensor<8x128xi32>) {
+  %out_values = tensor.empty() : tensor<4x128xf32>
+  %out_indices = tensor.empty() : tensor<8x128xi32>
+  // expected-error@+1 {{output values/indices shape must match}}
+  %0:2 = iree_linalg_ext.topk_v2 dimension(1)
+      ins(%arg0 : tensor<4x128xf32>)
+      outs(%out_values, %out_indices : tensor<4x128xf32>, tensor<8x128xi32>) {
+      ^bb0(%lhs: f32, %rhs: f32):
+        %cmp = arith.cmpf oge, %lhs, %rhs : f32
+        iree_linalg_ext.yield %cmp : i1
+      } -> tensor<4x128xf32>, tensor<8x128xi32>
+  return %0#0, %0#1 : tensor<4x128xf32>, tensor<8x128xi32>
+}
+
+// -----
+
+func.func @topk_v2_non_sort_dim_mismatch(%arg0: tensor<4x128xf32>) -> tensor<8x128xf32> {
+  %out = tensor.empty() : tensor<8x128xf32>
+  // expected-error@+1 {{incompatible input/output shapes at dimension 0}}
+  %0 = iree_linalg_ext.topk_v2 dimension(1)
+      ins(%arg0 : tensor<4x128xf32>)
+      outs(%out : tensor<8x128xf32>) {
+      ^bb0(%lhs: f32, %rhs: f32):
+        %cmp = arith.cmpf oge, %lhs, %rhs : f32
+        iree_linalg_ext.yield %cmp : i1
+      } -> tensor<8x128xf32>
+  return %0 : tensor<8x128xf32>
+}
+
+// -----
+
+func.func @topk_v2_wrong_comparator_args(%arg0: tensor<4x128xf32>) -> tensor<4x128xf32> {
+  %out = tensor.empty() : tensor<4x128xf32>
+  // expected-error@+1 {{region block should have 2 arguments}}
+  %0 = iree_linalg_ext.topk_v2 dimension(1)
+      ins(%arg0 : tensor<4x128xf32>)
+      outs(%out : tensor<4x128xf32>) {
+      ^bb0(%lhs: f32, %mid: f32, %rhs: f32):
+        %cmp = arith.cmpf oge, %lhs, %rhs : f32
+        iree_linalg_ext.yield %cmp : i1
+      } -> tensor<4x128xf32>
+  return %0 : tensor<4x128xf32>
+}
+
+// -----
+
+func.func @topk_v2_comparator_type_mismatch(%arg0: tensor<4x128xf32>) -> tensor<4x128xf32> {
+  %out = tensor.empty() : tensor<4x128xf32>
+  // expected-error@+1 {{region block types must match input value type}}
+  %0 = iree_linalg_ext.topk_v2 dimension(1)
+      ins(%arg0 : tensor<4x128xf32>)
+      outs(%out : tensor<4x128xf32>) {
+      ^bb0(%lhs: i32, %rhs: i32):
+        %cmp = arith.cmpi sgt, %lhs, %rhs : i32
+        iree_linalg_ext.yield %cmp : i1
+      } -> tensor<4x128xf32>
+  return %0 : tensor<4x128xf32>
+}
+
+// -----
+
+func.func @topk_v2_comparator_not_i1(%arg0: tensor<4x128xf32>) -> tensor<4x128xf32> {
+  %out = tensor.empty() : tensor<4x128xf32>
+  // expected-error@+1 {{region block must end with a linalg_ext.yield i1}}
+  %0 = iree_linalg_ext.topk_v2 dimension(1)
+      ins(%arg0 : tensor<4x128xf32>)
+      outs(%out : tensor<4x128xf32>) {
+      ^bb0(%lhs: f32, %rhs: f32):
+        iree_linalg_ext.yield %lhs : f32
+      } -> tensor<4x128xf32>
+  return %0 : tensor<4x128xf32>
+}
+
+// -----
+
+func.func @topk_v2_zero_k(%arg0: tensor<4x8xf32>) -> tensor<4x0xf32> {
+  %out = tensor.empty() : tensor<4x0xf32>
+  // expected-error@+1 {{output dimension must be positive}}
+  %0 = iree_linalg_ext.topk_v2 dimension(1)
+      ins(%arg0 : tensor<4x8xf32>)
+      outs(%out : tensor<4x0xf32>) {
+      ^bb0(%lhs: f32, %rhs: f32):
+        %cmp = arith.cmpf oge, %lhs, %rhs : f32
+        iree_linalg_ext.yield %cmp : i1
+      } -> tensor<4x0xf32>
+  return %0 : tensor<4x0xf32>
+}
+
+// -----
+
+func.func @topk_v2_output_exceeds_input(%arg0: tensor<4x8xf32>) -> tensor<4x16xf32> {
+  %out = tensor.empty() : tensor<4x16xf32>
+  // expected-error@+1 {{output dimension must not exceed input}}
+  %0 = iree_linalg_ext.topk_v2 dimension(1)
+      ins(%arg0 : tensor<4x8xf32>)
+      outs(%out : tensor<4x16xf32>) {
+      ^bb0(%lhs: f32, %rhs: f32):
+        %cmp = arith.cmpf oge, %lhs, %rhs : f32
+        iree_linalg_ext.yield %cmp : i1
+      } -> tensor<4x16xf32>
+  return %0 : tensor<4x16xf32>
+}
+
+// -----
+
+func.func @topk_v2_input_indices_without_output_indices(
+    %arg0: tensor<4x128xf32>, %arg1: tensor<4x128xi32>) -> tensor<4x64xf32> {
+  %out = tensor.empty() : tensor<4x64xf32>
+  // expected-error@+1 {{input indices require output indices to carry provenance}}
+  %0 = iree_linalg_ext.topk_v2 dimension(1)
+      ins(%arg0, %arg1 : tensor<4x128xf32>, tensor<4x128xi32>)
+      outs(%out : tensor<4x64xf32>) {
+      ^bb0(%lhs: f32, %rhs: f32):
+        %cmp = arith.cmpf oge, %lhs, %rhs : f32
+        iree_linalg_ext.yield %cmp : i1
+      } -> tensor<4x64xf32>
+  return %0 : tensor<4x64xf32>
+}
+
+// -----
+
 func.func @exp_reduction_non_zero(%S: tensor<2x3xf32>) -> tensor<2xf32> {
   %M = tensor.empty() : tensor<2xf32>
   %out = tensor.empty() : tensor<2xf32>
diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
index ac54caa..691f9df 100644
--- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
+++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir
@@ -1272,6 +1272,342 @@
 
 // -----
 
+func.func @topk_v2_tensor(%arg0: tensor<4x1024xf32>) -> tensor<4x1024xf32> {
+  %out = tensor.empty() : tensor<4x1024xf32>
+  %0 = iree_linalg_ext.topk_v2 dimension(1)
+      ins(%arg0 : tensor<4x1024xf32>)
+      outs(%out : tensor<4x1024xf32>) {
+      ^bb0(%lhs: f32, %rhs: f32):
+        %cmp = arith.cmpf oge, %lhs, %rhs : f32
+        iree_linalg_ext.yield %cmp : i1
+      } -> tensor<4x1024xf32>
+  return %0 : tensor<4x1024xf32>
+}
+// CHECK-LABEL: func.func @topk_v2_tensor(
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<4x1024xf32>
+//       CHECK:   %[[OUT:.+]] = tensor.empty()
+//       CHECK:   %[[RESULT:.+]] = iree_linalg_ext.topk_v2
+//  CHECK-SAME:      dimension(1)
+//  CHECK-SAME:      ins(%[[ARG0]]
+//  CHECK-SAME:      outs(%[[OUT]]
+//       CHECK:      iree_linalg_ext.yield
+//       CHECK:   return %[[RESULT]]
+
+// -----
+
+func.func @topk_v2_tensor_with_indices(%arg0: tensor<4x1024xf32>) -> (tensor<4x1024xf32>, tensor<4x1024xi32>) {
+  %out_values = tensor.empty() : tensor<4x1024xf32>
+  %out_indices = tensor.empty() : tensor<4x1024xi32>
+  %0:2 = iree_linalg_ext.topk_v2 dimension(1)
+      ins(%arg0 : tensor<4x1024xf32>)
+      outs(%out_values, %out_indices : tensor<4x1024xf32>, tensor<4x1024xi32>) {
+      ^bb0(%lhs: f32, %rhs: f32):
+        %cmp = arith.cmpf oge, %lhs, %rhs : f32
+        iree_linalg_ext.yield %cmp : i1
+      } -> tensor<4x1024xf32>, tensor<4x1024xi32>
+  return %0#0, %0#1 : tensor<4x1024xf32>, tensor<4x1024xi32>
+}
+// CHECK-LABEL: func.func @topk_v2_tensor_with_indices(
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<4x1024xf32>
+//       CHECK:   %[[OUT_VALUES:.+]] = tensor.empty()
+//       CHECK:   %[[OUT_INDICES:.+]] = tensor.empty()
+//       CHECK:   %[[RESULT:.+]]:2 = iree_linalg_ext.topk_v2
+//  CHECK-SAME:      dimension(1)
+//  CHECK-SAME:      ins(%[[ARG0]]
+//  CHECK-SAME:      outs(%[[OUT_VALUES]], %[[OUT_INDICES]]
+//       CHECK:      iree_linalg_ext.yield
+//       CHECK:   return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
+func.func @topk_v2_memref(%arg0: memref<4x1024xf32>, %arg1: memref<4x1024xf32>) {
+  iree_linalg_ext.topk_v2 dimension(1)
+      ins(%arg0 : memref<4x1024xf32>)
+      outs(%arg1 : memref<4x1024xf32>) {
+      ^bb0(%lhs: f32, %rhs: f32):
+        %cmp = arith.cmpf oge, %lhs, %rhs : f32
+        iree_linalg_ext.yield %cmp : i1
+      }
+  return
+}
+// CHECK-LABEL: func.func @topk_v2_memref(
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: memref<4x1024xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: memref<4x1024xf32>
+//       CHECK:   iree_linalg_ext.topk_v2
+//  CHECK-SAME:      dimension(1)
+//  CHECK-SAME:      ins(%[[ARG0]]
+//  CHECK-SAME:      outs(%[[ARG1]]
+//       CHECK:      iree_linalg_ext.yield
+
+// -----
+
+func.func @topk_v2_topk_tensor(%arg0: tensor<4x1024xf32>) -> (tensor<4x8xf32>, tensor<4x8xi32>) {
+  %out_values = tensor.empty() : tensor<4x8xf32>
+  %out_indices = tensor.empty() : tensor<4x8xi32>
+  %0:2 = iree_linalg_ext.topk_v2 dimension(1)
+      ins(%arg0 : tensor<4x1024xf32>)
+      outs(%out_values, %out_indices : tensor<4x8xf32>, tensor<4x8xi32>) {
+      ^bb0(%lhs: f32, %rhs: f32):
+        %cmp = arith.cmpf oge, %lhs, %rhs : f32
+        iree_linalg_ext.yield %cmp : i1
+      } -> tensor<4x8xf32>, tensor<4x8xi32>
+  return %0#0, %0#1 : tensor<4x8xf32>, tensor<4x8xi32>
+}
+// CHECK-LABEL: func.func @topk_v2_topk_tensor(
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<4x1024xf32>
+//       CHECK:   %[[OUT_VALUES:.+]] = tensor.empty()
+//       CHECK:   %[[OUT_INDICES:.+]] = tensor.empty()
+//       CHECK:   %[[RESULT:.+]]:2 = iree_linalg_ext.topk_v2
+//  CHECK-SAME:      dimension(1)
+//  CHECK-SAME:      ins(%[[ARG0]]
+//  CHECK-SAME:      outs(%[[OUT_VALUES]], %[[OUT_INDICES]]
+//       CHECK:      iree_linalg_ext.yield
+//       CHECK:   return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
+func.func @topk_v2_topk_no_indices(%arg0: tensor<4x1024xf32>) -> tensor<4x8xf32> {
+  %out_values = tensor.empty() : tensor<4x8xf32>
+  %0 = iree_linalg_ext.topk_v2 dimension(1)
+      ins(%arg0 : tensor<4x1024xf32>)
+      outs(%out_values : tensor<4x8xf32>) {
+      ^bb0(%lhs: f32, %rhs: f32):
+        %cmp = arith.cmpf oge, %lhs, %rhs : f32
+        iree_linalg_ext.yield %cmp : i1
+      } -> tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
+}
+// CHECK-LABEL: func.func @topk_v2_topk_no_indices(
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<4x1024xf32>
+//       CHECK:   %[[OUT_VALUES:.+]] = tensor.empty()
+//       CHECK:   %[[RESULT:.+]] = iree_linalg_ext.topk_v2
+//  CHECK-SAME:      dimension(1)
+//  CHECK-SAME:      ins(%[[ARG0]]
+//  CHECK-SAME:      outs(%[[OUT_VALUES]]
+//       CHECK:      iree_linalg_ext.yield
+//       CHECK:   return %[[RESULT]]
+
+// -----
+
+func.func @topk_v2_topk_memref(%arg0: memref<4x1024xf32>, %out_values: memref<4x8xf32>, %out_indices: memref<4x8xi32>) {
+  iree_linalg_ext.topk_v2 dimension(1)
+      ins(%arg0 : memref<4x1024xf32>)
+      outs(%out_values, %out_indices : memref<4x8xf32>, memref<4x8xi32>) {
+      ^bb0(%lhs: f32, %rhs: f32):
+        %cmp = arith.cmpf oge, %lhs, %rhs : f32
+        iree_linalg_ext.yield %cmp : i1
+      }
+  return
+}
+// CHECK-LABEL: func.func @topk_v2_topk_memref(
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: memref<4x1024xf32>
+//  CHECK-SAME:   %[[OUT_VALUES:[a-zA-Z0-9]+]]: memref<4x8xf32>
+//  CHECK-SAME:   %[[OUT_INDICES:[a-zA-Z0-9]+]]: memref<4x8xi32>
+//       CHECK:   iree_linalg_ext.topk_v2
+//  CHECK-SAME:      dimension(1)
+//  CHECK-SAME:      ins(%[[ARG0]]
+//  CHECK-SAME:      outs(%[[OUT_VALUES]], %[[OUT_INDICES]]
+//       CHECK:      iree_linalg_ext.yield
+
+// -----
+
+func.func @topk_v2_with_input_indices(%arg0: tensor<4x128xf32>, %arg1: tensor<4x128xi32>) -> (tensor<4x128xf32>, tensor<4x128xi32>) {
+  %out_values = tensor.empty() : tensor<4x128xf32>
+  %out_indices = tensor.empty() : tensor<4x128xi32>
+  %0:2 = iree_linalg_ext.topk_v2 dimension(1)
+      ins(%arg0, %arg1 : tensor<4x128xf32>, tensor<4x128xi32>)
+      outs(%out_values, %out_indices : tensor<4x128xf32>, tensor<4x128xi32>) {
+      ^bb0(%lhs: f32, %rhs: f32):
+        %cmp = arith.cmpf oge, %lhs, %rhs : f32
+        iree_linalg_ext.yield %cmp : i1
+      } -> tensor<4x128xf32>, tensor<4x128xi32>
+  return %0#0, %0#1 : tensor<4x128xf32>, tensor<4x128xi32>
+}
+// CHECK-LABEL: func.func @topk_v2_with_input_indices(
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<4x128xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: tensor<4x128xi32>
+//       CHECK:   %[[OUT_VALUES:.+]] = tensor.empty()
+//       CHECK:   %[[OUT_INDICES:.+]] = tensor.empty()
+//       CHECK:   %[[RESULT:.+]]:2 = iree_linalg_ext.topk_v2
+//  CHECK-SAME:      dimension(1)
+//  CHECK-SAME:      ins(%[[ARG0]], %[[ARG1]]
+//  CHECK-SAME:      outs(%[[OUT_VALUES]], %[[OUT_INDICES]]
+//       CHECK:      iree_linalg_ext.yield
+//       CHECK:   return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
+func.func @topk_v2_topk_with_input_indices(%arg0: tensor<4x1024xf32>, %arg1: tensor<4x1024xi32>) -> (tensor<4x8xf32>, tensor<4x8xi32>) {
+  %out_values = tensor.empty() : tensor<4x8xf32>
+  %out_indices = tensor.empty() : tensor<4x8xi32>
+  %0:2 = iree_linalg_ext.topk_v2 dimension(1)
+      ins(%arg0, %arg1 : tensor<4x1024xf32>, tensor<4x1024xi32>)
+      outs(%out_values, %out_indices : tensor<4x8xf32>, tensor<4x8xi32>) {
+      ^bb0(%lhs: f32, %rhs: f32):
+        %cmp = arith.cmpf oge, %lhs, %rhs : f32
+        iree_linalg_ext.yield %cmp : i1
+      } -> tensor<4x8xf32>, tensor<4x8xi32>
+  return %0#0, %0#1 : tensor<4x8xf32>, tensor<4x8xi32>
+}
+// CHECK-LABEL: func.func @topk_v2_topk_with_input_indices(
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<4x1024xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: tensor<4x1024xi32>
+//       CHECK:   %[[OUT_VALUES:.+]] = tensor.empty()
+//       CHECK:   %[[OUT_INDICES:.+]] = tensor.empty()
+//       CHECK:   %[[RESULT:.+]]:2 = iree_linalg_ext.topk_v2
+//  CHECK-SAME:      dimension(1)
+//  CHECK-SAME:      ins(%[[ARG0]], %[[ARG1]]
+//  CHECK-SAME:      outs(%[[OUT_VALUES]], %[[OUT_INDICES]]
+//       CHECK:      iree_linalg_ext.yield
+//       CHECK:   return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
+func.func @topk_v2_i64_indices(%arg0: tensor<4x128xf32>) -> (tensor<4x128xf32>, tensor<4x128xi64>) {
+  %out_values = tensor.empty() : tensor<4x128xf32>
+  %out_indices = tensor.empty() : tensor<4x128xi64>
+  %0:2 = iree_linalg_ext.topk_v2 dimension(1)
+      ins(%arg0 : tensor<4x128xf32>)
+      outs(%out_values, %out_indices : tensor<4x128xf32>, tensor<4x128xi64>) {
+      ^bb0(%lhs: f32, %rhs: f32):
+        %cmp = arith.cmpf oge, %lhs, %rhs : f32
+        iree_linalg_ext.yield %cmp : i1
+      } -> tensor<4x128xf32>, tensor<4x128xi64>
+  return %0#0, %0#1 : tensor<4x128xf32>, tensor<4x128xi64>
+}
+// CHECK-LABEL: func.func @topk_v2_i64_indices(
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<4x128xf32>
+//       CHECK:   %[[OUT_VALUES:.+]] = tensor.empty()
+//       CHECK:   %[[OUT_INDICES:.+]] = tensor.empty()
+//       CHECK:   %[[RESULT:.+]]:2 = iree_linalg_ext.topk_v2
+//  CHECK-SAME:      dimension(1)
+//  CHECK-SAME:      ins(%[[ARG0]]
+//  CHECK-SAME:      outs(%[[OUT_VALUES]], %[[OUT_INDICES]]
+//       CHECK:      iree_linalg_ext.yield
+//       CHECK:   return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
+func.func @topk_v2_integer_values(%arg0: tensor<4x128xi32>) -> tensor<4x128xi32> {
+  %out = tensor.empty() : tensor<4x128xi32>
+  %0 = iree_linalg_ext.topk_v2 dimension(1)
+      ins(%arg0 : tensor<4x128xi32>)
+      outs(%out : tensor<4x128xi32>) {
+      ^bb0(%lhs: i32, %rhs: i32):
+        %cmp = arith.cmpi sgt, %lhs, %rhs : i32
+        iree_linalg_ext.yield %cmp : i1
+      } -> tensor<4x128xi32>
+  return %0 : tensor<4x128xi32>
+}
+// CHECK-LABEL: func.func @topk_v2_integer_values(
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<4x128xi32>
+//       CHECK:   %[[OUT:.+]] = tensor.empty()
+//       CHECK:   %[[RESULT:.+]] = iree_linalg_ext.topk_v2
+//  CHECK-SAME:      dimension(1)
+//  CHECK-SAME:      ins(%[[ARG0]]
+//  CHECK-SAME:      outs(%[[OUT]]
+//       CHECK:      iree_linalg_ext.yield
+//       CHECK:   return %[[RESULT]]
+
+// -----
+
+func.func @topk_v2_dynamic(%arg0: tensor<?x?xf32>, %out_values: tensor<?x?xf32>, %out_indices: tensor<?x?xi32>) -> (tensor<?x?xf32>, tensor<?x?xi32>) {
+  %0:2 = iree_linalg_ext.topk_v2 dimension(1)
+      ins(%arg0 : tensor<?x?xf32>)
+      outs(%out_values, %out_indices : tensor<?x?xf32>, tensor<?x?xi32>) {
+      ^bb0(%lhs: f32, %rhs: f32):
+        %cmp = arith.cmpf oge, %lhs, %rhs : f32
+        iree_linalg_ext.yield %cmp : i1
+      } -> tensor<?x?xf32>, tensor<?x?xi32>
+  return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xi32>
+}
+// CHECK-LABEL: func.func @topk_v2_dynamic(
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[OUT_VALUES:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+//  CHECK-SAME:   %[[OUT_INDICES:[a-zA-Z0-9]+]]: tensor<?x?xi32>
+//       CHECK:   %[[RESULT:.+]]:2 = iree_linalg_ext.topk_v2
+//  CHECK-SAME:      dimension(1)
+//  CHECK-SAME:      ins(%[[ARG0]]
+//  CHECK-SAME:      outs(%[[OUT_VALUES]], %[[OUT_INDICES]]
+//       CHECK:      iree_linalg_ext.yield
+//       CHECK:   return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
+func.func @topk_v2_sorted_tensor(%arg0: tensor<4x1024xf32>) -> (tensor<4x8xf32>, tensor<4x8xi32>) {
+  %out_values = tensor.empty() : tensor<4x8xf32>
+  %out_indices = tensor.empty() : tensor<4x8xi32>
+  %0:2 = iree_linalg_ext.topk_v2 dimension(1) is_sorted
+      ins(%arg0 : tensor<4x1024xf32>)
+      outs(%out_values, %out_indices : tensor<4x8xf32>, tensor<4x8xi32>) {
+      ^bb0(%lhs: f32, %rhs: f32):
+        %cmp = arith.cmpf oge, %lhs, %rhs : f32
+        iree_linalg_ext.yield %cmp : i1
+      } -> tensor<4x8xf32>, tensor<4x8xi32>
+  return %0#0, %0#1 : tensor<4x8xf32>, tensor<4x8xi32>
+}
+// CHECK-LABEL: func.func @topk_v2_sorted_tensor(
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<4x1024xf32>
+//       CHECK:   %[[OUT_VALUES:.+]] = tensor.empty()
+//       CHECK:   %[[OUT_INDICES:.+]] = tensor.empty()
+//       CHECK:   %[[RESULT:.+]]:2 = iree_linalg_ext.topk_v2
+//  CHECK-SAME:      dimension(1) is_sorted
+//  CHECK-SAME:      ins(%[[ARG0]]
+//  CHECK-SAME:      outs(%[[OUT_VALUES]], %[[OUT_INDICES]]
+//       CHECK:      iree_linalg_ext.yield
+//       CHECK:   return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
+func.func @topk_v2_sorted_no_indices(%arg0: tensor<4x1024xf32>) -> tensor<4x8xf32> {
+  %out_values = tensor.empty() : tensor<4x8xf32>
+  %0 = iree_linalg_ext.topk_v2 dimension(1) is_sorted
+      ins(%arg0 : tensor<4x1024xf32>)
+      outs(%out_values : tensor<4x8xf32>) {
+      ^bb0(%lhs: f32, %rhs: f32):
+        %cmp = arith.cmpf oge, %lhs, %rhs : f32
+        iree_linalg_ext.yield %cmp : i1
+      } -> tensor<4x8xf32>
+  return %0 : tensor<4x8xf32>
+}
+// CHECK-LABEL: func.func @topk_v2_sorted_no_indices(
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<4x1024xf32>
+//       CHECK:   %[[OUT_VALUES:.+]] = tensor.empty()
+//       CHECK:   %[[RESULT:.+]] = iree_linalg_ext.topk_v2
+//  CHECK-SAME:      dimension(1) is_sorted
+//  CHECK-SAME:      ins(%[[ARG0]]
+//  CHECK-SAME:      outs(%[[OUT_VALUES]]
+//       CHECK:      iree_linalg_ext.yield
+//       CHECK:   return %[[RESULT]]
+
+// -----
+
+func.func @topk_v2_sorted_with_input_indices(%arg0: tensor<4x1024xf32>, %arg1: tensor<4x1024xi32>) -> (tensor<4x8xf32>, tensor<4x8xi32>) {
+  %out_values = tensor.empty() : tensor<4x8xf32>
+  %out_indices = tensor.empty() : tensor<4x8xi32>
+  %0:2 = iree_linalg_ext.topk_v2 dimension(1) is_sorted
+      ins(%arg0, %arg1 : tensor<4x1024xf32>, tensor<4x1024xi32>)
+      outs(%out_values, %out_indices : tensor<4x8xf32>, tensor<4x8xi32>) {
+      ^bb0(%lhs: f32, %rhs: f32):
+        %cmp = arith.cmpf oge, %lhs, %rhs : f32
+        iree_linalg_ext.yield %cmp : i1
+      } -> tensor<4x8xf32>, tensor<4x8xi32>
+  return %0#0, %0#1 : tensor<4x8xf32>, tensor<4x8xi32>
+}
+// CHECK-LABEL: func.func @topk_v2_sorted_with_input_indices(
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<4x1024xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: tensor<4x1024xi32>
+//       CHECK:   %[[OUT_VALUES:.+]] = tensor.empty()
+//       CHECK:   %[[OUT_INDICES:.+]] = tensor.empty()
+//       CHECK:   %[[RESULT:.+]]:2 = iree_linalg_ext.topk_v2
+//  CHECK-SAME:      dimension(1) is_sorted
+//  CHECK-SAME:      ins(%[[ARG0]], %[[ARG1]]
+//  CHECK-SAME:      outs(%[[OUT_VALUES]], %[[OUT_INDICES]]
+//       CHECK:      iree_linalg_ext.yield
+//       CHECK:   return %[[RESULT]]#0, %[[RESULT]]#1
+
+// -----
+
 func.func @exp_reduction(%S: tensor<2x3xf32>) -> tensor<2xf32> {
   %M = tensor.empty() : tensor<2xf32>
   %out = tensor.empty() : tensor<2xf32>