Source code for jaxoplanet.orbits.ttv

import jax
import jax.numpy as jnp

from jaxoplanet.orbits import TransitOrbit
from jaxoplanet.types import Scalar


@jax.jit
def _compute_linear_ephemeris_single(transit_times: jnp.ndarray, indices: jnp.ndarray):
    """Compute linear ephemeris parameters for a single planet using least squares
    fitting."""
    n = transit_times.shape[0]
    X = jnp.vstack([jnp.ones(n), indices])
    beta = jnp.linalg.solve(X @ X.T, X @ transit_times)
    intercept, slope = beta
    ttvs = transit_times - (intercept + slope * indices)
    return intercept, slope, ttvs


@jax.jit
def _compute_bin_edges_values_single(tt: jnp.ndarray, period: float):
    """Compute bin edges and values for a single planet."""
    midpoints = 0.5 * (tt[1:] + tt[:-1])
    bin_edges = jnp.concatenate(
        [
            jnp.array([tt[0] - 0.5 * period]),
            midpoints,
            jnp.array([tt[-1] + 0.5 * period]),
        ]
    )
    bin_values = jnp.concatenate([jnp.array([tt[0]]), tt, jnp.array([tt[-1]])])
    return bin_edges, bin_values


@jax.jit
def _process_planet_dt_single(
    bin_edges: jnp.ndarray, bin_values: jnp.ndarray, t_mag: jnp.ndarray
):
    """Process dt values for a single planet."""
    inds = jnp.searchsorted(bin_edges, t_mag)
    return bin_values[inds]


[docs] class TTVOrbit(TransitOrbit): """An extension of TransitOrbit that allows for transit timing variations (TTVs). The TTVs can be specified in one of two ways: - Provide a tuple (or list) of TTV offset arrays via the argument `ttvs`. - Provide a tuple (or list) of observed transit time arrays via `transit_times` (optionally with `transit_inds` to label each transit). In this case a least-squares linear fit is performed to infer a reference transit time (t₀) and period; the residuals define the TTVs. Only one of these two options should be provided. args added on to TransitOrbit: - ttvs: tuple (or list) of Scalar arrays, each giving the “observed minus computed” transit time offsets (in days) for one planet. - transit_times: tuple (or list) of Scalar arrays giving the observed transit times (in days). - transit_inds: tuple (or list) of integer arrays (one per planet) labeling the transit numbers. - delta_log_period: (optional) if using transit_times and the effective transit period is to be slightly different from the period that governs the transit shape, this parameter gives the offset in natural log. """
[docs] transit_times: tuple[jnp.ndarray, ...]
[docs] transit_inds: tuple[ jnp.ndarray, ... ] # e.g. (jnp.arange(n_transits0), jnp.arange(n_transits1), ...)
[docs] ttvs: tuple[ jnp.ndarray, ... ] # Residuals (observed minus linear model) for each planet
[docs] t0: jnp.ndarray # Reference transit times (one per planet)
[docs] ttv_period: jnp.ndarray # Inferred effective periods (one per planet)
_bin_edges: tuple[jnp.ndarray, ...] _bin_values: tuple[jnp.ndarray, ...] def __init__( self, *, period: Scalar | None = None, duration: Scalar | None = None, speed: Scalar | None = None, time_transit: Scalar | None = None, impact_param: Scalar | None = None, radius_ratio: Scalar | None = None, transit_times: tuple[jnp.ndarray, ...] = None, transit_inds: tuple[jnp.ndarray, ...] | None = None, ttvs: tuple[Scalar, ...] | None = None, delta_log_period: float | None = None, ): if ttvs is not None and transit_times is not None: raise ValueError("Supply either ttvs or transit_times, not both.") if ttvs is None and transit_times is None: raise ValueError("You must supply either transit_times or ttvs.") # CASE 1: transit_times are provided if transit_times is not None: self.transit_times = tuple(jnp.atleast_1d(tt) for tt in transit_times) if transit_inds is None: self.transit_inds = tuple( jnp.arange(tt.shape[0]) for tt in self.transit_times ) else: self.transit_inds = tuple(transit_inds) t0_list = [] period_list = [] ttvs_list = [] for tt, inds in zip(self.transit_times, self.transit_inds, strict=False): t0_i, period_i, ttv_i = _compute_linear_ephemeris_single(tt, inds) t0_list.append(t0_i) period_list.append(period_i) ttvs_list.append(ttv_i) self.t0 = jnp.array(t0_list) self.ttv_period = jnp.array(period_list) self.ttvs = tuple(ttvs_list) if time_transit is None: time_transit = self.t0 # --- Begin delta_log_period adjustment --- # If a period was not provided and a delta_log_period is provided, # adjust the computed period accordingly. if period is None: if delta_log_period is not None: # Compute the adjusted period using delta_log_period period = jnp.exp(jnp.log(self.ttv_period) + delta_log_period) else: period = self.ttv_period else: # Here, the user must supply period and time_transit. if period is None: raise ValueError("When supplying ttvs, period must be provided.") if time_transit is None: # If time_transit is not provided, assume t0 = 0 time_transit = 0.0 # In the TTVs branch of __init__ self.ttvs = tuple(jnp.atleast_1d(ttv - jnp.mean(ttv)) for ttv in ttvs) # For each planet, define transit_inds based on the shape of ttvs if not # provided. if transit_inds is None: self.transit_inds = tuple(jnp.arange(ttv.shape[0]) for ttv in self.ttvs) else: self.transit_inds = tuple(transit_inds) self.t0 = jnp.atleast_1d(time_transit) # For ttvs mode, period is required and used directly: self.ttv_period = jnp.atleast_1d(period) # Reconstruct transit_times: t0 + period * inds + ttvs transit_times_list = [] for ttv, inds in zip(self.ttvs, self.transit_inds, strict=False): # Handle single-planet t0 and period versus multi-planet cases. if self.t0.ndim > 0 and self.t0.shape[0] > 1: t0_i = self.t0[len(transit_times_list)] period_i = self.ttv_period[len(transit_times_list)] else: t0_i = self.t0.item() if hasattr(self.t0, "item") else self.t0 period_i = ( self.ttv_period.item() if hasattr(self.ttv_period, "item") else self.ttv_period ) transit_times_list.append(t0_i + period_i * inds + ttv) self.transit_times = tuple(transit_times_list) # Initialize the base TransitOrbit. super().__init__( period=period if period is not None else self.ttv_period, duration=duration, speed=speed, time_transit=time_transit, impact_param=impact_param, radius_ratio=radius_ratio, ) # Compute bins separately for each planet bin_edges_list = [] bin_values_list = [] for tt, period in zip(self.transit_times, self.ttv_period, strict=False): edges, values = _compute_bin_edges_values_single(tt, period) bin_edges_list.append(edges) bin_values_list.append(values) self._bin_edges = tuple(bin_edges_list) self._bin_values = tuple(bin_values_list) @jax.jit def _get_model_dt(self, t: Scalar) -> jnp.ndarray: """Get model dt values for time warping.""" t_magnitude = jnp.asarray(t) dt_list = [] for edges, values in zip(self._bin_edges, self._bin_values, strict=False): edges_mag = edges if hasattr(edges, "to") else edges values_mag = values if hasattr(values, "to") else values dt = _process_planet_dt_single(edges_mag, values_mag, t_magnitude) dt_list.append(dt) return jnp.stack(dt_list) @jax.jit def _warp_times(self, t: Scalar) -> Scalar: """Warp times based on transit timing variations.""" dt = self._get_model_dt(t) t0_days = self.t0 if hasattr(self.t0, "to") else self.t0 return t - (dt - t0_days)
[docs] def relative_position( self, t: Scalar, parallax: Scalar | None = None ) -> tuple[Scalar, Scalar, Scalar]: """Compute relative position of the planet(s).""" warped_t = self._warp_times(t) x, y, z = super().relative_position(warped_t, parallax=parallax) return ( jnp.squeeze(x), jnp.squeeze(y), jnp.squeeze(z), )
@property
[docs] def linear_t0(self): """Return the linear reference transit time.""" if hasattr(self.t0, "magnitude"): return jnp.atleast_1d(self.t0) else: return jnp.atleast_1d(self.t0)
@property
[docs] def linear_period(self): """Return the linear period.""" if hasattr(self.ttv_period, "magnitude"): return jnp.atleast_1d(self.ttv_period) else: return jnp.atleast_1d(self.ttv_period)
[docs] def compute_expected_transit_times(min_time, max_time, period, t0): """ Compute expected transit times for each planet and return them as a tuple of 1D arrays. Args: min_time (float): Start time (in days). max_time (float): End time (in days). period (array-like): Orbital period for each planet (in days). Should be convertible to a 1D JAX array. t0 (array-like): Reference transit times for each planet (in days). Should be convertible to a 1D JAX array. Returns: transit_times: tuple of JAX arrays, one per planet. """ period = jnp.atleast_1d(period) t0 = jnp.atleast_1d(t0) transit_times_list = [] for p, t0_val in zip(period, t0, strict=False): i_min = int(jnp.ceil((min_time - t0_val) / p)) i_max = int(jnp.floor((max_time - t0_val) / p)) indices = jnp.arange(i_min, i_max + 1) times = t0_val + p * indices times = times[(times >= min_time) & (times <= max_time)] transit_times_list.append(times) return tuple(transit_times_list)