blob: 98e62e7e0d334bd3d3222fb2bc5170c2c276f0f1 [file] [log] [blame]
// RUN: iree-opt --lower-xla-to-iree-interpreter %s | FileCheck %s --dump-input=fail
// CHECK-LABEL: @gather
// CHECK-SAME: [[INPUT:%[a-zA-Z0-9]+]]
// CHECK-SAME: [[START_INDICES:%[a-zA-Z0-9]+]]
func @gather(%input : tensor<5x2x3xf32>, %start_indices : tensor<i64>) -> tensor<2x3xf32> {
// CHECK-DAG: [[SRC:%.+]] = iree.tensor_to_memref([[INPUT]] : tensor<5x2x3xf32>)
// CHECK-DAG: [[START_INDICES_MEMREF:%.+]] = iree.tensor_to_memref([[START_INDICES]] : tensor<i64>)
// CHECK-DAG: [[START_INDICES_NEW_SHAPE:%.+]] = iree.constant[dense<1> : tensor<1xi64>
// CHECK-DAG: [[START_INDICES_RESHAPED:%.+]] = "iree_hl_interp.reshape"([[START_INDICES_MEMREF]], [[START_INDICES_NEW_SHAPE]])
// CHECK-DAG: [[ZEROES:%.+]] = iree.constant[dense<0> : tensor<2xi64>
// CHECK-DAG: [[START_INDICES_PADDED:%.+]] = "iree_hl_interp.concat"([[START_INDICES_RESHAPED]], [[ZEROES]])
// CHECK-DAG: [[DST:%.+]] = "iree_hl_interp.alloc_heap"() : () -> memref<1x2x3xf32>
// CHECK-DAG: [[DST_INDICES:%.+]] = iree.constant[dense<0>
// CHECK-DAG: [[LENGTHS:%.+]] = iree.constant[dense<[1, 2, 3]>
// CHECK-NEXT: "iree_hl_interp.copy"([[SRC]], [[START_INDICES_PADDED]], [[DST]], [[DST_INDICES]], [[LENGTHS]])
// CHECK-DAG: [[NEW_SHAPE:%.+]] = iree.constant[dense<[2, 3]>
// CHECK-DAG: [[RESHAPED:%.+]] = "iree_hl_interp.reshape"([[DST]], [[NEW_SHAPE]])
// CHECK-DAG: [[RESULT_TENSOR:%.+]] = iree.memref_to_tensor([[RESHAPED]] : memref<2x3xf32>)
%result = "xla_hlo.gather"(%input, %start_indices) {
collapsed_slice_dims = dense<0> : tensor<1xi64>,
index_vector_dim = 0 : i64,
offset_dims = dense<[0, 1]> : tensor<2xi64>,
slice_sizes = dense<[1, 2, 3]> : tensor<3xi64>,
start_index_map = dense<0> : tensor<1xi64>
} : (tensor<5x2x3xf32>, tensor<i64>) -> tensor<2x3xf32>
// CHECK-NEXT: return [[RESULT_TENSOR]]
return %result : tensor<2x3xf32>
}
// CHECK-LABEL: @gather_nonscalar_indices
// CHECK-SAME: [[INPUT:%[a-zA-Z0-9]+]]
// CHECK-SAME: [[START_INDICES:%[a-zA-Z0-9]+]]
func @gather_nonscalar_indices(%input : tensor<5x2x3xf32>, %start_indices : tensor<1xi64>) -> tensor<2x3xf32> {
// CHECK-DAG: [[SRC:%.+]] = iree.tensor_to_memref([[INPUT]] : tensor<5x2x3xf32>)
// CHECK-DAG: [[START_INDICES_MEMREF:%.+]] = iree.tensor_to_memref([[START_INDICES]] : tensor<1xi64>)
// CHECK-DAG: [[ZEROES:%.+]] = iree.constant[dense<0> : tensor<2xi64>
// CHECK-DAG: [[START_INDICES_PADDED:%.+]] = "iree_hl_interp.concat"([[START_INDICES_MEMREF]], [[ZEROES]])
// CHECK-DAG: [[DST:%.+]] = "iree_hl_interp.alloc_heap"() : () -> memref<1x2x3xf32>
// CHECK-DAG: [[DST_INDICES:%.+]] = iree.constant[dense<0>
// CHECK-DAG: [[LENGTHS:%.+]] = iree.constant[dense<[1, 2, 3]>
// CHECK-NEXT: "iree_hl_interp.copy"([[SRC]], [[START_INDICES_PADDED]], [[DST]], [[DST_INDICES]], [[LENGTHS]])
// CHECK-DAG: [[NEW_SHAPE:%.+]] = iree.constant[dense<[2, 3]>
// CHECK-DAG: [[RESHAPED:%.+]] = "iree_hl_interp.reshape"([[DST]], [[NEW_SHAPE]])
// CHECK-DAG: [[RESULT_TENSOR:%.+]] = iree.memref_to_tensor([[RESHAPED]] : memref<2x3xf32>)
%result = "xla_hlo.gather"(%input, %start_indices) {
collapsed_slice_dims = dense<0> : tensor<1xi64>,
index_vector_dim = 0 : i64,
offset_dims = dense<[0, 1]> : tensor<2xi64>,
slice_sizes = dense<[1, 2, 3]> : tensor<3xi64>,
start_index_map = dense<0> : tensor<1xi64>
} : (tensor<5x2x3xf32>, tensor<1xi64>) -> tensor<2x3xf32>
// CHECK-NEXT: return [[RESULT_TENSOR]]
return %result : tensor<2x3xf32>
}
// CHECK-LABEL: @gather_fully_specified_indices
// CHECK-SAME: [[INPUT:%[a-zA-Z0-9]+]]
// CHECK-SAME: [[START_INDICES:%[a-zA-Z0-9]+]]
func @gather_fully_specified_indices(%input : tensor<5x2x3xf32>, %start_indices : tensor<3xi64>) -> tensor<2x3xf32> {
// CHECK-DAG: [[SRC:%.+]] = iree.tensor_to_memref([[INPUT]] : tensor<5x2x3xf32>)
// CHECK-DAG: [[START_INDICES_MEMREF:%.+]] = iree.tensor_to_memref([[START_INDICES]] : tensor<3xi64>)
// CHECK-DAG: [[DST:%.+]] = "iree_hl_interp.alloc_heap"() : () -> memref<1x2x3xf32>
// CHECK-DAG: [[DST_INDICES:%.+]] = iree.constant[dense<0>
// CHECK-DAG: [[LENGTHS:%.+]] = iree.constant[dense<[1, 2, 3]>
// CHECK-NEXT: "iree_hl_interp.copy"([[SRC]], [[START_INDICES_MEMREF]], [[DST]], [[DST_INDICES]], [[LENGTHS]])
// CHECK-DAG: [[NEW_SHAPE:%.+]] = iree.constant[dense<[2, 3]>
// CHECK-DAG: [[RESHAPED:%.+]] = "iree_hl_interp.reshape"([[DST]], [[NEW_SHAPE]])
// CHECK-DAG: [[RESULT_TENSOR:%.+]] = iree.memref_to_tensor([[RESHAPED]] : memref<2x3xf32>)
%result = "xla_hlo.gather"(%input, %start_indices) {
collapsed_slice_dims = dense<0> : tensor<1xi64>,
index_vector_dim = 0 : i64,
offset_dims = dense<[0, 1]> : tensor<2xi64>,
slice_sizes = dense<[1, 2, 3]> : tensor<3xi64>,
start_index_map = dense<0> : tensor<1xi64>
} : (tensor<5x2x3xf32>, tensor<3xi64>) -> tensor<2x3xf32>
// CHECK-NEXT: return [[RESULT_TENSOR]]
return %result : tensor<2x3xf32>
}
// CHECK-LABEL: @gather_not_lowered
// CHECK-SAME: [[INPUT:%[a-zA-Z0-9]+]]
// CHECK-SAME: [[START_INDICES:%[a-zA-Z0-9]+]]
func @gather_not_lowered(%input : tensor<5x2x3xf32>, %start_indices : tensor<2x2xi64>) {
// CHECK-NEXT "xla_hlo.gather"
%axis_1 = "xla_hlo.gather"(%input, %start_indices) {
collapsed_slice_dims = dense<0> : tensor<1xi64>,
index_vector_dim = 1 : i64,
offset_dims = dense<[0, 1, 2]> : tensor<3xi64>,
slice_sizes = dense<[1, 2, 3]> : tensor<3xi64>,
start_index_map = dense<0> : tensor<1xi64>
} : (tensor<5x2x3xf32>, tensor<2x2xi64>) -> tensor<2x3xf32>
// CHECK-NEXT "xla_hlo.gather"
%collapse_1 = "xla_hlo.gather"(%input, %start_indices) {
collapsed_slice_dims = dense<1> : tensor<1xi64>,
index_vector_dim = 0 : i64,
offset_dims = dense<[0, 1, 2]> : tensor<3xi64>,
slice_sizes = dense<[1, 2, 3]> : tensor<3xi64>,
start_index_map = dense<0> : tensor<1xi64>
} : (tensor<5x2x3xf32>, tensor<2x2xi64>) -> tensor<2x3xf32>
// CHECK-NEXT "xla_hlo.gather"
%transposes = "xla_hlo.gather"(%input, %start_indices) {
collapsed_slice_dims = dense<0> : tensor<1xi64>,
index_vector_dim = 0 : i64,
offset_dims = dense<[0, 1, 2]> : tensor<3xi64>,
slice_sizes = dense<[1, 2, 3]> : tensor<3xi64>,
start_index_map = dense<[1, 0]> : tensor<2xi64>
} : (tensor<5x2x3xf32>, tensor<2x2xi64>) -> tensor<2x3xf32>
// CHECK-NEXT "xla_hlo.gather"
%has_batch_dims = "xla_hlo.gather"(%input, %start_indices) {
collapsed_slice_dims = dense<0> : tensor<1xi64>,
index_vector_dim = 0 : i64,
offset_dims = dense<1> : tensor<1xi64>,
slice_sizes = dense<[1, 2, 3]> : tensor<3xi64>,
start_index_map = dense<[1, 0]> : tensor<2xi64>
} : (tensor<5x2x3xf32>, tensor<2x2xi64>) -> tensor<2x3xf32>
return
}