[ROCM][Ukernel] Fix index types. (#16154)
diff --git a/compiler/plugins/target/ROCM/builtins/ukernel/argmax_ukernel.c b/compiler/plugins/target/ROCM/builtins/ukernel/argmax_ukernel.c
index 9c2caf3..81f4b85 100644
--- a/compiler/plugins/target/ROCM/builtins/ukernel/argmax_ukernel.c
+++ b/compiler/plugins/target/ROCM/builtins/ukernel/argmax_ukernel.c
@@ -109,7 +109,7 @@
}
// if there are multiple max value holder, find smallest index (argmax
// semantics).
- int32_t indexVal = wgMax == laneMax ? laneResult : __INT64_MAX__;
+ int64_t indexVal = wgMax == laneMax ? laneResult : __INT64_MAX__;
laneResult = __ockl_wfred_min_i64(indexVal);
if (laneID == 0)
outputBuffer[output_offset] = laneResult;
@@ -191,7 +191,7 @@
}
// if there are multiple max value holder, find smallest index (argmax
// semantics).
- int32_t indexVal = wgMax == laneMax ? laneResult : __INT64_MAX__;
+ int64_t indexVal = wgMax == laneMax ? laneResult : __INT64_MAX__;
laneResult = __ockl_wfred_min_i64(indexVal);
if (laneID == 0)
outputBuffer[output_offset] = laneResult;