Add specialized pointwise convolution.

Change-Id: Ied82f85003980759417760572d4f5601a21369c6
diff --git a/tflm/opt/conv_s8.cc b/tflm/opt/conv_s8.cc
index 3f77764..a3237a6 100644
--- a/tflm/opt/conv_s8.cc
+++ b/tflm/opt/conv_s8.cc
@@ -200,6 +200,15 @@
   return; \
 }
 
+  // special case of filter size 1x1
+  if (filter_height == 1 && filter_width == 1 && stride_height == 1 &&
+      stride_width == 1 && dilation_height_factor == 1 &&
+      dilation_width_factor == 1 && pad_height == 0 && pad_width == 0 &&
+      (input_depth == filter_depth) && (output_depth % 8) == 0 &&
+      (input_depth % 32) == 0) {
+    RUN_KERNEL(kelvin::opt::ConvS8K1x1);
+  }
+
   if (input_depth == 1 && filter_width == 5 && filter_height == 5 &&
       output_depth == 24) {
     RUN_KERNEL(kelvin::opt::ConvPerChannelD1OD24_5x5);
@@ -218,15 +227,6 @@
     RUN_KERNEL(kelvin::opt::ConvS8D32);
   }
 
-  // special case of filter size 1x1
-  if (filter_height == 1 && filter_width == 1 && stride_height == 1 &&
-      stride_width == 1 && dilation_height_factor == 1 &&
-      dilation_width_factor == 1 && pad_height == 0 && pad_width == 0 &&
-      (output_depth % 8) == 0 && (input_depth % 32) == 0) {
-    // TODO(ndodda): uncomment it when all tests are passed
-    // RUN_KERNEL(kelvin::opt::ConvS8K1x1);
-  }
-
   // special case of filter size 48x3x1x48
   if (batches == 1 && filter_height == 3 && filter_width == 1 &&
       input_width == 1 && input_depth == 48 && output_depth == 48 &&
diff --git a/tflm/opt/conv_s8_1x1.cc b/tflm/opt/conv_s8_1x1.cc
index 9da99c3..bc61ddf 100644
--- a/tflm/opt/conv_s8_1x1.cc
+++ b/tflm/opt/conv_s8_1x1.cc
@@ -22,28 +22,43 @@
 
 namespace kelvin::opt {
 
-void ConvS8K1x1(const tflite::ConvParams& params,
-                const int32_t* output_multiplier, const int32_t* output_shift,
-                const tflite::RuntimeShape& input_shape,
-                const int8_t* input_data,
-                const tflite::RuntimeShape& filter_shape,
-                const int8_t* filter_data,
-                const tflite::RuntimeShape& bias_shape,
-                const int32_t* bias_data,
-                const tflite::RuntimeShape& output_shape, int8_t* output_data) {
-  const auto batches = MatchingDim(input_shape, 0, output_shape, 0);
-  const auto input_depth = input_shape.Dims(3);
-  const auto input_offset = params.input_offset;
-  const auto output_height = output_shape.Dims(1);
-  const auto output_width = output_shape.Dims(2);
-  const auto output_depth = output_shape.Dims(3);
-  const auto output_offset = params.output_offset;
-  const auto output_activation_min = params.quantized_activation_min;
-  const auto output_activation_max = params.quantized_activation_max;
-  //  ToDo : support group convolutions.
-  int32_t bias[8 * 4];
-  int32_t mult[8 * 4];
-  int32_t shft[8 * 4];
+void ConvS8K1x1(
+    const tflite::ConvParams& params, const int32_t* output_multiplier,
+    const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
+    const int8_t* input_data, const tflite::RuntimeShape& filter_shape,
+    const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
+    const int32_t* bias_data, const tflite::RuntimeShape& output_shape,
+    int8_t* output_data) {
+  // Get parameters.
+  const int32_t input_offset = params.input_offset;  // r = s(q - Z)
+  const int32_t output_offset = params.output_offset;
+
+  // Set min and max value of the output.
+  const int32_t output_activation_min = params.quantized_activation_min;
+  const int32_t output_activation_max = params.quantized_activation_max;
+
+  // Consistency check.
+  TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+  TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+  const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+  const int input_depth = input_shape.Dims(3);
+  const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
+  if (bias_data) {
+    TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
+  }
+
+  // Check dimensions of the tensors.
+  const int filter_input_depth = filter_shape.Dims(3);
+  const int groups = input_depth / filter_input_depth;
+  TFLITE_DCHECK_NE(groups, 0);
+  TFLITE_DCHECK_EQ(input_depth % filter_input_depth, 0);
+  const int filters_per_group = output_depth / groups;
+  TFLITE_DCHECK_NE(filters_per_group, 0);
+  const int output_height = output_shape.Dims(1);
+  const int output_width = output_shape.Dims(2);
+
   union {
     vconv_u8_t conv;
     uint32_t raw;
@@ -55,43 +70,105 @@
   cmds.conv.sdata1 = true;
   cmds.conv.sbias2 = 0;
   cmds.conv.sdata2 = true;
-  for (int zo_hi = 0; zo_hi < output_depth; zo_hi += 8) {
-    // transpose filter weigths to support outer prodcut multiplication
-    int8_t juggled_filter_data[1][1][1][input_depth / 4][8][4];
-    Filter_N_H_W_M<8>(filter_data, juggled_filter_data[0][0][0][0][0], 1, 1,
-                      32);
 
-    Swizzle(bias_data, bias, 8);
-    Swizzle(output_multiplier, mult, 8);
-    Swizzle(output_shift, shft, 8, true);
+  const size_t swizzled_filter_data_size =
+      8 * filter_input_depth;
+  std::unique_ptr<int8_t> swizzled_filter_data(reinterpret_cast<int8_t*>(
+      ::aligned_alloc(32, swizzled_filter_data_size)));
+  int8_t* p_swizzled_filter_data = swizzled_filter_data.get();
+  int32_t swizzled_bias_data[32];
+  int32_t swizzled_mult_data[32];
+  int32_t swizzled_shift_data[32];
+
+  const int n_elems = (output_width * batches * output_height);
+  int out_channel = 0;
+  do {
+    int out_channels_this_iter = std::min(8, output_depth - out_channel);
+    Filter_N_H_W_M(filter_data + (out_channel * filter_input_depth),
+                   p_swizzled_filter_data, out_channels_this_iter, 1, 1,
+                   filter_input_depth);
+    if (bias_data) {
+      Swizzle(bias_data + out_channel, swizzled_bias_data, out_channels_this_iter);
+      vld_w_x_m(v16, swizzled_bias_data);
+    } else {
+      vdup_w_x_m(v16, 0);
+    }
+    Swizzle(output_multiplier + out_channel, swizzled_mult_data, out_channels_this_iter);
+    Swizzle(output_shift + out_channel, swizzled_shift_data, out_channels_this_iter);
+
+    vld_w_x_m(v20, swizzled_mult_data);
+    vld_w_x_m(v24, swizzled_shift_data);
+    vrsub_w_vx_m(v24, v24, 0);
+
     int out = 0;
-    for (; out + 8 <= output_height * output_width * batches; out += 8) {
-      // resetting accumulators to clean up old output
-      vdup_b_x_m(v48, 0);
-      vdup_b_x_m(v52, 0);
+    for (; out < n_elems; out += 8) {
+      int out_this_iter = std::min(8, n_elems - out);
 
-      int in = 0;
-      for (; in <= input_depth; in += 32) {
-        vld_b_s_xx_m(v0, input_data + out * input_depth + in, input_depth);
-        vld_b_s_xx_m(v4, input_data + out * input_depth + in + 4 * input_depth,
-                     input_depth);
+      const int8_t* p_in = input_data + (out * input_depth);
+      int8_t* p_out = output_data + (out * output_depth) + out_channel;
 
-        vld_b_x_m(v8, juggled_filter_data[0][0][0][in / 32][0][0]);
-        vld_b_x_m(v12, juggled_filter_data[0][0][0][(in / 32) + 4][0][0]);
+      // 8x accumulators
+      vmv_v_m(v48, v16);
+      vmv_v_m(v52, v16);
+      acset_v(v48, v48);
+      int in_channel = 0;
+      for (; in_channel < filter_input_depth; in_channel += 32) {
+        const int8_t* p_input = p_in + in_channel;
+        if (out_this_iter < 8) {
+          switch (out_this_iter) {
+            case 7:
+              vld_b_x(v6, p_input + (6 * input_depth));
+            case 6:
+              vld_b_x(v5, p_input + (5 * input_depth));
+            case 5:
+              vld_b_x(v4, p_input + (4 * input_depth));
+            case 4:
+              vld_b_x(v3, p_input + (3 * input_depth));
+            case 3:
+              vld_b_x(v2, p_input + (2 * input_depth));
+            case 2:
+              vld_b_x(v1, p_input + input_depth);
+            case 1:
+              vld_b_x(v0, p_input);
+          }
+        } else {
+          // Inputs
+          vld_b_s_xx_m(v0, p_input, input_depth);
+          vld_b_s_xx_m(v4, p_input + (4 * input_depth), input_depth);
+        }
+
+        int8_t* p_local_filter = p_swizzled_filter_data + (in_channel * 8);
+        vld_b_p_x_m(v8, p_local_filter);
+        vld_b_x_m(v12, p_local_filter);
 
         aconv_vxv(v48, v0, cmds, v8);
       }
 
-      INT32_TO_INT8_OUTPUT_PIPELINE(bias, mult, shft, output_activation_min,
-                                    output_activation_max, output_offset, v16,
-                                    v20, v24);
+      vcget(v48);
+      INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
+          v48, v20, v24, output_activation_min, output_activation_max,
+          output_offset);
+      INT32_TO_INT8_OUTPUT_PIPELINE_INPLACE(
+          v52, v20, v24, output_activation_min, output_activation_max,
+          output_offset);
+      vsraqs_b_vx(v48, v48, 0);
+      vsraqs_b_vx(v52, v52, 0);
 
-      // store the results to ouput memory
-      int8_t* p_out = output_data + (out * output_depth) + zo_hi;
-      vstq_b_sp_xx(v48, p_out, output_depth);
-      vstq_b_sp_xx(v52, p_out, output_depth);
+      int i = 0;
+      for (; i < std::min(4, out_this_iter); i++) {
+        vst_b_l_xx(v48, p_out, out_channels_this_iter);
+        p_out += output_depth;
+        vsliden_h_4_vv(v48, v48, v48);
+      }
+      for (; i < out_this_iter; i++) {
+        vst_b_l_xx(v52, p_out, out_channels_this_iter);
+        p_out += output_depth;
+        vsliden_h_4_vv(v52, v52, v52);
+      }
     }
-  }
+
+    out_channel += out_channels_this_iter;
+  } while (out_channel < output_depth);
 }
 
 }  // namespace kelvin::opt
diff --git a/tflm/opt/conv_util.h b/tflm/opt/conv_util.h
index e552d52..7c925c4 100644
--- a/tflm/opt/conv_util.h
+++ b/tflm/opt/conv_util.h
@@ -80,6 +80,36 @@
   }
 }
 
+inline void Filter_N_H_W_M(const int8_t* input, int8_t* output, int N, int H,
+                           int W, int M) {
+  const int8_t(&in)[8][H][W][M] = *(int8_t(*)[8][H][W][M])input;
+  int8_t(&out)[H][W][M / 4][8][4] = *(int8_t(*)[H][W][M / 4][8][4]) output;
+  assert(M >= 4);
+  for (int zo = 0; zo < N; ++zo) {
+    for (int ky = 0; ky < H; ++ky) {
+      for (int kx = 0; kx < W; ++kx) {
+        for (int zi = 0; zi < M; ++zi) {
+          const int zi_hi = zi >> 2;  // div4
+          const int zi_lo = zi & 3;   // rem4
+          out[ky][kx][zi_hi][zo][zi_lo] = in[zo][ky][kx][zi];
+        }
+      }
+    }
+  }
+  // Zero out the rest of the output.
+  for (int zo = N; zo < 8; ++zo) {
+    for (int ky = 0; ky < H; ++ky) {
+      for (int kx = 0; kx < W; ++kx) {
+        for (int zi = 0; zi < M; ++zi) {
+          const int zi_hi = zi >> 2;  // div4
+          const int zi_lo = zi & 3;   // rem4
+          out[ky][kx][zi_hi][zo][zi_lo] = 0;
+        }
+      }
+    }
+  }
+}
+
 // Swizzle values, and duplicate 4 times for stripmining.
 inline void Swizzle(const int32_t* input, int32_t* output, int N,
                     bool negate = false) {