#ifndef TEST_V_HELPERS_H
#define TEST_V_HELPERS_H

#include <stdint.h>

#include <bit>
#include <tuple>

#include "pw_unit_test/framework.h"

namespace test_v_helpers {

const int LMUL_MAX = 8;
const int VLEN = 512;
const int MAXVL_BYTES = VLEN * LMUL_MAX;

#ifdef FOR_TBM
const int32_t AVLS[] = {17};
#else
const int32_t AVLS[] = {1,    4,    3,     2,     16,    8,    5,    17,
                        32,   36,   64,    55,    100,   321,  256,  128,
                        512,  623,  1024,  1100,  1543,  2048, 3052, 4096,
                        5555, 8192, 10241, 16384, 24325, 32768};
#endif

const int32_t AVL_COUNT = sizeof(AVLS) / sizeof(AVLS[0]);

enum VSEW {
  SEW_E8 = 0,
  SEW_E16 = 1,
  SEW_E32 = 2,
  /* // SEW limited to E32
    SEW_E64 = 3,
    SEW_E128 = 4,
    SEW_E256 = 5,
    SEW_E512 = 6,
    SEW_E1024 = 7,
  */
};

enum VLMUL {
  LMUL_MF8 = 5,
  LMUL_MF4 = 6,
  LMUL_MF2 = 7,

  LMUL_M1 = 0,
  LMUL_M2 = 1,
  LMUL_M4 = 2,
  LMUL_M8 = 3,
};

uint32_t get_vtype(VSEW sew, VLMUL lmul, bool tail_agnostic,
                   bool mask_agnostic);

// vsetvl  rd, rs1, rs2      # rd = new vl, rs1 = AVL, rs2 = new vtype value
uint32_t set_vsetvl(VSEW sew, VLMUL lmul, uint32_t avl, bool tail_agnostic,
                    bool mask_agnostic);

int set_vsetvl_intrinsic(VSEW sew, VLMUL lmul, uint32_t avl);

int get_vsetvlmax_intrinsic(VSEW sew, VLMUL lmul);

int set_vsetvli(VSEW sew, VLMUL lmul, uint32_t avl);

void zero_vector_registers();
// Set AVL = constant 17 for now.
const uint32_t AVL_CONST = 17;
uint32_t set_vsetivli(VSEW sew, VLMUL lmul);

template <typename T>
void assert_vec_elem_eq(int avl, void *test_vector_1, void *test_vector_2) {
  T *ptr_vec_1 = reinterpret_cast<T *>(test_vector_1);
  T *ptr_vec_2 = reinterpret_cast<T *>(test_vector_2);
  for (int idx = 0; idx < avl; idx++) {
    ASSERT_EQ(ptr_vec_1[idx], ptr_vec_2[idx]);
  }
}

template <typename T>
void assert_vec_mask_eq(int avl, void *test_vector_1, void *test_vector_2) {
  const unsigned int bw_required = std::__bit_width(sizeof(T)* 8);
  const unsigned int shift = bw_required - 1;
  T *ptr_vec_1 = reinterpret_cast<T *>(test_vector_1);
  T *ptr_vec_2 = reinterpret_cast<T *>(test_vector_2);
  for (int idx = 0; idx < avl; idx++) {
    unsigned int element_idx = idx >> shift;  // Eqivalent to idx / (sizeof(T) * 8)
    unsigned int element_pos = idx & ~(element_idx << shift); // Equivalent to idx % (sizeof(T) * 8)
    T *e1 = ptr_vec_1 + element_idx;
    T *e2 = ptr_vec_2 + element_idx;
    ASSERT_EQ(*e1 & (1 << element_pos), *e2 & (1 << element_pos));
  }
}

template <typename T>
static std::tuple<int, int> vector_test_setup(
    VLMUL lmul, int32_t avl, const std::initializer_list<void *> &vec_list) {
  // Clear all vector registers
  zero_vector_registers();
  // Initialize test_vector_1 and determine vl, vlmax
  uint32_t bw = std::__bit_width(sizeof(T));
  VSEW sew = static_cast<VSEW>(bw - 1);
  int vlmax = get_vsetvlmax_intrinsic(sew, lmul);
  if (avl > vlmax) {
    avl = vlmax;
  }
  for (auto vec : vec_list) {
    memset(vec, 0, MAXVL_BYTES);
  }
  int vl = set_vsetvl_intrinsic(sew, lmul, avl);

  EXPECT_EQ(avl, vl);

  return std::make_tuple(vlmax, vl);
}

template <typename T>
void fill_random_vector(T *vec, int32_t avl) {
  for (int32_t i = 0; i < avl; i++) {
    vec[i] = static_cast<T>(rand());
  }
}

template <typename T>
void fill_vector_with_index(T *vec, int32_t avl) {
  for (int32_t i = 0; i < avl; i++) {
    vec[i] = static_cast<T>(i);
  }
}

}  // namespace test_v_helpers

#endif
