$$
% Define your custom commands here
\newcommand{\bmat}[1]{\begin{bmatrix}#1\end{bmatrix}}
\newcommand{\E}{\mathbb{E}}
\newcommand{\P}{\mathbb{P}}
\newcommand{\S}{\mathbb{S}}
\newcommand{\R}{\mathbb{R}}
\newcommand{\S}{\mathbb{S}}
\newcommand{\norm}[2]{\|{#1}\|_{{}_{#2}}}
\newcommand{\pd}[2]{\frac{\partial #1}{\partial #2}}
\newcommand{\pdd}[2]{\frac{\partial^2 #1}{\partial #2^2}}
\newcommand{\vectornorm}[1]{\left|\left|#1\right|\right|}
\newcommand{\abs}[1]{\left|{#1}\right|}
\newcommand{\mbf}[1]{\mathbf{#1}}
\newcommand{\mc}[1]{\mathcal{#1}}
\newcommand{\bm}[1]{\boldsymbol{#1}}
\newcommand{\nicefrac}[2]{{}^{#1}\!/_{\!#2}}
\newcommand{\argmin}{\operatorname*{arg\,min}}
\newcommand{\argmax}{\operatorname*{arg\,max}}
$$
import jax
import jax.numpy as jnp
from jax import random
import numpy as np
import collections
LAYER_SIZES = [200 * 200 * 3 , 2048 , 1024 , 2 ]
PARAM_SCALE = 0.01
def random_layer_params(m, n, key, scale= 1e-2 ):
w_key, b_key = random.split(key)
return (scale * random.normal(w_key, (n, m)),
scale * random.normal(b_key, (n,)))
def init_network_params(sizes, key= random.key(0 ), scale= 0.01 ):
keys = random.split(key, len (sizes)- 1 )
return [random_layer_params(m, n, k, scale)
for m, n, k in zip (sizes[:- 1 ], sizes[1 :], keys)]
key = random.key(42 )
params = init_network_params(LAYER_SIZES, key, scale= PARAM_SCALE)
params
shapes = jax.tree.map (lambda p: p.shape, params)
for i, shape in enumerate (shapes):
print (i, shape)
Point = collections.namedtuple('Point' , ['x' , 'y' ])
example_pytree = [
{
'a' : [1 , 2 , 3 ],
'b' : jnp.array([1 , 2 , 3 ]),
'c' : np.array([1 , 2 , 3 ]),
},
[42 , [44 , 46 ], None ],
31337 ,
(50 , (60 , 70 )),
Point(640 , 480 ),
collections.OrderedDict([('a' , 100 ), ('b' , 200 )]),
'some string'
]
jax.tree.leaves(example_pytree)
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.MNIST(
root= "data" ,
train= True ,
download= True ,
transform= ToTensor()
)
test_data = datasets.MNIST(
root= "data" ,
train= False ,
download= True ,
transform= ToTensor()
)
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size= 64 , shuffle= True )
test_dataloader = DataLoader(test_data, batch_size= 64 , shuffle= True )
HEIGHT, WIDTH, CHANNELS = 28 , 28 , 1
NUM_PIXELS = HEIGHT * WIDTH * CHANNELS
NUM_LABELS = 10
LAYER_SIZES = [28 * 28 , 512 , 10 ]
PARAM_SCALE = 0.1
from jax import grad, jit, vmap, value_and_grad
from jax import random
from jax.nn import swish, logsumexp, one_hot
def init_network_params(sizes, key= random.key(0 ), scale= 1e-2 ):
"""Initialize all layers"""
def random_layer_params(m, n, key, scale= 1e-2 ):
"""A helper function"""
w_key, b_key = random.split(key)
return scale * random.normal(w_key, (n,m)), scale * random.normal(b_key, (n,))
keys = random.split(key, len (sizes))
return [random_layer_params(m, n, k) for m, n, k in zip (sizes[:- 1 ], sizes[1 :], keys)]
init_params = init_network_params(LAYER_SIZES, random.key(0 ), scale= PARAM_SCALE)
def predict(params, image):
"""Function for per-example predictions."""
activations = image
for w, b in params[:- 1 ]:
outputs = jnp.dot(w, activations) + b
activations = swish(outputs)
final_w, final_b = params[- 1 ]
logits = jnp.dot(final_w, activations) + final_b
return logits
batched_predict = vmap(predict, in_axes= (None , 0 ))
INIT_LR = 1.0
DECAY_RATE = 0.95
DECAY_STEPS = 5
def loss(params, images, targets):
"""Categorical cross entropy loss."""
logits = batched_predict(params, images)
log_preds = logits - logsumexp(logits)
return - jnp.mean(targets* log_preds)
@jax.jit
def update(params, x, y, epoch_number):
print (f"Params shapes: { jax. tree. map (lambda p: p.shape, params)} " )
loss_value, grads = value_and_grad(loss)(params, x, y)
print (f"Grads shapes: { jax. tree. map (lambda g: g.shape, grads)} " )
lr = INIT_LR * DECAY_RATE ** (epoch_number / DECAY_STEPS)
return [(w - lr * dw, b - lr * db) for (w, b), (dw, db) in zip (params, grads)], loss_value
x, y = next (iter (train_dataloader))
x = x.numpy().reshape(64 , 28 * 28 )
x = jnp.reshape(x, (len (x), NUM_PIXELS))
y = one_hot(y.numpy(), NUM_LABELS)
params, loss_value = update(init_params, x, y, 0 )
params = init_network_params(LAYER_SIZES, key, scale= PARAM_SCALE)
scaled_params = jax.tree.map (lambda p: 10 * p, params)
some_pytree = [
[1 , 1 , 1 ],
[
[10 , 10 , 10 ], [20 , 20 ]
]
]
jax.tree.map (lambda p: p+ 1 , some_pytree)
leaves, struct = jax.tree.flatten(some_pytree)
print (leaves)
print (struct)
updated_leaves = map (lambda x: x+ 1 , leaves)
jax.tree.unflatten(struct, updated_leaves)
from jax.flatten_util import ravel_pytree
leaves, unflatten_func = ravel_pytree(some_pytree)
print (leaves)
print (unflatten_func)
Reducing a tree
jax.tree.reduce (lambda accumulator, value: accumulator+ value, some_pytree, initializer= 0 )
import math
from collections import namedtuple
Transposing a pytree
Point = namedtuple('Point' , ['x' , 'y' ])
points = [
Point(0.0 , 0.0 ),
Point(3.0 , 0.0 ),
Point(0.0 , 4.0 )
]
def rotate_point(p, theta):
x = p.x * math.cos(theta) - p.y * math.sin(theta)
y = p.x * math.sin(theta) + p.y * math.cos(theta)
return Point(x, y)
rotate_point(points[1 ], math.pi)
jax.vmap(rotate_point, in_axes= (0 , None ))(points, math.pi)
jax.tree.structure(points)
jax.tree.structure(points[0 ])
points_t = jax.tree.transpose(
outer_treedef= jax.tree.structure([0 for p in points]),
inner_treedef= jax.tree.structure(points[0 ]),
pytree_to_transpose= points
)
points_t
points_t_array = Point(jnp.array(points_t.x), jnp.array(points_t.y))
points_t_array
jax.vmap(rotate_point, in_axes= (0 , None ))(points_t_array, math.pi)
Creating custom pytree nodes
class Layer:
def __init__ (self , name, w, b):
self .name = name
self .w = w
self .b = b
h1 = Layer('hidden1' , jnp.zeros((100 ,20 )), jnp.zeros((20 ,)))
pt = [
jnp.ones(50 ),
h1
]
jax.tree.map (lambda x: x* 10 , pt)
def flatten_layer(container):
flat_contents = [container.w, container.b]
aux_data = container.name
return flat_contents, aux_data
def unflatten_layer(aux_data, flat_contents):
return Layer(aux_data, * flat_contents)
jax.tree_util.register_pytree_node(
Layer,
flatten_layer,
unflatten_layer
)
h1 = Layer('hidden1' , jnp.zeros((100 , 20 )), jnp.zeros((20 ,)))
pt = [
jnp.ones(50 ),
h1
]
jax.tree.map (lambda x: x* 10 , pt)
pt2 = jax.tree.map (lambda x: x+ 1 , pt)