blob: f58e9a10964a3f384387903e1b479e76087db817 [file]
# Copyright 2020 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import platform
import os
import shutil
import sys
IREE_CUDA_DEPS_DIR_ENV_KEY = "IREE_CUDA_DEPS_DIR"
IREE_CUDA_TOOLKIT_ROOT_ENV_KEY = "IREE_CUDA_TOOLKIT_ROOT"
VULKAN_SDK_ENV_KEY = "VULKAN_SDK"
def detect_unix_platform_config(bazelrc):
# This is hoaky. Ideally, bazel had any kind of rational way of selecting
# options from within its environment (key word: "rational"), but sadly, it
# is unintelligible to mere mortals. Why should a build system have a way for
# people to condition their build options on what compiler they are using
# (without descending down the hole of deciphering what a Bazel toolchain is)?
# All I want to do is set a couple of project specific warning options!
if platform.system() == "Darwin":
print(f"build --config=macos_clang", file=bazelrc)
print(f"build:release --config=macos_clang_release", file=bazelrc)
else:
# If the user specified a CXX environment var, bazel will later respect that,
# so we just see if it says "clang".
cxx = os.environ.get("CXX")
cc = os.environ.get("CC")
if (cxx is not None and cc is None) or (cxx is None and cc is not None):
print(
"WARNING: Only one of CXX or CC is set, which can confuse bazel. "
"Recommend: set both appropriately (or none)"
)
if cc is not None and cxx is not None:
# Persist the variables.
print(f'build --action_env CC="{cc}"', file=bazelrc)
print(f'build --action_env CXX="{cxx}"', file=bazelrc)
else:
print(
"WARNING: CC and CXX are not set, which can cause mismatches between "
"flag configurations and compiler. Recommend setting them explicitly."
)
if cxx is not None and "clang" in cxx:
print(f"Choosing generic_clang config because CXX is set to clang ({cxx})")
print(f"build --config=generic_clang", file=bazelrc)
print(f"build:release --config=generic_clang_release", file=bazelrc)
else:
print(
f"Choosing generic_gcc config by default because no CXX set or "
f"not recognized as clang ({cxx})"
)
print(f"build --config=generic_gcc", file=bazelrc)
print(f"build:release --config=generic_gcc_release", file=bazelrc)
def write_platform(bazelrc):
if platform.system() == "Windows":
print(f"build --config=msvc", file=bazelrc)
print(f"build:release --config=msvc_release", file=bazelrc)
else:
detect_unix_platform_config(bazelrc)
def cmake_bool_is_true(value):
"""Check if a CMake-style bool value is true."""
if not value:
return False
return value.upper() in ("ON", "YES", "TRUE", "Y", "1")
def is_cuda_toolkit_root(path):
"""Check if a directory is a CUDA toolkit root usable by @iree_cuda."""
return bool(path) and os.path.isfile(
os.path.join(path, "nvvm", "libdevice", "libdevice.10.bc")
)
def find_cuda_toolkit_root():
"""Find a CUDA toolkit root usable by the Bazel @iree_cuda repository."""
for env_var in [
IREE_CUDA_TOOLKIT_ROOT_ENV_KEY,
IREE_CUDA_DEPS_DIR_ENV_KEY,
"CUDA_ROOT",
"CUDA_HOME",
"CUDA_PATH",
"CUDA_TOOLKIT_ROOT_DIR",
]:
env_path = os.environ.get(env_var)
if is_cuda_toolkit_root(env_path):
return env_path, env_var
nvcc_path = shutil.which("nvcc")
if nvcc_path:
nvcc_toolkit_root = os.path.dirname(
os.path.dirname(os.path.realpath(nvcc_path))
)
if is_cuda_toolkit_root(nvcc_toolkit_root):
return nvcc_toolkit_root, "nvcc"
return None, None
def detect_cuda_toolkit():
"""Check if CUDA toolkit is available for Bazel plugin builds."""
toolkit_root, _ = find_cuda_toolkit_root()
return toolkit_root is not None
def is_vulkan_sdk_root(path):
"""Check if a directory is a Vulkan SDK root usable by @vulkan_sdk."""
if not path:
return False
if platform.system() == "Windows":
return os.path.isfile(os.path.join(path, "Lib", "vulkan-1.lib"))
return os.path.isfile(os.path.join(path, "lib", "libvulkan.so.1"))
def normalize_vulkan_sdk_root(path):
"""Return the SDK subdirectory shape expected by LLVM's Vulkan SDK rule."""
if not path:
return None
path = os.path.realpath(os.path.expanduser(path))
for candidate in [
path,
os.path.join(path, "x86_64"),
]:
if is_vulkan_sdk_root(candidate):
return candidate
return None
def find_vulkan_sdk_root():
"""Find a Vulkan SDK root usable by LLVM's Bazel @vulkan_sdk repository."""
env_vulkan_sdk = os.environ.get(VULKAN_SDK_ENV_KEY)
vulkan_sdk_root = normalize_vulkan_sdk_root(env_vulkan_sdk)
if vulkan_sdk_root:
return vulkan_sdk_root, VULKAN_SDK_ENV_KEY
if env_vulkan_sdk:
print(
f"WARNING: {VULKAN_SDK_ENV_KEY} is set but does not contain a "
"Vulkan SDK layout usable by Bazel"
)
vulkaninfo_path = shutil.which("vulkaninfo")
if vulkaninfo_path:
vulkaninfo_sdk_root = os.path.dirname(
os.path.dirname(os.path.realpath(vulkaninfo_path))
)
vulkan_sdk_root = normalize_vulkan_sdk_root(vulkaninfo_sdk_root)
if vulkan_sdk_root:
return vulkan_sdk_root, "vulkaninfo"
candidate_roots = []
if platform.system() == "Windows":
candidate_roots.append(r"C:\VulkanSDK")
else:
candidate_roots.extend(
[
"/opt/vulkan",
"/opt/VulkanSDK",
os.path.expanduser("~/VulkanSDK"),
]
)
for root in candidate_roots:
vulkan_sdk_root = normalize_vulkan_sdk_root(root)
if vulkan_sdk_root:
return vulkan_sdk_root, root
if os.path.isdir(root):
for child in sorted(os.listdir(root), reverse=True):
vulkan_sdk_root = normalize_vulkan_sdk_root(os.path.join(root, child))
if vulkan_sdk_root:
return vulkan_sdk_root, root
return None, None
def detect_rocm_toolkit():
"""Check if ROCm toolkit is available.
Follows MLIR's search order: ROCM_PATH, ROCM_ROOT, ROCM_HOME, HIP_PATH,
then hipcc in PATH.
"""
for env_var in ["ROCM_PATH", "ROCM_ROOT", "ROCM_HOME", "HIP_PATH"]:
if os.environ.get(env_var):
return True
if shutil.which("hipcc"):
return True
return False
def detect_submodule(submodule_path):
"""Check if a submodule is initialized by looking for a marker file.
Args:
submodule_path: Path relative to repo root (e.g., "third_party/torch-mlir")
Returns:
True if the submodule appears to be initialized.
"""
script_dir = os.path.dirname(os.path.abspath(__file__))
# Check for CMakeLists.txt as a common marker that the submodule has content
marker = os.path.join(script_dir, submodule_path, "CMakeLists.txt")
return os.path.isfile(marker)
def get_hal_driver_defaults():
"""Get HAL driver defaults matching CMake option(IREE_HAL_DRIVER_*) definitions."""
defaults_enabled = True # Matches IREE_HAL_DRIVER_DEFAULTS in CMakeLists.txt
return {
"AMDGPU": False,
"CUDA": False,
"HIP": False,
"LOCAL_SYNC": defaults_enabled,
"LOCAL_TASK": defaults_enabled,
"METAL": platform.system() == "Darwin" and defaults_enabled,
"NULL": False, # Special: OFF in tests, ON otherwise
"VULKAN": defaults_enabled and platform.system() not in ("Android", "iOS"),
}
def env_var_to_bazel_tag(name):
"""Convert env var name to Bazel tag format.
Bazel tags use hyphens: local-task, vulkan-spirv
Env vars use underscores: IREE_HAL_DRIVER_LOCAL_TASK
"""
if name.startswith("IREE_HAL_DRIVER_"):
tag_name = name[len("IREE_HAL_DRIVER_") :]
else:
tag_name = name
return tag_name.lower().replace("_", "-")
def get_plugin_defaults():
"""Get compiler plugin defaults matching CMake option definitions.
Returns a dict mapping plugin_id -> (default_enabled, env_var_name, can_build).
The can_build field indicates whether the plugin can be built on this system
(e.g., CUDA requires toolkit, some input plugins require submodules).
"""
cuda_available = detect_cuda_toolkit()
rocm_available = detect_rocm_toolkit()
stablehlo_available = detect_submodule("third_party/stablehlo")
torch_mlir_available = detect_submodule("third_party/torch-mlir")
return {
# Input plugins (require submodules)
"input_stablehlo": (True, "IREE_INPUT_STABLEHLO", stablehlo_available),
"input_tosa": (True, "IREE_INPUT_TOSA", True), # Part of MLIR, no submodule
"input_torch": (False, "IREE_INPUT_TORCH", torch_mlir_available),
# Target plugins
"hal_target_cuda": (False, "IREE_TARGET_BACKEND_CUDA", cuda_available),
"hal_target_llvm_cpu": (True, "IREE_TARGET_BACKEND_LLVM_CPU", True),
"hal_target_local": (True, "IREE_TARGET_BACKEND_LOCAL", True),
"hal_target_metal_spirv": (True, "IREE_TARGET_BACKEND_METAL_SPIRV", True),
"hal_target_rocm": (False, "IREE_TARGET_BACKEND_ROCM", rocm_available),
"hal_target_vmvx": (True, "IREE_TARGET_BACKEND_VMVX", True),
"hal_target_vulkan_spirv": (True, "IREE_TARGET_BACKEND_VULKAN_SPIRV", True),
# Sample plugins (always buildable)
"example": (True, None, True),
"simple_io_sample": (True, None, True),
}
def parse_plugin_spec(spec, plugin_defaults):
"""Parse a plugin specification that may include 'all' and exclusions.
Supports:
- "all" - all plugins that can be built on this system
- "all,-plugin1,-plugin2" - all except specified plugins
- "plugin1,plugin2" - explicit list
Returns a tuple (plugin_list, used_all) where used_all indicates whether
"all" expansion was performed, or (None, False) if no spec provided.
"""
if not spec:
return None, False
parts = [p.strip().lower() for p in spec.split(",") if p.strip()]
if not parts:
return [], False
if parts[0] == "all":
# Start with all buildable plugins
enabled = set(
plugin_id
for plugin_id, (_, _, can_build) in plugin_defaults.items()
if can_build
)
# Process exclusions
for part in parts[1:]:
if not part.startswith("-"):
print(
f"ERROR: Expected plugin exclusion after 'all' but got '{part}'. "
f"Use 'all,-{part}' to exclude it."
)
sys.exit(1)
plugin_to_exclude = part[1:]
if plugin_to_exclude not in plugin_defaults:
print(f"ERROR: Unknown plugin in exclusion: {plugin_to_exclude}")
sys.exit(1)
enabled.discard(plugin_to_exclude)
return sorted(enabled), True
else:
# Explicit list - validate plugins exist and can be built
validated = []
for plugin_id in parts:
if plugin_id not in plugin_defaults:
print(f"ERROR: Unknown plugin: {plugin_id}")
sys.exit(1)
_, _, can_build = plugin_defaults[plugin_id]
if not can_build:
print(
f"ERROR: Plugin '{plugin_id}' requested but cannot be built "
f"(missing toolkit/submodule). Either install prerequisites or "
f"remove from IREE_COMPILER_PLUGINS."
)
sys.exit(1)
validated.append(plugin_id)
return validated, False
def write_iree_plugin_options(bazelrc):
"""Write compiler plugin configuration to bazelrc."""
plugin_defaults = get_plugin_defaults()
# Check for IREE_COMPILER_PLUGINS env var with "all" support
plugins_spec = os.environ.get("IREE_COMPILER_PLUGINS")
parsed_plugins, used_all = parse_plugin_spec(plugins_spec, plugin_defaults)
if parsed_plugins is not None:
# Explicit specification via IREE_COMPILER_PLUGINS
enabled_plugins = parsed_plugins
if used_all:
print(
f"IREE_COMPILER_PLUGINS=all resolved to: {', '.join(enabled_plugins)}"
)
else:
print(f"IREE_COMPILER_PLUGINS set to: {', '.join(enabled_plugins)}")
else:
# Standard per-plugin env var processing
enabled_plugins = []
for plugin_id, (default, env_var, can_build) in plugin_defaults.items():
if env_var:
env_value = os.environ.get(env_var)
enabled = (
cmake_bool_is_true(env_value) if env_value is not None else default
)
else:
enabled = default
if enabled:
if not can_build:
print(
f"WARNING: {plugin_id} enabled but toolkit not detected, skipping"
)
else:
enabled_plugins.append(plugin_id)
# Write --iree_compiler_plugins flag (controls what gets built and linked)
# Always emit the flag, even for empty list, so Bazel doesn't fall back to defaults.
print(f'build --iree_compiler_plugins={",".join(enabled_plugins)}', file=bazelrc)
# Persist the CUDA repository environment variable so Bazel sees the same
# toolkit root that plugin selection detected.
if "hal_target_cuda" in enabled_plugins:
cuda_toolkit_root, cuda_source = find_cuda_toolkit_root()
if cuda_source == IREE_CUDA_DEPS_DIR_ENV_KEY:
cuda_env_key = IREE_CUDA_DEPS_DIR_ENV_KEY
else:
cuda_env_key = IREE_CUDA_TOOLKIT_ROOT_ENV_KEY
print(f"common --repo_env={cuda_env_key}={cuda_toolkit_root}", file=bazelrc)
def write_vulkan_sdk_options(bazelrc):
"""Write Vulkan SDK repository configuration to bazelrc when available."""
vulkan_sdk_root, vulkan_source = find_vulkan_sdk_root()
if not vulkan_sdk_root:
return
print(f"Detected Vulkan SDK at {vulkan_sdk_root} via {vulkan_source}")
print(f"common --repo_env={VULKAN_SDK_ENV_KEY}={vulkan_sdk_root}", file=bazelrc)
def write_iree_hal_driver_options(bazelrc):
"""Write HAL driver configuration to bazelrc."""
# Get defaults matching CMake
hal_drivers = get_hal_driver_defaults()
# Apply environment overrides
enabled_drivers = []
for driver, default in hal_drivers.items():
env_var = f"IREE_HAL_DRIVER_{driver}"
env_value = os.environ.get(env_var)
enabled = cmake_bool_is_true(env_value) if env_value is not None else default
if enabled:
enabled_drivers.append(env_var_to_bazel_tag(env_var))
# Write --iree_runtime_drivers flag (controls what gets built and linked)
if enabled_drivers:
print(f'build --iree_runtime_drivers={",".join(enabled_drivers)}', file=bazelrc)
print(f'test --iree_runtime_drivers={",".join(enabled_drivers)}', file=bazelrc)
if len(sys.argv) > 1:
local_bazelrc = sys.argv[1]
else:
local_bazelrc = os.path.join(os.path.dirname(__file__), "configured.bazelrc")
with open(local_bazelrc, "wt") as bazelrc:
write_platform(bazelrc)
write_iree_hal_driver_options(bazelrc)
write_vulkan_sdk_options(bazelrc)
write_iree_plugin_options(bazelrc)
print("Wrote", local_bazelrc)