blob: 31e391fb00fbde43d95eb59274d4584f3891b9e1 [file] [log] [blame]
/*
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cstdio>
#include <cstring>
#include "tests/kelvin_isa/kelvin_test.h"
#include "tests/kelvin_isa/vdwconv_data.h"
#ifdef TEST_GEN
static int32_t dwconv(const vdwconv_u8_t& cmd, uint8_t ina[3], uint8_t inb[3]) {
int32_t sbias1 = cmd.sbias1;
int32_t sbias2 = cmd.sbias2;
int32_t accum = 0;
for (int i = 0; i < 3; ++i) {
int32_t sdata1 = cmd.sdata1 ? int8_t(ina[i]) : uint8_t(ina[i]);
int32_t sdata2 = cmd.sdata2 ? int8_t(inb[i]) : uint8_t(inb[i]);
accum += (sdata1 + sbias1) * (sdata2 + sbias2);
}
return accum;
}
#endif // TEST_GEN
void dwconv(const vdwconv_u8_t& cmd, const uint8_t ina[3][kZlen],
const uint8_t inb[3][kZlen], const uint32_t ref[4][kZlen / 4],
uint32_t dut[4][kZlen / 4]) {
uint32_t cmdw;
memcpy(&cmdw, &cmd, 4);
int sparsity = cmd.sparsity;
int regbase = cmd.regbase;
#ifdef TEST_GEN
for (int j = 0; j < kZlen / 4; ++j) {
for (int i = 0; i < 4; ++i) {
int idx = i + 4 * j;
uint8_t va[3];
uint8_t vb[3] = {inb[0][idx], inb[1][idx], inb[2][idx]};
if (sparsity == 0) {
va[0] = ina[0][idx];
va[1] = ina[1][idx];
va[2] = ina[2][idx];
} else if (sparsity == 1) {
va[0] = ina[1][idx - 4];
va[1] = ina[1][idx + 0];
va[2] = ina[1][idx + 4];
} else if (sparsity == 2) {
va[0] = ina[0][idx + 0];
va[1] = ina[0][idx + 4];
va[2] = ina[0][idx + 8];
}
const int interleave[4] = {0, 2, 1, 3};
ref[interleave[i]][j] = dwconv(cmd, va, vb);
}
}
#endif // TEST_GEN
int vlenb;
getmaxvl_b(vlenb);
vdup_b_x(v16, 0);
vdup_b_x(v17, 0);
vdup_b_x(v18, 0);
vdup_b_x(v19, 0);
vdup_b_x(v20, 0);
vdup_b_x(v21, 0);
vdup_b_x(v22, 0);
vdup_b_x(v23, 0);
vdup_b_x(v24, 0);
vdup_w_x_m(v52, 0);
for (int i = 0; i < kZlen; i += vlenb) {
const int j = i / 4;
// dense
const uint8_t* pp = ina[0] + i; // prev
const uint8_t* pc = ina[1] + i; // curr
const uint8_t* pn = ina[2] + i; // next
if (sparsity == 1) {
pp = ina[1] + i - vlenb; // prev
pc = ina[1] + i; // curr
pn = ina[1] + i + vlenb; // next
}
if (sparsity == 2) {
pp = ina[0] + i; // curr
pc = ina[0] + i + vlenb; // next
pn = ina[0] + i + vlenb; // unused
}
switch (regbase) {
case 0:
vld_b_x(v16, pp);
vld_b_x(v17, pc);
vld_b_x(v18, pn);
break;
case 1:
vld_b_x(v17, pp);
vld_b_x(v18, pc);
vld_b_x(v19, pn);
break;
case 2:
vld_b_x(v18, pp);
vld_b_x(v19, pc);
vld_b_x(v20, pn);
break;
case 3:
vld_b_x(v19, pp);
vld_b_x(v20, pc);
vld_b_x(v21, pn);
break;
case 4:
vld_b_x(v20, pp);
vld_b_x(v21, pc);
vld_b_x(v22, pn);
break;
case 5:
vld_b_x(v21, pp);
vld_b_x(v22, pc);
vld_b_x(v23, pn);
break;
case 6:
vld_b_x(v22, pp);
vld_b_x(v23, pc);
vld_b_x(v24, pn);
break;
case 7:
vld_b_x(v17, pp);
vld_b_x(v16, pc);
vld_b_x(v18, pn);
break;
case 8:
vld_b_x(v17, pp);
vld_b_x(v18, pc);
vld_b_x(v16, pn);
break;
case 9:
vld_b_x(v19, pp);
vld_b_x(v20, pc);
vld_b_x(v16, pn);
break;
case 10:
vld_b_x(v21, pp);
vld_b_x(v22, pc);
vld_b_x(v16, pn);
break;
case 11:
vld_b_x(v23, pp);
vld_b_x(v24, pc);
vld_b_x(v16, pn);
break;
case 12:
vld_b_x(v18, pp);
vld_b_x(v16, pc);
vld_b_x(v17, pn);
break;
case 13:
vld_b_x(v20, pp);
vld_b_x(v16, pc);
vld_b_x(v17, pn);
break;
case 14:
vld_b_x(v22, pp);
vld_b_x(v16, pc);
vld_b_x(v17, pn);
break;
case 15:
vld_b_x(v24, pp);
vld_b_x(v16, pc);
vld_b_x(v17, pn);
break;
default:
exit(-1);
break;
}
vld_b_x(v32, inb[0] + i);
vld_b_x(v33, inb[1] + i);
vld_b_x(v34, inb[2] + i);
adwinit_v(v48, v52);
vdwconv_vxv(v48, v16, cmdw, v32);
vst_w_x(v48, dut[0] + j);
vst_w_x(v49, dut[1] + j);
vst_w_x(v50, dut[2] + j);
vst_w_x(v51, dut[3] + j);
}
#ifdef TEST_GEN
// Print the results.
printf("{ %d, %d, %d, %d, ", cmd.sdata1, cmd.sdata2, cmd.sbias1, cmd.sbias2);
printf("{ ");
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < kZlen / 4; ++j) {
printf("0x%lx, ", ref[i][j]);
}
}
printf("} },\n");
#else
// Check the results.
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < kZlen / 4; ++j) {
uint32_t r = ref[i][j];
uint32_t d = dut[i][j];
if (r != d) {
printf("**error::test_dwconv(%d,%d)(%d,%d,%d,%d)[%d,%d] ", cmd.sparsity,
cmd.regbase, cmd.sdata1, cmd.sdata2, cmd.sbias1, cmd.sbias2, i,
j);
printf("0x%lx 0x%lx\n", r, d);
exit(-1);
}
}
}
#endif // TEST_GEN
}
template <int step, bool use_accum>
void test_vdwconv() {
const uint32_t ref1[4] = {0x00001ba4, 0x01021ea8, 0x00061bac, 0x00001bb0};
const uint32_t ref2[4] = {0x00004dbb, 0x010250bf, 0x00064dc3, 0x00004dc7};
const uint32_t ref3[4] = {0x000066af, 0x010269b3, 0x000666b7, 0x000066bb};
const uint32_t* ref = step == 1 ? ref1 : step == 2 ? ref2 : ref3;
uint32_t dut[4];
uint32_t cmdw = 0;
vdup_w_x(v12, 0x00000000);
vdup_w_x(v13, 0x01020304);
vdup_w_x(v14, 0x00060008);
vdup_w_x(v15, 0x0000000c);
vdup_b_x(v16, 23);
vdup_b_x(v17, 34);
vdup_b_x(v18, 45);
vdup_b_x(v32, 56);
vdup_b_x(v33, 67);
vdup_b_x(v34, 78);
vdup_b_x(v36, 76);
vdup_b_x(v37, 65);
vdup_b_x(v38, 54);
adwinit_v(v0, v12);
if (!use_accum) {
vdwconv_vxv(v0, v16, cmdw, v32);
if (step >= 2) {
vdwconv_vxv(v0, v32, cmdw, v36);
}
if (step >= 3) {
vdwconv_vxv(v0, v36, cmdw, v16);
}
} else {
if (step == 1) {
vdwconv_vxv(v0, v16, cmdw, v32);
} else if (step == 2) {
adwconv_vxv(v0, v16, cmdw, v32);
vdwconv_vxv(v0, v32, cmdw, v36);
} else if (step == 3) {
adwconv_vxv(v0, v16, cmdw, v32);
adwconv_vxv(v0, v32, cmdw, v36);
vdwconv_vxv(v0, v36, cmdw, v16);
}
}
vst_w_l_xx(v0, dut + 0, 1);
vst_w_l_xx(v1, dut + 1, 1);
vst_w_l_xx(v2, dut + 2, 1);
vst_w_l_xx(v3, dut + 3, 1);
for (int i = 0; i < 4; ++i) {
if (ref[i] != dut[i]) {
printf("**error::test_dwconv<%d,%d>[%d] 0x%lx 0x%lx\n", step, use_accum,
i, ref[i], dut[i]);
exit(-1);
}
}
}
void test_vdwconv(int sparsity, int regbase, const test_dwconv_t& test) {
uint32_t dut[4][kZlen / 4];
vdwconv_u8_t cmd;
cmd.mode = 0;
cmd.sparsity = sparsity;
cmd.regbase = regbase;
cmd.sbias1 = test.sbias1;
cmd.sdata1 = test.sdata1;
cmd.sbias2 = test.sbias2;
cmd.sdata2 = test.sdata2;
dwconv(cmd, ina_, inb_, test.data, dut);
}
int main() {
#ifdef TEST_GEN
uint32_t* pw_ina = reinterpret_cast<uint32_t*>(ina_);
uint32_t* pw_inb = reinterpret_cast<uint32_t*>(inb_);
uint8_t* pb_ina = reinterpret_cast<uint8_t*>(ina_);
uint8_t* pb_inb = reinterpret_cast<uint8_t*>(inb_);
for (int i = 0; i < 3 * kZlen / 4; ++i) {
pw_ina[i] = krand();
pw_inb[i] = krand();
}
printf("{ ");
for (int i = 0; i < 3 * kZlen; ++i) {
printf("0x%02x, ", pb_ina[i]);
}
printf("};\n");
printf("{ ");
for (int i = 0; i < 3 * kZlen; ++i) {
printf("0x%02x, ", pb_inb[i]);
}
printf("};\n");
#endif // TEST_GEN
// Accumulator test.
test_vdwconv<1, false>();
test_vdwconv<2, false>();
test_vdwconv<3, false>();
test_vdwconv<1, true>();
test_vdwconv<2, true>();
test_vdwconv<3, true>();
// Regbase tests.
for (int i = 0; i < 16; ++i) {
test_vdwconv(0, i, ref_[0]);
}
// Bias tests.
for (int i = 0; i < 20; ++i) {
test_vdwconv(0, 0, ref_[i]);
}
// Sparsity tests.
test_vdwconv(1, 0, ref_sparsity1_[0]);
test_vdwconv(2, 0, ref_sparsity2_[0]);
return 0;
}