$$ % Define your custom commands here \newcommand{\bmat}[1]{\begin{bmatrix}#1\end{bmatrix}} \newcommand{\E}{\mathbb{E}} \newcommand{\P}{\mathbb{P}} \newcommand{\S}{\mathbb{S}} \newcommand{\R}{\mathbb{R}} \newcommand{\S}{\mathbb{S}} \newcommand{\norm}[2]{\|{#1}\|_{{}_{#2}}} \newcommand{\pd}[2]{\frac{\partial #1}{\partial #2}} \newcommand{\pdd}[2]{\frac{\partial^2 #1}{\partial #2^2}} \newcommand{\vectornorm}[1]{\left|\left|#1\right|\right|} \newcommand{\abs}[1]{\left|{#1}\right|} \newcommand{\mbf}[1]{\mathbf{#1}} \newcommand{\mc}[1]{\mathcal{#1}} \newcommand{\bm}[1]{\boldsymbol{#1}} \newcommand{\nicefrac}[2]{{}^{#1}\!/_{\!#2}} \newcommand{\argmin}{\operatorname*{arg\,min}} \newcommand{\argmax}{\operatorname*{arg\,max}} $$

A horizontal flip

import os

import jax 
from jax import numpy as jnp
from jax import random as jr

from matplotlib import pyplot as plt
from skimage.io import imread, imsave
img = imread('./The_Cat.jpg')
plt.figure(figsize=(6, 8))
plt.imshow(img)

seed = 42
key = jr.key(seed)
std_noise = jr.normal(key, img.shape)
std_noise.min(), std_noise.max(), std_noise.mean(), std_noise.std()
(Array(-5.012798, dtype=float32),
 Array(4.638076, dtype=float32),
 Array(5.15859e-05, dtype=float32),
 Array(1.0000969, dtype=float32))
noise = 0.5 + 0.1 * std_noise
noise.min(), noise.max(), noise.mean(), noise.std()
(Array(-0.00127977, dtype=float32),
 Array(0.9638076, dtype=float32),
 Array(0.5000051, dtype=float32),
 Array(0.1000097, dtype=float32))
plt.imshow(noise)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0012797713..0.9638076].

image = img / 255.0
new_image = image + noise
new_image.min(), new_image.max(), new_image.mean(), new_image.std()
(Array(0.11410853, dtype=float32),
 Array(1.8716927, dtype=float32),
 Array(0.95421576, dtype=float32),
 Array(0.26887897, dtype=float32))
plt.imshow(new_image)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [0.11410853..1.8716927].

image_flipped = image[:, ::-1, :]
plt.imshow(image_flipped)

image_flipped = jnp.fliplr(image)
plt.imshow(image_flipped)

Performing random augmentations

key
Array((), dtype=key<fry>) overlaying:
[ 0 42]
key1, key2 = jr.split(key)
key1, key2
(Array((), dtype=key<fry>) overlaying:
 [1832780943  270669613],
 Array((), dtype=key<fry>) overlaying:
 [  64467757 2916123636])
def add_noise(image, rng_key):
    noise = 0.5 + 0.1 * jr.normal(rng_key, image.shape)
    new_image = image + noise
    new_image = (new_image - new_image.min()) / (new_image.max() - new_image.min())
    return new_image
augmentations = [
    add_noise,
    lambda x, key: jnp.fliplr(x),
]
def random_augmentation(image, augmentations, rng_key):
    key1, key2 = jr.split(rng_key)
    augmentation_idx = jr.randint(key=key1, minval=0, 
                                  maxval=len(augmentations), shape=())
    augmented_image = jax.lax.switch(
        augmentation_idx, augmentations, image, key2
    )
    return augmented_image
key = jr.key(4242)
image_aug = random_augmentation(image, augmentations, key)
plt.imshow(image_aug)

Numpy PRNG

import numpy as np
from numpy import random
from numpy.random import default_rng
rng = default_rng()
vals = rng.normal(loc=0.5, scale=0.1, size=(3, 5))
more_vals = rng.normal(loc=0.5, scale=0.1, size=(3, 5))
vals, more_vals
(array([[0.45879781, 0.43006299, 0.51154182, 0.50356447, 0.36374889],
        [0.47998945, 0.72376906, 0.52692486, 0.63513868, 0.48407395],
        [0.34634252, 0.44499043, 0.50929666, 0.52215354, 0.64615907]]),
 array([[0.41411929, 0.61371454, 0.53851603, 0.52663525, 0.57510088],
        [0.32308464, 0.48213926, 0.68905328, 0.40801101, 0.58335347],
        [0.44979266, 0.4755453 , 0.53730226, 0.56361181, 0.57094472]]))

Controlling seed

random.seed(42)
vals = random.normal(loc=0.5, scale=0.1, size=(3, 5))
random.seed(42)
more_vals = random.normal(loc=0.5, scale=0.1, size=(3, 5))
random.seed(42)
even_more_vals = np.array(
    [random.normal(loc=0.5, scale=0.1) for _ in range(3*5)]
).reshape((3, 5))
vals
array([[0.54967142, 0.48617357, 0.56476885, 0.65230299, 0.47658466],
       [0.4765863 , 0.65792128, 0.57674347, 0.45305256, 0.554256  ],
       [0.45365823, 0.45342702, 0.52419623, 0.30867198, 0.32750822]])
more_vals
array([[0.54967142, 0.48617357, 0.56476885, 0.65230299, 0.47658466],
       [0.4765863 , 0.65792128, 0.57674347, 0.45305256, 0.554256  ],
       [0.45365823, 0.45342702, 0.52419623, 0.30867198, 0.32750822]])
even_more_vals
array([[0.54967142, 0.48617357, 0.56476885, 0.65230299, 0.47658466],
       [0.4765863 , 0.65792128, 0.57674347, 0.45305256, 0.554256  ],
       [0.45365823, 0.45342702, 0.52419623, 0.30867198, 0.32750822]])

Looking at the state:

random.seed(42)
state = random.get_state()
type(state), state[0]
(tuple, 'MT19937')
state[1].shape
(624,)
state[2], state[3], state[4]
(624, 0, 0.0)
state
('MT19937',
 array([        42, 3107752595, 1895908407, 3900362577, 3030691166,
        4081230161, 2732361568, 1361238961, 3961642104,  867618704,
        2837705690, 3281374275, 3928479052, 3691474744, 3088217429,
        1769265762, 3769508895, 2731227933, 2930436685,  486258750,
        1452990090, 3321835500, 3520974945, 2343938241,  928051207,
        2811458012, 3391994544, 3688461242, 1372039449, 3706424981,
        1717012300, 1728812672, 1688496645, 1203107765, 1648758310,
         440890502, 1396092674,  626042708, 3853121610,  669844980,
        2992565612,  310741647, 3820958101, 3474052697,  305511342,
        2053450195,  705225224, 3836704087, 3293527636, 1140926340,
        2738734251,  574359520, 1493564308,  269614846,  427919468,
        2903547603, 2957214125,  181522756, 4137743374, 2557886044,
        3399018834, 1348953650, 1575066973, 3837612427,  705360616,
        4138204617, 1604205300, 1605197804,  590851525, 2371419134,
        2530821810, 4183626679, 2872056396, 3895467791, 1156426758,
         184917518, 2502875602, 2730245981, 3251099593, 2228829441,
        2591075711, 3048691618, 3030004338, 1726207619,  993866654,
         823585707,  936803789, 3180156728, 1191670842,  348221088,
         988038522, 3281236861, 1153842962, 4152167900,   98291801,
         816305276,  575746380, 1719541597, 2584648622, 1791391551,
        3234806234,  413529090,  219961136, 4180088407, 1135264652,
        3923811338, 2304598263,  762142228, 1980420688, 1225347938,
        3657621885, 3762382117, 1157119598, 2556627260, 2276905960,
        3857700293, 1903185298, 4258743924, 2078637161, 4160077183,
        3569294948, 2138906140, 1346725611, 1473959117, 2798330104,
        3785346335, 4103334026, 3448442764, 1142532843, 4278036691,
        3071994514, 3474299731, 1121195796, 1536841934, 2132070705,
        1064908919, 2840327803,  992870214, 2041326888, 2906112696,
        4182466030, 1031463950,  703166484,  854266995, 4157971695,
        4071962029, 2600094776, 2770410869, 3776335751, 2599879593,
        2451043853, 2223709058, 2098813464, 4008111478, 2959232195,
        3072496064, 2498909222, 4020139729,  785990520,  958060279,
        4183949075, 2392404465,  533774465, 4092066952, 3967420027,
        1726137853, 2907699474, 3158758391, 1460845905, 1323598137,
        2446717890, 3004885867, 3447263769, 1378488047, 3172418196,
         652839901, 1695052769,  226007057,  778836071, 1216725078,
         655651335, 1850195064,  427367795,  800074262, 2241880422,
        1713434925,  339981078, 1730571881,  672610244, 1952245009,
        2729177102, 3516932475, 4032720152, 3177283432,  411893652,
        2440235559, 3587427933,   43170267,   39225133, 3904203400,
        1935961247, 3843123487, 1625453782, 1337993374, 2095455879,
        3402219947,  634671126,   70868861, 3072823841,  851862432,
        1828056818, 2794213810, 1222863684, 2164539406, 4249334162,
        1380362252, 1512719097, 2773165233, 4063118969, 3041859837,
         529421431,  563872464, 2478730478, 3168749051, 4132953373,
        3922807735, 1124217574, 1970058502, 1744120743, 1906315107,
        1074758800, 1611130652, 2878846041,  886823888, 1175456250,
        1669874674, 2428820171, 1044308794, 3841962192,  138850094,
        1239727126, 1753711876, 2194286827,  872797664, 4276240980,
         690338888, 4087206238, 2279169960, 1117436170, 3344885072,
        3127829945,  315537090, 3802787206, 4157203318, 1637047079,
        3774106877, 3230158646, 1855823338, 1931415993,  667252379,
        4288528171, 1587598285, 1096793218, 1916566454,  101891899,
        2354644560, 3351208292, 1467125166, 2177732119, 4122299478,
        3904084887, 2653591155, 4201043109, 2867379343, 2660555187,
        3641744616, 4126452939,  326579197, 2697259239, 3365236848,
        3007834487, 4118919490, 3306741951, 2285455175, 1956645973,
        1879691841,  891565150, 1843460149, 2013381028,  819311674,
         123282948, 1436558519, 1154343666,  206804484, 1650349242,
        2142011886,  304163699, 2608574600, 2500624796, 2996744833,
        2344192475, 3152512202,  165571606,  691170269, 1806226529,
         568535825, 1243813863, 3068953841, 3843784723, 1540495237,
        4246006858, 1303595780, 3288680241,  864868851,  819595545,
        3230857496, 3574119395, 1545404573, 2970139338, 4292786727,
        1803072884, 1374565738, 1736333177, 1978645403, 3962597126,
        1068006206, 3458125500,  168085922, 1597587506, 2052497512,
        1323596727, 2421372441, 1468386547, 3574947527, 3363915938,
         860279252, 1309097460, 3065417722, 1490716202, 3476091722,
        1669402145,  895071221, 1432690175, 3353592973,  149850974,
        2789493615,  826939483,  666980418,  755367270, 3988951195,
          21783894, 1924727373, 1699517788, 1152431122, 2593798113,
        3522529522, 2797535609, 4018366956, 2350035889, 3010507270,
        2832621820,  627979167,  997422629,  365587204, 2302500352,
        1720920631,  689999548, 3713985947, 3267499624, 1971264680,
        1981530399, 1662926921, 1833821660, 1422522022, 3141447769,
        2727954526, 4172728772, 1787436028, 1902276939, 3145551277,
        4207627911, 2497093521, 4111966589, 3929089589, 2253454030,
        1069424637, 2165048659, 2848813944, 2435898022, 2546206777,
        3864777677, 3107311565, 3776562483, 1040285049, 3171631943,
        2404677828, 2522848682, 2930777301, 2831905121, 1436989598,
         602730315,  664177960, 3959954010, 3116042160, 2881899726,
         233404945, 4058465099, 1781994751,  485046222, 2776777695,
         432082123, 1989128370,   86344507, 2510576356, 2194076764,
        1742125237, 3715839140,  895100548,  147445686,  705462897,
        2245325113, 1052295404, 1956014786, 2916055958, 1829369612,
        2541711050, 1594343058, 3708804266,  150438233,  323857098,
         294681952,  783931535,  606075163, 2427042904,  121207604,
        3943199031, 1196785464, 1818211378, 1788241109, 3138862427,
        2037307093, 2306750301, 1644605749,  165986111,  542190743,
         486828112, 1757411662,  894543082, 4108143634, 1232805238,
        3801632949, 3863166865,  713767006, 2091486427, 3174776264,
        1157004409,  623072544, 1667151721, 3361539538,  696723008,
        3247069452,  682044344, 1382136166, 1385645682, 4219951151,
        2747881261, 2489355869,  786564174, 2040230554, 2967874556,
        1414286092, 2677969656, 1393412218, 2216095072,  935533444,
        3662643439, 3285199608, 3103672804,  522796956, 3952383595,
        1928659176, 3397717710, 4278554051, 1984736931, 3559102926,
        1878353094,  875578217, 2398931796, 2313634006, 1606027661,
        2790634022, 2334166559, 1857067101,  666458681, 1626872683,
        2155121857,  715449823, 1865157100, 2938814835, 4084911240,
          45488075, 3474982924, 1750873825, 2246019159,  125388929,
        1110287838,  652200437, 4212247716, 2702974687, 2963764270,
         208692058, 3170393729, 1378248367,  752591527,  591629541,
        2253399388, 2402291226, 3089656189, 3202324513, 3818308310,
        2828131601, 2690672008, 3676629884, 1007739430, 4072247562,
        3574795162,  518485611, 1889402182, 3687902739, 3410263649,
        2790674620,  779455241, 3573984673, 3053204735, 4089925351,
         789980683,  476440431, 3843536868, 2400661309, 3139919094,
        1643266656,  113318754,  428163528, 2386492935, 3807242009,
         574560611, 3174039857, 3774465602, 1164640969,  455942925,
        1374407495, 2562304709, 1024844203,  521375136,  417432138,
        1203241821, 2900988280, 2841030991, 2301700751,  369508560,
        2396447808, 1891459643, 4225682708, 3930667846, 1518293357,
        2697063889, 3113075061, 2411136298, 2836361984, 4105335811,
         914081338, 2675982621, 1816939127, 1596754123, 1464603632,
        1598478676, 1318403529, 4016663081, 2106416852, 2757323084,
        2042842122, 1175184796, 2212339255, 1334626864, 3994484893,
        3938045599, 2166620630, 3036360431,  397499085,  975931950,
        1868702836, 3530424696, 3532548823, 2770836469, 3537418693,
        3344319345, 3208552526, 1771170897, 4097379814, 3761572528,
        2794194423,  706836738, 2953105956, 3446096217,  220984542,
         309619699,  223913021, 3985142640, 1757616575, 2582763607,
        4018329835, 1393278443, 4121569718, 2087146446, 4282833425,
         807775617, 1396604749, 3571181413,   90301352, 2618014643,
        2783561793, 1329389532,  836540831,   26719530], dtype=uint32),
 624,
 0,
 0.0)

jax PRNG

key = jr.key(42)
type(key)
jax._src.prng.PRNGKeyArray
key
Array((), dtype=key<fry>) overlaying:
[ 0 42]
vals = jr.normal(key, shape=(3,5))
vals
Array([[-0.02830462,  0.46713185,  0.29570296,  0.15354592, -0.12403282],
       [ 0.21692315, -1.4408789 ,  0.7558599 ,  0.52140963,  0.9101704 ],
       [-0.3844966 ,  1.1398233 ,  1.4457862 ,  1.0809066 , -0.05629321]],      dtype=float32)
more_vals = jr.normal(key, shape=(3,5))
more_vals
Array([[-0.02830462,  0.46713185,  0.29570296,  0.15354592, -0.12403282],
       [ 0.21692315, -1.4408789 ,  0.7558599 ,  0.52140963,  0.9101704 ],
       [-0.3844966 ,  1.1398233 ,  1.4457862 ,  1.0809066 , -0.05629321]],      dtype=float32)

Splitting the key to produce multiple keys

key = jr.key(42)
key1, key2 = jr.split(key, num=2)
key
Array((), dtype=key<fry>) overlaying:
[ 0 42]
vals = jr.normal(key1, shape=(3, 5))
more_vals = jr.normal(key2, shape=(3, 5))
key1, key2
(Array((), dtype=key<fry>) overlaying:
 [1832780943  270669613],
 Array((), dtype=key<fry>) overlaying:
 [  64467757 2916123636])
vals, more_vals
(Array([[ 0.07592554, -0.48634264,  1.2903206 ,  0.5196119 ,  0.30040437],
        [ 0.31034866,  0.5761609 , -0.8074621 , -1.9883217 ,  0.6395295 ],
        [ 0.21763174,  0.00247425,  1.6645706 ,  0.20313536, -0.02138225]],      dtype=float32),
 Array([[ 0.60576403,  0.7990441 , -0.908927  , -0.63525754, -1.2226585 ],
        [-0.83226097, -0.47417238, -1.2504351 , -0.17678244, -0.04917514],
        [-0.41177532, -0.39363015,  1.3116323 ,  0.21555556,  0.41164538]],      dtype=float32))

Generating many keys at once

key = jr.key(42)
key, *subkeys = jr.split(key, num=100)
key
Array((), dtype=key<fry>) overlaying:
[1832780943  270669613]
len(subkeys)
99
type(subkeys)
list

Generating new keys using fold_in()

key = jr.key(42)
for i in range(5):
    new_key = jr.fold_in(key, i)
    print(new_key)
    vals = jr.normal(new_key, shape=(3, 5))
    # do something with values
Array((), dtype=key<fry>) overlaying:
[1832780943  270669613]
Array((), dtype=key<fry>) overlaying:
[  64467757 2916123636]
Array((), dtype=key<fry>) overlaying:
[2465931498  255383827]
Array((), dtype=key<fry>) overlaying:
[3134548294  894150801]
Array((), dtype=key<fry>) overlaying:
[2954079971 3276725750]

Using a string to fold into a new key

import hashlib

def my_hash(s):
    return int(hashlib.sha1(s.encode()).hexdigest()[:8], base=16)

some_string = 'layer7_2'
some_int = my_hash(some_string)
some_int
2649017889
key = jr.key(42)
new_key = jr.fold_in(key, some_int)
new_key
Array((), dtype=key<fry>) overlaying:
[3110527424 3716265121]