Source code for jaxoplanet.starry.light_curves
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
import scipy
from jaxoplanet.core.limb_dark import light_curve as _limb_dark_light_curve
from jaxoplanet.starry.core.basis import A1, A2_inv, U
from jaxoplanet.starry.core.polynomials import Pijk
from jaxoplanet.starry.core.rotation import left_project
from jaxoplanet.starry.core.solution import rT, solution_vector
from jaxoplanet.starry.surface import Surface
from jaxoplanet.starry.system_observable import system_observable
[docs]
def surface_light_curve(
surface: Surface,
r: float | None = None,
x: float | None = None,
y: float | None = None,
z: float | None = None,
theta: float | None = None,
order: int = 20,
higher_precision: bool = False,
):
"""Light curve of an occulted surface.
Args:
surface (Surface): Surface object
r (float or None): radius of the occulting body, relative to the current map
body
x (float or None): x coordinate of the occulting body relative to the surface
center. By default (None) 0.0
y (float or None): y coordinate of the occulting body relative to the surface
center. By default (None) 0.0
z (float or None): z coordinate of the occulting body relative to the surface
center. By default (None) 0.0
theta (float):
rotation angle of the map, in radians. By default 0.0
order (int):
order of the P integral numerical approximation. By default 20
higher_precision (bool): whether to compute change of basis matrix as hight
precision. By default False (only used to testing).
Returns:
ArrayLike: flux
"""
if higher_precision:
try:
from jaxoplanet.starry.multiprecision import (
basis as basis_mp,
utils as utils_mp,
)
except ImportError as e:
raise ImportError(
"The `mpmath` Python package is required for higher_precision=True."
) from e
total_deg = surface.deg
rT_deg = rT(total_deg)
x = 0.0 if x is None else x
y = 0.0 if y is None else y
z = 0.0 if z is None else z
# no occulting body
if r is None:
b_rot = True
theta_z = 0.0
design_matrix_p = rT_deg
# occulting body
else:
b = jnp.sqrt(jnp.square(x) + jnp.square(y))
b_rot = jnp.logical_or(jnp.greater_equal(b, 1.0 + r), jnp.less_equal(z, 0.0))
b_occ = jnp.logical_not(b_rot)
# trick to avoid nan `x=jnp.where...` grad caused by nan sT
r = jnp.where(b_rot, 1.0, r)
b = jnp.where(b_rot, 1.0, b)
if surface.ydeg == 0:
if surface.udeg == 0:
ld_u = jnp.array([])
else:
ld_u = jnp.concatenate(
[jnp.atleast_1d(jnp.asarray(u_)) for u_ in surface.u], axis=0
)
lc_func = partial(_limb_dark_light_curve, ld_u, order=order)
lc = lc_func(b, r)
return surface.amplitude * (1.0 + jnp.where(b_occ, lc, 0))
else:
theta_z = jnp.arctan2(x, y)
sT = solution_vector(total_deg, order=order)(b, r)
if total_deg > 0:
if higher_precision:
A2 = np.atleast_2d(utils_mp.to_numpy(basis_mp.A2(total_deg)))
else:
A2 = scipy.sparse.linalg.inv(A2_inv(total_deg))
A2 = jax.experimental.sparse.BCOO.from_scipy_sparse(A2)
else:
A2 = jnp.array([[1]])
design_matrix_p = jnp.where(b_occ, sT @ A2, rT_deg)
if surface.ydeg == 0:
rotated_y = surface.y.todense()
else:
rotated_y = left_project(
surface.ydeg,
surface._inc,
surface._obl,
theta,
theta_z,
surface.y.todense(),
)
# limb darkening
if surface.udeg == 0:
p_u = Pijk.from_dense(jnp.array([1]))
else:
u = jnp.array([1, *surface.u])
p_u = Pijk.from_dense(u @ U(surface.udeg), degree=surface.udeg)
# surface map * limb darkening map
if higher_precision:
A1_val = np.atleast_2d(utils_mp.to_numpy(basis_mp.A1(surface.ydeg)))
else:
A1_val = jax.experimental.sparse.BCOO.from_scipy_sparse(A1(surface.ydeg))
p_y = Pijk.from_dense(A1_val @ rotated_y, degree=surface.ydeg)
p_yu = p_y * p_u
norm = np.pi / (p_u.tosparse() @ rT(surface.udeg))
return surface.amplitude * (p_yu.tosparse() @ design_matrix_p) * norm
[docs]
def light_curve(system, order=20):
return system_observable(surface_light_curve, order=order)(system)