blob: 552ab87254d3c04ddaa93d1b31154108104b6468 [file] [log] [blame]
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "compiler/plugins/target/ROCM/builtins/ukernel/common.h"
[[clang::always_inline]] void
iree_uk_amdgpu_argmax_f16i64(const _Float16 *inputBuffer, int64_t input_offset,
int64_t *outputBuffer, int64_t output_offset,
int64_t reductionSize) {
const int warpSize = __builtin_amdgcn_wavefrontsize();
_Float16 NEG_F16_MAX = (_Float16)(-65504.0f);
int32_t laneID = __builtin_amdgcn_workitem_id_x();
// Set identity value to handle problem non divisible by subgroupSize.
_Float16 laneMax = laneID >= reductionSize
? NEG_F16_MAX
: inputBuffer[input_offset + laneID];
int64_t laneResult = laneID;
int32_t numBatches = (reductionSize + warpSize - 1) / warpSize;
for (int i = 1; i < numBatches; ++i) {
int32_t idx = warpSize * i + laneID;
_Float16 newIn =
idx >= reductionSize ? NEG_F16_MAX : inputBuffer[input_offset + idx];
if (newIn == laneMax)
continue;
laneMax = __builtin_fmaxf16(newIn, laneMax);
laneResult = newIn == laneMax ? idx : laneResult;
}
// Final reduction with one subgroup
_Float16 wgMax = __ockl_wfred_max_f16(laneMax);
// Check if there are multiple max value holders.
uint64_t laneHasMaxValmask = __ballot(wgMax == laneMax);
// if there is only one max value holder, write and exit.
if (__builtin_popcountll(laneHasMaxValmask) == 1) {
if (wgMax == laneMax) {
outputBuffer[output_offset] = laneResult;
}
} else {
// if there are multiple max value holder, find smallest index (argmax
// semantics).
int64_t indexVal = wgMax == laneMax ? laneResult : INT64_MAX;
laneResult = __ockl_wfred_min_i64(indexVal);
if (laneID == 0) {
outputBuffer[output_offset] = laneResult;
}
}
// TODO(bjacob): this fence should be on the caller side. Move to TileAndFuse?
__threadfence_block();
}