"""A module to define Keplerian systems of bodies."""
from collections.abc import Callable, Iterable, Sequence
from typing import Any
import equinox as eqx
import jax.numpy as jnp
from jaxoplanet import constants
from jaxoplanet.core.kepler import kepler
from jaxoplanet.object_stack import ObjectStack
from jaxoplanet.types import Scalar
[docs]
class Central(eqx.Module):
"""A central body in an orbital system
Args:
mass (Optional[Scalar]): Mass of central body [mass unit].
radius (Optional[Scalar]): Radius of central body [length unit].
density (Optional[Scalar]): Density of central body [mass/length**3 unit].
map (Optional[Map]): Map of the central body. If None a uniform map with
intensity 1 is used.
"""
def __init__(
self,
*,
mass: Scalar | None = None,
radius: Scalar | None = None,
density: Scalar | None = None,
):
if radius is None and mass is None:
radius = 1.0
if density is None:
mass = 1.0
# Check that all the input values are scalars; we don't support Scalars
# here
if any(
jnp.ndim(arg) != 0 for arg in (mass, radius, density) if arg is not None
):
raise ValueError("All parameters of a KeplerianCentral must be scalars")
# Compute all three parameters based on the input values
error_msg = (
"Values must be provided for exactly two of mass, radius, and density"
)
if density is None:
if mass is None or radius is None:
raise ValueError(error_msg)
self.mass = mass
self.radius = radius
self.density = 3 * mass / (4 * jnp.pi * radius**3)
elif radius is None:
if mass is None or density is None:
raise ValueError(error_msg)
self.mass = mass
self.radius = (3 * mass / (4 * jnp.pi * density)) ** (1 / 3)
self.density = density
elif mass is None:
if radius is None or density is None:
raise ValueError(error_msg)
self.mass = 4 * jnp.pi * radius**3 * density / 3.0
self.radius = radius
self.density = density
@classmethod
[docs]
def from_orbital_properties(
cls,
*,
period: Scalar,
semimajor: Scalar,
radius: Scalar | None = None,
body_mass: Scalar | None = None,
) -> "Central":
"""Initialize the central body (e.g. a star) of an orbital system using
orbital parameters to derive radius and mass.
Args:
period: The orbital period of the orbiting body [time unit].
semimajor: The semi-major axis of the orbiting body [length unit].
radius (Optional[Scalar]): Radius of central body [length unit].
body_mass (Optional[Scalar]): Mass of orbiting body [mass unit].
Returns:
Central object
"""
# Check that inputs are scalar
if any(
jnp.ndim(arg) != 0
for arg in (semimajor, period, body_mass)
if arg is not None
):
raise ValueError(
"All parameters of 'KeplerianCentral.from_orbital_properties' must be "
"scalars; for multi-planet systems, use 'jax.vmap'"
)
radius = 1.0 if radius is None else radius
mass = 4 * jnp.pi**2 * semimajor**3 / (constants.G * period**2)
if body_mass is not None:
mass -= body_mass
return cls(mass=mass, radius=radius)
@property
[docs]
def shape(self) -> tuple[int, ...]:
return self.mass.shape
[docs]
class Body(eqx.Module):
"""Initialize an orbiting body (e.g. a planet) using orbital parameters
See https://docs.exoplanet.codes/en/latest/tutorials/data-and-models/ for a
description of the orbital geometry.
Args:
central (Optional[Central]): The Central object that this Body orbits
[Central].
time_transit (Optional[Scalar]): The epoch of a reference transit
[time unit].
time_peri (Optional[Scalar]): The epoch of a reference periastron passage
[time unit].
period (Optional[Scalar]): Orbital period [time unit].
semimajor (Optional[Scalar]): Semi-major axis in [length unit].
inclination (Optional[Scalar]): Inclination of orbital plane in
[angular unit].
impact_param (Optional): Impact parameter.
eccentricity (Optional): Eccentricity, must be ``0 <= eccentricity < 1``
where 0 = circular orbit.
omega_peri (Optional[Scalar]): Argument of periastron [angular unit].
sin_omega_peri (Optional): sin(argument of periastron).
cos_omega_peri (Optional): cos(argument of periastron).
asc_node (Optional[Scalar]): Longitude of ascending node [angular unit].
sin_asc_node (Optional): sin(longitude of ascending node).
cos_asc_node (Optional): cos(longitude of ascending node).
mass (Optional[Scalar]): Mass of orbiting body [mass unit].
radius (Optional[Scalar]): Radius of orbiting body [length unit].
central_radius (Optional[Scalar]): Radius of central body [length unit].
radial_velocity_semiamplitude (Optional[Scalar]): The radial velocity
semi-amplitude [length/time unit].
parallax (Optional[Scalar]): Parallax (to convert position/velocity into
arcsec). [length unit].
"""
[docs]
time_transit: Scalar | None = None
[docs]
time_peri: Scalar | None = None
[docs]
period: Scalar | None = None
[docs]
semimajor: Scalar | None = None
[docs]
inclination: Scalar | None = None
[docs]
impact_param: Scalar | None = None
[docs]
eccentricity: Scalar | None = None
[docs]
omega_peri: Scalar | None = None
[docs]
sin_omega_peri: Scalar | None = None
[docs]
cos_omega_peri: Scalar | None = None
[docs]
asc_node: Scalar | None = None
[docs]
sin_asc_node: Scalar | None = None
[docs]
cos_asc_node: Scalar | None = None
[docs]
mass: Scalar | None = None
[docs]
radius: Scalar | None = None
[docs]
radial_velocity_semiamplitude: Scalar | None = None
[docs]
parallax: Scalar | None = None
[docs]
def __check_init__(self) -> None:
if not ((self.period is None) ^ (self.semimajor is None)):
raise ValueError("Exactly one of period or semimajor must be specified")
# Check that all the input arguments have the right shape
provided_input_arguments = [
arg
for arg in (
self.time_transit,
self.time_peri,
self.period,
self.semimajor,
self.inclination,
self.impact_param,
self.eccentricity,
self.omega_peri,
self.sin_omega_peri,
self.cos_omega_peri,
self.asc_node,
self.sin_asc_node,
self.cos_asc_node,
self.mass,
self.radius,
self.radial_velocity_semiamplitude,
self.parallax,
)
if arg is not None
]
if any(jnp.ndim(arg) != 0 for arg in provided_input_arguments):
raise ValueError(
"All input arguments to 'Body' must be scalars; "
"for multi-planet systems, use a 'System'"
)
if self.omega_peri is not None and (
self.sin_omega_peri is not None or self.cos_omega_peri is not None
):
raise ValueError(
"Cannot specify both omega_peri and sin_omega_peri or cos_omega_peri"
)
if (self.sin_omega_peri is not None) ^ (self.cos_omega_peri is not None):
raise ValueError("Both sin_omega_peri and cos_omega_peri must be specified")
if self.asc_node is not None and (
self.sin_asc_node is not None or self.cos_asc_node is not None
):
raise ValueError(
"Cannot specify both asc_node and sin_asc_node or cos_asc_node"
)
if (self.sin_asc_node is not None) ^ (self.cos_asc_node is not None):
raise ValueError("Both sin_asc_node and cos_asc_node must be specified")
has_omega_peri = (
self.omega_peri is not None
or self.sin_omega_peri is not None
or self.cos_omega_peri is not None
)
if (self.eccentricity is not None) ^ has_omega_peri:
raise ValueError(
"Both or neither of eccentricity and omega_peri must be specified"
)
if self.impact_param is not None and self.inclination is not None:
raise ValueError(
"Only one of impact_param and inclination can be specified"
)
if self.time_transit is not None and self.time_peri is not None:
raise ValueError("Only one of time_transit or time_peri can be specified")
[docs]
class OrbitalBody(eqx.Module):
"""A computational tool"""
[docs]
time_transit: Scalar
[docs]
sin_inclination: Scalar
[docs]
cos_inclination: Scalar
[docs]
impact_param: Scalar
[docs]
mass: Scalar | None
[docs]
radius: Scalar | None
[docs]
eccentricity: Scalar | None
[docs]
sin_omega_peri: Scalar | None
[docs]
cos_omega_peri: Scalar | None
[docs]
sin_asc_node: Scalar | None
[docs]
cos_asc_node: Scalar | None
[docs]
radial_velocity_semiamplitude: Scalar | None
[docs]
parallax: Scalar | None
def __init__(self, central: Central, body: Body):
self.central = central
# Save the input mass and radius
self.radius = body.radius
self.mass = body.mass
self.radial_velocity_semiamplitude = body.radial_velocity_semiamplitude
self.parallax = body.parallax
# Work out the period and semimajor axis to be consistent
mass_factor = constants.G * self.total_mass
if body.semimajor is None:
assert body.period is not None
self.semimajor = jnp.cbrt(mass_factor * body.period**2 / (4 * jnp.pi**2))
self.period = body.period
elif body.period is None:
assert body.semimajor is not None
self.semimajor = body.semimajor
self.period = (
2 * jnp.pi * body.semimajor * jnp.sqrt(body.semimajor / mass_factor)
)
# Handle treatment and normalization of angles
if body.omega_peri is not None:
self.sin_omega_peri = jnp.sin(body.omega_peri)
self.cos_omega_peri = jnp.cos(body.omega_peri)
else:
self.sin_omega_peri = body.sin_omega_peri
self.cos_omega_peri = body.cos_omega_peri
if body.asc_node is not None:
self.sin_asc_node = jnp.sin(body.asc_node)
self.cos_asc_node = jnp.cos(body.asc_node)
else:
self.sin_asc_node = body.sin_asc_node
self.cos_asc_node = body.cos_asc_node
# Handle eccentric and circular orbits
self.eccentricity = body.eccentricity
if self.eccentricity is None:
M0 = jnp.full_like(self.period, 0.5 * jnp.pi) # type: ignore
incl_factor = 1
else:
assert self.sin_omega_peri is not None
assert self.cos_omega_peri is not None
opsw = 1 + self.sin_omega_peri
E0 = 2 * jnp.arctan2(
jnp.sqrt(1 - self.eccentricity) * self.cos_omega_peri,
jnp.sqrt(1 + self.eccentricity) * opsw,
)
M0 = E0 - self.eccentricity * jnp.sin(E0)
ome2 = 1 - self.eccentricity**2
incl_factor = (1 + self.eccentricity * self.sin_omega_peri) / ome2
# Handle inclined orbits
dcosidb = incl_factor * central.radius / self.semimajor
if body.impact_param is not None:
self.impact_param = body.impact_param
self.cos_inclination = dcosidb * body.impact_param
self.sin_inclination = jnp.sqrt(1 - self.cos_inclination**2)
elif body.inclination is not None:
self.cos_inclination = jnp.cos(body.inclination)
self.sin_inclination = jnp.sin(body.inclination)
self.impact_param = self.cos_inclination / dcosidb
else:
z = jnp.zeros_like(self.period)
self.impact_param = z
self.cos_inclination = z
self.sin_inclination = jnp.ones_like(self.period)
# Work out all the relevant reference times
self.time_ref = -M0 * self.period / (2 * jnp.pi)
if body.time_transit is not None:
self.time_transit = body.time_transit
elif body.time_peri is not None:
self.time_transit = body.time_peri - self.time_ref
else:
self.time_transit = jnp.zeros_like(self.time_ref)
@property
[docs]
def central_radius(self) -> Scalar:
return self.central.radius
@property
[docs]
def time_peri(self) -> Scalar:
return self.time_transit + self.time_ref # type: ignore
@property
[docs]
def inclination(self) -> Scalar:
return jnp.arctan2(self.sin_inclination, self.cos_inclination)
@property
[docs]
def omega_peri(self) -> Scalar | None:
if self.eccentricity is None:
return None
assert self.sin_omega_peri is not None
assert self.cos_omega_peri is not None
return jnp.arctan2(self.sin_omega_peri, self.cos_omega_peri)
@property
[docs]
def total_mass(self) -> Scalar:
return self.central.mass if self.mass is None else self.mass + self.central.mass
[docs]
def position(
self, t: Scalar, parallax: Scalar | None = None
) -> tuple[Scalar, Scalar, Scalar]:
"""This body's position in the barycentric frame
Args:
t: The times where the position should be evaluated.
Returns:
The components of the position vector at ``t`` in units of
``R_sun``, but if parallax is provided or ``self.parallax`` is not None,
then in units of arcseconds.
"""
semimajor = -self.semimajor * self.central.mass / self.total_mass
return self._get_position_and_velocity(
t, semimajor=semimajor, parallax=parallax
)[0]
[docs]
def central_position(
self, t: Scalar, parallax: Scalar | None = None
) -> tuple[Scalar, Scalar, Scalar]:
"""The central's position in the barycentric frame
Args:
t: The times where the position should be evaluated.
Returns:
The components of the position vector at ``t`` in units of
``R_sun``, but if parallax is provided or ``self.parallax`` is not None,
then in units of arcseconds.
"""
semimajor = self.semimajor * self.mass / self.total_mass
return self._get_position_and_velocity(
t, semimajor=semimajor, parallax=parallax
)[0]
[docs]
def relative_position(
self, t: Scalar, parallax: Scalar | None = None
) -> tuple[Scalar, Scalar, Scalar]:
"""This body's position relative to the central in the X,Y,Z frame
Args:
t: The times where the position should be evaluated.
Returns:
The components of the position vector at ``t`` in units of
``R_sun``, but if parallax is provided or ``self.parallax`` is not None,
then in units of arcseconds.
"""
return self._get_position_and_velocity(
t,
semimajor=-self.semimajor,
parallax=parallax, # type: ignore
)[0]
[docs]
def relative_angles(
self, t: Scalar, parallax: Scalar | None = None
) -> tuple[Scalar, Scalar]:
"""This body's relative position to the central in the sky plane, in
separation, position angle coordinates
Args:
t: The times where the angles should be evaluated.
Returns:
The separation (arcseconds) and position angle (radians, measured
east of north) of the planet relative to the star.
"""
X, Y, _ = self.relative_position(t, parallax=parallax)
rho = jnp.sqrt(X**2 + Y**2)
theta = jnp.arctan2(Y, X)
return rho, theta
[docs]
def velocity(
self, t: Scalar, semiamplitude: Scalar | None = None
) -> tuple[Scalar, Scalar, Scalar]:
"""This body's velocity in the barycentric frame
Args:
t: The times where the velocity should be evaluated.
semiamplitude: The semi-amplitude of the orbit. By convention, this
is half of the peak-to-peak amplitude of the central's velocity.
If provided, the ``mass`` and ``inclination`` parameters will be
ignored and this amplitude will be used instead.
Returns:
The components of the velocity vector at ``t`` in units of
``R_sun/day``, but if parallax is provided or ``self.parallax`` is not None,
then in units of arcseconds/day.
"""
if semiamplitude is None:
mass: Scalar = -self.central.mass # type: ignore
return self._get_position_and_velocity(t, mass=mass)[1]
k = -semiamplitude * self.central.mass / self.mass
return self._get_position_and_velocity(t, semiamplitude=k)[1]
[docs]
def central_velocity(
self, t: Scalar, semiamplitude: Scalar | None = None
) -> tuple[Scalar, Scalar, Scalar]:
"""The central's velocity in the barycentric frame
Args:
t: The times where the velocity should be evaluated.
semiamplitude: The semi-amplitude of the orbit. By convention, this
is half of the peak-to-peak amplitude of the central's velocity.
If provided, the ``mass`` and ``inclination`` parameters will be
ignored and this amplitude will be used instead.
Returns:
The components of the velocity vector at ``t`` in units of
``R_sun/day``, but if parallax is provided or ``self.parallax`` is not None,
then in units of arcseconds/day.
"""
if semiamplitude is None:
return self._get_position_and_velocity(t, mass=self.mass)[1]
_, v = self._get_position_and_velocity(t, semiamplitude=semiamplitude)
return v
[docs]
def relative_velocity(
self, t: Scalar, semiamplitude: Scalar | None = None
) -> tuple[Scalar, Scalar, Scalar]:
"""This body's velocity relative to the central
Args:
t: The times where the velocity should be evaluated.
semiamplitude: The semi-amplitude of the orbit. By convention, this
is half of the peak-to-peak amplitude of the central's velocity.
If provided, the ``mass`` and ``inclination`` parameters will be
ignored and this amplitude will be used instead.
Returns:
The components of the velocity vector at ``t`` in units of
``R_sun/day``, but if parallax is provided or ``self.parallax`` is not None,
then in units of arcseconds/day.
"""
if semiamplitude is None:
mass: Scalar = -self.total_mass # type: ignore
return self._get_position_and_velocity(t, mass=mass)[1]
k = -semiamplitude * self.total_mass / self.mass
_, v = self._get_position_and_velocity(t, semiamplitude=k)
return v
[docs]
def radial_velocity(self, t: Scalar, semiamplitude: Scalar | None = None) -> Scalar:
"""Get the radial velocity of the central
.. note:: The convention is that positive `z` points *towards* the
observer. However, for consistency with radial velocity literature
this method returns values where positive radial velocity
corresponds to a redshift as expected.
Args:
t: The times where the radial velocity should be evaluated.
semiamplitude: The semi-amplitude of the orbit. If provided, the
``mass`` and ``inclination`` parameters will be ignored and
this amplitude will be used instead.
Returns:
The reflex radial velocity evaluated at ``t`` in units of ``R_sun/day``.
"""
rv = -self.central_velocity(t, semiamplitude=semiamplitude)[2]
if (
self.parallax is not None
and semiamplitude is None
and self.radial_velocity_semiamplitude is None
):
rv = rv / self.parallax * constants.au
return rv
def _warp_times(self, t: Scalar) -> Scalar:
return t - self.time_transit # type: ignore
def _get_true_anomaly(self, t: Scalar) -> tuple[Scalar, Scalar]:
M = 2 * jnp.pi * (self._warp_times(t) - self.time_ref) / self.period
if self.eccentricity is None:
return jnp.sin(M), jnp.cos(M)
return kepler(M, self.eccentricity)
def _rotate_vector(
self, x: Scalar, y: Scalar, *, include_inclination: bool = True
) -> tuple[Scalar, Scalar, Scalar]:
"""Apply the rotation matrices to go from orbit to observer frame
In order,
1. rotate about the z axis by an amount omega -> x1, y1, z1
2. rotate about the x1 axis by an amount -incl -> x2, y2, z2
3. rotate about the z2 axis by an amount Omega -> x3, y3, z3
Args:
x: A tensor representing the x coodinate in the plane of the
orbit.
y: A tensor representing the y coodinate in the plane of the
orbit.
Returns:
Three tensors representing ``(X, Y, Z)`` in the observer frame.
"""
# 1) rotate about z0 axis by omega
if self.eccentricity is None:
x1 = x
y1 = y
else:
x1 = self.cos_omega_peri * x - self.sin_omega_peri * y
y1 = self.sin_omega_peri * x + self.cos_omega_peri * y
# 2) rotate about x1 axis by -incl
# z3 = z2, subsequent rotation by Omega doesn't affect it
if include_inclination:
x2 = x1
y2 = self.cos_inclination * y1
Z = -self.sin_inclination * y1
else:
x2 = x1
y2 = y1
Z = -y1
# 3) rotate about z2 axis by Omega
if self.cos_asc_node is None:
return x2, y2, Z # type: ignore
X = self.cos_asc_node * x2 - self.sin_asc_node * y2
Y = self.sin_asc_node * x2 + self.cos_asc_node * y2
return X, Y, Z # type: ignore
def _get_position_and_velocity(
self,
t: Scalar,
semimajor: Scalar | None = None,
mass: Scalar | None = None,
semiamplitude: Scalar | None = None,
parallax: Scalar | None = None,
) -> tuple[tuple[Scalar, Scalar, Scalar], tuple[Scalar, Scalar, Scalar]]:
if semiamplitude is None:
semiamplitude = self.radial_velocity_semiamplitude
if parallax is None:
parallax = self.parallax
if semiamplitude is None:
if self.radial_velocity_semiamplitude is None:
m = 1.0 if mass is None else mass
k0 = 2 * jnp.pi * self.semimajor * m / (self.total_mass * self.period)
if self.eccentricity is not None:
k0 /= jnp.sqrt(1 - self.eccentricity**2)
else:
k0 = self.radial_velocity_semiamplitude
if parallax is not None:
k0 = k0 * parallax / constants.au
else:
k0 = semiamplitude
r0 = 1
if semimajor is not None:
if parallax is None:
r0 = semimajor
else:
r0 = semimajor * parallax / constants.au
sinf, cosf = self._get_true_anomaly(t)
if self.eccentricity is None:
v1, v2 = -k0 * sinf, k0 * cosf
else:
v1, v2 = -k0 * sinf, k0 * (cosf + self.eccentricity)
r0 *= (1 - self.eccentricity**2) / (1 + self.eccentricity * cosf)
x, y, z = self._rotate_vector(r0 * cosf, r0 * sinf)
vx, vy, vz = self._rotate_vector(
v1, v2, include_inclination=semiamplitude is None
)
return (x, y, z), (vx, vy, vz)
[docs]
class System(eqx.Module):
"""A Keplerian orbital system"""
_body_stack: ObjectStack[OrbitalBody]
def __init__(
self,
central: Central | None = None,
*,
bodies: Iterable[Body | OrbitalBody] = (),
):
self.central = Central() if central is None else central
self._body_stack = ObjectStack(
*(
b if isinstance(b, OrbitalBody) else OrbitalBody(self.central, b)
for b in bodies
)
)
[docs]
def __repr__(self) -> str:
return eqx.tree_pformat(
self, truncate_leaf=lambda obj: isinstance(obj, ObjectStack)
)
@property
[docs]
def shape(self) -> tuple[int, ...]:
return (len(self._body_stack),)
@property
[docs]
def bodies(self) -> tuple[OrbitalBody, ...]:
return self._body_stack.objects
@property
[docs]
def radius(self) -> Scalar:
return self.body_vmap(lambda body: body.radius)()
@property
[docs]
def central_radius(self) -> Scalar:
return self.body_vmap(lambda body: body.central_radius)()
[docs]
def add_body(
self,
body: Body | None = None,
central: Central | None = None,
**kwargs: Any,
) -> "System":
"""Add a body to the system and return a new system
Args:
body (Body | None, optional): body to add. Defaults to None.
central (Central | None, optional): TODO. Defaults to None.
Returns:
System: :py:class:`~jaxoplanet.orbits.keplerian.System` with the added body
"""
body_: Body | OrbitalBody | None = body
if body_ is None:
body_ = Body(**kwargs)
if central is not None:
body_ = OrbitalBody(central, body_)
return System(central=self.central, bodies=self.bodies + (body_,))
[docs]
def body_vmap(
self,
func: Callable,
in_axes: int | None | Sequence[Any] = 0,
out_axes: Any = 0,
) -> Callable:
"""Map a function over the bodies of this system
If possible, this method will apply the appropriate ``jax.vmap`` to the input
function, but if the Pytree structure of the bodies don't match, this requires
a loop over bodies, applying the function separately to each body, and stacking
the results.
Args:
func: The function to map. It's first positional argument must accept a
Keplerian :class:`Body` object.
in_axes: The input axis specifications for all arguments after the first.
The semantics should match ``jax.vmap``.
out_axes: The output axis specifications, matching ``jax.vmap``.
Returns:
The vectorized version of ``func`` mapped over bodies in this system.
For example, if (for some reason) we wanted to compute the $x$ positions of all
the bodies in a system at a particular time, in units of the body radius, we
could use the following:
>>> from jaxoplanet.orbits.keplerian import Central, System
>>> sys = System(Central())
>>> sys = sys.add_body(period=1.0, radius=0.1)
>>> sys = sys.add_body(period=2.0, radius=0.2)
>>> pos = sys.body_vmap(
... lambda body, t: body.position(t)[0] / body.radius,
... in_axes=None,
... )
>>> pos(0.2)
<Scalar([40.0231 19.632687], 'dimensionless')>
"""
return self._body_stack.vmap(func, in_axes=in_axes, out_axes=out_axes)
[docs]
def position(self, t: Scalar) -> tuple[Scalar, Scalar, Scalar]:
return self.body_vmap(OrbitalBody.position, in_axes=None)(t)
[docs]
def central_position(self, t: Scalar) -> tuple[Scalar, Scalar, Scalar]:
return self.body_vmap(OrbitalBody.central_position, in_axes=None)(t)
[docs]
def relative_position(self, t: Scalar) -> tuple[Scalar, Scalar, Scalar]:
return self.body_vmap(OrbitalBody.relative_position, in_axes=None)(t)
[docs]
def velocity(self, t: Scalar) -> tuple[Scalar, Scalar, Scalar]:
return self.body_vmap(OrbitalBody.velocity, in_axes=None)(t)
[docs]
def central_velocity(self, t: Scalar) -> tuple[Scalar, Scalar, Scalar]:
return self.body_vmap(OrbitalBody.central_velocity, in_axes=None)(t)
[docs]
def relative_velocity(self, t: Scalar) -> tuple[Scalar, Scalar, Scalar]:
return self.body_vmap(OrbitalBody.relative_velocity, in_axes=None)(t)
[docs]
def radial_velocity(self, t: Scalar) -> Scalar:
return self.body_vmap(OrbitalBody.radial_velocity, in_axes=None)(t)