Source code for jaxoplanet.light_curves.utils

__all__ = ["vectorize"]

from functools import wraps
from typing import Any

import jax

from jaxoplanet.light_curves.types import LightCurveFunc
from jaxoplanet.types import Array, Scalar


[docs] def vectorize(func: LightCurveFunc) -> LightCurveFunc: """Vectorize a scalar light curve function to work with array inputs Like ``jax.numpy.vectorize``, this automatically wraps a function which operates on a scalar to handle array inputs. Unlike that function, this handles ``Scalar`` inputs and outputs, but it only broadcasts the first input (``time``). Args: func: A function which takes a scalar ``Scalar`` time as the first input Returns: An updated function which can operate on ``Scalar`` times of any shape """ @wraps(func) def wrapped(time: Scalar, *args: Any, **kwargs: Any) -> Array | Scalar: def inner(time_magnitude: Array) -> Array | Scalar: return func(time_magnitude, *args, **kwargs) for _ in time.shape: inner = jax.vmap(inner) return inner(time) return wrapped