Initial commit of vector_matmul4_asm_test This adds a test that uses a hand-written optimized 4x4 matrix multiplication function. Change-Id: I7ed0a66f49ed8b15ea175607a9c8221b9d98cf25
diff --git a/CMakeLists.txt b/CMakeLists.txt index 3dbdb92..ac3e2a9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt
@@ -36,6 +36,7 @@ add_subdirectory(vector_vadd_vsub_tests) add_subdirectory(vector_executive) add_subdirectory(vector_vset_tests) +add_subdirectory(vector_matmul4_asm_test) add_subdirectory(pw_unit_test_demo)
diff --git a/vector_matmul4_asm_test/CMakeLists.txt b/vector_matmul4_asm_test/CMakeLists.txt new file mode 100644 index 0000000..7b797a3 --- /dev/null +++ b/vector_matmul4_asm_test/CMakeLists.txt
@@ -0,0 +1,7 @@ +vec_cc_binary( + NAME + vector_matmul4_asm_test + SRCS + main.cpp + vector_matmul4_asm.S +)
diff --git a/vector_matmul4_asm_test/main.cpp b/vector_matmul4_asm_test/main.cpp new file mode 100644 index 0000000..d04624f --- /dev/null +++ b/vector_matmul4_asm_test/main.cpp
@@ -0,0 +1,116 @@ +#include <cinttypes> +#include <climits> +#include <cstdint> +#include <cstdio> +#include <cstdlib> +#include <random> + +//#define PRINT_INPUTS_AND_OUTPUTS (1) + +#ifndef PRINT_INPUTS_AND_OUTPUTS +#define PRINT_INPUTS_AND_OUTPUTS (0) +#endif + +extern "C" void vector_matmul4_asm(int32_t *out, const int8_t *lhs, const int8_t *rhs_t, std::size_t count); + +extern "C" int main(void) { + int8_t lhs[16*37]; + int8_t rhs_t[16*37]; + int32_t result[sizeof(lhs)+16]; + int32_t golden[sizeof(lhs)+16]; + std::default_random_engine generator; + std::uniform_int_distribution<int8_t> distribution(INT8_MIN, INT8_MAX); + + for (std::size_t i = 0; i < sizeof(lhs); i++) { + lhs[i] = distribution(generator); + rhs_t[i] = distribution(generator); + } + + // One extra guard matrix to ensure the assembly doesn't go past the end + for (std::size_t i = sizeof(lhs); i < sizeof(lhs)+16; i++) { + result[i] = 1337; + golden[i] = 1337; + } + + vector_matmul4_asm(result, lhs, rhs_t, sizeof(lhs)/16); + + for (std::size_t b = 0; b < sizeof(lhs)/16; b++) { + for (int j = 0; j < 4; j++) { + for (int i = 0; i < 4; i++) { + int32_t acc = 0; + for (int k = 0; k < 4; k++) { + acc += lhs[k+j*4+b*16] * rhs_t[k+i*4+b*16]; + } + golden[i+j*4+b*16] = acc; + } + } + } + + std::size_t errors = 0; + for (std::size_t b = 0; b < sizeof(result)/sizeof(int32_t)/16; b++) { + for (int j = 0; j < 4; j++) { + for (int i = 0; i < 4; i++) { + errors += result[i+4*j+b*16] == golden[i+4*j+b*16]? 0 : 1; + } + } + } + + if (PRINT_INPUTS_AND_OUTPUTS) { + printf("lhs:\n"); + for (std::size_t b = 0; b < sizeof(lhs)/sizeof(int8_t)/16; b++) { + printf("b = %d:\n",b); + for (int j = 0; j < 4; j++) { + printf(" "); + for (int i = 0; i < 4; i++) { + printf("%5d,", (int)lhs[i+4*j+b*16]); + } + printf("\n"); + } + printf("\n"); + } + + printf("rhs_t:\n"); + for (std::size_t b = 0; b < sizeof(rhs_t)/sizeof(int8_t)/16; b++) { + printf("b = %d:\n",b); + for (int j = 0; j < 4; j++) { + printf(" "); + for (int i = 0; i < 4; i++) { + printf("%5d,", (int)rhs_t[i+4*j+b*16]); + } + printf("\n"); + } + printf("\n"); + } + + printf("golden:\n"); + for (std::size_t b = 0; b < sizeof(golden)/sizeof(int32_t)/16; b++) { + printf("b = %d:\n",b); + for (int j = 0; j < 4; j++) { + printf(" "); + for (int i = 0; i < 4; i++) { + printf("%7d,", (int)golden[i+4*j+b*16]); + } + printf("\n"); + } + printf("\n"); + } + + printf("\nresults:\n"); + for (std::size_t b = 0; b < sizeof(result)/sizeof(int32_t)/16; b++) { + printf("b = %d:\n",b); + for (int j = 0; j < 4; j++) { + printf(" "); + for (int i = 0; i < 4; i++) { + bool same = result[i+4*j+b*16] == golden[i+4*j+b*16]; + printf("%7d%c", (int)result[i+4*j+b*16], same? ',' : '/'); + } + printf("\n"); + } + printf("\n"); + } + + printf("\n%d errors\n", errors); + } + + return (errors > INT_MAX)? INT_MAX : (int)errors; +}
diff --git a/vector_matmul4_asm_test/vector_matmul4_asm.S b/vector_matmul4_asm_test/vector_matmul4_asm.S new file mode 100644 index 0000000..2290179 --- /dev/null +++ b/vector_matmul4_asm_test/vector_matmul4_asm.S
@@ -0,0 +1,95 @@ + .text + .globl vector_matmul4_asm + .p2align 2 + .type vector_matmul4_asm,@function + +// extern "C" void vector_matmul4_asm(int32_t *out, const int8_t *lhs, +// const int8_t *rhs_t, size_t count); +// +// This function takes in two arrays of 4x4 int8 matrices and multiplies them to +// produce an array of 4x4 int32 matrices. The rhs is assumed to be pre- +// transposed. +// +// It will work as-is with VLEN from 64 to 512. Larger is possible, but requires +// a different arrangement of gather instructions due to the number of lanes in +// a register being larger than the LUT uint8_t element size. Smaller is not +// possible because we need to be able to fit at least one matrix in a two- +// register group. +// +// This concept may be extended to 8x8 matrices and will require a minimum VLEN +// of 256, but will still be subjected to the 512 upper limit without working +// around the uint8_t LUT element limit. + +// Register use notes: +// +// a0 int8_t (*out)[count][4][4] +// a1 int8_t (*lhs)[count][4][4] +// a2 int32_t (*rhs_t)[count][4][4] +// a3 count +// +// t0 VLEN/4 (number of bytes in two registers) +// t1 VLEN (number of bytes in eight registers) +// t2 avl (number of stripmining lanes for the current loop iteration) +// t3 dump (unused except as destination for vsetvli to set vl=vlmax) +// +// v0- v3 row/col splat LUT +// v4- v5 lhs +// v6- v7 rhs_t +// v8- v9 lhs[k+0] +// v10-v11 rhs_t[k+0] +// v12-v13 lhs[k+1] +// v14-v15 rhs_t[k+1] +// v16-v19 mul[k+0] +// v20-v23 mul[k+1] +// v24-v31 accumulator + +vector_matmul4_asm: + beq zero, a3, 1f + slli a3, a3, 4 + // Fabricate the row/column splat LUT for vrgather + vsetvli t0, zero, e8, m2, ta, ma + slli t1, t0, 2 + vid.v v0 + vid.v v2 + vsll.vi v0, v0, 6 + vsrl.vi v0, v0, 4 + vsrl.vi v2, v2, 4 + vsll.vi v2, v2, 4 + vadd.vv v2, v0, v2 + vadd.vx v2, v2, t0 + vid.v v0 + vsrl.vi v0, v0, 2 + vsll.vi v0, v0, 2 +2: + vsetvli t2, a3, e8, m2, ta, ma + vle8.v v4, (a1) //2 + vle8.v v6, (a2) //2 + vsetvli t3, zero, e8, m4, ta, ma + vrgather.vv v8, v4, v0 //4 + vslide1down.vx v4, v4, zero //4 + vrgather.vv v12, v4, v0 //4 + vslide1down.vx v4, v4, zero //4 + vsetvli t3, zero, e8, m2, ta, ma + vwmul.vv v16, v8, v10 //2 + vwmul.vv v20, v12, v14 //2 + vsetvli t3, zero, e16, m4, ta, ma + vwadd.vv v24, v16, v20 //4 + vsetvli t3, zero, e8, m4, ta, ma + vrgather.vv v8, v4, v0 //4 + vslide1down.vx v4, v4, zero //4 + vrgather.vv v12, v4, v0 //4 + vsetvli t3, zero, e8, m2, ta, ma + vwmul.vv v16, v8, v10 //2 + vwmul.vv v20, v12, v14 //2 + vsetvli t3, zero, e16, m4, ta, ma + vwadd.wv v24, v24, v16 //4 + vwadd.wv v24, v24, v20 //4 + vsetvli zero, t2, e32, m8, ta, ma + vse32.v v24, (a0) //8 + sub a3, a3, t2 + add a1, a1, t0 + add a2, a2, t0 + add a0, a0, t1 + bne zero, a3, 2b +1: + ret