blob: 3e9e41d5b9ee2e46056dff3f412b4afa3eb38bbc [file] [log] [blame]
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# End-to-end attention tests.
load("//build_tools/bazel:iree_e2e_generated_runner_test.bzl", "iree_generated_e2e_runner_test")
package(
features = ["layering_check"],
licenses = ["notice"], # Apache 2.0
)
py_binary(
name = "generate_e2e_attention_tests",
srcs = ["generate_e2e_attention_tests.py"],
)
###########################################################################
##
## LLVMCPU backend
##
###########################################################################
# Default CPU backend.
[iree_generated_e2e_runner_test(
name = "e2e_attention_cpu_%s_%s_%s_%s" % (dtype, dtype, dtype, size),
generator = ":generate_e2e_attention_tests",
generator_args = [
"--query_type=%s" % dtype,
"--key_type=%s" % dtype,
"--value_type=%s" % dtype,
"--shapes=%s" % size,
],
tags = [
"hostonly",
"local",
],
target_backends_and_drivers = [
("llvm-cpu", "local-task"),
],
target_cpu_features_variants = ["default"],
test_runner = "//tools/testing/e2e:iree-e2e-attention-test",
test_type = "attention",
) for dtype in [
"f16",
] for size in [
"small",
"medium",
"large",
]]