blob: c7db40789e7dfbe51ea2b52a850c06e912e784bd [file] [log] [blame]
Cindy Liu43879e42023-10-18 11:18:03 -07001/*
2 * Copyright 2023 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
Derek Chow16ffb602023-10-09 14:27:05 -070016
Lun Dong3b8d3cb2024-05-07 01:50:35 -070017// Depthwise convolution based on Kelvin ops
18// Data types: input: s16, filter: s8, bias s64
Derek Chow16ffb602023-10-09 14:27:05 -070019
Lun Dong3b8d3cb2024-05-07 01:50:35 -070020#include "tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h"
21#include "tflm/opt/conv_util.h"
Derek Chow16ffb602023-10-09 14:27:05 -070022
23namespace kelvin::opt {
Lun Dong3b8d3cb2024-05-07 01:50:35 -070024namespace {
25void DepthwiseConvS16K3x1(
26 const tflite::DepthwiseParams& params, const int32_t* output_multiplier,
27 const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
28 const int16_t* input_data, const tflite::RuntimeShape& filter_shape,
29 const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
30 const int64_t* bias_data, const tflite::RuntimeShape& output_shape,
31 int16_t* output_data) {
32 const int16_t* activations = input_data;
33 const int8_t* weights = filter_data;
34 const int64_t* biases = bias_data;
35 int channels = filter_shape.Dims(3);
36 int frames = input_shape.Dims(2);
37 int dilation = params.dilation_width_factor;
38 const int32_t* output_mult = output_multiplier;
39 int32_t output_activation_min = params.quantized_activation_min;
40 int32_t output_activation_max = params.quantized_activation_max;
41 int16_t* output = output_data;
Derek Chow16ffb602023-10-09 14:27:05 -070042
Derek Chow16ffb602023-10-09 14:27:05 -070043 for (int c = 0; c + 32 <= channels; c += 32) {
44 // Load weights and interleave into correct order [v58-v63].
45 // Because there are more activations than weights, interleave weights.
46 const int8_t* local_weights0 = weights + c;
47 vld_b_p_xx(v0, local_weights0, channels);
48 vaddw_h_vx(v48, v0, 0);
49 vzip_h_vv(v58, v48, v49);
50
51 vld_b_p_xx(v1, local_weights0, channels);
52 vaddw_h_vx(v50, v1, 0);
53 vzip_h_vv(v60, v50, v51);
54
55 vld_b_x(v2, local_weights0);
56 vaddw_h_vx(v52, v2, 0);
57 vzip_h_vv(v62, v52, v53);
58
59 // Assume biases fit in 32-bit. This assumption is verified offline.
60 // Load biases and swizzle [v52-v55].
61 int32_t local_biases[32];
62 for (int j = 0; j < 32; j++) {
63 local_biases[j] = static_cast<int32_t>(biases[c + j]);
64 }
65 vld_w_x_m(v4, local_biases);
66 vzip_w_vv(v52, v4, v5);
67 vzip_w_vv(v54, v6, v7);
68
Derek Chow766e5af2023-10-12 19:14:52 -070069 const int32_t step = dilation * channels;
Derek Chow16ffb602023-10-09 14:27:05 -070070 const int32_t* local_output_mult = output_mult + c;
71 const int32_t* local_output_shift = output_shift + c;
Derek Chow766e5af2023-10-12 19:14:52 -070072 for (int d = 0; d < dilation; d++) {
73 // Accumulators will be [v48 - v51].
74 const int16_t* local_activations0 = activations + (d * channels) + c;
75 const int16_t* local_activations1 = local_activations0 + 16;
76 int16_t* local_output = output + (d * channels) + c;
Derek Chow16ffb602023-10-09 14:27:05 -070077
Derek Chow766e5af2023-10-12 19:14:52 -070078 // Registers [v0-v5 will be for loading activations]
79 // Preload for valid padding:
80 vld_h_p_xx(v0, local_activations0, step);
81 vld_h_p_xx(v1, local_activations1, step);
82 vld_h_p_xx(v2, local_activations0, step);
83 vld_h_p_xx(v3, local_activations1, step);
Derek Chow16ffb602023-10-09 14:27:05 -070084
Derek Chow766e5af2023-10-12 19:14:52 -070085 int frames_idx = (2 * dilation) + d;
86 int32_t accumulators[32];
87 for (; frames_idx < frames; frames_idx += dilation) {
88 vld_h_p_xx(v4, local_activations0, step);
89 vld_h_p_xx(v5, local_activations1, step);
Lun Dong3b8d3cb2024-05-07 01:50:35 -070090 vmulw_w_vv(v48, v58, v0); // Clobber accumulator
91 vmulw_w_vv(v50, v59, v1); // Clobber accumulator
Derek Chow766e5af2023-10-12 19:14:52 -070092 vadd_w_vv_m(v48, v48, v52); // Add bias.
93 vmulw_w_vv(v40, v60, v2);
94 vmulw_w_vv(v42, v61, v3);
95 vadd_w_vv_m(v48, v48, v40);
96 vmulw_w_vv(v44, v62, v4);
97 vmulw_w_vv(v46, v63, v5);
98 vadd_w_vv_m(v48, v48, v44);
Derek Chow16ffb602023-10-09 14:27:05 -070099
Derek Chow766e5af2023-10-12 19:14:52 -0700100 vzip_w_vv(v48, v48, v49); // Swizzle accumulators
101 vzip_w_vv(v50, v50, v51);
Derek Chow16ffb602023-10-09 14:27:05 -0700102
Derek Chow766e5af2023-10-12 19:14:52 -0700103 vst_w_x_m(v48, accumulators); // Store accumulators
Derek Chow16ffb602023-10-09 14:27:05 -0700104
Derek Chow766e5af2023-10-12 19:14:52 -0700105 // Output pipeline in scalar, to preserve bit accuracy with the ARM CPU
106 // implementation.
107 for (int i = 0; i < 32; i++) {
108 int32_t result = tflite::MultiplyByQuantizedMultiplier(
109 static_cast<int64_t>(accumulators[i]), local_output_mult[i],
110 local_output_shift[i]);
111
112 local_output[i] = static_cast<int16_t>(
113 std::clamp(result, output_activation_min, output_activation_max));
114 }
115
116 // Slide registers
117 vmvp_vv(v0, v2, v3);
118 vmvp_vv(v2, v4, v5);
119
120 local_output += step;
Derek Chow16ffb602023-10-09 14:27:05 -0700121 }
Derek Chow16ffb602023-10-09 14:27:05 -0700122 }
123 }
124 // TODO(derekjchow): Handle channels % 32 cases.
125 // Break it down into:
126 // - one loop looking for 16 byte stripes
127 // - one final loop handling remainder
128}
129
Lun Dong3b8d3cb2024-05-07 01:50:35 -0700130// generic implementation based on Kelvin ops
131void DepthwiseConvS16Generic(
132 const tflite::DepthwiseParams& params, const int32_t* output_multiplier,
133 const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
134 const int16_t* input_data, const tflite::RuntimeShape& filter_shape,
135 const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
136 const int64_t* bias_data, const tflite::RuntimeShape& output_shape,
137 int16_t* output_data) {
138 // TBD: Use Kelvin implementation to replace the below
139 tflite::reference_integer_ops::DepthwiseConvPerChannel(
140 params, output_multiplier, output_shift, input_shape, input_data,
141 filter_shape, filter_data, bias_shape, bias_data, output_shape,
142 output_data);
143 return;
144}
145} // namespace
146
147void DepthwiseConvS16(
148 const tflite::DepthwiseParams& params, const int32_t* output_multiplier,
149 const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
150 const int16_t* input_data, const tflite::RuntimeShape& filter_shape,
151 const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
152 const int64_t* bias_data, const tflite::RuntimeShape& output_shape,
153 int16_t* output_data) {
154 // Get parameters.
155 const int stride_width = params.stride_width;
156 const int stride_height = params.stride_height;
157 const int dilation_width_factor = params.dilation_width_factor;
158 const int dilation_height_factor = params.dilation_height_factor;
159 const int filter_height = filter_shape.Dims(1);
160 const int filter_width = filter_shape.Dims(2);
161
162 if (params.padding_type == tflite::PaddingType::kValid && stride_width == 1 &&
163 stride_height == 1 && dilation_width_factor == 1 &&
164 dilation_height_factor == 1) {
165 // generic implementation by default
166 auto fn = DepthwiseConvS16Generic;
167
168 // special case of filter size 3x1
169 if (filter_height == 1 && filter_width == 3) {
170 fn = DepthwiseConvS16K3x1;
171 }
172
173 fn(params, output_multiplier, output_shift, input_shape, input_data,
174 filter_shape, filter_data, bias_shape, bias_data, output_shape,
175 output_data);
176 return;
177 }
178
179 // Use reference implementation
180 tflite::reference_integer_ops::DepthwiseConvPerChannel(
181 params, output_multiplier, output_shift, input_shape, input_data,
182 filter_shape, filter_data, bias_shape, bias_data, output_shape,
183 output_data);
184}
185
Derek Chow766e5af2023-10-12 19:14:52 -0700186} // namespace kelvin::opt