blob: 8944c24ce379a7f2b3011f6f04043aeed0912c02 [file]
# Copyright 2026 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Generates CTS test binaries for a HAL driver.
Usage in a driver's cts/BUILD.bazel:
load("//build_tools/bazel:build_defs.oss.bzl", "iree_runtime_cc_library")
load("//build_tools/bazel:iree_hal_cts_test_suite.bzl", "iree_hal_cts_test_suite")
iree_runtime_cc_library(
name = "backends",
testonly = True,
srcs = ["backends.cc"],
deps = [...],
alwayslink = True,
)
iree_hal_cts_test_suite(
backends_lib = ":backends",
executable_formats = {
"vmvx": {
"target_device": "local",
"flags": ["--iree-hal-local-target-device-backends=vmvx"],
"identifier": "iree_cts_testdata_vmvx",
"backend_name": "local_task",
"format_string": '"vmvx-bytecode-fb"',
},
},
testdata = "//runtime/src/iree/hal/cts/testdata:executable_srcs",
)
"""
load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo")
load("//build_tools/bazel:build_defs.oss.bzl", "iree_runtime_cc_library", "iree_runtime_cc_test")
load("//build_tools/bazel:iree_hal_executable.bzl", "iree_hal_executables")
# Non-executable test categories. Each entry maps a test binary name suffix
# to the aggregate test library it links.
_NON_EXECUTABLE_SUITES = [
("buffer_tests", "//runtime/src/iree/hal/cts/buffer:all_tests"),
("command_buffer_tests", "//runtime/src/iree/hal/cts/command_buffer:all_tests"),
("core_tests", "//runtime/src/iree/hal/cts/core:all_tests"),
("file_tests", "//runtime/src/iree/hal/cts/file:all_tests"),
("queue_tests", "//runtime/src/iree/hal/cts/queue:all_tests"),
]
# Executable-dependent test categories. Each entry maps a test binary name
# suffix to the aggregate test library it links.
_EXECUTABLE_SUITES = [
("dispatch_tests", "//runtime/src/iree/hal/cts/command_buffer:all_dispatch_tests"),
("executable_tests", "//runtime/src/iree/hal/cts/core:all_executable_tests"),
]
def _camel_case(snake_str):
"""Converts snake_case to CamelCase: 'llvm_cpu' -> 'LlvmCpu'."""
result = ""
for part in snake_str.split("_"):
result += part.capitalize()
return result
def _cts_testdata_gen_impl(ctx):
"""Expands the testdata_format.cc.tpl template with build setting resolution.
Like expand_template, but resolves template variables in substitution
values from build settings specified in flag_values. Non-build-setting
entries (file targets) are ignored — they only apply to compiler flag
resolution in iree_hal_executables, not to C++ template expansion.
"""
substitutions = dict(ctx.attr.substitutions)
for target, placeholder in ctx.attr.flag_values.items():
if BuildSettingInfo not in target:
continue
value = target[BuildSettingInfo].value
template = "{%s}" % placeholder
substitutions = {
key: val.replace(template, value)
for key, val in substitutions.items()
}
ctx.actions.expand_template(
template = ctx.file.template,
output = ctx.outputs.out,
substitutions = substitutions,
)
return [DefaultInfo(files = depset([ctx.outputs.out]))]
_cts_testdata_gen = rule(
implementation = _cts_testdata_gen_impl,
attrs = {
"template": attr.label(mandatory = True, allow_single_file = True),
"out": attr.output(mandatory = True),
"substitutions": attr.string_dict(),
"flag_values": attr.label_keyed_string_dict(
allow_files = True,
),
},
)
def iree_hal_cts_testdata(
format_name,
target_device,
identifier,
backend_name,
format_string,
testdata,
flags = [],
flag_values = {},
data = [],
testonly = True,
**kwargs):
"""Compiles CTS test executables and creates a testdata registration library.
Use this directly when multiple iree_hal_cts_test_suite() calls need to
share the same compiled executables (e.g., CUDA graph/stream variants).
For single-variant drivers, use executable_formats in iree_hal_cts_test_suite
instead -- it calls this internally.
Returns the label of the generated testdata library (e.g., ":testdata_cuda_lib").
Args:
format_name: Short name (e.g., "vmvx", "cuda", "hip").
target_device: Target device for iree-compile.
identifier: C identifier for the embedded data.
backend_name: Backend name for CtsRegistry registration.
format_string: C expression for the format string. May contain
{PLACEHOLDER} template variables resolved from flag_values.
testdata: Filegroup label for MLIR test sources (e.g.,
"//runtime/src/iree/hal/cts/testdata:executable_srcs").
flags: Compiler flags. May contain {PLACEHOLDER} template variables
resolved from flag_values.
flag_values: Dict mapping placeholder names to target labels.
See iree_hal_executable() for details.
data: Additional files for the compile action inputs.
testonly: Defaults to True.
**kwargs: Forwarded to underlying rules.
"""
testdata_name = "testdata_%s" % format_name
# iree_hal_executables() is a macro that inverts flag_values internally,
# so pass the user-facing form directly.
iree_hal_executables(
name = testdata_name,
srcs = [testdata],
target_device = target_device,
flags = flags,
flag_values = flag_values,
data = data,
identifier = identifier,
testonly = testonly,
**kwargs
)
gen_cc_name = "%s_gen" % testdata_name
gen_cc_file = "%s.cc" % testdata_name
header_path = "%s/%s.h" % (native.package_name(), testdata_name)
func_name = _camel_case(format_name)
# Invert to {"//label": "PLACEHOLDER"} for label_keyed_string_dict.
# File targets pass through to the rule but are ignored during template
# expansion (only BuildSettingInfo entries apply to format_string).
rule_flag_values = {v: k for k, v in flag_values.items()}
_cts_testdata_gen(
name = gen_cc_name,
template = "//runtime/src/iree/hal/cts/util:testdata_format.cc.tpl",
out = gen_cc_file,
substitutions = {
"{HEADER_PATH}": header_path,
"{FORMAT_FUNC_NAME}": func_name,
"{IDENTIFIER}": identifier,
"{FORMAT_VAR_NAME}": "%s_format" % format_name,
"{BACKEND_NAME}": backend_name,
"{FORMAT_NAME}": format_name,
"{FORMAT_STRING}": format_string,
},
flag_values = rule_flag_values,
testonly = testonly,
)
testdata_lib_name = "%s_lib" % testdata_name
iree_runtime_cc_library(
name = testdata_lib_name,
testonly = testonly,
srcs = [gen_cc_file],
deps = [
":%s" % testdata_name,
"//runtime/src/iree/hal/cts/util:registry",
],
alwayslink = True,
)
return ":%s" % testdata_lib_name
def iree_hal_cts_test_suite(
backends_lib,
executable_formats = {},
testdata_libs = [],
testdata = None,
flag_values = {},
name = "",
args = [],
tags = [],
testonly = True,
**kwargs):
"""Generates CTS test binaries for a HAL driver.
Creates non-executable test binaries (core, buffer, command_buffer, queue,
file) that link against the provided backends library. If executable_formats
is provided, also compiles MLIR test sources for each format and creates
executable and dispatch test binaries.
Args:
backends_lib: Label of the hand-written backends.cc library that
registers the driver with CtsRegistry.
executable_formats: Dict mapping format names to config dicts. Each
config dict has keys:
target_device: Target device for iree-compile (e.g., "local").
flags: List of compiler flags.
identifier: C identifier for the embedded data (used to derive
the _create() function name in the generated header).
backend_name: Backend name string for CtsRegistry registration.
format_string: C expression for the executable format string
(e.g., '"vmvx-bytecode-fb"' or '"embedded-elf-" IREE_ARCH').
Mutually exclusive with testdata_libs.
testdata_libs: Pre-built testdata library labels for multi-variant
drivers. When multiple iree_hal_cts_test_suite() calls share
the same compiled executables (e.g., CUDA graph/stream variants),
define the testdata targets once and pass them here instead of
using executable_formats. Mutually exclusive with executable_formats.
testdata: Filegroup label for MLIR test sources (e.g.,
"//runtime/src/iree/hal/cts/testdata:executable_srcs").
Required when executable_formats is provided.
flag_values: Dict mapping string_flag build setting labels to
placeholder names. Forwarded to iree_hal_cts_testdata when
using executable_formats.
name: Optional name prefix for generated targets. When empty, targets
are named directly (core_tests, buffer_tests, etc.). When set,
targets are prefixed (stream_core_tests, graph_buffer_tests, etc.).
Use a prefix for multi-variant drivers (e.g., CUDA graph/stream).
args: Runtime arguments passed to all test binaries.
tags: Additional tags for test targets.
testonly: Defaults to True.
**kwargs: Forwarded to underlying rules (e.g., target_compatible_with).
"""
# Build the name prefix: "name_" if set, "" otherwise.
prefix = ("%s_" % name) if name else ""
if executable_formats and not testdata:
fail("iree_hal_cts_test_suite: testdata is required when executable_formats is provided")
# Use pre-built testdata libs if provided, otherwise compile from formats.
_testdata_libs = list(testdata_libs)
for format_name, config in executable_formats.items():
lib_label = iree_hal_cts_testdata(
format_name = format_name,
target_device = config["target_device"],
identifier = config["identifier"],
backend_name = config["backend_name"],
format_string = config["format_string"],
testdata = testdata,
flags = config.get("flags", []),
flag_values = flag_values,
testonly = testonly,
**kwargs
)
_testdata_libs.append(lib_label)
# Common deps for all test binaries.
common_deps = [
backends_lib,
"//runtime/src/iree/hal/cts/util:registry",
"//runtime/src/iree/hal/cts/util:test_base",
"//runtime/src/iree/testing:gtest",
]
# Non-executable test binaries.
for suffix, test_lib in _NON_EXECUTABLE_SUITES:
iree_runtime_cc_test(
name = "%s%s" % (prefix, suffix),
srcs = ["//runtime/src/iree/hal/cts/util:test_main.cc"],
args = args,
deps = common_deps + [test_lib],
tags = tags,
**kwargs
)
# Executable-dependent test binaries (only if formats are configured).
if _testdata_libs:
for suffix, test_lib in _EXECUTABLE_SUITES:
iree_runtime_cc_test(
name = "%s%s" % (prefix, suffix),
srcs = ["//runtime/src/iree/hal/cts/util:test_main.cc"],
args = args,
deps = common_deps + _testdata_libs + [test_lib],
tags = tags,
**kwargs
)