Source code for solution

from collections.abc import Callable
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
from scipy.special import roots_legendre

from jaxoplanet.core.limb_dark import kite_area
from jaxoplanet.types import Array
from jaxoplanet.utils import zero_safe_sqrt


[docs] def solution_vector(l_max: int, order: int = 20) -> Callable[[Array, Array], Array]: n_max = l_max**2 + 2 * l_max + 1 @jax.jit @partial(jnp.vectorize, signature=f"(),()->({n_max})") def impl(b: Array, r: Array) -> Array: b = jnp.abs(b) r = jnp.abs(r) kappa0, kappa1 = kappas(b, r) P = p_integral(order, l_max, b, r, kappa0) Q = q_integral(l_max, 0.5 * jnp.pi - kappa1) return Q - P return impl
[docs] def kappas(b: Array, r: Array) -> tuple[Array, Array]: b2 = jnp.square(b) factor = (r - 1) * (r + 1) b_cond = jnp.logical_and(jnp.greater(b, jnp.abs(1 - r)), jnp.less(b, 1 + r)) b_ = jnp.where(b_cond, b, 1) area = jnp.where(b_cond, kite_area(r, b_, 1), 0) return jnp.arctan2(area, b2 + factor), jnp.arctan2(area, b2 - factor)
[docs] def q_integral(l_max: int, lam: Array) -> Array: zero = jnp.zeros_like(lam) c = jnp.cos(lam) s = jnp.sin(lam) h = { (0, 0): 2 * lam + jnp.pi, (0, 1): -2 * c, } def get(u: int, v: int) -> Array: if (u, v) in h: return h[(u, v)] if u >= 2: comp = 2 * c ** (u - 1) * s ** (v + 1) + (u - 1) * get(u - 2, v) else: assert v >= 2 comp = -2 * c ** (u + 1) * s ** (v - 1) + (v - 1) * get(u, v - 2) comp /= u + v h[(u, v)] = comp return comp U = [] for l in range(l_max + 1): # noqa for m in range(-l, l + 1): if l == 1 and m == 0: U.append((np.pi + 2 * lam) / 3) continue mu = l - m nu = l + m if (mu % 2) == 0 and (mu // 2) % 2 == 0: u = mu // 2 + 2 v = nu // 2 assert u % 2 == 0 U.append(get(u, v)) else: U.append(zero) return jnp.stack(U)
[docs] def p_integral(order: int, l_max: int, b: Array, r: Array, kappa0: Array) -> Array: """Numerical integration of the P integral using the Gauss-Legendre quadrature. As described in Equation D32 of Luger et al. (2019), there are 6 cases to consider. Empirically, we notice that the numerical integration of the first case (mu/2 even) is precise at very low order. Hence ``low_order=30``is used for the first case. For the other cases, we use the order specified by the user, renamed in the function ``high_order``. We also note that outside the linear limb-darkening case (i.e. (l,m)=(1, 0), or n=2) the integrand is symmetrical in phi, so we can evaluate the integral over half the range and multiply by 2. Parameters ---------- order : int The order of the Gauss-Legendre quadrature. l_max : int The maximum degree of the spherical harmonic expansion. b : Array Impact parameter. r : Array Occultor radius. kappa0 : Array k0 angle. Returns ------- Array The integral of the P function over the occultor surface. """ b2 = jnp.square(b) r2 = jnp.square(r) # This is a hack for when r -> 0 or b -> 0, so k2 -> inf factor = 4 * b * r k2_cond = jnp.less(factor, 10 * jnp.finfo(factor.dtype).eps) factor = jnp.where(k2_cond, 1, factor) k2 = jnp.maximum(0, (1 - r2 - b2 + 2 * b * r) / factor) # And for when r -> 0 r_cond = jnp.less(r, 10 * jnp.finfo(r.dtype).eps) delta = (b - r) / (2 * jnp.where(r_cond, 1, r)) rng = 0.25 * kappa0 # low order variables low_order = np.min([order, 20]) roots, low_weights = roots_legendre(low_order) phi = rng * (roots + 1) low_s2 = jnp.square(jnp.sin(phi)) low_a1 = low_s2 - jnp.square(low_s2) low_a2 = jnp.where(r_cond, 0, delta + low_s2) # high order variables high_order = order high_roots, high_weights = roots_legendre(high_order) phi = rng * (high_roots + 1) high_s2 = jnp.square(jnp.sin(phi)) high_f0 = jnp.maximum(0, jnp.where(k2_cond, 1 - r2, factor * (k2 - high_s2))) ** 1.5 high_a1 = high_s2 - jnp.square(high_s2) high_a2 = jnp.where(r_cond, 0, delta + high_s2) high_a4 = 1 - 2 * high_s2 low_indices = [] low_integrand = [] high_indices = [] high_integrand = [] n = 0 for l in range(l_max + 1): # noqa high_fa3 = (2 * r) ** (l - 1) * high_f0 for m in range(-l, l + 1): mu = l - m nu = l + m if mu == 1 and l == 1: phi = 2 * rng * high_roots c = jnp.cos(phi + 0.5 * kappa0) omz2 = r2 + b2 - 2 * b * r * c cond = jnp.less(omz2, 10 * jnp.finfo(omz2.dtype).eps) omz2 = jnp.where(cond, 1, omz2) z2 = jnp.maximum(0, 1 - omz2) result = ( 2 * r * (r - b * c) * (1 - z2 * zero_safe_sqrt(z2)) / (3 * omz2) ) high_integrand.append(jnp.where(cond, 0, 2 * result)) high_indices.append(n) elif mu % 2 == 0 and (mu // 2) % 2 == 0: f = ( 2 * (2 * r) ** (l + 2) * low_a1 ** (0.25 * (mu + 4)) * low_a2 ** (0.5 * nu) ) low_integrand.append(2 * f) low_indices.append(n) elif mu == 1 and l % 2 == 0: f = high_fa3 * high_a1 ** (l // 2 - 1) * high_a4 high_integrand.append(2 * f) high_indices.append(n) elif mu == 1: f = high_fa3 * high_a1 ** ((l - 3) // 2) * high_a2 * high_a4 high_integrand.append(2 * f) high_indices.append(n) elif (mu - 1) % 2 == 0 and ((mu - 1) // 2) % 2 == 0: f = ( 2 * high_fa3 * high_a1 ** ((mu - 1) // 4) * high_a2 ** (0.5 * (nu - 1)) ) high_integrand.append(2 * f) high_indices.append(n) else: n += 1 continue n += 1 low_indices = np.stack(low_indices) high_indices = np.stack(high_indices) low_P0 = rng * jnp.sum(jnp.stack(low_integrand) * low_weights, axis=1) high_P0 = rng * jnp.sum(jnp.stack(high_integrand) * high_weights, axis=1) P = jnp.zeros(l_max**2 + 2 * l_max + 1) P = P.at[low_indices].set(low_P0) P = P.at[high_indices].set(high_P0) return P
[docs] def rT(lmax: int) -> Array: rt = [0.0 for _ in range((lmax + 1) * (lmax + 1))] amp0 = jnp.pi lfac1 = 1.0 lfac2 = 2.0 / 3.0 for ell in range(0, lmax + 1, 4): amp = amp0 for m in range(0, ell + 1, 4): mu = ell - m nu = ell + m rt[ell * ell + ell + m] = amp * lfac1 rt[ell * ell + ell - m] = amp * lfac1 if ell < lmax: rt[(ell + 1) * (ell + 1) + ell + m + 1] = amp * lfac2 rt[(ell + 1) * (ell + 1) + ell - m + 1] = amp * lfac2 amp *= (nu + 2.0) / (mu - 2.0) lfac1 /= (ell / 2 + 2) * (ell / 2 + 3) lfac2 /= (ell / 2 + 2.5) * (ell / 2 + 3.5) amp0 *= 0.0625 * (ell + 2) * (ell + 2) amp0 = 0.5 * jnp.pi lfac1 = 0.5 lfac2 = 4.0 / 15.0 for ell in range(2, lmax + 1, 4): amp = amp0 for m in range(2, ell + 1, 4): mu = ell - m nu = ell + m rt[ell * ell + ell + m] = amp * lfac1 rt[ell * ell + ell - m] = amp * lfac1 if ell < lmax: rt[(ell + 1) * (ell + 1) + ell + m + 1] = amp * lfac2 rt[(ell + 1) * (ell + 1) + ell - m + 1] = amp * lfac2 amp *= (nu + 2.0) / (mu - 2.0) lfac1 /= (ell / 2 + 2) * (ell / 2 + 3) lfac2 /= (ell / 2 + 2.5) * (ell / 2 + 3.5) amp0 *= 0.0625 * ell * (ell + 4) return np.array(rt)