blob: e552d52186917934b2fb999553938f1f50702dba [file] [log] [blame]
Lun Dong3b8d3cb2024-05-07 01:50:35 -07001/*
2 * Copyright 2024 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef TFLM_OPT_CONV_UTIL_H_
18#define TFLM_OPT_CONV_UTIL_H_
19
20#include <cassert>
21#include <memory>
22
23#include "crt/kelvin.h"
24#include "tensorflow/lite/kernels/internal/common.h"
25#include "tensorflow/lite/kernels/internal/runtime_shape.h"
26#include "tensorflow/lite/kernels/internal/types.h"
27#include "tflm/opt/util.h"
28
29namespace kelvin::opt {
30/* clang-format off */
31constexpr const int swizzle[16] = {
32 0, 4, 8, 12,
33 2, 6, 10, 14,
34 1, 5, 9, 13,
35 3, 7, 11, 15,
36};
37/* clang-format on */
38
39constexpr int kFilterHeightIndex = 1;
40constexpr int kFilterWidthIndex = 2;
41constexpr int kFilterInputChannelIndex = 3;
42constexpr int kInputChannelIndex = 3;
43constexpr int kOutputChannelIndex = 3;
44
45#define INA0 v0
46#define FLTA0 v8
47#define FLTA1 v9
48#define FLTA2 v10
49#define FLTA3 v11
50#define FLTA4 v12
51#define FLTA5 v13
52#define FLTA6 v14
53#define FLTA7 v15
54#define ACC v48
55#define ACC0 v48
56#define OUT0 v56
57
58// H,W ( height and width of filter) N -number of inputs, M -number of outputs
59template <int N>
60inline void Filter_N_H_W_M(const int8_t* input, int8_t* output, int H, int W,
61 int M) {
62 // Convert: input [zo][ky][kx][zi] (N,3,1,M)
63 // output [zo.hi=N/8][ky][kx][zi_hi=M/4][zo.lo=8][zi_lo=4]
64 const int8_t(&in)[N][H][W][M] = *(int8_t(*)[N][H][W][M])input;
65 int8_t(&out)[N / 8][H][W][M / 4][8][4] =
66 *(int8_t(*)[N / 8][H][W][M / 4][8][4]) output;
67 assert(N >= 4 && M >= 4);
68 for (int zo = 0; zo < N; ++zo) {
69 for (int ky = 0; ky < H; ++ky) {
70 for (int kx = 0; kx < W; ++kx) {
71 for (int zi = 0; zi < M; ++zi) {
72 const int zo_hi = zo >> 3; // div8
73 const int zo_lo = zo & 7; // rem8
74 const int zi_hi = zi >> 2; // div4
75 const int zi_lo = zi & 3; // rem4
76 out[zo_hi][ky][kx][zi_hi][zo_lo][zi_lo] = in[zo][ky][kx][zi];
77 }
78 }
79 }
80 }
81}
82
83// Swizzle values, and duplicate 4 times for stripmining.
84inline void Swizzle(const int32_t* input, int32_t* output, int N,
85 bool negate = false) {
86 const int32_t(&in)[N] = *(int32_t(*)[N])input;
87 int32_t(&out)[N * 4] = *(int32_t(*)[N * 4]) output;
88 // Convert to accumulator swizzle pattern.
89 for (int i = 0; i < N / 8; ++i) {
90 int32_t* out0 = out + i * 32 + 0;
91 int32_t* out1 = out + i * 32 + 16;
92 int32_t* out2 = out + i * 32 + 8;
93 int32_t* out3 = out + i * 32 + 24;
94 for (int j = 0; j < 4; ++j) {
95 const int32_t* p_in = in + i * 8;
96 for (int k = 0; k < 2; ++k) {
97 *out0++ = *p_in++;
98 *out1++ = *p_in++;
99 *out2++ = *p_in++;
100 *out3++ = *p_in++;
101 }
102 }
103 }
104 if (negate) {
105 for (int i = 0; i < N * 4; ++i) {
106 out[i] = -out[i];
107 }
108 }
109}
110
Derek Chow2aeaaa82024-05-14 16:32:25 -0700111// Runs strip-mined output pipeline (without bias addition) in place on
112// registers.
113#define INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(result, mult, shft, output_min, \
114 output_max, output_offset) \
115 { \
116 vdmulh_w_rn_vv_m(result, result, mult); \
117 vsha_w_r_vv_m(result, result, shft); \
118 vadd_w_vx_m(result, result, output_offset); \
119 vmax_w_vx_m(result, result, output_activation_min); \
120 vmin_w_vx_m(result, result, output_activation_max); \
121 }
122
Lun Dong3b8d3cb2024-05-07 01:50:35 -0700123// Run output pipeline on int32 accumulators in [v48-v55] and store results
124// in v48 and v52. Clobbers [v48-v55].
125#define INT32_TO_INT8_OUTPUT_PIPELINE(bias, mult, shft, output_min, \
126 output_max, output_offset, bias_reg, \
127 mult_reg, shift_reg) \
128 { \
129 vcget(v48); \
130 vld_w_x_m(bias_reg, bias); \
131 vld_w_x_m(mult_reg, mult); \
132 vld_w_x_m(shift_reg, shft); \
133 vadd_w_vv_m(v48, v48, bias_reg); \
134 vadd_w_vv_m(v52, v52, bias_reg); \
135 vmin_w_vx_m(v48, v48, output_max); \
136 vmax_w_vx_m(v52, v52, output_min); \
137 vdmulh_w_r_vv_m(v48, v48, mult_reg); \
138 vdmulh_w_r_vv_m(v52, v52, mult_reg); \
139 vsha_w_r_vv_m(v48, v48, shift_reg); \
140 vsha_w_r_vv_m(v52, v52, shift_reg); \
141 vadd_w_vx_m(v48, v48, output_offset); \
142 vadd_w_vx_m(v52, v52, output_offset); \
143 vsraqs_b_vx(v48, v48, 0); \
144 vsraqs_b_vx(v52, v52, 0); \
145 }
146} // namespace kelvin::opt
147
148#endif // TFLM_OPT_CONV_UTIL_H_