Fix in_channel calculation in ConvS8D4 BUG=345095360 Change-Id: I9692c4d9c1c3c0e8e80b2084bcfc02ae1220e525
diff --git a/tflm/opt/conv_s8_d4.cc b/tflm/opt/conv_s8_d4.cc index 4ee4d33..9ee6edb 100644 --- a/tflm/opt/conv_s8_d4.cc +++ b/tflm/opt/conv_s8_d4.cc
@@ -169,7 +169,7 @@ acset_v(v48, v48); int in_channel = 0; do { - int in_channels_this_iter = std::min(filter_input_depth, 32); + int in_channels_this_iter = std::min(filter_input_depth - in_channel, 32); // Calculate first valid filter_y int filter_y = 0; { @@ -456,7 +456,7 @@ int in_channel = 0; while (in_channel < filter_input_depth) { - int in_channels_this_iter = std::min(filter_input_depth, 32); + int in_channels_this_iter = std::min(filter_input_depth - in_channel, 32); // Calculate first valid filter_y int filter_y = 0; { @@ -556,7 +556,7 @@ acset_v(v48, v48); int in_channel = 0; while (in_channel < filter_input_depth) { - int in_channels_this_iter = std::min(filter_input_depth, 32); + int in_channels_this_iter = std::min(filter_input_depth - in_channel, 32); cmds.conv.stop = (in_channels_this_iter / 4) - 1; // Calculate first valid filter_y @@ -693,7 +693,7 @@ int in_channel = 0; while (in_channel < filter_input_depth) { - int in_channels_this_iter = std::min(filter_input_depth, 32); + int in_channels_this_iter = std::min(filter_input_depth - in_channel, 32); // Calculate first valid filter_y int filter_y = 0; {