blob: e32ecdb9902e088cc02d564714e521e949413688 [file]
#!/usr/bin/env python3
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Generator for e2e attention tests."""
import argparse
import enum
import dataclasses
import typing
import math
# Data type of kernel entries. The string values must match MLIR data types.
@enum.unique
class QueryElemTypeId(enum.Enum):
NONE = ""
F16 = "f16"
# Data type of input entries. The string values must match MLIR data types.
@enum.unique
class KeyElemTypeId(enum.Enum):
NONE = ""
F16 = "f16"
# Data type of input entries. The string values must match MLIR data types.
@enum.unique
class ValueElemTypeId(enum.Enum):
NONE = ""
F16 = "f16"
# Data type of input entries. The string values must match MLIR data types.
@enum.unique
class ResultElemTypeId(enum.Enum):
NONE = ""
F16 = "f16"
# Enumerates of the collections of shapes that we can generate tests for.
# The values are the accepted values for the --shapes= flag.
@enum.unique
class ShapesId(enum.Enum):
SMALL = "small"
MEDIUM = "medium"
LARGE = "large"
DECODE_SMALL = "decode_small"
DECODE_MEDIUM = "decode_medium"
DECODE_LARGE = "decode_large"
PREFILL_SMALL = "prefill_small"
PREFILL_MEDIUM = "prefill_medium"
PREFILL_LARGE = "prefill_large"
# Enumerates the types of masks that can be applied to attention.
# The values are the accepted values for the --mask_type= flag.
@enum.unique
class MaskType(enum.Enum):
NONE = "none" # No mask
ALL_ONES = "all_ones" # All positions can attend (for decode)
CAUSAL = "causal" # Lower triangular mask (for prefill)
# batch: Batch dimension
# m: M dimension of first and second matmul
# n: N dimension of second matmul
# k1: K dimension of first matmul
# k2: K dimension of second matmul
@dataclasses.dataclass
class TestShapeAndScale:
batch: int
m: int
k1: int
k2: int
n: int
scale: float
# Returns the list of TestShape's to use for the collection of shapes
# identified by shapes_id.
def get_test_shapes(shapes_id: ShapesId):
if shapes_id == ShapesId.SMALL:
return [
TestShapeAndScale(batch=2, m=256, k1=64, k2=32, n=16, scale=1.0),
]
if shapes_id == ShapesId.MEDIUM:
return [
TestShapeAndScale(batch=2, m=512, k1=128, k2=64, n=32, scale=1.0),
]
if shapes_id == ShapesId.LARGE:
return [
TestShapeAndScale(batch=2, m=1024, k1=128, k2=128, n=64, scale=1.0),
]
# Decode: m = 1 (single token attending to cached KV)
if shapes_id == ShapesId.DECODE_SMALL:
return [
TestShapeAndScale(batch=2, m=1, k1=128, k2=128, n=128, scale=1.0),
]
if shapes_id == ShapesId.DECODE_MEDIUM:
return [
TestShapeAndScale(batch=2, m=1, k1=128, k2=2048, n=128, scale=1.0),
]
if shapes_id == ShapesId.DECODE_LARGE:
return [
TestShapeAndScale(batch=2, m=1, k1=128, k2=16384, n=128, scale=1.0),
]
# Prefill: m = k2 (self-attention on full sequence)
if shapes_id == ShapesId.PREFILL_SMALL:
return [
TestShapeAndScale(batch=2, m=128, k1=128, k2=128, n=128, scale=1.0),
]
if shapes_id == ShapesId.PREFILL_MEDIUM:
return [
TestShapeAndScale(batch=2, m=2048, k1=128, k2=2048, n=128, scale=1.0),
]
if shapes_id == ShapesId.PREFILL_LARGE:
# Currently not used due to time-out.
return [
TestShapeAndScale(batch=2, m=16384, k1=128, k2=16384, n=128, scale=1.0),
]
raise ValueError(shapes_id)
# Determines the shape of input and kernel tensors.
@dataclasses.dataclass
class TestInputTensorShapes:
batch: int
m: int
k1: int
k2: int
n: int
scale: float
# Helper for generate_function. Generates TestInputTensorShapes, i.e.
# converts from the runtime shape dimensions in TestShape and given dynamicity to
# the set of shapes to be used in a test function's input tensors.
def generate_shapes_and_scale(shape: TestShapeAndScale):
batch = shape.batch
m = shape.m
k1 = shape.k1
k2 = shape.k2
n = shape.n
scale = shape.scale
shapes_scale = TestInputTensorShapes(
batch=batch,
m=m,
k1=k1,
k2=k2,
n=n,
scale=scale,
)
return shapes_scale
# Helper to return input, kernel, output, and mask shapes based on the layout and the Attention Params.
def get_tensor_shapes(
shapes_scale: TestShapeAndScale,
):
batch = shapes_scale.batch
m = shapes_scale.m
k1 = shapes_scale.k1
k2 = shapes_scale.k2
n = shapes_scale.n
query_tensor_shape = [batch, m, k1]
key_tensor_shape = [batch, k2, k1]
value_tensor_shape = [batch, k2, n]
result_tensor_shape = [batch, m, n]
mask_tensor_shape = [batch, m, k2]
return (
query_tensor_shape,
key_tensor_shape,
value_tensor_shape,
result_tensor_shape,
mask_tensor_shape,
)
# Helper for generate_function.
# Generates a name for a test function in the generated MLIR code.
def generate_function_name(
query_type: QueryElemTypeId,
key_type: KeyElemTypeId,
value_type: ValueElemTypeId,
shapes_scale: TestInputTensorShapes,
):
query_t = query_type.value
key_t = key_type.value
value_t = value_type.value
result_t = value_type.value
batch = shapes_scale.batch
m = shapes_scale.m
k1 = shapes_scale.k1
k2 = shapes_scale.k2
n = shapes_scale.n
attention = "attention"
return (
f"{attention}_{batch}_{m}_{k1}_{k2}_{n}"
+ f"_dtype_{query_t}_{key_t}_{value_t}_{result_t}"
)
# Represents a generated test function.
@dataclasses.dataclass
class MLIRFunction:
name: str
signature: str
import_declaration: str
definition: str
# Generates a test function in the generated MLIR code.
# The generated function will take the same arguments as iree_linalg_ext.attention variants
# and will just call iree_linalg_ext.attention variants with them, returning its result.
def generate_function(
query_type: QueryElemTypeId,
key_type: KeyElemTypeId,
value_type: ValueElemTypeId,
shape_scale: TestShapeAndScale,
mask_type: MaskType,
):
shapes_scale = generate_shapes_and_scale(shape_scale)
func_name = generate_function_name(
query_type,
key_type,
value_type,
shapes_scale,
)
use_mask = mask_type != MaskType.NONE
if use_mask:
func_name = func_name + "_masked"
query_shape, key_shape, value_shape, result_shape, mask_shape = get_tensor_shapes(
shapes_scale
)
query_tensor_type = (
f"tensor<{query_shape[0]}x{query_shape[1]}x{query_shape[2]}x{query_type.value}>"
)
key_tensor_type = (
f"tensor<{key_shape[0]}x{key_shape[1]}x{key_shape[2]}x{key_type.value}>"
)
value_tensor_type = (
f"tensor<{value_shape[0]}x{value_shape[1]}x{value_shape[2]}x{value_type.value}>"
)
result_tensor_type = f"tensor<{result_shape[0]}x{result_shape[1]}x{result_shape[2]}x{value_type.value}>"
mask_tensor_type = f"tensor<{mask_shape[0]}x{mask_shape[1]}x{mask_shape[2]}xi1>"
F32 = "f32"
F16 = "f16"
op_name = "iree_linalg_ext.attention"
# Compilation info is optional; prints empty string by default.
func_definition = ""
if use_mask:
# Function with mask parameter.
signature = f"({query_tensor_type}, {key_tensor_type}, {value_tensor_type}, {mask_tensor_type}, {result_tensor_type}) -> {result_tensor_type}"
import_declaration = f"func.func private @module.{func_name}(%query: !hal.buffer_view, %key: !hal.buffer_view, %value: !hal.buffer_view, %mask: !hal.buffer_view, %scale: {F32}) -> !hal.buffer_view"
func_definition = func_definition + (
f"func.func @{func_name}(%query: {query_tensor_type}, %key: {key_tensor_type}, %value: {value_tensor_type}, %mask: {mask_tensor_type}, %scale: {F32}) -> {result_tensor_type} {{\n"
f" %result0 = tensor.empty(): {result_tensor_type}\n"
f" %scale_f16 = arith.truncf %scale : {F32} to {F16} \n"
f" %result1 = {op_name} {{\n"
f" indexing_maps = [affine_map<(batch, m, n, k1, k2) -> (batch, m, k1)>,\n"
f" affine_map<(batch, m, n, k1, k2) -> (batch, k2, k1)>,\n"
f" affine_map<(batch, m, n, k1, k2) -> (batch, k2, n)>,\n"
f" affine_map<(batch, m, n, k1, k2) -> ()>,\n"
f" affine_map<(batch, m, n, k1, k2) -> (batch, m, k2)>,\n"
f" affine_map<(batch, m, n, k1, k2) -> (batch, m, n)>]\n}}"
f" ins(%query, %key, %value, %scale_f16, %mask: {query_tensor_type}, {key_tensor_type}, {value_tensor_type}, {F16}, {mask_tensor_type})\n"
f" outs(%result0: {result_tensor_type}) {{\n"
f" ^bb0(%score: f32): \n"
f" iree_linalg_ext.yield %score : f32\n"
f" }} -> {result_tensor_type}\n"
f" return %result1: {result_tensor_type}\n"
f"}}\n"
)
else:
# Function without mask.
signature = f"({query_tensor_type}, {key_tensor_type}, {value_tensor_type}, {result_tensor_type}) -> {result_tensor_type}"
import_declaration = f"func.func private @module.{func_name}(%query: !hal.buffer_view, %key: !hal.buffer_view, %value: !hal.buffer_view, %scale: {F32}) -> !hal.buffer_view"
func_definition = func_definition + (
f"func.func @{func_name}(%query: {query_tensor_type}, %key: {key_tensor_type}, %value: {value_tensor_type}, %scale: {F32}) -> {result_tensor_type} {{\n"
f" %result0 = tensor.empty(): {result_tensor_type}\n"
f" %scale_f16 = arith.truncf %scale : {F32} to {F16} \n"
f" %result1 = {op_name} {{\n"
f" indexing_maps = [affine_map<(batch, m, n, k1, k2) -> (batch, m, k1)>,\n"
f" affine_map<(batch, m, n, k1, k2) -> (batch, k2, k1)>,\n"
f" affine_map<(batch, m, n, k1, k2) -> (batch, k2, n)>,\n"
f" affine_map<(batch, m, n, k1, k2) -> ()>,\n"
f" affine_map<(batch, m, n, k1, k2) -> (batch, m, n)>]\n}}"
f" ins(%query, %key, %value, %scale_f16: {query_tensor_type}, {key_tensor_type}, {value_tensor_type}, {F16})\n"
f" outs(%result0: {result_tensor_type}) {{\n"
f" ^bb0(%score: f32): \n"
f" iree_linalg_ext.yield %score : f32\n"
f" }} -> {result_tensor_type}\n"
f" return %result1: {result_tensor_type}\n"
f"}}\n"
)
return MLIRFunction(
name=func_name,
signature=signature,
import_declaration=import_declaration,
definition=func_definition,
)
# Represents a call to a generated test function.
@dataclasses.dataclass
class TestCall:
function: MLIRFunction
op: str
# Enumerates ways to initialize tensor buffer contents.
@enum.unique
class TensorGenerator(enum.Enum):
ZERO = "zero" # Fill with zeros
RANDOM = "random" # Fill with (deterministic) pseudorandom values.
# Intentionally fixed seed! We want full reproducibility here, both across runs
# and across machines.
# Intentionally not shared with local_pseudorandom_state to limit the ways
# in which shuffling testcases changes which random values are generated.
pseudorandom_generator_seed = 1
def contents_generator_tag(generator: TensorGenerator):
if generator == TensorGenerator.ZERO:
return ""
elif generator == TensorGenerator.RANDOM:
global pseudorandom_generator_seed
pseudorandom_generator_seed = pseudorandom_generator_seed + 1
return f"!tag:iree:fully_specified_pseudorandom {pseudorandom_generator_seed}"
else:
raise ValueError(generator)
# Generate a 3d tensor function argument of the given size as `%name`.
def generate_random_3d_tensor(
name: str,
tensor_shape: list,
element_type: typing.Union[QueryElemTypeId, ResultElemTypeId],
):
global pseudorandom_generator_seed
pseudorandom_generator_seed = pseudorandom_generator_seed + 1
return (
f" %{name}_dim0 = arith.constant {tensor_shape[0]} : i64\n"
f" %{name}_dim1 = arith.constant {tensor_shape[1]} : i64\n"
f" %{name}_dim2 = arith.constant {tensor_shape[2]} : i64\n"
f" %{name}_element_type = hal.element_type<{element_type.value}> : i32\n"
f" %{name}_seed = arith.constant {pseudorandom_generator_seed} : i32\n"
f" %{name} = call @attention_test.generate_random_tensor(%device, %{name}_dim0, %{name}_dim1, %{name}_dim2, %{name}_element_type, %{name}_seed) : (!hal.device, i64, i64, i64, i32, i32) -> !hal.buffer_view\n"
)
call_id = 0
# Generate causal mask as i8 tensor (0 or 1 values) for shape [batch, m, k2]
# Causal pattern: mask[b, i, j] = 1 if j <= i, else 0
def generate_causal_mask_values(batch: int, m: int, k2: int) -> list:
mask = []
for b in range(batch):
for i in range(m):
for j in range(k2):
mask.append(1 if j <= i else 0)
return mask
# Generate all-ones mask as i8 tensor for shape [batch, m, k2]
# All positions can attend to all positions.
def generate_all_ones_mask_values(batch: int, m: int, k2: int) -> list:
return [1] * (batch * m * k2)
def generate_call(
function: MLIRFunction,
query_type: QueryElemTypeId,
key_type: KeyElemTypeId,
value_type: ValueElemTypeId,
shapes_scale: TestShapeAndScale,
mask_type: MaskType,
):
global call_id
func_name = f"{function.name}_{shapes_scale.batch}_{shapes_scale.m}_{shapes_scale.k1}_{shapes_scale.k2}_{shapes_scale.n}_{shapes_scale.k1}_{shapes_scale.scale}"
func_name = f"{func_name}_{call_id}"
call_id = call_id + 1
use_mask = mask_type != MaskType.NONE
description = f"Attention shape (BATCHxMxK1xK2xN): {shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}x{shapes_scale.k2}x{shapes_scale.k1}x{shapes_scale.n}"
if mask_type == MaskType.ALL_ONES:
description = description + " (all-ones masked)"
elif mask_type == MaskType.CAUSAL:
description = description + " (causal masked)"
op = (
f"func.func @{func_name}() attributes {{\n"
f' iree.reflection = {{description = "{description}"}}\n'
"} {\n"
" %device_index = arith.constant 0 : index\n"
" %device = hal.devices.get %device_index : !hal.device\n"
)
query_shape, key_shape, value_shape, result_shape, mask_shape = get_tensor_shapes(
shapes_scale,
)
op = op + generate_random_3d_tensor("query", query_shape, query_type)
op = op + generate_random_3d_tensor("key", key_shape, key_type)
op = op + generate_random_3d_tensor("value", value_shape, value_type)
if use_mask:
batch, m, k2 = mask_shape
mask_size = batch * m * k2
# Generate mask as i8, then convert to i1 for attention op.
if mask_type == MaskType.ALL_ONES:
mask_values = generate_all_ones_mask_values(batch, m, k2)
elif mask_type == MaskType.CAUSAL:
mask_values = generate_causal_mask_values(batch, m, k2)
mask_values_str = ", ".join(str(v) for v in mask_values)
op = op + (
f" %mask_i8 = arith.constant dense<[{mask_values_str}]> : tensor<{mask_size}xi8>\n"
f" %mask_i8_reshaped = tensor.expand_shape %mask_i8 [[0, 1, 2]] output_shape [{batch}, {m}, {k2}] : tensor<{mask_size}xi8> into tensor<{batch}x{m}x{k2}xi8>\n"
f" %c0_i8 = arith.constant 0 : i8\n"
f" %zeros_i8 = tensor.empty() : tensor<{batch}x{m}x{k2}xi8>\n"
f" %zeros_i8_filled = linalg.fill ins(%c0_i8 : i8) outs(%zeros_i8 : tensor<{batch}x{m}x{k2}xi8>) -> tensor<{batch}x{m}x{k2}xi8>\n"
f" %mask_i1 = arith.cmpi ne, %mask_i8_reshaped, %zeros_i8_filled : tensor<{batch}x{m}x{k2}xi8>\n"
f" %mask = hal.tensor.export %mask_i1 : tensor<{batch}x{m}x{k2}xi1> -> !hal.buffer_view\n"
)
global pseudorandom_generator_seed
pseudorandom_generator_seed = pseudorandom_generator_seed - 1
if use_mask:
op = op + (
f" %scale = arith.constant {shapes_scale.scale} : f32\n"
f" %result = call @module.{function.name}(%query, %key, %value, %mask, %scale) : (!hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, f32) -> !hal.buffer_view\n"
)
else:
op = op + (
f" %scale = arith.constant {shapes_scale.scale} : f32\n"
f" %result = call @module.{function.name}(%query, %key, %value, %scale) : (!hal.buffer_view, !hal.buffer_view, !hal.buffer_view, f32) -> !hal.buffer_view\n"
)
op = op + (
f" %batch = arith.constant {shapes_scale.batch} : i64 \n"
f" %m = arith.constant {shapes_scale.m} : i64 \n"
f" %k1 = arith.constant {shapes_scale.k1} : i64 \n"
f" %k2 = arith.constant {shapes_scale.k2} : i64 \n"
f" %n = arith.constant {shapes_scale.n} : i64 \n"
f" %queryTensor = hal.tensor.import %query : !hal.buffer_view -> tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}xf16> \n"
f" %keyTensor = hal.tensor.import %key : !hal.buffer_view -> tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.k1}xf16> \n"
f" %valueTensor = hal.tensor.import %value : !hal.buffer_view -> tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.n}xf16> \n"
f" %resultTensor = hal.tensor.import %result : !hal.buffer_view -> tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.n}xf16> \n"
f" %queryExt = arith.extf %queryTensor : tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}xf16> to tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}xf32> \n"
f" %keyExt = arith.extf %keyTensor : tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.k1}xf16> to tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.k1}xf32> \n"
f" %valueExt = arith.extf %valueTensor : tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.n}xf16> to tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.n}xf32> \n"
f" %resultExt = arith.extf %resultTensor : tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.n}xf16> to tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.n}xf32> \n"
f" %queryExtBufferView = hal.tensor.export %queryExt : tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}xf32> -> !hal.buffer_view \n"
f" %keyExtBufferView = hal.tensor.export %keyExt : tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.k1}xf32> -> !hal.buffer_view \n"
f" %valueExtBufferView = hal.tensor.export %valueExt : tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.n}xf32> -> !hal.buffer_view \n"
f" %resultExtBufferView = hal.tensor.export %resultExt : tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.n}xf32> -> !hal.buffer_view \n"
)
if use_mask:
# Export mask as i8 for the reference implementation.
op = op + (
f" %mask_i8_export = hal.tensor.export %mask_i8_reshaped : tensor<{batch}x{m}x{k2}xi8> -> !hal.buffer_view\n"
f" call @attention_test.check_attention_results_with_mask(%device, %batch, %m, %k1, %k2, %n, %queryExtBufferView, %keyExtBufferView, %valueExtBufferView, %mask_i8_export, %resultExtBufferView) : (!hal.device, i64, i64, i64, i64, i64, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view) -> ()\n"
)
else:
op = op + (
f" call @attention_test.check_attention_results(%device, %batch, %m, %k1, %k2, %n, %queryExtBufferView, %keyExtBufferView, %valueExtBufferView, %resultExtBufferView) : (!hal.device, i64, i64, i64, i64, i64, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view) -> ()\n"
)
op = op + " return\n"
op = op + "}\n"
return TestCall(function=function, op=op)
# Generates all output files' contents as strings.
def generate(
query_type: QueryElemTypeId,
key_type: KeyElemTypeId,
value_type: ValueElemTypeId,
shapes_id: ShapesId,
mask_type: MaskType,
):
functions = {}
calls = []
for shape in get_test_shapes(shapes_id):
function = generate_function(
query_type,
key_type,
value_type,
shape,
mask_type,
)
if function.name not in functions:
functions[function.name] = function
calls.append(
generate_call(
function,
query_type,
key_type,
value_type,
shape,
mask_type,
)
)
return (functions, calls)
def parse_arguments():
parser = argparse.ArgumentParser(description="Generator of e2e Attention tests")
parser.add_argument(
"--output_attention_mlir",
type=str,
help="Path of output .mlir file containing the generated Attention functions",
required=True,
)
parser.add_argument(
"--output_calls_mlir",
type=str,
help="Path of output .mlir file containing the calls",
required=True,
)
parser.add_argument(
"--query_type",
type=str,
choices=["f16"],
help="Numeric type of query tensors ",
required=True,
)
parser.add_argument(
"--key_type",
type=str,
choices=["f16"],
help="Numeric type of key tensors ",
required=True,
)
parser.add_argument(
"--value_type",
type=str,
choices=["f16"],
help="Numeric type of value tensors ",
required=True,
)
parser.add_argument(
"--shapes",
type=str,
choices=[s.value for s in ShapesId],
help="Collection of tensor shapes to test",
required=True,
)
parser.add_argument(
"--requirements",
type=str,
help="Target requirements for this module. Comma-separated. As in -iree-llvmcpu-target-cpu-features. If the target device does not meet all of the requirements, the test will be skipped.",
required=False,
)
parser.add_argument(
"--mask_type",
type=str,
choices=[m.value for m in MaskType],
default="none",
help="Type of attention mask to generate: none, all_ones (for decode), or causal (for prefill)",
)
return parser.parse_args()
def write_code_file(functions, filename):
with open(filename, "w") as file:
for function in functions.values():
file.write(function.definition + "\n")
def write_calls_file(functions, calls, filename, requirements):
# Module-level reflection information used to control the test tool.
reflection = ""
if requirements:
reflection = (
"iree.reflection = {"
'target_features = "'
+ ",".join([req.lstrip("+") for req in requirements.split(",")])
+ '"'
"}"
)
module_definition = (
f"builtin.module @calls attributes {{\n" f" {reflection}\n" f"}} {{\n\n"
)
# Declare the custom module that generates arguments.
module_definition = module_definition + (
"func.func private @attention_test.generate_random_tensor(%device: !hal.device, %dim0: i64, %dim1: i64, %dim2: i64, %element_type: i32, %seed: i32) -> !hal.buffer_view\n"
"func.func private @attention_test.check_attention_results(%device: !hal.device, %batch: i64, %m: i64, %k1: i64, %k2: i64, %n: i64, %query: !hal.buffer_view, %key: !hal.buffer_view, %value: !hal.buffer_view, %result: !hal.buffer_view)\n"
"func.func private @attention_test.check_attention_results_with_mask(%device: !hal.device, %batch: i64, %m: i64, %k1: i64, %k2: i64, %n: i64, %query: !hal.buffer_view, %key: !hal.buffer_view, %value: !hal.buffer_view, %mask: !hal.buffer_view, %result: !hal.buffer_view)\n"
"\n"
)
# Declare the functions that will be called.
for function in functions.values():
module_definition = module_definition + function.import_declaration + "\n"
module_definition = module_definition + "\n"
# Emit the test cases for each call.
for call in calls:
module_definition = module_definition + call.op + "\n"
module_definition = module_definition + "\n}\n"
with open(filename, "w") as file:
file.write(module_definition)
def main(args):
query_type = QueryElemTypeId(args.query_type)
key_type = KeyElemTypeId(args.key_type)
value_type = ValueElemTypeId(args.value_type)
shapes_id = ShapesId(args.shapes)
mask_type = MaskType(args.mask_type)
(functions, calls) = generate(
query_type,
key_type,
value_type,
shapes_id,
mask_type,
)
write_code_file(functions, args.output_attention_mlir)
write_calls_file(
functions,
calls,
args.output_calls_mlir,
args.requirements,
)
if __name__ == "__main__":
main(parse_arguments())