blob: 7387a4ea0a6c84f6594e305efd607b8bebb7a12c [file] [log] [blame]
# Copyright 2025 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from iree.compiler import ir
from iree.compiler.dialects import iree_codegen
from iree.compiler.dialects import affine
from iree.compiler.ir import AffineMap, AffineDimExpr
def run(fn):
with ir.Context(), ir.Location.unknown():
module = ir.Module.create()
with ir.InsertionPoint(module.body):
print("\nTEST:", fn.__name__)
fn()
return fn
@run
def root_op():
module_str = """
module {
func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<4x4xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<4x4xf32>) -> tensor<4x4xf32>
%2 = linalg.matmul ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%1 : tensor<4x4xf32>) -> tensor<4x4xf32>
return %2 : tensor<4x4xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
assert input_module is not None, "Failed to parse input MLIR module"
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
assert len(root_op_list) == 0
module_str = """
module {
func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<4x4xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<4x4xf32>) -> tensor<4x4xf32>
%2 = linalg.matmul { root_op } ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%1 : tensor<4x4xf32>) -> tensor<4x4xf32>
return %2 : tensor<4x4xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
assert input_module is not None, "Failed to parse input MLIR module"
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
assert len(root_op_list) == 1
assert root_op_list[0].name == "linalg.matmul"
module_str = """
module {
func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<4x4xf32>
%1 = linalg.fill { root_op } ins(%cst : f32) outs(%0 : tensor<4x4xf32>) -> tensor<4x4xf32>
%2 = linalg.matmul { root_op } ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%1 : tensor<4x4xf32>) -> tensor<4x4xf32>
return %2 : tensor<4x4xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
assert input_module is not None, "Failed to parse input MLIR module"
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
assert len(root_op_list) == 2
assert root_op_list[0].name == "linalg.fill"
assert root_op_list[1].name == "linalg.matmul"
@run
def attention_op_detail():
dim_exprs = [affine.AffineDimExpr.get(i) for i in range(5)]
q_map = affine.AffineMap.get(5, 0, [dim_exprs[0], dim_exprs[1], dim_exprs[2]])
k_map = affine.AffineMap.get(5, 0, [dim_exprs[0], dim_exprs[3], dim_exprs[2]])
v_map = affine.AffineMap.get(5, 0, [dim_exprs[0], dim_exprs[3], dim_exprs[4]])
o_map = affine.AffineMap.get(5, 0, [dim_exprs[0], dim_exprs[1], dim_exprs[4]])
result = iree_codegen.get_attention_op_detail(q_map, k_map, v_map, o_map)
assert result.domain_rank == 5
assert result.batch_dims == [0]
assert result.m_dims == [1]
assert result.k1_dims == [2]
assert result.k2_dims == [3]
assert result.n_dims == [4]
dim_exprs = [affine.AffineDimExpr.get(i) for i in range(4)]
# Input affine maps that do not follow the expected pattern for an attention operation.
q_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[1]])
k_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[2]])
v_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[3]])
o_map = affine.AffineMap.get(4, 0, [dim_exprs[0], dim_exprs[1]])
result = iree_codegen.get_attention_op_detail(q_map, k_map, v_map, o_map)
assert result.domain_rank == 4
assert result.batch_dims == [0]
assert result.m_dims == [1]
assert result.k1_dims == []
assert result.k2_dims == [2]
assert result.n_dims == [3]
@run
def test_isa_attention_op():
module_str = """
module {
func.func @attention_20x4096x64x4096x64(
%q : tensor<20x4096x64xf16>,
%k : tensor<20x4096x64xf16>,
%v : tensor<20x4096x64xf16>,
%scale : f16,
%output : tensor<20x4096x64xf16>
) -> tensor<20x4096x64xf16> {
%result = iree_linalg_ext.attention { root_op,
indexing_maps = [
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
]
} ins(%q, %k, %v, %scale : tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16)
outs(%output : tensor<20x4096x64xf16>) {
^bb0(%score: f32):
iree_linalg_ext.yield %score : f32
} -> tensor<20x4096x64xf16>
return %result : tensor<20x4096x64xf16>
}
}
"""
input_module = ir.Module.parse(module_str)
assert input_module is not None, "Failed to parse input MLIR module"
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
assert len(root_op_list) == 1
assert root_op_list[0].name == "iree_linalg_ext.attention"
assert iree_codegen.isa_attention_op(root_op_list[0])
@run
def test_igemm_conv_details():
# Test 1: conv_2d_nhwc_hwcf.
module_str = """
module {
func.func @conv_2d_nhwc_hwcf(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
%0 = linalg.conv_2d_nhwc_hwcf { root_op, dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%arg0, %arg1 : tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
outs(%arg2 : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
return %0 : tensor<1x14x14x16xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
details = iree_codegen.get_igemm_generic_conv_details(root_op_list[0])
assert details is not None, "IGEMM details should be valid for NHWC_HWCF conv"
assert details.igemm_loop_bounds == [1, 14, 14, 16, 36]
assert len(details.igemm_contraction_maps) == 3
maps = [map_attr.value for map_attr in details.igemm_contraction_maps]
d0, d1, d2, d3, d4 = [AffineDimExpr.get(i) for i in range(5)]
# For channel-last (NHWC): input (N,H,W,K), filter (K,OC), output (N,H,W,OC).
assert maps[0] == AffineMap.get(
5, 0, [d0, d1, d2, d4]
), f"Input map mismatch: {maps[0]}"
assert maps[1] == AffineMap.get(5, 0, [d4, d3]), f"Filter map mismatch: {maps[1]}"
assert maps[2] == AffineMap.get(
5, 0, [d0, d1, d2, d3]
), f"Output map mismatch: {maps[2]}"
iter_types = [str(attr) for attr in details.igemm_loop_iterators]
assert iter_types == [
'"parallel"',
'"parallel"',
'"parallel"',
'"parallel"',
'"reduction"',
]
assert details.im2col_output_perm == [0, 1, 2, 3]
assert details.filter_reassoc_indices == [[0, 1, 2], [3]]
assert not details.is_output_channel_first
assert details.conv_to_igemm_dim_map == {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 4, 6: 4}
# Test 2: conv_2d_nhwc_fhwc.
module_str = """
module {
func.func @conv_2d_nhwc_fhwc(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> {
%0 = linalg.conv_2d_nhwc_fhwc { root_op, dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%arg0, %arg1 : tensor<1x16x16x4xf32>, tensor<16x3x3x4xf32>)
outs(%arg2 : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
return %0 : tensor<1x14x14x16xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
details = iree_codegen.get_igemm_generic_conv_details(root_op_list[0])
assert details is not None, "IGEMM details should be valid for NHWC_FHWC conv"
assert details.igemm_loop_bounds == [1, 14, 14, 16, 36]
assert len(details.igemm_contraction_maps) == 3
maps = [map_attr.value for map_attr in details.igemm_contraction_maps]
# Verify expected affine maps (NHWC_FHWC layout).
d0, d1, d2, d3, d4 = [AffineDimExpr.get(i) for i in range(5)]
assert maps[0] == AffineMap.get(
5, 0, [d0, d1, d2, d4]
), f"Input map mismatch: {maps[0]}"
assert maps[1] == AffineMap.get(5, 0, [d3, d4]), f"Filter map mismatch: {maps[1]}"
assert maps[2] == AffineMap.get(
5, 0, [d0, d1, d2, d3]
), f"Output map mismatch: {maps[2]}"
iter_types = [str(attr) for attr in details.igemm_loop_iterators]
assert iter_types == [
'"parallel"',
'"parallel"',
'"parallel"',
'"parallel"',
'"reduction"',
]
assert details.im2col_output_perm == [0, 1, 2, 3]
assert details.filter_reassoc_indices == [[0], [1, 2, 3]]
assert not details.is_output_channel_first
assert details.conv_to_igemm_dim_map == {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 4, 6: 4}
# Test 3: conv_2d_nchw_fchw.
module_str = """
module {
func.func @conv_2d_nchw_fchw(%arg0: tensor<1x4x16x16xf32>, %arg1: tensor<16x4x3x3xf32>, %arg2: tensor<1x16x14x14xf32>) -> tensor<1x16x14x14xf32> {
%0 = linalg.conv_2d_nchw_fchw { root_op, dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%arg0, %arg1 : tensor<1x4x16x16xf32>, tensor<16x4x3x3xf32>)
outs(%arg2 : tensor<1x16x14x14xf32>) -> tensor<1x16x14x14xf32>
return %0 : tensor<1x16x14x14xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
details = iree_codegen.get_igemm_generic_conv_details(root_op_list[0])
assert details is not None, "IGEMM details should be valid for NCHW conv"
assert details.igemm_loop_bounds == [1, 16, 14, 14, 36]
assert len(details.igemm_contraction_maps) == 3
maps = [map_attr.value for map_attr in details.igemm_contraction_maps]
# Verify expected affine maps for NCHW with loop dims [N, OC, H, W, K].
# Note: operands are swapped - filter first, then input.
d0, d1, d2, d3, d4 = [AffineDimExpr.get(i) for i in range(5)]
assert maps[0] == AffineMap.get(5, 0, [d1, d4]), f"Filter map mismatch: {maps[0]}"
assert maps[1] == AffineMap.get(
5, 0, [d0, d2, d3, d4]
), f"Input map mismatch: {maps[1]}"
assert maps[2] == AffineMap.get(
5, 0, [d0, d1, d2, d3]
), f"Output map mismatch: {maps[2]}"
iter_types = [str(attr) for attr in details.igemm_loop_iterators]
assert iter_types == [
'"parallel"',
'"parallel"',
'"parallel"',
'"parallel"',
'"reduction"',
]
assert details.im2col_output_perm == [0, 1, 2, 3]
assert details.filter_reassoc_indices == [[0], [1, 2, 3]]
assert details.is_output_channel_first
assert details.conv_to_igemm_dim_map == {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 4, 6: 4}
# Test 4: linalg.generic with convolution pattern (weight backward).
module_str = """
module {
func.func @conv_generic_weight_backward(%arg0: tensor<16x98x64x96xf32>, %arg1: tensor<16x96x64x96xf32>, %arg2: tensor<96x3x96xf32>) -> tensor<96x3x96xf32> {
%0 = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d1 + d4, d5, d2)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5, d0)>,
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
],
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]
} ins(%arg0, %arg1 : tensor<16x98x64x96xf32>, tensor<16x96x64x96xf32>) outs(%arg2 : tensor<96x3x96xf32>) attrs = {root_op} {
^bb0(%in: f32, %in_1: f32, %out: f32):
%mul = arith.mulf %in, %in_1 : f32
%add = arith.addf %out, %mul : f32
linalg.yield %add : f32
} -> tensor<96x3x96xf32>
return %0 : tensor<96x3x96xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
details = iree_codegen.get_igemm_generic_conv_details(root_op_list[0])
assert (
details is not None
), "IGEMM details should be valid for generic 1D conv weight backward"
assert details.igemm_loop_bounds == [96, 3, 96, 98304]
assert len(details.igemm_contraction_maps) == 3
maps = [map_attr.value for map_attr in details.igemm_contraction_maps]
d0, d1, d2, d3 = [AffineDimExpr.get(i) for i in range(4)]
assert maps[0] == AffineMap.get(4, 0, [d3, d0]), f"Map 0 mismatch: {maps[0]}"
assert maps[1] == AffineMap.get(4, 0, [d1, d3, d2]), f"Map 1 mismatch: {maps[1]}"
assert maps[2] == AffineMap.get(4, 0, [d0, d1, d2]), f"Map 2 mismatch: {maps[2]}"
iter_types = [str(attr) for attr in details.igemm_loop_iterators]
assert iter_types == ['"parallel"', '"parallel"', '"parallel"', '"reduction"']
assert details.im2col_output_perm == [1, 2, 0]
assert details.filter_reassoc_indices == [[0, 1, 2], [3]]
assert details.is_output_channel_first
assert details.conv_to_igemm_dim_map == {0: 0, 1: 1, 2: 2, 3: 3, 4: 3, 5: 3}
# Test with a non-conv operation.
module_str = """
module {
func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> tensor<4x4xf32> {
%0 = linalg.matmul { root_op } ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%arg2 : tensor<4x4xf32>) -> tensor<4x4xf32>
return %0 : tensor<4x4xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
matmul_op = root_op_list[0]
details = iree_codegen.get_igemm_generic_conv_details(matmul_op)
assert details is None, "IGEMM details should be None for non-conv operation"
@run
def test_isa_scaled_contraction_op():
# Test 1: Regular matmul is not a scaled contraction.
module_str = """
module {
func.func @matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> tensor<4x4xf32> {
%0 = linalg.matmul { root_op } ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%arg2 : tensor<4x4xf32>) -> tensor<4x4xf32>
return %0 : tensor<4x4xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
assert input_module is not None, "Failed to parse input MLIR module"
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
assert len(root_op_list) == 1
matmul_op = root_op_list[0]
assert not iree_codegen.isa_scaled_contraction_op(
matmul_op
), "Regular matmul should not be a scaled contraction"
# Test 2: Fill op is not a scaled contraction.
module_str = """
module {
func.func @fill(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.fill { root_op } ins(%cst : f32) outs(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32>
return %0 : tensor<4x4xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
assert len(root_op_list) == 1
fill_op = root_op_list[0]
assert not iree_codegen.isa_scaled_contraction_op(
fill_op
), "Fill op should not be a scaled contraction"
# Test 3: Scaled matmul as linalg.generic should be detected.
# Pattern: linalg.generic with 5 indexing maps (lhs, rhs, lhs_scale, rhs_scale, output),
# and 4 iterator types (2 parallel for M,N; 2 reduction for Ko,Kb).
# Uses f4E2M1FN for operands and f8E8M0FNU for scales (matching real scaled matmul pattern).
module_str = """
module {
func.func @scaled_matmul(%lhs: tensor<16x4x32xf4E2M1FN>, %rhs: tensor<16x4x32xf4E2M1FN>,
%lhs_scales: tensor<16x4xf8E8M0FNU>, %rhs_scales: tensor<16x4xf8E8M0FNU>,
%out: tensor<16x16xf32>) -> tensor<16x16xf32> {
%result = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d2)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1)>
],
iterator_types = ["parallel", "parallel", "reduction", "reduction"],
root_op
} ins(%lhs, %rhs, %lhs_scales, %rhs_scales : tensor<16x4x32xf4E2M1FN>, tensor<16x4x32xf4E2M1FN>, tensor<16x4xf8E8M0FNU>, tensor<16x4xf8E8M0FNU>)
outs(%out : tensor<16x16xf32>) {
^bb0(%a: f4E2M1FN, %b: f4E2M1FN, %a_scale: f8E8M0FNU, %b_scale: f8E8M0FNU, %acc: f32):
%a_scaled = arith.scaling_extf %a, %a_scale : f4E2M1FN, f8E8M0FNU to f32
%b_scaled = arith.scaling_extf %b, %b_scale : f4E2M1FN, f8E8M0FNU to f32
%prod = arith.mulf %a_scaled, %b_scaled : f32
%sum = arith.addf %acc, %prod : f32
linalg.yield %sum : f32
} -> tensor<16x16xf32>
return %result : tensor<16x16xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
assert len(root_op_list) == 1, "Should have one root op"
scaled_generic_op = root_op_list[0]
is_scaled = iree_codegen.isa_scaled_contraction_op(scaled_generic_op)
assert (
is_scaled
), "linalg.generic with scaled matmul pattern should be detected as scaled contraction"
dims = iree_codegen.infer_scaled_contraction_dimensions(scaled_generic_op)
assert dims is not None, "Should be able to infer dimensions for scaled contraction"
assert dims.m == [0], f"Got {dims.m}"
assert dims.n == [1], f"Got {dims.n}"
assert dims.k == [2], f"Got {dims.k}"
assert dims.kB == [3], f"Got {dims.kB}"
assert dims.batch == [], f"Got {dims.batch}"
@run
def test_infer_scaled_contraction_dimensions():
# Test 1: Verify dimension inference on a scaled matmul operation.
module_str = """
module {
func.func @scaled_matmul(%lhs: tensor<16x4x32xf4E2M1FN>, %rhs: tensor<16x4x32xf4E2M1FN>,
%lhs_scales: tensor<16x4xf8E8M0FNU>, %rhs_scales: tensor<16x4xf8E8M0FNU>,
%out: tensor<16x16xf32>) -> tensor<16x16xf32> {
%result = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d2)>,
affine_map<(d0, d1, d2, d3) -> (d1, d2)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1)>
],
iterator_types = ["parallel", "parallel", "reduction", "reduction"],
root_op
} ins(%lhs, %rhs, %lhs_scales, %rhs_scales : tensor<16x4x32xf4E2M1FN>, tensor<16x4x32xf4E2M1FN>, tensor<16x4xf8E8M0FNU>, tensor<16x4xf8E8M0FNU>)
outs(%out : tensor<16x16xf32>) {
^bb0(%a: f4E2M1FN, %b: f4E2M1FN, %a_scale: f8E8M0FNU, %b_scale: f8E8M0FNU, %acc: f32):
%a_scaled = arith.scaling_extf %a, %a_scale : f4E2M1FN, f8E8M0FNU to f32
%b_scaled = arith.scaling_extf %b, %b_scale : f4E2M1FN, f8E8M0FNU to f32
%prod = arith.mulf %a_scaled, %b_scaled : f32
%sum = arith.addf %acc, %prod : f32
linalg.yield %sum : f32
} -> tensor<16x16xf32>
return %result : tensor<16x16xf32>
}
}
"""
input_module = ir.Module.parse(module_str)
root_op_list = iree_codegen.get_tuner_root_ops(input_module)
assert len(root_op_list) == 1, "Should have exactly one root op"
scaled_op = root_op_list[0]
assert iree_codegen.isa_scaled_contraction_op(
scaled_op
), "Operation should be recognized as scaled contraction"
dims = iree_codegen.infer_scaled_contraction_dimensions(scaled_op)
assert dims is not None, "Should successfully infer dimensions"
assert dims.m == [0], f"Got {dims.m}"
assert dims.n == [1], f"Got {dims.n}"
assert dims.k == [2], f"Got {dims.k}"
assert dims.kB == [3], f"Got {dims.kB}"
assert dims.batch == [], f"Got {dims.batch}"
# Test 2: Non-scaled contraction should return None.
module_str_regular = """
module {
func.func @regular_matmul(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> tensor<4x4xf32> {
%0 = linalg.matmul { root_op } ins(%arg0, %arg1 : tensor<4x4xf32>, tensor<4x4xf32>) outs(%arg2 : tensor<4x4xf32>) -> tensor<4x4xf32>
return %0 : tensor<4x4xf32>
}
}
"""
input_module_regular = ir.Module.parse(module_str_regular)
regular_ops = iree_codegen.get_tuner_root_ops(input_module_regular)
assert len(regular_ops) == 1
regular_matmul = regular_ops[0]
# Regular matmul should not have scaled contraction dimensions.
# Check if all dimensions are empty (indicating it's not a scaled contraction).
dims_regular = iree_codegen.infer_scaled_contraction_dimensions(regular_matmul)
if dims_regular is not None:
all_empty = (
len(dims_regular.m) == 0
and len(dims_regular.n) == 0
and len(dims_regular.k) == 0
and len(dims_regular.kB) == 0
and len(dims_regular.batch) == 0
)
assert (
all_empty or dims_regular is None
), "Regular matmul should not have valid scaled contraction dimensions"
# Test 3: Batched scaled matmul.
module_str_batched = """
module {
func.func @batched_scaled_matmul(%lhs: tensor<8x16x4x32xf4E2M1FN>, %rhs: tensor<8x16x4x32xf4E2M1FN>,
%lhs_scales: tensor<8x16x4xf8E8M0FNU>, %rhs_scales: tensor<8x16x4xf8E8M0FNU>,
%out: tensor<8x16x16xf32>) -> tensor<8x16x16xf32> {
%result = linalg.generic {
indexing_maps = [
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d3)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
],
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"],
root_op
} ins(%lhs, %rhs, %lhs_scales, %rhs_scales : tensor<8x16x4x32xf4E2M1FN>, tensor<8x16x4x32xf4E2M1FN>, tensor<8x16x4xf8E8M0FNU>, tensor<8x16x4xf8E8M0FNU>)
outs(%out : tensor<8x16x16xf32>) {
^bb0(%a: f4E2M1FN, %b: f4E2M1FN, %a_scale: f8E8M0FNU, %b_scale: f8E8M0FNU, %acc: f32):
%a_scaled = arith.scaling_extf %a, %a_scale : f4E2M1FN, f8E8M0FNU to f32
%b_scaled = arith.scaling_extf %b, %b_scale : f4E2M1FN, f8E8M0FNU to f32
%prod = arith.mulf %a_scaled, %b_scaled : f32
%sum = arith.addf %acc, %prod : f32
linalg.yield %sum : f32
} -> tensor<8x16x16xf32>
return %result : tensor<8x16x16xf32>
}
}
"""
input_module_batched = ir.Module.parse(module_str_batched)
batched_ops = iree_codegen.get_tuner_root_ops(input_module_batched)
assert len(batched_ops) == 1, "Batched op should be found"
batched_op = batched_ops[0]
assert iree_codegen.isa_scaled_contraction_op(
batched_op
), "Batched scaled matmul should be recognized"
dims_batched = iree_codegen.infer_scaled_contraction_dimensions(batched_op)
assert (
dims_batched is not None
), "Batch dimension must be present in batched scaled matmul"
assert dims_batched.batch == [0], f"Got {dims_batched.batch}"
assert dims_batched.m == [1], f"Got {dims_batched.m}"
assert dims_batched.n == [2], f"Got {dims_batched.n}"
assert dims_batched.k == [3], f"Got {dims_batched.k}"
assert dims_batched.kB == [4], f"Got {dims_batched.kB}"