Source code for jaxoplanet.utils
__all__ = ["get_dtype_eps", "zero_safe_sqrt"]
import jax
import jax.numpy as jnp
[docs]
def get_dtype_eps(x):
return jnp.finfo(jax.dtypes.result_type(x)).eps
@jax.custom_jvp
[docs]
def zero_safe_sqrt(x):
return jnp.sqrt(x)
@zero_safe_sqrt.defjvp
def zero_safe_sqrt_jvp(primals, tangents):
(x,) = primals
(x_dot,) = tangents
primal_out = jnp.sqrt(x)
cond = jnp.less(x, 10 * get_dtype_eps(x))
denom = jnp.where(cond, jnp.ones_like(x), x)
tangent_out = 0.5 * x_dot * primal_out / denom
return primal_out, tangent_out