Relax output_depth for ConvS8D4
- ConvS8D4 can now handle output depths that are not a multiple of 8.
Change-Id: Ibfc445e09f392453ddb8c038a5b224b8bcb33c4e
diff --git a/tflm/opt/conv_s8.cc b/tflm/opt/conv_s8.cc
index 2d49dbc..ce56fda 100644
--- a/tflm/opt/conv_s8.cc
+++ b/tflm/opt/conv_s8.cc
@@ -198,7 +198,7 @@
// special case of filter_depth = 4n
if (dilation_width_factor == 1 && dilation_height_factor == 1 &&
stride_width <= 2 && stride_height <= 2 && filter_depth % 4 == 0 &&
- output_depth % 8 == 0 && output_width >= 8 && pad_width <= 1) {
+ output_depth >= 8 && output_width >= 8 && pad_width <= 1) {
fn = kelvin::opt::ConvS8D4;
}
diff --git a/tflm/opt/conv_s8_d4.cc b/tflm/opt/conv_s8_d4.cc
index 0dd3e50..18df3e7 100644
--- a/tflm/opt/conv_s8_d4.cc
+++ b/tflm/opt/conv_s8_d4.cc
@@ -29,11 +29,12 @@
namespace kelvin::opt {
namespace {
-void Filter_8_H_W_M(const int8_t* input, int8_t* output, int H, int W, int M) {
+
+void Filter_N_H_W_M(const int8_t* input, int8_t* output, int N, int H, int W, int M) {
const int8_t(&in)[8][H][W][M] = *(int8_t(*)[8][H][W][M])input;
int8_t(&out)[H][W][M / 4][8][4] = *(int8_t(*)[H][W][M / 4][8][4]) output;
assert(M >= 4);
- for (int zo = 0; zo < 8; ++zo) {
+ for (int zo = 0; zo < N; ++zo) {
for (int ky = 0; ky < H; ++ky) {
for (int kx = 0; kx < W; ++kx) {
for (int zi = 0; zi < M; ++zi) {
@@ -44,28 +45,33 @@
}
}
}
-}
-
-void Swizzle(const int32_t* input, int32_t* output, int N) {
- const int32_t(&in)[N] = *(int32_t(*)[N])input;
- int32_t(&out)[N * 4] = *(int32_t(*)[N * 4]) output;
- // Convert to accumulator swizzle pattern.
- for (int i = 0; i < N / 8; ++i) {
- int32_t* out0 = out + i * 32 + 0;
- int32_t* out1 = out + i * 32 + 16;
- int32_t* out2 = out + i * 32 + 8;
- int32_t* out3 = out + i * 32 + 24;
- for (int j = 0; j < 4; ++j) {
- const int32_t* p_in = in + i * 8;
- for (int k = 0; k < 2; ++k) {
- *out0++ = *p_in++;
- *out1++ = *p_in++;
- *out2++ = *p_in++;
- *out3++ = *p_in++;
+ // Zero out the rest of the output.
+ for (int zo = N; zo < 8; ++zo) {
+ for (int ky = 0; ky < H; ++ky) {
+ for (int kx = 0; kx < W; ++kx) {
+ for (int zi = 0; zi < M; ++zi) {
+ const int zi_hi = zi >> 2; // div4
+ const int zi_lo = zi & 3; // rem4
+ out[ky][kx][zi_hi][zo][zi_lo] = 0;
+ }
}
}
}
}
+
+void Swizzle(const int32_t* input, int32_t* output, int N) {
+ assert(N <= 8);
+ const int32_t(&in)[8] = *(int32_t(*)[8])input;
+ int32_t(&out)[32] = *(int32_t(*)[32]) output;
+ // Convert to accumulator swizzle pattern.
+ memset(out, 0, 32 * sizeof(int32_t));
+ int offsets[] = {0, 16, 8, 24, 1, 17, 9, 25};
+ for (int i = 0; i < N; ++i) {
+ int offset = offsets[i];
+ out[0 + offset] = out[2 + offset] = out[4 + offset] = out[6 + offset] = in[i];
+ }
+}
+
} // namespace
void ConvS8D4(
@@ -137,14 +143,16 @@
int32_t swizzled_mult_data[32];
int32_t swizzled_shift_data[32];
- for (int out_channel = 0; out_channel + 8 <= output_depth; out_channel += 8) {
- Filter_8_H_W_M(filter_data + (out_channel * filter_height * filter_width *
+ int out_channel = 0;
+ do {
+ int out_channels_this_iter = std::min(8, output_depth - out_channel);
+ Filter_N_H_W_M(filter_data + (out_channel * filter_height * filter_width *
filter_input_depth),
- p_swizzled_filter_data, filter_height, filter_width,
+ p_swizzled_filter_data, out_channels_this_iter, filter_height, filter_width,
filter_input_depth);
- Swizzle(bias_data + out_channel, swizzled_bias_data, 8);
- Swizzle(output_multiplier + out_channel, swizzled_mult_data, 8);
- Swizzle(output_shift + out_channel, swizzled_shift_data, 8);
+ Swizzle(bias_data + out_channel, swizzled_bias_data, out_channels_this_iter);
+ Swizzle(output_multiplier + out_channel, swizzled_mult_data, out_channels_this_iter);
+ Swizzle(output_shift + out_channel, swizzled_shift_data, out_channels_this_iter);
vld_w_x_m(v16, swizzled_bias_data);
vld_w_x_m(v20, swizzled_mult_data);
vld_w_x_m(v24, swizzled_shift_data);
@@ -162,7 +170,7 @@
acset_v(v48, v48);
int in_channel = 0;
do {
- int channels_this_iter = std::min(filter_input_depth, 32);
+ int in_channels_this_iter = std::min(filter_input_depth, 32);
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
const int in_y = in_y_origin + dilation_height_factor * filter_y;
const bool is_row_inside_input =
@@ -320,11 +328,11 @@
vld_b_p_x_m(v8, p_local_filter_start);
vld_b_x_m(v12, p_local_filter_start);
- cmds.conv.stop = (channels_this_iter / 4) - 1;
+ cmds.conv.stop = (in_channels_this_iter / 4) - 1;
aconv_vxv(v48, v0, cmds, v8);
}
}
- in_channel += channels_this_iter;
+ in_channel += in_channels_this_iter;
} while (in_channel < filter_input_depth);
vcget(v48);
vadd_w_vv_m(v48, v48, v16);
@@ -341,44 +349,70 @@
vmax_w_vx_m(v52, v52, output_activation_min);
vsraqs_b_vx(v56, v48, 0);
vsraqs_b_vx(v57, v52, 0);
- if (out_xs_this_iter >= 4) {
- vstq_b_s_xx(v56,
- &output_data[tflite::Offset(output_shape, batch, out_y,
- out_x, out_channel)],
- output_depth);
- } else {
- for (int i = 0; i < out_xs_this_iter; ++i) {
- if (i > 0) {
- vsliden_b_4_vv(v58, v56, v0);
- vsliden_b_4_vv(v56, v58, v0);
+ if (out_channels_this_iter == 8) {
+ if (out_xs_this_iter >= 4) {
+ vstq_b_s_xx(v56,
+ &output_data[tflite::Offset(output_shape, batch, out_y,
+ out_x, out_channel)],
+ output_depth);
+ } else {
+ for (int i = 0; i < std::min(4, out_xs_this_iter); ++i) {
+ if (i > 0) {
+ vsliden_b_4_vv(v58, v56, v0);
+ vsliden_b_4_vv(v56, v58, v0);
+ }
+ vst_b_l_xx(v56,
+ &output_data[tflite::Offset(output_shape, batch, out_y,
+ out_x + i, out_channel)],
+ out_channels_this_iter);
}
- vst_b_l_xx(v56,
- &output_data[tflite::Offset(output_shape, batch, out_y,
- out_x + i, out_channel)],
- 8);
}
- }
- if (out_xs_this_iter == 8) {
- vstq_b_s_xx(v57,
- &output_data[tflite::Offset(output_shape, batch, out_y,
- out_x + 4, out_channel)],
- output_depth);
- } else if (out_xs_this_iter > 4) {
- for (int i = 4; i < out_xs_this_iter; ++i) {
- if (i > 4) {
- vsliden_b_4_vv(v58, v57, v0);
- vsliden_b_4_vv(v57, v58, v0);
+ if (out_xs_this_iter == 8) {
+ vstq_b_s_xx(v57,
+ &output_data[tflite::Offset(output_shape, batch, out_y,
+ out_x + 4, out_channel)],
+ output_depth);
+ } else if (out_xs_this_iter > 4) {
+ for (int i = 4; i < std::min(8, out_xs_this_iter); ++i) {
+ if (i > 4) {
+ vsliden_b_4_vv(v58, v57, v0);
+ vsliden_b_4_vv(v57, v58, v0);
+ }
+ vst_b_l_xx(v57,
+ &output_data[tflite::Offset(output_shape, batch, out_y,
+ out_x + i, out_channel)],
+ out_channels_this_iter);
}
- vst_b_l_xx(v57,
- &output_data[tflite::Offset(output_shape, batch, out_y,
- out_x + i, out_channel)],
- 8);
+ }
+ } else {
+ for (int i = 0; i < std::min(4, out_xs_this_iter); ++i) {
+ if (i > 0) {
+ vsliden_b_4_vv(v58, v56, v0);
+ vsliden_b_4_vv(v56, v58, v0);
+ }
+ vst_b_l_xx(v56,
+ &output_data[tflite::Offset(output_shape, batch, out_y,
+ out_x + i, out_channel)],
+ out_channels_this_iter);
+ }
+ if (out_xs_this_iter > 4) {
+ for (int i = 4; i < std::min(8, out_xs_this_iter); ++i) {
+ if (i > 4) {
+ vsliden_b_4_vv(v58, v57, v0);
+ vsliden_b_4_vv(v57, v58, v0);
+ }
+ vst_b_l_xx(v57,
+ &output_data[tflite::Offset(output_shape, batch, out_y,
+ out_x + i, out_channel)],
+ out_channels_this_iter);
+ }
}
}
out_x += out_xs_this_iter;
} while (out_x < output_width);
}
}
- }
+ out_channel += out_channels_this_iter;
+ } while (out_channel < output_depth);
}
} // namespace kelvin::opt