Improve ConvS8D4
- Remove tflite::Offset
- Lift some invariants out of loops
- Loop unroll pragmas
- Remove conditional branching from output stage
Change-Id: Iabf35537ef3cdba3da4375d8b7473fecd1fa38f2
diff --git a/tflm/opt/conv_s8_d4.cc b/tflm/opt/conv_s8_d4.cc
index 18df3e7..54c20bd 100644
--- a/tflm/opt/conv_s8_d4.cc
+++ b/tflm/opt/conv_s8_d4.cc
@@ -27,6 +27,9 @@
#include "tensorflow/lite/kernels/internal/runtime_shape.h"
#include "tensorflow/lite/kernels/internal/types.h"
+#define unlikely(x) (__builtin_expect(false || (x), false))
+#define likely(x) (__builtin_expect(false || (x), true))
+
namespace kelvin::opt {
namespace {
@@ -159,8 +162,12 @@
vrsub_w_vx_m(v24, v24, 0);
for (int batch = 0; batch < batches; ++batch) {
+ const int8_t* p_output =
+ output_data + (batch * output_height * output_width * output_depth) +
+ out_channel;
for (int out_y = 0; out_y < output_height; ++out_y) {
const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int out_y_offset = (out_y * output_width * output_depth);
int out_x = 0;
do {
int out_xs_this_iter = std::min(8, output_width - out_x);
@@ -171,55 +178,59 @@
int in_channel = 0;
do {
int in_channels_this_iter = std::min(filter_input_depth, 32);
- for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
- const int in_y = in_y_origin + dilation_height_factor * filter_y;
- const bool is_row_inside_input =
- (in_y >= 0) && (in_y < input_height);
- if (!is_row_inside_input) {
- continue;
+ // Calculate first valid filter_y
+ int filter_y = 0;
+ {
+ int in_y = in_y_origin;
+ while (in_y < 0) {
+ ++filter_y;
+ in_y += (dilation_height_factor);
}
+ }
+ for (; filter_y < filter_height; ++filter_y) {
+ const int y_filter_offset =
+ (filter_y * filter_width * 8 * input_depth);
+ const int in_y = in_y_origin + dilation_height_factor * filter_y;
+ if (in_y >= input_height) {
+ break;
+ }
+ const int8_t* p_in =
+ input_data + in_channel + (in_y * input_width * input_depth) +
+ (batch * input_height * input_width * input_depth);
+ int in_x[8];
+#pragma GCC unroll 8
+ for (int i = 0; i < 8; ++i) {
+ in_x[i] = ((out_x + i) * stride_width) - pad_width;
+ }
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
- int in_x[8];
- bool right_pad = false;
+ const int8_t* p_in_x[8];
int first_right_pad = -1;
+
+#pragma GCC unroll 8
for (int i = 0; i < 8; ++i) {
- const int in_x_origin =
- ((out_x + i) * stride_width) - pad_width;
- in_x[i] = in_x_origin + dilation_width_factor * filter_x;
+ p_in_x[i] = p_in + (in_x[i] * input_depth);
}
- bool left_pad = (in_x[0] < 0);
+
+#pragma GCC unroll 8
for (int i = 7; i >= 0; --i) {
if (in_x[i] < input_width) {
break;
}
- right_pad = true;
first_right_pad = i;
}
+ bool left_pad = (in_x[0] < 0);
+ bool right_pad = (first_right_pad != -1);
- if (left_pad) {
+ int stride = input_depth * stride_width;
+
+ if (unlikely(left_pad)) {
vdup_b_x(v0, -input_offset);
- vld_b_s_xx(
- v1,
- &input_data[tflite::Offset(input_shape, batch, in_y,
- in_x[1], in_channel)],
- input_depth * stride_width);
- vld_b_s_xx(
- v2,
- &input_data[tflite::Offset(input_shape, batch, in_y,
- in_x[2], in_channel)],
- input_depth * stride_width);
- vld_b_s_xx(
- v3,
- &input_data[tflite::Offset(input_shape, batch, in_y,
- in_x[3], in_channel)],
- input_depth * stride_width);
- vld_b_s_xx_m(
- v4,
- &input_data[tflite::Offset(input_shape, batch, in_y,
- in_x[4], in_channel)],
- input_depth * stride_width);
- } else if (right_pad) {
+ vld_b_s_xx(v1, p_in_x[1], stride);
+ vld_b_s_xx(v2, p_in_x[2], stride);
+ vld_b_s_xx(v3, p_in_x[3], stride);
+ vld_b_s_xx_m(v4, p_in_x[4], stride);
+ } else if (unlikely(right_pad)) {
int first_pad = std::min(first_right_pad, out_xs_this_iter);
switch (first_pad) {
case 0:
@@ -241,88 +252,36 @@
}
switch (8 - first_pad) { // rest (stripmines?)
case 0:
- vld_b_s_xx(
- v7,
- &input_data[tflite::Offset(input_shape, batch, in_y,
- in_x[7], in_channel)],
- input_depth * stride_width);
+ vld_b_s_xx(v7, p_in_x[7], stride);
case 1:
- vld_b_s_xx(
- v6,
- &input_data[tflite::Offset(input_shape, batch, in_y,
- in_x[6], in_channel)],
- input_depth * stride_width);
+ vld_b_s_xx(v6, p_in_x[6], stride);
case 2:
- vld_b_s_xx(
- v5,
- &input_data[tflite::Offset(input_shape, batch, in_y,
- in_x[5], in_channel)],
- input_depth * stride_width);
+ vld_b_s_xx(v5, p_in_x[5], stride);
case 3:
- vld_b_s_xx(
- v4,
- &input_data[tflite::Offset(input_shape, batch, in_y,
- in_x[4], in_channel)],
- input_depth * stride_width);
+ vld_b_s_xx(v4, p_in_x[4], stride);
case 4:
- vld_b_s_xx(
- v3,
- &input_data[tflite::Offset(input_shape, batch, in_y,
- in_x[3], in_channel)],
- input_depth * stride_width);
+ vld_b_s_xx(v3, p_in_x[3], stride);
case 5:
- vld_b_s_xx(
- v2,
- &input_data[tflite::Offset(input_shape, batch, in_y,
- in_x[2], in_channel)],
- input_depth * stride_width);
+ vld_b_s_xx(v2, p_in_x[2], stride);
case 6:
- vld_b_s_xx(
- v1,
- &input_data[tflite::Offset(input_shape, batch, in_y,
- in_x[1], in_channel)],
- input_depth * stride_width);
+ vld_b_s_xx(v1, p_in_x[1], stride);
case 7:
- vld_b_s_xx(
- v0,
- &input_data[tflite::Offset(input_shape, batch, in_y,
- in_x[0], in_channel)],
- input_depth * stride_width);
+ vld_b_s_xx(v0, p_in_x[0], stride);
}
- } else if (!left_pad && !right_pad) {
+ } else if (likely(!left_pad && !right_pad)) {
// Inputs
- vld_b_s_xx_m(
- v0,
- &input_data[tflite::Offset(input_shape, batch, in_y,
- in_x[0], in_channel)],
- input_depth * stride_width);
- vld_b_s_xx_m(
- v4,
- &input_data[tflite::Offset(input_shape, batch, in_y,
- in_x[4], in_channel)],
- input_depth * stride_width);
+ vld_b_s_xx_m(v0, p_in_x[0], stride);
+ vld_b_s_xx_m(v4, p_in_x[4], stride);
} else {
- vdup_b_x(v0, -input_offset);
- vdup_b_x(v7, -input_offset);
- vld_b_s_xx_m(
- v1,
- &input_data[tflite::Offset(input_shape, batch, in_y,
- in_x[1], in_channel)],
- input_depth * stride_width);
- vld_b_s_xx(
- v5,
- &input_data[tflite::Offset(input_shape, batch, in_y,
- in_x[5], in_channel)],
- input_depth * stride_width);
- vld_b_s_xx(
- v6,
- &input_data[tflite::Offset(input_shape, batch, in_y,
- in_x[6], in_channel)],
- input_depth * stride_width);
+ vdup_b_x(v0, neg_input_offset);
+ vdup_b_x(v7, neg_input_offset);
+ vld_b_s_xx_m(v1, p_in_x[1], stride);
+ vld_b_s_xx(v5, p_in_x[5], stride);
+ vld_b_s_xx(v6, p_in_x[6], stride);
}
- size_t local_filter_offset =
- (filter_y * filter_width * 8 * input_depth) +
- (filter_x * 8 * input_depth) + (in_channel * 8);
+ size_t local_filter_offset = y_filter_offset +
+ (filter_x * 8 * input_depth) +
+ (in_channel * 8);
int8_t* p_local_filter_start =
p_swizzled_filter_data + local_filter_offset;
vld_b_p_x_m(v8, p_local_filter_start);
@@ -330,6 +289,11 @@
cmds.conv.stop = (in_channels_this_iter / 4) - 1;
aconv_vxv(v48, v0, cmds, v8);
+
+#pragma GCC unroll 8
+ for (int i = 0; i < 8; ++i) {
+ in_x[i] += dilation_width_factor;
+ }
}
}
in_channel += in_channels_this_iter;
@@ -349,64 +313,38 @@
vmax_w_vx_m(v52, v52, output_activation_min);
vsraqs_b_vx(v56, v48, 0);
vsraqs_b_vx(v57, v52, 0);
- if (out_channels_this_iter == 8) {
- if (out_xs_this_iter >= 4) {
- vstq_b_s_xx(v56,
- &output_data[tflite::Offset(output_shape, batch, out_y,
- out_x, out_channel)],
- output_depth);
- } else {
- for (int i = 0; i < std::min(4, out_xs_this_iter); ++i) {
- if (i > 0) {
- vsliden_b_4_vv(v58, v56, v0);
- vsliden_b_4_vv(v56, v58, v0);
- }
- vst_b_l_xx(v56,
- &output_data[tflite::Offset(output_shape, batch, out_y,
- out_x + i, out_channel)],
- out_channels_this_iter);
- }
- }
- if (out_xs_this_iter == 8) {
- vstq_b_s_xx(v57,
- &output_data[tflite::Offset(output_shape, batch, out_y,
- out_x + 4, out_channel)],
- output_depth);
- } else if (out_xs_this_iter > 4) {
- for (int i = 4; i < std::min(8, out_xs_this_iter); ++i) {
- if (i > 4) {
- vsliden_b_4_vv(v58, v57, v0);
- vsliden_b_4_vv(v57, v58, v0);
- }
- vst_b_l_xx(v57,
- &output_data[tflite::Offset(output_shape, batch, out_y,
- out_x + i, out_channel)],
- out_channels_this_iter);
- }
- }
- } else {
- for (int i = 0; i < std::min(4, out_xs_this_iter); ++i) {
- if (i > 0) {
- vsliden_b_4_vv(v58, v56, v0);
- vsliden_b_4_vv(v56, v58, v0);
- }
- vst_b_l_xx(v56,
- &output_data[tflite::Offset(output_shape, batch, out_y,
- out_x + i, out_channel)],
- out_channels_this_iter);
- }
- if (out_xs_this_iter > 4) {
- for (int i = 4; i < std::min(8, out_xs_this_iter); ++i) {
- if (i > 4) {
- vsliden_b_4_vv(v58, v57, v0);
- vsliden_b_4_vv(v57, v58, v0);
- }
- vst_b_l_xx(v57,
- &output_data[tflite::Offset(output_shape, batch, out_y,
- out_x + i, out_channel)],
- out_channels_this_iter);
- }
- }
+
+ const int8_t* p_out_x[8];
+#pragma GCC unroll 8
+ for (int i = 0; i < 8; ++i) {
+ p_out_x[i] = p_output + out_y_offset + ((out_x + i) * output_depth);
+ }
+
+ vslidep_h_4_vv(v58, v57, v57); // x7
+ vslidep_h_4_vv(v59, v58, v58); // x6
+ vslidep_h_4_vv(v60, v59, v59); // x5
+ vslidep_h_4_vv(v61, v60, v60); // x4
+ vslidep_h_4_vv(v62, v56, v56); // x3
+ vslidep_h_4_vv(v63, v62, v62); // x2
+ vslidep_h_4_vv(v57, v63, v63); // x1
+ vslidep_h_4_vv(v56, v57, v57); // x0
+ switch (out_xs_this_iter) {
+ case 8:
+ vst_b_l_xx(v58, p_out_x[7], out_channels_this_iter);
+ case 7:
+ vst_b_l_xx(v59, p_out_x[6], out_channels_this_iter);
+ case 6:
+ vst_b_l_xx(v60, p_out_x[5], out_channels_this_iter);
+ case 5:
+ vst_b_l_xx(v61, p_out_x[4], out_channels_this_iter);
+ case 4:
+ vst_b_l_xx(v62, p_out_x[3], out_channels_this_iter);
+ case 3:
+ vst_b_l_xx(v63, p_out_x[2], out_channels_this_iter);
+ case 2:
+ vst_b_l_xx(v57, p_out_x[1], out_channels_this_iter);
+ case 1:
+ vst_b_l_xx(v56, p_out_x[0], out_channels_this_iter);
}
out_x += out_xs_this_iter;
} while (out_x < output_width);