Source code for jaxoplanet.core.kepler

"""This module provides the core functionality to solve Kepler's equation in JAX. For
more details, see the :ref:`core-from-scratch` tutorial.
"""

__all__ = ["kepler"]

import jax
import jax.numpy as jnp
from jax.interpreters import ad

from jaxoplanet.types import Array


@jax.jit
[docs] def kepler(M: Array, ecc: Array) -> tuple[Array, Array]: """Solve Kepler's equation to compute the true anomaly Args: M: Mean anomaly ecc: Eccentricity Returns: The sine and cosine of the true anomaly """ return _kepler(M, ecc)
@jax.custom_jvp def _kepler(M: Array, ecc: Array) -> tuple[Array, Array]: # Wrap into the right range M = M % (2 * jnp.pi) # We can restrict to the range [0, pi) high = M > jnp.pi M = jnp.where(high, 2 * jnp.pi - M, M) # Solve ome = 1 - ecc E = starter(M, ecc, ome) E = refine(M, ecc, ome, E) # Re-wrap back into the full range E = jnp.where(high, 2 * jnp.pi - E, E) # Convert to true anomaly; tan(0.5 * f) tan_half_f = jnp.sqrt((1 + ecc) / (1 - ecc)) * jnp.tan(0.5 * E) tan2_half_f = jnp.square(tan_half_f) # Then we compute sin(f) and cos(f) using: # sin(f) = 2*tan(0.5*f)/(1 + tan(0.5*f)^2), and # cos(f) = (1 - tan(0.5*f)^2)/(1 + tan(0.5*f)^2) denom = 1 / (1 + tan2_half_f) sinf = 2 * tan_half_f * denom cosf = (1 - tan2_half_f) * denom return sinf, cosf @_kepler.defjvp def _(primals, tangents): M, e = primals M_dot, e_dot = tangents sinf, cosf = _kepler(M, e) # Pre-compute some things ecosf = e * cosf ome2 = 1 - e**2 def make_zero(tan): if type(tan) is ad.Zero: return ad.zeros_like_aval(tan.aval) else: return tan # Propagate the derivatives f_dot = make_zero(M_dot) * (1 + ecosf) ** 2 / ome2**1.5 f_dot += make_zero(e_dot) * (2 + ecosf) * sinf / ome2 return (sinf, cosf), (cosf * f_dot, -sinf * f_dot) def starter(M: Array, ecc: Array, ome: Array) -> Array: M2 = jnp.square(M) alpha = 3 * jnp.pi / (jnp.pi - 6 / jnp.pi) alpha += 1.6 / (jnp.pi - 6 / jnp.pi) * (jnp.pi - M) / (1 + ecc) d = 3 * ome + alpha * ecc alphad = alpha * d r = (3 * alphad * (d - ome) + M2) * M q = 2 * alphad * ome - M2 q2 = jnp.square(q) w = jnp.square(jnp.cbrt(jnp.abs(r) + jnp.sqrt(q2 * q + r * r))) return (2 * r * w / (jnp.square(w) + w * q + q2) + M) / d def refine(M: Array, ecc: Array, ome: Array, E: Array) -> Array: sE = E - jnp.sin(E) cE = 1 - jnp.cos(E) f_0 = ecc * sE + E * ome - M f_1 = ecc * cE + ome f_2 = ecc * (E - sE) f_3 = 1 - f_1 d_3 = -f_0 / (f_1 - 0.5 * f_0 * f_2 / f_1) d_4 = -f_0 / (f_1 + 0.5 * d_3 * f_2 + (d_3 * d_3) * f_3 / 6) d_42 = d_4 * d_4 dE = -f_0 / (f_1 + 0.5 * d_4 * f_2 + d_4 * d_4 * f_3 / 6 - d_42 * d_4 * f_2 / 24) return E + dE