TFLM compression changes (#2647)
@tensorflow/micro
initial header file changes/additions.
bug=#2646
diff --git a/tensorflow/lite/micro/compression.h b/tensorflow/lite/micro/compression.h
new file mode 100644
index 0000000..43965c2
--- /dev/null
+++ b/tensorflow/lite/micro/compression.h
@@ -0,0 +1,70 @@
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_MICRO_MICRO_COMPRESSION_H_
+#define TENSORFLOW_LITE_MICRO_MICRO_COMPRESSION_H_
+
+#ifdef USE_TFLM_COMPRESSION
+
+#include "tensorflow/lite/c/common.h"
+
+namespace tflite {
+
+//
+// Compressed tensors
+//
+
+static constexpr const char* kCompressionMetadataString = "TFLM_COMPRESSION";
+
+enum class CompressionScheme : uint8_t {
+ kBinQuant,
+};
+
+// TODO(ddavis-2015): pack struct
+struct LookupTableData {
+ static constexpr size_t kMaxBitWidth = 7;
+ static constexpr size_t kMaxValueTableChannelStride = 128;
+
+ const void* value_table; // Pointer into FlatBuffer Values.
+ uint8_t value_table_channel_stride; // elements per channel
+ uint8_t compressed_bit_width : 3; // 1 to 7 bits
+ bool is_per_channel_quantized : 1; // tensor is per-channel quantized
+ bool use_alternate_axis : 1; // shape default channel:
+ // 0 = first, 1 = last
+ uint8_t reserved : 3;
+};
+
+union CompressionData {
+ LookupTableData* lut_data;
+};
+
+// TODO(ddavis-2015): pack struct
+struct CompressionTensorData {
+ CompressionScheme scheme;
+ CompressionData data;
+};
+
+// TODO(ddavis-2015): pack struct
+struct CompressedTensorList {
+ // Sparsely populated array with the same number of elements as there are
+ // tensors in the Subgraph. An alternative would include a tensor index in
+ // the struct for each and walk the list on look up. This could be slow.
+ CompressionTensorData** tensors;
+};
+
+} // namespace tflite
+
+#endif // USE_TFLM_COMPRESSION
+#endif // TENSORFLOW_LITE_MICRO_MICRO_COMPRESSION_H_
diff --git a/tensorflow/lite/micro/fake_micro_context.h b/tensorflow/lite/micro/fake_micro_context.h
index 46d8a9b..7cf9c68 100644
--- a/tensorflow/lite/micro/fake_micro_context.h
+++ b/tensorflow/lite/micro/fake_micro_context.h
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -30,7 +30,12 @@
~FakeMicroContext() = default;
FakeMicroContext(TfLiteTensor* tensors, SingleArenaBufferAllocator* allocator,
- MicroGraph* micro_graph);
+ MicroGraph* micro_graph
+#ifdef USE_TFLM_COMPRESSION
+ ,
+ const CompressedTensorList* compressed_tensors = nullptr
+#endif // USE_TFLM_COMPRESSION
+ );
void* AllocatePersistentBuffer(size_t bytes) override;
TfLiteStatus RequestScratchBufferInArena(size_t bytes,
@@ -50,6 +55,24 @@
void* external_context() override;
MicroGraph& graph() override;
+#ifdef USE_TFLM_COMPRESSION
+
+ // Available during Prepare & Eval. Returns false if tensor is not
+ // compressed.
+ bool IsTensorCompressed(const TfLiteNode* node, int tensor_idx) override;
+
+ // Only available during Prepare. The kernel is responsible for storing the
+ // scratch buffer handle.
+ int AllocateDecompressionScratchBuffer(const TfLiteNode* node,
+ int tensor_idx) override;
+
+ // Available during Prepare & Eval. Returns nullptr if tensor is not
+ // compressed.
+ const CompressionTensorData* GetTensorCompressionData(
+ const TfLiteNode* node, int tensor_idx) override;
+
+#endif // USE_TFLM_COMPRESSION
+
private:
static constexpr int kNumScratchBuffers_ = 12;
@@ -62,6 +85,15 @@
SingleArenaBufferAllocator* allocator_;
+#ifdef USE_TFLM_COMPRESSION
+
+ //
+ // Compression
+ //
+ const CompressedTensorList* compressed_tensors_;
+
+#endif // USE_TFLM_COMPRESSION
+
TF_LITE_REMOVE_VIRTUAL_DELETE
};
diff --git a/tensorflow/lite/micro/kernels/kernel_runner.h b/tensorflow/lite/micro/kernels/kernel_runner.h
index 25b97c1..8dbd7f8 100644
--- a/tensorflow/lite/micro/kernels/kernel_runner.h
+++ b/tensorflow/lite/micro/kernels/kernel_runner.h
@@ -1,4 +1,4 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -36,7 +36,12 @@
KernelRunner(const TFLMRegistration& registration, TfLiteTensor* tensors,
int tensors_size, TfLiteIntArray* inputs,
TfLiteIntArray* outputs, const void* builtin_data,
- TfLiteIntArray* intermediates = nullptr);
+ TfLiteIntArray* intermediates = nullptr
+#ifdef USE_TFLM_COMPRESSION
+ ,
+ const CompressedTensorList* compressed_tensors = nullptr
+#endif // USE_TFLM_COMPRESSION
+ );
// Calls init and prepare on the kernel (i.e. TFLMRegistration) struct.
// Any exceptions will be DebugLog'd and returned as a status code.
diff --git a/tensorflow/lite/micro/kernels/kernel_util.h b/tensorflow/lite/micro/kernels/kernel_util.h
index f14c927..977ed95 100644
--- a/tensorflow/lite/micro/kernels/kernel_util.h
+++ b/tensorflow/lite/micro/kernels/kernel_util.h
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -91,6 +91,31 @@
: reinterpret_cast<const T*>(tensor->data.raw);
}
+#ifdef USE_TFLM_COMPRESSION
+
+// Overloads existing GetTensorData. If not compressed, this will return
+// tensor->data.
+//
+// TODO(ddavis-2015): make micro_context a const pointer
+template <typename T>
+const T* GetTensorData(MicroContext* micro_context,
+ const TfLiteEvalTensor* tensor,
+ const CompressionTensorData* compression_data,
+ int scratch_buffer_handle) {
+ if (tensor == nullptr) {
+ return nullptr;
+ }
+ if (compression_data == nullptr) {
+ return reinterpret_cast<const T*>(tensor->data.data);
+ }
+
+ void* uncompressed_data = micro_context->DecompressTensorToScratchBuffer(
+ *tensor, *compression_data, scratch_buffer_handle);
+ return reinterpret_cast<const T*>(uncompressed_data);
+}
+
+#endif // USE_TFLM_COMPRESSION
+
// Returns the shape of a TfLiteEvalTensor struct.
const RuntimeShape GetTensorShape(const TfLiteEvalTensor* tensor);
diff --git a/tensorflow/lite/micro/micro_context.h b/tensorflow/lite/micro/micro_context.h
index 2dd3233..33cad89 100644
--- a/tensorflow/lite/micro/micro_context.h
+++ b/tensorflow/lite/micro/micro_context.h
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -19,6 +19,12 @@
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/micro_graph.h"
+#ifdef USE_TFLM_COMPRESSION
+
+#include "tensorflow/lite/micro/compression.h"
+
+#endif // USE_TFLM_COMPRESSION
+
namespace tflite {
// TODO(b/149795762): kTfLiteAbort cannot be part of the tflite TfLiteStatus.
const TfLiteStatus kTfLiteAbort = static_cast<TfLiteStatus>(15);
@@ -95,6 +101,30 @@
virtual MicroGraph& graph() = 0;
+#ifdef USE_TFLM_COMPRESSION
+
+ // Available during Prepare & Eval. Returns false if tensor is not
+ // compressed.
+ virtual bool IsTensorCompressed(const TfLiteNode* node, int tensor_idx) = 0;
+
+ // Only available during Prepare. The kernel is responsible for storing the
+ // scratch buffer handle.
+ virtual int AllocateDecompressionScratchBuffer(const TfLiteNode* node,
+ int tensor_idx) = 0;
+
+ // Available during Prepare & Eval. Returns nullptr if tensor is not
+ // compressed.
+ virtual const CompressionTensorData* GetTensorCompressionData(
+ const TfLiteNode* node, int tensor_idx) = 0;
+
+ // Only available during Eval. Returns nullptr on failure, otherwise returns a
+ // pointer to the scratch buffer.
+ virtual void* DecompressTensorToScratchBuffer(
+ const TfLiteEvalTensor& tensor,
+ const CompressionTensorData& compression_data, int scratch_buffer_handle);
+
+#endif // USE_TFLM_COMPRESSION
+
private:
TF_LITE_REMOVE_VIRTUAL_DELETE
};
diff --git a/tensorflow/lite/micro/micro_interpreter_context.h b/tensorflow/lite/micro/micro_interpreter_context.h
index 5986dc3..7b336aa 100644
--- a/tensorflow/lite/micro/micro_interpreter_context.h
+++ b/tensorflow/lite/micro/micro_interpreter_context.h
@@ -1,4 +1,4 @@
-/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -106,6 +106,31 @@
// housekeeping in MicroInterpreterContext.
void SetScratchBufferHandles(ScratchBufferHandle* scratch_buffer_handles);
+#ifdef USE_TFLM_COMPRESSION
+
+ // Available during Prepare & Eval. Returns false if tensor is not
+ // compressed.
+ bool IsTensorCompressed(const TfLiteNode* node, int tensor_idx) override;
+
+ // Only available during Prepare. The kernel is responsible for storing the
+ // scratch buffer handle.
+ int AllocateDecompressionScratchBuffer(const TfLiteNode* node,
+ int tensor_idx) override;
+
+ // Available during Prepare & Eval. Returns nullptr if tensor is not
+ // compressed.
+ const CompressionTensorData* GetTensorCompressionData(
+ const TfLiteNode* node, int tensor_idx) override;
+
+ // Only available during Eval. Returns nullptr on failure, otherwise returns a
+ // pointer to the scratch buffer.
+ void* DecompressTensorToScratchBuffer(
+ const TfLiteEvalTensor& tensor,
+ const CompressionTensorData& compression_data,
+ int scratch_buffer_handle) override;
+
+#endif // USE_TFLM_COMPRESSION
+
private:
MicroAllocator& allocator_;
MicroInterpreterGraph& graph_;