$$ % Define your custom commands here \newcommand{\bmat}[1]{\begin{bmatrix}#1\end{bmatrix}} \newcommand{\E}{\mathbb{E}} \newcommand{\P}{\mathbb{P}} \newcommand{\S}{\mathbb{S}} \newcommand{\R}{\mathbb{R}} \newcommand{\S}{\mathbb{S}} \newcommand{\norm}[2]{\|{#1}\|_{{}_{#2}}} \newcommand{\pd}[2]{\frac{\partial #1}{\partial #2}} \newcommand{\pdd}[2]{\frac{\partial^2 #1}{\partial #2^2}} \newcommand{\vectornorm}[1]{\left|\left|#1\right|\right|} \newcommand{\abs}[1]{\left|{#1}\right|} \newcommand{\mbf}[1]{\mathbf{#1}} \newcommand{\mc}[1]{\mathcal{#1}} \newcommand{\bm}[1]{\boldsymbol{#1}} \newcommand{\nicefrac}[2]{{}^{#1}\!/_{\!#2}} \newcommand{\argmin}{\operatorname*{arg\,min}} \newcommand{\argmax}{\operatorname*{arg\,max}} $$

Functions for working with pytrees

import jax
import jax.numpy as jnp
from jax import random

import numpy as np
import collections
LAYER_SIZES = [200*200*3, 2048, 1024, 2]
PARAM_SCALE = 0.01
def random_layer_params(m, n, key, scale=1e-2):
    w_key, b_key = random.split(key)
    return (scale * random.normal(w_key, (n, m)), 
        scale * random.normal(b_key, (n,)))

def init_network_params(sizes, key=random.key(0), scale=0.01):
    keys = random.split(key, len(sizes)-1)
    return [random_layer_params(m, n, k, scale) 
        for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
key = random.key(42)
params = init_network_params(LAYER_SIZES, key, scale=PARAM_SCALE)
params
shapes = jax.tree.map(lambda p: p.shape, params)

for i, shape in enumerate(shapes):
    print(i, shape)
Point = collections.namedtuple('Point', ['x', 'y'])

example_pytree = [
    {
        'a': [1, 2, 3],
        'b': jnp.array([1, 2, 3]),
        'c': np.array([1, 2, 3]),
    },
    [42, [44, 46], None],
    31337,
    (50, (60, 70)),
    Point(640, 480),
    collections.OrderedDict([('a', 100), ('b', 200)]),
    'some string'
]

jax.tree.leaves(example_pytree)
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)
from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
HEIGHT, WIDTH, CHANNELS = 28, 28, 1
NUM_PIXELS = HEIGHT * WIDTH * CHANNELS
NUM_LABELS = 10
LAYER_SIZES = [28 * 28, 512, 10]
PARAM_SCALE = 0.1
from jax import grad, jit, vmap, value_and_grad
from jax import random
from jax.nn import swish, logsumexp, one_hot
def init_network_params(sizes, key=random.key(0), scale=1e-2):
    """Initialize all layers"""

    def random_layer_params(m, n, key, scale=1e-2):
        """A helper function"""
        w_key, b_key = random.split(key)
        return scale * random.normal(w_key, (n,m)), scale * random.normal(b_key, (n,))

    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
init_params = init_network_params(LAYER_SIZES, random.key(0), scale=PARAM_SCALE)
def predict(params, image):
    """Function for per-example predictions."""
    activations = image
    for w, b in params[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = swish(outputs)
    
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits
batched_predict = vmap(predict, in_axes=(None, 0))
INIT_LR = 1.0
DECAY_RATE = 0.95
DECAY_STEPS = 5
def loss(params, images, targets):
    """Categorical cross entropy loss."""
    logits = batched_predict(params, images)
    log_preds = logits - logsumexp(logits)
    return -jnp.mean(targets*log_preds)

@jax.jit
def update(params, x, y, epoch_number):
    print(f"Params shapes: {jax.tree.map(lambda p: p.shape, params)}")
    loss_value, grads = value_and_grad(loss)(params, x, y)
    print(f"Grads shapes: {jax.tree.map(lambda g: g.shape, grads)}")
    lr = INIT_LR * DECAY_RATE ** (epoch_number / DECAY_STEPS)
    return [(w - lr * dw, b - lr * db) for (w, b), (dw, db) in zip(params, grads)], loss_value
x, y = next(iter(train_dataloader))
x = x.numpy().reshape(64, 28*28)
x = jnp.reshape(x, (len(x), NUM_PIXELS))
y = one_hot(y.numpy(), NUM_LABELS)
params, loss_value = update(init_params, x, y, 0)
params = init_network_params(LAYER_SIZES, key, scale=PARAM_SCALE)
scaled_params = jax.tree.map(lambda p: 10 * p, params)
some_pytree = [
    [1, 1, 1],
    [
        [10, 10, 10], [20, 20]
    ]
]
jax.tree.map(lambda p: p+1, some_pytree)
leaves, struct = jax.tree.flatten(some_pytree)
print(leaves)
print(struct)
updated_leaves = map(lambda x: x+1, leaves)
jax.tree.unflatten(struct, updated_leaves)
from jax.flatten_util import ravel_pytree
leaves, unflatten_func = ravel_pytree(some_pytree)
print(leaves)
print(unflatten_func)
unflatten_func(leaves)

Reducing a tree

jax.tree.reduce(lambda accumulator, value: accumulator+value, some_pytree, initializer=0)
import math
from collections import namedtuple

Transposing a pytree

Point = namedtuple('Point', ['x', 'y'])
points = [
    Point(0.0, 0.0),
    Point(3.0, 0.0),
    Point(0.0, 4.0)
]
def rotate_point(p, theta):
    x = p.x * math.cos(theta) - p.y * math.sin(theta)
    y = p.x * math.sin(theta) + p.y * math.cos(theta)
    return Point(x, y)
rotate_point(points[1], math.pi)
jax.vmap(rotate_point, in_axes=(0, None))(points, math.pi)
jax.tree.structure(points)
jax.tree.structure(points[0])
points_t = jax.tree.transpose(
    outer_treedef=jax.tree.structure([0 for p in points]),
    inner_treedef=jax.tree.structure(points[0]),
    pytree_to_transpose=points
)
points_t
points_t_array = Point(jnp.array(points_t.x), jnp.array(points_t.y))
points_t_array
jax.vmap(rotate_point, in_axes=(0, None))(points_t_array, math.pi)

Creating custom pytree nodes

class Layer:
    def __init__(self, name, w, b):
        self.name = name
        self.w = w
        self.b = b
h1 = Layer('hidden1', jnp.zeros((100,20)), jnp.zeros((20,)))
pt = [
    jnp.ones(50),
    h1
]
jax.tree.leaves(pt)
jax.tree.map(lambda x: x*10, pt)
def flatten_layer(container):
    flat_contents = [container.w, container.b]
    aux_data = container.name
    return flat_contents, aux_data

def unflatten_layer(aux_data, flat_contents):
    return Layer(aux_data, *flat_contents)
jax.tree_util.register_pytree_node(
    Layer,
    flatten_layer,
    unflatten_layer
)
h1 = Layer('hidden1', jnp.zeros((100, 20)), jnp.zeros((20,)))
pt = [
    jnp.ones(50), 
    h1
]
jax.tree.leaves(pt)
jax.tree.map(lambda x: x*10, pt)
jax.tree.leaves(pt)
pt2 = jax.tree.map(lambda x: x+1, pt)
pt2
jax.tree.leaves(pt2)