Fix in_channel calculation in ConvS8D4
BUG=345095360
Change-Id: I9692c4d9c1c3c0e8e80b2084bcfc02ae1220e525
diff --git a/tflm/opt/conv_s8_d4.cc b/tflm/opt/conv_s8_d4.cc
index 4ee4d33..9ee6edb 100644
--- a/tflm/opt/conv_s8_d4.cc
+++ b/tflm/opt/conv_s8_d4.cc
@@ -169,7 +169,7 @@
acset_v(v48, v48);
int in_channel = 0;
do {
- int in_channels_this_iter = std::min(filter_input_depth, 32);
+ int in_channels_this_iter = std::min(filter_input_depth - in_channel, 32);
// Calculate first valid filter_y
int filter_y = 0;
{
@@ -456,7 +456,7 @@
int in_channel = 0;
while (in_channel < filter_input_depth) {
- int in_channels_this_iter = std::min(filter_input_depth, 32);
+ int in_channels_this_iter = std::min(filter_input_depth - in_channel, 32);
// Calculate first valid filter_y
int filter_y = 0;
{
@@ -556,7 +556,7 @@
acset_v(v48, v48);
int in_channel = 0;
while (in_channel < filter_input_depth) {
- int in_channels_this_iter = std::min(filter_input_depth, 32);
+ int in_channels_this_iter = std::min(filter_input_depth - in_channel, 32);
cmds.conv.stop = (in_channels_this_iter / 4) - 1;
// Calculate first valid filter_y
@@ -693,7 +693,7 @@
int in_channel = 0;
while (in_channel < filter_input_depth) {
- int in_channels_this_iter = std::min(filter_input_depth, 32);
+ int in_channels_this_iter = std::min(filter_input_depth - in_channel, 32);
// Calculate first valid filter_y
int filter_y = 0;
{