Transit Fitting#

Like exoplanet, jaxoplanet includes methods for computing the light curves of transiting exoplanets. In this tutorial, we introduce these methods and use it alongside the NumPyro probabilistic programming library to do some transit fitting. Parts of this tutorial will follow the Transit Fitting tutorial for the exoplanet package.

In addition to jaxoplanet (and NumPy, Matplotlib), you’ll need to install the following packages to run this tutorial:

Let’s import the necessary packages and configure the setup.

import jaxoplanet
from jaxoplanet.light_curves import limb_dark_light_curve
from jaxoplanet.orbits import TransitOrbit
import numpy as np
import matplotlib.pyplot as plt
import numpyro
import numpyro_ext.distributions, numpyro_ext.optim
import jax
import jax.numpy as jnp
import corner
import arviz as az
import copy

)  # For multi-core parallelism (useful when running multiple MCMC chains in parallel)
numpyro.set_platform("cpu")  # For CPU (use "gpu" for GPU)
    "jax_enable_x64", True
)  # For 64-bit precision since JAX defaults to 32-bit

print(f"jaxoplanet.__version__ = {jaxoplanet.__version__}")
print(f"numpy.__version__ = {np.__version__}")
print(f"matplotlib.__version__ = {plt.matplotlib.__version__}")
print(f"numpyro.__version__ = {numpyro.__version__}")
print(f"numpyro_ext.__version__ = {numpyro_ext.__version__}")
print(f"jax.__version__ = {jax.__version__}")
print(f"corner.__version__ = {corner.__version__}")
print(f"arviz.__version__ = {az.__version__}")
jaxoplanet.__version__ = 0.0.3.dev17+g2fd9026
numpy.__version__ = 1.26.4
matplotlib.__version__ = 3.8.4
numpyro.__version__ = 0.14.0
numpyro_ext.__version__ = 0.0.4
jax.__version__ = 0.4.26
corner.__version__ = 2.2.2
arviz.__version__ = 0.18.0

Let’s first compute a simple light curve.

# The light curve calculation requires an orbit object.
# We'll use TransitOrbit (similar to SimpleTransitOrbit in the exoplanet package),
# which is an orbit parameterized by the observables of a transiting system:
# period, speed/duration, time of transit, impact parameter, and radius ratio.
orbit = TransitOrbit(
    period=3.456, duration=0.12, time_transit=0.0, impact_param=0.0, radius=0.1
)  # TODO: Is it actually the radius ratio?

# Compute a limb-darkened light curve for this orbit
t = np.linspace(-0.1, 0.1, 1000)
u = [0.1, 0.06]  # Quadratic limb-darkening coefficients
light_curve = limb_dark_light_curve(orbit, u)(t)

# Plot the light curve
plt.plot(t, light_curve, lw=2)
plt.xlabel("time [days]")
plt.ylabel("relative flux")
plt.xlim(t.min(), t.max());

Transit model in NumPyro#

We’ll construct a transit model using NumPyro and fit to some simulated data. NumPyro is a probabilistic programming library (PPLs) like PyMC that allows us to succinctly build models and perform (gradient-based) inference with them. NumPyro models must be written in JAX!

Let’s start off by choosing the transit properties of our simulated data. These will be the “true” values that we would like to recover with our inference.

# Simulate some data with Gaussian noise
random = np.random.default_rng(42)
PERIOD = random.uniform(2, 5)  # day
T0 = PERIOD * random.uniform()  # day
DURATION = 0.5  # day
B = 0.5  # impact parameter
ROR = 0.08  # planet radius / star radius
U = np.array([0.1, 0.06])  # limb darkening coefficients
yerr = 5e-4  # flux uncertainty
t = np.arange(0, 17, 0.05)  # day

orbit = TransitOrbit(
    period=PERIOD, duration=DURATION, time_transit=T0, impact_param=B, radius=ROR
y_true = limb_dark_light_curve(orbit, U)(t)
y = y_true + yerr * random.normal(size=len(t))

# Let's see what the light curve looks like
plt.plot(t, y_true, "-k", lw=1.0, label="truth")
plt.plot(t, y, ".k", ms=2, label="data")
plt.xlabel("time [days]")
plt.ylabel("relative flux")
plt.xlim(t.min(), t.max())

Defining the model#

Let’s define our numpyro model. The syntax for numpyro might be a bit unfamiliar, but here it is. We’re sampling the period and duration in log space to constrain it to positive values, and we’re also sampling the quadratic limb darkening coefficients using the custom distribution QuadLDParams in the numpyro_ext package.

def model(t, yerr, y=None):
    # Priors for the parameters we're fitting for

    # The time of reference transit
    t0 = numpyro.sample("t0", numpyro.distributions.Normal(T0, 1))

    # The period
    logP = numpyro.sample("logP", numpyro.distributions.Normal(jnp.log(PERIOD), 0.1))
    period = numpyro.deterministic("period", jnp.exp(logP))

    # The duration
    logD = numpyro.sample("logD", numpyro.distributions.Normal(jnp.log(DURATION), 0.1))
    duration = numpyro.deterministic("duration", jnp.exp(logD))

    # The radius ratio
    # logR = numpyro.sample("logR", numpyro.distributions.Normal(jnp.log(ROR), 0.1))
    r = numpyro.sample("r", numpyro.distributions.Uniform(0.01, 0.2))
    # r = numpyro.deterministic("r", jnp.exp(logR))

    # The impact parameter
    # b = numpyro.sample("b", numpyro.distributions.Uniform(0, 1.0))
    _b = numpyro.sample("_b", numpyro.distributions.Uniform(0, 1.0))
    b = numpyro.deterministic("b", _b * (1 + r))

    # The limb darkening coefficients
    u = numpyro.sample("u", numpyro_ext.distributions.QuadLDParams())

    # The orbit and light curve
    orbit = TransitOrbit(
        period=period, duration=duration, time_transit=t0, impact_param=b, radius=r
    y_pred = limb_dark_light_curve(orbit, u)(t)

    # Let's track the light curve
    numpyro.deterministic("light_curve", y_pred)

    # The likelihood function assuming Gaussian uncertainty
    numpyro.sample("obs", numpyro.distributions.Normal(y_pred, yerr), obs=y)

Checking the priors#

It can be a good idea to see whether the priors we defined are reasonable by sampling and plotting them. Let’s do that now using the numpyro.infer submodule’s Predictive functionality to draw some samples from the priors.

n_prior_samples = 3000
prior_samples = numpyro.infer.Predictive(model, num_samples=n_prior_samples)(
    jax.random.PRNGKey(0), t, yerr

# Let's make it into an arviz InferenceData object.
# To do so we'll first need to reshape the samples to be of shape (chains, draws, *shape)
converted_prior_samples = {
    f"{p}": np.expand_dims(prior_samples[p], axis=0) for p in prior_samples
prior_samples_inf_data = az.from_dict(converted_prior_samples)
# Plot the corner plot
fig = plt.figure(figsize=(12, 12))
_ = corner.corner(
    var_names=["t0", "period", "duration", "r", "b", "u"],
    truths=[T0, PERIOD, DURATION, ROR, B, U[0], U[1]],
    title_kwargs={"fontsize": 10},
    label_kwargs={"fontsize": 12},

These priors seems sensible enough and the true values (blue lines) are within their bounds. Before we start sampling, let’s find the maximum a posteriori (MAP) solution. This is a good starting point for the sampling we’ll perform later and also a good check to see if things are working. We’ll use the optimize function defined within the numpyro_ext package.

We have a choice for the inital value of the optimization. Some potential options include:

  1. Manually setting them to a specific set of values. This approach might make sense for real data when it’s a system that’s been studied before and there’s a good guess for the parameters. As an example, if we were fitting some follow-up ground-based transit data it might make sense to use the parameters from a Kepler/TESS discovery paper as the initial values.

  2. The median values of the priors. This might be a good idea when we don’t have a good guess for the parameters. Similarly, we could also use the mean values of the priors.

Let’s do the former and set the initial values to the true values.

init_param_method = "true_values"  # "prior_median" or "true_values"

if init_param_method == "prior_median":
    print("Starting from the prior medians")
    run_optim = numpyro_ext.optim.optimize(
        model, init_strategy=numpyro.infer.init_to_median()
elif init_param_method == "true_values":
    print("Starting from the true values")
    init_params = {
        "t0": T0,
        "logP": jnp.log(PERIOD),
        "logD": jnp.log(DURATION),
        "logR": jnp.log(ROR),
        "_b": B / (1 + ROR),
        "u": U,
    run_optim = numpyro_ext.optim.optimize(

opt_params = run_optim(jax.random.PRNGKey(5), t, yerr, y=y)
Starting from the true values
for k, v in opt_params.items():
    if k in ["light_curve", "obs", "_b"]:
    print(f"{k}: {v}")
t0: 1.9000778724033895
logP: 1.4632457135160681
logD: -0.6800678396181157
r: 0.0805869891771916
u: [ 0.24748829 -0.12374414]
b: 0.5938904995145325
duration: 0.50658262482805
period: 4.319958144846175

Now let’s plot the MAP model against the simulated data.

plt.plot(t, y, ".k", ms=2, label="data")
plt.plot(t, y_true, "-k", lw=1.0, label="truth")
plt.plot(t, opt_params["light_curve"], "--C0", lw=1.0, label="MAP model")
plt.xlabel("time [days]")
plt.ylabel("relative flux")
plt.legend(fontsize=10, loc=4)
plt.xlim(t.min(), t.max());

Great. Not surprisingly, the MAP model is a good fit to the data. Let’s use these MAP values as the initial values for our sampling.


Let’s sample from the posterior defined by this model and data. We’ll use the No-U-Turn Sampler (NUTS) algorithm, which is a variant of the Hamiltonian Monte Carlo (HMC) algorithm that automatically tunes some of the sampling parameters.

This cell takes about a minute to run on my laptop. Don’t worry if it doesn’t seem like anything is happening for a while at the beginning; compiling the code and running the first 100-200 iterations are the most computationally demanding and the subsequent sampling runs much faster!

Below, we’re setting the regularize_mass_matrix keyword to False. This is because we realized that the default of setting it to True causes the sampling to be slower (roughly 4 times slower). We’re investigating why but it seems (at least for this case) the posteriors are pretty much the same whether you regularize the mass matrix or not.

sampler = numpyro.infer.MCMC(
), t, yerr, y=y)

Checking our posterior samples#

We should check the convergence of our sampler. Determining whether a sampler has converged is not trivial and there is a lot of literature on the subject. Here we’ll attempt to check for convergence by looking at the the Gelman-Rubin \(\hat{R}\) statistic and the bulk effective sample size (ESS) of each parameter.

  • The \(\hat{R}\) statistic is a diagnostic of convergence based on the ratio of the variance between chains to the variance within chains. We would like for it to be close to 1.00 for each parameter.

  • The ESS is a measure of the number of independent samples in the chains and is inversely correlated with the autocorrelation in a chain. Larger estimates for the ESS are better as they indicate less autocorrelation in the chains.

We can get both of these values using the summary function in the Arviz package. Let’s do that now.

inf_data = az.from_numpyro(sampler)
samples = sampler.get_samples()
az.summary(inf_data, var_names=["t0", "period", "duration", "r", "b", "u"])
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
t0 1.901 0.003 1.895 1.906 0.000 0.000 2026.0 2195.0 1.00
period 4.320 0.002 4.317 4.323 0.000 0.000 1720.0 1336.0 1.00
duration 0.504 0.006 0.493 0.516 0.000 0.000 1024.0 1133.0 1.01
r 0.079 0.001 0.077 0.081 0.000 0.000 637.0 1148.0 1.01
b 0.417 0.185 0.063 0.671 0.009 0.007 463.0 665.0 1.01
u[0] 0.151 0.111 0.000 0.349 0.002 0.002 1958.0 1932.0 1.00
u[1] 0.163 0.207 -0.169 0.555 0.006 0.004 1216.0 1878.0 1.00

The ESS (ess_bulk) isn’t great for some of the parameters, like the duration and the impact parameter \(b\), but since the \(\hat{R}\) values are good let’s just go ahead with these samples.

# There's also a method to obtain similar results to `az.summary` but directly
# as a method with the MCMC sampler. It also gives us the number of divergences.
                mean       std    median      5.0%     95.0%     n_eff     r_hat
        _b      0.39      0.17      0.44      0.07      0.59    395.88      1.01
      logD     -0.68      0.01     -0.68     -0.71     -0.67    985.29      1.01
      logP      1.46      0.00      1.46      1.46      1.46   1500.16      1.00
         r      0.08      0.00      0.08      0.08      0.08    592.50      1.01
        t0      1.90      0.00      1.90      1.90      1.91   1943.96      1.00
      u[0]      0.15      0.11      0.13      0.00      0.31   2030.01      1.00
      u[1]      0.16      0.21      0.14     -0.16      0.49   1176.15      1.00

Number of divergences: 0

Let’s get a different view of the chains by making some trace plots. We can do this using the plot_trace function in the Arviz package.

    var_names=["t0", "period", "duration", "r", "b", "u"],
    backend_kwargs={"constrained_layout": True},

The different line styles (not colors!) above indicate the different chains. There’s two colors for \(u\) since there are two limb-darkening coefficients (i.e., \(u_1, u_2\)).

Let’s now make a corner plot of the posterior samples to see the pairwise joint distributions of the parameters and see if there are any correlations.

    var_names=["t0", "period", "duration", "r", "b", "u"],
    truths=[T0, PERIOD, DURATION, ROR, B, U[0], U[1]],
    quantiles=[0.16, 0.5, 0.84],
    title_kwargs={"fontsize": 12},
    label_kwargs={"fontsize": 15},

The blue lines indicate the true values. All the true values are within 1-sigma of the marginalized posterior distributions, which is good!

Phase plots#

Let’s make the phase plot that is commonly shown in exoplanet papers.

inferred_t0 = np.median(samples["t0"])
inferred_period = np.median(samples["period"])
inferred_duration = np.median(samples["duration"])
inferred_r = np.median(samples["r"])
inferred_b = np.median(samples["b"])
inferred_u = np.median(samples["u"], axis=0)

orbit = TransitOrbit(
y_model = limb_dark_light_curve(orbit, inferred_u)(t)

fig, ax = plt.subplots(dpi=150)

# Plot the folded data
t_fold = (
    t - inferred_t0 + 0.5 * inferred_period
) % inferred_period - 0.5 * inferred_period

# Plot the folded model
inds = np.argsort(t_fold)
ax.plot(t_fold[inds], y_model[inds], color="C1", lw=1.5, label="model")
ax.set_xlabel("time since transit [days]")
ax.set_ylabel("relative flux")
ax.legend(fontsize=10, loc=4)
ax.set_xlim(-inferred_duration, inferred_duration);