blob: f4183ec214d3ce86a8cfe98b0a92810e7b430bc2 [file] [log] [blame]
#include <limits.h>
#include <riscv_vector.h>
#include <springbok.h>
#include <stdio.h>
#include <stdlib.h>
#include <bit>
#include <tuple>
#include "pw_unit_test/framework.h"
#include "test_v_helpers.h"
namespace vmax_vv_test {
namespace {
using namespace test_v_helpers;
uint8_t test_vector_1[MAXVL_BYTES];
uint8_t test_vector_2[MAXVL_BYTES];
template <typename T>
static std::tuple<int, int> vmax_vv_test_setup(VLMUL lmul, int32_t avl) {
// Clear all vector registers
zero_vector_registers();
// Initialize test_vector_1 and determine vl, vlmax
uint32_t bw = std::__bit_width(sizeof(T));
VSEW sew = static_cast<VSEW>(bw - 1);
int vlmax = get_vsetvlmax_intrinsic(sew, lmul);
if (avl > vlmax) {
avl = vlmax;
}
memset(test_vector_1, 0, MAXVL_BYTES);
memset(test_vector_2, 0, MAXVL_BYTES);
int vl = set_vsetvl_intrinsic(sew, lmul, avl);
EXPECT_EQ(avl, vl);
return std::make_tuple(vlmax, vl);
}
class VmaxVvTest : public ::testing::Test {
protected:
void SetUp() override { zero_vector_registers(); }
void TearDown() override { zero_vector_registers(); }
};
// Below is a non-macro version of the test for more convenient debugging.
// Remove the "DISABLED_" prefix to enable this test for debugging.
TEST_F(VmaxVvTest, DISABLED_vmax_vv_demo) {
for (int i = 0; i < AVL_COUNT; i++) {
int32_t avl = AVLS[i];
int vlmax;
int vl;
std::tie(vlmax, vl) = vmax_vv_test_setup<int16_t>(VLMUL::LMUL_M1, avl);
if (avl > vlmax) {
continue;
}
int16_t *ptr_vec_1 = reinterpret_cast<int16_t *>(test_vector_1);
int16_t *ptr_vec_2 = reinterpret_cast<int16_t *>(test_vector_2);
const int16_t start_value = INT16_MAX;
// set up values for vs2 of vmax.vv
for (int idx = 0; idx < vl; idx++) {
// restrict values to valid int8_t range
ptr_vec_1[idx] = (int16_t)(idx + start_value);
}
__asm__ volatile("vle16.v v8, (%0)" : : "r"(ptr_vec_1));
// set up values for vs2 of vmax.vv
for (int idx = 0; idx < vl; idx++) {
// offset sequence for ptr_vec_2
ptr_vec_2[idx] = (int16_t)(idx - start_value);
}
__asm__ volatile("vle16.v v16, (%0)" : : "r"(ptr_vec_2));
__asm__ volatile("vmax.vv v24, v16, v8");
// emulate operation in C
for (int idx = 0; idx < vl; idx++) {
ptr_vec_1[idx] =
(ptr_vec_1[idx] > ptr_vec_2[idx]) ? ptr_vec_1[idx] : ptr_vec_2[idx];
}
__asm__ volatile("vse16.v v24, (%0)" : : "r"(ptr_vec_2));
assert_vec_elem_eq<int16_t>(vlmax, test_vector_1, test_vector_2);
}
}
#define DEFINE_TEST_VMAX_VV(_SEW_, _LMUL_, START_VALUE) \
TEST_F(VmaxVvTest, vmax_vv##_SEW_##m##_LMUL_) { \
for (int i = 0; i < AVL_COUNT; i++) { \
int32_t avl = AVLS[i]; \
int vlmax; \
int vl; \
std::tie(vlmax, vl) = \
vmax_vv_test_setup<int##_SEW_##_t>(VLMUL::LMUL_M##_LMUL_, avl); \
if (avl > vlmax) { \
continue; \
} \
int##_SEW_##_t *ptr_vec_1 = \
reinterpret_cast<int##_SEW_##_t *>(test_vector_1); \
int##_SEW_##_t *ptr_vec_2 = \
reinterpret_cast<int##_SEW_##_t *>(test_vector_2); \
const int##_SEW_##_t start_value = START_VALUE; \
for (int idx = 0; idx < vl; idx++) { \
ptr_vec_1[idx] = (int##_SEW_##_t)(idx + start_value); \
} \
for (int idx = 0; idx < vl; idx++) { \
ptr_vec_1[idx] = (int##_SEW_##_t)(idx - start_value); \
} \
__asm__ volatile("vle" #_SEW_ ".v v8, (%0)" : : "r"(ptr_vec_1)); \
__asm__ volatile("vle" #_SEW_ ".v v16, (%0)" : : "r"(ptr_vec_2)); \
__asm__ volatile("vmax.vv v24, v16, v8"); \
for (int idx = 0; idx < vl; idx++) { \
ptr_vec_1[idx] = (ptr_vec_1[idx] > ptr_vec_2[idx]) ? ptr_vec_1[idx] \
: ptr_vec_2[idx]; \
} \
__asm__ volatile("vse" #_SEW_ ".v v24, (%0)" : : "r"(ptr_vec_2)); \
assert_vec_elem_eq<int##_SEW_##_t>(vlmax, test_vector_1, test_vector_2); \
} \
}
// TODO(gkielian): modify macro to permit more than one test per sew/lmul pair
DEFINE_TEST_VMAX_VV(8, 1, INT8_MAX)
DEFINE_TEST_VMAX_VV(8, 2, INT8_MIN)
DEFINE_TEST_VMAX_VV(8, 4, 120)
DEFINE_TEST_VMAX_VV(8, 8, 2)
DEFINE_TEST_VMAX_VV(16, 1, INT16_MAX)
DEFINE_TEST_VMAX_VV(16, 2, INT16_MIN)
DEFINE_TEST_VMAX_VV(16, 4, 16000)
DEFINE_TEST_VMAX_VV(16, 8, 2)
DEFINE_TEST_VMAX_VV(32, 1, INT32_MAX)
DEFINE_TEST_VMAX_VV(32, 2, INT32_MIN)
DEFINE_TEST_VMAX_VV(32, 4, 1000000000)
DEFINE_TEST_VMAX_VV(32, 8, 2)
} // namespace
} // namespace vmax_vv_test