Source code for jaxoplanet.utils

__all__ = ["get_dtype_eps", "zero_safe_sqrt"]

import jax
import jax.numpy as jnp


[docs] def get_dtype_eps(x): return jnp.finfo(jax.dtypes.result_type(x)).eps
@jax.custom_jvp
[docs] def zero_safe_sqrt(x): return jnp.sqrt(x)
@zero_safe_sqrt.defjvp def zero_safe_sqrt_jvp(primals, tangents): (x,) = primals (x_dot,) = tangents primal_out = jnp.sqrt(x) cond = jnp.less(x, 10 * get_dtype_eps(x)) denom = jnp.where(cond, jnp.ones_like(x), x) tangent_out = 0.5 * x_dot * primal_out / denom return primal_out, tangent_out