Specialized int8 conv kernels

- Add int8 conv kernels specialized for the following shapes:
    - filter_input_depth % 32
    - filter_input_depth % 32 && input_channels % 8 && output_width % 8
      && pad_width == 1
- Move the selection of kernels into the toplevel conv_per_channel_b8,
  instead of inside the TF kernel.

Change-Id: Ie10bcff04cc3394aea630f7e073e63cbf239eb68
diff --git a/tflm/opt/BUILD b/tflm/opt/BUILD
index e4d533b..19eb017 100644
--- a/tflm/opt/BUILD
+++ b/tflm/opt/BUILD
@@ -18,6 +18,7 @@
     name = "opt",
     srcs = [
         "conv.cc",
+        "conv_s8.cc",
         "depthwise_conv_s16.cc",
         "depthwise_conv_s8.cc",
         "elementwise_add_s16.cc",
diff --git a/tflm/opt/conv.cc b/tflm/opt/conv.cc
index 8d33848..49d32d5 100644
--- a/tflm/opt/conv.cc
+++ b/tflm/opt/conv.cc
@@ -196,7 +196,6 @@
   const auto input_height = input_shape.Dims(1);
   const auto input_width = input_shape.Dims(2);
   const auto input_depth = input_shape.Dims(3);
-  const auto input_offset = params.input_offset;
   const auto filter_input_depth = filter_shape.Dims(3);
   const auto output_depth = output_shape.Dims(3);
   const auto output_offset = params.output_offset;
@@ -271,7 +270,6 @@
   const auto pad_width = params.padding_values.width;
   const auto input_width = input_shape.Dims(2);
   const auto input_depth = input_shape.Dims(3);
-  const auto input_offset = params.input_offset;
   const auto filter_width = filter_shape.Dims(2);
   const auto filter_depth = filter_shape.Dims(3);
   const auto output_width = output_shape.Dims(2);
@@ -370,7 +368,6 @@
   const auto pad_width = params.padding_values.width;
   const auto input_width = input_shape.Dims(2);
   const auto input_depth = input_shape.Dims(3);
-  const auto input_offset = params.input_offset;
   const auto filter_width = filter_shape.Dims(2);
   const auto filter_depth = filter_shape.Dims(3);
   const auto output_width = output_shape.Dims(2);
@@ -584,148 +581,4 @@
       output_data);
 }
 
-#define INA0 v0
-#define FLTA0 v8
-#define FLTA1 v9
-#define FLTA2 v10
-#define FLTA3 v11
-#define FLTA4 v12
-#define FLTA5 v13
-#define FLTA6 v14
-#define FLTA7 v15
-#define ACC v48
-#define ACC0 v48
-#define OUT0 v56
-void conv_per_channel_b8(
-    const tflite::ConvParams& params, const int32_t* output_multiplier,
-    const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
-    const int8_t* input_data, const tflite::RuntimeShape& filter_shape,
-    const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
-    const int32_t* bias_data, const tflite::RuntimeShape& output_shape,
-    int8_t* output_data) {
-  const auto batches = MatchingDim(input_shape, 0, output_shape, 0);
-  const auto stride_width = params.stride_width;
-  const auto stride_height = params.stride_height;
-  const auto dilation_width_factor = params.dilation_width_factor;
-  const auto dilation_height_factor = params.dilation_height_factor;
-  const auto pad_width = params.padding_values.width;
-  const auto pad_height = params.padding_values.height;
-  const auto input_height = input_shape.Dims(1);
-  const auto input_width = input_shape.Dims(2);
-  const auto input_depth = input_shape.Dims(3);
-  const auto input_offset = params.input_offset;
-  const auto filter_height = filter_shape.Dims(1);
-  const auto filter_width = filter_shape.Dims(2);
-  const auto filter_depth = filter_shape.Dims(3);
-  const auto output_height = output_shape.Dims(1);
-  const auto output_width = output_shape.Dims(2);
-  const auto output_depth = output_shape.Dims(3);
-  const auto output_offset = params.output_offset;
-  const auto output_activation_min = params.quantized_activation_min;
-  const auto output_activation_max = params.quantized_activation_max;
-  const auto groups = input_depth / filter_depth;
-  const auto filters_per_group = output_depth / groups;
-  union {
-    vconv_u8_t conv;
-    uint32_t raw;
-  } cmds;
-  cmds.conv.mode = 0;
-  cmds.conv.start = 0;
-  cmds.conv.stop = 7;
-  cmds.conv.sbias1 = input_offset;
-  cmds.conv.sdata1 = true;
-  cmds.conv.sbias2 = 0;
-  cmds.conv.sdata2 = true;
-
-  // Zero out accumulators.
-  vdup_b_x(v0, 0);
-  acset_v(ACC, v0);
-  vdup_b_x_m(ACC0, 0);
-  for (int batch = 0; batch < batches; ++batch) {
-    for (int out_y = 0; out_y < output_height; ++out_y) {
-      const int in_y_origin = (out_y * stride_height) - pad_height;
-      for (int out_x = 0; out_x < output_width; /*out_x += 32*/ ++out_x) {
-        const int in_x_origin = (out_x * stride_width) - pad_width;
-        for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
-          auto group = out_channel / filters_per_group;
-
-          for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
-            const int in_y = in_y_origin + dilation_height_factor * filter_y;
-            const int in_x = in_x_origin + dilation_width_factor * 0;
-
-            // Zero padding by omitting the areas outside the image.
-            const bool is_point_inside_image =
-                (in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
-                (in_y < input_height);
-            if (!is_point_inside_image) {
-              continue;
-            }
-
-            int q = filter_width * filter_depth;
-            for (int i = 0; i < q; i += 32) {
-              int count = std::min(q - i, 32);
-              count = std::min(
-                  count, static_cast<int>((input_width - in_x) * filter_depth));
-              int input_offset = tflite::Offset(input_shape, batch, in_y, in_x,
-                                                group * filter_depth) +
-                                 i;
-              vdup_w_x_m(vm0, 0);
-              vdup_w_x_m(vm1, 0);
-              vld_b_l_xx(INA0, &input_data[input_offset], count);
-              int filter_offset =
-                  tflite::Offset(filter_shape, out_channel, filter_y, 0, 0) + i;
-              vdup_w_x_m(FLTA0, 0);
-              vdup_w_x_m(FLTA4, 0);
-              if (count > 0) {
-                vld_b_l_xx(FLTA0, &filter_data[filter_offset],
-                           std::min(count, 4));
-              }
-              if (count > 4) {
-                vld_b_l_xx(FLTA1, &filter_data[filter_offset + 4],
-                           std::min(count - 4, 4));
-              }
-              if (count > 8) {
-                vld_b_l_xx(FLTA2, &filter_data[filter_offset + 8],
-                           std::min(count - 8, 4));
-              }
-              if (count > 12) {
-                vld_b_l_xx(FLTA3, &filter_data[filter_offset + 12],
-                           std::min(count - 12, 4));
-              }
-              if (count > 16) {
-                vld_b_l_xx(FLTA4, &filter_data[filter_offset + 16],
-                           std::min(count - 16, 4));
-              }
-              if (count > 20) {
-                vld_b_l_xx(FLTA5, &filter_data[filter_offset + 20],
-                           std::min(count - 20, 4));
-              }
-              if (count > 24) {
-                vld_b_l_xx(FLTA6, &filter_data[filter_offset + 24],
-                           std::min(count - 24, 4));
-              }
-              if (count > 28) {
-                vld_b_l_xx(FLTA7, &filter_data[filter_offset + 28],
-                           std::min(count - 28, 4));
-              }
-              aconv_vxv(ACC, INA0, cmds, FLTA0);
-            }
-          }
-          vcget(ACC);
-          vadd_w_vx_m(ACC0, ACC0, bias_data[out_channel]);
-          vsll_w_vx_m(ACC0, ACC0, LEFT_SHIFT(output_shift[out_channel]));
-          vdmulh_w_r_vx_m(ACC0, ACC0, output_multiplier[out_channel]);
-          vsha_w_r_vx_m(ACC0, ACC0, RIGHT_SHIFT(output_shift[out_channel]));
-          vadd_w_vx_m(ACC0, ACC0, output_offset);
-          vmin_w_vx_m(ACC0, ACC0, output_activation_max);
-          vmax_w_vx_m(ACC0, ACC0, output_activation_min);
-          vsraqs_b_vx(OUT0, ACC0, 0);
-          size_t output_offset =
-              tflite::Offset(output_shape, batch, out_y, out_x, out_channel);
-          vst_b_l_xx(OUT0, &output_data[output_offset], 1);
-        }
-      }
-    }
-  }
-}
 }  // namespace kelvin::opt
diff --git a/tflm/opt/conv_s8.cc b/tflm/opt/conv_s8.cc
new file mode 100644
index 0000000..2da0028
--- /dev/null
+++ b/tflm/opt/conv_s8.cc
@@ -0,0 +1,599 @@
+/*
+ * Copyright 2024 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <cstdlib>
+#include <memory>
+
+#include "crt/kelvin.h"
+#include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
+#include "tensorflow/lite/kernels/internal/runtime_shape.h"
+#include "tensorflow/lite/kernels/internal/types.h"
+#include "tflm/opt/opt.h"
+#include "tflm/opt/util.h"
+
+namespace kelvin::opt {
+
+namespace {
+constexpr int kFilterInputChannelIndex = 3;
+constexpr int kOutputWidthIndex = 2;
+constexpr int kOutputChannelIndex = 3;
+
+// Convert: input  [zo][ky][kx][zi] (N,4,4,M)
+//          output [ky][kx][zi_hi=M/4][zo=8][zi_lo=4]
+//          output [3][3][16][8][4]
+void Filter_8_H_W_M(const int8_t* input, int8_t* output, int H, int W, int M) {
+  const int8_t(&in)[8][H][W][M] = *(int8_t(*)[8][H][W][M])input;
+  int8_t(&out)[H][W][M / 4][8][4] = *(int8_t(*)[H][W][M / 4][8][4]) output;
+  assert(M >= 4);
+  for (int zo = 0; zo < 8; ++zo) {
+    for (int ky = 0; ky < H; ++ky) {
+      for (int kx = 0; kx < W; ++kx) {
+        for (int zi = 0; zi < M; ++zi) {
+          const int zi_hi = zi >> 2;  // div4
+          const int zi_lo = zi & 3;   // rem4
+          out[ky][kx][zi_hi][zo][zi_lo] = in[zo][ky][kx][zi];
+        }
+      }
+    }
+  }
+}
+
+void Swizzle(const int32_t* input, int32_t* output, int N) {
+  const int32_t(&in)[N] = *(int32_t(*)[N])input;
+  int32_t(&out)[N * 4] = *(int32_t(*)[N * 4]) output;
+  // Convert to accumulator swizzle pattern.
+  for (int i = 0; i < N / 8; ++i) {
+    int32_t* out0 = out + i * 32 + 0;
+    int32_t* out1 = out + i * 32 + 16;
+    int32_t* out2 = out + i * 32 + 8;
+    int32_t* out3 = out + i * 32 + 24;
+    for (int j = 0; j < 4; ++j) {
+      const int32_t* p_in = in + i * 8;
+      for (int k = 0; k < 2; ++k) {
+        *out0++ = *p_in++;
+        *out1++ = *p_in++;
+        *out2++ = *p_in++;
+        *out3++ = *p_in++;
+      }
+    }
+  }
+}
+
+}  // namespace
+
+void conv_per_channel_pw1_ow8_id8_filterd32(
+    const tflite::ConvParams& params, const int32_t* output_multiplier,
+    const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
+    const int8_t* input_data, const tflite::RuntimeShape& filter_shape,
+    const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
+    const int32_t* bias_data, const tflite::RuntimeShape& output_shape,
+    int8_t* output_data) {
+  // Get parameters.
+  const int32_t input_offset = params.input_offset;  // r = s(q - Z)
+  const int stride_width = params.stride_width;
+  const int stride_height = params.stride_height;
+  const int dilation_width_factor = params.dilation_width_factor;
+  const int dilation_height_factor = params.dilation_height_factor;
+  const int pad_width = params.padding_values.width;
+  const int pad_height = params.padding_values.height;
+  const int32_t output_offset = params.output_offset;
+
+  // Set min and max value of the output.
+  const int32_t output_activation_min = params.quantized_activation_min;
+  const int32_t output_activation_max = params.quantized_activation_max;
+
+  // Consistency check.
+  TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+  TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+  const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+  const int input_depth = input_shape.Dims(3);
+  const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
+  if (bias_data) {
+    TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
+  }
+
+  // Check dimensions of the tensors.
+  const int input_height = input_shape.Dims(1);
+  const int input_width = input_shape.Dims(2);
+  const int filter_height = filter_shape.Dims(1);
+  const int filter_width = filter_shape.Dims(2);
+  const int filter_input_depth = filter_shape.Dims(3);
+  const int groups = input_depth / filter_input_depth;
+  TFLITE_DCHECK_NE(groups, 0);
+  TFLITE_DCHECK_EQ(input_depth % filter_input_depth, 0);
+  const int filters_per_group = output_depth / groups;
+  TFLITE_DCHECK_NE(filters_per_group, 0);
+  const int output_height = output_shape.Dims(1);
+  const int output_width = output_shape.Dims(2);
+
+  union {
+    vconv_u8_t conv;
+    uint32_t raw;
+  } cmds;
+  cmds.conv.mode = 0;
+  cmds.conv.start = 0;
+  cmds.conv.stop = 7;
+  cmds.conv.sbias1 = input_offset;
+  cmds.conv.sdata1 = true;
+  cmds.conv.sbias2 = 0;
+  cmds.conv.sdata2 = true;
+
+  const size_t swizzled_filter_data_size =
+      8 * filter_height * filter_width * filter_input_depth;
+  std::unique_ptr<int8_t> swizzled_filter_data(reinterpret_cast<int8_t*>(
+      ::aligned_alloc(32, swizzled_filter_data_size)));
+  int8_t* p_swizzled_filter_data = swizzled_filter_data.get();
+  int32_t swizzled_bias_data[32];
+  int32_t swizzled_mult_data[32];
+  int32_t swizzled_shift_data[32];
+
+  for (int out_channel = 0; out_channel + 8 <= output_depth; out_channel += 8) {
+    Filter_8_H_W_M(filter_data + (out_channel * filter_height * filter_width *
+                                  filter_input_depth),
+                   p_swizzled_filter_data, filter_height, filter_width,
+                   filter_input_depth);
+    Swizzle(bias_data + out_channel, swizzled_bias_data, 8);
+    Swizzle(output_multiplier + out_channel, swizzled_mult_data, 8);
+    Swizzle(output_shift + out_channel, swizzled_shift_data, 8);
+    vld_w_x_m(v16, swizzled_bias_data);
+    vld_w_x_m(v20, swizzled_mult_data);
+    vld_w_x_m(v24, swizzled_shift_data);
+    vrsub_w_vx_m(v24, v24, 0);
+
+    for (int batch = 0; batch < batches; ++batch) {
+      for (int out_y = 0; out_y < output_height; ++out_y) {
+        const int in_y_origin = (out_y * stride_height) - pad_height;
+        for (int out_x = 0; out_x + 8 <= output_width; out_x += 8) {
+          // 8x accumulators
+          vdup_w_x_m(v48, 0);
+          vdup_w_x_m(v52, 0);
+          acset_v(v48, v48);
+          for (int in_channel = 0; in_channel + 32 <= filter_input_depth;
+               in_channel += 32) {
+            for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+              const int in_y = in_y_origin + dilation_height_factor * filter_y;
+              const bool is_row_inside_input =
+                  (in_y >= 0) && (in_y < input_height);
+              if (!is_row_inside_input) {
+                continue;
+              }
+
+              for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+                int in_x[8];
+                bool left_pad = false;
+                bool right_pad = false;
+                for (int i = 0; i < 8; ++i) {
+                  const int in_x_origin =
+                      ((out_x + i) * stride_width) - pad_width;
+                  in_x[i] = in_x_origin + dilation_width_factor * filter_x;
+                  if (in_x[i] < 0) {
+                    left_pad = true;
+                  }
+                  if (in_x[i] >= input_width) {
+                    right_pad = true;
+                  }
+                }
+
+                if (left_pad) {
+                  vdup_b_x(v0, -input_offset);
+                  vld_b_s_xx(
+                      v1,
+                      &input_data[tflite::Offset(input_shape, batch, in_y,
+                                                 in_x[1], in_channel)],
+                      input_depth * stride_width);
+                  vld_b_s_xx(
+                      v2,
+                      &input_data[tflite::Offset(input_shape, batch, in_y,
+                                                 in_x[2], in_channel)],
+                      input_depth * stride_width);
+                  vld_b_s_xx(
+                      v3,
+                      &input_data[tflite::Offset(input_shape, batch, in_y,
+                                                 in_x[3], in_channel)],
+                      input_depth * stride_width);
+                  vld_b_s_xx_m(
+                      v4,
+                      &input_data[tflite::Offset(input_shape, batch, in_y,
+                                                 in_x[4], in_channel)],
+                      input_depth * stride_width);
+                } else if (right_pad) {
+                  vld_b_s_xx_m(
+                      v0,
+                      &input_data[tflite::Offset(input_shape, batch, in_y,
+                                                 in_x[0], in_channel)],
+                      input_depth * stride_width);
+                  vld_b_s_xx(
+                      v4,
+                      &input_data[tflite::Offset(input_shape, batch, in_y,
+                                                 in_x[4], in_channel)],
+                      input_depth * stride_width);
+                  vld_b_s_xx(
+                      v5,
+                      &input_data[tflite::Offset(input_shape, batch, in_y,
+                                                 in_x[5], in_channel)],
+                      input_depth * stride_width);
+                  vld_b_s_xx(
+                      v6,
+                      &input_data[tflite::Offset(input_shape, batch, in_y,
+                                                 in_x[6], in_channel)],
+                      input_depth * stride_width);
+                  vdup_b_x(v7, -input_offset);
+                } else if (!left_pad && !right_pad) {
+                  // Inputs
+                  vld_b_s_xx_m(
+                      v0,
+                      &input_data[tflite::Offset(input_shape, batch, in_y,
+                                                 in_x[0], in_channel)],
+                      input_depth * stride_width);
+                  vld_b_s_xx_m(
+                      v4,
+                      &input_data[tflite::Offset(input_shape, batch, in_y,
+                                                 in_x[4], in_channel)],
+                      input_depth * stride_width);
+                } else {
+                  vdup_b_x(v0, -input_offset);
+                  vdup_b_x(v7, -input_offset);
+                  vld_b_s_xx_m(
+                      v1,
+                      &input_data[tflite::Offset(input_shape, batch, in_y,
+                                                 in_x[1], in_channel)],
+                      input_depth * stride_width);
+                  vld_b_s_xx(
+                      v5,
+                      &input_data[tflite::Offset(input_shape, batch, in_y,
+                                                 in_x[5], in_channel)],
+                      input_depth * stride_width);
+                  vld_b_s_xx(
+                      v6,
+                      &input_data[tflite::Offset(input_shape, batch, in_y,
+                                                 in_x[6], in_channel)],
+                      input_depth * stride_width);
+                }
+                size_t local_filter_offset =
+                    (filter_y * filter_width * 8 * input_depth) +
+                    (filter_x * 8 * input_depth) + (in_channel * 8);
+                int8_t* p_local_filter_start =
+                    p_swizzled_filter_data + local_filter_offset;
+                vld_b_p_x_m(v8, p_local_filter_start);
+                vld_b_x_m(v12, p_local_filter_start);
+
+                aconv_vxv(v48, v0, cmds, v8);
+              }
+            }
+          }
+          vcget(v48);
+          vadd_w_vv_m(v48, v48, v16);
+          vadd_w_vv_m(v52, v52, v16);
+          vdmulh_w_r_vv_m(v48, v48, v20);
+          vdmulh_w_r_vv_m(v52, v52, v20);
+          vsha_w_r_vv_m(v48, v48, v24);
+          vsha_w_r_vv_m(v52, v52, v24);
+          vadd_w_vx_m(v48, v48, output_offset);
+          vadd_w_vx_m(v52, v52, output_offset);
+          vmin_w_vx_m(v48, v48, output_activation_max);
+          vmin_w_vx_m(v52, v52, output_activation_max);
+          vmax_w_vx_m(v48, v48, output_activation_min);
+          vmax_w_vx_m(v52, v52, output_activation_min);
+          vsraqs_b_vx(v56, v48, 0);
+          vsraqs_b_vx(v57, v52, 0);
+          vstq_b_s_xx(v56,
+                      &output_data[tflite::Offset(output_shape, batch, out_y,
+                                                  out_x, out_channel)],
+                      output_depth);
+          vstq_b_s_xx(v57,
+                      &output_data[tflite::Offset(output_shape, batch, out_y,
+                                                  out_x + 4, out_channel)],
+                      output_depth);
+        }
+      }
+    }
+  }
+}
+
+// Fixed-point per-channel-quantization convolution reference kernel.
+void conv_per_channel_filterd32(
+    const tflite::ConvParams& params, const int32_t* output_multiplier,
+    const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
+    const int8_t* input_data, const tflite::RuntimeShape& filter_shape,
+    const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
+    const int32_t* bias_data, const tflite::RuntimeShape& output_shape,
+    int8_t* output_data) {
+  // Get parameters.
+  const int32_t input_offset = params.input_offset;  // r = s(q - Z)
+  const int stride_width = params.stride_width;
+  const int stride_height = params.stride_height;
+  const int dilation_width_factor = params.dilation_width_factor;
+  const int dilation_height_factor = params.dilation_height_factor;
+  const int pad_width = params.padding_values.width;
+  const int pad_height = params.padding_values.height;
+  const int32_t output_offset = params.output_offset;
+
+  // Set min and max value of the output.
+  const int32_t output_activation_min = params.quantized_activation_min;
+  const int32_t output_activation_max = params.quantized_activation_max;
+
+  // Consistency check.
+  TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+  TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+  TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+  const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+  const int input_depth = input_shape.Dims(3);
+  const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
+  if (bias_data) {
+    TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
+  }
+
+  // Check dimensions of the tensors.
+  const int input_height = input_shape.Dims(1);
+  const int input_width = input_shape.Dims(2);
+  const int filter_height = filter_shape.Dims(1);
+  const int filter_width = filter_shape.Dims(2);
+  const int filter_input_depth = filter_shape.Dims(3);
+  const int groups = input_depth / filter_input_depth;
+  TFLITE_DCHECK_NE(groups, 0);
+  TFLITE_DCHECK_EQ(input_depth % filter_input_depth, 0);
+  const int filters_per_group = output_depth / groups;
+  TFLITE_DCHECK_NE(filters_per_group, 0);
+  const int output_height = output_shape.Dims(1);
+  const int output_width = output_shape.Dims(2);
+  for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
+    for (int batch = 0; batch < batches; ++batch) {
+      for (int out_y = 0; out_y < output_height; ++out_y) {
+        const int in_y_origin = (out_y * stride_height) - pad_height;
+        for (int out_x = 0; out_x < output_width; ++out_x) {
+          const int in_x_origin = (out_x * stride_width) - pad_width;
+          vdup_w_x_m(v60, 0);
+          int32_t acc = 0;
+          for (int in_channel = 0; in_channel + 32 <= filter_input_depth;
+               in_channel += 32) {
+            for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+              const int in_y = in_y_origin + dilation_height_factor * filter_y;
+              for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+                const int in_x = in_x_origin + dilation_width_factor * filter_x;
+
+                // Zero padding by omitting the areas outside the image.
+                const bool is_point_inside_image =
+                    (in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
+                    (in_y < input_height);
+
+                if (!is_point_inside_image) {
+                  continue;
+                }
+
+                vld_b_x(v0, &input_data[tflite::Offset(input_shape, batch, in_y,
+                                                       in_x, in_channel)]);
+                vaddw_h_vx(v0, v0, 0);
+                vadd_h_vx(v0, v0, static_cast<int16_t>(input_offset));
+                vadd_h_vx(v1, v1, static_cast<int16_t>(input_offset));
+                vld_b_x(v2, &filter_data[tflite::Offset(filter_shape,
+                                                        out_channel, filter_y,
+                                                        filter_x, in_channel)]);
+                vaddw_h_vx(v2, v2, 0);
+                vmulw_w_vv(v48, v0, v2);
+                vmulw_w_vv(v50, v1, v3);
+                vadd_w_vv_m(v60, v60, v48);
+              }
+            }
+          }
+          int32_t accumulators[32];
+          vst_w_x_m(v60, accumulators);
+          for (int i = 0; i < 32; ++i) {
+            acc += accumulators[i];
+          }
+
+          if (bias_data) {
+            acc += bias_data[out_channel];
+          }
+          acc = tflite::MultiplyByQuantizedMultiplier(
+              acc, output_multiplier[out_channel], output_shift[out_channel]);
+          acc += output_offset;
+          acc = std::max(acc, output_activation_min);
+          acc = std::min(acc, output_activation_max);
+          output_data[tflite::Offset(output_shape, batch, out_y, out_x,
+                                     out_channel)] = static_cast<int8_t>(acc);
+        }
+      }
+    }
+  }
+}
+
+void conv_per_channel_generic(
+    const tflite::ConvParams& params, const int32_t* output_multiplier,
+    const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
+    const int8_t* input_data, const tflite::RuntimeShape& filter_shape,
+    const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
+    const int32_t* bias_data, const tflite::RuntimeShape& output_shape,
+    int8_t* output_data) {
+  const auto batches = MatchingDim(input_shape, 0, output_shape, 0);
+  const auto stride_width = params.stride_width;
+  const auto stride_height = params.stride_height;
+  const auto dilation_width_factor = params.dilation_width_factor;
+  const auto dilation_height_factor = params.dilation_height_factor;
+  const auto pad_width = params.padding_values.width;
+  const auto pad_height = params.padding_values.height;
+  const auto input_height = input_shape.Dims(1);
+  const auto input_width = input_shape.Dims(2);
+  const auto input_depth = input_shape.Dims(3);
+  const auto input_offset = params.input_offset;
+  const auto filter_height = filter_shape.Dims(1);
+  const auto filter_width = filter_shape.Dims(2);
+  const auto filter_depth = filter_shape.Dims(3);
+  const auto output_height = output_shape.Dims(1);
+  const auto output_width = output_shape.Dims(2);
+  const auto output_depth = output_shape.Dims(3);
+  const auto output_offset = params.output_offset;
+  const auto output_activation_min = params.quantized_activation_min;
+  const auto output_activation_max = params.quantized_activation_max;
+  const auto groups = input_depth / filter_depth;
+  const auto filters_per_group = output_depth / groups;
+  union {
+    vconv_u8_t conv;
+    uint32_t raw;
+  } cmds;
+  cmds.conv.mode = 0;
+  cmds.conv.start = 0;
+  cmds.conv.stop = 7;
+  cmds.conv.sbias1 = input_offset;
+  cmds.conv.sdata1 = true;
+  cmds.conv.sbias2 = 0;
+  cmds.conv.sdata2 = true;
+
+  // Zero out accumulators.
+  vdup_b_x(v0, 0);
+  acset_v(v48, v0);
+  vdup_b_x_m(v48, 0);
+  for (int batch = 0; batch < batches; ++batch) {
+    for (int out_y = 0; out_y < output_height; ++out_y) {
+      const int in_y_origin = (out_y * stride_height) - pad_height;
+      for (int out_x = 0; out_x < output_width; /*out_x += 32*/ ++out_x) {
+        const int in_x_origin = (out_x * stride_width) - pad_width;
+        for (int out_channel = 0; out_channel < output_depth; ++out_channel) {
+          auto group = out_channel / filters_per_group;
+
+          for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+            const int in_y = in_y_origin + dilation_height_factor * filter_y;
+            const int in_x = in_x_origin + dilation_width_factor * 0;
+
+            // Zero padding by omitting the areas outside the image.
+            const bool is_point_inside_image =
+                (in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
+                (in_y < input_height);
+            if (!is_point_inside_image) {
+              continue;
+            }
+
+            int q = filter_width * filter_depth;
+            for (int i = 0; i < q; i += 32) {
+              int count = std::min(q - i, 32);
+              count = std::min(
+                  count, static_cast<int>((input_width - in_x) * filter_depth));
+              int input_offset = tflite::Offset(input_shape, batch, in_y, in_x,
+                                                group * filter_depth) +
+                                 i;
+              vdup_w_x_m(vm0, 0);
+              vdup_w_x_m(vm1, 0);
+              vld_b_l_xx(v0, &input_data[input_offset], count);
+              int filter_offset =
+                  tflite::Offset(filter_shape, out_channel, filter_y, 0, 0) + i;
+              vdup_w_x_m(v8, 0);
+              vdup_w_x_m(v12, 0);
+              if (count > 0) {
+                vld_b_l_xx(v8, &filter_data[filter_offset], std::min(count, 4));
+              }
+              if (count > 4) {
+                vld_b_l_xx(v9, &filter_data[filter_offset + 4],
+                           std::min(count - 4, 4));
+              }
+              if (count > 8) {
+                vld_b_l_xx(v10, &filter_data[filter_offset + 8],
+                           std::min(count - 8, 4));
+              }
+              if (count > 12) {
+                vld_b_l_xx(v11, &filter_data[filter_offset + 12],
+                           std::min(count - 12, 4));
+              }
+              if (count > 16) {
+                vld_b_l_xx(v12, &filter_data[filter_offset + 16],
+                           std::min(count - 16, 4));
+              }
+              if (count > 20) {
+                vld_b_l_xx(v13, &filter_data[filter_offset + 20],
+                           std::min(count - 20, 4));
+              }
+              if (count > 24) {
+                vld_b_l_xx(v14, &filter_data[filter_offset + 24],
+                           std::min(count - 24, 4));
+              }
+              if (count > 28) {
+                vld_b_l_xx(v15, &filter_data[filter_offset + 28],
+                           std::min(count - 28, 4));
+              }
+              aconv_vxv(v48, v0, cmds, v8);
+            }
+          }
+          vcget(v48);
+          vadd_w_vx_m(v48, v48, bias_data[out_channel]);
+          vsll_w_vx_m(v48, v48, LEFT_SHIFT(output_shift[out_channel]));
+          vdmulh_w_r_vx_m(v48, v48, output_multiplier[out_channel]);
+          vsha_w_r_vx_m(v48, v48, RIGHT_SHIFT(output_shift[out_channel]));
+          vadd_w_vx_m(v48, v48, output_offset);
+          vmin_w_vx_m(v48, v48, output_activation_max);
+          vmax_w_vx_m(v48, v48, output_activation_min);
+          vsraqs_b_vx(v56, v48, 0);
+          size_t output_offset =
+              tflite::Offset(output_shape, batch, out_y, out_x, out_channel);
+          vst_b_l_xx(v56, &output_data[output_offset], 1);
+        }
+      }
+    }
+  }
+}
+
+void conv_per_channel_b8(
+    const tflite::ConvParams& params, const int32_t* output_multiplier,
+    const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
+    const int8_t* input_data, const tflite::RuntimeShape& filter_shape,
+    const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
+    const int32_t* bias_data, const tflite::RuntimeShape& output_shape,
+    int8_t* output_data) {
+  const auto stride_width = params.stride_width;
+  const auto stride_height = params.stride_height;
+  const auto dilation_width_factor = params.dilation_width_factor;
+  const auto dilation_height_factor = params.dilation_height_factor;
+  const int pad_width = params.padding_values.width;
+  const int pad_height = params.padding_values.height;
+
+  if (dilation_width_factor == 1 && dilation_height_factor == 1 &&
+      stride_width <= 2 && stride_height <= 2) {
+    if (filter_shape.Dims(kFilterInputChannelIndex) % 32 == 0 &&
+        output_shape.Dims(kOutputChannelIndex) % 8 == 0 &&
+        output_shape.Dims(kOutputWidthIndex) % 8 == 0 && pad_width <= 1) {
+      conv_per_channel_pw1_ow8_id8_filterd32(
+          params, output_multiplier, output_shift, input_shape, input_data,
+          filter_shape, filter_data, bias_shape, bias_data, output_shape,
+          output_data);
+      return;
+    } else if (filter_shape.Dims(kFilterInputChannelIndex) % 32 == 0) {
+      conv_per_channel_filterd32(params, output_multiplier, output_shift,
+                                 input_shape, input_data, filter_shape,
+                                 filter_data, bias_shape, bias_data,
+                                 output_shape, output_data);
+      return;
+    }
+  }
+
+  if (stride_width == 1 && stride_height == 1 && dilation_width_factor == 1 &&
+      dilation_height_factor == 1) {
+    if (pad_width == 0 && pad_height == 0) {
+      conv_per_channel_generic(params, output_multiplier, output_shift,
+                               input_shape, input_data, filter_shape,
+                               filter_data, bias_shape, bias_data, output_shape,
+                               output_data);
+      return;
+    }
+  }
+
+  tflite::reference_integer_ops::ConvPerChannel(
+      params, output_multiplier, output_shift, input_shape, input_data,
+      filter_shape, filter_data, bias_shape, bias_data, output_shape,
+      output_data);
+}
+
+}  // namespace kelvin::opt