Source code for jaxoplanet.starry.surface_is_physical

from functools import partial

import jax
import jax.numpy as jnp

from jaxoplanet.starry.surface import Surface


# -------- Grid (equal-area-ish Fibonacci; tiny + deterministic) --------
[docs] def fibonacci_grid(oversample: int, lmax: int): n = oversample * (lmax**2) i = jnp.arange(n) phi = (1.0 + jnp.sqrt(5.0)) / 2.0 theta = jnp.arccos(1.0 - 2.0 * (i + 0.5) / n) # colatitude in [0, pi] lon = (2.0 * jnp.pi) * ((i / phi) % 1.0) # [0, 2pi) lat = (jnp.pi / 2.0) - theta # latitude in [-pi/2, pi/2] return jnp.stack([lat, lon], axis=-1) # (n, 2)
# -------- Helpers to keep angles in-range (differentiable) ------------- def _wrap_lon(lon): # Wrap to [0, 2pi) two_pi = 2.0 * jnp.pi return lon - two_pi * jnp.floor(lon / two_pi) def _clip_lat(lat, eps=1e-6): # Clip to (-pi/2, pi/2) by tiny epsilon to avoid singularities return jnp.clip(lat, -jnp.pi / 2 + eps, jnp.pi / 2 - eps) def _project_coord(x): lat = _clip_lat(x[0]) lon = _wrap_lon(x[1]) return jnp.array([lat, lon]) # -------- Fixed-iteration damped Newton in 2D (lat, lon) -------------- @partial( jax.jit, static_argnames=("oversample", "lmax", "newton_iters", "damping", "step") )
[docs] def surface_min_intensity( surface: Surface, oversample: int, lmax: int, newton_iters: int = 12, damping: float = 1e-3, step: float = 1.0, ): """ Fully JAX, end-to-end differentiable approximate global min: 1) seed from tiny equal-area grid 2) run fixed M Newton steps from each seed in parallel 3) take global min across seeds Args: surface: jaxoplanet Surface object with .intensity(lat, lon) oversample, lmax: define N = oversample * lmax^2 seeds newton_iters: fixed Newton iterations per seed (no line search) damping: Levenberg-Marquardt diagonal added to Hessian step: Newton step scaling (e.g., 1.0 or 0.5) Returns: (lat_min, lon_min), min_val """ # Scalar objective taking a 2-vector x = [lat, lon] def f(x): lat, lon = x return surface.intensity(lat, lon) # Grad & Hessian w.r.t. x grad_f = jax.grad(f) hess_f = jax.hessian(f) def newton_one_seed(x0): def body(_, x): g = grad_f(x) # (2,) H = hess_f(x) # (2,2) H_damped = H + damping * jnp.eye(2, dtype=x.dtype) # Solve H p = g (Newton step uses p = H^{-1} g) p = jnp.linalg.solve(H_damped, g) x_new = x - step * p # keep angles valid return _project_coord(x_new) xT = jax.lax.fori_loop(0, newton_iters, body, _project_coord(x0)) return xT, f(xT) # Seeds seeds = fibonacci_grid(oversample, lmax) # (N,2) # Run Newton from all seeds in parallel, no Python loops, fixed iteration count xs, vals = jax.vmap(newton_one_seed)(seeds) # xs: (N,2), vals: (N,) # Take global min i_min = jnp.argmin(vals) x_min = xs[i_min] v_min = vals[i_min] return x_min, v_min