jaxoplanet.object_stack#

Module Contents#

Classes#

ObjectStack

A stack of objects supporting vmapping even with different Pytree structure

class jaxoplanet.object_stack.ObjectStack(*objects: Obj)[source]#

Bases: equinox.Module, Generic[Obj]

A stack of objects supporting vmapping even with different Pytree structure

By default, functions can only be vmapped over a set of JAX objects when their Pytree structure matches, but this object generalizes that behavior to support a consistent interface that uses vmap whenever possible, falling back on a Python loop for variable Pytree structure.

Parameters:

objecst – A set of Pytree objects

objects: tuple[Obj, Ellipsis][source]#
stack: Obj | None[source]#
__len__() int[source]#
vmap(func: collections.abc.Callable, in_axes: int | None | collections.abc.Sequence[Any] = 0, out_axes: Any = 0) collections.abc.Callable[source]#

Map a function over the objects in this stac

If possible, this method will apply the appropriate jax.vmap to the input function, but if the Pytree structure of the objects don’t match, this requires a loop over objects, applying the function separately to each object, and stacking the results.

Parameters:
  • func – The function to map. It’s first positional argument must accept an object of the type Obj.

  • in_axes – The input axis specifications for all arguments after the first. The semantics should match jax.vmap.

  • out_axes – The output axis specifications, matching jax.vmap.

Returns:

The vectorized version of func mapped over obejcts in this stack.