Refactor kernel selection.
Early return in priority selection.
Change-Id: I435665df8ca29f5e148038f3836fdf53a4d33e2d
diff --git a/tflm/opt/conv_s8.cc b/tflm/opt/conv_s8.cc
index ce56fda..d3529ef 100644
--- a/tflm/opt/conv_s8.cc
+++ b/tflm/opt/conv_s8.cc
@@ -192,46 +192,49 @@
const auto output_width = output_shape.Dims(2);
const auto output_depth = output_shape.Dims(3);
- // use generic implementation by default
- auto fn = ConvS8Generic;
+#define RUN_KERNEL(kernel) {\
+ kernel(\
+ params, output_multiplier, output_shift, input_shape, input_data,\
+ filter_shape, filter_data, bias_shape, bias_data, output_shape,\
+ output_data);\
+ return; \
+}
// special case of filter_depth = 4n
if (dilation_width_factor == 1 && dilation_height_factor == 1 &&
stride_width <= 2 && stride_height <= 2 && filter_depth % 4 == 0 &&
output_depth >= 8 && output_width >= 8 && pad_width <= 1) {
- fn = kelvin::opt::ConvS8D4;
+ RUN_KERNEL(kelvin::opt::ConvS8D4);
}
// special case of filter depth = 32n
- else if (dilation_width_factor == 1 && dilation_height_factor == 1 &&
+ if (dilation_width_factor == 1 && dilation_height_factor == 1 &&
stride_width <= 2 && stride_height <= 2 && filter_depth % 32 == 0) {
- fn = kelvin::opt::ConvS8D32;
+ RUN_KERNEL(kelvin::opt::ConvS8D32);
}
// special case of filter size 1x1
- else if (filter_height == 1 && filter_width == 1 && stride_height == 1 &&
+ if (filter_height == 1 && filter_width == 1 && stride_height == 1 &&
stride_width == 1 && dilation_height_factor == 1 &&
dilation_width_factor == 1 && pad_height == 0 && pad_width == 0 &&
(output_depth % 8) == 0 && (input_depth % 32) == 0) {
// TODO(ndodda): uncomment it when all tests are passed
- // fn = kelvin::opt::ConvS8K1x1;
+ // RUN_KERNEL(kelvin::opt::ConvS8K1x1);
}
// special case of filter size 48x3x1x48
- else if (batches == 1 && filter_height == 3 && filter_width == 1 &&
+ if (batches == 1 && filter_height == 3 && filter_width == 1 &&
input_width == 1 && input_depth == 48 && output_depth == 48 &&
stride_height == 1 && stride_width == 1 && dilation_height_factor == 1 &&
dilation_width_factor == 1 && pad_height == 0 && pad_width == 0) {
- fn = kelvin::opt::ConvS8K3x1D48;
+ RUN_KERNEL(kelvin::opt::ConvS8K3x1D48);
}
- else if (input_depth == 1 && ((output_depth % 4) == 0)) {
- fn = kelvin::opt::ConvPerChannelD1;
+ if (input_depth == 1 && ((output_depth % 4) == 0)) {
+ RUN_KERNEL(kelvin::opt::ConvPerChannelD1);
}
- fn(params, output_multiplier, output_shift, input_shape, input_data,
- filter_shape, filter_data, bias_shape, bias_data, output_shape,
- output_data);
+ RUN_KERNEL(ConvS8Generic);
}
} // namespace kelvin::opt