Source code for jaxoplanet.starry.doppler

import jax
import jax.numpy as jnp
import numpy as np
import scipy

from jaxoplanet.starry.core import basis
from jaxoplanet.starry.core.basis import A1, A2_inv, U
from jaxoplanet.starry.core.polynomials import Pijk
from jaxoplanet.starry.core.rotation import dot_rz, left_project
from jaxoplanet.starry.core.solution import rT, solution_vector
from jaxoplanet.starry.surface import Surface
from jaxoplanet.starry.system_observable import system_observable


[docs] def rv_map_expansion(veq=None, inc=None, obl=None, alpha=None): cosi = jnp.cos(inc) if inc is not None else 0.0 sini = jnp.sin(inc) if inc is not None else 1.0 cosl = jnp.cos(obl) if obl is not None else 1.0 sinl = jnp.sin(obl) if obl is not None else 0.0 A = sini * cosl B = sini * sinl C = cosi A2 = A**2 B2 = B**2 C2 = C**2 sq3 = jnp.sqrt(3) if veq is None: return jnp.array([0.0, 0.0, 0.0, 0.0]) if alpha is None: return jnp.array([0.0, veq * sq3 * B / 3, 0, veq * sq3 * A / 3]) * jnp.pi else: return ( jnp.array( [ 0, veq * sq3 * B * (-(A2) * alpha - B2 * alpha - C2 * alpha + 5) / 15, 0, veq * sq3 * A * (-(A2) * alpha - B2 * alpha - C2 * alpha + 5) / 15, 0, 0, 0, 0, 0, veq * alpha * jnp.sqrt(70) * B * (3 * A2 - B2) / 70, veq * alpha * 2 * jnp.sqrt(105) * C * (-(A2) + B2) / 105, veq * alpha * jnp.sqrt(42) * B * (A2 + B2 - 4 * C2) / 210, 0, veq * alpha * jnp.sqrt(42) * A * (A2 + B2 - 4 * C2) / 210, veq * alpha * 4 * jnp.sqrt(105) * A * B * C / 105, veq * alpha * jnp.sqrt(70) * A * (A2 - 3 * B2) / 70, ] ) * jnp.pi )
[docs] def surface_radial_velocity( surface: Surface, r: float | None = None, x: float | None = None, y: float | None = None, z: float | None = None, theta: float | None = None, order: int = 20, higher_precision: bool = False, ): if higher_precision: try: from jaxoplanet.starry.multiprecision import ( basis as basis_mp, utils as utils_mp, ) except ImportError as e: raise ImportError( "The `mpmath` Python package is required for higher_precision=True." ) from e total_deg = surface.deg + surface.vdeg rT_deg = rT(total_deg) x = 0.0 if x is None else x y = 0.0 if y is None else y z = 0.0 if z is None else z # no occulting body if r is None: b_rot = True theta_z = 0.0 design_matrix_p = rT_deg # occulting body else: b = jnp.sqrt(jnp.square(x) + jnp.square(y)) b_rot = jnp.logical_or(jnp.greater_equal(b, 1.0 + r), jnp.less_equal(z, 0.0)) b_occ = jnp.logical_not(b_rot) theta_z = jnp.arctan2(x, y) # trick to avoid nan `x=jnp.where...` grad caused by nan sT r = jnp.where(b_rot, 1.0, r) b = jnp.where(b_rot, 1.0, b) sT = solution_vector(total_deg, order=order)(b, r) if total_deg > 0: if higher_precision: A2 = np.atleast_2d(utils_mp.to_numpy(basis_mp.A2(total_deg))) else: A2 = scipy.sparse.linalg.inv(A2_inv(total_deg)) A2 = jax.experimental.sparse.BCOO.from_scipy_sparse(A2) else: A2 = jnp.array([[1]]) design_matrix_p = jnp.where(b_occ, sT @ A2, rT_deg) if surface.ydeg == 0: rotated_y = surface.y.todense() else: rotated_y = left_project( surface.ydeg, surface._inc, surface._obl, theta, theta_z, surface.y.todense(), ) # limb darkening if surface.udeg == 0: p_u = Pijk.from_dense(jnp.array([1])) else: u = jnp.array([1, *surface.u]) p_u = Pijk.from_dense(u @ U(surface.udeg), degree=surface.udeg) # surface map * limb darkening map if higher_precision: A1_val = np.atleast_2d(utils_mp.to_numpy(basis_mp.A1(surface.ydeg))) else: A1_val = jax.experimental.sparse.BCOO.from_scipy_sparse(A1(surface.ydeg)) p_y = Pijk.from_dense(A1_val @ rotated_y, degree=surface.ydeg) norm = np.pi / (p_u.tosparse() @ rT(surface.udeg)) # the radial velocity map y_rv = rv_map_expansion( inc=surface._inc, obl=surface._obl, veq=surface.veq, alpha=None ) # rotated to account for the occultor rotated_rv_y = dot_rz(surface.vdeg, -theta_z)(y_rv) p_rv = Pijk.from_dense( jax.experimental.sparse.BCOO.from_scipy_sparse(basis.A1(surface.vdeg)).todense() @ rotated_rv_y ) rv = (p_y * p_rv * p_u).todense() @ design_matrix_p p_yu = p_y * p_u n_deg = (surface.deg + 1) ** 2 intensity = (p_yu.tosparse() @ design_matrix_p[0:n_deg]) * norm return surface.amplitude * rv * norm / intensity
[docs] def radial_velocity(system, order=20): bodies_rv = system_observable(surface_radial_velocity, order=order)(system) def impl(time): rvs = bodies_rv(time) rvs = rvs.at[:, 0].set(rvs[:, 0] + system.radial_velocity(time)[0]) return rvs return impl