blob: 4f4b83003de99fc199844c0ba80199a897018ab8 [file] [log] [blame]
/*
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <assert.h>
#include <string.h>
#include "vector_vadd_vsub_tests.h"
#include <springbok.h>
// TODO(b/194689843): Re-enable e64 and mf[2|4|8] tests.
static void randomize_array(void *, size_t, size_t);
static void check_array_equality(void *, void *, size_t);
static void calculate_expected_output(void *, void *, void *, uint32_t, uint8_t,
const char *, const char *);
#define SETUP_TEST(VTYPE, BUFFER_SIZE, VL_DST) \
do { \
VSET(BUFFER_SIZE, VTYPE, m1); \
COPY_SCALAR_REG(VL_DST); \
LOG_INFO(#VL_DST " = %u", VL_DST); \
} while (0)
#define MAKE_TEST(VTYPE, OPERATION, SUBOPERATION, DATATYPE) \
void \
run_operations_test_vector_##OPERATION##_##SUBOPERATION##_##VTYPE##_##DATATYPE( \
void *output, void *input1, void *input2) { \
__asm__ volatile("vl" #VTYPE ".v v1, (%0)" ::"r"(input1)); \
if (strequal(#SUBOPERATION, "vv")) { \
__asm__ volatile("vl" #VTYPE ".v v2, (%0)" ::"r"(input2)); \
if (strequal(#OPERATION, "vadd")) { \
__asm__ volatile("vadd.vv v3, v1, v2"); \
} \
if (strequal(#OPERATION, "vsub")) { \
__asm__ volatile("vsub.vv v3, v1, v2"); \
} \
} else if (strequal(#SUBOPERATION, "vx")) { \
uint32_t unsigned_scalar = *((uint32_t *)input2); \
int32_t signed_scalar = *((int32_t *)input2); \
if (strequal(#OPERATION, "vadd")) { \
if (*(#DATATYPE) == 'u') { \
__asm__ volatile( \
"vadd.vx v3, v1, %[RS1]" ::[RS1] "r"(unsigned_scalar)); \
} else { \
__asm__ volatile( \
"vadd.vx v3, v1, %[RS1]" ::[RS1] "r"(signed_scalar)); \
} \
} \
if (strequal(#OPERATION, "vsub")) { \
if (*(#DATATYPE) == 'u') { \
__asm__ volatile( \
"vsub.vx v3, v1, %[RS1]" ::[RS1] "r"(unsigned_scalar)); \
} else { \
__asm__ volatile( \
"vsub.vx v3, v1, %[RS1]" ::[RS1] "r"(signed_scalar)); \
} \
} \
if (strequal(#OPERATION, "vrsub")) { \
if (*(#DATATYPE) == 'u') { \
__asm__ volatile( \
"vrsub.vx v3, v1, %[RS1]" ::[RS1] "r"(unsigned_scalar)); \
} else { \
__asm__ volatile( \
"vrsub.vx v3, v1, %[RS1]" ::[RS1] "r"(signed_scalar)); \
} \
} \
} else if (strequal(#SUBOPERATION, "vi")) { \
int32_t input2_ = *((int32_t *)input2); \
switch (input2_) { \
case -16: \
if (strequal(#OPERATION, "vadd")) \
__asm__ volatile("vadd.vi v3, v1, -16"); \
if (strequal(#OPERATION, "vrsub")) \
__asm__ volatile("vrsub.vi v3, v1, -16"); \
break; \
case 15: \
if (strequal(#OPERATION, "vadd")) \
__asm__ volatile("vadd.vi v3, v1, 15"); \
if (strequal(#OPERATION, "vrsub")) \
__asm__ volatile("vrsub.vi v3, v1, 15"); \
break; \
default: \
assert(("unhandled intermediate for " #OPERATION ".vi", false)); \
} \
} \
__asm__ volatile("vs" #VTYPE ".v v3, (%0)" ::"r"(output)); \
} \
\
void test_vector_##OPERATION##_##SUBOPERATION##_##VTYPE##_##DATATYPE(void) { \
LOG_INFO("%s", __FUNCTION__); \
uint32_t allocation_size = 256; \
DATATYPE input1[allocation_size]; \
DATATYPE input2[allocation_size]; /* for vadd.vx/vsub.vx treat input2[0] \
as the scalar */ \
uint32_t unsigned_scalar = (uint32_t)random64(); \
int32_t signed_scalar = *((int32_t *)&unsigned_scalar); \
DATATYPE observed_output[allocation_size]; \
DATATYPE expected_output[allocation_size]; \
uint32_t new_vl; \
SETUP_TEST(VTYPE, allocation_size, new_vl); \
randomize_array(input1, new_vl, 8 * sizeof(DATATYPE)); \
randomize_array(input2, new_vl, 8 * sizeof(DATATYPE)); \
if (strequal(#SUBOPERATION, "vi")) { \
int32_t immediate_inputs[] = {-16, 15}; \
for (int i = 0; \
i < sizeof(immediate_inputs) / sizeof(*immediate_inputs); i++) { \
run_operations_test_vector_##OPERATION##_##SUBOPERATION##_##VTYPE##_##DATATYPE( \
observed_output, input1, &immediate_inputs[i]); \
calculate_expected_output( \
expected_output, input1, &immediate_inputs[i], new_vl, \
sizeof(DATATYPE), #DATATYPE, #OPERATION "." #SUBOPERATION); \
check_array_equality(observed_output, expected_output, \
new_vl * sizeof(DATATYPE)); \
} \
} else if (strequal(#SUBOPERATION, "vx")) { \
run_operations_test_vector_##OPERATION##_##SUBOPERATION##_##VTYPE##_##DATATYPE( \
observed_output, input1, \
(*(#DATATYPE) == 'u') ? (void *)&unsigned_scalar \
: (void *)&signed_scalar); \
calculate_expected_output( \
expected_output, input1, \
(*(#DATATYPE) == 'u') ? (void *)&unsigned_scalar \
: (void *)&signed_scalar, \
new_vl, sizeof(DATATYPE), #DATATYPE, #OPERATION "." #SUBOPERATION); \
check_array_equality(observed_output, expected_output, \
new_vl * sizeof(DATATYPE)); \
} else { \
run_operations_test_vector_##OPERATION##_##SUBOPERATION##_##VTYPE##_##DATATYPE( \
observed_output, input1, input2); \
calculate_expected_output(expected_output, input1, input2, new_vl, \
sizeof(DATATYPE), #DATATYPE, \
#OPERATION "." #SUBOPERATION); \
check_array_equality(expected_output, observed_output, \
new_vl * sizeof(DATATYPE)); \
} \
}
bool test_vector(void) {
LOG_INFO("%s", __FUNCTION__);
for (int i = 0; i < 100; i++) {
#ifdef TEST_VV
test_vector_vadd_vv();
test_vector_vsub_vv();
#endif
#ifdef TEST_VX
test_vector_vadd_vx();
test_vector_vsub_vx();
test_vector_vrsub_vx();
#endif
#ifdef TEST_VI
test_vector_vadd_vi();
test_vector_vrsub_vi();
#endif
}
return true;
}
MAKE_TEST(e8, vadd, vv, uint8_t);
MAKE_TEST(e8, vadd, vv, int8_t);
MAKE_TEST(e16, vadd, vv, uint16_t);
MAKE_TEST(e16, vadd, vv, int16_t);
MAKE_TEST(e32, vadd, vv, uint32_t);
MAKE_TEST(e32, vadd, vv, int32_t);
// MAKE_TEST(e64, vadd, vv, uint64_t);
// MAKE_TEST(e64, vadd, vv, int64_t);
void test_vector_vadd_vv(void) {
// TODO(julianmb): test signed + unsigned
test_vector_vadd_vv_e8_uint8_t();
test_vector_vadd_vv_e8_int8_t();
test_vector_vadd_vv_e16_uint16_t();
test_vector_vadd_vv_e16_int16_t();
test_vector_vadd_vv_e32_uint32_t();
test_vector_vadd_vv_e32_int32_t();
// test_vector_vadd_vv_e64_uint64_t();
// test_vector_vadd_vv_e64_int64_t();
}
MAKE_TEST(e8, vsub, vv, uint8_t);
MAKE_TEST(e8, vsub, vv, int8_t);
MAKE_TEST(e16, vsub, vv, uint16_t);
MAKE_TEST(e16, vsub, vv, int16_t);
MAKE_TEST(e32, vsub, vv, uint32_t);
MAKE_TEST(e32, vsub, vv, int32_t);
// MAKE_TEST(e64, vsub, vv, uint64_t);
// MAKE_TEST(e64, vsub, vv, int64_t);
void test_vector_vsub_vv(void) {
// TODO(julianmb): test signed + unsigned
test_vector_vsub_vv_e8_uint8_t();
test_vector_vsub_vv_e8_int8_t();
test_vector_vsub_vv_e16_uint16_t();
test_vector_vsub_vv_e16_int16_t();
test_vector_vsub_vv_e32_uint32_t();
test_vector_vsub_vv_e32_int32_t();
// test_vector_vsub_vv_e64_uint64_t();
// test_vector_vsub_vv_e64_int64_t();
}
MAKE_TEST(e8, vadd, vx, uint8_t);
MAKE_TEST(e8, vadd, vx, int8_t);
MAKE_TEST(e16, vadd, vx, uint16_t);
MAKE_TEST(e16, vadd, vx, int16_t);
MAKE_TEST(e32, vadd, vx, uint32_t);
MAKE_TEST(e32, vadd, vx, int32_t);
// MAKE_TEST(e64, vadd, vx, uint64_t);
// MAKE_TEST(e64, vadd, vx, int64_t);
void test_vector_vadd_vx(void) {
LOG_INFO("%s", __FUNCTION__);
// TODO(julianmb): test signed + unsigned
test_vector_vadd_vx_e8_uint8_t();
test_vector_vadd_vx_e8_int8_t();
test_vector_vadd_vx_e16_uint16_t();
test_vector_vadd_vv_e16_int16_t();
test_vector_vadd_vx_e32_uint32_t();
test_vector_vadd_vx_e32_int32_t();
// test_vector_vadd_vx_e64_uint64_t();
// test_vector_vadd_vx_e64_int64_t();
}
MAKE_TEST(e8, vsub, vx, uint8_t);
MAKE_TEST(e8, vsub, vx, int8_t);
MAKE_TEST(e16, vsub, vx, uint16_t);
MAKE_TEST(e16, vsub, vx, int16_t);
MAKE_TEST(e32, vsub, vx, uint32_t);
MAKE_TEST(e32, vsub, vx, int32_t);
// MAKE_TEST(e64, vsub, vx, uint64_t);
// MAKE_TEST(e64, vsub, vx, int64_t);
void test_vector_vsub_vx(void) {
LOG_INFO("%s", __FUNCTION__);
// TODO(julianmb): test signed + unsigned
test_vector_vsub_vx_e8_uint8_t();
test_vector_vsub_vx_e8_int8_t();
test_vector_vsub_vx_e16_uint16_t();
test_vector_vsub_vv_e16_int16_t();
test_vector_vsub_vx_e32_uint32_t();
test_vector_vsub_vx_e32_int32_t();
// test_vector_vsub_vx_e64_uint64_t();
// test_vector_vsub_vx_e64_int64_t();
}
MAKE_TEST(e8, vadd, vi, uint8_t);
MAKE_TEST(e8, vadd, vi, int8_t);
MAKE_TEST(e16, vadd, vi, uint16_t);
MAKE_TEST(e16, vadd, vi, int16_t);
MAKE_TEST(e32, vadd, vi, uint32_t);
MAKE_TEST(e32, vadd, vi, int32_t);
// MAKE_TEST(e64, vadd, vi, uint64_t);
// MAKE_TEST(e64, vadd, vi, int64_t);
void test_vector_vadd_vi(void) {
LOG_INFO("%s", __FUNCTION__);
// TODO(julianmb): test signed + unsigned
test_vector_vadd_vi_e8_uint8_t();
test_vector_vadd_vi_e8_int8_t();
test_vector_vadd_vi_e16_uint16_t();
test_vector_vadd_vi_e16_int16_t();
test_vector_vadd_vi_e32_uint32_t();
test_vector_vadd_vi_e32_int32_t();
// test_vector_vadd_vi_e64_uint64_t();
// test_vector_vadd_vi_e64_int64_t();
}
MAKE_TEST(e8, vrsub, vx, uint8_t);
MAKE_TEST(e8, vrsub, vx, int8_t);
MAKE_TEST(e16, vrsub, vx, uint16_t);
MAKE_TEST(e16, vrsub, vx, int16_t);
MAKE_TEST(e32, vrsub, vx, uint32_t);
MAKE_TEST(e32, vrsub, vx, int32_t);
// MAKE_TEST(e64, vrsub, vx, uint64_t);
// MAKE_TEST(e64, vrsub, vx, int64_t);
void test_vector_vrsub_vx(void) {
LOG_INFO("%s", __FUNCTION__);
// TODO(julianmb): test signed + unsigned
test_vector_vrsub_vx_e8_uint8_t();
test_vector_vrsub_vx_e8_int8_t();
test_vector_vrsub_vx_e16_uint16_t();
test_vector_vrsub_vx_e16_int16_t();
test_vector_vrsub_vx_e32_uint32_t();
test_vector_vrsub_vx_e32_int32_t();
// test_vector_vrsub_vx_e64_uint64_t();
// test_vector_vrsub_vx_e64_int64_t();
}
MAKE_TEST(e8, vrsub, vi, uint8_t);
MAKE_TEST(e8, vrsub, vi, int8_t);
MAKE_TEST(e16, vrsub, vi, uint16_t);
MAKE_TEST(e16, vrsub, vi, int16_t);
MAKE_TEST(e32, vrsub, vi, uint32_t);
MAKE_TEST(e32, vrsub, vi, int32_t);
// MAKE_TEST(e64, vrsub, vi, uint64_t);
// MAKE_TEST(e64, vrsub, vi, int64_t);
void test_vector_vrsub_vi(void) {
LOG_INFO("%s", __FUNCTION__);
// TODO(julianmb): test signed + unsigned
test_vector_vrsub_vi_e8_uint8_t();
test_vector_vrsub_vi_e8_int8_t();
test_vector_vrsub_vi_e16_uint16_t();
test_vector_vrsub_vi_e16_int16_t();
test_vector_vrsub_vi_e32_uint32_t();
test_vector_vrsub_vi_e32_int32_t();
// test_vector_vrsub_vi_e64_uint64_t();
// test_vector_vrsub_vi_e64_int64_t();
}
static void randomize_array(void *array, size_t length, size_t element_size) {
uint8_t *array8 = (uint8_t *)array;
for (int i = 0; i < ((length * element_size) >> 3); i++) {
array8[i] = (uint8_t)random64();
}
}
static void check_array_equality(void *array1, void *array2,
size_t byte_count) {
assert(memcmp(array1, array2, byte_count) == 0);
}
#define VADD_VV_EXPECTED_OUTPUT(DTYPE) \
do { \
if (strequal("vadd.vv", operation_str)) { \
*((DTYPE *)output) = *((DTYPE *)input1) + *((DTYPE *)input2); \
} \
} while (0)
#define VSUB_VV_EXPECTED_OUTPUT(DTYPE) \
do { \
if (strequal("vsub.vv", operation_str)) { \
*((DTYPE *)output) = *((DTYPE *)input1) - *((DTYPE *)input2); \
} \
} while (0)
#define VADD_VX_EXPECTED_OUTPUT(DTYPE) \
do { \
if (strequal("vadd.vx", operation_str) && *(#DTYPE) == 'u') { \
*((DTYPE *)output) = *((DTYPE *)input1) + *((uint32_t *)input2); \
} else if (strequal("vadd.vx", operation_str) && *(#DTYPE) == 'i') { \
*((DTYPE *)output) = *((DTYPE *)input1) + *((int32_t *)input2); \
} \
} while (0)
#define VSUB_VX_EXPECTED_OUTPUT(DTYPE) \
do { \
if (strequal("vsub.vx", operation_str) && *(#DTYPE) == 'u') { \
*((DTYPE *)output) = *((DTYPE *)input1) - *((uint32_t *)input2); \
} else if (strequal("vsub.vx", operation_str) && *(#DTYPE) == 'i') { \
*((DTYPE *)output) = *((DTYPE *)input1) - *((int32_t *)input2); \
} \
} while (0)
#define VADD_VI_EXPECTED_OUTPUT(DTYPE) \
do { \
if (strequal("vadd.vi", operation_str)) { \
*((DTYPE *)output) = *((DTYPE *)input1) + *((int32_t *)input2); \
} \
} while (0)
#define VRSUB_VX_EXPECTED_OUTPUT(DTYPE) \
do { \
if (strequal("vrsub.vx", operation_str) && *(#DTYPE) == 'u') { \
*((DTYPE *)output) = *((uint32_t *)input2) - *((DTYPE *)input1); \
} else if (strequal("vrsub.vx", operation_str) && *(#DTYPE) == 'i') { \
*((DTYPE *)output) = *((int32_t *)input2) - *((DTYPE *)input1); \
} \
} while (0)
#define VRSUB_VI_EXPECTED_OUTPUT(DTYPE) \
do { \
if (strequal("vrsub.vi", operation_str)) { \
*((DTYPE *)output) = *((int32_t *)input2) - *((DTYPE *)input1); \
} \
} while (0)
static void calculate_expected_output(void *output, void *input1, void *input2,
uint32_t len, uint8_t sizeof_datatype,
const char *datatype_str,
const char *operation_str) {
uint8_t output_inc = sizeof_datatype;
uint8_t input1_inc = sizeof_datatype;
uint8_t input2_inc = sizeof_datatype;
uint8_t input2_inc_mul =
strequal("vadd.vv", operation_str) || strequal("vsub.vv", operation_str)
? 1
: 0;
for (int i = 0; i < len; i++) {
if (strequal("uint8_t", datatype_str)) {
VADD_VV_EXPECTED_OUTPUT(uint8_t);
VSUB_VV_EXPECTED_OUTPUT(uint8_t);
VADD_VX_EXPECTED_OUTPUT(uint8_t);
VSUB_VX_EXPECTED_OUTPUT(uint8_t);
VADD_VI_EXPECTED_OUTPUT(uint8_t);
VRSUB_VX_EXPECTED_OUTPUT(uint8_t);
VRSUB_VI_EXPECTED_OUTPUT(uint8_t);
}
if (strequal("int8_t", datatype_str)) {
VADD_VV_EXPECTED_OUTPUT(int8_t);
VSUB_VV_EXPECTED_OUTPUT(int8_t);
VADD_VX_EXPECTED_OUTPUT(int8_t);
VSUB_VX_EXPECTED_OUTPUT(int8_t);
VADD_VI_EXPECTED_OUTPUT(int8_t);
VRSUB_VX_EXPECTED_OUTPUT(int8_t);
VRSUB_VI_EXPECTED_OUTPUT(int8_t);
}
if (strequal("uint16_t", datatype_str)) {
VADD_VV_EXPECTED_OUTPUT(uint16_t);
VSUB_VV_EXPECTED_OUTPUT(uint16_t);
VADD_VX_EXPECTED_OUTPUT(uint16_t);
VSUB_VX_EXPECTED_OUTPUT(uint16_t);
VADD_VI_EXPECTED_OUTPUT(uint16_t);
VRSUB_VX_EXPECTED_OUTPUT(uint16_t);
VRSUB_VI_EXPECTED_OUTPUT(uint16_t);
}
if (strequal("int16_t", datatype_str)) {
VADD_VV_EXPECTED_OUTPUT(int16_t);
VSUB_VV_EXPECTED_OUTPUT(int16_t);
VADD_VX_EXPECTED_OUTPUT(int16_t);
VSUB_VX_EXPECTED_OUTPUT(int16_t);
VADD_VI_EXPECTED_OUTPUT(int16_t);
VRSUB_VX_EXPECTED_OUTPUT(int16_t);
VRSUB_VI_EXPECTED_OUTPUT(int16_t);
}
if (strequal("uint32_t", datatype_str)) {
VADD_VV_EXPECTED_OUTPUT(uint32_t);
VSUB_VV_EXPECTED_OUTPUT(uint32_t);
VADD_VX_EXPECTED_OUTPUT(uint32_t);
VSUB_VX_EXPECTED_OUTPUT(uint32_t);
VADD_VI_EXPECTED_OUTPUT(uint32_t);
VRSUB_VX_EXPECTED_OUTPUT(uint32_t);
VRSUB_VI_EXPECTED_OUTPUT(uint32_t);
}
if (strequal("int32_t", datatype_str)) {
VADD_VV_EXPECTED_OUTPUT(int32_t);
VSUB_VV_EXPECTED_OUTPUT(int32_t);
VADD_VX_EXPECTED_OUTPUT(int32_t);
VSUB_VX_EXPECTED_OUTPUT(int32_t);
VADD_VI_EXPECTED_OUTPUT(int32_t);
VRSUB_VX_EXPECTED_OUTPUT(int32_t);
VRSUB_VI_EXPECTED_OUTPUT(int32_t);
}
if (strequal("uint64_t", datatype_str)) {
VADD_VV_EXPECTED_OUTPUT(uint64_t);
VSUB_VV_EXPECTED_OUTPUT(uint64_t);
VADD_VX_EXPECTED_OUTPUT(uint64_t);
VSUB_VX_EXPECTED_OUTPUT(uint64_t);
VADD_VI_EXPECTED_OUTPUT(uint64_t);
VRSUB_VX_EXPECTED_OUTPUT(uint64_t);
VRSUB_VI_EXPECTED_OUTPUT(uint64_t);
}
if (strequal("int64_t", datatype_str)) {
VADD_VV_EXPECTED_OUTPUT(int64_t);
VSUB_VV_EXPECTED_OUTPUT(int64_t);
VADD_VX_EXPECTED_OUTPUT(int64_t);
VSUB_VX_EXPECTED_OUTPUT(int64_t);
VADD_VI_EXPECTED_OUTPUT(int64_t);
VRSUB_VX_EXPECTED_OUTPUT(int64_t);
VRSUB_VI_EXPECTED_OUTPUT(int64_t);
}
output += output_inc;
input1 += input1_inc;
input2 += input2_inc_mul * input2_inc;
}
}