Radial Velocities Fitting

Radial Velocities Fitting#

Hide code cell source
import jax
import numpyro

numpyro.set_host_device_count(2)
jax.config.update("jax_enable_x64", True)

In this tutorial we will learn how to use jaxoplanet to compute the radial velocities of a star hosting a single exoplanet, and how to fit this dataset using numpyro.

Note

This tutorial requires the installation of the following packages:

Model and dataset#

Let’s first generate our dataset, consisting in the radial velocities of a star orbited by a unique exoplanet. We start by defining the system (see for more details)

import numpy as np
import matplotlib.pyplot as plt
from jaxoplanet.orbits import keplerian

truth = dict(mass=0.02, period=3.0)
star = keplerian.Central(mass=1.0, radius=1.0)
system = keplerian.System(star).add_body(**truth)

As we left many parameters as default, let’s check the parameters of the system

system
System(
  central=Central(
    mass=<Quantity(1.0, 'M_sun')>,
    radius=<Quantity(1.0, 'R_sun')>,
    density=<Quantity(0.238732415, 'M_sun / R_sun ** 3')>
  ),
  _body_stack=ObjectStack(...)
)

We can now compute the radial velocities of the star and add some noise to simulate our dataset

from jaxoplanet.units import unit_registry as ureg

np.random.seed(10)
over_time = np.linspace(0, 10, 1000)
over_rvs = system.radial_velocity(over_time)[0]
time = np.sort(np.random.uniform(0, 10, 40))
rv_obs = system.radial_velocity(time)[0]
rv_err = 0.05 * rv_obs.units
rv_obs += rv_err * np.random.normal(size=len(time))


def plot_data():
    plt.errorbar(time, rv_obs.magnitude, yerr=rv_err.magnitude, fmt="+k")
    plt.xlabel("time (days)")
    plt.ylabel(f"radial velocity (${rv_obs.units:L}$)")


plt.plot(over_time, over_rvs.magnitude)
plot_data()
../../_images/08185206b67de7f1c8661c4e343c1858d9f80acb510c3436937314e5c5c915e7.png

Inference#

We will infer the value and associated uncertainty of the system orbital parameters using numpyro. In order to do that we first define a callable model function

from numpyro import distributions as dist, infer


def rv_model(time, params):
    system = keplerian.System(star).add_body(
        mass=params["mass"], period=params["period"]
    )
    return system.radial_velocity(time)[0].magnitude


def model(time, y=None):
    mass = numpyro.sample("mass", dist.Uniform(0.01, 0.1))
    period = numpyro.sample("period", dist.Uniform(1.0, 10.0))
    error = numpyro.sample("error", dist.Uniform(0.01, 0.08))
    rv = rv_model(time, {"mass": mass, "period": period})

    # the likelihood function
    numpyro.sample("y", dist.Normal(rv, error), obs=y)

Once the model defined, we will sample the posterior likelihood of the model parameters given the observed radial velocities. As we will need to provide some initial values for these parameters, it is a good idea to check that these values provide a good starting point for the model.

init_values = {"mass": 0.022, "period": 3.01, "error": rv_err.magnitude}
init_model = rv_model(over_time, init_values)

plt.plot(over_time, init_model, "C0")
plot_data()
../../_images/46eb9cda6e2c0b2b436b2311d06833fa0f4a36b1f472ef50e8d1ac13a78c966a.png

Starting from this initial guess, we use numpyro’s MCMC No-U-Turn Sampler (NUTS) to sample the posterior likelihood of the parameters.

sampler = infer.MCMC(
    infer.NUTS(model, init_strategy=infer.init_to_value(values=init_values)),
    num_warmup=2000,
    num_samples=10000,
    progress_bar=True,
    num_chains=2,
)

sampler.run(jax.random.PRNGKey(6), time, y=rv_obs.magnitude)

We can plot the inferred model with

samples = sampler.get_samples()
posterior_rvs = infer.Predictive(model, samples)(jax.random.PRNGKey(0), over_time)["y"]

plot_data()
plt.plot(over_time, posterior_rvs.mean(0), "C0")
_ = plt.fill_between(
    over_time,
    *np.percentile(posterior_rvs, [16, 84], axis=0),
    alpha=0.3,
    color="C0",
)
../../_images/4de297d3a2d3fc66b84ad0f29563b584aca4d9befaeb7aadfc51300514659afb.png

and check the inferred model parameters in a corner plot

import corner

_ = corner.corner(
    samples,
    var_names=["mass", "period"],
    truths=[truth["mass"], truth["period"]],
    show_titles=True,
    quantiles=[0.16, 0.5, 0.84],
    title_fmt=".4f",
)
../../_images/44463fb489ec5bc62c10a637675b26c33409b501a722df31db32628a5be5128f.png