Source code for jaxoplanet.starry.system_observable
from collections.abc import Callable
from functools import partial
import jax
import jax.numpy as jnp
from jaxoplanet.starry.orbit import SurfaceSystem
from jaxoplanet.types import Array, Scalar
[docs]
def system_observable(surface_observable, **kwargs):
_surface_observable = partial(surface_observable, **kwargs)
def observable_fun(
system: SurfaceSystem,
) -> Callable[[Scalar], tuple[Array | None, Array | None]]:
# the observable function of the central given the position and radii
# of all other bodies
central_bodies_observable = jax.vmap(
_surface_observable, in_axes=(None, 0, 0, 0, 0, None)
)
# the observable function of all bodies combined given their position to the
# central
@partial(system.surface_vmap, in_axes=(0, 0, 0, 0, None))
def compute_body_observable(surface, radius, x, y, z, time):
if surface is None:
return 0.0
else:
theta = surface.rotational_phase(time)
return _surface_observable(
surface,
(system.central.radius / radius),
(x / radius),
(y / radius),
(z / radius),
theta,
)
@partial(jnp.vectorize, signature="()->(n)")
def observable_impl(time: Scalar) -> Array:
# a function that give the array of observables for all bodies, starting
# with the central
if system.central_surface is None:
central_light_curves = jnp.array([0.0])
else:
theta = system.central_surface.rotational_phase(time)
central_radius = system.central.radius
central_phase_curve = _surface_observable(
system.central_surface, theta=theta
)
if len(system.bodies) > 0:
xos, yos, zos = system.relative_position(time)
n = len(xos)
central_light_curves = central_bodies_observable(
system.central_surface,
(system.radius / central_radius),
(xos / central_radius),
(yos / central_radius),
(zos / central_radius),
theta,
)
if n > 1 and central_light_curves is not None:
central_light_curves = central_light_curves.sum(
0
) - central_phase_curve * (n - 1)
central_light_curves = jnp.expand_dims(central_light_curves, 0)
body_light_curves = compute_body_observable(
system.radius, -xos, -yos, -zos, time
)
return jnp.hstack([central_light_curves, body_light_curves])
else:
return jnp.array([central_phase_curve])
return observable_impl
return observable_fun