Source code for jaxoplanet.test_utils

__all__ = ["assert_allclose"]

from jax import tree_util
from jax._src.public_test_util import check_close


[docs] def assert_allclose(calculated, expected, *args, **kwargs): """ Check that two floating point arrays are equal within a dtype-dependent tolerance """ kwargs["rtol"] = kwargs.get( "rtol", { "float32": 5e-4, "float64": 5e-7, }, ) check_close(calculated, expected, *args, **kwargs)
def assert_pytree_allclose(calculated, expected, *args, **kwargs): """ Check that two Pytrees with floating point or arrays as leaves are equal within a dtype-dependent tolerance """ leaves1, treedef1 = tree_util.tree_flatten(calculated) leaves2, treedef2 = tree_util.tree_flatten(expected) assert treedef1 == treedef2 for l1, l2 in zip(leaves1, leaves2, strict=False): assert_allclose(l1, l2, *args, **kwargs)