blob: 39e3b27fdf65b86d964fb850f30a6a767d3fe262 [file]
# Copyright 2025 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from iree.compiler import ir
from iree.compiler.dialects import affine, iree_codegen, iree_gpu, smt
from iree.compiler.ir import AffineMap, AffineDimExpr
def run(fn):
with ir.Context(), ir.Location.unknown():
module = ir.Module.create()
with ir.InsertionPoint(module.body):
print("\nTEST:", fn.__name__)
fn()
return fn
@run
def root_op_attr():
# Test 1: Create RootOpAttr and read back the set.
attr = iree_codegen.RootOpAttr.get(set=0)
assert isinstance(attr, iree_codegen.RootOpAttr)
assert attr.set == 0
attr1 = iree_codegen.RootOpAttr.get(set=42)
assert attr1.set == 42
# Test 2: Default set value is 0.
attr_default = iree_codegen.RootOpAttr.get()
assert attr_default.set == 0
# Test 3: Parse from MLIR and read back.
parsed = ir.Attribute.parse("#iree_codegen.root_op<set = 7>")
assert isinstance(parsed, iree_codegen.RootOpAttr)
assert parsed.set == 7
# Test 4: Round-trip through string.
attr2 = iree_codegen.RootOpAttr.get(set=3)
reparsed = ir.Attribute.parse(str(attr2))
assert isinstance(reparsed, iree_codegen.RootOpAttr)
assert reparsed.set == 3
@run
def root_op():
module_str = """
module {
func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<4x4xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<4x4xf32>) -> tensor<4x4xf32>
%2 = linalg.matmul ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%1 : tensor<4x4xf32>) -> tensor<4x4xf32>
return %2 : tensor<4x4xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
assert input_module is not None, "Failed to parse input MLIR module"
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
assert len(root_op_list) == 0
module_str = """
module {
func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<4x4xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<4x4xf32>) -> tensor<4x4xf32>
%2 = linalg.matmul { root_op = #iree_codegen.root_op<set = 0> } ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%1 : tensor<4x4xf32>) -> tensor<4x4xf32>
return %2 : tensor<4x4xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
assert input_module is not None, "Failed to parse input MLIR module"
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
assert len(root_op_list) == 1
assert root_op_list[0].name == "linalg.matmul"
module_str = """
module {
func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<4x4xf32>
%1 = linalg.fill { root_op = #iree_codegen.root_op<set = 0> } ins(%cst : f32) outs(%0 : tensor<4x4xf32>) -> tensor<4x4xf32>
%2 = linalg.matmul { root_op = #iree_codegen.root_op<set = 0> } ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%1 : tensor<4x4xf32>) -> tensor<4x4xf32>
return %2 : tensor<4x4xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
assert input_module is not None, "Failed to parse input MLIR module"
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
assert len(root_op_list) == 2
assert root_op_list[0].name == "linalg.fill"
assert root_op_list[1].name == "linalg.matmul"
@run
def attention_op_detail():
dim_exprs = [affine.AffineDimExpr.get(i) for i in range(5)]
q_map = affine.AffineMap.get(5, 0, [dim_exprs[0], dim_exprs[1], dim_exprs[2]])
k_map = affine.AffineMap.get(5, 0, [dim_exprs[0], dim_exprs[3], dim_exprs[2]])
v_map = affine.AffineMap.get(5, 0, [dim_exprs[0], dim_exprs[3], dim_exprs[4]])
o_map = affine.AffineMap.get(5, 0, [dim_exprs[0], dim_exprs[1], dim_exprs[4]])
result = iree_codegen.get_attention_op_detail(q_map, k_map, v_map, o_map)
assert result.domain_rank == 5
assert result.batch_dims == [0]
assert result.m_dims == [1]
assert result.k1_dims == [2]
assert result.k2_dims == [3]
assert result.n_dims == [4]
dim_exprs = [affine.AffineDimExpr.get(i) for i in range(4)]
# Input affine maps that do not follow the expected pattern for an attention operation.
q_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[1]])
k_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[2]])
v_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[3]])
o_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[1]])
result = iree_codegen.get_attention_op_detail(q_map, k_map, v_map, o_map)
assert result.domain_rank == 4
assert result.batch_dims == [0]
assert result.m_dims == [1]
assert result.k1_dims == []
assert result.k2_dims == [2]
assert result.n_dims == [3]
@run
def test_isa_attention_op():
module_str = """
module {
func.func @attention_20x4096x64x4096x64(
%q : tensor<20x4096x64xf16>,
%k : tensor<20x4096x64xf16>,
%v : tensor<20x4096x64xf16>,
%scale : f16,
%output : tensor<20x4096x64xf16>
) -> tensor<20x4096x64xf16> {
%result = iree_linalg_ext.attention { root_op = #iree_codegen.root_op<set = 0>,
indexing_maps = [
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
]
} ins(%q, %k, %v, %scale : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16)
outs(%output : tensor<20x4096x64xf16>) {
^bb0(%score: f32):
iree_linalg_ext.yield %score : f32
} -> tensor<20x4096x64xf16>
return %result : tensor<20x4096x64xf16>
}
}
"""
input_module = ir.Module.parse(module_str)
assert input_module is not None, "Failed to parse input MLIR module"
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
assert len(root_op_list) == 1
assert root_op_list[0].name == "iree_linalg_ext.attention"
assert iree_codegen.isa_attention_op(root_op_list[0])
@run
def test_igemm_conv_details():
# Test 1: conv_2d_nhwc_hwcf.
module_str = """
module {
func.func @conv_2d_nhwc_hwcf(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
%0 = linalg.conv_2d_nhwc_hwcf { root_op = #iree_codegen.root_op<set = 0>, dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%arg0, %arg1 : tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
outs(%arg2 : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
return %0 : tensor<1x14x14x16xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
details = iree_codegen.get_igemm_generic_conv_details(root_op_list[0])
assert details is not None, "IGEMM details should be valid for NHWC_HWCF conv"
assert details.igemm_loop_bounds == [1, 14, 14, 16, 36]
assert len(details.igemm_contraction_maps) == 3
maps = [map_attr.value for map_attr in details.igemm_contraction_maps]
d0, d1, d2, d3, d4 = [AffineDimExpr.get(i) for i in range(5)]
# For channel-last (NHWC): input (N,H,W,K), filter (K,OC), output (N,H,W,OC).
assert maps[0] == AffineMap.get(
5, 0, [d0, d1, d2, d4]
), f"Input map mismatch: {maps[0]}"
assert maps[1] == AffineMap.get(5, 0, [d4, d3]), f"Filter map mismatch: {maps[1]}"
assert maps[2] == AffineMap.get(
5, 0, [d0, d1, d2, d3]
), f"Output map mismatch: {maps[2]}"
iter_types = [str(attr) for attr in details.igemm_loop_iterators]
assert iter_types == [
'"parallel"',
'"parallel"',
'"parallel"',
'"parallel"',
'"reduction"',
]
assert details.im2col_output_perm == [0, 1, 2, 3]
assert details.filter_reassoc_indices == [[0, 1, 2], [3]]
assert not details.is_output_channel_first
assert details.conv_to_igemm_dim_map == {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 4, 6: 4}
# Test 2: conv_2d_nhwc_fhwc.
module_str = """
module {
func.func @conv_2d_nhwc_fhwc(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
%0 = linalg.conv_2d_nhwc_fhwc { root_op = #iree_codegen.root_op<set = 0>, dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%arg0, %arg1 : tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>)
outs(%arg2 : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
return %0 : tensor<1x14x14x16xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
details = iree_codegen.get_igemm_generic_conv_details(root_op_list[0])
assert details is not None, "IGEMM details should be valid for NHWC_FHWC conv"
assert details.igemm_loop_bounds == [1, 14, 14, 16, 36]
assert len(details.igemm_contraction_maps) == 3
maps = [map_attr.value for map_attr in details.igemm_contraction_maps]
# Verify expected affine maps (NHWC_FHWC layout).
d0, d1, d2, d3, d4 = [AffineDimExpr.get(i) for i in range(5)]
assert maps[0] == AffineMap.get(
5, 0, [d0, d1, d2, d4]
), f"Input map mismatch: {maps[0]}"
assert maps[1] == AffineMap.get(5, 0, [d3, d4]), f"Filter map mismatch: {maps[1]}"
assert maps[2] == AffineMap.get(
5, 0, [d0, d1, d2, d3]
), f"Output map mismatch: {maps[2]}"
iter_types = [str(attr) for attr in details.igemm_loop_iterators]
assert iter_types == [
'"parallel"',
'"parallel"',
'"parallel"',
'"parallel"',
'"reduction"',
]
assert details.im2col_output_perm == [0, 1, 2, 3]
assert details.filter_reassoc_indices == [[0], [1, 2, 3]]
assert not details.is_output_channel_first
assert details.conv_to_igemm_dim_map == {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 4, 6: 4}
# Test 3: conv_2d_nchw_fchw.
module_str = """
module {
func.func @conv_2d_nchw_fchw(%arg0: tensor<1x4x16x16xf32>, %arg1: tensor<16x4x3x3xf32>, %arg2: tensor<1x16x14x14xf32>) -> tensor<1x16x14x14xf32> {
%0 = linalg.conv_2d_nchw_fchw { root_op = #iree_codegen.root_op<set = 0>, dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%arg0, %arg1 : tensor<1x4x16x16xf32>, tensor<16x4x3x3xf32>)
outs(%arg2 : tensor<1x16x14x14xf32>) -> tensor<1x16x14x14xf32>
return %0 : tensor<1x16x14x14xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
details = iree_codegen.get_igemm_generic_conv_details(root_op_list[0])
assert details is not None, "IGEMM details should be valid for NCHW conv"
assert details.igemm_loop_bounds == [1, 16, 14, 14, 36]
assert len(details.igemm_contraction_maps) == 3
maps = [map_attr.value for map_attr in details.igemm_contraction_maps]
# Verify expected affine maps for NCHW with loop dims [N, OC, H, W, K].
# Note: operands are swapped - filter first, then input.
d0, d1, d2, d3, d4 = [AffineDimExpr.get(i) for i in range(5)]
assert maps[0] == AffineMap.get(5, 0, [d1, d4]), f"Filter map mismatch: {maps[0]}"
assert maps[1] == AffineMap.get(
5, 0, [d0, d4, d2, d3]
), f"Input map mismatch: {maps[1]}"
assert maps[2] == AffineMap.get(
5, 0, [d0, d1, d2, d3]
), f"Output map mismatch: {maps[2]}"
iter_types = [str(attr) for attr in details.igemm_loop_iterators]
assert iter_types == [
'"parallel"',
'"parallel"',
'"parallel"',
'"parallel"',
'"reduction"',
]
assert details.im2col_output_perm == [
0,
3,
1,
2,
], f"im2col output perm mismatch: {details.im2col_output_perm}"
assert details.filter_reassoc_indices == [[0], [1, 2, 3]]
assert details.is_output_channel_first
assert details.conv_to_igemm_dim_map == {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 4, 6: 4}
# Test 4: linalg.generic with convolution pattern (weight backward).
module_str = """
module {
func.func @conv_generic_weight_backward(%arg0: tensor<16x98x64x96xf32>, %arg1: tensor<16x96x64x96xf32>, %arg2: tensor<96x3x96xf32>) -> tensor<96x3x96xf32> {
%0 = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d1 + d4, d5, d2)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5, d0)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
],
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]
} ins(%arg0, %arg1 : tensor<16x98x64x96xf32>, tensor<16x96x64x96xf32>) outs(%arg2 : tensor<96x3x96xf32>) attrs = {root_op = #iree_codegen.root_op<set = 0>} {
^bb0(%in: f32, %in_1: f32, %out: f32):
%mul = arith.mulf %in, %in_1 : f32
%add = arith.addf %out, %mul : f32
linalg.yield %add : f32
} -> tensor<96x3x96xf32>
return %0 : tensor<96x3x96xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
details = iree_codegen.get_igemm_generic_conv_details(root_op_list[0])
assert (
details is not None
), "IGEMM details should be valid for generic 1D conv weight backward"
assert details.igemm_loop_bounds == [96, 3, 96, 16, 6144]
assert len(details.igemm_contraction_maps) == 3
maps = [map_attr.value for map_attr in details.igemm_contraction_maps]
d0, d1, d2, d3, d4 = [AffineDimExpr.get(i) for i in range(5)]
assert maps[0] == AffineMap.get(5, 0, [d3, d4, d0]), f"Map 0 mismatch: {maps[0]}"
assert maps[1] == AffineMap.get(
5, 0, [d3, d1, d4, d2]
), f"Map 1 mismatch: {maps[1]}"
assert maps[2] == AffineMap.get(5, 0, [d0, d1, d2]), f"Map 2 mismatch: {maps[2]}"
iter_types = [str(attr) for attr in details.igemm_loop_iterators]
assert iter_types == [
'"parallel"',
'"parallel"',
'"parallel"',
'"reduction"',
'"reduction"',
]
assert details.im2col_output_perm == [2, 0, 3, 1]
assert details.filter_reassoc_indices == [[0], [1, 2], [3]]
assert details.is_output_channel_first
assert details.conv_to_igemm_dim_map == {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 4}
# Test with a non-conv operation.
module_str = """
module {
func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> tensor<4x4xf32> {
%0 = linalg.matmul { root_op = #iree_codegen.root_op<set = 0> } ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%arg2 : tensor<4x4xf32>) -> tensor<4x4xf32>
return %0 : tensor<4x4xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
matmul_op = root_op_list[0]
details = iree_codegen.get_igemm_generic_conv_details(matmul_op)
assert details is None, "IGEMM details should be None for non-conv operation"
@run
def test_isa_scaled_contraction_op():
# Test 1: Regular matmul is not a scaled contraction.
module_str = """
module {
func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> tensor<4x4xf32> {
%0 = linalg.matmul { root_op = #iree_codegen.root_op<set = 0> } ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%arg2 : tensor<4x4xf32>) -> tensor<4x4xf32>
return %0 : tensor<4x4xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
assert input_module is not None, "Failed to parse input MLIR module"
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
assert len(root_op_list) == 1
matmul_op = root_op_list[0]
assert not iree_codegen.isa_scaled_contraction_op(
matmul_op
), "Regular matmul should not be a scaled contraction"
# Test 2: Fill op is not a scaled contraction.
module_str = """
module {
func.func @fill(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.fill { root_op = #iree_codegen.root_op<set = 0> } ins(%cst : f32) outs(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32>
return %0 : tensor<4x4xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
assert len(root_op_list) == 1
fill_op = root_op_list[0]
assert not iree_codegen.isa_scaled_contraction_op(
fill_op
), "Fill op should not be a scaled contraction"
# Test 3: Scaled matmul as linalg.generic should be detected.
# Pattern: linalg.generic with 5 indexing maps (lhs, rhs, lhs_scale, rhs_scale, output),
# and 4 iterator types (2 parallel for M,N; 2 reduction for Ko,Kb).
# Uses f4E2M1FN for operands and f8E8M0FNU for scales (matching real scaled matmul pattern).
module_str = """
module {
func.func @scaled_matmul(%lhs: tensor<16x4x32xf4E2M1FN>, %rhs: tensor<16x4x32xf4E2M1FN>,
%lhs_scales: tensor<16x4xf8E8M0FNU>, %rhs_scales: tensor<16x4xf8E8M0FNU>,
%out: tensor<16x16xf32>) -> tensor<16x16xf32> {
%result = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d2)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1)>
],
iterator_types = ["parallel", "parallel", "reduction", "reduction"],
root_op = #iree_codegen.root_op<set = 0>
} ins(%lhs, %rhs, %lhs_scales, %rhs_scales : tensor<16x4x32xf4E2M1FN>, tensor<16x4x32xf4E2M1FN>, tensor<16x4xf8E8M0FNU>, tensor<16x4xf8E8M0FNU>)
outs(%out : tensor<16x16xf32>) {
^bb0(%a: f4E2M1FN, %b: f4E2M1FN, %a_scale: f8E8M0FNU, %b_scale: f8E8M0FNU, %acc: f32):
%a_scaled = arith.scaling_extf %a, %a_scale : f4E2M1FN, f8E8M0FNU to f32
%b_scaled = arith.scaling_extf %b, %b_scale : f4E2M1FN, f8E8M0FNU to f32
%prod = arith.mulf %a_scaled, %b_scaled : f32
%sum = arith.addf %acc, %prod : f32
linalg.yield %sum : f32
} -> tensor<16x16xf32>
return %result : tensor<16x16xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
assert len(root_op_list) == 1, "Should have one root op"
scaled_generic_op = root_op_list[0]
is_scaled = iree_codegen.isa_scaled_contraction_op(scaled_generic_op)
assert (
is_scaled
), "linalg.generic with scaled matmul pattern should be detected as scaled contraction"
dims = iree_codegen.infer_scaled_contraction_dimensions(scaled_generic_op)
assert dims is not None, "Should be able to infer dimensions for scaled contraction"
assert dims.m == [0], f"Got {dims.m}"
assert dims.n == [1], f"Got {dims.n}"
assert dims.k == [2], f"Got {dims.k}"
assert dims.kB == [3], f"Got {dims.kB}"
assert dims.batch == [], f"Got {dims.batch}"
@run
def test_infer_scaled_contraction_dimensions():
# Test 1: Verify dimension inference on a scaled matmul operation.
module_str = """
module {
func.func @scaled_matmul(%lhs: tensor<16x4x32xf4E2M1FN>, %rhs: tensor<16x4x32xf4E2M1FN>,
%lhs_scales: tensor<16x4xf8E8M0FNU>, %rhs_scales: tensor<16x4xf8E8M0FNU>,
%out: tensor<16x16xf32>) -> tensor<16x16xf32> {
%result = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d2)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1)>
],
iterator_types = ["parallel", "parallel", "reduction", "reduction"],
root_op = #iree_codegen.root_op<set = 0>
} ins(%lhs, %rhs, %lhs_scales, %rhs_scales : tensor<16x4x32xf4E2M1FN>, tensor<16x4x32xf4E2M1FN>, tensor<16x4xf8E8M0FNU>, tensor<16x4xf8E8M0FNU>)
outs(%out : tensor<16x16xf32>) {
^bb0(%a: f4E2M1FN, %b: f4E2M1FN, %a_scale: f8E8M0FNU, %b_scale: f8E8M0FNU, %acc: f32):
%a_scaled = arith.scaling_extf %a, %a_scale : f4E2M1FN, f8E8M0FNU to f32
%b_scaled = arith.scaling_extf %b, %b_scale : f4E2M1FN, f8E8M0FNU to f32
%prod = arith.mulf %a_scaled, %b_scaled : f32
%sum = arith.addf %acc, %prod : f32
linalg.yield %sum : f32
} -> tensor<16x16xf32>
return %result : tensor<16x16xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
assert len(root_op_list) == 1, "Should have exactly one root op"
scaled_op = root_op_list[0]
assert iree_codegen.isa_scaled_contraction_op(
scaled_op
), "Operation should be recognized as scaled contraction"
dims = iree_codegen.infer_scaled_contraction_dimensions(scaled_op)
assert dims is not None, "Should successfully infer dimensions"
assert dims.m == [0], f"Got {dims.m}"
assert dims.n == [1], f"Got {dims.n}"
assert dims.k == [2], f"Got {dims.k}"
assert dims.kB == [3], f"Got {dims.kB}"
assert dims.batch == [], f"Got {dims.batch}"
# Test 2: Non-scaled contraction should return None.
module_str_regular = """
module {
func.func @regular_matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> tensor<4x4xf32> {
%0 = linalg.matmul { root_op = #iree_codegen.root_op<set = 0> } ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%arg2 : tensor<4x4xf32>) -> tensor<4x4xf32>
return %0 : tensor<4x4xf32>
}
}
"""
input_module_regular = ir.Module.parse(module_str_regular)
regular_ops = iree_codegen.get_tuner_root_ops(input_module_regular)
assert len(regular_ops) == 1
regular_matmul = regular_ops[0]
# Regular matmul should not have scaled contraction dimensions.
# Check if all dimensions are empty (indicating it's not a scaled contraction).
dims_regular = iree_codegen.infer_scaled_contraction_dimensions(regular_matmul)
if dims_regular is not None:
all_empty = (
len(dims_regular.m) == 0
and len(dims_regular.n) == 0
and len(dims_regular.k) == 0
and len(dims_regular.kB) == 0
and len(dims_regular.batch) == 0
)
assert (
all_empty or dims_regular is None
), "Regular matmul should not have valid scaled contraction dimensions"
# Test 3: Batched scaled matmul.
module_str_batched = """
module {
func.func @batched_scaled_matmul(%lhs: tensor<8x16x4x32xf4E2M1FN>, %rhs: tensor<8x16x4x32xf4E2M1FN>,
%lhs_scales: tensor<8x16x4xf8E8M0FNU>, %rhs_scales: tensor<8x16x4xf8E8M0FNU>,
%out: tensor<8x16x16xf32>) -> tensor<8x16x16xf32> {
%result = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
],
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"],
root_op = #iree_codegen.root_op<set = 0>
} ins(%lhs, %rhs, %lhs_scales, %rhs_scales : tensor<8x16x4x32xf4E2M1FN>, tensor<8x16x4x32xf4E2M1FN>, tensor<8x16x4xf8E8M0FNU>, tensor<8x16x4xf8E8M0FNU>)
outs(%out : tensor<8x16x16xf32>) {
^bb0(%a: f4E2M1FN, %b: f4E2M1FN, %a_scale: f8E8M0FNU, %b_scale: f8E8M0FNU, %acc: f32):
%a_scaled = arith.scaling_extf %a, %a_scale : f4E2M1FN, f8E8M0FNU to f32
%b_scaled = arith.scaling_extf %b, %b_scale : f4E2M1FN, f8E8M0FNU to f32
%prod = arith.mulf %a_scaled, %b_scaled : f32
%sum = arith.addf %acc, %prod : f32
linalg.yield %sum : f32
} -> tensor<8x16x16xf32>
return %result : tensor<8x16x16xf32>
}
}
"""
input_module_batched = ir.Module.parse(module_str_batched)
batched_ops = iree_codegen.get_tuner_root_ops(input_module_batched)
assert len(batched_ops) == 1, "Batched op should be found"
batched_op = batched_ops[0]
assert iree_codegen.isa_scaled_contraction_op(
batched_op
), "Batched scaled matmul should be recognized"
dims_batched = iree_codegen.infer_scaled_contraction_dimensions(batched_op)
assert (
dims_batched is not None
), "Batch dimension must be present in batched scaled matmul"
assert dims_batched.batch == [0], f"Got {dims_batched.batch}"
assert dims_batched.m == [1], f"Got {dims_batched.m}"
assert dims_batched.n == [2], f"Got {dims_batched.n}"
assert dims_batched.k == [3], f"Got {dims_batched.k}"
assert dims_batched.kB == [4], f"Got {dims_batched.kB}"
@run
def test_is_xor_shuffle_valid():
"""Test XOR shuffle validation (pure function, no MLIR attributes)."""
# Valid: row and access divide tile; row >= access; tile >= row.
assert iree_gpu.is_xor_shuffle_valid(256, 32, 512)
assert iree_gpu.is_xor_shuffle_valid(512, 64, 512)
assert iree_gpu.is_xor_shuffle_valid(32, 8, 512)
# Invalid: row exceeds tile.
assert not iree_gpu.is_xor_shuffle_valid(512, 32, 256)
# Invalid: access exceeds row.
assert not iree_gpu.is_xor_shuffle_valid(256, 512, 512)
# Invalid: row does not evenly divide tile.
assert not iree_gpu.is_xor_shuffle_valid(300, 32, 512)
# Invalid: access does not evenly divide row.
assert not iree_gpu.is_xor_shuffle_valid(256, 33, 512)
@run
def test_get_xor_shuffle_bounds():
"""Test XOR shuffle bounds for an MMA intrinsic (for use by SharkTuner)."""
# Use an MMA intrinsic that supports getXorShuffleBounds (InnerTileDescAttrInterface).
mma_attr = iree_gpu.MMAAttr.get(iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16)
bounds = iree_gpu.get_xor_shuffle_bounds(mma_attr, operand_index=0)
assert bounds is not None, "get_xor_shuffle_bounds should succeed for MMAAttr"
min_access_elems, total_tile_elems = bounds
assert min_access_elems == 4
assert total_tile_elems == 256
bounds_rhs = iree_gpu.get_xor_shuffle_bounds(mma_attr, operand_index=1)
assert bounds_rhs is not None
@run
def test_int_knob_attr():
"""Test IntKnobAttr Python bindings."""
attr = ir.Attribute.parse('#iree_codegen.smt.int_knob<"wg_m">')
assert isinstance(attr, iree_codegen.IntKnobAttr)
assert attr.name == "wg_m"
non_knob = ir.Attribute.parse("42 : i64")
assert not isinstance(non_knob, iree_codegen.IntKnobAttr)
@run
def test_one_of_knob_attr():
"""Test OneOfKnobAttr Python bindings."""
attr = ir.Attribute.parse(
'#iree_codegen.smt.one_of_knob<"mma_idx", ["opt_a", "opt_b", "opt_c"]>'
)
assert isinstance(attr, iree_codegen.OneOfKnobAttr)
assert attr.name == "mma_idx"
opts = attr.options
assert len(opts) == 3
assert str(opts[0]) == '"opt_a"'
assert str(opts[1]) == '"opt_b"'
assert str(opts[2]) == '"opt_c"'
@run
def test_get_iree_constraints_op():
module_str = """
module {
iree_codegen.smt.constraints
target = <set = 0>,
pipeline = #iree_gpu.pipeline<VectorDistribute>,
knobs = {}
dims() {
}
func.func @main() -> () {
iree_codegen.smt.constraints
target = #iree_codegen.root_op<set = 1>,
pipeline = #iree_gpu.pipeline<VectorDistribute>,
knobs = {}
dims() {
}
return
}
func.func @test() -> () {
iree_codegen.smt.constraints
target = #iree_codegen.root_op<set = 0>,
pipeline = #iree_gpu.pipeline<VectorDistribute>,
knobs = {}
dims() {
}
return
}
}
"""
input_module = ir.Module.parse(module_str)
# Test if IREE Codegen Op (eg., `iree_codegen.ConstraintsOp`) types are
# exposed by the bindings.
constraints_ops = ir.get_ops_of_type(input_module, iree_codegen.ConstraintsOp)
assert (
len(constraints_ops) == 3
), f"Should get 3 constraints ops, got {len(constraints_ops)}"
for i, op in enumerate(constraints_ops):
assert isinstance(op, iree_codegen.ConstraintsOp)
assert constraints_ops[0].target == iree_codegen.RootOpAttr.get(set=0)
assert constraints_ops[1].target == iree_codegen.RootOpAttr.get(set=1)
assert constraints_ops[2].target == iree_codegen.RootOpAttr.get(set=0)
@run
def test_convert_constraints_op_to_smtlib():
module_str = """
module {
iree_codegen.smt.constraints
target = <set = 0>,
pipeline = #iree_gpu.pipeline<VectorDistribute>,
knobs = {wg_m = #iree_codegen.smt.int_knob<"wg_m">,
mma_idx = #iree_codegen.smt.int_knob<"mma_idx">}
dims() {
^bb0:
%wg_m = iree_codegen.smt.knob "wg_m" : !smt.int
%idx = iree_codegen.smt.knob "mma_idx" : !smt.int
%mma_m = iree_codegen.smt.lookup %idx [0, 1] -> [16, 32] : !smt.int
%cost_fn = smt.declare_fun "cost_fn" : !smt.func<(!smt.int) !smt.int>
%cost = smt.apply_func %cost_fn(%wg_m) : !smt.func<(!smt.int) !smt.int>
%cond = smt.int.cmp le %wg_m, %cost
%cond_mma = smt.int.cmp le %mma_m, %wg_m
iree_codegen.smt.assert %cond, "wg_m is positive" : !smt.bool
iree_codegen.smt.assert %cond_mma, "mma_m ({}) <= wg_m ({})", %mma_m, %wg_m : !smt.bool, !smt.int, !smt.int
}
}
"""
input_module = ir.Module.parse(module_str)
constraints_ops = ir.get_ops_of_type(input_module, iree_codegen.ConstraintsOp)
assert (
len(constraints_ops) == 1
), f"Should get 1 constraints op, got {len(constraints_ops)}"
constraints_op = constraints_ops[0]
smtlib = iree_codegen.convert_constraints_op_to_smtlib(
constraints_op, emit_reset=False
)
assert smtlib is not None, "smtlib should be created"
assert "; solver scope 0" in smtlib, f"Missing solver scope header."
assert "(reset)" not in smtlib, f"Unexpected reset in SMTLIB:\n{smtlib}"
err_str = f"Knobs conversion failed. SMTLIB:\n{smtlib}"
# knobs become declare-const constants (0-ary smt.declare_fun).
assert "(declare-const wg_m Int)" in smtlib, err_str
assert "(declare-const mma_idx Int)" in smtlib, err_str
# smt.declare_fun with a function type becomes declare-fun.
assert "(declare-fun cost_fn (Int) Int)" in smtlib, err_str
# lookup [0,1]->[16,32] lowers to an ite chain.
assert "ite" in smtlib, f"Lookup conversion failed. SMTLIB:\n{smtlib}"
# assert ops become smt assert commands.
assert "(assert" in smtlib, f"Assert conversion failed. SMTLIB:\n{smtlib}"
# Test emit_reset option.
smtlib = iree_codegen.convert_constraints_op_to_smtlib(
constraints_op, emit_reset=True
)
assert "(reset)" in smtlib, f"Missing reset in SMTLIB:\n{smtlib}"
_MATERIALIZE_CONSTRAINTS_MODULE = """
module {
iree_codegen.smt.constraints
target = <set = 0>,
pipeline = #iree_gpu.pipeline<TileAndFuse>,
knobs = {
workgroup = [#iree_codegen.smt.int_knob<"wg_0">, 1, 1],
workgroup_size = [#iree_codegen.smt.int_knob<"wg_size_x">, 1, 1],
subgroup_size = #iree_codegen.smt.int_knob<"sg_size">,
gpu_pipeline_options = {prefetch_num_stages = #iree_codegen.smt.int_knob<"prefetch_num_stages">,
no_reduce_shared_memory_bank_conflicts = false,
use_igemm_convolution = #iree_codegen.smt.one_of_knob<"use_igemm_convolution", [false, true]>}
}
dims() {
}
}
"""
def _get_materialize_constraints_op():
input_module = ir.Module.parse(_MATERIALIZE_CONSTRAINTS_MODULE)
constraints_ops = ir.get_ops_of_type(input_module, iree_codegen.ConstraintsOp)
assert len(constraints_ops) == 1
return input_module, constraints_ops[0]
@run
def test_materialize_compilation_info_happy_path():
# Keep the module alive while using the op wrapper below.
input_module, constraints_op = _get_materialize_constraints_op()
compilation_info = iree_codegen.materialize_compilation_info(
constraints_op,
{
"wg_0": 64,
"wg_size_x": 128,
"sg_size": 64,
"prefetch_num_stages": 2,
"use_igemm_convolution": False,
},
)
assert isinstance(compilation_info, iree_codegen.CompilationInfoAttr)
assert str(compilation_info.lowering_config) == (
"#iree_gpu.lowering_config<{workgroup = [64, 1, 1]}>"
)
translation_info = iree_codegen.TranslationInfoAttr(
compilation_info.translation_info
)
assert list(translation_info.workgroup_size) == [128, 1, 1]
assert translation_info.subgroup_size == 64
assert str(translation_info.pass_pipeline) == "#iree_gpu.pipeline<TileAndFuse>"
translation_info_str = str(translation_info)
assert (
"gpu_pipeline_options = #iree_gpu.pipeline_options<"
"prefetch_num_stages = 2, "
"no_reduce_shared_memory_bank_conflicts = false, "
"use_igemm_convolution = false>"
) in translation_info_str
@run
def test_materialize_compilation_info_error_diagnostic():
# Keep the module alive while using the op wrapper below.
input_module, constraints_op = _get_materialize_constraints_op()
try:
iree_codegen.materialize_compilation_info(
constraints_op,
{
"wg_size_x": 128,
"sg_size": 64,
"prefetch_num_stages": 2,
"use_igemm_convolution": False,
},
)
assert False, "expected missing wg_0 assignment to fail"
except RuntimeError as e:
assert "missing assignment for knob 'wg_0'" in str(e)
# Attention knob template with a nested decomposition_config carrying typed
# per-matmul (qk / pv) lowering_config attrs. Mirrors the shape that
# LLVMGPUConstraintGenerator emits for OnlineAttentionOp under
# VectorDistribute.
_MATERIALIZE_ATTENTION_CONSTRAINTS_MODULE = """
module {
iree_codegen.smt.constraints
target = <set = 0>,
pipeline = #iree_gpu.pipeline<VectorDistribute>,
knobs = {
workgroup = [1, #iree_codegen.smt.int_knob<"m_tile">, 0, 0, #iree_codegen.smt.int_knob<"n_tile">],
reduction = [0, 0, 0, #iree_codegen.smt.int_knob<"red_k2">, 0],
promote_operands = [0, 1, 2],
promotion_types = [#iree_gpu.derived_thread_config,
#iree_gpu.derived_thread_config,
#iree_gpu.derived_thread_config],
decomposition_config = {
qk_attrs = {lowering_config = #iree_gpu.lowering_config<{
mma_kind = #iree_codegen.smt.one_of_knob<"qk_mma_idx", [#iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>]>,
promote_operands = [0, 1],
promotion_types = [#iree_gpu.derived_thread_config,
#iree_gpu.derived_thread_config],
subgroup_basis = [[1, #iree_codegen.smt.int_knob<"sg_m_cnt">, 1, 1, #iree_codegen.smt.int_knob<"sg_n_cnt">], [0, 1, 2, 3]]
}>},
pv_attrs = {lowering_config = #iree_gpu.lowering_config<{
mma_kind = #iree_codegen.smt.one_of_knob<"pv_mma_idx", [#iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>]>,
promote_operands = [1],
promotion_types = [#iree_gpu.derived_thread_config],
subgroup_basis = [[1, #iree_codegen.smt.int_knob<"sg_m_cnt">, 1, 1, #iree_codegen.smt.int_knob<"sg_n_cnt">], [0, 1, 3, 4]]
}>}
},
workgroup_size = [#iree_codegen.smt.int_knob<"wg_size_x">, 1, 1],
subgroup_size = #iree_codegen.smt.int_knob<"sg_size">
}
dims() {
}
}
"""
def _get_materialize_attention_constraints_op():
input_module = ir.Module.parse(_MATERIALIZE_ATTENTION_CONSTRAINTS_MODULE)
constraints_ops = ir.get_ops_of_type(input_module, iree_codegen.ConstraintsOp)
assert len(constraints_ops) == 1
return input_module, constraints_ops[0]
_ATTENTION_ASSIGNMENTS = {
"m_tile": 64,
"n_tile": 64,
"red_k2": 64,
"qk_mma_idx": 0,
"pv_mma_idx": 0,
"sg_m_cnt": 4,
"sg_n_cnt": 1,
"wg_size_x": 256,
"sg_size": 64,
}
@run
def test_materialize_compilation_info_attention():
input_module, constraints_op = _get_materialize_attention_constraints_op()
compilation_info = iree_codegen.materialize_compilation_info(
constraints_op, _ATTENTION_ASSIGNMENTS
)
assert isinstance(compilation_info, iree_codegen.CompilationInfoAttr)
lowering_config_str = str(compilation_info.lowering_config)
# Top-level attention lowering_config entries.
assert "workgroup = [1, 64, 0, 0, 64]" in lowering_config_str
assert "reduction = [0, 0, 0, 64, 0]" in lowering_config_str
assert "promote_operands = [0, 1, 2]" in lowering_config_str
assert "partial_reduction" not in lowering_config_str
assert "subgroup_m_count" not in lowering_config_str
assert "subgroup_n_count" not in lowering_config_str
assert "decomposition_config" not in lowering_config_str
translation_info = iree_codegen.TranslationInfoAttr(
compilation_info.translation_info
)
assert list(translation_info.workgroup_size) == [256, 1, 1]
assert translation_info.subgroup_size == 64
@run
def test_materialize_configuration_attr_attention_decomposition_config():
input_module, constraints_op = _get_materialize_attention_constraints_op()
decomp = iree_codegen.materialize_configuration_attr(
constraints_op, "decomposition_config", _ATTENTION_ASSIGNMENTS
)
decomp_str = str(decomp)
assert "qk_attrs" in decomp_str
assert "pv_attrs" in decomp_str
assert "#iree_gpu.lowering_config" in decomp_str
# Both per-matmul lowering_configs select the chosen MMA intrinsic.
assert (
decomp_str.count("mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>") == 2
)
@run
def test_materialize_configuration_attr_picks_indexed_mma():
input_module, constraints_op = _get_materialize_attention_constraints_op()
assignments = dict(_ATTENTION_ASSIGNMENTS)
assignments["pv_mma_idx"] = 1 # MFMA_F32_32x32x8_F16
decomp = iree_codegen.materialize_configuration_attr(
constraints_op, "decomposition_config", assignments
)
decomp_str = str(decomp)
# QK picks index 0; PV picks index 1.
assert (
"qk_attrs = {lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>"
in decomp_str
)
assert (
"pv_attrs = {lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>"
in decomp_str
)
@run
def test_materialize_configuration_attr_error_diagnostic():
input_module, constraints_op = _get_materialize_attention_constraints_op()
assignments = dict(_ATTENTION_ASSIGNMENTS)
del assignments["qk_mma_idx"]
try:
iree_codegen.materialize_configuration_attr(
constraints_op, "decomposition_config", assignments
)
assert False, "expected missing qk_mma_idx assignment to fail"
except RuntimeError as e:
assert "missing assignment for knob 'qk_mma_idx'" in str(e)
@run
def test_materialize_configuration_attr_one_of_out_of_range():
input_module, constraints_op = _get_materialize_attention_constraints_op()
assignments = dict(_ATTENTION_ASSIGNMENTS)
assignments["pv_mma_idx"] = 99
try:
iree_codegen.materialize_configuration_attr(
constraints_op, "decomposition_config", assignments
)
assert False, "expected out-of-range pv_mma_idx assignment to fail"
except RuntimeError as e:
assert "assignment for knob 'pv_mma_idx' is out of range" in str(e)
assert "99 is not in [0, 2)" in str(e)
@run
def test_materialize_configuration_attr_unknown_attr_name():
input_module, constraints_op = _get_materialize_attention_constraints_op()
try:
iree_codegen.materialize_configuration_attr(
constraints_op, "missing_attr", _ATTENTION_ASSIGNMENTS
)
assert False, "expected missing configuration attr to fail"
except RuntimeError as e:
assert "constraints op has no 'missing_attr' entry" in str(e)