{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"(autodiff)=\n",
"\n",
"# Automatic differentiation\n",
"\n",
"The major selling point of `jaxoplanet` compared to other similar libraries is that it builds on top of `JAX`, which extends numerical computing tools like `numpy` and `scipy` to support automatic [automatic differentiation (AD)](https://en.wikipedia.org/wiki/Automatic_differentiation) and hardware acceleration.\n",
"In this tutorial, we present an introduction to the AD capabilities of `JAX` and `jaxoplanet`, but we won't go too deep into the technical weeds of how automatic differentiation works.\n",
"\n",
"It's beyond the scope of this tutorial to go into too many details about AD and most users of `jaxoplanet` shouldn't need to interact with these features directly very often, but this should at least give you a little taste of the kinds of things AD can do for you and demonstrate how this translates into efficient inference with probabilistic models.\n",
"The main thing that I want to emphasize here is that AD is not the same as *symbolic differentiation* (it's not going to provide you with a mathematical expression for your gradients), but it's also not the same as numerical methods like [finite difference](https://en.wikipedia.org/wiki/Finite_difference).\n",
"Using AD to evaluate the gradients of your model will generally be faster, more efficient, and more numerically stable than alternatives, but there are always exceptions to any rule.\n",
"There are times when providing your AD framework with a custom implementation and/or differentation rule for a particular function is beneficial in terms of cost and stability.\n",
"`jaxoplanet` is designed to provide these custom implementations only where it is useful (e.g. solving Kepler's equation or evaluating limb-darkened light curves) and then rely on the existing AD toolkit elsewhere.\n",
"\n",
"## Automatic differentiation in JAX\n",
"\n",
"One of the core features of `JAX` is its support for automatic differentiation (AD; that's what the \"A\" in `JAX` stands for).\n",
"To differentiate a Python function using `JAX`, we start by writing the function using `JAX`'s `numpy` interface.\n",
"In this case, let's use a made up function that isn't meant to be particularly meaningful:\n",
"\n",
"$$\n",
"y = \\exp\\left[\\sin\\left(\\frac{2\\,\\pi\\,x}{3}\\right)\\right]\n",
"$$\n",
"\n",
"and then calculate the derivative using AD.\n",
"For comparison, the symbolic derivative is:\n",
"\n",
"$$\n",
"\\frac{\\mathrm{d}y}{\\mathrm{d}x} = \\frac{2\\,\\pi}{3}\\,\\exp\\left[\\sin\\left(\\frac{2\\,\\pi\\,x}{3}\\right)\\right]\\,\\cos\\left(\\frac{2\\,\\pi\\,x}{3}\\right)\n",
"$$"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"def func(x):\n",
" arg = jnp.sin(2 * jnp.pi * x / 3)\n",
" return jnp.exp(arg)\n",
"\n",
"\n",
"def symbolic_grad(x):\n",
" arg = 2 * jnp.pi * x / 3\n",
" return 2 * jnp.pi / 3 * jnp.exp(jnp.sin(arg)) * jnp.cos(arg)\n",
"\n",
"\n",
"x = jnp.linspace(-3, 3, 100)\n",
"plt.plot(x, func(x))\n",
"plt.xlabel(r\"$x$\")\n",
"plt.ylabel(r\"$f(x)$\")\n",
"plt.xlim(x.min(), x.max());"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Then, we can differentiate this function using the `jax.grad` function.\n",
"The interface provided by the `jax.grad` function may seem a little strange at first, but the key point is that it takes a function (like the one we defined above) as input, and it returns a new function that can evaluate the gradient of that input.\n",
"For example:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"grad_func = jax.grad(func)\n",
"print(grad_func(0.5))\n",
"np.testing.assert_allclose(grad_func(0.5), symbolic_grad(0.5))"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"One subtlety here is that we can only compute the gradient of a _scalar_ function.\n",
"In other words, the output of the function must just be a number, not an array.\n",
"But, we can combine this `jax.grad` interface the `jax.vmap` function to evaluate the derivative of our entire function above:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.plot(x, jax.vmap(grad_func)(x), label=\"AD\")\n",
"plt.plot(x, symbolic_grad(x), \"--\", label=\"symbolic\")\n",
"plt.xlabel(r\"$x$\")\n",
"plt.ylabel(r\"$\\mathrm{d} f(x) / \\mathrm{d} x$\")\n",
"plt.xlim(x.min(), x.max())\n",
"plt.legend();"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"This example is pretty artificial, but I think that you can imagine how something like this would start to come in handy when your models get more complicated.\n",
"In particular, I think that you'll regularly find yourself experimenting with different choices of parameters, and it would be a real pain to be required to re-write all your derivative code for every new choice of parameterization.\n",
"\n",
"## Some more realistic examples\n",
"\n",
"Straightforward AD with `JAX` works well as long as everything you're doing can be easily and efficiently computed using `jax.numpy`.\n",
"However, in many exoplanet and other astrophysics applications, we need to evaluate physical models that are frequently computed numerically, and things are less simple.\n",
"A major driver of `jaxoplanet` is to provide some required custom operations to enable the use of `JAX` for exoplanet data analysis, including tasks like solving Kepler's equation, or computing the light curve for a limb-darkened exoplanet transit.\n",
"Most users shouldn't expect to typically interface with these custom operations directly, but they are exposed through the `jaxoplanet.core` module.\n",
"\n",
"### Solving Kepler's equation\n",
"\n",
"To start, let's solve for the true anomaly for a Keplerian orbit, and its derivative using the `jaxoplanet.core.kepler` function:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from jaxoplanet.core import kepler\n",
"\n",
"# The `kepler` function returns the sine and cosine of the true anomaly, so we\n",
"# need to take an `arctan` to get the value directly:\n",
"get_true_anomaly = lambda *args: jnp.arctan2(*kepler(*args))\n",
"\n",
"# The following functions compute the partial derivatives of the true anomaly as\n",
"# a function of mean anomaly and eccentricity, respectively:\n",
"d_true_d_mean = jax.vmap(jax.grad(get_true_anomaly, argnums=0), in_axes=(0, None))\n",
"d_true_d_ecc = jax.vmap(jax.grad(get_true_anomaly, argnums=1), in_axes=(0, None))\n",
"\n",
"mean_anomaly = jnp.linspace(-jnp.pi, jnp.pi, 1000)[:-1]\n",
"ecc = 0.5\n",
"true_anomaly = get_true_anomaly(mean_anomaly, ecc)\n",
"plt.plot(mean_anomaly, true_anomaly, label=f\"$f(M,e={ecc:.1f})$\")\n",
"plt.plot(\n",
" mean_anomaly,\n",
" d_true_d_mean(mean_anomaly, ecc),\n",
" label=r\"$\\mathrm{d}f(M,e)/\\mathrm{d}M$\",\n",
")\n",
"plt.plot(\n",
" mean_anomaly,\n",
" d_true_d_ecc(mean_anomaly, ecc),\n",
" label=r\"$\\mathrm{d}f(M,e)/\\mathrm{d}e$\",\n",
")\n",
"plt.legend()\n",
"plt.xlabel(\"mean anomaly\")\n",
"plt.ylabel(\"true anomaly, $f$; partial derivatives\")\n",
"plt.xlim(-jnp.pi, jnp.pi);"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Of note, the Kepler solver provided by `jaxoplanet` does not use an iterative method like those commonly used for exoplanet fitting tasks.\n",
"Instead, it uses a two-step solver which can be more efficiently parallelized using hardware acceleration like SIMD or a GPU.\n",
"Even so, it is not computationally efficient or numerically stable to directly apply AD to this solver function.\n",
"Instead, `jaxoplanet` uses the `jax.custom_jvp` interface to provide custom partial derivatives for this operation.\n",
"\n",
"### Limb-darkened transit light curves\n",
"\n",
"`jaxoplanet` also provides a custom operation for evaluating the light curve of an exoplanet transiting a limb-darkened star, with arbitrary order polynomial limb darkening laws.\n",
"This operation uses a re-implementation of the algorithms developed for the `starry` library in `JAX`.\n",
"As above, we can use AD to evaluate the derivatives of this light curve model.\n",
"For example, here's a quadratically limb-darkened model and its partial derivatives:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from jaxoplanet.core.limb_dark import light_curve\n",
"\n",
"lc = lambda u1, u2, b, r: light_curve(jnp.array([u1, u2]), b, r)\n",
"b = jnp.linspace(-1.2, 1.2, 1001)\n",
"r = 0.1\n",
"u1, u2 = 0.2, 0.3\n",
"\n",
"_, axes = plt.subplots(5, 1, figsize=(6, 10), sharex=True)\n",
"\n",
"axes[0].plot(b, lc(u1, u2, b, r))\n",
"axes[0].set_ylabel(\"flux\")\n",
"axes[0].yaxis.set_label_coords(-0.15, 0.5)\n",
"for n, name in enumerate([\"$u_1$\", \"$u_2$\", \"$b$\", \"$r$\"]):\n",
" axes[n + 1].plot(\n",
" b,\n",
" jax.vmap(jax.grad(lc, argnums=n), in_axes=(None, None, 0, None))(u1, u2, b, r),\n",
" )\n",
" axes[n + 1].set_ylabel(f\"d flux / d {name}\")\n",
" axes[n + 1].yaxis.set_label_coords(-0.15, 0.5)\n",
"axes[-1].set_xlabel(\"impact parameter\")\n",
"axes[-1].set_xlim(-1.2, 1.2);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 4
}