Initial commit of vector_matmul4_asm_test

This adds a test that uses a hand-written optimized 4x4 matrix
multiplication function.

Change-Id: I7ed0a66f49ed8b15ea175607a9c8221b9d98cf25
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 3dbdb92..ac3e2a9 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -36,6 +36,7 @@
 add_subdirectory(vector_vadd_vsub_tests)
 add_subdirectory(vector_executive)
 add_subdirectory(vector_vset_tests)
+add_subdirectory(vector_matmul4_asm_test)
 
 add_subdirectory(pw_unit_test_demo)
 
diff --git a/vector_matmul4_asm_test/CMakeLists.txt b/vector_matmul4_asm_test/CMakeLists.txt
new file mode 100644
index 0000000..7b797a3
--- /dev/null
+++ b/vector_matmul4_asm_test/CMakeLists.txt
@@ -0,0 +1,7 @@
+vec_cc_binary(
+    NAME
+      vector_matmul4_asm_test
+    SRCS
+      main.cpp
+      vector_matmul4_asm.S
+)
diff --git a/vector_matmul4_asm_test/main.cpp b/vector_matmul4_asm_test/main.cpp
new file mode 100644
index 0000000..d04624f
--- /dev/null
+++ b/vector_matmul4_asm_test/main.cpp
@@ -0,0 +1,116 @@
+#include <cinttypes>
+#include <climits>
+#include <cstdint>
+#include <cstdio>
+#include <cstdlib>
+#include <random>
+
+//#define PRINT_INPUTS_AND_OUTPUTS (1)
+
+#ifndef PRINT_INPUTS_AND_OUTPUTS
+#define PRINT_INPUTS_AND_OUTPUTS (0)
+#endif
+
+extern "C" void vector_matmul4_asm(int32_t *out, const int8_t *lhs, const int8_t *rhs_t, std::size_t count);
+
+extern "C" int main(void) {
+  int8_t lhs[16*37];
+  int8_t rhs_t[16*37];
+  int32_t result[sizeof(lhs)+16];
+  int32_t golden[sizeof(lhs)+16];
+  std::default_random_engine generator;
+  std::uniform_int_distribution<int8_t> distribution(INT8_MIN, INT8_MAX);
+
+  for (std::size_t i = 0; i < sizeof(lhs); i++) {
+    lhs[i] = distribution(generator);
+    rhs_t[i] = distribution(generator);
+  }
+
+  // One extra guard matrix to ensure the assembly doesn't go past the end
+  for (std::size_t i = sizeof(lhs); i < sizeof(lhs)+16; i++) {
+    result[i] = 1337;
+    golden[i] = 1337;
+  }
+
+  vector_matmul4_asm(result, lhs, rhs_t, sizeof(lhs)/16);
+
+  for (std::size_t b = 0; b < sizeof(lhs)/16; b++) {
+    for (int j = 0; j < 4; j++) {
+      for (int i = 0; i < 4; i++) {
+        int32_t acc = 0;
+        for (int k = 0; k < 4; k++) {
+          acc += lhs[k+j*4+b*16] * rhs_t[k+i*4+b*16];
+        }
+        golden[i+j*4+b*16] = acc;
+      }
+    }
+  }
+
+  std::size_t errors = 0;
+  for (std::size_t b = 0; b < sizeof(result)/sizeof(int32_t)/16; b++) {
+    for (int j = 0; j < 4; j++) {
+      for (int i = 0; i < 4; i++) {
+        errors += result[i+4*j+b*16] == golden[i+4*j+b*16]? 0 : 1;
+      }
+    }
+  }
+
+  if (PRINT_INPUTS_AND_OUTPUTS) {
+    printf("lhs:\n");
+    for (std::size_t b = 0; b < sizeof(lhs)/sizeof(int8_t)/16; b++) {
+      printf("b = %d:\n",b);
+      for (int j = 0; j < 4; j++) {
+        printf("    ");
+        for (int i = 0; i < 4; i++) {
+          printf("%5d,", (int)lhs[i+4*j+b*16]);
+        }
+        printf("\n");
+      }
+      printf("\n");
+    }
+
+    printf("rhs_t:\n");
+    for (std::size_t b = 0; b < sizeof(rhs_t)/sizeof(int8_t)/16; b++) {
+      printf("b = %d:\n",b);
+      for (int j = 0; j < 4; j++) {
+        printf("    ");
+        for (int i = 0; i < 4; i++) {
+          printf("%5d,", (int)rhs_t[i+4*j+b*16]);
+        }
+        printf("\n");
+      }
+      printf("\n");
+    }
+
+    printf("golden:\n");
+    for (std::size_t b = 0; b < sizeof(golden)/sizeof(int32_t)/16; b++) {
+      printf("b = %d:\n",b);
+      for (int j = 0; j < 4; j++) {
+        printf("    ");
+        for (int i = 0; i < 4; i++) {
+          printf("%7d,", (int)golden[i+4*j+b*16]);
+        }
+        printf("\n");
+      }
+      printf("\n");
+    }
+
+    printf("\nresults:\n");
+    for (std::size_t b = 0; b < sizeof(result)/sizeof(int32_t)/16; b++) {
+      printf("b = %d:\n",b);
+      for (int j = 0; j < 4; j++) {
+        printf("    ");
+        for (int i = 0; i < 4; i++) {
+          bool same = result[i+4*j+b*16] == golden[i+4*j+b*16];
+          printf("%7d%c", (int)result[i+4*j+b*16], same? ',' : '/');
+        }
+        printf("\n");
+      }
+      printf("\n");
+    }
+
+    printf("\n%d errors\n", errors);
+  }
+
+  return (errors > INT_MAX)? INT_MAX : (int)errors;
+}
diff --git a/vector_matmul4_asm_test/vector_matmul4_asm.S b/vector_matmul4_asm_test/vector_matmul4_asm.S
new file mode 100644
index 0000000..2290179
--- /dev/null
+++ b/vector_matmul4_asm_test/vector_matmul4_asm.S
@@ -0,0 +1,95 @@
+        .text
+        .globl vector_matmul4_asm
+        .p2align 2
+        .type vector_matmul4_asm,@function
+
+// extern "C" void vector_matmul4_asm(int32_t *out, const int8_t *lhs,
+//                                    const int8_t *rhs_t, size_t count);
+//
+// This function takes in two arrays of 4x4 int8 matrices and multiplies them to
+// produce an array of 4x4 int32 matrices. The rhs is assumed to be pre-
+// transposed.
+//
+// It will work as-is with VLEN from 64 to 512. Larger is possible, but requires
+// a different arrangement of gather instructions due to the number of lanes in
+// a register being larger than the LUT uint8_t element size. Smaller is not
+// possible because we need to be able to fit at least one matrix in a two-
+// register group.
+//
+// This concept may be extended to 8x8 matrices and will require a minimum VLEN
+// of 256, but will still be subjected to the 512 upper limit without working
+// around the uint8_t LUT element limit.
+
+// Register use notes:
+//
+//    a0   int8_t (*out)[count][4][4]
+//    a1   int8_t (*lhs)[count][4][4]
+//    a2   int32_t (*rhs_t)[count][4][4]
+//    a3   count
+//
+//    t0   VLEN/4 (number of bytes in two registers)
+//    t1   VLEN   (number of bytes in eight registers)
+//    t2   avl    (number of stripmining lanes for the current loop iteration)
+//    t3   dump   (unused except as destination for vsetvli to set vl=vlmax)
+//
+//  v0- v3 row/col splat LUT
+//  v4- v5 lhs
+//  v6- v7 rhs_t
+//  v8- v9 lhs[k+0]
+// v10-v11 rhs_t[k+0]
+// v12-v13 lhs[k+1]
+// v14-v15 rhs_t[k+1]
+// v16-v19 mul[k+0]
+// v20-v23 mul[k+1]
+// v24-v31 accumulator
+
+vector_matmul4_asm:
+        beq            zero,   a3,   1f
+        slli             a3,   a3,    4
+        // Fabricate the row/column splat LUT for vrgather
+        vsetvli          t0, zero,   e8, m2, ta, ma
+        slli             t1,   t0,    2
+        vid.v            v0
+        vid.v            v2
+        vsll.vi          v0,   v0,    6
+        vsrl.vi          v0,   v0,    4
+        vsrl.vi          v2,   v2,    4
+        vsll.vi          v2,   v2,    4
+        vadd.vv          v2,   v0,   v2
+        vadd.vx          v2,   v2,   t0
+        vid.v            v0
+        vsrl.vi          v0,   v0,    2
+        vsll.vi          v0,   v0,    2
+2:
+        vsetvli          t2,   a3,   e8, m2, ta, ma
+        vle8.v           v4, (a1)                       //2
+        vle8.v           v6, (a2)                       //2
+        vsetvli          t3, zero,   e8, m4, ta, ma
+        vrgather.vv      v8,   v4,   v0                 //4
+        vslide1down.vx   v4,   v4, zero                 //4
+        vrgather.vv     v12,   v4,   v0                 //4
+        vslide1down.vx   v4,   v4, zero                 //4
+        vsetvli          t3, zero,   e8, m2, ta, ma
+        vwmul.vv        v16,   v8,  v10                 //2
+        vwmul.vv        v20,  v12,  v14                 //2
+        vsetvli          t3, zero,  e16, m4, ta, ma
+        vwadd.vv        v24,  v16,  v20                 //4
+        vsetvli          t3, zero,   e8, m4, ta, ma
+        vrgather.vv      v8,   v4,   v0                 //4
+        vslide1down.vx   v4,   v4, zero                 //4
+        vrgather.vv     v12,   v4,   v0                 //4
+        vsetvli          t3, zero,   e8, m2, ta, ma
+        vwmul.vv        v16,   v8,  v10                 //2
+        vwmul.vv        v20,  v12,  v14                 //2
+        vsetvli          t3, zero,  e16, m4, ta, ma
+        vwadd.wv        v24,  v24,  v16                 //4
+        vwadd.wv        v24,  v24,  v20                 //4
+        vsetvli        zero,   t2,  e32, m8, ta, ma
+        vse32.v         v24, (a0)                       //8
+        sub              a3,   a3,   t2
+        add              a1,   a1,   t0
+        add              a2,   a2,   t0
+        add              a0,   a0,   t1
+        bne            zero,   a3,   2b
+1:
+        ret