blob: 9bf773e27a81bb0ed0a3b8a85ac212a5d7364501 [file] [log] [blame]
suleshahidc306ff62023-06-13 17:25:02 -07001"""Build rule for wrapping a custom TF OP from .cc to python."""
2
Advait Jain699c5172023-07-14 14:59:36 -07003load("@rules_python//python:defs.bzl", "py_library")
4
suleshahidc306ff62023-06-13 17:25:02 -07005# TODO(b/286890280): refactor to be more generic build target for any custom OP
6def py_tflm_signal_library(
7 name,
8 srcs = [],
9 deps = [],
10 visibility = None,
11 cc_op_defs = [],
12 cc_op_kernels = []):
13 """Creates build rules for signal ops as shared libraries.
14
15 Defines three targets:
16 <name>
17 Python library that exposes all ops defined in `cc_op_defs` and `py_srcs`.
18 <name>_cc
19 C++ library that registers any c++ ops in `cc_op_defs`, and includes the
20 kernels from `cc_op_kernels`.
21 ops/_<name>.so
22 Shared library exposing the <name>_cc library.
23 Args:
24 name: The name for the python library target build by this rule.
25 srcs: Python source files for the Python library.
26 deps: Dependencies for the Python library.
27 visibility: Visibility for the Python library.
28 cc_op_defs: A list of c++ src files containing REGISTER_OP definitions.
29 cc_op_kernels: A list of c++ targets containing kernels that are used
30 by the Python library.
31 """
32 binary_path = "ops"
33 if srcs:
34 binary_path_end_pos = srcs[0].rfind("/")
35 binary_path = srcs[0][0:binary_path_end_pos]
36 binary_name = binary_path + "/_" + cc_op_kernels[0][1:] + ".so"
37 if cc_op_defs:
38 binary_name = "ops/_" + name + ".so"
39 library_name = name + "_cc"
40 native.cc_library(
41 name = library_name,
42 srcs = cc_op_defs,
Ryan Kuester52c95682023-06-15 15:22:54 -050043 copts = select({
suleshahidc306ff62023-06-13 17:25:02 -070044 "//conditions:default": ["-pthread"],
45 }),
46 alwayslink = 1,
47 deps =
48 cc_op_kernels +
49 ["@tensorflow_cc_deps//:cc_library"] +
50 select({"//conditions:default": []}),
51 )
52
53 native.cc_binary(
54 name = binary_name,
Ryan Kuester52c95682023-06-15 15:22:54 -050055 copts = select({
suleshahidc306ff62023-06-13 17:25:02 -070056 "//conditions:default": ["-pthread"],
57 }),
58 linkshared = 1,
59 linkopts = [],
60 deps = [
61 ":" + library_name,
62 "@tensorflow_cc_deps//:cc_library",
63 ] + select({"//conditions:default": []}),
64 )
65
Advait Jain699c5172023-07-14 14:59:36 -070066 py_library(
suleshahidc306ff62023-06-13 17:25:02 -070067 name = name,
68 srcs = srcs,
69 srcs_version = "PY2AND3",
70 visibility = visibility,
71 data = [":" + binary_name],
72 deps = deps,
73 )
74
75# A rule to build a TensorFlow OpKernel.
76def tflm_signal_kernel_library(
77 name,
78 srcs = [],
79 hdrs = [],
80 deps = [],
81 copts = [],
82 alwayslink = 1):
83 native.cc_library(
84 name = name,
85 srcs = srcs,
86 hdrs = hdrs,
87 deps = deps,
Ryan Kuester52c95682023-06-15 15:22:54 -050088 copts = copts,
suleshahidc306ff62023-06-13 17:25:02 -070089 alwayslink = alwayslink,
90 )