Source code for jaxoplanet.object_stack

__all__ = ["ObjectStack"]

from collections.abc import Callable, Sequence
from functools import wraps
from typing import Any, Generic, TypeVar

import equinox as eqx
import jax
import jax.numpy as jnp
from jax.tree_util import tree_flatten

try:
    from jax.extend import linear_util as lu
except ImportError:
    from jax import linear_util as lu  # type: ignore

from jax._src.interpreters.batching import is_vmappable

Obj = TypeVar("Obj")


[docs] class ObjectStack(eqx.Module, Generic[Obj]): """A stack of objects supporting vmapping even with different Pytree structure By default, functions can only be vmapped over a set of JAX objects when their Pytree structure matches, but this object generalizes that behavior to support a consistent interface that uses ``vmap`` whenever possible, falling back on a Python loop for variable Pytree structure. Args: objecst: A set of Pytree objects """
[docs] objects: tuple[Obj, ...]
[docs] stack: Obj | None
def __init__(self, *objects: Obj): self.objects = objects # If all the objects have matching Pytree structure then we save a # stacked version that we can use for vmaps below. This allows for more # efficient evaluations in the case of multiple objects. self.stack = None if len(self.objects): spec = list(map(jax.tree_util.tree_structure, self.objects)) if spec.count(spec[0]) == len(spec): self.stack = jax.tree_util.tree_map( lambda *x: jnp.stack(x, axis=0), *self.objects )
[docs] def __len__(self) -> int: return len(self.objects)
[docs] def vmap( self, func: Callable, in_axes: int | None | Sequence[Any] = 0, out_axes: Any = 0, ) -> Callable: """Map a function over the objects in this stack If possible, this method will apply the appropriate ``jax.vmap`` to the input function, but if the Pytree structure of the objects don't match, this requires a loop over objects, applying the function separately to each object, and stacking the results. Args: func: The function to map. It's first positional argument must accept an object of the type ``Obj``. in_axes: The input axis specifications for all arguments after the first. The semantics should match ``jax.vmap``. out_axes: The output axis specifications, matching ``jax.vmap``. Returns: The vectorized version of ``func`` mapped over obejcts in this stack. """ @wraps(func) def impl(*args): # First, normalize the "in_axes" argument so we always have an iterable if isinstance(in_axes, Sequence): in_axes_ = tuple(in_axes) else: in_axes_ = tuple(in_axes for _ in args) # If we have a "body_stack" we can just vmap and be done if self.stack is not None: return jax.vmap(func, in_axes=(0,) + in_axes_, out_axes=out_axes)( self.stack, *args ) # Otherwise we need to loop over the bodies and apply the function once for # each body # Here we flatten the input arguments and `in_axes` so that we don't have # to deal with Pytree logic for the `in_axes` ourselves below. args_flat, in_tree = tree_flatten(args, is_leaf=is_vmappable) in_axes_flat = jax.api_util.flatten_axes( # type: ignore "body_vmap in_axes", in_tree, in_axes_ ) # Then loop over the bodies and accumulate the function results results = [] out_tree = None for n, body in enumerate(self.objects): debug_info = jax.api_util.debug_info( "ObjectStack.vmap", func, (body,) + tuple(args), {} ) f = lu.wrap_init(func, debug_info=debug_info) f, out_tree_ = flatten_func_for_object_vmap(f, in_tree, in_axes_flat, n) results.append(f.call_wrapped(body, *args_flat)) # type: ignore out_tree_ = out_tree_() # type: ignore if out_tree is not None and out_tree_ != out_tree: raise ValueError( "Input function does not return consistent Pytree structure;\n" f"expected: {out_tree}\n" f"found: {out_tree_}\n" ) out_tree = out_tree_ out_axes_flat = jax.api_util.flatten_axes( # type: ignore "body_vmap out_axes", out_tree, out_axes ) return out_tree.unflatten( # type: ignore parts[0] if a is None else jnp.stack(parts, axis=a) for a, *parts in zip(out_axes_flat, *results, strict=False) # type: ignore ) return impl
def index_helper(n, arg, axis): if axis is None: return arg else: idx = (slice(None),) * axis + (n,) return arg[idx] @lu.transformation_with_aux # type: ignore def flatten_func_for_object_vmap(in_tree, in_axes_flat, index, body, *args_flat): args_indexed = ( index_helper(index, *args) for args in zip(args_flat, in_axes_flat, strict=False) ) ans = yield (body,) + in_tree.unflatten(args_indexed), {} yield tree_flatten(ans, is_leaf=is_vmappable)