Source code for jaxoplanet.light_curves.transforms

"""A module providing decorators to transform light curve functions"""

__all__ = ["integrate", "interpolate"]

from functools import wraps
from typing import Any

import jax
import jax.numpy as jnp

from jaxoplanet.light_curves.types import LightCurveFunc
from jaxoplanet.light_curves.utils import vectorize
from jaxoplanet.types import Array, Scalar

try:
    from jax.extend import linear_util as lu
except ImportError:
    from jax import linear_util as lu  # type: ignore


[docs] def integrate( func: LightCurveFunc, exposure_time: Scalar | None = None, order: int = 0, num_samples: int = 7, ) -> LightCurveFunc: """Transform a light curve function to apply exposure time integration This transformation applies a fixed stencil numerical integration scheme to the input function ``func`` to convolve the light curve with a top hat exposure time centered on the input time, with a full width of ``exposure_time``. The order of the integration scheme is set using the ``order`` parameter which must be ``0``, ``1``, or ``2``. The default (``0``) uses the "resampling" scheme discussed by `Kipping (2010) <https://arxiv.org/abs/1004.3741>`_. The higher order schemes ``1`` and ``2`` apply the trapezoid and Simpson's rules respectively, but won't necessarily provide higher accuracy results because of discontinuities at the contact points. In practice, the parameter ``num_samples`` which sets the number of function evaluations per integral has the most significant effect on the accuracy of this integral, trading off against higher computational cost. Args: func: A light curve function which takes a time ``Scalar`` as the first argument exposure_time (Scalar): The exposure time (in days, by default) order (int): The order of the integration scheme as discussed above num_samples (int): The number of function evaluations made per integral, controlling the accuracy of the numerics Returns: A new light curve function with the same signature as ``func``, computing the exposure time integrated flux """ if exposure_time is None: return func if jnp.ndim(exposure_time) != 0: raise ValueError( "The exposure time passed to 'integrate_exposure_time' has shape " f"{jnp.shape(exposure_time)}, but a scalar was expected; " "To use exposure time integration with different exposures at different " "times, manually 'vmap' or 'vectorize' the function" ) # Ensure 'num_samples' is an odd number num_samples = int(num_samples) num_samples += 1 - num_samples % 2 stencil = jnp.ones(num_samples) # Construct exposure time integration stencil if order == 0: dt = jnp.linspace(-0.5, 0.5, 2 * num_samples + 1)[1:-1:2] elif order == 1: dt = jnp.linspace(-0.5, 0.5, num_samples) stencil = 2 * stencil stencil = stencil.at[0].set(1) stencil = stencil.at[-1].set(1) elif order == 2: dt = jnp.linspace(-0.5, 0.5, num_samples) stencil = stencil.at[1:-1:2].set(4) stencil = stencil.at[2:-1:2].set(2) else: raise ValueError( "The parameter 'order' in 'integrate_exposure_time' must be 0, 1, or 2" ) dt = dt * exposure_time stencil /= jnp.sum(stencil) @wraps(func) @vectorize def wrapped(time: Scalar, *args: Any, **kwargs: Any) -> Array | Scalar: if jnp.ndim(time) != 0: raise ValueError( "The time passed to 'integrate_exposure_time' has shape " f"{jnp.shape(time)}, but a scalar was expected; " "this shouldn't typically happen so please open an issue " "on GitHub demonstrating the problem" ) vmapped_func = jax.vmap(func, in_axes=(0,) + (None,) * len(args)) debug_info = jax.api_util.debug_info( "integrate", vmapped_func, (time,) + tuple(args), kwargs ) f = lu.wrap_init(vmapped_func, debug_info=debug_info) f = apply_exposure_time_integration(f, stencil, dt) # type: ignore return f.call_wrapped(time, args, kwargs) # type: ignore return wrapped
@lu.transformation # type: ignore def apply_exposure_time_integration(stencil, dt, time, args, kwargs): result = yield (time + dt,) + args, kwargs yield jnp.dot(stencil, result)
[docs] def interpolate( func: LightCurveFunc, *, period: Scalar, time_transit: Scalar, num_samples: int, duration: Scalar | None = None, args: tuple[Any, ...] = (), kwargs: dict[str, Any] | None = None, ) -> LightCurveFunc: """Transform a light curve function to pre-compute the model on a grid Sometimes it can be useful to precompute the light curve on a grid near a transit, and then interpolate those computations to the required phases when computing the full model. This can speed things up a lot when you have many transits, or a lot of out of transit data. This transform uses linear interpolation. .. note:: Unlike some other transforms, this function requires that any upstream ``*args`` and ``**kwargs`` be passed directly to the transform, rather than when calling the transformed function. This is necessary because the model is pre-computed when it is tranformed. Args: func: A light curve function which takes a time ``Scalar`` as the first argument period (Scalar): The period of the orbit. Used to wrap the input times into the domain of the pre-computed model time_transit (Scalar): The transit time of the orbit. Used to wrap the input times into the domain of the pre-computed model duration (Scalar): The duration centered on the transit to pre-compute. By default, the full period will be evaluated num_samples (int): The number of points in the time grid used for pre-computation args (tuple): Any extra positional arguments that should be passed to ``func`` kwargs (dict): Any extra keyword arguments that should be passed to ``func`` Returns: A new light curve function with the same signature as ``func``, computing the flux by interpolating a pre-computed model """ kwargs = kwargs or {} if duration is None: duration = period time_grid = time_transit + duration * jnp.linspace(-0.5, 0.5, num_samples) flux_grid = func(time_grid, *args, **kwargs) flux_magnitude = flux_grid @wraps(func) @vectorize def wrapped(time: Scalar, *args: Any, **kwargs: Any) -> Array | Scalar: del args, kwargs time_wrapped = ( jnp.mod(time - time_transit + 0.5 * period, period) + 0.5 * period + time_transit ) flux = jnp.interp( time_wrapped, time_grid, flux_magnitude, left=flux_magnitude[0], right=flux_magnitude[-1], period=period, ) return flux return wrapped