Naveen Dodda | be4ab97 | 2024-04-17 17:47:46 +0000 | [diff] [blame] | 1 | /* |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 2 | * Copyright 2024 Google LLC |
Naveen Dodda | be4ab97 | 2024-04-17 17:47:46 +0000 | [diff] [blame] | 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | */ |
| 16 | |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 17 | // Depthwise convolution based on Kelvin ops |
| 18 | // Data types: input: s8, filter: s8, bias s32 |
Naveen Dodda | be4ab97 | 2024-04-17 17:47:46 +0000 | [diff] [blame] | 19 | |
Naveen Dodda | be4ab97 | 2024-04-17 17:47:46 +0000 | [diff] [blame] | 20 | #include "tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h" |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 21 | #include "tflm/opt/conv_util.h" |
Naveen Dodda | be4ab97 | 2024-04-17 17:47:46 +0000 | [diff] [blame] | 22 | |
| 23 | namespace kelvin::opt { |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 24 | namespace { |
Alex Van Damme | cd3d0e3 | 2024-05-10 15:27:06 -0700 | [diff] [blame] | 25 | |
Alex Van Damme | 40a8300 | 2024-05-08 16:47:03 -0700 | [diff] [blame] | 26 | // Reorders a vector to match the pattern after double-widening. |
| 27 | // N must be a multiple of 4. |
| 28 | void VectorSwizzle(const int32_t* input, int32_t* output, int N) { |
| 29 | assert(N >= 4 && N % 4 == 0); |
| 30 | const int32_t(&in)[N] = *(int32_t(*)[N])input; |
| 31 | int32_t(&out)[N] = *(int32_t(*)[N]) output; |
| 32 | const int32_t* p_in = in; |
| 33 | for (int i = 0; i < N / 4; ++i) { |
| 34 | int32_t* out0 = out + i + 0; |
| 35 | int32_t* out1 = out + i + 16; |
| 36 | int32_t* out2 = out + i + 8; |
| 37 | int32_t* out3 = out + i + 24; |
| 38 | *out0 = *p_in++; |
| 39 | *out1 = *p_in++; |
| 40 | *out2 = *p_in++; |
| 41 | *out3 = *p_in++; |
Alex Van Damme | b1afda6 | 2024-05-09 16:48:40 -0700 | [diff] [blame] | 42 | } |
| 43 | } |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 44 | |
| 45 | // special case of input depth = 32n, filter shape of 3x3, strides of 1 |
| 46 | void DepthwiseConvS83x3D32_Stride1( |
| 47 | const tflite::DepthwiseParams& params, const int32_t* output_multiplier, |
| 48 | const int32_t* output_shift, const tflite::RuntimeShape& input_shape, |
| 49 | const int8_t* input_data, const tflite::RuntimeShape& filter_shape, |
| 50 | const int8_t* filter_data, const tflite::RuntimeShape& bias_shape, |
| 51 | const int32_t* bias_data, const tflite::RuntimeShape& output_shape, |
| 52 | int8_t* output_data |
| 53 | ) { |
| 54 | const int stride_width = params.stride_width; |
| 55 | const int stride_height = params.stride_height; |
| 56 | const int pad_width = params.padding_values.width; |
| 57 | const int pad_height = params.padding_values.height; |
| 58 | const int32_t input_offset = params.input_offset; |
| 59 | const int32_t output_offset = params.output_offset; |
| 60 | const int32_t output_activation_min = params.quantized_activation_min; |
| 61 | const int32_t output_activation_max = params.quantized_activation_max; |
| 62 | const int batches = MatchingDim(input_shape, 0, output_shape, 0); |
| 63 | const int input_height = input_shape.Dims(1); |
| 64 | const int input_width = input_shape.Dims(2); |
| 65 | const int input_depth = input_shape.Dims(3); |
| 66 | const int output_height = output_shape.Dims(1); |
| 67 | const int output_width = output_shape.Dims(2); |
| 68 | const int output_depth = output_shape.Dims(3); |
| 69 | int32_t swizzled_bias_data[32]; |
| 70 | int32_t swizzled_shift_multi[32]; |
| 71 | int32_t swizzled_output_multi[32]; |
| 72 | |
| 73 | for (int in_channel = 0; in_channel + 32 <= input_depth; in_channel += 32) { |
| 74 | const int output_channel = in_channel; |
| 75 | int8_t* p_output = output_data + output_channel; |
| 76 | VectorSwizzle(bias_data + output_channel, swizzled_bias_data, 32); |
| 77 | VectorSwizzle(output_multiplier + output_channel, swizzled_output_multi, 32); |
| 78 | VectorSwizzle(output_shift + output_channel, swizzled_shift_multi, 32); |
| 79 | |
| 80 | vld_w_x_m(v52, swizzled_bias_data); |
| 81 | vld_w_x_m(v56, swizzled_output_multi); |
| 82 | vld_w_x_m(v60, swizzled_shift_multi); |
| 83 | vrsub_w_vx_m(v60, v60, 0); |
| 84 | |
| 85 | union { |
| 86 | vdwconv_u8_t dwconv; |
| 87 | uint32_t raw; |
| 88 | } cmds; |
| 89 | cmds.raw = 0; |
| 90 | cmds.dwconv.sdata1 = true; |
| 91 | cmds.dwconv.sbias1 = input_offset; |
| 92 | cmds.dwconv.sdata2 = true; |
| 93 | cmds.dwconv.sbias2 = 0; |
| 94 | cmds.dwconv.mode = 0; |
| 95 | cmds.dwconv.sparsity = 0; |
| 96 | cmds.dwconv.regbase = 0; |
| 97 | |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 98 | #define FLT_0_0 v0 |
| 99 | #define FLT_0_1 v3 |
| 100 | #define FLT_0_2 v6 |
| 101 | #define FLT_1_0 v1 |
| 102 | #define FLT_1_1 v4 |
| 103 | #define FLT_1_2 v7 |
| 104 | #define FLT_2_0 v2 |
| 105 | #define FLT_2_1 v5 |
| 106 | #define FLT_2_2 v8 |
| 107 | |
| 108 | #define INPUT_0_0 v9 |
| 109 | #define INPUT_0_1 v12 |
| 110 | #define INPUT_0_2 v15 |
| 111 | #define INPUT_0_3 v18 |
| 112 | #define INPUT_0_4 v21 |
| 113 | #define INPUT_0_5 v24 |
| 114 | #define INPUT_1_0 v10 |
| 115 | #define INPUT_1_1 v13 |
| 116 | #define INPUT_1_2 v16 |
| 117 | #define INPUT_1_3 v19 |
| 118 | #define INPUT_1_4 v22 |
| 119 | #define INPUT_1_5 v25 |
| 120 | #define INPUT_2_0 v11 |
| 121 | #define INPUT_2_1 v14 |
| 122 | #define INPUT_2_2 v17 |
| 123 | #define INPUT_2_3 v20 |
| 124 | #define INPUT_2_4 v23 |
| 125 | #define INPUT_2_5 v26 |
| 126 | |
| 127 | #define INPUT_PTRS(_strides) \ |
| 128 | const int in_y_origin = (out_y * stride_height) - pad_height; \ |
| 129 | const int in_x_origin = (out_x * stride_width) - pad_width; \ |
| 130 | const int8_t* p_in_0 = input_data + \ |
| 131 | (batch * input_height * input_width * input_depth) + \ |
| 132 | (in_y_origin * input_width * input_depth) + \ |
| 133 | ((in_x_origin + _strides) * input_depth) + \ |
| 134 | in_channel; \ |
| 135 | const int8_t* p_in_1 = p_in_0 + (input_width * input_depth); \ |
| 136 | const int8_t* p_in_2 = p_in_1 + (input_width * input_depth); \ |
| 137 | (void)p_in_2; |
| 138 | |
| 139 | #define COMPUTE() \ |
| 140 | adwinit_v(v48, v48); \ |
| 141 | adwconv_vxv(v48, INPUT_0_0, cmds, FLT_0_0); \ |
| 142 | adwconv_vxv(v48, INPUT_0_1, cmds, FLT_0_1); \ |
| 143 | vdwconv_vxv(v48, INPUT_0_2, cmds, FLT_0_2); |
| 144 | |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 145 | // Don't reorder me, otherwise data will not be |
| 146 | // loaded in the correct order |
| 147 | // (we can reuse the p_flt* due to the `p` vld variant). |
| 148 | const int8_t* p_flt0 = filter_data + in_channel; |
| 149 | const int8_t* p_flt1 = p_flt0 + input_depth; |
| 150 | const int32_t stride = 2 * input_depth; |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 151 | vld_b_sp_xx(FLT_0_0, p_flt0, stride); |
| 152 | vld_b_sp_xx(FLT_0_1, p_flt1, stride); |
| 153 | vld_b_sp_xx(FLT_0_2, p_flt0, stride); |
| 154 | vld_b_sp_xx(FLT_1_0, p_flt1, stride); |
| 155 | vld_b_sp_xx(FLT_1_1, p_flt0, stride); |
| 156 | vld_b_sp_xx(FLT_1_2, p_flt1, stride); |
| 157 | vld_b_sp_xx(FLT_2_0, p_flt0, stride); |
| 158 | vld_b_sp_xx(FLT_2_1, p_flt1, stride); |
| 159 | vld_b_sp_xx(FLT_2_2, p_flt0, stride); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 160 | |
| 161 | for (int batch = 0; batch < batches; ++batch) { |
| 162 | int out_y = 0; |
| 163 | for (; out_y < pad_height; ++out_y) { |
| 164 | int out_x = 0; |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 165 | vdup_b_x(INPUT_0_0, -input_offset); |
| 166 | vdup_b_x(INPUT_0_1, -input_offset); |
| 167 | vdup_b_x(INPUT_0_2, -input_offset); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 168 | for (; out_x < pad_width; ++out_x) { |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 169 | INPUT_PTRS(1); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 170 | vmv_v_m(v48, v52); |
| 171 | |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 172 | vdup_b_x(INPUT_1_0, -input_offset); |
| 173 | vld_b_sp_xx(INPUT_1_1, p_in_1, input_depth); |
| 174 | vld_b_sp_xx(INPUT_1_2, p_in_1, input_depth); |
| 175 | vdup_b_x(INPUT_2_0, -input_offset); |
| 176 | vld_b_sp_xx(INPUT_2_1, p_in_2, input_depth); |
| 177 | vld_b_sp_xx(INPUT_2_2, p_in_2, input_depth); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 178 | |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 179 | COMPUTE(); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 180 | INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE( |
| 181 | v48, v56, v60, |
| 182 | output_activation_min, |
| 183 | output_activation_max, |
| 184 | output_offset); |
| 185 | vsraqs_b_vx(v48, v48, 0); |
| 186 | vst_b_x(v48, p_output); |
| 187 | |
| 188 | p_output += output_depth; |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 189 | } |
| 190 | for (; out_x < output_width - pad_width; ++out_x) { |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 191 | INPUT_PTRS(0); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 192 | vmv_v_m(v48, v52); |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 193 | vld_b_sp_xx(INPUT_1_0, p_in_1, input_depth); |
| 194 | vld_b_sp_xx(INPUT_1_1, p_in_1, input_depth); |
| 195 | vld_b_sp_xx(INPUT_1_2, p_in_1, input_depth); |
| 196 | vld_b_sp_xx(INPUT_2_0, p_in_2, input_depth); |
| 197 | vld_b_sp_xx(INPUT_2_1, p_in_2, input_depth); |
| 198 | vld_b_sp_xx(INPUT_2_2, p_in_2, input_depth); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 199 | |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 200 | COMPUTE(); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 201 | INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE( |
| 202 | v48, v56, v60, |
| 203 | output_activation_min, |
| 204 | output_activation_max, |
| 205 | output_offset); |
| 206 | vsraqs_b_vx(v48, v48, 0); |
| 207 | vst_b_x(v48, p_output); |
| 208 | p_output += output_depth; |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 209 | } |
| 210 | for (; out_x < output_width; ++out_x) { |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 211 | INPUT_PTRS(0); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 212 | vmv_v_m(v48, v52); |
| 213 | |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 214 | vld_b_sp_xx(INPUT_1_0, p_in_1, input_depth); |
| 215 | vld_b_sp_xx(INPUT_1_1, p_in_1, input_depth); |
| 216 | vdup_b_x(INPUT_1_2, -input_offset); |
| 217 | vld_b_sp_xx(INPUT_2_0, p_in_2, input_depth); |
| 218 | vld_b_sp_xx(INPUT_2_1, p_in_2, input_depth); |
| 219 | vdup_b_x(INPUT_2_2, -input_offset); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 220 | |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 221 | COMPUTE(); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 222 | INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE( |
| 223 | v48, v56, v60, |
| 224 | output_activation_min, |
| 225 | output_activation_max, |
| 226 | output_offset); |
| 227 | vsraqs_b_vx(v48, v48, 0); |
| 228 | vst_b_x(v48, p_output); |
| 229 | |
| 230 | p_output += output_depth; |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 231 | } |
| 232 | } |
| 233 | for (; out_y < output_height - pad_height; ++out_y) { |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 234 | int out_x = 0; |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 235 | for (; out_x < pad_width; ++out_x) { |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 236 | INPUT_PTRS(1); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 237 | vmv_v_m(v48, v52); |
| 238 | |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 239 | vdup_b_x(INPUT_0_0, -input_offset); |
| 240 | vdup_b_x(INPUT_1_0, -input_offset); |
| 241 | vdup_b_x(INPUT_2_0, -input_offset); |
| 242 | vld_b_sp_xx(INPUT_0_1, p_in_0, input_depth); |
| 243 | vld_b_sp_xx(INPUT_0_2, p_in_0, input_depth); |
| 244 | vld_b_sp_xx(INPUT_1_1, p_in_1, input_depth); |
| 245 | vld_b_sp_xx(INPUT_1_2, p_in_1, input_depth); |
| 246 | vld_b_sp_xx(INPUT_2_1, p_in_2, input_depth); |
| 247 | vld_b_sp_xx(INPUT_2_2, p_in_2, input_depth); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 248 | |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 249 | COMPUTE(); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 250 | INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE( |
| 251 | v48, v56, v60, |
| 252 | output_activation_min, |
| 253 | output_activation_max, |
| 254 | output_offset); |
| 255 | vsraqs_b_vx(v48, v48, 0); |
| 256 | vst_b_x(v48, p_output); |
| 257 | |
| 258 | p_output += output_depth; |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 259 | } |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 260 | for (; out_x + 4 <= output_width - pad_width; out_x += 4) { |
| 261 | INPUT_PTRS(0); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 262 | // Initialize accumulators w/ bias data. |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 263 | vmv_v_m(v36, v52); |
| 264 | vmv_v_m(v40, v52); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 265 | vmv_v_m(v44, v52); |
| 266 | vmv_v_m(v48, v52); |
| 267 | |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 268 | vld_b_sp_xx(INPUT_0_0, p_in_0, stride_width * input_depth); |
| 269 | vld_b_sp_xx(INPUT_1_0, p_in_1, stride_width * input_depth); |
| 270 | vld_b_sp_xx(INPUT_2_0, p_in_2, stride_width * input_depth); |
| 271 | vld_b_sp_xx(INPUT_0_1, p_in_0, stride_width * input_depth); |
| 272 | vld_b_sp_xx(INPUT_1_1, p_in_1, stride_width * input_depth); |
| 273 | vld_b_sp_xx(INPUT_2_1, p_in_2, stride_width * input_depth); |
| 274 | vld_b_sp_xx(INPUT_0_2, p_in_0, stride_width * input_depth); |
| 275 | vld_b_sp_xx(INPUT_1_2, p_in_1, stride_width * input_depth); |
| 276 | vld_b_sp_xx(INPUT_2_2, p_in_2, stride_width * input_depth); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 277 | |
| 278 | adwinit_v(v48, v48); |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 279 | adwconv_vxv(v48, INPUT_0_0, cmds, FLT_0_0); |
| 280 | adwconv_vxv(v48, INPUT_0_1, cmds, FLT_0_1); |
| 281 | vdwconv_vxv(v48, INPUT_0_2, cmds, FLT_0_2); |
| 282 | |
| 283 | vld_b_sp_xx(INPUT_0_3, p_in_0, stride_width * input_depth); |
| 284 | vld_b_sp_xx(INPUT_1_3, p_in_1, stride_width * input_depth); |
| 285 | vld_b_sp_xx(INPUT_2_3, p_in_2, stride_width * input_depth); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 286 | |
| 287 | adwinit_v(v44, v44); |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 288 | adwconv_vxv(v44, INPUT_0_1, cmds, FLT_0_0); |
| 289 | adwconv_vxv(v44, INPUT_0_2, cmds, FLT_0_1); |
| 290 | vdwconv_vxv(v44, INPUT_0_3, cmds, FLT_0_2); |
| 291 | |
| 292 | vld_b_sp_xx(INPUT_0_4, p_in_0, stride_width * input_depth); |
| 293 | vld_b_sp_xx(INPUT_1_4, p_in_1, stride_width * input_depth); |
| 294 | vld_b_sp_xx(INPUT_2_4, p_in_2, stride_width * input_depth); |
| 295 | |
| 296 | adwinit_v(v40, v40); |
| 297 | adwconv_vxv(v40, INPUT_0_2, cmds, FLT_0_0); |
| 298 | adwconv_vxv(v40, INPUT_0_3, cmds, FLT_0_1); |
| 299 | vdwconv_vxv(v40, INPUT_0_4, cmds, FLT_0_2); |
| 300 | |
| 301 | vld_b_sp_xx(INPUT_0_5, p_in_0, stride_width * input_depth); |
| 302 | vld_b_sp_xx(INPUT_1_5, p_in_1, stride_width * input_depth); |
| 303 | vld_b_sp_xx(INPUT_2_5, p_in_2, stride_width * input_depth); |
| 304 | |
| 305 | adwinit_v(v36, v36); |
| 306 | adwconv_vxv(v36, INPUT_0_3, cmds, FLT_0_0); |
| 307 | adwconv_vxv(v36, INPUT_0_4, cmds, FLT_0_1); |
| 308 | vdwconv_vxv(v36, INPUT_0_5, cmds, FLT_0_2); |
| 309 | |
| 310 | INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE( |
| 311 | v48, v56, v60, |
| 312 | output_activation_min, |
| 313 | output_activation_max, |
| 314 | output_offset); |
| 315 | vsraqs_b_vx(v48, v48, 0); |
| 316 | vst_b_x(v48, p_output); |
| 317 | p_output += output_depth; |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 318 | |
| 319 | INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE( |
| 320 | v44, v56, v60, |
| 321 | output_activation_min, |
| 322 | output_activation_max, |
| 323 | output_offset); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 324 | vsraqs_b_vx(v44, v44, 0); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 325 | vst_b_x(v44, p_output); |
| 326 | p_output += output_depth; |
| 327 | |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 328 | INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE( |
| 329 | v40, v56, v60, |
| 330 | output_activation_min, |
| 331 | output_activation_max, |
| 332 | output_offset); |
| 333 | vsraqs_b_vx(v40, v40, 0); |
| 334 | vst_b_x(v40, p_output); |
| 335 | p_output += output_depth; |
| 336 | |
| 337 | INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE( |
| 338 | v36, v56, v60, |
| 339 | output_activation_min, |
| 340 | output_activation_max, |
| 341 | output_offset); |
| 342 | vsraqs_b_vx(v36, v36, 0); |
| 343 | vst_b_x(v36, p_output); |
| 344 | p_output += output_depth; |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 345 | } |
| 346 | for (; out_x < output_width - pad_width; ++out_x) { |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 347 | INPUT_PTRS(0); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 348 | vmv_v_m(v48, v52); |
| 349 | |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 350 | vld_b_sp_xx(INPUT_0_0, p_in_0, input_depth); |
| 351 | vld_b_sp_xx(INPUT_0_1, p_in_0, input_depth); |
| 352 | vld_b_sp_xx(INPUT_0_2, p_in_0, input_depth); |
| 353 | vld_b_sp_xx(INPUT_1_0, p_in_1, input_depth); |
| 354 | vld_b_sp_xx(INPUT_1_1, p_in_1, input_depth); |
| 355 | vld_b_sp_xx(INPUT_1_2, p_in_1, input_depth); |
| 356 | vld_b_sp_xx(INPUT_2_0, p_in_2, input_depth); |
| 357 | vld_b_sp_xx(INPUT_2_1, p_in_2, input_depth); |
| 358 | vld_b_sp_xx(INPUT_2_2, p_in_2, input_depth); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 359 | |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 360 | COMPUTE(); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 361 | INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE( |
| 362 | v48, v56, v60, |
| 363 | output_activation_min, |
| 364 | output_activation_max, |
| 365 | output_offset); |
| 366 | vsraqs_b_vx(v48, v48, 0); |
| 367 | vst_b_x(v48, p_output); |
| 368 | p_output += output_depth; |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 369 | } |
| 370 | for (; out_x < output_width; ++out_x) { |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 371 | INPUT_PTRS(0); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 372 | vmv_v_m(v48, v52); |
| 373 | |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 374 | vdup_b_x(INPUT_0_2, -input_offset); |
| 375 | vdup_b_x(INPUT_1_2, -input_offset); |
| 376 | vdup_b_x(INPUT_2_2, -input_offset); |
| 377 | vld_b_sp_xx(INPUT_0_0, p_in_0, input_depth); |
| 378 | vld_b_sp_xx(INPUT_0_1, p_in_0, input_depth); |
| 379 | vld_b_sp_xx(INPUT_1_0, p_in_1, input_depth); |
| 380 | vld_b_sp_xx(INPUT_1_1, p_in_1, input_depth); |
| 381 | vld_b_sp_xx(INPUT_2_0, p_in_2, input_depth); |
| 382 | vld_b_sp_xx(INPUT_2_1, p_in_2, input_depth); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 383 | |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 384 | COMPUTE(); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 385 | INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE( |
| 386 | v48, v56, v60, |
| 387 | output_activation_min, |
| 388 | output_activation_max, |
| 389 | output_offset); |
| 390 | vsraqs_b_vx(v48, v48, 0); |
| 391 | vst_b_x(v48, p_output); |
| 392 | |
| 393 | p_output += output_depth; |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 394 | } |
| 395 | } |
| 396 | for (; out_y < output_height; ++out_y) { |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 397 | vdup_b_x(INPUT_2_0, -input_offset); |
| 398 | vdup_b_x(INPUT_2_1, -input_offset); |
| 399 | vdup_b_x(INPUT_2_2, -input_offset); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 400 | int out_x = 0; |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 401 | for (; out_x < pad_width; ++out_x) { |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 402 | INPUT_PTRS(1); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 403 | vmv_v_m(v48, v52); |
| 404 | |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 405 | vdup_b_x(INPUT_0_0, -input_offset); |
| 406 | vld_b_sp_xx(INPUT_0_1, p_in_0, input_depth); |
| 407 | vld_b_sp_xx(INPUT_0_2, p_in_0, input_depth); |
| 408 | vdup_b_x(INPUT_1_0, -input_offset); |
| 409 | vld_b_sp_xx(INPUT_1_1, p_in_1, input_depth); |
| 410 | vld_b_sp_xx(INPUT_1_2, p_in_1, input_depth); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 411 | |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 412 | COMPUTE(); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 413 | INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE( |
| 414 | v48, v56, v60, |
| 415 | output_activation_min, |
| 416 | output_activation_max, |
| 417 | output_offset); |
| 418 | vsraqs_b_vx(v48, v48, 0); |
| 419 | vst_b_x(v48, p_output); |
| 420 | p_output += output_depth; |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 421 | } |
| 422 | for (; out_x < output_width - pad_width; ++out_x) { |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 423 | INPUT_PTRS(0); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 424 | vmv_v_m(v48, v52); |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 425 | vld_b_sp_xx(INPUT_0_0, p_in_0, input_depth); |
| 426 | vld_b_sp_xx(INPUT_0_1, p_in_0, input_depth); |
| 427 | vld_b_sp_xx(INPUT_0_2, p_in_0, input_depth); |
| 428 | vld_b_sp_xx(INPUT_1_0, p_in_1, input_depth); |
| 429 | vld_b_sp_xx(INPUT_1_1, p_in_1, input_depth); |
| 430 | vld_b_sp_xx(INPUT_1_2, p_in_1, input_depth); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 431 | |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 432 | COMPUTE(); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 433 | INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE( |
| 434 | v48, v56, v60, |
| 435 | output_activation_min, |
| 436 | output_activation_max, |
| 437 | output_offset); |
| 438 | vsraqs_b_vx(v48, v48, 0); |
| 439 | vst_b_x(v48, p_output); |
| 440 | p_output += output_depth; |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 441 | } |
| 442 | for (; out_x < output_width; ++out_x) { |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 443 | INPUT_PTRS(0); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 444 | vmv_v_m(v48, v52); |
| 445 | |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 446 | vld_b_sp_xx(INPUT_0_0, p_in_0, input_depth); |
| 447 | vld_b_sp_xx(INPUT_0_1, p_in_0, input_depth); |
| 448 | vdup_b_x(INPUT_0_2, -input_offset); |
| 449 | vld_b_sp_xx(INPUT_1_0, p_in_1, input_depth); |
| 450 | vld_b_sp_xx(INPUT_1_1, p_in_1, input_depth); |
| 451 | vdup_b_x(INPUT_1_2, -input_offset); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 452 | |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 453 | COMPUTE(); |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 454 | INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE( |
| 455 | v48, v56, v60, |
| 456 | output_activation_min, |
| 457 | output_activation_max, |
| 458 | output_offset); |
| 459 | vsraqs_b_vx(v48, v48, 0); |
| 460 | vst_b_x(v48, p_output); |
| 461 | p_output += output_depth; |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 462 | } |
| 463 | } |
| 464 | } |
| 465 | } |
Alex Van Damme | d22d7d6 | 2024-06-07 10:53:00 -0700 | [diff] [blame^] | 466 | #undef FLT_0_0 |
| 467 | #undef FLT_0_1 |
| 468 | #undef FLT_0_2 |
| 469 | #undef FLT_1_0 |
| 470 | #undef FLT_1_1 |
| 471 | #undef FLT_1_2 |
| 472 | #undef FLT_2_0 |
| 473 | #undef FLT_2_1 |
| 474 | #undef FLT_2_2 |
| 475 | #undef COMPUTE |
Alex Van Damme | 088841b | 2024-06-03 16:16:54 -0700 | [diff] [blame] | 476 | } |
| 477 | |
Alex Van Damme | cd3d0e3 | 2024-05-10 15:27:06 -0700 | [diff] [blame] | 478 | // special case of input depth = 32n, filter shape of 3x3 |
| 479 | void DepthwiseConvS83x3D32( |
| 480 | const tflite::DepthwiseParams& params, const int32_t* output_multiplier, |
| 481 | const int32_t* output_shift, const tflite::RuntimeShape& input_shape, |
| 482 | const int8_t* input_data, const tflite::RuntimeShape& filter_shape, |
| 483 | const int8_t* filter_data, const tflite::RuntimeShape& bias_shape, |
| 484 | const int32_t* bias_data, const tflite::RuntimeShape& output_shape, |
| 485 | int8_t* output_data |
| 486 | ) { |
| 487 | const int stride_width = params.stride_width; |
| 488 | const int stride_height = params.stride_height; |
| 489 | const int pad_width = params.padding_values.width; |
| 490 | const int pad_height = params.padding_values.height; |
| 491 | const int32_t input_offset = params.input_offset; |
| 492 | const int32_t output_offset = params.output_offset; |
| 493 | const int32_t output_activation_min = params.quantized_activation_min; |
| 494 | const int32_t output_activation_max = params.quantized_activation_max; |
| 495 | const int batches = MatchingDim(input_shape, 0, output_shape, 0); |
| 496 | const int input_height = input_shape.Dims(1); |
| 497 | const int input_width = input_shape.Dims(2); |
| 498 | const int input_depth = input_shape.Dims(3); |
| 499 | const int output_height = output_shape.Dims(1); |
| 500 | const int output_width = output_shape.Dims(2); |
| 501 | const int output_depth = output_shape.Dims(3); |
| 502 | int32_t swizzled_bias_data[32]; |
| 503 | int32_t swizzled_shift_multi[32]; |
| 504 | int32_t swizzled_output_multi[32]; |
| 505 | |
| 506 | for (int in_channel = 0; in_channel + 32 <= input_depth; in_channel += 32) { |
| 507 | const int output_channel = in_channel; |
| 508 | VectorSwizzle(bias_data + output_channel, swizzled_bias_data, 32); |
| 509 | VectorSwizzle(output_multiplier + output_channel, swizzled_output_multi, 32); |
| 510 | VectorSwizzle(output_shift + output_channel, swizzled_shift_multi, 32); |
| 511 | |
| 512 | vld_w_x_m(v52, swizzled_bias_data); |
| 513 | vld_w_x_m(v56, swizzled_output_multi); |
| 514 | vld_w_x_m(v60, swizzled_shift_multi); |
| 515 | vrsub_w_vx_m(v60, v60, 0); |
| 516 | |
| 517 | union { |
| 518 | vdwconv_u8_t dwconv; |
| 519 | uint32_t raw; |
| 520 | } cmds; |
| 521 | cmds.raw = 0; |
| 522 | cmds.dwconv.sdata1 = true; |
| 523 | cmds.dwconv.sbias1 = input_offset; |
| 524 | cmds.dwconv.sdata2 = true; |
| 525 | cmds.dwconv.sbias2 = 0; |
| 526 | cmds.dwconv.mode = 0; |
| 527 | cmds.dwconv.sparsity = 0; |
| 528 | cmds.dwconv.regbase = 0; |
| 529 | |
| 530 | // Don't reorder me, otherwise data will not be |
| 531 | // loaded in the correct order |
| 532 | // (we can reuse the p_flt* due to the `p` vld variant). |
| 533 | const int8_t* p_flt0 = filter_data + in_channel; |
| 534 | const int8_t* p_flt1 = p_flt0 + input_depth; |
| 535 | const int32_t stride = 2 * input_depth; |
| 536 | vld_b_sp_xx(v6, p_flt0, stride); |
| 537 | vld_b_sp_xx(v7, p_flt1, stride); |
| 538 | vld_b_sp_xx(v8, p_flt0, stride); |
| 539 | vld_b_sp_xx(v9, p_flt1, stride); |
| 540 | vld_b_sp_xx(v10, p_flt0, stride); |
| 541 | vld_b_sp_xx(v11, p_flt1, stride); |
| 542 | vld_b_sp_xx(v12, p_flt0, stride); |
| 543 | vld_b_sp_xx(v13, p_flt1, stride); |
| 544 | vld_b_sp_xx(v14, p_flt0, stride); |
| 545 | |
| 546 | for (int batch = 0; batch < batches; ++batch) { |
| 547 | const int8_t* p_output = output_data + (batch * output_width * output_height * output_depth) + output_channel; |
| 548 | for (int out_y = 0; out_y < output_height; ++out_y) { |
| 549 | const int in_y_origin = (out_y * stride_height) - pad_height; |
| 550 | const int y_offset = (output_depth * output_width * out_y); |
| 551 | for (int out_x = 0; out_x < output_width; ++out_x) { |
| 552 | const int in_x_origin = (out_x * stride_width) - pad_width; |
| 553 | |
| 554 | // Initialize accumulators w/ bias data. |
| 555 | vmv_v_m(v48, v52); |
| 556 | |
| 557 | bool top_pad = in_y_origin < 0; |
| 558 | bool left_pad = in_x_origin < 0; |
| 559 | bool bottom_pad = (in_y_origin + 2) >= input_height; |
| 560 | bool right_pad = (in_x_origin + 2) >= input_width; |
| 561 | bool padding_required = top_pad || left_pad || bottom_pad || right_pad; |
| 562 | const int8_t* p_in_0 = input_data + |
| 563 | (batch * input_height * input_width * input_depth) + |
| 564 | (in_y_origin * input_width * input_depth) + |
| 565 | (in_x_origin * input_depth) + |
| 566 | in_channel; |
| 567 | const int8_t* p_in_1 = p_in_0 + (input_width * input_depth); |
| 568 | const int8_t* p_in_2 = p_in_1 + (input_width * input_depth); |
| 569 | if (!padding_required) { |
| 570 | vld_b_sp_xx(v15, p_in_0, input_depth); |
| 571 | vld_b_sp_xx(v16, p_in_0, input_depth); |
| 572 | vld_b_sp_xx(v17, p_in_0, input_depth); |
| 573 | vld_b_sp_xx(v18, p_in_1, input_depth); |
| 574 | vld_b_sp_xx(v19, p_in_1, input_depth); |
| 575 | vld_b_sp_xx(v20, p_in_1, input_depth); |
| 576 | vld_b_sp_xx(v21, p_in_2, input_depth); |
| 577 | vld_b_sp_xx(v22, p_in_2, input_depth); |
| 578 | vld_b_sp_xx(v23, p_in_2, input_depth); |
| 579 | } else { |
| 580 | // Top row |
| 581 | if (top_pad || left_pad) { |
| 582 | vdup_b_x(v15, -input_offset); |
| 583 | } else { |
| 584 | vld_b_x(v15, p_in_0); |
| 585 | } |
| 586 | if (top_pad) { |
| 587 | vdup_b_x(v16, -input_offset); |
| 588 | } else { |
| 589 | vld_b_x(v16, p_in_0 + input_depth); |
| 590 | } |
| 591 | if (top_pad || right_pad) { |
| 592 | vdup_b_x(v17, -input_offset); |
| 593 | } else { |
| 594 | vld_b_x(v17, p_in_0 + (2 * input_depth)); |
| 595 | } |
| 596 | // Middle row |
| 597 | if (left_pad) { |
| 598 | vdup_b_x(v18, -input_offset); |
| 599 | } else { |
| 600 | vld_b_x(v18, p_in_1); |
| 601 | } |
| 602 | vld_b_x(v19, p_in_1 + input_depth); |
| 603 | if (right_pad) { |
| 604 | vdup_b_x(v20, -input_offset); |
| 605 | } else { |
| 606 | vld_b_x(v20, p_in_1 + (2 * input_depth)); |
| 607 | } |
| 608 | // Bottom row |
| 609 | if (bottom_pad || left_pad) { |
| 610 | vdup_b_x(v21, -input_offset); |
| 611 | } else { |
| 612 | vld_b_x(v21, p_in_2); |
| 613 | } |
| 614 | if (bottom_pad) { |
| 615 | vdup_b_x(v22, -input_offset); |
| 616 | } else { |
| 617 | vld_b_x(v22, p_in_2 + input_depth); |
| 618 | } |
| 619 | if (bottom_pad || right_pad) { |
| 620 | vdup_b_x(v23, -input_offset); |
| 621 | } else { |
| 622 | vld_b_x(v23, p_in_2 + (2 * input_depth)); |
| 623 | } |
| 624 | } |
| 625 | |
| 626 | adwinit_v(v48, v48); |
| 627 | adwconv_vxv(v48, v15, cmds, v6); |
| 628 | adwconv_vxv(v48, v18, cmds, v9); |
| 629 | vdwconv_vxv(v48, v21, cmds, v12); |
| 630 | |
| 631 | vdmulh_w_rn_vv_m(v48, v48, v56); |
| 632 | vsha_w_r_vv_m(v48, v48, v60); |
| 633 | vadd_w_vx_m(v48, v48, output_offset); |
| 634 | vmax_w_vx_m(v48, v48, output_activation_min); |
| 635 | vmin_w_vx_m(v48, v48, output_activation_max); |
| 636 | vsraqs_b_vx(v48, v48, 0); |
| 637 | vst_b_x(v48, p_output + (out_x * output_depth) + y_offset); |
| 638 | } |
| 639 | } |
| 640 | } |
| 641 | } |
| 642 | } |
Alex Van Damme | b1afda6 | 2024-05-09 16:48:40 -0700 | [diff] [blame] | 643 | |
| 644 | // special case of input depth = 32n, filter shape of 5x5, stride == 1 |
| 645 | void DepthwiseConvS85x5D32_Stride1( |
| 646 | const tflite::DepthwiseParams& params, const int32_t* output_multiplier, |
| 647 | const int32_t* output_shift, const tflite::RuntimeShape& input_shape, |
| 648 | const int8_t* input_data, const tflite::RuntimeShape& filter_shape, |
| 649 | const int8_t* filter_data, const tflite::RuntimeShape& bias_shape, |
| 650 | const int32_t* bias_data, const tflite::RuntimeShape& output_shape, |
| 651 | int8_t* output_data |
| 652 | ) { |
| 653 | const int stride_width = params.stride_width; |
| 654 | const int stride_height = params.stride_height; |
| 655 | const int pad_width = params.padding_values.width; |
| 656 | const int pad_height = params.padding_values.height; |
| 657 | const int32_t input_offset = params.input_offset; |
| 658 | const int32_t output_offset = params.output_offset; |
| 659 | const int32_t output_activation_min = params.quantized_activation_min; |
| 660 | const int32_t output_activation_max = params.quantized_activation_max; |
| 661 | const int batches = MatchingDim(input_shape, 0, output_shape, 0); |
| 662 | const int input_height = input_shape.Dims(1); |
| 663 | const int input_width = input_shape.Dims(2); |
| 664 | const int input_depth = input_shape.Dims(3); |
| 665 | const int output_height = output_shape.Dims(1); |
| 666 | const int output_width = output_shape.Dims(2); |
| 667 | const int output_depth = output_shape.Dims(3); |
| 668 | int32_t swizzled_bias_data[32]; |
| 669 | int32_t swizzled_shift_multi[32]; |
| 670 | int32_t swizzled_output_multi[32]; |
| 671 | |
| 672 | for (int in_channel = 0; in_channel + 32 <= input_depth; in_channel += 32) { |
| 673 | const int output_channel = in_channel; |
| 674 | VectorSwizzle(bias_data + output_channel, swizzled_bias_data, 32); |
| 675 | VectorSwizzle(output_multiplier + output_channel, swizzled_output_multi, 32); |
| 676 | VectorSwizzle(output_shift + output_channel, swizzled_shift_multi, 32); |
| 677 | |
| 678 | union { |
| 679 | vdwconv_u8_t dwconv; |
| 680 | uint32_t raw; |
| 681 | } cmds; |
| 682 | cmds.raw = 0; |
| 683 | cmds.dwconv.sdata1 = true; |
| 684 | cmds.dwconv.sbias1 = input_offset; |
| 685 | cmds.dwconv.sdata2 = true; |
| 686 | cmds.dwconv.sbias2 = 0; |
| 687 | cmds.dwconv.mode = 0; |
| 688 | cmds.dwconv.sparsity = 0; |
| 689 | cmds.dwconv.regbase = 0; |
| 690 | |
| 691 | // Don't reorder me! |
| 692 | const int8_t* p_flt0 = filter_data + in_channel; |
| 693 | const int32_t stride = input_depth; |
| 694 | vld_b_sp_xx_m(v0, p_flt0, stride); |
| 695 | vld_b_sp_xx_m(v4, p_flt0, stride); |
| 696 | vld_b_sp_xx_m(v8, p_flt0, stride); |
| 697 | vld_b_sp_xx_m(v12, p_flt0, stride); |
| 698 | vld_b_sp_xx_m(v16, p_flt0, stride); |
| 699 | vld_b_sp_xx_m(v20, p_flt0, stride); |
| 700 | vld_b_sp_xx(v24, p_flt0, stride); |
| 701 | |
| 702 | // Extra two registers to get our |
| 703 | // total usage to a multiple of 3 for dwconv. |
| 704 | vdup_b_x(v25, 0); |
| 705 | vdup_b_x(v26, 0); |
| 706 | |
| 707 | for (int batch = 0; batch < batches; ++batch) { |
| 708 | const int8_t* p_output = output_data + (batch * output_height * output_width * output_depth) + output_channel; |
| 709 | for (int out_y = 0; out_y < output_height; ++out_y) { |
| 710 | const int y_offset = out_y * output_width * output_depth; |
| 711 | for (int out_x = 0; out_x < output_width; ++out_x) { |
| 712 | const int in_x_origin = (out_x * stride_width) - pad_width; |
| 713 | const int in_y_origin = (out_y * stride_height) - pad_height; |
| 714 | |
| 715 | bool top_pad = in_y_origin < 0; |
| 716 | bool left_pad = in_x_origin < 0; |
| 717 | int top_pad_count = top_pad ? 0 - in_y_origin : 0; |
| 718 | int left_pad_count = left_pad ? 0 - in_x_origin : 0; |
| 719 | bool bottom_pad = (in_y_origin + 4) >= input_height; |
| 720 | bool right_pad = (in_x_origin + 4) >= input_width; |
| 721 | int bottom_pad_count = std::abs(bottom_pad ? (in_y_origin + 4) - input_height + 1: 0); |
| 722 | int right_pad_count = std::abs(right_pad ? (in_x_origin + 4) - input_width + 1 : 0); |
| 723 | bool padding_required = top_pad || left_pad || bottom_pad || right_pad; |
| 724 | assert(top_pad_count <= pad_height); |
| 725 | assert(bottom_pad_count <= pad_height); |
| 726 | assert(left_pad_count <= pad_width); |
| 727 | assert(right_pad_count <= pad_width); |
| 728 | assert(!(left_pad && right_pad)); |
| 729 | const int8_t* p_in_0 = input_data + |
| 730 | (batch * input_height * input_width * input_depth) + |
| 731 | (in_y_origin * input_width * input_depth) + |
| 732 | (in_x_origin * input_depth) + |
| 733 | in_channel; |
| 734 | const int8_t* p_in_1 = p_in_0 + (input_width * input_depth); |
| 735 | const int8_t* p_in_2 = p_in_1 + (input_width * input_depth); |
| 736 | const int8_t* p_in_3 = p_in_2 + (input_width * input_depth); |
| 737 | const int8_t* p_in_4 = p_in_3 + (input_width * input_depth); |
| 738 | // Extra two registers to get our |
| 739 | // total usage to a multiple of 3 for dwconv. |
| 740 | vdup_b_x(v52, -input_offset); |
| 741 | vdup_b_x(v53, -input_offset); |
| 742 | if (!padding_required) { |
| 743 | vld_b_sp_xx(v27, p_in_0, input_depth); |
| 744 | vld_b_sp_xx_m(v28, p_in_0, input_depth); |
| 745 | vld_b_sp_xx_m(v32, p_in_1, input_depth); |
| 746 | vld_b_sp_xx(v36, p_in_1, input_depth); |
| 747 | vld_b_sp_xx(v37, p_in_2, input_depth); |
| 748 | vld_b_sp_xx(v38, p_in_2, input_depth); |
| 749 | vld_b_sp_xx(v39, p_in_2, input_depth); |
| 750 | vld_b_sp_xx(v40, p_in_2, input_depth); |
| 751 | vld_b_sp_xx(v41, p_in_2, input_depth); |
| 752 | vld_b_sp_xx(v42, p_in_3, input_depth); |
| 753 | vld_b_sp_xx(v43, p_in_3, input_depth); |
| 754 | vld_b_sp_xx(v44, p_in_3, input_depth); |
| 755 | vld_b_sp_xx(v45, p_in_3, input_depth); |
| 756 | vld_b_sp_xx(v46, p_in_3, input_depth); |
| 757 | vld_b_sp_xx(v47, p_in_4, input_depth); |
| 758 | vld_b_sp_xx_m(v48, p_in_4, input_depth); |
| 759 | } else { |
| 760 | // Top row |
| 761 | if (top_pad_count >= 1) { |
| 762 | vdup_b_x(v27, -input_offset); |
| 763 | vdup_b_x_m(v28, -input_offset); |
| 764 | } else { |
| 765 | switch (left_pad_count) { |
| 766 | case 2: |
| 767 | vdup_b_x(v28, -input_offset); |
| 768 | case 1: |
| 769 | vdup_b_x(v27, -input_offset); |
| 770 | } |
| 771 | switch (left_pad_count) { |
| 772 | case 0: |
| 773 | vld_b_x(v27, p_in_0); |
| 774 | case 1: |
| 775 | vld_b_x(v28, p_in_0 + input_depth); |
| 776 | } |
| 777 | vld_b_x(v29, p_in_0 + (2 * input_depth)); |
| 778 | switch (right_pad_count) { |
| 779 | case 2: |
| 780 | vdup_b_x(v30, -input_offset); |
| 781 | case 1: |
| 782 | vdup_b_x(v31, -input_offset); |
| 783 | } |
| 784 | switch (right_pad_count) { |
| 785 | case 0: |
| 786 | vld_b_x(v31, p_in_0 + (4 * input_depth)); |
| 787 | case 1: |
| 788 | vld_b_x(v30, p_in_0 + (3 * input_depth)); |
| 789 | } |
| 790 | } |
| 791 | |
| 792 | // 2nd row |
| 793 | if (top_pad_count == 2) { |
| 794 | vdup_b_x_m(v32, -input_offset); |
| 795 | vdup_b_x(v36, -input_offset); |
| 796 | } else { |
| 797 | switch (left_pad_count) { |
| 798 | case 2: |
| 799 | vdup_b_x(v33, -input_offset); |
| 800 | case 1: |
| 801 | vdup_b_x(v32, -input_offset); |
| 802 | } |
| 803 | switch (left_pad_count) { |
| 804 | case 0: |
| 805 | vld_b_x(v32, p_in_1); |
| 806 | case 1: |
| 807 | vld_b_x(v33, p_in_1 + input_depth); |
| 808 | } |
| 809 | vld_b_x(v34, p_in_1 + (2 * input_depth)); |
| 810 | switch (right_pad_count) { |
| 811 | case 2: |
| 812 | vdup_b_x(v35, -input_offset); |
| 813 | case 1: |
| 814 | vdup_b_x(v36, -input_offset); |
| 815 | } |
| 816 | switch (right_pad_count) { |
| 817 | case 0: |
| 818 | vld_b_x(v36, p_in_1 + (4 * input_depth)); |
| 819 | case 1: |
| 820 | vld_b_x(v35, p_in_1 + (3 * input_depth)); |
| 821 | } |
| 822 | } |
| 823 | |
| 824 | // 3rd row |
| 825 | switch (left_pad_count) { |
| 826 | case 2: |
| 827 | vdup_b_x(v38, -input_offset); |
| 828 | case 1: |
| 829 | vdup_b_x(v37, -input_offset); |
| 830 | } |
| 831 | switch (left_pad_count) { |
| 832 | case 0: |
| 833 | vld_b_x(v37, p_in_2); |
| 834 | case 1: |
| 835 | vld_b_x(v38, p_in_2 + input_depth); |
| 836 | } |
| 837 | vld_b_x(v39, p_in_2 + (2 * input_depth)); |
| 838 | switch (right_pad_count) { |
| 839 | case 2: |
| 840 | vdup_b_x(v40, -input_offset); |
| 841 | case 1: |
| 842 | vdup_b_x(v41, -input_offset); |
| 843 | } |
| 844 | switch (right_pad_count) { |
| 845 | case 0: |
| 846 | vld_b_x(v41, p_in_2 + (4 * input_depth)); |
| 847 | case 1: |
| 848 | vld_b_x(v40, p_in_2 + (3 * input_depth)); |
| 849 | } |
| 850 | |
| 851 | // 4th row |
| 852 | if (bottom_pad_count == 2) { |
| 853 | vdup_b_x(v42, -input_offset); |
| 854 | vdup_b_x(v43, -input_offset); |
| 855 | vdup_b_x(v44, -input_offset); |
| 856 | vdup_b_x(v45, -input_offset); |
| 857 | vdup_b_x(v46, -input_offset); |
| 858 | } else { |
| 859 | switch (left_pad_count) { |
| 860 | case 2: |
| 861 | vdup_b_x(v43, -input_offset); |
| 862 | case 1: |
| 863 | vdup_b_x(v42, -input_offset); |
| 864 | } |
| 865 | switch (left_pad_count) { |
| 866 | case 0: |
| 867 | vld_b_x(v42, p_in_3); |
| 868 | case 1: |
| 869 | vld_b_x(v43, p_in_3 + input_depth); |
| 870 | } |
| 871 | switch (right_pad_count) { |
| 872 | case 2: |
| 873 | vdup_b_x(v45, -input_offset); |
| 874 | case 1: |
| 875 | vdup_b_x(v46, -input_offset); |
| 876 | } |
| 877 | vld_b_x(v44, p_in_3 + (2 * input_depth)); |
| 878 | switch (right_pad_count) { |
| 879 | case 0: |
| 880 | vld_b_x(v46, p_in_3 + (4 * input_depth)); |
| 881 | case 1: |
| 882 | vld_b_x(v45, p_in_3 + (3 * input_depth)); |
| 883 | } |
| 884 | } |
| 885 | |
| 886 | // 5th row |
| 887 | if (bottom_pad_count >= 1) { |
| 888 | vdup_b_x(v47, -input_offset); |
| 889 | vdup_b_x(v48, -input_offset); |
| 890 | vdup_b_x(v49, -input_offset); |
| 891 | vdup_b_x(v50, -input_offset); |
| 892 | vdup_b_x(v51, -input_offset); |
| 893 | } else { |
| 894 | switch (left_pad_count) { |
| 895 | case 2: |
| 896 | vdup_b_x(v48, -input_offset); |
| 897 | case 1: |
| 898 | vdup_b_x(v47, -input_offset); |
| 899 | } |
| 900 | switch (left_pad_count) { |
| 901 | case 0: |
| 902 | vld_b_x(v47, p_in_4); |
| 903 | case 1: |
| 904 | vld_b_x(v48, p_in_4 + input_depth); |
| 905 | } |
| 906 | vld_b_x(v49, p_in_4 + (2 * input_depth)); |
| 907 | switch (right_pad_count) { |
| 908 | case 2: |
| 909 | vdup_b_x(v50, -input_offset); |
| 910 | case 1: |
| 911 | vdup_b_x(v51, -input_offset); |
| 912 | } |
| 913 | switch (right_pad_count) { |
| 914 | case 0: |
| 915 | vld_b_x(v51, p_in_4 + (4 * input_depth)); |
| 916 | case 1: |
| 917 | vld_b_x(v50, p_in_4 + (3 * input_depth)); |
| 918 | } |
| 919 | } |
| 920 | } |
| 921 | |
| 922 | vld_w_x_m(v60, swizzled_bias_data); |
| 923 | adwinit_v(v60, v60); |
| 924 | adwconv_vxv(v60, v27, cmds, v0); |
| 925 | adwconv_vxv(v60, v30, cmds, v3); |
| 926 | adwconv_vxv(v60, v33, cmds, v6); |
| 927 | adwconv_vxv(v60, v36, cmds, v9); |
| 928 | adwconv_vxv(v60, v39, cmds, v12); |
| 929 | adwconv_vxv(v60, v42, cmds, v15); |
| 930 | adwconv_vxv(v60, v45, cmds, v18); |
| 931 | adwconv_vxv(v60, v48, cmds, v21); |
| 932 | vdwconv_vxv(v60, v51, cmds, v24); |
| 933 | |
| 934 | vld_w_x_m(v56, swizzled_output_multi); |
| 935 | vdmulh_w_rn_vv_m(v60, v60, v56); |
| 936 | vld_w_x_m(v56, swizzled_shift_multi); |
| 937 | vrsub_w_vx_m(v56, v56, 0); |
| 938 | vsha_w_r_vv_m(v60, v60, v56); |
| 939 | vadd_w_vx_m(v60, v60, output_offset); |
| 940 | vmax_w_vx_m(v60, v60, output_activation_min); |
| 941 | vmin_w_vx_m(v60, v60, output_activation_max); |
| 942 | vsraqs_b_vx(v60, v60, 0); |
| 943 | vst_b_x(v60, p_output + y_offset + (out_x * output_depth)); |
| 944 | } |
| 945 | } |
Alex Van Damme | 40a8300 | 2024-05-08 16:47:03 -0700 | [diff] [blame] | 946 | } |
| 947 | } |
Alex Van Damme | b1afda6 | 2024-05-09 16:48:40 -0700 | [diff] [blame] | 948 | } |
| 949 | |
| 950 | // special case of input depth = 32n, filter shape of 5x5 |
| 951 | void DepthwiseConvS85x5D32( |
| 952 | const tflite::DepthwiseParams& params, const int32_t* output_multiplier, |
| 953 | const int32_t* output_shift, const tflite::RuntimeShape& input_shape, |
| 954 | const int8_t* input_data, const tflite::RuntimeShape& filter_shape, |
| 955 | const int8_t* filter_data, const tflite::RuntimeShape& bias_shape, |
| 956 | const int32_t* bias_data, const tflite::RuntimeShape& output_shape, |
| 957 | int8_t* output_data |
| 958 | ) { |
| 959 | const int stride_width = params.stride_width; |
| 960 | const int stride_height = params.stride_height; |
| 961 | const int pad_width = params.padding_values.width; |
| 962 | const int pad_height = params.padding_values.height; |
| 963 | const int32_t input_offset = params.input_offset; |
| 964 | const int32_t output_offset = params.output_offset; |
| 965 | const int32_t output_activation_min = params.quantized_activation_min; |
| 966 | const int32_t output_activation_max = params.quantized_activation_max; |
| 967 | const int batches = MatchingDim(input_shape, 0, output_shape, 0); |
| 968 | const int input_height = input_shape.Dims(1); |
| 969 | const int input_width = input_shape.Dims(2); |
| 970 | const int input_depth = input_shape.Dims(3); |
| 971 | const int filter_height = filter_shape.Dims(1); |
| 972 | const int filter_width = filter_shape.Dims(2); |
| 973 | const int output_height = output_shape.Dims(1); |
| 974 | const int output_width = output_shape.Dims(2); |
| 975 | const int output_depth = output_shape.Dims(3); |
| 976 | int32_t swizzled_bias_data[32]; |
| 977 | int32_t swizzled_shift_multi[32]; |
| 978 | int32_t swizzled_output_multi[32]; |
| 979 | |
| 980 | for (int in_channel = 0; in_channel + 32 <= input_depth; in_channel += 32) { |
| 981 | const int output_channel = in_channel; |
| 982 | VectorSwizzle(bias_data + output_channel, swizzled_bias_data, 32); |
| 983 | VectorSwizzle(output_multiplier + output_channel, swizzled_output_multi, 32); |
| 984 | VectorSwizzle(output_shift + output_channel, swizzled_shift_multi, 32); |
| 985 | |
| 986 | vld_w_x_m(v52, swizzled_bias_data); |
| 987 | vld_w_x_m(v56, swizzled_output_multi); |
| 988 | vld_w_x_m(v60, swizzled_shift_multi); |
| 989 | vrsub_w_vx_m(v60, v60, 0); |
| 990 | |
| 991 | // Don't reorder me! |
| 992 | const int8_t* p_flt = filter_data + in_channel; |
| 993 | vld_b_sp_xx(v6, p_flt, input_depth); |
| 994 | vld_b_sp_xx(v7, p_flt, input_depth); |
| 995 | vld_b_sp_xx_m(v8, p_flt, input_depth); |
| 996 | vld_b_sp_xx_m(v12, p_flt, input_depth); |
| 997 | vld_b_sp_xx_m(v16, p_flt, input_depth); |
| 998 | vld_b_sp_xx_m(v20, p_flt, input_depth); |
| 999 | vld_b_sp_xx_m(v24, p_flt, input_depth); |
| 1000 | vld_b_sp_xx(v28, p_flt, input_depth); |
| 1001 | vld_b_sp_xx(v29, p_flt, input_depth); |
| 1002 | vld_b_sp_xx(v30, p_flt, input_depth); |
| 1003 | |
| 1004 | |
| 1005 | for (int batch = 0; batch < batches; ++batch) { |
| 1006 | const int8_t* p_input = input_data + (batch * input_width * input_height * input_depth) + in_channel; |
| 1007 | const int8_t* p_output = output_data + (batch * output_width * output_height * output_depth) + output_channel; |
| 1008 | for (int out_y = 0; out_y < output_height; ++out_y) { |
| 1009 | const int out_y_offset = (out_y * output_width * output_depth); |
| 1010 | for (int out_x = 0; out_x < output_width; ++out_x) { |
| 1011 | const int in_x_origin = (out_x * stride_width) - pad_width; |
| 1012 | const int in_y_origin = (out_y * stride_height) - pad_height; |
| 1013 | |
| 1014 | // Initialize accumulators w/ bias_data |
| 1015 | vmv_v_m(v48, v52); |
| 1016 | |
| 1017 | for (int filter_y = 0; filter_y < filter_height; ++filter_y) { |
| 1018 | const int in_y = in_y_origin + filter_y; |
| 1019 | if ((in_y < 0) || (in_y >= input_height)) { |
| 1020 | continue; |
| 1021 | } |
| 1022 | switch (filter_y) { |
| 1023 | case 0: |
| 1024 | vaddw_h_vx(v31, v6, 0); |
| 1025 | vaddw_h_vx(v33, v7, 0); |
| 1026 | vaddw_h_vx(v35, v8, 0); |
| 1027 | vaddw_h_vx(v37, v9, 0); |
| 1028 | vaddw_h_vx(v39, v10, 0); |
| 1029 | break; |
| 1030 | case 1: |
| 1031 | vaddw_h_vx(v31, v11, 0); |
| 1032 | vaddw_h_vx(v33, v12, 0); |
| 1033 | vaddw_h_vx(v35, v13, 0); |
| 1034 | vaddw_h_vx(v37, v14, 0); |
| 1035 | vaddw_h_vx(v39, v15, 0); |
| 1036 | break; |
| 1037 | case 2: |
| 1038 | vaddw_h_vx(v31, v16, 0); |
| 1039 | vaddw_h_vx(v33, v17, 0); |
| 1040 | vaddw_h_vx(v35, v18, 0); |
| 1041 | vaddw_h_vx(v37, v19, 0); |
| 1042 | vaddw_h_vx(v39, v20, 0); |
| 1043 | break; |
| 1044 | case 3: |
| 1045 | vaddw_h_vx(v31, v21, 0); |
| 1046 | vaddw_h_vx(v33, v22, 0); |
| 1047 | vaddw_h_vx(v35, v23, 0); |
| 1048 | vaddw_h_vx(v37, v24, 0); |
| 1049 | vaddw_h_vx(v39, v25, 0); |
| 1050 | break; |
| 1051 | case 4: |
| 1052 | vaddw_h_vx(v31, v26, 0); |
| 1053 | vaddw_h_vx(v33, v27, 0); |
| 1054 | vaddw_h_vx(v35, v28, 0); |
| 1055 | vaddw_h_vx(v37, v29, 0); |
| 1056 | vaddw_h_vx(v39, v30, 0); |
| 1057 | break; |
| 1058 | } |
| 1059 | const int in_y_offset = in_y * input_width * input_depth; |
| 1060 | for (int filter_x = 0; filter_x < filter_width; ++filter_x) { |
| 1061 | const int in_x = in_x_origin + filter_x; |
| 1062 | if ((in_x < 0) || (in_x >= input_width)) { |
| 1063 | continue; |
| 1064 | } |
| 1065 | |
| 1066 | vld_b_x(v0, p_input + (in_x * input_depth) + in_y_offset); |
| 1067 | |
| 1068 | vaddw_h_vx(v0, v0, 0); |
| 1069 | vadd_h_vx(v0, v0, static_cast<int16_t>(input_offset)); |
| 1070 | vadd_h_vx(v1, v1, |
| 1071 | static_cast<int16_t>(input_offset)); // v0 v1 input |
| 1072 | switch (filter_x) { |
| 1073 | case 0: |
| 1074 | vmulw_w_vv(v2, v1, v32); |
| 1075 | vmulw_w_vv(v0, v0, v31); |
| 1076 | break; |
| 1077 | case 1: |
| 1078 | vmulw_w_vv(v2, v1, v34); |
| 1079 | vmulw_w_vv(v0, v0, v33); |
| 1080 | break; |
| 1081 | case 2: |
| 1082 | vmulw_w_vv(v2, v1, v36); |
| 1083 | vmulw_w_vv(v0, v0, v35); |
| 1084 | break; |
| 1085 | case 3: |
| 1086 | vmulw_w_vv(v2, v1, v38); |
| 1087 | vmulw_w_vv(v0, v0, v37); |
| 1088 | break; |
| 1089 | case 4: |
| 1090 | vmulw_w_vv(v2, v1, v40); |
| 1091 | vmulw_w_vv(v0, v0, v39); |
| 1092 | break; |
| 1093 | } |
| 1094 | vadd_w_vv_m(v48, v48, v0); |
| 1095 | } |
| 1096 | } |
| 1097 | |
| 1098 | vdmulh_w_rn_vv_m(v48, v48, v56); |
| 1099 | vsha_w_r_vv_m(v48, v48, v60); |
| 1100 | vadd_w_vx_m(v48, v48, output_offset); |
| 1101 | vmax_w_vx_m(v48, v48, output_activation_min); |
| 1102 | vmin_w_vx_m(v48, v48, output_activation_max); |
| 1103 | vsraqs_b_vx(v48, v48, 0); |
| 1104 | vst_b_x(v48, p_output + out_y_offset + (out_x * output_depth)); |
| 1105 | } |
| 1106 | } |
| 1107 | } |
| 1108 | } |
| 1109 | } |
Alex Van Damme | 40a8300 | 2024-05-08 16:47:03 -0700 | [diff] [blame] | 1110 | |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 1111 | // special case of input depth = 32n |
| 1112 | void DepthwiseConvS8D32( |
Naveen Dodda | be4ab97 | 2024-04-17 17:47:46 +0000 | [diff] [blame] | 1113 | const tflite::DepthwiseParams& params, const int32_t* output_multiplier, |
| 1114 | const int32_t* output_shift, const tflite::RuntimeShape& input_shape, |
| 1115 | const int8_t* input_data, const tflite::RuntimeShape& filter_shape, |
| 1116 | const int8_t* filter_data, const tflite::RuntimeShape& bias_shape, |
| 1117 | const int32_t* bias_data, const tflite::RuntimeShape& output_shape, |
| 1118 | int8_t* output_data |
| 1119 | |
| 1120 | ) { |
| 1121 | const int stride_width = params.stride_width; |
| 1122 | const int stride_height = params.stride_height; |
Naveen Dodda | be4ab97 | 2024-04-17 17:47:46 +0000 | [diff] [blame] | 1123 | const int pad_width = params.padding_values.width; |
| 1124 | const int pad_height = params.padding_values.height; |
| 1125 | const int32_t input_offset = params.input_offset; |
| 1126 | const int32_t output_offset = params.output_offset; |
| 1127 | const int32_t output_activation_min = params.quantized_activation_min; |
| 1128 | const int32_t output_activation_max = params.quantized_activation_max; |
| 1129 | const int batches = MatchingDim(input_shape, 0, output_shape, 0); |
| 1130 | const int input_height = input_shape.Dims(1); |
| 1131 | const int input_width = input_shape.Dims(2); |
| 1132 | const int input_depth = input_shape.Dims(3); |
| 1133 | const int filter_height = filter_shape.Dims(1); |
| 1134 | const int filter_width = filter_shape.Dims(2); |
| 1135 | const int output_height = output_shape.Dims(1); |
| 1136 | const int output_width = output_shape.Dims(2); |
Alex Van Damme | 40a8300 | 2024-05-08 16:47:03 -0700 | [diff] [blame] | 1137 | int32_t swizzled_bias_data[32]; |
| 1138 | int32_t swizzled_shift_multi[32]; |
| 1139 | int32_t swizzled_output_multi[32]; |
Naveen Dodda | be4ab97 | 2024-04-17 17:47:46 +0000 | [diff] [blame] | 1140 | |
| 1141 | for (int in_channel = 0; in_channel + 32 <= input_depth; in_channel += 32) { |
| 1142 | const int output_channel = in_channel; |
Alex Van Damme | 40a8300 | 2024-05-08 16:47:03 -0700 | [diff] [blame] | 1143 | VectorSwizzle(bias_data + output_channel, swizzled_bias_data, 32); |
| 1144 | VectorSwizzle(output_multiplier + output_channel, swizzled_output_multi, 32); |
| 1145 | VectorSwizzle(output_shift + output_channel, swizzled_shift_multi, 32); |
Naveen Dodda | be4ab97 | 2024-04-17 17:47:46 +0000 | [diff] [blame] | 1146 | |
| 1147 | vld_w_x_m(v20, swizzled_bias_data); |
| 1148 | vld_w_x_m(v24, swizzled_output_multi); |
| 1149 | vld_w_x_m(v28, swizzled_shift_multi); |
| 1150 | vrsub_w_vx_m(v28, v28, 0); |
| 1151 | |
| 1152 | for (int batch = 0; batch < batches; ++batch) { |
| 1153 | for (int out_y = 0; out_y < output_height; ++out_y) { |
| 1154 | for (int out_x = 0; out_x < output_width; ++out_x) { |
| 1155 | const int in_x_origin = (out_x * stride_width) - pad_width; |
| 1156 | const int in_y_origin = (out_y * stride_height) - pad_height; |
| 1157 | |
Alex Van Damme | 40a8300 | 2024-05-08 16:47:03 -0700 | [diff] [blame] | 1158 | vdup_w_x_m(v48, 0); |
Naveen Dodda | be4ab97 | 2024-04-17 17:47:46 +0000 | [diff] [blame] | 1159 | for (int filter_y = 0; filter_y < filter_height; ++filter_y) { |
| 1160 | const int in_y = in_y_origin + filter_y; |
| 1161 | if ((in_y < 0) || (in_y >= input_height)) { |
| 1162 | continue; |
| 1163 | } |
| 1164 | for (int filter_x = 0; filter_x < filter_width; ++filter_x) { |
| 1165 | const int in_x = in_x_origin + filter_x; |
| 1166 | if ((in_x < 0) || (in_x >= input_width)) { |
| 1167 | continue; |
| 1168 | } |
| 1169 | |
| 1170 | vld_b_x(v0, &input_data[tflite::Offset(input_shape, batch, in_y, |
| 1171 | in_x, in_channel)]); // xp |
| 1172 | vld_b_x(v4, &filter_data[tflite::Offset(filter_shape, 0, filter_y, |
| 1173 | filter_x, in_channel)]); |
| 1174 | |
| 1175 | vaddw_h_vx(v0, v0, 0); |
| 1176 | vadd_h_vx(v0, v0, static_cast<int16_t>(input_offset)); |
| 1177 | vadd_h_vx(v1, v1, |
| 1178 | static_cast<int16_t>(input_offset)); // v0 v1 input |
| 1179 | |
| 1180 | vaddw_h_vx(v4, v4, static_cast<int16_t>(0)); |
| 1181 | vmulw_w_vv(v8, v0, v4); |
| 1182 | vmulw_w_vv(v10, v1, v5); |
| 1183 | |
| 1184 | vadd_w_vv_m(v48, v48, v8); |
| 1185 | } |
| 1186 | } |
| 1187 | |
| 1188 | vadd_w_vv_m(v48, v48, v20); // add bias |
Alex Van Damme | 40a8300 | 2024-05-08 16:47:03 -0700 | [diff] [blame] | 1189 | vdmulh_w_rn_vv_m(v48, v48, v24); |
Naveen Dodda | be4ab97 | 2024-04-17 17:47:46 +0000 | [diff] [blame] | 1190 | vsha_w_r_vv_m(v48, v48, v28); |
| 1191 | vadd_w_vx_m(v48, v48, output_offset); |
| 1192 | vmax_w_vx_m(v48, v48, output_activation_min); |
| 1193 | vmin_w_vx_m(v48, v48, output_activation_max); |
| 1194 | vsraqs_b_vx(v48, v48, 0); |
| 1195 | vst_b_x(v48, &output_data[tflite::Offset(output_shape, batch, out_y, |
| 1196 | out_x, output_channel)]); |
| 1197 | } |
| 1198 | } |
| 1199 | } |
| 1200 | } |
| 1201 | } |
| 1202 | |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 1203 | // generic implementation based on Kelvin ops |
| 1204 | void DepthwiseConvS8Generic( |
| 1205 | const tflite::DepthwiseParams& params, const int32_t* output_multiplier, |
| 1206 | const int32_t* output_shift, const tflite::RuntimeShape& input_shape, |
| 1207 | const int8_t* input_data, const tflite::RuntimeShape& filter_shape, |
| 1208 | const int8_t* filter_data, const tflite::RuntimeShape& bias_shape, |
| 1209 | const int32_t* bias_data, const tflite::RuntimeShape& output_shape, |
| 1210 | int8_t* output_data) { |
| 1211 | // TBD: Use Kelvin implementation to replace the below |
| 1212 | tflite::reference_integer_ops::DepthwiseConvPerChannel( |
| 1213 | params, output_multiplier, output_shift, input_shape, input_data, |
| 1214 | filter_shape, filter_data, bias_shape, bias_data, output_shape, |
| 1215 | output_data); |
| 1216 | return; |
| 1217 | } |
| 1218 | } // namespace |
| 1219 | |
| 1220 | void DepthwiseConvS8( |
Naveen Dodda | be4ab97 | 2024-04-17 17:47:46 +0000 | [diff] [blame] | 1221 | const tflite::DepthwiseParams& params, const int32_t* output_multiplier, |
| 1222 | const int32_t* output_shift, const tflite::RuntimeShape& input_shape, |
| 1223 | const int8_t* input_data, const tflite::RuntimeShape& filter_shape, |
| 1224 | const int8_t* filter_data, const tflite::RuntimeShape& bias_shape, |
| 1225 | const int32_t* bias_data, const tflite::RuntimeShape& output_shape, |
| 1226 | int8_t* output_data) { |
| 1227 | // Get parameters. |
| 1228 | // TODO(b/141565753): Re-introduce ScopedProfilingLabel on Micro. |
| 1229 | const int stride_width = params.stride_width; |
| 1230 | const int stride_height = params.stride_height; |
Alex Van Damme | cd3d0e3 | 2024-05-10 15:27:06 -0700 | [diff] [blame] | 1231 | const int pad_width = params.padding_values.width; |
| 1232 | const int pad_height = params.padding_values.height; |
Alex Van Damme | b1afda6 | 2024-05-09 16:48:40 -0700 | [diff] [blame] | 1233 | const int filter_height = filter_shape.Dims(1); |
| 1234 | const int filter_width = filter_shape.Dims(2); |
Naveen Dodda | be4ab97 | 2024-04-17 17:47:46 +0000 | [diff] [blame] | 1235 | const int dilation_width_factor = params.dilation_width_factor; |
| 1236 | const int dilation_height_factor = params.dilation_height_factor; |
Naveen Dodda | be4ab97 | 2024-04-17 17:47:46 +0000 | [diff] [blame] | 1237 | const int depth_multiplier = params.depth_multiplier; |
Naveen Dodda | be4ab97 | 2024-04-17 17:47:46 +0000 | [diff] [blame] | 1238 | const int32_t output_activation_min = params.quantized_activation_min; |
| 1239 | const int32_t output_activation_max = params.quantized_activation_max; |
| 1240 | |
| 1241 | // Check dimensions of the tensors. |
| 1242 | TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); |
| 1243 | TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); |
| 1244 | TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); |
| 1245 | |
| 1246 | TFLITE_DCHECK_LE(output_activation_min, output_activation_max); |
Naveen Dodda | be4ab97 | 2024-04-17 17:47:46 +0000 | [diff] [blame] | 1247 | const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3); |
Naveen Dodda | be4ab97 | 2024-04-17 17:47:46 +0000 | [diff] [blame] | 1248 | const int input_depth = input_shape.Dims(3); |
Naveen Dodda | be4ab97 | 2024-04-17 17:47:46 +0000 | [diff] [blame] | 1249 | TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier); |
| 1250 | TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth); |
| 1251 | |
Alex Van Damme | 13fe02a | 2024-06-06 16:27:54 -0700 | [diff] [blame] | 1252 | #define RUN_KERNEL(kernel) { \ |
| 1253 | kernel(params, output_multiplier, output_shift, input_shape, input_data, \ |
| 1254 | filter_shape, filter_data, bias_shape, bias_data, output_shape, \ |
| 1255 | output_data \ |
| 1256 | ); \ |
| 1257 | return; \ |
| 1258 | } |
| 1259 | |
Alex Van Damme | 40a8300 | 2024-05-08 16:47:03 -0700 | [diff] [blame] | 1260 | if (depth_multiplier == 1 && |
Naveen Dodda | be4ab97 | 2024-04-17 17:47:46 +0000 | [diff] [blame] | 1261 | dilation_height_factor == 1 && dilation_width_factor == 1 && |
Alex Van Damme | 40a8300 | 2024-05-08 16:47:03 -0700 | [diff] [blame] | 1262 | stride_height <= 2 && stride_width <= 2) { |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 1263 | // special case of output depth = 32n |
| 1264 | if (output_depth % 32 == 0) { |
Alex Van Damme | b1afda6 | 2024-05-09 16:48:40 -0700 | [diff] [blame] | 1265 | if (filter_width == 5 && filter_height == 5) { |
| 1266 | if (stride_width <= 1 && stride_height <= 1) { |
Alex Van Damme | 13fe02a | 2024-06-06 16:27:54 -0700 | [diff] [blame] | 1267 | RUN_KERNEL(DepthwiseConvS85x5D32_Stride1); |
Alex Van Damme | b1afda6 | 2024-05-09 16:48:40 -0700 | [diff] [blame] | 1268 | } |
Alex Van Damme | 13fe02a | 2024-06-06 16:27:54 -0700 | [diff] [blame] | 1269 | RUN_KERNEL(DepthwiseConvS85x5D32); |
| 1270 | } if (filter_width == 3 && filter_height == 3 && pad_width <= 1 && pad_height <= 1 && stride_width == 1 && stride_height == 1) { |
| 1271 | RUN_KERNEL(DepthwiseConvS83x3D32_Stride1); |
| 1272 | } if (filter_width == 3 && filter_height == 3 && pad_width <= 1 && pad_height <= 1) { |
| 1273 | RUN_KERNEL(DepthwiseConvS83x3D32); |
Alex Van Damme | b1afda6 | 2024-05-09 16:48:40 -0700 | [diff] [blame] | 1274 | } |
Alex Van Damme | 13fe02a | 2024-06-06 16:27:54 -0700 | [diff] [blame] | 1275 | RUN_KERNEL(DepthwiseConvS8D32); |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 1276 | } |
| 1277 | |
Alex Van Damme | 13fe02a | 2024-06-06 16:27:54 -0700 | [diff] [blame] | 1278 | RUN_KERNEL(DepthwiseConvS8Generic); |
Naveen Dodda | be4ab97 | 2024-04-17 17:47:46 +0000 | [diff] [blame] | 1279 | } |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 1280 | |
Alex Van Damme | 13fe02a | 2024-06-06 16:27:54 -0700 | [diff] [blame] | 1281 | RUN_KERNEL(tflite::reference_integer_ops::DepthwiseConvPerChannel); |
| 1282 | |
| 1283 | #undef RUN_KERNEL |
Naveen Dodda | be4ab97 | 2024-04-17 17:47:46 +0000 | [diff] [blame] | 1284 | } |
Lun Dong | 3b8d3cb | 2024-05-07 01:50:35 -0700 | [diff] [blame] | 1285 | |
| 1286 | } // namespace kelvin::opt |