\n",
"ⓘ The **Equinox library** provides the functionality to register custom classes as PyTrees by inheritance from

\n",
"\n",
"Let's take a short stroll through the woods:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"example_trees = [{\"a\": 0, \"b\": (1, 2)}, jnp.array([3, 4, 5]), [\"a\", object()]]\n",
"for pytree in example_trees:\n",
" leaves = jax.tree_util.tree_leaves(pytree)\n",
" print(f\"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example: Model parameters\n",
"\n",
"In the following, we will optimise a linear model using gradient descent to illustrate how PyTrees can be used.\n",
"\n",
"We will start by generating some data:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"%config InlineBackend.figure_format = 'svg'\n",
"\n",
"N = 50\n",
"x = jnp.linspace(-1, 1, N)\n",
"noise = 0.3 * jax.random.normal(key=jax.random.PRNGKey(0), shape=(N,))\n",
"y = 2 * x + 10 + noise\n",
"\n",
"data = {\"x\": x, \"y\": y}\n",
"\n",
"fig, ax = plt.subplots(figsize=(8, 4))\n",
"ax.scatter(data[\"x\"], data[\"y\"], marker=\".\", color=\"black\")\n",
"ax.set(xlabel=\"x\", ylabel=\"y\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we define the linear model, the loss function, and a function that performs gradient descent:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def linear_model(params: dict, x: jax.Array) -> jax.Array:\n",
" return params[\"slope\"] * x + params[\"offset\"]\n",
"\n",
"\n",
"def mse_loss(params: dict, data: dict) -> jax.Array:\n",
" prediction = linear_model(params, data[\"x\"])\n",
" return jnp.mean((data[\"y\"] - prediction) ** 2)\n",
"\n",
"\n",
"@jax.jit\n",
"def gradient_descent_step(params: dict, data: dict, lr: float = 0.1) -> dict:\n",
" grads = jax.grad(mse_loss, argnums=0)(params, data)\n",
" return {key: params[key] - lr * grads[key] for key in params}\n",
"\n",
"\n",
"def gradient_descent(\n",
" initial_params: dict, data: dict, lr: float = 0.1, num_steps: int = 50\n",
") -> dict:\n",
" \"\"\"Performs multiple steps of gradient descent.\"\"\"\n",
" params = initial_params.copy()\n",
"\n",
" for _ in range(num_steps):\n",
" params = gradient_descent_step(params, data, lr)\n",
"\n",
" return params"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that we pass the model parameters, `params`, as a dictionary, but retain the ability to compute gradients with respect to this argument.\n",
"\n",
"Now, we can perform 50 steps of gradient descent and inspect the result:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"initial_params = {\"slope\": jnp.array([1.0]), \"offset\": jnp.array([0.0])}\n",
"\n",
"params = gradient_descent(initial_params, data, lr=0.1, num_steps=50)\n",
"\n",
"fig, ax = plt.subplots(figsize=(8, 4))\n",
"ax.scatter(data[\"x\"], data[\"y\"], marker=\".\", color=\"black\", label=\"Data\")\n",
"ax.plot(\n",
" data[\"x\"],\n",
" linear_model(params, data[\"x\"]),\n",
" ls=\"dashed\",\n",
" color=\"orange\",\n",
" label=\"Optimised model\",\n",
")\n",
"ax.set(xlabel=\"x\", ylabel=\"y\")\n",
"ax.legend(loc=\"upper left\", frameon=False)\n",
"\n",
"print(f\"slope: {params['slope'][0]:<.2f}, offset: {params['offset'][0]:<.2f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Although a model with only two scalar parameters does not necessitate the usage of a dictionary,\n",
"it is not hard to imagine a more complex model with many parameters of various shapes where it\n",
"is very convenient to apply transformations across the entire structure while maintaining the nested hierarchy.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Sources and Further Resources\n",
"\n",
"The above tutorial is largely based on [JAX's comprehensive documentation](https://jax.readthedocs.io/en/latest/index.html). For those eager to delve deeper into the intricacies of JAX, we highly recommend exploring this further, as it offers detailed explanations, examples, and advanced techniques to enrich your understanding and mastery of JAX's capabilities.\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.14"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
`equinox.Module`

. \n",
"This facilitates seamless integration of custom structures into the PyTree ecosystem,\n",
"enhancing the flexibility and extensibility of PyTree-based operations.\n",
"