blob: d50981d7e8894eff66ea479f7ed4c45d2531bf47 [file] [log] [blame]
#include <riscv_vector.h>
#include <stdint.h>
constexpr size_t kLhsRows = 16;
constexpr size_t kRhsCols = 16;
constexpr size_t kInner = 24;
int8_t lhs_input[kLhsRows*kInner] __attribute__((section(".data"))) __attribute__((aligned(16)));
int8_t rhs_input[kInner*kRhsCols] __attribute__((section(".data"))) __attribute__((aligned(16)));
int32_t result_output[kLhsRows*kRhsCols] __attribute__((section(".data"))) __attribute__((aligned(16)));
// Assume rhs is column major.
void MatMul(size_t lhs_rows, size_t inner, size_t rhs_cols,
const int8_t* lhs, const int8_t* rhs, int32_t* result) {
const size_t vlenb = __riscv_vlenb();
// Create zero register for vredsum
asm("vsetvli zero, %0, e32, m4, ta, ma;"
"vmv.v.i v0, 0;" : : "r" (vlenb));
for (size_t r = 0; r < lhs_rows; r++) {
const int8_t* lhs_data = lhs + (r * inner);
int32_t* result_row = result + (r * rhs_cols);
for (size_t c = 0; c < rhs_cols; c++) {
const int8_t* rhs_data = rhs + (c * inner);
// Reset accumulators
asm("vsetvli zero, %0, e32, m4, ta, ma" : : "r" (vlenb));
asm("vmv.v.i v8, 0");
// Inner dot product loop
size_t k = 0;
size_t vl = vlenb;
while (k < inner) {
if (inner - k < vl) {
vl = inner - k;
}
// Load weights/activations
asm("vsetvli zero, %0, e8, m1, ta, ma" : : "r" (vl));
asm("vle8.v v14, (%0)" : : "r" (lhs_data + k));
asm("vle8.v v15, (%0)" : : "r" (rhs_data + k));
// Multiply-accumulate
asm("vsetvli zero, %0, e8, m1, ta, ma;"
"vwmul.vv v12, v14, v15;"
"vsetvli zero, %0, e16, m2, ta, ma;"
"vwadd.wv v8, v8, v12;" : : "r" (vl));
k += vl;
}
// Reduction
asm("vsetvli zero, %0, e32, m4, ta, ma;"
"vredsum.vs v8, v8, v0;" : : "r" (vlenb));
// Store
asm("vsetivli zero, 1, e32, m1, ta, ma;"
"vse32.v v8, (%0);" : : "r" (result_row + c));
}
}
}
int main() {
MatMul(kLhsRows, kInner, kRhsCols, lhs_input, rhs_input, result_output);
return 0;
}