Improve Conv2dD32
- Specialize Conv2D for exact depths of 32, improves performance on
hardware by ~40%
Change-Id: Ic4f193d2c9104b91abd11ddcba6616b1d5bd0710
diff --git a/tflm/opt/conv_s8.cc b/tflm/opt/conv_s8.cc
index 0dc432b..7842b0d 100644
--- a/tflm/opt/conv_s8.cc
+++ b/tflm/opt/conv_s8.cc
@@ -184,6 +184,8 @@
const auto dilation_height_factor = params.dilation_height_factor;
const auto pad_width = params.padding_values.width;
const auto pad_height = params.padding_values.height;
+ const auto input_batch = input_shape.Dims(0);
+ const auto input_height = input_shape.Dims(1);
const auto input_width = input_shape.Dims(2);
const auto input_depth = input_shape.Dims(3);
const auto filter_height = filter_shape.Dims(1);
@@ -205,13 +207,18 @@
stride_width == 1 && dilation_height_factor == 1 &&
dilation_width_factor == 1 && pad_height == 0 && pad_width == 0 &&
(input_depth == filter_depth)) {
- if ((output_depth % 8) == 0 && (input_depth % 32) == 0) {
+
+ if ((input_depth == 32) && (input_batch * input_height * input_width) >= 4) {
RUN_KERNEL(kelvin::opt::ConvS8K1x1D32);
}
+ if ((output_depth % 8) == 0 && (input_depth % 32) == 0) {
+ RUN_KERNEL(kelvin::opt::ConvS8K1x1DMod32);
+ }
+
// TODO: Relax this kernel for all output_depths
if ((output_depth < 8) && (input_depth % 32) == 0) {
- RUN_KERNEL(kelvin::opt::ConvS8K1x1D32);
+ RUN_KERNEL(kelvin::opt::ConvS8K1x1DMod32);
}
if ((output_depth % 16) == 0 && (input_depth == 16)) {
diff --git a/tflm/opt/conv_s8.h b/tflm/opt/conv_s8.h
index b79bd65..96cea2e 100644
--- a/tflm/opt/conv_s8.h
+++ b/tflm/opt/conv_s8.h
@@ -23,6 +23,17 @@
namespace kelvin::opt {
// filter 1x1 d%32==0
+void ConvS8K1x1DMod32(const tflite::ConvParams& params,
+ const int32_t* output_multiplier, const int32_t* output_shift,
+ const tflite::RuntimeShape& input_shape,
+ const int8_t* input_data,
+ const tflite::RuntimeShape& filter_shape,
+ const int8_t* filter_data,
+ const tflite::RuntimeShape& bias_shape,
+ const int32_t* bias_data,
+ const tflite::RuntimeShape& output_shape, int8_t* output_data);
+
+// filter 1x1 d==32
void ConvS8K1x1D32(const tflite::ConvParams& params,
const int32_t* output_multiplier, const int32_t* output_shift,
const tflite::RuntimeShape& input_shape,
diff --git a/tflm/opt/conv_s8_1x1.cc b/tflm/opt/conv_s8_1x1.cc
index 18046e9..600f765 100644
--- a/tflm/opt/conv_s8_1x1.cc
+++ b/tflm/opt/conv_s8_1x1.cc
@@ -22,7 +22,7 @@
namespace kelvin::opt {
-void ConvS8K1x1D32(
+void ConvS8K1x1DMod32(
const tflite::ConvParams& params, const int32_t* output_multiplier,
const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
const int8_t* input_data, const tflite::RuntimeShape& filter_shape,
@@ -138,8 +138,8 @@
}
int8_t* p_local_filter = p_swizzled_filter_data + (in_channel * 8);
- vld_b_p_x_m(v8, p_local_filter);
- vld_b_x_m(v12, p_local_filter);
+ vld_b_x_m(v8, p_local_filter);
+ vld_b_x_m(v12, p_local_filter + (4 * 32));
aconv_vxv(v48, v0, cmds, v8);
}
@@ -152,13 +152,13 @@
int i = 0;
for (; i < std::min(4, out_this_iter); i++) {
- vst_b_l_xx(v48, p_out, out_channels_this_iter);
- p_out += output_depth;
+ vst_b_l_xx(v48, p_out + (i * output_depth), out_channels_this_iter);
+ // p_out += output_depth;
vsliden_h_4_vv(v48, v48, v48);
}
for (; i < out_this_iter; i++) {
- vst_b_l_xx(v52, p_out, out_channels_this_iter);
- p_out += output_depth;
+ vst_b_l_xx(v52, p_out + (i * output_depth), out_channels_this_iter);
+ // p_out += output_depth;
vsliden_h_4_vv(v52, v52, v52);
}
}
@@ -167,6 +167,261 @@
} while (out_channel < output_depth);
}
+void ConvS8K1x1D32(
+ const tflite::ConvParams& params, const int32_t* output_multiplier,
+ const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
+ const int8_t* input_data, const tflite::RuntimeShape& filter_shape,
+ const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
+ const int32_t* bias_data, const tflite::RuntimeShape& output_shape,
+ int8_t* output_data) {
+ // Get parameters.
+ const int32_t input_offset = params.input_offset; // r = s(q - Z)
+ const int32_t output_offset = params.output_offset;
+
+ // Set min and max value of the output.
+ const int32_t output_activation_min = params.quantized_activation_min;
+ const int32_t output_activation_max = params.quantized_activation_max;
+
+ // Consistency check.
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_depth = input_shape.Dims(3);
+ const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
+ assert(input_depth == 32);
+ if (bias_data) {
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
+ }
+
+ // Check dimensions of the tensors.
+ const int filter_input_depth = filter_shape.Dims(3);
+ assert(filter_input_depth == 32);
+ const int groups = input_depth / filter_input_depth;
+ TFLITE_DCHECK_NE(groups, 0);
+ TFLITE_DCHECK_EQ(input_depth % filter_input_depth, 0);
+ const int filters_per_group = output_depth / groups;
+ TFLITE_DCHECK_NE(filters_per_group, 0);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+
+ union {
+ vconv_u8_t conv;
+ uint32_t raw;
+ } cmds;
+ cmds.conv.mode = 0;
+ cmds.conv.start = 0;
+ cmds.conv.stop = 7;
+ cmds.conv.sbias1 = input_offset;
+ cmds.conv.sdata1 = true;
+ cmds.conv.sbias2 = 0;
+ cmds.conv.sdata2 = true;
+ const size_t swizzled_filter_data_size =
+ 8 * filter_input_depth;
+ std::unique_ptr<int8_t> swizzled_filter_data(reinterpret_cast<int8_t*>(
+ ::aligned_alloc(32, swizzled_filter_data_size)));
+ int8_t* p_swizzled_filter_data = swizzled_filter_data.get();
+ int32_t swizzled_bias_data[32];
+ int32_t swizzled_mult_data[32];
+ int32_t swizzled_shift_data[32];
+
+// v0-v7 INPUT0
+// v8-v15 FLT
+// v16-v23 INPUT1
+// v24-v27 ACCB
+// v32-v43 unused
+// v44-47 BIAS
+// v48-v55 aconv ACC
+// v56-v59 MULT
+// v60-v63 BIAS
+#define INPUT0_0 v0
+#define INPUT0_1 v4
+#define FLT0_0 v8
+#define FLT0_1 v12
+#define INPUT1_0 v16
+#define INPUT1_1 v20
+#define ACCB0 v24
+#define ACCB1 v25
+#define ACCB2 v26
+#define ACCB3 v27
+#define ACCB4 v28
+#define ACCB5 v29
+#define ACCB6 v30
+#define ACCB7 v31
+#define BIAS0 v44
+#define ACC0 v48
+#define ACC1 v52
+#define MULT0 v56
+#define SHFT0 v60
+
+ const int n_elems = (output_width * batches * output_height);
+ int out_channel = 0;
+ do { // out_channel
+ int out_channels_this_iter = std::min(8, output_depth - out_channel);
+ assert(out_channels_this_iter == 8);
+ Filter_N_H_W_M(filter_data + (out_channel * filter_input_depth),
+ p_swizzled_filter_data, out_channels_this_iter, 1, 1,
+ filter_input_depth);
+ if (bias_data) {
+ Swizzle(bias_data + out_channel, swizzled_bias_data, out_channels_this_iter);
+ vld_w_x_m(BIAS0, swizzled_bias_data);
+ } else {
+ vdup_w_x_m(BIAS0, 0);
+ }
+ Swizzle(output_multiplier + out_channel, swizzled_mult_data, out_channels_this_iter);
+ Swizzle(output_shift + out_channel, swizzled_shift_data, out_channels_this_iter);
+
+ vld_w_x_m(MULT0, swizzled_mult_data);
+ vld_w_x_m(SHFT0, swizzled_shift_data);
+ vrsub_w_vx_m(SHFT0, SHFT0, 0);
+
+ int out = 0;
+ do {
+ const int8_t* p_in = input_data + (out * input_depth);
+ int8_t* p_out = output_data + (out * output_depth) + out_channel;
+
+ // 8x accumulators
+ vmv_v_m(ACC0, BIAS0);
+ vmv_v_m(ACC1, BIAS0);
+
+ acset_v(ACC0, ACC0);
+ const int8_t* p_input = p_in;
+ // Inputs
+ vld_b_s_xx_m(INPUT0_0, p_input, input_depth);
+ vld_b_s_xx_m(INPUT0_1, p_input + (4 * input_depth), input_depth);
+
+ int8_t* p_local_filter = p_swizzled_filter_data;
+ vld_b_x_m(FLT0_0, p_local_filter);
+ vld_b_x_m(FLT0_1, p_local_filter + (4 * 32));
+
+ aconv_vxv(ACC0, INPUT0_0, cmds, FLT0_0);
+ vld_b_s_xx_m(INPUT1_0, p_input + (8 * input_depth), input_depth);
+ vld_b_s_xx_m(INPUT1_1, p_input + ( 12 * input_depth), input_depth);
+
+ vcget(ACC0);
+ vmv_v_m(INPUT0_0, ACC0);
+ vmv_v_m(INPUT0_1, ACC1);
+ INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE2(
+ INPUT0_0, INPUT0_1, MULT0, SHFT0, output_activation_min, output_activation_max,
+ output_offset);
+ vsraqs_b_vx(ACCB0, INPUT0_0, 0);
+ vsraqs_b_vx(ACCB1, INPUT0_1, 0);
+
+ vmv_v_m(ACC0, BIAS0);
+ vstq_b_s_xx(ACCB0, p_out, output_depth);
+ vstq_b_s_xx(ACCB1, p_out + (4 * output_depth), output_depth);
+ vmv_v_m(ACC1, BIAS0);
+ acset_v(ACC0, ACC0);
+
+ aconv_vxv(ACC0, INPUT1_0, cmds, FLT0_0);
+ vld_b_s_xx_m(INPUT0_0, p_input + (16 * input_depth), input_depth);
+ vld_b_s_xx_m(INPUT0_1, p_input + (20 * input_depth), input_depth);
+ vcget(ACC0);
+ vmv_v_m(INPUT1_0, ACC0);
+ vmv_v_m(INPUT1_1, ACC1);
+ INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE2(
+ INPUT1_0, INPUT1_1, MULT0, SHFT0, output_activation_min, output_activation_max,
+ output_offset);
+ vsraqs_b_vx(ACCB2, INPUT1_0, 0);
+ vsraqs_b_vx(ACCB3, INPUT1_1, 0);
+
+ vld_b_s_xx_m(INPUT1_0, p_input + (24 * input_depth), input_depth);
+ vld_b_s_xx_m(INPUT1_1, p_input + (28 * input_depth), input_depth);
+
+ vmv_v_m(ACC0, BIAS0);
+ vstq_b_s_xx(ACCB2, p_out + (8 * output_depth), output_depth);
+ vstq_b_s_xx(ACCB3, p_out + (12 * output_depth), output_depth);
+ vmv_v_m(ACC1, BIAS0);
+ acset_v(ACC0, ACC0);
+ aconv_vxv(ACC0, INPUT0_0, cmds, FLT0_0);
+ vcget(ACC0);
+ vmv_v_m(INPUT0_0, ACC0);
+ vmv_v_m(INPUT0_1, ACC1);
+ INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE2(
+ INPUT0_0, INPUT0_1, MULT0, SHFT0, output_activation_min, output_activation_max,
+ output_offset);
+ vsraqs_b_vx(ACCB4, INPUT0_0, 0);
+ vsraqs_b_vx(ACCB5, INPUT0_1, 0);
+
+ vmv_v_m(ACC0, BIAS0);
+ vstq_b_s_xx(ACCB4, p_out + (16 * output_depth), output_depth);
+ vstq_b_s_xx(ACCB5, p_out + (20 * output_depth), output_depth);
+ vmv_v_m(ACC1, BIAS0);
+ acset_v(ACC0, ACC0);
+ aconv_vxv(ACC0, INPUT1_0, cmds, FLT0_0);
+ vcget(ACC0);
+ INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE2(
+ ACC0, ACC1, MULT0, SHFT0, output_activation_min, output_activation_max,
+ output_offset);
+ vsraqs_b_vx(ACCB6, ACC0, 0);
+ vsraqs_b_vx(ACCB7, ACC1, 0);
+
+ vstq_b_s_xx(ACCB6, p_out + (24 * output_depth), output_depth);
+ vstq_b_s_xx(ACCB7, p_out + (28 * output_depth), output_depth);
+
+ out += 32;
+ } while ((n_elems - out) >= 32);
+ do {// remainder loop
+ int out_this_iter = std::min(8, n_elems - out);
+ const int8_t* p_in = input_data + (out * input_depth);
+ int8_t* p_out = output_data + (out * output_depth) + out_channel;
+
+ // 8x accumulators
+ vmv_v_m(ACC0, BIAS0);
+ vmv_v_m(ACC1, BIAS0);
+ acset_v(ACC0, ACC0);
+ const int8_t* p_input = p_in;
+ if (out_this_iter < 8) {
+ switch (out_this_iter) {
+ case 7:
+ vld_b_x(v6, p_input + (6 * input_depth));
+ case 6:
+ vld_b_x(v5, p_input + (5 * input_depth));
+ case 5:
+ vld_b_x(v4, p_input + (4 * input_depth));
+ case 4:
+ vld_b_x(v3, p_input + (3 * input_depth));
+ case 3:
+ vld_b_x(v2, p_input + (2 * input_depth));
+ case 2:
+ vld_b_x(v1, p_input + input_depth);
+ case 1:
+ vld_b_x(INPUT0_0, p_input);
+ }
+ } else {
+ // Inputs
+ vld_b_s_xx_m(INPUT0_0, p_input, input_depth);
+ vld_b_s_xx_m(INPUT0_1, p_input + (4 * input_depth), input_depth);
+ }
+
+ int8_t* p_local_filter = p_swizzled_filter_data;
+ vld_b_x_m(FLT0_0, p_local_filter);
+ vld_b_x_m(FLT0_1, p_local_filter + (4 * 32));
+ aconv_vxv(ACC0, INPUT0_0, cmds, FLT0_0);
+
+ vcget(v48);
+ INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE2(
+ ACC0, ACC1, MULT0, SHFT0, output_activation_min, output_activation_max,
+ output_offset);
+ vsraqs_b_vx(ACC0, ACC0, 0);
+ vsraqs_b_vx(ACC1, ACC1, 0);
+
+ int i = 0;
+ for (; i < std::min(4, out_this_iter); i++) {
+ vst_b_l_xx(ACC0, p_out + (i * output_depth), out_channels_this_iter);
+ vsliden_h_4_vv(ACC0, ACC0, ACC0);
+ }
+ for (; i < out_this_iter; i++) {
+ vst_b_l_xx(ACC1, p_out + (i * output_depth), out_channels_this_iter);
+ vsliden_h_4_vv(ACC1, ACC1, ACC1);
+ }
+ out += out_this_iter;
+ } while (out < n_elems);
+ out_channel += out_channels_this_iter;
+ } while (out_channel < output_depth);
+}
+
void ConvS8K1x1D16(
const tflite::ConvParams& params, const int32_t* output_multiplier,
const int32_t* output_shift, const tflite::RuntimeShape& input_shape,