Fix swizzling and rounding for depthwise conv
- DepthwiseConv requires a different swizzling method than the original
copied from Conv2D.
- Use the _rn rounding mode for vdmulh, to match the reference kernel
behaviour.
Change-Id: Iad1a75258d64ae968123559d7130dbc250902c89
diff --git a/tflm/opt/depthwise_conv_s8.cc b/tflm/opt/depthwise_conv_s8.cc
index c94b38a..111fdc1 100644
--- a/tflm/opt/depthwise_conv_s8.cc
+++ b/tflm/opt/depthwise_conv_s8.cc
@@ -22,6 +22,25 @@
namespace kelvin::opt {
namespace {
+// Reorders a vector to match the pattern after double-widening.
+// N must be a multiple of 4.
+void VectorSwizzle(const int32_t* input, int32_t* output, int N) {
+ assert(N >= 4 && N % 4 == 0);
+ const int32_t(&in)[N] = *(int32_t(*)[N])input;
+ int32_t(&out)[N] = *(int32_t(*)[N]) output;
+ const int32_t* p_in = in;
+ for (int i = 0; i < N / 4; ++i) {
+ int32_t* out0 = out + i + 0;
+ int32_t* out1 = out + i + 16;
+ int32_t* out2 = out + i + 8;
+ int32_t* out3 = out + i + 24;
+ *out0 = *p_in++;
+ *out1 = *p_in++;
+ *out2 = *p_in++;
+ *out3 = *p_in++;
+ }
+ }
+
// special case of input depth = 32n
void DepthwiseConvS8D32(
const tflite::DepthwiseParams& params, const int32_t* output_multiplier,
@@ -48,15 +67,15 @@
const int filter_width = filter_shape.Dims(2);
const int output_height = output_shape.Dims(1);
const int output_width = output_shape.Dims(2);
- int32_t swizzled_bias_data[32 * 4];
- int32_t swizzled_shift_multi[32 * 4];
- int32_t swizzled_output_multi[32 * 4];
+ int32_t swizzled_bias_data[32];
+ int32_t swizzled_shift_multi[32];
+ int32_t swizzled_output_multi[32];
for (int in_channel = 0; in_channel + 32 <= input_depth; in_channel += 32) {
const int output_channel = in_channel;
- Swizzle(bias_data + output_channel, swizzled_bias_data, 32);
- Swizzle(output_multiplier + output_channel, swizzled_output_multi, 32);
- Swizzle(output_shift + output_channel, swizzled_shift_multi, 32);
+ VectorSwizzle(bias_data + output_channel, swizzled_bias_data, 32);
+ VectorSwizzle(output_multiplier + output_channel, swizzled_output_multi, 32);
+ VectorSwizzle(output_shift + output_channel, swizzled_shift_multi, 32);
vld_w_x_m(v20, swizzled_bias_data);
vld_w_x_m(v24, swizzled_output_multi);
@@ -69,6 +88,7 @@
const int in_x_origin = (out_x * stride_width) - pad_width;
const int in_y_origin = (out_y * stride_height) - pad_height;
+ vdup_w_x_m(v48, 0);
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
const int in_y = in_y_origin + filter_y;
if ((in_y < 0) || (in_y >= input_height)) {
@@ -99,7 +119,7 @@
}
vadd_w_vv_m(v48, v48, v20); // add bias
- vdmulh_w_r_vv_m(v48, v48, v24);
+ vdmulh_w_rn_vv_m(v48, v48, v24);
vsha_w_r_vv_m(v48, v48, v28);
vadd_w_vx_m(v48, v48, output_offset);
vmax_w_vx_m(v48, v48, output_activation_min);
@@ -143,8 +163,6 @@
const int stride_height = params.stride_height;
const int dilation_width_factor = params.dilation_width_factor;
const int dilation_height_factor = params.dilation_height_factor;
- const int pad_width = params.padding_values.width;
- const int pad_height = params.padding_values.height;
const int depth_multiplier = params.depth_multiplier;
const int32_t output_activation_min = params.quantized_activation_min;
const int32_t output_activation_max = params.quantized_activation_max;
@@ -160,9 +178,9 @@
TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
- if (depth_multiplier == 1 && pad_height < 2 && pad_width < 2 &&
+ if (depth_multiplier == 1 &&
dilation_height_factor == 1 && dilation_width_factor == 1 &&
- stride_height == 1 && stride_width == 1) {
+ stride_height <= 2 && stride_width <= 2) {
// generic implementation by default
auto fn = DepthwiseConvS8Generic;