| #!/usr/bin/env python3 |
| # Copyright 2024 The IREE Authors |
| # |
| # Licensed under the Apache License v2.0 with LLVM Exceptions. |
| # See https://llvm.org/LICENSE.txt for license information. |
| # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| """Generator for e2e attention tests.""" |
| |
| import argparse |
| import enum |
| import dataclasses |
| import typing |
| import math |
| |
| |
| # Data type of kernel entries. The string values must match MLIR data types. |
| @enum.unique |
| class QueryElemTypeId(enum.Enum): |
| NONE = "" |
| F16 = "f16" |
| |
| |
| # Data type of input entries. The string values must match MLIR data types. |
| @enum.unique |
| class KeyElemTypeId(enum.Enum): |
| NONE = "" |
| F16 = "f16" |
| |
| |
| # Data type of input entries. The string values must match MLIR data types. |
| @enum.unique |
| class ValueElemTypeId(enum.Enum): |
| NONE = "" |
| F16 = "f16" |
| |
| |
| # Data type of input entries. The string values must match MLIR data types. |
| @enum.unique |
| class ResultElemTypeId(enum.Enum): |
| NONE = "" |
| F16 = "f16" |
| |
| |
| # Enumerates of the collections of shapes that we can generate tests for. |
| # The values are the accepted values for the --shapes= flag. |
| @enum.unique |
| class ShapesId(enum.Enum): |
| SMALL = "small" |
| MEDIUM = "medium" |
| LARGE = "large" |
| DECODE_SMALL = "decode_small" |
| DECODE_MEDIUM = "decode_medium" |
| DECODE_LARGE = "decode_large" |
| PREFILL_SMALL = "prefill_small" |
| PREFILL_MEDIUM = "prefill_medium" |
| PREFILL_LARGE = "prefill_large" |
| |
| |
| # Enumerates the types of masks that can be applied to attention. |
| # The values are the accepted values for the --mask_type= flag. |
| @enum.unique |
| class MaskType(enum.Enum): |
| NONE = "none" # No mask |
| ALL_ONES = "all_ones" # All positions can attend (for decode) |
| CAUSAL = "causal" # Lower triangular mask (for prefill) |
| |
| |
| # batch: Batch dimension |
| # m: M dimension of first and second matmul |
| # n: N dimension of second matmul |
| # k1: K dimension of first matmul |
| # k2: K dimension of second matmul |
| @dataclasses.dataclass |
| class TestShapeAndScale: |
| batch: int |
| m: int |
| k1: int |
| k2: int |
| n: int |
| scale: float |
| |
| |
| # Returns the list of TestShape's to use for the collection of shapes |
| # identified by shapes_id. |
| def get_test_shapes(shapes_id: ShapesId): |
| if shapes_id == ShapesId.SMALL: |
| return [ |
| TestShapeAndScale(batch=2, m=256, k1=64, k2=32, n=16, scale=1.0), |
| ] |
| if shapes_id == ShapesId.MEDIUM: |
| return [ |
| TestShapeAndScale(batch=2, m=512, k1=128, k2=64, n=32, scale=1.0), |
| ] |
| if shapes_id == ShapesId.LARGE: |
| return [ |
| TestShapeAndScale(batch=2, m=1024, k1=128, k2=128, n=64, scale=1.0), |
| ] |
| # Decode: m = 1 (single token attending to cached KV) |
| if shapes_id == ShapesId.DECODE_SMALL: |
| return [ |
| TestShapeAndScale(batch=2, m=1, k1=128, k2=128, n=128, scale=1.0), |
| ] |
| if shapes_id == ShapesId.DECODE_MEDIUM: |
| return [ |
| TestShapeAndScale(batch=2, m=1, k1=128, k2=2048, n=128, scale=1.0), |
| ] |
| if shapes_id == ShapesId.DECODE_LARGE: |
| return [ |
| TestShapeAndScale(batch=2, m=1, k1=128, k2=16384, n=128, scale=1.0), |
| ] |
| # Prefill: m = k2 (self-attention on full sequence) |
| if shapes_id == ShapesId.PREFILL_SMALL: |
| return [ |
| TestShapeAndScale(batch=2, m=128, k1=128, k2=128, n=128, scale=1.0), |
| ] |
| if shapes_id == ShapesId.PREFILL_MEDIUM: |
| return [ |
| TestShapeAndScale(batch=2, m=2048, k1=128, k2=2048, n=128, scale=1.0), |
| ] |
| if shapes_id == ShapesId.PREFILL_LARGE: |
| # Currently not used due to time-out. |
| return [ |
| TestShapeAndScale(batch=2, m=16384, k1=128, k2=16384, n=128, scale=1.0), |
| ] |
| |
| raise ValueError(shapes_id) |
| |
| |
| # Determines the shape of input and kernel tensors. |
| @dataclasses.dataclass |
| class TestInputTensorShapes: |
| batch: int |
| m: int |
| k1: int |
| k2: int |
| n: int |
| scale: float |
| |
| |
| # Helper for generate_function. Generates TestInputTensorShapes, i.e. |
| # converts from the runtime shape dimensions in TestShape and given dynamicity to |
| # the set of shapes to be used in a test function's input tensors. |
| def generate_shapes_and_scale(shape: TestShapeAndScale): |
| batch = shape.batch |
| m = shape.m |
| k1 = shape.k1 |
| k2 = shape.k2 |
| n = shape.n |
| scale = shape.scale |
| |
| shapes_scale = TestInputTensorShapes( |
| batch=batch, |
| m=m, |
| k1=k1, |
| k2=k2, |
| n=n, |
| scale=scale, |
| ) |
| return shapes_scale |
| |
| |
| # Helper to return input, kernel, output, and mask shapes based on the layout and the Attention Params. |
| def get_tensor_shapes( |
| shapes_scale: TestShapeAndScale, |
| ): |
| batch = shapes_scale.batch |
| m = shapes_scale.m |
| k1 = shapes_scale.k1 |
| k2 = shapes_scale.k2 |
| n = shapes_scale.n |
| |
| query_tensor_shape = [batch, m, k1] |
| key_tensor_shape = [batch, k2, k1] |
| value_tensor_shape = [batch, k2, n] |
| result_tensor_shape = [batch, m, n] |
| mask_tensor_shape = [batch, m, k2] |
| |
| return ( |
| query_tensor_shape, |
| key_tensor_shape, |
| value_tensor_shape, |
| result_tensor_shape, |
| mask_tensor_shape, |
| ) |
| |
| |
| # Helper for generate_function. |
| # Generates a name for a test function in the generated MLIR code. |
| def generate_function_name( |
| query_type: QueryElemTypeId, |
| key_type: KeyElemTypeId, |
| value_type: ValueElemTypeId, |
| shapes_scale: TestInputTensorShapes, |
| ): |
| query_t = query_type.value |
| key_t = key_type.value |
| value_t = value_type.value |
| result_t = value_type.value |
| |
| batch = shapes_scale.batch |
| m = shapes_scale.m |
| k1 = shapes_scale.k1 |
| k2 = shapes_scale.k2 |
| n = shapes_scale.n |
| |
| attention = "attention" |
| return ( |
| f"{attention}_{batch}_{m}_{k1}_{k2}_{n}" |
| + f"_dtype_{query_t}_{key_t}_{value_t}_{result_t}" |
| ) |
| |
| |
| # Represents a generated test function. |
| @dataclasses.dataclass |
| class MLIRFunction: |
| name: str |
| signature: str |
| import_declaration: str |
| definition: str |
| |
| |
| # Generates a test function in the generated MLIR code. |
| # The generated function will take the same arguments as iree_linalg_ext.attention variants |
| # and will just call iree_linalg_ext.attention variants with them, returning its result. |
| def generate_function( |
| query_type: QueryElemTypeId, |
| key_type: KeyElemTypeId, |
| value_type: ValueElemTypeId, |
| shape_scale: TestShapeAndScale, |
| mask_type: MaskType, |
| ): |
| shapes_scale = generate_shapes_and_scale(shape_scale) |
| func_name = generate_function_name( |
| query_type, |
| key_type, |
| value_type, |
| shapes_scale, |
| ) |
| use_mask = mask_type != MaskType.NONE |
| if use_mask: |
| func_name = func_name + "_masked" |
| |
| query_shape, key_shape, value_shape, result_shape, mask_shape = get_tensor_shapes( |
| shapes_scale |
| ) |
| query_tensor_type = ( |
| f"tensor<{query_shape[0]}x{query_shape[1]}x{query_shape[2]}x{query_type.value}>" |
| ) |
| key_tensor_type = ( |
| f"tensor<{key_shape[0]}x{key_shape[1]}x{key_shape[2]}x{key_type.value}>" |
| ) |
| value_tensor_type = ( |
| f"tensor<{value_shape[0]}x{value_shape[1]}x{value_shape[2]}x{value_type.value}>" |
| ) |
| result_tensor_type = f"tensor<{result_shape[0]}x{result_shape[1]}x{result_shape[2]}x{value_type.value}>" |
| mask_tensor_type = f"tensor<{mask_shape[0]}x{mask_shape[1]}x{mask_shape[2]}xi1>" |
| F32 = "f32" |
| F16 = "f16" |
| op_name = "iree_linalg_ext.attention" |
| |
| # Compilation info is optional; prints empty string by default. |
| func_definition = "" |
| |
| if use_mask: |
| # Function with mask parameter. |
| signature = f"({query_tensor_type}, {key_tensor_type}, {value_tensor_type}, {mask_tensor_type}, {result_tensor_type}) -> {result_tensor_type}" |
| import_declaration = f"func.func private @module.{func_name}(%query: !hal.buffer_view, %key: !hal.buffer_view, %value: !hal.buffer_view, %mask: !hal.buffer_view, %scale: {F32}) -> !hal.buffer_view" |
| func_definition = func_definition + ( |
| f"func.func @{func_name}(%query: {query_tensor_type}, %key: {key_tensor_type}, %value: {value_tensor_type}, %mask: {mask_tensor_type}, %scale: {F32}) -> {result_tensor_type} {{\n" |
| f" %result0 = tensor.empty(): {result_tensor_type}\n" |
| f" %scale_f16 = arith.truncf %scale : {F32} to {F16} \n" |
| f" %result1 = {op_name} {{\n" |
| f" indexing_maps = [affine_map<(batch, m, n, k1, k2) -> (batch, m, k1)>,\n" |
| f" affine_map<(batch, m, n, k1, k2) -> (batch, k2, k1)>,\n" |
| f" affine_map<(batch, m, n, k1, k2) -> (batch, k2, n)>,\n" |
| f" affine_map<(batch, m, n, k1, k2) -> ()>,\n" |
| f" affine_map<(batch, m, n, k1, k2) -> (batch, m, k2)>,\n" |
| f" affine_map<(batch, m, n, k1, k2) -> (batch, m, n)>]\n}}" |
| f" ins(%query, %key, %value, %scale_f16, %mask: {query_tensor_type}, {key_tensor_type}, {value_tensor_type}, {F16}, {mask_tensor_type})\n" |
| f" outs(%result0: {result_tensor_type}) {{\n" |
| f" ^bb0(%score: f32): \n" |
| f" iree_linalg_ext.yield %score : f32\n" |
| f" }} -> {result_tensor_type}\n" |
| f" return %result1: {result_tensor_type}\n" |
| f"}}\n" |
| ) |
| else: |
| # Function without mask. |
| signature = f"({query_tensor_type}, {key_tensor_type}, {value_tensor_type}, {result_tensor_type}) -> {result_tensor_type}" |
| import_declaration = f"func.func private @module.{func_name}(%query: !hal.buffer_view, %key: !hal.buffer_view, %value: !hal.buffer_view, %scale: {F32}) -> !hal.buffer_view" |
| func_definition = func_definition + ( |
| f"func.func @{func_name}(%query: {query_tensor_type}, %key: {key_tensor_type}, %value: {value_tensor_type}, %scale: {F32}) -> {result_tensor_type} {{\n" |
| f" %result0 = tensor.empty(): {result_tensor_type}\n" |
| f" %scale_f16 = arith.truncf %scale : {F32} to {F16} \n" |
| f" %result1 = {op_name} {{\n" |
| f" indexing_maps = [affine_map<(batch, m, n, k1, k2) -> (batch, m, k1)>,\n" |
| f" affine_map<(batch, m, n, k1, k2) -> (batch, k2, k1)>,\n" |
| f" affine_map<(batch, m, n, k1, k2) -> (batch, k2, n)>,\n" |
| f" affine_map<(batch, m, n, k1, k2) -> ()>,\n" |
| f" affine_map<(batch, m, n, k1, k2) -> (batch, m, n)>]\n}}" |
| f" ins(%query, %key, %value, %scale_f16: {query_tensor_type}, {key_tensor_type}, {value_tensor_type}, {F16})\n" |
| f" outs(%result0: {result_tensor_type}) {{\n" |
| f" ^bb0(%score: f32): \n" |
| f" iree_linalg_ext.yield %score : f32\n" |
| f" }} -> {result_tensor_type}\n" |
| f" return %result1: {result_tensor_type}\n" |
| f"}}\n" |
| ) |
| return MLIRFunction( |
| name=func_name, |
| signature=signature, |
| import_declaration=import_declaration, |
| definition=func_definition, |
| ) |
| |
| |
| # Represents a call to a generated test function. |
| @dataclasses.dataclass |
| class TestCall: |
| function: MLIRFunction |
| op: str |
| |
| |
| # Enumerates ways to initialize tensor buffer contents. |
| @enum.unique |
| class TensorGenerator(enum.Enum): |
| ZERO = "zero" # Fill with zeros |
| RANDOM = "random" # Fill with (deterministic) pseudorandom values. |
| |
| |
| # Intentionally fixed seed! We want full reproducibility here, both across runs |
| # and across machines. |
| # Intentionally not shared with local_pseudorandom_state to limit the ways |
| # in which shuffling testcases changes which random values are generated. |
| pseudorandom_generator_seed = 1 |
| |
| |
| def contents_generator_tag(generator: TensorGenerator): |
| if generator == TensorGenerator.ZERO: |
| return "" |
| elif generator == TensorGenerator.RANDOM: |
| global pseudorandom_generator_seed |
| pseudorandom_generator_seed = pseudorandom_generator_seed + 1 |
| return f"!tag:iree:fully_specified_pseudorandom {pseudorandom_generator_seed}" |
| else: |
| raise ValueError(generator) |
| |
| |
| # Generate a 3d tensor function argument of the given size as `%name`. |
| def generate_random_3d_tensor( |
| name: str, |
| tensor_shape: list, |
| element_type: typing.Union[QueryElemTypeId, ResultElemTypeId], |
| ): |
| global pseudorandom_generator_seed |
| pseudorandom_generator_seed = pseudorandom_generator_seed + 1 |
| return ( |
| f" %{name}_dim0 = arith.constant {tensor_shape[0]} : i64\n" |
| f" %{name}_dim1 = arith.constant {tensor_shape[1]} : i64\n" |
| f" %{name}_dim2 = arith.constant {tensor_shape[2]} : i64\n" |
| f" %{name}_element_type = hal.element_type<{element_type.value}> : i32\n" |
| f" %{name}_seed = arith.constant {pseudorandom_generator_seed} : i32\n" |
| f" %{name} = call @attention_test.generate_random_tensor(%device, %{name}_dim0, %{name}_dim1, %{name}_dim2, %{name}_element_type, %{name}_seed) : (!hal.device, i64, i64, i64, i32, i32) -> !hal.buffer_view\n" |
| ) |
| |
| |
| call_id = 0 |
| |
| |
| # Generate causal mask as i8 tensor (0 or 1 values) for shape [batch, m, k2] |
| # Causal pattern: mask[b, i, j] = 1 if j <= i, else 0 |
| def generate_causal_mask_values(batch: int, m: int, k2: int) -> list: |
| mask = [] |
| for b in range(batch): |
| for i in range(m): |
| for j in range(k2): |
| mask.append(1 if j <= i else 0) |
| return mask |
| |
| |
| # Generate all-ones mask as i8 tensor for shape [batch, m, k2] |
| # All positions can attend to all positions. |
| def generate_all_ones_mask_values(batch: int, m: int, k2: int) -> list: |
| return [1] * (batch * m * k2) |
| |
| |
| def generate_call( |
| function: MLIRFunction, |
| query_type: QueryElemTypeId, |
| key_type: KeyElemTypeId, |
| value_type: ValueElemTypeId, |
| shapes_scale: TestShapeAndScale, |
| mask_type: MaskType, |
| ): |
| global call_id |
| func_name = f"{function.name}_{shapes_scale.batch}_{shapes_scale.m}_{shapes_scale.k1}_{shapes_scale.k2}_{shapes_scale.n}_{shapes_scale.k1}_{shapes_scale.scale}" |
| func_name = f"{func_name}_{call_id}" |
| call_id = call_id + 1 |
| |
| use_mask = mask_type != MaskType.NONE |
| description = f"Attention shape (BATCHxMxK1xK2xN): {shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}x{shapes_scale.k2}x{shapes_scale.k1}x{shapes_scale.n}" |
| if mask_type == MaskType.ALL_ONES: |
| description = description + " (all-ones masked)" |
| elif mask_type == MaskType.CAUSAL: |
| description = description + " (causal masked)" |
| op = ( |
| f"func.func @{func_name}() attributes {{\n" |
| f' iree.reflection = {{description = "{description}"}}\n' |
| "} {\n" |
| " %device_index = arith.constant 0 : index\n" |
| " %device = hal.devices.get %device_index : !hal.device\n" |
| ) |
| |
| query_shape, key_shape, value_shape, result_shape, mask_shape = get_tensor_shapes( |
| shapes_scale, |
| ) |
| |
| op = op + generate_random_3d_tensor("query", query_shape, query_type) |
| op = op + generate_random_3d_tensor("key", key_shape, key_type) |
| op = op + generate_random_3d_tensor("value", value_shape, value_type) |
| |
| if use_mask: |
| batch, m, k2 = mask_shape |
| mask_size = batch * m * k2 |
| # Generate mask as i8, then convert to i1 for attention op. |
| if mask_type == MaskType.ALL_ONES: |
| mask_values = generate_all_ones_mask_values(batch, m, k2) |
| elif mask_type == MaskType.CAUSAL: |
| mask_values = generate_causal_mask_values(batch, m, k2) |
| mask_values_str = ", ".join(str(v) for v in mask_values) |
| op = op + ( |
| f" %mask_i8 = arith.constant dense<[{mask_values_str}]> : tensor<{mask_size}xi8>\n" |
| f" %mask_i8_reshaped = tensor.expand_shape %mask_i8 [[0, 1, 2]] output_shape [{batch}, {m}, {k2}] : tensor<{mask_size}xi8> into tensor<{batch}x{m}x{k2}xi8>\n" |
| f" %c0_i8 = arith.constant 0 : i8\n" |
| f" %zeros_i8 = tensor.empty() : tensor<{batch}x{m}x{k2}xi8>\n" |
| f" %zeros_i8_filled = linalg.fill ins(%c0_i8 : i8) outs(%zeros_i8 : tensor<{batch}x{m}x{k2}xi8>) -> tensor<{batch}x{m}x{k2}xi8>\n" |
| f" %mask_i1 = arith.cmpi ne, %mask_i8_reshaped, %zeros_i8_filled : tensor<{batch}x{m}x{k2}xi8>\n" |
| f" %mask = hal.tensor.export %mask_i1 : tensor<{batch}x{m}x{k2}xi1> -> !hal.buffer_view\n" |
| ) |
| |
| global pseudorandom_generator_seed |
| pseudorandom_generator_seed = pseudorandom_generator_seed - 1 |
| |
| if use_mask: |
| op = op + ( |
| f" %scale = arith.constant {shapes_scale.scale} : f32\n" |
| f" %result = call @module.{function.name}(%query, %key, %value, %mask, %scale) : (!hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, f32) -> !hal.buffer_view\n" |
| ) |
| else: |
| op = op + ( |
| f" %scale = arith.constant {shapes_scale.scale} : f32\n" |
| f" %result = call @module.{function.name}(%query, %key, %value, %scale) : (!hal.buffer_view, !hal.buffer_view, !hal.buffer_view, f32) -> !hal.buffer_view\n" |
| ) |
| |
| op = op + ( |
| f" %batch = arith.constant {shapes_scale.batch} : i64 \n" |
| f" %m = arith.constant {shapes_scale.m} : i64 \n" |
| f" %k1 = arith.constant {shapes_scale.k1} : i64 \n" |
| f" %k2 = arith.constant {shapes_scale.k2} : i64 \n" |
| f" %n = arith.constant {shapes_scale.n} : i64 \n" |
| f" %queryTensor = hal.tensor.import %query : !hal.buffer_view -> tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}xf16> \n" |
| f" %keyTensor = hal.tensor.import %key : !hal.buffer_view -> tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.k1}xf16> \n" |
| f" %valueTensor = hal.tensor.import %value : !hal.buffer_view -> tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.n}xf16> \n" |
| f" %resultTensor = hal.tensor.import %result : !hal.buffer_view -> tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.n}xf16> \n" |
| f" %queryExt = arith.extf %queryTensor : tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}xf16> to tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}xf32> \n" |
| f" %keyExt = arith.extf %keyTensor : tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.k1}xf16> to tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.k1}xf32> \n" |
| f" %valueExt = arith.extf %valueTensor : tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.n}xf16> to tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.n}xf32> \n" |
| f" %resultExt = arith.extf %resultTensor : tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.n}xf16> to tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.n}xf32> \n" |
| f" %queryExtBufferView = hal.tensor.export %queryExt : tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.k1}xf32> -> !hal.buffer_view \n" |
| f" %keyExtBufferView = hal.tensor.export %keyExt : tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.k1}xf32> -> !hal.buffer_view \n" |
| f" %valueExtBufferView = hal.tensor.export %valueExt : tensor<{shapes_scale.batch}x{shapes_scale.k2}x{shapes_scale.n}xf32> -> !hal.buffer_view \n" |
| f" %resultExtBufferView = hal.tensor.export %resultExt : tensor<{shapes_scale.batch}x{shapes_scale.m}x{shapes_scale.n}xf32> -> !hal.buffer_view \n" |
| ) |
| |
| if use_mask: |
| # Export mask as i8 for the reference implementation. |
| op = op + ( |
| f" %mask_i8_export = hal.tensor.export %mask_i8_reshaped : tensor<{batch}x{m}x{k2}xi8> -> !hal.buffer_view\n" |
| f" call @attention_test.check_attention_results_with_mask(%device, %batch, %m, %k1, %k2, %n, %queryExtBufferView, %keyExtBufferView, %valueExtBufferView, %mask_i8_export, %resultExtBufferView) : (!hal.device, i64, i64, i64, i64, i64, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view) -> ()\n" |
| ) |
| else: |
| op = op + ( |
| f" call @attention_test.check_attention_results(%device, %batch, %m, %k1, %k2, %n, %queryExtBufferView, %keyExtBufferView, %valueExtBufferView, %resultExtBufferView) : (!hal.device, i64, i64, i64, i64, i64, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view, !hal.buffer_view) -> ()\n" |
| ) |
| |
| op = op + " return\n" |
| op = op + "}\n" |
| |
| return TestCall(function=function, op=op) |
| |
| |
| # Generates all output files' contents as strings. |
| def generate( |
| query_type: QueryElemTypeId, |
| key_type: KeyElemTypeId, |
| value_type: ValueElemTypeId, |
| shapes_id: ShapesId, |
| mask_type: MaskType, |
| ): |
| functions = {} |
| calls = [] |
| |
| for shape in get_test_shapes(shapes_id): |
| function = generate_function( |
| query_type, |
| key_type, |
| value_type, |
| shape, |
| mask_type, |
| ) |
| if function.name not in functions: |
| functions[function.name] = function |
| calls.append( |
| generate_call( |
| function, |
| query_type, |
| key_type, |
| value_type, |
| shape, |
| mask_type, |
| ) |
| ) |
| |
| return (functions, calls) |
| |
| |
| def parse_arguments(): |
| parser = argparse.ArgumentParser(description="Generator of e2e Attention tests") |
| parser.add_argument( |
| "--output_attention_mlir", |
| type=str, |
| help="Path of output .mlir file containing the generated Attention functions", |
| required=True, |
| ) |
| parser.add_argument( |
| "--output_calls_mlir", |
| type=str, |
| help="Path of output .mlir file containing the calls", |
| required=True, |
| ) |
| parser.add_argument( |
| "--query_type", |
| type=str, |
| choices=["f16"], |
| help="Numeric type of query tensors ", |
| required=True, |
| ) |
| parser.add_argument( |
| "--key_type", |
| type=str, |
| choices=["f16"], |
| help="Numeric type of key tensors ", |
| required=True, |
| ) |
| parser.add_argument( |
| "--value_type", |
| type=str, |
| choices=["f16"], |
| help="Numeric type of value tensors ", |
| required=True, |
| ) |
| parser.add_argument( |
| "--shapes", |
| type=str, |
| choices=[s.value for s in ShapesId], |
| help="Collection of tensor shapes to test", |
| required=True, |
| ) |
| parser.add_argument( |
| "--requirements", |
| type=str, |
| help="Target requirements for this module. Comma-separated. As in -iree-llvmcpu-target-cpu-features. If the target device does not meet all of the requirements, the test will be skipped.", |
| required=False, |
| ) |
| parser.add_argument( |
| "--mask_type", |
| type=str, |
| choices=[m.value for m in MaskType], |
| default="none", |
| help="Type of attention mask to generate: none, all_ones (for decode), or causal (for prefill)", |
| ) |
| return parser.parse_args() |
| |
| |
| def write_code_file(functions, filename): |
| with open(filename, "w") as file: |
| for function in functions.values(): |
| file.write(function.definition + "\n") |
| |
| |
| def write_calls_file(functions, calls, filename, requirements): |
| # Module-level reflection information used to control the test tool. |
| reflection = "" |
| if requirements: |
| reflection = ( |
| "iree.reflection = {" |
| 'target_features = "' |
| + ",".join([req.lstrip("+") for req in requirements.split(",")]) |
| + '"' |
| "}" |
| ) |
| module_definition = ( |
| f"builtin.module @calls attributes {{\n" f" {reflection}\n" f"}} {{\n\n" |
| ) |
| |
| # Declare the custom module that generates arguments. |
| module_definition = module_definition + ( |
| "func.func private @attention_test.generate_random_tensor(%device: !hal.device, %dim0: i64, %dim1: i64, %dim2: i64, %element_type: i32, %seed: i32) -> !hal.buffer_view\n" |
| "func.func private @attention_test.check_attention_results(%device: !hal.device, %batch: i64, %m: i64, %k1: i64, %k2: i64, %n: i64, %query: !hal.buffer_view, %key: !hal.buffer_view, %value: !hal.buffer_view, %result: !hal.buffer_view)\n" |
| "func.func private @attention_test.check_attention_results_with_mask(%device: !hal.device, %batch: i64, %m: i64, %k1: i64, %k2: i64, %n: i64, %query: !hal.buffer_view, %key: !hal.buffer_view, %value: !hal.buffer_view, %mask: !hal.buffer_view, %result: !hal.buffer_view)\n" |
| "\n" |
| ) |
| |
| # Declare the functions that will be called. |
| for function in functions.values(): |
| module_definition = module_definition + function.import_declaration + "\n" |
| module_definition = module_definition + "\n" |
| |
| # Emit the test cases for each call. |
| for call in calls: |
| module_definition = module_definition + call.op + "\n" |
| |
| module_definition = module_definition + "\n}\n" |
| |
| with open(filename, "w") as file: |
| file.write(module_definition) |
| |
| |
| def main(args): |
| query_type = QueryElemTypeId(args.query_type) |
| key_type = KeyElemTypeId(args.key_type) |
| value_type = ValueElemTypeId(args.value_type) |
| shapes_id = ShapesId(args.shapes) |
| mask_type = MaskType(args.mask_type) |
| |
| (functions, calls) = generate( |
| query_type, |
| key_type, |
| value_type, |
| shapes_id, |
| mask_type, |
| ) |
| |
| write_code_file(functions, args.output_attention_mlir) |
| write_calls_file( |
| functions, |
| calls, |
| args.output_calls_mlir, |
| args.requirements, |
| ) |
| |
| |
| if __name__ == "__main__": |
| main(parse_arguments()) |