Add e2e test suite for the Attention - CPU Backend (#17751)
Add the e2e test suite for the Attention. For now, it only checks CPU
FP16, and the reference implementation is FP32.
---------
Signed-off-by: ERMAN GURSES <erman@nod-labs.com>
Co-authored-by: ian <ian.nordeng@amd.com>
diff --git a/tests/e2e/attention/BUILD.bazel b/tests/e2e/attention/BUILD.bazel
new file mode 100644
index 0000000..3e9e41d
--- /dev/null
+++ b/tests/e2e/attention/BUILD.bazel
@@ -0,0 +1,53 @@
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# End-to-end attention tests.
+
+load("//build_tools/bazel:iree_e2e_generated_runner_test.bzl", "iree_generated_e2e_runner_test")
+
+package(
+ features = ["layering_check"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+py_binary(
+ name = "generate_e2e_attention_tests",
+ srcs = ["generate_e2e_attention_tests.py"],
+)
+
+###########################################################################
+##
+## LLVMCPU backend
+##
+###########################################################################
+
+# Default CPU backend.
+[iree_generated_e2e_runner_test(
+ name = "e2e_attention_cpu_%s_%s_%s_%s" % (dtype, dtype, dtype, size),
+ generator = ":generate_e2e_attention_tests",
+ generator_args = [
+ "--query_type=%s" % dtype,
+ "--key_type=%s" % dtype,
+ "--value_type=%s" % dtype,
+ "--shapes=%s" % size,
+ ],
+ tags = [
+ "hostonly",
+ "local",
+ ],
+ target_backends_and_drivers = [
+ ("llvm-cpu", "local-task"),
+ ],
+ target_cpu_features_variants = ["default"],
+ test_runner = "//tools/testing/e2e:iree-e2e-attention-test",
+ test_type = "attention",
+) for dtype in [
+ "f16",
+] for size in [
+ "small",
+ "medium",
+ "large",
+]]
diff --git a/tests/e2e/attention/CMakeLists.txt b/tests/e2e/attention/CMakeLists.txt
new file mode 100644
index 0000000..f793784
--- /dev/null
+++ b/tests/e2e/attention/CMakeLists.txt
@@ -0,0 +1,88 @@
+################################################################################
+# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
+# tests/e2e/attention/BUILD.bazel #
+# #
+# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
+# CMake-only content. #
+# #
+# To disable autogeneration for this file entirely, delete this header. #
+################################################################################
+
+iree_add_all_subdirs()
+
+iree_generated_e2e_runner_test(
+ NAME
+ e2e_attention_cpu_f16_f16_f16_small
+ TEST_TYPE
+ attention
+ GENERATOR
+ "generate_e2e_attention_tests.py"
+ GENERATOR_ARGS
+ "--query_type=f16"
+ "--key_type=f16"
+ "--value_type=f16"
+ "--shapes=small"
+ TEST_RUNNER
+ iree_tools_testing_e2e_iree-e2e-attention-test
+ TARGET_BACKENDS
+ "llvm-cpu"
+ DRIVERS
+ "local-task"
+ LABELS
+ "hostonly"
+ "local"
+ TARGET_CPU_FEATURES_VARIANTS
+ "default"
+)
+
+iree_generated_e2e_runner_test(
+ NAME
+ e2e_attention_cpu_f16_f16_f16_medium
+ TEST_TYPE
+ attention
+ GENERATOR
+ "generate_e2e_attention_tests.py"
+ GENERATOR_ARGS
+ "--query_type=f16"
+ "--key_type=f16"
+ "--value_type=f16"
+ "--shapes=medium"
+ TEST_RUNNER
+ iree_tools_testing_e2e_iree-e2e-attention-test
+ TARGET_BACKENDS
+ "llvm-cpu"
+ DRIVERS
+ "local-task"
+ LABELS
+ "hostonly"
+ "local"
+ TARGET_CPU_FEATURES_VARIANTS
+ "default"
+)
+
+iree_generated_e2e_runner_test(
+ NAME
+ e2e_attention_cpu_f16_f16_f16_large
+ TEST_TYPE
+ attention
+ GENERATOR
+ "generate_e2e_attention_tests.py"
+ GENERATOR_ARGS
+ "--query_type=f16"
+ "--key_type=f16"
+ "--value_type=f16"
+ "--shapes=large"
+ TEST_RUNNER
+ iree_tools_testing_e2e_iree-e2e-attention-test
+ TARGET_BACKENDS
+ "llvm-cpu"
+ DRIVERS
+ "local-task"
+ LABELS
+ "hostonly"
+ "local"
+ TARGET_CPU_FEATURES_VARIANTS
+ "default"
+)
+
+### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/tests/e2e/attention/generate_e2e_attention_tests.py b/tests/e2e/attention/generate_e2e_attention_tests.py
new file mode 100644
index 0000000..f567a16
--- /dev/null
+++ b/tests/e2e/attention/generate_e2e_attention_tests.py
@@ -0,0 +1,499 @@
+#!/usr/bin/env python3
+# Copyright 2024 The IREE Authors
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+"""Generator for e2e attention tests.
+"""
+
+import argparse
+import enum
+import dataclasses
+import typing
+import math
+
+
+# Data type of kernel entries. The string values must match MLIR data types.
+@enum.unique
+class QueryElemTypeId(enum.Enum):
+ NONE = ""
+ F16 = "f16"
+
+
+# Data type of input entries. The string values must match MLIR data types.
+@enum.unique
+class KeyElemTypeId(enum.Enum):
+ NONE = ""
+ F16 = "f16"
+
+
+# Data type of input entries. The string values must match MLIR data types.
+@enum.unique
+class ValueElemTypeId(enum.Enum):
+ NONE = ""
+ F16 = "f16"
+
+
+# Data type of input entries. The string values must match MLIR data types.
+@enum.unique
+class ResultElemTypeId(enum.Enum):
+ NONE = ""
+ F16 = "f16"
+
+
+# Enumerates of the collections of shapes that we can generate tests for.
+# The values are the accepted values for the --shapes= flag.
+@enum.unique
+class ShapesId(enum.Enum):
+ SMALL = "small"
+ MEDIUM = "medium"
+ LARGE = "large"
+
+
+# batch: Batch dimension
+# m: M dimension of first and second matmul
+# n: N dimension of second matmul
+# k1: K dimension of first matmul
+# k2: K dimension of second matmul
+@dataclasses.dataclass
+class TestShapeAndScale:
+ batch: int
+ m: int
+ k1: int
+ k2: int
+ n: int
+ scale: float
+
+
+# Returns the list of TestShape's to use for the collection of shapes
+# identified by shapes_id.
+def get_test_shapes(shapes_id: ShapesId):
+ if shapes_id == ShapesId.SMALL:
+ return [
+ TestShapeAndScale(batch=2, m=512, k1=64, k2=128, n=32, scale=1.0),
+ ]
+ if shapes_id == ShapesId.MEDIUM:
+ return [
+ TestShapeAndScale(batch=2, m=1024, k1=128, k2=256, n=64, scale=1.0),
+ ]
+ if shapes_id == ShapesId.LARGE:
+ return [
+ TestShapeAndScale(batch=2, m=2048, k1=256, k2=512, n=128, scale=1.0),
+ ]
+
+ raise ValueError(shapes_id)
+
+
+# Determines the shape of input and kernel tensors.
+@dataclasses.dataclass
+class TestInputTensorShapes:
+ batch: int
+ m: int
+ k1: int
+ k2: int
+ n: int
+ scale: float
+
+
+# Helper for generate_function. Generates TestInputTensorShapes, i.e.
+# converts from the runtime shape dimensions in TestShape and given dynamicity to
+# the set of shapes to be used in a test function's input tensors.
+def generate_shapes_and_scale(shape: TestShapeAndScale):
+ batch = shape.batch
+ m = shape.m
+ k1 = shape.k1
+ k2 = shape.k2
+ n = shape.n
+ scale = shape.scale
+
+ shapes_scale = TestInputTensorShapes(
+ batch=batch,
+ m=m,
+ k1=k1,
+ k2=k2,
+ n=n,
+ scale=scale,
+ )
+ return shapes_scale
+
+
+# Helper to return input, kernel and output shapes based on the layout and the Attention Params.
+def get_tensor_shapes(
+ shapes_scale: TestShapeAndScale,
+):
+ batch = shapes_scale.batch
+ m = shapes_scale.m
+ k1 = shapes_scale.k1
+ k2 = shapes_scale.k2
+ n = shapes_scale.n
+ scale = shapes_scale.scale
+
+ query_tensor_shape = [batch, m, k1]
+ key_tensor_shape = [batch, k2, k1]
+ value_tensor_shape = [batch, k2, n]
+ result_tensor_shape = [batch, m, n]
+
+ return query_tensor_shape, key_tensor_shape, value_tensor_shape, result_tensor_shape
+
+
+# Helper for generate_function.
+# Generates a name for a test function in the generated MLIR code.
+def generate_function_name(
+ query_type: QueryElemTypeId,
+ key_type: KeyElemTypeId,
+ value_type: ValueElemTypeId,
+ shapes_scale: TestInputTensorShapes,
+):
+ query_t = query_type.value
+ key_t = key_type.value
+ value_t = value_type.value
+ result_t = value_type.value
+
+ batch = shapes_scale.batch
+ m = shapes_scale.m
+ k1 = shapes_scale.k1
+ k2 = shapes_scale.k2
+ n = shapes_scale.n
+
+ attention = "attention"
+ return (
+ f"{attention}_{batch}_{m}_{k1}_{k2}_{n}"
+ + f"_dtype_{query_t}_{key_t}_{value_t}_{result_t}"
+ )
+
+
+# Represents a generated test function.
+@dataclasses.dataclass
+class MLIRFunction:
+ name: str
+ signature: str
+ import_declaration: str
+ definition: str
+
+
+# Generates a test function in the generated MLIR code.
+# The generated function will take the same arguments as iree_linalg_ext.attention variants
+# and will just call iree_linalg_ext.attention variants with them, returning its result.
+def generate_function(
+ query_type: QueryElemTypeId,
+ key_type: KeyElemTypeId,
+ value_type: ValueElemTypeId,
+ shape_scale: TestShapeAndScale,
+):
+ shapes_scale = generate_shapes_and_scale(shape_scale)
+ func_name = generate_function_name(
+ query_type,
+ key_type,
+ value_type,
+ shapes_scale,
+ )
+
+ query_shape, key_shape, value_shape, result_shape = get_tensor_shapes(shapes_scale)
+ query_tensor_type = (
+ f"tensor<{query_shape[0]}x{query_shape[1]}x{query_shape[2]}x{query_type.value}>"
+ )
+ key_tensor_type = (
+ f"tensor<{key_shape[0]}x{key_shape[1]}x{key_shape[2]}x{key_type.value}>"
+ )
+ value_tensor_type = (
+ f"tensor<{value_shape[0]}x{value_shape[1]}x{value_shape[2]}x{value_type.value}>"
+ )
+ result_tensor_type = f"tensor<{result_shape[0]}x{result_shape[1]}x{result_shape[2]}x{value_type.value}>"
+ F32 = "f32"
+ F16 = "f16"
+ op_name = "iree_linalg_ext.attention"
+
+ # Compilation info is optional; prints empty string by default.
+ func_definition = ""
+
+ signature = f"({query_tensor_type}, {key_tensor_type}, {value_tensor_type}, {result_tensor_type}) -> {result_tensor_type}"
+ import_declaration = f"func.func private @module.{func_name}(%query: !hal.buffer_view, %key: !hal.buffer_view, %value: !hal.buffer_view, %scale: {F32}) -> !hal.buffer_view"
+ func_definition = func_definition + (
+ f"func.func @{func_name}(%query: {query_tensor_type}, %key: {key_tensor_type}, %value: {value_tensor_type}, %scale: {F32}) -> {result_tensor_type} {{\n"
+ f" %result0 = tensor.empty(): {result_tensor_type}\n"
+ f" %scale_f16 = arith.truncf %scale : {F32} to {F16} \n"
+ f" %result1 = {op_name} {{\n"
+ f" indexing_maps = [affine_map<(batch, m, n, k1, k2) -> (batch, m, k1)>,\n"
+ f" affine_map<(batch, m, n, k1, k2) -> (batch, k2, k1)>,\n"
+ f" affine_map<(batch, m, n, k1, k2) -> (batch, k2, n)>,\n"
+ f" affine_map<(batch, m, n, k1, k2) -> (batch, m, n)>]\n}}"
+ f" ins(%query, %key, %value, %scale_f16: {query_tensor_type}, {key_tensor_type}, {value_tensor_type}, {F16})\n"
+ f" outs(%result0: {result_tensor_type}) -> {result_tensor_type}\n"
+ f" return %result1: {result_tensor_type}\n"
+ f"}}\n"
+ )
+ return MLIRFunction(
+ name=func_name,
+ signature=signature,
+ import_declaration=import_declaration,
+ definition=func_definition,
+ )
+
+
+# Represents a call to a generated test function.
+@dataclasses.dataclass
+class TestCall:
+ function: MLIRFunction
+ op: str
+
+
+# Enumerates ways to initialize tensor buffer contents.
+@enum.unique
+class TensorGenerator(enum.Enum):
+ ZERO = "zero" # Fill with zeros
+ RANDOM = "random" # Fill with (deterministic) pseudorandom values.
+
+
+# Intentionally fixed seed! We want full reproducibility here, both across runs
+# and across machines.
+# Intentionally not shared with local_pseudorandom_state to limit the ways
+# in which shuffling testcases changes which random values are generated.
+pseudorandom_generator_seed = 1
+
+
+def contents_generator_tag(generator: TensorGenerator):
+ if generator == TensorGenerator.ZERO:
+ return ""
+ elif generator == TensorGenerator.RANDOM:
+ global pseudorandom_generator_seed
+ pseudorandom_generator_seed = pseudorandom_generator_seed + 1
+ return f"!tag:iree:fully_specified_pseudorandom {pseudorandom_generator_seed}"
+ else:
+ raise ValueError(generator)
+
+
+# Generate a 3d tensor function argument of the given size as `%name`.
+def generate_random_3d_tensor(
+ name: str,
+ tensor_shape: list,
+ element_type: typing.Union[QueryElemTypeId, ResultElemTypeId],
+):
+ global pseudorandom_generator_seed
+ pseudorandom_generator_seed = pseudorandom_generator_seed + 1
+ return (
+ f" %{name}_dim0 = arith.constant {tensor_shape[0]} : i64\n"
+ f" %{name}_dim1 = arith.constant {tensor_shape[1]} : i64\n"
+ f" %{name}_dim2 = arith.constant {tensor_shape[2]} : i64\n"
+ f" %{name}_element_type = hal.element_type<{element_type.value}> : i32\n"
+ f" %{name}_seed = arith.constant {pseudorandom_generator_seed} : i32\n"
+ f" %{name} = call @attention_test.generate_random_tensor(%device, %{name}_dim0, %{name}_dim1, %{name}_dim2, %{name}_element_type, %{name}_seed) : (!hal.device, i64, i64, i64, i32, i32) -> !hal.buffer_view\n"
+ )
+
+
+call_id = 0
+
+
+def generate_call(
+ function: MLIRFunction,
+ query_type: QueryElemTypeId,
+ key_type: KeyElemTypeId,
+ value_type: ValueElemTypeId,
+ shapes_scale: TestShapeAndScale,
+):
+ global call_id
+ func_name = f"{function.name}_{shapes_scale.batch}_{shapes_scale.m}_{shapes_scale.k1}_{shapes_scale.k2}_{shapes_scale.n}_{shapes_scale.k1}_{shapes_scale.scale}"
+ func_name = f"{func_name}_{call_id}"
+ call_id = call_id + 1
+
+ description = f"Attention shape (BATCHxMxK1xK2xN): {shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}x{shapes_scale.k2}x{shapes_scale.k1}x{shapes_scale.n}"
+ op = (
+ f"func.func @{func_name}() attributes {{\n"
+ f' iree.reflection = {{description = "{description}"}}\n'
+ "} {\n"
+ " %device_index = arith.constant 0 : index\n"
+ " %device = hal.devices.get %device_index : !hal.device\n"
+ )
+
+ query_shape, key_shape, value_shape, result_shape = get_tensor_shapes(
+ shapes_scale,
+ )
+
+ op = op + generate_random_3d_tensor("query", query_shape, query_type)
+ op = op + generate_random_3d_tensor("key", key_shape, key_type)
+ op = op + generate_random_3d_tensor("value", value_shape, value_type)
+
+ global pseudorandom_generator_seed
+ pseudorandom_generator_seed = pseudorandom_generator_seed - 1
+ op = op + (
+ f" %scale = arith.constant {shapes_scale.scale} : f32\n"
+ f" %result = call @module.{function.name}(%query, %key, %value, %scale) : (!hal.buffer_view, !hal.buffer_view, !hal.buffer_view, f32) -> !hal.buffer_view\n"
+ )
+
+ op = op + (
+ f" %batch = arith.constant {shapes_scale.batch} : i64 \n"
+ f" %m = arith.constant {shapes_scale.m} : i64 \n"
+ f" %k1 = arith.constant {shapes_scale.k1} : i64 \n"
+ f" %k2 = arith.constant {shapes_scale.k2} : i64 \n"
+ f" %n = arith.constant {shapes_scale.n} : i64 \n"
+ f" %queryTensor = hal.tensor.import %query : !hal.buffer_view -> tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}xf16> \n"
+ f" %keyTensor = hal.tensor.import %key : !hal.buffer_view -> tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.k1}xf16> \n"
+ f" %valueTensor = hal.tensor.import %value : !hal.buffer_view -> tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.n}xf16> \n"
+ f" %resultTensor = hal.tensor.import %result : !hal.buffer_view -> tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.n}xf16> \n"
+ f" %queryExt = arith.extf %queryTensor : tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}xf16> to tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}xf32> \n"
+ f" %keyExt = arith.extf %keyTensor : tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.k1}xf16> to tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.k1}xf32> \n"
+ f" %valueExt = arith.extf %valueTensor : tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.n}xf16> to tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.n}xf32> \n"
+ f" %resultExt = arith.extf %resultTensor : tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.n}xf16> to tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.n}xf32> \n"
+ f" %queryExtBufferView = hal.tensor.export %queryExt : tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}xf32> -> !hal.buffer_view \n"
+ f" %keyExtBufferView = hal.tensor.export %keyExt : tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.k1}xf32> -> !hal.buffer_view \n"
+ f" %valueExtBufferView = hal.tensor.export %valueExt : tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.n}xf32> -> !hal.buffer_view \n"
+ f" %resultExtBufferView = hal.tensor.export %resultExt : tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.n}xf32> -> !hal.buffer_view \n"
+ f" call @attention_test.check_attention_results(%device, %batch, %m, %k1, %k2, %n, %queryExtBufferView, %keyExtBufferView, %valueExtBufferView, %resultExtBufferView) : (!hal.device, i64, i64, i64, i64, i64, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view) -> ()\n"
+ )
+
+ op = op + " return\n"
+ op = op + "}\n"
+
+ return TestCall(function=function, op=op)
+
+
+# Generates all output files' contents as strings.
+def generate(
+ query_type: QueryElemTypeId,
+ key_type: KeyElemTypeId,
+ value_type: ValueElemTypeId,
+ shapes_id: ShapesId,
+):
+ functions = {}
+ calls = []
+
+ for shape in get_test_shapes(shapes_id):
+ function = generate_function(
+ query_type,
+ key_type,
+ value_type,
+ shape,
+ )
+ if function.name not in functions:
+ functions[function.name] = function
+ calls.append(
+ generate_call(
+ function,
+ query_type,
+ key_type,
+ value_type,
+ shape,
+ )
+ )
+
+ return (functions, calls)
+
+
+def parse_arguments():
+ parser = argparse.ArgumentParser(description="Generator of e2e Attention tests")
+ parser.add_argument(
+ "--output_attention_mlir",
+ type=str,
+ help="Path of output .mlir file containing the generated Attention functions",
+ required=True,
+ )
+ parser.add_argument(
+ "--output_calls_mlir",
+ type=str,
+ help="Path of output .mlir file containing the calls",
+ required=True,
+ )
+ parser.add_argument(
+ "--query_type",
+ type=str,
+ choices=["f16"],
+ help="Numeric type of query tensors ",
+ required=True,
+ )
+ parser.add_argument(
+ "--key_type",
+ type=str,
+ choices=["f16"],
+ help="Numeric type of key tensors ",
+ required=True,
+ )
+ parser.add_argument(
+ "--value_type",
+ type=str,
+ choices=["f16"],
+ help="Numeric type of value tensors ",
+ required=True,
+ )
+ parser.add_argument(
+ "--shapes_scale",
+ type=str,
+ choices=[s.value for s in ShapesId],
+ help="Collection of tensor shapes to test",
+ required=True,
+ )
+ parser.add_argument(
+ "--requirements",
+ type=str,
+ help="Target requirements for this module. Comma-separated. As in -iree-llvmcpu-target-cpu-features. If the target device does not meet all of the requirements, the test will be skipped.",
+ required=False,
+ )
+ return parser.parse_args()
+
+
+def write_code_file(functions, filename):
+ with open(filename, "w") as file:
+ for function in functions.values():
+ file.write(function.definition + "\n")
+
+
+def write_calls_file(functions, calls, filename, requirements):
+ # Module-level reflection information used to control the test tool.
+ reflection = ""
+ if requirements:
+ reflection = (
+ "iree.reflection = {"
+ 'target_features = "'
+ + ",".join([req.lstrip("+") for req in requirements.split(",")])
+ + '"'
+ "}"
+ )
+ module_definition = (
+ f"builtin.module @calls attributes {{\n" f" {reflection}\n" f"}} {{\n\n"
+ )
+
+ # Declare the custom module that generates arguments.
+ module_definition = module_definition + (
+ "func.func private @attention_test.generate_random_tensor(%device: !hal.device, %dim0: i64, %dim1: i64, %dim2: i64, %element_type: i32, %seed: i32) -> !hal.buffer_view\n"
+ "func.func private @attention_test.check_attention_results(%device: !hal.device, %batch: i64, %m: i64, %k1: i64, %k2: i64, %n: i64, %query: !hal.buffer_view, %key: !hal.buffer_view, %value: !hal.buffer_view, %result: !hal.buffer_view)\n"
+ "\n"
+ )
+
+ # Declare the functions that will be called.
+ for function in functions.values():
+ module_definition = module_definition + function.import_declaration + "\n"
+ module_definition = module_definition + "\n"
+
+ # Emit the test cases for each call.
+ for call in calls:
+ module_definition = module_definition + call.op + "\n"
+
+ module_definition = module_definition + "\n}\n"
+
+ with open(filename, "w") as file:
+ file.write(module_definition)
+
+
+def main(args):
+ query_type = QueryElemTypeId(args.query_type)
+ key_type = KeyElemTypeId(args.key_type)
+ value_type = ValueElemTypeId(args.value_type)
+ shapes_id = ShapesId(args.shapes_scale)
+
+ (functions, calls) = generate(
+ query_type,
+ key_type,
+ value_type,
+ shapes_id,
+ )
+
+ write_code_file(functions, args.output_attention_mlir)
+ write_calls_file(
+ functions,
+ calls,
+ args.output_calls_mlir,
+ args.requirements,
+ )
+
+
+if __name__ == "__main__":
+ main(parse_arguments())
diff --git a/tools/testing/e2e/BUILD.bazel b/tools/testing/e2e/BUILD.bazel
index 0c510a9..3976279 100644
--- a/tools/testing/e2e/BUILD.bazel
+++ b/tools/testing/e2e/BUILD.bazel
@@ -68,3 +68,22 @@
"//runtime/src/iree/vm:cc",
],
)
+
+iree_runtime_cc_binary(
+ name = "iree-e2e-attention-test",
+ srcs = ["iree-e2e-attention-test.cc"],
+ deps = [
+ ":e2e_test_util",
+ "//runtime/src/iree/base",
+ "//runtime/src/iree/base/internal",
+ "//runtime/src/iree/base/internal:cpu",
+ "//runtime/src/iree/base/internal:flags",
+ "//runtime/src/iree/base/internal:path",
+ "//runtime/src/iree/hal",
+ "//runtime/src/iree/modules/hal",
+ "//runtime/src/iree/tooling:context_util",
+ "//runtime/src/iree/tooling:device_util",
+ "//runtime/src/iree/vm",
+ "//runtime/src/iree/vm:cc",
+ ],
+)
diff --git a/tools/testing/e2e/CMakeLists.txt b/tools/testing/e2e/CMakeLists.txt
index e4fc8fb..ece0c59 100644
--- a/tools/testing/e2e/CMakeLists.txt
+++ b/tools/testing/e2e/CMakeLists.txt
@@ -77,4 +77,24 @@
iree::vm::cc
)
+iree_cc_binary(
+ NAME
+ iree-e2e-attention-test
+ SRCS
+ "iree-e2e-attention-test.cc"
+ DEPS
+ ::e2e_test_util
+ iree::base
+ iree::base::internal
+ iree::base::internal::cpu
+ iree::base::internal::flags
+ iree::base::internal::path
+ iree::hal
+ iree::modules::hal
+ iree::tooling::context_util
+ iree::tooling::device_util
+ iree::vm
+ iree::vm::cc
+)
+
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
diff --git a/tools/testing/e2e/iree-e2e-attention-test.cc b/tools/testing/e2e/iree-e2e-attention-test.cc
new file mode 100644
index 0000000..4b0464b
--- /dev/null
+++ b/tools/testing/e2e/iree-e2e-attention-test.cc
@@ -0,0 +1,486 @@
+// Copyright 2024 The IREE Authors
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include <float.h>
+#include <math.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+
+#include "iree/base/api.h"
+#include "iree/base/internal/cpu.h"
+#include "iree/base/internal/flags.h"
+#include "iree/base/internal/math.h"
+#include "iree/base/internal/path.h"
+#include "iree/hal/api.h"
+#include "iree/modules/hal/module.h"
+#include "iree/tooling/context_util.h"
+#include "iree/tooling/device_util.h"
+#include "iree/vm/api.h"
+#include "iree/vm/native_module_cc.h"
+#include "tools/testing/e2e/test_utils.h"
+
+//===----------------------------------------------------------------------===//
+// Reference Attention
+//===----------------------------------------------------------------------===//
+
+// Helper for reference_attention.
+// Function to allocate and initialize tensors
+float* allocate_tensor(int dim1, int dim2, int dim3) {
+ const int size = dim1 * dim2 * dim3;
+ float* tensor = (float*)malloc(size * sizeof(float));
+ for (int i = 0; i < size; ++i) {
+ tensor[i] = 0.0f;
+ }
+ return tensor;
+}
+
+// Function to free allocated tensors
+void free_tensor(float* tensor) {
+ if (tensor != nullptr) free(tensor);
+}
+
+// Function to calculate 1D index for a 3D array
+int index_3d(int i, int j, int k, int dim2, int dim3) {
+ return i * dim2 * dim3 + j * dim3 + k;
+}
+
+static void reference_attention_f32_f32_f32_f32(
+ iree_hal_dim_t M, iree_hal_dim_t K1, iree_hal_dim_t K2, iree_hal_dim_t N,
+ iree_hal_dim_t B, const float* query_data, const float* key_data,
+ const float* value_data, float* result_data, iree_hal_dim_t b,
+ float* Attention) {
+ // Compute Q * K^T
+ for (int m = 0; m < M; ++m) {
+ for (int k2 = 0; k2 < K2; ++k2) {
+ float sum = 0.0;
+ for (int k1 = 0; k1 < K1; ++k1) {
+ int q_idx = index_3d(b, m, k1, M, K1);
+ int k_idx = index_3d(b, k2, k1, K2, K1);
+
+ sum += query_data[q_idx] * key_data[k_idx];
+ }
+ int att_idx = index_3d(0, m, k2, M, K2);
+ Attention[att_idx] = sum / sqrt(K1); // Scale by sqrt(K1)
+ }
+ }
+
+ // Compute softmax on Attention
+ for (int m = 0; m < M; ++m) {
+ // Find the maximum value for the current sequence
+ float max_val = -FLT_MAX;
+ for (int k2 = 0; k2 < K2; ++k2) {
+ int att_idx = index_3d(0, m, k2, M, K2);
+ max_val = iree_max(max_val, Attention[att_idx]);
+ }
+
+ // Calculate the softmax denominator
+ float sum = 0.0f;
+ for (int k2 = 0; k2 < K2; ++k2) {
+ int att_idx = index_3d(0, m, k2, M, K2);
+ sum += exp(Attention[att_idx] - max_val);
+ }
+
+ // Apply softmax
+ for (int k2 = 0; k2 < K2; ++k2) {
+ int att_idx = index_3d(0, m, k2, M, K2);
+ Attention[att_idx] = exp(Attention[att_idx]) / sum;
+ }
+ }
+
+ // Compute Attention * V
+ for (int m = 0; m < M; ++m) {
+ for (int n = 0; n < N; ++n) {
+ float sum = 0.0;
+ for (int k2 = 0; k2 < K2; ++k2) {
+ int att_idx = index_3d(0, m, k2, M, K2);
+ int v_idx = index_3d(b, k2, n, K2, N);
+ sum += Attention[att_idx] * value_data[v_idx];
+ }
+ int o_idx = index_3d(b, m, n, M, N);
+ result_data[o_idx] = sum;
+ }
+ }
+}
+
+static iree_status_t reference_attention_element(
+ iree_hal_dim_t M, iree_hal_dim_t K1, iree_hal_dim_t K2, iree_hal_dim_t N,
+ iree_hal_dim_t B, iree_hal_element_type_t query_elem_type,
+ iree_hal_element_type_t key_elem_type,
+ iree_hal_element_type_t value_elem_type, void* query_data, void* key_data,
+ void* value_data, void* actual_data, void* result_data, iree_hal_dim_t b,
+ float* Attention) {
+ if (query_elem_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 &&
+ key_elem_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32 &&
+ value_elem_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
+ reference_attention_f32_f32_f32_f32(
+ M, K1, K2, N, B, (const float*)query_data, (const float*)key_data,
+ (const float*)value_data, (float*)result_data, b, Attention);
+
+ } else {
+ return iree_make_status(
+ IREE_STATUS_INVALID_ARGUMENT,
+ "unhandled combination of element types in attention");
+ }
+ return iree_ok_status();
+}
+
+// Reference attention implementation, used to compare attention results
+// against.
+static iree_status_t reference_attention(
+ iree_hal_dim_t B, iree_hal_dim_t M, iree_hal_dim_t K1, iree_hal_dim_t K2,
+ iree_hal_dim_t N, iree_hal_element_type_t query_elem_type,
+ iree_hal_element_type_t key_elem_type,
+ iree_hal_element_type_t value_elem_type, iree_byte_span_t query_contents,
+ iree_byte_span_t key_contents, iree_byte_span_t value_contents,
+ iree_byte_span_t actual_contents, iree_byte_span_t result_contents,
+ int compute_every) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+ IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, B);
+ IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, M);
+ IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, K1);
+ IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, K2);
+ IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, N);
+
+ iree_host_size_t count = 0;
+ float* Attention = allocate_tensor(1, M, K2);
+ for (iree_hal_dim_t b = 0; b < B; ++b) {
+ if (++count < compute_every) continue;
+ count = 0;
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0,
+ reference_attention_element(
+ M, K1, K2, N, B, query_elem_type, key_elem_type, value_elem_type,
+ query_contents.data, key_contents.data, value_contents.data,
+ actual_contents.data, result_contents.data, b, Attention));
+ }
+ free_tensor(Attention);
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+//===----------------------------------------------------------------------===//
+// Attention comparison/logging
+//===----------------------------------------------------------------------===//
+
+typedef struct {
+ iree_allocator_t host_allocator;
+ iree_hal_dim_t b;
+ iree_hal_dim_t m;
+ iree_hal_dim_t k1;
+ iree_hal_dim_t k2;
+ iree_hal_dim_t n;
+ iree_hal_element_type_t query_elem_type;
+ iree_hal_element_type_t key_elem_type;
+ iree_hal_element_type_t value_elem_type;
+ iree_hal_element_type_t result_elem_type;
+ iree_byte_span_t query_contents;
+ iree_byte_span_t key_contents;
+ iree_byte_span_t value_contents;
+ iree_byte_span_t actual_contents;
+ iree_byte_span_t expected_contents;
+} attention_results_t;
+
+static void attention_results_deinitialize(attention_results_t* results);
+
+static iree_status_t attention_results_initialize(
+ iree_hal_device_t* device, iree_hal_dim_t b_size, iree_hal_dim_t m_size,
+ iree_hal_dim_t k1_size, iree_hal_dim_t k2_size, iree_hal_dim_t n_size,
+ iree_hal_buffer_view_t* query, iree_hal_buffer_view_t* key,
+ iree_hal_buffer_view_t* value, iree_hal_buffer_view_t* result,
+ iree_allocator_t host_allocator, attention_results_t* out_results) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ memset(out_results, 0, sizeof(*out_results));
+ out_results->host_allocator = host_allocator;
+
+ out_results->b = b_size;
+ out_results->m = m_size;
+ out_results->k1 = k1_size;
+ out_results->k2 = k2_size;
+ out_results->n = n_size;
+
+ out_results->query_elem_type = iree_hal_buffer_view_element_type(query);
+ out_results->key_elem_type = iree_hal_buffer_view_element_type(key);
+ out_results->value_elem_type = iree_hal_buffer_view_element_type(value);
+ out_results->result_elem_type = iree_hal_buffer_view_element_type(result);
+
+ iree_hal_buffer_t* query_buffer = iree_hal_buffer_view_buffer(query);
+ iree_hal_buffer_t* key_buffer = iree_hal_buffer_view_buffer(key);
+ iree_hal_buffer_t* value_buffer = iree_hal_buffer_view_buffer(value);
+ iree_hal_buffer_t* result_buffer = iree_hal_buffer_view_buffer(result);
+
+ iree_status_t status = iree_ok_status();
+
+ if (iree_status_is_ok(status)) {
+ out_results->query_contents.data_length =
+ iree_hal_buffer_byte_length(query_buffer);
+ status = iree_allocator_malloc(host_allocator,
+ out_results->query_contents.data_length,
+ (void**)&out_results->query_contents.data);
+ }
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_device_transfer_d2h(
+ device, query_buffer, 0, out_results->query_contents.data,
+ out_results->query_contents.data_length,
+ IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout());
+ }
+ if (iree_status_is_ok(status)) {
+ out_results->key_contents.data_length =
+ iree_hal_buffer_byte_length(key_buffer);
+ status = iree_allocator_malloc(host_allocator,
+ out_results->key_contents.data_length,
+ (void**)&out_results->key_contents.data);
+ }
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_device_transfer_d2h(
+ device, key_buffer, 0, out_results->key_contents.data,
+ out_results->key_contents.data_length,
+ IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout());
+ }
+ if (iree_status_is_ok(status)) {
+ out_results->value_contents.data_length =
+ iree_hal_buffer_byte_length(value_buffer);
+ status = iree_allocator_malloc(host_allocator,
+ out_results->value_contents.data_length,
+ (void**)&out_results->value_contents.data);
+ }
+
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_device_transfer_d2h(
+ device, value_buffer, 0, out_results->value_contents.data,
+ out_results->value_contents.data_length,
+ IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout());
+ }
+ if (iree_status_is_ok(status)) {
+ out_results->actual_contents.data_length =
+ iree_hal_buffer_byte_length(result_buffer);
+ status = iree_allocator_malloc(host_allocator,
+ out_results->actual_contents.data_length,
+ (void**)&out_results->actual_contents.data);
+ }
+ if (iree_status_is_ok(status)) {
+ status = iree_hal_device_transfer_d2h(
+ device, result_buffer, 0, out_results->actual_contents.data,
+ out_results->actual_contents.data_length,
+ IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout());
+ }
+ if (iree_status_is_ok(status)) {
+ out_results->expected_contents.data_length =
+ iree_hal_buffer_byte_length(result_buffer);
+ status = iree_allocator_malloc(
+ host_allocator, out_results->expected_contents.data_length,
+ (void**)&out_results->expected_contents.data);
+ }
+ if (!iree_status_is_ok(status)) {
+ attention_results_deinitialize(out_results);
+ }
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+static void attention_results_deinitialize(attention_results_t* results) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+ iree_allocator_free(results->host_allocator, results->query_contents.data);
+ iree_allocator_free(results->host_allocator, results->key_contents.data);
+ iree_allocator_free(results->host_allocator, results->value_contents.data);
+ iree_allocator_free(results->host_allocator, results->actual_contents.data);
+ iree_allocator_free(results->host_allocator, results->expected_contents.data);
+
+ IREE_TRACE_ZONE_END(z0);
+}
+
+// Helper for check_attention_results: the actual interesting part once we've
+// obtained and validated the {b,m,k1,k2,n}_size values. On error, detailed
+// logging is written to |file| if it is not NULL.
+static iree_status_t check_attention_results_impl(
+ FILE* file, const attention_results_t* results, int check_every) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+
+ IREE_RETURN_AND_END_ZONE_IF_ERROR(
+ z0, reference_attention(results->b, results->m, results->k1, results->k2,
+ results->n, results->query_elem_type,
+ results->key_elem_type, results->value_elem_type,
+ results->query_contents, results->key_contents,
+ results->value_contents, results->actual_contents,
+ results->expected_contents, check_every));
+
+ IREE_TRACE_ZONE_END(z0);
+ return iree_ok_status();
+}
+
+// Given an actual attention's inputs and output (all host-local), uses a
+// reference attention implementation on the same inputs to check if the output
+// is correct. On error, detailed logging is written to |file| if it is not
+// NULL.
+static iree_status_t check_attention_results(
+ FILE* file, const attention_results_t* results) {
+ IREE_TRACE_ZONE_BEGIN(z0);
+ // TODO: Increase the check every param to reduce the number of comparisons.
+ int check_every = 1;
+ iree_status_t status =
+ check_attention_results_impl(file, results, check_every);
+ if (!iree_status_is_ok(status) && check_every > 1) {
+ // If we got a failure with check_every>1, that didn't log a useful
+ // numerical summary, as most of the reference matrix entries hadn't been
+ // computed. Rerun now with check_every=1 to get that numerical logging.
+ iree_status_ignore(status);
+ status = check_attention_results_impl(file, results, 1);
+ }
+ IREE_TRACE_ZONE_END(z0);
+ return status;
+}
+
+//===----------------------------------------------------------------------===//
+// `attention_test` custom module
+//===----------------------------------------------------------------------===//
+// This uses the C++ wrapper to keep things simple. Though easier to use it's
+// got additional overhead/code-size bloat that doesn't matter in a test like
+// this. Making a C module builder API that removes the boilerplate there is TBD
+// so this file is written in C besides this module so that we can swap it back
+// to being pure C in the future.
+
+namespace iree {
+
+class AttentionTestModuleState final {
+ public:
+ explicit AttentionTestModuleState(iree_allocator_t host_allocator)
+ : host_allocator_(host_allocator) {}
+ ~AttentionTestModuleState() = default;
+
+ // Fills the destination span with pseudorandom values of the given
+ // |element_type|. The given |seed| is passed to the pseudorandom generator.
+ // The pseudorandom values are reproducible both across runs and across
+ // machines.
+ StatusOr<vm::ref<iree_hal_buffer_view_t>> GenerateRandom3dTensor(
+ const vm::ref<iree_hal_device_t> device, int64_t dim0, int64_t dim1,
+ int64_t dim2, iree_hal_element_type_t element_type, int32_t seed) {
+ iree_hal_dim_t dims[3] = {
+ (iree_hal_dim_t)dim0,
+ (iree_hal_dim_t)dim1,
+ (iree_hal_dim_t)dim2,
+ };
+ iree_hal_buffer_params_t buffer_params = {0};
+ buffer_params.usage = IREE_HAL_BUFFER_USAGE_DEFAULT;
+ buffer_params.access = IREE_HAL_MEMORY_ACCESS_ALL;
+ buffer_params.type = IREE_HAL_MEMORY_TYPE_OPTIMAL_FOR_DEVICE;
+ vm::ref<iree_hal_buffer_view_t> result_view;
+ struct callback_state_t {
+ iree_hal_element_type_t element_type;
+ int32_t seed;
+ } callback_state = {
+ element_type,
+ seed,
+ };
+ IREE_RETURN_IF_ERROR(iree_hal_buffer_view_generate_buffer(
+ device.get(), iree_hal_device_allocator(device.get()),
+ IREE_ARRAYSIZE(dims), dims, element_type,
+ IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, buffer_params,
+ +[](iree_hal_buffer_mapping_t* mapping, void* user_data) {
+ callback_state_t callback_state = *(callback_state_t*)user_data;
+ iree_byte_span_t span = mapping->contents;
+ // Generate "uniform" integer-valued numbers in the range [min, max].
+ int32_t min = 0;
+ int32_t max = 0;
+ iree_test_utils_get_min_max_for_element_type(
+ callback_state.element_type, &min, &max);
+ uint32_t range = (max - min + 1);
+ iree_host_size_t element_byte_count =
+ iree_hal_element_dense_byte_count(callback_state.element_type);
+ uint8_t* data_end = span.data + span.data_length;
+ uint32_t state = callback_state.seed;
+ for (uint8_t* data = span.data; data < data_end;
+ data += element_byte_count) {
+ int32_t value =
+ (int32_t)iree_test_utils_pseudorandom_range(&state, range) +
+ min;
+ iree_test_utils_write_element(callback_state.element_type, value,
+ data);
+ }
+ return iree_ok_status();
+ },
+ &callback_state, &result_view));
+ return std::move(result_view);
+ }
+
+ Status CheckAttentionResults(
+ const vm::ref<iree_hal_device_t> device, int64_t b, int64_t m, int64_t k1,
+ int64_t k2, int64_t n, const vm::ref<iree_hal_buffer_view_t> query,
+ const vm::ref<iree_hal_buffer_view_t> key,
+ const vm::ref<iree_hal_buffer_view_t> value,
+ const vm::ref<iree_hal_buffer_view_t> actual_result) {
+ attention_results_t results = {};
+ IREE_RETURN_IF_ERROR(attention_results_initialize(
+ device.get(), (iree_hal_dim_t)b, (iree_hal_dim_t)m, (iree_hal_dim_t)k1,
+ (iree_hal_dim_t)k2, (iree_hal_dim_t)n, query.get(), key.get(),
+ value.get(), actual_result.get(), host_allocator_, &results));
+ iree_status_t status = check_attention_results(stderr, &results);
+ attention_results_deinitialize(&results);
+ return status;
+ }
+
+ private:
+ iree_allocator_t host_allocator_;
+};
+
+static const vm::NativeFunction<AttentionTestModuleState>
+ kAttentionTestModuleFunctions[] = {
+ vm::MakeNativeFunction(
+ "generate_random_tensor",
+ &AttentionTestModuleState::GenerateRandom3dTensor),
+ vm::MakeNativeFunction(
+ "check_attention_results",
+ &AttentionTestModuleState::CheckAttentionResults),
+};
+
+struct AttentionTestModule final
+ : public vm::NativeModule<AttentionTestModuleState> {
+ using vm::NativeModule<AttentionTestModuleState>::NativeModule;
+ StatusOr<std::unique_ptr<AttentionTestModuleState>> CreateState(
+ iree_allocator_t host_allocator) override {
+ return std::make_unique<AttentionTestModuleState>(host_allocator);
+ }
+};
+
+} // namespace iree
+
+static iree_status_t attention_test_module_create(
+ iree_vm_instance_t* instance, iree_allocator_t host_allocator,
+ iree_vm_module_t** out_module) {
+ IREE_ASSERT_ARGUMENT(out_module);
+ *out_module = NULL;
+ auto module = std::make_unique<iree::AttentionTestModule>(
+ "attention_test", /*version=*/0, instance, host_allocator,
+ iree::span<
+ const iree::vm::NativeFunction<iree::AttentionTestModuleState>>(
+ iree::kAttentionTestModuleFunctions));
+ *out_module = module.release()->interface();
+ return iree_ok_status();
+}
+
+int main(int argc, char** argv) {
+ IREE_TRACE_APP_ENTER();
+
+ iree_flags_parse_checked(IREE_FLAGS_PARSE_MODE_DEFAULT, &argc, &argv);
+ if (argc != 1) {
+ fprintf(stderr, "use --module= flags to specify the modules to run\n");
+ IREE_TRACE_APP_EXIT(EXIT_FAILURE);
+ return EXIT_FAILURE;
+ }
+
+ iree_status_t status = iree_test_utils_load_and_run_e2e_tests(
+ iree_allocator_system(), attention_test_module_create);
+ int exit_code = EXIT_SUCCESS;
+ if (!iree_status_is_ok(status)) {
+ iree_status_fprint(stderr, status);
+ bool is_unavailable = iree_status_is_unavailable(status);
+ iree_status_free(status);
+ exit_code = is_unavailable ? EXIT_SUCCESS : EXIT_FAILURE;
+ }
+
+ IREE_TRACE_APP_EXIT(exit_code);
+ return exit_code;
+}