blob: b4c618ab158fc6710d2b28cc12da426099db8ffb [file] [log] [blame]
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
.text
.globl vector_matmul4_asm
.p2align 2
.type vector_matmul4_asm,@function
// extern "C" void vector_matmul4_asm(int32_t *out, const int8_t *lhs,
// const int8_t *rhs_t, size_t count);
//
// This function takes in two arrays of 4x4 int8 matrices and multiplies them to
// produce an array of 4x4 int32 matrices. The rhs is assumed to be pre-
// transposed.
//
// It will work as-is with VLEN from 64 to 512. Larger is possible, but requires
// a different arrangement of gather instructions due to the number of lanes in
// a register being larger than the LUT uint8_t element size. Smaller is not
// possible because we need to be able to fit at least one matrix in a two-
// register group.
//
// This concept may be extended to 8x8 matrices and will require a minimum VLEN
// of 256, but will still be subjected to the 512 upper limit without working
// around the uint8_t LUT element limit.
// Register use notes:
//
// a0 int8_t (*out)[count][4][4]
// a1 int8_t (*lhs)[count][4][4]
// a2 int32_t (*rhs_t)[count][4][4]
// a3 count
//
// t0 VLEN/4 (number of bytes in two registers)
// t1 VLEN (number of bytes in eight registers)
// t2 avl (number of stripmining lanes for the current loop iteration)
// t3 dump (unused except as destination for vsetvli to set vl=vlmax)
//
// v0- v3 row/col splat LUT
// v4- v5 lhs
// v6- v7 rhs_t
// v8- v9 lhs[k+0]
// v10-v11 rhs_t[k+0]
// v12-v13 lhs[k+1]
// v14-v15 rhs_t[k+1]
// v16-v19 mul[k+0]
// v20-v23 mul[k+1]
// v24-v31 accumulator
vector_matmul4_asm:
beq zero, a3, 1f
slli a3, a3, 4
// Fabricate the row/column splat LUT for vrgather
vsetvli t0, zero, e8, m2, ta, ma
slli t1, t0, 2
vid.v v0
vid.v v2
vsll.vi v0, v0, 6
vsrl.vi v0, v0, 4
vsrl.vi v2, v2, 4
vsll.vi v2, v2, 4
vadd.vv v2, v0, v2
vadd.vx v2, v2, t0
vid.v v0
vsrl.vi v0, v0, 2
vsll.vi v0, v0, 2
2:
vsetvli t2, a3, e8, m2, ta, ma
vle8.v v4, (a1) //2
vle8.v v6, (a2) //2
vsetvli t3, zero, e8, m4, ta, ma
vrgather.vv v8, v4, v0 //4
vslide1down.vx v4, v4, zero //4
vrgather.vv v12, v4, v0 //4
vslide1down.vx v4, v4, zero //4
vsetvli t3, zero, e8, m2, ta, ma
vwmul.vv v16, v8, v10 //2
vwmul.vv v20, v12, v14 //2
vsetvli t3, zero, e16, m4, ta, ma
vwadd.vv v24, v16, v20 //4
vsetvli t3, zero, e8, m4, ta, ma
vrgather.vv v8, v4, v0 //4
vslide1down.vx v4, v4, zero //4
vrgather.vv v12, v4, v0 //4
vsetvli t3, zero, e8, m2, ta, ma
vwmul.vv v16, v8, v10 //2
vwmul.vv v20, v12, v14 //2
vsetvli t3, zero, e16, m4, ta, ma
vwadd.wv v24, v24, v16 //4
vwadd.wv v24, v24, v20 //4
vsetvli zero, t2, e32, m8, ta, ma
vse32.v v24, (a0) //8
sub a3, a3, t2
add a1, a1, t0
add a2, a2, t0
add a0, a0, t1
bne zero, a3, 2b
1:
ret