Source code for jaxoplanet.light_curves.limb_dark

__all__ = ["light_curve"]

from collections.abc import Callable
from functools import partial

import jax.numpy as jnp
import jpu.numpy as jnpu

from jaxoplanet import units
from jaxoplanet.core.limb_dark import light_curve as _limb_dark_light_curve
from jaxoplanet.light_curves.utils import vectorize
from jaxoplanet.proto import LightCurveOrbit
from jaxoplanet.types import Array, Quantity
from jaxoplanet.units import unit_registry as ureg


[docs] def light_curve( orbit: LightCurveOrbit, *u: Array, order: int = 10 ) -> Callable[[Quantity], Array]: """Compute the light curve for arbitrary polynomial limb darkening See `Agol et al. (2020) <https://arxiv.org/abs/1908.03222>`_ and :func:`jaxoplanet.core.limb_dark.light_curve` for more technical details. Args: orbit (LightCurveOrbit): An orbit object that can be used to evaluate the relative positions of the transiting body with respect to the light source. u (Array): The coefficients of the polynomial limb darkening order (int): The order of the numerical integration used by the backend; see :func:`jaxoplanet.core.limb_dark.light_curve` Returns: A function which takes the time in days as input and returns the light curve flux """ if u: ld_u = jnp.concatenate([jnp.atleast_1d(jnp.asarray(u_)) for u_ in u], axis=0) else: ld_u = jnp.array([]) @units.quantity_input(time=ureg.d) @vectorize def light_curve_impl(time: Quantity) -> Array: if jnpu.ndim(time) != 0: raise ValueError( "The time passed to 'light_curve' has shape " f"{jnpu.shape(time)}, but a scalar was expected; " "this shouldn't typically happen so please open an issue " "on GitHub demonstrating the problem" ) # Evaluate the coordinates of the transiting body r_star = orbit.central_radius x, y, z = orbit.relative_position(time) b = jnpu.sqrt(x**2 + y**2) / r_star assert b.units == ureg.dimensionless r = orbit.radius / r_star assert r.units == ureg.dimensionless lc_func = partial(_limb_dark_light_curve, ld_u, order=order) lc = lc_func(b.magnitude, r.magnitude) lc = jnp.where(z > 0, lc, 0) return lc return light_curve_impl