blob: 499a7d21c3c4a89e6b8dddbedd1d50f5b54ddd6b [file] [log] [blame]
func.func @torch_select_index_0() {
%input = util.unfoldable_constant dense<[
[[01, 02, 03, 04, 05]],
[[06, 07, 08, 09, 10]],
[[11, 12, 13, 14, 15]],
[[16, 17, 18, 19, 20]],
[[21, 22, 23, 24, 25]]]> : tensor<5x1x5xi32>
%indices = util.unfoldable_constant dense<[0, 2]> : tensor<2xi32>
%res = "stablehlo.torch_index_select"(%input, %indices) {
dim = 0 : i64,
batch_dims = 0 : i64
} : (tensor<5x1x5xi32>, tensor<2xi32>) -> tensor<2x1x5xi32>
check.expect_eq_const(%res, dense<[[[01, 02, 03, 04, 05]], [[11, 12, 13, 14, 15]]]> : tensor<2x1x5xi32>) : tensor<2x1x5xi32>
return
}
func.func @torch_select_index_1() {
%input = util.unfoldable_constant dense<[
[[ 1, 2],[ 3, 4]],
[[ 5, 6],[ 7, 8]],
[[ 9, 10],[11, 12]]]> : tensor<3x2x2xi32>
%indices = util.unfoldable_constant dense<[0, 1]> : tensor<2xi32>
%res = "stablehlo.torch_index_select"(%input, %indices) {
dim = 1 : i64,
batch_dims = 0 : i64
} : (tensor<3x2x2xi32>, tensor<2xi32>) -> tensor<3x2x2xi32>
check.expect_eq_const(%res, dense<[[[1, 2], [3, 4]], [[5, 6], [7, 8]],[[9, 10], [11, 12]]]> : tensor<3x2x2xi32>) : tensor<3x2x2xi32>
return
}
func.func @torch_select_index_2() {
%input = util.unfoldable_constant dense<[
[[01, 02, 03, 04, 05]],
[[06, 07, 08, 09, 10]],
[[11, 12, 13, 14, 15]],
[[16, 17, 18, 19, 20]],
[[21, 22, 23, 24, 25]]]> : tensor<5x1x5xi32>
%indices = util.unfoldable_constant dense<0> : tensor<i32>
%res = "stablehlo.torch_index_select"(%input, %indices) {
dim = 0 : i64,
batch_dims = 0 : i64
} : (tensor<5x1x5xi32>, tensor<i32>) -> tensor<1x5xi32>
check.expect_eq_const(%res, dense<[[01, 02, 03, 04, 05]]> : tensor<1x5xi32>) : tensor<1x5xi32>
return
}