Source code for jaxoplanet.starry.multiprecision.flux

from collections import defaultdict

from jaxoplanet.starry.core.basis import U
from jaxoplanet.starry.core.polynomials import Pijk
from jaxoplanet.starry.multiprecision import mp, utils
from jaxoplanet.starry.multiprecision.basis import A1, A2
from jaxoplanet.starry.multiprecision.rotation import (
    dot_rotation_matrix,
    dot_rz,
    get_R,
)
from jaxoplanet.starry.multiprecision.solution import get_sT, rT

[docs] CACHED_MATRICES = defaultdict( lambda: { "R_obl": {}, "R_inc": {}, "sT": {}, } )
[docs] def flux(ydeg=0, udeg=0, inc=mp.pi / 2, obl=0.0, cache=None): if cache is None: cache = CACHED_MATRICES deg = ydeg + udeg _rT = rT(deg) _A2 = A2(deg, cache=cache) if ydeg > 0: _A1 = A1(ydeg, cache=cache) R_px = get_R("R_px", ydeg, cache=cache) R_mx = get_R("R_mx", ydeg, cache=cache) R_obl = get_R("R_obl", ydeg, obl=obl, inc=inc, cache=cache) R_inc = get_R("R_inc", ydeg, obl=obl, inc=inc, cache=cache) else: _A1 = A1(0, cache=cache) if udeg > 0: rT_udeg = rT(udeg) def rotate_y(y, phi): y_rotated = dot_rotation_matrix(ydeg, None, None, R_px)(y) y_rotated = dot_rz(ydeg, phi)(y_rotated) y_rotated = dot_rotation_matrix(ydeg, None, None, R_mx)(y_rotated) y_rotated = dot_rotation_matrix(ydeg, None, None, R_obl)(y_rotated) y_rotated = dot_rotation_matrix(ydeg, None, None, R_inc)(y_rotated) return y_rotated def rot_flux(y, phi): y = mp.matrix(y.tolist()) y_rotated = rotate_y(y, phi) return ((_A1 @ y_rotated).T @ _rT)[0] def occ_flux(y, b, r, phi=0.0, u=None, rotate=True): xo = 0.0 yo = b theta_z = mp.atan2(xo, yo) _sT = get_sT(deg, b, r, cache=cache) x = _sT.T @ _A2 if ydeg > 0: y = mp.matrix(y.tolist()) if rotate: y_rotated = rotate_y(y, phi) y_rotated = dot_rz(ydeg, theta_z)(y_rotated) else: y_rotated = mp.matrix([1]) py = (_A1 @ y_rotated).T if udeg > 0: _U = utils.to_mp(U(udeg)) u = mp.matrix([1, *u]) pu = u.T @ _U norm = mp.pi / (pu @ rT_udeg)[0] py = Pijk.from_dense(pu.tolist()[0]) * Pijk.from_dense(py) py = mp.matrix( [py.data.get(Pijk.n2ijk(i), 0.0) for i in range((py.degree + 1) ** 2)] ).T else: norm = 1 return norm * py @ x.T def impl(b, r, y=(), phi=0.0, u=(), rotate=True): assert len(u) == udeg, "u must have length udeg" if ydeg > 0: assert len(y) == (ydeg + 1) ** 2, "y must have length (ydeg + 1) ** 2" if abs(b) >= (r + 1): return rot_flux(y, phi) elif r > 1 and abs(b) <= (r - 1): return mp.mpf(0.0) else: return occ_flux(y, b, r, phi, u=u, rotate=rotate)[0] return impl