Source code for jaxoplanet.starry.orbit
from collections.abc import Callable, Iterable, Sequence
from typing import Any
from jaxoplanet.object_stack import ObjectStack
from jaxoplanet.orbits.keplerian import Body, Central, OrbitalBody, System
from jaxoplanet.starry.surface import Surface
[docs]
class SurfaceBody(Body):
[docs]
surface: Surface | None = None
[docs]
class SurfaceSystem(System):
[docs]
central_surface: Surface | None
_body_surface_stack: ObjectStack[Surface]
def __init__(
self,
central: Central | None = None,
central_surface: Surface | None = None,
*,
bodies: Iterable[tuple[Body | OrbitalBody | SurfaceBody, Surface | None]] = (),
):
[docs]
self.central = Central() if central is None else central
if central_surface is None:
central_surface = Surface()
self.central_surface = central_surface
orbital_bodies = []
body_surfaces = []
for body, surface in bodies:
if isinstance(body, OrbitalBody):
orbital_bodies.append(body)
body_surfaces.append(surface)
else:
orbital_bodies.append(OrbitalBody(self.central, body))
if surface is None:
body_surfaces.append(getattr(body, "surface", None))
else:
body_surfaces.append(surface)
self._body_stack = ObjectStack(*orbital_bodies)
self._body_surface_stack = ObjectStack(*body_surfaces)
@property
[docs]
def body_surfaces(self) -> tuple[Surface, ...]:
return self._body_surface_stack.objects
[docs]
def add_body(
self,
body: Body | SurfaceBody | None = None,
surface: Surface | None = None,
**kwargs: Any,
) -> "SurfaceSystem":
if body is None:
body = Body(**kwargs)
if surface is None:
surface = getattr(body, "surface", None)
bodies = list(zip(self.bodies, self.body_surfaces, strict=False)) + [
(body, surface)
]
return SurfaceSystem(
central=self.central,
central_surface=self.central_surface,
bodies=bodies,
)
[docs]
def surface_vmap(
self,
func: Callable,
in_axes: int | None | Sequence[Any] = 0,
out_axes: Any = 0,
) -> Callable:
return self._body_surface_stack.vmap(func, in_axes=in_axes, out_axes=out_axes)