Improved strategy for filter padding of conv I3xD8
*This is to ensure padding strategy is more readable
*Make loop efficient by reducing if statements and using memcpy
Change-Id: I337efe180d7a5ea57d0b4b02ae9fa820edf74cef
diff --git a/tflm/opt/conv_s8_I3xD8.cc b/tflm/opt/conv_s8_I3xD8.cc
index 370bd97..dd17056 100644
--- a/tflm/opt/conv_s8_I3xD8.cc
+++ b/tflm/opt/conv_s8_I3xD8.cc
@@ -26,6 +26,12 @@
#include "tflm/opt/conv_util.h"
#include "tflm/opt/opt.h"
+namespace {
+
+constexpr int32_t kBytesPerRegister = 32;
+constexpr int8_t kFilterZeroPoint = 0;
+} // namespace
+
namespace kelvin::opt {
namespace {
@@ -57,65 +63,39 @@
out2[7] = in[7];
}
-void PaddedFilter_N_H_W_M(const int8_t* input, int8_t* output, int N, int H,
- int W, int M) {
- if (M != 3) {
- MicroPrintf("Filter shuffling can only handle M(input_depth) == 3");
- exit(-1);
- }
-
- const int8_t(&in)[N][H][W][M] = *(int8_t(*)[N][H][W][M])input;
- int8_t(&out)[N / 8][3][8 * 4 * 3] = *(int8_t(*)[N / 8][3][8 * 4 * 3]) output;
- int group = 0;
+void PaddedFilter(const int8_t* input, int8_t* output, int output_channels) {
// Filter data is being reorganized into groups of 8 channels and falttening
// row. 9th element of 3x3 filter is padded (9000 9000 9000 9000) 8 channels
// are aligned this way ( c0 c1 c2 c3)
- for (int ky = 0; ky < H; ++ky) {
- int filter_element[N / 8]{0};
- for (int kx = 0; kx < W; ++kx) {
- for (int output_channel = 0; output_channel < N; ++output_channel) {
- for (int input_channel = 0; input_channel < M; ++input_channel) {
- group = output_channel >> 3;
- if (kx == 1 && input_channel == 0) {
- continue;
- }
- if (kx == 2 && (input_channel < 2)) {
- continue;
- }
- if (kx == 0 && input_channel == 2) {
- out[group][ky][filter_element[group]] =
- in[output_channel][ky][kx][input_channel];
- filter_element[group] += 1;
- out[group][ky][filter_element[group]] =
- in[output_channel][ky][kx + 1][0];
- filter_element[group] += 1;
- } else if (kx == 1 && input_channel == 2) {
- out[group][ky][filter_element[group]] =
- in[output_channel][ky][kx][input_channel];
- filter_element[group] += 1;
- out[group][ky][filter_element[group]] =
- in[output_channel][ky][kx + 1][0];
- filter_element[group] += 1;
- out[group][ky][filter_element[group]] =
- in[output_channel][ky][kx + 1][1];
- filter_element[group] += 1;
- } else if (kx == 2 && input_channel == 2) {
- out[group][ky][filter_element[group]] =
- in[output_channel][ky][kx][input_channel];
- filter_element[group] += 1;
- out[group][ky][filter_element[group]] = 0;
- filter_element[group] += 1;
- out[group][ky][filter_element[group]] = 0;
- filter_element[group] += 1;
- out[group][ky][filter_element[group]] = 0;
- filter_element[group] += 1;
- } else {
- out[group][ky][filter_element[group]] =
- in[output_channel][ky][kx][input_channel];
- filter_element[group] += 1;
- }
- }
+ for (int group = 0; group < output_channels / 8; group++) {
+ const int8_t* input_group_pointer = input + (group * 8 * 3 * 3 * 3);
+ int8_t* output_group_pointer = output + (group * 8 * 3 * 3 * 4);
+
+ for (int channel = 0; channel < 8; channel++) {
+ for (int row = 0; row < 3; row++) {
+ int out_row_offset = (channel * 4) + (row * 3 * kBytesPerRegister);
+
+ const int8_t* input_c1_offset =
+ input_group_pointer + (channel * 27) + (9 * row);
+ int8_t* output_c1_offset = output_group_pointer + out_row_offset;
+ memcpy(output_c1_offset, input_c1_offset, 4);
+
+ const int8_t* input_c2_offset =
+ input_group_pointer + (channel * 27) + (9 * row + 4);
+ int8_t* output_c2_offset =
+ output_group_pointer + out_row_offset + kBytesPerRegister;
+ memcpy(output_c2_offset, input_c2_offset, 4);
+
+ const int8_t* input_c3_offset =
+ input_group_pointer + (channel * 27) + (9 * row + 8);
+ int8_t* output_c3_offset =
+ output_group_pointer + out_row_offset + (2 * kBytesPerRegister);
+ memcpy(output_c3_offset, input_c3_offset, 1);
+
+ *(output_c3_offset + 1) = kFilterZeroPoint;
+ *(output_c3_offset + 2) = kFilterZeroPoint;
+ *(output_c3_offset + 3) = kFilterZeroPoint;
}
}
}
@@ -216,6 +196,7 @@
const int groups = input_depth / filter_input_depth;
TFLITE_DCHECK_NE(groups, 0);
TFLITE_DCHECK_EQ(input_depth % filter_input_depth, 0);
+ TFLITE_DCHECK_EQ(filter_input_depth, 3);
const int filters_per_group = output_depth / groups;
TFLITE_DCHECK_NE(filters_per_group, 0);
const int output_height = output_shape.Dims(1);
@@ -253,9 +234,7 @@
::aligned_alloc(32, swizzled_filter_data_size)));
int8_t* p_swizzled_filter_data = swizzled_filter_data.get();
- PaddedFilter_N_H_W_M(filter_data, p_swizzled_filter_data, output_depth,
- filter_height, filter_width, filter_input_depth);
-
+ PaddedFilter(filter_data, p_swizzled_filter_data, output_depth);
// structure of padded filter data : 1st row 0-8 channels 0-95 , 2nd row 0-8
// channels 96-191, 3rd row 0-8 channels 192-287