Consolidate kernel selection behind generic conv_per_channel_b64.

This keeps selection of the correct specialized conv kernel in one
repo.

Change-Id: Ib4dbe13e561b3acec81abb5c39a7fcb0a196671e
diff --git a/tflm/opt/conv.cc b/tflm/opt/conv.cc
index df4dc89..8d33848 100644
--- a/tflm/opt/conv.cc
+++ b/tflm/opt/conv.cc
@@ -34,6 +34,12 @@
     3, 7, 11, 15,
 };
 /* clang-format on */
+
+constexpr int kFilterHeightIndex = 1;
+constexpr int kFilterWidthIndex = 2;
+constexpr int kFilterInputChannelIndex = 3;
+constexpr int kInputChannelIndex = 3;
+constexpr int kOutputChannelIndex = 3;
 }  // namespace
 
 void conv_per_channel_b32(
@@ -430,7 +436,7 @@
   }
 }
 
-void conv_per_channel_b64(
+void conv_per_channel_b64_generic(
     const tflite::ConvParams& params, const int32_t* output_multiplier,
     const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
     const int16_t* input_data, const tflite::RuntimeShape& filter_shape,
@@ -532,6 +538,52 @@
   }
 }
 
+void conv_per_channel_b64(
+    const tflite::ConvParams& params, const int32_t* output_multiplier,
+    const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
+    const int16_t* input_data, const tflite::RuntimeShape& filter_shape,
+    const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
+    const int64_t* bias_data, const tflite::RuntimeShape& output_shape,
+    int16_t* output_data) {
+  if (filter_shape.Dims(kFilterHeightIndex) == 1 &&
+      output_shape.Dims(kOutputChannelIndex) % 2 == 0) {
+    if (filter_shape.Dims(kFilterWidthIndex) == 1 &&
+        filter_shape.Dims(kFilterInputChannelIndex) % 32 == 0) {
+      kelvin::opt::conv_per_channel_b64_1x1(
+          params, output_multiplier, output_shift, input_shape, input_data,
+          filter_shape, filter_data, bias_shape, bias_data, output_shape,
+          output_data);
+      return;
+    }
+
+    // TODO(derekjchow): Check for valid padding
+    bool group_conv = !(input_shape.Dims(kInputChannelIndex) ==
+        filter_shape.Dims(kFilterInputChannelIndex));
+    int32_t fan_in = filter_shape.Dims(kFilterWidthIndex) *
+        filter_shape.Dims(kFilterInputChannelIndex);
+    if (!group_conv && fan_in % 32 == 0) {
+      kelvin::opt::conv_per_channel_b64_filter1xn_non_group(
+          params, output_multiplier, output_shift, input_shape, input_data,
+          filter_shape, filter_data, bias_shape, bias_data, output_shape,
+          output_data);
+      return;
+    }
+
+    if (fan_in % 32 == 0) {
+      kelvin::opt::conv_per_channel_b64_filter1xn_group(
+          params, output_multiplier, output_shift, input_shape, input_data,
+          filter_shape, filter_data, bias_shape, bias_data, output_shape,
+          output_data);
+      return;
+    }
+  }
+
+  kelvin::opt::conv_per_channel_b64_generic(
+      params, output_multiplier, output_shift, input_shape, input_data,
+      filter_shape, filter_data, bias_shape, bias_data, output_shape,
+      output_data);
+}
+
 #define INA0 v0
 #define FLTA0 v8
 #define FLTA1 v9
diff --git a/tflm/opt/opt.h b/tflm/opt/opt.h
index 6797f36..e4268c0 100644
--- a/tflm/opt/opt.h
+++ b/tflm/opt/opt.h
@@ -71,30 +71,7 @@
     const int32_t* bias_data, const tflite::RuntimeShape& output_shape,
     int16_t* output_data);
 
-void conv_per_channel_b64_1x1(
-    const tflite::ConvParams& params, const int32_t* output_multiplier,
-    const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
-    const int16_t* input_data, const tflite::RuntimeShape& filter_shape,
-    const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
-    const int64_t* bias_data, const tflite::RuntimeShape& output_shape,
-    int16_t* output_data);
-
-void conv_per_channel_b64_filter1xn_non_group(
-    const tflite::ConvParams& params, const int32_t* output_multiplier,
-    const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
-    const int16_t* input_data, const tflite::RuntimeShape& filter_shape,
-    const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
-    const int64_t* bias_data, const tflite::RuntimeShape& output_shape,
-    int16_t* output_data);
-
-void conv_per_channel_b64_filter1xn_group(
-    const tflite::ConvParams& params, const int32_t* output_multiplier,
-    const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
-    const int16_t* input_data, const tflite::RuntimeShape& filter_shape,
-    const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
-    const int64_t* bias_data, const tflite::RuntimeShape& output_shape,
-    int16_t* output_data);
-
+// Top level conv function, will invoke correct variant below.
 void conv_per_channel_b64(
     const tflite::ConvParams& params, const int32_t* output_multiplier,
     const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
@@ -102,6 +79,35 @@
     const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
     const int64_t* bias_data, const tflite::RuntimeShape& output_shape,
     int16_t* output_data);
+void conv_per_channel_b64_1x1(
+    const tflite::ConvParams& params, const int32_t* output_multiplier,
+    const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
+    const int16_t* input_data, const tflite::RuntimeShape& filter_shape,
+    const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
+    const int64_t* bias_data, const tflite::RuntimeShape& output_shape,
+    int16_t* output_data);
+void conv_per_channel_b64_filter1xn_non_group(
+    const tflite::ConvParams& params, const int32_t* output_multiplier,
+    const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
+    const int16_t* input_data, const tflite::RuntimeShape& filter_shape,
+    const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
+    const int64_t* bias_data, const tflite::RuntimeShape& output_shape,
+    int16_t* output_data);
+void conv_per_channel_b64_filter1xn_group(
+    const tflite::ConvParams& params, const int32_t* output_multiplier,
+    const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
+    const int16_t* input_data, const tflite::RuntimeShape& filter_shape,
+    const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
+    const int64_t* bias_data, const tflite::RuntimeShape& output_shape,
+    int16_t* output_data);
+void conv_per_channel_b64_generic(
+    const tflite::ConvParams& params, const int32_t* output_multiplier,
+    const int32_t* output_shift, const tflite::RuntimeShape& input_shape,
+    const int16_t* input_data, const tflite::RuntimeShape& filter_shape,
+    const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
+    const int64_t* bias_data, const tflite::RuntimeShape& output_shape,
+    int16_t* output_data);
+
 void conv_per_channel_b8(
     const tflite::ConvParams& params, const int32_t* output_multiplier,
     const int32_t* output_shift, const tflite::RuntimeShape& input_shape,