import math
from functools import partial
import jax
import jax.numpy as jnp
from jax.scipy.spatial.transform import Rotation
from jaxoplanet.starry import utils
from jaxoplanet.starry.core.s2fft_rotation import (
compute_rotation_matrices as compute_rotation_matrices_s2fft,
rotate_flms,
)
from jaxoplanet.types import Array
from jaxoplanet.utils import get_dtype_eps
[docs]
def dot_rotation_matrix(ydeg, x, y, z, theta):
"""Construct a callable to apply a spherical harmonic rotation
Args:
ydeg (int): The order of the spherical harmonic map
x (float): The x component of the rotation axis
y (float): The y component of the rotation axis
z (float): The z component of the rotation axis
theta (float): The rotation angle in radians
"""
try:
ydeg = int(ydeg)
except TypeError as e:
raise TypeError(f"ydeg must be an integer; got {ydeg}") from e
if theta is None:
def do_dot(M):
return M
return do_dot
if x is None and y is None:
if z is None:
raise ValueError("Either x, y, or z must be specified")
return dot_rz(ydeg, theta)
x = 0.0 if x is None else x
y = 0.0 if y is None else y
z = 0.0 if z is None else z
if jnp.shape(x) != ():
raise ValueError(f"x must be a scalar; got {jnp.shape(x)}")
if jnp.shape(y) != ():
raise ValueError(f"y must be a scalar; got {jnp.shape(y)}")
if jnp.shape(z) != ():
raise ValueError(f"z must be a scalar; got {jnp.shape(z)}")
if jnp.shape(theta) != ():
raise ValueError(f"theta must be a scalar; got {jnp.shape(theta)}")
rotation_matrices = compute_rotation_matrices_s2fft(ydeg, x, y, z, theta)
n_max = ydeg**2 + 2 * ydeg + 1
@jax.jit
@partial(jnp.vectorize, signature=f"({n_max})->({n_max})")
def do_dot(M):
"""Rotate a spherical harmonic map
Args:
M (Array[..., n_max]): The spherical harmonic map to rotate
"""
if M.shape[-1] != n_max:
raise ValueError(
f"Dimension mismatch: Input array must have shape (..., {n_max}); "
f"got {M.shape}"
)
result = []
for ell in range(ydeg + 1):
result.append(
M[ell * ell : ell * ell + 2 * ell + 1] @ rotation_matrices[ell]
)
return jnp.concatenate(result, axis=0)
return do_dot
[docs]
def full_rotation_axis_angle(
inc: float | None, obl: float | None, theta: float | None, theta_z: float | None
):
"""Return the axis-angle representation of the full rotation of the
spherical harmonic map
Args:
inc (float): map inclination
obl (float): map obliquity
theta (float): rotation angle about the map z-axis
theta_z (float): rotation angle about the sky y-axis
Returns:
tuple: x, y, z, angle
"""
inc = 0.0 if inc is None else inc
obl = 0.0 if obl is None else obl
theta = 0.0 if theta is None else theta
theta_z = 0.0 if theta_z is None else theta_z
f = 0.5 * math.sqrt(2)
si = jnp.sin(inc / 2)
ci = jnp.cos(inc / 2)
if theta is not None and theta_z is not None:
sp = jnp.sin(0.5 * (obl + theta + theta_z))
cp = jnp.cos(0.5 * (obl + theta + theta_z))
sm = jnp.sin(0.5 * (obl - theta + theta_z))
cm = jnp.cos(0.5 * (obl - theta + theta_z))
else:
sp = jnp.sin(obl / 2)
cp = jnp.cos(obl / 2)
sm = sp
cm = cp
numerator1 = f * (-si * cm + ci * cp)
numerator2 = f * (-si * sm + sp * ci)
numerator3 = f * (si * sm + sp * ci)
arg = si * cm + ci * cp
denominator = jnp.sqrt(1 - 0.5 * arg**2)
farg = f * arg
zero_angle = jnp.allclose(farg, 1.0, atol=get_dtype_eps(farg))
angle = jnp.where(zero_angle, 0.0, 2 * jnp.arccos(farg))
non_zero_angle = jnp.logical_not(zero_angle)
# this is mostly useful for the float32 case, where
# (1 - 0.5 * arg**2) in denominator can be negative due to numerical error
positive_arg = arg**2 < 2.0
axis_x = jnp.where(positive_arg & non_zero_angle, numerator1 / denominator, 1.0)
axis_y = jnp.where(positive_arg & non_zero_angle, numerator2 / denominator, 0.0)
axis_z = jnp.where(positive_arg & non_zero_angle, numerator3 / denominator, 0.0)
return axis_x, axis_y, axis_z, angle
[docs]
def sky_projection_axis_angle(inc: float | None, obl: float | None):
"""Return the axis-angle representation of the partial rotation of the
map due to inclination and obliquity
Args:
inc (float or None): map inclination
obl (float or None): map obliquity
Returns:
tuple: x, y, z, angle
"""
if obl is None and inc is None:
return 1.0, None, None, None
elif obl is None:
return 1.0, None, None, inc
elif inc is None:
return None, None, 1.0, obl
else:
co = jnp.cos(obl / 2)
so = jnp.sin(obl / 2)
ci = jnp.cos(inc / 2)
si = jnp.sin(inc / 2)
denominator = jnp.sqrt(1 - ci**2 * co**2)
# to avoid nans for the case where ci * co == 1
denominator = jnp.where(denominator == 0.0, 1.0, denominator)
axis_x = si * co
axis_y = si * so
axis_z = -so * ci
angle = 2 * jnp.arccos(ci * co)
arg = jnp.linalg.norm(jnp.array([axis_x, axis_y, axis_z]))
axis_x = jnp.where(arg > 0, axis_x / denominator, 1.0)
axis_y = jnp.where(arg > 0, axis_y / denominator, 0.0)
axis_z = jnp.where(arg > 0, axis_z / denominator, 0.0)
return axis_x, axis_y, axis_z, angle
[docs]
def left_project(
ydeg: int,
inc: float | None,
obl: float | None,
theta: float | None,
theta_z: float | None,
y: Array,
):
"""R @ y
Args:
ydeg (int): degree of the spherical harmonic map
inc (float or None): map inclination
obl (float or None): map obliquity
theta (float or None): rotation angle about the map z-axis
theta_z (float or None): rotation angle about the sky y-axis
x (Array): spherical harmonic map coefficients
Returns:
Array: rotated spherical harmonic map coefficients
"""
m_theta = -theta if theta is not None else theta
m_theta_z = -theta_z if theta_z is not None else theta_z
axis_x, axis_y, axis_z, angle = sky_projection_axis_angle(inc, obl)
y = dot_rotation_matrix(ydeg, 1.0, None, None, -0.5 * jnp.pi)(y)
y = dot_rotation_matrix(ydeg, None, None, 1.0, m_theta)(y)
y = dot_rotation_matrix(ydeg, axis_x, axis_y, axis_z, angle)(y)
y = dot_rotation_matrix(ydeg, None, None, 1.0, m_theta_z)(y)
return y
[docs]
def right_project(
ydeg: int,
inc: float | None,
obl: float | None,
theta: float | None,
theta_z: float | None,
y: Array,
):
"""y @ R
Args:
ydeg (int): degree of the spherical harmonic map
inc (float or None): map inclination
obl (float or None): map obliquity
theta (float or None): rotation angle about the map z-axis
theta_z (float or None): rotation angle about the sky y-axis
x (Array): spherical harmonic map coefficients
Returns:
Array: rotated spherical harmonic map coefficients
"""
axis_x, axis_y, axis_z, angle = sky_projection_axis_angle(inc, obl)
m_axis_x = -axis_x if axis_x is not None else axis_x
m_axis_y = -axis_y if axis_y is not None else axis_y
m_axis_z = -axis_z if axis_z is not None else axis_z
y = dot_rotation_matrix(ydeg, None, None, 1.0, theta_z)(y)
y = dot_rotation_matrix(ydeg, m_axis_x, m_axis_y, m_axis_z, angle)(y)
y = dot_rotation_matrix(ydeg, None, None, 1.0, theta)(y)
y = dot_rotation_matrix(ydeg, 1.0, None, None, 0.5 * jnp.pi)(y)
return y
@partial(jax.jit, static_argnums=(0,))
[docs]
def compute_rotation_matrices(ydeg, x, y, z, theta):
# we need the axis to be a unit vector - enforce that here
norm = jnp.sqrt(x * x + y * y + z * z)
# handle the case where axis is (0, 0, 0)
x = jnp.where(norm == 0.0, 0.0, x / norm)
y = jnp.where(norm == 0.0, 0.0, y / norm)
z = jnp.where(norm == 0.0, 0.0, z / norm)
s = jnp.sin(theta)
c = jnp.cos(theta)
ra01 = x * y * (1 - c) - z * s
ra02 = x * z * (1 - c) + y * s
ra11 = c + y * y * (1 - c)
ra12 = y * z * (1 - c) - x * s
ra20 = z * x * (1 - c) - y * s
ra21 = z * y * (1 - c) + x * s
ra22 = c + z * z * (1 - c)
tol = 10 * get_dtype_eps(ra22)
cond_neg = jnp.less(jnp.abs(ra22 + 1.0), tol)
cond_pos = jnp.less(jnp.abs(ra22 - 1.0), tol)
cond_full = jnp.logical_or(cond_pos, cond_neg)
sign = cond_neg.astype(int) - cond_pos.astype(int)
norm1 = jnp.sqrt(jnp.where(cond_full, 1, ra20 * ra20 + ra21 * ra21))
norm2 = jnp.sqrt(jnp.where(cond_full, 1, ra02 * ra02 + ra12 * ra12))
cos_beta = ra22
ra22_ = jnp.where(cond_full, 0.0, ra22)
sin_beta = jnp.where(
cond_full,
1 + sign * ra22,
jnp.sqrt(1 - ra22_ * ra22_), # type: ignore
)
cos_gamma = jnp.where(cond_full, ra11, -ra20 / norm1)
sin_gamma = jnp.where(cond_full, sign * ra01, ra21 / norm1)
cos_alpha = jnp.where(cond_full, -sign * ra22, ra02 / norm2)
sin_alpha = jnp.where(cond_full, 1 + sign * ra22, ra12 / norm2)
return rotar(
ydeg,
cos_alpha,
sin_alpha,
cos_beta,
sin_beta,
cos_gamma,
sin_gamma,
)[1]
[docs]
def rotar(ydeg, c1, s1, c2, s2, c3, s3):
sqrt_2 = jnp.sqrt(2.0)
D = []
R = []
# D[0]; R[0 ]
D.append(jnp.ones((1, 1)))
R.append(jnp.ones((1, 1)))
if ydeg == 0:
return D, R
# D[1]
D1_22 = 0.5 * (1 + c2)
D1_21 = -s2 / sqrt_2
D1_20 = 0.5 * (1 - c2)
D1_12 = -D1_21
D1_11 = D1_22 - D1_20
D1_10 = D1_21
D1_02 = D1_20
D1_01 = D1_12
D1_00 = D1_22
D.append(
jnp.array(
[
[D1_00, D1_01, D1_02],
[D1_10, D1_11, D1_12],
[D1_20, D1_21, D1_22],
]
)
)
# R[1]
cosag = c1 * c3 - s1 * s3
cosamg = c1 * c3 + s1 * s3
sinag = s1 * c3 + c1 * s3
sinamg = s1 * c3 - c1 * s3
R1_11 = D1_11
R1_21 = sqrt_2 * D1_12 * c1
R1_01 = sqrt_2 * D1_12 * s1
R1_12 = sqrt_2 * D1_21 * c3
R1_10 = -sqrt_2 * D1_21 * s3
R1_22 = D1_22 * cosag - D1_20 * cosamg
R1_20 = -D1_22 * sinag - D1_20 * sinamg
R1_02 = D1_22 * sinag - D1_20 * sinamg
R1_00 = D1_22 * cosag + D1_20 * cosamg
R.append(
jnp.array(
[
[R1_00, R1_01, R1_02],
[R1_10, R1_11, R1_12],
[R1_20, R1_21, R1_22],
]
)
)
tol = 10 * jnp.finfo(jnp.dtype(s2)).eps
s2_cond = jnp.less(jnp.abs(s2), tol)
tgbet2 = jnp.where(s2_cond, s2, (1 - c2) / jnp.where(s2_cond, 1.0, s2))
for ell in range(2, ydeg + 1):
D_, R_ = dlmn(ell, s1, c1, c2, tgbet2, s3, c3, D)
D.append(D_)
R.append(R_)
return D, R
[docs]
def dlmn(ell, s1, c1, c2, tgbet2, s3, c3, D):
iinf = 1 - ell
isup = -iinf
# Last row by recurrence (Eq. 19 and 20 in Alvarez Collado et al.)
D_ = [[0 for _ in range(2 * ell + 1)] for _ in range(2 * ell + 1)]
D_[2 * ell][2 * ell] = 0.5 * D[-1][isup + ell - 1, isup + ell - 1] * (1 + c2)
for m in range(isup, iinf - 1, -1):
D_[2 * ell][m + ell] = (
-tgbet2 * jnp.sqrt((ell + m + 1) / (ell - m)) * D_[2 * ell][m + 1 + ell]
)
D_[2 * ell][0] = 0.5 * D[ell - 1][isup + ell - 1, -isup + ell - 1] * (1 - c2)
# The rows of the upper quarter triangle of the D[l;m',m) matrix
# (Eq. 21 in Alvarez Collado et al.)
al = ell
al1 = al - 1
tal1 = al + al1
ali = 1.0 / al1
cosaux = c2 * al * al1
for mp in range(ell - 1, -1, -1):
amp = mp
laux = ell + mp
lbux = ell - mp
aux = ali / jnp.sqrt(laux * lbux)
cux = jnp.sqrt((laux - 1) * (lbux - 1)) * al
for m in range(isup, iinf - 1, -1):
am = m
lauz = ell + m
lbuz = ell - m
auz = 1.0 / jnp.sqrt(lauz * lbuz)
fact = aux * auz
term = tal1 * (cosaux - am * amp) * D[-1][mp + ell - 1, m + ell - 1]
if lbuz != 1 and lbux != 1:
cuz = jnp.sqrt((lauz - 1) * (lbuz - 1))
term = term - D[-2][mp + ell - 2, m + ell - 2] * cux * cuz
D_[mp + ell][m + ell] = fact * term
iinf += 1
isup -= 1
# The remaining elements of the D[l;m',m) matrix are calculated
# using the corresponding symmetry relations:
# reflection ---> ((-1)**(m-m')) D[l;m,m') = D[l;m',m), m'<=m
# inversion ---> ((-1)**(m-m')) D[l;-m',-m) = D[l;m',m)
sign = 1
iinf = -ell
isup = ell - 1
for m in range(ell, 0, -1):
for mp in range(iinf, isup + 1):
D_[mp + ell][m + ell] = sign * D_[m + ell][mp + ell]
sign *= -1
iinf += 1
isup -= 1
# Inversion
iinf = -ell
isup = iinf
for m in range(ell - 1, -(ell + 1), -1):
sign = -1
for mp in range(isup, iinf - 1, -1):
D_[mp + ell][m + ell] = sign * D_[-mp + ell][-m + ell]
sign *= -1
# iinf += 1
isup += 1
# Compute the real rotation matrices R from the complex ones D
R_ = [[0 for _ in range(2 * ell + 1)] for _ in range(2 * ell + 1)]
R_[ell][ell] = D_[ell][ell]
cosmal = c1
sinmal = s1
sign = -1
root_two = jnp.sqrt(2.0)
for mp in range(1, ell + 1):
cosmga = c3
sinmga = s3
aux = root_two * D_[ell][mp + ell]
R_[mp + ell][ell] = aux * cosmal
R_[-mp + ell][ell] = aux * sinmal
for m in range(1, ell + 1):
aux = root_two * D_[m + ell][ell]
R_[ell][m + ell] = aux * cosmga
R_[ell][-m + ell] = -aux * sinmga
d1 = D_[-mp + ell][-m + ell]
d2 = sign * D_[mp + ell][-m + ell]
cosag = cosmal * cosmga - sinmal * sinmga
cosagm = cosmal * cosmga + sinmal * sinmga
sinag = sinmal * cosmga + cosmal * sinmga
sinagm = sinmal * cosmga - cosmal * sinmga
R_[mp + ell][m + ell] = d1 * cosag + d2 * cosagm
R_[mp + ell][-m + ell] = -d1 * sinag + d2 * sinagm
R_[-mp + ell][m + ell] = d1 * sinag + d2 * sinagm
R_[-mp + ell][-m + ell] = d1 * cosag - d2 * cosagm
aux = cosmga * c3 - sinmga * s3
sinmga = sinmga * c3 + cosmga * s3
cosmga = aux
sign *= -1
aux = cosmal * c1 - sinmal * s1
sinmal = sinmal * c1 + cosmal * s1
cosmal = aux
return jnp.asarray(D_), jnp.asarray(R_)
[docs]
def dot_rz(deg, theta):
"""Special case for rotation only around z axis"""
c = jnp.cos(theta)
s = jnp.sin(theta)
cosnt = [1.0, c]
sinnt = [0.0, s]
for n in range(2, deg + 1):
cosnt.append(2.0 * cosnt[n - 1] * c - cosnt[n - 2])
sinnt.append(2.0 * sinnt[n - 1] * c - sinnt[n - 2])
n = 0
cosmt = []
sinmt = []
for ell in range(deg + 1):
for m in range(-ell, 0):
cosmt.append(cosnt[-m])
sinmt.append(-sinnt[-m])
for m in range(ell + 1):
cosmt.append(cosnt[m])
sinmt.append(sinnt[m])
n_max = deg**2 + 2 * deg + 1
@jax.jit
@partial(jnp.vectorize, signature=f"({n_max})->({n_max})")
def impl(M):
result = [0 for _ in range(n_max)]
for ell in range(deg + 1):
for j in range(2 * ell + 1):
result[ell * ell + j] = (
M[ell * ell + j] * cosmt[ell * ell + j]
+ M[ell * ell + 2 * ell - j] * sinmt[ell * ell + j]
)
return jnp.array(result, dtype=jnp.dtype(M))
return impl
[docs]
def fast_direct_left_project(ydeg, inc, obl, theta, theta_z, y):
def euler(x, y, z, theta):
"""axis-angle to euler angles"""
# the jnp where for theta == 0 is to avoid nans when computing grad
axis = jnp.array([jnp.where(theta == 0.0, 1.0, x), y, z])
_theta = jnp.where(theta == 0.0, 1.0, theta)
axis = axis / jnp.linalg.norm(axis)
r = Rotation.from_rotvec(axis * _theta)
return jnp.where(theta == 0.0, jnp.array([0.0, 0.0, 0.0]), r.as_euler("zyz"))
_x, _y, _z, angle = full_rotation_axis_angle(inc, obl, theta, theta_z)
_axis = jnp.array([-_x, -_y, _z])
alpha, beta, gamma = euler(*_axis, -angle)
u = utils.C(ydeg)
u_dag = jnp.conj(u.T)
y_complex = jnp.array(y, dtype=jnp.complex128)
y_d2 = utils.y1d_to_2d(ydeg, y_complex)
y_2d_complex = (u_dag @ y_d2.T).T
y2d_rotated_complex = rotate_flms(y_2d_complex, ydeg + 1, (alpha, beta, gamma))
y2d_rotated_real = u @ y2d_rotated_complex.T
y_rotated = utils.y2d_to_1d(ydeg, y2d_rotated_real.T).real
return y_rotated