blob: 2579d05d59e243b1181f01be9f6ce147088d4aef [file] [log] [blame]
func.func @argmax_1d() {
%seven = arith.constant 7.0 : f32
%input_init = tensor.empty() : tensor<131072xf32>
%input_filled = linalg.fill ins(%seven : f32) outs(%input_init : tensor<131072xf32>) -> tensor<131072xf32>
%large = arith.constant 53.0 : f32
%index = arith.constant 131071 : index
%input = tensor.insert %large into %input_filled[%index] : tensor<131072xf32>
%neg_inf = arith.constant 0xFF800000 : f32 // -inf
%c0_i32 = arith.constant 0 : i32
%init_val_buf = tensor.empty() : tensor<f32>
%init_idx_buf = tensor.empty() : tensor<i32>
%init_val = linalg.fill ins(%neg_inf : f32) outs(%init_val_buf : tensor<f32>) -> tensor<f32>
%init_idx = linalg.fill ins(%c0_i32 : i32) outs(%init_idx_buf : tensor<i32>) -> tensor<i32>
%result:2 = linalg.generic {
indexing_maps = [
affine_map<(d0) -> (d0)>,
affine_map<(d0) -> ()>,
affine_map<(d0) -> ()>
],
iterator_types = ["reduction"]
} ins(%input : tensor<131072xf32>)
outs(%init_val, %init_idx : tensor<f32>, tensor<i32>) {
^bb0(%in: f32, %val: f32, %idx: i32):
%i = linalg.index 0 : index
%i_cast = arith.index_cast %i : index to i32
%maxval = arith.maximumf %in, %val : f32
%cmp = arith.cmpf ogt, %in, %val : f32
%sel_idx = arith.select %cmp, %i_cast, %idx : i32
linalg.yield %maxval, %sel_idx : f32, i32
} -> (tensor<f32>, tensor<i32>)
check.expect_almost_eq_const(%result#0, dense<53.0> : tensor<f32>) : tensor<f32>
check.expect_eq_const(%result#1, dense<131071> : tensor<i32>) : tensor<i32>
return
}