Add specialized pointwise convolution.
Change-Id: Ied82f85003980759417760572d4f5601a21369c6
diff --git a/tflm/opt/conv_s8.cc b/tflm/opt/conv_s8.cc
index 3f77764..a3237a6 100644
--- a/tflm/opt/conv_s8.cc
+++ b/tflm/opt/conv_s8.cc
@@ -200,6 +200,15 @@
return; \
}
+ // special case of filter size 1x1
+ if (filter_height == 1 && filter_width == 1 && stride_height == 1 &&
+ stride_width == 1 && dilation_height_factor == 1 &&
+ dilation_width_factor == 1 && pad_height == 0 && pad_width == 0 &&
+ (input_depth == filter_depth) && (output_depth % 8) == 0 &&
+ (input_depth % 32) == 0) {
+ RUN_KERNEL(kelvin::opt::ConvS8K1x1);
+ }
+
if (input_depth == 1 && filter_width == 5 && filter_height == 5 &&
output_depth == 24) {
RUN_KERNEL(kelvin::opt::ConvPerChannelD1OD24_5x5);
@@ -218,15 +227,6 @@
RUN_KERNEL(kelvin::opt::ConvS8D32);
}
- // special case of filter size 1x1
- if (filter_height == 1 && filter_width == 1 && stride_height == 1 &&
- stride_width == 1 && dilation_height_factor == 1 &&
- dilation_width_factor == 1 && pad_height == 0 && pad_width == 0 &&
- (output_depth % 8) == 0 && (input_depth % 32) == 0) {
- // TODO(ndodda): uncomment it when all tests are passed
- // RUN_KERNEL(kelvin::opt::ConvS8K1x1);
- }
-
// special case of filter size 48x3x1x48
if (batches == 1 && filter_height == 3 && filter_width == 1 &&
input_width == 1 && input_depth == 48 && output_depth == 48 &&
diff --git a/tflm/opt/conv_s8_1x1.cc b/tflm/opt/conv_s8_1x1.cc
index 9da99c3..bc61ddf 100644
--- a/tflm/opt/conv_s8_1x1.cc
+++ b/tflm/opt/conv_s8_1x1.cc
@@ -22,28 +22,43 @@
namespace kelvin::opt {
-void ConvS8K1x1(const tflite::ConvParams& params,
- const int32_t* output_multiplier, const int32_t* output_shift,
- const tflite::RuntimeShape& input_shape,
- const int8_t* input_data,
- const tflite::RuntimeShape& filter_shape,
- const int8_t* filter_data,
- const tflite::RuntimeShape& bias_shape,
- const int32_t* bias_data,
- const tflite::RuntimeShape& output_shape, int8_t* output_data) {
- const auto batches = MatchingDim(input_shape, 0, output_shape, 0);
- const auto input_depth = input_shape.Dims(3);
- const auto input_offset = params.input_offset;
- const auto output_height = output_shape.Dims(1);
- const auto output_width = output_shape.Dims(2);
- const auto output_depth = output_shape.Dims(3);
- const auto output_offset = params.output_offset;
- const auto output_activation_min = params.quantized_activation_min;
- const auto output_activation_max = params.quantized_activation_max;
- // ToDo : support group convolutions.
- int32_t bias[8 * 4];
- int32_t mult[8 * 4];
- int32_t shft[8 * 4];
+void ConvS8K1x1(
+ const tflite::ConvParams& params, const int32_t* output_multiplier,
+ const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
+ const int8_t* input_data, const tflite::RuntimeShape& filter_shape,
+ const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
+ const int32_t* bias_data, const tflite::RuntimeShape& output_shape,
+ int8_t* output_data) {
+ // Get parameters.
+ const int32_t input_offset = params.input_offset; // r = s(q - Z)
+ const int32_t output_offset = params.output_offset;
+
+ // Set min and max value of the output.
+ const int32_t output_activation_min = params.quantized_activation_min;
+ const int32_t output_activation_max = params.quantized_activation_max;
+
+ // Consistency check.
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_depth = input_shape.Dims(3);
+ const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
+ if (bias_data) {
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
+ }
+
+ // Check dimensions of the tensors.
+ const int filter_input_depth = filter_shape.Dims(3);
+ const int groups = input_depth / filter_input_depth;
+ TFLITE_DCHECK_NE(groups, 0);
+ TFLITE_DCHECK_EQ(input_depth % filter_input_depth, 0);
+ const int filters_per_group = output_depth / groups;
+ TFLITE_DCHECK_NE(filters_per_group, 0);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+
union {
vconv_u8_t conv;
uint32_t raw;
@@ -55,43 +70,105 @@
cmds.conv.sdata1 = true;
cmds.conv.sbias2 = 0;
cmds.conv.sdata2 = true;
- for (int zo_hi = 0; zo_hi < output_depth; zo_hi += 8) {
- // transpose filter weigths to support outer prodcut multiplication
- int8_t juggled_filter_data[1][1][1][input_depth / 4][8][4];
- Filter_N_H_W_M<8>(filter_data, juggled_filter_data[0][0][0][0][0], 1, 1,
- 32);
- Swizzle(bias_data, bias, 8);
- Swizzle(output_multiplier, mult, 8);
- Swizzle(output_shift, shft, 8, true);
+ const size_t swizzled_filter_data_size =
+ 8 * filter_input_depth;
+ std::unique_ptr<int8_t> swizzled_filter_data(reinterpret_cast<int8_t*>(
+ ::aligned_alloc(32, swizzled_filter_data_size)));
+ int8_t* p_swizzled_filter_data = swizzled_filter_data.get();
+ int32_t swizzled_bias_data[32];
+ int32_t swizzled_mult_data[32];
+ int32_t swizzled_shift_data[32];
+
+ const int n_elems = (output_width * batches * output_height);
+ int out_channel = 0;
+ do {
+ int out_channels_this_iter = std::min(8, output_depth - out_channel);
+ Filter_N_H_W_M(filter_data + (out_channel * filter_input_depth),
+ p_swizzled_filter_data, out_channels_this_iter, 1, 1,
+ filter_input_depth);
+ if (bias_data) {
+ Swizzle(bias_data + out_channel, swizzled_bias_data, out_channels_this_iter);
+ vld_w_x_m(v16, swizzled_bias_data);
+ } else {
+ vdup_w_x_m(v16, 0);
+ }
+ Swizzle(output_multiplier + out_channel, swizzled_mult_data, out_channels_this_iter);
+ Swizzle(output_shift + out_channel, swizzled_shift_data, out_channels_this_iter);
+
+ vld_w_x_m(v20, swizzled_mult_data);
+ vld_w_x_m(v24, swizzled_shift_data);
+ vrsub_w_vx_m(v24, v24, 0);
+
int out = 0;
- for (; out + 8 <= output_height * output_width * batches; out += 8) {
- // resetting accumulators to clean up old output
- vdup_b_x_m(v48, 0);
- vdup_b_x_m(v52, 0);
+ for (; out < n_elems; out += 8) {
+ int out_this_iter = std::min(8, n_elems - out);
- int in = 0;
- for (; in <= input_depth; in += 32) {
- vld_b_s_xx_m(v0, input_data + out * input_depth + in, input_depth);
- vld_b_s_xx_m(v4, input_data + out * input_depth + in + 4 * input_depth,
- input_depth);
+ const int8_t* p_in = input_data + (out * input_depth);
+ int8_t* p_out = output_data + (out * output_depth) + out_channel;
- vld_b_x_m(v8, juggled_filter_data[0][0][0][in / 32][0][0]);
- vld_b_x_m(v12, juggled_filter_data[0][0][0][(in / 32) + 4][0][0]);
+ // 8x accumulators
+ vmv_v_m(v48, v16);
+ vmv_v_m(v52, v16);
+ acset_v(v48, v48);
+ int in_channel = 0;
+ for (; in_channel < filter_input_depth; in_channel += 32) {
+ const int8_t* p_input = p_in + in_channel;
+ if (out_this_iter < 8) {
+ switch (out_this_iter) {
+ case 7:
+ vld_b_x(v6, p_input + (6 * input_depth));
+ case 6:
+ vld_b_x(v5, p_input + (5 * input_depth));
+ case 5:
+ vld_b_x(v4, p_input + (4 * input_depth));
+ case 4:
+ vld_b_x(v3, p_input + (3 * input_depth));
+ case 3:
+ vld_b_x(v2, p_input + (2 * input_depth));
+ case 2:
+ vld_b_x(v1, p_input + input_depth);
+ case 1:
+ vld_b_x(v0, p_input);
+ }
+ } else {
+ // Inputs
+ vld_b_s_xx_m(v0, p_input, input_depth);
+ vld_b_s_xx_m(v4, p_input + (4 * input_depth), input_depth);
+ }
+
+ int8_t* p_local_filter = p_swizzled_filter_data + (in_channel * 8);
+ vld_b_p_x_m(v8, p_local_filter);
+ vld_b_x_m(v12, p_local_filter);
aconv_vxv(v48, v0, cmds, v8);
}
- INT32_TO_INT8_OUTPUT_PIPELINE(bias, mult, shft, output_activation_min,
- output_activation_max, output_offset, v16,
- v20, v24);
+ vcget(v48);
+ INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
+ v48, v20, v24, output_activation_min, output_activation_max,
+ output_offset);
+ INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
+ v52, v20, v24, output_activation_min, output_activation_max,
+ output_offset);
+ vsraqs_b_vx(v48, v48, 0);
+ vsraqs_b_vx(v52, v52, 0);
- // store the results to ouput memory
- int8_t* p_out = output_data + (out * output_depth) + zo_hi;
- vstq_b_sp_xx(v48, p_out, output_depth);
- vstq_b_sp_xx(v52, p_out, output_depth);
+ int i = 0;
+ for (; i < std::min(4, out_this_iter); i++) {
+ vst_b_l_xx(v48, p_out, out_channels_this_iter);
+ p_out += output_depth;
+ vsliden_h_4_vv(v48, v48, v48);
+ }
+ for (; i < out_this_iter; i++) {
+ vst_b_l_xx(v52, p_out, out_channels_this_iter);
+ p_out += output_depth;
+ vsliden_h_4_vv(v52, v52, v52);
+ }
}
- }
+
+ out_channel += out_channels_this_iter;
+ } while (out_channel < output_depth);
}
} // namespace kelvin::opt
diff --git a/tflm/opt/conv_util.h b/tflm/opt/conv_util.h
index e552d52..7c925c4 100644
--- a/tflm/opt/conv_util.h
+++ b/tflm/opt/conv_util.h
@@ -80,6 +80,36 @@
}
}
+inline void Filter_N_H_W_M(const int8_t* input, int8_t* output, int N, int H,
+ int W, int M) {
+ const int8_t(&in)[8][H][W][M] = *(int8_t(*)[8][H][W][M])input;
+ int8_t(&out)[H][W][M / 4][8][4] = *(int8_t(*)[H][W][M / 4][8][4]) output;
+ assert(M >= 4);
+ for (int zo = 0; zo < N; ++zo) {
+ for (int ky = 0; ky < H; ++ky) {
+ for (int kx = 0; kx < W; ++kx) {
+ for (int zi = 0; zi < M; ++zi) {
+ const int zi_hi = zi >> 2; // div4
+ const int zi_lo = zi & 3; // rem4
+ out[ky][kx][zi_hi][zo][zi_lo] = in[zo][ky][kx][zi];
+ }
+ }
+ }
+ }
+ // Zero out the rest of the output.
+ for (int zo = N; zo < 8; ++zo) {
+ for (int ky = 0; ky < H; ++ky) {
+ for (int kx = 0; kx < W; ++kx) {
+ for (int zi = 0; zi < M; ++zi) {
+ const int zi_hi = zi >> 2; // div4
+ const int zi_lo = zi & 3; // rem4
+ out[ky][kx][zi_hi][zo][zi_lo] = 0;
+ }
+ }
+ }
+ }
+}
+
// Swizzle values, and duplicate 4 times for stripmining.
inline void Swizzle(const int32_t* input, int32_t* output, int N,
bool negate = false) {