Improved strategy for filter padding of conv I3xD8 *This is to ensure padding strategy is more readable *Make loop efficient by reducing if statements and using memcpy Change-Id: I337efe180d7a5ea57d0b4b02ae9fa820edf74cef
diff --git a/tflm/opt/conv_s8_I3xD8.cc b/tflm/opt/conv_s8_I3xD8.cc index 370bd97..dd17056 100644 --- a/tflm/opt/conv_s8_I3xD8.cc +++ b/tflm/opt/conv_s8_I3xD8.cc
@@ -26,6 +26,12 @@ #include "tflm/opt/conv_util.h" #include "tflm/opt/opt.h" +namespace { + +constexpr int32_t kBytesPerRegister = 32; +constexpr int8_t kFilterZeroPoint = 0; +} // namespace + namespace kelvin::opt { namespace { @@ -57,65 +63,39 @@ out2[7] = in[7]; } -void PaddedFilter_N_H_W_M(const int8_t* input, int8_t* output, int N, int H, - int W, int M) { - if (M != 3) { - MicroPrintf("Filter shuffling can only handle M(input_depth) == 3"); - exit(-1); - } - - const int8_t(&in)[N][H][W][M] = *(int8_t(*)[N][H][W][M])input; - int8_t(&out)[N / 8][3][8 * 4 * 3] = *(int8_t(*)[N / 8][3][8 * 4 * 3]) output; - int group = 0; +void PaddedFilter(const int8_t* input, int8_t* output, int output_channels) { // Filter data is being reorganized into groups of 8 channels and falttening // row. 9th element of 3x3 filter is padded (9000 9000 9000 9000) 8 channels // are aligned this way ( c0 c1 c2 c3) - for (int ky = 0; ky < H; ++ky) { - int filter_element[N / 8]{0}; - for (int kx = 0; kx < W; ++kx) { - for (int output_channel = 0; output_channel < N; ++output_channel) { - for (int input_channel = 0; input_channel < M; ++input_channel) { - group = output_channel >> 3; - if (kx == 1 && input_channel == 0) { - continue; - } - if (kx == 2 && (input_channel < 2)) { - continue; - } - if (kx == 0 && input_channel == 2) { - out[group][ky][filter_element[group]] = - in[output_channel][ky][kx][input_channel]; - filter_element[group] += 1; - out[group][ky][filter_element[group]] = - in[output_channel][ky][kx + 1][0]; - filter_element[group] += 1; - } else if (kx == 1 && input_channel == 2) { - out[group][ky][filter_element[group]] = - in[output_channel][ky][kx][input_channel]; - filter_element[group] += 1; - out[group][ky][filter_element[group]] = - in[output_channel][ky][kx + 1][0]; - filter_element[group] += 1; - out[group][ky][filter_element[group]] = - in[output_channel][ky][kx + 1][1]; - filter_element[group] += 1; - } else if (kx == 2 && input_channel == 2) { - out[group][ky][filter_element[group]] = - in[output_channel][ky][kx][input_channel]; - filter_element[group] += 1; - out[group][ky][filter_element[group]] = 0; - filter_element[group] += 1; - out[group][ky][filter_element[group]] = 0; - filter_element[group] += 1; - out[group][ky][filter_element[group]] = 0; - filter_element[group] += 1; - } else { - out[group][ky][filter_element[group]] = - in[output_channel][ky][kx][input_channel]; - filter_element[group] += 1; - } - } + for (int group = 0; group < output_channels / 8; group++) { + const int8_t* input_group_pointer = input + (group * 8 * 3 * 3 * 3); + int8_t* output_group_pointer = output + (group * 8 * 3 * 3 * 4); + + for (int channel = 0; channel < 8; channel++) { + for (int row = 0; row < 3; row++) { + int out_row_offset = (channel * 4) + (row * 3 * kBytesPerRegister); + + const int8_t* input_c1_offset = + input_group_pointer + (channel * 27) + (9 * row); + int8_t* output_c1_offset = output_group_pointer + out_row_offset; + memcpy(output_c1_offset, input_c1_offset, 4); + + const int8_t* input_c2_offset = + input_group_pointer + (channel * 27) + (9 * row + 4); + int8_t* output_c2_offset = + output_group_pointer + out_row_offset + kBytesPerRegister; + memcpy(output_c2_offset, input_c2_offset, 4); + + const int8_t* input_c3_offset = + input_group_pointer + (channel * 27) + (9 * row + 8); + int8_t* output_c3_offset = + output_group_pointer + out_row_offset + (2 * kBytesPerRegister); + memcpy(output_c3_offset, input_c3_offset, 1); + + *(output_c3_offset + 1) = kFilterZeroPoint; + *(output_c3_offset + 2) = kFilterZeroPoint; + *(output_c3_offset + 3) = kFilterZeroPoint; } } } @@ -216,6 +196,7 @@ const int groups = input_depth / filter_input_depth; TFLITE_DCHECK_NE(groups, 0); TFLITE_DCHECK_EQ(input_depth % filter_input_depth, 0); + TFLITE_DCHECK_EQ(filter_input_depth, 3); const int filters_per_group = output_depth / groups; TFLITE_DCHECK_NE(filters_per_group, 0); const int output_height = output_shape.Dims(1); @@ -253,9 +234,7 @@ ::aligned_alloc(32, swizzled_filter_data_size))); int8_t* p_swizzled_filter_data = swizzled_filter_data.get(); - PaddedFilter_N_H_W_M(filter_data, p_swizzled_filter_data, output_depth, - filter_height, filter_width, filter_input_depth); - + PaddedFilter(filter_data, p_swizzled_filter_data, output_depth); // structure of padded filter data : 1st row 0-8 channels 0-95 , 2nd row 0-8 // channels 96-191, 3rd row 0-8 channels 192-287