Merge "Process 16 channels per iteration in 1x1 d16"
diff --git a/tflm/opt/conv_s8.cc b/tflm/opt/conv_s8.cc
index 64e22e1..933e9f8 100644
--- a/tflm/opt/conv_s8.cc
+++ b/tflm/opt/conv_s8.cc
@@ -214,7 +214,7 @@
RUN_KERNEL(kelvin::opt::ConvS8K1x1D32);
}
- if ((output_depth % 8) == 0 && (input_depth == 16)) {
+ if ((output_depth % 16) == 0 && (input_depth == 16)) {
RUN_KERNEL(kelvin::opt::ConvS8K1x1D16);
}
}
diff --git a/tflm/opt/conv_s8_1x1.cc b/tflm/opt/conv_s8_1x1.cc
index 26a681b..18046e9 100644
--- a/tflm/opt/conv_s8_1x1.cc
+++ b/tflm/opt/conv_s8_1x1.cc
@@ -223,59 +223,95 @@
const int effective_output_width = batches * output_width * output_height;
- // TODO(derekjchow): Remove this when simulator supports vx vslide ops
- vdup_b_x_m(v12, 0);
+#define INPUT v0 // v0, v1, v2, v3, v4, v5, v6, v7
+#define INPUT_SLIDE v4
+#define FLT_0 v8 // v8, v9, v10, v11
+#define FLT_1 v16 // v16, v17, v18, v19
+#define BIAS_0 v20
+#define BIAS_1 v24
+#define MULT_0 v28
+#define MULT_1 v32
+#define SHFT_0 v36
+#define SHFT_1 v40
+#define ACC_0 v48
+#define ACC_1 v52
+#define RES_0 v60
+#define RES_1 v61
+#define RES_2 v62
+#define RES_3 v63
int out_channel = 0;
- for (; out_channel < output_depth; out_channel += 8) {
+ for (; out_channel < output_depth; out_channel += 16) {
Filter_N_H_W_M(filter_data + (out_channel * 16),
swizzled_filter_data, 8, 1, 1, 16);
- vld_b_x_m(v8, swizzled_filter_data);
+ vld_b_x_m(FLT_0, swizzled_filter_data);
+ Filter_N_H_W_M(filter_data + ((out_channel + 8) * 16), swizzled_filter_data, 8, 1, 1, 16);
+ vld_b_x_m(FLT_1, swizzled_filter_data);
if (bias_data) {
Swizzle(bias_data + out_channel, swizzled_bias_data, 8);
- vld_w_x_m(v16, swizzled_bias_data);
+ vld_w_x_m(BIAS_0, swizzled_bias_data);
+ Swizzle(bias_data + out_channel + 8, swizzled_bias_data, 8);
+ vld_w_x_m(BIAS_1, swizzled_bias_data);
} else {
- vdup_w_x_m(v16, 0);
+ vdup_w_x_m(BIAS_0, 0);
+ vdup_w_x_m(BIAS_1, 0);
}
+
Swizzle(output_multiplier + out_channel, swizzled_mult_data, 8);
Swizzle(output_shift + out_channel, swizzled_shift_data, 8);
+ vld_w_x_m(MULT_0, swizzled_mult_data);
+ vld_w_x_m(SHFT_0, swizzled_shift_data);
+ vrsub_w_vx_m(SHFT_0, SHFT_0, 0);
- vld_w_x_m(v20, swizzled_mult_data);
- vld_w_x_m(v24, swizzled_shift_data);
- vrsub_w_vx_m(v24, v24, 0);
+ Swizzle(output_multiplier + out_channel + 8, swizzled_mult_data, 8);
+ Swizzle(output_shift + out_channel + 8, swizzled_shift_data, 8);
+ vld_w_x_m(MULT_1, swizzled_mult_data);
+ vld_w_x_m(SHFT_1, swizzled_shift_data);
+ vrsub_w_vx_m(SHFT_1, SHFT_1, 0);
int8_t* p_output = output_data + out_channel;
int out = 0;
for (; out + 8 <= effective_output_width; out += 8) {
// 8x accumulators
- vmv_v_m(v48, v16);
- vmv_v_m(v52, v16);
- acset_v(v48, v48);
const int8_t* p_in = input_data + (out * input_depth);
- vld_b_x_m(v0, p_in);
- vslidevp_w_4_vv_m(v4, v0, v12);
- aconv_vxv(v48, v0, cmds, v8);
+ vld_b_x_m(INPUT, p_in);
+ vslidevp_w_4_vv_m(INPUT_SLIDE, INPUT, INPUT);
- vcget(v48);
+ vmv_v_m(ACC_0, BIAS_0);
+ vmv_v_m(ACC_1, BIAS_0);
+ acset_v(ACC_0, ACC_0);
+ aconv_vxv(ACC_0, INPUT, cmds, FLT_0);
+
+ vcget(ACC_0);
INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE2(
- v48, v52, v20, v24, output_activation_min, output_activation_max,
+ ACC_0, ACC_1, MULT_0, SHFT_0, output_activation_min, output_activation_max,
output_offset);
- vsraqs_b_vx(v48, v48, 0);
- vsraqs_b_vx(v52, v52, 0);
+ vsraqs_b_vx(RES_0, ACC_0, 0);
+ vsraqs_b_vx(RES_1, ACC_1, 0);
+ vstq_b_s_xx(RES_0, p_output, 2 * output_depth);
+ vstq_b_s_xx(RES_1, p_output + output_depth, 2 * output_depth);
- vstq_b_s_xx(v48, p_output, 2 * output_depth);
- vstq_b_s_xx(v52, p_output + output_depth, 2 * output_depth);
+ vmv_v_m(ACC_0, BIAS_1);
+ vmv_v_m(ACC_1, BIAS_1);
+ acset_v(ACC_0, ACC_0);
+ aconv_vxv(ACC_0, INPUT, cmds, FLT_1);
+ vcget(ACC_0);
+ INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE2(
+ ACC_0, ACC_1, MULT_1, SHFT_1, output_activation_min, output_activation_max,
+ output_offset);
+ vsraqs_b_vx(RES_2, ACC_0, 0);
+ vsraqs_b_vx(RES_3, ACC_1, 0);
+ vstq_b_s_xx(RES_2, p_output + 8, 2 * output_depth);
+ vstq_b_s_xx(RES_3, p_output + output_depth + 8, 2 * output_depth);
p_output += (8 * output_depth);
+
} // out_x
// Remainder
int remainder_x = (effective_output_width - out);
if (remainder_x != 0) {
- vmv_v_m(v48, v16);
- vmv_v_m(v52, v16);
- acset_v(v48, v48);
const int8_t* p_in = input_data + (out * input_depth);
// Load inputs
@@ -298,29 +334,62 @@
vld_b_l_xx(v0, p_in, 16);
}
- aconv_vxv(v48, v0, cmds, v8);
-
- vcget(v48);
+ vmv_v_m(ACC_0, BIAS_0);
+ vmv_v_m(ACC_1, BIAS_0);
+ acset_v(ACC_0, ACC_0);
+ aconv_vxv(ACC_0, INPUT, cmds, FLT_0);
+ vcget(ACC_0);
INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE2(
- v48, v52, v20, v24, output_activation_min, output_activation_max,
+ ACC_0, ACC_1, MULT_0, SHFT_0, output_activation_min, output_activation_max,
output_offset);
- vsraqs_b_vx(v48, v48, 0);
- vsraqs_b_vx(v52, v52, 0);
+ vsraqs_b_vx(RES_0, ACC_0, 0);
+ vsraqs_b_vx(RES_1, ACC_1, 0);
+
+ vmv_v_m(ACC_0, BIAS_1);
+ vmv_v_m(ACC_1, BIAS_1);
+ acset_v(ACC_0, ACC_0);
+ aconv_vxv(ACC_0, INPUT, cmds, FLT_1);
+ vcget(ACC_0);
+ INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE2(
+ ACC_0, ACC_1, MULT_1, SHFT_1, output_activation_min, output_activation_max,
+ output_offset);
+ vsraqs_b_vx(RES_2, ACC_0, 0);
+ vsraqs_b_vx(RES_3, ACC_1, 0);
int i = 0;
for (; i < std::min(4, remainder_x); i++) {
- vst_b_l_xx(v48, p_output, 8);
+ vst_b_l_xx(RES_0, p_output, 8);
+ vsliden_w_2_vv(RES_0, RES_0, RES_0);
+ vst_b_l_xx(RES_2, p_output + 8, 8);
+ vsliden_w_2_vv(RES_2, RES_2, RES_2);
p_output += output_depth;
- vsliden_w_2_vv(v48, v48, v12);
}
for (; i < remainder_x; i++) {
- vst_b_l_xx(v52, p_output, 8);
+ vst_b_l_xx(RES_1, p_output, 8);
+ vsliden_w_2_vv(RES_1, RES_1, RES_1);
+ vst_b_l_xx(RES_3, p_output + 8, 8);
+ vsliden_w_2_vv(RES_3, RES_3, RES_3);
p_output += output_depth;
- vsliden_w_2_vv(v52, v52, v12);
}
}
}
+#undef INPUT
+#undef INPUT_SLIDE
+#undef FLT_0
+#undef FLT_1
+#undef BIAS_0
+#undef BIAS_1
+#undef MULT_0
+#undef MULT_1
+#undef SHFT_0
+#undef SHFT_1
+#undef ACC_0
+#undef ACC_1
+#undef RES_0
+#undef RES_1
+#undef RES_2
+#undef RES_3
}
} // namespace kelvin::opt