blob: 6de6de4a9360dd7f9aceb7931d637f80213c2f22 [file] [log] [blame]
#!/usr/bin/env python3
# Copyright 2025 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Generate e2e tests of math ops."""
import argparse
import dataclasses
import enum
import math
import json
import sys
import typing
def parse_arguments():
parser = argparse.ArgumentParser(description="Generate e2e tests of math ops")
parser.add_argument(
"--testcases",
type=str,
help="JSON file listing the ops to be tested.",
required=True,
)
return parser.parse_args()
def evaluate_unary_float_op(op, x):
if op == "cbrt":
return x ** (1.0 / 3)
if op == "exp2":
return 2.0**x
if op == "roundeven":
return float(round(x))
if op == "round":
if 2.0 * x == round(2.0 * x):
return math.copysign(math.ceil(math.fabs(x)), x)
return float(round(x))
if op == "rsqrt":
return x**-0.5
return float(eval(f"math.{op}({x})"))
def evaluate_binary_float_op(op, x, y):
if op == "powf" or op == "fpowi":
return 0.0 if x == 0 else math.pow(x, y)
return float(eval(f"math.{op}({x}, {y})"))
def generate_unary_float_op(op, type, atol, rtol, input_values):
output_values = [evaluate_unary_float_op(op, x) for x in input_values]
size = len(input_values)
print(
f"""
func.func @test_{op}_{type}() -> () {{
%input = util.unfoldable_constant dense<{input_values}> : tensor<{size}x{type}>
%result_empty = tensor.empty() : tensor<{size}x{type}>
%result = linalg.generic {{indexing_maps = [
affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>
], iterator_types = ["parallel"]}}
ins(%input : tensor<{size}x{type}>) outs(%result_empty : tensor<{size}x{type}>) {{
^bb0(%in: {type}, %out: {type}):
%result = math.{op} %in : {type}
linalg.yield %result : {type}
}} -> tensor<{size}x{type}>
check.expect_almost_eq_const(%result,
dense<{output_values}> : tensor<{size}x{type}>,
atol {float(atol):.3e}, rtol {float(rtol):.3e}) : tensor<{size}x{type}>
return
}}"""
)
def generate_binary_float_op(op, type, atol, rtol, input_values):
input_values_x = [x for x, _ in input_values]
input_values_y = [y for _, y in input_values]
output_values = [evaluate_binary_float_op(op, x, y) for x, y in input_values]
size = len(input_values)
print(
f"""
func.func @test_{op}_{type}() -> () {{
%input_x = util.unfoldable_constant dense<{input_values_x}> : tensor<{size}x{type}>
%input_y = util.unfoldable_constant dense<{input_values_y}> : tensor<{size}x{type}>
%result_empty = tensor.empty() : tensor<{size}x{type}>
%result = linalg.generic {{indexing_maps = [
affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>
], iterator_types = ["parallel"]}}
ins(%input_x, %input_y : tensor<{size}x{type}>, tensor<{size}x{type}>) outs(%result_empty : tensor<{size}x{type}>) {{
^bb0(%in_x: {type}, %in_y : {type}, %out: {type}):
%result = math.{op} %in_x, %in_y : {type}
linalg.yield %result : {type}
}} -> tensor<{size}x{type}>
check.expect_almost_eq_const(%result,
dense<{output_values}> : tensor<{size}x{type}>,
atol {float(atol):.3e}, rtol {float(rtol):.3e}) : tensor<{size}x{type}>
return
}}"""
)
def generate_unary_float_input_values(predicate):
# Some functions like math.round and math.roundeven care
# specifically about half-integral values. To ensure coverage of
# half-integral cases, unit_subdivisions should be even.
# Note that the correctness of the logic here rests on exact representability
# of small integer values scaled by small powers of two.
unit_subdivisions = 4
float_range_radius = 16
index_range_radius = int(float_range_radius * unit_subdivisions)
values = []
for x in range(-index_range_radius, index_range_radius):
val = float(x) / unit_subdivisions
if predicate(val):
values.append(val)
return values
def generate_binary_float_input_values(predicate):
# Be conservative due to quadratic growth.
unit_subdivisions = 4
float_range_radius = 4
index_range_radius = int(float_range_radius * unit_subdivisions)
values = []
for x in range(-index_range_radius, index_range_radius):
val_x = float(x) / unit_subdivisions
for y in range(-index_range_radius, index_range_radius):
val_y = float(y) / unit_subdivisions
if predicate(val_x, val_y):
values.append((val_x, val_y))
return values
@enum.unique
class MathOpKind(enum.Enum):
UNARY_FLOAT = 1
BINARY_FLOAT = 2
@dataclasses.dataclass
class MathOpInfo:
kind: MathOpKind
domain: typing.Callable
def get_math_op_info():
return {
"acos": MathOpInfo(
kind=MathOpKind.UNARY_FLOAT,
domain=lambda x: -1 <= x <= 1,
),
"acosh": MathOpInfo(
kind=MathOpKind.UNARY_FLOAT,
domain=lambda x: x >= 1,
),
"asin": MathOpInfo(
kind=MathOpKind.UNARY_FLOAT,
domain=lambda x: -1 <= x <= 1,
),
"asinh": MathOpInfo(
kind=MathOpKind.UNARY_FLOAT,
domain=lambda x: True,
),
"atan": MathOpInfo(
kind=MathOpKind.UNARY_FLOAT,
domain=lambda x: True,
),
"atan2": MathOpInfo(
kind=MathOpKind.BINARY_FLOAT,
domain=lambda x, y: x != 0 or y != 0,
),
"atanh": MathOpInfo(
kind=MathOpKind.UNARY_FLOAT,
domain=lambda x: -1 < x < 1,
),
"cbrt": MathOpInfo(
kind=MathOpKind.UNARY_FLOAT,
domain=lambda x: x >= 0,
),
"ceil": MathOpInfo(
kind=MathOpKind.UNARY_FLOAT,
domain=lambda x: True,
),
"cos": MathOpInfo(
kind=MathOpKind.UNARY_FLOAT,
domain=lambda x: True,
),
"cosh": MathOpInfo(
kind=MathOpKind.UNARY_FLOAT,
domain=lambda x: True,
),
"erf": MathOpInfo(
kind=MathOpKind.UNARY_FLOAT,
domain=lambda x: True,
),
"erfc": MathOpInfo(
kind=MathOpKind.UNARY_FLOAT,
domain=lambda x: True,
),
"exp": MathOpInfo(
kind=MathOpKind.UNARY_FLOAT,
domain=lambda x: True,
),
"exp2": MathOpInfo(
kind=MathOpKind.UNARY_FLOAT,
domain=lambda x: True,
),
"expm1": MathOpInfo(
kind=MathOpKind.UNARY_FLOAT,
domain=lambda x: True,
),
"floor": MathOpInfo(
kind=MathOpKind.UNARY_FLOAT,
domain=lambda x: True,
),
"log": MathOpInfo(
kind=MathOpKind.UNARY_FLOAT,
domain=lambda x: x > 0,
),
"log10": MathOpInfo(
kind=MathOpKind.UNARY_FLOAT,
domain=lambda x: x > 0,
),
"log1p": MathOpInfo(
kind=MathOpKind.UNARY_FLOAT,
domain=lambda x: x > -1,
),
"log2": MathOpInfo(
kind=MathOpKind.UNARY_FLOAT,
domain=lambda x: x > 0,
),
"powf": MathOpInfo(
kind=MathOpKind.BINARY_FLOAT,
domain=lambda x, y: x > 0,
),
"round": MathOpInfo(
kind=MathOpKind.UNARY_FLOAT,
domain=lambda x: True,
),
"roundeven": MathOpInfo(
kind=MathOpKind.UNARY_FLOAT,
domain=lambda x: True,
),
"rsqrt": MathOpInfo(
kind=MathOpKind.UNARY_FLOAT,
domain=lambda x: x > 0,
),
"sin": MathOpInfo(
kind=MathOpKind.UNARY_FLOAT,
domain=lambda x: True,
),
"sinh": MathOpInfo(
kind=MathOpKind.UNARY_FLOAT,
domain=lambda x: True,
),
"sqrt": MathOpInfo(
kind=MathOpKind.UNARY_FLOAT,
domain=lambda x: x > 0,
),
"tan": MathOpInfo(
kind=MathOpKind.UNARY_FLOAT,
domain=lambda x: True,
),
"tanh": MathOpInfo(
kind=MathOpKind.UNARY_FLOAT,
domain=lambda x: True,
),
"trunc": MathOpInfo(
kind=MathOpKind.UNARY_FLOAT,
domain=lambda x: True,
),
}
def main(args):
with open(args.testcases, "r") as testcases_file:
testcases = json.load(testcases_file)
math_op_info = get_math_op_info()
ops_not_yet_encountered = {op for op in math_op_info}
for testcase in testcases:
op = testcase["op"]
ops_not_yet_encountered.discard(op)
if "disabled" in testcase:
continue
info = math_op_info[op]
kind = info.kind
type = testcase["type"]
atol = testcase["atol"]
rtol = testcase["rtol"]
if kind == MathOpKind.UNARY_FLOAT:
# Combine the op's inherent domain (info.domain) with optional
# testcase-specific restrictions.
domain = lambda x: (
info.domain(x)
and (x >= testcase["xmin"] if "xmin" in testcase else True)
and (x <= testcase["xmax"] if "xmax" in testcase else True)
)
input_values = generate_unary_float_input_values(domain)
generate_unary_float_op(op, type, atol, rtol, input_values)
elif kind == MathOpKind.BINARY_FLOAT:
# Combine the op's inherent domain (info.domain) with optional
# testcase-specific restrictions.
domain = lambda x, y: (
info.domain(x, y)
and (x >= testcase["xmin"] if "xmin" in testcase else True)
and (x <= testcase["xmax"] if "xmax" in testcase else True)
and (y >= testcase["ymin"] if "ymin" in testcase else True)
and (y <= testcase["ymax"] if "ymax" in testcase else True)
)
input_values = generate_binary_float_input_values(domain)
generate_binary_float_op(op, type, atol, rtol, input_values)
else:
raise ValueError(f"Unhandled op kind: {info.kind}")
if ops_not_yet_encountered:
print(
f"Warning: did not find testcases covering {', '.join(ops_not_yet_encountered)} in {args.testcases}",
file=sys.stderr,
)
if __name__ == "__main__":
main(parse_arguments())