jaxoplanet.units.decorator
#
Module Contents#
Functions#
|
A decorator to wrap functions that require quantities as inputs |
- jaxoplanet.units.decorator.quantity_input(func: Callable[Ellipsis, Any] | None = None, *, _strict: bool = False, **kwargs: Any) Any [source]#
A decorator to wrap functions that require quantities as inputs
Please note, this is similar to the decorator of the same name from
astropy.units
, but the behavior is slightly different, in ways that we’ll try to highlight here.This decorator will wrap a function and check or convert the units of all inputs, such that the wrapped function can assume that all input units are correct. Note that all arguments must be specified by name (even when they) are positional, and variable
*args
and**kwargs
arguments are not supported.By default, if a non-
Quantity
is provided, it will be assumed to be in the correct units, and converted to aQuantity
. If the_strict
flag is instead set toTrue
, inputting a non-Quantity
will raise aValueError
.Examples
The following function expects a length in meters and a time in seconds, and it returns a speed in meters per second:
from jaxoplanet.units import unit_registry as ureg @units.quantity_input(a=ureg.m, b=ureg.s) def speed(a, b): return a / b
If we call this function with a length and a time, it will work as expected:
speed(1.5 * ureg.m, 0.5 * ureg.s)
And it will also handle unit conversions:
speed(1.5 * ureg.AU, 0.5 * ureg.day) # The result will still be in m/s
To skip validating specific inputs, you can set the unit to
None
, or omit it from the decorator arguments:@units.quantity_input(x=ureg.m) # optionally include flag=None def condition(x, flag): if flag: return x else: return 0.0 * x
JAX Pytree support
This decorator also supports JAX Pytrees, so you can use it to wrap functions with structured inputs. For example, we could rewrite the
speed
example from above as:@units.quantity_input(params={"distance": ureg.m, "time": ureg.s}) def speed(params): return params["distance"] / params["time"]
This will work with arbitrary Pytrees, as long as structure of the input fully matches the decorator argument. In other words, since the full Pytree structure must be specified, you’ll need to explicitly list
None
for any Pytree nodes that you want to skip during validation:# Omitting `flag` from the decorator wouldn't work here @units.quantity_input(params={"x": ureg.m, "flag": None}) def condition(params): if params["flag"]: return params["x"] else: return 0.0 * params["x"]