Source code for jaxoplanet.starry.utils

from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
from scipy.spatial.transform import Rotation


@partial(jax.jit, static_argnums=(0))
[docs] def ortho_grid(res: int): x, y = jnp.meshgrid(jnp.linspace(-1, 1, res), jnp.linspace(-1, 1, res)) z = jnp.sqrt(1.0 - x**2 - y**2) y = y + 0.0 * z # propagate nans x = jnp.ravel(x)[None, :] y = jnp.ravel(y)[None, :] z = jnp.ravel(z)[None, :] lat = 0.5 * jnp.pi - jnp.arccos(y) lon = jnp.arctan2(x, z) return (lat, lon), (x, y, z)
[docs] def lon_lat_lines(n: int = 6, pts: int = 100, radius: float = 1.0): assert isinstance(n, int) or len(n) == 2 if isinstance(n, int): n = (n, 2 * n) n_lat, n_lon = n sqrt_radius = radius _theta = np.linspace(0, 2 * np.pi, pts) _phi = np.linspace(0, np.pi, n_lat + 1) lat = np.array( [ (r * np.cos(_theta), r * np.sin(_theta), np.ones_like(_theta) * h) for (h, r) in zip( sqrt_radius * np.cos(_phi), sqrt_radius * np.sin(_phi), strict=False ) ] ) _theta = np.linspace(0, np.pi, pts // 2) _phi = np.linspace(0, 2 * np.pi, n_lon + 1)[0:-1] radii = np.sin(_theta) lon = np.array( [ ( sqrt_radius * radii * np.cos(p), sqrt_radius * radii * np.sin(p), sqrt_radius * np.cos(_theta), ) for p in _phi ] ) return lat, lon
[docs] def rotation(inc, obl, theta): obl = np.array(obl) u = [np.cos(obl), np.sin(obl), 0] u /= np.linalg.norm(u) u *= -(inc - np.pi / 2) R = Rotation.from_rotvec(u) R *= Rotation.from_rotvec([0, 0, obl]) R *= Rotation.from_rotvec([np.pi / 2, 0, 0]) R *= Rotation.from_rotvec([0, 0, -theta]) return R
[docs] def rotate_lines(lines, inc, obl, theta): inc = np.array(inc) obl = np.array(obl) theta = np.array(theta) R = rotation(inc, obl, theta) rotated_lines = np.array([R.apply(l.T) for l in lines]).T rotated_lines = np.swapaxes(rotated_lines.T, -1, 1) return rotated_lines
[docs] def plot_lines(lines, axis=(0, 1), ax=None, **kwargs): import matplotlib.pyplot as plt if ax is None: ax = plt.gca() if ax is None: ax = plt.subplot(111) # hide lines behind other_axis = list(set(axis).symmetric_difference([0, 1, 2]))[0] behind = lines[:, other_axis, :] < 0 _xyzs = lines.copy().swapaxes(1, 2) _xyzs[behind, :] = np.nan _xyzs = _xyzs.swapaxes(1, 2) for i, j in _xyzs[:, axis, :]: ax.plot(i, j, **kwargs)
[docs] def graticule( inc: float, obl: float, theta: float = 0.0, pts: int = 100, white_contour=True, radius: float = 1.0, n=6, ax=None, **kwargs, ): import matplotlib.pyplot as plt if ax is None: ax = plt.gca() if ax is None: ax = plt.subplot(111) kwargs.setdefault("c", kwargs.pop("color", "k")) kwargs.setdefault("lw", kwargs.pop("linewidth", 1)) kwargs.setdefault("alpha", 0.3) # plot lines lat, lon = lon_lat_lines(pts=pts, radius=radius, n=n) lat = rotate_lines(lat, inc, obl, theta) plot_lines(lat, ax=ax, **kwargs) lon = rotate_lines(lon, inc, obl, theta) plot_lines(lon, ax=ax, **kwargs) theta = np.linspace(0, 2 * np.pi, 2 * pts) # contour sqrt_radius = radius ax.plot(sqrt_radius * np.cos(theta), sqrt_radius * np.sin(theta), **kwargs) if white_contour: ax.plot(sqrt_radius * np.cos(theta), sqrt_radius * np.sin(theta), c="w", lw=3)
# s2fft have the same but this one is jitabel
[docs] def y1d_to_2d(ydeg: int, flm_1d: np.ndarray) -> np.ndarray: """1D starry Ylm to 2D s2fft""" new_flm = jnp.zeros((ydeg + 1, 2 * ydeg + 1), dtype=flm_1d.dtype) i = 0 for l in range(ydeg + 1): for m in range(-l, l + 1): new_flm = new_flm.at[l, m + ydeg].set(flm_1d[i]) i += 1 return new_flm
# s2fft have the same but this one is jitabel
[docs] def y2d_to_1d(ydeg: int, flm_2d: np.ndarray) -> np.ndarray: """2D starry Ylm to 1D s2fft""" new_flm = jnp.zeros((ydeg + 1) ** 2, dtype=flm_2d.dtype) i = 0 for l in range(ydeg + 1): for m in range(-l, l + 1): new_flm = new_flm.at[i].set(flm_2d[l, m + ydeg]) i += 1 return new_flm
[docs] def C(l): """Complex to real conversion matrix""" # See https://doi.org/10.1016/s0166-1280(97)00185-1 (Blanco 1997, Eq. 19) A = np.eye(l, l)[:, ::-1] B = np.zeros(l)[:, None] C = np.diag((-1) ** np.arange(1, l + 1)) ABC = np.hstack([A, B, C]) jABC = np.hstack([1j * A, B, -1j * C])[::-1, :] one = np.zeros(2 * l + 1) one[l] = np.sqrt(2) return np.vstack([jABC, one, ABC]) / np.sqrt(2)