Check whether the bias tensor is `nullptr` before accessing the type. (#2566)
Prevents program crash due to null pointer dereference.
BUG=none
diff --git a/tensorflow/lite/micro/kernels/transpose_conv.cc b/tensorflow/lite/micro/kernels/transpose_conv.cc
index cd61660..ea0efae 100644
--- a/tensorflow/lite/micro/kernels/transpose_conv.cc
+++ b/tensorflow/lite/micro/kernels/transpose_conv.cc
@@ -15,6 +15,9 @@
#include "tensorflow/lite/kernels/internal/reference/transpose_conv.h"
+#include <cstddef>
+#include <cstdint>
+
#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
@@ -48,8 +51,9 @@
// A scratch buffer is required for quantized implementations.
int scratch_buffer_index;
- // TODO(b/192090531): Remove this once all 8x16 transpose conv models use
- // 64-bit biases.
+ // Index to the converted 64-bit bias buffer from 16-bit bias. This is
+ // required to handle 16x8 transpose convolutions where a 16-bit bias is
+ // provided, whereas the kernel expects 64-bit biases.
int bias_converted_buffer_index;
// Multiplier and shift arrays are required for the int8 implementation.
@@ -123,7 +127,9 @@
if (input->type == kTfLiteInt16) {
TFLITE_DCHECK(filter->type == kTfLiteInt8);
TFLITE_DCHECK(output->type == kTfLiteInt16);
- if (bias->type == kTfLiteInt16) {
+ // Handle the case where the bias is 16 bits for 16x8 transpose
+ // convolution where the kernel actually expects 64-bit biases.
+ if (bias != nullptr && bias->type == kTfLiteInt16) {
TFLITE_DCHECK(
context->RequestScratchBufferInArena(
context, GetTensorShape(bias).FlatSize() * sizeof(std::int64_t),
@@ -299,12 +305,10 @@
break;
}
case kTfLiteInt16: {
- std::int64_t* scratch_buffer = static_cast<int64_t*>(
+ auto* scratch_buffer = static_cast<int64_t*>(
context->GetScratchBuffer(context, data.scratch_buffer_index));
- // TODO(b/192090531): Remove this once all 8x16 transpose conv models use
- // 64-bit biases.
if (bias != nullptr && bias->type == kTfLiteInt16) {
- std::int64_t* bias_converted_buffer =
+ auto* bias_converted_buffer =
static_cast<int64_t*>(context->GetScratchBuffer(
context, data.bias_converted_buffer_index));
for (int i = 0; i < tflite::micro::GetTensorShape(bias).FlatSize();