VMLA sort implementation with pseudo lowering (#3295)

Adds an implementation of Sort for VMLA. Includes the kernel, an intermediate
operations, and end-to-end tests. The 2D case is currently disabled due to a
bug in the gather kernel being addressed right now.
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index 7a07478..dc289d6 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -95,6 +95,7 @@
     "range_test.py",
     "ring_buffer_test.py",  # TODO(b/148747011)
     "scatter_update_test.py",
+    "sort_test.py",
     "strings_test.py",
 ]
 
@@ -112,6 +113,7 @@
     "range_test.py",
     "ring_buffer_test.py",  # TODO(b/148747011)
     "scatter_update_test.py",
+    "sort_test.py",
     "strings_test.py",
 ]
 
diff --git a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp b/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
index a80a260..6af672a 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
@@ -42,6 +42,7 @@
   // Pseudo-ops are illegal.
   // If we end up with a lot of these, consider using an "is pseudo" trait.
   addIllegalOp<IREE::VMLA::BatchMatMulPseudoOp>();
+  addIllegalOp<IREE::VMLA::SortPseudoOp>();
 
   // Allow other ops to pass through so long as their type is valid (not a
   // tensor, basically).
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
index ccd1a3f..75b640f 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
@@ -704,6 +704,29 @@
   TypeConverter &typeConverter;
 };
 
+struct SortOpConversion : public OpConversionPattern<IREE::VMLA::SortPseudoOp> {
+  SortOpConversion(MLIRContext *context, TypeConverter &typeConverter)
+      : OpConversionPattern(context), typeConverter(typeConverter) {}
+
+  LogicalResult matchAndRewrite(
+      IREE::VMLA::SortPseudoOp srcOp, ArrayRef<Value> rawOperands,
+      ConversionPatternRewriter &rewriter) const override {
+    auto inputType =
+        srcOp.getOperand().getType().cast<ShapedType>().getElementType();
+    auto src = rawOperands[0];
+    auto src_shape = VMLAConversionTarget::getTensorShape(
+        srcOp.getLoc(), srcOp.value(), typeConverter, rewriter);
+    auto dst = VMLAConversionTarget::allocateOutputBuffer(
+        srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
+    rewriter.createOrFold<IREE::VMLA::SortOp>(srcOp.getLoc(), src, src_shape,
+                                              dst, TypeAttr::get(inputType));
+    rewriter.replaceOp(srcOp, {dst});
+    return success();
+  }
+
+  TypeConverter &typeConverter;
+};
+
 struct ConvertOpConversion : public OpConversionPattern<mhlo::ConvertOp> {
   ConvertOpConversion(MLIRContext *context, TypeConverter &typeConverter)
       : OpConversionPattern(context), typeConverter(typeConverter) {}
@@ -769,6 +792,9 @@
                                    IREE::VMLA::BatchMatMulOp>>(context,
                                                                typeConverter);
 
+  // vmla.sort.pseudo
+  patterns.insert<SortOpConversion>(context, typeConverter);
+
   // Simple 1:1 conversion patterns using the automated trait-based converter.
   // Used for HLO ops that have equivalent VMLA ops such as most arithmetic ops.
   patterns.insert<VMLAOpConversion<mhlo::AddOp, IREE::VMLA::AddOp>>(
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/sort.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/sort.mlir
new file mode 100644
index 0000000..0903793
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/sort.mlir
@@ -0,0 +1,37 @@
+// RUN: iree-opt -split-input-file -iree-vmla-pre-conversion-lowering -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
+
+func @sort1D(%arg0 : tensor<4xf32>) -> tensor<4xf32> attributes { sym_visibility = "private" } {
+  // CHECK-DAG: [[C16:%.+]] = constant 16 : index
+  // CHECK-DAG: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[4]>
+  // CHECK-DAG: [[BL:%.+]] = vmla.buffer.alloc byte_length = [[C16]] : !vmla.buffer
+  // CHECK-DAG: vmla.sort %arg0([[RS]] : !shapex.ranked_shape<[4]>), out [[BL]] : f32
+  // CHECK-DAG: [[BUF:%.+]] = vmla.buffer.alloc byte_length = [[C16]] : !vmla.buffer
+  // CHECK-DAG: vmla.gather %arg0([[RS]] : !shapex.ranked_shape<[4]>), [[BL]]([[RS]] : !shapex.ranked_shape<[4]>), out [[BUF]]([[RS]] : !shapex.ranked_shape<[4]>) {batch_dims = 0 : i64, dim = 0 : i64} : f32
+  %sort = "mhlo.sort"(%arg0) ( {
+  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):  // no predecessors
+    %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+    "mhlo.return"(%compare) : (tensor<i1>) -> ()
+  }) {dimension = 0 : i64, is_stable = false} : (tensor<4xf32>) -> tensor<4xf32>
+
+  // CHECK: return [[BUF]] : !vmla.buffer
+  return %sort : tensor<4xf32>
+}
+
+
+// CHECK-LABEL: func @sort2D
+func @sort2D(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> attributes { sym_visibility = "private" } {
+  // CHECK-DAG: [[C64:%.+]] = constant 64 : index
+  // CHECK-DAG: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[4,4]>
+  // CHECK-DAG: [[BL:%.+]] = vmla.buffer.alloc byte_length = [[C64]] : !vmla.buffer
+  // CHECK-DAG: vmla.sort %arg0([[RS]] : !shapex.ranked_shape<[4,4]>), out [[BL]] : f32
+  // CHECK-DAG: [[BUF:%.+]] = vmla.buffer.alloc byte_length = [[C64]] : !vmla.buffer
+  // CHECK-DAG: vmla.gather %arg0([[RS]] : !shapex.ranked_shape<[4,4]>), [[BL]]([[RS]] : !shapex.ranked_shape<[4,4]>), out [[BUF]]([[RS]] : !shapex.ranked_shape<[4,4]>) {batch_dims = 1 : i64, dim = 1 : i64} : f32
+  %sort = "mhlo.sort"(%arg0) ( {
+  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):  // no predecessors
+    %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+    "mhlo.return"(%compare) : (tensor<i1>) -> ()
+  }) {dimension = 1 : i64, is_stable = false} : (tensor<4x4xf32>) -> tensor<4x4xf32>
+
+  // CHECK: return [[BUF]] : !vmla.buffer
+  return %sort : tensor<4x4xf32>
+}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
index 7e4f6ea..1b66485 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
@@ -321,6 +321,7 @@
   VMLA_TYPED_IMPORT_OP(IREE::VMLA::FloorOp, "vmla.floor");
   VMLA_TYPED_IMPORT_OP(IREE::VMLA::CeilOp, "vmla.ceil");
   VMLA_TYPED_IMPORT_OP(IREE::VMLA::RoundOp, "vmla.round");
+  VMLA_TYPED_IMPORT_OP(IREE::VMLA::SortOp, "vmla.sort");
 
   patterns.insert<VMLAConvertImportOpConversion>(context, importSymbols,
                                                  typeConverter, "vmla.convert");
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
index f5cd0fe..422fed3 100644
--- a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
+++ b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
@@ -422,7 +422,7 @@
 }
 
 //===----------------------------------------------------------------------===//
-// VMLA Ops: Convultion
+// VMLA Ops: Convolution
 //===----------------------------------------------------------------------===//
 
 def VLMA_ConvOp : VMLA_Op<"conv", [VMLA_IncludeShapes]> {
@@ -460,6 +460,46 @@
 }
 
 //===----------------------------------------------------------------------===//
+// VMLA Ops: Sorting
+//===----------------------------------------------------------------------===//
+
+def VMLA_SortPseudoOp : VMLA_Op<"sort.pseudo"> {
+  let summary = "Tensor-level pseudo-op of VMLA::SortOp.";
+  let description = [{
+    This is a tensor-level version of VMLA::SortOp, to facilitate
+    the lowering process.
+
+    This operation generates a sorted index list along the last dimension,
+    performing batch-wise along all other dimensions.
+  }];
+  let arguments = (ins
+    AnyTensor:$value
+  );
+  let results = (outs
+    I32Tensor:$dst
+  );
+
+  let assemblyFormat = [{
+    $value attr-dict `:` `(`type($value)`)` `->` type($dst)
+  }];
+}
+
+def VMLA_SortOp : VMLA_ElementTypeOp<"sort", [VMLA_IncludeShapes]> {
+  let arguments = (ins
+    VMLA_Buffer:$src,
+    VMLA_Shape:$src_shape,
+    VMLA_Buffer:$dst,
+    VMLA_AnyTypeAttr:$element_type
+  );
+
+  let assemblyFormat = [{
+    $src`(`$src_shape `:` type($src_shape)`)``,`
+    `out` $dst attr-dict `:` $element_type
+  }];
+}
+
+
+//===----------------------------------------------------------------------===//
 // VMLA Ops: GEMM/GEMV
 //===----------------------------------------------------------------------===//
 
diff --git a/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp b/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
index d937026..62dd921 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
+++ b/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
@@ -277,6 +277,96 @@
   }
 };
 
+// Lower mhlo::SortOp to an pseudo SortOp in the VMLA dialect. This
+// pseudo op generates a set of ordered indices for that array along the last
+// dimension. Then using a torch_index_select the values can be reordered to
+// support arbitrary inputs.
+//
+// TODO(suderman): This lowering only covers the case of ascending values, we
+// should support a separate descending value case by having separate
+// SortAscending and SortDescending operations.
+class LowerSortOp : public OpRewritePattern<mhlo::SortOp> {
+ public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(mhlo::SortOp op,
+                                PatternRewriter &rewriter) const override {
+    auto operandTy = op.getOperand(0).getType().cast<RankedTensorType>();
+    bool lastDimension =
+        (op.dimension() == -1) || (op.dimension() == (operandTy.getRank() - 1));
+
+    // TODO(suderman): Add transpose to sort along the last dimension.
+    if (!lastDimension) return failure();
+
+    auto &comparator = op.comparator();
+    auto &block = comparator.getBlocks().front();
+    auto &operations = block.getOperations();
+    auto comparison = dyn_cast_or_null<mhlo::CompareOp>(&operations.front());
+
+    // First verify that the block is purely a return of a comparison. This
+    // handles sorting a single tensor of values.
+    if (!comparison) return failure();
+
+    auto returnOp =
+        dyn_cast_or_null<mhlo::ReturnOp>(&(*(++operations.begin())));
+    if (!returnOp) return failure();
+
+    if (returnOp.getOperand(0) != comparison.getResult()) return failure();
+
+    // Determine which operands being compared.
+    auto lhs = comparison.getOperand(0);
+    auto rhs = comparison.getOperand(1);
+    auto lhsIndex = -1;
+    auto rhsIndex = -1;
+    for (auto arg : llvm::enumerate(block.getArguments())) {
+      if (arg.value() == lhs) lhsIndex = arg.index();
+      if (arg.value() == rhs) rhsIndex = arg.index();
+    }
+
+    // This should never happen but best to check.
+    if (lhsIndex == -1) return failure();
+    if (rhsIndex == -1) return failure();
+
+    // They should not be the same.
+    if (lhsIndex == rhsIndex) return failure();
+
+    // Comparisons need to pull from same Sort operand..
+    auto lhsOperand = lhsIndex / 2;
+    auto rhsOperand = rhsIndex / 2;
+    if (lhsOperand != rhsOperand) return failure();
+
+    // Must be GT, GE, LT, or LE.
+    auto isGt = comparison.comparison_direction() == "GT" ||
+                comparison.comparison_direction() == "GE";
+    auto isLt = comparison.comparison_direction() == "LT" ||
+                comparison.comparison_direction() == "LE";
+    if (!isGt && !isLt) return failure();
+
+    bool operandParity = lhsIndex > rhsIndex;
+    auto isAscending = operandParity ^ isGt;
+    // TODO(suderman): Add support for descended sorting.
+    if (!isAscending) return failure();
+
+    auto operand = op.getOperand(lhsOperand);
+    auto sortedIndices = rewriter.create<VMLA::SortPseudoOp>(
+        op.getLoc(),
+        RankedTensorType::get(operandTy.getShape(), rewriter.getI32Type()),
+        operand);
+
+    llvm::SmallVector<Value, 6> sortedResults;
+    for (auto operand : op.getOperands()) {
+      auto tensorTy = operand.getType().cast<RankedTensorType>();
+      auto gathered = rewriter.create<mhlo::TorchIndexSelectOp>(
+          op.getLoc(), tensorTy, operand, sortedIndices,
+          /**dim=*/operandTy.getRank() - 1,
+          /**batch_dims=*/operandTy.getRank() - 1);
+      sortedResults.push_back(gathered);
+    }
+
+    rewriter.replaceOp(op, sortedResults);
+    return success();
+  }
+};
+
 class PreConversionLoweringPass
     : public PassWrapper<PreConversionLoweringPass, OperationPass<FuncOp>> {
  public:
@@ -310,6 +400,8 @@
     patterns.insert<LowerBroadcastInDimOp>(context);
     target.addIllegalOp<mhlo::BroadcastOp>();
     patterns.insert<LowerBroadcastOp>(context);
+    target.addIllegalOp<mhlo::SortOp>();
+    patterns.insert<LowerSortOp>(context);
 
     if (failed(applyPartialConversion(getOperation(), target, patterns))) {
       return signalPassFailure();
diff --git a/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir b/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir
index 3473e44..0b9cd82 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir
+++ b/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir
@@ -17,6 +17,38 @@
 // -----
 
 // CHECK-LABEL: func @f
+func @f(%arg0 : tensor<4xf32>) -> tensor<4xf32> attributes { sym_visibility = "private" } {
+  // CHECK-DAG: [[SORT:%.+]] = vmla.sort.pseudo %arg0
+  // CHECK-DAG: [[GATHER:%.+]] = "mhlo.torch_index_select"(%arg0, [[SORT]]) {batch_dims = 0 : i64, dim = 0 : i64}
+  %sort = "mhlo.sort"(%arg0) ( {
+  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):  // no predecessors
+    %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+    "mhlo.return"(%compare) : (tensor<i1>) -> ()
+  }) {dimension = 0 : i64, is_stable = false} : (tensor<4xf32>) -> tensor<4xf32>
+
+  // CHECK: return [[GATHER]]
+  return %sort : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @f
+func @f(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> attributes { sym_visibility = "private" } {
+  // CHECK-DAG: [[SORT:%.+]] = vmla.sort.pseudo %arg0
+  // CHECK-DAG: [[GATHER:%.+]] = "mhlo.torch_index_select"(%arg0, [[SORT]]) {batch_dims = 1 : i64, dim = 1 : i64}
+  %sort = "mhlo.sort"(%arg0) ( {
+  ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):  // no predecessors
+    %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+    "mhlo.return"(%compare) : (tensor<i1>) -> ()
+  }) {dimension = 1 : i64, is_stable = false} : (tensor<4x4xf32>) -> tensor<4x4xf32>
+
+  // CHECK return [[GATHER]]
+  return %sort : tensor<4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @f
 func @f(%arg0: tensor<3xf32>) -> tensor<4x3xf32> {
   // CHECK: "shapex.ranked_broadcast_in_dim"(%arg0, %rs4_3)
   %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<3xf32>) -> tensor<4x3xf32>
diff --git a/iree/compiler/Dialect/VMLA/vmla.imports.mlir b/iree/compiler/Dialect/VMLA/vmla.imports.mlir
index 17d7e85..ff575b0 100644
--- a/iree/compiler/Dialect/VMLA/vmla.imports.mlir
+++ b/iree/compiler/Dialect/VMLA/vmla.imports.mlir
@@ -333,6 +333,20 @@
 vm.import @ceil.f32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
 vm.import @round.f32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
 
+
+vm.import @sort.i8(
+  %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+  %dst : !vm.ref<!vmla.buffer>)
+vm.import @sort.i16(
+  %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+  %dst : !vm.ref<!vmla.buffer>)
+vm.import @sort.i32(
+  %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+  %dst : !vm.ref<!vmla.buffer>)
+vm.import @sort.f32(
+  %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+  %dst : !vm.ref<!vmla.buffer>)
+
 //===----------------------------------------------------------------------===//
 // VMLA Ops: conversion
 //===----------------------------------------------------------------------===//
diff --git a/iree/hal/vmla/op_kernels.h b/iree/hal/vmla/op_kernels.h
index c6ab9d6..ba5b8bc 100644
--- a/iree/hal/vmla/op_kernels.h
+++ b/iree/hal/vmla/op_kernels.h
@@ -168,6 +168,12 @@
                         absl::Span<const int32_t> dimensions);
 };
 
+struct Sort {
+  template <typename T>
+  static Status Execute(absl::Span<const T> src_buffer,
+                        absl::Span<int32_t> dst_buffer, ShapeSpan src_shape);
+};
+
 struct Broadcast {
   template <typename T>
   static Status Execute(absl::Span<const T> src_buffer,
diff --git a/iree/hal/vmla/op_kernels_generic.h b/iree/hal/vmla/op_kernels_generic.h
index d3545d1..0b4f904 100644
--- a/iree/hal/vmla/op_kernels_generic.h
+++ b/iree/hal/vmla/op_kernels_generic.h
@@ -15,7 +15,10 @@
 #ifndef IREE_HAL_VMLA_OP_KERNELS_GENERIC_H_
 #define IREE_HAL_VMLA_OP_KERNELS_GENERIC_H_
 
+#include <algorithm>
 #include <cmath>
+#include <iostream>
+#include <numeric>
 
 #include "absl/container/flat_hash_set.h"
 #include "absl/container/inlined_vector.h"
@@ -519,6 +522,25 @@
 }
 
 template <typename T>
+Status Sort::Execute(absl::Span<const T> src_buffer,
+                     absl::Span<int32_t> dst_buffer, ShapeSpan src_shape) {
+  int elements = src_buffer.size();
+  const int sort_size = src_shape.back();
+
+  for (int i = 0; i < elements; i += sort_size) {
+    auto src_subspan = src_buffer.subspan(i, sort_size);
+    auto dst_subspan = dst_buffer.subspan(i, sort_size);
+    std::iota(dst_subspan.begin(), dst_subspan.end(), 0);
+    std::stable_sort(dst_subspan.begin(), dst_subspan.end(),
+                     [&src_subspan](int32_t i1, int32_t i2) {
+                       return src_subspan[i1] < src_subspan[i2];
+                     });
+  }
+
+  return OkStatus();
+}
+
+template <typename T>
 Status Broadcast::Execute(absl::Span<const T> src_buffer,
                           absl::Span<T> dst_buffer) {
   for (size_t i = 0; i < dst_buffer.size(); ++i) {
diff --git a/iree/hal/vmla/vmla_module.cc b/iree/hal/vmla/vmla_module.cc
index 09dbb3c..5852de0 100644
--- a/iree/hal/vmla/vmla_module.cc
+++ b/iree/hal/vmla/vmla_module.cc
@@ -642,6 +642,19 @@
   IREE_VMLA_UNARY_OP(CeilF32, kernels::Ceil, float);
   IREE_VMLA_UNARY_OP(RoundF32, kernels::Round, float);
 
+#define IREE_VMLA_SORT_OP(name, type)                                        \
+  Status name(const vm::ref<Buffer>& src, iree_vmla_shape_t src_shape,       \
+              const vm::ref<Buffer>& dst) {                                  \
+    IREE_TRACE_SCOPE0("VMLAModuleState::" #name);                            \
+    return kernels::Sort::Execute<type>(src->As<type>(), dst->As<int32_t>(), \
+                                        src_shape);                          \
+  }
+
+  IREE_VMLA_SORT_OP(SortI8, int8_t);
+  IREE_VMLA_SORT_OP(SortI16, int16_t);
+  IREE_VMLA_SORT_OP(SortI32, int32_t);
+  IREE_VMLA_SORT_OP(SortF32, float);
+
   //===--------------------------------------------------------------------===//
   // VMLA Ops: conversion
   //===--------------------------------------------------------------------===//
@@ -970,6 +983,10 @@
     vm::MakeNativeFunction("floor.f32", &VMLAModuleState::FloorF32),
     vm::MakeNativeFunction("ceil.f32", &VMLAModuleState::CeilF32),
     vm::MakeNativeFunction("round.f32", &VMLAModuleState::RoundF32),
+    vm::MakeNativeFunction("sort.i8", &VMLAModuleState::SortI8),
+    vm::MakeNativeFunction("sort.i16", &VMLAModuleState::SortI16),
+    vm::MakeNativeFunction("sort.i32", &VMLAModuleState::SortI32),
+    vm::MakeNativeFunction("sort.f32", &VMLAModuleState::SortF32),
     vm::MakeNativeFunction("finite.f32", &VMLAModuleState::FiniteF32),
 
     vm::MakeNativeFunction("convert.i8.i16", &VMLAModuleState::ConvertI8I16),
diff --git a/iree/test/e2e/xla_ops/sort.mlir b/iree/test/e2e/xla_ops/sort.mlir
new file mode 100644
index 0000000..1820d8d
--- /dev/null
+++ b/iree/test/e2e/xla_ops/sort.mlir
@@ -0,0 +1,40 @@
+func @sort1D() attributes { iree.module.export } {
+  %input = iree.unfoldable_constant dense<[3, 2, 1, 4]> : tensor<4xi32>
+
+  %sort = "mhlo.sort"(%input) ( {
+  ^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>):  // no predecessors
+    %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    "mhlo.return"(%compare) : (tensor<i1>) -> ()
+  }) {dimension = 0 : i64, is_stable = false} : (tensor<4xi32>) -> tensor<4xi32>
+
+  check.expect_eq_const(%sort, dense<[1, 2, 3, 4]> : tensor<4xi32>) : tensor<4xi32>
+  return
+}
+
+func @sort2D() attributes { iree.module.export } {
+  %input = iree.unfoldable_constant dense<[[1, 2, 3, 4],
+                                           [4, 3, 2, 1]]> : tensor<2x4xi32>
+
+  %sort = "mhlo.sort"(%input) ( {
+  ^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>):  // no predecessors
+    %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    "mhlo.return"(%compare) : (tensor<i1>) -> ()
+  }) {dimension = 1 : i64, is_stable = false} : (tensor<2x4xi32>) -> tensor<2x4xi32>
+
+  check.expect_eq_const(%sort, dense<[[1, 2, 3, 4], [1, 2, 3, 4]]> : tensor<2x4xi32>) : tensor<2x4xi32>
+  return
+}
+
+func @sort3D() attributes { iree.module.export } {
+  %input = iree.unfoldable_constant dense<[[[1, 2, 3, 4],
+                                            [4, 3, 2, 1]]]> : tensor<1x2x4xi32>
+
+  %sort = "mhlo.sort"(%input) ( {
+  ^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>):  // no predecessors
+    %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    "mhlo.return"(%compare) : (tensor<i1>) -> ()
+  }) {dimension = 2 : i64, is_stable = false} : (tensor<1x2x4xi32>) -> tensor<1x2x4xi32>
+
+  check.expect_eq_const(%sort, dense<[[[1, 2, 3, 4], [1, 2, 3, 4]]]> : tensor<1x2x4xi32>) : tensor<1x2x4xi32>
+  return
+}