TFLM compression changes (3rd) (#2658)

@tensorflow/micro

Updates to support TFLM compression:

MicroContext
MicroInterpreterContext
FakeMicroContext
KernelRunner

bug=#2657
diff --git a/tensorflow/lite/micro/fake_micro_context.cc b/tensorflow/lite/micro/fake_micro_context.cc
index 5787ffd..8874798 100644
--- a/tensorflow/lite/micro/fake_micro_context.cc
+++ b/tensorflow/lite/micro/fake_micro_context.cc
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -23,10 +23,23 @@
 
 namespace tflite {
 
-FakeMicroContext::FakeMicroContext(TfLiteTensor* tensors,
-                                   SingleArenaBufferAllocator* allocator,
-                                   MicroGraph* micro_graph)
-    : graph_(*micro_graph), tensors_(tensors), allocator_(allocator) {}
+FakeMicroContext::FakeMicroContext(
+    TfLiteTensor* tensors, SingleArenaBufferAllocator* allocator,
+    MicroGraph* micro_graph
+#ifdef USE_TFLM_COMPRESSION
+    ,
+    const CompressedTensorList* compressed_tensors
+#endif  // USE_TFLM_COMPRESSION
+    )
+    : graph_(*micro_graph),
+      tensors_(tensors),
+      allocator_(allocator)
+#ifdef USE_TFLM_COMPRESSION
+      ,
+      compressed_tensors_(compressed_tensors)
+#endif  // USE_TFLM_COMPRESSION
+{
+}
 
 TfLiteTensor* FakeMicroContext::AllocateTempTfLiteTensor(int tensor_index) {
   allocated_temp_count_++;
@@ -112,4 +125,60 @@
 
 MicroGraph& FakeMicroContext::graph() { return graph_; }
 
+#ifdef USE_TFLM_COMPRESSION
+
+// Available during Prepare & Eval. Returns false if tensor is not
+// compressed.
+bool FakeMicroContext::IsTensorCompressed(const TfLiteNode* node,
+                                          int tensor_idx) {
+  if (compressed_tensors_ != nullptr && tensor_idx < node->inputs->size) {
+    int index = node->inputs->data[tensor_idx];
+    if (index >= 0 && compressed_tensors_->tensors[index] != nullptr) {
+      return true;
+    }
+  }
+
+  return false;
+}
+
+// Only available during Prepare. The kernel is responsible for storing the
+// scratch buffer handle.
+int FakeMicroContext::AllocateDecompressionScratchBuffer(const TfLiteNode* node,
+                                                         int tensor_idx) {
+  if (compressed_tensors_ == nullptr || tensor_idx >= node->inputs->size) {
+    return -1;
+  }
+  int index = node->inputs->data[tensor_idx];
+  if (index < 0 || compressed_tensors_->tensors[index] == nullptr) {
+    return -1;
+  }
+  TfLiteTensor* tensor = &tensors_[index];
+  int scratch_index = -1;
+  TfLiteStatus result =
+      RequestScratchBufferInArena(tensor->bytes, &scratch_index);
+  if (result != kTfLiteOk) {
+    return -1;
+  }
+
+  return scratch_index;
+}
+
+// Available during Prepare & Eval. Returns nullptr if tensor is not
+// compressed.
+const CompressionTensorData* FakeMicroContext::GetTensorCompressionData(
+    const TfLiteNode* node, int tensor_idx) {
+  if (compressed_tensors_ == nullptr || tensor_idx >= node->inputs->size) {
+    return nullptr;
+  }
+
+  int index = node->inputs->data[tensor_idx];
+  if (index < 0) {
+    return nullptr;
+  }
+
+  return compressed_tensors_->tensors[index];
+}
+
+#endif  // USE_TFLM_COMPRESSION
+
 }  // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/kernel_runner.cc b/tensorflow/lite/micro/kernels/kernel_runner.cc
index 602778d..79824ef 100644
--- a/tensorflow/lite/micro/kernels/kernel_runner.cc
+++ b/tensorflow/lite/micro/kernels/kernel_runner.cc
@@ -1,4 +1,4 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -18,7 +18,6 @@
 #include "tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.h"
 #include "tensorflow/lite/micro/micro_arena_constants.h"
 #include "tensorflow/lite/micro/micro_log.h"
-#include "tensorflow/lite/micro/test_helpers.h"
 
 namespace tflite {
 namespace micro {
@@ -38,12 +37,22 @@
                            TfLiteTensor* tensors, int tensors_size,
                            TfLiteIntArray* inputs, TfLiteIntArray* outputs,
                            const void* builtin_data,
-                           TfLiteIntArray* intermediates)
+                           TfLiteIntArray* intermediates
+#ifdef USE_TFLM_COMPRESSION
+                           ,
+                           const CompressedTensorList* compressed_tensors
+#endif  // USE_TFLM_COMPRESSION
+                           )
     : registration_(registration),
       allocator_(SingleArenaBufferAllocator::Create(kKernelRunnerBuffer_,
                                                     kKernelRunnerBufferSize_)),
       mock_micro_graph_(allocator_),
-      fake_micro_context_(tensors, allocator_, &mock_micro_graph_) {
+      fake_micro_context_(tensors, allocator_, &mock_micro_graph_
+#ifdef USE_TFLM_COMPRESSION
+                          ,
+                          compressed_tensors
+#endif  // USE_TFLM_COMPRESSION
+      ) {
   // Prepare TfLiteContext:
   context_.impl_ = static_cast<void*>(&fake_micro_context_);
   context_.ReportError = MicroContextReportOpError;
diff --git a/tensorflow/lite/micro/micro_context.cc b/tensorflow/lite/micro/micro_context.cc
index 295b3c3..680dee8 100644
--- a/tensorflow/lite/micro/micro_context.cc
+++ b/tensorflow/lite/micro/micro_context.cc
@@ -1,4 +1,4 @@
-/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -18,8 +18,10 @@
 #include <cstdarg>
 #include <cstddef>
 
+#include "tensorflow/lite/kernels/internal/compatibility.h"
 #include "tensorflow/lite/micro/micro_common.h"
 #include "tensorflow/lite/micro/micro_log.h"
+#include "tensorflow/lite/micro/micro_utils.h"
 
 namespace tflite {
 namespace {
@@ -34,6 +36,103 @@
   return -1;
 }
 
+#ifdef USE_TFLM_COMPRESSION
+
+struct DecompressionState {
+  DecompressionState() = delete;
+
+  DecompressionState(const uint8_t* compressed_indices,
+                     const size_t count_indices,
+                     const CompressionTensorData& comp_data,
+                     const size_t num_channels)
+      : compressed_indices_(compressed_indices),
+        count_indices_(count_indices),
+        comp_data_(comp_data),
+        num_channels_(num_channels) {}
+
+  template <typename T>
+  T* DecompressToBuffer(void* buffer);
+
+  size_t GetNextTableIndex();
+  void UpdateBufferAndChannelIndex();
+
+ private:
+  const uint8_t* compressed_indices_;
+  const size_t count_indices_;
+  const CompressionTensorData& comp_data_;
+  const size_t num_channels_;
+  const size_t compressed_bit_width_ =
+      comp_data_.data.lut_data->compressed_bit_width;
+  size_t channel_ = 0;
+  size_t index_in_channel_ = 0;
+  const size_t elements_per_channel_ =
+      comp_data_.data.lut_data->use_alternate_axis
+          ? 1
+          : count_indices_ / num_channels_;
+  size_t buffer_index_ = 0;
+  size_t current_offset_ = 0;
+  size_t current_bits_remaining_ = 8;
+  uint8_t current_byte_ = compressed_indices_[0];
+};
+
+template <typename T>
+T* DecompressionState::DecompressToBuffer(void* buffer) {
+  while (buffer_index_ < count_indices_) {
+    const size_t table_index = GetNextTableIndex();
+    static_cast<T*>(buffer)[buffer_index_] =
+        static_cast<const T*>(comp_data_.data.lut_data->value_table)
+            [table_index +
+             (channel_ * comp_data_.data.lut_data->value_table_channel_stride)];
+    UpdateBufferAndChannelIndex();
+  }
+
+  return static_cast<T*>(buffer);
+}
+
+size_t DecompressionState::GetNextTableIndex() {
+  TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth);
+  TFLITE_DCHECK(compressed_bit_width_ > 0);
+
+  size_t table_index_bits_to_fill = compressed_bit_width_;
+  size_t table_index = 0;
+
+  while (table_index_bits_to_fill > 0) {
+    if (current_bits_remaining_ == 0) {
+      current_offset_++;
+      current_byte_ = compressed_indices_[current_offset_];
+      current_bits_remaining_ = 8;
+    }
+
+    const uint8_t mask_bit_count =
+        std::min(table_index_bits_to_fill,
+                 std::min(compressed_bit_width_, current_bits_remaining_));
+    const uint8_t current_byte_mask = (1 << mask_bit_count) - 1;
+    table_index <<= mask_bit_count;
+    table_index |=
+        (current_byte_ >> (current_bits_remaining_ - mask_bit_count)) &
+        current_byte_mask;
+
+    table_index_bits_to_fill -= mask_bit_count;
+    current_bits_remaining_ -= mask_bit_count;
+  }
+
+  return table_index;
+}
+
+void DecompressionState::UpdateBufferAndChannelIndex() {
+  buffer_index_++;
+  index_in_channel_++;
+  if (index_in_channel_ == elements_per_channel_) {
+    index_in_channel_ = 0;
+    channel_++;
+    if (channel_ == num_channels_) {
+      channel_ = 0;
+    }
+  }
+}
+
+#endif  // USE_TFLM_COMPRESSION
+
 }  // namespace
 
 TfLiteTensor* MicroContext::AllocateTempInputTensor(const TfLiteNode* node,
@@ -74,4 +173,56 @@
   va_end(args);
 }
 
+#ifdef USE_TFLM_COMPRESSION
+
+void* MicroContext::DecompressTensorToScratchBuffer(
+    const TfLiteEvalTensor& tensor,
+    const CompressionTensorData& compression_data, int scratch_buffer_handle) {
+  TFLITE_DCHECK(compression_data.scheme == CompressionScheme::kBinQuant);
+  TFLITE_DCHECK(scratch_buffer_handle != -1);
+  void* scratch_buffer = GetScratchBuffer(scratch_buffer_handle);
+  TFLITE_DCHECK(scratch_buffer != nullptr);
+  size_t count = ElementCount(*tensor.dims);
+  size_t num_channels = 1;
+
+  if (compression_data.data.lut_data->is_per_channel_quantized) {
+    const size_t channel_axis =
+        compression_data.data.lut_data->use_alternate_axis
+            ? tensor.dims->size - 1
+            : 0;
+    num_channels = tensor.dims->data[channel_axis];
+  }
+
+  DecompressionState ds(static_cast<uint8_t*>(tensor.data.data), count,
+                        compression_data, num_channels);
+
+  switch (tensor.type) {
+    case kTfLiteBool: {
+      return ds.DecompressToBuffer<bool>(scratch_buffer);
+    } break;
+    case kTfLiteInt8: {
+      return ds.DecompressToBuffer<int8_t>(scratch_buffer);
+    } break;
+    case kTfLiteInt16: {
+      return ds.DecompressToBuffer<int16_t>(scratch_buffer);
+    } break;
+    case kTfLiteInt32: {
+      return ds.DecompressToBuffer<int32_t>(scratch_buffer);
+    } break;
+    case kTfLiteInt64: {
+      return ds.DecompressToBuffer<int64_t>(scratch_buffer);
+    } break;
+    case kTfLiteFloat32: {
+      return ds.DecompressToBuffer<float>(scratch_buffer);
+    } break;
+    default: {
+      MicroPrintf("Unsupported decompression tensor type %d", tensor.type);
+    } break;
+  }
+
+  return nullptr;
+}
+
+#endif  // USE_TFLM_COMPRESSION
+
 }  // namespace tflite
diff --git a/tensorflow/lite/micro/micro_interpreter_context.cc b/tensorflow/lite/micro/micro_interpreter_context.cc
index 098df15..0ba461f 100644
--- a/tensorflow/lite/micro/micro_interpreter_context.cc
+++ b/tensorflow/lite/micro/micro_interpreter_context.cc
@@ -1,4 +1,4 @@
-/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -18,8 +18,28 @@
 #include <cstdint>
 
 #include "tensorflow/lite/kernels/internal/compatibility.h"
+#include "tensorflow/lite/micro/micro_utils.h"
 
 namespace tflite {
+
+namespace {
+
+#ifdef USE_TFLM_COMPRESSION
+
+int GetInputTensorIndex(const TfLiteNode* node, const int index) {
+  if (index >= 0 && index < node->inputs->size) {
+    const int tensor_index = node->inputs->data[index];
+    if (tensor_index != kTfLiteOptionalTensor) {
+      return tensor_index;
+    }
+  }
+  return -1;
+}
+
+#endif  // USE_TFLM_COMPRESSION
+
+}  // namespace
+
 MicroInterpreterContext::MicroInterpreterContext(MicroAllocator* allocator,
                                                  const Model* model,
                                                  MicroInterpreterGraph* graph)
@@ -106,4 +126,83 @@
   return state_;
 }
 
+#ifdef USE_TFLM_COMPRESSION
+
+// Available during Prepare & Eval. Returns false if tensor is not
+// compressed.
+bool MicroInterpreterContext::IsTensorCompressed(const TfLiteNode* node,
+                                                 int tensor_idx) {
+  TFLITE_DCHECK(state_ == InterpreterState::kPrepare ||
+                state_ == InterpreterState::kInvoke);
+
+  const SubgraphAllocations* allocations =
+      &graph_.GetAllocations()[graph_.GetCurrentSubgraphIndex()];
+  if (allocations->compressed.tensors == nullptr) {
+    return false;
+  }
+  int index = GetInputTensorIndex(node, tensor_idx);
+  if (index == -1) {
+    return false;
+  }
+  return allocations->compressed.tensors[index] != nullptr;
+}
+
+// Only available during Prepare. The kernel is responsible for storing the
+// scratch buffer handle.
+int MicroInterpreterContext::AllocateDecompressionScratchBuffer(
+    const TfLiteNode* node, int tensor_idx) {
+  TFLITE_DCHECK(state_ == InterpreterState::kPrepare);
+
+  const SubgraphAllocations* allocations =
+      &graph_.GetAllocations()[graph_.GetCurrentSubgraphIndex()];
+  if (allocations->compressed.tensors == nullptr) {
+    return -1;
+  }
+  int index = GetInputTensorIndex(node, tensor_idx);
+  if (index == -1 || allocations->compressed.tensors[index] == nullptr) {
+    return -1;
+  }
+  const TfLiteEvalTensor* tensor = &allocations->tensors[index];
+  const size_t byte_count = EvalTensorBytes(tensor);
+  int scratch_index = -1;
+  TfLiteStatus result = RequestScratchBufferInArena(byte_count, &scratch_index);
+  if (result != kTfLiteOk) {
+    return -1;
+  }
+
+  return scratch_index;
+}
+
+// Available during Prepare & Eval. Returns nullptr if tensor is not
+// compressed.
+const CompressionTensorData* MicroInterpreterContext::GetTensorCompressionData(
+    const TfLiteNode* node, int tensor_idx) {
+  TFLITE_DCHECK(state_ == InterpreterState::kPrepare ||
+                state_ == InterpreterState::kInvoke);
+
+  const SubgraphAllocations* allocations =
+      &graph_.GetAllocations()[graph_.GetCurrentSubgraphIndex()];
+  if (allocations->compressed.tensors == nullptr) {
+    return nullptr;
+  }
+  int index = GetInputTensorIndex(node, tensor_idx);
+  if (index == -1) {
+    return nullptr;
+  }
+  return allocations->compressed.tensors[index];
+}
+
+// Only available during Eval. Returns nullptr on failure, otherwise returns a
+// pointer to the scratch buffer.
+void* MicroInterpreterContext::DecompressTensorToScratchBuffer(
+    const TfLiteEvalTensor& tensor,
+    const CompressionTensorData& compression_data, int scratch_buffer_handle) {
+  TFLITE_DCHECK(state_ == InterpreterState::kInvoke);
+
+  return MicroContext::DecompressTensorToScratchBuffer(tensor, compression_data,
+                                                       scratch_buffer_handle);
+}
+
+#endif  // USE_TFLM_COMPRESSION
+
 }  // namespace tflite