blob: 5e3fbcf156778ae21801b9a8df712a7cef0fffef [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for defining TensorFlow Bazel dependencies."""
def _get_env_var(ctx, name):
if name in ctx.os.environ:
return ctx.os.environ[name]
else:
return None
# Checks if we should use the system lib instead of the bundled one
def _use_system_lib(ctx, name):
syslibenv = _get_env_var(ctx, "TF_SYSTEM_LIBS")
if not syslibenv:
return False
return name in [n.strip() for n in syslibenv.split(",")]
def _get_link_dict(ctx, link_files, build_file):
if build_file:
# Use BUILD.bazel because it takes precedence over BUILD.
link_files = dict(link_files, **{build_file: "BUILD.bazel"})
return {ctx.path(v): Label(k) for k, v in link_files.items()}
def _tf_http_archive_impl(ctx):
# Construct all labels early on to prevent rule restart. We want the
# attributes to be strings instead of labels because they refer to files
# in the TensorFlow repository, not files in repos depending on TensorFlow.
# See also https://github.com/bazelbuild/bazel/issues/10515.
link_dict = _get_link_dict(ctx, ctx.attr.link_files, ctx.attr.build_file)
if _use_system_lib(ctx, ctx.attr.name):
link_dict.update(_get_link_dict(
ctx = ctx,
link_files = ctx.attr.system_link_files,
build_file = ctx.attr.system_build_file,
))
else:
patch_file = ctx.attr.patch_file
patch_file = Label(patch_file) if patch_file else None
ctx.download_and_extract(
url = ctx.attr.urls,
sha256 = ctx.attr.sha256,
type = ctx.attr.type,
stripPrefix = ctx.attr.strip_prefix,
)
if patch_file:
ctx.patch(patch_file, strip = 1)
for path, label in link_dict.items():
ctx.delete(path)
ctx.symlink(label, path)
_tf_http_archive = repository_rule(
implementation = _tf_http_archive_impl,
attrs = {
"sha256": attr.string(mandatory = True),
"urls": attr.string_list(mandatory = True),
"strip_prefix": attr.string(),
"type": attr.string(),
"patch_file": attr.string(),
"build_file": attr.string(),
"system_build_file": attr.string(),
"link_files": attr.string_dict(),
"system_link_files": attr.string_dict(),
},
environ = ["TF_SYSTEM_LIBS"],
)
def tf_http_archive(name, sha256, urls, **kwargs):
"""Downloads and creates Bazel repos for dependencies.
"""
if native.existing_rule(name):
print("\n\033[1;33mWarning:\033[0m skipping import of repository '" +
name + "' because it already exists.\n")
return
_tf_http_archive(
name = name,
sha256 = sha256,
urls = urls,
**kwargs
)