blob: 55fa9baa11ee8d00f977def45a26b65e5fe31751 [file] [log] [blame]
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "test_v_helpers.h"
#ifndef LIBSPRINGBOK_NO_VECTOR_SUPPORT
#include <riscv_vector.h>
namespace test_v_helpers {
uint32_t get_vtype(VSEW sew, VLMUL lmul, bool tail_agnostic,
bool mask_agnostic) {
return (static_cast<int>(lmul) & 0x7) | (static_cast<int>(sew) & 0x7) << 3 |
(tail_agnostic & 0x1) << 6 | (mask_agnostic & 0x1) << 7;
}
uint32_t set_vsetvl(VSEW sew, VLMUL lmul, uint32_t avl, bool tail_agnostic,
bool mask_agnostic) {
uint32_t vtype = get_vtype(sew, lmul, tail_agnostic, mask_agnostic);
uint32_t vl;
__asm__ volatile("vsetvl %[VL], %[AVL], %[VTYPE]"
: [VL] "=r"(vl)
: [AVL] "r"(avl), [VTYPE] "r"(vtype));
return vl;
}
int set_vsetvl_intrinsic(VSEW sew, VLMUL lmul, uint32_t avl) {
switch (lmul) {
case VLMUL::LMUL_M1:
switch (sew) {
case VSEW::SEW_E8:
return __riscv_vsetvl_e8m1(avl);
break;
case VSEW::SEW_E16:
return __riscv_vsetvl_e16m1(avl);
break;
case VSEW::SEW_E32:
return __riscv_vsetvl_e32m1(avl);
break;
default:
return -1;
break;
}
break;
case VLMUL::LMUL_M2:
switch (sew) {
case VSEW::SEW_E8:
return __riscv_vsetvl_e8m2(avl);
break;
case VSEW::SEW_E16:
return __riscv_vsetvl_e16m2(avl);
break;
case VSEW::SEW_E32:
return __riscv_vsetvl_e32m2(avl);
break;
default:
return -1;
break;
}
case VLMUL::LMUL_M4:
switch (sew) {
case VSEW::SEW_E8:
return __riscv_vsetvl_e8m4(avl);
break;
case VSEW::SEW_E16:
return __riscv_vsetvl_e16m4(avl);
break;
case VSEW::SEW_E32:
return __riscv_vsetvl_e32m4(avl);
break;
default:
return -1;
break;
}
break;
case VLMUL::LMUL_M8:
switch (sew) {
case VSEW::SEW_E8:
return __riscv_vsetvl_e8m8(avl);
break;
case VSEW::SEW_E16:
return __riscv_vsetvl_e16m8(avl);
break;
case VSEW::SEW_E32:
return __riscv_vsetvl_e32m8(avl);
break;
default:
return -1;
break;
}
break;
case VLMUL::LMUL_MF2:
switch (sew) {
case VSEW::SEW_E8:
return __riscv_vsetvl_e8mf2(avl);
break;
case VSEW::SEW_E16:
return __riscv_vsetvl_e16mf2(avl);
break;
case VSEW::SEW_E32:
return __riscv_vsetvl_e32mf2(avl);
break;
default:
return -1;
break;
}
break;
case VLMUL::LMUL_MF4:
switch (sew) {
case VSEW::SEW_E8:
return __riscv_vsetvl_e8mf4(avl);
break;
case VSEW::SEW_E16:
return __riscv_vsetvl_e16mf4(avl);
break;
default:
return -1;
break;
}
break;
case VLMUL::LMUL_MF8:
switch (sew) {
case VSEW::SEW_E8:
return __riscv_vsetvl_e8mf8(avl);
break;
default:
return -1;
break;
}
break;
default:
break;
}
return -1;
}
int get_vsetvlmax_intrinsic(VSEW sew, VLMUL lmul) {
switch (lmul) {
case VLMUL::LMUL_M1:
switch (sew) {
case VSEW::SEW_E8:
return __riscv_vsetvlmax_e8m1();
break;
case VSEW::SEW_E16:
return __riscv_vsetvlmax_e16m1();
break;
case VSEW::SEW_E32:
return __riscv_vsetvlmax_e32m1();
break;
default:
return -1;
break;
}
break;
case VLMUL::LMUL_M2:
switch (sew) {
case VSEW::SEW_E8:
return __riscv_vsetvlmax_e8m2();
break;
case VSEW::SEW_E16:
return __riscv_vsetvlmax_e16m2();
break;
case VSEW::SEW_E32:
return __riscv_vsetvlmax_e32m2();
break;
default:
return -1;
break;
}
case VLMUL::LMUL_M4:
switch (sew) {
case VSEW::SEW_E8:
return __riscv_vsetvlmax_e8m4();
break;
case VSEW::SEW_E16:
return __riscv_vsetvlmax_e16m4();
break;
case VSEW::SEW_E32:
return __riscv_vsetvlmax_e32m4();
break;
default:
return -1;
break;
}
break;
case VLMUL::LMUL_M8:
switch (sew) {
case VSEW::SEW_E8:
return __riscv_vsetvlmax_e8m8();
break;
case VSEW::SEW_E16:
return __riscv_vsetvlmax_e16m8();
break;
case VSEW::SEW_E32:
return __riscv_vsetvlmax_e32m8();
break;
default:
return -1;
break;
}
break;
case VLMUL::LMUL_MF2:
switch (sew) {
case VSEW::SEW_E8:
return __riscv_vsetvlmax_e8mf2();
break;
case VSEW::SEW_E16:
return __riscv_vsetvlmax_e16mf2();
break;
case VSEW::SEW_E32:
return __riscv_vsetvlmax_e32mf2();
break;
default:
return -1;
break;
}
case VLMUL::LMUL_MF4:
switch (sew) {
case VSEW::SEW_E8:
return __riscv_vsetvlmax_e8mf4();
break;
case VSEW::SEW_E16:
return __riscv_vsetvlmax_e16mf4();
break;
default:
return -1;
break;
}
case VLMUL::LMUL_MF8:
switch (sew) {
case VSEW::SEW_E8:
return __riscv_vsetvlmax_e8mf8();
break;
default:
return -1;
break;
}
default:
break;
}
return -1;
}
int set_vsetvli(VSEW sew, VLMUL lmul, uint32_t avl) {
uint32_t vl = 0;
switch (lmul) {
case VLMUL::LMUL_M1:
switch (sew) {
case VSEW::SEW_E8:
__asm__ volatile("vsetvli %[VL], %[AVL], e8, m1, tu, mu"
: [VL] "=r"(vl)
: [AVL] "r"(avl));
break;
case VSEW::SEW_E16:
__asm__ volatile("vsetvli %[VL], %[AVL], e16, m1, tu, mu"
: [VL] "=r"(vl)
: [AVL] "r"(avl));
break;
case VSEW::SEW_E32:
__asm__ volatile("vsetvli %[VL], %[AVL], e32, m1, tu, mu"
: [VL] "=r"(vl)
: [AVL] "r"(avl));
break;
default:
return 0;
}
break;
case VLMUL::LMUL_M2:
switch (sew) {
case VSEW::SEW_E8:
__asm__ volatile("vsetvli %[VL], %[AVL], e8, m2, tu, mu"
: [VL] "=r"(vl)
: [AVL] "r"(avl));
break;
case VSEW::SEW_E16:
__asm__ volatile("vsetvli %[VL], %[AVL], e16, m2, tu, mu"
: [VL] "=r"(vl)
: [AVL] "r"(avl));
break;
case VSEW::SEW_E32:
__asm__ volatile("vsetvli %[VL], %[AVL], e32, m2, tu, mu"
: [VL] "=r"(vl)
: [AVL] "r"(avl));
break;
default:
return 0;
}
break;
case VLMUL::LMUL_M4:
switch (sew) {
case VSEW::SEW_E8:
__asm__ volatile("vsetvli %[VL], %[AVL], e8, m4, tu, mu"
: [VL] "=r"(vl)
: [AVL] "r"(avl));
break;
case VSEW::SEW_E16:
__asm__ volatile("vsetvli %[VL], %[AVL], e16, m4, tu, mu"
: [VL] "=r"(vl)
: [AVL] "r"(avl));
break;
case VSEW::SEW_E32:
__asm__ volatile("vsetvli %[VL], %[AVL], e32, m4, tu, mu"
: [VL] "=r"(vl)
: [AVL] "r"(avl));
break;
default:
return 0;
}
break;
case VLMUL::LMUL_M8:
switch (sew) {
case VSEW::SEW_E8:
__asm__ volatile("vsetvli %[VL], %[AVL], e8, m8, tu, mu"
: [VL] "=r"(vl)
: [AVL] "r"(avl));
break;
case VSEW::SEW_E16:
__asm__ volatile("vsetvli %[VL], %[AVL], e16, m8, tu, mu"
: [VL] "=r"(vl)
: [AVL] "r"(avl));
break;
case VSEW::SEW_E32:
__asm__ volatile("vsetvli %[VL], %[AVL], e32, m8, tu, mu"
: [VL] "=r"(vl)
: [AVL] "r"(avl));
break;
default:
return 0;
}
break;
case VLMUL::LMUL_MF2:
switch (sew) {
case VSEW::SEW_E8:
__asm__ volatile("vsetvli %[VL], %[AVL], e8, mf2, ta, mu"
: [VL] "=r"(vl)
: [AVL] "r"(avl));
break;
case VSEW::SEW_E16:
__asm__ volatile("vsetvli %[VL], %[AVL], e16, mf2, ta, mu"
: [VL] "=r"(vl)
: [AVL] "r"(avl));
break;
case VSEW::SEW_E32:
__asm__ volatile("vsetvli %[VL], %[AVL], e32, mf2, ta, mu"
: [VL] "=r"(vl)
: [AVL] "r"(avl));
break;
default:
return 0;
}
break;
case VLMUL::LMUL_MF4:
switch (sew) {
case VSEW::SEW_E8:
__asm__ volatile("vsetvli %[VL], %[AVL], e8, mf4, ta, mu"
: [VL] "=r"(vl)
: [AVL] "r"(avl));
break;
case VSEW::SEW_E16:
__asm__ volatile("vsetvli %[VL], %[AVL], e16, mf4, ta, mu"
: [VL] "=r"(vl)
: [AVL] "r"(avl));
break;
default:
return 0;
}
break;
case VLMUL::LMUL_MF8:
switch (sew) {
case VSEW::SEW_E8:
__asm__ volatile("vsetvli %[VL], %[AVL], e8, mf8, ta, mu"
: [VL] "=r"(vl)
: [AVL] "r"(avl));
break;
default:
return 0;
}
break;
default:
return 0;
}
return vl;
}
void zero_vector_registers() {
// Clear all vector registers
int vlmax = get_vsetvlmax_intrinsic(VSEW::SEW_E32, VLMUL::LMUL_M8);
set_vsetvl_intrinsic(VSEW::SEW_E32, VLMUL::LMUL_M8, vlmax);
__asm__ volatile("vmv.v.i v0, 0");
__asm__ volatile("vmv.v.i v8, 0");
__asm__ volatile("vmv.v.i v16, 0");
__asm__ volatile("vmv.v.i v24, 0");
}
uint32_t set_vsetivli(VSEW sew, VLMUL lmul) {
uint32_t vl = 0;
switch (lmul) {
case VLMUL::LMUL_M1:
switch (sew) {
case VSEW::SEW_E8:
__asm__ volatile("vsetivli %[VL], %[AVL], e8, m1, tu, mu"
: [VL] "=r"(vl)
: [AVL] "n"(AVL_CONST));
break;
case VSEW::SEW_E16:
__asm__ volatile("vsetivli %[VL], %[AVL], e16, m1, tu, mu"
: [VL] "=r"(vl)
: [AVL] "n"(AVL_CONST));
break;
case VSEW::SEW_E32:
__asm__ volatile("vsetivli %[VL], %[AVL], e32, m1, tu, mu"
: [VL] "=r"(vl)
: [AVL] "n"(AVL_CONST));
break;
default:
return 0;
}
break;
case VLMUL::LMUL_M2:
switch (sew) {
case VSEW::SEW_E8:
__asm__ volatile("vsetivli %[VL], %[AVL], e8, m2, tu, mu"
: [VL] "=r"(vl)
: [AVL] "n"(AVL_CONST));
break;
case VSEW::SEW_E16:
__asm__ volatile("vsetivli %[VL], %[AVL], e16, m2, tu, mu"
: [VL] "=r"(vl)
: [AVL] "n"(AVL_CONST));
break;
case VSEW::SEW_E32:
__asm__ volatile("vsetivli %[VL], %[AVL], e32, m2, tu, mu"
: [VL] "=r"(vl)
: [AVL] "n"(AVL_CONST));
break;
default:
return 0;
}
break;
case VLMUL::LMUL_M4:
switch (sew) {
case VSEW::SEW_E8:
__asm__ volatile("vsetivli %[VL], %[AVL], e8, m4, tu, mu"
: [VL] "=r"(vl)
: [AVL] "n"(AVL_CONST));
break;
case VSEW::SEW_E16:
__asm__ volatile("vsetivli %[VL], %[AVL], e16, m4, tu, mu"
: [VL] "=r"(vl)
: [AVL] "n"(AVL_CONST));
break;
case VSEW::SEW_E32:
__asm__ volatile("vsetivli %[VL], %[AVL], e32, m4, tu, mu"
: [VL] "=r"(vl)
: [AVL] "n"(AVL_CONST));
break;
default:
return 0;
}
break;
case VLMUL::LMUL_M8:
switch (sew) {
case VSEW::SEW_E8:
__asm__ volatile("vsetivli %[VL], %[AVL], e8, m8, tu, mu"
: [VL] "=r"(vl)
: [AVL] "n"(AVL_CONST));
break;
case VSEW::SEW_E16:
__asm__ volatile("vsetivli %[VL], %[AVL], e16, m8, tu, mu"
: [VL] "=r"(vl)
: [AVL] "n"(AVL_CONST));
break;
case VSEW::SEW_E32:
__asm__ volatile("vsetivli %[VL], %[AVL], e32, m8, tu, mu"
: [VL] "=r"(vl)
: [AVL] "n"(AVL_CONST));
break;
default:
return 0;
}
break;
case VLMUL::LMUL_MF2:
switch (sew) {
case VSEW::SEW_E8:
__asm__ volatile("vsetivli %[VL], %[AVL], e8, mf2, ta, mu"
: [VL] "=r"(vl)
: [AVL] "n"(AVL_CONST));
break;
case VSEW::SEW_E16:
__asm__ volatile("vsetivli %[VL], %[AVL], e16, mf2, ta, mu"
: [VL] "=r"(vl)
: [AVL] "n"(AVL_CONST));
break;
case VSEW::SEW_E32:
__asm__ volatile("vsetivli %[VL], %[AVL], e32, mf2, ta, mu"
: [VL] "=r"(vl)
: [AVL] "n"(AVL_CONST));
break;
default:
return 0;
}
break;
case VLMUL::LMUL_MF4:
switch (sew) {
case VSEW::SEW_E8:
__asm__ volatile("vsetivli %[VL], %[AVL], e8, mf4, ta, mu"
: [VL] "=r"(vl)
: [AVL] "n"(AVL_CONST));
break;
case VSEW::SEW_E16:
__asm__ volatile("vsetivli %[VL], %[AVL], e16, mf4, ta, mu"
: [VL] "=r"(vl)
: [AVL] "n"(AVL_CONST));
break;
default:
return 0;
}
break;
case VLMUL::LMUL_MF8:
switch (sew) {
case VSEW::SEW_E8:
__asm__ volatile("vsetivli %[VL], %[AVL], e8, mf8, ta, mu"
: [VL] "=r"(vl)
: [AVL] "n"(AVL_CONST));
break;
default:
return 0;
}
break;
default:
return 0;
}
return vl;
}
} // namespace test_v_helpers
#endif // #ifndef LIBSPRINGBOK_NO_VECTOR_SUPPORT