blob: 28b6232b339dbc82440c21aacdb83570a9a788ab [file] [log] [blame]
""" Build rule for generating ML inference code from TFLite model. """
load("//tensorflow/lite/micro:build_def.bzl", "micro_copts")
def tflm_inference_library(
name,
tflite_model,
visibility = None):
"""Creates a C++ library capable of performing ML inference of the provided
model.
Args:
name: Target name.
tflite_model: TFLite Model to generate inference from.
visibility: Visibility for the C++ library.
"""
generated_target = name + "_gen"
native.genrule(
name = generated_target,
srcs = [tflite_model],
outs = [name + ".h", name + ".cc"],
tools = ["//codegen:code_generator"],
cmd = "$(location //codegen:code_generator) " +
"--model=$< --output_dir=$(RULEDIR) --output_name=%s" % name,
visibility = ["//visibility:private"],
)
native.cc_library(
name = name,
hdrs = [name + ".h"],
srcs = [name + ".cc"],
deps = [
generated_target,
"//codegen/runtime:micro_codegen_context",
"//tensorflow/lite/c:common",
"//tensorflow/lite/c:c_api_types",
"//tensorflow/lite/kernels/internal:compatibility",
"//tensorflow/lite/micro/kernels:micro_ops",
"//tensorflow/lite/micro:micro_common",
"//tensorflow/lite/micro:micro_context",
],
copts = micro_copts(),
visibility = visibility,
)