blob: aa7b6bb3d3fc4e97fd15b52a2e0acaa822d18ed8 [file] [log] [blame]
Naveen Doddabe4ab972024-04-17 17:47:46 +00001/*
Lun Dong3b8d3cb2024-05-07 01:50:35 -07002 * Copyright 2024 Google LLC
Naveen Doddabe4ab972024-04-17 17:47:46 +00003 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Lun Dong3b8d3cb2024-05-07 01:50:35 -070017// Depthwise convolution based on Kelvin ops
18// Data types: input: s8, filter: s8, bias s32
Naveen Doddabe4ab972024-04-17 17:47:46 +000019
Naveen Doddabe4ab972024-04-17 17:47:46 +000020#include "tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h"
Lun Dong3b8d3cb2024-05-07 01:50:35 -070021#include "tflm/opt/conv_util.h"
Naveen Doddabe4ab972024-04-17 17:47:46 +000022
23namespace kelvin::opt {
Lun Dong3b8d3cb2024-05-07 01:50:35 -070024namespace {
Alex Van Dammecd3d0e32024-05-10 15:27:06 -070025
Alex Van Damme40a83002024-05-08 16:47:03 -070026// Reorders a vector to match the pattern after double-widening.
27// N must be a multiple of 4.
28void VectorSwizzle(const int32_t* input, int32_t* output, int N) {
29 assert(N >= 4 && N % 4 == 0);
30 const int32_t(&in)[N] = *(int32_t(*)[N])input;
31 int32_t(&out)[N] = *(int32_t(*)[N]) output;
32 const int32_t* p_in = in;
33 for (int i = 0; i < N / 4; ++i) {
34 int32_t* out0 = out + i + 0;
35 int32_t* out1 = out + i + 16;
36 int32_t* out2 = out + i + 8;
37 int32_t* out3 = out + i + 24;
38 *out0 = *p_in++;
39 *out1 = *p_in++;
40 *out2 = *p_in++;
41 *out3 = *p_in++;
Alex Van Dammeb1afda62024-05-09 16:48:40 -070042 }
43}
Alex Van Damme088841b2024-06-03 16:16:54 -070044
45// special case of input depth = 32n, filter shape of 3x3, strides of 1
46void DepthwiseConvS83x3D32_Stride1(
47 const tflite::DepthwiseParams& params, const int32_t* output_multiplier,
48 const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
49 const int8_t* input_data, const tflite::RuntimeShape& filter_shape,
50 const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
51 const int32_t* bias_data, const tflite::RuntimeShape& output_shape,
52 int8_t* output_data
53) {
54 const int stride_width = params.stride_width;
55 const int stride_height = params.stride_height;
56 const int pad_width = params.padding_values.width;
57 const int pad_height = params.padding_values.height;
58 const int32_t input_offset = params.input_offset;
59 const int32_t output_offset = params.output_offset;
60 const int32_t output_activation_min = params.quantized_activation_min;
61 const int32_t output_activation_max = params.quantized_activation_max;
62 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
63 const int input_height = input_shape.Dims(1);
64 const int input_width = input_shape.Dims(2);
65 const int input_depth = input_shape.Dims(3);
66 const int output_height = output_shape.Dims(1);
67 const int output_width = output_shape.Dims(2);
68 const int output_depth = output_shape.Dims(3);
69 int32_t swizzled_bias_data[32];
70 int32_t swizzled_shift_multi[32];
71 int32_t swizzled_output_multi[32];
72
73 for (int in_channel = 0; in_channel + 32 <= input_depth; in_channel += 32) {
74 const int output_channel = in_channel;
75 int8_t* p_output = output_data + output_channel;
76 VectorSwizzle(bias_data + output_channel, swizzled_bias_data, 32);
77 VectorSwizzle(output_multiplier + output_channel, swizzled_output_multi, 32);
78 VectorSwizzle(output_shift + output_channel, swizzled_shift_multi, 32);
79
80 vld_w_x_m(v52, swizzled_bias_data);
81 vld_w_x_m(v56, swizzled_output_multi);
82 vld_w_x_m(v60, swizzled_shift_multi);
83 vrsub_w_vx_m(v60, v60, 0);
84
85 union {
86 vdwconv_u8_t dwconv;
87 uint32_t raw;
88 } cmds;
89 cmds.raw = 0;
90 cmds.dwconv.sdata1 = true;
91 cmds.dwconv.sbias1 = input_offset;
92 cmds.dwconv.sdata2 = true;
93 cmds.dwconv.sbias2 = 0;
94 cmds.dwconv.mode = 0;
95 cmds.dwconv.sparsity = 0;
96 cmds.dwconv.regbase = 0;
97
Alex Van Dammed22d7d62024-06-07 10:53:00 -070098#define FLT_0_0 v0
99#define FLT_0_1 v3
100#define FLT_0_2 v6
101#define FLT_1_0 v1
102#define FLT_1_1 v4
103#define FLT_1_2 v7
104#define FLT_2_0 v2
105#define FLT_2_1 v5
106#define FLT_2_2 v8
107
108#define INPUT_0_0 v9
109#define INPUT_0_1 v12
110#define INPUT_0_2 v15
111#define INPUT_0_3 v18
112#define INPUT_0_4 v21
113#define INPUT_0_5 v24
114#define INPUT_1_0 v10
115#define INPUT_1_1 v13
116#define INPUT_1_2 v16
117#define INPUT_1_3 v19
118#define INPUT_1_4 v22
119#define INPUT_1_5 v25
120#define INPUT_2_0 v11
121#define INPUT_2_1 v14
122#define INPUT_2_2 v17
123#define INPUT_2_3 v20
124#define INPUT_2_4 v23
125#define INPUT_2_5 v26
126
127#define INPUT_PTRS(_strides) \
128 const int in_y_origin = (out_y * stride_height) - pad_height; \
129 const int in_x_origin = (out_x * stride_width) - pad_width; \
130 const int8_t* p_in_0 = input_data + \
131 (batch * input_height * input_width * input_depth) + \
132 (in_y_origin * input_width * input_depth) + \
133 ((in_x_origin + _strides) * input_depth) + \
134 in_channel; \
135 const int8_t* p_in_1 = p_in_0 + (input_width * input_depth); \
136 const int8_t* p_in_2 = p_in_1 + (input_width * input_depth); \
137 (void)p_in_2;
138
139#define COMPUTE() \
140 adwinit_v(v48, v48); \
141 adwconv_vxv(v48, INPUT_0_0, cmds, FLT_0_0); \
142 adwconv_vxv(v48, INPUT_0_1, cmds, FLT_0_1); \
143 vdwconv_vxv(v48, INPUT_0_2, cmds, FLT_0_2);
144
Alex Van Damme088841b2024-06-03 16:16:54 -0700145 // Don't reorder me, otherwise data will not be
146 // loaded in the correct order
147 // (we can reuse the p_flt* due to the `p` vld variant).
148 const int8_t* p_flt0 = filter_data + in_channel;
149 const int8_t* p_flt1 = p_flt0 + input_depth;
150 const int32_t stride = 2 * input_depth;
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700151 vld_b_sp_xx(FLT_0_0, p_flt0, stride);
152 vld_b_sp_xx(FLT_0_1, p_flt1, stride);
153 vld_b_sp_xx(FLT_0_2, p_flt0, stride);
154 vld_b_sp_xx(FLT_1_0, p_flt1, stride);
155 vld_b_sp_xx(FLT_1_1, p_flt0, stride);
156 vld_b_sp_xx(FLT_1_2, p_flt1, stride);
157 vld_b_sp_xx(FLT_2_0, p_flt0, stride);
158 vld_b_sp_xx(FLT_2_1, p_flt1, stride);
159 vld_b_sp_xx(FLT_2_2, p_flt0, stride);
Alex Van Damme088841b2024-06-03 16:16:54 -0700160
161 for (int batch = 0; batch < batches; ++batch) {
162 int out_y = 0;
163 for (; out_y < pad_height; ++out_y) {
164 int out_x = 0;
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700165 vdup_b_x(INPUT_0_0, -input_offset);
166 vdup_b_x(INPUT_0_1, -input_offset);
167 vdup_b_x(INPUT_0_2, -input_offset);
Alex Van Damme088841b2024-06-03 16:16:54 -0700168 for (; out_x < pad_width; ++out_x) {
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700169 INPUT_PTRS(1);
Alex Van Damme088841b2024-06-03 16:16:54 -0700170 vmv_v_m(v48, v52);
171
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700172 vdup_b_x(INPUT_1_0, -input_offset);
173 vld_b_sp_xx(INPUT_1_1, p_in_1, input_depth);
174 vld_b_sp_xx(INPUT_1_2, p_in_1, input_depth);
175 vdup_b_x(INPUT_2_0, -input_offset);
176 vld_b_sp_xx(INPUT_2_1, p_in_2, input_depth);
177 vld_b_sp_xx(INPUT_2_2, p_in_2, input_depth);
Alex Van Damme088841b2024-06-03 16:16:54 -0700178
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700179 COMPUTE();
Alex Van Damme088841b2024-06-03 16:16:54 -0700180 INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
181 v48, v56, v60,
182 output_activation_min,
183 output_activation_max,
184 output_offset);
185 vsraqs_b_vx(v48, v48, 0);
186 vst_b_x(v48, p_output);
187
188 p_output += output_depth;
Alex Van Damme088841b2024-06-03 16:16:54 -0700189 }
190 for (; out_x < output_width - pad_width; ++out_x) {
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700191 INPUT_PTRS(0);
Alex Van Damme088841b2024-06-03 16:16:54 -0700192 vmv_v_m(v48, v52);
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700193 vld_b_sp_xx(INPUT_1_0, p_in_1, input_depth);
194 vld_b_sp_xx(INPUT_1_1, p_in_1, input_depth);
195 vld_b_sp_xx(INPUT_1_2, p_in_1, input_depth);
196 vld_b_sp_xx(INPUT_2_0, p_in_2, input_depth);
197 vld_b_sp_xx(INPUT_2_1, p_in_2, input_depth);
198 vld_b_sp_xx(INPUT_2_2, p_in_2, input_depth);
Alex Van Damme088841b2024-06-03 16:16:54 -0700199
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700200 COMPUTE();
Alex Van Damme088841b2024-06-03 16:16:54 -0700201 INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
202 v48, v56, v60,
203 output_activation_min,
204 output_activation_max,
205 output_offset);
206 vsraqs_b_vx(v48, v48, 0);
207 vst_b_x(v48, p_output);
208 p_output += output_depth;
Alex Van Damme088841b2024-06-03 16:16:54 -0700209 }
210 for (; out_x < output_width; ++out_x) {
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700211 INPUT_PTRS(0);
Alex Van Damme088841b2024-06-03 16:16:54 -0700212 vmv_v_m(v48, v52);
213
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700214 vld_b_sp_xx(INPUT_1_0, p_in_1, input_depth);
215 vld_b_sp_xx(INPUT_1_1, p_in_1, input_depth);
216 vdup_b_x(INPUT_1_2, -input_offset);
217 vld_b_sp_xx(INPUT_2_0, p_in_2, input_depth);
218 vld_b_sp_xx(INPUT_2_1, p_in_2, input_depth);
219 vdup_b_x(INPUT_2_2, -input_offset);
Alex Van Damme088841b2024-06-03 16:16:54 -0700220
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700221 COMPUTE();
Alex Van Damme088841b2024-06-03 16:16:54 -0700222 INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
223 v48, v56, v60,
224 output_activation_min,
225 output_activation_max,
226 output_offset);
227 vsraqs_b_vx(v48, v48, 0);
228 vst_b_x(v48, p_output);
229
230 p_output += output_depth;
Alex Van Damme088841b2024-06-03 16:16:54 -0700231 }
232 }
233 for (; out_y < output_height - pad_height; ++out_y) {
Alex Van Damme088841b2024-06-03 16:16:54 -0700234 int out_x = 0;
Alex Van Damme088841b2024-06-03 16:16:54 -0700235 for (; out_x < pad_width; ++out_x) {
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700236 INPUT_PTRS(1);
Alex Van Damme088841b2024-06-03 16:16:54 -0700237 vmv_v_m(v48, v52);
238
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700239 vdup_b_x(INPUT_0_0, -input_offset);
240 vdup_b_x(INPUT_1_0, -input_offset);
241 vdup_b_x(INPUT_2_0, -input_offset);
242 vld_b_sp_xx(INPUT_0_1, p_in_0, input_depth);
243 vld_b_sp_xx(INPUT_0_2, p_in_0, input_depth);
244 vld_b_sp_xx(INPUT_1_1, p_in_1, input_depth);
245 vld_b_sp_xx(INPUT_1_2, p_in_1, input_depth);
246 vld_b_sp_xx(INPUT_2_1, p_in_2, input_depth);
247 vld_b_sp_xx(INPUT_2_2, p_in_2, input_depth);
Alex Van Damme088841b2024-06-03 16:16:54 -0700248
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700249 COMPUTE();
Alex Van Damme088841b2024-06-03 16:16:54 -0700250 INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
251 v48, v56, v60,
252 output_activation_min,
253 output_activation_max,
254 output_offset);
255 vsraqs_b_vx(v48, v48, 0);
256 vst_b_x(v48, p_output);
257
258 p_output += output_depth;
Alex Van Damme088841b2024-06-03 16:16:54 -0700259 }
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700260 for (; out_x + 4 <= output_width - pad_width; out_x += 4) {
261 INPUT_PTRS(0);
Alex Van Damme088841b2024-06-03 16:16:54 -0700262 // Initialize accumulators w/ bias data.
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700263 vmv_v_m(v36, v52);
264 vmv_v_m(v40, v52);
Alex Van Damme088841b2024-06-03 16:16:54 -0700265 vmv_v_m(v44, v52);
266 vmv_v_m(v48, v52);
267
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700268 vld_b_sp_xx(INPUT_0_0, p_in_0, stride_width * input_depth);
269 vld_b_sp_xx(INPUT_1_0, p_in_1, stride_width * input_depth);
270 vld_b_sp_xx(INPUT_2_0, p_in_2, stride_width * input_depth);
271 vld_b_sp_xx(INPUT_0_1, p_in_0, stride_width * input_depth);
272 vld_b_sp_xx(INPUT_1_1, p_in_1, stride_width * input_depth);
273 vld_b_sp_xx(INPUT_2_1, p_in_2, stride_width * input_depth);
274 vld_b_sp_xx(INPUT_0_2, p_in_0, stride_width * input_depth);
275 vld_b_sp_xx(INPUT_1_2, p_in_1, stride_width * input_depth);
276 vld_b_sp_xx(INPUT_2_2, p_in_2, stride_width * input_depth);
Alex Van Damme088841b2024-06-03 16:16:54 -0700277
278 adwinit_v(v48, v48);
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700279 adwconv_vxv(v48, INPUT_0_0, cmds, FLT_0_0);
280 adwconv_vxv(v48, INPUT_0_1, cmds, FLT_0_1);
281 vdwconv_vxv(v48, INPUT_0_2, cmds, FLT_0_2);
282
283 vld_b_sp_xx(INPUT_0_3, p_in_0, stride_width * input_depth);
284 vld_b_sp_xx(INPUT_1_3, p_in_1, stride_width * input_depth);
285 vld_b_sp_xx(INPUT_2_3, p_in_2, stride_width * input_depth);
Alex Van Damme088841b2024-06-03 16:16:54 -0700286
287 adwinit_v(v44, v44);
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700288 adwconv_vxv(v44, INPUT_0_1, cmds, FLT_0_0);
289 adwconv_vxv(v44, INPUT_0_2, cmds, FLT_0_1);
290 vdwconv_vxv(v44, INPUT_0_3, cmds, FLT_0_2);
291
292 vld_b_sp_xx(INPUT_0_4, p_in_0, stride_width * input_depth);
293 vld_b_sp_xx(INPUT_1_4, p_in_1, stride_width * input_depth);
294 vld_b_sp_xx(INPUT_2_4, p_in_2, stride_width * input_depth);
295
296 adwinit_v(v40, v40);
297 adwconv_vxv(v40, INPUT_0_2, cmds, FLT_0_0);
298 adwconv_vxv(v40, INPUT_0_3, cmds, FLT_0_1);
299 vdwconv_vxv(v40, INPUT_0_4, cmds, FLT_0_2);
300
301 vld_b_sp_xx(INPUT_0_5, p_in_0, stride_width * input_depth);
302 vld_b_sp_xx(INPUT_1_5, p_in_1, stride_width * input_depth);
303 vld_b_sp_xx(INPUT_2_5, p_in_2, stride_width * input_depth);
304
305 adwinit_v(v36, v36);
306 adwconv_vxv(v36, INPUT_0_3, cmds, FLT_0_0);
307 adwconv_vxv(v36, INPUT_0_4, cmds, FLT_0_1);
308 vdwconv_vxv(v36, INPUT_0_5, cmds, FLT_0_2);
309
310 INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
311 v48, v56, v60,
312 output_activation_min,
313 output_activation_max,
314 output_offset);
315 vsraqs_b_vx(v48, v48, 0);
316 vst_b_x(v48, p_output);
317 p_output += output_depth;
Alex Van Damme088841b2024-06-03 16:16:54 -0700318
319 INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
320 v44, v56, v60,
321 output_activation_min,
322 output_activation_max,
323 output_offset);
Alex Van Damme088841b2024-06-03 16:16:54 -0700324 vsraqs_b_vx(v44, v44, 0);
Alex Van Damme088841b2024-06-03 16:16:54 -0700325 vst_b_x(v44, p_output);
326 p_output += output_depth;
327
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700328 INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
329 v40, v56, v60,
330 output_activation_min,
331 output_activation_max,
332 output_offset);
333 vsraqs_b_vx(v40, v40, 0);
334 vst_b_x(v40, p_output);
335 p_output += output_depth;
336
337 INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
338 v36, v56, v60,
339 output_activation_min,
340 output_activation_max,
341 output_offset);
342 vsraqs_b_vx(v36, v36, 0);
343 vst_b_x(v36, p_output);
344 p_output += output_depth;
Alex Van Damme088841b2024-06-03 16:16:54 -0700345 }
346 for (; out_x < output_width - pad_width; ++out_x) {
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700347 INPUT_PTRS(0);
Alex Van Damme088841b2024-06-03 16:16:54 -0700348 vmv_v_m(v48, v52);
349
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700350 vld_b_sp_xx(INPUT_0_0, p_in_0, input_depth);
351 vld_b_sp_xx(INPUT_0_1, p_in_0, input_depth);
352 vld_b_sp_xx(INPUT_0_2, p_in_0, input_depth);
353 vld_b_sp_xx(INPUT_1_0, p_in_1, input_depth);
354 vld_b_sp_xx(INPUT_1_1, p_in_1, input_depth);
355 vld_b_sp_xx(INPUT_1_2, p_in_1, input_depth);
356 vld_b_sp_xx(INPUT_2_0, p_in_2, input_depth);
357 vld_b_sp_xx(INPUT_2_1, p_in_2, input_depth);
358 vld_b_sp_xx(INPUT_2_2, p_in_2, input_depth);
Alex Van Damme088841b2024-06-03 16:16:54 -0700359
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700360 COMPUTE();
Alex Van Damme088841b2024-06-03 16:16:54 -0700361 INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
362 v48, v56, v60,
363 output_activation_min,
364 output_activation_max,
365 output_offset);
366 vsraqs_b_vx(v48, v48, 0);
367 vst_b_x(v48, p_output);
368 p_output += output_depth;
Alex Van Damme088841b2024-06-03 16:16:54 -0700369 }
370 for (; out_x < output_width; ++out_x) {
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700371 INPUT_PTRS(0);
Alex Van Damme088841b2024-06-03 16:16:54 -0700372 vmv_v_m(v48, v52);
373
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700374 vdup_b_x(INPUT_0_2, -input_offset);
375 vdup_b_x(INPUT_1_2, -input_offset);
376 vdup_b_x(INPUT_2_2, -input_offset);
377 vld_b_sp_xx(INPUT_0_0, p_in_0, input_depth);
378 vld_b_sp_xx(INPUT_0_1, p_in_0, input_depth);
379 vld_b_sp_xx(INPUT_1_0, p_in_1, input_depth);
380 vld_b_sp_xx(INPUT_1_1, p_in_1, input_depth);
381 vld_b_sp_xx(INPUT_2_0, p_in_2, input_depth);
382 vld_b_sp_xx(INPUT_2_1, p_in_2, input_depth);
Alex Van Damme088841b2024-06-03 16:16:54 -0700383
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700384 COMPUTE();
Alex Van Damme088841b2024-06-03 16:16:54 -0700385 INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
386 v48, v56, v60,
387 output_activation_min,
388 output_activation_max,
389 output_offset);
390 vsraqs_b_vx(v48, v48, 0);
391 vst_b_x(v48, p_output);
392
393 p_output += output_depth;
Alex Van Damme088841b2024-06-03 16:16:54 -0700394 }
395 }
396 for (; out_y < output_height; ++out_y) {
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700397 vdup_b_x(INPUT_2_0, -input_offset);
398 vdup_b_x(INPUT_2_1, -input_offset);
399 vdup_b_x(INPUT_2_2, -input_offset);
Alex Van Damme088841b2024-06-03 16:16:54 -0700400 int out_x = 0;
Alex Van Damme088841b2024-06-03 16:16:54 -0700401 for (; out_x < pad_width; ++out_x) {
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700402 INPUT_PTRS(1);
Alex Van Damme088841b2024-06-03 16:16:54 -0700403 vmv_v_m(v48, v52);
404
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700405 vdup_b_x(INPUT_0_0, -input_offset);
406 vld_b_sp_xx(INPUT_0_1, p_in_0, input_depth);
407 vld_b_sp_xx(INPUT_0_2, p_in_0, input_depth);
408 vdup_b_x(INPUT_1_0, -input_offset);
409 vld_b_sp_xx(INPUT_1_1, p_in_1, input_depth);
410 vld_b_sp_xx(INPUT_1_2, p_in_1, input_depth);
Alex Van Damme088841b2024-06-03 16:16:54 -0700411
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700412 COMPUTE();
Alex Van Damme088841b2024-06-03 16:16:54 -0700413 INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
414 v48, v56, v60,
415 output_activation_min,
416 output_activation_max,
417 output_offset);
418 vsraqs_b_vx(v48, v48, 0);
419 vst_b_x(v48, p_output);
420 p_output += output_depth;
Alex Van Damme088841b2024-06-03 16:16:54 -0700421 }
422 for (; out_x < output_width - pad_width; ++out_x) {
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700423 INPUT_PTRS(0);
Alex Van Damme088841b2024-06-03 16:16:54 -0700424 vmv_v_m(v48, v52);
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700425 vld_b_sp_xx(INPUT_0_0, p_in_0, input_depth);
426 vld_b_sp_xx(INPUT_0_1, p_in_0, input_depth);
427 vld_b_sp_xx(INPUT_0_2, p_in_0, input_depth);
428 vld_b_sp_xx(INPUT_1_0, p_in_1, input_depth);
429 vld_b_sp_xx(INPUT_1_1, p_in_1, input_depth);
430 vld_b_sp_xx(INPUT_1_2, p_in_1, input_depth);
Alex Van Damme088841b2024-06-03 16:16:54 -0700431
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700432 COMPUTE();
Alex Van Damme088841b2024-06-03 16:16:54 -0700433 INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
434 v48, v56, v60,
435 output_activation_min,
436 output_activation_max,
437 output_offset);
438 vsraqs_b_vx(v48, v48, 0);
439 vst_b_x(v48, p_output);
440 p_output += output_depth;
Alex Van Damme088841b2024-06-03 16:16:54 -0700441 }
442 for (; out_x < output_width; ++out_x) {
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700443 INPUT_PTRS(0);
Alex Van Damme088841b2024-06-03 16:16:54 -0700444 vmv_v_m(v48, v52);
445
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700446 vld_b_sp_xx(INPUT_0_0, p_in_0, input_depth);
447 vld_b_sp_xx(INPUT_0_1, p_in_0, input_depth);
448 vdup_b_x(INPUT_0_2, -input_offset);
449 vld_b_sp_xx(INPUT_1_0, p_in_1, input_depth);
450 vld_b_sp_xx(INPUT_1_1, p_in_1, input_depth);
451 vdup_b_x(INPUT_1_2, -input_offset);
Alex Van Damme088841b2024-06-03 16:16:54 -0700452
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700453 COMPUTE();
Alex Van Damme088841b2024-06-03 16:16:54 -0700454 INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
455 v48, v56, v60,
456 output_activation_min,
457 output_activation_max,
458 output_offset);
459 vsraqs_b_vx(v48, v48, 0);
460 vst_b_x(v48, p_output);
461 p_output += output_depth;
Alex Van Damme088841b2024-06-03 16:16:54 -0700462 }
463 }
464 }
465 }
Alex Van Dammed22d7d62024-06-07 10:53:00 -0700466#undef FLT_0_0
467#undef FLT_0_1
468#undef FLT_0_2
469#undef FLT_1_0
470#undef FLT_1_1
471#undef FLT_1_2
472#undef FLT_2_0
473#undef FLT_2_1
474#undef FLT_2_2
475#undef COMPUTE
Alex Van Damme088841b2024-06-03 16:16:54 -0700476}
477
Alex Van Dammecd3d0e32024-05-10 15:27:06 -0700478// special case of input depth = 32n, filter shape of 3x3
479void DepthwiseConvS83x3D32(
480 const tflite::DepthwiseParams& params, const int32_t* output_multiplier,
481 const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
482 const int8_t* input_data, const tflite::RuntimeShape& filter_shape,
483 const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
484 const int32_t* bias_data, const tflite::RuntimeShape& output_shape,
485 int8_t* output_data
486) {
487 const int stride_width = params.stride_width;
488 const int stride_height = params.stride_height;
489 const int pad_width = params.padding_values.width;
490 const int pad_height = params.padding_values.height;
491 const int32_t input_offset = params.input_offset;
492 const int32_t output_offset = params.output_offset;
493 const int32_t output_activation_min = params.quantized_activation_min;
494 const int32_t output_activation_max = params.quantized_activation_max;
495 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
496 const int input_height = input_shape.Dims(1);
497 const int input_width = input_shape.Dims(2);
498 const int input_depth = input_shape.Dims(3);
499 const int output_height = output_shape.Dims(1);
500 const int output_width = output_shape.Dims(2);
501 const int output_depth = output_shape.Dims(3);
502 int32_t swizzled_bias_data[32];
503 int32_t swizzled_shift_multi[32];
504 int32_t swizzled_output_multi[32];
505
506 for (int in_channel = 0; in_channel + 32 <= input_depth; in_channel += 32) {
507 const int output_channel = in_channel;
508 VectorSwizzle(bias_data + output_channel, swizzled_bias_data, 32);
509 VectorSwizzle(output_multiplier + output_channel, swizzled_output_multi, 32);
510 VectorSwizzle(output_shift + output_channel, swizzled_shift_multi, 32);
511
512 vld_w_x_m(v52, swizzled_bias_data);
513 vld_w_x_m(v56, swizzled_output_multi);
514 vld_w_x_m(v60, swizzled_shift_multi);
515 vrsub_w_vx_m(v60, v60, 0);
516
517 union {
518 vdwconv_u8_t dwconv;
519 uint32_t raw;
520 } cmds;
521 cmds.raw = 0;
522 cmds.dwconv.sdata1 = true;
523 cmds.dwconv.sbias1 = input_offset;
524 cmds.dwconv.sdata2 = true;
525 cmds.dwconv.sbias2 = 0;
526 cmds.dwconv.mode = 0;
527 cmds.dwconv.sparsity = 0;
528 cmds.dwconv.regbase = 0;
529
530 // Don't reorder me, otherwise data will not be
531 // loaded in the correct order
532 // (we can reuse the p_flt* due to the `p` vld variant).
533 const int8_t* p_flt0 = filter_data + in_channel;
534 const int8_t* p_flt1 = p_flt0 + input_depth;
535 const int32_t stride = 2 * input_depth;
536 vld_b_sp_xx(v6, p_flt0, stride);
537 vld_b_sp_xx(v7, p_flt1, stride);
538 vld_b_sp_xx(v8, p_flt0, stride);
539 vld_b_sp_xx(v9, p_flt1, stride);
540 vld_b_sp_xx(v10, p_flt0, stride);
541 vld_b_sp_xx(v11, p_flt1, stride);
542 vld_b_sp_xx(v12, p_flt0, stride);
543 vld_b_sp_xx(v13, p_flt1, stride);
544 vld_b_sp_xx(v14, p_flt0, stride);
545
546 for (int batch = 0; batch < batches; ++batch) {
547 const int8_t* p_output = output_data + (batch * output_width * output_height * output_depth) + output_channel;
548 for (int out_y = 0; out_y < output_height; ++out_y) {
549 const int in_y_origin = (out_y * stride_height) - pad_height;
550 const int y_offset = (output_depth * output_width * out_y);
551 for (int out_x = 0; out_x < output_width; ++out_x) {
552 const int in_x_origin = (out_x * stride_width) - pad_width;
553
554 // Initialize accumulators w/ bias data.
555 vmv_v_m(v48, v52);
556
557 bool top_pad = in_y_origin < 0;
558 bool left_pad = in_x_origin < 0;
559 bool bottom_pad = (in_y_origin + 2) >= input_height;
560 bool right_pad = (in_x_origin + 2) >= input_width;
561 bool padding_required = top_pad || left_pad || bottom_pad || right_pad;
562 const int8_t* p_in_0 = input_data +
563 (batch * input_height * input_width * input_depth) +
564 (in_y_origin * input_width * input_depth) +
565 (in_x_origin * input_depth) +
566 in_channel;
567 const int8_t* p_in_1 = p_in_0 + (input_width * input_depth);
568 const int8_t* p_in_2 = p_in_1 + (input_width * input_depth);
569 if (!padding_required) {
570 vld_b_sp_xx(v15, p_in_0, input_depth);
571 vld_b_sp_xx(v16, p_in_0, input_depth);
572 vld_b_sp_xx(v17, p_in_0, input_depth);
573 vld_b_sp_xx(v18, p_in_1, input_depth);
574 vld_b_sp_xx(v19, p_in_1, input_depth);
575 vld_b_sp_xx(v20, p_in_1, input_depth);
576 vld_b_sp_xx(v21, p_in_2, input_depth);
577 vld_b_sp_xx(v22, p_in_2, input_depth);
578 vld_b_sp_xx(v23, p_in_2, input_depth);
579 } else {
580 // Top row
581 if (top_pad || left_pad) {
582 vdup_b_x(v15, -input_offset);
583 } else {
584 vld_b_x(v15, p_in_0);
585 }
586 if (top_pad) {
587 vdup_b_x(v16, -input_offset);
588 } else {
589 vld_b_x(v16, p_in_0 + input_depth);
590 }
591 if (top_pad || right_pad) {
592 vdup_b_x(v17, -input_offset);
593 } else {
594 vld_b_x(v17, p_in_0 + (2 * input_depth));
595 }
596 // Middle row
597 if (left_pad) {
598 vdup_b_x(v18, -input_offset);
599 } else {
600 vld_b_x(v18, p_in_1);
601 }
602 vld_b_x(v19, p_in_1 + input_depth);
603 if (right_pad) {
604 vdup_b_x(v20, -input_offset);
605 } else {
606 vld_b_x(v20, p_in_1 + (2 * input_depth));
607 }
608 // Bottom row
609 if (bottom_pad || left_pad) {
610 vdup_b_x(v21, -input_offset);
611 } else {
612 vld_b_x(v21, p_in_2);
613 }
614 if (bottom_pad) {
615 vdup_b_x(v22, -input_offset);
616 } else {
617 vld_b_x(v22, p_in_2 + input_depth);
618 }
619 if (bottom_pad || right_pad) {
620 vdup_b_x(v23, -input_offset);
621 } else {
622 vld_b_x(v23, p_in_2 + (2 * input_depth));
623 }
624 }
625
626 adwinit_v(v48, v48);
627 adwconv_vxv(v48, v15, cmds, v6);
628 adwconv_vxv(v48, v18, cmds, v9);
629 vdwconv_vxv(v48, v21, cmds, v12);
630
631 vdmulh_w_rn_vv_m(v48, v48, v56);
632 vsha_w_r_vv_m(v48, v48, v60);
633 vadd_w_vx_m(v48, v48, output_offset);
634 vmax_w_vx_m(v48, v48, output_activation_min);
635 vmin_w_vx_m(v48, v48, output_activation_max);
636 vsraqs_b_vx(v48, v48, 0);
637 vst_b_x(v48, p_output + (out_x * output_depth) + y_offset);
638 }
639 }
640 }
641 }
642}
Alex Van Dammeb1afda62024-05-09 16:48:40 -0700643
644// special case of input depth = 32n, filter shape of 5x5, stride == 1
645void DepthwiseConvS85x5D32_Stride1(
646 const tflite::DepthwiseParams& params, const int32_t* output_multiplier,
647 const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
648 const int8_t* input_data, const tflite::RuntimeShape& filter_shape,
649 const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
650 const int32_t* bias_data, const tflite::RuntimeShape& output_shape,
651 int8_t* output_data
652) {
653 const int stride_width = params.stride_width;
654 const int stride_height = params.stride_height;
655 const int pad_width = params.padding_values.width;
656 const int pad_height = params.padding_values.height;
657 const int32_t input_offset = params.input_offset;
658 const int32_t output_offset = params.output_offset;
659 const int32_t output_activation_min = params.quantized_activation_min;
660 const int32_t output_activation_max = params.quantized_activation_max;
661 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
662 const int input_height = input_shape.Dims(1);
663 const int input_width = input_shape.Dims(2);
664 const int input_depth = input_shape.Dims(3);
665 const int output_height = output_shape.Dims(1);
666 const int output_width = output_shape.Dims(2);
667 const int output_depth = output_shape.Dims(3);
668 int32_t swizzled_bias_data[32];
669 int32_t swizzled_shift_multi[32];
670 int32_t swizzled_output_multi[32];
671
672 for (int in_channel = 0; in_channel + 32 <= input_depth; in_channel += 32) {
673 const int output_channel = in_channel;
674 VectorSwizzle(bias_data + output_channel, swizzled_bias_data, 32);
675 VectorSwizzle(output_multiplier + output_channel, swizzled_output_multi, 32);
676 VectorSwizzle(output_shift + output_channel, swizzled_shift_multi, 32);
677
678 union {
679 vdwconv_u8_t dwconv;
680 uint32_t raw;
681 } cmds;
682 cmds.raw = 0;
683 cmds.dwconv.sdata1 = true;
684 cmds.dwconv.sbias1 = input_offset;
685 cmds.dwconv.sdata2 = true;
686 cmds.dwconv.sbias2 = 0;
687 cmds.dwconv.mode = 0;
688 cmds.dwconv.sparsity = 0;
689 cmds.dwconv.regbase = 0;
690
691 // Don't reorder me!
692 const int8_t* p_flt0 = filter_data + in_channel;
693 const int32_t stride = input_depth;
694 vld_b_sp_xx_m(v0, p_flt0, stride);
695 vld_b_sp_xx_m(v4, p_flt0, stride);
696 vld_b_sp_xx_m(v8, p_flt0, stride);
697 vld_b_sp_xx_m(v12, p_flt0, stride);
698 vld_b_sp_xx_m(v16, p_flt0, stride);
699 vld_b_sp_xx_m(v20, p_flt0, stride);
700 vld_b_sp_xx(v24, p_flt0, stride);
701
702 // Extra two registers to get our
703 // total usage to a multiple of 3 for dwconv.
704 vdup_b_x(v25, 0);
705 vdup_b_x(v26, 0);
706
707 for (int batch = 0; batch < batches; ++batch) {
708 const int8_t* p_output = output_data + (batch * output_height * output_width * output_depth) + output_channel;
709 for (int out_y = 0; out_y < output_height; ++out_y) {
710 const int y_offset = out_y * output_width * output_depth;
711 for (int out_x = 0; out_x < output_width; ++out_x) {
712 const int in_x_origin = (out_x * stride_width) - pad_width;
713 const int in_y_origin = (out_y * stride_height) - pad_height;
714
715 bool top_pad = in_y_origin < 0;
716 bool left_pad = in_x_origin < 0;
717 int top_pad_count = top_pad ? 0 - in_y_origin : 0;
718 int left_pad_count = left_pad ? 0 - in_x_origin : 0;
719 bool bottom_pad = (in_y_origin + 4) >= input_height;
720 bool right_pad = (in_x_origin + 4) >= input_width;
721 int bottom_pad_count = std::abs(bottom_pad ? (in_y_origin + 4) - input_height + 1: 0);
722 int right_pad_count = std::abs(right_pad ? (in_x_origin + 4) - input_width + 1 : 0);
723 bool padding_required = top_pad || left_pad || bottom_pad || right_pad;
724 assert(top_pad_count <= pad_height);
725 assert(bottom_pad_count <= pad_height);
726 assert(left_pad_count <= pad_width);
727 assert(right_pad_count <= pad_width);
728 assert(!(left_pad && right_pad));
729 const int8_t* p_in_0 = input_data +
730 (batch * input_height * input_width * input_depth) +
731 (in_y_origin * input_width * input_depth) +
732 (in_x_origin * input_depth) +
733 in_channel;
734 const int8_t* p_in_1 = p_in_0 + (input_width * input_depth);
735 const int8_t* p_in_2 = p_in_1 + (input_width * input_depth);
736 const int8_t* p_in_3 = p_in_2 + (input_width * input_depth);
737 const int8_t* p_in_4 = p_in_3 + (input_width * input_depth);
738 // Extra two registers to get our
739 // total usage to a multiple of 3 for dwconv.
740 vdup_b_x(v52, -input_offset);
741 vdup_b_x(v53, -input_offset);
742 if (!padding_required) {
743 vld_b_sp_xx(v27, p_in_0, input_depth);
744 vld_b_sp_xx_m(v28, p_in_0, input_depth);
745 vld_b_sp_xx_m(v32, p_in_1, input_depth);
746 vld_b_sp_xx(v36, p_in_1, input_depth);
747 vld_b_sp_xx(v37, p_in_2, input_depth);
748 vld_b_sp_xx(v38, p_in_2, input_depth);
749 vld_b_sp_xx(v39, p_in_2, input_depth);
750 vld_b_sp_xx(v40, p_in_2, input_depth);
751 vld_b_sp_xx(v41, p_in_2, input_depth);
752 vld_b_sp_xx(v42, p_in_3, input_depth);
753 vld_b_sp_xx(v43, p_in_3, input_depth);
754 vld_b_sp_xx(v44, p_in_3, input_depth);
755 vld_b_sp_xx(v45, p_in_3, input_depth);
756 vld_b_sp_xx(v46, p_in_3, input_depth);
757 vld_b_sp_xx(v47, p_in_4, input_depth);
758 vld_b_sp_xx_m(v48, p_in_4, input_depth);
759 } else {
760 // Top row
761 if (top_pad_count >= 1) {
762 vdup_b_x(v27, -input_offset);
763 vdup_b_x_m(v28, -input_offset);
764 } else {
765 switch (left_pad_count) {
766 case 2:
767 vdup_b_x(v28, -input_offset);
768 case 1:
769 vdup_b_x(v27, -input_offset);
770 }
771 switch (left_pad_count) {
772 case 0:
773 vld_b_x(v27, p_in_0);
774 case 1:
775 vld_b_x(v28, p_in_0 + input_depth);
776 }
777 vld_b_x(v29, p_in_0 + (2 * input_depth));
778 switch (right_pad_count) {
779 case 2:
780 vdup_b_x(v30, -input_offset);
781 case 1:
782 vdup_b_x(v31, -input_offset);
783 }
784 switch (right_pad_count) {
785 case 0:
786 vld_b_x(v31, p_in_0 + (4 * input_depth));
787 case 1:
788 vld_b_x(v30, p_in_0 + (3 * input_depth));
789 }
790 }
791
792 // 2nd row
793 if (top_pad_count == 2) {
794 vdup_b_x_m(v32, -input_offset);
795 vdup_b_x(v36, -input_offset);
796 } else {
797 switch (left_pad_count) {
798 case 2:
799 vdup_b_x(v33, -input_offset);
800 case 1:
801 vdup_b_x(v32, -input_offset);
802 }
803 switch (left_pad_count) {
804 case 0:
805 vld_b_x(v32, p_in_1);
806 case 1:
807 vld_b_x(v33, p_in_1 + input_depth);
808 }
809 vld_b_x(v34, p_in_1 + (2 * input_depth));
810 switch (right_pad_count) {
811 case 2:
812 vdup_b_x(v35, -input_offset);
813 case 1:
814 vdup_b_x(v36, -input_offset);
815 }
816 switch (right_pad_count) {
817 case 0:
818 vld_b_x(v36, p_in_1 + (4 * input_depth));
819 case 1:
820 vld_b_x(v35, p_in_1 + (3 * input_depth));
821 }
822 }
823
824 // 3rd row
825 switch (left_pad_count) {
826 case 2:
827 vdup_b_x(v38, -input_offset);
828 case 1:
829 vdup_b_x(v37, -input_offset);
830 }
831 switch (left_pad_count) {
832 case 0:
833 vld_b_x(v37, p_in_2);
834 case 1:
835 vld_b_x(v38, p_in_2 + input_depth);
836 }
837 vld_b_x(v39, p_in_2 + (2 * input_depth));
838 switch (right_pad_count) {
839 case 2:
840 vdup_b_x(v40, -input_offset);
841 case 1:
842 vdup_b_x(v41, -input_offset);
843 }
844 switch (right_pad_count) {
845 case 0:
846 vld_b_x(v41, p_in_2 + (4 * input_depth));
847 case 1:
848 vld_b_x(v40, p_in_2 + (3 * input_depth));
849 }
850
851 // 4th row
852 if (bottom_pad_count == 2) {
853 vdup_b_x(v42, -input_offset);
854 vdup_b_x(v43, -input_offset);
855 vdup_b_x(v44, -input_offset);
856 vdup_b_x(v45, -input_offset);
857 vdup_b_x(v46, -input_offset);
858 } else {
859 switch (left_pad_count) {
860 case 2:
861 vdup_b_x(v43, -input_offset);
862 case 1:
863 vdup_b_x(v42, -input_offset);
864 }
865 switch (left_pad_count) {
866 case 0:
867 vld_b_x(v42, p_in_3);
868 case 1:
869 vld_b_x(v43, p_in_3 + input_depth);
870 }
871 switch (right_pad_count) {
872 case 2:
873 vdup_b_x(v45, -input_offset);
874 case 1:
875 vdup_b_x(v46, -input_offset);
876 }
877 vld_b_x(v44, p_in_3 + (2 * input_depth));
878 switch (right_pad_count) {
879 case 0:
880 vld_b_x(v46, p_in_3 + (4 * input_depth));
881 case 1:
882 vld_b_x(v45, p_in_3 + (3 * input_depth));
883 }
884 }
885
886 // 5th row
887 if (bottom_pad_count >= 1) {
888 vdup_b_x(v47, -input_offset);
889 vdup_b_x(v48, -input_offset);
890 vdup_b_x(v49, -input_offset);
891 vdup_b_x(v50, -input_offset);
892 vdup_b_x(v51, -input_offset);
893 } else {
894 switch (left_pad_count) {
895 case 2:
896 vdup_b_x(v48, -input_offset);
897 case 1:
898 vdup_b_x(v47, -input_offset);
899 }
900 switch (left_pad_count) {
901 case 0:
902 vld_b_x(v47, p_in_4);
903 case 1:
904 vld_b_x(v48, p_in_4 + input_depth);
905 }
906 vld_b_x(v49, p_in_4 + (2 * input_depth));
907 switch (right_pad_count) {
908 case 2:
909 vdup_b_x(v50, -input_offset);
910 case 1:
911 vdup_b_x(v51, -input_offset);
912 }
913 switch (right_pad_count) {
914 case 0:
915 vld_b_x(v51, p_in_4 + (4 * input_depth));
916 case 1:
917 vld_b_x(v50, p_in_4 + (3 * input_depth));
918 }
919 }
920 }
921
922 vld_w_x_m(v60, swizzled_bias_data);
923 adwinit_v(v60, v60);
924 adwconv_vxv(v60, v27, cmds, v0);
925 adwconv_vxv(v60, v30, cmds, v3);
926 adwconv_vxv(v60, v33, cmds, v6);
927 adwconv_vxv(v60, v36, cmds, v9);
928 adwconv_vxv(v60, v39, cmds, v12);
929 adwconv_vxv(v60, v42, cmds, v15);
930 adwconv_vxv(v60, v45, cmds, v18);
931 adwconv_vxv(v60, v48, cmds, v21);
932 vdwconv_vxv(v60, v51, cmds, v24);
933
934 vld_w_x_m(v56, swizzled_output_multi);
935 vdmulh_w_rn_vv_m(v60, v60, v56);
936 vld_w_x_m(v56, swizzled_shift_multi);
937 vrsub_w_vx_m(v56, v56, 0);
938 vsha_w_r_vv_m(v60, v60, v56);
939 vadd_w_vx_m(v60, v60, output_offset);
940 vmax_w_vx_m(v60, v60, output_activation_min);
941 vmin_w_vx_m(v60, v60, output_activation_max);
942 vsraqs_b_vx(v60, v60, 0);
943 vst_b_x(v60, p_output + y_offset + (out_x * output_depth));
944 }
945 }
Alex Van Damme40a83002024-05-08 16:47:03 -0700946 }
947 }
Alex Van Dammeb1afda62024-05-09 16:48:40 -0700948}
949
950// special case of input depth = 32n, filter shape of 5x5
951void DepthwiseConvS85x5D32(
952 const tflite::DepthwiseParams& params, const int32_t* output_multiplier,
953 const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
954 const int8_t* input_data, const tflite::RuntimeShape& filter_shape,
955 const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
956 const int32_t* bias_data, const tflite::RuntimeShape& output_shape,
957 int8_t* output_data
958) {
959 const int stride_width = params.stride_width;
960 const int stride_height = params.stride_height;
961 const int pad_width = params.padding_values.width;
962 const int pad_height = params.padding_values.height;
963 const int32_t input_offset = params.input_offset;
964 const int32_t output_offset = params.output_offset;
965 const int32_t output_activation_min = params.quantized_activation_min;
966 const int32_t output_activation_max = params.quantized_activation_max;
967 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
968 const int input_height = input_shape.Dims(1);
969 const int input_width = input_shape.Dims(2);
970 const int input_depth = input_shape.Dims(3);
971 const int filter_height = filter_shape.Dims(1);
972 const int filter_width = filter_shape.Dims(2);
973 const int output_height = output_shape.Dims(1);
974 const int output_width = output_shape.Dims(2);
975 const int output_depth = output_shape.Dims(3);
976 int32_t swizzled_bias_data[32];
977 int32_t swizzled_shift_multi[32];
978 int32_t swizzled_output_multi[32];
979
980 for (int in_channel = 0; in_channel + 32 <= input_depth; in_channel += 32) {
981 const int output_channel = in_channel;
982 VectorSwizzle(bias_data + output_channel, swizzled_bias_data, 32);
983 VectorSwizzle(output_multiplier + output_channel, swizzled_output_multi, 32);
984 VectorSwizzle(output_shift + output_channel, swizzled_shift_multi, 32);
985
986 vld_w_x_m(v52, swizzled_bias_data);
987 vld_w_x_m(v56, swizzled_output_multi);
988 vld_w_x_m(v60, swizzled_shift_multi);
989 vrsub_w_vx_m(v60, v60, 0);
990
991 // Don't reorder me!
992 const int8_t* p_flt = filter_data + in_channel;
993 vld_b_sp_xx(v6, p_flt, input_depth);
994 vld_b_sp_xx(v7, p_flt, input_depth);
995 vld_b_sp_xx_m(v8, p_flt, input_depth);
996 vld_b_sp_xx_m(v12, p_flt, input_depth);
997 vld_b_sp_xx_m(v16, p_flt, input_depth);
998 vld_b_sp_xx_m(v20, p_flt, input_depth);
999 vld_b_sp_xx_m(v24, p_flt, input_depth);
1000 vld_b_sp_xx(v28, p_flt, input_depth);
1001 vld_b_sp_xx(v29, p_flt, input_depth);
1002 vld_b_sp_xx(v30, p_flt, input_depth);
1003
1004
1005 for (int batch = 0; batch < batches; ++batch) {
1006 const int8_t* p_input = input_data + (batch * input_width * input_height * input_depth) + in_channel;
1007 const int8_t* p_output = output_data + (batch * output_width * output_height * output_depth) + output_channel;
1008 for (int out_y = 0; out_y < output_height; ++out_y) {
1009 const int out_y_offset = (out_y * output_width * output_depth);
1010 for (int out_x = 0; out_x < output_width; ++out_x) {
1011 const int in_x_origin = (out_x * stride_width) - pad_width;
1012 const int in_y_origin = (out_y * stride_height) - pad_height;
1013
1014 // Initialize accumulators w/ bias_data
1015 vmv_v_m(v48, v52);
1016
1017 for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
1018 const int in_y = in_y_origin + filter_y;
1019 if ((in_y < 0) || (in_y >= input_height)) {
1020 continue;
1021 }
1022 switch (filter_y) {
1023 case 0:
1024 vaddw_h_vx(v31, v6, 0);
1025 vaddw_h_vx(v33, v7, 0);
1026 vaddw_h_vx(v35, v8, 0);
1027 vaddw_h_vx(v37, v9, 0);
1028 vaddw_h_vx(v39, v10, 0);
1029 break;
1030 case 1:
1031 vaddw_h_vx(v31, v11, 0);
1032 vaddw_h_vx(v33, v12, 0);
1033 vaddw_h_vx(v35, v13, 0);
1034 vaddw_h_vx(v37, v14, 0);
1035 vaddw_h_vx(v39, v15, 0);
1036 break;
1037 case 2:
1038 vaddw_h_vx(v31, v16, 0);
1039 vaddw_h_vx(v33, v17, 0);
1040 vaddw_h_vx(v35, v18, 0);
1041 vaddw_h_vx(v37, v19, 0);
1042 vaddw_h_vx(v39, v20, 0);
1043 break;
1044 case 3:
1045 vaddw_h_vx(v31, v21, 0);
1046 vaddw_h_vx(v33, v22, 0);
1047 vaddw_h_vx(v35, v23, 0);
1048 vaddw_h_vx(v37, v24, 0);
1049 vaddw_h_vx(v39, v25, 0);
1050 break;
1051 case 4:
1052 vaddw_h_vx(v31, v26, 0);
1053 vaddw_h_vx(v33, v27, 0);
1054 vaddw_h_vx(v35, v28, 0);
1055 vaddw_h_vx(v37, v29, 0);
1056 vaddw_h_vx(v39, v30, 0);
1057 break;
1058 }
1059 const int in_y_offset = in_y * input_width * input_depth;
1060 for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
1061 const int in_x = in_x_origin + filter_x;
1062 if ((in_x < 0) || (in_x >= input_width)) {
1063 continue;
1064 }
1065
1066 vld_b_x(v0, p_input + (in_x * input_depth) + in_y_offset);
1067
1068 vaddw_h_vx(v0, v0, 0);
1069 vadd_h_vx(v0, v0, static_cast<int16_t>(input_offset));
1070 vadd_h_vx(v1, v1,
1071 static_cast<int16_t>(input_offset)); // v0 v1 input
1072 switch (filter_x) {
1073 case 0:
1074 vmulw_w_vv(v2, v1, v32);
1075 vmulw_w_vv(v0, v0, v31);
1076 break;
1077 case 1:
1078 vmulw_w_vv(v2, v1, v34);
1079 vmulw_w_vv(v0, v0, v33);
1080 break;
1081 case 2:
1082 vmulw_w_vv(v2, v1, v36);
1083 vmulw_w_vv(v0, v0, v35);
1084 break;
1085 case 3:
1086 vmulw_w_vv(v2, v1, v38);
1087 vmulw_w_vv(v0, v0, v37);
1088 break;
1089 case 4:
1090 vmulw_w_vv(v2, v1, v40);
1091 vmulw_w_vv(v0, v0, v39);
1092 break;
1093 }
1094 vadd_w_vv_m(v48, v48, v0);
1095 }
1096 }
1097
1098 vdmulh_w_rn_vv_m(v48, v48, v56);
1099 vsha_w_r_vv_m(v48, v48, v60);
1100 vadd_w_vx_m(v48, v48, output_offset);
1101 vmax_w_vx_m(v48, v48, output_activation_min);
1102 vmin_w_vx_m(v48, v48, output_activation_max);
1103 vsraqs_b_vx(v48, v48, 0);
1104 vst_b_x(v48, p_output + out_y_offset + (out_x * output_depth));
1105 }
1106 }
1107 }
1108 }
1109}
Alex Van Damme40a83002024-05-08 16:47:03 -07001110
Lun Dong3b8d3cb2024-05-07 01:50:35 -07001111// special case of input depth = 32n
1112void DepthwiseConvS8D32(
Naveen Doddabe4ab972024-04-17 17:47:46 +00001113 const tflite::DepthwiseParams& params, const int32_t* output_multiplier,
1114 const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
1115 const int8_t* input_data, const tflite::RuntimeShape& filter_shape,
1116 const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
1117 const int32_t* bias_data, const tflite::RuntimeShape& output_shape,
1118 int8_t* output_data
1119
1120) {
1121 const int stride_width = params.stride_width;
1122 const int stride_height = params.stride_height;
Naveen Doddabe4ab972024-04-17 17:47:46 +00001123 const int pad_width = params.padding_values.width;
1124 const int pad_height = params.padding_values.height;
1125 const int32_t input_offset = params.input_offset;
1126 const int32_t output_offset = params.output_offset;
1127 const int32_t output_activation_min = params.quantized_activation_min;
1128 const int32_t output_activation_max = params.quantized_activation_max;
1129 const int batches = MatchingDim(input_shape, 0, output_shape, 0);
1130 const int input_height = input_shape.Dims(1);
1131 const int input_width = input_shape.Dims(2);
1132 const int input_depth = input_shape.Dims(3);
1133 const int filter_height = filter_shape.Dims(1);
1134 const int filter_width = filter_shape.Dims(2);
1135 const int output_height = output_shape.Dims(1);
1136 const int output_width = output_shape.Dims(2);
Alex Van Damme40a83002024-05-08 16:47:03 -07001137 int32_t swizzled_bias_data[32];
1138 int32_t swizzled_shift_multi[32];
1139 int32_t swizzled_output_multi[32];
Naveen Doddabe4ab972024-04-17 17:47:46 +00001140
1141 for (int in_channel = 0; in_channel + 32 <= input_depth; in_channel += 32) {
1142 const int output_channel = in_channel;
Alex Van Damme40a83002024-05-08 16:47:03 -07001143 VectorSwizzle(bias_data + output_channel, swizzled_bias_data, 32);
1144 VectorSwizzle(output_multiplier + output_channel, swizzled_output_multi, 32);
1145 VectorSwizzle(output_shift + output_channel, swizzled_shift_multi, 32);
Naveen Doddabe4ab972024-04-17 17:47:46 +00001146
1147 vld_w_x_m(v20, swizzled_bias_data);
1148 vld_w_x_m(v24, swizzled_output_multi);
1149 vld_w_x_m(v28, swizzled_shift_multi);
1150 vrsub_w_vx_m(v28, v28, 0);
1151
1152 for (int batch = 0; batch < batches; ++batch) {
1153 for (int out_y = 0; out_y < output_height; ++out_y) {
1154 for (int out_x = 0; out_x < output_width; ++out_x) {
1155 const int in_x_origin = (out_x * stride_width) - pad_width;
1156 const int in_y_origin = (out_y * stride_height) - pad_height;
1157
Alex Van Damme40a83002024-05-08 16:47:03 -07001158 vdup_w_x_m(v48, 0);
Naveen Doddabe4ab972024-04-17 17:47:46 +00001159 for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
1160 const int in_y = in_y_origin + filter_y;
1161 if ((in_y < 0) || (in_y >= input_height)) {
1162 continue;
1163 }
1164 for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
1165 const int in_x = in_x_origin + filter_x;
1166 if ((in_x < 0) || (in_x >= input_width)) {
1167 continue;
1168 }
1169
1170 vld_b_x(v0, &input_data[tflite::Offset(input_shape, batch, in_y,
1171 in_x, in_channel)]); // xp
1172 vld_b_x(v4, &filter_data[tflite::Offset(filter_shape, 0, filter_y,
1173 filter_x, in_channel)]);
1174
1175 vaddw_h_vx(v0, v0, 0);
1176 vadd_h_vx(v0, v0, static_cast<int16_t>(input_offset));
1177 vadd_h_vx(v1, v1,
1178 static_cast<int16_t>(input_offset)); // v0 v1 input
1179
1180 vaddw_h_vx(v4, v4, static_cast<int16_t>(0));
1181 vmulw_w_vv(v8, v0, v4);
1182 vmulw_w_vv(v10, v1, v5);
1183
1184 vadd_w_vv_m(v48, v48, v8);
1185 }
1186 }
1187
1188 vadd_w_vv_m(v48, v48, v20); // add bias
Alex Van Damme40a83002024-05-08 16:47:03 -07001189 vdmulh_w_rn_vv_m(v48, v48, v24);
Naveen Doddabe4ab972024-04-17 17:47:46 +00001190 vsha_w_r_vv_m(v48, v48, v28);
1191 vadd_w_vx_m(v48, v48, output_offset);
1192 vmax_w_vx_m(v48, v48, output_activation_min);
1193 vmin_w_vx_m(v48, v48, output_activation_max);
1194 vsraqs_b_vx(v48, v48, 0);
1195 vst_b_x(v48, &output_data[tflite::Offset(output_shape, batch, out_y,
1196 out_x, output_channel)]);
1197 }
1198 }
1199 }
1200 }
1201}
1202
Lun Dong3b8d3cb2024-05-07 01:50:35 -07001203// generic implementation based on Kelvin ops
1204void DepthwiseConvS8Generic(
1205 const tflite::DepthwiseParams& params, const int32_t* output_multiplier,
1206 const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
1207 const int8_t* input_data, const tflite::RuntimeShape& filter_shape,
1208 const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
1209 const int32_t* bias_data, const tflite::RuntimeShape& output_shape,
1210 int8_t* output_data) {
1211 // TBD: Use Kelvin implementation to replace the below
1212 tflite::reference_integer_ops::DepthwiseConvPerChannel(
1213 params, output_multiplier, output_shift, input_shape, input_data,
1214 filter_shape, filter_data, bias_shape, bias_data, output_shape,
1215 output_data);
1216 return;
1217}
1218} // namespace
1219
1220void DepthwiseConvS8(
Naveen Doddabe4ab972024-04-17 17:47:46 +00001221 const tflite::DepthwiseParams& params, const int32_t* output_multiplier,
1222 const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
1223 const int8_t* input_data, const tflite::RuntimeShape& filter_shape,
1224 const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
1225 const int32_t* bias_data, const tflite::RuntimeShape& output_shape,
1226 int8_t* output_data) {
1227 // Get parameters.
1228 // TODO(b/141565753): Re-introduce ScopedProfilingLabel on Micro.
1229 const int stride_width = params.stride_width;
1230 const int stride_height = params.stride_height;
Alex Van Dammecd3d0e32024-05-10 15:27:06 -07001231 const int pad_width = params.padding_values.width;
1232 const int pad_height = params.padding_values.height;
Alex Van Dammeb1afda62024-05-09 16:48:40 -07001233 const int filter_height = filter_shape.Dims(1);
1234 const int filter_width = filter_shape.Dims(2);
Naveen Doddabe4ab972024-04-17 17:47:46 +00001235 const int dilation_width_factor = params.dilation_width_factor;
1236 const int dilation_height_factor = params.dilation_height_factor;
Naveen Doddabe4ab972024-04-17 17:47:46 +00001237 const int depth_multiplier = params.depth_multiplier;
Naveen Doddabe4ab972024-04-17 17:47:46 +00001238 const int32_t output_activation_min = params.quantized_activation_min;
1239 const int32_t output_activation_max = params.quantized_activation_max;
1240
1241 // Check dimensions of the tensors.
1242 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
1243 TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
1244 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
1245
1246 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
Naveen Doddabe4ab972024-04-17 17:47:46 +00001247 const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
Naveen Doddabe4ab972024-04-17 17:47:46 +00001248 const int input_depth = input_shape.Dims(3);
Naveen Doddabe4ab972024-04-17 17:47:46 +00001249 TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
1250 TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
1251
Alex Van Damme13fe02a2024-06-06 16:27:54 -07001252#define RUN_KERNEL(kernel) { \
1253 kernel(params, output_multiplier, output_shift, input_shape, input_data, \
1254 filter_shape, filter_data, bias_shape, bias_data, output_shape, \
1255 output_data \
1256 ); \
1257 return; \
1258}
1259
Alex Van Damme40a83002024-05-08 16:47:03 -07001260 if (depth_multiplier == 1 &&
Naveen Doddabe4ab972024-04-17 17:47:46 +00001261 dilation_height_factor == 1 && dilation_width_factor == 1 &&
Alex Van Damme40a83002024-05-08 16:47:03 -07001262 stride_height <= 2 && stride_width <= 2) {
Lun Dong3b8d3cb2024-05-07 01:50:35 -07001263 // special case of output depth = 32n
1264 if (output_depth % 32 == 0) {
Alex Van Dammeb1afda62024-05-09 16:48:40 -07001265 if (filter_width == 5 && filter_height == 5) {
1266 if (stride_width <= 1 && stride_height <= 1) {
Alex Van Damme13fe02a2024-06-06 16:27:54 -07001267 RUN_KERNEL(DepthwiseConvS85x5D32_Stride1);
Alex Van Dammeb1afda62024-05-09 16:48:40 -07001268 }
Alex Van Damme13fe02a2024-06-06 16:27:54 -07001269 RUN_KERNEL(DepthwiseConvS85x5D32);
1270 } if (filter_width == 3 && filter_height == 3 && pad_width <= 1 && pad_height <= 1 && stride_width == 1 && stride_height == 1) {
1271 RUN_KERNEL(DepthwiseConvS83x3D32_Stride1);
1272 } if (filter_width == 3 && filter_height == 3 && pad_width <= 1 && pad_height <= 1) {
1273 RUN_KERNEL(DepthwiseConvS83x3D32);
Alex Van Dammeb1afda62024-05-09 16:48:40 -07001274 }
Alex Van Damme13fe02a2024-06-06 16:27:54 -07001275 RUN_KERNEL(DepthwiseConvS8D32);
Lun Dong3b8d3cb2024-05-07 01:50:35 -07001276 }
1277
Alex Van Damme13fe02a2024-06-06 16:27:54 -07001278 RUN_KERNEL(DepthwiseConvS8Generic);
Naveen Doddabe4ab972024-04-17 17:47:46 +00001279 }
Lun Dong3b8d3cb2024-05-07 01:50:35 -07001280
Alex Van Damme13fe02a2024-06-06 16:27:54 -07001281 RUN_KERNEL(tflite::reference_integer_ops::DepthwiseConvPerChannel);
1282
1283#undef RUN_KERNEL
Naveen Doddabe4ab972024-04-17 17:47:46 +00001284}
Lun Dong3b8d3cb2024-05-07 01:50:35 -07001285
1286} // namespace kelvin::opt