feat: add fixed-capacity, statically-allocated type tflite::StaticVector (#2642)
Add a type, tflite::StaticVector, which behaves like std::vector, but
which avoids heap memory allocation.
BUG=#2636
diff --git a/tensorflow/lite/micro/BUILD b/tensorflow/lite/micro/BUILD
index 527b85a..97d6897 100644
--- a/tensorflow/lite/micro/BUILD
+++ b/tensorflow/lite/micro/BUILD
@@ -408,6 +408,15 @@
)
cc_library(
+ name = "static_vector",
+ hdrs = ["static_vector.h"],
+ copts = micro_copts(),
+ deps = [
+ "//tensorflow/lite/kernels:op_macros",
+ ],
+)
+
+cc_library(
name = "system_setup",
srcs = [
"system_setup.cc",
@@ -616,6 +625,18 @@
],
)
+cc_test(
+ name = "static_vector_test",
+ size = "small",
+ srcs = [
+ "static_vector_test.cc",
+ ],
+ deps = [
+ ":static_vector",
+ "//tensorflow/lite/micro/testing:micro_test",
+ ],
+)
+
bzl_library(
name = "build_def_bzl",
srcs = ["build_def.bzl"],
diff --git a/tensorflow/lite/micro/static_vector.h b/tensorflow/lite/micro/static_vector.h
new file mode 100644
index 0000000..8b9e063
--- /dev/null
+++ b/tensorflow/lite/micro/static_vector.h
@@ -0,0 +1,83 @@
+// Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef TENSORFLOW_LITE_MICRO_STATIC_VECTOR_H_
+#define TENSORFLOW_LITE_MICRO_STATIC_VECTOR_H_
+
+#include <array>
+#include <cassert>
+#include <cstddef>
+
+#include "tensorflow/lite/kernels/op_macros.h" // for TF_LITE_ASSERT
+
+namespace tflite {
+
+template <typename T, std::size_t MaxSize>
+class StaticVector {
+ // A staticlly-allocated vector. Add to the interface as needed.
+
+ private:
+ std::array<T, MaxSize> array_;
+ std::size_t size_{0};
+
+ public:
+ using iterator = typename decltype(array_)::iterator;
+ using const_iterator = typename decltype(array_)::const_iterator;
+ using pointer = typename decltype(array_)::pointer;
+ using reference = typename decltype(array_)::reference;
+ using const_reference = typename decltype(array_)::const_reference;
+
+ StaticVector() {}
+
+ StaticVector(std::initializer_list<T> values) {
+ for (const T& v : values) {
+ push_back(v);
+ }
+ }
+
+ static constexpr std::size_t max_size() { return MaxSize; }
+ std::size_t size() const { return size_; }
+ bool full() const { return size() == max_size(); }
+ iterator begin() { return array_.begin(); }
+ const_iterator begin() const { return array_.begin(); }
+ iterator end() { return begin() + size(); }
+ const_iterator end() const { return begin() + size(); }
+ pointer data() { return array_.data(); }
+ reference operator[](int i) { return array_[i]; }
+ const_reference operator[](int i) const { return array_[i]; }
+ void clear() { size_ = 0; }
+
+ template <std::size_t N>
+ bool operator==(const StaticVector<T, N>& other) const {
+ return std::equal(begin(), end(), other.begin(), other.end());
+ }
+
+ template <std::size_t N>
+ bool operator!=(const StaticVector<T, N>& other) const {
+ return !(*this == other);
+ }
+
+ void push_back(const T& t) {
+ TF_LITE_ASSERT(!full());
+ *end() = t;
+ ++size_;
+ }
+};
+
+template <typename T, typename... U>
+StaticVector(T, U...) -> StaticVector<T, 1 + sizeof...(U)>;
+
+} // end namespace tflite
+
+#endif // TENSORFLOW_LITE_MICRO_STATIC_VECTOR_H_
diff --git a/tensorflow/lite/micro/static_vector_test.cc b/tensorflow/lite/micro/static_vector_test.cc
new file mode 100644
index 0000000..6d601bc
--- /dev/null
+++ b/tensorflow/lite/micro/static_vector_test.cc
@@ -0,0 +1,82 @@
+// Copyright 2024 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "tensorflow/lite/micro/static_vector.h"
+
+#include "tensorflow/lite/micro/testing/micro_test.h"
+
+using tflite::StaticVector;
+
+TF_LITE_MICRO_TESTS_BEGIN
+
+TF_LITE_MICRO_TEST(StaticVectorPushBack) {
+ StaticVector<int, 4> a;
+ TF_LITE_MICRO_EXPECT(a.max_size() == 4);
+ TF_LITE_MICRO_EXPECT(a.size() == 0);
+
+ a.push_back(1);
+ TF_LITE_MICRO_EXPECT(a.size() == 1);
+ TF_LITE_MICRO_EXPECT(a[0] == 1);
+
+ a.push_back(2);
+ TF_LITE_MICRO_EXPECT(a.size() == 2);
+ TF_LITE_MICRO_EXPECT(a[1] == 2);
+
+ a.push_back(3);
+ TF_LITE_MICRO_EXPECT(a.size() == 3);
+ TF_LITE_MICRO_EXPECT(a[2] == 3);
+}
+
+TF_LITE_MICRO_TEST(StaticVectorInitializationPartial) {
+ const StaticVector<int, 4> a{1, 2, 3};
+ TF_LITE_MICRO_EXPECT(a.max_size() == 4);
+ TF_LITE_MICRO_EXPECT(a.size() == 3);
+ TF_LITE_MICRO_EXPECT(a[0] == 1);
+ TF_LITE_MICRO_EXPECT(a[1] == 2);
+ TF_LITE_MICRO_EXPECT(a[2] == 3);
+}
+
+TF_LITE_MICRO_TEST(StaticVectorInitializationFull) {
+ const StaticVector b{1, 2, 3};
+ TF_LITE_MICRO_EXPECT(b.max_size() == 3);
+ TF_LITE_MICRO_EXPECT(b.size() == 3);
+}
+
+TF_LITE_MICRO_TEST(StaticVectorEquality) {
+ const StaticVector a{1, 2, 3};
+ const StaticVector b{1, 2, 3};
+ TF_LITE_MICRO_EXPECT(a == b);
+ TF_LITE_MICRO_EXPECT(!(a != b));
+}
+
+TF_LITE_MICRO_TEST(StaticVectorInequality) {
+ const StaticVector a{1, 2, 3};
+ const StaticVector b{3, 2, 1};
+ TF_LITE_MICRO_EXPECT(a != b);
+ TF_LITE_MICRO_EXPECT(!(a == b));
+}
+
+TF_LITE_MICRO_TEST(StaticVectorSizeInequality) {
+ const StaticVector a{1, 2};
+ const StaticVector b{1, 2, 3};
+ TF_LITE_MICRO_EXPECT(a != b);
+}
+
+TF_LITE_MICRO_TEST(StaticVectorPartialSizeInequality) {
+ const StaticVector<int, 3> a{1, 2};
+ const StaticVector<int, 3> b{1, 2, 3};
+ TF_LITE_MICRO_EXPECT(a != b);
+}
+
+TF_LITE_MICRO_TESTS_END