import math
from collections import defaultdict
from functools import partial
import jax.numpy as jnp
import numpy as np
import scipy.sparse.linalg
from jax.experimental.sparse import BCOO
from scipy.special import comb, gamma
try:
from scipy.sparse import csc_array
except ImportError:
# With older versions of scipy, the data structures were called "matrices"
# not "arrays"; this allows us to support either.
from scipy.sparse import csc_matrix as csc_array
[docs]
def basis(lmax):
"""Full change of basis matrix from spherical harmonics to Green's basis
Args:
lmax (int): maximum degree of the spherical harmonic basis
"""
matrix = scipy.sparse.linalg.spsolve(A2_inv(lmax), A1(lmax))
if lmax > 0:
return BCOO.from_scipy_sparse(matrix)
else:
return BCOO.fromdense(np.squeeze(matrix)[None, None])
[docs]
def A1(lmax):
"""Change of basis matrix from spherical harmonics to polynomial basis.
Args:
lmax (int): Maximum degree of the spherical harmonic basis.
Returns:
TODO: Description of the return value.
"""
return _A_impl(lmax, p_Y) * 2 / np.sqrt(np.pi)
[docs]
def A2_inv(lmax):
"""Change of basis matrix from polynomial basis to Green's basis.
Args:
lmax (int): Maximum degree of the spherical harmonic basis.
Returns:
TODO: Description of the return value.
"""
return _A_impl(lmax, p_G)
def _A_impl(lmax, func):
"""Return a sparse change of basis matrix given a function that maps
to the polynomial basis.
Args:
lmax (int): Maximum degree of the spherical harmonic basis.
func (callable): Function that maps to the polynomial basis (signature
irrelevant here and used for convenience). The output must be a tuple
of (indices, data) where indices is a list of indices of the polynomial basis
terms and data is a list of the coefficients of the polynomial basis terms
(see `p_Y` and `p_G`).
Returns:
_type_: _description_
"""
n = (lmax + 1) ** 2
data = []
row_ind = []
col_ind = []
p = {ptilde(m): m for m in range(n)}
n = 0
for l in range(lmax + 1):
for m in range(-l, l + 1):
idx, val = func(p, l, m, n)
data.extend(val)
row_ind.extend(idx)
col_ind.extend([n] * len(idx))
n += 1
return csc_array((np.array(data), (row_ind, col_ind)), shape=(n, n))
[docs]
def ptilde(n):
"""Compute the x, y, and z powers of the n-th polynomial basis term.
If the n-th term is x^i y^j z^k, return (i, j, k).
Args:
n (int): Index of the polynomial basis term.
Returns:
tuple: (i, j, k)
Example:
>>> ptilde(2) # z
(0, 0, 1)
>>> ptilde(3) # x + y
(1, 1, 0)
"""
l = math.floor(math.sqrt(n))
m = n - l * l - l
mu = l - m
nu = l + m
if nu % 2 == 0:
i = mu // 2
j = nu // 2
k = 0
else:
i = (mu - 1) // 2
j = (nu - 1) // 2
k = 1
return (i, j, k)
[docs]
def Alm(l, m):
return math.sqrt(
(2 - int(m == 0))
* (2 * l + 1)
* math.factorial(l - m)
/ (4 * math.pi * math.factorial(l + m))
)
[docs]
def Blmjk(l, m, j, k):
a = l + m + k - 1
b = -l + m + k - 1
if (b < 0) and (b % 2 == 0):
return 0
else:
ratio = gamma(0.5 * a + 1) / gamma(0.5 * b + 1)
return (
2**l
* math.factorial(m)
/ (
math.factorial(j)
* math.factorial(k)
* math.factorial(m - j)
* math.factorial(l - m - k)
)
* ratio
)
[docs]
def Cpqk(p, q, k):
return math.factorial(k // 2) / (
math.factorial(q // 2)
* math.factorial((k - p) // 2)
* math.factorial((p - q) // 2)
)
[docs]
def Ylm(l, m):
"""Compute the coefficients of the spherical harmonic Y_{l,m}.
Args:
l (int): Degree of the spherical harmonic.
m (int): Order of the spherical harmonic in the range [-l, l].
Returns:
dict: {(i, j, k): coeff} where i, j, k are the powers of x, y, z (see `ptilde`).
Example:
>>> Ylm(2, 0)
{(0, 0, 0): 0.6307831305050402,
(2, 0, 0): -0.9461746957575603,
(0, 2, 0): -0.9461746957575603}
"""
res = defaultdict(lambda: 0)
A = Alm(l, abs(m))
for j in range(int(m < 0), abs(m) + 1, 2):
for k in range(0, l - abs(m) + 1, 2):
B = Blmjk(l, abs(m), j, k)
if not B:
continue
factor = A * B
for p in range(0, k + 1, 2):
for q in range(0, p + 1, 2):
ind = (abs(m) - j + p - q, j + q, 0)
res[ind] += (
(-1) ** ((j + p - (m < 0)) // 2) * factor * Cpqk(p, q, k)
)
for k in range(1, l - abs(m) + 1, 2):
B = Blmjk(l, abs(m), j, k)
if not B:
continue
factor = A * B
for p in range(0, k, 2):
for q in range(0, p + 1, 2):
ind = (abs(m) - j + p - q, j + q, 1)
res[ind] += (
(-1) ** ((j + p - (m < 0)) // 2) * factor * Cpqk(p, q, k - 1)
)
return dict(res)
[docs]
def p_Y(p, l, m, n):
"""Return a representation of Y_{l, m} in the polynomial basis.
Args:
p (dict): Powers of xyz as returned by `ptilde`.
l (int): Degree of the spherical harmonic.
m (int): Order of the spherical harmonic in the range [-l, l].
n (None): Dummy variable.
Returns:
tuple: (indices, data) where indices is an np.array of indices of the polynomial
basis terms and data is an np.array of the coefficients of the polynomial
basis terms.
Example:
>>> p = {ptilde(m): m for m in range(9)}
>>> p_Y(p, 2, 0, 0)
(array([0, 4, 8]), array([ 0.63078313, -0.9461747 , -0.9461747 ]))
# see correspondence with `Ylm` example
"""
del n
indices = []
data = []
for k, v in Ylm(l, m).items():
if k not in p:
continue
indices.append(p[k])
data.append(v)
indices = np.array(indices, dtype=int)
data = np.array(data, dtype=float)
idx = np.argsort(indices)
return indices[idx], data[idx]
[docs]
def gtilde(n):
"""Compute the n-th term of the Green basis in the polynomial basis.
Args:
n (int): Index of the Green basis term.
Returns:
dict: {(i, j, k): coeff} where i, j, k are the powers of x, y, z (see `ptilde`).
Example:
>>> gtilde(50)
{(4, 0, 1): 5, (4, 2, 1): -5, (6, 0, 1): -8}
"""
l = math.floor(math.sqrt(n))
m = n - l * l - l
mu = l - m
nu = l + m
if nu % 2 == 0:
I = [mu // 2]
J = [nu // 2]
K = [0]
C = [(mu + 2) // 2]
elif (l == 1) and (m == 0):
I = [0]
J = [0]
K = [1]
C = [1]
elif (mu == 1) and (l % 2 == 0):
I = [l - 2]
J = [1]
K = [1]
C = [3]
elif mu == 1:
I = [l - 3, l - 1, l - 3]
J = [0, 0, 2]
K = [1, 1, 1]
C = [-1, 1, 4]
else:
I = [(mu - 5) // 2, (mu - 5) // 2, (mu - 1) // 2]
J = [(nu - 1) // 2, (nu + 3) // 2, (nu - 1) // 2]
K = [1, 1, 1]
C = [(mu - 3) // 2, -(mu - 3) // 2, -(mu + 3) // 2]
res = {}
for i, j, k, c in zip(I, J, K, C, strict=False):
if c != 0.0:
res[(i, j, k)] = c
return res
[docs]
def p_G(p, l, m, n):
"""Return a representation of Green's basis n-th term in the polynomial basis.
Args:
p (dict): Powers of xyz as returned by `ptilde`.
l (None): Dummy variable.
m (None): Dummy variable.
n (int): Index of the Green basis term.
Returns:
tuple: (indices, data) where indices is an np.array of indices of the polynomial
basis terms and data is an np.array of the coefficients of the polynomial
basis terms.
Example:
>>> p = {ptilde(n): n for n in range(100)}
>>> p_G(p, None, None, 50)
(array([26, 50, 54]), array([ 5., -8., -5.]))
"""
del l, m
indices = []
data = []
for k, v in gtilde(n).items():
if k not in p:
continue
indices.append(p[k])
data.append(v)
indices = np.array(indices, dtype=int)
data = np.array(data, dtype=float)
idx = np.argsort(indices)
return indices[idx], data[idx]
[docs]
def poly_basis(deg):
N = (deg + 1) * (deg + 1)
@partial(jnp.vectorize, signature=f"(),(),()->({N})")
def impl(x, y, z):
xarr = [None for _ in range(N)]
yarr = [None for _ in range(N)]
# Ensures we get `nan`s off the disk
xterm = 1.0 + 0.0 * z
yterm = 1.0 + 0.0 * z
i0 = 0
di0 = 3
j0 = 0
dj0 = 2
for n in range(deg + 1):
i = i0
di = di0
xarr[i] = xterm
j = j0
dj = dj0
yarr[j] = yterm
i = i0 + di - 1
j = j0 + dj - 1
while i + 1 < N:
xarr[i] = xterm
xarr[i + 1] = xterm
di += 2
i += di
yarr[j] = yterm
yarr[j + 1] = yterm
dj += 2
j += dj - 1
xterm *= x
i0 += 2 * n + 1
di0 += 2
yterm *= y
j0 += 2 * (n + 1) + 1
dj0 += 2
assert all(v is not None for v in xarr)
assert all(v is not None for v in yarr)
inds = []
n = 0
for ell in range(deg + 1):
for m in range(-ell, ell + 1):
if (ell + m) % 2 != 0:
inds.append(n)
n += 1
p = jnp.array(xarr) * jnp.array(yarr)
if len(inds):
return p.at[np.array(inds)].multiply(z, unique_indices=True)
else:
return p
return impl
[docs]
def utilde(n):
res = defaultdict(float)
if n == 0:
return {(0, 0, 0): 1.0}
for k in range(n + 1):
c1 = comb(n, k) * (-1) ** k
k2 = k // 2
for j in range(k2 + 1):
c2 = comb(k2, j) * (-1) ** j
for l in range(j + 1):
c3 = comb(j, l)
res[(2 * (j - l), 2 * l, k % 2)] += -c1 * c2 * c3
return res
[docs]
def u_p(p, l, m, n):
# this is very similar to gtilde, might be a more general
# way of doing the _A_impl function without dummy variables
del l, m
indices = []
data = []
for k, v in utilde(n).items():
if k not in p:
continue
indices.append(p[k])
data.append(v)
indices = np.array(indices, dtype=int)
data = np.array(data, dtype=float)
idx = np.argsort(indices)
return indices[idx], data[idx]
[docs]
def U(udeg: int):
"""Change of basis matrix from limb darkening basis to polynomial basis.
Args:
udeg (int): Degree of the limb darkening basis.
Returns:
TODO
"""
n = (udeg + 1) ** 2
p = {ptilde(m): m for m in range(n)}
P = np.zeros((udeg + 1, n))
for i in range(udeg + 1):
idxs, values = u_p(p, None, None, i)
for j, v in zip(idxs, values, strict=False):
P[i, j] += v
return P