TFLM Person Detection
- Add WORKSPACE imports for Tensorflow Lite for Microcontrollers,
some python support repos, and patches to build for Kelvin.
- Add a sample application that executes the sample person detection
model on a fixed input.
Change-Id: Iee176e48b5316758c610b4a72a00e9ee041a4537
diff --git a/.bazelrc b/.bazelrc
index 99fc0bd..741adf4 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -6,3 +6,11 @@
build --incompatible_enable_cc_toolchain_resolution
build:kelvin --platforms=//platforms/riscv32:kelvin
+
+# Set preprocessor defines for tflite-micro.
+build --copt=-DTF_LITE_USE_GLOBAL_CMATH_FUNCTIONS
+build --copt=-DTF_LITE_USE_GLOBAL_MIN
+build --copt=-DTF_LITE_USE_GLOBAL_MAX
+build --copt=-DTF_LITE_MCU_DEBUG_LOG
+build --copt=-DTF_LITE_STATIC_MEMORY
+build:opt --copt=-DTF_LITE_STRIP_ERROR_STRINGS
diff --git a/WORKSPACE b/WORKSPACE
index 3909e03..5f7eedf 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -1,6 +1,6 @@
workspace(name = "kelvin_sw")
-load("//build_tools/bazel:repos.bzl", "kelvin_repos")
+load("//build_tools/bazel:repos.bzl", "kelvin_repos", "tflm_repos")
kelvin_repos()
@@ -8,3 +8,17 @@
load("//platforms:registration.bzl", "kelvin_register_toolchain")
kelvin_register_toolchain()
+
+tflm_repos()
+
+load("@tflite-micro//tensorflow:workspace.bzl", "tf_repositories")
+tf_repositories()
+
+load("@rules_python//python:pip.bzl", "pip_parse")
+pip_parse(
+ name = "tflm_pip_deps",
+ requirements_lock = "@tflite-micro//third_party:python_requirements.txt",
+)
+
+load("@tflm_pip_deps//:requirements.bzl", "install_deps")
+install_deps()
diff --git a/build_tools/bazel/kelvin.bzl b/build_tools/bazel/kelvin.bzl
index 6f7596d..e6fb102 100644
--- a/build_tools/bazel/kelvin.bzl
+++ b/build_tools/bazel/kelvin.bzl
@@ -223,3 +223,20 @@
"{}.elf".format(kelvin_elf),
],
)
+
+# From @tflite-micro//tensorflow/lite/micro/build_def.bzl, and paths
+# modified to point to the external repo.
+def generate_cc_arrays(name, src, out, visibility = None, tags = []):
+ native.genrule(
+ name = name,
+ srcs = [
+ src,
+ ],
+ outs = [
+ out,
+ ],
+ tags = tags,
+ cmd = "$(location @tflite-micro//tensorflow/lite/micro/tools:generate_cc_arrays) $@ $<",
+ tools = ["@tflite-micro//tensorflow/lite/micro/tools:generate_cc_arrays"],
+ visibility = visibility,
+ )
diff --git a/build_tools/bazel/repos.bzl b/build_tools/bazel/repos.bzl
index 9cad348..5cc5b76 100644
--- a/build_tools/bazel/repos.bzl
+++ b/build_tools/bazel/repos.bzl
@@ -1,7 +1,6 @@
"""Kelvin dependency repository setup."""
load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
-load("@bazel_tools//tools/build_defs/repo:git.bzl", "new_git_repository")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
def kelvin_repos():
@@ -32,3 +31,62 @@
urls = ["https://github.com/riscv-software-src/riscv-tests/archive/d4eaa5bd6674b51d3b9b24913713c4638e99cdd9.tar.gz"],
strip_prefix = "riscv-tests-d4eaa5bd6674b51d3b9b24913713c4638e99cdd9",
)
+
+def tflm_repos():
+ """Setup Tensorflow Lite For Microcontrollers repositories."""
+ # Tensorflow Lite for Microcontrollers
+ native.local_repository(
+ name = "tflite-micro",
+ path = "../../sw/tflite-micro",
+ )
+
+ maybe(
+ http_archive,
+ name = "gemmlowp",
+ sha256 = "43146e6f56cb5218a8caaab6b5d1601a083f1f31c06ff474a4378a7d35be9cfb", # SHARED_GEMMLOWP_SHA
+ strip_prefix = "gemmlowp-fda83bdc38b118cc6b56753bd540caa49e570745",
+ urls = [
+ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/gemmlowp/archive/fda83bdc38b118cc6b56753bd540caa49e570745.zip",
+ "https://github.com/google/gemmlowp/archive/fda83bdc38b118cc6b56753bd540caa49e570745.zip",
+ ],
+ patches = [
+ "@kelvin_sw//third_party/gemmlowp:pthread.patch",
+ ],
+ patch_args = [
+ "-p1",
+ ],
+ )
+
+ maybe(
+ http_archive,
+ name = "ruy",
+ sha256 = "da5ec0cc07472bdb21589b0b51c8f3d7f75d2ed6230b794912adf213838d289a",
+ strip_prefix = "ruy-54774a7a2cf85963777289193629d4bd42de4a59",
+ urls = [
+ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/ruy/archive/54774a7a2cf85963777289193629d4bd42de4a59.zip",
+ "https://github.com/google/ruy/archive/54774a7a2cf85963777289193629d4bd42de4a59.zip",
+ ],
+ build_file = "@tflite-micro//third_party/ruy:BUILD",
+ patches = [
+ "@kelvin_sw//third_party/ruy:pthread.patch",
+ ],
+ patch_args = [
+ "-p1",
+ ],
+ )
+
+ maybe(
+ http_archive,
+ name = "rules_python",
+ sha256 = "497ca47374f48c8b067d786b512ac10a276211810f4a580178ee9b9ad139323a",
+ strip_prefix = "rules_python-0.16.1",
+ url = "https://github.com/bazelbuild/rules_python/archive/refs/tags/0.16.1.tar.gz",
+ )
+
+ maybe(
+ http_archive,
+ name = "pybind11_bazel",
+ strip_prefix = "pybind11_bazel-faf56fb3df11287f26dbc66fdedf60a2fc2c6631",
+ urls = ["https://github.com/pybind/pybind11_bazel/archive/faf56fb3df11287f26dbc66fdedf60a2fc2c6631.zip"],
+ sha256 = "a185aa68c93b9f62c80fcb3aadc3c83c763854750dc3f38be1dadcb7be223837",
+ )
diff --git a/examples/tflm/person_detection/BUILD b/examples/tflm/person_detection/BUILD
new file mode 100644
index 0000000..4dcbb9a
--- /dev/null
+++ b/examples/tflm/person_detection/BUILD
@@ -0,0 +1,43 @@
+load("//build_tools/bazel:kelvin.bzl", "kelvin_binary", "generate_cc_arrays")
+package(default_visibility = ["//visibility:public"])
+
+kelvin_binary(
+ name = "person_detection",
+ srcs = [
+ "person_detection.cc",
+ "person_detect_tflite.cc",
+ "person_bmp.cc",
+ ],
+ hdrs = [
+ "person_bmp.h",
+ "person_detect_tflite.h",
+ ],
+ deps = [
+ "//crt:crt_header",
+ "@tflite-micro//tensorflow/lite/micro:micro_framework",
+ "@tflite-micro//tensorflow/lite/micro:system_setup",
+ ],
+)
+
+generate_cc_arrays(
+ name = "person_bmp_cc",
+ src = "@tflite-micro//tensorflow/lite/micro/examples/person_detection:testdata/person.bmp",
+ out = "person_bmp.cc",
+)
+generate_cc_arrays(
+ name = "person_bmp_h",
+ src = "@tflite-micro//tensorflow/lite/micro/examples/person_detection:testdata/person.bmp",
+ out = "person_bmp.h",
+)
+
+generate_cc_arrays(
+ name = "person_detect_tflite_cc",
+ src = "@tflite-micro//tensorflow/lite/micro/models:person_detect.tflite",
+ out = "person_detect_tflite.cc",
+)
+
+generate_cc_arrays(
+ name = "person_detect_tflite_h",
+ src = "@tflite-micro//tensorflow/lite/micro/models:person_detect.tflite",
+ out = "person_detect_tflite.h",
+)
diff --git a/examples/tflm/person_detection/person_detection.cc b/examples/tflm/person_detection/person_detection.cc
new file mode 100644
index 0000000..4b608c6
--- /dev/null
+++ b/examples/tflm/person_detection/person_detection.cc
@@ -0,0 +1,57 @@
+// Copyright 2023 Google LLC
+// Licensed under the Apache License, Version 2.0, see LICENSE for details.
+// SPDX-License-Identifier: Apache-2.0
+
+#include "examples/tflm/person_detection/person_bmp.h"
+#include "examples/tflm/person_detection/person_detect_tflite.h"
+#include "tensorflow/lite/micro/micro_interpreter.h"
+#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
+#include "tensorflow/lite/micro/system_setup.h"
+
+namespace {
+const tflite::Model* model = nullptr;
+tflite::MicroInterpreter* interpreter = nullptr;
+constexpr int kTensorArenaSize = 96 * 1024;
+uint8_t tensor_arena[kTensorArenaSize] __attribute__((aligned(64)));
+} // namespace
+
+extern "C" int main(int argc, char** argv) {
+ tflite::InitializeTarget();
+
+ model = tflite::GetModel(g_person_detect_model_data);
+
+ if (model->version() != TFLITE_SCHEMA_VERSION) {
+ return 1;
+ }
+
+ static tflite::MicroMutableOpResolver<5> micro_op_resolver;
+ micro_op_resolver.AddAveragePool2D();
+ micro_op_resolver.AddConv2D();
+ micro_op_resolver.AddDepthwiseConv2D();
+ micro_op_resolver.AddReshape();
+ micro_op_resolver.AddSoftmax();
+
+ static tflite::MicroInterpreter static_interpreter(
+ model, micro_op_resolver, tensor_arena, kTensorArenaSize);
+ interpreter = &static_interpreter;
+
+ TfLiteStatus allocate_status = interpreter->AllocateTensors();
+ if (allocate_status != kTfLiteOk) {
+ return 2;
+ }
+
+ TfLiteTensor* input = interpreter->input(0);
+ TfLiteTensor* output = interpreter->output(0);
+
+ memcpy(input->data.uint8, g_person_image_data, input->bytes);
+ TfLiteStatus invoke_status = interpreter->Invoke();
+ if (invoke_status != kTfLiteOk) {
+ return 3;
+ }
+
+ int8_t person = output->data.int8[1];
+ int8_t not_person = output->data.int8[0];
+ MicroPrintf("person: %d not_person: %d", person, not_person);
+
+ return 0;
+}
diff --git a/platforms/riscv32/features/BUILD b/platforms/riscv32/features/BUILD
index 98de11f..737a988 100644
--- a/platforms/riscv32/features/BUILD
+++ b/platforms/riscv32/features/BUILD
@@ -65,22 +65,51 @@
],
)
+feature(
+ name = "all_warnings",
+ enabled = True,
+ flag_sets = [
+ flag_set(
+ actions = CPP_ALL_COMPILE_ACTIONS + C_ALL_COMPILE_ACTIONS,
+ flag_groups = [
+ flag_group(
+ flags = [
+ "-Wall",
+ ],
+ ),
+ ],
+ ),
+ ],
+)
+
+feature(
+ name = "all_warnings_as_errors",
+ enabled = False,
+ flag_sets = [
+ flag_set(
+ actions = CPP_ALL_COMPILE_ACTIONS + C_ALL_COMPILE_ACTIONS,
+ flag_groups = [
+ flag_group(
+ flags = ["-Werror"],
+ ),
+ ],
+ ),
+ ],
+)
+
feature_set(
name = "rv32im",
feature = [
":architecture",
":sys_spec",
+ ":all_warnings",
+ ":all_warnings_as_errors",
+ "@crt//features/common:includes",
+ "@crt//features/common:reproducible",
+ "@crt//features/common:symbol_garbage_collection",
"@crt//features/embedded:cc_constructor_destructor",
"@crt//features/embedded:exceptions",
"@crt//features/embedded:runtime_type_information",
- "@crt//platforms/riscv32/features:all_warnings_as_errors",
"@crt//platforms/riscv32/features:fastbuild",
- "@crt//features/common:includes",
- "@crt//features/common:all_warnings",
- "@crt//features/common:all_warnings_as_errors",
- "@crt//features/common:reproducible",
- # TODO(atv): It would be nice to have the feature, but for now enabling
- # this creates the wrong program.
- # "@crt//features/common:symbol_garbage_collection",
],
)
diff --git a/third_party/gemmlowp/BUILD b/third_party/gemmlowp/BUILD
new file mode 100644
index 0000000..584f7e9
--- /dev/null
+++ b/third_party/gemmlowp/BUILD
@@ -0,0 +1,4 @@
+# Copyright 2023 Google LLC
+package(default_visibility = ["//visibility:public"])
+
+exports_files(glob(["*.patch"]))
diff --git a/third_party/gemmlowp/pthread.patch b/third_party/gemmlowp/pthread.patch
new file mode 100644
index 0000000..547dd52
--- /dev/null
+++ b/third_party/gemmlowp/pthread.patch
@@ -0,0 +1,13 @@
+diff --git a/flags.bzl b/flags.bzl
+index e35fe9e..e26a448 100644
+--- a/flags.bzl
++++ b/flags.bzl
+@@ -4,7 +4,7 @@ LIB_COPTS = []
+ LIB_LINKOPTS = select({
+ ":android": [],
+ ":windows": [],
+- "//conditions:default": ["-lpthread"],
++ "//conditions:default": [],
+ })
+
+ BIN_LINKOPTS = LIB_LINKOPTS
\ No newline at end of file
diff --git a/third_party/ruy/BUILD b/third_party/ruy/BUILD
new file mode 100644
index 0000000..584f7e9
--- /dev/null
+++ b/third_party/ruy/BUILD
@@ -0,0 +1,4 @@
+# Copyright 2023 Google LLC
+package(default_visibility = ["//visibility:public"])
+
+exports_files(glob(["*.patch"]))
diff --git a/third_party/ruy/pthread.patch b/third_party/ruy/pthread.patch
new file mode 100644
index 0000000..c8ddf4d
--- /dev/null
+++ b/third_party/ruy/pthread.patch
@@ -0,0 +1,11 @@
+diff --git a/ruy/build_defs.oss.bzl b/ruy/build_defs.oss.bzl
+index e405b41..1d7612b 100644
+--- a/ruy/build_defs.oss.bzl
++++ b/ruy/build_defs.oss.bzl
+@@ -11,5 +11,5 @@ def ruy_linkopts_thread_standard_library():
+ # https://github.com/abseil/abseil-cpp/blob/1112609635037a32435de7aa70a9188dcb591458/absl/base/BUILD.bazel#L155
+ return select({
+ "@bazel_tools//src/conditions:windows": [],
+- "//conditions:default": ["-pthread"],
++ "//conditions:default": [],
+ })
\ No newline at end of file