Add dilation support for depthwise convolution.
Change-Id: I3413edda1b9679401adab6e3fe0cee1502918e8e
diff --git a/tflm/opt/depthwise_conv_s16.cc b/tflm/opt/depthwise_conv_s16.cc
index e05ef13..0d04e2a 100644
--- a/tflm/opt/depthwise_conv_s16.cc
+++ b/tflm/opt/depthwise_conv_s16.cc
@@ -13,7 +13,7 @@
void DepthwiseConv2DKelvinS16K3x1(const int16_t* activations,
const int8_t* weights,
const int64_t* biases,
- int channels, int frames,
+ int channels, int frames, int dilation,
const int32_t* output_mult,
const int32_t* output_shift,
int32_t output_activation_min,
@@ -45,58 +45,59 @@
vzip_w_vv(v52, v4, v5);
vzip_w_vv(v54, v6, v7);
- // Accumulators will be [v48 - v51].
- const int16_t* local_activations0 = activations + c;
- const int16_t* local_activations1 = local_activations0 + 16;
- int16_t* local_output = output + c;
-
- // Registers [v0-v5 will be for loading activations]
- // Preload for valid padding:
- vld_h_p_xx(v0, local_activations0, channels);
- vld_h_p_xx(v1, local_activations1, channels);
- vld_h_p_xx(v2, local_activations0, channels);
- vld_h_p_xx(v3, local_activations1, channels);
- int frames_left = frames - 2;
-
+ const int32_t step = dilation * channels;
const int32_t* local_output_mult = output_mult + c;
const int32_t* local_output_shift = output_shift + c;
+ for (int d = 0; d < dilation; d++) {
+ // Accumulators will be [v48 - v51].
+ const int16_t* local_activations0 = activations + (d * channels) + c;
+ const int16_t* local_activations1 = local_activations0 + 16;
+ int16_t* local_output = output + (d * channels) + c;
- int32_t accumulators[32];
- while (frames_left > 0) {
- vld_h_p_xx(v4, local_activations0, channels);
- vld_h_p_xx(v5, local_activations1, channels);
- vmulw_w_vv(v48, v58, v0); // Clobber accumulator
- vmulw_w_vv(v50, v59, v1); // Clobber accumulator
- vadd_w_vv_m(v48, v48, v52); // Add bias.
- vmulw_w_vv(v40, v60, v2);
- vmulw_w_vv(v42, v61, v3);
- vadd_w_vv_m(v48, v48, v40);
- vmulw_w_vv(v44, v62, v4);
- vmulw_w_vv(v46, v63, v5);
- vadd_w_vv_m(v48, v48, v44);
+ // Registers [v0-v5 will be for loading activations]
+ // Preload for valid padding:
+ vld_h_p_xx(v0, local_activations0, step);
+ vld_h_p_xx(v1, local_activations1, step);
+ vld_h_p_xx(v2, local_activations0, step);
+ vld_h_p_xx(v3, local_activations1, step);
- vzip_w_vv(v48, v48, v49); // Swizzle accumulators
- vzip_w_vv(v50, v50, v51);
+ int frames_idx = (2 * dilation) + d;
+ int32_t accumulators[32];
+ for (; frames_idx < frames; frames_idx += dilation) {
+ vld_h_p_xx(v4, local_activations0, step);
+ vld_h_p_xx(v5, local_activations1, step);
+ vmulw_w_vv(v48, v58, v0); // Clobber accumulator
+ vmulw_w_vv(v50, v59, v1); // Clobber accumulator
+ vadd_w_vv_m(v48, v48, v52); // Add bias.
+ vmulw_w_vv(v40, v60, v2);
+ vmulw_w_vv(v42, v61, v3);
+ vadd_w_vv_m(v48, v48, v40);
+ vmulw_w_vv(v44, v62, v4);
+ vmulw_w_vv(v46, v63, v5);
+ vadd_w_vv_m(v48, v48, v44);
- vst_w_x_m(v48, accumulators); // Store accumulators
+ vzip_w_vv(v48, v48, v49); // Swizzle accumulators
+ vzip_w_vv(v50, v50, v51);
- // Output pipeline in scalar, to preserve bit accuracy with the ARM CPU
- // implementation.
- for (int i = 0; i < 32; i++) {
- int32_t result = tflite::MultiplyByQuantizedMultiplier(
- static_cast<int64_t>(accumulators[i]), local_output_mult[i],
- local_output_shift[i]);
+ vst_w_x_m(v48, accumulators); // Store accumulators
- local_output[i] = static_cast<int16_t>(
- std::clamp(result, output_activation_min, output_activation_max));
+ // Output pipeline in scalar, to preserve bit accuracy with the ARM CPU
+ // implementation.
+ for (int i = 0; i < 32; i++) {
+ int32_t result = tflite::MultiplyByQuantizedMultiplier(
+ static_cast<int64_t>(accumulators[i]), local_output_mult[i],
+ local_output_shift[i]);
+
+ local_output[i] = static_cast<int16_t>(
+ std::clamp(result, output_activation_min, output_activation_max));
+ }
+
+ // Slide registers
+ vmvp_vv(v0, v2, v3);
+ vmvp_vv(v2, v4, v5);
+
+ local_output += step;
}
-
- // Slide registers
- vmvp_vv(v0, v2, v3);
- vmvp_vv(v2, v4, v5);
-
- local_output += channels;
- frames_left--;
}
}
// TODO(derekjchow): Handle channels % 32 cases.
@@ -105,4 +106,4 @@
// - one final loop handling remainder
}
-} // namespace kelvin::opt
\ No newline at end of file
+} // namespace kelvin::opt
diff --git a/tflm/opt/opt.h b/tflm/opt/opt.h
index f12596c..c721803 100644
--- a/tflm/opt/opt.h
+++ b/tflm/opt/opt.h
@@ -74,7 +74,7 @@
int8_t* output_data);
void DepthwiseConv2DKelvinS16K3x1(
const int16_t* activations, const int8_t* weights, const int64_t* biases,
- int channels, int frames, const int32_t* output_mult,
+ int channels, int frames, int dilation, const int32_t* output_mult,
const int32_t* output_shift, int32_t output_activation_min,
int32_t output_activation_max, int16_t* output);
} // namespace kelvin::opt