blob: 2ff606e87c969f48c1d924abe01e0e84a4fc29c0 [file] [log] [blame]
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Util functions/classes for jax primitive test harnesses."""
import contextlib
import functools
from typing import Optional
import warnings
import zlib
from absl import logging
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax.experimental import jax2tf
from jax.experimental.jax2tf.tests import jax2tf_limitations
from jax.experimental.jax2tf.tests import primitive_harness
import ml_dtypes
import numpy as np
import numpy.random as npr
import tensorflow as tf
_SUPPORTED_DTYPES = [np.float32]
def _harness_matches(harness, group_name, dtype, params):
if harness.group_name != group_name:
return False
if dtype is not None and harness.dtype != dtype:
return False
for key, value in params.items():
if harness.params.get(key, None) != value:
return False
return True
_CRASH_LIST_PARMS = []
_DEFAULT_TOLERANCE = {
jax.dtypes.float0: 0,
np.dtype(np.bool_): 0,
np.dtype(ml_dtypes.int4): 0,
np.dtype(np.int8): 0,
np.dtype(np.int16): 0,
np.dtype(np.int32): 0,
np.dtype(np.int64): 0,
np.dtype(ml_dtypes.uint4): 0,
np.dtype(np.uint8): 0,
np.dtype(np.uint16): 0,
np.dtype(np.uint32): 0,
np.dtype(np.uint64): 0,
np.dtype(ml_dtypes.float8_e4m3b11fnuz): 1e-1,
np.dtype(ml_dtypes.float8_e4m3fn): 1e-1,
np.dtype(ml_dtypes.float8_e5m2): 1e-1,
np.dtype(ml_dtypes.bfloat16): 1e-2,
np.dtype(np.float16): 1e-3,
np.dtype(np.float32): 1e-6,
np.dtype(np.float64): 1e-15,
np.dtype(np.complex64): 1e-6,
np.dtype(np.complex128): 1e-15,
}
def _dtype(x):
if hasattr(x, "dtype"):
return x.dtype
elif type(x) in jax.dtypes.python_scalar_dtypes:
return np.dtype(jax.dtypes.python_scalar_dtypes[type(x)])
else:
return np.asarray(x).dtype
def tolerance(dtype, tol=None):
tol = {} if tol is None else tol
if not isinstance(tol, dict):
return tol
tol = {np.dtype(key): value for key, value in tol.items()}
dtype = jax.dtypes.canonicalize_dtype(np.dtype(dtype))
return tol.get(dtype, _DEFAULT_TOLERANCE[dtype])
def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=""):
"""Checks if two numpy arrays are all close given tolerances.
Args:
a: The array to check.
b: The expected array.
atol: Absolute tolerance.
rtol: Relative tolerance.
err_msg: The error message to print in case of failure.
"""
if a.dtype == b.dtype == jax.dtypes.float0:
np.testing.assert_array_equal(a, b, err_msg=err_msg)
return
custom_dtypes = [
ml_dtypes.float8_e4m3b11fnuz,
ml_dtypes.float8_e4m3fn,
ml_dtypes.float8_e5m2,
ml_dtypes.bfloat16,
]
a = a.astype(np.float32) if a.dtype in custom_dtypes else a
b = b.astype(np.float32) if b.dtype in custom_dtypes else b
kw = {}
if atol:
kw["atol"] = atol
if rtol:
kw["rtol"] = rtol
with np.errstate(invalid="ignore"):
# TODO(phawkins): surprisingly, assert_allclose sometimes reports invalid
# value errors. It should not do that.
np.testing.assert_allclose(a, b, **kw, err_msg=err_msg)
@contextlib.contextmanager
def ignore_warning(**kw):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", **kw)
yield
def _has_only_supported_dtypes(harness):
return True
if harness.dtype not in _SUPPORTED_DTYPES:
return False
for key, value in harness.params.items():
if "dtype" in key and value not in _SUPPORTED_DTYPES:
return False
return True
def is_sequence(x):
try:
iter(x)
except TypeError:
return False
else:
return True
def primitives_parameterized(harnesses, *, one_containing: Optional[str] = None):
"""Decorator for tests. This is used to filter the tests.
Args:
harnesses: List of Harness objects to be filtered.
one_containing: If set, only creates one test case for the provided name.
Returns:
A parameterized version of the test function with filtered set of harnesses.
"""
def _filter_harness(harness):
# TODO(b/295369536) Put a limitations system in place so what's not covered
# is explicit.
if not harness.params.get("enable_xla", True):
return False
if one_containing is not None and one_containing not in harness.fullname:
return False
if not _has_only_supported_dtypes(harness):
return False
for crash_item in _CRASH_LIST_PARMS:
if _harness_matches(
harness,
crash_item["group_name"],
crash_item["dtype"],
crash_item["params"],
):
return False
return True
harnesses = filter(_filter_harness, harnesses)
return primitive_harness.parameterized(harnesses, include_jax_unimpl=False)
class JaxRunningTestCase(parameterized.TestCase):
"""A test case for JAX conversions."""
def setUp(self):
super().setUp()
# We use the adler32 hash for two reasons.
# a) it is deterministic run to run, unlike hash() which is randomized.
# b) it returns values in int32 range, which RandomState requires.
self._rng = npr.RandomState(zlib.adler32(self._testMethodName.encode()))
def rng(self):
return self._rng
def assert_all_close(
self,
x,
y,
*,
check_dtypes=True,
atol=None,
rtol=None,
canonicalize_dtypes=True,
err_msg="",
):
"""Assert that x and y, either arrays or nested tuples/lists, are close."""
if isinstance(x, dict):
self.assertIsInstance(y, dict)
self.assertEqual(set(x.keys()), set(y.keys()))
for k in x.keys():
self.assert_all_close(
x[k],
y[k],
check_dtypes=check_dtypes,
atol=atol,
rtol=rtol,
canonicalize_dtypes=canonicalize_dtypes,
err_msg=err_msg,
)
elif is_sequence(x) and not hasattr(x, "__array__"):
self.assertTrue(is_sequence(y) and not hasattr(y, "__array__"))
self.assertEqual(len(x), len(y))
for x_elt, y_elt in zip(x, y):
self.assert_all_close(
x_elt,
y_elt,
check_dtypes=check_dtypes,
atol=atol,
rtol=rtol,
canonicalize_dtypes=canonicalize_dtypes,
err_msg=err_msg,
)
elif hasattr(x, "__array__") or np.isscalar(x):
self.assertTrue(hasattr(y, "__array__") or np.isscalar(y))
if check_dtypes:
self.assert_dtypes_match(x, y, canonicalize_dtypes=canonicalize_dtypes)
x = np.asarray(x)
y = np.asarray(y)
self.assert_arrays_all_close(
x, y, check_dtypes=False, atol=atol, rtol=rtol, err_msg=err_msg
)
elif x == y:
return
else:
raise TypeError((type(x), type(y)))
def assert_arrays_all_close(
self, x, y, *, check_dtypes=True, atol=None, rtol=None, err_msg=""
):
"""Assert that x and y are close (up to numerical tolerances)."""
self.assertEqual(x.shape, y.shape)
atol = max(tolerance(_dtype(x), atol), tolerance(_dtype(y), atol))
rtol = max(tolerance(_dtype(x), rtol), tolerance(_dtype(y), rtol))
_assert_numpy_allclose(x, y, atol=atol, rtol=rtol, err_msg=err_msg)
if check_dtypes:
self.assert_dtypes_match(x, y)
def assert_dtypes_match(self, x, y, *, canonicalize_dtypes=True):
if not jax.config.x64_enabled and canonicalize_dtypes:
self.assertEqual(
jax.dtypes.canonicalize_dtype(_dtype(x), allow_opaque_dtype=True),
jax.dtypes.canonicalize_dtype(_dtype(y), allow_opaque_dtype=True),
)
else:
self.assertEqual(_dtype(x), _dtype(y))
class PrimitivesTest(JaxRunningTestCase):
@primitives_parameterized(
primitive_harness.all_harnesses,
)
@ignore_warning(
category=UserWarning, message="Using reduced precision for gradient.*"
)
def test_prim(self, harness: primitive_harness.Harness):
def _filter_limitation(limitation):
return limitation.filter(device="cpu", dtype=harness.dtype, mode="compiled")
limitations = tuple(
filter(
_filter_limitation,
jax2tf_limitations.Jax2TfLimitation.limitations_for_harness(harness),
)
)
devices = []
possible_backends = ["iree_cpu"]
for backend in possible_backends:
try:
devices.extend(jax.devices(backend))
except RuntimeError:
continue
func_jax = harness.dyn_fun
args = harness.dyn_args_maker(self.rng())
baseline_platform: Optional[str] = None
baseline_results = None
for d in devices:
if baseline_platform is None or baseline_platform != d.platform:
device_args = jax.tree_util.tree_map(
lambda x: jax.device_put(x, d), args
)
logging.info("Running harness on %s", d)
with jax.jax2tf_associative_scan_reductions(True):
res = func_jax(*device_args)
if baseline_platform is None:
baseline_platform = d.platform
baseline_results = res
else:
logging.info(
"Comparing results for %s and %s", baseline_platform, d.platform
)
self.assert_all_close(baseline_results, res)
if __name__ == "__main__":
absltest.main()