jaxoplanet.object_stack
#
Module Contents#
Classes#
A stack of objects supporting vmapping even with different Pytree structure |
- class jaxoplanet.object_stack.ObjectStack(*objects: Obj)[source]#
Bases:
equinox.Module
,Generic
[Obj
]A stack of objects supporting vmapping even with different Pytree structure
By default, functions can only be vmapped over a set of JAX objects when their Pytree structure matches, but this object generalizes that behavior to support a consistent interface that uses
vmap
whenever possible, falling back on a Python loop for variable Pytree structure.- Parameters:
objecst – A set of Pytree objects
- vmap(func: collections.abc.Callable, in_axes: int | None | collections.abc.Sequence[Any] = 0, out_axes: Any = 0) collections.abc.Callable [source]#
Map a function over the objects in this stac
If possible, this method will apply the appropriate
jax.vmap
to the input function, but if the Pytree structure of the objects don’t match, this requires a loop over objects, applying the function separately to each object, and stacking the results.- Parameters:
func – The function to map. It’s first positional argument must accept an object of the type
Obj
.in_axes – The input axis specifications for all arguments after the first. The semantics should match
jax.vmap
.out_axes – The output axis specifications, matching
jax.vmap
.
- Returns:
The vectorized version of
func
mapped over obejcts in this stack.