jaxoplanet.units.decorator

jaxoplanet.units.decorator#

Module Contents#

Functions#

quantity_input(→ Any)

A decorator to wrap functions that require quantities as inputs

jaxoplanet.units.decorator.quantity_input(func: Callable[Ellipsis, Any] | None = None, *, _strict: bool = False, **kwargs: Any) Any[source]#

A decorator to wrap functions that require quantities as inputs

Please note, this is similar to the decorator of the same name from astropy.units, but the behavior is slightly different, in ways that we’ll try to highlight here.

This decorator will wrap a function and check or convert the units of all inputs, such that the wrapped function can assume that all input units are correct. Note that all arguments must be specified by name (even when they) are positional, and variable *args and **kwargs arguments are not supported.

By default, if a non-Quantity is provided, it will be assumed to be in the correct units, and converted to a Quantity. If the _strict flag is instead set to True, inputting a non-Quantity will raise a ValueError.

Examples

The following function expects a length in meters and a time in seconds, and it returns a speed in meters per second:

from jaxoplanet.units import unit_registry as ureg

@units.quantity_input(a=ureg.m, b=ureg.s)
def speed(a, b):
    return a / b

If we call this function with a length and a time, it will work as expected:

speed(1.5 * ureg.m, 0.5 * ureg.s)

And it will also handle unit conversions:

speed(1.5 * ureg.AU, 0.5 * ureg.day)  # The result will still be in m/s

To skip validating specific inputs, you can set the unit to None, or omit it from the decorator arguments:

@units.quantity_input(x=ureg.m)  # optionally include flag=None
def condition(x, flag):
    if flag:
        return x
    else:
        return 0.0 * x

JAX Pytree support

This decorator also supports JAX Pytrees, so you can use it to wrap functions with structured inputs. For example, we could rewrite the speed example from above as:

@units.quantity_input(params={"distance": ureg.m, "time": ureg.s})
def speed(params):
    return params["distance"] / params["time"]

This will work with arbitrary Pytrees, as long as structure of the input fully matches the decorator argument. In other words, since the full Pytree structure must be specified, you’ll need to explicitly list None for any Pytree nodes that you want to skip during validation:

# Omitting `flag` from the decorator wouldn't work here
@units.quantity_input(params={"x": ureg.m, "flag": None})
def condition(params):
    if params["flag"]:
        return params["x"]
    else:
        return 0.0 * params["x"]