blob: 2a59d72ea24c00b6fe82c3de7ca749af351e6d8a [file] [log] [blame]
# Copyright 2025 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from iree.compiler import ir
from iree.compiler.dialects import iree_codegen
from iree.compiler.dialects import affine
def run(fn):
with ir.Context(), ir.Location.unknown():
module = ir.Module.create()
with ir.InsertionPoint(module.body):
print("\nTEST:", fn.__name__)
fn()
return fn
@run
def root_op():
module_str = """
module {
func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<4x4xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<4x4xf32>) -> tensor<4x4xf32>
%2 = linalg.matmul ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%1 : tensor<4x4xf32>) -> tensor<4x4xf32>
return %2 : tensor<4x4xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
assert input_module is not None, "Failed to parse input MLIR module"
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
assert len(root_op_list) == 0
module_str = """
module {
func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<4x4xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<4x4xf32>) -> tensor<4x4xf32>
%2 = linalg.matmul { root_op } ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%1 : tensor<4x4xf32>) -> tensor<4x4xf32>
return %2 : tensor<4x4xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
assert input_module is not None, "Failed to parse input MLIR module"
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
assert len(root_op_list) == 1
assert root_op_list[0].name == "linalg.matmul"
module_str = """
module {
func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<4x4xf32>
%1 = linalg.fill { root_op } ins(%cst : f32) outs(%0 : tensor<4x4xf32>) -> tensor<4x4xf32>
%2 = linalg.matmul { root_op } ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%1 : tensor<4x4xf32>) -> tensor<4x4xf32>
return %2 : tensor<4x4xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
assert input_module is not None, "Failed to parse input MLIR module"
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
assert len(root_op_list) == 2
assert root_op_list[0].name == "linalg.fill"
assert root_op_list[1].name == "linalg.matmul"
@run
def attention_op_detail():
dim_exprs = [affine.AffineDimExpr.get(i) for i in range(5)]
q_map = affine.AffineMap.get(
5, 0, [dim_exprs[0], dim_exprs[1], dim_exprs[2]]
) # (d0, d1, d2).
k_map = affine.AffineMap.get(
5, 0, [dim_exprs[0], dim_exprs[3], dim_exprs[2]]
) # (d0, d3, d2).
v_map = affine.AffineMap.get(
5, 0, [dim_exprs[0], dim_exprs[3], dim_exprs[4]]
) # (d0, d3, d4). # ()
o_map = affine.AffineMap.get(
5, 0, [dim_exprs[0], dim_exprs[1], dim_exprs[4]]
) # (d0, d1, d4).
result = iree_codegen.get_attention_op_detail(q_map, k_map, v_map, o_map)
assert result.domain_rank == 5
assert result.batch_dims == [0]
assert result.m_dims == [1]
assert result.k1_dims == [2]
assert result.k2_dims == [3]
assert result.n_dims == [4]
dim_exprs = [affine.AffineDimExpr.get(i) for i in range(4)]
# Input affine maps that do not follow the expected pattern for an attention operation.
q_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[1]]) # (d0, d1).
k_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[2]]) # (d0, d2).
v_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[3]]) # (d0, d3).
o_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[1]]) # (d0, d1).
result = iree_codegen.get_attention_op_detail(q_map, k_map, v_map, o_map)
assert result.domain_rank == 4
assert result.batch_dims == [0]
assert result.m_dims == [1]
assert result.k1_dims == []
assert result.k2_dims == [2]
assert result.n_dims == [3]
@run
def test_isa_attention_op():
module_str = """
module {
func.func @attention_20x4096x64x4096x64(
%q : tensor<20x4096x64xf16>,
%k : tensor<20x4096x64xf16>,
%v : tensor<20x4096x64xf16>,
%scale : f16,
%output : tensor<20x4096x64xf16>
) -> tensor<20x4096x64xf16> {
%result = iree_linalg_ext.attention { root_op,
indexing_maps = [
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
]
} ins(%q, %k, %v, %scale : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16)
outs(%output : tensor<20x4096x64xf16>) {
^bb0(%score: f32):
iree_linalg_ext.yield %score : f32
} -> tensor<20x4096x64xf16>
return %result : tensor<20x4096x64xf16>
}
}
"""
input_module = ir.Module.parse(module_str)
assert input_module is not None, "Failed to parse input MLIR module"
print(input_module)
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
assert len(root_op_list) == 1
assert root_op_list[0].name == "iree_linalg_ext.attention"
assert iree_codegen.isa_attention_op(root_op_list[0])