#ifndef TEST_V_HELPERS_H
#define TEST_V_HELPERS_H

#include <stdint.h>

#include <bit>
#include <tuple>

#include "pw_unit_test/framework.h"

namespace test_v_helpers {

const int LMUL_MAX = 8;
const int VLEN = 512;
const int MAXVL_BYTES = VLEN * LMUL_MAX;

const int32_t AVLS[] = {1,    4,    3,     2,     16,    8,    5,    17,
                        32,   36,   64,    55,    100,   321,  256,  128,
                        512,  623,  1024,  1100,  1543,  2048, 3052, 4096,
                        5555, 8192, 10241, 16384, 24325, 32768};
const int32_t AVL_COUNT = sizeof(AVLS) / sizeof(AVLS[0]);

enum VSEW {
  SEW_E8 = 0,
  SEW_E16 = 1,
  SEW_E32 = 2,
  /* // SEW limited to E32
    SEW_E64 = 3,
    SEW_E128 = 4,
    SEW_E256 = 5,
    SEW_E512 = 6,
    SEW_E1024 = 7,
  */
};

enum VLMUL {

  /* // Fractional LMUL not supported by our intrinsic compiler
    LMUL_MF8 = 5,
    LMUL_MF4 = 6,
    LMUL_MF2 = 7,
  */
  LMUL_M1 = 0,
  LMUL_M2 = 1,
  LMUL_M4 = 2,
  LMUL_M8 = 3,
};

uint32_t get_vtype(VSEW sew, VLMUL lmul, bool tail_agnostic,
                   bool mask_agnostic);

// vsetvl  rd, rs1, rs2      # rd = new vl, rs1 = AVL, rs2 = new vtype value
uint32_t set_vsetvl(VSEW sew, VLMUL lmul, uint32_t avl, bool tail_agnostic,
                    bool mask_agnostic);

int set_vsetvl_intrinsic(VSEW sew, VLMUL lmul, uint32_t avl);

int get_vsetvlmax_intrinsic(VSEW sew, VLMUL lmul);

int set_vsetvli(VSEW sew, VLMUL lmul, uint32_t avl);

void zero_vector_registers();
// Set AVL = constant 17 for now.
const uint32_t AVL_CONST = 17;
uint32_t set_vsetivli(VSEW sew, VLMUL lmul);

template <typename T>
void assert_vec_elem_eq(int avl, void *test_vector_1, void *test_vector_2) {
  T *ptr_vec_1 = reinterpret_cast<T *>(test_vector_1);
  T *ptr_vec_2 = reinterpret_cast<T *>(test_vector_2);
  for (int idx = 0; idx < avl; idx++) {
    ASSERT_EQ(ptr_vec_1[idx], ptr_vec_2[idx]);
  }
}

template <typename T>
static std::tuple<int, int> vector_test_setup(VLMUL lmul, int32_t avl,
                                              uint8_t *test_vector_1,
                                              uint8_t *test_vector_2) {
  // Clear all vector registers
  zero_vector_registers();
  // Initialize test_vector_1 and determine vl, vlmax
  uint32_t bw = std::__bit_width(sizeof(T));
  VSEW sew = static_cast<VSEW>(bw - 1);
  int vlmax = get_vsetvlmax_intrinsic(sew, lmul);
  if (avl > vlmax) {
    avl = vlmax;
  }
  memset(test_vector_1, 0, MAXVL_BYTES);
  memset(test_vector_2, 0, MAXVL_BYTES);
  int vl = set_vsetvl_intrinsic(sew, lmul, avl);

  EXPECT_EQ(avl, vl);

  return std::make_tuple(vlmax, vl);
}

}  // namespace test_v_helpers

#endif
