Source code for jaxoplanet.starry.ylm

r"""A module to manipulate vectors in the spherical harmonic basis.

The spherical harmonics basis is a set of orthogonal functions defined on the
unit sphere. In jaxoplanet, this basis is used to represent the intensity at the surface
of a spherical body, such as a star or a planet. We say that :math:`y` represents the
intensity of a surface in the spherical harmonics basis if the specific intensity at the
:math:`(x,y)` on the surface can be written as:

.. math::

    I(x, y) = \mathbf{\tilde{y}_n^\mathsf{T}} (x, y) \, \mathbf{y}
    \quad,

where :math:`\tilde{y}_n` is the **spherical harmonic basis**,
arranged in increasing degree and order:

.. math::

    \mathbf{\tilde{y}_n} =
    \begin{pmatrix}
        Y_{0, 0} &
        Y_{1, -1} & Y_{1, 0} & Y_{1, 1} &
        Y_{2, -2} & Y_{2, -1} & Y_{2, 0} & Y_{2, 1} & Y_{2, 2} &
        \cdot\cdot\cdot
    \end{pmatrix}^\mathsf{T}
    \quad,

where :math:`Y_{l, m} = Y_{l, m}(x, y)` is the spherical harmonic of degree :math:`l`
and order :math:`m`. For reference, in this basis the coefficient of the spherical
harmonic :math:`Y_{l, m}` is located at the index

.. math::

    n = l^2 + l + m

"""

import math
from collections import defaultdict
from collections.abc import Mapping
from typing import Any

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.sparse import BCOO
from scipy.special import legendre as LegendreP

from jaxoplanet.starry.core import basis, solution
from jaxoplanet.starry.core.polynomials import Pijk
from jaxoplanet.starry.core.rotation import dot_rotation_matrix
from jaxoplanet.starry.core.wigner3j import Wigner3jCalculator
from jaxoplanet.types import Array


[docs] class Ylm(eqx.Module): """Ylm object containing the spherical harmonic coefficients. Args: data (Mapping[tuple[int, int], Array], optional): dictionary of spherical harmonic coefficients. Defaults to {(0, 0): 1.0}. """
[docs] data: dict[tuple[int, int], Array]
"""coefficients of the spherical harmonic expansion of the map in the form `{(l, m): coefficient}`"""
[docs] deg: int = eqx.field(static=True)
"""The maximum degree of the spherical harmonic coefficients."""
[docs] diagonal: bool = eqx.field(static=True)
"""Whether are orders m of the spherical harmonic coefficients are zero. Diagonal if only the degrees "l" are non-zero.""" def __init__( self, data: Mapping[tuple[int, int], Array] | None = None, ): if data is None: data = {(0, 0): 1.0} self.data = dict(data) self.deg = max(ell for ell, _ in data.keys()) self.diagonal = all(m == 0 for _, m in data.keys()) @property
[docs] def shape(self) -> tuple[int, ...]: """The number of coefficients in the basis. This sets the shape of the output of `todense`.""" return (self.deg**2 + 2 * self.deg + 1,)
@property
[docs] def indices(self) -> list[tuple[int, int]]: """List of (l,m) indices of the spherical harmonic coefficients.""" return list(self.data.keys())
@staticmethod
[docs] def index(l: Array, m: Array) -> Array: """Convert the degree and order of the spherical harmonic to the corresponding index in the coefficient array.""" return l * (l + 1) + m
[docs] def normalize(self) -> "Ylm": """Return a new Ylm instance with coefficients normalized to :math:`Y_{0,0}`. Returns: Ylm instance with normalized coefficients. Raises: ValueError: if the (0, 0) coefficient is zero. """ data = {k: v / self.data[(0, 0)] for k, v in self.data.items()} return Ylm(data=data)
[docs] def tosparse(self) -> BCOO: """Return a sparse (jax.experimental.sparse.BCOO) spherical harmonic coefficients vector where the spherical harmonic :math:`Y_{l, m}` is located at the index :math:`n = l^2 + l + m`. """ indices, values = zip(*self.data.items(), strict=False) idx = jnp.array([Ylm.index(l, m) for l, m in indices])[:, None] return BCOO((jnp.asarray(values), idx), shape=self.shape)
[docs] def todense(self) -> Array: """Return a dense spherical harmonic coefficients vector where the spherical harmonic :math:`Y_{l, m}` is located at the index :math:`n = l^2 + l + m`. """ return self.tosparse().todense()
@classmethod
[docs] def from_dense(cls, y: Array, normalize: bool = True) -> "Ylm": """Create a Ylm object from a dense array of spherical harmonic coefficients where the spherical harmonic :math:`Y_{l, m}` is located at the index :math:`n = l^2 + l + m`. """ data = {} for i, ylm in enumerate(y): l = int(np.floor(np.sqrt(i))) m = i - l * (l + 1) data[(l, m)] = ylm ylm = cls(data) if normalize: return ylm.normalize() else: return ylm
[docs] def __mul__(self, other: Any) -> "Ylm": if isinstance(other, Ylm): return _mul(self, other) else: return jax.tree_util.tree_map(lambda x: x * other, self)
[docs] def __rmul__(self, other: Any) -> "Ylm": assert not isinstance(other, Ylm) return jax.tree_util.tree_map(lambda x: other * x, self)
[docs] def __getitem__(self, key) -> Array: assert isinstance(key, tuple) return self.todense()[Ylm.index(*key)]
@classmethod
[docs] def from_limb_darkening(cls, u: Array) -> "Ylm": """ Spherical harmonics coefficients from limb darkening coefficients. """ deg = len(u) _u = np.array([1, *u]) pu = _u @ basis.U(deg) yu = np.array(np.linalg.inv(basis.A1(deg).todense()) @ pu) yu = Ylm.from_dense(yu.flatten(), normalize=False) norm = 1 / (Pijk.from_dense(pu, degree=deg).tosparse() @ solution.rT(deg)) return yu * norm
def _mul(f: Ylm, g: Ylm) -> Ylm: """ Based closely on the implementation from the MIT-licensed spherical package: https://github.com/moble/spherical/blob/0aa81c309cac70b90f8dfb743ce35d2cc9ae6dee/ spherical/multiplication.py """ ellmax_f = f.deg ellmax_g = g.deg ellmax_fg = ellmax_f + ellmax_g fg = defaultdict(lambda *_: 0.0) m_calculator = Wigner3jCalculator(ellmax_f, ellmax_g) for ell1 in range(ellmax_f + 1): sqrt1 = math.sqrt((2 * ell1 + 1) / (4 * math.pi)) for m1 in range(-ell1, ell1 + 1): idx1 = (ell1, m1) if idx1 not in f.data: continue sum1 = sqrt1 * f.data[idx1] for ell2 in range(ellmax_g + 1): sqrt2 = math.sqrt(2 * ell2 + 1) # w3j_s = s_calculator.calculate(ell1, ell2, s_f, s_g) for m2 in range(-ell2, ell2 + 1): idx2 = (ell2, m2) if idx2 not in g.data: continue w3j_m = m_calculator.calculate(ell1, ell2, m1, m2) sum2 = sqrt2 * g.data[idx2] m3 = m1 + m2 for ell3 in range( max(abs(m3), abs(ell1 - ell2)), min(ell1 + ell2, ellmax_fg) + 1 ): # Could loop over same (ell3, m3) more than once, so add all # contributions together fg[(ell3, m3)] += ( ( math.pow(-1, ell1 + ell2 + ell3 + m3) * math.sqrt(2 * ell3 + 1) * w3j_m[ell3] # Wigner3j(ell1, ell2, ell3, m1, m2, -m3) ) * sum1 * sum2 ) return Ylm(fg)
[docs] def Bp(ydeg, npts: int = 1000, eps: float = 1e-9, smoothing=None): """ Return the matrix B+. This expands the spot profile `b` in Legendre polynomials. From https://github.com/rodluger/ mapping_stellar_surfaces/blob/paper2-arxiv/paper2/figures/spot_profile.py and _spot_setup in starry/_core/core.py. """ if smoothing is None: if ydeg < 4: smoothing = 0.5 else: smoothing = 2.0 / ydeg theta = jnp.linspace(0, jnp.pi, npts) cost = jnp.cos(theta) B = jnp.hstack( [ jnp.sqrt(2 * l + 1) * LegendreP(l)(cost).reshape(-1, 1) for l in range(ydeg + 1) ] ) _Bp = jnp.linalg.solve(B.T @ B + eps * jnp.eye(ydeg + 1), B.T) l = jnp.arange(ydeg + 1) indices = l * (l + 1) S = jnp.exp(-0.5 * indices * smoothing**2) return (S[:, None] * _Bp, theta, indices)
[docs] def spot_profile(theta, radius, spot_fac=300): """ The sigmoid spot profile. """ z = spot_fac * (theta - radius) return 1 / (1 + jnp.exp(-z)) - 1
[docs] def ylm_spot(ydeg: int, npts=300, spot_fac=300) -> callable: """spot expansion in the spherical harmonics basis. Args: ydeg (int): max degree of the spherical harmonics Returns: callable: function that returns the spherical harmonics coefficients of the spot """ B, theta, indices = Bp(ydeg, npts=npts) def func(contrast: float, r: float, lat: float = 0.0, lon: float = 0.0): """spot expansion in the spherical harmonics basis. Args: contrast (float): spot contrast, defined as (1-c) where c is the intensity of the center of the spot relative to the unspotted surface. A contrast of 1. means that the spot intensity drops to zero at the center, 0. means that the intensity at the center of the spot is the same as the intensity of the unspotted surface. r (float): radius of the spot. lat (float, optional): latitude of the spot, assuming that the center of a star with an inclination of pi/2 has a latitude of 0. Defaults to 0.0. lon (float, optional): longitude of the spot, assuming that the center of a star with an inclination of pi/2 has a longitude of 0. Defaults to 0.0. Returns: Ylm: Ylm object containing the spherical harmonics coefficients of the spot """ b = spot_profile(theta, r, spot_fac=spot_fac) y = jnp.zeros((ydeg + 1) * (ydeg + 1)) y = y.at[indices].set(B @ b * contrast) y = y.at[0].set(y[0] + 1.0) y = dot_rotation_matrix(ydeg, 1.0, 0.0, 0.0, lat)(y) y = dot_rotation_matrix(ydeg, 0.0, 1.0, 0.0, -lon)(y) return Ylm.from_dense(y, normalize=False) return func
# TODO
[docs] def ring_y(l_max, pts=1000, eps=1e-9, smoothing=None): Bp, theta, idxs = None, None, None n_max = l_max**2 + 2 * l_max + 1 def _y(contrast: float, width: float, latitude: float): b = 1 - jnp.array((theta > latitude - width) & (theta < latitude + width)) y = jnp.zeros(n_max) y = y.at[idxs].set(Bp @ b) return y * contrast return _y