Revert "Add e2e test suite for the Attention - CPU Backend" (#18302)
Reverts iree-org/iree#17751. A few of the new tests are failing on
various platforms:
* Timeouts (after 60 seconds) in
`iree/tests/e2e/attention/e2e_attention_cpu_f16_f16_f16_large_llvm-cpu_local-task`
on GitHub-hosted Windows and macOS runners
*
https://github.com/iree-org/iree/actions/runs/10468974350/job/28990992473#step:8:2477
*
https://github.com/iree-org/iree/actions/runs/10468947894/job/28990909629#step:9:3076
```
1529/1568 Test #969:
iree/tests/e2e/attention/e2e_attention_cpu_f16_f16_f16_large_llvm-cpu_local-task
.............................***Timeout 60.07 sec
---
TEST[attention_2_2048_256_512_128_dtype_f16_f16_f16_f16_2_2048_256_512_128_256_1.0_0]
---
Attention shape (BATCHxMxK1xK2xN): 2x2048x256x512x256x128
```
* Compilation error on arm64:
https://github.com/iree-org/iree/actions/runs/10468944505/job/28990909321#step:4:9815:
```
[415/1150] Generating
/work/build-arm64/tests/e2e/attention/e2e_attention_cpu_f16_f16_f16_medium_llvm-cpu_local-task_attention.vmfb
from
e2e_attention_cpu_f16_f16_f16_medium_llvm-cpu_local-task_attention.mlir
FAILED:
tests/e2e/attention/e2e_attention_cpu_f16_f16_f16_medium_llvm-cpu_local-task_attention.vmfb
/work/build-arm64/tests/e2e/attention/e2e_attention_cpu_f16_f16_f16_medium_llvm-cpu_local-task_attention.vmfb
cd /work/build-arm64/tests/e2e/attention &&
/work/build-arm64/tools/iree-compile --output-format=vm-bytecode
--mlir-print-op-on-diagnostic=false --iree-hal-target-backends=llvm-cpu
/work/build-arm64/tests/e2e/attention/e2e_attention_cpu_f16_f16_f16_medium_llvm-cpu_local-task_attention.mlir
-o
/work/build-arm64/tests/e2e/attention/e2e_attention_cpu_f16_f16_f16_medium_llvm-cpu_local-task_attention.vmfb
--iree-hal-executable-object-search-path=\"/work/build-arm64\"
--iree-llvmcpu-embedded-linker-path=\"/work/build-arm64/llvm-project/bin/lld\"
--iree-llvmcpu-wasm-linker-path=\"/work/build-arm64/llvm-project/bin/lld\"
/work/build-arm64/tests/e2e/attention/e2e_attention_cpu_f16_f16_f16_medium_llvm-cpu_local-task_attention.mlir:4:14:
error: Yield operand #2 is not equivalent to the corresponding iter
bbArg
%result1 = iree_linalg_ext.attention {
^
/work/build-arm64/tests/e2e/attention/e2e_attention_cpu_f16_f16_f16_medium_llvm-cpu_local-task_attention.mlir:1:1:
note: called from
func.func @attention_2_1024_128_256_64_dtype_f16_f16_f16_f16(%query:
tensor<2x1024x128xf16>, %key: tensor<2x256x128xf16>, %value:
tensor<2x256x64xf16>, %scale: f32) -> tensor<2x1024x64xf16> {
^
/work/build-arm64/tests/e2e/attention/e2e_attention_cpu_f16_f16_f16_medium_llvm-cpu_local-task_attention.mlir:4:14:
error: failed to run translation of source executable to target
executable for backend #hal.executable.target<"llvm-cpu",
"embedded-elf-arm_64", {cpu = "generic", cpu_features = "+reserve-x18",
data_layout =
"e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128-Fn32",
native_vector_size = 16 : i64, target_triple =
"aarch64-unknown-unknown-eabi-elf"}>
%result1 = iree_linalg_ext.attention {
^
/work/build-arm64/tests/e2e/attention/e2e_attention_cpu_f16_f16_f16_medium_llvm-cpu_local-task_attention.mlir:1:1:
note: called from
func.func @attention_2_1024_128_256_64_dtype_f16_f16_f16_f16(%query:
tensor<2x1024x128xf16>, %key: tensor<2x256x128xf16>, %value:
tensor<2x256x64xf16>, %scale: f32) -> tensor<2x1024x64xf16> {
^
failed to translate executables
```
diff --git a/tests/e2e/attention/BUILD.bazel b/tests/e2e/attention/BUILD.bazel
deleted file mode 100644
index 3e9e41d..0000000
--- a/tests/e2e/attention/BUILD.bazel
+++ /dev/null
@@ -1,53 +0,0 @@
-# Copyright 2024 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-# End-to-end attention tests.
-
-load("//build_tools/bazel:iree_e2e_generated_runner_test.bzl", "iree_generated_e2e_runner_test")
-
-package(
- features = ["layering_check"],
- licenses = ["notice"], # Apache 2.0
-)
-
-py_binary(
- name = "generate_e2e_attention_tests",
- srcs = ["generate_e2e_attention_tests.py"],
-)
-
-###########################################################################
-##
-## LLVMCPU backend
-##
-###########################################################################
-
-# Default CPU backend.
-[iree_generated_e2e_runner_test(
- name = "e2e_attention_cpu_%s_%s_%s_%s" % (dtype, dtype, dtype, size),
- generator = ":generate_e2e_attention_tests",
- generator_args = [
- "--query_type=%s" % dtype,
- "--key_type=%s" % dtype,
- "--value_type=%s" % dtype,
- "--shapes=%s" % size,
- ],
- tags = [
- "hostonly",
- "local",
- ],
- target_backends_and_drivers = [
- ("llvm-cpu", "local-task"),
- ],
- target_cpu_features_variants = ["default"],
- test_runner = "//tools/testing/e2e:iree-e2e-attention-test",
- test_type = "attention",
-) for dtype in [
- "f16",
-] for size in [
- "small",
- "medium",
- "large",
-]]
diff --git a/tests/e2e/attention/CMakeLists.txt b/tests/e2e/attention/CMakeLists.txt
deleted file mode 100644
index f793784..0000000
--- a/tests/e2e/attention/CMakeLists.txt
+++ /dev/null
@@ -1,88 +0,0 @@
-################################################################################
-# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
-# tests/e2e/attention/BUILD.bazel #
-# #
-# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
-# CMake-only content. #
-# #
-# To disable autogeneration for this file entirely, delete this header. #
-################################################################################
-
-iree_add_all_subdirs()
-
-iree_generated_e2e_runner_test(
- NAME
- e2e_attention_cpu_f16_f16_f16_small
- TEST_TYPE
- attention
- GENERATOR
- "generate_e2e_attention_tests.py"
- GENERATOR_ARGS
- "--query_type=f16"
- "--key_type=f16"
- "--value_type=f16"
- "--shapes=small"
- TEST_RUNNER
- iree_tools_testing_e2e_iree-e2e-attention-test
- TARGET_BACKENDS
- "llvm-cpu"
- DRIVERS
- "local-task"
- LABELS
- "hostonly"
- "local"
- TARGET_CPU_FEATURES_VARIANTS
- "default"
-)
-
-iree_generated_e2e_runner_test(
- NAME
- e2e_attention_cpu_f16_f16_f16_medium
- TEST_TYPE
- attention
- GENERATOR
- "generate_e2e_attention_tests.py"
- GENERATOR_ARGS
- "--query_type=f16"
- "--key_type=f16"
- "--value_type=f16"
- "--shapes=medium"
- TEST_RUNNER
- iree_tools_testing_e2e_iree-e2e-attention-test
- TARGET_BACKENDS
- "llvm-cpu"
- DRIVERS
- "local-task"
- LABELS
- "hostonly"
- "local"
- TARGET_CPU_FEATURES_VARIANTS
- "default"
-)
-
-iree_generated_e2e_runner_test(
- NAME
- e2e_attention_cpu_f16_f16_f16_large
- TEST_TYPE
- attention
- GENERATOR
- "generate_e2e_attention_tests.py"
- GENERATOR_ARGS
- "--query_type=f16"
- "--key_type=f16"
- "--value_type=f16"
- "--shapes=large"
- TEST_RUNNER
- iree_tools_testing_e2e_iree-e2e-attention-test
- TARGET_BACKENDS
- "llvm-cpu"
- DRIVERS
- "local-task"
- LABELS
- "hostonly"
- "local"
- TARGET_CPU_FEATURES_VARIANTS
- "default"
-)
-
-### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/tests/e2e/attention/generate_e2e_attention_tests.py b/tests/e2e/attention/generate_e2e_attention_tests.py
deleted file mode 100644
index f567a16..0000000
--- a/tests/e2e/attention/generate_e2e_attention_tests.py
+++ /dev/null
@@ -1,499 +0,0 @@
-#!/usr/bin/env python3
-# Copyright 2024 The IREE Authors
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-"""Generator for e2e attention tests.
-"""
-
-import argparse
-import enum
-import dataclasses
-import typing
-import math
-
-
-# Data type of kernel entries. The string values must match MLIR data types.
-@enum.unique
-class QueryElemTypeId(enum.Enum):
- NONE = ""
- F16 = "f16"
-
-
-# Data type of input entries. The string values must match MLIR data types.
-@enum.unique
-class KeyElemTypeId(enum.Enum):
- NONE = ""
- F16 = "f16"
-
-
-# Data type of input entries. The string values must match MLIR data types.
-@enum.unique
-class ValueElemTypeId(enum.Enum):
- NONE = ""
- F16 = "f16"
-
-
-# Data type of input entries. The string values must match MLIR data types.
-@enum.unique
-class ResultElemTypeId(enum.Enum):
- NONE = ""
- F16 = "f16"
-
-
-# Enumerates of the collections of shapes that we can generate tests for.
-# The values are the accepted values for the --shapes= flag.
-@enum.unique
-class ShapesId(enum.Enum):
- SMALL = "small"
- MEDIUM = "medium"
- LARGE = "large"
-
-
-# batch: Batch dimension
-# m: M dimension of first and second matmul
-# n: N dimension of second matmul
-# k1: K dimension of first matmul
-# k2: K dimension of second matmul
-@dataclasses.dataclass
-class TestShapeAndScale:
- batch: int
- m: int
- k1: int
- k2: int
- n: int
- scale: float
-
-
-# Returns the list of TestShape's to use for the collection of shapes
-# identified by shapes_id.
-def get_test_shapes(shapes_id: ShapesId):
- if shapes_id == ShapesId.SMALL:
- return [
- TestShapeAndScale(batch=2, m=512, k1=64, k2=128, n=32, scale=1.0),
- ]
- if shapes_id == ShapesId.MEDIUM:
- return [
- TestShapeAndScale(batch=2, m=1024, k1=128, k2=256, n=64, scale=1.0),
- ]
- if shapes_id == ShapesId.LARGE:
- return [
- TestShapeAndScale(batch=2, m=2048, k1=256, k2=512, n=128, scale=1.0),
- ]
-
- raise ValueError(shapes_id)
-
-
-# Determines the shape of input and kernel tensors.
-@dataclasses.dataclass
-class TestInputTensorShapes:
- batch: int
- m: int
- k1: int
- k2: int
- n: int
- scale: float
-
-
-# Helper for generate_function. Generates TestInputTensorShapes, i.e.
-# converts from the runtime shape dimensions in TestShape and given dynamicity to
-# the set of shapes to be used in a test function's input tensors.
-def generate_shapes_and_scale(shape: TestShapeAndScale):
- batch = shape.batch
- m = shape.m
- k1 = shape.k1
- k2 = shape.k2
- n = shape.n
- scale = shape.scale
-
- shapes_scale = TestInputTensorShapes(
- batch=batch,
- m=m,
- k1=k1,
- k2=k2,
- n=n,
- scale=scale,
- )
- return shapes_scale
-
-
-# Helper to return input, kernel and output shapes based on the layout and the Attention Params.
-def get_tensor_shapes(
- shapes_scale: TestShapeAndScale,
-):
- batch = shapes_scale.batch
- m = shapes_scale.m
- k1 = shapes_scale.k1
- k2 = shapes_scale.k2
- n = shapes_scale.n
- scale = shapes_scale.scale
-
- query_tensor_shape = [batch, m, k1]
- key_tensor_shape = [batch, k2, k1]
- value_tensor_shape = [batch, k2, n]
- result_tensor_shape = [batch, m, n]
-
- return query_tensor_shape, key_tensor_shape, value_tensor_shape, result_tensor_shape
-
-
-# Helper for generate_function.
-# Generates a name for a test function in the generated MLIR code.
-def generate_function_name(
- query_type: QueryElemTypeId,
- key_type: KeyElemTypeId,
- value_type: ValueElemTypeId,
- shapes_scale: TestInputTensorShapes,
-):
- query_t = query_type.value
- key_t = key_type.value
- value_t = value_type.value
- result_t = value_type.value
-
- batch = shapes_scale.batch
- m = shapes_scale.m
- k1 = shapes_scale.k1
- k2 = shapes_scale.k2
- n = shapes_scale.n
-
- attention = "attention"
- return (
- f"{attention}_{batch}_{m}_{k1}_{k2}_{n}"
- + f"_dtype_{query_t}_{key_t}_{value_t}_{result_t}"
- )
-
-
-# Represents a generated test function.
-@dataclasses.dataclass
-class MLIRFunction:
- name: str
- signature: str
- import_declaration: str
- definition: str
-
-
-# Generates a test function in the generated MLIR code.
-# The generated function will take the same arguments as iree_linalg_ext.attention variants
-# and will just call iree_linalg_ext.attention variants with them, returning its result.
-def generate_function(
- query_type: QueryElemTypeId,
- key_type: KeyElemTypeId,
- value_type: ValueElemTypeId,
- shape_scale: TestShapeAndScale,
-):
- shapes_scale = generate_shapes_and_scale(shape_scale)
- func_name = generate_function_name(
- query_type,
- key_type,
- value_type,
- shapes_scale,
- )
-
- query_shape, key_shape, value_shape, result_shape = get_tensor_shapes(shapes_scale)
- query_tensor_type = (
- f"tensor<{query_shape[0]}x{query_shape[1]}x{query_shape[2]}x{query_type.value}>"
- )
- key_tensor_type = (
- f"tensor<{key_shape[0]}x{key_shape[1]}x{key_shape[2]}x{key_type.value}>"
- )
- value_tensor_type = (
- f"tensor<{value_shape[0]}x{value_shape[1]}x{value_shape[2]}x{value_type.value}>"
- )
- result_tensor_type = f"tensor<{result_shape[0]}x{result_shape[1]}x{result_shape[2]}x{value_type.value}>"
- F32 = "f32"
- F16 = "f16"
- op_name = "iree_linalg_ext.attention"
-
- # Compilation info is optional; prints empty string by default.
- func_definition = ""
-
- signature = f"({query_tensor_type}, {key_tensor_type}, {value_tensor_type}, {result_tensor_type}) -> {result_tensor_type}"
- import_declaration = f"func.func private @module.{func_name}(%query: !hal.buffer_view, %key: !hal.buffer_view, %value: !hal.buffer_view, %scale: {F32}) -> !hal.buffer_view"
- func_definition = func_definition + (
- f"func.func @{func_name}(%query: {query_tensor_type}, %key: {key_tensor_type}, %value: {value_tensor_type}, %scale: {F32}) -> {result_tensor_type} {{\n"
- f" %result0 = tensor.empty(): {result_tensor_type}\n"
- f" %scale_f16 = arith.truncf %scale : {F32} to {F16} \n"
- f" %result1 = {op_name} {{\n"
- f" indexing_maps = [affine_map<(batch, m, n, k1, k2) -> (batch, m, k1)>,\n"
- f" affine_map<(batch, m, n, k1, k2) -> (batch, k2, k1)>,\n"
- f" affine_map<(batch, m, n, k1, k2) -> (batch, k2, n)>,\n"
- f" affine_map<(batch, m, n, k1, k2) -> (batch, m, n)>]\n}}"
- f" ins(%query, %key, %value, %scale_f16: {query_tensor_type}, {key_tensor_type}, {value_tensor_type}, {F16})\n"
- f" outs(%result0: {result_tensor_type}) -> {result_tensor_type}\n"
- f" return %result1: {result_tensor_type}\n"
- f"}}\n"
- )
- return MLIRFunction(
- name=func_name,
- signature=signature,
- import_declaration=import_declaration,
- definition=func_definition,
- )
-
-
-# Represents a call to a generated test function.
-@dataclasses.dataclass
-class TestCall:
- function: MLIRFunction
- op: str
-
-
-# Enumerates ways to initialize tensor buffer contents.
-@enum.unique
-class TensorGenerator(enum.Enum):
- ZERO = "zero" # Fill with zeros
- RANDOM = "random" # Fill with (deterministic) pseudorandom values.
-
-
-# Intentionally fixed seed! We want full reproducibility here, both across runs
-# and across machines.
-# Intentionally not shared with local_pseudorandom_state to limit the ways
-# in which shuffling testcases changes which random values are generated.
-pseudorandom_generator_seed = 1
-
-
-def contents_generator_tag(generator: TensorGenerator):
- if generator == TensorGenerator.ZERO:
- return ""
- elif generator == TensorGenerator.RANDOM:
- global pseudorandom_generator_seed
- pseudorandom_generator_seed = pseudorandom_generator_seed + 1
- return f"!tag:iree:fully_specified_pseudorandom {pseudorandom_generator_seed}"
- else:
- raise ValueError(generator)
-
-
-# Generate a 3d tensor function argument of the given size as `%name`.
-def generate_random_3d_tensor(
- name: str,
- tensor_shape: list,
- element_type: typing.Union[QueryElemTypeId, ResultElemTypeId],
-):
- global pseudorandom_generator_seed
- pseudorandom_generator_seed = pseudorandom_generator_seed + 1
- return (
- f" %{name}_dim0 = arith.constant {tensor_shape[0]} : i64\n"
- f" %{name}_dim1 = arith.constant {tensor_shape[1]} : i64\n"
- f" %{name}_dim2 = arith.constant {tensor_shape[2]} : i64\n"
- f" %{name}_element_type = hal.element_type<{element_type.value}> : i32\n"
- f" %{name}_seed = arith.constant {pseudorandom_generator_seed} : i32\n"
- f" %{name} = call @attention_test.generate_random_tensor(%device, %{name}_dim0, %{name}_dim1, %{name}_dim2, %{name}_element_type, %{name}_seed) : (!hal.device, i64, i64, i64, i32, i32) -> !hal.buffer_view\n"
- )
-
-
-call_id = 0
-
-
-def generate_call(
- function: MLIRFunction,
- query_type: QueryElemTypeId,
- key_type: KeyElemTypeId,
- value_type: ValueElemTypeId,
- shapes_scale: TestShapeAndScale,
-):
- global call_id
- func_name = f"{function.name}_{shapes_scale.batch}_{shapes_scale.m}_{shapes_scale.k1}_{shapes_scale.k2}_{shapes_scale.n}_{shapes_scale.k1}_{shapes_scale.scale}"
- func_name = f"{func_name}_{call_id}"
- call_id = call_id + 1
-
- description = f"Attention shape (BATCHxMxK1xK2xN): {shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}x{shapes_scale.k2}x{shapes_scale.k1}x{shapes_scale.n}"
- op = (
- f"func.func @{func_name}() attributes {{\n"
- f' iree.reflection = {{description = "{description}"}}\n'
- "} {\n"
- " %device_index = arith.constant 0 : index\n"
- " %device = hal.devices.get %device_index : !hal.device\n"
- )
-
- query_shape, key_shape, value_shape, result_shape = get_tensor_shapes(
- shapes_scale,
- )
-
- op = op + generate_random_3d_tensor("query", query_shape, query_type)
- op = op + generate_random_3d_tensor("key", key_shape, key_type)
- op = op + generate_random_3d_tensor("value", value_shape, value_type)
-
- global pseudorandom_generator_seed
- pseudorandom_generator_seed = pseudorandom_generator_seed - 1
- op = op + (
- f" %scale = arith.constant {shapes_scale.scale} : f32\n"
- f" %result = call @module.{function.name}(%query, %key, %value, %scale) : (!hal.buffer_view, !hal.buffer_view, !hal.buffer_view, f32) -> !hal.buffer_view\n"
- )
-
- op = op + (
- f" %batch = arith.constant {shapes_scale.batch} : i64 \n"
- f" %m = arith.constant {shapes_scale.m} : i64 \n"
- f" %k1 = arith.constant {shapes_scale.k1} : i64 \n"
- f" %k2 = arith.constant {shapes_scale.k2} : i64 \n"
- f" %n = arith.constant {shapes_scale.n} : i64 \n"
- f" %queryTensor = hal.tensor.import %query : !hal.buffer_view -> tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}xf16> \n"
- f" %keyTensor = hal.tensor.import %key : !hal.buffer_view -> tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.k1}xf16> \n"
- f" %valueTensor = hal.tensor.import %value : !hal.buffer_view -> tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.n}xf16> \n"
- f" %resultTensor = hal.tensor.import %result : !hal.buffer_view -> tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.n}xf16> \n"
- f" %queryExt = arith.extf %queryTensor : tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}xf16> to tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}xf32> \n"
- f" %keyExt = arith.extf %keyTensor : tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.k1}xf16> to tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.k1}xf32> \n"
- f" %valueExt = arith.extf %valueTensor : tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.n}xf16> to tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.n}xf32> \n"
- f" %resultExt = arith.extf %resultTensor : tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.n}xf16> to tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.n}xf32> \n"
- f" %queryExtBufferView = hal.tensor.export %queryExt : tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}xf32> -> !hal.buffer_view \n"
- f" %keyExtBufferView = hal.tensor.export %keyExt : tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.k1}xf32> -> !hal.buffer_view \n"
- f" %valueExtBufferView = hal.tensor.export %valueExt : tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.n}xf32> -> !hal.buffer_view \n"
- f" %resultExtBufferView = hal.tensor.export %resultExt : tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.n}xf32> -> !hal.buffer_view \n"
- f" call @attention_test.check_attention_results(%device, %batch, %m, %k1, %k2, %n, %queryExtBufferView, %keyExtBufferView, %valueExtBufferView, %resultExtBufferView) : (!hal.device, i64, i64, i64, i64, i64, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view) -> ()\n"
- )
-
- op = op + " return\n"
- op = op + "}\n"
-
- return TestCall(function=function, op=op)
-
-
-# Generates all output files' contents as strings.
-def generate(
- query_type: QueryElemTypeId,
- key_type: KeyElemTypeId,
- value_type: ValueElemTypeId,
- shapes_id: ShapesId,
-):
- functions = {}
- calls = []
-
- for shape in get_test_shapes(shapes_id):
- function = generate_function(
- query_type,
- key_type,
- value_type,
- shape,
- )
- if function.name not in functions:
- functions[function.name] = function
- calls.append(
- generate_call(
- function,
- query_type,
- key_type,
- value_type,
- shape,
- )
- )
-
- return (functions, calls)
-
-
-def parse_arguments():
- parser = argparse.ArgumentParser(description="Generator of e2e Attention tests")
- parser.add_argument(
- "--output_attention_mlir",
- type=str,
- help="Path of output .mlir file containing the generated Attention functions",
- required=True,
- )
- parser.add_argument(
- "--output_calls_mlir",
- type=str,
- help="Path of output .mlir file containing the calls",
- required=True,
- )
- parser.add_argument(
- "--query_type",
- type=str,
- choices=["f16"],
- help="Numeric type of query tensors ",
- required=True,
- )
- parser.add_argument(
- "--key_type",
- type=str,
- choices=["f16"],
- help="Numeric type of key tensors ",
- required=True,
- )
- parser.add_argument(
- "--value_type",
- type=str,
- choices=["f16"],
- help="Numeric type of value tensors ",
- required=True,
- )
- parser.add_argument(
- "--shapes_scale",
- type=str,
- choices=[s.value for s in ShapesId],
- help="Collection of tensor shapes to test",
- required=True,
- )
- parser.add_argument(
- "--requirements",
- type=str,
- help="Target requirements for this module. Comma-separated. As in -iree-llvmcpu-target-cpu-features. If the target device does not meet all of the requirements, the test will be skipped.",
- required=False,
- )
- return parser.parse_args()
-
-
-def write_code_file(functions, filename):
- with open(filename, "w") as file:
- for function in functions.values():
- file.write(function.definition + "\n")
-
-
-def write_calls_file(functions, calls, filename, requirements):
- # Module-level reflection information used to control the test tool.
- reflection = ""
- if requirements:
- reflection = (
- "iree.reflection = {"
- 'target_features = "'
- + ",".join([req.lstrip("+") for req in requirements.split(",")])
- + '"'
- "}"
- )
- module_definition = (
- f"builtin.module @calls attributes {{\n" f" {reflection}\n" f"}} {{\n\n"
- )
-
- # Declare the custom module that generates arguments.
- module_definition = module_definition + (
- "func.func private @attention_test.generate_random_tensor(%device: !hal.device, %dim0: i64, %dim1: i64, %dim2: i64, %element_type: i32, %seed: i32) -> !hal.buffer_view\n"
- "func.func private @attention_test.check_attention_results(%device: !hal.device, %batch: i64, %m: i64, %k1: i64, %k2: i64, %n: i64, %query: !hal.buffer_view, %key: !hal.buffer_view, %value: !hal.buffer_view, %result: !hal.buffer_view)\n"
- "\n"
- )
-
- # Declare the functions that will be called.
- for function in functions.values():
- module_definition = module_definition + function.import_declaration + "\n"
- module_definition = module_definition + "\n"
-
- # Emit the test cases for each call.
- for call in calls:
- module_definition = module_definition + call.op + "\n"
-
- module_definition = module_definition + "\n}\n"
-
- with open(filename, "w") as file:
- file.write(module_definition)
-
-
-def main(args):
- query_type = QueryElemTypeId(args.query_type)
- key_type = KeyElemTypeId(args.key_type)
- value_type = ValueElemTypeId(args.value_type)
- shapes_id = ShapesId(args.shapes_scale)
-
- (functions, calls) = generate(
- query_type,
- key_type,
- value_type,
- shapes_id,
- )
-
- write_code_file(functions, args.output_attention_mlir)
- write_calls_file(
- functions,
- calls,
- args.output_calls_mlir,
- args.requirements,
- )
-
-
-if __name__ == "__main__":
- main(parse_arguments())
diff --git a/tools/testing/e2e/BUILD.bazel b/tools/testing/e2e/BUILD.bazel
index 3976279..0c510a9 100644
--- a/tools/testing/e2e/BUILD.bazel
+++ b/tools/testing/e2e/BUILD.bazel
@@ -68,22 +68,3 @@
"//runtime/src/iree/vm:cc",
],
)
-
-iree_runtime_cc_binary(
- name = "iree-e2e-attention-test",
- srcs = ["iree-e2e-attention-test.cc"],
- deps = [
- ":e2e_test_util",
- "//runtime/src/iree/base",
- "//runtime/src/iree/base/internal",
- "//runtime/src/iree/base/internal:cpu",
- "//runtime/src/iree/base/internal:flags",
- "//runtime/src/iree/base/internal:path",
- "//runtime/src/iree/hal",
- "//runtime/src/iree/modules/hal",
- "//runtime/src/iree/tooling:context_util",
- "//runtime/src/iree/tooling:device_util",
- "//runtime/src/iree/vm",
- "//runtime/src/iree/vm:cc",
- ],
-)
diff --git a/tools/testing/e2e/CMakeLists.txt b/tools/testing/e2e/CMakeLists.txt
index ece0c59..e4fc8fb 100644
--- a/tools/testing/e2e/CMakeLists.txt
+++ b/tools/testing/e2e/CMakeLists.txt
@@ -77,24 +77,4 @@
iree::vm::cc
)
-iree_cc_binary(
- NAME
- iree-e2e-attention-test
- SRCS
- "iree-e2e-attention-test.cc"
- DEPS
- ::e2e_test_util
- iree::base
- iree::base::internal
- iree::base::internal::cpu
- iree::base::internal::flags
- iree::base::internal::path
- iree::hal
- iree::modules::hal
- iree::tooling::context_util
- iree::tooling::device_util
- iree::vm
- iree::vm::cc
-)
-
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/tools/testing/e2e/iree-e2e-attention-test.cc b/tools/testing/e2e/iree-e2e-attention-test.cc
deleted file mode 100644
index 4b0464b..0000000
--- a/tools/testing/e2e/iree-e2e-attention-test.cc
+++ /dev/null
@@ -1,486 +0,0 @@
-// Copyright 2024 The IREE Authors
-//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-#include <float.h>
-#include <math.h>
-#include <stdio.h>
-#include <stdlib.h>
-#include <string.h>
-
-#include "iree/base/api.h"
-#include "iree/base/internal/cpu.h"
-#include "iree/base/internal/flags.h"
-#include "iree/base/internal/math.h"
-#include "iree/base/internal/path.h"
-#include "iree/hal/api.h"
-#include "iree/modules/hal/module.h"
-#include "iree/tooling/context_util.h"
-#include "iree/tooling/device_util.h"
-#include "iree/vm/api.h"
-#include "iree/vm/native_module_cc.h"
-#include "tools/testing/e2e/test_utils.h"
-
-//===----------------------------------------------------------------------===//
-// Reference Attention
-//===----------------------------------------------------------------------===//
-
-// Helper for reference_attention.
-// Function to allocate and initialize tensors
-float* allocate_tensor(int dim1, int dim2, int dim3) {
- const int size = dim1 * dim2 * dim3;
- float* tensor = (float*)malloc(size * sizeof(float));
- for (int i = 0; i < size; ++i) {
- tensor[i] = 0.0f;
- }
- return tensor;
-}
-
-// Function to free allocated tensors
-void free_tensor(float* tensor) {
- if (tensor != nullptr) free(tensor);
-}
-
-// Function to calculate 1D index for a 3D array
-int index_3d(int i, int j, int k, int dim2, int dim3) {
- return i * dim2 * dim3 + j * dim3 + k;
-}
-
-static void reference_attention_f32_f32_f32_f32(
- iree_hal_dim_t M, iree_hal_dim_t K1, iree_hal_dim_t K2, iree_hal_dim_t N,
- iree_hal_dim_t B, const float* query_data, const float* key_data,
- const float* value_data, float* result_data, iree_hal_dim_t b,
- float* Attention) {
- // Compute Q * K^T
- for (int m = 0; m < M; ++m) {
- for (int k2 = 0; k2 < K2; ++k2) {
- float sum = 0.0;
- for (int k1 = 0; k1 < K1; ++k1) {
- int q_idx = index_3d(b, m, k1, M, K1);
- int k_idx = index_3d(b, k2, k1, K2, K1);
-
- sum += query_data[q_idx] * key_data[k_idx];
- }
- int att_idx = index_3d(0, m, k2, M, K2);
- Attention[att_idx] = sum / sqrt(K1); // Scale by sqrt(K1)
- }
- }
-
- // Compute softmax on Attention
- for (int m = 0; m < M; ++m) {
- // Find the maximum value for the current sequence
- float max_val = -FLT_MAX;
- for (int k2 = 0; k2 < K2; ++k2) {
- int att_idx = index_3d(0, m, k2, M, K2);
- max_val = iree_max(max_val, Attention[att_idx]);
- }
-
- // Calculate the softmax denominator
- float sum = 0.0f;
- for (int k2 = 0; k2 < K2; ++k2) {
- int att_idx = index_3d(0, m, k2, M, K2);
- sum += exp(Attention[att_idx] - max_val);
- }
-
- // Apply softmax
- for (int k2 = 0; k2 < K2; ++k2) {
- int att_idx = index_3d(0, m, k2, M, K2);
- Attention[att_idx] = exp(Attention[att_idx]) / sum;
- }
- }
-
- // Compute Attention * V
- for (int m = 0; m < M; ++m) {
- for (int n = 0; n < N; ++n) {
- float sum = 0.0;
- for (int k2 = 0; k2 < K2; ++k2) {
- int att_idx = index_3d(0, m, k2, M, K2);
- int v_idx = index_3d(b, k2, n, K2, N);
- sum += Attention[att_idx] * value_data[v_idx];
- }
- int o_idx = index_3d(b, m, n, M, N);
- result_data[o_idx] = sum;
- }
- }
-}
-
-static iree_status_t reference_attention_element(
- iree_hal_dim_t M, iree_hal_dim_t K1, iree_hal_dim_t K2, iree_hal_dim_t N,
- iree_hal_dim_t B, iree_hal_element_type_t query_elem_type,
- iree_hal_element_type_t key_elem_type,
- iree_hal_element_type_t value_elem_type, void* query_data, void* key_data,
- void* value_data, void* actual_data, void* result_data, iree_hal_dim_t b,
- float* Attention) {
- if (query_elem_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 &&
- key_elem_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 &&
- value_elem_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
- reference_attention_f32_f32_f32_f32(
- M, K1, K2, N, B, (const float*)query_data, (const float*)key_data,
- (const float*)value_data, (float*)result_data, b, Attention);
-
- } else {
- return iree_make_status(
- IREE_STATUS_INVALID_ARGUMENT,
- "unhandled combination of element types in attention");
- }
- return iree_ok_status();
-}
-
-// Reference attention implementation, used to compare attention results
-// against.
-static iree_status_t reference_attention(
- iree_hal_dim_t B, iree_hal_dim_t M, iree_hal_dim_t K1, iree_hal_dim_t K2,
- iree_hal_dim_t N, iree_hal_element_type_t query_elem_type,
- iree_hal_element_type_t key_elem_type,
- iree_hal_element_type_t value_elem_type, iree_byte_span_t query_contents,
- iree_byte_span_t key_contents, iree_byte_span_t value_contents,
- iree_byte_span_t actual_contents, iree_byte_span_t result_contents,
- int compute_every) {
- IREE_TRACE_ZONE_BEGIN(z0);
- IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, B);
- IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, M);
- IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, K1);
- IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, K2);
- IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, N);
-
- iree_host_size_t count = 0;
- float* Attention = allocate_tensor(1, M, K2);
- for (iree_hal_dim_t b = 0; b < B; ++b) {
- if (++count < compute_every) continue;
- count = 0;
- IREE_RETURN_AND_END_ZONE_IF_ERROR(
- z0,
- reference_attention_element(
- M, K1, K2, N, B, query_elem_type, key_elem_type, value_elem_type,
- query_contents.data, key_contents.data, value_contents.data,
- actual_contents.data, result_contents.data, b, Attention));
- }
- free_tensor(Attention);
-
- IREE_TRACE_ZONE_END(z0);
- return iree_ok_status();
-}
-//===----------------------------------------------------------------------===//
-// Attention comparison/logging
-//===----------------------------------------------------------------------===//
-
-typedef struct {
- iree_allocator_t host_allocator;
- iree_hal_dim_t b;
- iree_hal_dim_t m;
- iree_hal_dim_t k1;
- iree_hal_dim_t k2;
- iree_hal_dim_t n;
- iree_hal_element_type_t query_elem_type;
- iree_hal_element_type_t key_elem_type;
- iree_hal_element_type_t value_elem_type;
- iree_hal_element_type_t result_elem_type;
- iree_byte_span_t query_contents;
- iree_byte_span_t key_contents;
- iree_byte_span_t value_contents;
- iree_byte_span_t actual_contents;
- iree_byte_span_t expected_contents;
-} attention_results_t;
-
-static void attention_results_deinitialize(attention_results_t* results);
-
-static iree_status_t attention_results_initialize(
- iree_hal_device_t* device, iree_hal_dim_t b_size, iree_hal_dim_t m_size,
- iree_hal_dim_t k1_size, iree_hal_dim_t k2_size, iree_hal_dim_t n_size,
- iree_hal_buffer_view_t* query, iree_hal_buffer_view_t* key,
- iree_hal_buffer_view_t* value, iree_hal_buffer_view_t* result,
- iree_allocator_t host_allocator, attention_results_t* out_results) {
- IREE_TRACE_ZONE_BEGIN(z0);
-
- memset(out_results, 0, sizeof(*out_results));
- out_results->host_allocator = host_allocator;
-
- out_results->b = b_size;
- out_results->m = m_size;
- out_results->k1 = k1_size;
- out_results->k2 = k2_size;
- out_results->n = n_size;
-
- out_results->query_elem_type = iree_hal_buffer_view_element_type(query);
- out_results->key_elem_type = iree_hal_buffer_view_element_type(key);
- out_results->value_elem_type = iree_hal_buffer_view_element_type(value);
- out_results->result_elem_type = iree_hal_buffer_view_element_type(result);
-
- iree_hal_buffer_t* query_buffer = iree_hal_buffer_view_buffer(query);
- iree_hal_buffer_t* key_buffer = iree_hal_buffer_view_buffer(key);
- iree_hal_buffer_t* value_buffer = iree_hal_buffer_view_buffer(value);
- iree_hal_buffer_t* result_buffer = iree_hal_buffer_view_buffer(result);
-
- iree_status_t status = iree_ok_status();
-
- if (iree_status_is_ok(status)) {
- out_results->query_contents.data_length =
- iree_hal_buffer_byte_length(query_buffer);
- status = iree_allocator_malloc(host_allocator,
- out_results->query_contents.data_length,
- (void**)&out_results->query_contents.data);
- }
- if (iree_status_is_ok(status)) {
- status = iree_hal_device_transfer_d2h(
- device, query_buffer, 0, out_results->query_contents.data,
- out_results->query_contents.data_length,
- IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout());
- }
- if (iree_status_is_ok(status)) {
- out_results->key_contents.data_length =
- iree_hal_buffer_byte_length(key_buffer);
- status = iree_allocator_malloc(host_allocator,
- out_results->key_contents.data_length,
- (void**)&out_results->key_contents.data);
- }
- if (iree_status_is_ok(status)) {
- status = iree_hal_device_transfer_d2h(
- device, key_buffer, 0, out_results->key_contents.data,
- out_results->key_contents.data_length,
- IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout());
- }
- if (iree_status_is_ok(status)) {
- out_results->value_contents.data_length =
- iree_hal_buffer_byte_length(value_buffer);
- status = iree_allocator_malloc(host_allocator,
- out_results->value_contents.data_length,
- (void**)&out_results->value_contents.data);
- }
-
- if (iree_status_is_ok(status)) {
- status = iree_hal_device_transfer_d2h(
- device, value_buffer, 0, out_results->value_contents.data,
- out_results->value_contents.data_length,
- IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout());
- }
- if (iree_status_is_ok(status)) {
- out_results->actual_contents.data_length =
- iree_hal_buffer_byte_length(result_buffer);
- status = iree_allocator_malloc(host_allocator,
- out_results->actual_contents.data_length,
- (void**)&out_results->actual_contents.data);
- }
- if (iree_status_is_ok(status)) {
- status = iree_hal_device_transfer_d2h(
- device, result_buffer, 0, out_results->actual_contents.data,
- out_results->actual_contents.data_length,
- IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout());
- }
- if (iree_status_is_ok(status)) {
- out_results->expected_contents.data_length =
- iree_hal_buffer_byte_length(result_buffer);
- status = iree_allocator_malloc(
- host_allocator, out_results->expected_contents.data_length,
- (void**)&out_results->expected_contents.data);
- }
- if (!iree_status_is_ok(status)) {
- attention_results_deinitialize(out_results);
- }
- IREE_TRACE_ZONE_END(z0);
- return status;
-}
-
-static void attention_results_deinitialize(attention_results_t* results) {
- IREE_TRACE_ZONE_BEGIN(z0);
- iree_allocator_free(results->host_allocator, results->query_contents.data);
- iree_allocator_free(results->host_allocator, results->key_contents.data);
- iree_allocator_free(results->host_allocator, results->value_contents.data);
- iree_allocator_free(results->host_allocator, results->actual_contents.data);
- iree_allocator_free(results->host_allocator, results->expected_contents.data);
-
- IREE_TRACE_ZONE_END(z0);
-}
-
-// Helper for check_attention_results: the actual interesting part once we've
-// obtained and validated the {b,m,k1,k2,n}_size values. On error, detailed
-// logging is written to |file| if it is not NULL.
-static iree_status_t check_attention_results_impl(
- FILE* file, const attention_results_t* results, int check_every) {
- IREE_TRACE_ZONE_BEGIN(z0);
-
- IREE_RETURN_AND_END_ZONE_IF_ERROR(
- z0, reference_attention(results->b, results->m, results->k1, results->k2,
- results->n, results->query_elem_type,
- results->key_elem_type, results->value_elem_type,
- results->query_contents, results->key_contents,
- results->value_contents, results->actual_contents,
- results->expected_contents, check_every));
-
- IREE_TRACE_ZONE_END(z0);
- return iree_ok_status();
-}
-
-// Given an actual attention's inputs and output (all host-local), uses a
-// reference attention implementation on the same inputs to check if the output
-// is correct. On error, detailed logging is written to |file| if it is not
-// NULL.
-static iree_status_t check_attention_results(
- FILE* file, const attention_results_t* results) {
- IREE_TRACE_ZONE_BEGIN(z0);
- // TODO: Increase the check every param to reduce the number of comparisons.
- int check_every = 1;
- iree_status_t status =
- check_attention_results_impl(file, results, check_every);
- if (!iree_status_is_ok(status) && check_every > 1) {
- // If we got a failure with check_every>1, that didn't log a useful
- // numerical summary, as most of the reference matrix entries hadn't been
- // computed. Rerun now with check_every=1 to get that numerical logging.
- iree_status_ignore(status);
- status = check_attention_results_impl(file, results, 1);
- }
- IREE_TRACE_ZONE_END(z0);
- return status;
-}
-
-//===----------------------------------------------------------------------===//
-// `attention_test` custom module
-//===----------------------------------------------------------------------===//
-// This uses the C++ wrapper to keep things simple. Though easier to use it's
-// got additional overhead/code-size bloat that doesn't matter in a test like
-// this. Making a C module builder API that removes the boilerplate there is TBD
-// so this file is written in C besides this module so that we can swap it back
-// to being pure C in the future.
-
-namespace iree {
-
-class AttentionTestModuleState final {
- public:
- explicit AttentionTestModuleState(iree_allocator_t host_allocator)
- : host_allocator_(host_allocator) {}
- ~AttentionTestModuleState() = default;
-
- // Fills the destination span with pseudorandom values of the given
- // |element_type|. The given |seed| is passed to the pseudorandom generator.
- // The pseudorandom values are reproducible both across runs and across
- // machines.
- StatusOr<vm::ref<iree_hal_buffer_view_t>> GenerateRandom3dTensor(
- const vm::ref<iree_hal_device_t> device, int64_t dim0, int64_t dim1,
- int64_t dim2, iree_hal_element_type_t element_type, int32_t seed) {
- iree_hal_dim_t dims[3] = {
- (iree_hal_dim_t)dim0,
- (iree_hal_dim_t)dim1,
- (iree_hal_dim_t)dim2,
- };
- iree_hal_buffer_params_t buffer_params = {0};
- buffer_params.usage = IREE_HAL_BUFFER_USAGE_DEFAULT;
- buffer_params.access = IREE_HAL_MEMORY_ACCESS_ALL;
- buffer_params.type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE;
- vm::ref<iree_hal_buffer_view_t> result_view;
- struct callback_state_t {
- iree_hal_element_type_t element_type;
- int32_t seed;
- } callback_state = {
- element_type,
- seed,
- };
- IREE_RETURN_IF_ERROR(iree_hal_buffer_view_generate_buffer(
- device.get(), iree_hal_device_allocator(device.get()),
- IREE_ARRAYSIZE(dims), dims, element_type,
- IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, buffer_params,
- +[](iree_hal_buffer_mapping_t* mapping, void* user_data) {
- callback_state_t callback_state = *(callback_state_t*)user_data;
- iree_byte_span_t span = mapping->contents;
- // Generate "uniform" integer-valued numbers in the range [min, max].
- int32_t min = 0;
- int32_t max = 0;
- iree_test_utils_get_min_max_for_element_type(
- callback_state.element_type, &min, &max);
- uint32_t range = (max - min + 1);
- iree_host_size_t element_byte_count =
- iree_hal_element_dense_byte_count(callback_state.element_type);
- uint8_t* data_end = span.data + span.data_length;
- uint32_t state = callback_state.seed;
- for (uint8_t* data = span.data; data < data_end;
- data += element_byte_count) {
- int32_t value =
- (int32_t)iree_test_utils_pseudorandom_range(&state, range) +
- min;
- iree_test_utils_write_element(callback_state.element_type, value,
- data);
- }
- return iree_ok_status();
- },
- &callback_state, &result_view));
- return std::move(result_view);
- }
-
- Status CheckAttentionResults(
- const vm::ref<iree_hal_device_t> device, int64_t b, int64_t m, int64_t k1,
- int64_t k2, int64_t n, const vm::ref<iree_hal_buffer_view_t> query,
- const vm::ref<iree_hal_buffer_view_t> key,
- const vm::ref<iree_hal_buffer_view_t> value,
- const vm::ref<iree_hal_buffer_view_t> actual_result) {
- attention_results_t results = {};
- IREE_RETURN_IF_ERROR(attention_results_initialize(
- device.get(), (iree_hal_dim_t)b, (iree_hal_dim_t)m, (iree_hal_dim_t)k1,
- (iree_hal_dim_t)k2, (iree_hal_dim_t)n, query.get(), key.get(),
- value.get(), actual_result.get(), host_allocator_, &results));
- iree_status_t status = check_attention_results(stderr, &results);
- attention_results_deinitialize(&results);
- return status;
- }
-
- private:
- iree_allocator_t host_allocator_;
-};
-
-static const vm::NativeFunction<AttentionTestModuleState>
- kAttentionTestModuleFunctions[] = {
- vm::MakeNativeFunction(
- "generate_random_tensor",
- &AttentionTestModuleState::GenerateRandom3dTensor),
- vm::MakeNativeFunction(
- "check_attention_results",
- &AttentionTestModuleState::CheckAttentionResults),
-};
-
-struct AttentionTestModule final
- : public vm::NativeModule<AttentionTestModuleState> {
- using vm::NativeModule<AttentionTestModuleState>::NativeModule;
- StatusOr<std::unique_ptr<AttentionTestModuleState>> CreateState(
- iree_allocator_t host_allocator) override {
- return std::make_unique<AttentionTestModuleState>(host_allocator);
- }
-};
-
-} // namespace iree
-
-static iree_status_t attention_test_module_create(
- iree_vm_instance_t* instance, iree_allocator_t host_allocator,
- iree_vm_module_t** out_module) {
- IREE_ASSERT_ARGUMENT(out_module);
- *out_module = NULL;
- auto module = std::make_unique<iree::AttentionTestModule>(
- "attention_test", /*version=*/0, instance, host_allocator,
- iree::span<
- const iree::vm::NativeFunction<iree::AttentionTestModuleState>>(
- iree::kAttentionTestModuleFunctions));
- *out_module = module.release()->interface();
- return iree_ok_status();
-}
-
-int main(int argc, char** argv) {
- IREE_TRACE_APP_ENTER();
-
- iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv);
- if (argc != 1) {
- fprintf(stderr, "use --module= flags to specify the modules to run\n");
- IREE_TRACE_APP_EXIT(EXIT_FAILURE);
- return EXIT_FAILURE;
- }
-
- iree_status_t status = iree_test_utils_load_and_run_e2e_tests(
- iree_allocator_system(), attention_test_module_create);
- int exit_code = EXIT_SUCCESS;
- if (!iree_status_is_ok(status)) {
- iree_status_fprint(stderr, status);
- bool is_unavailable = iree_status_is_unavailable(status);
- iree_status_free(status);
- exit_code = is_unavailable ? EXIT_SUCCESS : EXIT_FAILURE;
- }
-
- IREE_TRACE_APP_EXIT(exit_code);
- return exit_code;
-}