blob: 2dd8025ab4bbfac16daf5bbda56f156f5d8ea1e1 [file]
# Copyright 2021 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Helper functions for configuring IREE and dependent project WORKSPACE files."""
load("@bazel_skylib//lib:paths.bzl", "paths")
load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
IREE_INPUT_TORCH_ENV_KEY = "IREE_INPUT_TORCH"
CUDA_TOOLKIT_ROOT_ENV_KEY = "IREE_CUDA_TOOLKIT_ROOT"
# Our CI docker images use a stripped down CUDA directory tree in some
# images, and it is tailored just to support building key elements.
# When this is done, the IREE_CUDA_DEPS_DIR env var is set, and we
# respect that here in order to match the CMake side (which needs it
# because CUDA toolkit detection differs depending on whether it is
# stripped down or not).
# TODO: Simplify this on the CMake/docker side and update here to match.
# TODO(#15332): Dockerfiles no longer include these deps. Simplify.
CUDA_DEPS_DIR_FOR_CI_ENV_KEY = "IREE_CUDA_DEPS_DIR"
def cuda_auto_configure_impl(repository_ctx):
env = repository_ctx.os.environ
cuda_toolkit_root = None
iree_repo_alias = repository_ctx.attr.iree_repo_alias
# Probe environment for CUDA toolkit location.
env_cuda_toolkit_root = env.get(CUDA_TOOLKIT_ROOT_ENV_KEY)
env_cuda_deps_dir_for_ci = env.get(CUDA_DEPS_DIR_FOR_CI_ENV_KEY)
if env_cuda_toolkit_root:
cuda_toolkit_root = env_cuda_toolkit_root
elif env_cuda_deps_dir_for_ci:
cuda_toolkit_root = env_cuda_deps_dir_for_ci
# Symlink the tree.
libdevice_rel_path = "iree_local/libdevice.bc"
if cuda_toolkit_root != None:
# Symlink top-level directories we care about.
repository_ctx.symlink(cuda_toolkit_root + "/include", "include")
# TODO: Should be probing for the libdevice, as it can change from
# version to version.
repository_ctx.symlink(
cuda_toolkit_root + "/nvvm/libdevice/libdevice.10.bc",
libdevice_rel_path,
)
repository_ctx.template(
"BUILD",
Label("%s//:build_tools/third_party/cuda/BUILD.template" % iree_repo_alias),
{
"%ENABLED%": "True" if cuda_toolkit_root else "False",
"%LIBDEVICE_REL_PATH%": libdevice_rel_path if cuda_toolkit_root else "BUILD",
"%IREE_REPO_ALIAS%": iree_repo_alias,
},
)
cuda_auto_configure = repository_rule(
environ = [
CUDA_DEPS_DIR_FOR_CI_ENV_KEY,
CUDA_TOOLKIT_ROOT_ENV_KEY,
],
implementation = cuda_auto_configure_impl,
attrs = {
"iree_repo_alias": attr.string(default = "@iree_core"),
},
)
def torch_mlir_auto_configure_impl(repository_ctx):
"""Conditionally configures torch-mlir based on IREE_INPUT_TORCH env var."""
env = repository_ctx.os.environ
iree_repo_alias = repository_ctx.attr.iree_repo_alias
enabled = env.get(IREE_INPUT_TORCH_ENV_KEY, "OFF").upper() in ["ON", "TRUE", "1", "YES"]
if enabled:
# Run torch-mlir's configure to create the overlay.
# We need to find the torch-mlir source and run its overlay script.
torch_mlir_path = repository_ctx.path(
Label("%s//:third_party/torch-mlir/CMakeLists.txt" % iree_repo_alias),
).dirname
bazel_path = torch_mlir_path.get_child("utils").get_child("bazel")
overlay_path = bazel_path.get_child("torch-mlir-overlay")
script_path = bazel_path.get_child("overlay_directories.py")
python_bin = repository_ctx.which("python3")
if not python_bin:
python_bin = repository_ctx.which("python")
if not python_bin:
fail("Failed to find python3 binary for torch-mlir configuration")
cmd = [
python_bin,
script_path,
"--src",
torch_mlir_path,
"--overlay",
overlay_path,
"--target",
".",
]
exec_result = repository_ctx.execute(cmd, timeout = 60)
if exec_result.return_code != 0:
fail(("Failed to configure torch-mlir: '{cmd}'\n" +
"Exited with code {return_code}\n" +
"stdout:\n{stdout}\n" +
"stderr:\n{stderr}\n").format(
cmd = " ".join([str(arg) for arg in cmd]),
return_code = exec_result.return_code,
stdout = exec_result.stdout,
stderr = exec_result.stderr,
))
else:
# Create stub repository when torch-mlir is disabled.
repository_ctx.file(
"BUILD.bazel",
content = """# Stub: torch-mlir disabled (IREE_INPUT_TORCH != ON)
package(default_visibility = ["//visibility:public"])
# Provide empty targets that dependent code can reference.
# These will fail at build time if actually used.
""",
)
torch_mlir_auto_configure = repository_rule(
environ = [IREE_INPUT_TORCH_ENV_KEY],
implementation = torch_mlir_auto_configure_impl,
local = True,
attrs = {
"iree_repo_alias": attr.string(default = "@iree_core"),
},
)
def configure_iree_torch_mlir_deps(iree_repo_alias = None):
maybe(
torch_mlir_auto_configure,
name = "torch-mlir",
iree_repo_alias = iree_repo_alias,
)
def configure_iree_cuda_deps(iree_repo_alias = None):
maybe(
cuda_auto_configure,
name = "iree_cuda",
iree_repo_alias = iree_repo_alias,
)
def configure_iree_submodule_deps(iree_repo_alias = "@", iree_path = "./"):
"""Configure all of IREE's simple repository dependencies that come from submodules.
Simple is defined here as just calls to `local_repository` or
`new_local_repository`. This assumes you have a directory that includes IREE
and all its submodules. Note that fetching a GitHub archive does not include
submodules.
Yes it is necessary to have both the workspace alias and path argument...
Args:
iree_repo_alias: The alias for the IREE repository.
iree_path: The path to the IREE repository containing submodules
"""
maybe(
native.local_repository,
name = "com_google_googletest",
path = paths.join(iree_path, "third_party/googletest"),
)
maybe(
native.new_local_repository,
name = "com_github_dvidelabs_flatcc",
build_file = iree_repo_alias + "//:build_tools/third_party/flatcc/BUILD.overlay",
path = paths.join(iree_path, "third_party/flatcc"),
)
maybe(
native.new_local_repository,
name = "vulkan_headers",
build_file = iree_repo_alias + "//:build_tools/third_party/vulkan_headers/BUILD.overlay",
path = paths.join(iree_path, "third_party/vulkan_headers"),
)
maybe(
native.local_repository,
name = "stablehlo",
path = paths.join(iree_path, "third_party/stablehlo"),
)
maybe(
native.local_repository,
name = "com_google_benchmark",
path = paths.join(iree_path, "third_party/benchmark"),
)
maybe(
native.new_local_repository,
name = "spirv_cross",
build_file = iree_repo_alias + "//:build_tools/third_party/spirv_cross/BUILD.overlay",
path = paths.join(iree_path, "third_party/spirv_cross"),
)
maybe(
native.new_local_repository,
name = "tracy_client",
build_file = iree_repo_alias + "//:build_tools/third_party/tracy_client/BUILD.overlay",
path = paths.join(iree_path, "third_party/tracy"),
)
maybe(
native.new_local_repository,
name = "nccl",
build_file = iree_repo_alias + "//:build_tools/third_party/nccl/BUILD.overlay",
path = paths.join(iree_path, "third_party/nccl"),
)
maybe(
native.new_local_repository,
name = "hsa_runtime_headers",
build_file = iree_repo_alias + "//:build_tools/third_party/hsa-runtime-headers/BUILD.overlay",
path = paths.join(iree_path, "third_party/hsa-runtime-headers"),
)
maybe(
native.new_local_repository,
name = "webgpu_headers",
build_file = iree_repo_alias + "//:build_tools/third_party/webgpu-headers/BUILD.overlay",
path = paths.join(iree_path, "third_party/webgpu-headers"),
)