| """ Build rule for generating ML inference code from TFLite model. """ |
| |
| load("//tensorflow/lite/micro:build_def.bzl", "micro_copts") |
| |
| def tflm_inference_library( |
| name, |
| tflite_model, |
| visibility = None): |
| """Creates a C++ library capable of performing ML inference of the provided |
| model. |
| |
| Args: |
| name: Target name. |
| tflite_model: TFLite Model to generate inference from. |
| visibility: Visibility for the C++ library. |
| """ |
| generated_target = name + "_gen" |
| native.genrule( |
| name = generated_target, |
| srcs = [tflite_model], |
| outs = [name + ".h", name + ".cc"], |
| tools = ["//codegen:code_generator"], |
| cmd = "$(location //codegen:code_generator) " + |
| "--model=$< --output_dir=$(RULEDIR) --output_name=%s" % name, |
| visibility = ["//visibility:private"], |
| ) |
| |
| native.cc_library( |
| name = name, |
| hdrs = [name + ".h"], |
| srcs = [name + ".cc"], |
| deps = [ |
| generated_target, |
| "//codegen/runtime:micro_codegen_context", |
| "//tensorflow/lite/c:common", |
| "//tensorflow/lite/c:c_api_types", |
| "//tensorflow/lite/kernels/internal:compatibility", |
| "//tensorflow/lite/micro/kernels:micro_ops", |
| "//tensorflow/lite/micro:micro_common", |
| "//tensorflow/lite/micro:micro_context", |
| ], |
| copts = micro_copts(), |
| visibility = visibility, |
| ) |