TFLM compression changes (3rd) (#2658)
@tensorflow/micro
Updates to support TFLM compression:
MicroContext
MicroInterpreterContext
FakeMicroContext
KernelRunner
bug=#2657
diff --git a/tensorflow/lite/micro/fake_micro_context.cc b/tensorflow/lite/micro/fake_micro_context.cc
index 5787ffd..8874798 100644
--- a/tensorflow/lite/micro/fake_micro_context.cc
+++ b/tensorflow/lite/micro/fake_micro_context.cc
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -23,10 +23,23 @@
namespace tflite {
-FakeMicroContext::FakeMicroContext(TfLiteTensor* tensors,
- SingleArenaBufferAllocator* allocator,
- MicroGraph* micro_graph)
- : graph_(*micro_graph), tensors_(tensors), allocator_(allocator) {}
+FakeMicroContext::FakeMicroContext(
+ TfLiteTensor* tensors, SingleArenaBufferAllocator* allocator,
+ MicroGraph* micro_graph
+#ifdef USE_TFLM_COMPRESSION
+ ,
+ const CompressedTensorList* compressed_tensors
+#endif // USE_TFLM_COMPRESSION
+ )
+ : graph_(*micro_graph),
+ tensors_(tensors),
+ allocator_(allocator)
+#ifdef USE_TFLM_COMPRESSION
+ ,
+ compressed_tensors_(compressed_tensors)
+#endif // USE_TFLM_COMPRESSION
+{
+}
TfLiteTensor* FakeMicroContext::AllocateTempTfLiteTensor(int tensor_index) {
allocated_temp_count_++;
@@ -112,4 +125,60 @@
MicroGraph& FakeMicroContext::graph() { return graph_; }
+#ifdef USE_TFLM_COMPRESSION
+
+// Available during Prepare & Eval. Returns false if tensor is not
+// compressed.
+bool FakeMicroContext::IsTensorCompressed(const TfLiteNode* node,
+ int tensor_idx) {
+ if (compressed_tensors_ != nullptr && tensor_idx < node->inputs->size) {
+ int index = node->inputs->data[tensor_idx];
+ if (index >= 0 && compressed_tensors_->tensors[index] != nullptr) {
+ return true;
+ }
+ }
+
+ return false;
+}
+
+// Only available during Prepare. The kernel is responsible for storing the
+// scratch buffer handle.
+int FakeMicroContext::AllocateDecompressionScratchBuffer(const TfLiteNode* node,
+ int tensor_idx) {
+ if (compressed_tensors_ == nullptr || tensor_idx >= node->inputs->size) {
+ return -1;
+ }
+ int index = node->inputs->data[tensor_idx];
+ if (index < 0 || compressed_tensors_->tensors[index] == nullptr) {
+ return -1;
+ }
+ TfLiteTensor* tensor = &tensors_[index];
+ int scratch_index = -1;
+ TfLiteStatus result =
+ RequestScratchBufferInArena(tensor->bytes, &scratch_index);
+ if (result != kTfLiteOk) {
+ return -1;
+ }
+
+ return scratch_index;
+}
+
+// Available during Prepare & Eval. Returns nullptr if tensor is not
+// compressed.
+const CompressionTensorData* FakeMicroContext::GetTensorCompressionData(
+ const TfLiteNode* node, int tensor_idx) {
+ if (compressed_tensors_ == nullptr || tensor_idx >= node->inputs->size) {
+ return nullptr;
+ }
+
+ int index = node->inputs->data[tensor_idx];
+ if (index < 0) {
+ return nullptr;
+ }
+
+ return compressed_tensors_->tensors[index];
+}
+
+#endif // USE_TFLM_COMPRESSION
+
} // namespace tflite
diff --git a/tensorflow/lite/micro/kernels/kernel_runner.cc b/tensorflow/lite/micro/kernels/kernel_runner.cc
index 602778d..79824ef 100644
--- a/tensorflow/lite/micro/kernels/kernel_runner.cc
+++ b/tensorflow/lite/micro/kernels/kernel_runner.cc
@@ -1,4 +1,4 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -18,7 +18,6 @@
#include "tensorflow/lite/micro/arena_allocator/single_arena_buffer_allocator.h"
#include "tensorflow/lite/micro/micro_arena_constants.h"
#include "tensorflow/lite/micro/micro_log.h"
-#include "tensorflow/lite/micro/test_helpers.h"
namespace tflite {
namespace micro {
@@ -38,12 +37,22 @@
TfLiteTensor* tensors, int tensors_size,
TfLiteIntArray* inputs, TfLiteIntArray* outputs,
const void* builtin_data,
- TfLiteIntArray* intermediates)
+ TfLiteIntArray* intermediates
+#ifdef USE_TFLM_COMPRESSION
+ ,
+ const CompressedTensorList* compressed_tensors
+#endif // USE_TFLM_COMPRESSION
+ )
: registration_(registration),
allocator_(SingleArenaBufferAllocator::Create(kKernelRunnerBuffer_,
kKernelRunnerBufferSize_)),
mock_micro_graph_(allocator_),
- fake_micro_context_(tensors, allocator_, &mock_micro_graph_) {
+ fake_micro_context_(tensors, allocator_, &mock_micro_graph_
+#ifdef USE_TFLM_COMPRESSION
+ ,
+ compressed_tensors
+#endif // USE_TFLM_COMPRESSION
+ ) {
// Prepare TfLiteContext:
context_.impl_ = static_cast<void*>(&fake_micro_context_);
context_.ReportError = MicroContextReportOpError;
diff --git a/tensorflow/lite/micro/micro_context.cc b/tensorflow/lite/micro/micro_context.cc
index 295b3c3..680dee8 100644
--- a/tensorflow/lite/micro/micro_context.cc
+++ b/tensorflow/lite/micro/micro_context.cc
@@ -1,4 +1,4 @@
-/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -18,8 +18,10 @@
#include <cstdarg>
#include <cstddef>
+#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/micro/micro_common.h"
#include "tensorflow/lite/micro/micro_log.h"
+#include "tensorflow/lite/micro/micro_utils.h"
namespace tflite {
namespace {
@@ -34,6 +36,103 @@
return -1;
}
+#ifdef USE_TFLM_COMPRESSION
+
+struct DecompressionState {
+ DecompressionState() = delete;
+
+ DecompressionState(const uint8_t* compressed_indices,
+ const size_t count_indices,
+ const CompressionTensorData& comp_data,
+ const size_t num_channels)
+ : compressed_indices_(compressed_indices),
+ count_indices_(count_indices),
+ comp_data_(comp_data),
+ num_channels_(num_channels) {}
+
+ template <typename T>
+ T* DecompressToBuffer(void* buffer);
+
+ size_t GetNextTableIndex();
+ void UpdateBufferAndChannelIndex();
+
+ private:
+ const uint8_t* compressed_indices_;
+ const size_t count_indices_;
+ const CompressionTensorData& comp_data_;
+ const size_t num_channels_;
+ const size_t compressed_bit_width_ =
+ comp_data_.data.lut_data->compressed_bit_width;
+ size_t channel_ = 0;
+ size_t index_in_channel_ = 0;
+ const size_t elements_per_channel_ =
+ comp_data_.data.lut_data->use_alternate_axis
+ ? 1
+ : count_indices_ / num_channels_;
+ size_t buffer_index_ = 0;
+ size_t current_offset_ = 0;
+ size_t current_bits_remaining_ = 8;
+ uint8_t current_byte_ = compressed_indices_[0];
+};
+
+template <typename T>
+T* DecompressionState::DecompressToBuffer(void* buffer) {
+ while (buffer_index_ < count_indices_) {
+ const size_t table_index = GetNextTableIndex();
+ static_cast<T*>(buffer)[buffer_index_] =
+ static_cast<const T*>(comp_data_.data.lut_data->value_table)
+ [table_index +
+ (channel_ * comp_data_.data.lut_data->value_table_channel_stride)];
+ UpdateBufferAndChannelIndex();
+ }
+
+ return static_cast<T*>(buffer);
+}
+
+size_t DecompressionState::GetNextTableIndex() {
+ TFLITE_DCHECK(compressed_bit_width_ <= LookupTableData::kMaxBitWidth);
+ TFLITE_DCHECK(compressed_bit_width_ > 0);
+
+ size_t table_index_bits_to_fill = compressed_bit_width_;
+ size_t table_index = 0;
+
+ while (table_index_bits_to_fill > 0) {
+ if (current_bits_remaining_ == 0) {
+ current_offset_++;
+ current_byte_ = compressed_indices_[current_offset_];
+ current_bits_remaining_ = 8;
+ }
+
+ const uint8_t mask_bit_count =
+ std::min(table_index_bits_to_fill,
+ std::min(compressed_bit_width_, current_bits_remaining_));
+ const uint8_t current_byte_mask = (1 << mask_bit_count) - 1;
+ table_index <<= mask_bit_count;
+ table_index |=
+ (current_byte_ >> (current_bits_remaining_ - mask_bit_count)) &
+ current_byte_mask;
+
+ table_index_bits_to_fill -= mask_bit_count;
+ current_bits_remaining_ -= mask_bit_count;
+ }
+
+ return table_index;
+}
+
+void DecompressionState::UpdateBufferAndChannelIndex() {
+ buffer_index_++;
+ index_in_channel_++;
+ if (index_in_channel_ == elements_per_channel_) {
+ index_in_channel_ = 0;
+ channel_++;
+ if (channel_ == num_channels_) {
+ channel_ = 0;
+ }
+ }
+}
+
+#endif // USE_TFLM_COMPRESSION
+
} // namespace
TfLiteTensor* MicroContext::AllocateTempInputTensor(const TfLiteNode* node,
@@ -74,4 +173,56 @@
va_end(args);
}
+#ifdef USE_TFLM_COMPRESSION
+
+void* MicroContext::DecompressTensorToScratchBuffer(
+ const TfLiteEvalTensor& tensor,
+ const CompressionTensorData& compression_data, int scratch_buffer_handle) {
+ TFLITE_DCHECK(compression_data.scheme == CompressionScheme::kBinQuant);
+ TFLITE_DCHECK(scratch_buffer_handle != -1);
+ void* scratch_buffer = GetScratchBuffer(scratch_buffer_handle);
+ TFLITE_DCHECK(scratch_buffer != nullptr);
+ size_t count = ElementCount(*tensor.dims);
+ size_t num_channels = 1;
+
+ if (compression_data.data.lut_data->is_per_channel_quantized) {
+ const size_t channel_axis =
+ compression_data.data.lut_data->use_alternate_axis
+ ? tensor.dims->size - 1
+ : 0;
+ num_channels = tensor.dims->data[channel_axis];
+ }
+
+ DecompressionState ds(static_cast<uint8_t*>(tensor.data.data), count,
+ compression_data, num_channels);
+
+ switch (tensor.type) {
+ case kTfLiteBool: {
+ return ds.DecompressToBuffer<bool>(scratch_buffer);
+ } break;
+ case kTfLiteInt8: {
+ return ds.DecompressToBuffer<int8_t>(scratch_buffer);
+ } break;
+ case kTfLiteInt16: {
+ return ds.DecompressToBuffer<int16_t>(scratch_buffer);
+ } break;
+ case kTfLiteInt32: {
+ return ds.DecompressToBuffer<int32_t>(scratch_buffer);
+ } break;
+ case kTfLiteInt64: {
+ return ds.DecompressToBuffer<int64_t>(scratch_buffer);
+ } break;
+ case kTfLiteFloat32: {
+ return ds.DecompressToBuffer<float>(scratch_buffer);
+ } break;
+ default: {
+ MicroPrintf("Unsupported decompression tensor type %d", tensor.type);
+ } break;
+ }
+
+ return nullptr;
+}
+
+#endif // USE_TFLM_COMPRESSION
+
} // namespace tflite
diff --git a/tensorflow/lite/micro/micro_interpreter_context.cc b/tensorflow/lite/micro/micro_interpreter_context.cc
index 098df15..0ba461f 100644
--- a/tensorflow/lite/micro/micro_interpreter_context.cc
+++ b/tensorflow/lite/micro/micro_interpreter_context.cc
@@ -1,4 +1,4 @@
-/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -18,8 +18,28 @@
#include <cstdint>
#include "tensorflow/lite/kernels/internal/compatibility.h"
+#include "tensorflow/lite/micro/micro_utils.h"
namespace tflite {
+
+namespace {
+
+#ifdef USE_TFLM_COMPRESSION
+
+int GetInputTensorIndex(const TfLiteNode* node, const int index) {
+ if (index >= 0 && index < node->inputs->size) {
+ const int tensor_index = node->inputs->data[index];
+ if (tensor_index != kTfLiteOptionalTensor) {
+ return tensor_index;
+ }
+ }
+ return -1;
+}
+
+#endif // USE_TFLM_COMPRESSION
+
+} // namespace
+
MicroInterpreterContext::MicroInterpreterContext(MicroAllocator* allocator,
const Model* model,
MicroInterpreterGraph* graph)
@@ -106,4 +126,83 @@
return state_;
}
+#ifdef USE_TFLM_COMPRESSION
+
+// Available during Prepare & Eval. Returns false if tensor is not
+// compressed.
+bool MicroInterpreterContext::IsTensorCompressed(const TfLiteNode* node,
+ int tensor_idx) {
+ TFLITE_DCHECK(state_ == InterpreterState::kPrepare ||
+ state_ == InterpreterState::kInvoke);
+
+ const SubgraphAllocations* allocations =
+ &graph_.GetAllocations()[graph_.GetCurrentSubgraphIndex()];
+ if (allocations->compressed.tensors == nullptr) {
+ return false;
+ }
+ int index = GetInputTensorIndex(node, tensor_idx);
+ if (index == -1) {
+ return false;
+ }
+ return allocations->compressed.tensors[index] != nullptr;
+}
+
+// Only available during Prepare. The kernel is responsible for storing the
+// scratch buffer handle.
+int MicroInterpreterContext::AllocateDecompressionScratchBuffer(
+ const TfLiteNode* node, int tensor_idx) {
+ TFLITE_DCHECK(state_ == InterpreterState::kPrepare);
+
+ const SubgraphAllocations* allocations =
+ &graph_.GetAllocations()[graph_.GetCurrentSubgraphIndex()];
+ if (allocations->compressed.tensors == nullptr) {
+ return -1;
+ }
+ int index = GetInputTensorIndex(node, tensor_idx);
+ if (index == -1 || allocations->compressed.tensors[index] == nullptr) {
+ return -1;
+ }
+ const TfLiteEvalTensor* tensor = &allocations->tensors[index];
+ const size_t byte_count = EvalTensorBytes(tensor);
+ int scratch_index = -1;
+ TfLiteStatus result = RequestScratchBufferInArena(byte_count, &scratch_index);
+ if (result != kTfLiteOk) {
+ return -1;
+ }
+
+ return scratch_index;
+}
+
+// Available during Prepare & Eval. Returns nullptr if tensor is not
+// compressed.
+const CompressionTensorData* MicroInterpreterContext::GetTensorCompressionData(
+ const TfLiteNode* node, int tensor_idx) {
+ TFLITE_DCHECK(state_ == InterpreterState::kPrepare ||
+ state_ == InterpreterState::kInvoke);
+
+ const SubgraphAllocations* allocations =
+ &graph_.GetAllocations()[graph_.GetCurrentSubgraphIndex()];
+ if (allocations->compressed.tensors == nullptr) {
+ return nullptr;
+ }
+ int index = GetInputTensorIndex(node, tensor_idx);
+ if (index == -1) {
+ return nullptr;
+ }
+ return allocations->compressed.tensors[index];
+}
+
+// Only available during Eval. Returns nullptr on failure, otherwise returns a
+// pointer to the scratch buffer.
+void* MicroInterpreterContext::DecompressTensorToScratchBuffer(
+ const TfLiteEvalTensor& tensor,
+ const CompressionTensorData& compression_data, int scratch_buffer_handle) {
+ TFLITE_DCHECK(state_ == InterpreterState::kInvoke);
+
+ return MicroContext::DecompressTensorToScratchBuffer(tensor, compression_data,
+ scratch_buffer_handle);
+}
+
+#endif // USE_TFLM_COMPRESSION
+
} // namespace tflite