Consolidate kernel selection behind generic conv_per_channel_b64. This keeps selection of the correct specialized conv kernel in one repo. Change-Id: Ib4dbe13e561b3acec81abb5c39a7fcb0a196671e
diff --git a/tflm/opt/conv.cc b/tflm/opt/conv.cc index df4dc89..8d33848 100644 --- a/tflm/opt/conv.cc +++ b/tflm/opt/conv.cc
@@ -34,6 +34,12 @@ 3, 7, 11, 15, }; /* clang-format on */ + +constexpr int kFilterHeightIndex = 1; +constexpr int kFilterWidthIndex = 2; +constexpr int kFilterInputChannelIndex = 3; +constexpr int kInputChannelIndex = 3; +constexpr int kOutputChannelIndex = 3; } // namespace void conv_per_channel_b32( @@ -430,7 +436,7 @@ } } -void conv_per_channel_b64( +void conv_per_channel_b64_generic( const tflite::ConvParams& params, const int32_t* output_multiplier, const int32_t* output_shift, const tflite::RuntimeShape& input_shape, const int16_t* input_data, const tflite::RuntimeShape& filter_shape, @@ -532,6 +538,52 @@ } } +void conv_per_channel_b64( + const tflite::ConvParams& params, const int32_t* output_multiplier, + const int32_t* output_shift, const tflite::RuntimeShape& input_shape, + const int16_t* input_data, const tflite::RuntimeShape& filter_shape, + const int8_t* filter_data, const tflite::RuntimeShape& bias_shape, + const int64_t* bias_data, const tflite::RuntimeShape& output_shape, + int16_t* output_data) { + if (filter_shape.Dims(kFilterHeightIndex) == 1 && + output_shape.Dims(kOutputChannelIndex) % 2 == 0) { + if (filter_shape.Dims(kFilterWidthIndex) == 1 && + filter_shape.Dims(kFilterInputChannelIndex) % 32 == 0) { + kelvin::opt::conv_per_channel_b64_1x1( + params, output_multiplier, output_shift, input_shape, input_data, + filter_shape, filter_data, bias_shape, bias_data, output_shape, + output_data); + return; + } + + // TODO(derekjchow): Check for valid padding + bool group_conv = !(input_shape.Dims(kInputChannelIndex) == + filter_shape.Dims(kFilterInputChannelIndex)); + int32_t fan_in = filter_shape.Dims(kFilterWidthIndex) * + filter_shape.Dims(kFilterInputChannelIndex); + if (!group_conv && fan_in % 32 == 0) { + kelvin::opt::conv_per_channel_b64_filter1xn_non_group( + params, output_multiplier, output_shift, input_shape, input_data, + filter_shape, filter_data, bias_shape, bias_data, output_shape, + output_data); + return; + } + + if (fan_in % 32 == 0) { + kelvin::opt::conv_per_channel_b64_filter1xn_group( + params, output_multiplier, output_shift, input_shape, input_data, + filter_shape, filter_data, bias_shape, bias_data, output_shape, + output_data); + return; + } + } + + kelvin::opt::conv_per_channel_b64_generic( + params, output_multiplier, output_shift, input_shape, input_data, + filter_shape, filter_data, bias_shape, bias_data, output_shape, + output_data); +} + #define INA0 v0 #define FLTA0 v8 #define FLTA1 v9
diff --git a/tflm/opt/opt.h b/tflm/opt/opt.h index 6797f36..e4268c0 100644 --- a/tflm/opt/opt.h +++ b/tflm/opt/opt.h
@@ -71,30 +71,7 @@ const int32_t* bias_data, const tflite::RuntimeShape& output_shape, int16_t* output_data); -void conv_per_channel_b64_1x1( - const tflite::ConvParams& params, const int32_t* output_multiplier, - const int32_t* output_shift, const tflite::RuntimeShape& input_shape, - const int16_t* input_data, const tflite::RuntimeShape& filter_shape, - const int8_t* filter_data, const tflite::RuntimeShape& bias_shape, - const int64_t* bias_data, const tflite::RuntimeShape& output_shape, - int16_t* output_data); - -void conv_per_channel_b64_filter1xn_non_group( - const tflite::ConvParams& params, const int32_t* output_multiplier, - const int32_t* output_shift, const tflite::RuntimeShape& input_shape, - const int16_t* input_data, const tflite::RuntimeShape& filter_shape, - const int8_t* filter_data, const tflite::RuntimeShape& bias_shape, - const int64_t* bias_data, const tflite::RuntimeShape& output_shape, - int16_t* output_data); - -void conv_per_channel_b64_filter1xn_group( - const tflite::ConvParams& params, const int32_t* output_multiplier, - const int32_t* output_shift, const tflite::RuntimeShape& input_shape, - const int16_t* input_data, const tflite::RuntimeShape& filter_shape, - const int8_t* filter_data, const tflite::RuntimeShape& bias_shape, - const int64_t* bias_data, const tflite::RuntimeShape& output_shape, - int16_t* output_data); - +// Top level conv function, will invoke correct variant below. void conv_per_channel_b64( const tflite::ConvParams& params, const int32_t* output_multiplier, const int32_t* output_shift, const tflite::RuntimeShape& input_shape, @@ -102,6 +79,35 @@ const int8_t* filter_data, const tflite::RuntimeShape& bias_shape, const int64_t* bias_data, const tflite::RuntimeShape& output_shape, int16_t* output_data); +void conv_per_channel_b64_1x1( + const tflite::ConvParams& params, const int32_t* output_multiplier, + const int32_t* output_shift, const tflite::RuntimeShape& input_shape, + const int16_t* input_data, const tflite::RuntimeShape& filter_shape, + const int8_t* filter_data, const tflite::RuntimeShape& bias_shape, + const int64_t* bias_data, const tflite::RuntimeShape& output_shape, + int16_t* output_data); +void conv_per_channel_b64_filter1xn_non_group( + const tflite::ConvParams& params, const int32_t* output_multiplier, + const int32_t* output_shift, const tflite::RuntimeShape& input_shape, + const int16_t* input_data, const tflite::RuntimeShape& filter_shape, + const int8_t* filter_data, const tflite::RuntimeShape& bias_shape, + const int64_t* bias_data, const tflite::RuntimeShape& output_shape, + int16_t* output_data); +void conv_per_channel_b64_filter1xn_group( + const tflite::ConvParams& params, const int32_t* output_multiplier, + const int32_t* output_shift, const tflite::RuntimeShape& input_shape, + const int16_t* input_data, const tflite::RuntimeShape& filter_shape, + const int8_t* filter_data, const tflite::RuntimeShape& bias_shape, + const int64_t* bias_data, const tflite::RuntimeShape& output_shape, + int16_t* output_data); +void conv_per_channel_b64_generic( + const tflite::ConvParams& params, const int32_t* output_multiplier, + const int32_t* output_shift, const tflite::RuntimeShape& input_shape, + const int16_t* input_data, const tflite::RuntimeShape& filter_shape, + const int8_t* filter_data, const tflite::RuntimeShape& bias_shape, + const int64_t* bias_data, const tflite::RuntimeShape& output_shape, + int16_t* output_data); + void conv_per_channel_b8( const tflite::ConvParams& params, const int32_t* output_multiplier, const int32_t* output_shift, const tflite::RuntimeShape& input_shape,