blob: 57c6eb51fb53d8b20985a207b6d15ef2d6726dc1 [file] [log] [blame]
Cindy Liu43879e42023-10-18 11:18:03 -07001/*
Lun Dong3b8d3cb2024-05-07 01:50:35 -07002 * Copyright 2024 Google LLC
Cindy Liu43879e42023-10-18 11:18:03 -07003 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
Alex Van Dammef39cadd2023-08-28 16:21:20 -070016
Lun Dong3b8d3cb2024-05-07 01:50:35 -070017// Convolution based on Kelvin ops
18// Data types: input: s16, filter: s8, bias s64
Alex Van Dammef39cadd2023-08-28 16:21:20 -070019
Lun Dong3b8d3cb2024-05-07 01:50:35 -070020#include "tflm/opt/conv_util.h"
Alex Van Dammef39cadd2023-08-28 16:21:20 -070021
22namespace kelvin::opt {
23namespace {
Derek Chowa5f129c2023-11-03 13:48:14 -070024// Accumulates in v0-v7. [v0-v3], [v4-v7] are sub accumulators for two outputs.
25// Load/swizzle filters use [v52-v63].
26// Input activations use [v32-v33].
27// No clobbers.
Lun Dong3b8d3cb2024-05-07 01:50:35 -070028void ConvUkernelS8S16(const int16_t* input_data0, const int8_t* filter_data0,
29 const int8_t* filter_data1, size_t n) {
Derek Chowa5f129c2023-11-03 13:48:14 -070030 n = n >> 5;
31 while (n > 0) {
32 // Load filters 0 to v58, v59
33 vld_b_p_x(v52, filter_data0);
34 vaddw_h_vx(v56, v52, 0);
35 vzip_h_vv(v58, v56, v57);
36
37 // Load activations
38 vld_h_p_x(v32, input_data0);
39 vld_h_p_x(v33, input_data0);
40
41 // Multiply filters0 * activations
42 vmulw_w_vv(v16, v58, v32);
43 vmulw_w_vv(v18, v59, v33);
44
45 // Accumulate v0
46 vadd_w_vv_m(v0, v0, v16);
47
48 // Load filters 1 to v62, v63
49 vld_b_p_x(v53, filter_data1);
50 vaddw_h_vx(v60, v53, 0);
51 vzip_h_vv(v62, v60, v61);
52
53 // Multiply filters1 * activations
54 vmulw_w_vv(v20, v62, v32);
55 vmulw_w_vv(v22, v63, v33);
56
57 // Accumulate v4
58 vadd_w_vv_m(v4, v4, v20);
59 n--;
60 }
61}
62
Lun Dong3b8d3cb2024-05-07 01:50:35 -070063void ConvS16B64K1x1(
Derek Chowa5f129c2023-11-03 13:48:14 -070064 const tflite::ConvParams& params, const int32_t* output_multiplier,
65 const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
66 const int16_t* input_data, const tflite::RuntimeShape& filter_shape,
67 const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
68 const int64_t* bias_data, const tflite::RuntimeShape& output_shape,
69 int16_t* output_data) {
70 const auto batches = MatchingDim(input_shape, 0, output_shape, 0);
71 const auto input_height = input_shape.Dims(1);
72 const auto input_width = input_shape.Dims(2);
73 const auto input_depth = input_shape.Dims(3);
Derek Chowa5f129c2023-11-03 13:48:14 -070074 const auto filter_input_depth = filter_shape.Dims(3);
75 const auto output_depth = output_shape.Dims(3);
76 const auto output_offset = params.output_offset;
77 const auto output_activation_min = params.quantized_activation_min;
78 const auto output_activation_max = params.quantized_activation_max;
79 const auto groups = input_depth / filter_input_depth;
80 const auto output_filters_per_group = output_depth / groups;
81
82 int32_t accumulators[8];
83 for (int bhw = 0; bhw < batches * input_height * input_width; bhw++) {
84 const int16_t* local_input = input_data + (bhw * input_depth);
85 int16_t* local_output = output_data + (bhw * output_depth);
86 for (int g = 0; g < groups; g++) {
87 const int16_t* group_input = local_input + (g * filter_input_depth);
88 for (int gc = 0; gc + 2 <= output_filters_per_group; gc += 2) {
89 int oc = (g * output_filters_per_group) + gc;
90 const int8_t* local_filters0 = filter_data + (oc * filter_input_depth);
91 const int8_t* local_filters1 = local_filters0 + filter_input_depth;
92
93 vdup_w_x_m(v0, 0);
94 vdup_w_x_m(v4, 0);
Lun Dong3b8d3cb2024-05-07 01:50:35 -070095 ConvUkernelS8S16(group_input, local_filters0, local_filters1,
96 filter_input_depth);
Derek Chowa5f129c2023-11-03 13:48:14 -070097 // sum accumulators
98 vadd_w_vv(v0, v0, v1);
99 vadd_w_vv(v2, v2, v3);
100 vadd_w_vv(v0, v0, v2);
101 vadd_w_vv(v4, v4, v5);
102 vadd_w_vv(v6, v6, v7);
103 vadd_w_vv(v4, v4, v6);
104
105 {
106 vst_w_x(v0, accumulators);
107 int64_t acc64 = bias_data[oc];
108 for (int i = 0; i < 8; i++) {
109 acc64 += accumulators[i];
110 }
111 int32_t acc = tflite::MultiplyByQuantizedMultiplier(
Lun Dong3b8d3cb2024-05-07 01:50:35 -0700112 acc64, output_multiplier[oc], output_shift[oc]);
Derek Chowa5f129c2023-11-03 13:48:14 -0700113 acc += output_offset;
114 acc = std::clamp(acc, output_activation_min, output_activation_max);
115 local_output[oc] = static_cast<int16_t>(acc);
116 }
117
118 {
119 vst_w_x(v4, accumulators);
120 int64_t acc64 = bias_data[oc + 1];
121 for (int i = 0; i < 8; i++) {
122 acc64 += accumulators[i];
123 }
124 int32_t acc = tflite::MultiplyByQuantizedMultiplier(
Lun Dong3b8d3cb2024-05-07 01:50:35 -0700125 acc64, output_multiplier[oc + 1], output_shift[oc + 1]);
Derek Chowa5f129c2023-11-03 13:48:14 -0700126 acc += output_offset;
127 acc = std::clamp(acc, output_activation_min, output_activation_max);
128 local_output[oc + 1] = static_cast<int16_t>(acc);
129 }
130 }
131 }
132 }
133}
134
135// Optimized for grouped convolutions, no dilation, 1xn filter
Lun Dong3b8d3cb2024-05-07 01:50:35 -0700136void ConvS16B64K1xnGroup(
Derek Chowa5f129c2023-11-03 13:48:14 -0700137 const tflite::ConvParams& params, const int32_t* output_multiplier,
138 const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
139 const int16_t* input_data, const tflite::RuntimeShape& filter_shape,
140 const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
141 const int64_t* bias_data, const tflite::RuntimeShape& output_shape,
142 int16_t* output_data) {
143 const auto batches = MatchingDim(input_shape, 0, output_shape, 0);
144 const auto stride_width = params.stride_width;
145 const auto pad_width = params.padding_values.width;
146 const auto input_width = input_shape.Dims(2);
147 const auto input_depth = input_shape.Dims(3);
Derek Chowa5f129c2023-11-03 13:48:14 -0700148 const auto filter_width = filter_shape.Dims(2);
149 const auto filter_depth = filter_shape.Dims(3);
150 const auto output_width = output_shape.Dims(2);
151 const auto output_depth = output_shape.Dims(3);
152 const auto output_offset = params.output_offset;
153 const auto output_activation_min = params.quantized_activation_min;
154 const auto output_activation_max = params.quantized_activation_max;
155
156 const auto groups = input_depth / filter_depth;
157 const auto output_filters_per_group = output_depth / groups;
158
159 int32_t accumulators[8];
160 for (int g = 0; g < groups; g++) {
161 for (int gc = 0; gc + 2 <= output_filters_per_group; gc += 2) {
162 int oc = (g * output_filters_per_group) + gc;
163 for (int b = 0; b < batches; ++b) {
164 for (int out_x = 0; out_x < output_width; ++out_x) {
165 const int in_x_origin = out_x * stride_width - pad_width;
166 const int8_t* local_filters0 =
167 filter_data + (oc * filter_width * filter_depth);
168 const int8_t* local_filters1 =
169 local_filters0 + (filter_width * filter_depth);
Lun Dong3b8d3cb2024-05-07 01:50:35 -0700170 const int16_t* local_input =
171 input_data + (b * input_width * input_depth) +
172 (in_x_origin * input_depth) + (g * filter_depth);
Derek Chowa5f129c2023-11-03 13:48:14 -0700173 int16_t* local_output = output_data +
Lun Dong3b8d3cb2024-05-07 01:50:35 -0700174 (b * output_width * output_depth) +
175 (out_x * output_depth);
Derek Chowa5f129c2023-11-03 13:48:14 -0700176
177 int64_t acc64_0 = 0;
178 int64_t acc64_1 = 0;
179 vdup_w_x_m(v0, 0);
180 vdup_w_x_m(v4, 0);
181 for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
182 const int8_t* local_filters0x =
183 local_filters0 + (filter_x * filter_depth);
184 const int8_t* local_filters1x =
185 local_filters1 + (filter_x * filter_depth);
186 const int16_t* local_inputx =
187 local_input + (filter_x * input_depth);
188
Lun Dong3b8d3cb2024-05-07 01:50:35 -0700189 ConvUkernelS8S16(local_inputx, local_filters0x, local_filters1x,
190 filter_depth);
Derek Chowa5f129c2023-11-03 13:48:14 -0700191 }
192
193 // sum accumulators
194 vadd_w_vv(v0, v0, v1);
195 vadd_w_vv(v2, v2, v3);
196 vadd_w_vv(v0, v0, v2);
197 vadd_w_vv(v4, v4, v5);
198 vadd_w_vv(v6, v6, v7);
199 vadd_w_vv(v4, v4, v6);
200
201 {
202 vst_w_x(v0, accumulators);
203 for (int i = 0; i < 8; i++) {
204 acc64_0 += accumulators[i];
205 }
206 acc64_0 += bias_data[oc];
207 int32_t acc = tflite::MultiplyByQuantizedMultiplier(
Lun Dong3b8d3cb2024-05-07 01:50:35 -0700208 acc64_0, output_multiplier[oc], output_shift[oc]);
Derek Chowa5f129c2023-11-03 13:48:14 -0700209 acc += output_offset;
210 acc = std::clamp(acc, output_activation_min, output_activation_max);
211 local_output[oc] = static_cast<int16_t>(acc);
212 }
213
214 {
215 vst_w_x(v4, accumulators);
216 for (int i = 0; i < 8; i++) {
217 acc64_1 += accumulators[i];
218 }
219 acc64_1 += bias_data[oc + 1];
220 int32_t acc = tflite::MultiplyByQuantizedMultiplier(
Lun Dong3b8d3cb2024-05-07 01:50:35 -0700221 acc64_1, output_multiplier[oc + 1], output_shift[oc + 1]);
Derek Chowa5f129c2023-11-03 13:48:14 -0700222 acc += output_offset;
223 acc = std::clamp(acc, output_activation_min, output_activation_max);
224 local_output[oc + 1] = static_cast<int16_t>(acc);
225 }
226 }
227 }
228 }
229 }
230}
231
232// Optimized for no group, no dilation, 1xn filter.
Lun Dong3b8d3cb2024-05-07 01:50:35 -0700233void ConvS16B64K1xnNonGroup(
Derek Chowa5f129c2023-11-03 13:48:14 -0700234 const tflite::ConvParams& params, const int32_t* output_multiplier,
235 const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
236 const int16_t* input_data, const tflite::RuntimeShape& filter_shape,
237 const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
238 const int64_t* bias_data, const tflite::RuntimeShape& output_shape,
239 int16_t* output_data) {
240 const auto batches = MatchingDim(input_shape, 0, output_shape, 0);
241 const auto stride_width = params.stride_width;
242 const auto pad_width = params.padding_values.width;
243 const auto input_width = input_shape.Dims(2);
244 const auto input_depth = input_shape.Dims(3);
Derek Chowa5f129c2023-11-03 13:48:14 -0700245 const auto filter_width = filter_shape.Dims(2);
246 const auto filter_depth = filter_shape.Dims(3);
247 const auto output_width = output_shape.Dims(2);
248 const auto output_depth = output_shape.Dims(3);
249 const auto output_offset = params.output_offset;
250 const auto output_activation_min = params.quantized_activation_min;
251 const auto output_activation_max = params.quantized_activation_max;
252 int32_t accumulators[8];
253 for (int oc = 0; oc + 2 <= output_depth; oc += 2) {
254 for (int batch = 0; batch < batches; ++batch) {
255 for (int out_x = 0; out_x < output_width; ++out_x) {
256 const int in_x_origin = out_x * stride_width - pad_width;
257
258 const int8_t* local_filters0 =
259 filter_data + (oc * filter_width * filter_depth);
260 const int8_t* local_filters1 =
261 local_filters0 + (filter_width * filter_depth);
262 const int16_t* local_input = input_data +
Lun Dong3b8d3cb2024-05-07 01:50:35 -0700263 (batch * input_width * input_depth) +
264 (in_x_origin * input_depth);
Derek Chowa5f129c2023-11-03 13:48:14 -0700265 int16_t* local_output = output_data +
Lun Dong3b8d3cb2024-05-07 01:50:35 -0700266 (batch * output_width * output_depth) +
267 (out_x * output_depth);
Derek Chowa5f129c2023-11-03 13:48:14 -0700268
269 vdup_w_x_m(v0, 0);
270 vdup_w_x_m(v4, 0);
Lun Dong3b8d3cb2024-05-07 01:50:35 -0700271 ConvUkernelS8S16(local_input, local_filters0, local_filters1,
272 filter_width * filter_depth);
Derek Chowa5f129c2023-11-03 13:48:14 -0700273 // sum accumulators
274 vadd_w_vv(v0, v0, v1);
275 vadd_w_vv(v2, v2, v3);
276 vadd_w_vv(v0, v0, v2);
277 vadd_w_vv(v4, v4, v5);
278 vadd_w_vv(v6, v6, v7);
279 vadd_w_vv(v4, v4, v6);
280 {
281 vst_w_x(v0, accumulators);
282 int64_t acc64 = bias_data[oc];
283 for (int i = 0; i < 8; i++) {
284 acc64 += accumulators[i];
285 }
286 int32_t acc = tflite::MultiplyByQuantizedMultiplier(
Lun Dong3b8d3cb2024-05-07 01:50:35 -0700287 acc64, output_multiplier[oc], output_shift[oc]);
Derek Chowa5f129c2023-11-03 13:48:14 -0700288 acc += output_offset;
289 acc = std::clamp(acc, output_activation_min, output_activation_max);
290 local_output[oc] = static_cast<int16_t>(acc);
291 }
292
293 {
294 vst_w_x(v4, accumulators);
295 int64_t acc64 = bias_data[oc + 1];
296 for (int i = 0; i < 8; i++) {
297 acc64 += accumulators[i];
298 }
299 int32_t acc = tflite::MultiplyByQuantizedMultiplier(
Lun Dong3b8d3cb2024-05-07 01:50:35 -0700300 acc64, output_multiplier[oc + 1], output_shift[oc + 1]);
Derek Chowa5f129c2023-11-03 13:48:14 -0700301 acc += output_offset;
302 acc = std::clamp(acc, output_activation_min, output_activation_max);
303 local_output[oc + 1] = static_cast<int16_t>(acc);
304 }
305 }
306 }
307 }
308}
309
Lun Dong3b8d3cb2024-05-07 01:50:35 -0700310void ConvS16B64Generic(
Alex Van Dammef39cadd2023-08-28 16:21:20 -0700311 const tflite::ConvParams& params, const int32_t* output_multiplier,
312 const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
313 const int16_t* input_data, const tflite::RuntimeShape& filter_shape,
314 const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
315 const int64_t* bias_data, const tflite::RuntimeShape& output_shape,
316 int16_t* output_data) {
317 const auto batches = MatchingDim(input_shape, 0, output_shape, 0);
318 const auto stride_width = params.stride_width;
319 const auto stride_height = params.stride_height;
320 const auto dilation_width_factor = params.dilation_width_factor;
321 const auto dilation_height_factor = params.dilation_height_factor;
322 const auto pad_width = params.padding_values.width;
323 const auto pad_height = params.padding_values.height;
324 const auto input_height = input_shape.Dims(1);
325 const auto input_width = input_shape.Dims(2);
326 const auto input_depth = input_shape.Dims(3);
327 const auto input_offset = params.input_offset;
328 const auto filter_height = filter_shape.Dims(1);
329 const auto filter_width = filter_shape.Dims(2);
330 const auto filter_depth = filter_shape.Dims(3);
331 const auto output_height = output_shape.Dims(1);
332 const auto output_width = output_shape.Dims(2);
333 const auto output_depth = output_shape.Dims(3);
334 const auto output_offset = params.output_offset;
335 const auto output_activation_min = params.quantized_activation_min;
336 const auto output_activation_max = params.quantized_activation_max;
337 const auto groups = input_depth / filter_depth;
338 const auto filters_per_group = output_depth / groups;
339 for (int batch = 0; batch < batches; ++batch) {
340 for (int out_y = 0; out_y < output_height; ++out_y) {
341 const int in_y_origin = out_y * stride_height - pad_height;
342 for (int out_x = 0; out_x < output_width; ++out_x) {
343 const int in_x_origin = out_x * stride_width - pad_width;
344 for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
345 auto group = out_channel / filters_per_group;
346 int64_t acc64 = 0;
347 for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
348 const int in_y = in_y_origin + dilation_height_factor * filter_y;
349 for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
350 const int in_x = in_x_origin + dilation_width_factor * filter_x;
351 const bool inside = (in_x >= 0) && (in_x < input_width) &&
352 (in_y >= 0) && (in_y < input_height);
353 if (!inside) {
354 continue;
355 }
356
357 int in_channel = 0;
358 do {
359 int load_count = std::min(filter_depth - in_channel, 16L);
360 int32_t input_swizzled[16];
361 const int16_t* p_input = &input_data[tflite::Offset(
362 input_shape, batch, in_y, in_x,
363 in_channel + group * filter_depth)];
364 for (int i = 0; i < 16; ++i) {
365 int swizzle_idx = swizzle[i];
366 if (swizzle_idx < load_count)
367 input_swizzled[i] = *(p_input + swizzle_idx) + input_offset;
368 else
369 input_swizzled[i] = 0;
370 }
371 vld_w_l_xx(v0, input_swizzled, 4);
372 vld_w_l_xx(v1, input_swizzled + 4, 4);
373 vld_w_l_xx(v2, input_swizzled + 8, 4);
374 vld_w_l_xx(v3, input_swizzled + 12, 4);
375 vld_b_l_xx(v4,
376 &filter_data[tflite::Offset(filter_shape,
377 out_channel, filter_y,
378 filter_x, in_channel)],
379 load_count);
380 vaddw_h_vx(v4, v4, 0);
381 vaddw_w_vx(v6, v5, 0);
382 vaddw_w_vx(v4, v4, 0);
383
384 vmul_w_vv_m(vm0, vm0, vm1);
385 vadd_w_vv(v0, v0, v1);
386 vadd_w_vv(v0, v0, v2);
387 vadd_w_vv(v0, v0, v3);
388 int32_t acc32[4];
389 vst_w_l_xx(v0, acc32, 4);
390 for (int i = 0; i < 4; ++i) {
391 acc64 += acc32[i];
392 }
393 in_channel += 16;
394 } while (in_channel + 16 <= filter_depth);
395 }
396 }
397 if (bias_data) {
398 acc64 = acc64 + bias_data[out_channel];
399 }
400 int32_t acc = tflite::MultiplyByQuantizedMultiplier(
401 acc64, output_multiplier[out_channel], output_shift[out_channel]);
402 acc += output_offset;
403 acc = std::clamp(acc, output_activation_min, output_activation_max);
404 output_data[tflite::Offset(output_shape, batch, out_y, out_x,
405 out_channel)] = static_cast<int16_t>(acc);
406 }
407 }
408 }
409 }
410}
Lun Dong3b8d3cb2024-05-07 01:50:35 -0700411} // namespace
Alex Van Dammef39cadd2023-08-28 16:21:20 -0700412
Lun Dong3b8d3cb2024-05-07 01:50:35 -0700413void ConvS16B64(
Derek Chow903c3112023-11-13 10:17:15 -0800414 const tflite::ConvParams& params, const int32_t* output_multiplier,
415 const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
416 const int16_t* input_data, const tflite::RuntimeShape& filter_shape,
417 const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
418 const int64_t* bias_data, const tflite::RuntimeShape& output_shape,
419 int16_t* output_data) {
Alex Van Damme008f0ae2024-06-18 14:10:02 -0700420 const auto input_height = input_shape.Dims(1);
Lun Dong3b8d3cb2024-05-07 01:50:35 -0700421 const auto input_depth = input_shape.Dims(3);
422 const auto filter_height = filter_shape.Dims(1);
423 const auto filter_width = filter_shape.Dims(2);
424 const auto filter_depth = filter_shape.Dims(3);
425 const auto output_depth = output_shape.Dims(3);
426
427 // generic implementation by default
428 auto fn = ConvS16B64Generic;
429
430 // special cases
431 if (filter_height == 1 && output_depth % 2 == 0) {
432 // 1x1 filter, filter depth = 32n
433 if (filter_width == 1 && filter_depth % 32 == 0) {
434 fn = ConvS16B64K1x1;
Derek Chow903c3112023-11-13 10:17:15 -0800435 }
436
Lun Dong3b8d3cb2024-05-07 01:50:35 -0700437 // 1xn non group filter
438 bool group_conv = !(input_depth == filter_depth);
439 int32_t fan_in = filter_width * filter_depth;
Alex Van Damme008f0ae2024-06-18 14:10:02 -0700440 if (!group_conv && fan_in % 32 == 0 && input_height == 1) {
Lun Dong3b8d3cb2024-05-07 01:50:35 -0700441 fn = ConvS16B64K1xnNonGroup;
Derek Chow903c3112023-11-13 10:17:15 -0800442 }
443
Lun Dong3b8d3cb2024-05-07 01:50:35 -0700444 // 1xn group filter
Alex Van Damme008f0ae2024-06-18 14:10:02 -0700445 if (fan_in % 32 == 0 && input_height == 1) {
Lun Dong3b8d3cb2024-05-07 01:50:35 -0700446 fn = ConvS16B64K1xnGroup;
Derek Chow903c3112023-11-13 10:17:15 -0800447 }
448 }
449
Lun Dong3b8d3cb2024-05-07 01:50:35 -0700450 fn(params, output_multiplier, output_shift, input_shape, input_data,
451 filter_shape, filter_data, bias_shape, bias_data, output_shape,
452 output_data);
Derek Chow903c3112023-11-13 10:17:15 -0800453}
454
Alex Van Dammef39cadd2023-08-28 16:21:20 -0700455} // namespace kelvin::opt