Restrict ConvS16B64K1xn to height 1
- ConvS16B64K1xn{Group/NonGroup} are designed to evaluate inputs with a
height of 1 -- reflect this in our kernel dispatch.
Change-Id: I262c5dbcc62af7e5235caf1815539939f64a9eba
diff --git a/tflm/opt/conv_s16_b64.cc b/tflm/opt/conv_s16_b64.cc
index 48823dd..57c6eb5 100644
--- a/tflm/opt/conv_s16_b64.cc
+++ b/tflm/opt/conv_s16_b64.cc
@@ -417,6 +417,7 @@
const int8_t* filter_data, const tflite::RuntimeShape& bias_shape,
const int64_t* bias_data, const tflite::RuntimeShape& output_shape,
int16_t* output_data) {
+ const auto input_height = input_shape.Dims(1);
const auto input_depth = input_shape.Dims(3);
const auto filter_height = filter_shape.Dims(1);
const auto filter_width = filter_shape.Dims(2);
@@ -436,12 +437,12 @@
// 1xn non group filter
bool group_conv = !(input_depth == filter_depth);
int32_t fan_in = filter_width * filter_depth;
- if (!group_conv && fan_in % 32 == 0) {
+ if (!group_conv && fan_in % 32 == 0 && input_height == 1) {
fn = ConvS16B64K1xnNonGroup;
}
// 1xn group filter
- if (fan_in % 32 == 0) {
+ if (fan_in % 32 == 0 && input_height == 1) {
fn = ConvS16B64K1xnGroup;
}
}