from collections import defaultdict
from functools import reduce
import numpy as np
from jaxoplanet.starry.core import basis
from jaxoplanet.starry.multiprecision import mp, utils
from jaxoplanet.starry.multiprecision.utils import (
fac as fac_function,
kron_delta,
)
[docs]
CACHED_MATRICES = defaultdict(lambda: {})
[docs]
def get_A(A, l_max, cache=None):
if cache is None:
cache = CACHED_MATRICES
name = A.__name__.replace("_", "")
if name not in cache[l_max]:
print(f"pre-computing {name}...")
cache[l_max][name] = A(l_max)
return cache[l_max][name]
[docs]
def fac(n):
if n not in FAC_CACHE:
FAC_CACHE[n] = fac_function(n)
return FAC_CACHE[n]
[docs]
def A(l, m):
"""A spherical harmonic normalization constant."""
return mp.sqrt(
(2 - kron_delta(m, 0)) * (2 * l + 1) * fac(l - m) / (4 * mp.pi * fac(l + m))
)
[docs]
def B(l, m, j, k):
"""Another spherical harmonic normalization constant."""
ratio = fac((l + m + k - 1) / 2) / fac((-l + m + k - 1) / 2)
return 2**l * fac(m) / (fac(j) * fac(k) * fac(m - j) * fac(l - m - k)) * ratio
[docs]
def C(p, q, k):
"""Return the binomial theorem coefficient `C`."""
return fac(k / 2) / (fac(q / 2) * fac((k - p) / 2) * fac((p - q) / 2))
[docs]
def Y(l, m):
"""Return the spherical harmonic of degree `l` and order `m`."""
coeffs = defaultdict(lambda: {})
def get(function, *indices):
fun_name = function.__name__
if indices not in coeffs[fun_name]:
coeffs[fun_name][indices] = function(*indices)
return coeffs[fun_name][indices]
res = defaultdict(lambda: 0)
if m >= 0:
for j in range(0, m + 1, 2):
for k in range(0, l - m + 1, 2):
for p in range(0, k + 1, 2):
for q in range(0, p + 1, 2):
coeff = (
(-1) ** ((j + p) // 2)
* get(A, l, m)
* get(B, l, m, j, k)
* get(C, p, q, k)
)
x_order = m - j + p - q
y_order = j + q
z_order = 0
res[(x_order, y_order, z_order)] += coeff
for k in range(1, l - m + 1, 2):
for p in range(0, k, 2):
for q in range(0, p + 1, 2):
coeff = (
(-1) ** ((j + p) // 2)
* get(A, l, m)
* get(B, l, m, j, k)
* get(C, p, q, k - 1)
)
x_order = m - j + p - q
y_order = j + q
z_order = 1
res[(x_order, y_order, z_order)] += coeff
else:
for j in range(1, abs(m) + 1, 2):
for k in range(0, l - abs(m) + 1, 2):
for p in range(0, k + 1, 2):
for q in range(0, p + 1, 2):
coeff = (
(-1) ** ((j + p - 1) // 2)
* get(A, l, abs(m))
* get(B, l, abs(m), j, k)
* get(C, p, q, k)
)
x_order = abs(m) - j + p - q
y_order = j + q
z_order = 0
res[(x_order, y_order, z_order)] += coeff
for k in range(1, l - abs(m) + 1, 2):
for p in range(0, k, 2):
for q in range(0, p + 1, 2):
coeff = (
(-1) ** ((j + p - 1) // 2)
* get(A, l, abs(m))
* get(B, l, abs(m), j, k)
* get(C, p, q, k - 1)
)
x_order = abs(m) - j + p - q
y_order = j + q
z_order = 1
res[(x_order, y_order, z_order)] += coeff
return res
[docs]
def p_coeffs(n):
l = mp.floor(mp.sqrt(n))
m = n - l * l - l
mu = int(l - m)
nu = int(l + m)
if nu % 2 == 0:
i = mu // 2
j = nu // 2
k = 0
else:
i = (mu - 1) // 2
j = (nu - 1) // 2
k = 1
return (i, j, k)
def _A1(l_max):
n = (l_max + 1) ** 2
p = {m: p_coeffs(m) for m in range(n)}
res = mp.zeros(n, n)
k = 0
for l in range(l_max + 1):
for m in range(-l, l + 1):
y = Y(l, m)
res[:, k] = mp.matrix([y[p[i]] for i in range(n)])
k += 1
res = res * 2 / mp.sqrt(mp.pi)
return res
[docs]
def A1(l_max, cache=None):
if cache is None:
cache = CACHED_MATRICES
return get_A(_A1, l_max, cache=cache)
[docs]
def gtilde(n):
l = mp.floor(mp.sqrt(n))
m = n - l * l - l
mu = l - m
nu = l + m
if nu % 2 == 0:
I = [mp.floor(mu / 2)]
J = [mp.floor(nu / 2)]
K = [0]
C = [mp.floor((mu + 2) / 2)]
elif (l == 1) and (m == 0):
I = [0]
J = [0]
K = [1]
C = [1]
elif (mu == 1) and (l % 2 == 0):
I = [l - 2]
J = [1]
K = [1]
C = [3]
elif mu == 1:
I = [l - 3, l - 1, l - 3]
J = [0, 0, 2]
K = [1, 1, 1]
C = [-1, 1, 4]
else:
I = [mp.floor((mu - 5) / 2), mp.floor((mu - 5) / 2), mp.floor((mu - 1) / 2)]
J = [mp.floor((nu - 1) / 2), mp.floor((nu + 3) / 2), mp.floor((nu - 1) / 2)]
K = [1, 1, 1]
C = [mp.floor((mu - 3) / 2), -mp.floor((mu - 3) / 2), -mp.floor((mu + 3) / 2)]
res = {}
for i, j, k, c in zip(I, J, K, C, strict=False):
res[(i, j, k)] = c
return res
[docs]
def A2_inv(l_max):
n = (l_max + 1) ** 2
p = {m: p_coeffs(m) for m in range(n)}
res = mp.zeros(n, n)
k = 0
for l in range(l_max + 1):
for _ in range(-l, l + 1):
y = gtilde(k)
res[:, k] = mp.matrix([y[p[i]] if p[i] in y else 0.0 for i in range(n)])
k += 1
return res
def _A2(lmax):
"""Compute the A2 matrix directly
A2_inv is way faster to compute but then computing the inverse with mpmath is very
slow (~4 min for lmax=20). This function makes the computation by solving the change
basis matrix for each column of the A2_inv matrix.
The way it works is to formalize the problem by building the A matrix which
encapsulate the linear system. The tricky part is to find all possible powers to
use, which is done recursively in the all_subsets_indices function.
"""
n = (lmax + 1) ** 2
A2 = mp.zeros(n, n)
gs = [basis.gtilde(i) for i in range(n)]
def all_subsets_indices(s, sets, indices=None):
if indices is None:
indices = []
for si in s:
for i, set_ in enumerate(sets):
if si in set_ and i not in indices:
indices.append(i)
all_subsets_indices(set_, sets, indices)
return indices
for k in range(n):
available_sets = gs[0 : int((np.floor(np.sqrt(k)) + 1) ** 2)]
target_set = [basis.ptilde(k)]
indices = np.array(all_subsets_indices(target_set, available_sets))
powers = list(reduce(set.union, [set(gs[i].keys()) for i in indices]))
A = mp.zeros(len(powers), len(indices))
for i, ijk in enumerate(powers):
for j, index in enumerate(indices):
A[i, j] = gs[index].get(ijk, 0)
b = np.array([target_set[0] == p for p in powers]).astype(int)
# in multi-precision
x = mp.lu_solve(A, utils.to_mp(b))
for i, _x in zip(indices, x, strict=False):
A2[k, i] += _x
return A2.T
[docs]
def A2(lmax, cache=None):
if cache is None:
cache = CACHED_MATRICES
return get_A(_A2, lmax, cache=cache)