| # Copyright 2022 Google LLC |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| .text |
| .globl vector_matmul4_asm |
| .p2align 2 |
| .type vector_matmul4_asm,@function |
| |
| // extern "C" void vector_matmul4_asm(int32_t *out, const int8_t *lhs, |
| // const int8_t *rhs_t, size_t count); |
| // |
| // This function takes in two arrays of 4x4 int8 matrices and multiplies them to |
| // produce an array of 4x4 int32 matrices. The rhs is assumed to be pre- |
| // transposed. |
| // |
| // It will work as-is with VLEN from 64 to 512. Larger is possible, but requires |
| // a different arrangement of gather instructions due to the number of lanes in |
| // a register being larger than the LUT uint8_t element size. Smaller is not |
| // possible because we need to be able to fit at least one matrix in a two- |
| // register group. |
| // |
| // This concept may be extended to 8x8 matrices and will require a minimum VLEN |
| // of 256, but will still be subjected to the 512 upper limit without working |
| // around the uint8_t LUT element limit. |
| |
| // Register use notes: |
| // |
| // a0 int8_t (*out)[count][4][4] |
| // a1 int8_t (*lhs)[count][4][4] |
| // a2 int32_t (*rhs_t)[count][4][4] |
| // a3 count |
| // |
| // t0 VLEN/4 (number of bytes in two registers) |
| // t1 VLEN (number of bytes in eight registers) |
| // t2 avl (number of stripmining lanes for the current loop iteration) |
| // t3 dump (unused except as destination for vsetvli to set vl=vlmax) |
| // |
| // v0- v3 row/col splat LUT |
| // v4- v5 lhs |
| // v6- v7 rhs_t |
| // v8- v9 lhs[k+0] |
| // v10-v11 rhs_t[k+0] |
| // v12-v13 lhs[k+1] |
| // v14-v15 rhs_t[k+1] |
| // v16-v19 mul[k+0] |
| // v20-v23 mul[k+1] |
| // v24-v31 accumulator |
| |
| vector_matmul4_asm: |
| beq zero, a3, 1f |
| slli a3, a3, 4 |
| // Fabricate the row/column splat LUT for vrgather |
| vsetvli t0, zero, e8, m2, ta, ma |
| slli t1, t0, 2 |
| vid.v v0 |
| vid.v v2 |
| vsll.vi v0, v0, 6 |
| vsrl.vi v0, v0, 4 |
| vsrl.vi v2, v2, 4 |
| vsll.vi v2, v2, 4 |
| vadd.vv v2, v0, v2 |
| vadd.vx v2, v2, t0 |
| vid.v v0 |
| vsrl.vi v0, v0, 2 |
| vsll.vi v0, v0, 2 |
| 2: |
| vsetvli t2, a3, e8, m2, ta, ma |
| vle8.v v4, (a1) //2 |
| vle8.v v6, (a2) //2 |
| vsetvli t3, zero, e8, m4, ta, ma |
| vrgather.vv v8, v4, v0 //4 |
| vslide1down.vx v4, v4, zero //4 |
| vrgather.vv v12, v4, v0 //4 |
| vslide1down.vx v4, v4, zero //4 |
| vsetvli t3, zero, e8, m2, ta, ma |
| vwmul.vv v16, v8, v10 //2 |
| vwmul.vv v20, v12, v14 //2 |
| vsetvli t3, zero, e16, m4, ta, ma |
| vwadd.vv v24, v16, v20 //4 |
| vsetvli t3, zero, e8, m4, ta, ma |
| vrgather.vv v8, v4, v0 //4 |
| vslide1down.vx v4, v4, zero //4 |
| vrgather.vv v12, v4, v0 //4 |
| vsetvli t3, zero, e8, m2, ta, ma |
| vwmul.vv v16, v8, v10 //2 |
| vwmul.vv v20, v12, v14 //2 |
| vsetvli t3, zero, e16, m4, ta, ma |
| vwadd.wv v24, v24, v16 //4 |
| vwadd.wv v24, v24, v20 //4 |
| vsetvli zero, t2, e32, m8, ta, ma |
| vse32.v v24, (a0) //8 |
| sub a3, a3, t2 |
| add a1, a1, t0 |
| add a2, a2, t0 |
| add a0, a0, t1 |
| bne zero, a3, 2b |
| 1: |
| ret |