Create RUN_KERNEL macro for DepthwiseConvS8
- As in ConvS8, make a little macro for invoking the different flavours
of the kernel.
Change-Id: Iac354ca2928b64423f79ef1a8e5cf3a70d98a96e
diff --git a/tflm/opt/depthwise_conv_s8.cc b/tflm/opt/depthwise_conv_s8.cc
index 1b450b5..b839c6f 100644
--- a/tflm/opt/depthwise_conv_s8.cc
+++ b/tflm/opt/depthwise_conv_s8.cc
@@ -1236,40 +1236,38 @@
TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
+#define RUN_KERNEL(kernel) { \
+ kernel(params, output_multiplier, output_shift, input_shape, input_data, \
+ filter_shape, filter_data, bias_shape, bias_data, output_shape, \
+ output_data \
+ ); \
+ return; \
+}
+
if (depth_multiplier == 1 &&
dilation_height_factor == 1 && dilation_width_factor == 1 &&
stride_height <= 2 && stride_width <= 2) {
- // generic implementation by default
- auto fn = DepthwiseConvS8Generic;
-
// special case of output depth = 32n
if (output_depth % 32 == 0) {
if (filter_width == 5 && filter_height == 5) {
if (stride_width <= 1 && stride_height <= 1) {
- fn = DepthwiseConvS85x5D32_Stride1;
- } else {
- fn = DepthwiseConvS85x5D32;
+ RUN_KERNEL(DepthwiseConvS85x5D32_Stride1);
}
- } else if (filter_width == 3 && filter_height == 3 && pad_width <= 1 && pad_height <= 1 && stride_width == 1 && stride_height == 1) {
- fn = DepthwiseConvS83x3D32_Stride1;
- } else if (filter_width == 3 && filter_height == 3 && pad_width <= 1 && pad_height <= 1) {
- fn = DepthwiseConvS83x3D32;
- } else {
- fn = DepthwiseConvS8D32;
+ RUN_KERNEL(DepthwiseConvS85x5D32);
+ } if (filter_width == 3 && filter_height == 3 && pad_width <= 1 && pad_height <= 1 && stride_width == 1 && stride_height == 1) {
+ RUN_KERNEL(DepthwiseConvS83x3D32_Stride1);
+ } if (filter_width == 3 && filter_height == 3 && pad_width <= 1 && pad_height <= 1) {
+ RUN_KERNEL(DepthwiseConvS83x3D32);
}
+ RUN_KERNEL(DepthwiseConvS8D32);
}
- fn(params, output_multiplier, output_shift, input_shape, input_data,
- filter_shape, filter_data, bias_shape, bias_data, output_shape,
- output_data);
- return;
+ RUN_KERNEL(DepthwiseConvS8Generic);
}
- // Use reference implementation
- tflite::reference_integer_ops::DepthwiseConvPerChannel(
- params, output_multiplier, output_shift, input_shape, input_data,
- filter_shape, filter_data, bias_shape, bias_data, output_shape,
- output_data);
+ RUN_KERNEL(tflite::reference_integer_ops::DepthwiseConvPerChannel);
+
+#undef RUN_KERNEL
}
} // namespace kelvin::opt