3x3 DepthwiseConv w/ adwconv
- Specialize depthwise convolutions with a 3x3 kernel shape, using the
adwconv/vdwconv instruction set.
Change-Id: Id35ea3e13aa699eb9f0da354a18ba287decc4b11
diff --git a/tflm/opt/depthwise_conv_s8.cc b/tflm/opt/depthwise_conv_s8.cc
index 9e15b3c..a130324 100644
--- a/tflm/opt/depthwise_conv_s8.cc
+++ b/tflm/opt/depthwise_conv_s8.cc
@@ -22,6 +22,7 @@
namespace kelvin::opt {
namespace {
+
// Reorders a vector to match the pattern after double-widening.
// N must be a multiple of 4.
void VectorSwizzle(const int32_t* input, int32_t* output, int N) {
@@ -40,6 +41,171 @@
*out3 = *p_in++;
}
}
+// special case of input depth = 32n, filter shape of 3x3
+void DepthwiseConvS83x3D32(
+ const tflite::DepthwiseParams& params, const int32_t* output_multiplier,
+ const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
+ const int8_t* input_data, const tflite::RuntimeShape& filter_shape,
+ const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
+ const int32_t* bias_data, const tflite::RuntimeShape& output_shape,
+ int8_t* output_data
+) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const int32_t input_offset = params.input_offset;
+ const int32_t output_offset = params.output_offset;
+ const int32_t output_activation_min = params.quantized_activation_min;
+ const int32_t output_activation_max = params.quantized_activation_max;
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int input_depth = input_shape.Dims(3);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ const int output_depth = output_shape.Dims(3);
+ int32_t swizzled_bias_data[32];
+ int32_t swizzled_shift_multi[32];
+ int32_t swizzled_output_multi[32];
+
+ for (int in_channel = 0; in_channel + 32 <= input_depth; in_channel += 32) {
+ const int output_channel = in_channel;
+ VectorSwizzle(bias_data + output_channel, swizzled_bias_data, 32);
+ VectorSwizzle(output_multiplier + output_channel, swizzled_output_multi, 32);
+ VectorSwizzle(output_shift + output_channel, swizzled_shift_multi, 32);
+
+ vld_w_x_m(v52, swizzled_bias_data);
+ vld_w_x_m(v56, swizzled_output_multi);
+ vld_w_x_m(v60, swizzled_shift_multi);
+ vrsub_w_vx_m(v60, v60, 0);
+
+ union {
+ vdwconv_u8_t dwconv;
+ uint32_t raw;
+ } cmds;
+ cmds.raw = 0;
+ cmds.dwconv.sdata1 = true;
+ cmds.dwconv.sbias1 = input_offset;
+ cmds.dwconv.sdata2 = true;
+ cmds.dwconv.sbias2 = 0;
+ cmds.dwconv.mode = 0;
+ cmds.dwconv.sparsity = 0;
+ cmds.dwconv.regbase = 0;
+
+ // Don't reorder me, otherwise data will not be
+ // loaded in the correct order
+ // (we can reuse the p_flt* due to the `p` vld variant).
+ const int8_t* p_flt0 = filter_data + in_channel;
+ const int8_t* p_flt1 = p_flt0 + input_depth;
+ const int32_t stride = 2 * input_depth;
+ vld_b_sp_xx(v6, p_flt0, stride);
+ vld_b_sp_xx(v7, p_flt1, stride);
+ vld_b_sp_xx(v8, p_flt0, stride);
+ vld_b_sp_xx(v9, p_flt1, stride);
+ vld_b_sp_xx(v10, p_flt0, stride);
+ vld_b_sp_xx(v11, p_flt1, stride);
+ vld_b_sp_xx(v12, p_flt0, stride);
+ vld_b_sp_xx(v13, p_flt1, stride);
+ vld_b_sp_xx(v14, p_flt0, stride);
+
+ for (int batch = 0; batch < batches; ++batch) {
+ const int8_t* p_output = output_data + (batch * output_width * output_height * output_depth) + output_channel;
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int y_offset = (output_depth * output_width * out_y);
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+
+ // Initialize accumulators w/ bias data.
+ vmv_v_m(v48, v52);
+
+ bool top_pad = in_y_origin < 0;
+ bool left_pad = in_x_origin < 0;
+ bool bottom_pad = (in_y_origin + 2) >= input_height;
+ bool right_pad = (in_x_origin + 2) >= input_width;
+ bool padding_required = top_pad || left_pad || bottom_pad || right_pad;
+ const int8_t* p_in_0 = input_data +
+ (batch * input_height * input_width * input_depth) +
+ (in_y_origin * input_width * input_depth) +
+ (in_x_origin * input_depth) +
+ in_channel;
+ const int8_t* p_in_1 = p_in_0 + (input_width * input_depth);
+ const int8_t* p_in_2 = p_in_1 + (input_width * input_depth);
+ if (!padding_required) {
+ vld_b_sp_xx(v15, p_in_0, input_depth);
+ vld_b_sp_xx(v16, p_in_0, input_depth);
+ vld_b_sp_xx(v17, p_in_0, input_depth);
+ vld_b_sp_xx(v18, p_in_1, input_depth);
+ vld_b_sp_xx(v19, p_in_1, input_depth);
+ vld_b_sp_xx(v20, p_in_1, input_depth);
+ vld_b_sp_xx(v21, p_in_2, input_depth);
+ vld_b_sp_xx(v22, p_in_2, input_depth);
+ vld_b_sp_xx(v23, p_in_2, input_depth);
+ } else {
+ // Top row
+ if (top_pad || left_pad) {
+ vdup_b_x(v15, -input_offset);
+ } else {
+ vld_b_x(v15, p_in_0);
+ }
+ if (top_pad) {
+ vdup_b_x(v16, -input_offset);
+ } else {
+ vld_b_x(v16, p_in_0 + input_depth);
+ }
+ if (top_pad || right_pad) {
+ vdup_b_x(v17, -input_offset);
+ } else {
+ vld_b_x(v17, p_in_0 + (2 * input_depth));
+ }
+ // Middle row
+ if (left_pad) {
+ vdup_b_x(v18, -input_offset);
+ } else {
+ vld_b_x(v18, p_in_1);
+ }
+ vld_b_x(v19, p_in_1 + input_depth);
+ if (right_pad) {
+ vdup_b_x(v20, -input_offset);
+ } else {
+ vld_b_x(v20, p_in_1 + (2 * input_depth));
+ }
+ // Bottom row
+ if (bottom_pad || left_pad) {
+ vdup_b_x(v21, -input_offset);
+ } else {
+ vld_b_x(v21, p_in_2);
+ }
+ if (bottom_pad) {
+ vdup_b_x(v22, -input_offset);
+ } else {
+ vld_b_x(v22, p_in_2 + input_depth);
+ }
+ if (bottom_pad || right_pad) {
+ vdup_b_x(v23, -input_offset);
+ } else {
+ vld_b_x(v23, p_in_2 + (2 * input_depth));
+ }
+ }
+
+ adwinit_v(v48, v48);
+ adwconv_vxv(v48, v15, cmds, v6);
+ adwconv_vxv(v48, v18, cmds, v9);
+ vdwconv_vxv(v48, v21, cmds, v12);
+
+ vdmulh_w_rn_vv_m(v48, v48, v56);
+ vsha_w_r_vv_m(v48, v48, v60);
+ vadd_w_vx_m(v48, v48, output_offset);
+ vmax_w_vx_m(v48, v48, output_activation_min);
+ vmin_w_vx_m(v48, v48, output_activation_max);
+ vsraqs_b_vx(v48, v48, 0);
+ vst_b_x(v48, p_output + (out_x * output_depth) + y_offset);
+ }
+ }
+ }
+ }
+}
// special case of input depth = 32n, filter shape of 5x5, stride == 1
void DepthwiseConvS85x5D32_Stride1(
@@ -628,6 +794,8 @@
// TODO(b/141565753): Re-introduce ScopedProfilingLabel on Micro.
const int stride_width = params.stride_width;
const int stride_height = params.stride_height;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
const int filter_height = filter_shape.Dims(1);
const int filter_width = filter_shape.Dims(2);
const int dilation_width_factor = params.dilation_width_factor;
@@ -661,6 +829,8 @@
} else {
fn = DepthwiseConvS85x5D32;
}
+ } else if (filter_width == 3 && filter_height == 3 && pad_width <= 1 && pad_height <= 1) {
+ fn = DepthwiseConvS83x3D32;
} else {
fn = DepthwiseConvS8D32;
}