Source code for jaxoplanet.units.decorator

__all__ = ["quantity_input"]

import inspect
from collections.abc import Callable
from functools import partial, wraps
from typing import Any

import jax
from pint import DimensionalityError

from jaxoplanet.units.registry import unit_registry


[docs] def quantity_input( func: Callable[..., Any] | None = None, *, _strict: bool = False, **kwargs: Any, ) -> Any: """A decorator to wrap functions that require quantities as inputs Please note, this is similar to the decorator of the same name from ``astropy.units``, but the behavior is slightly different, in ways that we'll try to highlight here. This decorator will wrap a function and check or convert the units of all inputs, such that the wrapped function can assume that all input units are correct. Note that all arguments must be specified by name (even when they) are positional, and variable ``*args`` and ``**kwargs`` arguments are not supported. By default, if a non-``Quantity`` is provided, it will be assumed to be in the correct units, and converted to a ``Quantity``. If the ``_strict`` flag is instead set to ``True``, inputting a non-``Quantity`` will raise a ``ValueError``. **Examples** The following function expects a length in meters and a time in seconds, and it returns a speed in meters per second: .. code-block:: python from jaxoplanet.units import unit_registry as ureg @units.quantity_input(a=ureg.m, b=ureg.s) def speed(a, b): return a / b If we call this function with a length and a time, it will work as expected: .. code-block:: python speed(1.5 * ureg.m, 0.5 * ureg.s) And it will also handle unit conversions: .. code-block:: python speed(1.5 * ureg.AU, 0.5 * ureg.day) # The result will still be in m/s To skip validating specific inputs, you can set the unit to ``None``, or omit it from the decorator arguments: .. code-block:: python @units.quantity_input(x=ureg.m) # optionally include flag=None def condition(x, flag): if flag: return x else: return 0.0 * x **JAX Pytree support** This decorator also supports JAX Pytrees, so you can use it to wrap functions with structured inputs. For example, we could rewrite the ``speed`` example from above as: .. code-block:: python @units.quantity_input(params={"distance": ureg.m, "time": ureg.s}) def speed(params): return params["distance"] / params["time"] This will work with arbitrary Pytrees, as long as structure of the input fully matches the decorator argument. In other words, since the full Pytree structure must be specified, you'll need to explicitly list ``None`` for any Pytree nodes that you want to skip during validation: .. code-block:: python # Omitting `flag` from the decorator wouldn't work here @units.quantity_input(params={"x": ureg.m, "flag": None}) def condition(params): if params["flag"]: return params["x"] else: return 0.0 * params["x"] """ if func is None: return QuantityInput(_strict=_strict, **kwargs) else: if not callable(func): raise TypeError( "The first argument to 'quantity_input' must be a callable function, " "and all unit specifications must be passed as keyword arguments " "by name" ) return QuantityInput(_strict=_strict, **kwargs)(func)
class QuantityInput: """This helper class implements the logic for ``quantity_input`` Typically users should expect to interact primarily with the ``quantity_input`` decorator directly instead of this class, but this enables the use of ``quantity_input`` either as a decorator with arguments or as a single function call. """ def __init__( self, *, _strict: bool = False, **kwargs: Any, ): self.decorator_kwargs = kwargs self.strict = _strict def __call__(self, func: Callable) -> Callable: signature = inspect.signature(func) bound_input_units = signature.bind_partial(**self.decorator_kwargs) input_unit_map = {} for param in signature.parameters.values(): if param.kind in ( inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL, ): # We typically ignore *args and **kwargs, but if they were # explicitly specified, we raise an error since we don't support # that use case if param.name in bound_input_units.arguments: raise TypeError( "Units for general variable arguments and keyword " "arguments are not supported" ) else: input_unit_map[param.name] = bound_input_units.arguments.get( param.name, None ) @wraps(func) def wrapped(*args: Any, **kwargs: Any) -> Any: bound_args = signature.bind(*args, **kwargs) bound_args.apply_defaults() for name, value in bound_args.arguments.items(): # If the value and the default value is None, pass through # ignoring units if value is None and signature.parameters[name].default is None: continue unit = input_unit_map.get(name, None) if unit is not None: bound_args.arguments[name] = jax.tree_util.tree_map( partial(_apply_units, name=name, strict=self.strict), value, unit, is_leaf=_is_quantity, ) return func(*bound_args.args, **bound_args.kwargs) return wrapped def _apply_units( value: Any, units: Any, strict: bool = False, name: str | None = None ) -> Any: if units is None: return value if _is_quantity(value): try: return value.to(units) except DimensionalityError as e: raise DimensionalityError( e.units1, e.units2, e.dim1, e.dim2, "" if name is None else f" for input '{name}'", ) from None elif strict: raise ValueError("Arguments must be quantities for strict parsing") else: return unit_registry.Quantity(value, units) def _is_quantity(x: Any) -> bool: return hasattr(x, "_magnitude") and hasattr(x, "_units")