Cindy Liu | 43879e4 | 2023-10-18 11:18:03 -0700 | [diff] [blame] | 1 | /* |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 2 | * Copyright 2024 Google LLC |
Cindy Liu | 43879e4 | 2023-10-18 11:18:03 -0700 | [diff] [blame] | 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | */ |
Alex Van Damme | f39cadd | 2023-08-28 16:21:20 -0700 | [diff] [blame] | 16 | |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 17 | // Convolution based on Kelvin ops |
| 18 | // Data types: input: s16, filter: s8, bias s64 |
Alex Van Damme | f39cadd | 2023-08-28 16:21:20 -0700 | [diff] [blame] | 19 | |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 20 | #include "tflm/opt/conv_util.h" |
Alex Van Damme | f39cadd | 2023-08-28 16:21:20 -0700 | [diff] [blame] | 21 | |
| 22 | namespace kelvin::opt { |
| 23 | namespace { |
Derek Chow | a5f129c | 2023-11-03 13:48:14 -0700 | [diff] [blame] | 24 | // Accumulates in v0-v7. [v0-v3], [v4-v7] are sub accumulators for two outputs. |
| 25 | // Load/swizzle filters use [v52-v63]. |
| 26 | // Input activations use [v32-v33]. |
| 27 | // No clobbers. |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 28 | void ConvUkernelS8S16(const int16_t* input_data0, const int8_t* filter_data0, |
| 29 | const int8_t* filter_data1, size_t n) { |
Derek Chow | a5f129c | 2023-11-03 13:48:14 -0700 | [diff] [blame] | 30 | n = n >> 5; |
| 31 | while (n > 0) { |
| 32 | // Load filters 0 to v58, v59 |
| 33 | vld_b_p_x(v52, filter_data0); |
| 34 | vaddw_h_vx(v56, v52, 0); |
| 35 | vzip_h_vv(v58, v56, v57); |
| 36 | |
| 37 | // Load activations |
| 38 | vld_h_p_x(v32, input_data0); |
| 39 | vld_h_p_x(v33, input_data0); |
| 40 | |
| 41 | // Multiply filters0 * activations |
| 42 | vmulw_w_vv(v16, v58, v32); |
| 43 | vmulw_w_vv(v18, v59, v33); |
| 44 | |
| 45 | // Accumulate v0 |
| 46 | vadd_w_vv_m(v0, v0, v16); |
| 47 | |
| 48 | // Load filters 1 to v62, v63 |
| 49 | vld_b_p_x(v53, filter_data1); |
| 50 | vaddw_h_vx(v60, v53, 0); |
| 51 | vzip_h_vv(v62, v60, v61); |
| 52 | |
| 53 | // Multiply filters1 * activations |
| 54 | vmulw_w_vv(v20, v62, v32); |
| 55 | vmulw_w_vv(v22, v63, v33); |
| 56 | |
| 57 | // Accumulate v4 |
| 58 | vadd_w_vv_m(v4, v4, v20); |
| 59 | n--; |
| 60 | } |
| 61 | } |
| 62 | |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 63 | void ConvS16B64K1x1( |
Derek Chow | a5f129c | 2023-11-03 13:48:14 -0700 | [diff] [blame] | 64 | const tflite::ConvParams& params, const int32_t* output_multiplier, |
| 65 | const int32_t* output_shift, const tflite::RuntimeShape& input_shape, |
| 66 | const int16_t* input_data, const tflite::RuntimeShape& filter_shape, |
| 67 | const int8_t* filter_data, const tflite::RuntimeShape& bias_shape, |
| 68 | const int64_t* bias_data, const tflite::RuntimeShape& output_shape, |
| 69 | int16_t* output_data) { |
| 70 | const auto batches = MatchingDim(input_shape, 0, output_shape, 0); |
| 71 | const auto input_height = input_shape.Dims(1); |
| 72 | const auto input_width = input_shape.Dims(2); |
| 73 | const auto input_depth = input_shape.Dims(3); |
Derek Chow | a5f129c | 2023-11-03 13:48:14 -0700 | [diff] [blame] | 74 | const auto filter_input_depth = filter_shape.Dims(3); |
| 75 | const auto output_depth = output_shape.Dims(3); |
| 76 | const auto output_offset = params.output_offset; |
| 77 | const auto output_activation_min = params.quantized_activation_min; |
| 78 | const auto output_activation_max = params.quantized_activation_max; |
| 79 | const auto groups = input_depth / filter_input_depth; |
| 80 | const auto output_filters_per_group = output_depth / groups; |
| 81 | |
| 82 | int32_t accumulators[8]; |
| 83 | for (int bhw = 0; bhw < batches * input_height * input_width; bhw++) { |
| 84 | const int16_t* local_input = input_data + (bhw * input_depth); |
| 85 | int16_t* local_output = output_data + (bhw * output_depth); |
| 86 | for (int g = 0; g < groups; g++) { |
| 87 | const int16_t* group_input = local_input + (g * filter_input_depth); |
| 88 | for (int gc = 0; gc + 2 <= output_filters_per_group; gc += 2) { |
| 89 | int oc = (g * output_filters_per_group) + gc; |
| 90 | const int8_t* local_filters0 = filter_data + (oc * filter_input_depth); |
| 91 | const int8_t* local_filters1 = local_filters0 + filter_input_depth; |
| 92 | |
| 93 | vdup_w_x_m(v0, 0); |
| 94 | vdup_w_x_m(v4, 0); |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 95 | ConvUkernelS8S16(group_input, local_filters0, local_filters1, |
| 96 | filter_input_depth); |
Derek Chow | a5f129c | 2023-11-03 13:48:14 -0700 | [diff] [blame] | 97 | // sum accumulators |
| 98 | vadd_w_vv(v0, v0, v1); |
| 99 | vadd_w_vv(v2, v2, v3); |
| 100 | vadd_w_vv(v0, v0, v2); |
| 101 | vadd_w_vv(v4, v4, v5); |
| 102 | vadd_w_vv(v6, v6, v7); |
| 103 | vadd_w_vv(v4, v4, v6); |
| 104 | |
| 105 | { |
| 106 | vst_w_x(v0, accumulators); |
| 107 | int64_t acc64 = bias_data[oc]; |
| 108 | for (int i = 0; i < 8; i++) { |
| 109 | acc64 += accumulators[i]; |
| 110 | } |
| 111 | int32_t acc = tflite::MultiplyByQuantizedMultiplier( |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 112 | acc64, output_multiplier[oc], output_shift[oc]); |
Derek Chow | a5f129c | 2023-11-03 13:48:14 -0700 | [diff] [blame] | 113 | acc += output_offset; |
| 114 | acc = std::clamp(acc, output_activation_min, output_activation_max); |
| 115 | local_output[oc] = static_cast<int16_t>(acc); |
| 116 | } |
| 117 | |
| 118 | { |
| 119 | vst_w_x(v4, accumulators); |
| 120 | int64_t acc64 = bias_data[oc + 1]; |
| 121 | for (int i = 0; i < 8; i++) { |
| 122 | acc64 += accumulators[i]; |
| 123 | } |
| 124 | int32_t acc = tflite::MultiplyByQuantizedMultiplier( |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 125 | acc64, output_multiplier[oc + 1], output_shift[oc + 1]); |
Derek Chow | a5f129c | 2023-11-03 13:48:14 -0700 | [diff] [blame] | 126 | acc += output_offset; |
| 127 | acc = std::clamp(acc, output_activation_min, output_activation_max); |
| 128 | local_output[oc + 1] = static_cast<int16_t>(acc); |
| 129 | } |
| 130 | } |
| 131 | } |
| 132 | } |
| 133 | } |
| 134 | |
| 135 | // Optimized for grouped convolutions, no dilation, 1xn filter |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 136 | void ConvS16B64K1xnGroup( |
Derek Chow | a5f129c | 2023-11-03 13:48:14 -0700 | [diff] [blame] | 137 | const tflite::ConvParams& params, const int32_t* output_multiplier, |
| 138 | const int32_t* output_shift, const tflite::RuntimeShape& input_shape, |
| 139 | const int16_t* input_data, const tflite::RuntimeShape& filter_shape, |
| 140 | const int8_t* filter_data, const tflite::RuntimeShape& bias_shape, |
| 141 | const int64_t* bias_data, const tflite::RuntimeShape& output_shape, |
| 142 | int16_t* output_data) { |
| 143 | const auto batches = MatchingDim(input_shape, 0, output_shape, 0); |
| 144 | const auto stride_width = params.stride_width; |
| 145 | const auto pad_width = params.padding_values.width; |
| 146 | const auto input_width = input_shape.Dims(2); |
| 147 | const auto input_depth = input_shape.Dims(3); |
Derek Chow | a5f129c | 2023-11-03 13:48:14 -0700 | [diff] [blame] | 148 | const auto filter_width = filter_shape.Dims(2); |
| 149 | const auto filter_depth = filter_shape.Dims(3); |
| 150 | const auto output_width = output_shape.Dims(2); |
| 151 | const auto output_depth = output_shape.Dims(3); |
| 152 | const auto output_offset = params.output_offset; |
| 153 | const auto output_activation_min = params.quantized_activation_min; |
| 154 | const auto output_activation_max = params.quantized_activation_max; |
| 155 | |
| 156 | const auto groups = input_depth / filter_depth; |
| 157 | const auto output_filters_per_group = output_depth / groups; |
| 158 | |
| 159 | int32_t accumulators[8]; |
| 160 | for (int g = 0; g < groups; g++) { |
| 161 | for (int gc = 0; gc + 2 <= output_filters_per_group; gc += 2) { |
| 162 | int oc = (g * output_filters_per_group) + gc; |
| 163 | for (int b = 0; b < batches; ++b) { |
| 164 | for (int out_x = 0; out_x < output_width; ++out_x) { |
| 165 | const int in_x_origin = out_x * stride_width - pad_width; |
| 166 | const int8_t* local_filters0 = |
| 167 | filter_data + (oc * filter_width * filter_depth); |
| 168 | const int8_t* local_filters1 = |
| 169 | local_filters0 + (filter_width * filter_depth); |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 170 | const int16_t* local_input = |
| 171 | input_data + (b * input_width * input_depth) + |
| 172 | (in_x_origin * input_depth) + (g * filter_depth); |
Derek Chow | a5f129c | 2023-11-03 13:48:14 -0700 | [diff] [blame] | 173 | int16_t* local_output = output_data + |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 174 | (b * output_width * output_depth) + |
| 175 | (out_x * output_depth); |
Derek Chow | a5f129c | 2023-11-03 13:48:14 -0700 | [diff] [blame] | 176 | |
| 177 | int64_t acc64_0 = 0; |
| 178 | int64_t acc64_1 = 0; |
| 179 | vdup_w_x_m(v0, 0); |
| 180 | vdup_w_x_m(v4, 0); |
| 181 | for (int filter_x = 0; filter_x < filter_width; ++filter_x) { |
| 182 | const int8_t* local_filters0x = |
| 183 | local_filters0 + (filter_x * filter_depth); |
| 184 | const int8_t* local_filters1x = |
| 185 | local_filters1 + (filter_x * filter_depth); |
| 186 | const int16_t* local_inputx = |
| 187 | local_input + (filter_x * input_depth); |
| 188 | |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 189 | ConvUkernelS8S16(local_inputx, local_filters0x, local_filters1x, |
| 190 | filter_depth); |
Derek Chow | a5f129c | 2023-11-03 13:48:14 -0700 | [diff] [blame] | 191 | } |
| 192 | |
| 193 | // sum accumulators |
| 194 | vadd_w_vv(v0, v0, v1); |
| 195 | vadd_w_vv(v2, v2, v3); |
| 196 | vadd_w_vv(v0, v0, v2); |
| 197 | vadd_w_vv(v4, v4, v5); |
| 198 | vadd_w_vv(v6, v6, v7); |
| 199 | vadd_w_vv(v4, v4, v6); |
| 200 | |
| 201 | { |
| 202 | vst_w_x(v0, accumulators); |
| 203 | for (int i = 0; i < 8; i++) { |
| 204 | acc64_0 += accumulators[i]; |
| 205 | } |
| 206 | acc64_0 += bias_data[oc]; |
| 207 | int32_t acc = tflite::MultiplyByQuantizedMultiplier( |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 208 | acc64_0, output_multiplier[oc], output_shift[oc]); |
Derek Chow | a5f129c | 2023-11-03 13:48:14 -0700 | [diff] [blame] | 209 | acc += output_offset; |
| 210 | acc = std::clamp(acc, output_activation_min, output_activation_max); |
| 211 | local_output[oc] = static_cast<int16_t>(acc); |
| 212 | } |
| 213 | |
| 214 | { |
| 215 | vst_w_x(v4, accumulators); |
| 216 | for (int i = 0; i < 8; i++) { |
| 217 | acc64_1 += accumulators[i]; |
| 218 | } |
| 219 | acc64_1 += bias_data[oc + 1]; |
| 220 | int32_t acc = tflite::MultiplyByQuantizedMultiplier( |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 221 | acc64_1, output_multiplier[oc + 1], output_shift[oc + 1]); |
Derek Chow | a5f129c | 2023-11-03 13:48:14 -0700 | [diff] [blame] | 222 | acc += output_offset; |
| 223 | acc = std::clamp(acc, output_activation_min, output_activation_max); |
| 224 | local_output[oc + 1] = static_cast<int16_t>(acc); |
| 225 | } |
| 226 | } |
| 227 | } |
| 228 | } |
| 229 | } |
| 230 | } |
| 231 | |
| 232 | // Optimized for no group, no dilation, 1xn filter. |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 233 | void ConvS16B64K1xnNonGroup( |
Derek Chow | a5f129c | 2023-11-03 13:48:14 -0700 | [diff] [blame] | 234 | const tflite::ConvParams& params, const int32_t* output_multiplier, |
| 235 | const int32_t* output_shift, const tflite::RuntimeShape& input_shape, |
| 236 | const int16_t* input_data, const tflite::RuntimeShape& filter_shape, |
| 237 | const int8_t* filter_data, const tflite::RuntimeShape& bias_shape, |
| 238 | const int64_t* bias_data, const tflite::RuntimeShape& output_shape, |
| 239 | int16_t* output_data) { |
| 240 | const auto batches = MatchingDim(input_shape, 0, output_shape, 0); |
| 241 | const auto stride_width = params.stride_width; |
| 242 | const auto pad_width = params.padding_values.width; |
| 243 | const auto input_width = input_shape.Dims(2); |
| 244 | const auto input_depth = input_shape.Dims(3); |
Derek Chow | a5f129c | 2023-11-03 13:48:14 -0700 | [diff] [blame] | 245 | const auto filter_width = filter_shape.Dims(2); |
| 246 | const auto filter_depth = filter_shape.Dims(3); |
| 247 | const auto output_width = output_shape.Dims(2); |
| 248 | const auto output_depth = output_shape.Dims(3); |
| 249 | const auto output_offset = params.output_offset; |
| 250 | const auto output_activation_min = params.quantized_activation_min; |
| 251 | const auto output_activation_max = params.quantized_activation_max; |
| 252 | int32_t accumulators[8]; |
| 253 | for (int oc = 0; oc + 2 <= output_depth; oc += 2) { |
| 254 | for (int batch = 0; batch < batches; ++batch) { |
| 255 | for (int out_x = 0; out_x < output_width; ++out_x) { |
| 256 | const int in_x_origin = out_x * stride_width - pad_width; |
| 257 | |
| 258 | const int8_t* local_filters0 = |
| 259 | filter_data + (oc * filter_width * filter_depth); |
| 260 | const int8_t* local_filters1 = |
| 261 | local_filters0 + (filter_width * filter_depth); |
| 262 | const int16_t* local_input = input_data + |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 263 | (batch * input_width * input_depth) + |
| 264 | (in_x_origin * input_depth); |
Derek Chow | a5f129c | 2023-11-03 13:48:14 -0700 | [diff] [blame] | 265 | int16_t* local_output = output_data + |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 266 | (batch * output_width * output_depth) + |
| 267 | (out_x * output_depth); |
Derek Chow | a5f129c | 2023-11-03 13:48:14 -0700 | [diff] [blame] | 268 | |
| 269 | vdup_w_x_m(v0, 0); |
| 270 | vdup_w_x_m(v4, 0); |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 271 | ConvUkernelS8S16(local_input, local_filters0, local_filters1, |
| 272 | filter_width * filter_depth); |
Derek Chow | a5f129c | 2023-11-03 13:48:14 -0700 | [diff] [blame] | 273 | // sum accumulators |
| 274 | vadd_w_vv(v0, v0, v1); |
| 275 | vadd_w_vv(v2, v2, v3); |
| 276 | vadd_w_vv(v0, v0, v2); |
| 277 | vadd_w_vv(v4, v4, v5); |
| 278 | vadd_w_vv(v6, v6, v7); |
| 279 | vadd_w_vv(v4, v4, v6); |
| 280 | { |
| 281 | vst_w_x(v0, accumulators); |
| 282 | int64_t acc64 = bias_data[oc]; |
| 283 | for (int i = 0; i < 8; i++) { |
| 284 | acc64 += accumulators[i]; |
| 285 | } |
| 286 | int32_t acc = tflite::MultiplyByQuantizedMultiplier( |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 287 | acc64, output_multiplier[oc], output_shift[oc]); |
Derek Chow | a5f129c | 2023-11-03 13:48:14 -0700 | [diff] [blame] | 288 | acc += output_offset; |
| 289 | acc = std::clamp(acc, output_activation_min, output_activation_max); |
| 290 | local_output[oc] = static_cast<int16_t>(acc); |
| 291 | } |
| 292 | |
| 293 | { |
| 294 | vst_w_x(v4, accumulators); |
| 295 | int64_t acc64 = bias_data[oc + 1]; |
| 296 | for (int i = 0; i < 8; i++) { |
| 297 | acc64 += accumulators[i]; |
| 298 | } |
| 299 | int32_t acc = tflite::MultiplyByQuantizedMultiplier( |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 300 | acc64, output_multiplier[oc + 1], output_shift[oc + 1]); |
Derek Chow | a5f129c | 2023-11-03 13:48:14 -0700 | [diff] [blame] | 301 | acc += output_offset; |
| 302 | acc = std::clamp(acc, output_activation_min, output_activation_max); |
| 303 | local_output[oc + 1] = static_cast<int16_t>(acc); |
| 304 | } |
| 305 | } |
| 306 | } |
| 307 | } |
| 308 | } |
| 309 | |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 310 | void ConvS16B64Generic( |
Alex Van Damme | f39cadd | 2023-08-28 16:21:20 -0700 | [diff] [blame] | 311 | const tflite::ConvParams& params, const int32_t* output_multiplier, |
| 312 | const int32_t* output_shift, const tflite::RuntimeShape& input_shape, |
| 313 | const int16_t* input_data, const tflite::RuntimeShape& filter_shape, |
| 314 | const int8_t* filter_data, const tflite::RuntimeShape& bias_shape, |
| 315 | const int64_t* bias_data, const tflite::RuntimeShape& output_shape, |
| 316 | int16_t* output_data) { |
| 317 | const auto batches = MatchingDim(input_shape, 0, output_shape, 0); |
| 318 | const auto stride_width = params.stride_width; |
| 319 | const auto stride_height = params.stride_height; |
| 320 | const auto dilation_width_factor = params.dilation_width_factor; |
| 321 | const auto dilation_height_factor = params.dilation_height_factor; |
| 322 | const auto pad_width = params.padding_values.width; |
| 323 | const auto pad_height = params.padding_values.height; |
| 324 | const auto input_height = input_shape.Dims(1); |
| 325 | const auto input_width = input_shape.Dims(2); |
| 326 | const auto input_depth = input_shape.Dims(3); |
| 327 | const auto input_offset = params.input_offset; |
| 328 | const auto filter_height = filter_shape.Dims(1); |
| 329 | const auto filter_width = filter_shape.Dims(2); |
| 330 | const auto filter_depth = filter_shape.Dims(3); |
| 331 | const auto output_height = output_shape.Dims(1); |
| 332 | const auto output_width = output_shape.Dims(2); |
| 333 | const auto output_depth = output_shape.Dims(3); |
| 334 | const auto output_offset = params.output_offset; |
| 335 | const auto output_activation_min = params.quantized_activation_min; |
| 336 | const auto output_activation_max = params.quantized_activation_max; |
| 337 | const auto groups = input_depth / filter_depth; |
| 338 | const auto filters_per_group = output_depth / groups; |
| 339 | for (int batch = 0; batch < batches; ++batch) { |
| 340 | for (int out_y = 0; out_y < output_height; ++out_y) { |
| 341 | const int in_y_origin = out_y * stride_height - pad_height; |
| 342 | for (int out_x = 0; out_x < output_width; ++out_x) { |
| 343 | const int in_x_origin = out_x * stride_width - pad_width; |
| 344 | for (int out_channel = 0; out_channel < output_depth; ++out_channel) { |
| 345 | auto group = out_channel / filters_per_group; |
| 346 | int64_t acc64 = 0; |
| 347 | for (int filter_y = 0; filter_y < filter_height; ++filter_y) { |
| 348 | const int in_y = in_y_origin + dilation_height_factor * filter_y; |
| 349 | for (int filter_x = 0; filter_x < filter_width; ++filter_x) { |
| 350 | const int in_x = in_x_origin + dilation_width_factor * filter_x; |
| 351 | const bool inside = (in_x >= 0) && (in_x < input_width) && |
| 352 | (in_y >= 0) && (in_y < input_height); |
| 353 | if (!inside) { |
| 354 | continue; |
| 355 | } |
| 356 | |
| 357 | int in_channel = 0; |
| 358 | do { |
| 359 | int load_count = std::min(filter_depth - in_channel, 16L); |
| 360 | int32_t input_swizzled[16]; |
| 361 | const int16_t* p_input = &input_data[tflite::Offset( |
| 362 | input_shape, batch, in_y, in_x, |
| 363 | in_channel + group * filter_depth)]; |
| 364 | for (int i = 0; i < 16; ++i) { |
| 365 | int swizzle_idx = swizzle[i]; |
| 366 | if (swizzle_idx < load_count) |
| 367 | input_swizzled[i] = *(p_input + swizzle_idx) + input_offset; |
| 368 | else |
| 369 | input_swizzled[i] = 0; |
| 370 | } |
| 371 | vld_w_l_xx(v0, input_swizzled, 4); |
| 372 | vld_w_l_xx(v1, input_swizzled + 4, 4); |
| 373 | vld_w_l_xx(v2, input_swizzled + 8, 4); |
| 374 | vld_w_l_xx(v3, input_swizzled + 12, 4); |
| 375 | vld_b_l_xx(v4, |
| 376 | &filter_data[tflite::Offset(filter_shape, |
| 377 | out_channel, filter_y, |
| 378 | filter_x, in_channel)], |
| 379 | load_count); |
| 380 | vaddw_h_vx(v4, v4, 0); |
| 381 | vaddw_w_vx(v6, v5, 0); |
| 382 | vaddw_w_vx(v4, v4, 0); |
| 383 | |
| 384 | vmul_w_vv_m(vm0, vm0, vm1); |
| 385 | vadd_w_vv(v0, v0, v1); |
| 386 | vadd_w_vv(v0, v0, v2); |
| 387 | vadd_w_vv(v0, v0, v3); |
| 388 | int32_t acc32[4]; |
| 389 | vst_w_l_xx(v0, acc32, 4); |
| 390 | for (int i = 0; i < 4; ++i) { |
| 391 | acc64 += acc32[i]; |
| 392 | } |
| 393 | in_channel += 16; |
| 394 | } while (in_channel + 16 <= filter_depth); |
| 395 | } |
| 396 | } |
| 397 | if (bias_data) { |
| 398 | acc64 = acc64 + bias_data[out_channel]; |
| 399 | } |
| 400 | int32_t acc = tflite::MultiplyByQuantizedMultiplier( |
| 401 | acc64, output_multiplier[out_channel], output_shift[out_channel]); |
| 402 | acc += output_offset; |
| 403 | acc = std::clamp(acc, output_activation_min, output_activation_max); |
| 404 | output_data[tflite::Offset(output_shape, batch, out_y, out_x, |
| 405 | out_channel)] = static_cast<int16_t>(acc); |
| 406 | } |
| 407 | } |
| 408 | } |
| 409 | } |
| 410 | } |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 411 | } // namespace |
Alex Van Damme | f39cadd | 2023-08-28 16:21:20 -0700 | [diff] [blame] | 412 | |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 413 | void ConvS16B64( |
Derek Chow | 903c311 | 2023-11-13 10:17:15 -0800 | [diff] [blame] | 414 | const tflite::ConvParams& params, const int32_t* output_multiplier, |
| 415 | const int32_t* output_shift, const tflite::RuntimeShape& input_shape, |
| 416 | const int16_t* input_data, const tflite::RuntimeShape& filter_shape, |
| 417 | const int8_t* filter_data, const tflite::RuntimeShape& bias_shape, |
| 418 | const int64_t* bias_data, const tflite::RuntimeShape& output_shape, |
| 419 | int16_t* output_data) { |
Alex Van Damme | 008f0ae | 2024-06-18 14:10:02 -0700 | [diff] [blame] | 420 | const auto input_height = input_shape.Dims(1); |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 421 | const auto input_depth = input_shape.Dims(3); |
| 422 | const auto filter_height = filter_shape.Dims(1); |
| 423 | const auto filter_width = filter_shape.Dims(2); |
| 424 | const auto filter_depth = filter_shape.Dims(3); |
| 425 | const auto output_depth = output_shape.Dims(3); |
| 426 | |
| 427 | // generic implementation by default |
| 428 | auto fn = ConvS16B64Generic; |
| 429 | |
| 430 | // special cases |
| 431 | if (filter_height == 1 && output_depth % 2 == 0) { |
| 432 | // 1x1 filter, filter depth = 32n |
| 433 | if (filter_width == 1 && filter_depth % 32 == 0) { |
| 434 | fn = ConvS16B64K1x1; |
Derek Chow | 903c311 | 2023-11-13 10:17:15 -0800 | [diff] [blame] | 435 | } |
| 436 | |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 437 | // 1xn non group filter |
| 438 | bool group_conv = !(input_depth == filter_depth); |
| 439 | int32_t fan_in = filter_width * filter_depth; |
Alex Van Damme | 008f0ae | 2024-06-18 14:10:02 -0700 | [diff] [blame] | 440 | if (!group_conv && fan_in % 32 == 0 && input_height == 1) { |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 441 | fn = ConvS16B64K1xnNonGroup; |
Derek Chow | 903c311 | 2023-11-13 10:17:15 -0800 | [diff] [blame] | 442 | } |
| 443 | |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 444 | // 1xn group filter |
Alex Van Damme | 008f0ae | 2024-06-18 14:10:02 -0700 | [diff] [blame] | 445 | if (fan_in % 32 == 0 && input_height == 1) { |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 446 | fn = ConvS16B64K1xnGroup; |
Derek Chow | 903c311 | 2023-11-13 10:17:15 -0800 | [diff] [blame] | 447 | } |
| 448 | } |
| 449 | |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 450 | fn(params, output_multiplier, output_shift, input_shape, input_data, |
| 451 | filter_shape, filter_data, bias_shape, bias_data, output_shape, |
| 452 | output_data); |
Derek Chow | 903c311 | 2023-11-13 10:17:15 -0800 | [diff] [blame] | 453 | } |
| 454 | |
Alex Van Damme | f39cadd | 2023-08-28 16:21:20 -0700 | [diff] [blame] | 455 | } // namespace kelvin::opt |