VMLA sort implementation with pseudo lowering (#3295)
Adds an implementation of Sort for VMLA. Includes the kernel, an intermediate
operations, and end-to-end tests. The 2D case is currently disabled due to a
bug in the gather kernel being addressed right now.
diff --git a/integrations/tensorflow/e2e/BUILD b/integrations/tensorflow/e2e/BUILD
index 7a07478..dc289d6 100644
--- a/integrations/tensorflow/e2e/BUILD
+++ b/integrations/tensorflow/e2e/BUILD
@@ -95,6 +95,7 @@
"range_test.py",
"ring_buffer_test.py", # TODO(b/148747011)
"scatter_update_test.py",
+ "sort_test.py",
"strings_test.py",
]
@@ -112,6 +113,7 @@
"range_test.py",
"ring_buffer_test.py", # TODO(b/148747011)
"scatter_update_test.py",
+ "sort_test.py",
"strings_test.py",
]
diff --git a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp b/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
index a80a260..6af672a 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/ConversionTarget.cpp
@@ -42,6 +42,7 @@
// Pseudo-ops are illegal.
// If we end up with a lot of these, consider using an "is pseudo" trait.
addIllegalOp<IREE::VMLA::BatchMatMulPseudoOp>();
+ addIllegalOp<IREE::VMLA::SortPseudoOp>();
// Allow other ops to pass through so long as their type is valid (not a
// tensor, basically).
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
index ccd1a3f..75b640f 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/ConvertHLOToVMLA.cpp
@@ -704,6 +704,29 @@
TypeConverter &typeConverter;
};
+struct SortOpConversion : public OpConversionPattern<IREE::VMLA::SortPseudoOp> {
+ SortOpConversion(MLIRContext *context, TypeConverter &typeConverter)
+ : OpConversionPattern(context), typeConverter(typeConverter) {}
+
+ LogicalResult matchAndRewrite(
+ IREE::VMLA::SortPseudoOp srcOp, ArrayRef<Value> rawOperands,
+ ConversionPatternRewriter &rewriter) const override {
+ auto inputType =
+ srcOp.getOperand().getType().cast<ShapedType>().getElementType();
+ auto src = rawOperands[0];
+ auto src_shape = VMLAConversionTarget::getTensorShape(
+ srcOp.getLoc(), srcOp.value(), typeConverter, rewriter);
+ auto dst = VMLAConversionTarget::allocateOutputBuffer(
+ srcOp.getLoc(), srcOp.getResult(), typeConverter, rewriter);
+ rewriter.createOrFold<IREE::VMLA::SortOp>(srcOp.getLoc(), src, src_shape,
+ dst, TypeAttr::get(inputType));
+ rewriter.replaceOp(srcOp, {dst});
+ return success();
+ }
+
+ TypeConverter &typeConverter;
+};
+
struct ConvertOpConversion : public OpConversionPattern<mhlo::ConvertOp> {
ConvertOpConversion(MLIRContext *context, TypeConverter &typeConverter)
: OpConversionPattern(context), typeConverter(typeConverter) {}
@@ -769,6 +792,9 @@
IREE::VMLA::BatchMatMulOp>>(context,
typeConverter);
+ // vmla.sort.pseudo
+ patterns.insert<SortOpConversion>(context, typeConverter);
+
// Simple 1:1 conversion patterns using the automated trait-based converter.
// Used for HLO ops that have equivalent VMLA ops such as most arithmetic ops.
patterns.insert<VMLAOpConversion<mhlo::AddOp, IREE::VMLA::AddOp>>(
diff --git a/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/sort.mlir b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/sort.mlir
new file mode 100644
index 0000000..0903793
--- /dev/null
+++ b/iree/compiler/Dialect/VMLA/Conversion/HLOToVMLA/test/sort.mlir
@@ -0,0 +1,37 @@
+// RUN: iree-opt -split-input-file -iree-vmla-pre-conversion-lowering -iree-vmla-conversion -canonicalize %s | IreeFileCheck %s
+
+func @sort1D(%arg0 : tensor<4xf32>) -> tensor<4xf32> attributes { sym_visibility = "private" } {
+ // CHECK-DAG: [[C16:%.+]] = constant 16 : index
+ // CHECK-DAG: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[4]>
+ // CHECK-DAG: [[BL:%.+]] = vmla.buffer.alloc byte_length = [[C16]] : !vmla.buffer
+ // CHECK-DAG: vmla.sort %arg0([[RS]] : !shapex.ranked_shape<[4]>), out [[BL]] : f32
+ // CHECK-DAG: [[BUF:%.+]] = vmla.buffer.alloc byte_length = [[C16]] : !vmla.buffer
+ // CHECK-DAG: vmla.gather %arg0([[RS]] : !shapex.ranked_shape<[4]>), [[BL]]([[RS]] : !shapex.ranked_shape<[4]>), out [[BUF]]([[RS]] : !shapex.ranked_shape<[4]>) {batch_dims = 0 : i64, dim = 0 : i64} : f32
+ %sort = "mhlo.sort"(%arg0) ( {
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
+ %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+ "mhlo.return"(%compare) : (tensor<i1>) -> ()
+ }) {dimension = 0 : i64, is_stable = false} : (tensor<4xf32>) -> tensor<4xf32>
+
+ // CHECK: return [[BUF]] : !vmla.buffer
+ return %sort : tensor<4xf32>
+}
+
+
+// CHECK-LABEL: func @sort2D
+func @sort2D(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> attributes { sym_visibility = "private" } {
+ // CHECK-DAG: [[C64:%.+]] = constant 64 : index
+ // CHECK-DAG: [[RS:%.+]] = shapex.const_ranked_shape : !shapex.ranked_shape<[4,4]>
+ // CHECK-DAG: [[BL:%.+]] = vmla.buffer.alloc byte_length = [[C64]] : !vmla.buffer
+ // CHECK-DAG: vmla.sort %arg0([[RS]] : !shapex.ranked_shape<[4,4]>), out [[BL]] : f32
+ // CHECK-DAG: [[BUF:%.+]] = vmla.buffer.alloc byte_length = [[C64]] : !vmla.buffer
+ // CHECK-DAG: vmla.gather %arg0([[RS]] : !shapex.ranked_shape<[4,4]>), [[BL]]([[RS]] : !shapex.ranked_shape<[4,4]>), out [[BUF]]([[RS]] : !shapex.ranked_shape<[4,4]>) {batch_dims = 1 : i64, dim = 1 : i64} : f32
+ %sort = "mhlo.sort"(%arg0) ( {
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
+ %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+ "mhlo.return"(%compare) : (tensor<i1>) -> ()
+ }) {dimension = 1 : i64, is_stable = false} : (tensor<4x4xf32>) -> tensor<4x4xf32>
+
+ // CHECK: return [[BUF]] : !vmla.buffer
+ return %sort : tensor<4x4xf32>
+}
diff --git a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
index 7e4f6ea..1b66485 100644
--- a/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
+++ b/iree/compiler/Dialect/VMLA/Conversion/VMLAToVM/ConvertVMLAToVM.cpp
@@ -321,6 +321,7 @@
VMLA_TYPED_IMPORT_OP(IREE::VMLA::FloorOp, "vmla.floor");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::CeilOp, "vmla.ceil");
VMLA_TYPED_IMPORT_OP(IREE::VMLA::RoundOp, "vmla.round");
+ VMLA_TYPED_IMPORT_OP(IREE::VMLA::SortOp, "vmla.sort");
patterns.insert<VMLAConvertImportOpConversion>(context, importSymbols,
typeConverter, "vmla.convert");
diff --git a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
index f5cd0fe..422fed3 100644
--- a/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
+++ b/iree/compiler/Dialect/VMLA/IR/VMLAOps.td
@@ -422,7 +422,7 @@
}
//===----------------------------------------------------------------------===//
-// VMLA Ops: Convultion
+// VMLA Ops: Convolution
//===----------------------------------------------------------------------===//
def VLMA_ConvOp : VMLA_Op<"conv", [VMLA_IncludeShapes]> {
@@ -460,6 +460,46 @@
}
//===----------------------------------------------------------------------===//
+// VMLA Ops: Sorting
+//===----------------------------------------------------------------------===//
+
+def VMLA_SortPseudoOp : VMLA_Op<"sort.pseudo"> {
+ let summary = "Tensor-level pseudo-op of VMLA::SortOp.";
+ let description = [{
+ This is a tensor-level version of VMLA::SortOp, to facilitate
+ the lowering process.
+
+ This operation generates a sorted index list along the last dimension,
+ performing batch-wise along all other dimensions.
+ }];
+ let arguments = (ins
+ AnyTensor:$value
+ );
+ let results = (outs
+ I32Tensor:$dst
+ );
+
+ let assemblyFormat = [{
+ $value attr-dict `:` `(`type($value)`)` `->` type($dst)
+ }];
+}
+
+def VMLA_SortOp : VMLA_ElementTypeOp<"sort", [VMLA_IncludeShapes]> {
+ let arguments = (ins
+ VMLA_Buffer:$src,
+ VMLA_Shape:$src_shape,
+ VMLA_Buffer:$dst,
+ VMLA_AnyTypeAttr:$element_type
+ );
+
+ let assemblyFormat = [{
+ $src`(`$src_shape `:` type($src_shape)`)``,`
+ `out` $dst attr-dict `:` $element_type
+ }];
+}
+
+
+//===----------------------------------------------------------------------===//
// VMLA Ops: GEMM/GEMV
//===----------------------------------------------------------------------===//
diff --git a/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp b/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
index d937026..62dd921 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
+++ b/iree/compiler/Dialect/VMLA/Transforms/PreConversionLowering.cpp
@@ -277,6 +277,96 @@
}
};
+// Lower mhlo::SortOp to an pseudo SortOp in the VMLA dialect. This
+// pseudo op generates a set of ordered indices for that array along the last
+// dimension. Then using a torch_index_select the values can be reordered to
+// support arbitrary inputs.
+//
+// TODO(suderman): This lowering only covers the case of ascending values, we
+// should support a separate descending value case by having separate
+// SortAscending and SortDescending operations.
+class LowerSortOp : public OpRewritePattern<mhlo::SortOp> {
+ public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(mhlo::SortOp op,
+ PatternRewriter &rewriter) const override {
+ auto operandTy = op.getOperand(0).getType().cast<RankedTensorType>();
+ bool lastDimension =
+ (op.dimension() == -1) || (op.dimension() == (operandTy.getRank() - 1));
+
+ // TODO(suderman): Add transpose to sort along the last dimension.
+ if (!lastDimension) return failure();
+
+ auto &comparator = op.comparator();
+ auto &block = comparator.getBlocks().front();
+ auto &operations = block.getOperations();
+ auto comparison = dyn_cast_or_null<mhlo::CompareOp>(&operations.front());
+
+ // First verify that the block is purely a return of a comparison. This
+ // handles sorting a single tensor of values.
+ if (!comparison) return failure();
+
+ auto returnOp =
+ dyn_cast_or_null<mhlo::ReturnOp>(&(*(++operations.begin())));
+ if (!returnOp) return failure();
+
+ if (returnOp.getOperand(0) != comparison.getResult()) return failure();
+
+ // Determine which operands being compared.
+ auto lhs = comparison.getOperand(0);
+ auto rhs = comparison.getOperand(1);
+ auto lhsIndex = -1;
+ auto rhsIndex = -1;
+ for (auto arg : llvm::enumerate(block.getArguments())) {
+ if (arg.value() == lhs) lhsIndex = arg.index();
+ if (arg.value() == rhs) rhsIndex = arg.index();
+ }
+
+ // This should never happen but best to check.
+ if (lhsIndex == -1) return failure();
+ if (rhsIndex == -1) return failure();
+
+ // They should not be the same.
+ if (lhsIndex == rhsIndex) return failure();
+
+ // Comparisons need to pull from same Sort operand..
+ auto lhsOperand = lhsIndex / 2;
+ auto rhsOperand = rhsIndex / 2;
+ if (lhsOperand != rhsOperand) return failure();
+
+ // Must be GT, GE, LT, or LE.
+ auto isGt = comparison.comparison_direction() == "GT" ||
+ comparison.comparison_direction() == "GE";
+ auto isLt = comparison.comparison_direction() == "LT" ||
+ comparison.comparison_direction() == "LE";
+ if (!isGt && !isLt) return failure();
+
+ bool operandParity = lhsIndex > rhsIndex;
+ auto isAscending = operandParity ^ isGt;
+ // TODO(suderman): Add support for descended sorting.
+ if (!isAscending) return failure();
+
+ auto operand = op.getOperand(lhsOperand);
+ auto sortedIndices = rewriter.create<VMLA::SortPseudoOp>(
+ op.getLoc(),
+ RankedTensorType::get(operandTy.getShape(), rewriter.getI32Type()),
+ operand);
+
+ llvm::SmallVector<Value, 6> sortedResults;
+ for (auto operand : op.getOperands()) {
+ auto tensorTy = operand.getType().cast<RankedTensorType>();
+ auto gathered = rewriter.create<mhlo::TorchIndexSelectOp>(
+ op.getLoc(), tensorTy, operand, sortedIndices,
+ /**dim=*/operandTy.getRank() - 1,
+ /**batch_dims=*/operandTy.getRank() - 1);
+ sortedResults.push_back(gathered);
+ }
+
+ rewriter.replaceOp(op, sortedResults);
+ return success();
+ }
+};
+
class PreConversionLoweringPass
: public PassWrapper<PreConversionLoweringPass, OperationPass<FuncOp>> {
public:
@@ -310,6 +400,8 @@
patterns.insert<LowerBroadcastInDimOp>(context);
target.addIllegalOp<mhlo::BroadcastOp>();
patterns.insert<LowerBroadcastOp>(context);
+ target.addIllegalOp<mhlo::SortOp>();
+ patterns.insert<LowerSortOp>(context);
if (failed(applyPartialConversion(getOperation(), target, patterns))) {
return signalPassFailure();
diff --git a/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir b/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir
index 3473e44..0b9cd82 100644
--- a/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir
+++ b/iree/compiler/Dialect/VMLA/Transforms/test/pre_conversion_lowering.mlir
@@ -17,6 +17,38 @@
// -----
// CHECK-LABEL: func @f
+func @f(%arg0 : tensor<4xf32>) -> tensor<4xf32> attributes { sym_visibility = "private" } {
+ // CHECK-DAG: [[SORT:%.+]] = vmla.sort.pseudo %arg0
+ // CHECK-DAG: [[GATHER:%.+]] = "mhlo.torch_index_select"(%arg0, [[SORT]]) {batch_dims = 0 : i64, dim = 0 : i64}
+ %sort = "mhlo.sort"(%arg0) ( {
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
+ %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+ "mhlo.return"(%compare) : (tensor<i1>) -> ()
+ }) {dimension = 0 : i64, is_stable = false} : (tensor<4xf32>) -> tensor<4xf32>
+
+ // CHECK: return [[GATHER]]
+ return %sort : tensor<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @f
+func @f(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> attributes { sym_visibility = "private" } {
+ // CHECK-DAG: [[SORT:%.+]] = vmla.sort.pseudo %arg0
+ // CHECK-DAG: [[GATHER:%.+]] = "mhlo.torch_index_select"(%arg0, [[SORT]]) {batch_dims = 1 : i64, dim = 1 : i64}
+ %sort = "mhlo.sort"(%arg0) ( {
+ ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): // no predecessors
+ %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
+ "mhlo.return"(%compare) : (tensor<i1>) -> ()
+ }) {dimension = 1 : i64, is_stable = false} : (tensor<4x4xf32>) -> tensor<4x4xf32>
+
+ // CHECK return [[GATHER]]
+ return %sort : tensor<4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @f
func @f(%arg0: tensor<3xf32>) -> tensor<4x3xf32> {
// CHECK: "shapex.ranked_broadcast_in_dim"(%arg0, %rs4_3)
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<3xf32>) -> tensor<4x3xf32>
diff --git a/iree/compiler/Dialect/VMLA/vmla.imports.mlir b/iree/compiler/Dialect/VMLA/vmla.imports.mlir
index 17d7e85..ff575b0 100644
--- a/iree/compiler/Dialect/VMLA/vmla.imports.mlir
+++ b/iree/compiler/Dialect/VMLA/vmla.imports.mlir
@@ -333,6 +333,20 @@
vm.import @ceil.f32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
vm.import @round.f32(%src : !vm.ref<!vmla.buffer>, %dst : !vm.ref<!vmla.buffer>)
+
+vm.import @sort.i8(
+ %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+ %dst : !vm.ref<!vmla.buffer>)
+vm.import @sort.i16(
+ %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+ %dst : !vm.ref<!vmla.buffer>)
+vm.import @sort.i32(
+ %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+ %dst : !vm.ref<!vmla.buffer>)
+vm.import @sort.f32(
+ %src : !vm.ref<!vmla.buffer>, %src_shape : i32 ...,
+ %dst : !vm.ref<!vmla.buffer>)
+
//===----------------------------------------------------------------------===//
// VMLA Ops: conversion
//===----------------------------------------------------------------------===//
diff --git a/iree/hal/vmla/op_kernels.h b/iree/hal/vmla/op_kernels.h
index c6ab9d6..ba5b8bc 100644
--- a/iree/hal/vmla/op_kernels.h
+++ b/iree/hal/vmla/op_kernels.h
@@ -168,6 +168,12 @@
absl::Span<const int32_t> dimensions);
};
+struct Sort {
+ template <typename T>
+ static Status Execute(absl::Span<const T> src_buffer,
+ absl::Span<int32_t> dst_buffer, ShapeSpan src_shape);
+};
+
struct Broadcast {
template <typename T>
static Status Execute(absl::Span<const T> src_buffer,
diff --git a/iree/hal/vmla/op_kernels_generic.h b/iree/hal/vmla/op_kernels_generic.h
index d3545d1..0b4f904 100644
--- a/iree/hal/vmla/op_kernels_generic.h
+++ b/iree/hal/vmla/op_kernels_generic.h
@@ -15,7 +15,10 @@
#ifndef IREE_HAL_VMLA_OP_KERNELS_GENERIC_H_
#define IREE_HAL_VMLA_OP_KERNELS_GENERIC_H_
+#include <algorithm>
#include <cmath>
+#include <iostream>
+#include <numeric>
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
@@ -519,6 +522,25 @@
}
template <typename T>
+Status Sort::Execute(absl::Span<const T> src_buffer,
+ absl::Span<int32_t> dst_buffer, ShapeSpan src_shape) {
+ int elements = src_buffer.size();
+ const int sort_size = src_shape.back();
+
+ for (int i = 0; i < elements; i += sort_size) {
+ auto src_subspan = src_buffer.subspan(i, sort_size);
+ auto dst_subspan = dst_buffer.subspan(i, sort_size);
+ std::iota(dst_subspan.begin(), dst_subspan.end(), 0);
+ std::stable_sort(dst_subspan.begin(), dst_subspan.end(),
+ [&src_subspan](int32_t i1, int32_t i2) {
+ return src_subspan[i1] < src_subspan[i2];
+ });
+ }
+
+ return OkStatus();
+}
+
+template <typename T>
Status Broadcast::Execute(absl::Span<const T> src_buffer,
absl::Span<T> dst_buffer) {
for (size_t i = 0; i < dst_buffer.size(); ++i) {
diff --git a/iree/hal/vmla/vmla_module.cc b/iree/hal/vmla/vmla_module.cc
index 09dbb3c..5852de0 100644
--- a/iree/hal/vmla/vmla_module.cc
+++ b/iree/hal/vmla/vmla_module.cc
@@ -642,6 +642,19 @@
IREE_VMLA_UNARY_OP(CeilF32, kernels::Ceil, float);
IREE_VMLA_UNARY_OP(RoundF32, kernels::Round, float);
+#define IREE_VMLA_SORT_OP(name, type) \
+ Status name(const vm::ref<Buffer>& src, iree_vmla_shape_t src_shape, \
+ const vm::ref<Buffer>& dst) { \
+ IREE_TRACE_SCOPE0("VMLAModuleState::" #name); \
+ return kernels::Sort::Execute<type>(src->As<type>(), dst->As<int32_t>(), \
+ src_shape); \
+ }
+
+ IREE_VMLA_SORT_OP(SortI8, int8_t);
+ IREE_VMLA_SORT_OP(SortI16, int16_t);
+ IREE_VMLA_SORT_OP(SortI32, int32_t);
+ IREE_VMLA_SORT_OP(SortF32, float);
+
//===--------------------------------------------------------------------===//
// VMLA Ops: conversion
//===--------------------------------------------------------------------===//
@@ -970,6 +983,10 @@
vm::MakeNativeFunction("floor.f32", &VMLAModuleState::FloorF32),
vm::MakeNativeFunction("ceil.f32", &VMLAModuleState::CeilF32),
vm::MakeNativeFunction("round.f32", &VMLAModuleState::RoundF32),
+ vm::MakeNativeFunction("sort.i8", &VMLAModuleState::SortI8),
+ vm::MakeNativeFunction("sort.i16", &VMLAModuleState::SortI16),
+ vm::MakeNativeFunction("sort.i32", &VMLAModuleState::SortI32),
+ vm::MakeNativeFunction("sort.f32", &VMLAModuleState::SortF32),
vm::MakeNativeFunction("finite.f32", &VMLAModuleState::FiniteF32),
vm::MakeNativeFunction("convert.i8.i16", &VMLAModuleState::ConvertI8I16),
diff --git a/iree/test/e2e/xla_ops/sort.mlir b/iree/test/e2e/xla_ops/sort.mlir
new file mode 100644
index 0000000..1820d8d
--- /dev/null
+++ b/iree/test/e2e/xla_ops/sort.mlir
@@ -0,0 +1,40 @@
+func @sort1D() attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<[3, 2, 1, 4]> : tensor<4xi32>
+
+ %sort = "mhlo.sort"(%input) ( {
+ ^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>): // no predecessors
+ %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ "mhlo.return"(%compare) : (tensor<i1>) -> ()
+ }) {dimension = 0 : i64, is_stable = false} : (tensor<4xi32>) -> tensor<4xi32>
+
+ check.expect_eq_const(%sort, dense<[1, 2, 3, 4]> : tensor<4xi32>) : tensor<4xi32>
+ return
+}
+
+func @sort2D() attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<[[1, 2, 3, 4],
+ [4, 3, 2, 1]]> : tensor<2x4xi32>
+
+ %sort = "mhlo.sort"(%input) ( {
+ ^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>): // no predecessors
+ %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ "mhlo.return"(%compare) : (tensor<i1>) -> ()
+ }) {dimension = 1 : i64, is_stable = false} : (tensor<2x4xi32>) -> tensor<2x4xi32>
+
+ check.expect_eq_const(%sort, dense<[[1, 2, 3, 4], [1, 2, 3, 4]]> : tensor<2x4xi32>) : tensor<2x4xi32>
+ return
+}
+
+func @sort3D() attributes { iree.module.export } {
+ %input = iree.unfoldable_constant dense<[[[1, 2, 3, 4],
+ [4, 3, 2, 1]]]> : tensor<1x2x4xi32>
+
+ %sort = "mhlo.sort"(%input) ( {
+ ^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>): // no predecessors
+ %compare = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ "mhlo.return"(%compare) : (tensor<i1>) -> ()
+ }) {dimension = 2 : i64, is_stable = false} : (tensor<1x2x4xi32>) -> tensor<1x2x4xi32>
+
+ check.expect_eq_const(%sort, dense<[[[1, 2, 3, 4], [1, 2, 3, 4]]]> : tensor<1x2x4xi32>) : tensor<1x2x4xi32>
+ return
+}