from collections.abc import Iterable
import equinox as eqx
import jax
import jax.numpy as jnp
from jax.scipy.spatial.transform import Rotation
from jaxoplanet import starry
from jaxoplanet.starry.core.basis import A1, U, poly_basis
from jaxoplanet.starry.core.polynomials import Pijk
from jaxoplanet.starry.core.rotation import full_rotation_axis_angle, left_project
from jaxoplanet.starry.utils import ortho_grid
from jaxoplanet.starry.ylm import Ylm
from jaxoplanet.types import Array, Scalar
[docs]
class Surface(eqx.Module):
"""Surface map object.
Args:
y (Optional(:py:class:`~jaxoplanet.starry.ylm.Ylm`)) Ylm object containing the
spherical harmonic expansion of the map. Defaults to a uniform map with
amplitude 1.0.
inc (Optional[Scalar]): inclination of the map relative to line of sight.
Defaults to pi/2 [angular unit].
obl (Optional[Scalar]): obliquity of the map [angular unit]. Defaults to None.
u (Optional[Array]): polynomial limb-darkening coefficients of the map.
period (Optional[Scalar]): rotation period of the map [time unit]. Defaults to
None.
amplitude (Optional[float]): amplitude of the map; this Scalar is proportional
to the luminosity of the map and multiplies all flux-related observables.
Defaults to 1.0.
normalize (Optional(bool)): whether to normalize the coefficients of the
spherical harmonics. If True, Ylm is normalized and the amplitude of the map
is set to y[(0, 0)]. Defaults to True.
phase (Optional[float]): initial phase of the map rotation around the polar
axis. Defaults to 0.0.
Example:
.. code-block:: python
import numpy as np
import jax
from jaxoplanet.starry.visualization import show_surface
from jaxoplanet.starry.surface import Surface
from jaxoplanet.starry.ylm import Ylm
jax.config.update("jax_enable_x64", True)
np.random.seed(30)
y = Ylm.from_dense(np.random.rand(20))
m = Surface(y=y, u=[0.5, 0.1], inc=0.9, obl=-0.3)
show_surface(m)
"""
""":py:class:`~starry.ylm.Ylm` object representing the spherical harmonic expansion
of the map"""
_inc: Array | None
"""Inclination of the map in radians. None if seen from the pole."""
_obl: Array | None
"""Obliquity of the map in radians. None if no obliquity."""
"""Tuple of limb darkening coefficients."""
"""Rotation period of the map in days (attribute subject to change). None if not
rotating."""
"""Amplitude of the map, a Scalar proportional to map luminosity."""
"""Boolean to specify whether the Ylm coefficients should be normalized"""
"""Initial phase of the map rotation around polar axis"""
"""Radius of the map in solar radii"""
"""Differential rotation shear of the map"""
def __init__(
self,
*,
y: Ylm | None = None,
inc: Scalar | None = 0.5 * jnp.pi,
obl: Scalar | None = None,
u: Iterable[Array] = (),
period: Scalar | None = None,
amplitude: Array = 1.0,
normalize: bool = True,
phase: Array = 0.0,
radius: Array = 1.0,
shear: Array = None,
):
if y is None:
y = Ylm()
if normalize:
amplitude = jnp.array(y[(0, 0)], float)
y = Ylm(data=y.data).normalize()
self.y = y
self._inc = inc
self._obl = obl
self.u = tuple(u)
self.period = period
self.amplitude = amplitude
self.normalize = normalize
self.phase = phase
self.radius = radius
self.shear = shear
@property
[docs]
def inc(self):
return self._inc if self._inc is not None else 0.0
@inc.setter
def inc(self, value):
self._inc = value
@property
[docs]
def obl(self):
return self._obl if self._obl is not None else 0.0
@obl.setter
def obl(self, value):
self._obl = value
@property
[docs]
def veq(self):
"""Equatorial velocity of the map in Rsun/day."""
return 2 * jnp.pi * self.radius / self.period
def _poly_basis(self, rv=False):
if rv:
return jax.jit(poly_basis(self.deg + self.vdeg))
else:
return jax.jit(poly_basis(self.deg))
@property
[docs]
def udeg(self) -> int:
"""Order of the polynomial limb darkening."""
return len(self.u)
@property
[docs]
def ydeg(self) -> int:
return self.y.deg
@property
[docs]
def vdeg(self) -> int:
if self.shear is not None:
return
else:
return 1
@property
[docs]
def deg(self):
"""Total degree of the spherical harmonic expansion (``udeg + ydeg``)."""
return self.ydeg + self.udeg
def _intensity(self, x, y, z, theta=None, rv=False):
pT = self._poly_basis(rv)(x, y, z)
Ry = left_project(self.ydeg, self.inc, self.obl, theta, 0.0, self.y.todense())
A1Ry = A1(self.ydeg).todense() @ Ry
p_y = Pijk.from_dense(A1Ry, degree=self.ydeg)
u = jnp.array([1, *self.u])
p_u = Pijk.from_dense(u @ U(self.udeg), degree=self.udeg)
p = p_y * p_u
if rv:
y_rv = starry.doppler.rv_map_expansion(
inc=self._inc, obl=self._obl, veq=self.veq, alpha=None
)
p_rv = Pijk.from_dense(
jax.experimental.sparse.BCOO.from_scipy_sparse(A1(self.vdeg)).todense()
@ y_rv
)
p = p_y * p_u * p_rv
return pT @ p.todense() * self.amplitude
# @partial(jax.jit, static_argnames=("res",))
[docs]
def render(self, theta: float | None = None, res: int = 400, rv: bool = False):
"""Returns the intensity map projected onto the x-y plane (sky).
Args:
theta (float, optional): rotation angle of the map. Defaults to 0.0.
res (int, optional): resolution of the render. Defaults to 400.
Returns:
ArrayLike: square 2D array representing the intensity map
(with nans outside the map disk).
"""
_, xyz = ortho_grid(res)
intensity = self._intensity(*xyz, theta=theta, rv=rv)
return jnp.reshape(intensity, (res, res))
[docs]
def intensity(self, lat: float, lon: float):
"""Returns the intensity of the map at a given latitude and longitude.
Args:
lat (float): latitude in the rest frame of the map
lon (float): longitude in the rest frame of the map
Returns:
float: intensity of the map at the given latitude and longitude
"""
lon = lon + jnp.pi / 2 # convention, 0 lon faces the observer
lat = jnp.pi / 2 - lat # convention, latitude 0 is equator
x = jnp.sin(lat) * jnp.cos(lon)
y = jnp.sin(lat) * jnp.sin(lon)
z = jnp.cos(lat) * jnp.ones_like(x)
axis = full_rotation_axis_angle(self.inc - jnp.pi / 2, self.obl, 0.0, 0.0)
axis = jnp.array(axis[0:3]) * axis[-1]
rotation = Rotation.from_rotvec(axis)
x, y, z = rotation.apply(jnp.array([x, y, z]).T).T
return self._intensity(x, y, z)
[docs]
def rotational_phase(self, time: Array) -> Array | None:
"""Returns the rotational phase of the map at a given time.
Args:
time (ArrayLike): time in same units as the period
Returns:
ArrayLike: rotational phase of the map at the given time
"""
if self.period is None:
return None
else:
return 2 * jnp.pi * time / self.period + self.phase