Source code for jaxoplanet.experimental.starry.multiprecision.solution

from collections import defaultdict

from jaxoplanet.experimental.starry.multiprecision import mp

[docs] CACHED_MATRICES = defaultdict( lambda: { "sT": {}, } )
[docs] def get_sT(l_max, b, r, cache=None): if cache is None: cache = CACHED_MATRICES name = "sT" br = (b, r) if br not in cache[l_max][name]: cache[l_max][name][br] = sT(l_max, b, r) return cache[l_max][name][br]
[docs] def check_occultation(b, r): if not (abs(b) < r + 1.0): raise ValueError(f"No occultation with b = {b} and r = {r}.") if r >= 1 and b <= abs(1 - r): raise ValueError(f"occultation is full with b = {b} and r = {r}.")
[docs] def kappas(b, r): b2 = b**2 factor = (r - 1) * (r + 1) area = kite_area(r, b, 1.0) return mp.atan2(area, b2 + factor), mp.atan2(area, b2 - factor)
[docs] def kite_area(a, b, c): def sort2(a, b): return min(a, b), max(a, b) a, b = sort2(a, b) b, c = sort2(b, c) a, b = sort2(a, b) square_area = (a + (b + c)) * (c - (a - b)) * (c + (a - b)) * (a + (b - c)) return mp.sqrt(max(0, square_area))
[docs] def P(l, m, b, r): """Compute the P integral numerically from its new parametrization.""" mu = l - m nu = l + m if (abs(1 - r) < b) and (b < 1 + r): kappa = kappas(b, r)[0] phi = kappa - mp.pi / 2 else: phi = mp.pi / 2 kappa = mp.pi kappa = phi + mp.pi / 2 delta = (b - r) / (2 * r) k2 = (1 - r**2 - b**2 + 2 * b * r) / (4 * b * r) if (mu / 2) % 2 == 0: def func(x): s = mp.sin(x) return ( 2 * (2 * r) ** (l + 2) * (s**2 - s**4) ** (0.25 * (mu + 4)) * (delta + s**2) ** (0.5 * nu) ) elif (mu == 1) and (l % 2 == 0): def func(x): s = mp.sin(x) return ( (2 * r) ** (l - 1) * (4 * b * r) ** (3.0 / 2.0) * (s**2 - s**4) ** (0.5 * (l - 2)) * (k2 - s**2) ** (3.0 / 2.0) * (1 - 2 * s**2) ) elif (mu == 1) and (l != 1) and (l % 2 != 0): def func(x): s = mp.sin(x) return ( (2 * r) ** (l - 1) * (4 * b * r) ** (3.0 / 2.0) * (s**2 - s**4) ** (0.5 * (l - 3)) * (delta + s**2) * (k2 - s**2) ** (3.0 / 2.0) * (1 - 2 * s**2) ) elif ((mu - 1) % 2) == 0 and ((mu - 1) // 2 % 2 == 0) and (l != 1): def func(x): s = mp.sin(x) return ( 2 * (2 * r) ** (l - 1) * (4 * b * r) ** (3.0 / 2.0) * (s**2 - s**4) ** (0.25 * (mu - 1)) * (delta + s**2) ** (0.5 * (nu - 1)) * (k2 - s**2) ** (3.0 / 2.0) ) elif (mu == 1) and (l == 1): def func(x): c = mp.cos(x + 0.5 * kappa) omz2 = r**2 + b**2 - 2 * b * r * c if omz2 <= mp.mpf(0.0): return 0.0 else: z2 = max(0, 1 - omz2) return 2 * r * (r - b * c) * (1 - z2 * mp.sqrt(z2)) / (3 * omz2) else: return 0 res = mp.quad(func, [-kappa / 2, kappa / 2]) return res
[docs] def p_numerical(l_max, b, r): """Compute the P integral numerically.""" p = [] for l in range(l_max + 1): for m in range(-l, l + 1): p.append(P(l, m, b, r)) return mp.matrix(p).apply(mp.re)
[docs] def q_numerical(l_max, b, r): if (abs(1 - r) < b) and (b < 1 + r): lam = 0.5 * mp.pi - kappas(b, r)[1] else: lam = mp.pi / 2 zero = 0.0 c = mp.cos(lam) s = mp.sin(lam) h = { (0, 0): 2 * lam + mp.pi, (0, 1): -2 * c, } def get(u: int, v: int): if (u, v) in h: return h[(u, v)] if u >= 2: comp = 2 * c ** (u - 1) * s ** (v + 1) + (u - 1) * get(u - 2, v) else: assert v >= 2 comp = -2 * c ** (u + 1) * s ** (v - 1) + (v - 1) * get(u, v - 2) comp /= u + v h[(u, v)] = comp return comp U = [] for l in range(l_max + 1): # noqa for m in range(-l, l + 1): if l == 1 and m == 0: U.append((mp.pi + 2 * lam) / 3) continue mu = l - m nu = l + m if (mu % 2) == 0 and (mu // 2) % 2 == 0: u = mu // 2 + 2 v = nu // 2 assert u % 2 == 0 U.append(get(u, v)) else: U.append(zero) return mp.matrix(U)
[docs] def sT(l_max, b, r): r = abs(r) b = abs(b) return q_numerical(l_max, b, r) - p_numerical(l_max, b, r)
[docs] def rT(lmax: int): rt = [0.0 for _ in range((lmax + 1) * (lmax + 1))] amp0 = mp.pi lfac1 = 1.0 lfac2 = 2.0 / 3.0 for ell in range(0, lmax + 1, 4): amp = amp0 for m in range(0, ell + 1, 4): mu = ell - m nu = ell + m rt[ell * ell + ell + m] = amp * lfac1 rt[ell * ell + ell - m] = amp * lfac1 if ell < lmax: rt[(ell + 1) * (ell + 1) + ell + m + 1] = amp * lfac2 rt[(ell + 1) * (ell + 1) + ell - m + 1] = amp * lfac2 amp *= (nu + 2.0) / (mu - 2.0) lfac1 /= (ell / 2 + 2) * (ell / 2 + 3) lfac2 /= (ell / 2 + 2.5) * (ell / 2 + 3.5) amp0 *= 0.0625 * (ell + 2) * (ell + 2) amp0 = 0.5 * mp.pi lfac1 = 0.5 lfac2 = 4.0 / 15.0 for ell in range(2, lmax + 1, 4): amp = amp0 for m in range(2, ell + 1, 4): mu = ell - m nu = ell + m rt[ell * ell + ell + m] = amp * lfac1 rt[ell * ell + ell - m] = amp * lfac1 if ell < lmax: rt[(ell + 1) * (ell + 1) + ell + m + 1] = amp * lfac2 rt[(ell + 1) * (ell + 1) + ell - m + 1] = amp * lfac2 amp *= (nu + 2.0) / (mu - 2.0) lfac1 /= (ell / 2 + 2) * (ell / 2 + 3) lfac2 /= (ell / 2 + 2.5) * (ell / 2 + 3.5) amp0 *= 0.0625 * ell * (ell + 4) return mp.matrix(rt)