// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <riscv_vector.h>
#include <springbok.h>
#include <stdio.h>
#include <stdlib.h>

#include <bit>
#include <tuple>

#include "pw_unit_test/framework.h"
#include "test_v_helpers/include/test_v_helpers.h"

namespace vmv_test {
namespace {

using namespace test_v_helpers;

uint8_t test_vector_1[MAXVL_BYTES];
uint8_t reference_vector_1[MAXVL_BYTES];

class VmvTest : public ::testing::Test {
 protected:
  void SetUp() override { zero_vector_registers(); }
  void TearDown() override { zero_vector_registers(); }
};

TEST_F(VmvTest, DISABLED_vmv_demo) {
  for (int i = 0; i < AVL_COUNT; i++) {
    int32_t avl = AVLS[i];
    int vlmax;
    int vl;
    std::tie(vlmax, vl) = vector_test_setup<int8_t>(
        VLMUL::LMUL_M1, avl, {test_vector_1, reference_vector_1});
    if (avl > vlmax) {
      continue;
    }

    __asm__ volatile("vle8.v v0, (%0)" : : "r"(test_vector_1));
    __asm__ volatile("vmv.v.v v1, v0");
    __asm__ volatile("vse8.v v1, (%0)" : : "r"(reference_vector_1));
    assert_vec_elem_eq<int8_t>(vlmax, test_vector_1, reference_vector_1);
  }
}

// TODO(henryherman): clang vle intrinsic uses vs1r.v (unsupported in renode)
TEST_F(VmvTest, DISABLED_intrinsic_vmv_demo) {
  for (int i = 0; i < AVL_COUNT; i++) {
    int32_t avl = AVLS[i];
    int vlmax;
    int vl;
    std::tie(vlmax, vl) = vector_test_setup<int8_t>(
        VLMUL::LMUL_M1, avl, {test_vector_1, reference_vector_1});
    if (avl > vlmax) {
      continue;
    }
    vint8m1_t vec1 =
        __riscv_vle8_v_i8m1(reinterpret_cast<int8_t *>(test_vector_1), vl);
    vint8m1_t vec2 = __riscv_vmv_v_v_i8m1(vec1, vl);
    int8_t *ptr_vec_2 = reinterpret_cast<int8_t *>(reference_vector_1);
    __riscv_vse8_v_i8m1(ptr_vec_2, vec2, vl);
    assert_vec_elem_eq<int8_t>(vlmax, test_vector_1, reference_vector_1);
  }
}

#define DEFINE_TEST_VMV_V_V_I_INTRINSIC(_SEW_, _LMUL_)                      \
  TEST_F(VmvTest, DISABLED_intrinsic_vmv_v_v_i##_SEW_##m##_LMUL_) {         \
    for (int i = 0; i < AVL_COUNT; i++) {                                   \
      int32_t avl = AVLS[i];                                                \
      int vlmax;                                                            \
      int vl;                                                               \
      std::tie(vlmax, vl) = vector_test_setup<int##_SEW_##_t>(              \
          VLMUL::LMUL_M##_LMUL_, avl, {test_vector_1, reference_vector_1}); \
      if (avl > vlmax) {                                                    \
        continue;                                                           \
      }                                                                     \
      vint##_SEW_##m##_LMUL_##_t vec1 =                                     \
          __riscv_vle##_SEW_##_v_i##_SEW_##m##_LMUL_(                       \
              reinterpret_cast<int##_SEW_##_t *>(test_vector_1), vl);       \
      vint##_SEW_##m##_LMUL_##_t vec2 =                                     \
          __riscv_vmv_v_v_i##_SEW_##m##_LMUL_(vec1, vl);                    \
      int##_SEW_##_t *ptr_vec_2 =                                           \
          reinterpret_cast<int##_SEW_##_t *>(reference_vector_1);           \
      __riscv_vse##_SEW_##_v_i##_SEW_##m##_LMUL_(ptr_vec_2, vec2, vl);      \
      assert_vec_elem_eq<int##_SEW_##_t>(vlmax, test_vector_1,              \
                                         reference_vector_1);               \
    }                                                                       \
  }

DEFINE_TEST_VMV_V_V_I_INTRINSIC(8, 1)
DEFINE_TEST_VMV_V_V_I_INTRINSIC(8, 2)
DEFINE_TEST_VMV_V_V_I_INTRINSIC(8, 4)
DEFINE_TEST_VMV_V_V_I_INTRINSIC(8, 8)

DEFINE_TEST_VMV_V_V_I_INTRINSIC(16, 1)
DEFINE_TEST_VMV_V_V_I_INTRINSIC(16, 2)
DEFINE_TEST_VMV_V_V_I_INTRINSIC(16, 4)
DEFINE_TEST_VMV_V_V_I_INTRINSIC(16, 8)

DEFINE_TEST_VMV_V_V_I_INTRINSIC(32, 1)
DEFINE_TEST_VMV_V_V_I_INTRINSIC(32, 2)
DEFINE_TEST_VMV_V_V_I_INTRINSIC(32, 4)
DEFINE_TEST_VMV_V_V_I_INTRINSIC(32, 8)

#define DEFINE_TEST_VMV_V_V_I(_SEW_, _LMUL_)                                \
  TEST_F(VmvTest, vmv_v_v_i##_SEW_##m##_LMUL_) {                            \
    for (int i = 0; i < AVL_COUNT; i++) {                                   \
      int32_t avl = AVLS[i];                                                \
      int vlmax;                                                            \
      int vl;                                                               \
      std::tie(vlmax, vl) = vector_test_setup<int##_SEW_##_t>(              \
          VLMUL::LMUL_M##_LMUL_, avl, {test_vector_1, reference_vector_1}); \
      if (avl > vlmax) {                                                    \
        continue;                                                           \
      }                                                                     \
      int##_SEW_##_t *ptr_vec_1 =                                           \
          reinterpret_cast<int##_SEW_##_t *>(test_vector_1);                \
      int##_SEW_##_t *ptr_vec_2 =                                           \
          reinterpret_cast<int##_SEW_##_t *>(reference_vector_1);           \
      __asm__ volatile("vle" #_SEW_ ".v v0, (%0)" : : "r"(ptr_vec_1));      \
      __asm__ volatile("vmv.v.v v8, v0");                                   \
      __asm__ volatile("vse" #_SEW_ ".v v8, (%0)" : : "r"(ptr_vec_2));      \
      assert_vec_elem_eq<int##_SEW_##_t>(vlmax, test_vector_1,              \
                                         reference_vector_1);               \
    }                                                                       \
  }

DEFINE_TEST_VMV_V_V_I(8, 1)
DEFINE_TEST_VMV_V_V_I(8, 2)
DEFINE_TEST_VMV_V_V_I(8, 4)
DEFINE_TEST_VMV_V_V_I(8, 8)

DEFINE_TEST_VMV_V_V_I(16, 1)
DEFINE_TEST_VMV_V_V_I(16, 2)
DEFINE_TEST_VMV_V_V_I(16, 4)
DEFINE_TEST_VMV_V_V_I(16, 8)

DEFINE_TEST_VMV_V_V_I(32, 1)
DEFINE_TEST_VMV_V_V_I(32, 2)
DEFINE_TEST_VMV_V_V_I(32, 4)
DEFINE_TEST_VMV_V_V_I(32, 8)

TEST_F(VmvTest, vmv_v_x_demo) {
  for (int i = 0; i < AVL_COUNT; i++) {
    int32_t avl = AVLS[i];
    int vlmax;
    int vl;
    std::tie(vlmax, vl) = vector_test_setup<int8_t>(
        VLMUL::LMUL_M1, avl, {test_vector_1, reference_vector_1});
    if (avl > vlmax) {
      continue;
    }
    int8_t *ptr_vec_1 = reinterpret_cast<int8_t *>(test_vector_1);
    int8_t *ptr_vec_2 = reinterpret_cast<int8_t *>(reference_vector_1);
    int8_t test_val = 0xAB;
    __asm__ volatile("vmv.v.x v8, %[RS1]" ::[RS1] "r"(test_val));
    for (int i = 0; i < vl; i++) {
      ptr_vec_1[i] = test_val;
    }
    __asm__ volatile("vse8.v v8, (%0)" : : "r"(ptr_vec_2));
    assert_vec_elem_eq<int8_t>(vlmax, test_vector_1, reference_vector_1);
  }
}

#define DEFINE_TEST_VMV_V_X_I(_SEW_, _LMUL_, TEST_VAL)                      \
  TEST_F(VmvTest, vmv_v_x_e##_SEW_##m##_LMUL_) {                            \
    for (int i = 0; i < AVL_COUNT; i++) {                                   \
      int32_t avl = AVLS[i];                                                \
      int vlmax;                                                            \
      int vl;                                                               \
      std::tie(vlmax, vl) = vector_test_setup<int##_SEW_##_t>(              \
          VLMUL::LMUL_M##_LMUL_, avl, {test_vector_1, reference_vector_1}); \
      if (avl > vlmax) {                                                    \
        continue;                                                           \
      }                                                                     \
      int##_SEW_##_t *ptr_vec_1 =                                           \
          reinterpret_cast<int##_SEW_##_t *>(test_vector_1);                \
      int##_SEW_##_t *ptr_vec_2 =                                           \
          reinterpret_cast<int##_SEW_##_t *>(reference_vector_1);           \
      const int##_SEW_##_t test_val = TEST_VAL;                             \
      __asm__ volatile("vmv.v.x v8, %[RS1]" ::[RS1] "r"(test_val));         \
      for (int i = 0; i < vl; i++) {                                        \
        ptr_vec_1[i] = test_val;                                            \
      }                                                                     \
      __asm__ volatile("vse" #_SEW_ ".v v8, (%0)" : : "r"(ptr_vec_2));      \
      assert_vec_elem_eq<int##_SEW_##_t>(vlmax, test_vector_1,              \
                                         reference_vector_1);               \
    }                                                                       \
  }

DEFINE_TEST_VMV_V_X_I(8, 1, 0xab)
DEFINE_TEST_VMV_V_X_I(8, 2, 0xac)
DEFINE_TEST_VMV_V_X_I(8, 4, 0xad)
DEFINE_TEST_VMV_V_X_I(8, 8, 0xae)

DEFINE_TEST_VMV_V_X_I(16, 1, 0xabc1)
DEFINE_TEST_VMV_V_X_I(16, 2, 0xabc2)
DEFINE_TEST_VMV_V_X_I(16, 4, 0xabc3)
DEFINE_TEST_VMV_V_X_I(16, 8, 0xabc4)

DEFINE_TEST_VMV_V_X_I(32, 1, 0xabcdef12)
DEFINE_TEST_VMV_V_X_I(32, 2, 0xabcdef13)
DEFINE_TEST_VMV_V_X_I(32, 4, 0xabcdef14)
DEFINE_TEST_VMV_V_X_I(32, 8, 0xabcdef15)

TEST_F(VmvTest, vmv_v_i_demo) {
  for (int i = 0; i < AVL_COUNT; i++) {
    int32_t avl = AVLS[i];
    int vlmax;
    int vl;
    std::tie(vlmax, vl) = vector_test_setup<int8_t>(
        VLMUL::LMUL_M1, avl, {test_vector_1, reference_vector_1});
    if (avl > vlmax) {
      continue;
    }
    int8_t *ptr_vec_1 = reinterpret_cast<int8_t *>(test_vector_1);
    int8_t *ptr_vec_2 = reinterpret_cast<int8_t *>(reference_vector_1);
    int8_t test_val = -12;
    __asm__ volatile("vmv.v.i v8, -12" ::);
    for (int i = 0; i < vl; i++) {
      ptr_vec_1[i] = test_val;
    }
    __asm__ volatile("vse8.v v8, (%0)" : : "r"(ptr_vec_2));
    assert_vec_elem_eq<int8_t>(vlmax, test_vector_1, reference_vector_1);
  }
}

// TODO(gkielian): Allow mechanism for multiple tests for same sew,lmul pair
#define DEFINE_TEST_VMV_V_I_I(_SEW_, _LMUL_, TEST_VAL)                      \
  TEST_F(VmvTest, vmv_v_i_e##_SEW_##m##_LMUL_) {                            \
    for (int i = 0; i < AVL_COUNT; i++) {                                   \
      int32_t avl = AVLS[i];                                                \
      int vlmax;                                                            \
      int vl;                                                               \
      std::tie(vlmax, vl) = vector_test_setup<int##_SEW_##_t>(              \
          VLMUL::LMUL_M##_LMUL_, avl, {test_vector_1, reference_vector_1}); \
      if (avl > vlmax) {                                                    \
        continue;                                                           \
      }                                                                     \
      int##_SEW_##_t *ptr_vec_1 =                                           \
          reinterpret_cast<int##_SEW_##_t *>(test_vector_1);                \
      int##_SEW_##_t *ptr_vec_2 =                                           \
          reinterpret_cast<int##_SEW_##_t *>(reference_vector_1);           \
      int##_SEW_##_t test_val = TEST_VAL;                                   \
      __asm__ volatile("vmv.v.i v8, " #TEST_VAL);                           \
      for (int i = 0; i < vl; i++) {                                        \
        ptr_vec_1[i] = test_val;                                            \
      }                                                                     \
      __asm__ volatile("vse" #_SEW_ ".v v8, (%0)" : : "r"(ptr_vec_2));      \
      assert_vec_elem_eq<int##_SEW_##_t>(vlmax, test_vector_1,              \
                                         reference_vector_1);               \
    }                                                                       \
  }

DEFINE_TEST_VMV_V_I_I(8, 1, -11)
DEFINE_TEST_VMV_V_I_I(8, 2, -12)
DEFINE_TEST_VMV_V_I_I(8, 4, -13)
DEFINE_TEST_VMV_V_I_I(8, 8, -14)

DEFINE_TEST_VMV_V_I_I(16, 1, -10)
DEFINE_TEST_VMV_V_I_I(16, 2, -9)
DEFINE_TEST_VMV_V_I_I(16, 4, -8)
DEFINE_TEST_VMV_V_I_I(16, 8, -7)

DEFINE_TEST_VMV_V_I_I(32, 1, -2)
DEFINE_TEST_VMV_V_I_I(32, 2, -3)
DEFINE_TEST_VMV_V_I_I(32, 4, -4)
DEFINE_TEST_VMV_V_I_I(32, 8, -5)

}  // namespace
}  // namespace vmv_test
